## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#importosfromtypingimportAny,Dict,List,Optional,Tuple,Type,Union,cast,TYPE_CHECKINGfrompysparkimportkeyword_only,sincefrompyspark.ml.baseimportEstimator,Model,Transformerfrompyspark.ml.paramimportParam,Paramsfrompyspark.ml.utilimport(MLReadable,MLWritable,JavaMLWriter,JavaMLReader,DefaultParamsReader,DefaultParamsWriter,MLWriter,MLReader,JavaMLReadable,JavaMLWritable,)frompyspark.ml.wrapperimportJavaParamsfrompyspark.ml.commonimportinherit_docfrompyspark.sqlimportSparkSessionfrompyspark.sql.dataframeimportDataFrameifTYPE_CHECKING:frompyspark.ml._typingimportParamMap,PipelineStagefrompy4j.java_gatewayimportJavaObjectfrompyspark.core.contextimportSparkContext
[docs]@inherit_docclassPipeline(Estimator["PipelineModel"],MLReadable["Pipeline"],MLWritable):""" A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each of which is either an :py:class:`Estimator` or a :py:class:`Transformer`. When :py:meth:`Pipeline.fit` is called, the stages are executed in order. If a stage is an :py:class:`Estimator`, its :py:meth:`Estimator.fit` method will be called on the input dataset to fit a model. Then the model, which is a transformer, will be used to transform the dataset as the input to the next stage. If a stage is a :py:class:`Transformer`, its :py:meth:`Transformer.transform` method will be called to produce the dataset for the next stage. The fitted model from a :py:class:`Pipeline` is a :py:class:`PipelineModel`, which consists of fitted models and transformers, corresponding to the pipeline stages. If stages is an empty list, the pipeline acts as an identity transformer. .. versionadded:: 1.3.0 """stages:Param[List["PipelineStage"]]=Param(Params._dummy(),"stages","a list of pipeline stages")_input_kwargs:Dict[str,Any]@keyword_onlydef__init__(self,*,stages:Optional[List["PipelineStage"]]=None):""" __init__(self, \\*, stages=None) """super(Pipeline,self).__init__()kwargs=self._input_kwargsself.setParams(**kwargs)
[docs]defsetStages(self,value:List["PipelineStage"])->"Pipeline":""" Set pipeline stages. .. versionadded:: 1.3.0 Parameters ---------- value : list of :py:class:`pyspark.ml.Transformer` or :py:class:`pyspark.ml.Estimator` Returns ------- :py:class:`Pipeline` the pipeline instance """returnself._set(stages=value)
[docs]@since("1.3.0")defgetStages(self)->List["PipelineStage"]:""" Get pipeline stages. """returnself.getOrDefault(self.stages)
[docs]@keyword_only@since("1.3.0")defsetParams(self,*,stages:Optional[List["PipelineStage"]]=None)->"Pipeline":""" setParams(self, \\*, stages=None) Sets params for Pipeline. """kwargs=self._input_kwargsreturnself._set(**kwargs)
def_fit(self,dataset:DataFrame)->"PipelineModel":stages=self.getStages()forstageinstages:ifnot(isinstance(stage,Estimator)orisinstance(stage,Transformer)):raiseTypeError("Cannot recognize a pipeline stage of type %s."%type(stage))indexOfLastEstimator=-1fori,stageinenumerate(stages):ifisinstance(stage,Estimator):indexOfLastEstimator=itransformers:List[Transformer]=[]fori,stageinenumerate(stages):ifi<=indexOfLastEstimator:ifisinstance(stage,Transformer):transformers.append(stage)dataset=stage.transform(dataset)else:# must be an Estimatormodel=stage.fit(dataset)transformers.append(model)ifi<indexOfLastEstimator:dataset=model.transform(dataset)else:transformers.append(cast(Transformer,stage))returnPipelineModel(transformers)
[docs]defcopy(self,extra:Optional["ParamMap"]=None)->"Pipeline":""" Creates a copy of this instance. .. versionadded:: 1.4.0 Parameters ---------- extra : dict, optional extra parameters Returns ------- :py:class:`Pipeline` new instance """ifextraisNone:extra=dict()that=Params.copy(self,extra)stages=[stage.copy(extra)forstageinthat.getStages()]returnthat.setStages(stages)
[docs]@since("2.0.0")defwrite(self)->MLWriter:"""Returns an MLWriter instance for this ML instance."""allStagesAreJava=PipelineSharedReadWrite.checkStagesForJava(self.getStages())ifallStagesAreJava:returnJavaMLWriter(self)# type: ignore[arg-type]returnPipelineWriter(self)
[docs]@classmethod@since("2.0.0")defread(cls)->"PipelineReader":"""Returns an MLReader instance for this class."""returnPipelineReader(cls)
@classmethoddef_from_java(cls,java_stage:"JavaObject")->"Pipeline":""" Given a Java Pipeline, create and return a Python wrapper of it. Used for ML persistence. """# Create a new instance of this stage.py_stage=cls()# Load information from java_stage to the instance.py_stages:List["PipelineStage"]=[JavaParams._from_java(s)forsinjava_stage.getStages()]py_stage.setStages(py_stages)py_stage._resetUid(java_stage.uid())returnpy_stagedef_to_java(self)->"JavaObject":""" Transfer this instance to a Java Pipeline. Used for ML persistence. Returns ------- py4j.java_gateway.JavaObject Java object equivalent to this instance. """frompyspark.core.contextimportSparkContextgateway=SparkContext._gatewayassertgatewayisnotNoneandSparkContext._jvmisnotNonecls=SparkContext._jvm.org.apache.spark.ml.PipelineStagejava_stages=gateway.new_array(cls,len(self.getStages()))foridx,stageinenumerate(self.getStages()):java_stages[idx]=cast(JavaParams,stage)._to_java()_java_obj=JavaParams._new_java_obj("org.apache.spark.ml.Pipeline",self.uid)_java_obj.setStages(java_stages)return_java_obj
@inherit_docclassPipelineWriter(MLWriter):""" (Private) Specialization of :py:class:`MLWriter` for :py:class:`Pipeline` types """def__init__(self,instance:Pipeline):super(PipelineWriter,self).__init__()self.instance=instancedefsaveImpl(self,path:str)->None:stages=self.instance.getStages()PipelineSharedReadWrite.validateStages(stages)PipelineSharedReadWrite.saveImpl(self.instance,stages,self.sparkSession,path)@inherit_docclassPipelineReader(MLReader[Pipeline]):""" (Private) Specialization of :py:class:`MLReader` for :py:class:`Pipeline` types """def__init__(self,cls:Type[Pipeline]):super(PipelineReader,self).__init__()self.cls=clsdefload(self,path:str)->Pipeline:metadata=DefaultParamsReader.loadMetadata(path,self.sparkSession)if"language"notinmetadata["paramMap"]ormetadata["paramMap"]["language"]!="Python":returnJavaMLReader(cast(Type["JavaMLReadable[Pipeline]"],self.cls)).load(path)else:uid,stages=PipelineSharedReadWrite.load(metadata,self.sparkSession,path)returnPipeline(stages=stages)._resetUid(uid)@inherit_docclassPipelineModelWriter(MLWriter):""" (Private) Specialization of :py:class:`MLWriter` for :py:class:`PipelineModel` types """def__init__(self,instance:"PipelineModel"):super(PipelineModelWriter,self).__init__()self.instance=instancedefsaveImpl(self,path:str)->None:stages=self.instance.stagesPipelineSharedReadWrite.validateStages(cast(List["PipelineStage"],stages))PipelineSharedReadWrite.saveImpl(self.instance,cast(List["PipelineStage"],stages),self.sparkSession,path)@inherit_docclassPipelineModelReader(MLReader["PipelineModel"]):""" (Private) Specialization of :py:class:`MLReader` for :py:class:`PipelineModel` types """def__init__(self,cls:Type["PipelineModel"]):super(PipelineModelReader,self).__init__()self.cls=clsdefload(self,path:str)->"PipelineModel":metadata=DefaultParamsReader.loadMetadata(path,self.sparkSession)if"language"notinmetadata["paramMap"]ormetadata["paramMap"]["language"]!="Python":returnJavaMLReader(cast(Type["JavaMLReadable[PipelineModel]"],self.cls)).load(path)else:uid,stages=PipelineSharedReadWrite.load(metadata,self.sparkSession,path)returnPipelineModel(stages=cast(List[Transformer],stages))._resetUid(uid)
[docs]@inherit_docclassPipelineModel(Model,MLReadable["PipelineModel"],MLWritable):""" Represents a compiled pipeline with transformers and fitted models. .. versionadded:: 1.3.0 """def__init__(self,stages:List[Transformer]):super(PipelineModel,self).__init__()self.stages=stagesdef_transform(self,dataset:DataFrame)->DataFrame:fortinself.stages:dataset=t.transform(dataset)returndataset
[docs]defcopy(self,extra:Optional["ParamMap"]=None)->"PipelineModel":""" Creates a copy of this instance. .. versionadded:: 1.4.0 :param extra: extra parameters :returns: new instance """ifextraisNone:extra=dict()stages=[stage.copy(extra)forstageinself.stages]returnPipelineModel(stages)
[docs]@since("2.0.0")defwrite(self)->MLWriter:"""Returns an MLWriter instance for this ML instance."""allStagesAreJava=PipelineSharedReadWrite.checkStagesForJava(cast(List["PipelineStage"],self.stages))ifallStagesAreJava:returnJavaMLWriter(self)# type: ignore[arg-type]returnPipelineModelWriter(self)
[docs]@classmethod@since("2.0.0")defread(cls)->PipelineModelReader:"""Returns an MLReader instance for this class."""returnPipelineModelReader(cls)
@classmethoddef_from_java(cls,java_stage:"JavaObject")->"PipelineModel":""" Given a Java PipelineModel, create and return a Python wrapper of it. Used for ML persistence. """# Load information from java_stage to the instance.py_stages:List[Transformer]=[JavaParams._from_java(s)forsinjava_stage.stages()]# Create a new instance of this stage.py_stage=cls(py_stages)py_stage._resetUid(java_stage.uid())returnpy_stagedef_to_java(self)->"JavaObject":""" Transfer this instance to a Java PipelineModel. Used for ML persistence. :return: Java object equivalent to this instance. """frompyspark.core.contextimportSparkContextgateway=SparkContext._gatewayassertgatewayisnotNoneandSparkContext._jvmisnotNonecls=SparkContext._jvm.org.apache.spark.ml.Transformerjava_stages=gateway.new_array(cls,len(self.stages))foridx,stageinenumerate(self.stages):java_stages[idx]=cast(JavaParams,stage)._to_java()_java_obj=JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel",self.uid,java_stages)return_java_obj
@inherit_docclassPipelineSharedReadWrite:""" Functions for :py:class:`MLReader` and :py:class:`MLWriter` shared between :py:class:`Pipeline` and :py:class:`PipelineModel` .. versionadded:: 2.3.0 """@staticmethoddefcheckStagesForJava(stages:List["PipelineStage"])->bool:returnall(isinstance(stage,JavaMLWritable)forstageinstages)@staticmethoddefvalidateStages(stages:List["PipelineStage"])->None:""" Check that all stages are Writable """forstageinstages:ifnotisinstance(stage,MLWritable):raiseValueError("Pipeline write will fail on this pipeline "+"because stage %s of type %s is not MLWritable",stage.uid,type(stage),)@staticmethoddefsaveImpl(instance:Union[Pipeline,PipelineModel],stages:List["PipelineStage"],sc:Union["SparkContext",SparkSession],path:str,)->None:""" Save metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel` - save metadata to path/metadata - save stages to stages/IDX_UID """stageUids=[stage.uidforstageinstages]jsonParams={"stageUids":stageUids,"language":"Python"}DefaultParamsWriter.saveMetadata(instance,path,sc,paramMap=jsonParams)stagesDir=os.path.join(path,"stages")forindex,stageinenumerate(stages):cast(MLWritable,stage).write().save(PipelineSharedReadWrite.getStagePath(stage.uid,index,len(stages),stagesDir))@staticmethoddefload(metadata:Dict[str,Any],sc:Union["SparkContext",SparkSession],path:str,)->Tuple[str,List["PipelineStage"]]:""" Load metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel` Returns ------- tuple (UID, list of stages) """stagesDir=os.path.join(path,"stages")stageUids=metadata["paramMap"]["stageUids"]stages=[]forindex,stageUidinenumerate(stageUids):stagePath=PipelineSharedReadWrite.getStagePath(stageUid,index,len(stageUids),stagesDir)stage:"PipelineStage"=DefaultParamsReader.loadParamsInstance(stagePath,sc)stages.append(stage)return(metadata["uid"],stages)@staticmethoddefgetStagePath(stageUid:str,stageIdx:int,numStages:int,stagesDir:str)->str:""" Get path for saving the given stage. """stageIdxDigits=len(str(numStages))stageDir=str(stageIdx).zfill(stageIdxDigits)+"_"+stageUidstagePath=os.path.join(stagesDir,stageDir)returnstagePath