def testOutgoingConcurrentPutAndTwoGetters(self): channel = comm_channel.CommChannel() result1 = {"outgoing": []} result2 = {"outgoing": []} def getter1(): result1["outgoing"].append(channel.get(1)) result1["outgoing"].append(channel.get(2)) def getter2(): result2["outgoing"].append(channel.get(1)) result2["outgoing"].append(channel.get(2)) t1 = threading.Thread(target=getter1) t1.start() t2 = threading.Thread(target=getter2) t2.start() channel.put("A") channel.put("B") t1.join() t2.join() self.assertEqual("A", result1["outgoing"][0][0]) self.assertIn(result1["outgoing"][0][1], [1, 2]) self.assertEqual(("B", 2), result1["outgoing"][1]) self.assertEqual("A", result2["outgoing"][0][0]) self.assertIn(result2["outgoing"][0][1], [1, 2]) self.assertEqual(("B", 2), result2["outgoing"][1])
def __init__(self, receive_port): """Receives health pills from a debugger and writes them to disk. Args: receive_port: The port at which to receive health pills from the TensorFlow debugger. always_flush: A boolean indicating whether the EventsWriter will be flushed after every write. Can be used for testing. """ super(InteractiveDebuggerDataServer, self).__init__( receive_port, InteractiveDebuggerDataStreamHandler ) self._incoming_channel = queue.Queue() self._outgoing_channel = comm_channel_lib.CommChannel() self._run_states = RunStates(breakpoints_func=lambda: self.breakpoints) self._tensor_store = tensor_store_lib.TensorStore() self._source_manager = SourceManager() curried_handler_constructor = functools.partial( InteractiveDebuggerDataStreamHandler, self._incoming_channel, self._outgoing_channel, self._run_states, self._tensor_store, ) grpc_debug_server.EventListenerBaseServicer.__init__( self, receive_port, curried_handler_constructor )
def testOutgoingConcurrentPutAndTwoGetters(self): channel = comm_channel.CommChannel() result1 = {'outgoing': []} result2 = {'outgoing': []} def getter1(): result1['outgoing'].append(channel.get(1)) result1['outgoing'].append(channel.get(2)) def getter2(): result2['outgoing'].append(channel.get(1)) result2['outgoing'].append(channel.get(2)) t1 = threading.Thread(target=getter1) t1.start() t2 = threading.Thread(target=getter2) t2.start() channel.put('A') channel.put('B') t1.join() t2.join() self.assertEqual('A', result1['outgoing'][0][0]) self.assertIn(result1['outgoing'][0][1], [1, 2]) self.assertEqual(('B', 2), result1['outgoing'][1]) self.assertEqual('A', result2['outgoing'][0][0]) self.assertIn(result2['outgoing'][0][1], [1, 2]) self.assertEqual(('B', 2), result2['outgoing'][1])
def testOutgoingSerialPutTwoGetOne(self): channel = comm_channel.CommChannel() channel.put("A") channel.put("B") channel.put("C") self.assertEqual(("A", 3), channel.get(1)) self.assertEqual(("B", 3), channel.get(2)) self.assertEqual(("C", 3), channel.get(3))
def testOutgoingSerialPutTwoGetOne(self): channel = comm_channel.CommChannel() channel.put('A') channel.put('B') channel.put('C') self.assertEqual(('A', 3), channel.get(1)) self.assertEqual(('B', 3), channel.get(2)) self.assertEqual(('C', 3), channel.get(3))
def testOutgoingConcurrentPutAndOneGetter(self): channel = comm_channel.CommChannel() result = {"outgoing": []} def get_two(): result["outgoing"].append(channel.get(1)) result["outgoing"].append(channel.get(2)) t = threading.Thread(target=get_two) t.start() channel.put("A") channel.put("B") t.join() self.assertEqual("A", result["outgoing"][0][0]) self.assertIn(result["outgoing"][0][1], [1, 2]) self.assertEqual(("B", 2), result["outgoing"][1])
def testOutgoingConcurrentPutAndOneGetter(self): channel = comm_channel.CommChannel() result = {'outgoing': []} def get_two(): result['outgoing'].append(channel.get(1)) result['outgoing'].append(channel.get(2)) t = threading.Thread(target=get_two) t.start() channel.put('A') channel.put('B') t.join() self.assertEqual('A', result['outgoing'][0][0]) self.assertIn(result['outgoing'][0][1], [1, 2]) self.assertEqual(('B', 2), result['outgoing'][1])
def testOutgoingSerialPutOneAndGetOne(self): channel = comm_channel.CommChannel() channel.put("A") self.assertEqual(("A", 1), channel.get(1))
def testGetOutgoingWithInvalidPosLeadsToAssertionError(self): channel = comm_channel.CommChannel() with self.assertRaises(ValueError): channel.get(0) with self.assertRaises(ValueError): channel.get(-1)