Skip to content

Commit 656d01c

Browse files
Cast partitionNum to Int (#91)
1 parent 371ffbb commit 656d01c

File tree

4 files changed

+15
-20
lines changed

4 files changed

+15
-20
lines changed

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/Main.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ object Main {
7676
*/
7777
private[this] def createDataSource(spark: SparkSession,
7878
configs: Configs,
79-
partitionNum: String): DataFrame = {
79+
partitionNum: Int): DataFrame = {
8080
val dataSource = DataReader.make(configs)
8181
dataSource.read(spark, configs, partitionNum)
8282
}

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/SparkConfig.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ package com.vesoft.nebula.algorithm.config
77

88
import org.apache.spark.sql.SparkSession
99

10-
case class SparkConfig(spark: SparkSession, partitionNum: String)
10+
case class SparkConfig(spark: SparkSession, partitionNum: Int)
1111

1212
object SparkConfig {
1313

@@ -27,7 +27,7 @@ object SparkConfig {
2727
partitionNum = sparkConfigs.getOrElse("spark.app.partitionNum", "0")
2828
val spark = session.getOrCreate()
2929
validate(spark.version, "2.4.*")
30-
SparkConfig(spark, partitionNum)
30+
SparkConfig(spark, partitionNum.toInt)
3131
}
3232

3333
def validate(sparkVersion: String, supportedVersions: String*): Unit = {

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/reader/DataReader.scala

+11-16
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import scala.collection.mutable.ListBuffer
1414

1515
abstract class DataReader {
1616
val tpe: ReaderType
17-
def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame
17+
def read(spark: SparkSession, configs: Configs, partitionNum: Int): DataFrame
1818
}
1919
object DataReader {
2020
def make(configs: Configs): DataReader = {
@@ -32,12 +32,11 @@ object DataReader {
3232

3333
class NebulaReader extends DataReader {
3434
override val tpe: ReaderType = ReaderType.nebula
35-
override def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame = {
35+
override def read(spark: SparkSession, configs: Configs, partitionNum: Int): DataFrame = {
3636
val metaAddress = configs.nebulaConfig.readConfigEntry.address
3737
val space = configs.nebulaConfig.readConfigEntry.space
3838
val labels = configs.nebulaConfig.readConfigEntry.labels
3939
val weights = configs.nebulaConfig.readConfigEntry.weightCols
40-
val partition = partitionNum.toInt
4140

4241
val config =
4342
NebulaConnectionConfig
@@ -60,7 +59,7 @@ class NebulaReader extends DataReader {
6059
.withLabel(labels(i))
6160
.withNoColumn(noColumn)
6261
.withReturnCols(returnCols.toList)
63-
.withPartitionNum(partition)
62+
.withPartitionNum(partitionNum)
6463
.build()
6564
if (dataset == null) {
6665
dataset = spark.read.nebula(config, nebulaReadEdgeConfig).loadEdgesToDF()
@@ -85,13 +84,12 @@ final class NebulaNgqlReader extends NebulaReader {
8584

8685
override val tpe: ReaderType = ReaderType.nebulaNgql
8786

88-
override def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame = {
87+
override def read(spark: SparkSession, configs: Configs, partitionNum: Int): DataFrame = {
8988
val metaAddress = configs.nebulaConfig.readConfigEntry.address
9089
val graphAddress = configs.nebulaConfig.readConfigEntry.graphAddress
9190
val space = configs.nebulaConfig.readConfigEntry.space
9291
val labels = configs.nebulaConfig.readConfigEntry.labels
9392
val weights = configs.nebulaConfig.readConfigEntry.weightCols
94-
val partition = partitionNum.toInt
9593
val ngql = configs.nebulaConfig.readConfigEntry.ngql
9694

9795
val config =
@@ -112,7 +110,7 @@ final class NebulaNgqlReader extends NebulaReader {
112110
.builder()
113111
.withSpace(space)
114112
.withLabel(labels(i))
115-
.withPartitionNum(partition)
113+
.withPartitionNum(partitionNum)
116114
.withNgql(ngql)
117115
.build()
118116
if (dataset == null) {
@@ -137,13 +135,11 @@ final class NebulaNgqlReader extends NebulaReader {
137135

138136
final class CsvReader extends DataReader {
139137
override val tpe: ReaderType = ReaderType.csv
140-
override def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame = {
138+
override def read(spark: SparkSession, configs: Configs, partitionNum: Int): DataFrame = {
141139
val delimiter = configs.localConfigEntry.delimiter
142140
val header = configs.localConfigEntry.header
143141
val localPath = configs.localConfigEntry.filePath
144142

145-
val partition = partitionNum.toInt
146-
147143
val data =
148144
spark.read
149145
.option("header", header)
@@ -157,18 +153,17 @@ final class CsvReader extends DataReader {
157153
} else {
158154
data.select(src, dst)
159155
}
160-
if (partition != 0) {
161-
data.repartition(partition)
156+
if (partitionNum != 0) {
157+
data.repartition(partitionNum)
162158
}
163159
data
164160
}
165161
}
166162
final class JsonReader extends DataReader {
167163
override val tpe: ReaderType = ReaderType.json
168-
override def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame = {
164+
override def read(spark: SparkSession, configs: Configs, partitionNum: Int): DataFrame = {
169165
val localPath = configs.localConfigEntry.filePath
170166
val data = spark.read.json(localPath)
171-
val partition = partitionNum.toInt
172167

173168
val weight = configs.localConfigEntry.weight
174169
val src = configs.localConfigEntry.srcId
@@ -178,8 +173,8 @@ final class JsonReader extends DataReader {
178173
} else {
179174
data.select(src, dst)
180175
}
181-
if (partition != 0) {
182-
data.repartition(partition)
176+
if (partitionNum != 0) {
177+
data.repartition(partitionNum)
183178
}
184179
data
185180
}

nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/config/ConfigSuite.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class ConfigSuite {
4646
assert(sparkConfig.map.size == 3)
4747

4848
val spark = SparkConfig.getSpark(configs)
49-
assert(spark.partitionNum.toInt == 100)
49+
assert(spark.partitionNum == 100)
5050
}
5151

5252
@Test

0 commit comments

Comments
 (0)