提示: 本文是模型部署方案的一部分

依赖

  • spark 2.2
  • jpmml-sparkml 1.3.8
  • scala 2.11

步骤

  1. 使用spark ml训练一个决策树模型
  2. 在控制台验证可以输出后,写到文件中

我们开始吧

maven依赖

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
<properties>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<encoding>UTF-8</encoding>
<scala.version>2.11.12</scala.version>
<scala.compat.version>2.11</scala.compat.version>
<spark-core.version>2.2.0</spark-core.version>
<jpmml-sparkml.version>1.3.8</jpmml-sparkml.version>
</properties>

<!--整合jpmml-->
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>jpmml-sparkml</artifactId>
<version>${jpmml-sparkml.version}</version>
</dependency>

<!--spark mllib-->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.11</artifactId>
<version>${spark-core.version}</version>
</dependency>

<!--spark-->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.compat.version}</artifactId>
<version>${spark-core.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.11</artifactId>
<version>${spark-core.version}</version>
</dependency>

模型训练 & 输出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
package org.houqian.spark.jpmml

import java.io.File

import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.ml.{Pipeline, PipelineStage}
import org.apache.spark.sql.SparkSession
import org.jpmml.model.JAXBUtil

/**
* @author : houqian
* @version : 1.0
* @since : 2018-08-30
*/
object Pipeline {

def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.appName("Pipeline")
.master("local[4]")
.getOrCreate()

// 加载训练集
val irisData = spark
.read
.format("csv")
.option("header", "true")
.load("file:///Users/houqian/repo/github/data-notebook/src/main/resources/Iris.csv")

irisData.show()

// 特征选择
val formula = new RFormula().setFormula("Species ~ .")

// 设置决策树分类器
val classifier = new DecisionTreeClassifier().setLabelCol(formula.getLabelCol).setFeaturesCol(formula.getFeaturesCol)

// 组合pipeline
val pipeline = new Pipeline().setStages(Array[PipelineStage](formula, classifier))

// 训练
val pipelineModel = pipeline.fit(irisData)

import javax.xml.transform.stream.StreamResult
import org.jpmml.sparkml.PMMLBuilder

val schema = irisData.schema
val pmml = new PMMLBuilder(schema, pipelineModel).build

// 将pmml以流的形式输出到控制台
JAXBUtil.marshalPMML(pmml, new StreamResult(System.out))

// 将pmml写到文件
new PMMLBuilder(schema, pipelineModel).buildFile(new File("/Users/houqian/repo/github/data-notebook/src/main/resources/pipeline.pmml"))
}
}

运行,控制台输出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML xmlns="http://www.dmg.org/PMML-4_3" xmlns:data="http://jpmml.org/jpmml-model/InlineTable" version="4.3">
<Header>
<Application name="JPMML-SparkML" version="1.5.3"/>
<Timestamp>2018-08-30T14:03:52Z</Timestamp>
</Header>
<DataDictionary>
<DataField name="Species" optype="categorical" dataType="string">
<Value value="versicolor"/>
<Value value="virginica"/>
<Value value="setosa"/>
</DataField>
<DataField name="Petal_Width" optype="categorical" dataType="string">
<Value value="0.2"/>
<Value value="1.3"/>
<Value value="1.5"/>
<Value value="1.8"/>
<Value value="2.3"/>
<Value value="1.4"/>
<Value value="0.4"/>
<Value value="1"/>
<Value value="0.3"/>
<Value value="2.1"/>
<Value value="2"/>
<Value value="0.1"/>
<Value value="1.9"/>
<Value value="1.2"/>
<Value value="1.6"/>
<Value value="2.4"/>
<Value value="1.1"/>
<Value value="2.5"/>
<Value value="2.2"/>
<Value value="1.7"/>
<Value value="0.5"/>
<Value value="0.6"/>
</DataField>
</DataDictionary>
<TreeModel functionName="classification" missingValueStrategy="nullPrediction" splitCharacteristic="multiSplit">
<MiningSchema>
<MiningField name="Species" usageType="target"/>
<MiningField name="Petal_Width"/>
</MiningSchema>
<Output>
<OutputField name="pmml(prediction)" optype="categorical" dataType="string" feature="predictedValue" isFinalResult="false"/>
<OutputField name="prediction" optype="categorical" dataType="double" feature="transformedValue">
<MapValues outputColumn="data:output" dataType="double">
<FieldColumnPair field="pmml(prediction)" column="data:input"/>
<InlineTable>
<row>
<data:input>versicolor</data:input>
<data:output>0</data:output>
</row>
<row>
<data:input>virginica</data:input>
<data:output>1</data:output>
</row>
<row>
<data:input>setosa</data:input>
<data:output>2</data:output>
</row>
</InlineTable>
</MapValues>
</OutputField>
<OutputField name="probability(versicolor)" optype="continuous" dataType="double" feature="probability" value="versicolor"/>
<OutputField name="probability(virginica)" optype="continuous" dataType="double" feature="probability" value="virginica"/>
<OutputField name="probability(setosa)" optype="continuous" dataType="double" feature="probability" value="setosa"/>
</Output>
<Node>
<True/>
<Node score="setosa" recordCount="29">
<SimplePredicate field="Petal_Width" operator="equal" value="0.2"/>
<ScoreDistribution value="versicolor" recordCount="0.0"/>
<ScoreDistribution value="virginica" recordCount="0.0"/>
<ScoreDistribution value="setosa" recordCount="29.0"/>
</Node>
<Node score="versicolor" recordCount="13">
<SimplePredicate field="Petal_Width" operator="equal" value="1.3"/>
<ScoreDistribution value="versicolor" recordCount="13.0"/>
<ScoreDistribution value="virginica" recordCount="0.0"/>
<ScoreDistribution value="setosa" recordCount="0.0"/>
</Node>
<Node score="setosa" recordCount="7">
<SimplePredicate field="Petal_Width" operator="equal" value="0.4"/>
<ScoreDistribution value="versicolor" recordCount="0.0"/>
<ScoreDistribution value="virginica" recordCount="0.0"/>
<ScoreDistribution value="setosa" recordCount="7.0"/>
</Node>
<Node score="setosa" recordCount="7">
<SimplePredicate field="Petal_Width" operator="equal" value="0.3"/>
<ScoreDistribution value="versicolor" recordCount="0.0"/>
<ScoreDistribution value="virginica" recordCount="0.0"/>
<ScoreDistribution value="setosa" recordCount="7.0"/>
</Node>
<Node score="setosa" recordCount="5">
<SimplePredicate field="Petal_Width" operator="equal" value="0.1"/>
<ScoreDistribution value="versicolor" recordCount="0.0"/>
<ScoreDistribution value="virginica" recordCount="0.0"/>
<ScoreDistribution value="setosa" recordCount="5.0"/>
</Node>
<Node score="virginica" recordCount="89">
<True/>
<ScoreDistribution value="versicolor" recordCount="37.0"/>
<ScoreDistribution value="virginica" recordCount="50.0"/>
<ScoreDistribution value="setosa" recordCount="2.0"/>
</Node>
</Node>
</TreeModel>
</PMML>

我们成功生成的pmml文件:

image-20180830161456730.png

参考

  1. https://openscoring.io/blog/2018/07/09/converting_sparkml_pipeline_pmml/

  2. https://github.com/jpmml/jpmml-sparkml

  3. 用到的训练集Iris.csv:https://github.com/jpmml/jpmml-sparkml/blob/1.3.X/src/test/resources/csv/Iris.csv

Comments