Skip to content

Commit 21564be

Browse files
authored
support id mapping for algos and add examples (#68)
* support id mapping for algos * update config * add example for graph with string id
1 parent 00efbcd commit 21564be

19 files changed

+165
-52
lines changed

README-CN.md

+2-4
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,6 @@ nebula-algorithm 是一款基于 [GraphX](https://spark.apache.org/graphx/) 的
5151
```
5252
${SPARK_HOME}/bin/spark-submit --master <mode> --class com.vesoft.nebula.algorithm.Main nebula-algorithm-3.0-SNAPSHOT.jar -p application.conf
5353
```
54-
* 使用限制
55-
56-
Nebula Algorithm 算法包未自动对字符串 id 进行编码,因此采用第一种方式执行图算法时,边的源点和目标点必须是整数(Nebula Space 的 vid_type 可以是 String 类型,但数据必须是整数)。
5754
* 使用方法2:调用 nebula-algorithm 算法接口
5855
5956
在 `nebula-algorithm` 的 `lib` 库中提供了10+种常用图计算算法,可通过编程调用的形式调用算法。
@@ -75,7 +72,8 @@ nebula-algorithm 是一款基于 [GraphX](https://spark.apache.org/graphx/) 的
7572
val prConfig = new PRConfig(5, 1.0)
7673
val prResult = PageRankAlgo.apply(spark, data, prConfig, false)
7774
```
78-
* 如果你的节点 id 是 String 类型,可以参考 PageRank 的 [Example](https://github.com/vesoft-inc/nebula-algorithm/blob/master/example/src/main/scala/com/vesoft/nebula/algorithm/PageRankExample.scala) 。
75+
* 如果你的节点 id 是 String 类型,可以参考 [examples](https://github.com/vesoft-inc/nebula-algorithm/tree/master/example/src/main/scala/com/vesoft/nebula/algorithm/DegreeStaticExample.scala).
76+
7977
该 Example 进行了 id 转换,将 String 类型 id 编码为 Long 类型的 id , 并在算法结果中将 Long 类型 id 解码为原始的 String 类型 id 。
8078
8179
其他算法的调用方法见[测试示例](https://github.com/vesoft-inc/nebula-algorithm/tree/master/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib) 。

README.md

+1-5
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,6 @@ You could submit the entire spark application or invoke algorithms in `lib` libr
6060
${SPARK_HOME}/bin/spark-submit --master <mode> --class com.vesoft.nebula.algorithm.Main nebula-algorithm-3.0—SNAPSHOT.jar -p application.conf
6161
```
6262
63-
* Limitation
64-
65-
Due to Nebula Algorithm jar does not encode string id, thus during the algorithm execution, the source and target of edges must be in Type Int (The `vid_type` in Nebula Space could be String, while data must be in Type Int).
66-
6763
* Option2: Call nebula-algorithm interface
6864
6965
Now there are 10+ algorithms provided in `lib` from `nebula-algorithm`, which could be invoked in a programming fashion as below:
@@ -87,7 +83,7 @@ You could submit the entire spark application or invoke algorithms in `lib` libr
8783
val prResult = PageRankAlgo.apply(spark, data, prConfig, false)
8884
```
8985
90-
If your vertex ids are Strings, see [Pagerank Example](https://github.com/vesoft-inc/nebula-algorithm/blob/master/example/src/main/scala/com/vesoft/nebula/algorithm/PageRankExample.scala) for how to encoding and decoding them.
86+
If your vertex ids are Strings, please set the algo config with encodeId = true. see [examples](https://github.com/vesoft-inc/nebula-algorithm/tree/master/example/src/main/scala/com/vesoft/nebula/algorithm/DegreeStaticExample.scala)
9187
9288
For examples of other algorithms, see [examples](https://github.com/vesoft-inc/nebula-algorithm/tree/master/example/src/main/scala/com/vesoft/nebula/algorithm)
9389
> Note: The first column of DataFrame in the application represents the source vertices, the second represents the target vertices and the third represents edges' weight.

example/pom.xml

+1-1
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@
182182
<dependency>
183183
<groupId>com.vesoft</groupId>
184184
<artifactId>nebula-algorithm</artifactId>
185-
<version>3.0.0</version>
185+
<version>3.0-SNAPSHOT</version>
186186
</dependency>
187187
</dependencies>
188188
</project>

example/src/main/scala/com/vesoft/nebula/algorithm/DegreeStaticExample.scala

+11-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
package com.vesoft.nebula.algorithm
77

88
import com.facebook.thrift.protocol.TCompactProtocol
9-
import com.vesoft.nebula.algorithm.lib.{DegreeStaticAlgo}
9+
import com.vesoft.nebula.algorithm.config.DegreeStaticConfig
10+
import com.vesoft.nebula.algorithm.lib.DegreeStaticAlgo
1011
import org.apache.spark.SparkConf
1112
import org.apache.spark.sql.{DataFrame, SparkSession}
1213

@@ -22,15 +23,22 @@ object DegreeStaticExample {
2223
.config(sparkConf)
2324
.getOrCreate()
2425

25-
// val csvDF = ReadData.readCsvData(spark)
2626
// val nebulaDF = ReadData.readNebulaData(spark)
2727
val journalDF = ReadData.readLiveJournalData(spark)
28-
2928
degree(spark, journalDF)
29+
30+
val csvDF = ReadData.readStringCsvData(spark)
31+
degreeForStringId(spark, csvDF)
3032
}
3133

3234
def degree(spark: SparkSession, df: DataFrame): Unit = {
3335
val degree = DegreeStaticAlgo.apply(spark, df)
3436
degree.show()
3537
}
38+
39+
def degreeForStringId(spark: SparkSession, df: DataFrame): Unit = {
40+
val degreeConfig = new DegreeStaticConfig(true)
41+
val degree = DegreeStaticAlgo.apply(spark, df, degreeConfig)
42+
degree.show()
43+
}
3644
}

nebula-algorithm/src/main/resources/application.conf

+7-2
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,9 @@
113113
}
114114

115115
# Vertex degree statistics parameter
116-
degreestatic: {}
116+
degreestatic: {
117+
encodeId:false
118+
}
117119

118120
# KCore parameter
119121
kcore:{
@@ -123,7 +125,9 @@
123125
}
124126

125127
# Trianglecount parameter
126-
trianglecount:{}
128+
trianglecount:{
129+
encodeId:false
130+
}
127131

128132
# graphTriangleCount parameter
129133
graphtrianglecount:{}
@@ -189,6 +193,7 @@
189193
# JaccardAlgo parameter
190194
jaccard:{
191195
tol: 1.0
196+
encodeId:false
192197
}
193198
}
194199
}

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/AlgoConfig.scala

+24-15
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package com.vesoft.nebula.algorithm.config
77

8+
import com.vesoft.nebula.algorithm.config.JaccardConfig.encodeId
89
import org.apache.spark.graphx.VertexId
910

1011
case class PRConfig(maxIter: Int, resetProb: Double, encodeId: Boolean = false)
@@ -110,24 +111,30 @@ object LouvainConfig {
110111
/**
111112
* degree static
112113
*/
113-
case class DegreeStaticConfig(degree: Boolean,
114-
inDegree: Boolean,
115-
outDegree: Boolean,
116-
encodeId: Boolean = false)
114+
case class DegreeStaticConfig(encodeId: Boolean = false)
117115

118116
object DegreeStaticConfig {
119-
var degree: Boolean = false
120-
var inDegree: Boolean = false
121-
var outDegree: Boolean = false
122-
var encodeId: Boolean = false
117+
var encodeId: Boolean = false
123118

124119
def getDegreeStaticConfig(configs: Configs): DegreeStaticConfig = {
125120
val degreeConfig = configs.algorithmConfig.map
126-
degree = ConfigUtil.getOrElseBoolean(degreeConfig, "algorithm.degreestatic.degree", false)
127-
inDegree = ConfigUtil.getOrElseBoolean(degreeConfig, "algorithm.degreestatic.indegree", false)
128-
outDegree = ConfigUtil.getOrElseBoolean(degreeConfig, "algorithm.degreestatic.outdegree", false)
129121
encodeId = ConfigUtil.getOrElseBoolean(degreeConfig, "algorithm.degreestatic.encodeId", false)
130-
DegreeStaticConfig(degree, inDegree, outDegree, encodeId)
122+
DegreeStaticConfig(encodeId)
123+
}
124+
}
125+
126+
/**
127+
* graph triangle count
128+
*/
129+
case class TriangleConfig(encodeId: Boolean = false)
130+
131+
object TriangleConfig {
132+
var encodeId: Boolean = false
133+
def getTriangleConfig(configs: Configs): TriangleConfig = {
134+
val triangleConfig = configs.algorithmConfig.map
135+
encodeId =
136+
ConfigUtil.getOrElseBoolean(triangleConfig, "algorithm.trianglecount.encodeId", false)
137+
TriangleConfig(encodeId)
131138
}
132139
}
133140

@@ -321,14 +328,16 @@ object Node2vecConfig {
321328
/**
322329
* Jaccard
323330
*/
324-
case class JaccardConfig(tol: Double)
331+
case class JaccardConfig(tol: Double, encodeId: Boolean = false)
325332

326333
object JaccardConfig {
327-
var tol: Double = _
334+
var tol: Double = _
335+
var encodeId: Boolean = false
328336
def getJaccardConfig(configs: Configs): JaccardConfig = {
329337
val jaccardConfig = configs.algorithmConfig.map
330338
tol = jaccardConfig("algorithm.jaccard.tol").toDouble
331-
JaccardConfig(tol)
339+
encodeId = ConfigUtil.getOrElseBoolean(jaccardConfig, "algorithm.jaccard.encodeId", false)
340+
JaccardConfig(tol, encodeId)
332341
}
333342
}
334343

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/BfsAlgo.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ object BfsAlgo {
5050
.orderBy(col(AlgoConstants.BFS_RESULT_COL))
5151

5252
if (bfsConfig.encodeId) {
53-
DecodeUtil.convertAlgoId2StringId(algoResult, encodeIdDf)
53+
DecodeUtil.convertAlgoId2StringId(algoResult, encodeIdDf).coalesce(1)
5454
} else {
55-
algoResult
55+
algoResult.coalesce(1)
5656
}
5757
}
5858

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/DegreeStaticAlgo.scala

+18-5
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
package com.vesoft.nebula.algorithm.lib
77

8-
import com.vesoft.nebula.algorithm.config.AlgoConstants
9-
import com.vesoft.nebula.algorithm.utils.NebulaUtil
8+
import com.vesoft.nebula.algorithm.config.{AlgoConstants, DegreeStaticConfig}
9+
import com.vesoft.nebula.algorithm.utils.{DecodeUtil, NebulaUtil}
1010
import org.apache.log4j.Logger
1111
import org.apache.spark.graphx.{Graph, VertexRDD}
1212
import org.apache.spark.rdd.RDD
@@ -22,9 +22,18 @@ object DegreeStaticAlgo {
2222
/**
2323
* run the pagerank algorithm for nebula graph
2424
*/
25-
def apply(spark: SparkSession, dataset: Dataset[Row]): DataFrame = {
25+
def apply(spark: SparkSession,
26+
dataset: Dataset[Row],
27+
degreeConfig: DegreeStaticConfig = new DegreeStaticConfig): DataFrame = {
28+
var encodeIdDf: DataFrame = null
2629

27-
val graph: Graph[None.type, Double] = NebulaUtil.loadInitGraph(dataset, false)
30+
val graph: Graph[None.type, Double] = if (degreeConfig.encodeId) {
31+
val (data, encodeId) = DecodeUtil.convertStringId2LongId(dataset, false)
32+
encodeIdDf = encodeId
33+
NebulaUtil.loadInitGraph(data, false)
34+
} else {
35+
NebulaUtil.loadInitGraph(dataset, false)
36+
}
2837

2938
val degreeResultRDD = execute(graph)
3039

@@ -38,7 +47,11 @@ object DegreeStaticAlgo {
3847
val algoResult = spark.sqlContext
3948
.createDataFrame(degreeResultRDD, schema)
4049

41-
algoResult
50+
if (degreeConfig.encodeId) {
51+
DecodeUtil.convertAlgoId2StringId(algoResult, encodeIdDf)
52+
} else {
53+
algoResult
54+
}
4255
}
4356

4457
def execute(graph: Graph[None.type, Double]): RDD[Row] = {

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/JaccardAlgo.scala

+19-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
package com.vesoft.nebula.algorithm.lib
77

88
import com.vesoft.nebula.algorithm.config.JaccardConfig
9+
import com.vesoft.nebula.algorithm.utils.{DecodeUtil, NebulaUtil}
910
import org.apache.log4j.Logger
11+
import org.apache.spark.graphx.Graph
1012
import org.apache.spark.ml.feature.{
1113
CountVectorizer,
1214
CountVectorizerModel,
@@ -29,7 +31,16 @@ object JaccardAlgo {
2931
*/
3032
def apply(spark: SparkSession, dataset: Dataset[Row], jaccardConfig: JaccardConfig): DataFrame = {
3133

32-
val jaccardResult: RDD[Row] = execute(spark, dataset, jaccardConfig.tol)
34+
var encodeIdDf: DataFrame = null
35+
var data: DataFrame = dataset
36+
37+
if (jaccardConfig.encodeId) {
38+
val (encodeData, encodeId) = DecodeUtil.convertStringId2LongId(dataset, false)
39+
encodeIdDf = encodeId
40+
data = encodeData
41+
}
42+
43+
val jaccardResult: RDD[Row] = execute(spark, data, jaccardConfig.tol)
3344

3445
val schema = StructType(
3546
List(
@@ -38,7 +49,13 @@ object JaccardAlgo {
3849
StructField("similarity", DoubleType, nullable = true)
3950
))
4051
val algoResult = spark.sqlContext.createDataFrame(jaccardResult, schema)
41-
algoResult
52+
53+
if (jaccardConfig.encodeId) {
54+
DecodeUtil.convertIds2String(algoResult, encodeIdDf, "srcId", "dstId")
55+
} else {
56+
algoResult
57+
}
58+
4259
}
4360

4461
def execute(spark: SparkSession, dataset: Dataset[Row], tol: Double): RDD[Row] = {

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/TriangleCountAlgo.scala

+19-5
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
package com.vesoft.nebula.algorithm.lib
77

8-
import com.vesoft.nebula.algorithm.config.AlgoConstants
9-
import com.vesoft.nebula.algorithm.utils.NebulaUtil
8+
import com.vesoft.nebula.algorithm.config.{AlgoConstants, TriangleConfig}
9+
import com.vesoft.nebula.algorithm.utils.{DecodeUtil, NebulaUtil}
1010
import org.apache.log4j.Logger
1111
import org.apache.spark.graphx.{Graph, VertexRDD}
1212
import org.apache.spark.graphx.lib.TriangleCount
@@ -24,9 +24,19 @@ object TriangleCountAlgo {
2424
*
2525
* compute each vertex's triangle count
2626
*/
27-
def apply(spark: SparkSession, dataset: Dataset[Row]): DataFrame = {
27+
def apply(spark: SparkSession,
28+
dataset: Dataset[Row],
29+
triangleConfig: TriangleConfig = new TriangleConfig): DataFrame = {
2830

29-
val graph: Graph[None.type, Double] = NebulaUtil.loadInitGraph(dataset, false)
31+
var encodeIdDf: DataFrame = null
32+
33+
val graph: Graph[None.type, Double] = if (triangleConfig.encodeId) {
34+
val (data, encodeId) = DecodeUtil.convertStringId2LongId(dataset, false)
35+
encodeIdDf = encodeId
36+
NebulaUtil.loadInitGraph(data, false)
37+
} else {
38+
NebulaUtil.loadInitGraph(dataset, false)
39+
}
3040

3141
val triangleResultRDD = execute(graph)
3242

@@ -38,7 +48,11 @@ object TriangleCountAlgo {
3848
val algoResult = spark.sqlContext
3949
.createDataFrame(triangleResultRDD, schema)
4050

41-
algoResult
51+
if (triangleConfig.encodeId) {
52+
DecodeUtil.convertAlgoId2StringId(algoResult, encodeIdDf)
53+
} else {
54+
algoResult
55+
}
4256
}
4357

4458
def execute(graph: Graph[None.type, Double]): RDD[Row] = {

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/utils/DecodeUtil.scala

+17
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,21 @@ object DecodeUtil {
7373
.drop(algoProp)
7474
.withColumnRenamed(ORIGIN_ID_COL, algoProp)
7575
}
76+
77+
def convertIds2String(dataframe: DataFrame,
78+
encodeId: DataFrame,
79+
srcCol: String,
80+
dstCol: String): DataFrame = {
81+
encodeId
82+
.join(dataframe)
83+
.where(col(ENCODE_ID_COL) === col(srcCol))
84+
.drop(ENCODE_ID_COL)
85+
.drop(srcCol)
86+
.withColumnRenamed(ORIGIN_ID_COL, srcCol)
87+
.join(encodeId)
88+
.where(col(dstCol) === col(ENCODE_ID_COL))
89+
.drop(ENCODE_ID_COL)
90+
.drop(dstCol)
91+
.withColumnRenamed(ORIGIN_ID_COL, dstCol)
92+
}
7693
}

nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/DegreeStaticAlgoSuite.scala

+7-1
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55

66
package com.vesoft.nebula.algorithm.lib
77

8+
import com.vesoft.nebula.algorithm.config.DegreeStaticConfig
89
import org.apache.spark.sql.SparkSession
910
import org.junit.Test
1011

1112
class DegreeStaticAlgoSuite {
1213
@Test
1314
def degreeStaticAlgoSuite(): Unit = {
14-
val spark = SparkSession.builder().master("local").getOrCreate()
15+
val spark =
16+
SparkSession.builder().master("local").config("spark.sql.shuffle.partitions", 5).getOrCreate()
1517
val data = spark.read.option("header", true).csv("src/test/resources/edge.csv")
1618
val result = DegreeStaticAlgo.apply(spark, data)
1719
assert(result.count() == 4)
@@ -20,5 +22,9 @@ class DegreeStaticAlgoSuite {
2022
assert(row.get(2).toString.toInt == 4)
2123
assert(row.get(3).toString.toInt == 4)
2224
})
25+
26+
val config = DegreeStaticConfig(true)
27+
val encodeResult = DegreeStaticAlgo.apply(spark, data, config)
28+
assert(result.count() == 4)
2329
}
2430
}

nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/HanpSuite.scala

+5-1
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@ import org.junit.Test
1414
class HanpSuite {
1515
@Test
1616
def hanpSuite() = {
17-
val spark = SparkSession.builder().master("local").getOrCreate()
17+
val spark =
18+
SparkSession.builder().master("local").config("spark.sql.shuffle.partitions", 5).getOrCreate()
1819
val data = spark.read.option("header", true).csv("src/test/resources/edge.csv")
1920
val hanpConfig = new HanpConfig(0.1, 10, 1.0)
2021
val result = HanpAlgo.apply(spark, data, hanpConfig, false)
2122
assert(result.count() == 4)
23+
24+
val encodeHanpConfig = new HanpConfig(0.1, 10, 1.0, true)
25+
assert(HanpAlgo.apply(spark, data, encodeHanpConfig, false).count() == 4)
2226
}
2327
}

0 commit comments

Comments
 (0)