def test_to_representation_for_type_with_nested_odict_constant(self): rep = executor.to_representation_for_type( collections.OrderedDict([('a', 10), ('b', [20, 30])]), collections.OrderedDict([('a', np.int32), ('b', [np.int32, np.int32])])) self.assertIsInstance(rep, structure.Struct) self.assertEqual(str(rep), '<a=10,b=<20,30>>')
def test_to_representation_for_type_with_noarg_to_int32_comp(self): builder = xla_client.XlaBuilder('comp') xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple())) xla_client.ops.Constant(builder, np.int32(10)) xla_comp = builder.build() comp_type = computation_types.FunctionType(None, np.int32) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [], comp_type) rep = executor.to_representation_for_type(comp_pb, comp_type, self._backend) self.assertTrue(callable(rep)) result = rep() self.assertEqual(result, 10)
def test_to_representation_for_type_with_2xint32_to_int32_comp(self): builder = xla_client.XlaBuilder('comp') param = xla_client.ops.Parameter( builder, 0, xla_client.shape_from_pyval( tuple([np.array(0, dtype=np.int32)] * 2))) xla_client.ops.Add(xla_client.ops.GetTupleElement(param, 0), xla_client.ops.GetTupleElement(param, 1)) xla_comp = builder.build() comp_type = computation_types.FunctionType((np.int32, np.int32), np.int32) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [0, 1], comp_type) rep = executor.to_representation_for_type(comp_pb, comp_type, self._backend) self.assertTrue(callable(rep)) result = rep( structure.Struct([(None, np.int32(20)), (None, np.int32(30))])) self.assertEqual(result, 50)
def test_to_representation_for_type_with_noarg_to_2xint32_comp(self): builder = xla_client.XlaBuilder('comp') xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple())) xla_client.ops.Tuple(builder, [ xla_client.ops.Constant(builder, np.int32(10)), xla_client.ops.Constant(builder, np.int32(20)) ]) xla_comp = builder.build() comp_type = computation_types.FunctionType( None, computation_types.StructType([('a', np.int32), ('b', np.int32)])) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [0, 1], comp_type) rep = executor.to_representation_for_type(comp_pb, comp_type, self._backend) self.assertTrue(callable(rep)) result = rep() self.assertEqual(str(result), '<a=10,b=20>')
def test_to_representation_for_type_with_2xint32_list_constant(self): rep = executor.to_representation_for_type([10, 20], (np.int32, np.int32)) self.assertIsInstance(rep, structure.Struct) self.assertEqual(str(rep), '<10,20>')
def test_to_representation_for_type_with_int32_constant(self): rep = executor.to_representation_for_type(10, np.int32) self.assertEqual(rep, 10)