Skip to content

Commit 00efbcd

Browse files
authored
support algorithm for string id for jar-tool approach (#67)
* add map string id function * support algorithm for string id for jar-tool approach * config the default encodeId as false * remove the return df * add test
1 parent 192ee24 commit 00efbcd

19 files changed

+483
-103
lines changed

example/pom.xml

+1-1
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@
182182
<dependency>
183183
<groupId>com.vesoft</groupId>
184184
<artifactId>nebula-algorithm</artifactId>
185-
<version>${project.version}</version>
185+
<version>3.0.0</version>
186186
</dependency>
187187
</dependencies>
188188
</project>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/* Copyright (c) 2022 vesoft inc. All rights reserved.
2+
*
3+
* This source code is licensed under Apache 2.0 License.
4+
*/
5+
6+
package com.vesoft.nebula.algorithm
7+
8+
import com.vesoft.nebula.connector.connector.NebulaDataFrameReader
9+
import com.facebook.thrift.protocol.TCompactProtocol
10+
import com.vesoft.nebula.connector.{NebulaConnectionConfig, ReadNebulaConfig}
11+
import org.apache.log4j.Logger
12+
import org.apache.spark.SparkConf
13+
import org.apache.spark.graphx.{Edge, EdgeDirection, EdgeTriplet, Graph, Pregel, VertexId}
14+
import org.apache.spark.rdd.RDD
15+
import org.apache.spark.sql.{DataFrame, Encoder, SparkSession}
16+
17+
import scala.collection.mutable
18+
19+
object DeepQueryTest {
20+
private val LOGGER = Logger.getLogger(this.getClass)
21+
22+
def main(args: Array[String]): Unit = {
23+
val sparkConf = new SparkConf()
24+
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
25+
val spark = SparkSession
26+
.builder()
27+
.config(sparkConf)
28+
.getOrCreate()
29+
val iter = args(0).toInt
30+
val id = args(1).toInt
31+
32+
query(spark, iter, id)
33+
}
34+
35+
def readNebulaData(spark: SparkSession): DataFrame = {
36+
37+
val config =
38+
NebulaConnectionConfig
39+
.builder()
40+
.withMetaAddress("192.168.15.5:9559")
41+
.withTimeout(6000)
42+
.withConenctionRetry(2)
43+
.build()
44+
val nebulaReadEdgeConfig: ReadNebulaConfig = ReadNebulaConfig
45+
.builder()
46+
.withSpace("twitter")
47+
.withLabel("FOLLOW")
48+
.withNoColumn(true)
49+
.withLimit(20000)
50+
.withPartitionNum(120)
51+
.build()
52+
val df: DataFrame =
53+
spark.read.nebula(config, nebulaReadEdgeConfig).loadEdgesToDF()
54+
df
55+
}
56+
57+
def deepQuery(df: DataFrame,
58+
maxIterations: Int,
59+
startId: Int): Graph[mutable.HashSet[Int], Double] = {
60+
implicit val encoder: Encoder[Edge[Double]] = org.apache.spark.sql.Encoders.kryo[Edge[Double]]
61+
val edges: RDD[Edge[Double]] = df
62+
.map(row => {
63+
Edge(row.get(0).toString.toLong, row.get(1).toString.toLong, 1.0)
64+
})(encoder)
65+
.rdd
66+
67+
val graph = Graph.fromEdges(edges, None)
68+
69+
val queryGraph = graph.mapVertices { (vid, _) =>
70+
mutable.HashSet[Int](vid.toInt)
71+
}
72+
queryGraph.cache()
73+
queryGraph.numVertices
74+
queryGraph.numEdges
75+
df.unpersist()
76+
77+
def sendMessage(edge: EdgeTriplet[mutable.HashSet[Int], Double])
78+
: Iterator[(VertexId, mutable.HashSet[Int])] = {
79+
val (smallSet, largeSet) = if (edge.srcAttr.size < edge.dstAttr.size) {
80+
(edge.srcAttr, edge.dstAttr)
81+
} else {
82+
(edge.dstAttr, edge.srcAttr)
83+
}
84+
85+
if (smallSet.size == maxIterations) {
86+
Iterator.empty
87+
} else {
88+
val newNeighbors =
89+
(for (id <- smallSet; neighbor <- largeSet if neighbor != id) yield neighbor)
90+
Iterator((edge.dstId, newNeighbors))
91+
}
92+
}
93+
94+
val initialMessage = mutable.HashSet[Int]()
95+
96+
val pregelGraph = Pregel(queryGraph, initialMessage, maxIterations, EdgeDirection.Both)(
97+
vprog = (id, attr, msg) => attr ++ msg,
98+
sendMsg = sendMessage,
99+
mergeMsg = (a, b) => {
100+
val setResult = a ++ b
101+
setResult
102+
}
103+
)
104+
pregelGraph.cache()
105+
pregelGraph.numVertices
106+
pregelGraph.numEdges
107+
queryGraph.unpersist()
108+
pregelGraph
109+
}
110+
111+
def query(spark: SparkSession, maxIter: Int, startId: Int): Unit = {
112+
val start = System.currentTimeMillis()
113+
val df = readNebulaData(spark)
114+
df.cache()
115+
df.count()
116+
println(s"read data cost time ${(System.currentTimeMillis() - start)}")
117+
118+
val startQuery = System.currentTimeMillis()
119+
val graph = deepQuery(df, maxIter, startId)
120+
121+
val endQuery = System.currentTimeMillis()
122+
val num = graph.vertices.filter(row => row._2.contains(startId)).count()
123+
val end = System.currentTimeMillis()
124+
println(s"query cost: ${endQuery - startQuery}")
125+
println(s"count: ${num}, cost: ${end - endQuery}")
126+
}
127+
}

example/src/main/scala/com/vesoft/nebula/algorithm/PageRankExample.scala

+5-3
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
package com.vesoft.nebula.algorithm
77

88
import com.facebook.thrift.protocol.TCompactProtocol
9-
import com.vesoft.nebula.algorithm.config.{CcConfig, PRConfig}
10-
import com.vesoft.nebula.algorithm.lib.{PageRankAlgo, StronglyConnectedComponentsAlgo}
9+
import com.vesoft.nebula.algorithm.config.PRConfig
10+
import com.vesoft.nebula.algorithm.lib.PageRankAlgo
1111
import org.apache.spark.SparkConf
1212
import org.apache.spark.sql.expressions.Window
13-
import org.apache.spark.sql.functions.{col, dense_rank}
13+
import org.apache.spark.sql.functions.{col, dense_rank, monotonically_increasing_id}
1414
import org.apache.spark.sql.{DataFrame, SparkSession}
1515

1616
object PageRankExample {
@@ -69,6 +69,8 @@ object PageRankExample {
6969
// encode id to Long type using dense_rank, the encodeId has two columns: id, encodedId
7070
// then you need to save the encodeId to convert back for the algorithm's result.
7171
val encodeId = idDF.withColumn("encodedId", dense_rank().over(Window.orderBy("id")))
72+
// using function monotonically_increasing_id(), please refer https://spark.apache.org/docs/3.0.2/api/java/org/apache/spark/sql/functions.html#monotonically_increasing_id--
73+
// val encodeId = idDF.withColumn("encodedId", monotonically_increasing_id())
7274
encodeId.write.option("header", true).csv("file:///tmp/encodeId.csv")
7375
encodeId.show()
7476

nebula-algorithm/src/main/resources/application.conf

+11
Original file line numberDiff line numberDiff line change
@@ -83,23 +83,27 @@
8383
pagerank: {
8484
maxIter: 10
8585
resetProb: 0.15 # default 0.15
86+
encodeId:false # if your data has string type id, please config encodeId as true.
8687
}
8788

8889
# Louvain parameter
8990
louvain: {
9091
maxIter: 20
9192
internalIter: 10
9293
tol: 0.5
94+
encodeId:false
9395
}
9496

9597
# connected component parameter.
9698
connectedcomponent: {
9799
maxIter: 20
100+
encodeId:false
98101
}
99102

100103
# LabelPropagation parameter
101104
labelpropagation: {
102105
maxIter: 20
106+
encodeId:false
103107
}
104108

105109
# ShortestPaths parameter
@@ -115,6 +119,7 @@
115119
kcore:{
116120
maxIter:10
117121
degree:1
122+
encodeId:false
118123
}
119124

120125
# Trianglecount parameter
@@ -126,13 +131,15 @@
126131
# Betweenness centrality parameter. maxIter parameter means the max times of iterations.
127132
betweenness:{
128133
maxIter:5
134+
encodeId:false
129135
}
130136

131137
# Clustering Coefficient parameter. The type parameter has two choice, local or global
132138
# local type will compute the clustering coefficient for each vertex, and print the average coefficient for graph.
133139
# global type just compute the graph's clustering coefficient.
134140
clusteringcoefficient:{
135141
type: local
142+
encodeId:false
136143
}
137144

138145
# ClosenessAlgo parameter
@@ -142,19 +149,22 @@
142149
bfs:{
143150
maxIter:5
144151
root:"10"
152+
encodeId:false
145153
}
146154

147155
# DFS parameter
148156
dfs:{
149157
maxIter:5
150158
root:"10"
159+
encodeId:false
151160
}
152161

153162
# HanpAlgo parameter
154163
hanp:{
155164
hopAttenuation:0.1
156165
maxIter:10
157166
preference:1.0
167+
encodeId:false
158168
}
159169

160170
#Node2vecAlgo parameter
@@ -173,6 +183,7 @@
173183
degree: 30,
174184
embSeparate: ",",
175185
modelPath: "hdfs://127.0.0.1:9000/model"
186+
encodeId:false
176187
}
177188

178189
# JaccardAlgo parameter

0 commit comments

Comments
 (0)