-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
137 lines (113 loc) · 4.72 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import api_keys
import config
from config import MAX_DISCUSSION_ROUNDS
import atexit
import random
from models import (
gpt4o_chat,
gemini_chat,
grok_chat,
deepseek_chat,
claude_chat,
summarize_discussion
)
def get_result_values(result):
"""
Helper function to extract 'model', 'contribution' and 'vote' from the result.
"""
try:
# For Pydantic models or dictionary results with model info
model = result.get("model", result.__class__.__name__)
contribution = result.get("contribution", getattr(result, "contribution", ""))
vote = result.get("vote", getattr(result, "vote", False))
except AttributeError:
# Fallback
model = "Unknown Model"
contribution = result.get("contribution", "")
vote = result.get("vote", False)
return model, contribution, vote
def cleanup():
"""Clean up resources when the program exits."""
# Force cleanup of any outstanding gRPC connections
try:
import grpc
grpc.experimental.aio.shutdown_asyncio_engine()
except (ImportError, AttributeError):
pass
# Register the cleanup function to run at exit
atexit.register(cleanup)
def randomize_model_order(model_functions):
"""
Randomizes the order of model functions for each discussion round.
Args:
model_functions: List of model chat functions
Returns:
Randomized list of model functions
"""
# Create a copy of the original list to avoid modifying it
shuffled_models = model_functions.copy()
# Shuffle the list in place
random.shuffle(shuffled_models)
return shuffled_models
def main():
# Initialize an empty list to store the discussion context.
discussion_context = []
# Ask the user for the initial topic/message/question.
topic = input("Enter the discussion topic/message/question: ")
current_round = 0
continue_discussion = True
while continue_discussion and current_round < config.MAX_DISCUSSION_ROUNDS:
print(f"\n--- Discussion Round {current_round + 1} ---")
round_votes = []
# Get a new randomized order for this round
current_round_models = randomize_model_order([
gpt4o_chat,
gemini_chat,
grok_chat,
deepseek_chat,
claude_chat
])
# Iterate through each model.
for i, model_fn in enumerate(current_round_models):
# Call the model with the topic and the full discussion context
is_first_turn = current_round == 0 and i == 0
# Pass the first model indicator to ensure proper prompting
result = model_fn(topic, context_messages=discussion_context)
# Ensure we get valid responses
if isinstance(result, dict):
model_name = result.get("model", "Unknown Model")
contribution = result.get("contribution", "No contribution provided")
vote = result.get("vote", False)
else:
# Fallback for non-dictionary results
model_name = getattr(result, "model", model_fn.__name__.replace("_chat", "").capitalize())
contribution = getattr(result, "contribution", "No contribution provided")
vote = getattr(result, "vote", False)
# Display the model's contribution and vote
print(f"\n{model_name} contributed:\n{contribution}")
print(f"Vote for further discussion: {vote}")
# Append to discussion context with model identity
discussion_context.append({"model": model_name, "content": contribution})
round_votes.append(vote)
# Count votes: True means further discussion.
true_votes = sum(1 for vote in round_votes if vote)
false_votes = len(round_votes) - true_votes
print(f"\nRound {current_round + 1} votes -> Further discussion: {true_votes}, Stop: {false_votes}")
# Continue discussion if more models voted for further discussion.
if true_votes > false_votes:
print("Proceeding to the next discussion round...\n")
continue_discussion = True
else:
print("Ending discussion based on votes.\n")
continue_discussion = False
current_round += 1
# After the discussion ends, output the full discussion transcript.
print("\n--- Final Discussion Transcript ---")
for idx, entry in enumerate(discussion_context, 1):
print(f"{idx}. {entry}")
# Summarize the discussion using the summarize_discussion function.
final_summary = summarize_discussion(discussion_context)
print("\n--- Final Summary ---")
print(final_summary)
if __name__ == "__main__":
main()