def testErrorIndicesMultiDimensional(self):
   indices = [
       constant_op.constant([0, 4, 7]), constant_op.constant([[1, 6, 2, 3, 5]])
   ]
   data = [
       constant_op.constant([[0, 40, 70]]),
       constant_op.constant([10, 60, 20, 30, 50])
   ]
   with self.assertRaises(ValueError):
     data_flow_ops.dynamic_stitch(indices, data)
 def testErrorDataDimSizeMismatch(self):
   indices = [
       constant_op.constant([0, 4, 5]), constant_op.constant([1, 6, 2, 3])
   ]
   data = [
       constant_op.constant([[0], [40], [70]]),
       constant_op.constant([[10, 11], [60, 61], [20, 21], [30, 31]])
   ]
   with self.assertRaises(ValueError):
     data_flow_ops.dynamic_stitch(indices, data)
 def testErrorDataAndIndicesSizeMismatch(self):
   indices = [
       constant_op.constant([0, 4, 7]), constant_op.constant([1, 6, 2, 3, 5])
   ]
   data = [
       constant_op.constant([0, 40, 70]),
       constant_op.constant([10, 60, 20, 30])
   ]
   with self.assertRaises(ValueError):
     data_flow_ops.dynamic_stitch(indices, data)
 def testErrorDataDimSizeMismatch(self):
     indices = [
         constant_op.constant([0, 4, 5]),
         constant_op.constant([1, 6, 2, 3])
     ]
     data = [
         constant_op.constant([[0], [40], [70]]),
         constant_op.constant([[10, 11], [60, 61], [20, 21], [30, 31]])
     ]
     with self.assertRaises(ValueError):
         data_flow_ops.dynamic_stitch(indices, data)
 def testErrorDataAndIndicesSizeMismatch(self):
     indices = [
         constant_op.constant([0, 4, 7]),
         constant_op.constant([1, 6, 2, 3, 5])
     ]
     data = [
         constant_op.constant([0, 40, 70]),
         constant_op.constant([10, 60, 20, 30])
     ]
     with self.assertRaises(ValueError):
         data_flow_ops.dynamic_stitch(indices, data)
 def testErrorIndicesMultiDimensional(self):
     indices = [
         constant_op.constant([0, 4, 7]),
         constant_op.constant([[1, 6, 2, 3, 5]])
     ]
     data = [
         constant_op.constant([[0, 40, 70]]),
         constant_op.constant([10, 60, 20, 30, 50])
     ]
     with self.assertRaises(ValueError):
         data_flow_ops.dynamic_stitch(indices, data)
 def testHigherRankGPU(self):
     with self.test_session() as sess:
         indices = [
             constant_op.constant(6),
             constant_op.constant([4, 1]),
             constant_op.constant([[5, 2], [0, 3]])
         ]
         data = [
             constant_op.constant([61, 62], dtype=dtypes.float32),
             constant_op.constant([[41, 42], [11, 12]],
                                  dtype=dtypes.float32),
             constant_op.constant(
                 [[[51, 52], [21, 22]], [[1, 2], [31, 32]]],
                 dtype=dtypes.float32)
         ]
         stitched_t = data_flow_ops.dynamic_stitch(indices, data)
         stitched_val = stitched_t.eval()
         correct = 10 * np.arange(7)[:, None] + [1.0, 2.0]
         self.assertAllEqual(correct, stitched_val)
         self.assertEqual([7, 2], stitched_t.get_shape().as_list())
         # Test gradients
         stitched_grad = 7 * stitched_val
         grads = gradients_impl.gradients(stitched_t, indices + data,
                                          stitched_grad)
         self.assertEqual(grads[:3],
                          [None] * 3)  # Indices have no gradients
         for datum, grad in zip(data, sess.run(grads[3:])):
             self.assertAllEqual(7.0 * datum.eval(), grad)
Ejemplo n.º 8
0
    def DynamicStitchGrads(op, grad):
        num_values = len(op.inputs) // 2
        indices_grad = [None] * num_values

        def AsInt32(x):
            return (x if op.inputs[0].dtype == dtypes.int32 else
                    math_ops.cast(x, dtypes.int32))

        idxs = [AsInt32(array_ops.reshape(op.inputs[i], (-1,)))
                for i in range(num_values)]
        if isinstance(grad, ops.IndexedSlices):
            output_shape = array_ops.shape(op.outputs[0])
            output_rows = output_shape[0]
            grad = math_ops.unsorted_segment_sum(grad.values, grad.indices,
                                                 output_rows)

        values_grad = []
        zeros = array_ops.zeros_like(grad)
        idx_zeros = [zeros[:array_ops.shape(x)[0]] for x in idxs]
        grad_range = math_ops.range(array_ops.shape(grad)[0])
        for i in range(num_values):
            if i == num_values - 1:
                v_grad = grad
            else:
                v_grad = data_flow_ops.dynamic_stitch(
                    [grad_range] + idxs[i + 1:], [grad] + idx_zeros[i + 1:])
            v_grad = array_ops.gather(v_grad, AsInt32(op.inputs[i]))
            values_grad += [v_grad]

        return indices_grad + values_grad
  def lookup(self, keys, name=None):
    if keys.dtype != self._key_dtype:
      raise TypeError('Signature mismatch. Keys must be dtype %s, got %s.' %
                      (self._key_dtype, keys.dtype))
    self._check_keys(keys)
    num_shards = self._num_shards
    if num_shards == 1:
      return self._table_shards[0].lookup(keys, name=name)

    shard_indices = self._shard_indices(keys)
    # TODO(andreasst): support 'keys' that are not vectors
    key_shards = data_flow_ops.dynamic_partition(keys, shard_indices,
                                                 num_shards)
    value_shards = [
        self._table_shards[i].lookup(key_shards[i], name=name)
        for i in range(num_shards)
    ]

    num_keys = keys.get_shape().dims[0]
    original_indices = math_ops.range(num_keys)
    partitioned_indices = data_flow_ops.dynamic_partition(original_indices,
                                                          shard_indices,
                                                          num_shards)
    result = data_flow_ops.dynamic_stitch(partitioned_indices, value_shards)
    result.set_shape(
        tensor_shape.TensorShape([num_keys]).concatenate(self._value_shape))
    return result
Ejemplo n.º 10
0
    def DynamicStitchGrads(op, grad):
        num_values = len(op.inputs) // 2
        indices_grad = [None] * num_values

        def AsInt32(x):
            return (x if op.inputs[0].dtype == dtypes.int32 else math_ops.cast(
                x, dtypes.int32))

        idxs = [
            AsInt32(array_ops.reshape(op.inputs[i], (-1, )))
            for i in range(num_values)
        ]
        if isinstance(grad, ops.IndexedSlices):
            output_shape = array_ops.shape(op.outputs[0])
            output_rows = output_shape[0]
            grad = math_ops.unsorted_segment_sum(grad.values, grad.indices,
                                                 output_rows)

        values_grad = []
        zeros = array_ops.zeros_like(grad)
        idx_zeros = [zeros[:array_ops.shape(x)[0]] for x in idxs]
        grad_range = math_ops.range(array_ops.shape(grad)[0])
        for i in range(num_values):
            if i == num_values - 1:
                v_grad = grad
            else:
                v_grad = data_flow_ops.dynamic_stitch(
                    [grad_range] + idxs[i + 1:], [grad] + idx_zeros[i + 1:])
            v_grad = array_ops.gather(v_grad, AsInt32(op.inputs[i]))
            values_grad += [v_grad]

        return indices_grad + values_grad
 def testHigherRankGPU(self):
   with self.cached_session() as sess:
     indices = [
         constant_op.constant(6),
         constant_op.constant([4, 1]),
         constant_op.constant([[5, 2], [0, 3]])
     ]
     data = [
         constant_op.constant([61, 62], dtype=dtypes.float32),
         constant_op.constant([[41, 42], [11, 12]], dtype=dtypes.float32),
         constant_op.constant(
             [[[51, 52], [21, 22]], [[1, 2], [31, 32]]], dtype=dtypes.float32)
     ]
     stitched_t = data_flow_ops.dynamic_stitch(indices, data)
     stitched_val = self.evaluate(stitched_t)
     correct = 10 * np.arange(7)[:, None] + [1.0, 2.0]
     self.assertAllEqual(correct, stitched_val)
     self.assertEqual([7, 2], stitched_t.get_shape().as_list())
     # Test gradients
     stitched_grad = 7 * stitched_val
     grads = gradients_impl.gradients(stitched_t, indices + data,
                                      stitched_grad)
     self.assertEqual(grads[:3], [None] * 3)  # Indices have no gradients
     for datum, grad in zip(data, sess.run(grads[3:])):
       self.assertAllEqual(7.0 * self.evaluate(datum), grad)
    def lookup(self, keys, name=None):
        """Looks up `keys` in a table, outputs the corresponding values."""
        if keys.dtype.base_dtype != self._key_dtype:
            raise TypeError(
                'Signature mismatch. Keys must be dtype %s, got %s.' %
                (self._key_dtype, keys.dtype))
        self._check_keys(keys)
        num_shards = self._num_shards
        if num_shards == 1:
            return self._table_shards[0].lookup(keys, name=name)

        shard_indices = self._shard_indices(keys)
        # TODO(andreasst): support 'keys' that are not vectors
        key_shards = data_flow_ops.dynamic_partition(keys, shard_indices,
                                                     num_shards)
        value_shards = [
            self._table_shards[i].lookup(key_shards[i], name=name)
            for i in range(num_shards)
        ]

        num_keys = keys.get_shape().dims[0]
        original_indices = math_ops.range(num_keys)
        partitioned_indices = data_flow_ops.dynamic_partition(
            original_indices, shard_indices, num_shards)
        result = data_flow_ops.dynamic_stitch(partitioned_indices,
                                              value_shards)
        result.set_shape(
            tensor_shape.TensorShape([num_keys
                                      ]).concatenate(self._value_shape))
        return result
Ejemplo n.º 13
0
    def lookup(self, keys, name=None):
        if keys.dtype != self._key_dtype:
            raise TypeError(
                'Signature mismatch. Keys must be dtype %s, got %s.' %
                (self._key_dtype, keys.dtype))
        num_shards = self._num_shards
        if num_shards == 1:
            return self._table_shards[0].lookup(keys, name=name)

        shard_indices = self._shard_indices(keys)
        # TODO(andreasst): support 'keys' that are not vectors
        key_shards = data_flow_ops.dynamic_partition(keys, shard_indices,
                                                     num_shards)
        value_shards = [
            self._table_shards[i].lookup(key_shards[i], name=name)
            for i in range(num_shards)
        ]

        original_indices = math_ops.range(array_ops.size(keys))
        partitioned_indices = data_flow_ops.dynamic_partition(
            original_indices, shard_indices, num_shards)
        result = data_flow_ops.dynamic_stitch(partitioned_indices,
                                              value_shards)
        result.set_shape(keys.get_shape().concatenate(self._value_shape))
        return result
Ejemplo n.º 14
0
 def testScalarGPU(self):
   indices = [constant_op.constant(0), constant_op.constant(1)]
   data = [constant_op.constant(40.0), constant_op.constant(60.0)]
   for step in -1, 1:
     stitched_t = data_flow_ops.dynamic_stitch(indices[::step], data)
     stitched_val = self.evaluate(stitched_t)
     self.assertAllEqual([40.0, 60.0][::step], stitched_val)
     # Dimension 0 is max(flatten(indices))+1.
     self.assertEqual([2], stitched_t.get_shape().as_list())
Ejemplo n.º 15
0
 def testPinRequiredOpsOnCPU(self):
     with ops.Graph().as_default() as g, g.device(graph_util.pin_variables_on_cpu):
         const_a = constant_op.constant(5.0)
         const_b = constant_op.constant(10.0)
         add_c = const_a + const_b
         var_v = state_ops.variable_op([], dtype=types.float32)
         assign_c_to_v = state_ops.assign(var_v, add_c)
         dynamic_stitch_int_result = data_flow_ops.dynamic_stitch([[0, 1, 2], [2, 3]], [[12, 23, 34], [1, 2]])
         dynamic_stitch_float_result = data_flow_ops.dynamic_stitch(
             [[0, 1, 2], [2, 3]], [[12.0, 23.0, 34.0], [1.0, 2.0]]
         )
         # Non-variable ops shuld not specify a device
         self.assertEqual(const_a.device, None)
         self.assertEqual(const_b.device, None)
         self.assertEqual(add_c.device, None)
         # Variable ops specify a device
         self.assertEqual(var_v.device, "/device:CPU:0")
         self.assertEqual(assign_c_to_v.device, "/device:CPU:0")
 def testScalarGPU(self):
   indices = [constant_op.constant(0), constant_op.constant(1)]
   data = [constant_op.constant(40.0), constant_op.constant(60.0)]
   for step in -1, 1:
     stitched_t = data_flow_ops.dynamic_stitch(indices[::step], data)
     stitched_val = self.evaluate(stitched_t)
     self.assertAllEqual([40.0, 60.0][::step], stitched_val)
     # Dimension 0 is max(flatten(indices))+1.
     self.assertEqual([2], stitched_t.get_shape().as_list())
Ejemplo n.º 17
0
    def ScatterUpdateGrads(op, grad):
        _, indices, updates = op.inputs

        grad_range = math_ops.range(array_ops.shape(grad)[0])
        var_grad = data_flow_ops.dynamic_stitch(
            [grad_range, indices], [grad, array_ops.zeros_like(updates)])

        updates_grad = array_ops.gather(grad, indices)

        return var_grad, None, updates_grad
Ejemplo n.º 18
0
 def testPinRequiredOpsOnCPU(self):
     with ops.Graph().as_default() as g, g.device(
             graph_util.pin_variables_on_cpu):
         const_a = constant_op.constant(5.0)
         const_b = constant_op.constant(10.0)
         add_c = const_a + const_b
         var_v = state_ops.variable_op([], dtype=dtypes.float32)
         assign_c_to_v = state_ops.assign(var_v, add_c)
         dynamic_stitch_int_result = data_flow_ops.dynamic_stitch(
             [[0, 1, 2], [2, 3]], [[12, 23, 34], [1, 2]])
         dynamic_stitch_float_result = data_flow_ops.dynamic_stitch(
             [[0, 1, 2], [2, 3]], [[12.0, 23.0, 34.0], [1.0, 2.0]])
         # Non-variable ops shuld not specify a device
         self.assertDeviceEqual(const_a.device, None)
         self.assertDeviceEqual(const_b.device, None)
         self.assertDeviceEqual(add_c.device, None)
         # Variable ops specify a device
         self.assertDeviceEqual(var_v.device, "/device:CPU:0")
         self.assertDeviceEqual(assign_c_to_v.device, "/device:CPU:0")
Ejemplo n.º 19
0
 def testSumGradArgs(self):
   with self.test_session(use_gpu=False):
     indices = [
         ops.convert_to_tensor([0, 1, 2, 3]), ops.convert_to_tensor([2, 3])
     ]
     values = [
         ops.convert_to_tensor([2, 3, 5, 7]), ops.convert_to_tensor([1, 1])
     ]
     self.assertAllEqual(
         data_flow_ops.dynamic_stitch(indices, values).eval(), [2, 3, 1, 1])
Ejemplo n.º 20
0
 def testInt32Gpu(self):
   with self.test_session(use_gpu=True):
     indices = [
         ops.convert_to_tensor([0, 1, 2]), ops.convert_to_tensor([2, 3])
     ]
     values = [
         ops.convert_to_tensor([12, 23, 34]), ops.convert_to_tensor([1, 2])
     ]
     self.assertAllEqual(
         data_flow_ops.dynamic_stitch(indices, values).eval(), [12, 23, 1, 2])
Ejemplo n.º 21
0
 def testPinToCpu(self):
   with ops.Graph().as_default() as g, g.device(graph_util.pin_to_cpu):
     const_a = constant_op.constant(5.0)
     const_b = constant_op.constant(10.0)
     add_c = const_a + const_b
     var_v = state_ops.variable_op([], dtype=types.float32)
     assign_c_to_v = state_ops.assign(var_v, add_c)
     const_string = constant_op.constant("on a cpu")
     dynamic_stitch_int_result = data_flow_ops.dynamic_stitch(
         [[0, 1, 2], [2, 3]], [[12, 23, 34], [1, 2]])
     dynamic_stitch_float_result = data_flow_ops.dynamic_stitch(
         [[0, 1, 2], [2, 3]], [[12.0, 23.0, 34.0], [1.0, 2.0]])
   self.assertEqual(const_a.device, "/device:CPU:0")
   self.assertEqual(const_b.device, "/device:CPU:0")
   self.assertEqual(add_c.device, "/device:CPU:0")
   self.assertEqual(var_v.device, "/device:CPU:0")
   self.assertEqual(assign_c_to_v.device, "/device:CPU:0")
   self.assertEqual(const_string.device, "/device:CPU:0")
   self.assertEqual(dynamic_stitch_int_result.device, "/device:CPU:0")
   self.assertEqual(dynamic_stitch_float_result.device, "/device:CPU:0")
Ejemplo n.º 22
0
 def testStitchOrder(self):
   with self.cached_session():
     indices = []
     np_values = []
     values = []
     for _ in range(10):
       indices.extend([ops.convert_to_tensor(np.arange(100).astype(np.int32))])
       np_values.extend([np.random.uniform(size=100)])
       values.extend([ops.convert_to_tensor(np_values[-1])])
     stitched = data_flow_ops.dynamic_stitch(indices, values).eval()
   self.assertAllEqual(np_values[-1], stitched)
Ejemplo n.º 23
0
 def testPinToCpu(self):
   with ops.Graph().as_default() as g, g.device(graph_util.pin_to_cpu):
     const_a = constant_op.constant(5.0)
     const_b = constant_op.constant(10.0)
     add_c = const_a + const_b
     var_v = state_ops.variable_op([], dtype=dtypes.float32)
     assign_c_to_v = state_ops.assign(var_v, add_c)
     const_string = constant_op.constant("on a cpu")
     dynamic_stitch_int_result = data_flow_ops.dynamic_stitch(
         [[0, 1, 2], [2, 3]], [[12, 23, 34], [1, 2]])
     dynamic_stitch_float_result = data_flow_ops.dynamic_stitch(
         [[0, 1, 2], [2, 3]], [[12.0, 23.0, 34.0], [1.0, 2.0]])
   self.assertDeviceEqual(const_a.device, "/device:CPU:0")
   self.assertDeviceEqual(const_b.device, "/device:CPU:0")
   self.assertDeviceEqual(add_c.device, "/device:CPU:0")
   self.assertDeviceEqual(var_v.device, "/device:CPU:0")
   self.assertDeviceEqual(assign_c_to_v.device, "/device:CPU:0")
   self.assertDeviceEqual(const_string.device, "/device:CPU:0")
   self.assertDeviceEqual(dynamic_stitch_int_result.device, "/device:CPU:0")
   self.assertDeviceEqual(dynamic_stitch_float_result.device, "/device:CPU:0")
Ejemplo n.º 24
0
 def testStitchOrder(self):
   with self.test_session():
     indices = []
     np_values = []
     values = []
     for _ in range(10):
       indices.extend([ops.convert_to_tensor(np.arange(100).astype(np.int32))])
       np_values.extend([np.random.uniform(size=100)])
       values.extend([ops.convert_to_tensor(np_values[-1])])
     stitched = data_flow_ops.dynamic_stitch(indices, values).eval()
   self.assertAllEqual(np_values[-1], stitched)
 def testOneListOneDimensional(self):
     with self.test_session():
         indices = [constant_op.constant([1, 6, 2, 3, 5, 0, 4, 7])]
         data = [constant_op.constant([10, 60, 20, 30, 50, 0, 40, 70])]
         stitched_t = data_flow_ops.dynamic_stitch(indices, data)
         stitched_val = stitched_t.eval()
         self.assertAllEqual([0, 10, 20, 30, 40, 50, 60, 70], stitched_val)
         # Dimension 0 is determined by the max index in indices, so we
         # can only infer that the output is a vector of some unknown
         # length.
         self.assertEqual([None], stitched_t.get_shape().as_list())
 def testOneListOneDimensional(self):
   with self.test_session():
     indices = [constant_op.constant([1, 6, 2, 3, 5, 0, 4, 7])]
     data = [constant_op.constant([10, 60, 20, 30, 50, 0, 40, 70])]
     stitched_t = data_flow_ops.dynamic_stitch(indices, data)
     stitched_val = stitched_t.eval()
     self.assertAllEqual([0, 10, 20, 30, 40, 50, 60, 70], stitched_val)
     # Dimension 0 is determined by the max index in indices, so we
     # can only infer that the output is a vector of some unknown
     # length.
     self.assertEqual([None], stitched_t.get_shape().as_list())
 def testScalar(self):
   with self.test_session():
     indices = [constant_op.constant(0), constant_op.constant(1)]
     data = [constant_op.constant(40), constant_op.constant(60)]
     for step in -1, 1:
       stitched_t = data_flow_ops.dynamic_stitch(indices[::step], data)
       stitched_val = stitched_t.eval()
       self.assertAllEqual([40, 60][::step], stitched_val)
       # Dimension 0 is determined by the max index in indices, so we
       # can only infer that the output is a vector of some unknown
       # length.
       self.assertEqual([None], stitched_t.get_shape().as_list())
Ejemplo n.º 28
0
 def testCint32Gpu(self):
   with self.session():
     indices = [
         ops.convert_to_tensor([0, 1, 2]),
         ops.convert_to_tensor([2, 3])
     ]
     values = [
         ops.convert_to_tensor([12, 23, 34]),
         ops.convert_to_tensor([1, 2])
     ]
     self.assertAllEqual(
         data_flow_ops.dynamic_stitch(indices, values), [12, 23, 1, 2])
Ejemplo n.º 29
0
def _ReductionGradAssist(op):
    """Reduction grads have much in common, so factor the commonality out."""
    inp = op.inputs[0]  # Example:
    input_shape = array_ops.shape(inp)  # [2, 3, 5, 7]
    input_rank = array_ops.rank(inp)  # 4
    indices = op.inputs[1]  # [1, 2]
    indices_shape = array_ops.shape(indices)  # [2]
    new_output_shape = data_flow_ops.dynamic_stitch(  # [2, 1, 1, 7]
        [math_ops.range(input_rank), indices],  # [0, 1, 2, 3]  # [1, 2]
        [input_shape, array_ops.fill(indices_shape, 1)],  # [2, 3, 5, 7]
    )  # [1, 1]
    return inp, new_output_shape, input_shape
Ejemplo n.º 30
0
 def testSumGradArgs(self):
   with self.session(use_gpu=False):
     indices = [
         ops.convert_to_tensor([0, 1, 2, 3]),
         ops.convert_to_tensor([2, 3])
     ]
     values = [
         ops.convert_to_tensor([2, 3, 5, 7]),
         ops.convert_to_tensor([1, 1])
     ]
     self.assertAllEqual(
         data_flow_ops.dynamic_stitch(indices, values).eval(), [2, 3, 1, 1])
Ejemplo n.º 31
0
 def testInt32Cpu(self):
   with self.session(use_gpu=False):
     indices = [
         ops.convert_to_tensor([0, 1, 2]),
         ops.convert_to_tensor([2, 3])
     ]
     values = [
         ops.convert_to_tensor([12, 23, 34]),
         ops.convert_to_tensor([1, 2])
     ]
     self.assertAllEqual(
         data_flow_ops.dynamic_stitch(indices, values).eval(), [12, 23, 1, 2])
 def testScalar(self):
     with self.test_session():
         indices = [constant_op.constant(0), constant_op.constant(1)]
         data = [constant_op.constant(40), constant_op.constant(60)]
         for step in -1, 1:
             stitched_t = data_flow_ops.dynamic_stitch(
                 indices[::step], data)
             stitched_val = stitched_t.eval()
             self.assertAllEqual([40, 60][::step], stitched_val)
             # Dimension 0 is determined by the max index in indices, so we
             # can only infer that the output is a vector of some unknown
             # length.
             self.assertEqual([None], stitched_t.get_shape().as_list())
Ejemplo n.º 33
0
def _ReductionGradAssist(op):
  """Reduction grads have much in common, so factor the commonality out."""
  inp = op.inputs[0]                                # Example:
  input_shape = array_ops.shape(inp)                # [2, 3, 5, 7]
  input_rank = array_ops.rank(inp)                  # 4
  indices = op.inputs[1]                            # [1, 2]
  indices_shape = array_ops.shape(indices)          # [2]
  new_output_shape = data_flow_ops.dynamic_stitch(  # [2, 1, 1, 7]
      [math_ops.range(input_rank),                  # [0, 1, 2, 3]
       indices],                                    # [1, 2]
      [input_shape,                                 # [2, 3, 5, 7]
       array_ops.fill(indices_shape, 1)])           # [1, 1]
  return inp, new_output_shape, input_shape
Ejemplo n.º 34
0
def _DynamicPartitionGrads(op, *grads):
  """Gradients for DynamicPartition."""
  data = op.inputs[0]
  indices = op.inputs[1]
  num_partitions = op.get_attr("num_partitions")

  prefix_shape = array_ops.shape(indices)
  original_indices = array_ops.reshape(
      math_ops.range(math_ops.reduce_prod(prefix_shape)), prefix_shape)
  partitioned_indices = data_flow_ops.dynamic_partition(
      original_indices, indices, num_partitions)
  reconstructed = data_flow_ops.dynamic_stitch(partitioned_indices, grads)
  reconstructed = array_ops.reshape(reconstructed, array_ops.shape(data))
  return [reconstructed, None]
Ejemplo n.º 35
0
def _DynamicPartitionGrads(op, *grads):
    """Gradients for DynamicPartition."""
    data = op.inputs[0]
    indices = op.inputs[1]
    num_partitions = op.get_attr("num_partitions")

    prefix_shape = array_ops.shape(indices)
    original_indices = array_ops.reshape(
        math_ops.range(math_ops.reduce_prod(prefix_shape)), prefix_shape)
    partitioned_indices = data_flow_ops.dynamic_partition(
        original_indices, indices, num_partitions)
    reconstructed = data_flow_ops.dynamic_stitch(partitioned_indices, grads)
    reconstructed = array_ops.reshape(reconstructed, array_ops.shape(data))
    return [reconstructed, None]
Ejemplo n.º 36
0
    def sample(self, n_samples):
        n_samples = int(n_samples)

        if self.is_train:
            cat_probs = self._cat(n_samples)  # n x c
            agg_mu = tf.reduce_sum(tf.expand_dims(cat_probs, 2) * self._mu,
                                   axis=1)  # n x d
            agg_var = tf.reduce_sum(tf.expand_dims(cat_probs, 2) * self._var,
                                    axis=1)  # n x d

            raw = tf.random_normal([n_samples, self.dim])
            ret = agg_mu + tf.sqrt(agg_var) * raw  # n x d

            #cat_probs = self._cat(n_samples)  # n x c
            #samples_class = [None for _ in range(self.n_components)]
            #for c in range(self.n_components):
            #    raw = tf.random_normal([n_samples, self.dim])
            #    samples_class_c = self._mu[c] + raw * tf.sqrt(self._var[c]) #tf.matmul(raw, tf.transpose(self._scale[c]))
            #    samples_class[c] = samples_class_c
            #samples_class = tf.stack(samples_class) # c x n x d
            #ret = tf.reduce_sum(tf.expand_dims(cat_probs, 2) * tf.transpose(samples_class, [1,0,2]), axis=1)
        else:
            cat_samples = self._cat.sample(n_samples)  # n x 1

            samples_raw_indices = array_ops.reshape(
                math_ops.range(0, n_samples),
                cat_samples.get_shape().as_list())

            partitioned_samples_indices = data_flow_ops.dynamic_partition(
                data=samples_raw_indices,
                partitions=cat_samples,
                num_partitions=self.n_components)

            samples_class = [None for _ in range(self.n_components)]
            for c in range(self.n_components):
                n_class = array_ops.size(partitioned_samples_indices[c])
                raw = tf.random_normal([n_class, self.dim])
                samples_class_c = self._mu[c] + raw * tf.sqrt(self._var[c])
                samples_class[c] = samples_class_c

            # Stitch back together the samples across the components.
            ret = data_flow_ops.dynamic_stitch(
                indices=partitioned_samples_indices, data=samples_class)
            ret.set_shape((int(n_samples), self.dim))
        return ret
 def testSimpleTwoDimensional(self):
   with self.test_session():
     indices = [
         constant_op.constant([0, 4, 7]), constant_op.constant([1, 6]),
         constant_op.constant([2, 3, 5])
     ]
     data = [
         constant_op.constant([[0, 1], [40, 41], [70, 71]]),
         constant_op.constant([[10, 11], [60, 61]]),
         constant_op.constant([[20, 21], [30, 31], [50, 51]])
     ]
     stitched_t = data_flow_ops.dynamic_stitch(indices, data)
     stitched_val = stitched_t.eval()
     self.assertAllEqual([[0, 1], [10, 11], [20, 21], [30, 31], [40, 41],
                          [50, 51], [60, 61], [70, 71]], stitched_val)
     # Dimension 0 is determined by the max index in indices, so we
     # can only infer that the output is a matrix with 2 columns and
     # some unknown number of rows.
     self.assertEqual([None, 2], stitched_t.get_shape().as_list())
  def _AssertDynamicStitchResultIs(self, indices, data, expected):
    with self.test_session() as session:
      index_placeholders = [
          array_ops.placeholder(dtypes.as_dtype(arg.dtype)) for arg in indices
      ]
      data_placeholders = [
          array_ops.placeholder(dtypes.as_dtype(arg.dtype)) for arg in data
      ]
      with self.test_scope():
        output = data_flow_ops.dynamic_stitch(index_placeholders,
                                              data_placeholders)

      feed_dict = {}
      for placeholder, value in zip(index_placeholders, indices):
        feed_dict[placeholder] = value
      for placeholder, value in zip(data_placeholders, data):
        feed_dict[placeholder] = value
      result = session.run(output, feed_dict=feed_dict)
      self.assertAllClose(expected, result, rtol=1e-3)
 def testSimpleTwoDimensional(self):
     with self.test_session():
         indices = [
             constant_op.constant([0, 4, 7]),
             constant_op.constant([1, 6]),
             constant_op.constant([2, 3, 5])
         ]
         data = [
             constant_op.constant([[0, 1], [40, 41], [70, 71]]),
             constant_op.constant([[10, 11], [60, 61]]),
             constant_op.constant([[20, 21], [30, 31], [50, 51]])
         ]
         stitched_t = data_flow_ops.dynamic_stitch(indices, data)
         stitched_val = stitched_t.eval()
         self.assertAllEqual([[0, 1], [10, 11], [20, 21], [30, 31],
                              [40, 41], [50, 51], [60, 61], [70, 71]],
                             stitched_val)
         # Dimension 0 is determined by the max index in indices, so we
         # can only infer that the output is a matrix with 2 columns and
         # some unknown number of rows.
         self.assertEqual([None, 2], stitched_t.get_shape().as_list())
    def lookup(self, keys, name=None):
        if keys.dtype.base_dtype != self._key_dtype:
            raise TypeError(
                "Signature mismatch. Keys must be dtype %s, got %s." %
                (self._key_dtype, keys.dtype))
        self._check_keys(keys)
        num_shards = self._num_shards
        if num_shards == 1:
            return self._table_shards[0].lookup(keys, name=name)

        shard_indices = self._shard_indices(keys)
        key_shards = data_flow_ops.dynamic_partition(keys, shard_indices,
                                                     num_shards)
        value_shards = [
            self._table_shards[i].lookup(key_shards[i], name=name)
            for i in range(num_shards)
        ]

        num_keys = array_ops.shape(keys)[0]
        original_indices = math_ops.range(num_keys)
        partitioned_indices = data_flow_ops.dynamic_partition(
            original_indices, shard_indices, num_shards)
        return data_flow_ops.dynamic_stitch(partitioned_indices, value_shards)
Ejemplo n.º 41
0
    def minimize(self, global_step=None, name=None):
        """Add operations to train a linear model by minimizing the loss function.

    Args:
      global_step: Optional `Variable` to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation.

    Returns:
      An Operation that updates the variables passed in the constructor.
    """
        # Technically, the op depends on a lot more than the variables,
        # but we'll keep the list short.
        with name_scope(name, 'sdca/minimize'):
            sparse_example_indices = []
            sparse_feature_indices = []
            sparse_features_values = []
            for sf in self._examples['sparse_features']:
                sparse_example_indices.append(sf.example_indices)
                sparse_feature_indices.append(sf.feature_indices)
                # If feature values are missing, sdca assumes a value of 1.0f.
                if sf.feature_values is not None:
                    sparse_features_values.append(sf.feature_values)

            # pylint: disable=protected-access
            example_ids_hashed = gen_sdca_ops.sdca_fprint(
                internal_convert_to_tensor(self._examples['example_ids']))
            # pylint: enable=protected-access
            example_state_data = self._hashtable.lookup(example_ids_hashed)
            # Solver returns example_state_update, new delta sparse_feature_weights
            # and delta dense_feature_weights.

            sparse_weights = []
            sparse_indices = []
            # If we have partitioned variables, keep a few dictionaries of Tensors
            # around that we need for the assign_add after the op call to
            # gen_sdca_ops.sdca_optimizer().  These are keyed because we may have a
            # mix of partitioned and un-partitioned variables.
            num_partitions_by_var = {}
            p_assignments_by_var = {}
            gather_ids_by_var = {}
            for v_num, (w, i) in enumerate(
                    zip(self._slots['unshrinked_sparse_features_weights'],
                        sparse_feature_indices)):
                # Append the sparse_indices (in full-variable space).
                sparse_idx = math_ops.cast(
                    array_ops.unique(math_ops.cast(i, dtypes.int32))[0],
                    dtypes.int64)
                sparse_indices.append(sparse_idx)
                if isinstance(w, list) or isinstance(
                        w, var_ops.PartitionedVariable):
                    num_partitions = len(w)
                    flat_ids = array_ops.reshape(sparse_idx, [-1])
                    # We use div partitioning, which is easiest to support downstream.
                    # Compute num_total_ids as the sum of dim-0 of w, then assign
                    # to partitions based on a constant number of ids per partition.
                    # Optimize if we already know the full shape statically.
                    dim_0_size = self._get_first_dimension_size_statically(
                        w, num_partitions)

                    if tensor_shape.dimension_value(dim_0_size):
                        num_total_ids = constant_op.constant(
                            tensor_shape.dimension_value(dim_0_size),
                            flat_ids.dtype)
                    else:
                        dim_0_sizes = []
                        for p in range(num_partitions):
                            if tensor_shape.dimension_value(
                                    w[p].shape[0]) is not None:
                                dim_0_sizes.append(
                                    tensor_shape.dimension_value(
                                        w[p].shape[0]))
                            else:
                                with ops.colocate_with(w[p]):
                                    dim_0_sizes.append(
                                        array_ops.shape(w[p])[0])
                        num_total_ids = math_ops.reduce_sum(
                            math_ops.cast(array_ops.stack(dim_0_sizes),
                                          flat_ids.dtype))
                    ids_per_partition = num_total_ids // num_partitions
                    extras = num_total_ids % num_partitions

                    p_assignments = math_ops.maximum(
                        flat_ids // (ids_per_partition + 1),
                        (flat_ids - extras) // ids_per_partition)

                    # Emulate a conditional using a boolean indicator tensor
                    new_ids = array_ops.where(
                        p_assignments < extras,
                        flat_ids % (ids_per_partition + 1),
                        (flat_ids - extras) % ids_per_partition)

                    # Cast partition assignments to int32 for use in dynamic_partition.
                    # There really should not be more than 2^32 partitions.
                    p_assignments = math_ops.cast(p_assignments, dtypes.int32)
                    # Partition list of ids based on assignments into num_partitions
                    # separate lists.
                    gather_ids = data_flow_ops.dynamic_partition(
                        new_ids, p_assignments, num_partitions)
                    # Add these into the dictionaries for use in the later update.
                    num_partitions_by_var[v_num] = num_partitions
                    p_assignments_by_var[v_num] = p_assignments
                    gather_ids_by_var[v_num] = gather_ids

                    # Gather the weights from each partition.
                    partition_gathered_weights = []
                    for p in range(num_partitions):
                        with ops.colocate_with(w[p]):
                            partition_gathered_weights.append(
                                array_ops.gather(w[p], gather_ids[p]))

                    # Stitch the weights back together in the same order they were before
                    # we dynamic_partitioned them.
                    condition_indices = data_flow_ops.dynamic_partition(
                        math_ops.range(array_ops.shape(new_ids)[0]),
                        p_assignments, num_partitions)
                    batch_gathered_weights = data_flow_ops.dynamic_stitch(
                        condition_indices, partition_gathered_weights)
                else:
                    w_as_tensor = internal_convert_to_tensor(w)
                    with ops.device(w_as_tensor.device):
                        batch_gathered_weights = array_ops.gather(
                            w_as_tensor, sparse_idx)
                sparse_weights.append(batch_gathered_weights)

            # pylint: disable=protected-access
            if compat.forward_compatible(year=2018, month=10, day=30):
                esu, sfw, dfw = gen_sdca_ops.sdca_optimizer_v2(
                    sparse_example_indices,
                    sparse_feature_indices,
                    sparse_features_values,
                    self._convert_n_to_tensor(
                        self._examples['dense_features']),
                    internal_convert_to_tensor(
                        self._examples['example_weights']),
                    internal_convert_to_tensor(
                        self._examples['example_labels']),
                    sparse_indices,
                    sparse_weights,
                    self._convert_n_to_tensor(
                        self._slots['unshrinked_dense_features_weights']),
                    example_state_data,
                    loss_type=self._options['loss_type'],
                    l1=self._options['symmetric_l1_regularization'],
                    l2=self._symmetric_l2_regularization(),
                    num_loss_partitions=self._num_loss_partitions(),
                    num_inner_iterations=1,
                    adaptive=self._adaptive())
            else:
                esu, sfw, dfw = gen_sdca_ops.sdca_optimizer(
                    sparse_example_indices,
                    sparse_feature_indices,
                    sparse_features_values,
                    self._convert_n_to_tensor(
                        self._examples['dense_features']),
                    internal_convert_to_tensor(
                        self._examples['example_weights']),
                    internal_convert_to_tensor(
                        self._examples['example_labels']),
                    sparse_indices,
                    sparse_weights,
                    self._convert_n_to_tensor(
                        self._slots['unshrinked_dense_features_weights']),
                    example_state_data,
                    loss_type=self._options['loss_type'],
                    l1=self._options['symmetric_l1_regularization'],
                    l2=self._symmetric_l2_regularization(),
                    num_loss_partitions=self._num_loss_partitions(),
                    num_inner_iterations=1,
                    adaptative=self._adaptive())
            # pylint: enable=protected-access

            with ops.control_dependencies([esu]):
                update_ops = [self._hashtable.insert(example_ids_hashed, esu)]
                # Update the weights before the proximal step.
                for v_num, (w, i, u) in enumerate(
                        zip(self._slots['unshrinked_sparse_features_weights'],
                            sparse_indices, sfw)):
                    if (isinstance(w, var_ops.PartitionedVariable)
                            or isinstance(w, list)):
                        update_ops += self._get_partitioned_update_ops(
                            v_num, num_partitions_by_var, p_assignments_by_var,
                            gather_ids_by_var, w, u, p_assignments,
                            num_partitions)
                    else:
                        update_ops.append(state_ops.scatter_add(w, i, u))
                for w, u in zip(
                        self._slots['unshrinked_dense_features_weights'], dfw):
                    if (isinstance(w, var_ops.PartitionedVariable)
                            or isinstance(w, list)):
                        split_updates = array_ops.split(
                            u,
                            num_or_size_splits=[
                                v.shape.as_list()[0] for v in w
                            ])
                        for v, split_update in zip(w, split_updates):
                            update_ops.append(
                                state_ops.assign_add(v, split_update))
                    else:
                        update_ops.append(state_ops.assign_add(w, u))
            if not global_step:
                return control_flow_ops.group(*update_ops)
            with ops.control_dependencies(update_ops):
                return state_ops.assign_add(global_step, 1, name=name).op
Ejemplo n.º 42
0
    def _sample_n(self, n, seed=None):
        if self._use_static_graph:
            # This sampling approach is almost the same as the approach used by
            # `MixtureSameFamily`. The differences are due to having a list of
            # `Distribution` objects rather than a single object, and maintaining
            # random seed management that is consistent with the non-static code path.
            samples = []
            cat_samples = self.cat.sample(n, seed=seed)
            for c in range(self.num_components):
                seed = distribution_util.gen_new_seed(seed, "mixture")
                samples.append(self.components[c].sample(n, seed=seed))
            x = array_ops.stack(samples, -self._static_event_shape.ndims -
                                1)  # [n, B, k, E]
            npdt = x.dtype.as_numpy_dtype
            mask = array_ops.one_hot(
                indices=cat_samples,  # [n, B]
                depth=self._num_components,  # == k
                on_value=np.ones([], dtype=npdt),
                off_value=np.zeros([], dtype=npdt))  # [n, B, k]
            mask = distribution_utils.pad_mixture_dimensions(
                mask, self, self._cat,
                self._static_event_shape.ndims)  # [n, B, k, [1]*e]
            return math_ops.reduce_sum(
                x * mask,
                axis=-1 - self._static_event_shape.ndims)  # [n, B, E]

        with ops.control_dependencies(self._assertions):
            n = ops.convert_to_tensor(n, name="n")
            static_n = tensor_util.constant_value(n)
            n = int(static_n) if static_n is not None else n
            cat_samples = self.cat.sample(n, seed=seed)

            static_samples_shape = cat_samples.get_shape()
            if static_samples_shape.is_fully_defined():
                samples_shape = static_samples_shape.as_list()
                samples_size = static_samples_shape.num_elements()
            else:
                samples_shape = array_ops.shape(cat_samples)
                samples_size = array_ops.size(cat_samples)
            static_batch_shape = self.batch_shape
            if static_batch_shape.is_fully_defined():
                batch_shape = static_batch_shape.as_list()
                batch_size = static_batch_shape.num_elements()
            else:
                batch_shape = self.batch_shape_tensor()
                batch_size = math_ops.reduce_prod(batch_shape)
            static_event_shape = self.event_shape
            if static_event_shape.is_fully_defined():
                event_shape = np.array(static_event_shape.as_list(),
                                       dtype=np.int32)
            else:
                event_shape = self.event_shape_tensor()

            # Get indices into the raw cat sampling tensor. We will
            # need these to stitch sample values back out after sampling
            # within the component partitions.
            samples_raw_indices = array_ops.reshape(
                math_ops.range(0, samples_size), samples_shape)

            # Partition the raw indices so that we can use
            # dynamic_stitch later to reconstruct the samples from the
            # known partitions.
            partitioned_samples_indices = data_flow_ops.dynamic_partition(
                data=samples_raw_indices,
                partitions=cat_samples,
                num_partitions=self.num_components)

            # Copy the batch indices n times, as we will need to know
            # these to pull out the appropriate rows within the
            # component partitions.
            batch_raw_indices = array_ops.reshape(
                array_ops.tile(math_ops.range(0, batch_size), [n]),
                samples_shape)

            # Explanation of the dynamic partitioning below:
            #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
            # Suppose partitions are:
            #     [1 1 0 0 1 1]
            # After partitioning, batch indices are cut as:
            #     [batch_indices[x] for x in 2, 3]
            #     [batch_indices[x] for x in 0, 1, 4, 5]
            # i.e.
            #     [1 1] and [0 0 0 0]
            # Now we sample n=2 from part 0 and n=4 from part 1.
            # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
            # and for part 1 we want samples from batch entries 0, 0, 0, 0
            #   (samples 0, 1, 2, 3).
            partitioned_batch_indices = data_flow_ops.dynamic_partition(
                data=batch_raw_indices,
                partitions=cat_samples,
                num_partitions=self.num_components)
            samples_class = [None for _ in range(self.num_components)]

            for c in range(self.num_components):
                n_class = array_ops.size(partitioned_samples_indices[c])
                seed = distribution_util.gen_new_seed(seed, "mixture")
                samples_class_c = self.components[c].sample(n_class, seed=seed)

                # Pull out the correct batch entries from each index.
                # To do this, we may have to flatten the batch shape.

                # For sample s, batch element b of component c, we get the
                # partitioned batch indices from
                # partitioned_batch_indices[c]; and shift each element by
                # the sample index. The final lookup can be thought of as
                # a matrix gather along locations (s, b) in
                # samples_class_c where the n_class rows correspond to
                # samples within this component and the batch_size columns
                # correspond to batch elements within the component.
                #
                # Thus the lookup index is
                #   lookup[c, i] = batch_size * s[i] + b[c, i]
                # for i = 0 ... n_class[c] - 1.
                lookup_partitioned_batch_indices = (
                    batch_size * math_ops.range(n_class) +
                    partitioned_batch_indices[c])
                samples_class_c = array_ops.reshape(
                    samples_class_c,
                    array_ops.concat([[n_class * batch_size], event_shape], 0))
                samples_class_c = array_ops.gather(
                    samples_class_c,
                    lookup_partitioned_batch_indices,
                    name="samples_class_c_gather")
                samples_class[c] = samples_class_c

            # Stitch back together the samples across the components.
            lhs_flat_ret = data_flow_ops.dynamic_stitch(
                indices=partitioned_samples_indices, data=samples_class)
            # Reshape back to proper sample, batch, and event shape.
            ret = array_ops.reshape(
                lhs_flat_ret,
                array_ops.concat(
                    [samples_shape, self.event_shape_tensor()], 0))
            ret.set_shape(
                tensor_shape.TensorShape(static_samples_shape).concatenate(
                    self.event_shape))
            return ret
def _stitch(values, indices):
    if len(values) == 1:
        return values[0]
    with ops.colocate_with(indices[0], ignore_existing=True):
        all_values = data_flow_ops.dynamic_stitch(indices, values)
    return all_values
Ejemplo n.º 44
0
def embedding_lookup(params, ids, partition_strategy="mod", name=None,
                     validate_indices=True, max_norm=None):
  """Looks up `ids` in a list of embedding tensors.

  This function is used to perform parallel lookups on the list of
  tensors in `params`.  It is a generalization of
  [`tf.gather()`](../../api_docs/python/array_ops.md#gather), where `params` is
  interpreted as a partitioning of a large embedding tensor.  `params` may be
  a `PartitionedVariable` as returned by using `tf.get_variable()` with a
  partitioner.

  If `len(params) > 1`, each element `id` of `ids` is partitioned between
  the elements of `params` according to the `partition_strategy`.
  In all strategies, if the id space does not evenly divide the number of
  partitions, each of the first `(max_id + 1) % len(params)` partitions will
  be assigned one more id.

  If `partition_strategy` is `"mod"`, we assign each id to partition
  `p = id % len(params)`. For instance,
  13 ids are split across 5 partitions as:
  `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`

  If `partition_strategy` is `"div"`, we assign ids to partitions in a
  contiguous manner. In this case, 13 ids are split across 5 partitions as:
  `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`

  The results of the lookup are concatenated into a dense
  tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.

  Args:
    params: A list of tensors with the same type and which can be concatenated
      along dimension 0. Alternatively, a `PartitionedVariable`, created by
      partitioning along dimension 0.  Each element must be appropriately sized
      for the given `partition_strategy`.
    ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked
      up in `params`.
    partition_strategy: A string specifying the partitioning strategy, relevant
      if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
      is `"mod"`.
    name: A name for the operation (optional).
    validate_indices: Whether or not to validate gather indices.
    max_norm: If not None, embedding values are l2-normalized to the value of
     max_norm.

  Returns:
    A `Tensor` with the same type as the tensors in `params`.

  Raises:
    ValueError: If `params` is empty.
  """
  if params is None or params == []:  # pylint: disable=g-explicit-bool-comparison
    raise ValueError("Need at least one param")
  if isinstance(params, variables.PartitionedVariable):
    params = list(params)  # Iterate to get the underlying Variables.
  if not isinstance(params, list):
    params = [params]
  def maybe_normalize(x):
    if max_norm is not None:
      if x.get_shape().ndims is not None:
        ndims = x.get_shape().ndims
      else:
        ndims = array_ops.size(array_ops.shape(x))
      return clip_ops.clip_by_norm(x, max_norm, axes=list(range(1, ndims)))
    return x
  with ops.name_scope(name, "embedding_lookup", params + [ids]) as name:
    np = len(params)  # Number of partitions
    params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
    if np == 1:
      with ops.colocate_with(params[0]):
        # TODO(apassos): implement the sharded version as well.
        if isinstance(params[0], resource_variable_ops.ResourceVariable):
          ret = params[0].sparse_read(ids, name=name)
        else:
          ret = array_ops.gather(params[0], ids, name=name,
                                 validate_indices=validate_indices)
      return maybe_normalize(ret)
    else:
      ids = ops.convert_to_tensor(ids, name="ids")
      flat_ids = array_ops.reshape(ids, [-1])
      original_indices = math_ops.range(array_ops.size(flat_ids))

      # Create p_assignments and set new_ids depending on the strategy.
      if partition_strategy == "mod":
        p_assignments = flat_ids % np
        new_ids = flat_ids // np
      elif partition_strategy == "div":
        # Compute num_total_ids as the sum of dim-0 of params, then assign to
        # partitions based on a constant number of ids per partition. Optimize
        # if we already know the full shape statically.
        dim_0_size = params[0].get_shape()[0]
        for p in xrange(1, np):
          dim_0_size += params[p].get_shape()[0]
        if dim_0_size.value:
          num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
        else:
          dim_0_sizes = []
          for p in xrange(np):
            if params[p].get_shape()[0].value is not None:
              dim_0_sizes.append(params[p].get_shape()[0].value)
            else:
              with ops.colocate_with(params[p]):
                dim_0_sizes.append(array_ops.shape(params[p])[0])
          num_total_ids = math_ops.reduce_sum(
              math_ops.cast(array_ops.pack(dim_0_sizes), flat_ids.dtype))
        ids_per_partition = num_total_ids // np
        extras = num_total_ids % np

        p_assignments = math_ops.maximum(
            flat_ids // (ids_per_partition + 1),
            (flat_ids - extras) // ids_per_partition)

        # Emulate a conditional using a boolean indicator tensor
        is_in_first_extras_partitions = math_ops.cast(
            p_assignments < extras, flat_ids.dtype)
        new_ids = (
            is_in_first_extras_partitions * (
                flat_ids % (ids_per_partition + 1)) +
            (1 - is_in_first_extras_partitions) * (
                (flat_ids - extras) % ids_per_partition))
      else:
        raise ValueError("Unrecognized partition strategy: " +
                         partition_strategy)

      # Cast partition assignments to int32 for use in dynamic_partition.
      # There really should not be more than 2^32 partitions.
      p_assignments = math_ops.cast(p_assignments, dtypes.int32)
      # Partition list of ids based on assignments into np separate lists
      gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
      # Similarly, partition the original indices.
      pindices = data_flow_ops.dynamic_partition(original_indices,
                                                 p_assignments, np)
      # Do np separate lookups, finding embeddings for plist[p] in params[p]
      partitioned_result = []
      for p in xrange(np):
        with ops.colocate_with(params[p]):
          partitioned_result.append(array_ops.gather(
              params[p], gather_ids[p],
              validate_indices=validate_indices))
      # Stitch these back together
      ret = data_flow_ops.dynamic_stitch(pindices, partitioned_result,
                                         name=name)
      # Reshape to reverse the flattening of ids.
      element_shape = params[0].get_shape()[1:]
      for p in params[1:]:
        element_shape = element_shape.merge_with(p.get_shape()[1:])
      if element_shape.is_fully_defined():
        ret = array_ops.reshape(ret, array_ops.concat(0, [
            array_ops.shape(ids), element_shape]))
      else:
        # It's important that we compute params[0].shape on the right device
        # to avoid data motion.
        with ops.colocate_with(params[0]):
          params_shape = array_ops.shape(params[0])
        ret = array_ops.reshape(ret, array_ops.concat(0, [
            array_ops.shape(ids), array_ops.slice(params_shape, [1], [-1])]))
      # output shape = ids.shape + params[*].shape[1:]
      # Normally the reshape is sufficient, but setting shape explicitly
      # teaches shape inference that params[1:].get_shape() matters.
      ret.set_shape(ids.get_shape().concatenate(element_shape))
      return maybe_normalize(ret)
Ejemplo n.º 45
0
    def _sample_n(self, n, seed=None):
        with ops.control_dependencies(self._assertions):
            n = ops.convert_to_tensor(n, name="n")
            static_n = tensor_util.constant_value(n)
            n = int(static_n) if static_n is not None else n
            pi_samples = self.pi.sample(n, seed=seed)

            static_samples_shape = pi_samples.get_shape()
            if static_samples_shape.is_fully_defined():
                samples_shape = static_samples_shape.as_list()
                samples_size = static_samples_shape.num_elements()
            else:
                samples_shape = array_ops.shape(pi_samples)
                samples_size = array_ops.size(pi_samples)
            static_batch_shape = self.batch_shape
            if static_batch_shape.is_fully_defined():
                batch_shape = static_batch_shape.as_list()
                batch_size = static_batch_shape.num_elements()
            else:
                batch_shape = self.batch_shape_tensor()
                batch_size = math_ops.reduce_prod(batch_shape)
            static_event_shape = self.event_shape
            if static_event_shape.is_fully_defined():
                event_shape = np.array(static_event_shape.as_list(),
                                       dtype=np.int32)
            else:
                event_shape = self.event_shape_tensor()

            # Get indices into the raw pi sampling tensor. We will
            # need these to stitch sample values back out after sampling
            # within the component partitions.
            samples_raw_indices = array_ops.reshape(
                math_ops.range(0, samples_size), samples_shape)

            # Partition the raw indices so that we can use
            # dynamic_stitch later to reconstruct the samples from the
            # known partitions.
            partitioned_samples_indices = data_flow_ops.dynamic_partition(
                data=samples_raw_indices,
                partitions=pi_samples,
                num_partitions=self.num_dist)

            # Copy the batch indices n times, as we will need to know
            # these to pull out the appropriate rows within the
            # component partitions.
            batch_raw_indices = array_ops.reshape(
                array_ops.tile(math_ops.range(0, batch_size), [n]),
                samples_shape)

            # Explanation of the dynamic partitioning below:
            #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
            # Suppose partitions are:
            #     [1 1 0 0 1 1]
            # After partitioning, batch indices are cut as:
            #     [batch_indices[x] for x in 2, 3]
            #     [batch_indices[x] for x in 0, 1, 4, 5]
            # i.e.
            #     [1 1] and [0 0 0 0]
            # Now we sample n=2 from part 0 and n=4 from part 1.
            # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
            # and for part 1 we want samples from batch entries 0, 0, 0, 0
            #   (samples 0, 1, 2, 3).
            partitioned_batch_indices = data_flow_ops.dynamic_partition(
                data=batch_raw_indices,
                partitions=pi_samples,
                num_partitions=self.num_dist)
            samples_class = [None for _ in range(self.num_dist)]

            for c in range(self.num_dist):
                n_class = array_ops.size(partitioned_samples_indices[c])
                seed = distribution_util.gen_new_seed(seed, "ZeroInflated")
                samples_class_c = self.dist[c].sample(n_class, seed=seed)

                # Pull out the correct batch entries from each index.
                # To do this, we may have to flatten the batch shape.

                # For sample s, batch element b of component c, we get the
                # partitioned batch indices from
                # partitioned_batch_indices[c]; and shift each element by
                # the sample index. The final lookup can be thought of as
                # a matrix gather along lopiions (s, b) in
                # samples_class_c where the n_class rows correspond to
                # samples within this component and the batch_size columns
                # correspond to batch elements within the component.
                #
                # Thus the lookup index is
                #   lookup[c, i] = batch_size * s[i] + b[c, i]
                # for i = 0 ... n_class[c] - 1.
                lookup_partitioned_batch_indices = (
                    batch_size * math_ops.range(n_class) +
                    partitioned_batch_indices[c])
                samples_class_c = array_ops.reshape(
                    samples_class_c,
                    array_ops.conpi([[n_class * batch_size], event_shape], 0))
                samples_class_c = array_ops.gather(
                    samples_class_c,
                    lookup_partitioned_batch_indices,
                    name="samples_class_c_gather")
                samples_class[c] = samples_class_c

            # Stitch back together the samples across the dist.
            lhs_flat_ret = data_flow_ops.dynamic_stitch(
                indices=partitioned_samples_indices, data=samples_class)
            # Reshape back to proper sample, batch, and event shape.
            ret = array_ops.reshape(
                lhs_flat_ret,
                array_ops.conpi(
                    [samples_shape, self.event_shape_tensor()], 0))
            ret.set_shape(
                tensor_shape.TensorShape(static_samples_shape).conpienate(
                    self.event_shape))
            return ret
Ejemplo n.º 46
0
  def minimize(self, global_step=None, name=None):
    """Add operations to train a linear model by minimizing the loss function.

    Args:
      global_step: Optional `Variable` to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation.

    Returns:
      An Operation that updates the variables passed in the constructor.
    """
    # Technically, the op depends on a lot more than the variables,
    # but we'll keep the list short.
    with name_scope(name, 'sdca/minimize'):
      sparse_example_indices = []
      sparse_feature_indices = []
      sparse_features_values = []
      for sf in self._examples['sparse_features']:
        sparse_example_indices.append(sf.example_indices)
        sparse_feature_indices.append(sf.feature_indices)
        # If feature values are missing, sdca assumes a value of 1.0f.
        if sf.feature_values is not None:
          sparse_features_values.append(sf.feature_values)

      # pylint: disable=protected-access
      example_ids_hashed = gen_sdca_ops.sdca_fprint(
          internal_convert_to_tensor(self._examples['example_ids']))
      # pylint: enable=protected-access
      example_state_data = self._hashtable.lookup(example_ids_hashed)
      # Solver returns example_state_update, new delta sparse_feature_weights
      # and delta dense_feature_weights.

      sparse_weights = []
      sparse_indices = []
      # If we have partitioned variables, keep a few dictionaries of Tensors
      # around that we need for the assign_add after the op call to
      # gen_sdca_ops.sdca_optimizer().  These are keyed because we may have a
      # mix of partitioned and un-partitioned variables.
      num_partitions_by_var = {}
      p_assignments_by_var = {}
      gather_ids_by_var = {}
      for v_num, (w, i) in enumerate(
          zip(self._slots['unshrinked_sparse_features_weights'],
              sparse_feature_indices)):
        # Append the sparse_indices (in full-variable space).
        sparse_idx = math_ops.cast(
            array_ops.unique(math_ops.cast(i, dtypes.int32))[0],
            dtypes.int64)
        sparse_indices.append(sparse_idx)
        if isinstance(w, list) or isinstance(w, var_ops.PartitionedVariable):
          num_partitions = len(w)
          flat_ids = array_ops.reshape(sparse_idx, [-1])
          # We use div partitioning, which is easiest to support downstream.
          # Compute num_total_ids as the sum of dim-0 of w, then assign
          # to partitions based on a constant number of ids per partition.
          # Optimize if we already know the full shape statically.
          dim_0_size = self._get_first_dimension_size_statically(
              w, num_partitions)

          if tensor_shape.dimension_value(dim_0_size):
            num_total_ids = constant_op.constant(
                tensor_shape.dimension_value(dim_0_size),
                flat_ids.dtype)
          else:
            dim_0_sizes = []
            for p in range(num_partitions):
              if tensor_shape.dimension_value(w[p].shape[0]) is not None:
                dim_0_sizes.append(tensor_shape.dimension_value(w[p].shape[0]))
              else:
                with ops.colocate_with(w[p]):
                  dim_0_sizes.append(array_ops.shape(w[p])[0])
            num_total_ids = math_ops.reduce_sum(
                math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
          ids_per_partition = num_total_ids // num_partitions
          extras = num_total_ids % num_partitions

          p_assignments = math_ops.maximum(
              flat_ids // (ids_per_partition + 1),
              (flat_ids - extras) // ids_per_partition)

          # Emulate a conditional using a boolean indicator tensor
          new_ids = array_ops.where(p_assignments < extras,
                                    flat_ids % (ids_per_partition + 1),
                                    (flat_ids - extras) % ids_per_partition)

          # Cast partition assignments to int32 for use in dynamic_partition.
          # There really should not be more than 2^32 partitions.
          p_assignments = math_ops.cast(p_assignments, dtypes.int32)
          # Partition list of ids based on assignments into num_partitions
          # separate lists.
          gather_ids = data_flow_ops.dynamic_partition(new_ids,
                                                       p_assignments,
                                                       num_partitions)
          # Add these into the dictionaries for use in the later update.
          num_partitions_by_var[v_num] = num_partitions
          p_assignments_by_var[v_num] = p_assignments
          gather_ids_by_var[v_num] = gather_ids

          # Gather the weights from each partition.
          partition_gathered_weights = []
          for p in range(num_partitions):
            with ops.colocate_with(w[p]):
              partition_gathered_weights.append(
                  array_ops.gather(w[p], gather_ids[p]))

          # Stitch the weights back together in the same order they were before
          # we dynamic_partitioned them.
          condition_indices = data_flow_ops.dynamic_partition(
              math_ops.range(array_ops.shape(new_ids)[0]),
              p_assignments, num_partitions)
          batch_gathered_weights = data_flow_ops.dynamic_stitch(
              condition_indices, partition_gathered_weights)
        else:
          w_as_tensor = internal_convert_to_tensor(w)
          with ops.device(w_as_tensor.device):
            batch_gathered_weights = array_ops.gather(
                w_as_tensor, sparse_idx)
        sparse_weights.append(batch_gathered_weights)

      # pylint: disable=protected-access
      if compat.forward_compatible(year=2018, month=10, day=30):
        esu, sfw, dfw = gen_sdca_ops.sdca_optimizer_v2(
            sparse_example_indices,
            sparse_feature_indices,
            sparse_features_values,
            self._convert_n_to_tensor(self._examples['dense_features']),
            internal_convert_to_tensor(self._examples['example_weights']),
            internal_convert_to_tensor(self._examples['example_labels']),
            sparse_indices,
            sparse_weights,
            self._convert_n_to_tensor(self._slots[
                'unshrinked_dense_features_weights']),
            example_state_data,
            loss_type=self._options['loss_type'],
            l1=self._options['symmetric_l1_regularization'],
            l2=self._symmetric_l2_regularization(),
            num_loss_partitions=self._num_loss_partitions(),
            num_inner_iterations=1,
            adaptive=self._adaptive())
      else:
        esu, sfw, dfw = gen_sdca_ops.sdca_optimizer(
            sparse_example_indices,
            sparse_feature_indices,
            sparse_features_values,
            self._convert_n_to_tensor(self._examples['dense_features']),
            internal_convert_to_tensor(self._examples['example_weights']),
            internal_convert_to_tensor(self._examples['example_labels']),
            sparse_indices,
            sparse_weights,
            self._convert_n_to_tensor(self._slots[
                'unshrinked_dense_features_weights']),
            example_state_data,
            loss_type=self._options['loss_type'],
            l1=self._options['symmetric_l1_regularization'],
            l2=self._symmetric_l2_regularization(),
            num_loss_partitions=self._num_loss_partitions(),
            num_inner_iterations=1,
            adaptative=self._adaptive())
      # pylint: enable=protected-access

      with ops.control_dependencies([esu]):
        update_ops = [self._hashtable.insert(example_ids_hashed, esu)]
        # Update the weights before the proximal step.
        for v_num, (w, i, u) in enumerate(
            zip(self._slots['unshrinked_sparse_features_weights'],
                sparse_indices, sfw)):
          if (isinstance(w, var_ops.PartitionedVariable) or
              isinstance(w, list)):
            update_ops += self._get_partitioned_update_ops(
                v_num, num_partitions_by_var, p_assignments_by_var,
                gather_ids_by_var, w, u, p_assignments, num_partitions)
          else:
            update_ops.append(state_ops.scatter_add(w, i, u))
        for w, u in zip(self._slots['unshrinked_dense_features_weights'], dfw):
          if (isinstance(w, var_ops.PartitionedVariable) or
              isinstance(w, list)):
            split_updates = array_ops.split(
                u, num_or_size_splits=[v.shape.as_list()[0] for v in w])
            for v, split_update in zip(w, split_updates):
              update_ops.append(state_ops.assign_add(v, split_update))
          else:
            update_ops.append(state_ops.assign_add(w, u))
      if not global_step:
        return control_flow_ops.group(*update_ops)
      with ops.control_dependencies(update_ops):
        return state_ops.assign_add(global_step, 1, name=name).op
Ejemplo n.º 47
0
def one_dimensional_calibration_layer(uncalibrated_tensor,
                                      num_keypoints,
                                      signal_name,
                                      keypoints_initializers=None,
                                      keypoints_initializer_fns=None,
                                      bound=False,
                                      monotonic=None,
                                      missing_input_value=None,
                                      missing_output_value=None,
                                      **regularizer_amounts):
    """Creates a calibration layer for one single continuous signal.

  Returns a calibrated tensor of the uncalibrated continuous signal and a list
  of projections ops.

  Args:
    uncalibrated_tensor: Tensor of shape [batch_size] of one single signal.
    num_keypoints: Number of keypoints to use.
    signal_name: (Required) Used as a suffix to the variable names.
    keypoints_initializers: For evaluation or inference (or when resuming
      training from a checkpoint) the values will be loaded from disk, so they
      don't need to be given -- but in this case num_keypoints need to be
      accurate. Two tensors of shape [num_keypoints]. See
      load_keypoints_from_quantiles or uniform_keypoints_for_signal on how to
      generate these (module keypoints_initialization).
    keypoints_initializer_fns: Like keypoints_initializers but using lambda
      initializers. They should be compatible with tf.get_variable. If this is
      set, then keypoints_initializers must be None.
    bound: boolean whether output of calibration must be bound. Alternatively
      a dict mapping feature name to boundness.
    monotonic: whether calibration has to be kept monotonic: None or 0 means
      no monotonicity. Positive or negative values mean increasing or decreasing
      monotonicity respectively. Alternatively a dict mapping feature name
      to monotonic.
    missing_input_value: If set, and if the input has this value it is assumed
      to be missing and the output will either be calibrated to some value
      between `[calibration_output_min, calibration_output_max]` or set to a
      fixed value set by missing_output_value. Limitation: it only works for
      scalars.
    missing_output_value: Requires missing_input_value also to be set. If set
      if will convert missing input to this value.
    **regularizer_amounts: Keyword args of regularization amounts passed to
      regularizers.calibrator_regularization(). Keyword names should be among
      supported regularizers.CALIBRATOR_REGULARIZERS and values should be
      float.

  Returns:
    A tuple of:
    * calibrated tensor of shape [batchsize]
    * None or projection ops, that must be applied at each
      step (or every so many steps) to project the model to a feasible space:
      used for bounding the outputs or for imposing monotonicity.
    * None of a regularization loss, if regularization is configured.

  Raises:
    ValueError: if dtypes are incompatible.
    ValueError: if keypoints_initializers and keypoints_initializer_fns are both
      set.




  """
    if (keypoints_initializers is not None
            and keypoints_initializer_fns is not None):
        raise ValueError(
            'keypoints_initializers and keypoints_initializer_fns '
            'cannot both be set.')
    with variable_scope.variable_scope('pwl_calibration'):
        # Sanity checks.
        if uncalibrated_tensor.get_shape().ndims != 1:
            raise ValueError(
                'one_dimensional_calibration_layer can only be used for a single '
                'signal, so uncalibrated shape must be of form (batchsize), got %s'
                % uncalibrated_tensor.get_shape())
        if missing_output_value is not None and missing_input_value is None:
            raise ValueError(
                'missing_output_value can only be set if a misisng_input_value is '
                'also set, missing_input_value=None, missing_output_values=%s'
                % missing_output_value)

        # Create variables: only uses initializer if they are given.
        kp_in_name = signal_name + '_keypoints_inputs'
        kp_out_name = signal_name + '_keypoints_outputs'
        missing_out_calibrated_name = signal_name + '_calibrated_missing_output'

        if keypoints_initializers is not None:
            kp_in, kp_out = keypoints_initializers[0], keypoints_initializers[
                1]
            if (uncalibrated_tensor.dtype != kp_in.dtype
                    or uncalibrated_tensor.dtype != kp_out.dtype):
                raise ValueError(
                    'incompatible types for signal \'%s\': uncalibrated=%s, '
                    'keypoints_initializers[input=%s, output=%s]' %
                    (signal_name, uncalibrated_tensor.dtype, kp_in.dtype,
                     kp_out.dtype))
            tools.assert_shape(kp_in, [num_keypoints],
                               'keypoints_initializers[input]')
            tools.assert_shape(kp_out, [num_keypoints],
                               'keypoints_initializers[output]')
            keypoints_inputs = variable_scope.get_variable(kp_in_name,
                                                           initializer=kp_in)
            keypoints_outputs = variable_scope.get_variable(kp_out_name,
                                                            initializer=kp_out)

            if missing_input_value is not None:
                # Value to be taken by missing features.
                if missing_output_value is not None:
                    missing_out_calibrated = constant_op.constant(
                        missing_output_value, dtype=uncalibrated_tensor.dtype)
                else:
                    # Learned missing value, initialized by the first value of kp_out.
                    missing_out_calibrated = variable_scope.get_variable(
                        missing_out_calibrated_name, initializer=kp_out[0])
        elif keypoints_initializer_fns is not None:
            kp_in, kp_out = keypoints_initializer_fns[
                0], keypoints_initializer_fns[1]
            keypoints_inputs = variable_scope.get_variable(
                kp_in_name, shape=[num_keypoints], initializer=kp_in)
            keypoints_outputs = variable_scope.get_variable(
                kp_out_name, shape=[num_keypoints], initializer=kp_out)

            if missing_input_value is not None:
                # Value to be taken by missing features.
                if missing_output_value is not None:
                    missing_out_calibrated = constant_op.constant(
                        missing_output_value, dtype=uncalibrated_tensor.dtype)
                else:
                    # Learned missing value, initialized by the first value of kp_out.
                    def first_kp_out(*args, **kwargs):
                        return kp_out(*args, **kwargs)[0]

                    missing_out_calibrated = variable_scope.get_variable(
                        missing_out_calibrated_name,
                        shape=[],
                        initializer=first_kp_out)
        else:
            # When loading a model, no initializer.
            keypoints_inputs = variable_scope.get_variable(
                kp_in_name,
                shape=[num_keypoints],
                dtype=uncalibrated_tensor.dtype)
            keypoints_outputs = variable_scope.get_variable(
                kp_out_name,
                shape=[num_keypoints],
                dtype=uncalibrated_tensor.dtype)
            if missing_input_value:
                if missing_output_value:
                    missing_out_calibrated = constant_op.constant(
                        missing_output_value, dtype=uncalibrated_tensor.dtype)
                else:
                    missing_out_calibrated = variable_scope.get_variable(
                        missing_out_calibrated_name,
                        shape=[],
                        dtype=uncalibrated_tensor.dtype)

        # Split missing values from normal values.
        # FutureWork: move handling of missing values be moved to C++ land.
        if missing_input_value is not None:
            missing_mask = math_ops.equal(
                uncalibrated_tensor, constant_op.constant(missing_input_value))
            mask_indices = math_ops.range(
                array_ops.shape(uncalibrated_tensor)[0])
            mask_indices = data_flow_ops.dynamic_partition(
                mask_indices, math_ops.cast(missing_mask, dtypes.int32), 2)
            (uncalibrated_tensor,
             missing_values) = data_flow_ops.dynamic_partition(
                 uncalibrated_tensor, math_ops.cast(missing_mask,
                                                    dtypes.int32), 2)

            # Assign value to missing_values.
            missing_values = array_ops.ones_like(missing_values)
            missing_values *= missing_out_calibrated

        # Dense implementation.
        interpolation = pwl_calibration_ops.pwl_indexing_calibrator(
            uncalibrated_tensor, keypoints_inputs)
        calibrated = math_ops.reduce_sum(interpolation * keypoints_outputs, 1)
        projection_ops = None

        # Re-join missing values.
        if missing_input_value is not None:
            calibrated = data_flow_ops.dynamic_stitch(
                mask_indices, [calibrated, missing_values])

        # Boundness.
        projected_keypoints_outputs = None
        if bound:
            bound_min_name = signal_name + '_bound_min'
            bound_max_name = signal_name + '_bound_max'
            # Set bound_min/max from min/max values initialized.
            if keypoints_initializers is not None:
                # Store bound_min and bound_max in variables because their values (from
                # kp_out) are only available during train (when keypoints_initializers
                # is available). During inference the value is not available. Storing
                # them in variables make them available during inference.
                bound_min = variable_scope.get_variable(
                    bound_min_name,
                    dtype=uncalibrated_tensor.dtype,
                    initializer=math_ops.reduce_min(kp_out))
                bound_max = variable_scope.get_variable(
                    bound_max_name,
                    dtype=uncalibrated_tensor.dtype,
                    initializer=math_ops.reduce_max(kp_out))
            elif keypoints_initializer_fns is not None:
                # Store bound_min and bound_max in variables because their values (from
                # kp_out) are only available during train (when keypoints_initializers
                # is available). During inference the value is not available. Storing
                # them in variables make them available during inference.
                def min_kp_out(*args, **kwargs):
                    return math_ops.reduce_min(kp_out(*args, **kwargs))

                def max_kp_out(*args, **kwargs):
                    return math_ops.reduce_max(kp_out(*args, **kwargs))

                bound_min = variable_scope.get_variable(
                    bound_min_name,
                    dtype=uncalibrated_tensor.dtype,
                    shape=[],
                    initializer=min_kp_out)
                bound_max = variable_scope.get_variable(
                    bound_max_name,
                    dtype=uncalibrated_tensor.dtype,
                    shape=[],
                    initializer=max_kp_out)
            else:
                # No need to initialize, since presumably their values will be read
                # from some checkpoint.
                bound_min = variable_scope.get_variable(
                    bound_min_name, dtype=uncalibrated_tensor.dtype, shape=[])
                bound_max = variable_scope.get_variable(
                    bound_max_name, dtype=uncalibrated_tensor.dtype, shape=[])
            projected_keypoints_outputs = math_ops.minimum(
                math_ops.maximum(keypoints_outputs, bound_min), bound_max)

        # Monotonicity.
        if monotonic:
            # First a soft-enforcement: might not break indirect constraints.
            if projected_keypoints_outputs is None:
                projected_keypoints_outputs = keypoints_outputs
            projected_keypoints_outputs = pwl_calibration_ops.monotonic_projection(
                increasing=bool(monotonic > 0),
                values=projected_keypoints_outputs,
                name='project_calibration_to_monotonic')

        # Make assing_add op to projected output.
        if projected_keypoints_outputs is not None:
            constrained_diff = projected_keypoints_outputs - keypoints_outputs
            projection_ops = state_ops.assign_add(keypoints_outputs,
                                                  constrained_diff,
                                                  use_locking=None,
                                                  name='project_feasible')
            if (bound and missing_input_value is not None
                    and missing_output_value is None):
                # Include op bounding calibrated missing value.
                projected_missing_out_calibrated = math_ops.minimum(
                    math_ops.maximum(missing_out_calibrated, bound_min),
                    bound_max)
                projected_missing_out_calibrated_diff = (
                    projected_missing_out_calibrated - missing_out_calibrated)
                projected_missing_out_calibrated_op = state_ops.assign_add(
                    missing_out_calibrated,
                    projected_missing_out_calibrated_diff,
                    use_locking=None,
                    name='project_missing_calibration_to_bounds')
                projection_ops = control_flow_ops.group(
                    projection_ops, projected_missing_out_calibrated_op)

        # Regularization
        regularization = regularizers.calibrator_regularization(
            keypoints_outputs,
            name=signal_name + '_calibrator_regularization',
            **regularizer_amounts)
    return calibrated, projection_ops, regularization
Ejemplo n.º 48
0
  def _sample_n(self, n, seed=None):
    with ops.control_dependencies(self._assertions):
      n = ops.convert_to_tensor(n, name="n")
      static_n = tensor_util.constant_value(n)
      n = int(static_n) if static_n is not None else n
      cat_samples = self.cat.sample(n, seed=seed)

      static_samples_shape = cat_samples.get_shape()
      if static_samples_shape.is_fully_defined():
        samples_shape = static_samples_shape.as_list()
        samples_size = static_samples_shape.num_elements()
      else:
        samples_shape = array_ops.shape(cat_samples)
        samples_size = array_ops.size(cat_samples)
      static_batch_shape = self.get_batch_shape()
      if static_batch_shape.is_fully_defined():
        batch_shape = static_batch_shape.as_list()
        batch_size = static_batch_shape.num_elements()
      else:
        batch_shape = self.batch_shape()
        batch_size = array_ops.reduce_prod(batch_shape)
      static_event_shape = self.get_event_shape()
      if static_event_shape.is_fully_defined():
        event_shape = np.array(static_event_shape.as_list(), dtype=np.int32)
      else:
        event_shape = self.event_shape()

      # Get indices into the raw cat sampling tensor.  We will
      # need these to stitch sample values back out after sampling
      # within the component partitions.
      samples_raw_indices = array_ops.reshape(
          math_ops.range(0, samples_size), samples_shape)

      # Partition the raw indices so that we can use
      # dynamic_stitch later to reconstruct the samples from the
      # known partitions.
      partitioned_samples_indices = data_flow_ops.dynamic_partition(
          data=samples_raw_indices,
          partitions=cat_samples,
          num_partitions=self.num_components)

      # Copy the batch indices n times, as we will need to know
      # these to pull out the appropriate rows within the
      # component partitions.
      batch_raw_indices = array_ops.reshape(
          array_ops.tile(math_ops.range(0, batch_size), [n]), samples_shape)

      # Explanation of the dynamic partitioning below:
      #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
      # Suppose partitions are:
      #     [1 1 0 0 1 1]
      # After partitioning, batch indices are cut as:
      #     [batch_indices[x] for x in 2, 3]
      #     [batch_indices[x] for x in 0, 1, 4, 5]
      # i.e.
      #     [1 1] and [0 0 0 0]
      # Now we sample n=2 from part 0 and n=4 from part 1.
      # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
      # and for part 1 we want samples from batch entries 0, 0, 0, 0
      #   (samples 0, 1, 2, 3).
      partitioned_batch_indices = data_flow_ops.dynamic_partition(
          data=batch_raw_indices,
          partitions=cat_samples,
          num_partitions=self.num_components)
      samples_class = [None for _ in range(self.num_components)]

      for c in range(self.num_components):
        n_class = array_ops.size(partitioned_samples_indices[c])
        seed = distribution_util.gen_new_seed(seed, "mixture")
        samples_class_c = self.components[c].sample(n_class, seed=seed)

        # Pull out the correct batch entries from each index.
        # To do this, we may have to flatten the batch shape.

        # For sample s, batch element b of component c, we get the
        # partitioned batch indices from
        # partitioned_batch_indices[c]; and shift each element by
        # the sample index.  The final lookup can be thought of as
        # a matrix gather along locations (s, b) in
        # samples_class_c where the n_class rows correspond to
        # samples within this component and the batch_size columns
        # correspond to batch elements within the component.
        #
        # Thus the lookup index is
        #   lookup[c, i] = batch_size * s[i] + b[c, i]
        # for i = 0 ... n_class[c] - 1.
        lookup_partitioned_batch_indices = (
            batch_size * math_ops.range(n_class) +
            partitioned_batch_indices[c])
        samples_class_c = array_ops.reshape(
            samples_class_c,
            array_ops.concat(([n_class * batch_size], event_shape), 0))
        samples_class_c = array_ops.gather(
            samples_class_c, lookup_partitioned_batch_indices,
            name="samples_class_c_gather")
        samples_class[c] = samples_class_c

      # Stitch back together the samples across the components.
      lhs_flat_ret = data_flow_ops.dynamic_stitch(
          indices=partitioned_samples_indices, data=samples_class)
      # Reshape back to proper sample, batch, and event shape.
      ret = array_ops.reshape(lhs_flat_ret,
                              array_ops.concat((samples_shape,
                                                self.event_shape()), 0))
      ret.set_shape(
          tensor_shape.TensorShape(static_samples_shape).concatenate(
              self.get_event_shape()))
      return ret
Ejemplo n.º 49
0
def _embedding_lookup_and_transform(params,
                                    ids,
                                    partition_strategy="mod",
                                    name=None,
                                    max_norm=None,
                                    transform_fn=None):
  """Helper function for embedding_lookup and _compute_sampled_logits.

  This function is a generalization of embedding_lookup that optionally
  applies a caller-specified transformation to each embedding. This is
  done through the `transform_fn` argument. If provided, the function is
  applied to each partitioned tensor of retrieved embeddings, colocated
  with the embeddings. This function will be called with a single `Tensor`
  argument of the same type as the `params` tensor and should return a
  `Tensor`. The shape of the argument will be the same as `params` except
  for the size of the first dimension. The first dimension of the result's
  shape must be the same size as the argument's.

  Args:
    params: See embedding_lookup.
    ids: See embedding_lookup.
    partition_strategy: See embedding_lookup.
    name: See embedding_lookup.
    max_norm: See embedding_lookup.
    transform_fn: An optional function to apply to each retrieved embedding.
      If max_norm is provided, transform_fn is applied to the norm-limited
      embeddings.

  Returns:
    See embedding_lookup for details.
  Raises:
    ValueError: If `params` is empty.
  """
  if params is None or params in ((), []):
    raise ValueError("Need at least one param")
  if isinstance(params, variables.PartitionedVariable):
    params = list(params)  # Iterate to get the underlying Variables.
  if not isinstance(params, list):
    params = [params]

  with ops.name_scope(name, "embedding_lookup", params + [ids]) as name:
    np = len(params)  # Number of partitions
    # Preserve the resource variable status to avoid accidental dense reads.
    if not any(
        isinstance(p, resource_variable_ops.ResourceVariable) for p in params):
      params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
    ids = ops.convert_to_tensor(ids, name="ids")
    if np == 1 and (not transform_fn or ids.get_shape().ndims == 1):
      with ops.colocate_with(params[0]):
        result = _clip(_gather(params[0], ids, name=name), ids, max_norm)
        if transform_fn:
          result = transform_fn(result)
        return result
    else:
      # Flatten the ids. There are two cases where we need to do this.
      # - There is more than one params tensor.
      # - There is a transform_fn and ids is not statically known to be 1-D.
      #   We must flatten in this case because transform_fn expects a flat
      #   tensor of embeddings.
      flat_ids = array_ops.reshape(ids, [-1])
      original_indices = math_ops.range(array_ops.size(flat_ids))

      # Create p_assignments and set new_ids depending on the strategy.
      if partition_strategy == "mod":
        p_assignments = flat_ids % np
        new_ids = flat_ids // np
      elif partition_strategy == "div":
        # Compute num_total_ids as the sum of dim-0 of params, then assign to
        # partitions based on a constant number of ids per partition. Optimize
        # if we already know the full shape statically.
        dim_0_size = params[0].get_shape()[0]
        for p in xrange(1, np):
          dim_0_size += params[p].get_shape()[0]
        if dim_0_size.value:
          num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
        else:
          dim_0_sizes = []
          for p in xrange(np):
            if params[p].get_shape()[0].value is not None:
              dim_0_sizes.append(params[p].get_shape()[0].value)
            else:
              with ops.colocate_with(params[p]):
                dim_0_sizes.append(array_ops.shape(params[p])[0])
          num_total_ids = math_ops.reduce_sum(
              math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
        ids_per_partition = num_total_ids // np
        extras = num_total_ids % np

        p_assignments = math_ops.maximum(
            flat_ids // (ids_per_partition + 1),
            (flat_ids - extras) // ids_per_partition)

        # Emulate a conditional using a boolean indicator tensor
        is_in_first_extras_partitions = math_ops.cast(p_assignments < extras,
                                                      flat_ids.dtype)
        new_ids = (is_in_first_extras_partitions * (flat_ids %
                                                    (ids_per_partition + 1)) +
                   (1 - is_in_first_extras_partitions) *
                   ((flat_ids - extras) % ids_per_partition))
      else:
        raise ValueError("Unrecognized partition strategy: " +
                         partition_strategy)

      # Cast partition assignments to int32 for use in dynamic_partition.
      # There really should not be more than 2^32 partitions.
      p_assignments = math_ops.cast(p_assignments, dtypes.int32)
      # Partition list of ids based on assignments into np separate lists
      gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
      # Similarly, partition the original indices.
      pindices = data_flow_ops.dynamic_partition(original_indices,
                                                 p_assignments, np)
      # Do np separate lookups, finding embeddings for plist[p] in params[p]
      partitioned_result = []
      for p in xrange(np):
        pids = gather_ids[p]
        with ops.colocate_with(params[p]):
          result = _gather(params[p], pids)
          if transform_fn:
            # If transform_fn is provided, the clip_by_norm precedes
            # the transform and hence must be co-located. See below
            # for the counterpart if transform_fn is not proveded.
            result = transform_fn(_clip(result, pids, max_norm))
        partitioned_result.append(result)
      # Stitch these back together
      ret = data_flow_ops.dynamic_stitch(
          pindices, partitioned_result, name=name)

      # Determine the static element shape.
      if transform_fn is None:
        element_shape_s = params[0].get_shape()[1:]
        for p in params[1:]:
          element_shape_s = element_shape_s.merge_with(p.get_shape()[1:])
      else:
        element_shape_s = ret.get_shape()[1:]

      # Compute the dynamic element shape.
      if element_shape_s.is_fully_defined():
        element_shape_d = element_shape_s
      elif transform_fn is None:
        # It's important that we compute params[0].shape on the right device
        # to avoid data motion.
        with ops.colocate_with(params[0]):
          params_shape = array_ops.shape(params[0])
        element_shape_d = params_shape[1:]
      else:
        element_shape_d = array_ops.shape(ret)[1:]

      # Reshape to reverse the flattening of ids.
      ret = array_ops.reshape(ret,
                              array_ops.concat(
                                  [array_ops.shape(ids), element_shape_d], 0))

      # Normally the reshape is sufficient, but setting shape explicitly
      # teaches shape inference that params[1:].get_shape() matters
      # (in the case that transform_fn is None).
      ret.set_shape(ids.get_shape().concatenate(element_shape_s))
      if not transform_fn:
        # If transform_fn was provided, the clip_by_norm was done above.
        ret = _clip(ret, ids, max_norm)
      return ret
Ejemplo n.º 50
0
def embedding_lookup(params,
                     ids,
                     partition_strategy="mod",
                     name=None,
                     validate_indices=True,
                     max_norm=None):
    """Looks up `ids` in a list of embedding tensors.

  This function is used to perform parallel lookups on the list of
  tensors in `params`.  It is a generalization of
  [`tf.gather()`](../../api_docs/python/array_ops.md#gather), where `params` is
  interpreted as a partitioning of a large embedding tensor.  `params` may be
  a `PartitionedVariable` as returned by using `tf.get_variable()` with a
  partitioner.

  If `len(params) > 1`, each element `id` of `ids` is partitioned between
  the elements of `params` according to the `partition_strategy`.
  In all strategies, if the id space does not evenly divide the number of
  partitions, each of the first `(max_id + 1) % len(params)` partitions will
  be assigned one more id.

  If `partition_strategy` is `"mod"`, we assign each id to partition
  `p = id % len(params)`. For instance,
  13 ids are split across 5 partitions as:
  `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`

  If `partition_strategy` is `"div"`, we assign ids to partitions in a
  contiguous manner. In this case, 13 ids are split across 5 partitions as:
  `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`

  The results of the lookup are concatenated into a dense
  tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.

  Args:
    params: A single tensor representing the complete embedding tensor,
      or a list of P tensors all of same shape except for the first dimension,
      representing sharded embedding tensors.  Alternatively, a
      `PartitionedVariable`, created by partitioning along dimension 0. Each
      element must be appropriately sized for the given `partition_strategy`.
    ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked
      up in `params`.
    partition_strategy: A string specifying the partitioning strategy, relevant
      if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
      is `"mod"`.
    name: A name for the operation (optional).
    validate_indices: Whether or not to validate gather indices.
    max_norm: If not None, embedding values are l2-normalized to the value of
     max_norm.

  Returns:
    A `Tensor` with the same type as the tensors in `params`.

  Raises:
    ValueError: If `params` is empty.
  """
    if params is None or params == []:  # pylint: disable=g-explicit-bool-comparison
        raise ValueError("Need at least one param")
    if isinstance(params, variables.PartitionedVariable):
        params = list(params)  # Iterate to get the underlying Variables.
    if not isinstance(params, list):
        params = [params]

    def maybe_normalize(x):
        if max_norm is not None:
            if x.get_shape().ndims is not None:
                ndims = x.get_shape().ndims
            else:
                ndims = array_ops.size(array_ops.shape(x))
            return clip_ops.clip_by_norm(x,
                                         max_norm,
                                         axes=list(range(1, ndims)))
        return x

    with ops.name_scope(name, "embedding_lookup", params + [ids]) as name:
        np = len(params)  # Number of partitions
        params = ops.convert_n_to_tensor_or_indexed_slices(params,
                                                           name="params")
        if np == 1:
            with ops.colocate_with(params[0]):
                # TODO(apassos): implement the sharded version as well.
                if isinstance(params[0],
                              resource_variable_ops.ResourceVariable):
                    ret = params[0].sparse_read(ids, name=name)
                else:
                    ret = array_ops.gather(params[0],
                                           ids,
                                           name=name,
                                           validate_indices=validate_indices)
            return maybe_normalize(ret)
        else:
            ids = ops.convert_to_tensor(ids, name="ids")
            flat_ids = array_ops.reshape(ids, [-1])
            original_indices = math_ops.range(array_ops.size(flat_ids))

            # Create p_assignments and set new_ids depending on the strategy.
            if partition_strategy == "mod":
                p_assignments = flat_ids % np
                new_ids = flat_ids // np
            elif partition_strategy == "div":
                # Compute num_total_ids as the sum of dim-0 of params, then assign to
                # partitions based on a constant number of ids per partition. Optimize
                # if we already know the full shape statically.
                dim_0_size = params[0].get_shape()[0]
                for p in xrange(1, np):
                    dim_0_size += params[p].get_shape()[0]
                if dim_0_size.value:
                    num_total_ids = constant_op.constant(
                        dim_0_size.value, flat_ids.dtype)
                else:
                    dim_0_sizes = []
                    for p in xrange(np):
                        if params[p].get_shape()[0].value is not None:
                            dim_0_sizes.append(params[p].get_shape()[0].value)
                        else:
                            with ops.colocate_with(params[p]):
                                dim_0_sizes.append(
                                    array_ops.shape(params[p])[0])
                    num_total_ids = math_ops.reduce_sum(
                        math_ops.cast(array_ops.stack(dim_0_sizes),
                                      flat_ids.dtype))
                ids_per_partition = num_total_ids // np
                extras = num_total_ids % np

                p_assignments = math_ops.maximum(
                    flat_ids // (ids_per_partition + 1),
                    (flat_ids - extras) // ids_per_partition)

                # Emulate a conditional using a boolean indicator tensor
                is_in_first_extras_partitions = math_ops.cast(
                    p_assignments < extras, flat_ids.dtype)
                new_ids = (is_in_first_extras_partitions *
                           (flat_ids % (ids_per_partition + 1)) +
                           (1 - is_in_first_extras_partitions) *
                           ((flat_ids - extras) % ids_per_partition))
            else:
                raise ValueError("Unrecognized partition strategy: " +
                                 partition_strategy)

            # Cast partition assignments to int32 for use in dynamic_partition.
            # There really should not be more than 2^32 partitions.
            p_assignments = math_ops.cast(p_assignments, dtypes.int32)
            # Partition list of ids based on assignments into np separate lists
            gather_ids = data_flow_ops.dynamic_partition(
                new_ids, p_assignments, np)
            # Similarly, partition the original indices.
            pindices = data_flow_ops.dynamic_partition(original_indices,
                                                       p_assignments, np)
            # Do np separate lookups, finding embeddings for plist[p] in params[p]
            partitioned_result = []
            for p in xrange(np):
                with ops.colocate_with(params[p]):
                    partitioned_result.append(
                        array_ops.gather(params[p],
                                         gather_ids[p],
                                         validate_indices=validate_indices))
            # Stitch these back together
            ret = data_flow_ops.dynamic_stitch(pindices,
                                               partitioned_result,
                                               name=name)
            # Reshape to reverse the flattening of ids.
            element_shape = params[0].get_shape()[1:]
            for p in params[1:]:
                element_shape = element_shape.merge_with(p.get_shape()[1:])
            if element_shape.is_fully_defined():
                ret = array_ops.reshape(
                    ret,
                    array_ops.concat_v2([array_ops.shape(ids), element_shape],
                                        0))
            else:
                # It's important that we compute params[0].shape on the right device
                # to avoid data motion.
                with ops.colocate_with(params[0]):
                    params_shape = array_ops.shape(params[0])
                ret = array_ops.reshape(
                    ret,
                    array_ops.concat_v2([
                        array_ops.shape(ids),
                        array_ops.slice(params_shape, [1], [-1])
                    ], 0))
            # output shape = ids.shape + params[*].shape[1:]
            # Normally the reshape is sufficient, but setting shape explicitly
            # teaches shape inference that params[1:].get_shape() matters.
            ret.set_shape(ids.get_shape().concatenate(element_shape))
            return maybe_normalize(ret)
Ejemplo n.º 51
0
def embedding_lookup(params, ids, name=None):
  """Looks up `ids` in a list of embedding tensors.

  This function is used to perform parallel lookups on the list of
  tensors in `params`.  It is a generalization of
  [`tf.gather()`](../../api_docs/python/array_ops.md#gather), where `params` is
  interpreted as a partition of a larger embedding tensor.

  If `len(params) > 1`, each element `id` of `ids` is partitioned between
  the elements of `params` by computing `p = id % len(params)`, and is
  then used to look up the slice `params[p][id // len(params), ...]`.

  The results of the lookup are then concatenated into a dense
  tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.

  Args:
    params: A list of tensors with the same shape and type.
    ids: A `Tensor` with type `int32` containing the ids to be looked
      up in `params`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor` with the same type as the tensors in `params`.

  Raises:
    ValueError: If `params` is empty.
  """
  if not isinstance(params, list):
    params = [params]
  with ops.op_scope(params + [ids], name, "embedding_lookup") as name:
    if not params:
      raise ValueError("Need at least one param")
    np = len(params)  # Number of partitions
    params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
    if np == 1:
      with ops.device(params[0].device):
        return array_ops.gather(params[0], ids, name=name)
    else:
      ids = ops.convert_to_tensor(ids, name="ids")
      flat_ids = array_ops.reshape(ids, [-1])
      original_indices = math_ops.range(0, array_ops.size(flat_ids))
      # Compute flat_ids % partitions for each id
      ids_mod_p = flat_ids % np
      if ids_mod_p.dtype != types.int32:
        ids_mod_p = math_ops.cast(ids_mod_p, types.int32)
      # Partition single list of ids based on ids % np into np separate lists
      plist = data_flow_ops.dynamic_partition(flat_ids, ids_mod_p, np)
      # Similarly, partition the original indices.
      pindices = data_flow_ops.dynamic_partition(original_indices, ids_mod_p,
                                                 np)
      # Do np separate lookups, finding embeddings for plist[p] in params[p]
      partitioned_result = []
      for p in xrange(np):
        # TODO(agarwal): handle device allocations here and later in the
        # colocate code.
        gather_ids = plist[p] // np
        with ops.device(params[p].device):
          partitioned_result.append(array_ops.gather(params[p], gather_ids))
      # Stitch these back together
      ret = data_flow_ops.dynamic_stitch(pindices, partitioned_result,
                                         name=name)
      # Reshape to reverse the flattening of ids.
      # It's important that we compute params[0].shape on the right device
      # to avoid data motion.
      with ops.device(params[0].device):
        params_shape = array_ops.shape(params[0])
      ret = array_ops.reshape(ret, array_ops.concat(0, [
          array_ops.shape(ids), array_ops.slice(params_shape, [1], [-1])]))
      # output shape = ids.shape + params[*].shape[1:]
      # Normally the reshape is sufficient, but setting shape explicitly
      # teaches shape inference that params[1:].get_shape() matters.
      element_shape = params[0].get_shape()[1:]
      for p in params[1:]:
        element_shape = element_shape.merge_with(p.get_shape()[1:])
      ret.set_shape(ids.get_shape().concatenate(element_shape))
      return ret
Ejemplo n.º 52
0
def _embedding_lookup_and_transform(params,
                                    ids,
                                    partition_strategy="mod",
                                    name=None,
                                    max_norm=None,
                                    transform_fn=None):
  """Helper function for embedding_lookup and _compute_sampled_logits.

  This function is a generalization of embedding_lookup that optionally
  applies a caller-specified transformation to each embedding. This is
  done through the `transform_fn` argument. If provided, the function is
  applied to each partitioned tensor of retrieved embeddings, colocated
  with the embeddings. This function will be called with a single `Tensor`
  argument of the same type as the `params` tensor and should return a
  `Tensor`. The shape of the argument will be the same as `params` except
  for the size of the first dimension. The first dimension of the result's
  shape must be the same size as the argument's.

  Args:
    params: See embedding_lookup.
    ids: See embedding_lookup.
    partition_strategy: See embedding_lookup.
    name: See embedding_lookup.
    max_norm: See embedding_lookup.
    transform_fn: An optional function to apply to each retrieved embedding.

  Returns:
    See embedding_lookup for details.
  Raises:
    ValueError: If `params` is empty.
  """
  if params is None or params in ((), []):
    raise ValueError("Need at least one param")
  if isinstance(params, variables.PartitionedVariable):
    params = list(params)  # Iterate to get the underlying Variables.
  if not isinstance(params, list):
    params = [params]

  with ops.name_scope(name, "embedding_lookup", params + [ids]) as name:
    np = len(params)  # Number of partitions
    # Preserve the resource variable status to avoid accidental dense reads.
    if not any(
        isinstance(p, resource_variable_ops.ResourceVariable) for p in params):
      params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
    ids = ops.convert_to_tensor(ids, name="ids")
    if np == 1 and (transform_fn is None or ids.get_shape().ndims == 1):
      with ops.colocate_with(params[0]):
        result = _gather_and_clip(params[0], ids, max_norm, name=name)
        if transform_fn is not None:
          result = transform_fn(result)
        return result
    else:
      # Flatten the ids. There are two cases where we need to do this.
      # - There is more than one params tensor.
      # - There is a transform_fn and ids is not statically known to be 1-D.
      #   We must flatten in this case because transform_fn expects a flat
      #   a flat tensor of embeddings.
      flat_ids = array_ops.reshape(ids, [-1])
      original_indices = math_ops.range(array_ops.size(flat_ids))

      # Create p_assignments and set new_ids depending on the strategy.
      if partition_strategy == "mod":
        p_assignments = flat_ids % np
        new_ids = flat_ids // np
      elif partition_strategy == "div":
        # Compute num_total_ids as the sum of dim-0 of params, then assign to
        # partitions based on a constant number of ids per partition. Optimize
        # if we already know the full shape statically.
        dim_0_size = params[0].get_shape()[0]
        for p in xrange(1, np):
          dim_0_size += params[p].get_shape()[0]
        if dim_0_size.value:
          num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
        else:
          dim_0_sizes = []
          for p in xrange(np):
            if params[p].get_shape()[0].value is not None:
              dim_0_sizes.append(params[p].get_shape()[0].value)
            else:
              with ops.colocate_with(params[p]):
                dim_0_sizes.append(array_ops.shape(params[p])[0])
          num_total_ids = math_ops.reduce_sum(
              math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
        ids_per_partition = num_total_ids // np
        extras = num_total_ids % np

        p_assignments = math_ops.maximum(
            flat_ids // (ids_per_partition + 1),
            (flat_ids - extras) // ids_per_partition)

        # Emulate a conditional using a boolean indicator tensor
        is_in_first_extras_partitions = math_ops.cast(p_assignments < extras,
                                                      flat_ids.dtype)
        new_ids = (is_in_first_extras_partitions * (flat_ids %
                                                    (ids_per_partition + 1)) +
                   (1 - is_in_first_extras_partitions) *
                   ((flat_ids - extras) % ids_per_partition))
      else:
        raise ValueError("Unrecognized partition strategy: " +
                         partition_strategy)

      # Cast partition assignments to int32 for use in dynamic_partition.
      # There really should not be more than 2^32 partitions.
      p_assignments = math_ops.cast(p_assignments, dtypes.int32)
      # Partition list of ids based on assignments into np separate lists
      gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
      # Similarly, partition the original indices.
      pindices = data_flow_ops.dynamic_partition(original_indices,
                                                 p_assignments, np)
      # Do np separate lookups, finding embeddings for plist[p] in params[p]
      partitioned_result = []
      for p in xrange(np):
        with ops.colocate_with(params[p]):
          result = _gather_and_clip(params[p], gather_ids[p], max_norm)
          if transform_fn is not None:
            result = transform_fn(result)
          partitioned_result.append(result)
      # Stitch these back together
      ret = data_flow_ops.dynamic_stitch(
          pindices, partitioned_result, name=name)

      # Determine the static element shape.
      if transform_fn is None:
        element_shape_s = params[0].get_shape()[1:]
        for p in params[1:]:
          element_shape_s = element_shape_s.merge_with(p.get_shape()[1:])
      else:
        element_shape_s = ret.get_shape()[1:]

      # Compute the dynamic element shape.
      if element_shape_s.is_fully_defined():
        element_shape_d = element_shape_s
      elif transform_fn is None:
        # It's important that we compute params[0].shape on the right device
        # to avoid data motion.
        with ops.colocate_with(params[0]):
          params_shape = array_ops.shape(params[0])
        element_shape_d = params_shape[1:]
      else:
        element_shape_d = array_ops.shape(ret)[1:]

      # Reshape to reverse the flattening of ids.
      ret = array_ops.reshape(ret,
                              array_ops.concat(
                                  [array_ops.shape(ids), element_shape_d], 0))

      # Normally the reshape is sufficient, but setting shape explicitly
      # teaches shape inference that params[1:].get_shape() matters
      # (in the case that transform_fn is None).
      ret.set_shape(ids.get_shape().concatenate(element_shape_s))
      return ret
Ejemplo n.º 53
0
  def _sample_n(self, n, seed=None):
    if self._use_static_graph:
      # This sampling approach is almost the same as the approach used by
      # `MixtureSameFamily`. The differences are due to having a list of
      # `Distribution` objects rather than a single object, and maintaining
      # random seed management that is consistent with the non-static code path.
      samples = []
      cat_samples = self.cat.sample(n, seed=seed)
      for c in range(self.num_components):
        seed = distribution_util.gen_new_seed(seed, "mixture")
        samples.append(self.components[c].sample(n, seed=seed))
      x = array_ops.stack(
          samples, -self._static_event_shape.ndims - 1)     # [n, B, k, E]
      npdt = x.dtype.as_numpy_dtype
      mask = array_ops.one_hot(
          indices=cat_samples,                              # [n, B]
          depth=self._num_components,                       # == k
          on_value=np.ones([], dtype=npdt),
          off_value=np.zeros([], dtype=npdt))               # [n, B, k]
      mask = distribution_utils.pad_mixture_dimensions(
          mask, self, self._cat,
          self._static_event_shape.ndims)                   # [n, B, k, [1]*e]
      return math_ops.reduce_sum(
          x * mask,
          axis=-1 - self._static_event_shape.ndims)         # [n, B, E]

    with ops.control_dependencies(self._assertions):
      n = ops.convert_to_tensor(n, name="n")
      static_n = tensor_util.constant_value(n)
      n = int(static_n) if static_n is not None else n
      cat_samples = self.cat.sample(n, seed=seed)

      static_samples_shape = cat_samples.get_shape()
      if static_samples_shape.is_fully_defined():
        samples_shape = static_samples_shape.as_list()
        samples_size = static_samples_shape.num_elements()
      else:
        samples_shape = array_ops.shape(cat_samples)
        samples_size = array_ops.size(cat_samples)
      static_batch_shape = self.batch_shape
      if static_batch_shape.is_fully_defined():
        batch_shape = static_batch_shape.as_list()
        batch_size = static_batch_shape.num_elements()
      else:
        batch_shape = self.batch_shape_tensor()
        batch_size = math_ops.reduce_prod(batch_shape)
      static_event_shape = self.event_shape
      if static_event_shape.is_fully_defined():
        event_shape = np.array(static_event_shape.as_list(), dtype=np.int32)
      else:
        event_shape = self.event_shape_tensor()

      # Get indices into the raw cat sampling tensor. We will
      # need these to stitch sample values back out after sampling
      # within the component partitions.
      samples_raw_indices = array_ops.reshape(
          math_ops.range(0, samples_size), samples_shape)

      # Partition the raw indices so that we can use
      # dynamic_stitch later to reconstruct the samples from the
      # known partitions.
      partitioned_samples_indices = data_flow_ops.dynamic_partition(
          data=samples_raw_indices,
          partitions=cat_samples,
          num_partitions=self.num_components)

      # Copy the batch indices n times, as we will need to know
      # these to pull out the appropriate rows within the
      # component partitions.
      batch_raw_indices = array_ops.reshape(
          array_ops.tile(math_ops.range(0, batch_size), [n]), samples_shape)

      # Explanation of the dynamic partitioning below:
      #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
      # Suppose partitions are:
      #     [1 1 0 0 1 1]
      # After partitioning, batch indices are cut as:
      #     [batch_indices[x] for x in 2, 3]
      #     [batch_indices[x] for x in 0, 1, 4, 5]
      # i.e.
      #     [1 1] and [0 0 0 0]
      # Now we sample n=2 from part 0 and n=4 from part 1.
      # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
      # and for part 1 we want samples from batch entries 0, 0, 0, 0
      #   (samples 0, 1, 2, 3).
      partitioned_batch_indices = data_flow_ops.dynamic_partition(
          data=batch_raw_indices,
          partitions=cat_samples,
          num_partitions=self.num_components)
      samples_class = [None for _ in range(self.num_components)]

      for c in range(self.num_components):
        n_class = array_ops.size(partitioned_samples_indices[c])
        seed = distribution_util.gen_new_seed(seed, "mixture")
        samples_class_c = self.components[c].sample(n_class, seed=seed)

        # Pull out the correct batch entries from each index.
        # To do this, we may have to flatten the batch shape.

        # For sample s, batch element b of component c, we get the
        # partitioned batch indices from
        # partitioned_batch_indices[c]; and shift each element by
        # the sample index. The final lookup can be thought of as
        # a matrix gather along locations (s, b) in
        # samples_class_c where the n_class rows correspond to
        # samples within this component and the batch_size columns
        # correspond to batch elements within the component.
        #
        # Thus the lookup index is
        #   lookup[c, i] = batch_size * s[i] + b[c, i]
        # for i = 0 ... n_class[c] - 1.
        lookup_partitioned_batch_indices = (
            batch_size * math_ops.range(n_class) +
            partitioned_batch_indices[c])
        samples_class_c = array_ops.reshape(
            samples_class_c,
            array_ops.concat([[n_class * batch_size], event_shape], 0))
        samples_class_c = array_ops.gather(
            samples_class_c, lookup_partitioned_batch_indices,
            name="samples_class_c_gather")
        samples_class[c] = samples_class_c

      # Stitch back together the samples across the components.
      lhs_flat_ret = data_flow_ops.dynamic_stitch(
          indices=partitioned_samples_indices, data=samples_class)
      # Reshape back to proper sample, batch, and event shape.
      ret = array_ops.reshape(lhs_flat_ret,
                              array_ops.concat([samples_shape,
                                                self.event_shape_tensor()], 0))
      ret.set_shape(
          tensor_shape.TensorShape(static_samples_shape).concatenate(
              self.event_shape))
      return ret