@@ -11,26 +11,22 @@ case class SparkConfig(spark: SparkSession, partitionNum: Int)
11
11
12
12
object SparkConfig {
13
13
14
- var spark : SparkSession = _
15
-
16
- var partitionNum : String = _
17
-
18
14
def getSpark (configs : Configs , defaultAppName : String = " algorithm" ): SparkConfig = {
19
15
val sparkConfigs = configs.sparkConfig.map
20
16
val session = SparkSession .builder
21
17
.appName(defaultAppName)
22
18
.config(" spark.serializer" , " org.apache.spark.serializer.KryoSerializer" )
23
19
24
- for (key <- sparkConfigs.keySet) {
25
- session.config(key, sparkConfigs(key) )
20
+ sparkConfigs.foreach { case (key, value) =>
21
+ session.config(key, value )
26
22
}
27
- partitionNum = sparkConfigs.getOrElse(" spark.app.partitionNum" , " 0" )
23
+ val partitionNum = sparkConfigs.getOrElse(" spark.app.partitionNum" , " 0" )
28
24
val spark = session.getOrCreate()
29
25
validate(spark.version, " 2.4.*" )
30
26
SparkConfig (spark, partitionNum.toInt)
31
27
}
32
28
33
- def validate (sparkVersion : String , supportedVersions : String * ): Unit = {
29
+ private def validate (sparkVersion : String , supportedVersions : String * ): Unit = {
34
30
if (sparkVersion != " UNKNOWN" && ! supportedVersions.exists(sparkVersion.matches)) {
35
31
throw new RuntimeException (
36
32
s """ Your current spark version ${sparkVersion} is not supported by the current NebulaGraph Algorithm.
0 commit comments