def test_matmul(self): a = np.eye(10) b = np.eye(10) for i in range(10): b[i, i] = i a_s = scipy.sparse.coo_matrix(a) a_t = tensor.SparseMatrix(a_s, info=GraphInfo('a')) b_t = tensor.Constant(b, info=GraphInfo('b')) c_t = a_t @ b_t with self.test_session() as sess: c_t = sess.run(c_t.data) np.testing.assert_almost_equal(c_t, a @ b)
def test_assign_add(self): x = tensor.Variable(GraphInfo('x', 'scope', False), initializer=1) y = x.assign_add(10) with self.variables_initialized_test_session() as sess: assert sess.run(x.data) == 1 assert sess.run(y.data) == 11 assert sess.run(x.data) == 11
def create_simple_info(self, name='x'): return GraphInfo(name, None, False)
def test_empty_scope(self): info = GraphInfo('x', '', False) with info.variable_scope(): assert tf.get_variable_scope().name == ''
def test_child(self): info = GraphInfo('x', None, False) child = info.child_scope('y') self.assertNameEqual(child, 'x/y')
def test_auto_scope(self): info = GraphInfo('x', None, False) self.assertEqual(info.scope, 'x')
def test_update(self): info = GraphInfo('x', None, False) info_u = info.update(name=info.name / 'y') self.assertNameEqual(info_u, 'x/y')
def create_reuseable_info(self, name='x'): return GraphInfo(name, None, True)
def test_construct_by_graph_info_value(self): x = tensor.Variable(GraphInfo('x', 'scope', False), initializer=0) with self.variables_initialized_test_session() as sess: assert sess.run(x.data) == 0
def test_construct_by_graph_info_name(self): x = tensor.Variable(GraphInfo('x', 'scope', False), initializer=0) assert x.data.name == 'scope/x:0'
def test_basic(self): x_ = np.array([1.0, 2.0, 3.0], np.float32) x = tensor.Variable(GraphInfo('x'), initializer=x_) with self.variables_initialized_test_session() as sess: self.assertAllEqual(x.eval(), [1.0, 2.0, 3.0])
def test_construct_with_none_name(self): with tf.variable_scope('scope') as scope: a = tf.constant(1, name='a') t = tensor.Tensor(a, GraphInfo(None, scope, False)) self.assertNameEqual(t.info, 'scope/a') assert t.info.scope == scope
def test_input_info(self): g = Graph(GraphInfo('g', 'g_scope')) self.assertNameEqual(g, 'g') self.assertNameEqual(g.info, 'g') self.assertNameEqual(g.info.scope, 'g_scope')