1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
   | package org.houqian.spark.jpmml
  import java.io.File
  import org.apache.spark.ml.classification.DecisionTreeClassifier import org.apache.spark.ml.feature.RFormula import org.apache.spark.ml.{Pipeline, PipelineStage} import org.apache.spark.sql.SparkSession import org.jpmml.model.JAXBUtil
 
 
 
 
 
  object Pipeline {
    def main(args: Array[String]): Unit = {     val spark = SparkSession       .builder       .appName("Pipeline")       .master("local[4]")       .getOrCreate()               val irisData = spark       .read       .format("csv")       .option("header", "true")       .load("file:///Users/houqian/repo/github/data-notebook/src/main/resources/Iris.csv")
      irisData.show()
           val formula = new RFormula().setFormula("Species ~ .")
           val classifier = new DecisionTreeClassifier().setLabelCol(formula.getLabelCol).setFeaturesCol(formula.getFeaturesCol)
           val pipeline = new Pipeline().setStages(Array[PipelineStage](formula, classifier))
           val pipelineModel = pipeline.fit(irisData)
      import javax.xml.transform.stream.StreamResult     import org.jpmml.sparkml.PMMLBuilder
      val schema = irisData.schema     val pmml = new PMMLBuilder(schema, pipelineModel).build
           JAXBUtil.marshalPMML(pmml, new StreamResult(System.out))
           new PMMLBuilder(schema, pipelineModel).buildFile(new File("/Users/houqian/repo/github/data-notebook/src/main/resources/pipeline.pmml"))   } }
   |