搜索系统 基础教程

搜索 query 分析

搜索系统 索引教程

搜索系统 高级教程

搜索系统 排序层

搜索系统 笔记

搜索系统 排序层


搜索系统的排序层又称为精排层,主要是基于离线训练好的模型,结合模型所用的特征给每个 item 进行打分,然后根据分数进行降序排序。搜索排序模型在被训练时,有指定的目标,根据目的可以有单个目标,主要有预测点击率(ctr)、预测转化率(cvr)、预测停留时长(dwell time)等,也可以多目标进行训练。根据笔者的相关实际经验,下面主要列出了 xgboost、lr 及 深度学习框架下的模型的工程实践。

样本构造

排序层主要用到离线模型,该离线模型基于在线特征打印客户端相关行为埋点进行 join 之后生成的样本进行训练的。

在线特征打印

在调用排序层时,除了打分还要进行在线特征的打印(若初次没有模型时,只有在线特征打印),作为后续关联行为后,作为样本。各每个带排序的 item 进行在线特征打印,即一个 item 一行特征打印记录,设计可参考如下结构:

{
  "scene": "sports_search",                     //  场景标识
  "request_key": "${request_id}_${item_id}",    //  关联 key
  "rank_model_name": "XGB_SPORTS_CTR_A2",       //  模型名称
  "rank_model_score": 0.06298167354,            //  模型最终得分
  "rank_model_score_map": {                     //  模型分数映射
    "XGB_SPORTS_CTR_A2-SCORE": 0.06298167354,   //  具体模型得分
    "XGB_SPORTS_CTR_A2-V": "v_2019-08-12_04:10" //  具体模型版本
  },
  "context": {                                  //  上下文特征
    "timestamp": 1554480000,
    "os_sys": "android",
    "os_version": "9.0",
    "app_version": "8.18.8",
    "query_str": "天气",
    ...
  },
  "user": {                                     //  用户特征
    "uid": 119911,
    "age": 26,
    "sex": 1,
    "query_str_list": [{
      "query_str": "王者荣耀",
      "time": 1575944297
    },{
      "query_str": "凯打法",
      "time": 1575944291
    }],
    "click_list": [{
      "item_id": 123456,
      "time": 1575944297
    },{
      "item_id": 654321,
      "time": 1575944291
    }],
    ...
  },
  "item": {                                     //  物品特征
    "item_id": 13579,
    "ctr": 0.03592834,
    ...
  }
}

如上结构为笔者实际项目中的特征打印结构,特征主要有 CUI 三大部分组成,分别对应上下文(Context)、用户(User)和物品(Item),其他还包括场景标识、关联 key、模型名称、模型版本,模型打分及子模型相关映射信息(若存在子模型)等等;其中关联 key 作为特征与 label 关联时用,模型相关信息主要作为线上线下一致性校验所用。

特征打印的特征都从何而来?

上下文特征主要从请求的参数构成,而用户特征和物品特征主要从离线算好或者实时计算好的特征库中获取。

特征库的技术选型主要是 Redis,根据业务数据量及性能的要求可以选择单机或者集群版本。Redis 相关的内容可以参考 Redis 集群方案

客户端埋点

按漏斗思维和用户的行为路径将埋点进行大致分类拆解,示例如下:

埋点事件分类
具体事件 大类别 描述
曝光事件 曝光 曝光
点击事件 点击 点击
图文浏览 浏览 图文浏览
视频播放 浏览 视频播放
音频播放 浏览 音频播放
点赞事件 功能 点赞按钮
点踩事件 功能 点踩(不喜欢)按钮
收藏事件 功能 点收藏
分享事件 功能 分享
关注事件 功能 关注
取关事件 功能 取消关注
列表页停留时长事件 时长 客户端列表页停留时长
详情页停留时长事件 时长 客户端详情页停留时长

除了上述列出的典型事件外,还有一些其他事件,如刷新事件、通知事件等。

值得注意的是,上述事件的埋点信息中,需要带有在线特征打印时设计的 request_key 字段,作为后续构造样本时关联 key 所用

关联之后的样本

将如上的在线特征打印和客户端相关行为埋点进行关联后的文件示例如下:

1       {"scene":"sports_search","request_key":"6a0c763e2984b873bda82ecd9f27184c_9873481",
"rank_model_name":"XGB_SPORTS_CTR_A2","rank_model_score":0.06298167354,"rank_model_score_map":
{"XGB_SPORTS_CTR_A2-SCORE":0.06298167354,"XGB_SPORTS_CTR_A2-V":"v_2019-08-12_04:10"},"context":
{"timestamp":1554480000,"os_sys":"android","os_version":"9.0","app_version":"8.18.8"},"user":
{"uid":119911,"age":26,"sex":1,"click_list":[{"item_id":123456,"time":1575944297},
{"item_id":654321,"time":1575944291}]},"item":{"item_id":13579,"ctr":0.03592834}}
0       {"scene":"sports_search","request_key":"ddc986b218f82099e1dca76e2a5ea873_8169847",
"rank_model_name":"XGB_SPORTS_CTR_A2","rank_model_score":0.05298167354,"rank_model_score_map":
{"XGB_SPORTS_CTR_A2-SCORE":0.05298167354,"XGB_SPORTS_CTR_A2-V":"v_2019-08-12_04:10"},"context":
{"timestamp":1554490000,"os_sys":"android","os_version":"9.0","app_version":"8.18.8"},"user":
{"uid":119912,"age":27,"sex":2,"click_list":[{"item_id":1234567,"time":1575985297},
{"item_id":7654321,"time":1575984291}]},"item":{"item_id":13879,"ctr":0.05592834}}

如上所示,每行的一个值是对应的 label,如 1 是点击,0 是未点击。

构造 libsvm 文件

这里简单介绍以下 libsvm 数据格式,具体如下:

[label] [index1]:[value1] [index2]:[value2] …
[label] [index1]:[value1] [index2]:[value2] …

label:目标值,就是说 class(属于哪一类),就是你要分类的种类,通常是一些整数。

index:是有顺序的索引,通常是连续的整数。就是指特征编号,必须按照升序排列。

value:就是特征值,用来 train 的数据,通常是一堆实数组成。

下一步是需要将上述的关联之后的样本转换为 libsvm 数据格式,下面给出以 java 实现的转换代码示例:

package demo;

import ch.hsr.geohash.GeoHash;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;

import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Writer;
import java.time.Instant;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

/**
 * 特征处理
 **/
public class FeatureProcessDemo {

    //  构造的特征 key 与 特征 index 的映射
    public Map<String, Long> allFeatureIndexMap = new HashMap<>(50000);

    public static void main(String[] args) throws IOException {

        //  带有 label 的原始样本
        String originalFileName = "/ml/original_label_data";
        //  commons-io
        List<String> lines = IOUtils.readLines(new FileReader(originalFileName));

        if (CollectionUtils.isEmpty(lines)) {
            return;
        }

        //  libsvm 格式的样本
        String libsvmFileName = "/ml/libsvm_label_data";
        FeatureProcessDemo featureProcess = new FeatureProcessDemo();
        for (String line : lines) {
            String[] lineArr = line.split("\t");
            if (ArrayUtils.isNotEmpty(lineArr)) {
                //  特征处理
                String lineFeature = featureProcess.parseLineForOffline(lineArr[1]);
                if (StringUtils.isBlank(lineFeature)) {
                    // TODO: 2020-02-03 warn log
                    continue;
                }
                try {
                    Writer writer = new FileWriter(libsvmFileName, true);
                    IOUtils.writeLines(Lists.newArrayList(lineArr[0] + " " + lineFeature), "", writer);
                } catch (Exception e) {
                    // TODO: append error log
                }

            }

        }

        if (!featureProcess.allFeatureIndexMap.isEmpty()) {
            //  保存特征 key 与 index 映射
            String featureIndexFileName = "/ml/feature_index_data";
            String featureIndexData = JSON.toJSONString(featureProcess.allFeatureIndexMap);
            IOUtils.write(featureIndexData, new FileWriter(featureIndexFileName));
        }

    }

    public String parseLineForOffline(String lineJson) {

        JSONObject lineObject = JSON.parseObject(lineJson);

        Map<String, Float> contextInfoMap = parseContextInfo(lineObject);
        Map<String, Float> userInfoMap = parseUserInfo(lineObject);
        Map<String, Float> itemInfoMap = parseItemInfo(lineObject);

        Map<String, Float> featureMap = new HashMap<>();
        featureMap.putAll(contextInfoMap);
        featureMap.putAll(userInfoMap);
        featureMap.putAll(itemInfoMap);

        if (MapUtils.isNotEmpty(featureMap)) {

            //  每行的特征 index 与 value 的映射
            Map<Long, Float> lineFeatureMap = new TreeMap<>();

            for (Map.Entry<String, Float> entry : featureMap.entrySet()) {
                String key = entry.getKey();
                Float value = entry.getValue();

                Long index = allFeatureIndexMap.get(key);
                if (index == null) {
                    index = Long.valueOf(allFeatureIndexMap.size());
                    //  构造的特征 key 与 特征 index 的映射
                    allFeatureIndexMap.put(key, index);
                }
                lineFeatureMap.put(index, value);
            }

            StringBuilder stringBuilder = new StringBuilder();
            for (Map.Entry<Long, Float> entry : lineFeatureMap.entrySet()) {
                Long index = entry.getKey();
                Float value = entry.getValue();

                stringBuilder.append(index).append(":").append(value).append(" ");
            }

            return stringBuilder.toString().trim();
        }

        return null;
    }

    /**
     * 处理上下文特征
     */
    public static Map<String, Float> parseContextInfo(JSONObject lineObject) {

        Map<String, Float> map = new HashMap<>();

        JSONObject contextObj = lineObject.getJSONObject("context");

        if (MapUtils.isNotEmpty(contextObj)) {

            /**
             * 独热处理
             */

            Integer cityId = contextObj.getInteger("city_id");
            if (cityId != null) {
                map.put("c_city_id" + "_" + cityId, 1f);
            }

            String osSys = contextObj.getString("os_sys");
            if (StringUtils.isNotBlank(osSys)) {
                map.put("c_os_sys" + "_" + osSys, 1f);
            }

            String osVersion = contextObj.getString("os_version");
            if (StringUtils.isNotBlank(osVersion)) {
                map.put("c_os_version" + "_" + osSys, 1f);
            }

            String appVersion = contextObj.getString("app_version");
            if (StringUtils.isNotBlank(appVersion)) {
                map.put("c_app_version" + "_" + osSys, 1f);
            }

            // TODO: other and so on

            /**
             * geohash 处理
             */
            Double latitude = contextObj.getDouble("latitude");
            Double longitude = contextObj.getDouble("longitude");
            if (latitude != null && longitude != null) {
                GeoHash geoHash = GeoHash.withCharacterPrecision(latitude, longitude, 5);
                map.put("c_geo_hash" + "_" + geoHash.toBase32(), 1f);
            }

            /**
             * 时间戳分裂成多个特征
             */
            Integer timestamp = contextObj.getInteger("timestamp");
            if (timestamp != null) {
                Instant instant = Instant.ofEpochSecond(timestamp);
                ZonedDateTime zonedDateTime = instant.atZone(ZoneId.systemDefault());
                map.put("c_day_of_week", Float.valueOf(zonedDateTime.getDayOfWeek().getValue() - 1));
                int hourOfDay = zonedDateTime.getHour();
                map.put("c_hour_of_day", Float.valueOf(hourOfDay));
                int minutes10OfHour = zonedDateTime.getMinute() / 10;
                map.put("c_minutes10_of_hour", Float.valueOf(minutes10OfHour));
                map.put("c_minutes10_of_day", Float.valueOf(hourOfDay * 6 + minutes10OfHour));
            }

            /**
             * query 分类特征
             */
            Integer queryClassification = contextObj.getInteger("query_classification");
            if (queryClassification != null) {
                map.put("c_query_classification", Float.valueOf(queryClassification));
            }

        }

        return map;
    }

    /**
     * 处理用户特征
     */
    public static Map<String, Float> parseUserInfo(JSONObject lineObject) {

        Map<String, Float> map = new HashMap<>();

        JSONObject userObj = lineObject.getJSONObject("user");

        if (MapUtils.isNotEmpty(userObj)) {

            Long uid = userObj.getLong("uid");
            if (uid != null) {
                map.put("u_uid", Float.valueOf(uid));
            }

            Integer age = userObj.getInteger("age");
            if (age != null) {
                map.put("u_age", Float.valueOf(age));
            }

            Integer sex = userObj.getInteger("sex");
            if (sex != null) {
                map.put("u_sex", Float.valueOf(sex));
            }

            Integer lastActive = userObj.getInteger("last_active");
            if (lastActive != null) {
                map.put("u_last_active", Float.valueOf(lastActive));
            }

            // TODO: 2020-02-02 other and so on

            /**
             * 对数变换
             */
            Integer registerDay = userObj.getInteger("register_day");
            if (registerDay != null && registerDay > 0) {
                map.put("u_register_day", (float) Math.log(registerDay + 1d));
            }

            JSONArray queryStrList = userObj.getJSONArray("query_str_list");
            if (CollectionUtils.isNotEmpty(queryStrList)) {
                for (int i = 0; i < queryStrList.size(); i++) {
                    JSONObject queryStrObj = queryStrList.getJSONObject(i);
                    String queryStr = queryStrObj.getString("query_str");
                    if (StringUtils.isNotBlank(queryStr) && queryStr.length() <= 15) {
                        map.put("u_query_str_" + queryStr, 1f);
                    }
                }
            }

            JSONArray clickList = userObj.getJSONArray("click_list");
            if (CollectionUtils.isNotEmpty(clickList)) {
                for (int i = 0; i < clickList.size(); i++) {
                    JSONObject clickObj = clickList.getJSONObject(i);
                    Long itemId = clickObj.getLong("item_id");
                    if (itemId != null) {
                        map.put("u_click_" + itemId, 1f);
                    }
                }
            }

        }

        return map;
    }

    /**
     * 处理 item 特征
     */
    public static Map<String, Float> parseItemInfo(JSONObject lineObject) {

        Map<String, Float> map = new HashMap<>();

        JSONObject itemObj = lineObject.getJSONObject("item");

        if (MapUtils.isNotEmpty(itemObj)) {

            Long itemId = itemObj.getLong("item_id");
            if (itemId != null) {
                map.put("i_item_id_" + itemId, 1f);
            }

            Integer itemType = itemObj.getInteger("item_type");
            if (itemType != null) {
                map.put("i_item_type", Float.valueOf(itemType));
            }

            Double ctr = itemObj.getDouble("ctr");
            if (ctr != null) {
                map.put("i_ctr", (float) (ctr + 1));
            }

            Double ctr15 = itemObj.getDouble("ctr_15");
            if (ctr15 != null) {
                map.put("i_ctr15", (float) (ctr15 + 1));
            }

            Integer pv = itemObj.getInteger("pv");
            if (pv != null && pv > 0) {
                map.put("i_pv", (float) Math.log(pv + 1d));
            }

            Integer like = itemObj.getInteger("like");
            if (like != null && like > 0) {
                map.put("i_like", (float) Math.log(like + 1d));
            }

            Integer click = itemObj.getInteger("click");
            if (click != null && click > 0) {
                map.put("i_click", (float) Math.log(click + 1d));
            }

            // TODO: other and so on
        }

        return map;
    }

}

XGBoost

XGBoost 支持除了 libsvm 格式的数据外,还支持 numpy 数组形式及 XGBoost 的二进制的缓存文件等,这里只列出 libsvm 格式的示例,更详细的 XGBoost 的内容,参考机器学习之 XGBoost

离线训练

下面给出 XGBoost 的实战示例:

import xgboost as xgb
from sklearn.metrics import roc_auc_score
from xgboost import plot_importance
from matplotlib import pyplot as plt

params = {
    'booster': 'gbtree',  # 采用 gbt 树模型
    'objective': 'binary:logistic',  # 二分类的逻辑回归问题,输出为概率
    'eval_metric': 'auc',  # 校验数据所需要的评价指标
    'eta': 0.1,  # 如同学习率
    'gamma': 0.1,  # 用于控制是否后剪枝的参数,越大越保守,一般 0.1、0.2
    'max_depth': 6,  # 构建树的深度,越大越容易过拟合
    'alpha': 1,  # L1 正则化参数,参数越大,越不容易过拟合
    'lambda': 2,  # L2正则化参数,参数越大,越不容易过拟合
    'subsample': 1.0,  # 用于训练模型的子样本占整个样本集合的比例
    'colsample_bytree': 1.0,  # 在建立树时对特征采样的比例
    'silent': 0,  # 设置成 1,则没有运行信息输出,最好是设置为 0
    'seed': 1000,  # 随机数的种子
    'nthread': 4,  # XGBoost 运行时的线程数
    'n_jobs': 15,
    'scale_pos_weight': 1  # 取值大于0,在类别样本偏斜时,有助于快速收敛
}

# 对样本集进行训练
train_file = '/ml/libsvm_label_data'
d_train = xgb.DMatrix(train_file)
num_rounds = 150
model = xgb.train(params, d_train, num_rounds)

# 对测试集进行预测
test_file = '/ml/libsvm_label_data_test'
d_test = xgb.DMatrix(test_file)
test_pred = model.predict(d_test)

# auc
auc = roc_auc_score(d_test.get_label(), test_pred)
print('auc', auc)

# save model
model.save_model('/ml/xgb.model')

# 显示重要特征
plot_importance(model)
plt.show()