1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
| package org.houqian.spark.jpmml
import java.io.File
import org.apache.spark.ml.classification.DecisionTreeClassifier import org.apache.spark.ml.feature.RFormula import org.apache.spark.ml.{Pipeline, PipelineStage} import org.apache.spark.sql.SparkSession import org.jpmml.model.JAXBUtil
object Pipeline {
def main(args: Array[String]): Unit = { val spark = SparkSession .builder .appName("Pipeline") .master("local[4]") .getOrCreate() val irisData = spark .read .format("csv") .option("header", "true") .load("file:///Users/houqian/repo/github/data-notebook/src/main/resources/Iris.csv")
irisData.show()
val formula = new RFormula().setFormula("Species ~ .")
val classifier = new DecisionTreeClassifier().setLabelCol(formula.getLabelCol).setFeaturesCol(formula.getFeaturesCol)
val pipeline = new Pipeline().setStages(Array[PipelineStage](formula, classifier))
val pipelineModel = pipeline.fit(irisData)
import javax.xml.transform.stream.StreamResult import org.jpmml.sparkml.PMMLBuilder
val schema = irisData.schema val pmml = new PMMLBuilder(schema, pipelineModel).build
JAXBUtil.marshalPMML(pmml, new StreamResult(System.out))
new PMMLBuilder(schema, pipelineModel).buildFile(new File("/Users/houqian/repo/github/data-notebook/src/main/resources/pipeline.pmml")) } }
|