Spark ML 9.分类算法 1


1. 逻辑回归

逻辑回归(Logistic Regression)是预测分类的流程方法,它是 广义线性模型 的一个特例来预测结果分类的可能性。在spar.ml 逻辑回归中可以使用二项式逻辑回归来预测二进制结果,也可以通过多项式逻辑回归来预测多类结果。使用参数在这两种算法之间进行选择,或者不进行设置,Spark将自行推断选择合适的算法。

通过将family 参数设置为“multinomial” 可以将多项逻辑回归用于二进制分类。它将产生两组系数和两个截距。

当在具有常量非零列的数据集上对LogisticRegressionModel进行拟合时,Spark MLlib为常数非零列输出零系数。此行为与R glmnet 相同,但与LIBSVM不同。

2. 二项式逻辑回归

2.1 算法介绍

  • 二项式逻辑回归(Binomial logistic regression)

  • 线性最小二乘法是回归问题的最常见公式。它是 Generalized Linear models(广义线性模型) 的一个特殊应用以预测结果概率

  • 它是一个线性模型如下列方程所示, 其中损失函数为逻辑损失: $L(w;x,y):=log(1 + exp(-yw^Tx))$

  • 对于二分类问题,算法产出一个二值逻辑回归模型,给定一个新数据,由 x 表示,则模型通过下列逻辑方程来预测:

    $f(z) = \frac{1}{1+e^{-z}}$

    其中 $z=w^Tx$ ,默认情况下,如果 $f(w^Tx)> 0.5$ , 结果为正,否则为负。和线性 SVMs 不同,逻辑回归的原始输出有概率解释(x为正的概率)

2.2 参数列表

参数名称 类型 说明
elasticNetParam 双精度 弹性网格混合参数, 范围[0,1]
family 用于系列名称的参数,它是模型中使用的标签分布的描述默认为“auto”
auto:根据类的数量自动选择系列:
如果numClasses == 1 || numClasses == 2,则设置为“binomial”。
否则,设置为“multinomial” binomial:二元逻辑回归。
multinomial:没有旋转的多项logistic(softmax)回归
featuresCol 字符串 特征列名
fitlntercept 布尔值 是否训练拦截对象
labelCol 字符串 标签列名
maxlter 整数 最多迭代次数(>=0)
predictionCol 字符串 预测结果列名
probabilityCol 字符串 用以预测类别条件概率的列名
regParam 双精度 正则化参数 (>=0)
standardization 布尔值 训练模型前是否需要对训练特征进行标准化处理
threshold 双精度 二分类预测的阈值, 范围[0,1]
thresholds 双精度 多分类预测的阈值, 以调整预测结果在各个类别的概率
tol 双精度 迭代算法的收敛性
weightCol 字符串 列权重

2.3 代码示例

package hnbian.spark.ml.algorithms.classification

import hnbian.spark.utils.{FileUtils, SparkUtils}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator

/**
  * @author hnbian
  * @ Description 逻辑回归二项式预测
  * @ Date 2019/1/3 15:33
  **/
object LogisticRegression extends App{
  val spark = SparkUtils.getSparkSession("LogisticRegression",4)
  import spark.implicits._

  val filePath = FileUtils.getFilePath("sample_libsvm_data.txt")

  val training = spark.read.format("libsvm").load(filePath)
  val lr = new LogisticRegression()
    .setMaxIter(10) //最大迭代次数
    .setRegParam(0.3) //正则化参数
    .setElasticNetParam(0.8) //正则化范式比(默认0.0),正则化一般有两种范式:L1(Lasso)和L2(Ridge)。L1一般用于特征的稀疏化,L2一般用于防止过拟合。这里的参数即设置L1范式的占比,默认0.0即只使用L2范式
    //.setFamily("binomial")
    .setThreshold(1.0)

  // Fit the model
  val lrModel = lr.fit(training)

  // 打印逻辑回归的系数和截距
  println(s"Coefficients: ${lrModel.coefficients} \nIntercept: ${lrModel.intercept}")
  //Coefficients: (692,[244,263,272,300,301,328,350,351,378,379,405,406,407,428,433,434,455,456,461,462,483,484,489,490,496,511,512,517,539,540,568],[-7.35398352418814E-5,-9.102738505589432E-5,-1.9467430546904216E-4,-2.030064247348659E-4,-3.1476183314865005E-5,-6.842977602660699E-5,1.5883626898237813E-5,1.4023497091369702E-5,3.5432047524968963E-4,1.1443272898170924E-4,1.0016712383666388E-4,6.014109303795469E-4,2.8402481791227693E-4,-1.1541084736508769E-4,3.8599688631290956E-4,6.350195574241061E-4,-1.1506412384575594E-4,-1.5271865864986703E-4,2.8049338089942207E-4,6.070117471191611E-4,-2.0084596632474318E-4,-1.4210755792901163E-4,2.739010341160889E-4,2.7730456244968185E-4,-9.838027027269304E-5,-3.808522443517673E-4,-2.5315198008554816E-4,2.7747714770754383E-4,-2.443619763919179E-4,-0.0015394744687597863,-2.3073328411331095E-4])
  //Intercept: 0.224563159612503

  /**
    * 获取训练集上的模型摘要。 如果`trainingSummary == None`或它是多类模型,则抛出异常。
    */
  val trainingSummary = lrModel.binarySummary

  // 获得每次迭代的目标
  val objectiveHistory = trainingSummary.objectiveHistory
  objectiveHistory.foreach(loss => println(loss))
  /**
    * objectiveHistory:
    * 0.6833149135741672
    * 0.6662875751473734
    * 0.6217068546034618
    * 0.6127265245887887
    * 0.6060347986802872
    * 0.6031750687571563
    * 0.5969621534836272
    * 0.5940743031983121
    * 0.5906089243339023
    * 0.5894724576491043
    * 0.5882187775729587
    */

  // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
  val roc = trainingSummary.roc
  roc.show()
  /**
    * +---+--------------------+
    * |FPR|                 TPR|
    * +---+--------------------+
    * |0.0|                 0.0|
    * |0.0|0.017543859649122806|
    * |0.0| 0.03508771929824561|
    * |0.0| 0.05263157894736842|
    * |0.0| 0.07017543859649122|
    * |0.0|  0.3157894736842105|
    * |0.0|  0.3333333333333333|
    * +---+--------------------+
    */
  println(s"areaUnderROC: ${trainingSummary.areaUnderROC}")

  // 设置模型阈值以最大化F-Measure
  val fMeasure = trainingSummary.fMeasureByThreshold
  fMeasure.show(false)
  /**
    * +------------------+--------------------+
    * |threshold         |F-Measure           |
    * +------------------+--------------------+
    * |0.7845860015371141|0.034482758620689655|
    * |0.784319334416892 |0.06779661016949151 |
    * |0.784297609251013 |0.1                 |
    * |0.7842531051133191|0.13114754098360656 |
    * |0.7788060694625323|0.45945945945945943 |
    * |0.7783754276111222|0.4799999999999999  |
    * |0.7771658291080573|0.5                 |
    * |0.7769914303593917|0.5194805194805194  |
    * +------------------+--------------------+
    */

  import org.apache.spark.sql.functions.max
  val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0)
  println(s"maxFMeasure: ${maxFMeasure}")
  //maxFMeasure: 1.0

  //获取最大阈值
  val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure)
    .select("threshold").head().getDouble(0)
  //设置阈值
  lrModel.setThreshold(bestThreshold)
  //通过模型得到预测结果
  val predictions = lrModel.transform(training)
  //打印预测结果
  predictions.select("label","prediction","rawPrediction","probability").show(20,false)

  /**
    * +-----+----------+------------------------------------------+----------------------------------------+
    * |label|prediction|rawPrediction                             |probability                             |
    * +-----+----------+------------------------------------------+----------------------------------------+
    * |1.0  |1.0       |[-1.2006427604100238,1.2006427604100238]  |[0.23136089273325805,0.7686391072667419]|
    * |1.0  |1.0       |[-0.9725809559312681,0.9725809559312681]  |[0.27436636195954833,0.7256336380404517]|
    * |1.0  |1.0       |[-1.0780500487239726,1.0780500487239726]  |[0.2538752041717863,0.7461247958282137] |
    * |1.0  |1.0       |[-1.08453337526972,1.08453337526972]      |[0.2526490765347195,0.7473509234652804] |
    * |0.0  |0.0       |[0.7376543954891006,-0.7376543954891006]  |[0.6764827243160596,0.32351727568394034]|
    * |1.0  |1.0       |[-1.2286964884747311,1.2286964884747311]  |[0.2264096521620533,0.7735903478379467] |
    * |1.0  |1.0       |[-1.259664579572604,1.259664579572604]    |[0.22103163838285006,0.77896836161715]  |
    * |1.0  |1.0       |[-1.2371063245185787,1.2371063245185787]  |[0.22494007343582254,0.7750599265641774]|
    * |0.0  |0.0       |[0.738396178597871,-0.738396178597871]    |[0.6766450451466368,0.32335495485336313]|
    * |1.0  |1.0       |[-1.2123284339889662,1.2123284339889662]  |[0.2292893207049596,0.7707106792950403] |
    * |1.0  |0.0       |[-0.23508568050538953,0.23508568050538953]|[0.4414977605721645,0.5585022394278355] |
    * +-----+----------+------------------------------------------+----------------------------------------+
    */

  //模型评估

  val evaluator = new MulticlassClassificationEvaluator()
    .setLabelCol("label")
    .setPredictionCol("prediction")
    .setMetricName("accuracy")
  //计算错误率
  val accuracy = evaluator.evaluate(predictions)
  //打印准确率
  println(s"准确率:${accuracy}")
  //打印错误率
  println("Test Error = " + (1.0 - accuracy))
  //Test Error = 0.010000000000000009

  spark.stop()
}

3 多项式逻辑回归

3.1 算法说明

  • 多项式逻辑回归 (Multinomial logistic regression)

  • 通过多项Logistic(softmax)回归支持多类分类。

  • 在多项Logistic回归中,该算法产生K个系数集,或K×J矩阵,其中K是结果类的数量,J是特征数。 如果算法与截距项拟合,则截距的长度K向量是可用的。

  • 多项式系数可用作系数矩阵,截距可作为interceptVector使用。

  • 不支持用多项式族训练的逻辑回归模型的系数和截距方法。 改用系数矩阵和interceptVector。

  • 使用softmax函数对结果类k∈1,2,…,K的条件概率进行建模。

3.2 代码示例

package hnbian.spark.ml.algorithms.classification

import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator

/**
  * @author hnbian
  * @ Description 多项式逻辑回归
  * @ Date 2019/1/3 17:53
  **/
object LogisticRegressionMulticlass extends App {

  import hnbian.spark.utils.SparkUtils
  val spark = SparkUtils.getSparkSession("LogisticRegressionMulticlass",4)

  import hnbian.spark.utils.FileUtils
  val  filePath = FileUtils.getFilePath("sample_multiclass_classification_data.txt")

  val training = spark
    .read
    .format("libsvm")
    .load(filePath)

  val lr = new LogisticRegression()
    .setMaxIter(10)
    .setRegParam(0.3)
    .setElasticNetParam(0.8)

  val lrModel = lr.fit(training)

  // Print the coefficients and intercept for multinomial logistic regression
  println(s"Coefficients: \n${lrModel.coefficientMatrix}")
  /**
    * Coefficients:
    * 3 x 4 CSCMatrix
    * (1,2) -0.7803943459681859
    * (0,3) 0.3176483191238039
    * (1,3) -0.3769611423403096
    */
  println(s"Intercepts: ${lrModel.interceptVector}")
  //Intercepts: [0.05165231659832854,-0.12391224990853622,0.07225993331020768]
  //通过模型预测结果
  val predictions = lrModel.transform(training)

  predictions.show(false)
  predictions.select("label","prediction","probability","rawPrediction","features").show(false)
  /**
    * +-----+----------+------------------------------------------------------------+-----------------------------------------------------------------+-----------------------------------------------------------+
    * |label|prediction|probability                                                 |rawPrediction                                                    |features                                                   |
    * +-----+----------+------------------------------------------------------------+-----------------------------------------------------------------+-----------------------------------------------------------+
    * |1.0  |1.0       |[0.19824091021950388,0.5380629504663061,0.26369613931419]   |[-0.21305451012206839,0.7854380421234279,0.07225993331020768]    |(4,[0,1,2,3],[-0.222222,0.5,-0.762712,-0.833333])          |
    * |1.0  |1.0       |[0.18250386256254247,0.568221142106847,0.24927499533061048] |[-0.23952541514793146,0.8962139249724502,0.07225993331020768]    |(4,[0,1,2,3],[-0.555556,0.25,-0.864407,-0.916667])         |
    * |1.0  |1.0       |[0.18980556250236028,0.5577188309398853,0.25247560655775453]|[-0.21305451012206839,0.8648002451366628,0.07225993331020768]    |(4,[0,1,2,3],[-0.722222,-0.166667,-0.864407,-0.833333])    |
    * |1.0  |1.0       |[0.19632523546632502,0.5355216883905428,0.2681530761431321] |[-0.23952541514793146,0.7639433264856103,0.07225993331020768]    |(4,[0,1,2,3],[-0.722222,0.166667,-0.694915,-0.916667])     |
    * |0.0  |0.0       |[0.4375039818343829,0.18146938948218602,0.38102662868343107]|[0.21047647616023052,-0.6695223444410741,0.07225993331020768]    |(4,[0,1,2,3],[0.166667,-0.416667,0.457627,0.5])            |
    * |1.0  |1.0       |[0.18250386256254247,0.568221142106847,0.24927499533061048] |[-0.23952541514793146,0.8962139249724502,0.07225993331020768]    |(4,[0,2,3],[-0.833333,-0.864407,-0.916667])                |
    * |2.0  |0.0 *     |[0.37581775428218006,0.2505614888540287,0.37362075686379126]|[0.07812299927036823,-0.32727697566780806,0.07225993331020768]   |(4,[0,1,2,3],[-1.32455E-7,-0.166667,0.220339,0.0833333])   |
    * |2.0  |2.0       |[0.3510273915379525,0.2906363236580725,0.3583362848039749]  |[0.05165230377890003,-0.1371392165046517,0.07225993331020768]    |(4,[0,1,2,3],[-1.32455E-7,-0.333333,0.0169491,-4.03573E-8])|
    * |1.0  |1.0       |[0.17808226409449213,0.5721574660564578,0.24976026984905003]|[-0.2659960025254754,0.9011726399131196,0.07225993331020768]     |(4,[0,1,2,3],[-0.5,0.75,-0.830508,-1.0])                   |
    * |0.0  |0.0       |[0.44258017540583633,0.16163302558940845,0.3957867990047552]|[0.18400588878268653,-0.8232872551325279,0.07225993331020768]    |(4,[0,2,3],[0.611111,0.694915,0.416667])                   |
    * |0.0  |0.0       |[0.4444230148660403,0.17863559375693347,0.37694139137702626]|[0.23694706353777445,-0.6744818397760896,0.07225993331020768]    |(4,[0,1,2,3],[0.222222,-0.166667,0.423729,0.583333])       |
    * |1.0  |1.0       |[0.1753920693035666,0.5786206574472668,0.24598727324916655] |[-0.2659960025254754,0.9276272278470951,0.07225993331020768]     |(4,[0,1,2,3],[-0.722222,-0.166667,-0.864407,-1.0])         |
    * |1.0  |1.0       |[0.18250386256254247,0.568221142106847,0.24927499533061048] |[-0.23952541514793146,0.8962139249724502,0.07225993331020768]    |(4,[0,1,2,3],[-0.5,0.166667,-0.864407,-0.916667])          |
    * |2.0  |2.0       |[0.3537112464509239,0.2852127321272451,0.36107602142183104] |[0.05165230377890003,-0.16359325816258505,0.07225993331020768]   |(4,[0,1,2,3],[-0.222222,-0.333333,0.0508474,-4.03573E-8])  |
    * |2.0  |2.0       |[0.32360705108265925,0.31874480283179474,0.3576481460855459]|[-0.027759763182622438,-0.0428989461327082,0.07225993331020768]  |(4,[0,1,2,3],[-0.0555556,-0.833333,0.0169491,-0.25])       |
    * |2.0  |2.0       |[0.33909561029444296,0.30546298762039154,0.3554414020851655]|[0.025181633926288853,-0.07927185213629911,0.07225993331020768]  |(4,[0,1,2,3],[-0.166667,-0.416667,-0.0169491,-0.0833333])  |
    * |1.0  |1.0       |[0.17976563656243563,0.5746994055166108,0.24553495792095353]|[-0.23952541514793146,0.9226677325120797,0.07225993331020768]    |(4,[0,2,3],[-0.944444,-0.898305,-0.916667])                |
    * |2.0  |2.0       |[0.3299437131426226,0.3149308100595172,0.35512547679786005] |[-0.0012891758050784866,-0.04785828538885445,0.07225993331020768]|(4,[0,1,2,3],[-0.277778,-0.583333,-0.0169491,-0.166667])   |
    * |0.0  |0.0       |[0.39691355784123494,0.21880137373710423,0.3842850684216609]|[0.10459380900173557,-0.4909603605077465,0.07225993331020768]    |(4,[0,1,2,3],[0.111111,-0.333333,0.38983,0.166667])        |
    * |2.0  |2.0       |[0.3471868571075156,0.28889046288880366,0.36392268000368067]|[0.025181633926288853,-0.1586338990706646,0.07225993331020768]   |(4,[0,1,2,3],[-0.222222,-0.166667,0.0847457,-0.0833333])   |
    * +-----+----------+------------------------------------------------------------+-----------------------------------------------------------------+-----------------------------------------------------------+
    */

  //模型评估

  val evaluator = new MulticlassClassificationEvaluator()
    .setLabelCol("label")
    .setPredictionCol("prediction")
    .setMetricName("accuracy")
  //计算错误率
  val accuracy = evaluator.evaluate(predictions)
  //打印准确率
  println(s"准确率:${accuracy}")
  //准确率:0.82
  //打印错误率
  println("Test Error = " + (1.0 - accuracy))
  //Test Error = 0.18000000000000005

  spark.stop()

}

文章作者: hnbian
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 hnbian !
评论
 上一篇
Spark ML 10.分类算法 2 Spark ML 10.分类算法 2
1. 决策树1.1 算法简介决策树以及其继承算法是机器学习分类和回归问题中非常流行的算法,因其易解释性、可处理类别特征、易扩展到多分类问题、不需特征缩放等性质被广泛使用。决策树模式呈树形结构,其中: 每个内部节点 代表一个属性上的测试 每
2019-01-15
下一篇 
Spark ML 8. 特征选择 Spark ML 8. 特征选择
1. 介绍 特征选择(Feature Selection)指的是在特征向量中选择出那些“优秀”的特征,组成新的、更“精简”的特征向量的过程。 特征选择在高维数据分析中十分常用,可以剔除掉“冗余”和“无关”的特征,提升学习器的性能。特征选择
2019-01-05
  目录