def create_retrieve_snap(self): global init_counter branch_msg = bank_pb2.BranchMessage() ret = bank_pb2.RetrieveSnapshot() ret.snapshot_id = init_counter branch_msg.retrieve_snapshot.CopyFrom(ret) return branch_msg
def get_snapshot(snapshot_id, connections, money): snapshots = [] retrieve = bank_pb2.RetrieveSnapshot() retrieve.snapshot_id = snapshot_id message = bank_pb2.BranchMessage() message.retrieve_snapshot.MergeFrom(retrieve) total = 0 print('snapshot_id: {}'.format(snapshot_id)) for sock in connections: result_str = '' # Message server message_socket(sock[0], message) # Read returned snapshot initial_data = sock[0].recv(2) data_size = struct.unpack('H', initial_data)[0] data = sock[0].recv(data_size) rec_message = bank_pb2.BranchMessage() rec_message.ParseFromString(data) assert rec_message.WhichOneof('branch_message') == 'return_snapshot' local_snapshot = rec_message.return_snapshot.local_snapshot # Get initial balance result_str += '{}: {}, '.format(sock[1], local_snapshot.balance) # if not all(x == 0 for x in local_snapshot.channel_state): # print('channel_states: ') # print(local_snapshot.channel_state) total += local_snapshot.balance for i in range(len(local_snapshot.channel_state)): if i != get_branch_index(sock[1]): amt = local_snapshot.channel_state[i] total += amt result_str += '{}->{}: {}, '.format(connections[i][1], sock[1], amt) print(result_str) print('Total: {}'.format(total)) print('') assert total == money
random_entry = random.choice(BRANCH_LIST) print "Initiated snapshot message to %s\n" % str( random_entry["name"]) s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.connect((random_entry['ip'], random_entry['port'])) initSnap = bank_pb2.InitSnapshot() initSnap.snapshot_id = snap_id + 1 msg = bank_pb2.BranchMessage() msg.init_snapshot.CopyFrom(initSnap) s.sendall(pickle.dumps(msg)) s.close() time.sleep(4) retrieveSnap = bank_pb2.RetrieveSnapshot() retrieveSnap.snapshot_id = snap_id + 1 msg1 = bank_pb2.BranchMessage() msg1.retrieve_snapshot.CopyFrom(retrieveSnap) for b in BRANCH_LIST: s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.connect((b['ip'], b['port'])) s.sendall(pickle.dumps(msg1)) data = pickle.loads(s.recv(1024)) if data.WhichOneof('branch_message') == 'return_snapshot': d = {} f = {} l = [] to_branch = b['name']
InitSnapshotMsg.snapshot_id = int(snapshot_num) branchsocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) branchsocket.connect((randomBranch['ip'], randomBranch['port'])) branchMessage = bank_pb2.BranchMessage() branchMessage.init_snapshot.CopyFrom(InitSnapshotMsg) branchsocket.sendall(pickle.dumps(branchMessage)) branchsocket.close() time.sleep(10) totalCurrentBalance = 0 for branchobj in branhList: branches.name = branchobj['name'] branches.ip = branchobj['ip'] branches.port = branchobj['port'] retriveSnapshotMessage = bank_pb2.RetrieveSnapshot() retriveSnapshotMessage.snapshot_id = int(snapshot_num) branchMessage1 = bank_pb2.BranchMessage() branchMessage1.retrieve_snapshot.CopyFrom(retriveSnapshotMessage) branchsocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) branchsocket.connect((branchobj['ip'], branchobj['port'])) branchsocket.sendall(pickle.dumps(branchMessage1)) data = pickle.loads(branchsocket.recv(1024)) branchsocket.close() channel_state = "" newList = [] for item in branhList: if item['name'] != branchobj['name']: newList.append(item) for branch , channelState in zip(newList, data.return_snapshot.local_snapshot.channel_state) : channel_state = channel_state + str(branch['name']) + " -> " + str(branchobj['name']) + " : " + str(channelState) + " "
def main(): branch_names = [] branch_ip = [] branch_port = [] snapshot_count = 1 f1 = open(sys.argv[2], 'r') for f in f1: columns = f.split(' ') columns = [c.strip() for c in columns] branch_names.append(columns[0]) branch_ip.append(columns[1]) branch_port.append(columns[2]) f1.close() init_Branch = bank_pb2.InitBranch() init_Branch.balance = int(sys.argv[1]) / len(branch_names) all_branches = [None] * len(branch_names) for i in range(len(branch_names)): all_branches[i] = bank_pb2.InitBranch.Branch( ) #init_Branch.all_branches.add() all_branches[i].name = branch_names[i] all_branches[i].ip = branch_ip[i] all_branches[i].port = int(branch_port[i]) init_Branch.all_branches.extend(all_branches) branch_Message = bank_pb2.BranchMessage() branch_Message1 = bank_pb2.BranchMessage() branch_Message.init_branch.balance = init_Branch.balance branch_Message.init_branch.all_branches.extend(init_Branch.all_branches) for i in range(len(branch_names)): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) #out of loop s.connect((branch_ip[i], int(branch_port[i]))) s.send(branch_Message.SerializeToString()) s.close() while True: sleep(1) branch_number = random.randint(0, len(branch_names) - 1) init_snapshot_message = bank_pb2.InitSnapshot() init_snapshot_message.snapshot_id = snapshot_count branch_Message.init_snapshot.snapshot_id = init_snapshot_message.snapshot_id s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.connect((branch_ip[branch_number], int(branch_port[branch_number]))) s.send(branch_Message.SerializeToString()) s.close() sleep(2) retrieve_snapshot_message = bank_pb2.RetrieveSnapshot() retrieve_snapshot_message.snapshot_id = snapshot_count branch_Message.retrieve_snapshot.snapshot_id = retrieve_snapshot_message.snapshot_id for i in range(len(branch_names)): br_n = [] for j in range(len(branch_names)): if i != j: br_n.append(branch_names[j]) s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) #out of loop s.connect((branch_ip[i], int(branch_port[i]))) s.send(branch_Message.SerializeToString()) data = s.recv(1024) branch_Message1.ParseFromString(data) print "\nSnapshot ID:", branch_Message1.return_snapshot.local_snapshot.snapshot_id print branch_names[ i] + ":", branch_Message1.return_snapshot.local_snapshot.balance count = 0 for f in branch_Message1.return_snapshot.local_snapshot.channel_state: print br_n[count] + "->" + branch_names[i] + ":" + str(f) count += 1 s.close() snapshot_count += 1
def SendRetriveSnapshotMessage(self, snapshot_num): total = 0 print "---------------------------------------" print "snapshot_id:" + str(snapshot_num) flag = 0 for item in self.list: NextIp = item.ip NextPort = item.port NextBranch = item.name retrive_snapshot_msg = bank_pb2.RetrieveSnapshot() retrive_snapshot_msg.snapshot_id = int(snapshot_num) pb_msg = bank_pb2.BranchMessage() pb_msg.retrieve_snapshot.CopyFrom(retrive_snapshot_msg) encoded = pb_msg.SerializeToString() ip = NextIp port_num = NextPort message = encoded try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect((ip, int(port_num))) x = pack('>I', len(message)) sock.sendall(x) sock.sendall(message) header = sock.recv(4) while header == "removethios": sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect((ip, int(port_num))) x = pack('>I', len(message)) sock.sendall(x) sock.sendall(message) header = sock.recv(4) message_length, = unpack('>I', header) message = sock.recv(message_length) pb_message = bank_pb2.BranchMessage() pb_message.ParseFromString(message) sock.close() except: print "EXCEPTION ! Socket exception in SendRetriveSnapshotMessage" sys.exit(0) list123 = [] for item1 in self.list: if item1.name != NextBranch: list123.append(item1) channel_state = "" for item , chanstat in zip(list123, pb_message.return_snapshot.local_snapshot.channel_state) : if int(chanstat) != 0 : flag = 1 channel_state = channel_state + str(item.name) + " -> " + str(NextBranch) + " : " + str(chanstat) + " " total = total + int(chanstat) branchname = NextBranch print branchname + " : " + str(pb_message.return_snapshot.local_snapshot.balance) + " , " + channel_state total = total + pb_message.return_snapshot.local_snapshot.balance print "Total Balance " + str(total)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_address = (b_ips[dest], b_ports[dest]) sock.connect(server_address) try: #print 'Init Snaphot message sent to ' + b_names[dest] sock.sendall(msg.SerializeToString()) except: print 'Controller could not send the init snapshot message' finally: # print 'Socket closed by the controller' sock.close() time.sleep(2) # Send retrieve snapshot r_msg = bank_pb2.RetrieveSnapshot() r_msg.snapshot_id = s_counter msg = bank_pb2.BranchMessage() msg.retrieve_snapshot.CopyFrom(r_msg) for i in range(0, b_counter): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_address = (b_ips[i], b_ports[i]) sock.connect(server_address) try: #print 'Retrieve message sent' sock.sendall(msg.SerializeToString()) except: print 'Controller could not send retrieve message'
def start_snapshotting(branch_sockets): no_sockets = len(branch_sockets) snapshot_id = 1 while True: # Wait before sending init_snapshot time.sleep(SNAPSHOT_INTERVAL) # Create init_snapshot message pb_msg = bank_pb2.BranchMessage() init_snapshot_msg = bank_pb2.InitSnapshot() init_snapshot_msg.snapshot_id = snapshot_id pb_msg.init_snapshot.CopyFrom(init_snapshot_msg) # Select a random branch and send init_snapshot victim = branch_sockets[randint(0, no_sockets - 1)] victim[0].send(pb_msg.SerializeToString()) print "\nSent snapshot msg " + str(snapshot_id) + " to " + str( victim[1]) # Wait before retrieving snapshot time.sleep(SNAPSHOT_RETRIEVE_INTERVAL) # Create retrieve_snapshot message pb_msg = bank_pb2.BranchMessage() retrieve_snapshot_msg = bank_pb2.RetrieveSnapshot() retrieve_snapshot_msg.snapshot_id = snapshot_id pb_msg.retrieve_snapshot.CopyFrom(retrieve_snapshot_msg) print "\nsnapshot_id: " + str(snapshot_id) # Send retrieve_snapshot to all branches and display the returned message for branch in branch_sockets: branch[0].send(pb_msg.SerializeToString()) # Get the reply - return_snapshot incoming_msg_from_wire = branch[0].recv(MAX_BUFFER_SIZE) if len(incoming_msg_from_wire) == 0: print "ERROR! the branch " + branch[ 1] + " doesn't have snapshot: " + str(snapshot_id) continue pb_msg_ret = bank_pb2.BranchMessage() pb_msg_ret.ParseFromString(incoming_msg_from_wire) # Error handling if not pb_msg_ret.HasField("return_snapshot"): print "ERROR! the branch " + branch[ 1] + " returned some other message : " + str(pb_msg_ret) continue # Display on the screen in required format output_string = str(branch[1]) + ": " + str( pb_msg_ret.return_snapshot.local_snapshot.balance) + ", " channel_states = pb_msg_ret.return_snapshot.local_snapshot.channel_state branch_names = [x[1] for x in branch_sockets if x[1] != branch[1]] for br_name, channel_state in zip(branch_names, channel_states): output_string += br_name + "->" + branch[1] + ": " + str( channel_state) + ", " print output_string snapshot_id += 1