Spark ML 15.协同过滤


1.协同过滤

算法介绍:

协同过滤常备用于推荐系统。 这类技术目标在于填充“用户 - 商品”联系矩阵中的缺失项。 Spark.ml目前支持基于模型的协同过滤, 其中用户和商品以少量的潜在因子来描述, 用以预测缺失项。 Spark.ml使用交替最小二乘(ALS) 算法来学习这些潜在因子。

注意基于 DataFrame 的ALS接口目前仅支持整数类型的用户和商品编号。

显示与隐式反馈

基于矩阵分解的协同过滤的标准方法中, “用户 - 商品” 矩阵中的条目是用户给予商品的显示偏好, 例如, 用户给电影评级。 然而在现实世界中使用时, 我们常常只能访问隐式反馈(如意见、点击、购买、喜欢以及分享等等) ,在spark.ml中我们使用 “隐式反馈数据集的协同过滤” 来处理这类数据。本质上来说它不是直接对评分矩阵进行建模, 而是将数据当做数值来看待, 这些数值代表用户行为的观察值(如点击次数, 用户观看一部电影的持续时间)。 这些数值被用来衡量用户偏好观察值的置信水平, 而不是显式地给商品一个评分。 然后, 模型用来寻找可以预测用户对商品预期偏好的潜在因子。

正则化参数

我们调整正则化参数regParam 来解决用户在更新用户因子时产生新评分或者商品更新商品因子时受到的新评分带来的最小二乘问题。 这个方法叫做“ALS-WR” 它降低regParam 对数据集规模的依赖, 所以我们可以从部分子集中学习到的最佳参数应用到整个数据集中时获得同样的性能。

2. 参数说明

参数名称 类型 说明
userCol 字符串 用户列名
itemCol 字符串 商品编号列名
predictionCol 字符串 预测结果列名
ratingCol 字符串 评分列名
implicitPrefs 布尔型 特征列名
alpha 双精度 隐式偏好中的alpha参数(非负)
checkpointInterval 整数 设置检查点间隔(>=1),或不设置检查点(-1)
maxIter 整数 迭代次数(>=0)
nonnegative 布尔 是否需要非负约束
numItemBlocks 整数 商品数目(正数)
numUserBlocks 整数 用户数目(正数)
rank 整数 分解矩阵的排名(正数)
regParam 双精度 正则化参数(>=0)
seed 长整型 随机种子

3. 使用示例

调用示例:

下面的例子中, 我们从 MovieLens dataset 读入评分数据, 每一行包括用户, 电影、评分以及时间戳。 我么你默认其排序是显示的来训练ALS模型。 我们通过预测评分的均方根误差来评价推荐模型。 如果评分矩阵来自其他信息来源, 也可将implicitPrefs设置为true 来获得更好的结果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SparkSession

object ALS extends App{
val conf = new SparkConf().setAppName("a_ALS")
//设置master local[4] 指定本地模式开启模拟worker线程数
conf.setMaster("local[4]")
//创建sparkContext文件
val sc = new SparkContext(conf)
val spark = SparkSession.builder().getOrCreate()
sc.setLogLevel("Error")

import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.recommendation.ALS
import spark.implicits._
case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long)
def parseRating(str: String): Rating = {
val fields = str.split("::")
assert(fields.size == 4)
Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong)
}

val ratings = spark.read.textFile("D:\\data\\sample_movielens_ratings.txt")
.map(parseRating)
.toDF()
ratings.show(3)
/
+------+-------+------+----------+
|userId|movieId|rating| timestamp|
+------+-------+------+----------+
| 0| 2| 3.0|1424380312|
| 0| 3| 1.0|1424380312|
| 0| 5| 2.0|1424380312|
+------+-------+------+----------+
/
val Array(training, test) = ratings.randomSplit(Array(0.8, 0.2))

// Build the recommendation model using ALS on the training data
val als = new ALS()
.setMaxIter(5)
.setRegParam(0.01)
.setUserCol("userId")
.setItemCol("movieId")
.setRatingCol("rating")
val model = als.fit(training)

// Evaluate the model by computing the RMSE on the test data
val predictions = model.transform(test)

val evaluator = new RegressionEvaluator()
.setMetricName("rmse")
.setLabelCol("rating")
.setPredictionCol("prediction")
val rmse = evaluator.evaluate(predictions)
println(s"Root-mean-square error = $rmse")
//Root-mean-square error = 1.7213189727198661
}

文章作者: hnbian
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 hnbian !
评论
  目录