推荐系统 教程

推荐系统 召回

推荐系统 笔记

基于 kv(如 redis)召回源,基于召回策略配置的多路策略 java 实现

推荐系统 笔记 推荐系统 笔记


推荐系统的策略召回往往有很多种大策略,如个性化的策略(协同过滤等等)、探索策略(新内容等)、冷启动、地域热门、兴趣召回等等。

背景

如上的召回策略的数据基于 kv 数据库存储,这里以 redis 为例;在实际的推荐系统中,一次的召回由多个子策略组成,这需要我们根据策略 id 配置相应子策略的召回比例,不足的子策略 item 个数,需要分摊给其他子策略来补充,若还不够数,还需通过打底数据来补充。

配置设计

根据 AB 分流策略,命中的流量策略采用如下召回策略配置:

{
  "strategy_id": 188, //  召回策略 id
  "news": [ //  新闻类
    {
      "key": "personalized_{uid}",  //  个性化
      "id": "2",
      "ratio": 0.35
    },
    {
      "key": "explore", //  探索类
      "id": "0",
      "ratio": 0.05
    },
    {
      "key": "cold_start",  //  冷启动
      "id": "100",
      "ratio": 0.1
    },
    {
      "key": "follow",  //  关注
      "id": "87",
      "ratio": 0.15
    },
    {
      "key": "hot", //  热门
      "id": "66",
      "ratio": 0.35
    }
  ],
  "video": [  //  视频
    {
      "key": "personalized_{uid}",
      "id": "10",
      "ratio": 0.5
    },
    {
      "key": "personalized_{uid}",
      "id": "51",
      "ratio": 0.1
    },
    {
      "key": "hot",
      "id": "48",
      "ratio": 0.35
    },
    {
      "key": "live",  //  直播
      "id": "43",
      "ratio": 0.05
    }
  ],
  "photo": [  //  图片
    {
      "key": "hot",
      "id": "25",
      "ratio": 0.75
    },
    {
      "key": "explore",
      "id": "102",
      "ratio": 0.25
    }
  ]
}

上图配置中的 key 代表一类召回策略,对应 redis 哈希结构的 hash_key,id 为子策略,对应 redis 哈希结构的 field,ratio 表示召回一路中的条数占比。

多路召回策略实现

召回配置信息类,如下:

import lombok.Data;

/**
 * 召回配置信息 VO
 **/
@Data
public class RatioBo {

    private String key;
    private String id;
    private float ratio;

}

召回配置的实际召回条数调节、获取配置信息,采用召回配置信息包装类,如下:

import java.util.*;

/**
 * 召回配置信息包装类-统一处理
 **/
public class RatioBoWrapper {

    private static final int fetchLimit = 50;

    private Map<String, RatioBo> subStrategyRatioBoMap = new HashMap<>();
    private Set<String> fetchedSubStrategy = new HashSet<>();
    private int fetchCount = 0;

    /**
     * 配置初始化
     */
    public void init(List<RatioBo> ratioBoList) {
        for (RatioBo ratioBo : ratioBoList) {
            subStrategyRatioBoMap.put(ratioBo.getId(), ratioBo);
        }
    }

    /**
     * 归一化比例
     */
    public void refreshRatioBo() {
        float totalLeftRatio = 0f;
        for (Map.Entry<String, RatioBo> entry : subStrategyRatioBoMap.entrySet()) {
            String key = entry.getKey();
            if (fetchedSubStrategy.contains(key)) {
                continue;
            }
            RatioBo ratioBo = entry.getValue();
            float ratio = ratioBo.getRatio();
            totalLeftRatio += ratio;
        }

        if (totalLeftRatio > 0) {
            for (Map.Entry<String, RatioBo> entry : subStrategyRatioBoMap.entrySet()) {
                String key = entry.getKey();
                if (fetchedSubStrategy.contains(key)) {
                    continue;
                }
                RatioBo ratioBo = entry.getValue();
                float ratio = ratioBo.getRatio();
                ratioBo.setRatio(ratio / totalLeftRatio);
            }
        }

    }

    /**
     * 拉取一个配置
     */
    public RatioBo fetchRatioBo() {
        if (fetchCount >= fetchLimit) {// 容错保护
            return null;
        }
        try {
            for (Map.Entry<String, RatioBo> entry : subStrategyRatioBoMap.entrySet()) {
                String subStrategyId = entry.getKey();
                if (fetchedSubStrategy.contains(subStrategyId)) {
                    continue;
                }
                RatioBo ratioBo = entry.getValue();
                fetchedSubStrategy.add(subStrategyId);
                return ratioBo;
            }

            return null;
        } finally {
            fetchCount++;
        }

    }

}

融合上述的两个类,召回 java 实现样例如下:

public List<RecallBo> recallCandidateInfo(RecallContext recallContext, int contentType, int reqSize, List<RatioBo> ratioBoList) {

        FeedRecRecallContext feedRecRecallContext = (FeedRecRecallContext) recallContext;
        HomeFeedRequest homeFeedRequest = feedRecRecallContext.getHomeFeedRequest();
        long uid = homeFeedRequest.getUid();

        //  曝光去重
        Set<Long> recommendedItemSet = recommendedItemSet(uid, contentType, feedRecRecallContext.getFilterDuplicateDays());

        //  全局黑名单
        Set<Long> globalBlackListSet = localCache.fetchGlobalBlackListSet(contentType);

        //  个人负反馈
        Set<Long> personalDegenerativeFeedbackItemSet = personalDegenerativeFeedbackItemSet(uid, contentType, feedRecRecallContext.getFilterDuplicateDays());

        RatioBoWrapper ratioBoWrapper = new RatioBoWrapper();

        /**
         * 初始比例归一化
         */
        ratioBoWrapper.init(ratioBoList);
        ratioBoWrapper.refreshRatioBo();

        List<RecallBo> recallBoList = new ArrayList<>();

        int leftTotalSize = reqSize;
        RatioBo ratioBo;
        while ((ratioBo = ratioBoWrapper.fetchRatioBo()) != null) {
            float ratio = ratioBo.getRatio();
            String subRecallIdStr = ratioBo.getId();
            String key = ratioBo.getKey();
            if (key.contains("{uid}")) {//  个性化
                key = key.replace("{uid}", String.valueOf(uid));
            }
            int wantSize = (int) Math.ceil((leftTotalSize * ratio));
            if (wantSize <= 0) {//  最少召回1条
                wantSize = 1;
            }
            List<RecallBo> originalList;
            if (key.equals("hot")) {
                originalList = localCache.getBkDataNew().get(subRecallIdStr);
            } else if (key.equals("cold_start")) {
                originalList = localCache.getCsDataNew().get(subRecallIdStr);
            } else if (key.equals("explore")) {
                originalList = localCache.getExpDataNew().get(subRecallIdStr);
            } else {
                Object object = recallRedisTemplate.opsForHash().get(key, subRecallIdStr);
                if (object == null) {
                    continue;
                }
                String result = (String) object;
                originalList = convertRecallBoFromString(Integer.parseInt(subRecallIdStr), result);
            }

            if (CollectionUtils.isEmpty(originalList)) {
                continue;
            }

            int fetchedSize = 0;
            for (RecallBo originalBo : originalList) {
                Long id = originalBo.getId();
                if (recommendedItemSet.contains(id)) {
                    continue;
                }
                if (globalBlackListSet.contains(id)) {
                    continue;
                }
                if (personalDegenerativeFeedbackItemSet.contains(id)) {
                    continue;
                }
                fetchedSize++;
                /**
                 * 如果能保证线程安全(确保后续对 RecallBo 不会 write 操作)的话,可以去掉深拷贝
                 */
                RecallBo deepCopy = new Cloner().deepClone(originalBo);
                recallBoList.add(deepCopy);
                if (fetchedSize >= wantSize) {
                    break;
                }
            }

            leftTotalSize -= fetchedSize;
            if (leftTotalSize <= 0) {
                break;
            }

            //  归一化剩下的比例
            ratioBoWrapper.refreshRatioBo();
        }

        if (recallBoList.size() < reqSize) {//  打底召回
            int left = reqSize - recallBoList.size();
            String subRecallIdStr = RedisConstants.ES_CONTENT_TYPE_ID_MAP.get(contentType);
            if (StringUtils.isNotBlank(subRecallIdStr)) {
                Set<RecallBo> esDataSet = localCache.getEsDataNew().get(subRecallIdStr);
                if (CollectionUtils.isNotEmpty(esDataSet)) {
                    int count = 0;
                    for (RecallBo originalBo : esDataSet) {
                        Long id = originalBo.getId();
                        if (recommendedItemSet.contains(id)) {
                            continue;
                        }
                        if (globalBlackListSet.contains(id)) {
                            continue;
                        }
                        if (personalDegenerativeFeedbackItemSet.contains(id)) {
                            continue;
                        }
                        count++;
                        /**
                         * 如果能保证线程安全(确保后续对 RecallBo 不会 write 操作)的话,可以去掉深拷贝
                         */
                        RecallBo deepCopy = new Cloner().deepClone(originalBo);
                        recallBoList.add(deepCopy);
                        if (count >= left) {
                            break;
                        }
                    }
                }
            }
        }

        return recallBoList;

    }