## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#importgcimportosimportsysfromtempfileimportNamedTemporaryFileimportthreadingimportpicklefromtypingimport(overload,Any,Callable,Dict,Generic,IO,Iterator,Optional,Tuple,TypeVar,TYPE_CHECKING,Union,)fromtyping.ioimportBinaryIO# type: ignore[import]frompyspark.java_gatewayimportlocal_connect_and_authfrompyspark.serializersimportChunkedStream,pickle_protocolfrompyspark.utilimportprint_execifTYPE_CHECKING:frompysparkimportSparkContext__all__=["Broadcast"]T=TypeVar("T")# Holds broadcasted data received from Java, keyed by its id._broadcastRegistry:Dict[int,"Broadcast[Any]"]={}def_from_id(bid:int)->"Broadcast[Any]":frompyspark.broadcastimport_broadcastRegistryifbidnotin_broadcastRegistry:raiseRuntimeError("Broadcast variable '%s' not loaded!"%bid)return_broadcastRegistry[bid]
[docs]classBroadcast(Generic[T]):""" A broadcast variable created with :meth:`SparkContext.broadcast`. Access its value through :attr:`value`. Examples -------- >>> b = spark.sparkContext.broadcast([1, 2, 3, 4, 5]) >>> b.value [1, 2, 3, 4, 5] >>> spark.sparkContext.parallelize([0, 0]).flatMap(lambda x: b.value).collect() [1, 2, 3, 4, 5, 1, 2, 3, 4, 5] >>> b.unpersist() >>> large_broadcast = spark.sparkContext.broadcast(range(10000)) """@overload# On driverdef__init__(self:"Broadcast[T]",sc:"SparkContext",value:T,pickle_registry:"BroadcastPickleRegistry",):...@overload# On worker without decryption serverdef__init__(self:"Broadcast[Any]",*,path:str):...@overload# On worker with decryption serverdef__init__(self:"Broadcast[Any]",*,sock_file:str):...def__init__(self,sc:Optional["SparkContext"]=None,value:Optional[T]=None,pickle_registry:Optional["BroadcastPickleRegistry"]=None,path:Optional[str]=None,sock_file:Optional[BinaryIO]=None,):""" Should not be called directly by users -- use :meth:`SparkContext.broadcast` instead. """ifscisnotNone:# we're on the driver. We want the pickled data to end up in a file (maybe encrypted)f=NamedTemporaryFile(delete=False,dir=sc._temp_dir)self._path=f.nameself._sc:Optional["SparkContext"]=scassertsc._jvmisnotNoneself._python_broadcast=sc._jvm.PythonRDD.setupBroadcast(self._path)broadcast_out:Union[ChunkedStream,IO[bytes]]ifsc._encryption_enabled:# with encryption, we ask the jvm to do the encryption for us, we send it data# over a socketport,auth_secret=self._python_broadcast.setupEncryptionServer()(encryption_sock_file,_)=local_connect_and_auth(port,auth_secret)broadcast_out=ChunkedStream(encryption_sock_file,8192)else:# no encryption, we can just write pickled data directly to the file from pythonbroadcast_out=fself.dump(value,broadcast_out)# type: ignore[arg-type]ifsc._encryption_enabled:self._python_broadcast.waitTillDataReceived()self._jbroadcast=sc._jsc.broadcast(self._python_broadcast)self._pickle_registry=pickle_registryelse:# we're on an executorself._jbroadcast=Noneself._sc=Noneself._python_broadcast=Noneifsock_fileisnotNone:# the jvm is doing decryption for us. Read the value# immediately from the sock_fileself._value=self.load(sock_file)else:# the jvm just dumps the pickled data in path -- we'll unpickle lazily when# the value is requestedassertpathisnotNoneself._path=path
[docs]defdump(self,value:T,f:BinaryIO)->None:""" Write a pickled representation of value to the open file or socket. The protocol pickle is HIGHEST_PROTOCOL. Parameters ---------- value : T Value to write. f : :class:`BinaryIO` File or socket where the pickled value will be stored. Examples -------- >>> import os >>> import tempfile >>> b = spark.sparkContext.broadcast([1, 2, 3, 4, 5]) Write a pickled representation of `b` to the open temp file. >>> with tempfile.TemporaryDirectory() as d: ... path = os.path.join(d, "test.txt") ... with open(path, "wb") as f: ... b.dump(b.value, f) """try:pickle.dump(value,f,pickle_protocol)exceptpickle.PickleError:raiseexceptExceptionase:msg="Could not serialize broadcast: %s: %s"%(e.__class__.__name__,str(e))print_exec(sys.stderr)raisepickle.PicklingError(msg)f.close()
[docs]defload_from_path(self,path:str)->T:""" Read the pickled representation of an object from the open file and return the reconstituted object hierarchy specified therein. Parameters ---------- path : str File path where reads the pickled value. Returns ------- T The object hierarchy specified therein reconstituted from the pickled representation of an object. Examples -------- >>> import os >>> import tempfile >>> b = spark.sparkContext.broadcast([1, 2, 3, 4, 5]) >>> c = spark.sparkContext.broadcast(1) Read the pickled representation of value fron temp file. >>> with tempfile.TemporaryDirectory() as d: ... path = os.path.join(d, "test.txt") ... with open(path, "wb") as f: ... b.dump(b.value, f) ... c.load_from_path(path) [1, 2, 3, 4, 5] """withopen(path,"rb",1<<20)asf:returnself.load(f)
[docs]defload(self,file:BinaryIO)->T:""" Read a pickled representation of value from the open file or socket. Parameters ---------- file : :class:`BinaryIO` File or socket where the pickled value will be read. Returns ------- T The object hierarchy specified therein reconstituted from the pickled representation of an object. Examples -------- >>> import os >>> import tempfile >>> b = spark.sparkContext.broadcast([1, 2, 3, 4, 5]) >>> c = spark.sparkContext.broadcast(1) Read the pickled representation of value from the open temp file. >>> with tempfile.TemporaryDirectory() as d: ... path = os.path.join(d, "test.txt") ... with open(path, "wb") as f: ... b.dump(b.value, f) ... with open(path, "rb") as f: ... c.load(f) [1, 2, 3, 4, 5] """gc.disable()try:returnpickle.load(file)finally:gc.enable()
@propertydefvalue(self)->T:"""Return the broadcasted value"""ifnothasattr(self,"_value")andself._pathisnotNone:# we only need to decrypt it here when encryption is enabled and# if its on the driver, since executor decryption is handled alreadyifself._scisnotNoneandself._sc._encryption_enabled:port,auth_secret=self._python_broadcast.setupDecryptionServer()(decrypted_sock_file,_)=local_connect_and_auth(port,auth_secret)self._python_broadcast.waitTillBroadcastDataSent()returnself.load(decrypted_sock_file)else:self._value=self.load_from_path(self._path)returnself._value
[docs]defunpersist(self,blocking:bool=False)->None:""" Delete cached copies of this broadcast on the executors. If the broadcast is used after this is called, it will need to be re-sent to each executor. Parameters ---------- blocking : bool, optional, default False Whether to block until unpersisting has completed. Examples -------- >>> b = spark.sparkContext.broadcast([1, 2, 3, 4, 5]) Delete cached copies of this broadcast on the executors >>> b.unpersist() """ifself._jbroadcastisNone:raiseRuntimeError("Broadcast can only be unpersisted in driver")self._jbroadcast.unpersist(blocking)
[docs]defdestroy(self,blocking:bool=False)->None:""" Destroy all data and metadata related to this broadcast variable. Use this with caution; once a broadcast variable has been destroyed, it cannot be used again. .. versionchanged:: 3.0.0 Added optional argument `blocking` to specify whether to block until all blocks are deleted. Parameters ---------- blocking : bool, optional, default False Whether to block until unpersisting has completed. Examples -------- >>> b = spark.sparkContext.broadcast([1, 2, 3, 4, 5]) Destroy all data and metadata related to this broadcast variable >>> b.destroy() """ifself._jbroadcastisNone:raiseRuntimeError("Broadcast can only be destroyed in driver")self._jbroadcast.destroy(blocking)os.unlink(self._path)
def__reduce__(self)->Tuple[Callable[[int],"Broadcast[T]"],Tuple[int]]:ifself._jbroadcastisNone:raiseRuntimeError("Broadcast can only be serialized in driver")assertself._pickle_registryisnotNoneself._pickle_registry.add(self)return_from_id,(self._jbroadcast.id(),)
classBroadcastPickleRegistry(threading.local):"""Thread-local registry for broadcast variables that have been pickled"""def__init__(self)->None:self.__dict__.setdefault("_registry",set())def__iter__(self)->Iterator[Broadcast[Any]]:forbcastinself._registry:yieldbcastdefadd(self,bcast:Broadcast[Any])->None:self._registry.add(bcast)defclear(self)->None:self._registry.clear()def_test()->None:importdoctestfrompyspark.sqlimportSparkSessionimportpyspark.broadcastglobs=pyspark.broadcast.__dict__.copy()spark=SparkSession.builder.master("local[4]").appName("broadcast tests").getOrCreate()globs["spark"]=spark(failure_count,test_count)=doctest.testmod(pyspark.broadcast,globs=globs)spark.stop()iffailure_count:sys.exit(-1)if__name__=="__main__":_test()