Example #1
0
 def test_to_representation_for_type_with_nested_odict_constant(self):
     rep = executor.to_representation_for_type(
         collections.OrderedDict([('a', 10), ('b', [20, 30])]),
         collections.OrderedDict([('a', np.int32),
                                  ('b', [np.int32, np.int32])]))
     self.assertIsInstance(rep, structure.Struct)
     self.assertEqual(str(rep), '<a=10,b=<20,30>>')
Example #2
0
 def test_to_representation_for_type_with_noarg_to_int32_comp(self):
     builder = xla_client.XlaBuilder('comp')
     xla_client.ops.Parameter(builder, 0,
                              xla_client.shape_from_pyval(tuple()))
     xla_client.ops.Constant(builder, np.int32(10))
     xla_comp = builder.build()
     comp_type = computation_types.FunctionType(None, np.int32)
     comp_pb = xla_serialization.create_xla_tff_computation(
         xla_comp, [], comp_type)
     rep = executor.to_representation_for_type(comp_pb, comp_type,
                                               self._backend)
     self.assertTrue(callable(rep))
     result = rep()
     self.assertEqual(result, 10)
Example #3
0
 def test_to_representation_for_type_with_2xint32_to_int32_comp(self):
     builder = xla_client.XlaBuilder('comp')
     param = xla_client.ops.Parameter(
         builder, 0,
         xla_client.shape_from_pyval(
             tuple([np.array(0, dtype=np.int32)] * 2)))
     xla_client.ops.Add(xla_client.ops.GetTupleElement(param, 0),
                        xla_client.ops.GetTupleElement(param, 1))
     xla_comp = builder.build()
     comp_type = computation_types.FunctionType((np.int32, np.int32),
                                                np.int32)
     comp_pb = xla_serialization.create_xla_tff_computation(
         xla_comp, [0, 1], comp_type)
     rep = executor.to_representation_for_type(comp_pb, comp_type,
                                               self._backend)
     self.assertTrue(callable(rep))
     result = rep(
         structure.Struct([(None, np.int32(20)), (None, np.int32(30))]))
     self.assertEqual(result, 50)
Example #4
0
 def test_to_representation_for_type_with_noarg_to_2xint32_comp(self):
     builder = xla_client.XlaBuilder('comp')
     xla_client.ops.Parameter(builder, 0,
                              xla_client.shape_from_pyval(tuple()))
     xla_client.ops.Tuple(builder, [
         xla_client.ops.Constant(builder, np.int32(10)),
         xla_client.ops.Constant(builder, np.int32(20))
     ])
     xla_comp = builder.build()
     comp_type = computation_types.FunctionType(
         None,
         computation_types.StructType([('a', np.int32), ('b', np.int32)]))
     comp_pb = xla_serialization.create_xla_tff_computation(
         xla_comp, [0, 1], comp_type)
     rep = executor.to_representation_for_type(comp_pb, comp_type,
                                               self._backend)
     self.assertTrue(callable(rep))
     result = rep()
     self.assertEqual(str(result), '<a=10,b=20>')
Example #5
0
 def test_to_representation_for_type_with_2xint32_list_constant(self):
     rep = executor.to_representation_for_type([10, 20],
                                               (np.int32, np.int32))
     self.assertIsInstance(rep, structure.Struct)
     self.assertEqual(str(rep), '<10,20>')
Example #6
0
 def test_to_representation_for_type_with_int32_constant(self):
     rep = executor.to_representation_for_type(10, np.int32)
     self.assertEqual(rep, 10)