示例#1
0
    def handle_retrieve_snapshot(cls, incoming_message):
        snapshot_id = incoming_message.retrieve_snapshot.snapshot_id

        if snapshot_id not in cls.current_snapshots:
            # Asked to retrieve a snapshot which doesn't exist
            print "ERROR! Asked to retrieve a snapshot ID (" + str(
                snapshot_id) + ") which doesn't exist!"
            return None

        # Create the return_snapshot object and local_snapshot object
        pb_msg = bank_pb2.BranchMessage()
        return_snapshot_obj = bank_pb2.ReturnSnapshot()
        local_snapshot_obj = return_snapshot_obj.LocalSnapshot()

        local_snapshot_obj.snapshot_id = snapshot_id

        # Populate the local_snapshot object
        for state in cls.current_snapshots[snapshot_id]:
            if state != "local":
                local_snapshot_obj.channel_state.append(
                    cls.current_snapshots[snapshot_id][state])
            else:
                local_snapshot_obj.balance = cls.current_snapshots[
                    snapshot_id][state]

        return_snapshot_obj.local_snapshot.CopyFrom(local_snapshot_obj)
        pb_msg.return_snapshot.CopyFrom(return_snapshot_obj)

        return pb_msg
示例#2
0
文件: Branch.py 项目: atskae/cs557
    def sendReturnSnapshot(self):
        snapshot = bank_pb2.ReturnSnapshot()
        snapshot.local_snapshot.snapshot_id = self.snapshot_id
        #snapshot.local_snapshot.balance = self.snapshotBalance
        snapshot.local_snapshot.balance = self.snapshots[
            self.snapshot_id].balance

        # First get senders in sorted order
        senders = []
        for port in self.branches.keys():
            sender = self.branches[port].name
            senders.append(sender)
        senders.sort()

        channel_states = []
        for sender in senders:
            amount = self.snapshots[self.snapshot_id].getChannel(sender)
            channel_states.append(amount)

        #for branch in sorted(self.commChannel.keys()):
        #    #snapshot.local_snapshot.channel_state.extend(self.commChannel[branch])
        #    channel_states.append(self.commChannel[branch])

        snapshot.local_snapshot.channel_state.extend(channel_states)
        bm = bank_pb2.BranchMessage()
        bm.return_snapshot.CopyFrom(snapshot)

        print('ReturnSnapshot to send to Controller', bm)

        # Send to Controller
        self.sendMessageToController(bm)
示例#3
0
 def handle_snapshot(self, message, initial=True):
     print('snapshot_id: {}'.format(message.snapshot_id))
     if not initial:
         print('From: {}'.format(message.branch_name))
     # Set marker
     marker = bank_pb2.Marker()
     marker.branch_name = self.name
     marker.snapshot_id = message.snapshot_id
     # Create local snapshot to send later
     stored_snapshot = bank_pb2.ReturnSnapshot()
     stored_snapshot.local_snapshot.snapshot_id = message.snapshot_id
     # Set balance of local snapshot and default value channel_states
     stored_snapshot.local_snapshot.balance = self.balance
     stored_snapshot.local_snapshot.channel_state[:] = [
         0 for _ in range(len(self.branches))
     ]
     print('Initial channel_state')
     print(stored_snapshot.local_snapshot.channel_state)
     # Store local snapshot which means it has seen snapshot_id
     self.snapshots[message.snapshot_id] = (stored_snapshot, True)
     self.channel_states[message.snapshot_id] = [
         0 for _ in range(len(self.branches))
     ]
     if not initial:
         incoming_channel_index = self.get_branch_index(message.branch_name)
         # Set incoming channel as empty for snapshot
         self.channel_states[
             message.snapshot_id][incoming_channel_index] = None
     new_message = bank_pb2.BranchMessage()
     new_message.marker.MergeFrom(marker)
     for sock in self.sockets:
         sock[1].acquire()
         self.message_socket(sock[0], new_message)
         sock[1].release()
def receiver(client_socket):
    global balance
    global branches_all
    global branch_soc_list
    global curr_snap_id
    global this_branch
    global marker_count
    global RECORD

    while 1:
        received_message = client_socket.recv(1024)
        data = bank_pb2.BranchMessage()
        data.ParseFromString(received_message)
        if data.WhichOneof('branch_message') == 'init_branch':
            lock = threading.Lock()
            lock.acquire()
            balance = data.init_branch.balance
            lock.release()
            branches_add(data.init_branch.all_branches)
            connect_branches()
            sleep(5)
            print("calling sleep thread")
            try:
                th = Thread(target=sleep_thread, args=())
                th.daemon = True
                th.start()
            except KeyboardInterrupt:
                exit()
        if data.WhichOneof('branch_message') == 'transfer':

            if not RECORD[int(data.transfer.src_branch)]:
                lock = threading.Lock()
                lock.acquire()
                balance += data.transfer.money
                print "Balance in the bank after getting transferred money:", balance
                lock.release()
            else:
                print "Recording", data.transfer.src_branch, "->", this_branch, data.transfer.money
                lock = threading.Lock()
                lock.acquire()
                if curr_snap_id not in global_snapshot:
                    global_snapshot[curr_snap_id] = {}
                global_snapshot[curr_snap_id][
                    data.transfer.src_branch] = data.transfer.money
                lock.release()
        if data.WhichOneof('branch_message') == 'init_snapshot':
            print "Received init_snapshot for", this_branch, "with snap_id", data.init_snapshot.snapshot_id
            lock = threading.Lock()
            lock.acquire()
            for obj in RECORD:
                RECORD[obj] = True
            curr_snap_id = data.init_snapshot.snapshot_id
            if curr_snap_id not in global_snapshot:
                global_snapshot[curr_snap_id] = {}
            global_snapshot[curr_snap_id]['balance'] = balance
            lock.release()
            for banks in branch_soc_list:
                marker = bank_pb2.Marker()
                marker.src_branch = str(port)
                marker.dst_branch = str(banks)
                marker.snapshot_id = curr_snap_id
                send_message = bank_pb2.BranchMessage()
                send_message.marker.CopyFrom(marker)
                print "Sending marker to", banks, "for snapshot", curr_snap_id
                branch_soc_list[banks].sendall(
                    send_message.SerializeToString())
        if data.WhichOneof('branch_message') == 'marker':
            if data.marker.snapshot_id != curr_snap_id:
                print "Received marker from", data.marker.src_branch, "and snap_id is", data.marker.snapshot_id
                lock = threading.Lock()
                lock.acquire()
                for obj in RECORD:
                    RECORD[obj] = True
                curr_snap_id = data.marker.snapshot_id
                if curr_snap_id not in global_snapshot:
                    global_snapshot[curr_snap_id] = {}
                global_snapshot[curr_snap_id]['balance'] = balance
                lock.release()
                for banks in branch_soc_list:
                    marker = bank_pb2.Marker()
                    marker.src_branch = str(port)
                    marker.dst_branch = str(banks)
                    marker.snapshot_id = curr_snap_id
                    send_message = bank_pb2.BranchMessage()
                    send_message.marker.CopyFrom(marker)
                    print "Sending marker to", banks, "for snapshot", curr_snap_id
                    branch_soc_list[banks].sendall(
                        send_message.SerializeToString())
                    if banks == data.marker.src_branch:
                        RECORD[banks] = False
            elif marker_count < len(branches_all):
                print "Reply marker from", data.marker.src_branch, "for snapshot", data.marker.snapshot_id
                marker_count += 1
                lock = threading.Lock()
                lock.acquire()
                RECORD[int(data.marker.src_branch)] = False
                if marker_count == len(branches_all):
                    print "Received markers from all. Snapshot created"
                    marker_count = 0
                lock.release()
        if data.WhichOneof('branch_message') == 'retrieve_snapshot':
            print "Received retrieve_snapshot for", this_branch
            snapshot = bank_pb2.ReturnSnapshot()
            snapshot.local_snapshot.snapshot_id = data.retrieve_snapshot.snapshot_id
            for key, val in global_snapshot.items():
                if key == data.retrieve_snapshot.snapshot_id:
                    for key2, val2 in val.items():
                        if key2 != 'balance':
                            snapshot.local_snapshot.channel_state.append(
                                int(str(key2) + str(val2)))
                        else:
                            snapshot.local_snapshot.balance = val2
            send_message = bank_pb2.BranchMessage()
            send_message.return_snapshot.CopyFrom(snapshot)
            client_socket.sendall(send_message.SerializeToString())
示例#5
0
    def bankhandle(self,data,branchNameIn,clientsocket,timelimit):
	global TOTAL_BALANCE
	global branhList
        global moneyTransfer
	global branchName
	global timeLimit
	branchName = branchNameIn
	if data.HasField("init_branch") :
			TOTAL_BALANCE = data.init_branch.balance
			for branch in data.init_branch.all_branches:
                        	if branch.name != branchName:
                        		branhList.append(branch)
			logger.debug("Branch List initialized .....")
			logger.debug(branhList)
			moneyTransfer = True 
			timeLimit = (timelimit/1000.0)
			thread = Thread(target = self.MoneyTransfer)
                	thread.daemon = True
                	thread.start()

	elif data.HasField("transfer") :
			if len(self.snapshotList) > 0 :
                        	snapshot_num = self.snapshotList [-1]
			logger.debug("Transfer Message Recieved.....")
                        logger.debug("if incoming message need to be recorded or Store it in current branch balance")
                        if len(self.snapshotList) > 0 and self.markerMsgChnlState[ snapshot_num , data.transfer.src_branch ][0] == True and self.markerMsgBalance != 0:
				logger.debug("Recording Transfer balance in marker message channel state")
                                self.markerMsgChnlState[ snapshot_num , data.transfer.src_branch ] = (True , int(data.transfer.money))
			else:	
        		        with self.critical_section_lock:
                 		       TOTAL_BALANCE = TOTAL_BALANCE + data.transfer.money
			print "branchBalance Updated to...."+ str(TOTAL_BALANCE)
			
		
	elif data.HasField("init_snapshot") :
			time.sleep(2)
			moneyTransfer  = False
                	logger.debug("Recording initial snapshot")
                	self.snapshotList.append(data.init_snapshot.snapshot_id)
                	self.markerMsgBalance[data.init_snapshot.snapshot_id] = TOTAL_BALANCE
                	for branch in branhList:
      				logger.debug("Started recording on all the incoming channels")
                        	self.markerMsgChnlState[ data.init_snapshot.snapshot_id , branch.name ] = (True , 0 )

                	logger.debug("Sending Markers to all the channels")
                	thread = Thread(target = self.SendMarkers(data.init_snapshot.snapshot_id))
                	thread.daemon = True
                	thread.start()
		
	elif data.HasField("marker") :
			moneyTransfer  = False
			logger.debug("Marker message received")
			if data.marker.snapshot_id not in self.snapshotList :
				self.snapshotList.append(data.marker.snapshot_id)
                        	self.markerMsgBalance[data.marker.snapshot_id] = TOTAL_BALANCE
				self.markerMsgChnlState[ data.marker.snapshot_id , data.marker.src_branch] = (False , 0 )
                        	for branch in branhList:
					if branch.name != data.marker.src_branch :
                                		self.markerMsgChnlState[ data.marker.snapshot_id , branch.name ] = (True , 0 )
                        	self.SendMarkers(data.marker.snapshot_id)
                	else:
                        	amount = self.markerMsgChnlState[data.marker.snapshot_id , data.marker.src_branch][1]
                              	self.markerMsgChnlState[data.marker.snapshot_id , data.marker.src_branch] = (False , amount )
			moneyTransfer  = True	
	elif data.HasField("retrieve_snapshot") :
			logger.debug("Retriving snapshot")
			returnSnapshotMessage = bank_pb2.ReturnSnapshot.LocalSnapshot()
                	returnSnapshotMessage.snapshot_id = int(data.retrieve_snapshot.snapshot_id)
                	returnSnapshotMessage.balance = int(self.markerMsgBalance[data.retrieve_snapshot.snapshot_id])
                	for returnBranch in branhList:
                        	amount = self.markerMsgChnlState[data.retrieve_snapshot.snapshot_id , returnBranch.name][1]
				with self.critical_section_lock:
					TOTAL_BALANCE = TOTAL_BALANCE + amount 
				returnSnapshotMessage.channel_state.append(int(amount))

			branchMessage = bank_pb2.ReturnSnapshot()
                	branchMessage.local_snapshot.CopyFrom(returnSnapshotMessage)
			branchmessage = bank_pb2.BranchMessage()
			branchmessage.return_snapshot.CopyFrom(branchMessage)
			logger.debug("Returning Snapshot : "+ str(branchmessage))
			clientsocket.sendall(pickle.dumps(branchmessage))
                            s.close()

                        isCapturing = True
                elif MARKER_MSG == len(BRANCH_LIST):
                    isCapturing = False
                    MARKER_MSG = 1
                    isUpdate = False
                    print "\n\nSnapShots captured : %s \n\n" % str(
                        SNAPSHOTS[data.marker.snapshot_id])
                else:
                    MARKER_MSG = MARKER_MSG + 1

                if isUpdate:
                    SNAPSHOTS[currentSnapId][sys.argv[1]] = BRANCH_BALANCE
            elif data.WhichOneof('branch_message') == 'retrieve_snapshot':
                snap_shot_obj = bank_pb2.ReturnSnapshot()
                snap_shot_obj.local_snapshot.snapshot_id = data.retrieve_snapshot.snapshot_id
                snap_shot_obj.local_snapshot.balance = SNAPSHOTS[
                    data.retrieve_snapshot.snapshot_id][str(sys.argv[1])]
                for k, v in SNAPSHOTS[
                        data.retrieve_snapshot.snapshot_id].iteritems():
                    if k != str(sys.argv[1]):
                        if len(v) == 0:
                            snap_shot_obj.local_snapshot.channel_state.append(
                                0)
                        else:
                            for val in v:
                                snap_shot_obj.local_snapshot.channel_state.append(
                                    val)
                msg = bank_pb2.BranchMessage()
                msg.return_snapshot.CopyFrom(snap_shot_obj)
    def recieve_transfer_message(self, data):
        rec_branch_message = bank_pb2.BranchMessage()
        rec_branch_message.ParseFromString(data)

        # print rec_branch_message

        if rec_branch_message.HasField("transfer"):
            recieved_money = rec_branch_message.transfer.money
            print "Recieved Money: " + str(recieved_money)
            self.lock.acquire()
            print "Curr Balance " + str(self.curr_balance)
            self.curr_balance += recieved_money
            if self.recording_state[rec_branch_message.transfer.branch_name]: ## If recording
                Channel_States.add_balance(rec_branch_message.transfer.branch_name, recieved_money)
            print "Updated Curr Balance " + str(self.curr_balance)
            self.lock.release()

        elif rec_branch_message.HasField("init_snapshot"):
            self.recorded_states[self.branch_name] = self.curr_balance
            self.send_markers(self.snapshot_id) ##send markers to all others
            for branch in self.recording_state:
                Channel_States.start_listen(branch)
                self.recording_state[branch] = True # Enable recording for the branch
            self.isFirstMarker = False
        
        elif rec_branch_message.HasField("marker"):
            
            if self.isFirstMarker:
                print "Recieved 1st marker from " + rec_branch_message.marker.branch_name
                self.recorded_states[self.branch_name] = self.curr_balance
                # self.recorded_states[rec_branch_message.marker.branch_name] = 0
                self.send_markers(self.snapshot_id)
                for branch in self.recording_state:
                    if branch != rec_branch_message.marker.branch_name:
                        self.lock.acquire()
                        Channel_States.start_listen(branch)
                        self.lock.release()
                        self.recording_state[branch] = True
                self.isFirstMarker = False
                
                
            else:
                self.lock.acquire()
                print "Recieved 2nd marker from " + rec_branch_message.marker.branch_name
                value = Channel_States.stop_listen(rec_branch_message.marker.branch_name)
                self.recording_state[rec_branch_message.marker.branch_name] = False
                self.recorded_states[rec_branch_message.marker.branch_name] = value
                self.lock.release()
            
        elif rec_branch_message.HasField("retrieve_snapshot"):
            local = bank_pb2.ReturnSnapshot().LocalSnapshot()
            local.snapshot_id = self.snapshot_id
            local.balance = self.recorded_states[self.branch_name]
            for branch in self.recorded_states:
                if branch != self.branch_name:
                    local.channel_state.append(self.recorded_states[branch])
            ret = bank_pb2.ReturnSnapshot()
            ret.local_snapshot.CopyFrom(local)
            branchmsg = bank_pb2.BranchMessage()
            branchmsg.return_snapshot.CopyFrom(ret)

            self.snapshot_id+=1
            self.isFirstMarker = True
            for branch in self.recording_state:
                self.recording_state[branch] = False ## init recording state
            self.recorded_states = {}
            Channel_States.channel_states = {}


            print "Got Retrieve msg!"
            self.controller_connect[0].sendall(branchmsg.SerializeToString())
示例#8
0
def thread_listener(sock, handler):
    init_received = False
    c_ip = ''
    c_port = 0

    # Listen for init message
    while True:
        connection, client = sock.accept()
        try:
            while True:
                data = connection.recv(1024)

                if data:
                    init_received = True
                    msg = bank_pb2.BranchMessage()
                    msg.ParseFromString(data)
                    handler.balance = msg.init_branch.balance 
                    
                    for i in msg.init_branch.all_branches:
                        handler.b_counter = handler.b_counter + 1
                        handler.b_names.append(i.name)
                        handler.b_ips.append(i.ip)
                        handler.b_ports.append(i.port)

                    handler.b_counter = handler.b_counter - 1
                    c_ip = handler.b_ips[handler.b_counter]
                    c_port = handler.b_ports[handler.b_counter]
                    handler.b_names.pop()
                    handler.b_ips.pop()
                    handler.b_ports.pop()
                else:
                    break
        finally:
            connection.close()
        if init_received:
            handler.start_connecting = True
            break

    # Listen for other messages
    while True:
        connection, client = sock.accept()
        try:
            while True:
                data = connection.recv(1024)
                
                if data:
                    msg = bank_pb2.BranchMessage()
                    msg.ParseFromString(data)

                    if msg.HasField('transfer'):
                        money = msg.transfer.money
                        # print handler.name + ' receives ' + str(money) + ' from ' + msg.transfer.branch_name
                        
                        # Update channel during snapshot algorithm 
                        # print '         ' + str(handler.b_recording)
                        # print '         ' + msg.transfer.branch_name

                        if msg.transfer.branch_name in handler.b_recording:
                            # print '     ' + handler.name + ' receives and appends ' + str(money) + ' from ' + msg.transfer.branch_name
                            handler.incoming_names.append(msg.transfer.branch_name)
                            handler.incoming_balances.append(money)
                        else:
                            # print '     ' + handler.name + ' receives ' + str(money) + ' from ' + msg.transfer.branch_name
                            balance_lock.acquire()
                            handler.balance = handler.balance + money
                            balance_lock.release()

                    elif msg.HasField('init_snapshot'):
                        # balance_lock.acquire()
                        handler.markers_going = True
                        # print handler.name + ' Init Snapshot received - local: ' +  str(handler.balance)
                        handler.state_balance = handler.balance
                        handler.current_snapshot_id = msg.init_snapshot.snapshot_id
                        
                        for i in range(0, handler.b_counter):
                            if handler.b_names[i] != handler.name:
                                handler.b_recording.append(handler.b_names[i])
                                addr = (handler.b_ips[i], handler.b_ports[i])
                                
                                try:
                                    m = bank_pb2.Marker()
                                    m.snapshot_id = handler.current_snapshot_id
                                    m.branch_name = handler.name
                                    
                                    msg = bank_pb2.BranchMessage()            
                                    msg.marker.CopyFrom(m)
           
                                    # print handler.name + ' sends marker to ' + handler.b_names[i]
                                    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                                    s.connect(addr)
                                    s.sendall(msg.SerializeToString())
                                except Exception as e:
                                    print handler.name + ' could not send the marker message'
                                    print e
                                s.close()
                        handler.markers_going = False
                        # balance_lock.release()
                    elif msg.HasField('marker'):
                        # print 'marker msg received'
                        # Not first marker message
                        if handler.current_snapshot_id == msg.marker.snapshot_id:
                            # print handler.name + ' Second marker received from ' + msg.marker.branch_name
                            if msg.marker.branch_name in handler.b_recording:
                                handler.b_recording.remove(msg.marker.branch_name)
                        # First marker message
                        else:
                            # balance_lock.acquire()
                            handler.markers_going = True
                            mm = msg.marker.branch_name 
                            # print handler.name + ' First marker received from ' + mm + ' - local: ' + str(handler.balance)
                            handler.current_snapshot_id = msg.marker.snapshot_id
                            handler.state_balance = handler.balance
   
                            # Set marker`s channel to 0 and start recording other channels
                            handler.incoming_names.append(msg.marker.branch_name)
                            handler.incoming_balances.append(0)

                            for i in range(0, handler.b_counter):
                                n = handler.b_names[i]
                                if handler.name != n and mm != n:
                                    #print "         mm: " + mm
                                    #print "         record: " + n
                                    handler.b_recording.append(n)
                                # Send markers
                                if handler.name != n:
                                    addr = (handler.b_ips[i], handler.b_ports[i])        
                                    try:
                                        m = bank_pb2.Marker()
                                        m.snapshot_id = handler.current_snapshot_id
                                        m.branch_name = handler.name

                                        msg = bank_pb2.BranchMessage()            
                                        msg.marker.CopyFrom(m)
                                        
                                        # print handler.name + ' sends marker to ' + handler.b_names[i]
                                        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                                        s.connect(addr)
                                        s.sendall(msg.SerializeToString())
                                    except:
                                        print handler.name + ' could not send the marker message from first'
                                    finally:
                                        s.close()
                            handler.markers_going = False
                            # balance_lock.release()
                    elif msg.HasField('retrieve_snapshot'):
                        # print handler.name + ' Retrieve Snapshot received'

                        if msg.retrieve_snapshot.snapshot_id == handler.current_snapshot_id:
                            # Send return snapshot message
                            r = bank_pb2.ReturnSnapshot()
                            r.local_snapshot.snapshot_id = handler.current_snapshot_id
                            r.local_snapshot.balance = handler.state_balance

                            names = []
                            balances = []

                            for index, i in enumerate(handler.incoming_names):
                                if i not in names:
                                    names.append(i)
                                    balances.append(handler.incoming_balances[index])
                                else:
                                    ind = names.index(i)
                                    balances[ind] = balances[ind] + handler.incoming_balances[index]

                            # Populate channel states
                            for i in balances:
                            # for i in handler.incoming_balances:
                                r.local_snapshot.channel_state.append(i)
                            
                            msg = bank_pb2.BranchMessage()
                            msg.return_snapshot.CopyFrom(r)

                            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                            server_address = (c_ip, c_port)
                            s.connect(server_address)

                            try:
                                # print handler.name + ' Response sent'
                                s.sendall(msg.SerializeToString())
                            except:
                                print 'Could not send response message'
                            finally:
                                s.close()
                        
                            balance_lock.acquire()
                            for i in balances:
                            # for i in handler.incoming_balances:
                                handler.balance = handler.balance + i
                            balance_lock.release()

                            handler.incoming_names = []
                            handler.incoming_balances = []
                            handler.b_recording = []
                        else:
                            print 'Error in retrieve snapshot - ids do not match'
                    else:
                        print 'Incoming data has wrong protobuf format'
                break
        finally:
            connection.close()