Skip to content

Commit 0dc354c

Browse files
authored
cherry pick fix data structure for nebula datasource (#33)
1 parent 92940e6 commit 0dc354c

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/Configs.scala

+2
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@ case class NebulaReadConfigEntry(address: String = "",
197197
space: String = "",
198198
labels: List[String] = List(),
199199
weightCols: List[String] = List()) {
200+
assert(weightCols.isEmpty || labels.size == weightCols.size,
201+
"weightCols must be empty or has the same amount values with labels")
200202
override def toString: String = {
201203
s"NebulaReadConfigEntry: " +
202204
s"{address: $address, space: $space, labels: ${labels.mkString(",")}, " +

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/reader/DataReader.scala

+4-4
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,11 @@ class NebulaReader(spark: SparkSession, configs: Configs, partitionNum: String)
4848
.withReturnCols(returnCols.toList)
4949
.withPartitionNum(partition)
5050
.build()
51-
if (dataset == null) {
52-
dataset = spark.read.nebula(config, nebulaReadEdgeConfig).loadEdgesToDF()
53-
} else {
54-
dataset = dataset.union(spark.read.nebula(config, nebulaReadEdgeConfig).loadEdgesToDF())
51+
var df = spark.read.nebula(config, nebulaReadEdgeConfig).loadEdgesToDF()
52+
if (weights.nonEmpty) {
53+
df = df.select("_srcId", "_dstId", weights(i))
5554
}
55+
dataset = if (dataset == null) df else dataset.union(df)
5656
}
5757
dataset
5858
}

0 commit comments

Comments
 (0)