1. 逻辑回归
逻辑回归(Logistic Regression)是预测分类的流程方法,它是 广义线性模型 的一个特例来预测结果分类的可能性。在spar.ml 逻辑回归中可以使用二项式逻辑回归来预测二进制结果,也可以通过多项式逻辑回归来预测多类结果。使用参数在这两种算法之间进行选择,或者不进行设置,Spark将自行推断选择合适的算法。
通过将
family
参数设置为“multinomial” 可以将多项逻辑回归用于二进制分类。它将产生两组系数和两个截距。当在具有常量非零列的数据集上对LogisticRegressionModel进行拟合时,Spark MLlib为常数非零列输出零系数。此行为与R glmnet 相同,但与LIBSVM不同。
2. 二项式逻辑回归
2.1 算法介绍
二项式逻辑回归(Binomial logistic regression)
线性最小二乘法是回归问题的最常见公式。它是 Generalized Linear models(广义线性模型) 的一个特殊应用以预测结果概率
它是一个线性模型如下列方程所示, 其中损失函数为逻辑损失: $L(w;x,y):=log(1 + exp(-yw^Tx))$
对于二分类问题,算法产出一个二值逻辑回归模型,给定一个新数据,由 x 表示,则模型通过下列逻辑方程来预测:
$f(z) = \frac{1}{1+e^{-z}}$
其中 $z=w^Tx$ ,默认情况下,如果 $f(w^Tx)> 0.5$ , 结果为正,否则为负。和线性 SVMs 不同,逻辑回归的原始输出有概率解释(x为正的概率)
2.2 参数列表
参数名称 | 类型 | 说明 |
---|---|---|
elasticNetParam | 双精度 | 弹性网格混合参数, 范围[0,1] |
family | 用于系列名称的参数,它是模型中使用的标签分布的描述默认为“auto” auto:根据类的数量自动选择系列: 如果numClasses == 1 || numClasses == 2,则设置为“binomial”。 否则,设置为“multinomial” binomial:二元逻辑回归。 multinomial:没有旋转的多项logistic(softmax)回归 |
|
featuresCol | 字符串 | 特征列名 |
fitlntercept | 布尔值 | 是否训练拦截对象 |
labelCol | 字符串 | 标签列名 |
maxlter | 整数 | 最多迭代次数(>=0) |
predictionCol | 字符串 | 预测结果列名 |
probabilityCol | 字符串 | 用以预测类别条件概率的列名 |
regParam | 双精度 | 正则化参数 (>=0) |
standardization | 布尔值 | 训练模型前是否需要对训练特征进行标准化处理 |
threshold | 双精度 | 二分类预测的阈值, 范围[0,1] |
thresholds | 双精度 | 多分类预测的阈值, 以调整预测结果在各个类别的概率 |
tol | 双精度 | 迭代算法的收敛性 |
weightCol | 字符串 | 列权重 |
2.3 代码示例
package hnbian.spark.ml.algorithms.classification
import hnbian.spark.utils.{FileUtils, SparkUtils}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
/**
* @author hnbian
* @ Description 逻辑回归二项式预测
* @ Date 2019/1/3 15:33
**/
object LogisticRegression extends App{
val spark = SparkUtils.getSparkSession("LogisticRegression",4)
import spark.implicits._
val filePath = FileUtils.getFilePath("sample_libsvm_data.txt")
val training = spark.read.format("libsvm").load(filePath)
val lr = new LogisticRegression()
.setMaxIter(10) //最大迭代次数
.setRegParam(0.3) //正则化参数
.setElasticNetParam(0.8) //正则化范式比(默认0.0),正则化一般有两种范式:L1(Lasso)和L2(Ridge)。L1一般用于特征的稀疏化,L2一般用于防止过拟合。这里的参数即设置L1范式的占比,默认0.0即只使用L2范式
//.setFamily("binomial")
.setThreshold(1.0)
// Fit the model
val lrModel = lr.fit(training)
// 打印逻辑回归的系数和截距
println(s"Coefficients: ${lrModel.coefficients} \nIntercept: ${lrModel.intercept}")
//Coefficients: (692,[244,263,272,300,301,328,350,351,378,379,405,406,407,428,433,434,455,456,461,462,483,484,489,490,496,511,512,517,539,540,568],[-7.35398352418814E-5,-9.102738505589432E-5,-1.9467430546904216E-4,-2.030064247348659E-4,-3.1476183314865005E-5,-6.842977602660699E-5,1.5883626898237813E-5,1.4023497091369702E-5,3.5432047524968963E-4,1.1443272898170924E-4,1.0016712383666388E-4,6.014109303795469E-4,2.8402481791227693E-4,-1.1541084736508769E-4,3.8599688631290956E-4,6.350195574241061E-4,-1.1506412384575594E-4,-1.5271865864986703E-4,2.8049338089942207E-4,6.070117471191611E-4,-2.0084596632474318E-4,-1.4210755792901163E-4,2.739010341160889E-4,2.7730456244968185E-4,-9.838027027269304E-5,-3.808522443517673E-4,-2.5315198008554816E-4,2.7747714770754383E-4,-2.443619763919179E-4,-0.0015394744687597863,-2.3073328411331095E-4])
//Intercept: 0.224563159612503
/**
* 获取训练集上的模型摘要。 如果`trainingSummary == None`或它是多类模型,则抛出异常。
*/
val trainingSummary = lrModel.binarySummary
// 获得每次迭代的目标
val objectiveHistory = trainingSummary.objectiveHistory
objectiveHistory.foreach(loss => println(loss))
/**
* objectiveHistory:
* 0.6833149135741672
* 0.6662875751473734
* 0.6217068546034618
* 0.6127265245887887
* 0.6060347986802872
* 0.6031750687571563
* 0.5969621534836272
* 0.5940743031983121
* 0.5906089243339023
* 0.5894724576491043
* 0.5882187775729587
*/
// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
val roc = trainingSummary.roc
roc.show()
/**
* +---+--------------------+
* |FPR| TPR|
* +---+--------------------+
* |0.0| 0.0|
* |0.0|0.017543859649122806|
* |0.0| 0.03508771929824561|
* |0.0| 0.05263157894736842|
* |0.0| 0.07017543859649122|
* |0.0| 0.3157894736842105|
* |0.0| 0.3333333333333333|
* +---+--------------------+
*/
println(s"areaUnderROC: ${trainingSummary.areaUnderROC}")
// 设置模型阈值以最大化F-Measure
val fMeasure = trainingSummary.fMeasureByThreshold
fMeasure.show(false)
/**
* +------------------+--------------------+
* |threshold |F-Measure |
* +------------------+--------------------+
* |0.7845860015371141|0.034482758620689655|
* |0.784319334416892 |0.06779661016949151 |
* |0.784297609251013 |0.1 |
* |0.7842531051133191|0.13114754098360656 |
* |0.7788060694625323|0.45945945945945943 |
* |0.7783754276111222|0.4799999999999999 |
* |0.7771658291080573|0.5 |
* |0.7769914303593917|0.5194805194805194 |
* +------------------+--------------------+
*/
import org.apache.spark.sql.functions.max
val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0)
println(s"maxFMeasure: ${maxFMeasure}")
//maxFMeasure: 1.0
//获取最大阈值
val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure)
.select("threshold").head().getDouble(0)
//设置阈值
lrModel.setThreshold(bestThreshold)
//通过模型得到预测结果
val predictions = lrModel.transform(training)
//打印预测结果
predictions.select("label","prediction","rawPrediction","probability").show(20,false)
/**
* +-----+----------+------------------------------------------+----------------------------------------+
* |label|prediction|rawPrediction |probability |
* +-----+----------+------------------------------------------+----------------------------------------+
* |1.0 |1.0 |[-1.2006427604100238,1.2006427604100238] |[0.23136089273325805,0.7686391072667419]|
* |1.0 |1.0 |[-0.9725809559312681,0.9725809559312681] |[0.27436636195954833,0.7256336380404517]|
* |1.0 |1.0 |[-1.0780500487239726,1.0780500487239726] |[0.2538752041717863,0.7461247958282137] |
* |1.0 |1.0 |[-1.08453337526972,1.08453337526972] |[0.2526490765347195,0.7473509234652804] |
* |0.0 |0.0 |[0.7376543954891006,-0.7376543954891006] |[0.6764827243160596,0.32351727568394034]|
* |1.0 |1.0 |[-1.2286964884747311,1.2286964884747311] |[0.2264096521620533,0.7735903478379467] |
* |1.0 |1.0 |[-1.259664579572604,1.259664579572604] |[0.22103163838285006,0.77896836161715] |
* |1.0 |1.0 |[-1.2371063245185787,1.2371063245185787] |[0.22494007343582254,0.7750599265641774]|
* |0.0 |0.0 |[0.738396178597871,-0.738396178597871] |[0.6766450451466368,0.32335495485336313]|
* |1.0 |1.0 |[-1.2123284339889662,1.2123284339889662] |[0.2292893207049596,0.7707106792950403] |
* |1.0 |0.0 |[-0.23508568050538953,0.23508568050538953]|[0.4414977605721645,0.5585022394278355] |
* +-----+----------+------------------------------------------+----------------------------------------+
*/
//模型评估
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
.setMetricName("accuracy")
//计算错误率
val accuracy = evaluator.evaluate(predictions)
//打印准确率
println(s"准确率:${accuracy}")
//打印错误率
println("Test Error = " + (1.0 - accuracy))
//Test Error = 0.010000000000000009
spark.stop()
}
3 多项式逻辑回归
3.1 算法说明
多项式逻辑回归 (Multinomial logistic regression)
通过多项Logistic(softmax)回归支持多类分类。
在多项Logistic回归中,该算法产生K个系数集,或K×J矩阵,其中K是结果类的数量,J是特征数。 如果算法与截距项拟合,则截距的长度K向量是可用的。
多项式系数可用作系数矩阵,截距可作为interceptVector使用。
不支持用多项式族训练的逻辑回归模型的系数和截距方法。 改用系数矩阵和interceptVector。
使用softmax函数对结果类k∈1,2,…,K的条件概率进行建模。
3.2 代码示例
package hnbian.spark.ml.algorithms.classification
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
/**
* @author hnbian
* @ Description 多项式逻辑回归
* @ Date 2019/1/3 17:53
**/
object LogisticRegressionMulticlass extends App {
import hnbian.spark.utils.SparkUtils
val spark = SparkUtils.getSparkSession("LogisticRegressionMulticlass",4)
import hnbian.spark.utils.FileUtils
val filePath = FileUtils.getFilePath("sample_multiclass_classification_data.txt")
val training = spark
.read
.format("libsvm")
.load(filePath)
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.3)
.setElasticNetParam(0.8)
val lrModel = lr.fit(training)
// Print the coefficients and intercept for multinomial logistic regression
println(s"Coefficients: \n${lrModel.coefficientMatrix}")
/**
* Coefficients:
* 3 x 4 CSCMatrix
* (1,2) -0.7803943459681859
* (0,3) 0.3176483191238039
* (1,3) -0.3769611423403096
*/
println(s"Intercepts: ${lrModel.interceptVector}")
//Intercepts: [0.05165231659832854,-0.12391224990853622,0.07225993331020768]
//通过模型预测结果
val predictions = lrModel.transform(training)
predictions.show(false)
predictions.select("label","prediction","probability","rawPrediction","features").show(false)
/**
* +-----+----------+------------------------------------------------------------+-----------------------------------------------------------------+-----------------------------------------------------------+
* |label|prediction|probability |rawPrediction |features |
* +-----+----------+------------------------------------------------------------+-----------------------------------------------------------------+-----------------------------------------------------------+
* |1.0 |1.0 |[0.19824091021950388,0.5380629504663061,0.26369613931419] |[-0.21305451012206839,0.7854380421234279,0.07225993331020768] |(4,[0,1,2,3],[-0.222222,0.5,-0.762712,-0.833333]) |
* |1.0 |1.0 |[0.18250386256254247,0.568221142106847,0.24927499533061048] |[-0.23952541514793146,0.8962139249724502,0.07225993331020768] |(4,[0,1,2,3],[-0.555556,0.25,-0.864407,-0.916667]) |
* |1.0 |1.0 |[0.18980556250236028,0.5577188309398853,0.25247560655775453]|[-0.21305451012206839,0.8648002451366628,0.07225993331020768] |(4,[0,1,2,3],[-0.722222,-0.166667,-0.864407,-0.833333]) |
* |1.0 |1.0 |[0.19632523546632502,0.5355216883905428,0.2681530761431321] |[-0.23952541514793146,0.7639433264856103,0.07225993331020768] |(4,[0,1,2,3],[-0.722222,0.166667,-0.694915,-0.916667]) |
* |0.0 |0.0 |[0.4375039818343829,0.18146938948218602,0.38102662868343107]|[0.21047647616023052,-0.6695223444410741,0.07225993331020768] |(4,[0,1,2,3],[0.166667,-0.416667,0.457627,0.5]) |
* |1.0 |1.0 |[0.18250386256254247,0.568221142106847,0.24927499533061048] |[-0.23952541514793146,0.8962139249724502,0.07225993331020768] |(4,[0,2,3],[-0.833333,-0.864407,-0.916667]) |
* |2.0 |0.0 * |[0.37581775428218006,0.2505614888540287,0.37362075686379126]|[0.07812299927036823,-0.32727697566780806,0.07225993331020768] |(4,[0,1,2,3],[-1.32455E-7,-0.166667,0.220339,0.0833333]) |
* |2.0 |2.0 |[0.3510273915379525,0.2906363236580725,0.3583362848039749] |[0.05165230377890003,-0.1371392165046517,0.07225993331020768] |(4,[0,1,2,3],[-1.32455E-7,-0.333333,0.0169491,-4.03573E-8])|
* |1.0 |1.0 |[0.17808226409449213,0.5721574660564578,0.24976026984905003]|[-0.2659960025254754,0.9011726399131196,0.07225993331020768] |(4,[0,1,2,3],[-0.5,0.75,-0.830508,-1.0]) |
* |0.0 |0.0 |[0.44258017540583633,0.16163302558940845,0.3957867990047552]|[0.18400588878268653,-0.8232872551325279,0.07225993331020768] |(4,[0,2,3],[0.611111,0.694915,0.416667]) |
* |0.0 |0.0 |[0.4444230148660403,0.17863559375693347,0.37694139137702626]|[0.23694706353777445,-0.6744818397760896,0.07225993331020768] |(4,[0,1,2,3],[0.222222,-0.166667,0.423729,0.583333]) |
* |1.0 |1.0 |[0.1753920693035666,0.5786206574472668,0.24598727324916655] |[-0.2659960025254754,0.9276272278470951,0.07225993331020768] |(4,[0,1,2,3],[-0.722222,-0.166667,-0.864407,-1.0]) |
* |1.0 |1.0 |[0.18250386256254247,0.568221142106847,0.24927499533061048] |[-0.23952541514793146,0.8962139249724502,0.07225993331020768] |(4,[0,1,2,3],[-0.5,0.166667,-0.864407,-0.916667]) |
* |2.0 |2.0 |[0.3537112464509239,0.2852127321272451,0.36107602142183104] |[0.05165230377890003,-0.16359325816258505,0.07225993331020768] |(4,[0,1,2,3],[-0.222222,-0.333333,0.0508474,-4.03573E-8]) |
* |2.0 |2.0 |[0.32360705108265925,0.31874480283179474,0.3576481460855459]|[-0.027759763182622438,-0.0428989461327082,0.07225993331020768] |(4,[0,1,2,3],[-0.0555556,-0.833333,0.0169491,-0.25]) |
* |2.0 |2.0 |[0.33909561029444296,0.30546298762039154,0.3554414020851655]|[0.025181633926288853,-0.07927185213629911,0.07225993331020768] |(4,[0,1,2,3],[-0.166667,-0.416667,-0.0169491,-0.0833333]) |
* |1.0 |1.0 |[0.17976563656243563,0.5746994055166108,0.24553495792095353]|[-0.23952541514793146,0.9226677325120797,0.07225993331020768] |(4,[0,2,3],[-0.944444,-0.898305,-0.916667]) |
* |2.0 |2.0 |[0.3299437131426226,0.3149308100595172,0.35512547679786005] |[-0.0012891758050784866,-0.04785828538885445,0.07225993331020768]|(4,[0,1,2,3],[-0.277778,-0.583333,-0.0169491,-0.166667]) |
* |0.0 |0.0 |[0.39691355784123494,0.21880137373710423,0.3842850684216609]|[0.10459380900173557,-0.4909603605077465,0.07225993331020768] |(4,[0,1,2,3],[0.111111,-0.333333,0.38983,0.166667]) |
* |2.0 |2.0 |[0.3471868571075156,0.28889046288880366,0.36392268000368067]|[0.025181633926288853,-0.1586338990706646,0.07225993331020768] |(4,[0,1,2,3],[-0.222222,-0.166667,0.0847457,-0.0833333]) |
* +-----+----------+------------------------------------------------------------+-----------------------------------------------------------------+-----------------------------------------------------------+
*/
//模型评估
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
.setMetricName("accuracy")
//计算错误率
val accuracy = evaluator.evaluate(predictions)
//打印准确率
println(s"准确率:${accuracy}")
//准确率:0.82
//打印错误率
println("Test Error = " + (1.0 - accuracy))
//Test Error = 0.18000000000000005
spark.stop()
}