@@ -14,7 +14,7 @@ import scala.collection.mutable.ListBuffer
14
14
15
15
abstract class DataReader {
16
16
val tpe : ReaderType
17
- def read (spark : SparkSession , configs : Configs , partitionNum : String ): DataFrame
17
+ def read (spark : SparkSession , configs : Configs , partitionNum : Int ): DataFrame
18
18
}
19
19
object DataReader {
20
20
def make (configs : Configs ): DataReader = {
@@ -32,12 +32,11 @@ object DataReader {
32
32
33
33
class NebulaReader extends DataReader {
34
34
override val tpe : ReaderType = ReaderType .nebula
35
- override def read (spark : SparkSession , configs : Configs , partitionNum : String ): DataFrame = {
35
+ override def read (spark : SparkSession , configs : Configs , partitionNum : Int ): DataFrame = {
36
36
val metaAddress = configs.nebulaConfig.readConfigEntry.address
37
37
val space = configs.nebulaConfig.readConfigEntry.space
38
38
val labels = configs.nebulaConfig.readConfigEntry.labels
39
39
val weights = configs.nebulaConfig.readConfigEntry.weightCols
40
- val partition = partitionNum.toInt
41
40
42
41
val config =
43
42
NebulaConnectionConfig
@@ -60,7 +59,7 @@ class NebulaReader extends DataReader {
60
59
.withLabel(labels(i))
61
60
.withNoColumn(noColumn)
62
61
.withReturnCols(returnCols.toList)
63
- .withPartitionNum(partition )
62
+ .withPartitionNum(partitionNum )
64
63
.build()
65
64
if (dataset == null ) {
66
65
dataset = spark.read.nebula(config, nebulaReadEdgeConfig).loadEdgesToDF()
@@ -85,13 +84,12 @@ final class NebulaNgqlReader extends NebulaReader {
85
84
86
85
override val tpe : ReaderType = ReaderType .nebulaNgql
87
86
88
- override def read (spark : SparkSession , configs : Configs , partitionNum : String ): DataFrame = {
87
+ override def read (spark : SparkSession , configs : Configs , partitionNum : Int ): DataFrame = {
89
88
val metaAddress = configs.nebulaConfig.readConfigEntry.address
90
89
val graphAddress = configs.nebulaConfig.readConfigEntry.graphAddress
91
90
val space = configs.nebulaConfig.readConfigEntry.space
92
91
val labels = configs.nebulaConfig.readConfigEntry.labels
93
92
val weights = configs.nebulaConfig.readConfigEntry.weightCols
94
- val partition = partitionNum.toInt
95
93
val ngql = configs.nebulaConfig.readConfigEntry.ngql
96
94
97
95
val config =
@@ -112,7 +110,7 @@ final class NebulaNgqlReader extends NebulaReader {
112
110
.builder()
113
111
.withSpace(space)
114
112
.withLabel(labels(i))
115
- .withPartitionNum(partition )
113
+ .withPartitionNum(partitionNum )
116
114
.withNgql(ngql)
117
115
.build()
118
116
if (dataset == null ) {
@@ -137,13 +135,11 @@ final class NebulaNgqlReader extends NebulaReader {
137
135
138
136
final class CsvReader extends DataReader {
139
137
override val tpe : ReaderType = ReaderType .csv
140
- override def read (spark : SparkSession , configs : Configs , partitionNum : String ): DataFrame = {
138
+ override def read (spark : SparkSession , configs : Configs , partitionNum : Int ): DataFrame = {
141
139
val delimiter = configs.localConfigEntry.delimiter
142
140
val header = configs.localConfigEntry.header
143
141
val localPath = configs.localConfigEntry.filePath
144
142
145
- val partition = partitionNum.toInt
146
-
147
143
val data =
148
144
spark.read
149
145
.option(" header" , header)
@@ -157,18 +153,17 @@ final class CsvReader extends DataReader {
157
153
} else {
158
154
data.select(src, dst)
159
155
}
160
- if (partition != 0 ) {
161
- data.repartition(partition )
156
+ if (partitionNum != 0 ) {
157
+ data.repartition(partitionNum )
162
158
}
163
159
data
164
160
}
165
161
}
166
162
final class JsonReader extends DataReader {
167
163
override val tpe : ReaderType = ReaderType .json
168
- override def read (spark : SparkSession , configs : Configs , partitionNum : String ): DataFrame = {
164
+ override def read (spark : SparkSession , configs : Configs , partitionNum : Int ): DataFrame = {
169
165
val localPath = configs.localConfigEntry.filePath
170
166
val data = spark.read.json(localPath)
171
- val partition = partitionNum.toInt
172
167
173
168
val weight = configs.localConfigEntry.weight
174
169
val src = configs.localConfigEntry.srcId
@@ -178,8 +173,8 @@ final class JsonReader extends DataReader {
178
173
} else {
179
174
data.select(src, dst)
180
175
}
181
- if (partition != 0 ) {
182
- data.repartition(partition )
176
+ if (partitionNum != 0 ) {
177
+ data.repartition(partitionNum )
183
178
}
184
179
data
185
180
}
0 commit comments