def testBasic(self): devices = _ListDevices(_Target()) print("\n".join(devices)) sender, recver = devices[0], devices[-1] shape = [] for dtype in tf.float32, tf.complex64: to_send = np.array(3.1415 + 2j).astype(dtype.as_numpy_dtype) g = tf.Graph() with g.as_default(): ch = sendrecv.Channel(dtype, shape, sender, recver, "test") with tf.device(sender): src_val = tf.constant(to_send) send_op = ch.Send(src_val) with tf.device(recver): recv_val = ch.Recv() with tf.Session(_Target(), graph=g) as sess: _, val = sess.run([send_op, recv_val]) self.assertAllClose(to_send, val)
def SendRecv(graph, dtype): to_send = np.array(3.1415 + 2j).astype(dtype.as_numpy_dtype) with graph.as_default(): ch = sendrecv.Channel(dtype, shape, sender, recver, "test") with tf.device(sender): @function.Defun() def Send(): src_val = tf.constant(to_send) ch.Send(src_val) return 1.0 send_op = Send() with tf.device(recver): @function.Defun() def Recv(): return ch.Recv() recv_val = Recv() return send_op, recv_val, to_send
def SendRecv(graph, dtype): to_send = np.array(3.1415 + 2j).astype(dtype.as_numpy_dtype) with graph.as_default(): ch = sendrecv.Channel(dtype, shape, sender, recver, "test") with tf.device(sender): # py_utils.CallDefun requires non-empty inputs. Same below. def Send(_): src_val = tf.constant(to_send) ch.Send(src_val) return tf.convert_to_tensor(1.0) send_op = py_utils.CallDefun(Send, tf.convert_to_tensor(0)) with tf.device(recver): def Recv(_): return ch.Recv() recv_val = py_utils.CallDefun(Recv, tf.convert_to_tensor(0)) return send_op, recv_val, to_send