## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#fromtypingimportAny,Callable,TYPE_CHECKINGifTYPE_CHECKING:frompyspark.mllib._typingimportC,JavaObjectOrPickleDumpimportpy4j.protocolfrompy4j.protocolimportPy4JJavaErrorfrompy4j.java_gatewayimportJavaObjectfrompy4j.java_collectionsimportJavaArray,JavaListimportpyspark.core.contextfrompysparkimportRDD,SparkContextfrompyspark.serializersimportCPickleSerializer,AutoBatchedSerializerfrompyspark.sqlimportDataFrame,SparkSession# Hack for support float('inf') in Py4j_old_smart_decode=py4j.protocol.smart_decode_float_str_mapping={"nan":"NaN","inf":"Infinity","-inf":"-Infinity",}def_new_smart_decode(obj:Any)->str:ifisinstance(obj,float):s=str(obj)return_float_str_mapping.get(s,s)return_old_smart_decode(obj)py4j.protocol.smart_decode=_new_smart_decode_picklable_classes=["LinkedList","SparseVector","DenseVector","DenseMatrix","Rating","LabeledPoint",]# this will call the MLlib version of pythonToJava()def_to_java_object_rdd(rdd:RDD)->JavaObject:"""Return a JavaRDD of Object by unpickling It will convert each Python object into Java object by Pickle, whenever the RDD is serialized in batch or not. """rdd=rdd._reserialize(AutoBatchedSerializer(CPickleSerializer()))assertrdd.ctx._jvmisnotNonereturnrdd.ctx._jvm.org.apache.spark.mllib.api.python.SerDe.pythonToJava(rdd._jrdd,True)def_py2java(sc:SparkContext,obj:Any)->JavaObject:"""Convert Python object into Java"""ifisinstance(obj,RDD):obj=_to_java_object_rdd(obj)elifisinstance(obj,DataFrame):obj=obj._jdfelifisinstance(obj,SparkContext):obj=obj._jscelifisinstance(obj,list):obj=[_py2java(sc,x)forxinobj]elifisinstance(obj,JavaObject):passelifisinstance(obj,(int,float,bool,bytes,str)):passelse:data=bytearray(CPickleSerializer().dumps(obj))assertsc._jvmisnotNoneobj=sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(data)returnobjdef_java2py(sc:SparkContext,r:"JavaObjectOrPickleDump",encoding:str="bytes")->Any:ifisinstance(r,JavaObject):clsName=r.getClass().getSimpleName()# convert RDD into JavaRDDifclsName!="JavaRDD"andclsName.endswith("RDD"):r=r.toJavaRDD()clsName="JavaRDD"assertsc._jvmisnotNoneifclsName=="JavaRDD":jrdd=sc._jvm.org.apache.spark.mllib.api.python.SerDe.javaToPython(r)returnRDD(jrdd,sc)ifclsName=="Dataset":returnDataFrame(r,SparkSession._getActiveSessionOrCreate())ifclsNamein_picklable_classes:r=sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r)elifisinstance(r,(JavaArray,JavaList)):try:r=sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r)exceptPy4JJavaError:pass# not pickableifisinstance(r,(bytearray,bytes)):r=CPickleSerializer().loads(bytes(r),encoding=encoding)returnrdefcallJavaFunc(sc:pyspark.core.context.SparkContext,func:Callable[...,"JavaObjectOrPickleDump"],*args:Any)->Any:"""Call Java Function"""java_args=[_py2java(sc,a)forainargs]return_java2py(sc,func(*java_args))defcallMLlibFunc(name:str,*args:Any)->Any:"""Call API in PythonMLLibAPI"""sc=SparkContext.getOrCreate()assertsc._jvmisnotNoneapi=getattr(sc._jvm.PythonMLLibAPI(),name)returncallJavaFunc(sc,api,*args)classJavaModelWrapper:""" Wrapper for the model in JVM """def__init__(self,java_model:JavaObject):self._sc=SparkContext.getOrCreate()self._java_model=java_modeldef__del__(self)->None:assertself._sc._gatewayisnotNoneself._sc._gateway.detach(self._java_model)defcall(self,name:str,*a:Any)->Any:"""Call method of java_model"""returncallJavaFunc(self._sc,getattr(self._java_model,name),*a)definherit_doc(cls:"C")->"C":""" A decorator that makes a class inherit documentation from its parents. """forname,funcinvars(cls).items():# only inherit docstring for public functionsifname.startswith("_"):continueifnotfunc.__doc__:forparentincls.__bases__:parent_func=getattr(parent,name,None)ifparent_funcandgetattr(parent_func,"__doc__",None):func.__doc__=parent_func.__doc__breakreturncls