-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathserver.py
83 lines (73 loc) · 3.03 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import json
import os
import time
import traceback
import grpc
from concurrent import futures
from threading import Thread
import raftos
from state_machine import NodeStateMachine # Import the corrected state machine
from rag import RAG
from utils import calculate_similarity, get_other_nodes
import service_pb2
import service_pb2_grpc
class QueryService(service_pb2_grpc.QueryServiceServicer):
def __init__(self, raft_node):
self.raft_node = raft_node
self.rag = RAG()
def Query(self, request, context):
print(f"Received query: {request.query}")
try:
# Make sure this operation is performed only by the leader
if self.raft_node.is_leader():
# Apply the query as a command to the Raft cluster
command = {
"type": "query",
"data": {"query": request.query}
}
# Serialize the command to a string
command_str = json.dumps(command)
# Apply the command via Raft, which will update the state machine
# Ensure apply_log returns a result that can be processed
result = self.raft_node.apply_log(command_str, True)
print(f"Result from Raft: {result}")
return service_pb2.QueryResponse(response=result)
else:
# Optionally, forward the request to the current leader
leader_address = self.raft_node.get_leader_address()
if leader_address:
print(f"Forwarding query to leader at {leader_address}")
with grpc.insecure_channel(leader_address) as channel:
stub = service_pb2_grpc.QueryServiceStub(channel)
return stub.Query(request)
else:
return service_pb2.QueryResponse(response="Leader not known")
except Exception as e:
print(f"Error during Query: {e}")
traceback.print_exc()
return service_pb2.QueryResponse(response=f"Error: {e}")
def serve():
node_id = os.environ.get("RAFT_ID")
raft_port = int(os.environ.get("RAFT_PORT"))
other_nodes = get_other_nodes(node_id)
# Initialize the state machine
state_machine = NodeStateMachine(node_id)
# Initialize and start Raft node
raft_node = raftos.RaftNode(node_id, raft_port, state_machine, other_nodes)
raftos.configure_logging(node_id, '/app/log')
raft_thread = Thread(target=raft_node.run_forever)
raft_thread.daemon = True
raft_thread.start()
# Wait for Raft node to be ready
time.sleep(5) # Adjust this delay as needed
# Initialize gRPC server
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
service_pb2_grpc.add_QueryServiceServicer_to_server(
QueryService(raft_node), server
)
server.add_insecure_port(f"[::]:50051")
server.start()
print(f"gRPC server started on port 50051 for {node_id}")
server.wait_for_termination()
if __name__ == "__main__":
serve()