Пример #1
0
 def verify(values, name=None, indices=None):
     tensor = Tensor(values, indices, name)
     pb = elasticdl_pb2.Tensor()
     serialize_tensor(tensor, pb)
     tensor_new = Tensor()
     deserialize_tensor_pb(pb, tensor_new)
     np.testing.assert_array_equal(values, tensor_new.values)
     if indices is not None:
         np.testing.assert_array_equal(indices, tensor_new.indices)
     if name:
         self.assertEqual(name, tensor.name)
Пример #2
0
 def _restore_params_from_pb(self, tensors_pb):
     for pb in tensors_pb:
         name = pb.name
         if not pb.indices:
             # Please note that `tf.Variable` will do something with magic.
             # If you pass a name "somename" to a `tf.Variable`, the final
             # variable name will be "somename:0". So the `tf.Variable.name`
             # is meaningless, we must avoid use it in PS side.
             arr = tensor_pb_to_ndarray(pb)
             var = tf.Variable(initial_value=arr, trainable=True)
             self.non_embedding_params[name] = var
         else:
             # Only pb of embedding parameters has indices.
             tensor = Tensor()
             deserialize_tensor_pb(pb, tensor)
             self.embedding_params[name].set(tensor.indices, tensor.values)
Пример #3
0
    def test_deserialize_tensor_pb(self):
        pb = elasticdl_pb2.Tensor()
        tensor = Tensor()
        # No dim defined, should raise.
        self.assertRaises(ValueError, deserialize_tensor_pb, pb, tensor)

        # Empty array, should be ok.
        pb.dim.append(0)
        pb.content = b""
        pb.dtype = tensor_dtype_pb2.DT_FLOAT32
        deserialize_tensor_pb(pb, tensor)
        np.testing.assert_array_equal(np.array([], dtype=np.float32),
                                      tensor.values)

        # Wrong type, should raise
        del pb.dim[:]
        pb.dim.append(0)
        pb.content = b""
        pb.dtype = tensor_dtype_pb2.DT_INVALID
        self.assertRaises(ValueError, deserialize_tensor_pb, pb, tensor)

        # Pathological case, one of the dimensions is 0.
        del pb.dim[:]
        pb.dim.extend([2, 0, 1, 9])
        pb.content = b""
        pb.dtype = tensor_dtype_pb2.DT_FLOAT32
        deserialize_tensor_pb(pb, tensor)
        np.testing.assert_array_equal(
            np.ndarray(shape=[2, 0, 1, 9], dtype=np.float32), tensor.values)

        # Wrong content size, should raise
        del pb.dim[:]
        pb.dim.append(11)
        pb.content = b"\0" * (4 * 12)
        pb.dtype = tensor_dtype_pb2.DT_FLOAT32
        self.assertRaises(ValueError, deserialize_tensor_pb, pb, tensor)

        # Compatible dimensions, should be ok.
        for m in (1, 2, 3, 4, 6, 12):
            for with_inidices in [True, False]:
                del pb.dim[:]
                pb.content = b"\0" * (4 * 12)
                pb.dim.extend([m, 12 // m])
                if with_inidices:
                    pb.indices.extend([0] * m)
                pb.dtype = tensor_dtype_pb2.DT_FLOAT32
                deserialize_tensor_pb(pb, tensor)
                self.assertEqual((m, 12 // m), tensor.values.shape)
                self.assertTrue(isinstance(tensor.values, np.ndarray))
                if tensor.indices is not None:
                    self.assertTrue(isinstance(tensor.indices, np.ndarray))