「机器学习论文解读」:推荐算法-xDeepFM

一起看看CIN的巧妙之处

Posted by Wenqian on April 11, 2020

论文简介

这篇文章发表于KDD 18,对DeepFM的FM端进行了改进(也可以理解为是对现有的显式交叉特征构建+DNN结构的这些算法的一种改进)。其最大的贡献在于提出了CIN(Compressed Interation Network)来显式地生成向量维度(vector-wise)的交叉特征。

向量维度(vector-wise)

什么是向量维度?

在介绍CIN之前,我们需要先搞清楚作者说的vector-wise是什么。以我个人的理解,向量维度的交叉特征指的是在向量和向量之间生成交叉特征,这里的向量一般是指embedding向量或者隐向量。与之相对的,bit-wise则是指在上述向量的元素之间生成交叉特征。

具体来说,FM可以被认为是向量维度的特征交叉,因为特征间的交叉其实就是隐向量的内积;而DCN的交叉网络可以理解为是比特维度的特征交叉,因为那里面的交叉的单位是embedding向量(或者后续中间向量)的某一项(DCN用的是外积)。另外,像FNN直接将embedding的结果输入全连接层,因此也是典型的比特维度的特征交叉。

为什么要生成向量维度的交叉特征?

我个人理解是向量维度相比比特维度粒度更粗,由于不同向量代表不同特征,因此生成的无用交叉特征数量会更少,可解释性更强;相比比特维度可能没那么容易过拟合(虽然在某些场景下还是很容易过拟合)。

CIN的结构

我们先来看一下CIN的整体结构是什么样的。

img

一个直观的感觉就是:有点像RNN,只不过输入永远是X0。下面的X0是原始特征经过embedding之后的结果,其中m表示原始特征数,D表示每个特征的embedding向量的维度。从上面也可以看到每一层的结果都是由上一层的结果和x0做交叉得到的。这些结果经过Sum pooling后被拼接到一起,因此可以得到k+1维度下的所有交叉特征(类似于DCN中的xl)。另外可以注意到,每一层的宽度是不一样的(即图中的Hk),那么Xk是到底是怎么和X0做交叉的呢?其原理就是下面的公式: img

h表示Hk中的某一行,o符号表示Hadamard product,即对位乘法。例如<a1, a2, a3>o<b1, b2, b3> = <a1b1, a2b2, a3b3>。W是我们要学习的参数矩阵。通过上面的式子我们可以看到,前一层的结果其实通过和X0进行交叉后变成了一个向量(维度还是D),并且由于对位乘法的特点,第n维(0<n<=d)的结果只包含Xk-1和X0的第n维特征的交叉。对于这一个维度来说,其结果可以看成是Xk-1和X0的对应维度的外积。放大到整个向量,那么可以理解为是矩阵粒度的“外积”,即作者所说的向量维度的交叉。通过作者给出的这张图可能能看得更清楚一些: img

我们也可以把上图中的Z^k+1看成是一个D通道的图像,那么W矩阵就可以看成是一种特殊的滤波器。每个滤波器滤波后的结果(也就是前面说的那个向量)就对应于CV中的特征图(feature map)。所以Xk可以看做是Hk个不同的特征图的集合。这里把原始的Hk-1*m个向量压缩成了Hk个向量(维度为D),所以CIN里面有“Compressed”,即对空间进行了压缩。下面这个图可能更直观一些: img

CIN相关分析

复杂度分析

从空间复杂度来说,CIN的参数量小于DNN,是一个可接受的量级。如果采用矩阵分解对W矩阵进行拆分,那么空间复杂度会更低。不过CIN的时间复杂度高于DNN,这也是其最大的一个问题。

交叉项参数个数分析

对于一个传统的k阶多项式,其拥有O(m^k)个参数,不过CIN只有km^3个参数。因此和DCN或者FM一样,这里其实也用到了参数共享的思想

和DeepFM的关系

当CIN的深度以及feature map的个数(H1)都设置为1时,xDeepFM其实就是DeepFM的一种泛化(generalization),只不过在DeepFM中,FM层的单元是直接连接到输出的,而这里会引入参数(即下面的Wcin): img

如果不考虑DNN部分,并且同时使用一个constant sum filter(即直接把输入相加,而不引入任何参数)对特征图的结果进行累加,那么其实xDeepFM就变成了一个传统的FM。这是因为对位乘+sum其实就是内积。

CIN每层的Hk如何设置

Hk在CIN的结构里有很大的灵活度,每一层都可以不一样。但这也带来了一个问题,那就是怎么调这个参数。从作者的实验来看,这东西还是和数据集有关,也就是很玄。有的数据集上可能越多越好,而有的则可能多了也没什么提升。不过需要注意的是,CIN的空间复杂度(假如不用矩阵分解)和时间复杂度都和H呈平方关系,所以H还是要尽量小一些

整体结构

最后我们再来看一下整个网络的结构:

img

原始输入经过embedding后分别输入CIN和DNN,得到两者的结果后与原始输入一起放入逻辑回归中得到y的概率。可以看到,这个结构和DeepFM很类似,区别只在于CIN部分。这也说明CIN是这篇文章的精髓所在。