0%

双塔模型

双塔模型

本笔记为王树森发表在youtube频道推荐系统的视频课程学习随笔,笔记中图片亦截取自王树森视频课件,原始资料出处为王树森youtube视频。其中有我从其他途径得到的一些内容和我自己的见解,未必全部都准确正确。

双塔模型因为两个垂直网络象形而得名, 这两根塔分别是用户网络和物品网络。

结构

双塔模型因为两个垂直网络象形而得名, 这两根塔分别是用户网络和物品网络。实际的训练过程中,可能存在不止一根“塔”,但是只有两类。

用户网络

用户网络将用户ID做Embedding, 离散特征做Embedding, 剩余的一些连续特征先做归一化和分桶等处理方式然后concatenate到一起输送给神经网络。得到用户的隐向量表征。向少分量的离散向量例如性别之类的,就不用做embeding,直接使用onehot即可。

用户网络

物品网络

物品网络和用户网络基本一样,同样的需要把物品ID做Embedding。

我个人感觉把ID做Embedding好没有道理,本质上物品的ID,它并不直接包含任何特征,可能仅仅是为了缩小输入规模所以做embeding?

物品网络
双塔模型

双塔模型的训练

  • Pointwise: 独立看待每个正样本、负样本,做简单的二元分类
  • Pairwise:训练时使用一个用户特征向量,一正一负共两个物品特征向量。
  • Listwise:训练时使用一个用户特征向量,一正多负多个物品特征向量。

正负样本的选择

  • 正样本: 正样本的选择其实很明确,用户点击(点赞收藏转发)过的肯定是正样本。
  • 负样本:负样本按级别分几类,简单负样本,困难负样本等。

下面先从总体概览的角度上看一下样本到底有哪些,是怎么来的。

以全部的库存物品作为样本的全集,那么:

  • 没有被召回的物品,大概率就是不感兴趣的了,为负样本。

  • 被召回的物品,我们需要再分情况讨论:

    • 被曝光的且被用户点击/点赞/收藏 的,自然为正样本,这个无疑是无可置疑的。
    • 被曝光的,但是用户未点击的。
    • 未被曝光的(通过召回但却被粗排精排过滤掉的)。

Pointwise 训练

把找回看作二元分类任务。

  • 对于正样本,鼓励 \(cos(a, b)\) 接近+1。
  • 对于负样本,鼓励 \(cos(a, b)\) 接近-1。
  • 控制正负样本的数量为 1:2或者1:3 (经验数值)。

Pairwise 训练

输入是一个三元组,输入两个样本数据(一正一副),一个用户数据。 用户数据的网络输出分别与两个物品样本数据的网络输出做cos距离。这种三元训练与人脸识别三元损失中是一样的。

这里需要提一下,图中绘制了两个物品网络,实际上它们是参数共享的,应为一个结构,样本在其中传输了两次。

Pairwise训练

基本想法:鼓励\(cos(a, b^+)\) 大于\(cos(a, b^-)\)

使用cos距离时,输出值约接近1则表明两个向量越接近,而-1则表明越远。

  • 如果\(cos(a, b^+)\)大于\(cos(a, b^-) + m\),则没有损失。
  • 否则, 损失等于 \(cos(a, b^-) + m - cos(a, b^+)\)

Triplet hinge loss损失:

\[ L(a, b^+, b^-) = \max(0, cos(a, b^-) + m - cos(a, b^+))) \]

其中,向量a是用户向量,向量\(b^+\)\(b^-\)分别是物品的正负样本向量。

训练过程即使用梯度下降算法对损失函数最小化。

Triplet logistic loss损失:

\[ L(a, b^+, b^-) = \log(1+\exp[\sigma\cdot(\cos(a, b^-)-\cos(a, b^+))]) \]

Listwise训练

Listwise方法训练时,每次取一个正样本和多个负样本。

  • 一条数据包含:

    • 一个用户,特征向量记作\(a\)
    • 一个正样本,特征向量记作 \(b^+\)
    • 多个负样本,特征向量记作\(b^-_1,\cdots,b^-_n\)
  • 鼓励\(\cos(a, b^+)\) 尽量大。

Listwise

图示已经很清晰的揭示了Listwise训练的原理, 让正样本对应输出的概率接近于1,而所有负样本输出的概率接近于0,使用交叉熵损失即可进行训练。

总结

  • 用户塔和物品塔各输出一个向量。

  • 两个向量的余弦相似度作为兴趣的预估值。

  • 三种训练方式:

    • Pointwise: 每次一个用户、一个物品(可正可负)。
    • Pairwise: 每次一个用户、一个正样本、一个负样本。
    • Listwise:每次一个用户、一个正样本、多个负样本。

一个错误的网络结构示例

先上一个反例图,然后再分析其中原因。

wrong_tower

上图中,明显的,在我们获取了用户和物品的特征向量之后,又将它们拼接通入了一个神经网络获得输出。但从是否work的角度来说是没有问题的,但是结合工程实际就不行了,表现在:如果我们获得了一个用户的特征向量,想要获得最近邻的物品,我们需要把每个物品的特征向量都拿来通过神经网络计算兴趣值,往往工程中物品数量是个巨大的数字,这种做法无法满足实时性的要求。

所以,我们在召回模型中,必须能够通过用户特征可以快速的获得最近邻的物品,那么我们必须满足不能在查表之后再经过神经网络了。