-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserver.py
337 lines (268 loc) · 13.3 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
import argparse
import datetime
import random
import os
import subprocess
import sys
import shutil
import pathlib
import pytz
import time
import boto3
import botocore
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
from helper import *
from ssl import *
from socket import *
# constants
S3 = boto3.client('s3', region_name='us-east-2')
DATA_BUCKET = "gatk-amd-genomics-test-data"
RESULT_BUCKET = "gatk-amd-genomics-result"
CLIENT_FS_BASE = os.path.expanduser("~/client")
RSA_PRIVATE_FILE = "rsa_priv.pem"
RSA_PUBLIC_FILE = "rsa_pub.pem"
# global variables
SECURE = True
TESTING=False
# send self-signed certificate
def send_self_cert(socket, self_cert_path):
with open(self_cert_path, "rb") as f:
send_message(socket, f.read())
# generate rsa filename
def get_rsa_filename(infix: str):
return "rsa_" + infix + ".pem"
# fetch and decrypt s3_sym_key_file
def decrypt_symmetric_key(s3_sym_key_file, secrets_dir):
try:
encrypted_path = os.path.join(secrets_dir, "encrypted.txt")
decrypted_path = os.path.join(secrets_dir, "decrypted.txt")
public_path = os.path.join(secrets_dir, RSA_PUBLIC_FILE)
private_path = os.path.join(secrets_dir, RSA_PRIVATE_FILE)
if TESTING:
response = S3.get_object(Bucket=DATA_BUCKET,
Key=s3_sym_key_file
)
else:
response = S3.get_object(Bucket=DATA_BUCKET,
Key=s3_sym_key_file,
IfModifiedSince=datetime.datetime.fromtimestamp(os.path.getmtime(public_path), tz=pytz.timezone('US/Eastern'))
)
with open(encrypted_path, "wb") as f:
f.write(response['Body'].read())
subprocess.run(["openssl", "pkeyutl", "-decrypt", "-inkey", private_path, "-in", encrypted_path, "-out", decrypted_path])
with open(decrypted_path, "r") as f:
decrypted = f.readline()
return decrypted.strip()
except botocore.exceptions.ClientError as e:
if e.response['Error']['Code'] == 'InvalidObjectState':
raise Exception("Symmetric key encrypted with previous version of RSA public key")
else:
raise Exception("Unexpected exception while fetching symmetric key")
# generate rsa keypair for decrypting symmetric keys
def generate_rsa_key():
private_key = rsa.generate_private_key(
public_exponent=65537, key_size=2048
)
return private_key
# generate the private key for ssl
def generate_private_key(key_path):
if not os.path.exists(key_path):
subprocess.run(["openssl", "genpkey", "-algorithm", "RSA", "-out", key_path])
return
# generate self-signed certificate for ssl using the private key
def generate_self_signed_cert(key_path, cert_path, common_name):
if not os.path.exists(cert_path):
subprocess.run(["openssl", "req", "-new", "-x509", "-key", key_path, "-out", cert_path, "-subj", "/CN="+common_name])
return
# generates certificates for attestation
def generate_certificates(snpguest: str):
cert_dir = "./certs"
if not os.path.exists(cert_dir):
os.mkdir(cert_dir)
try:
subprocess.run(f"sudo {snpguest} certificates PEM {cert_dir}", shell=True, check=True, capture_output=True)
except subprocess.CalledProcessError as e:
raise Exception(f"Failed to generate certificates: {e}")
return cert_dir
# generates attestation report
def generate_attestation_report(snpguest: str):
report_file = "report.bin"
try:
subprocess.run(f"sudo {snpguest} report {report_file} request-file.txt --random", shell=True, check=True, capture_output=True)
except subprocess.CalledProcessError as e:
raise Exception(f"Failed to generate attestation report: {e}")
return report_file
# sends attestation and handles requests for each TCP connection
def handle_client_connection(client_ssock, snpguest, secrets_dir):
create_dirs([CLIENT_FS_BASE])
try:
if SECURE:
# AMD SEV-SNP attestation
# generate and send attestation
start_time = time.time()
report_file = generate_attestation_report(snpguest)
with open(report_file, "rb") as f:
report_content = f.read()
# Send the length of the attestation report to the client
send_message(client_ssock, report_content)
# generate and send certificates
cert_dir = generate_certificates(snpguest)
for cert_file in os.listdir(cert_dir):
with open(os.path.join(cert_dir, cert_file), "rb") as f:
cert_content = f.read()
send_message(client_ssock, cert_file.encode())
send_message(client_ssock, cert_content)
send_message(client_ssock, "\r\n".encode())
print(f"Time to send attestation report and certificates: {time.time() - start_time} seconds")
# change into client directory
os.chdir(CLIENT_FS_BASE);
while True:
# listen for client requests until there are no more
cmd = receive_message(client_ssock).decode().split()
file_path = ''
if len(cmd) < 2 or cmd[0] not in ["DATA", "SCRIPT", "RSA"]:
break
if cmd[0] in ["DATA", "SCRIPT"]:
file_path = os.path.join(CLIENT_FS_BASE, cmd[1])
file_contents = receive_message(client_ssock)
with open(file_path, "wb") as f:
f.write(file_contents)
if cmd[0] == "RSA":
# share rsa public key
with open(os.path.join(secrets_dir, RSA_PUBLIC_FILE), 'rb') as f:
send_message(client_ssock, f.read())
elif cmd[0] == "DATA":
# fetch and decrypt data files from s3
start_time = time.time()
# fetch data files specified in file_path from s3
with open(file_path, "r") as f:
data_files = f.readlines()
for data_file in data_files:
data_file = data_file.strip()
response = S3.get_object(Bucket=DATA_BUCKET, Key=data_file)
with open(data_file, "wb") as f:
f.write(response['Body'].read())
symmetric_key = decrypt_symmetric_key(response['Metadata']['symmetric-key'], secrets_dir)
# decrypt file
subprocess.run(f"gpg --batch --output {data_file[:-4]} --passphrase {symmetric_key} --decrypt {data_file}", shell=True, check=True)
print(f"Finished reading and decrypting data files in {file_path}")
print(f"Time to fetch and decrypt data: {time.time() - start_time} seconds")
elif cmd[0] == "SCRIPT":
result_dir = cmd[2]
create_dirs([result_dir])
print(f"Running client script: {file_path}")
start_time = time.time()
# set file_path as executable and execute script (with no arguments)
subprocess.run(f"chmod +x {file_path}; bash {file_path}", shell=True, check=True, capture_output=True)
print(f"Finished running script {file_path}")
print(f"Time to run client script: {time.time() - start_time} seconds")
start_time = time.time()
# create new s3 directory
s3_dir = "result-" + str(random.randint(0, sys.maxsize * 2 + 1))
while "Common prefixes" in S3.list_objects(Bucket=RESULT_BUCKET, Prefix=s3_dir, Delimiter='/',MaxKeys=1):
s3_dir = "result-" + random.randint()
# upload all files under result_dir to s3_dir
for filename in os.listdir(result_dir):
file_path = os.path.join(result_dir, filename)
if os.path.isfile(file_path):
with open(file_path, "rb") as f:
S3.put_object(Body=f.read(), Bucket=RESULT_BUCKET, Key=os.path.join(s3_dir, filename))
print(f"Time to upload results to S3: {time.time() - start_time} seconds")
send_message(client_ssock, s3_dir.encode())
print(f"Uploaded results to {s3_dir}")
except Exception as e:
print(e)
# remove all files created for client before closing connection
if os.getcwd() == CLIENT_FS_BASE:
os.chdir("../")
shutil.rmtree(pathlib.Path(CLIENT_FS_BASE))
# run server
def run_server(snpguest: str, key_path: str, self_cert_path: str, secrets_dir: str, common_name: str):
port = 8080
server_sock = socket(AF_INET, SOCK_STREAM)
server_sock.bind(('', port))
server_sock.listen(10)
print(f"SERVER_HOST={common_name}")
print(f"SERVER_PORT={server_sock.getsockname()[1]}")
context = SSLContext(PROTOCOL_TLS_SERVER)
context.load_cert_chain(self_cert_path, key_path)
try:
while True:
try:
connection, _ = server_sock.accept()
# send self-signed certificate
send_self_cert(connection, self_cert_path)
client_ssock = context.wrap_socket(connection, server_side=True)
handle_client_connection(client_ssock, snpguest, secrets_dir)
client_ssock.close()
except KeyboardInterrupt as e:
raise e
except Exception as e:
print(e)
except Exception as e:
print(e)
server_sock.close()
def main():
try:
parser = argparse.ArgumentParser()
parser.add_argument('-sg', '--snpguest', default=None, help="Location of the snpguest utility executable (default: fetches and builds snpguest from source)")
parser.add_argument('-s', '--secrets_dir', default="~/secrets", help="Common name for generating self-signed certificate (default: ~/secrets)")
parser.add_argument('-kf', '--key_file', default="server.key", help="Private key file (default: server.key)")
parser.add_argument('-cf', '--cert_file', default="server.pem", help="Self-signed certificate file (default: server.pem)")
parser.add_argument('-cn', '--common_name', default=gethostname(), help=f"Common name for generating self-signed certificate (default: {gethostname()})")
parser.add_argument('-is', '--insecure', action='store_true', help="Flag for running server outside of a trusted execution environment")
parser.add_argument('-t', '--testing', action='store_true', help="Flag for running server for testing")
args = parser.parse_args()
global TESTING
TESTING = args.testing
if args.insecure:
global SECURE
SECURE = False
# generate private key and certificates for ssl
secrets_dir = os.path.expanduser(args.secrets_dir)
if not os.path.exists(secrets_dir):
os.mkdir(secrets_dir)
key_path = os.path.join(secrets_dir, args.key_file)
cert_path = os.path.join(secrets_dir, args.cert_file)
generate_private_key(key_path)
generate_self_signed_cert(key_path, cert_path, args.common_name)
if not args.snpguest:
try:
# fetch and build snpguest from source
if not os.path.isdir("./snpguest"):
subprocess.run('git clone https://github.com/virtee/snpguest.git', shell=True, capture_output=True, check=True)
if not os.path.isfile("./snpguest/target/release/snpguest"):
subprocess.run('cargo build -r', shell=True, capture_output=True, check=True, cwd="./snpguest/")
args.snpguest = "./snpguest/target/release/snpguest"
except subprocess.CalledProcessError as e:
print(f"Failed to fetch and build snpguest from source: {e}")
sys.exit(1)
elif not os.path.isfile(args.snpguest()):
print(f"Cannot find file {args.snpguest()}.")
# generate and save rsa keypair if either key does not already exist
private_path = os.path.join(secrets_dir, RSA_PRIVATE_FILE)
public_path = os.path.join(secrets_dir, RSA_PUBLIC_FILE)
if not os.path.isfile(private_path) or not os.path.isfile(public_path):
private_key = generate_rsa_key()
private_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
)
public_pem = private_key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
with open(private_path, 'wb') as f:
f.write(private_pem)
with open(public_path, 'wb') as f:
f.write(public_pem)
run_server(args.snpguest, key_path, cert_path, secrets_dir, args.common_name)
except Exception as e:
print(f"Unexpected error occurred: {e}")
sys.exit(1)
if __name__ == "__main__":
main()