def create_retrieve_snap(self):
     global init_counter
     branch_msg = bank_pb2.BranchMessage()
     ret = bank_pb2.RetrieveSnapshot()
     ret.snapshot_id = init_counter
     branch_msg.retrieve_snapshot.CopyFrom(ret)
     return branch_msg
예제 #2
0
def get_snapshot(snapshot_id, connections, money):
    snapshots = []
    retrieve = bank_pb2.RetrieveSnapshot()
    retrieve.snapshot_id = snapshot_id
    message = bank_pb2.BranchMessage()
    message.retrieve_snapshot.MergeFrom(retrieve)
    total = 0
    print('snapshot_id: {}'.format(snapshot_id))
    for sock in connections:
        result_str = ''
        # Message server
        message_socket(sock[0], message)
        # Read returned snapshot
        initial_data = sock[0].recv(2)
        data_size = struct.unpack('H', initial_data)[0]
        data = sock[0].recv(data_size)
        rec_message = bank_pb2.BranchMessage()
        rec_message.ParseFromString(data)
        assert rec_message.WhichOneof('branch_message') == 'return_snapshot'
        local_snapshot = rec_message.return_snapshot.local_snapshot
        # Get initial balance
        result_str += '{}: {}, '.format(sock[1], local_snapshot.balance)
        # if not all(x == 0 for x in local_snapshot.channel_state):
        #     print('channel_states: ')
        #     print(local_snapshot.channel_state)
        total += local_snapshot.balance
        for i in range(len(local_snapshot.channel_state)):
            if i != get_branch_index(sock[1]):
                amt = local_snapshot.channel_state[i]
                total += amt
                result_str += '{}->{}: {}, '.format(connections[i][1], sock[1], amt)
        print(result_str)
    print('Total: {}'.format(total))
    print('')
    assert total == money
예제 #3
0
            random_entry = random.choice(BRANCH_LIST)
            print "Initiated snapshot message to %s\n" % str(
                random_entry["name"])
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.connect((random_entry['ip'], random_entry['port']))
            initSnap = bank_pb2.InitSnapshot()
            initSnap.snapshot_id = snap_id + 1
            msg = bank_pb2.BranchMessage()
            msg.init_snapshot.CopyFrom(initSnap)
            s.sendall(pickle.dumps(msg))
            s.close()

            time.sleep(4)

            retrieveSnap = bank_pb2.RetrieveSnapshot()
            retrieveSnap.snapshot_id = snap_id + 1
            msg1 = bank_pb2.BranchMessage()
            msg1.retrieve_snapshot.CopyFrom(retrieveSnap)

            for b in BRANCH_LIST:
                s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                s.connect((b['ip'], b['port']))
                s.sendall(pickle.dumps(msg1))
                data = pickle.loads(s.recv(1024))

                if data.WhichOneof('branch_message') == 'return_snapshot':
                    d = {}
                    f = {}
                    l = []
                    to_branch = b['name']
예제 #4
0
		InitSnapshotMsg.snapshot_id = int(snapshot_num)
		branchsocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            	branchsocket.connect((randomBranch['ip'], randomBranch['port']))
		branchMessage = bank_pb2.BranchMessage()
                branchMessage.init_snapshot.CopyFrom(InitSnapshotMsg)
		branchsocket.sendall(pickle.dumps(branchMessage))
		branchsocket.close()
		time.sleep(10)


                totalCurrentBalance = 0
		for branchobj in branhList:
                	branches.name = branchobj['name']
                        branches.ip = branchobj['ip']
                        branches.port = branchobj['port']
			retriveSnapshotMessage = bank_pb2.RetrieveSnapshot()
	        	retriveSnapshotMessage.snapshot_id = int(snapshot_num)
			branchMessage1 = bank_pb2.BranchMessage()
                	branchMessage1.retrieve_snapshot.CopyFrom(retriveSnapshotMessage)
			branchsocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            		branchsocket.connect((branchobj['ip'], branchobj['port']))
			branchsocket.sendall(pickle.dumps(branchMessage1))
			data = pickle.loads(branchsocket.recv(1024))		
			branchsocket.close()
			channel_state = ""
			newList = []
			for item in branhList:
				if item['name'] != branchobj['name']:
					newList.append(item)
			for branch , channelState in zip(newList, data.return_snapshot.local_snapshot.channel_state) :
					channel_state = channel_state + str(branch['name']) + " -> " + str(branchobj['name']) + " : " + str(channelState) + " "
def main():

    branch_names = []
    branch_ip = []
    branch_port = []
    snapshot_count = 1

    f1 = open(sys.argv[2], 'r')
    for f in f1:
        columns = f.split(' ')
        columns = [c.strip() for c in columns]
        branch_names.append(columns[0])
        branch_ip.append(columns[1])
        branch_port.append(columns[2])
    f1.close()

    init_Branch = bank_pb2.InitBranch()
    init_Branch.balance = int(sys.argv[1]) / len(branch_names)
    all_branches = [None] * len(branch_names)
    for i in range(len(branch_names)):
        all_branches[i] = bank_pb2.InitBranch.Branch(
        )  #init_Branch.all_branches.add()
        all_branches[i].name = branch_names[i]
        all_branches[i].ip = branch_ip[i]
        all_branches[i].port = int(branch_port[i])

    init_Branch.all_branches.extend(all_branches)

    branch_Message = bank_pb2.BranchMessage()
    branch_Message1 = bank_pb2.BranchMessage()
    branch_Message.init_branch.balance = init_Branch.balance
    branch_Message.init_branch.all_branches.extend(init_Branch.all_branches)

    for i in range(len(branch_names)):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)  #out of loop
        s.connect((branch_ip[i], int(branch_port[i])))
        s.send(branch_Message.SerializeToString())
        s.close()

    while True:
        sleep(1)
        branch_number = random.randint(0, len(branch_names) - 1)
        init_snapshot_message = bank_pb2.InitSnapshot()
        init_snapshot_message.snapshot_id = snapshot_count
        branch_Message.init_snapshot.snapshot_id = init_snapshot_message.snapshot_id
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.connect((branch_ip[branch_number], int(branch_port[branch_number])))
        s.send(branch_Message.SerializeToString())
        s.close()

        sleep(2)
        retrieve_snapshot_message = bank_pb2.RetrieveSnapshot()
        retrieve_snapshot_message.snapshot_id = snapshot_count
        branch_Message.retrieve_snapshot.snapshot_id = retrieve_snapshot_message.snapshot_id
        for i in range(len(branch_names)):
            br_n = []
            for j in range(len(branch_names)):
                if i != j:
                    br_n.append(branch_names[j])
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)  #out of loop
            s.connect((branch_ip[i], int(branch_port[i])))
            s.send(branch_Message.SerializeToString())
            data = s.recv(1024)
            branch_Message1.ParseFromString(data)
            print "\nSnapshot ID:", branch_Message1.return_snapshot.local_snapshot.snapshot_id
            print branch_names[
                i] + ":", branch_Message1.return_snapshot.local_snapshot.balance
            count = 0
            for f in branch_Message1.return_snapshot.local_snapshot.channel_state:
                print br_n[count] + "->" + branch_names[i] + ":" + str(f)
                count += 1
            s.close()

        snapshot_count += 1
예제 #6
0
	def SendRetriveSnapshotMessage(self, snapshot_num):
		total = 0
		print "---------------------------------------"
		print "snapshot_id:" + str(snapshot_num)

		flag = 0
		for item in self.list:
			NextIp = item.ip
	                NextPort = item.port
        	        NextBranch = item.name

	                retrive_snapshot_msg = bank_pb2.RetrieveSnapshot()
	                retrive_snapshot_msg.snapshot_id = int(snapshot_num)

        	        pb_msg = bank_pb2.BranchMessage()
                	pb_msg.retrieve_snapshot.CopyFrom(retrive_snapshot_msg)
                	encoded = pb_msg.SerializeToString()

			ip = NextIp
			port_num = NextPort
			message = encoded

			try:
		                sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
	        	        sock.connect((ip, int(port_num)))
	        	        x = pack('>I', len(message))
	        	        sock.sendall(x)
	        	        sock.sendall(message)
				header = sock.recv(4)

				while header == "removethios":
					sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
					sock.connect((ip, int(port_num)))
					x = pack('>I', len(message))
					sock.sendall(x)
					sock.sendall(message)
					header = sock.recv(4)

				message_length, = unpack('>I', header)
				message = sock.recv(message_length)
				pb_message = bank_pb2.BranchMessage()
				pb_message.ParseFromString(message)		
				sock.close()
			except:
				print "EXCEPTION ! Socket exception in SendRetriveSnapshotMessage"
				sys.exit(0)

			list123 = []
			for item1 in self.list:
				if item1.name != NextBranch:
					list123.append(item1)


			channel_state = ""
			for item , chanstat in zip(list123, pb_message.return_snapshot.local_snapshot.channel_state) :
				if int(chanstat) != 0 :
					flag = 1

				channel_state = channel_state + str(item.name) + " -> " + str(NextBranch) + " : " + str(chanstat) + " "
				total = total + int(chanstat)

			branchname = NextBranch
			print branchname + " : " + str(pb_message.return_snapshot.local_snapshot.balance) + " , " + channel_state

			total = total + pb_message.return_snapshot.local_snapshot.balance

		print "Total Balance " + str(total)
예제 #7
0
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        server_address = (b_ips[dest], b_ports[dest])
        sock.connect(server_address)

        try:
            #print 'Init Snaphot message sent to ' + b_names[dest]
            sock.sendall(msg.SerializeToString())
        except:
            print 'Controller could not send the init snapshot message'
        finally:
            # print 'Socket closed by the controller'
            sock.close()

        time.sleep(2)
        # Send retrieve snapshot
        r_msg = bank_pb2.RetrieveSnapshot()
        r_msg.snapshot_id = s_counter

        msg = bank_pb2.BranchMessage()
        msg.retrieve_snapshot.CopyFrom(r_msg)

        for i in range(0, b_counter):
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            server_address = (b_ips[i], b_ports[i])
            sock.connect(server_address)

            try:
                #print 'Retrieve message sent'
                sock.sendall(msg.SerializeToString())
            except:
                print 'Controller could not send retrieve message'
def start_snapshotting(branch_sockets):
    no_sockets = len(branch_sockets)

    snapshot_id = 1

    while True:
        # Wait before sending init_snapshot
        time.sleep(SNAPSHOT_INTERVAL)

        # Create init_snapshot message
        pb_msg = bank_pb2.BranchMessage()
        init_snapshot_msg = bank_pb2.InitSnapshot()
        init_snapshot_msg.snapshot_id = snapshot_id
        pb_msg.init_snapshot.CopyFrom(init_snapshot_msg)

        # Select a random branch and send init_snapshot
        victim = branch_sockets[randint(0, no_sockets - 1)]
        victim[0].send(pb_msg.SerializeToString())

        print "\nSent snapshot msg " + str(snapshot_id) + " to " + str(
            victim[1])

        # Wait before retrieving snapshot
        time.sleep(SNAPSHOT_RETRIEVE_INTERVAL)

        # Create retrieve_snapshot message
        pb_msg = bank_pb2.BranchMessage()
        retrieve_snapshot_msg = bank_pb2.RetrieveSnapshot()
        retrieve_snapshot_msg.snapshot_id = snapshot_id
        pb_msg.retrieve_snapshot.CopyFrom(retrieve_snapshot_msg)

        print "\nsnapshot_id: " + str(snapshot_id)
        # Send retrieve_snapshot to all branches and display the returned message
        for branch in branch_sockets:
            branch[0].send(pb_msg.SerializeToString())

            # Get the reply - return_snapshot
            incoming_msg_from_wire = branch[0].recv(MAX_BUFFER_SIZE)
            if len(incoming_msg_from_wire) == 0:
                print "ERROR! the branch " + branch[
                    1] + " doesn't have snapshot: " + str(snapshot_id)
                continue

            pb_msg_ret = bank_pb2.BranchMessage()
            pb_msg_ret.ParseFromString(incoming_msg_from_wire)

            # Error handling
            if not pb_msg_ret.HasField("return_snapshot"):
                print "ERROR! the branch " + branch[
                    1] + " returned some other message : " + str(pb_msg_ret)
                continue

            # Display on the screen in required format
            output_string = str(branch[1]) + ": " + str(
                pb_msg_ret.return_snapshot.local_snapshot.balance) + ", "
            channel_states = pb_msg_ret.return_snapshot.local_snapshot.channel_state

            branch_names = [x[1] for x in branch_sockets if x[1] != branch[1]]
            for br_name, channel_state in zip(branch_names, channel_states):
                output_string += br_name + "->" + branch[1] + ": " + str(
                    channel_state) + ", "

            print output_string

        snapshot_id += 1