def lookup(self, keys, name=None): if keys.dtype != self._key_dtype: raise TypeError( 'Signature mismatch. Keys must be dtype %s, got %s.' % (self._key_dtype, keys.dtype)) self._check_keys(keys) num_shards = self._num_shards if num_shards == 1: return self._table_shards[0].lookup(keys, name=name) shard_indices = self._shard_indices(keys) # TODO(andreasst): support 'keys' that are not vectors key_shards = data_flow_ops.dynamic_partition(keys, shard_indices, num_shards) value_shards = [ self._table_shards[i].lookup(key_shards[i], name=name) for i in range(num_shards) ] num_keys = keys.get_shape().dims[0] original_indices = math_ops.range(num_keys) partitioned_indices = data_flow_ops.dynamic_partition( original_indices, shard_indices, num_shards) result = data_flow_ops.dynamic_stitch(partitioned_indices, value_shards) result.set_shape( tensor_shape.TensorShape([num_keys ]).concatenate(self._value_shape)) return result
def lookup(self, keys, name=None): if keys.dtype != self._key_dtype: raise TypeError('Signature mismatch. Keys must be dtype %s, got %s.' % (self._key_dtype, keys.dtype)) self._check_keys(keys) num_shards = self._num_shards if num_shards == 1: return self._table_shards[0].lookup(keys, name=name) shard_indices = self._shard_indices(keys) # TODO(andreasst): support 'keys' that are not vectors key_shards = data_flow_ops.dynamic_partition(keys, shard_indices, num_shards) value_shards = [ self._table_shards[i].lookup(key_shards[i], name=name) for i in range(num_shards) ] num_keys = keys.get_shape().dims[0] original_indices = math_ops.range(num_keys) partitioned_indices = data_flow_ops.dynamic_partition(original_indices, shard_indices, num_shards) result = data_flow_ops.dynamic_stitch(partitioned_indices, value_shards) result.set_shape( tensor_shape.TensorShape([num_keys]).concatenate(self._value_shape)) return result
def scatter_update(cls, factor, indices, values, sharding_func, name=None): """Helper function for doing sharded scatter update.""" assert isinstance(factor, list) if len(factor) == 1: with ops.colocate_with(factor[0]): # TODO(agarwal): assign instead of scatter update for full batch update. return state_ops.scatter_update(factor[0], indices, values, name=name).op else: num_shards = len(factor) assignments, new_ids = sharding_func(indices) assert assignments is not None assignments = math_ops.cast(assignments, dtypes.int32) sharded_ids = data_flow_ops.dynamic_partition( new_ids, assignments, num_shards) sharded_values = data_flow_ops.dynamic_partition( values, assignments, num_shards) updates = [] for i in xrange(num_shards): updates.append( state_ops.scatter_update(factor[i], sharded_ids[i], sharded_values[i])) return control_flow_ops.group(*updates, name=name)
def insert(self, keys, values, name=None): num_shards = self._num_shards if num_shards == 1: return self._table_shards[0].insert(keys, values, name=name) shard_indices = self._shard_indices(keys) # TODO(andreasst): support 'keys' that are not vectors key_shards = data_flow_ops.dynamic_partition(keys, shard_indices, num_shards) value_shards = data_flow_ops.dynamic_partition(values, shard_indices, num_shards) return_values = [ self._table_shards[i].insert(key_shards[i], value_shards[i], name=name) for i in range(num_shards) ] return control_flow_ops.group(*return_values)
def _make_per_class_queues(data, labels, num_classes, queue_capacity, threads_per_queue): """Creates per-class-queues based on data and labels.""" # Create one queue per class. queues = [] per_data_shape = data.get_shape().with_rank_at_least(1)[1:] per_data_shape.assert_is_fully_defined() for i in range(num_classes): q = data_flow_ops.FIFOQueue(capacity=queue_capacity, shapes=per_data_shape, dtypes=[data.dtype], name='stratified_sample_class%d_queue' % i) logging_ops.scalar_summary('queue/stratified_sample_class%d' % i, q.size()) queues.append(q) # Partition tensors according to labels. partitions = data_flow_ops.dynamic_partition(data, labels, num_classes) # Enqueue each tensor on the per-class-queue. for i in range(num_classes): enqueue_op = queues[i].enqueue_many(partitions[i]), queue_runner.add_queue_runner(queue_runner.QueueRunner( queues[i], [enqueue_op] * threads_per_queue)) return queues
def testRaggedSegmentStack(self, data, partitions, num_partitions, expected, data_ragged_rank=None, segment_ids_ragged_rank=None, expected_ragged_rank=None): for seg_dtype in [dtypes.int32, dtypes.int64]: data_tensor = ragged_factory_ops.constant( data, row_splits_dtype=seg_dtype, ragged_rank=data_ragged_rank) segment_ids_tensor = ragged_factory_ops.constant( partitions, dtype=seg_dtype, row_splits_dtype=seg_dtype, ragged_rank=segment_ids_ragged_rank) expected_tensor = ragged_factory_ops.constant( expected, row_splits_dtype=seg_dtype, ragged_rank=expected_ragged_rank) result = ragged_array_ops.stack_dynamic_partitions( data_tensor, segment_ids_tensor, num_partitions) self.assertAllEqual(result, expected_tensor) # Check that it's equivalent to tf.stack(dynamic_partition(...)), # where applicable. if (data_ragged_rank == 0 and segment_ids_ragged_rank == 0 and seg_dtype == dtypes.int32): equiv = ragged_concat_ops.stack( data_flow_ops.dynamic_partition(data_tensor, segment_ids_tensor, num_partitions)) self.assertAllEqual(result, self.evaluate(equiv).to_list())
def testScalarIndexOutOfRange(self): with self.test_session() as sess: bad = 17 data = np.zeros(5) partitions = data_flow_ops.dynamic_partition(data, bad, num_partitions=7) with self.assertRaisesOpError(r"partitions = 17 is not in \[0, 7\)"): sess.run(partitions)
def _make_per_class_queues(tensor_list, labels, num_classes, queue_capacity, threads_per_queue): """Creates per-class-queues based on data and labels.""" # Create one queue per class. queues = [] data_shapes = [] data_dtypes = [] for data_tensor in tensor_list: per_data_shape = data_tensor.get_shape().with_rank_at_least(1)[1:] per_data_shape.assert_is_fully_defined() data_shapes.append(per_data_shape) data_dtypes.append(data_tensor.dtype) for i in range(num_classes): q = data_flow_ops.FIFOQueue( capacity=queue_capacity, shapes=data_shapes, dtypes=data_dtypes, name="stratified_sample_class%d_queue" % i ) logging_ops.scalar_summary("queue/%s/stratified_sample_class%d" % (q.name, i), q.size()) queues.append(q) # Partition tensors according to labels. `partitions` is a list of lists, of # size num_classes X len(tensor_list). The number of tensors in partition `i` # should be the same for all tensors. all_partitions = [data_flow_ops.dynamic_partition(data, labels, num_classes) for data in tensor_list] partitions = [[cur_partition[i] for cur_partition in all_partitions] for i in range(num_classes)] # Enqueue each tensor on the per-class-queue. for i in range(num_classes): enqueue_op = (queues[i].enqueue_many(partitions[i]),) queue_runner.add_queue_runner(queue_runner.QueueRunner(queues[i], [enqueue_op] * threads_per_queue)) return queues
def testHigherRank(self): np.random.seed(7) with self.test_session(use_gpu=True) as sess: for n in 2, 3: for shape in (4, ), (4, 5), (4, 5, 2): partitions = np.random.randint( n, size=np.prod(shape)).reshape(shape) for extra_shape in (), (6, ), (6, 7): data = np.random.randn(*(shape + extra_shape)) partitions_t = constant_op.constant(partitions, dtype=dtypes.int32) data_t = constant_op.constant(data) outputs = data_flow_ops.dynamic_partition( data_t, partitions_t, num_partitions=n) self.assertEqual(n, len(outputs)) outputs_val = sess.run(outputs) for i, output in enumerate(outputs_val): self.assertAllEqual(output, data[partitions == i]) # Test gradients outputs_grad = [7 * output for output in outputs_val] grads = gradients_impl.gradients( outputs, [data_t, partitions_t], outputs_grad) self.assertEqual(grads[1], None) # Partitions has no gradients self.assertAllEqual(7 * data, sess.run(grads[0]))
def testScalarIndexOutOfRange(self): # GPU kernels don't throw exceptions. with self.cached_session(use_gpu=False): bad = 17 data = np.zeros(5) partitions = data_flow_ops.dynamic_partition(data, bad, num_partitions=7) with self.assertRaisesOpError(r"partitions = 17 is not in \[0, 7\)"): self.evaluate(partitions)
def testErrorIndexOutOfRange(self): with self.test_session() as sess: data = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14]]) indices = constant_op.constant([0, 2, 99, 2, 2]) partitions = data_flow_ops.dynamic_partition( data, indices, num_partitions=4) with self.assertRaisesOpError(r"partitions\[2\] = 99 is not in \[0, 4\)"): sess.run(partitions)
def insert(self, keys, values, name=None): """Inserts `keys` in a table.""" self._check_keys(keys) num_shards = self._num_shards if num_shards == 1: return self._table_shards[0].insert(keys, values, name=name) shard_indices = self._shard_indices(keys) key_shards = data_flow_ops.dynamic_partition(keys, shard_indices, num_shards) value_shards = data_flow_ops.dynamic_partition(values, shard_indices, num_shards) return_values = [ self._table_shards[i].insert(key_shards[i], value_shards[i], name=name) for i in range(num_shards) ] return control_flow_ops.group(*return_values)
def testCUBBug(self): x = constant_op.constant(np.random.randn(3072)) inds = [0]*189 + [1]*184 + [2]*184 + [3]*191 + [4]*192 + [5]*195 + [6]*195 inds += [7]*195 + [8]*188 + [9]*195 + [10]*188 + [11]*202 + [12]*194 inds += [13]*194 + [14]*194 + [15]*192 self.assertEqual(len(inds), x.shape[0]) partitioned = data_flow_ops.dynamic_partition(x, inds, 16) with self.test_session() as sess: res = sess.run(partitioned) self.assertEqual(res[-1].shape[0], 192)
def testMultiGPU(self): device_list = config.list_logical_devices("GPU") results = [] for device in device_list: with ops.device(device.name): data = constant_op.constant(np.zeros((1000,))) partitions = constant_op.constant(np.arange(1000, dtype=np.int32) % 10) result = data_flow_ops.dynamic_partition(data, partitions, 10) results.append(self.evaluate(result)) if device_list: self.assertAllEqual(results, np.zeros((len(device_list), 10, 100)))
def _decompose_indexed_slices(self, indexed_slices): """Decompose a global `IndexedSlices` into a list of per-variable ones.""" per_var_indices, partition_assignments = self._decompose_indices( indexed_slices.indices) per_var_values = data_flow_ops.dynamic_partition( indexed_slices.values, partition_assignments, len(self._variables)) return [ indexed_slices_lib.IndexedSlices(values=per_var_values[i], indices=per_var_indices[i]) for i in range(len(self._variables)) ]
def testErrorIndexOutOfRange(self): # GPU kernels don't throw exceptions. with self.cached_session(use_gpu=False): data = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14]]) indices = constant_op.constant([0, 2, 99, 2, 2]) partitions = data_flow_ops.dynamic_partition(data, indices, num_partitions=4) with self.assertRaisesOpError( r"partitions\[2\] = 99 is not in \[0, 4\)"): self.evaluate(partitions)
def testEmptyPartitions(self): data_list = [] indices_list = [] with self.test_session(use_gpu=True) as sess: data = constant_op.constant(data_list, dtype=dtypes.float32) indices = constant_op.constant(indices_list, dtype=dtypes.int32) partitions = data_flow_ops.dynamic_partition( data, indices, num_partitions=2) partition_vals = sess.run(partitions) self.assertAllEqual([], partition_vals[0]) self.assertAllEqual([], partition_vals[1])
def testSimpleComplex(self): data_list = [1 + 2j, 3 + 4j, 5 + 6j, 7 + 8j] indices_list = [1, 0, 1, 0] with self.test_session(use_gpu=True) as sess: data = constant_op.constant(data_list, dtype=dtypes.complex64) indices = constant_op.constant(indices_list, dtype=dtypes.int32) partitions = data_flow_ops.dynamic_partition( data, indices, num_partitions=2) partition_vals = sess.run(partitions) self.assertAllEqual([3 + 4j, 7 + 8j], partition_vals[0]) self.assertAllEqual([1 + 2j, 5 + 6j], partition_vals[1])
def testEmptyPartitions(self): data_list = [] indices_list = [] with self.test_session(use_gpu=True) as sess: data = constant_op.constant(data_list, dtype=dtypes.float32) indices = constant_op.constant(indices_list, dtype=dtypes.int32) partitions = data_flow_ops.dynamic_partition( data, indices, num_partitions=2) partition_vals = sess.run(partitions) self.assertEqual(2, len(partition_vals)) self.assertAllEqual([], partition_vals[0]) self.assertAllEqual([], partition_vals[1])
def _partition(data, partition_index, shard_num): """ Shard keys to shard_num partitions Args: data: keys or values, usually the IDs of dynamic features. partition_index: partitions index. shard_num: partition number Returns: a pair of tensor: (partition result, partition indices) """ if shard_num <= 1: return [ data, ], None with ops.colocate_with(data, ignore_existing=True): partitions = data_flow_ops.dynamic_partition(data, partition_index, shard_num) indices = data_flow_ops.dynamic_partition( math_ops.range(array_ops.shape(data)[0]), math_ops.cast(partition_index, dtypes.int32), shard_num) return partitions, indices
def testSimpleComplex(self): data_list = [1 + 2j, 3 + 4j, 5 + 6j, 7 + 8j] indices_list = [1, 0, 1, 0] with self.test_session(use_gpu=True) as sess: data = constant_op.constant(data_list, dtype=dtypes.complex64) indices = constant_op.constant(indices_list, dtype=dtypes.int32) partitions = data_flow_ops.dynamic_partition( data, indices, num_partitions=2) partition_vals = sess.run(partitions) self.assertEqual(2, len(partition_vals)) self.assertAllEqual([3 + 4j, 7 + 8j], partition_vals[0]) self.assertAllEqual([1 + 2j, 5 + 6j], partition_vals[1])
def scatter_update(cls, factor, indices, values, sharding_func): """Helper function for doing sharded scatter update.""" assert isinstance(factor, list) if len(factor) == 1: with ops.colocate_with(factor[0]): # TODO(agarwal): assign instead of scatter update for full batch update. return state_ops.scatter_update(factor[0], indices, values).op else: num_shards = len(factor) assignments, new_ids = sharding_func(indices) assert assignments is not None assignments = math_ops.cast(assignments, dtypes.int32) sharded_ids = data_flow_ops.dynamic_partition(new_ids, assignments, num_shards) sharded_values = data_flow_ops.dynamic_partition(values, assignments, num_shards) updates = [] for i in xrange(num_shards): updates.append( state_ops.scatter_update(factor[i], sharded_ids[i], sharded_values[ i])) return control_flow_ops.group(*updates)
def _DynamicPartitionGrads(op, *grads): """Gradients for DynamicPartition.""" data = op.inputs[0] indices = op.inputs[1] num_partitions = op.get_attr("num_partitions") prefix_shape = array_ops.shape(indices) original_indices = array_ops.reshape( math_ops.range(math_ops.reduce_prod(prefix_shape)), prefix_shape) partitioned_indices = data_flow_ops.dynamic_partition( original_indices, indices, num_partitions) reconstructed = data_flow_ops.dynamic_stitch(partitioned_indices, grads) reconstructed = array_ops.reshape(reconstructed, array_ops.shape(data)) return [reconstructed, None]
def lookup(self, keys, name=None): if keys.dtype.base_dtype != self._key_dtype: raise TypeError( "Signature mismatch. Keys must be dtype %s, got %s." % (self._key_dtype, keys.dtype)) self._check_keys(keys) num_shards = self._num_shards if num_shards == 1: return self._table_shards[0].lookup(keys, name=name) shard_indices = self._shard_indices(keys) key_shards = data_flow_ops.dynamic_partition(keys, shard_indices, num_shards) value_shards = [ self._table_shards[i].lookup(key_shards[i], name=name) for i in range(num_shards) ] num_keys = array_ops.shape(keys)[0] original_indices = math_ops.range(num_keys) partitioned_indices = data_flow_ops.dynamic_partition( original_indices, shard_indices, num_shards) return data_flow_ops.dynamic_stitch(partitioned_indices, value_shards)
def testEmptyDataTwoDimensional(self): data_list = [[], []] indices_list = [0, 1] with self.test_session(use_gpu=True) as sess: data = constant_op.constant(data_list, dtype=dtypes.float32) indices = constant_op.constant(indices_list, dtype=dtypes.int32) partitions = data_flow_ops.dynamic_partition( data, indices, num_partitions=3) partition_vals = sess.run(partitions) self.assertAllEqual([[]], partition_vals[0]) self.assertAllEqual([[]], partition_vals[1]) self.assertAllEqual(np.array([], dtype=np.float).reshape(0, 0), partition_vals[2])
def testHigherRankIndexOutOfRange(self): with self.test_session() as sess: shape = (2, 3) indices = array_ops.placeholder(shape=shape, dtype=np.int32) data = np.zeros(shape + (5,)) partitions = data_flow_ops.dynamic_partition( data, indices, num_partitions=7) for i in xrange(2): for j in xrange(3): bad = np.zeros(shape, dtype=np.int32) bad[i, j] = 17 with self.assertRaisesOpError( r"partitions\[%d,%d\] = 17 is not in \[0, 7\)" % (i, j)): sess.run(partitions, feed_dict={indices: bad})
def testEmptyDataTwoDimensional(self): data_list = [[], []] indices_list = [0, 1] with self.test_session(use_gpu=True) as sess: data = constant_op.constant(data_list, dtype=dtypes.float32) indices = constant_op.constant(indices_list, dtype=dtypes.int32) partitions = data_flow_ops.dynamic_partition( data, indices, num_partitions=3) partition_vals = sess.run(partitions) self.assertEqual(3, len(partition_vals)) self.assertAllEqual([[]], partition_vals[0]) self.assertAllEqual([[]], partition_vals[1]) self.assertAllEqual(np.array([], dtype=np.float).reshape(0, 0), partition_vals[2])
def testEmptyParts(self): data_list = [1, 2, 3, 4] indices_list = [1, 3, 1, 3] with self.session(use_gpu=True) as sess: data = constant_op.constant(data_list, dtype=dtypes.float32) indices = constant_op.constant(indices_list, dtype=dtypes.int32) partitions = data_flow_ops.dynamic_partition( data, indices, num_partitions=4) partition_vals = sess.run(partitions) self.assertEqual(4, len(partition_vals)) self.assertAllEqual([], partition_vals[0]) self.assertAllEqual([1, 3], partition_vals[1]) self.assertAllEqual([], partition_vals[2]) self.assertAllEqual([2, 4], partition_vals[3])
def testEmptyParts(self): data_list = [1, 2, 3, 4] indices_list = [1, 3, 1, 3] with self.session(): data = constant_op.constant(data_list, dtype=dtypes.float32) indices = constant_op.constant(indices_list, dtype=dtypes.int32) partitions = data_flow_ops.dynamic_partition( data, indices, num_partitions=4) partition_vals = self.evaluate(partitions) self.assertEqual(4, len(partition_vals)) self.assertAllEqual([], partition_vals[0]) self.assertAllEqual([1, 3], partition_vals[1]) self.assertAllEqual([], partition_vals[2]) self.assertAllEqual([2, 4], partition_vals[3])
def testLargeOneDimensional(self): num = 100000 data_list = [x for x in range(num)] indices_list = [x % 2 for x in range(num)] part1 = [x for x in range(num) if x % 2 == 0] part2 = [x for x in range(num) if x % 2 == 1] with self.test_session(use_gpu=True) as sess: data = constant_op.constant(data_list, dtype=dtypes.float32) indices = constant_op.constant(indices_list, dtype=dtypes.int32) partitions = data_flow_ops.dynamic_partition( data, indices, num_partitions=2) partition_vals = sess.run(partitions) self.assertAllEqual(part1, partition_vals[0]) self.assertAllEqual(part2, partition_vals[1])
def testLargeOneDimensional(self): num = 100000 data_list = [x for x in range(num)] indices_list = [x % 2 for x in range(num)] part1 = [x for x in range(num) if x % 2 == 0] part2 = [x for x in range(num) if x % 2 == 1] with self.test_session(use_gpu=True) as sess: data = constant_op.constant(data_list, dtype=dtypes.float32) indices = constant_op.constant(indices_list, dtype=dtypes.int32) partitions = data_flow_ops.dynamic_partition( data, indices, num_partitions=2) partition_vals = sess.run(partitions) self.assertEqual(2, len(partition_vals)) self.assertAllEqual(part1, partition_vals[0]) self.assertAllEqual(part2, partition_vals[1])
def _get_partitioned_update_ops(self, v_num, num_partitions_by_var, p_assignments_by_var, gather_ids_by_var, weights, full_update, p_assignments, num_partitions): """Get updates for partitioned variables.""" num_partitions = num_partitions_by_var[v_num] p_assignments = p_assignments_by_var[v_num] gather_ids = gather_ids_by_var[v_num] updates = data_flow_ops.dynamic_partition(full_update, p_assignments, num_partitions) update_ops = [] for p in range(num_partitions): with ops.colocate_with(weights[p]): result = state_ops.scatter_add(weights[p], gather_ids[p], updates[p]) update_ops.append(result) return update_ops
def sample(self, n_samples): n_samples = int(n_samples) if self.is_train: cat_probs = self._cat(n_samples) # n x c agg_mu = tf.reduce_sum(tf.expand_dims(cat_probs, 2) * self._mu, axis=1) # n x d agg_var = tf.reduce_sum(tf.expand_dims(cat_probs, 2) * self._var, axis=1) # n x d raw = tf.random_normal([n_samples, self.dim]) ret = agg_mu + tf.sqrt(agg_var) * raw # n x d #cat_probs = self._cat(n_samples) # n x c #samples_class = [None for _ in range(self.n_components)] #for c in range(self.n_components): # raw = tf.random_normal([n_samples, self.dim]) # samples_class_c = self._mu[c] + raw * tf.sqrt(self._var[c]) #tf.matmul(raw, tf.transpose(self._scale[c])) # samples_class[c] = samples_class_c #samples_class = tf.stack(samples_class) # c x n x d #ret = tf.reduce_sum(tf.expand_dims(cat_probs, 2) * tf.transpose(samples_class, [1,0,2]), axis=1) else: cat_samples = self._cat.sample(n_samples) # n x 1 samples_raw_indices = array_ops.reshape( math_ops.range(0, n_samples), cat_samples.get_shape().as_list()) partitioned_samples_indices = data_flow_ops.dynamic_partition( data=samples_raw_indices, partitions=cat_samples, num_partitions=self.n_components) samples_class = [None for _ in range(self.n_components)] for c in range(self.n_components): n_class = array_ops.size(partitioned_samples_indices[c]) raw = tf.random_normal([n_class, self.dim]) samples_class_c = self._mu[c] + raw * tf.sqrt(self._var[c]) samples_class[c] = samples_class_c # Stitch back together the samples across the components. ret = data_flow_ops.dynamic_stitch( indices=partitioned_samples_indices, data=samples_class) ret.set_shape((int(n_samples), self.dim)) return ret
def testSimpleOneDimensional(self): with self.test_session() as sess: data = constant_op.constant([0, 13, 2, 39, 4, 17]) indices = constant_op.constant([0, 0, 2, 3, 2, 1]) partitions = data_flow_ops.dynamic_partition( data, indices, num_partitions=4) partition_vals = sess.run(partitions) self.assertAllEqual([0, 13], partition_vals[0]) self.assertAllEqual([17], partition_vals[1]) self.assertAllEqual([2, 4], partition_vals[2]) self.assertAllEqual([39], partition_vals[3]) # Vector data input to DynamicPartition results in # `num_partitions` vectors of unknown length. self.assertEqual([None], partitions[0].get_shape().as_list()) self.assertEqual([None], partitions[1].get_shape().as_list()) self.assertEqual([None], partitions[2].get_shape().as_list()) self.assertEqual([None], partitions[3].get_shape().as_list())
def testScalarPartitions(self): data_list = [10, 13, 12, 11] with self.test_session(use_gpu=True) as sess: data = constant_op.constant(data_list, dtype=dtypes.float64) indices = 3 partitions = data_flow_ops.dynamic_partition( data, indices, num_partitions=4) partition_vals = sess.run(partitions) self.assertEqual(4, len(partition_vals)) self.assertAllEqual(np.array([], dtype=np.float64).reshape(-1, 4), partition_vals[0]) self.assertAllEqual(np.array([], dtype=np.float64).reshape(-1, 4), partition_vals[1]) self.assertAllEqual(np.array([], dtype=np.float64).reshape(-1, 4), partition_vals[2]) self.assertAllEqual(np.array([10, 13, 12, 11], dtype=np.float64).reshape(-1, 4), partition_vals[3])
def testGPUTooManyParts(self): # This test only makes sense on the GPU. There we do not check # for errors. In this case, we should discard all but the first # num_partitions indices. if not test.is_gpu_available(): return data_list = [1, 2, 3, 4, 5, 6] indices_list = [6, 5, 4, 3, 1, 0] with self.test_session(use_gpu=True) as sess: data = constant_op.constant(data_list, dtype=dtypes.float32) indices = constant_op.constant(indices_list, dtype=dtypes.int32) partitions = data_flow_ops.dynamic_partition( data, indices, num_partitions=2) partition_vals = sess.run(partitions) self.assertEqual(2, len(partition_vals)) self.assertAllEqual([6], partition_vals[0]) self.assertAllEqual([5], partition_vals[1])
def testGPUAllIndicesBig(self): # This test only makes sense on the GPU. There we do not check # for errors. In this case, we should discard all the values # and have an empty output. if not test.is_gpu_available(): return data_list = [1.1, 2.1, 3.1, 4.1, 5.1, 6.1] indices_list = [90, 70, 60, 100, 110, 40] with self.test_session(use_gpu=True) as sess: data = constant_op.constant(data_list, dtype=dtypes.float32) indices = constant_op.constant(indices_list, dtype=dtypes.int32) partitions = data_flow_ops.dynamic_partition( data, indices, num_partitions=40) partition_vals = sess.run(partitions) self.assertEqual(40, len(partition_vals)) for i in range(40): self.assertAllEqual([], partition_vals[i])
def testSimpleTwoDimensional(self): with self.test_session() as sess: data = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14], [15, 16, 17]]) indices = constant_op.constant([0, 0, 2, 3, 2, 1]) partitions = data_flow_ops.dynamic_partition( data, indices, num_partitions=4) partition_vals = sess.run(partitions) self.assertAllEqual([[0, 1, 2], [3, 4, 5]], partition_vals[0]) self.assertAllEqual([[15, 16, 17]], partition_vals[1]) self.assertAllEqual([[6, 7, 8], [12, 13, 14]], partition_vals[2]) self.assertAllEqual([[9, 10, 11]], partition_vals[3]) # Vector data input to DynamicPartition results in # `num_partitions` matrices with an unknown number of rows, and 3 columns. self.assertEqual([None, 3], partitions[0].get_shape().as_list()) self.assertEqual([None, 3], partitions[1].get_shape().as_list()) self.assertEqual([None, 3], partitions[2].get_shape().as_list()) self.assertEqual([None, 3], partitions[3].get_shape().as_list())
def _get_partitioned_update_ops(self, v_num, num_partitions_by_var, p_assignments_by_var, gather_ids_by_var, weights, full_update, p_assignments, num_partitions): """Get updates for partitioned variables.""" num_partitions = num_partitions_by_var[v_num] p_assignments = p_assignments_by_var[v_num] gather_ids = gather_ids_by_var[v_num] updates = data_flow_ops.dynamic_partition( full_update, p_assignments, num_partitions) update_ops = [] for p in range(num_partitions): with ops.colocate_with(weights[p]): result = state_ops.scatter_add(weights[p], gather_ids[p], updates[p]) update_ops.append(result) return update_ops
def testLargeTwoDimensional(self): rows = 100000 cols = 100 data_list = [None] * rows for i in range(rows): data_list[i] = [i for _ in range(cols)] num_partitions = 97 indices_list = [(i ** 2) % num_partitions for i in range(rows)] parts = [[] for _ in range(num_partitions)] for i in range(rows): parts[(i ** 2) % num_partitions].append(data_list[i]) with self.test_session(use_gpu=True) as sess: data = constant_op.constant(data_list, dtype=dtypes.float32) indices = constant_op.constant(indices_list, dtype=dtypes.int32) partitions = data_flow_ops.dynamic_partition( data, indices, num_partitions=num_partitions) partition_vals = sess.run(partitions) for i in range(num_partitions): # reshape because of empty parts parts_np = np.array(parts[i], dtype=np.float).reshape(-1, cols) self.assertAllEqual(parts_np, partition_vals[i])
def testGPUPartsTooLarge(self): # This test only makes sense on the GPU. There we do not check # for errors. In this case, we should discard all the values # larger than num_partitions. if not test.is_gpu_available(): return data_list = [1, 2, 3, 4, 5, 6] indices_list = [10, 11, 2, 12, 0, 1000] with self.test_session(use_gpu=True) as sess: data = constant_op.constant(data_list, dtype=dtypes.float32) indices = constant_op.constant(indices_list, dtype=dtypes.int32) partitions = data_flow_ops.dynamic_partition( data, indices, num_partitions=5) partition_vals = sess.run(partitions) self.assertEqual(5, len(partition_vals)) self.assertAllEqual([5], partition_vals[0]) self.assertAllEqual([], partition_vals[1]) self.assertAllEqual([3], partition_vals[2]) self.assertAllEqual([], partition_vals[3]) self.assertAllEqual([], partition_vals[4])
def testHigherRank(self): np.random.seed(7) with self.test_session(use_gpu=True) as sess: for n in 2, 3: for shape in (4,), (4, 5), (4, 5, 2): partitions = np.random.randint(n, size=np.prod(shape)).reshape(shape) for extra_shape in (), (6,), (6, 7): data = np.random.randn(*(shape + extra_shape)) partitions_t = constant_op.constant(partitions, dtype=dtypes.int32) data_t = constant_op.constant(data) outputs = data_flow_ops.dynamic_partition( data_t, partitions_t, num_partitions=n) self.assertEqual(n, len(outputs)) outputs_val = sess.run(outputs) for i, output in enumerate(outputs_val): self.assertAllEqual(output, data[partitions == i]) # Test gradients outputs_grad = [7 * output for output in outputs_val] grads = gradients_impl.gradients(outputs, [data_t, partitions_t], outputs_grad) self.assertEqual(grads[1], None) # Partitions has no gradients self.assertAllEqual(7 * data, sess.run(grads[0]))
def _embedding_lookup_with_distributed_aggregation(params, ids, partition_strategy="mod", name=None, max_norm=None, weights=None, idx=None, segment_ids=None): """Lookup helper for embedding_lookup_sparse_with_distributed_aggregation.""" if params is None or params == []: # pylint: disable=g-explicit-bool-comparison raise ValueError("Need at least one param") if isinstance(params, variables.PartitionedVariable): params = list(params) # Iterate to get the underlying Variables. if not isinstance(params, list): params = [params] def maybe_normalize(x): if max_norm is not None: if x.get_shape().ndims is not None: ndims = x.get_shape().ndims else: ndims = array_ops.size(array_ops.shape(x)) return clip_ops.clip_by_norm(x, max_norm, axes=list(range(1, ndims))) return x with ops.name_scope(name, "embedding_lookup_with_distributed_aggregation", params + [ids]) as name: np = len(params) # Number of partitions # Preserve the resource variable status to avoid accidental dense reads. if not any( isinstance(p, resource_variable_ops.ResourceVariable) for p in params): params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params") if np == 1: with ops.colocate_with(params[0]): ret = maybe_normalize(_do_gather(params[0], ids)) ignore_weights = weights is None if not ignore_weights: if weights.dtype != ret.dtype: weights = math_ops.cast(weights, ret.dtype) # Reshape to allow broadcast ones = array_ops.fill( array_ops.expand_dims(array_ops.rank(ret) - 1, 0), 1) bcast_weights_shape = array_ops.concat( [array_ops.shape(weights), ones], 0) orig_weights_shape = weights.get_shape() weights = array_ops.reshape(weights, bcast_weights_shape) # Set weights shape after reshape if ret.get_shape().ndims is not None: weights.set_shape( orig_weights_shape.concatenate( [1 for _ in range(ret.get_shape().ndims - 1)])) ret *= weights return math_ops.segment_sum(ret, segment_ids, name=name) else: return math_ops.sparse_segment_sum(ret, idx, segment_ids, name=name) else: ids = ops.convert_to_tensor(ids, name="ids") flat_ids = array_ops.reshape(ids, [-1]) original_indices = math_ops.range(array_ops.size(flat_ids)) # Create p_assignments and set new_ids depending on the strategy. if partition_strategy == "mod": p_assignments = flat_ids % np new_ids = flat_ids // np elif partition_strategy == "div": # Compute num_total_ids as the sum of dim-0 of params, then assign to # partitions based on a constant number of ids per partition. Optimize # if we already know the full shape statically. dim_0_size = params[0].get_shape()[0] for p in xrange(1, np): dim_0_size += params[p].get_shape()[0] if dim_0_size.value: num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype) else: dim_0_sizes = [] for p in xrange(np): if params[p].get_shape()[0].value is not None: dim_0_sizes.append(params[p].get_shape()[0].value) else: with ops.colocate_with(params[p]): dim_0_sizes.append(array_ops.shape(params[p])[0]) num_total_ids = math_ops.reduce_sum( math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype)) ids_per_partition = num_total_ids // np extras = num_total_ids % np p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1), ( flat_ids - extras) // ids_per_partition) # Emulate a conditional using a boolean indicator tensor is_in_first_extras_partitions = math_ops.cast(p_assignments < extras, flat_ids.dtype) new_ids = (is_in_first_extras_partitions * (flat_ids % (ids_per_partition + 1)) + (1 - is_in_first_extras_partitions) * ( (flat_ids - extras) % ids_per_partition)) else: raise ValueError("Unrecognized partition strategy: " + partition_strategy) # Cast partition assignments to int32 for use in dynamic_partition. # There really should not be more than 2^32 partitions. p_assignments = math_ops.cast(p_assignments, dtypes.int32) # Partition list of ids based on assignments into np separate lists gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np) # Similarly, partition the original indices. pindices = data_flow_ops.dynamic_partition(original_indices, p_assignments, np) # Do np separate lookups, finding embeddings for plist[p] in params[p] partitioned_result = [] for p in xrange(np): with ops.colocate_with(params[p]): partitioned_result.append(_do_gather(params[p], gather_ids[p])) ignore_weights = weights is None if not ignore_weights: # Partition weights according to pindices. partitioned_weight = [] for p in xrange(np): partitioned_weight.append(array_ops.gather(weights, pindices[p])) # Reshape each partition result. element_shape = params[0].get_shape()[1:] for p in params[1:]: element_shape = element_shape.merge_with(p.get_shape()[1:]) if element_shape.is_fully_defined(): for p in xrange(np): with ops.colocate_with(params[p]): partitioned_result[p] = array_ops.reshape( partitioned_result[p], array_ops.concat([array_ops.shape(pindices[p]), element_shape], 0)) else: with ops.colocate_with(params[0]): params_shape = array_ops.shape(params[0]) for p in xrange(np): with ops.colocate_with(params[p]): partitioned_result[p] = array_ops.reshape( partitioned_result[p], array_ops.concat([ array_ops.shape(pindices[p]), array_ops.slice( params_shape, [1], [-1]) ], 0)) # Normalize each partition result. for p in xrange(np): with ops.colocate_with(params[p]): partitioned_result[p] = maybe_normalize(partitioned_result[p]) if not ignore_weights: # Multiply each partition result with partition weights. for p in xrange(np): with ops.colocate_with(params[p]): if partitioned_weight[p].dtype != partitioned_result[p].dtype: partitioned_weight[p] = math_ops.cast(partitioned_weight[p], partitioned_result[p].dtype) # Reshape partition weights. ones = array_ops.fill( array_ops.expand_dims( array_ops.rank(partitioned_result[p]) - 1, 0), 1) bcast_weights_shape = array_ops.concat( [array_ops.shape(partitioned_weight[p]), ones], 0) orig_weights_shape = partitioned_weight[p].get_shape() partitioned_weight[p] = array_ops.reshape(partitioned_weight[p], bcast_weights_shape) if partitioned_result[p].get_shape().ndims is not None: partitioned_weight[p].set_shape( orig_weights_shape.concatenate([ 1 for _ in range(partitioned_result[p].get_shape().ndims - 1) ])) partitioned_result[p] *= partitioned_weight[p] partitioned_segment_ids = [] for p in xrange(np): if not ignore_weights: # Partition segment_ids according to pindices. p_segment_ids = array_ops.gather(segment_ids, pindices[p]) # Number the p_segment_ids to meet segment_sum's requirements. Note # that unique_p_segment_ids contains unique segment ids of this # partition and these ids' order is unchanged. unique_p_segment_ids, unique_p_segment_idx = array_ops.unique( p_segment_ids) partitioned_segment_ids.append(unique_p_segment_ids) # segment_sum this partition's result. with ops.colocate_with(params[p]): partitioned_result[p] = math_ops.segment_sum( partitioned_result[p], unique_p_segment_idx) else: # When ignore weights, we need to get indexs of elements in idx and # segment_ids. _, exclude_idx = array_ops.setdiff1d(idx, pindices[p]) all_idx = math_ops.range(array_ops.shape(idx)[0]) _, include_idx = array_ops.setdiff1d(all_idx, exclude_idx) # Gather segment_ids and idx according to indexs. p_segment_ids = array_ops.gather(segment_ids, include_idx) p_idx = array_ops.gather(idx, include_idx) # Number the p_segment_ids, same as ignore_weights case above. unique_p_segment_ids, unique_p_segment_idx = array_ops.unique( p_segment_ids) _, unique_p_idx_idx = array_ops.unique(p_idx) partitioned_segment_ids.append(unique_p_segment_ids) with ops.colocate_with(params[p]): partitioned_result[p] = math_ops.sparse_segment_sum( partitioned_result[p], unique_p_idx_idx, unique_p_segment_idx) # Concat each partition's segment_ids and result for final segment_sum. concat_segment_ids = array_ops.concat(partitioned_segment_ids, 0) concat_partitioned_result = array_ops.concat(partitioned_result, 0) return math_ops.unsorted_segment_sum( concat_partitioned_result, concat_segment_ids, math_ops.reduce_max(concat_segment_ids) + 1, name=name)
def _sample_n(self, n, seed=None): with ops.control_dependencies(self._assertions): n = ops.convert_to_tensor(n, name="n") static_n = tensor_util.constant_value(n) n = int(static_n) if static_n is not None else n cat_samples = self.cat.sample(n, seed=seed) static_samples_shape = cat_samples.get_shape() if static_samples_shape.is_fully_defined(): samples_shape = static_samples_shape.as_list() samples_size = static_samples_shape.num_elements() else: samples_shape = array_ops.shape(cat_samples) samples_size = array_ops.size(cat_samples) static_batch_shape = self.get_batch_shape() if static_batch_shape.is_fully_defined(): batch_shape = static_batch_shape.as_list() batch_size = static_batch_shape.num_elements() else: batch_shape = self.batch_shape() batch_size = array_ops.reduce_prod(batch_shape) static_event_shape = self.get_event_shape() if static_event_shape.is_fully_defined(): event_shape = np.array(static_event_shape.as_list(), dtype=np.int32) else: event_shape = self.event_shape() # Get indices into the raw cat sampling tensor. We will # need these to stitch sample values back out after sampling # within the component partitions. samples_raw_indices = array_ops.reshape( math_ops.range(0, samples_size), samples_shape) # Partition the raw indices so that we can use # dynamic_stitch later to reconstruct the samples from the # known partitions. partitioned_samples_indices = data_flow_ops.dynamic_partition( data=samples_raw_indices, partitions=cat_samples, num_partitions=self.num_components) # Copy the batch indices n times, as we will need to know # these to pull out the appropriate rows within the # component partitions. batch_raw_indices = array_ops.reshape( array_ops.tile(math_ops.range(0, batch_size), [n]), samples_shape) # Explanation of the dynamic partitioning below: # batch indices are i.e., [0, 1, 0, 1, 0, 1] # Suppose partitions are: # [1 1 0 0 1 1] # After partitioning, batch indices are cut as: # [batch_indices[x] for x in 2, 3] # [batch_indices[x] for x in 0, 1, 4, 5] # i.e. # [1 1] and [0 0 0 0] # Now we sample n=2 from part 0 and n=4 from part 1. # For part 0 we want samples from batch entries 1, 1 (samples 0, 1), # and for part 1 we want samples from batch entries 0, 0, 0, 0 # (samples 0, 1, 2, 3). partitioned_batch_indices = data_flow_ops.dynamic_partition( data=batch_raw_indices, partitions=cat_samples, num_partitions=self.num_components) samples_class = [None for _ in range(self.num_components)] for c in range(self.num_components): n_class = array_ops.size(partitioned_samples_indices[c]) seed = distribution_util.gen_new_seed(seed, "mixture") samples_class_c = self.components[c].sample(n_class, seed=seed) # Pull out the correct batch entries from each index. # To do this, we may have to flatten the batch shape. # For sample s, batch element b of component c, we get the # partitioned batch indices from # partitioned_batch_indices[c]; and shift each element by # the sample index. The final lookup can be thought of as # a matrix gather along locations (s, b) in # samples_class_c where the n_class rows correspond to # samples within this component and the batch_size columns # correspond to batch elements within the component. # # Thus the lookup index is # lookup[c, i] = batch_size * s[i] + b[c, i] # for i = 0 ... n_class[c] - 1. lookup_partitioned_batch_indices = ( batch_size * math_ops.range(n_class) + partitioned_batch_indices[c]) samples_class_c = array_ops.reshape( samples_class_c, array_ops.concat(([n_class * batch_size], event_shape), 0)) samples_class_c = array_ops.gather( samples_class_c, lookup_partitioned_batch_indices, name="samples_class_c_gather") samples_class[c] = samples_class_c # Stitch back together the samples across the components. lhs_flat_ret = data_flow_ops.dynamic_stitch( indices=partitioned_samples_indices, data=samples_class) # Reshape back to proper sample, batch, and event shape. ret = array_ops.reshape(lhs_flat_ret, array_ops.concat((samples_shape, self.event_shape()), 0)) ret.set_shape( tensor_shape.TensorShape(static_samples_shape).concatenate( self.get_event_shape())) return ret
def minimize(self, global_step=None, name=None): """Add operations to train a linear model by minimizing the loss function. Args: global_step: Optional `Variable` to increment by one after the variables have been updated. name: Optional name for the returned operation. Returns: An Operation that updates the variables passed in the constructor. """ # Technically, the op depends on a lot more than the variables, # but we'll keep the list short. with name_scope(name, 'sdca/minimize'): sparse_example_indices = [] sparse_feature_indices = [] sparse_features_values = [] for sf in self._examples['sparse_features']: sparse_example_indices.append(sf.example_indices) sparse_feature_indices.append(sf.feature_indices) # If feature values are missing, sdca assumes a value of 1.0f. if sf.feature_values is not None: sparse_features_values.append(sf.feature_values) # pylint: disable=protected-access example_ids_hashed = gen_sdca_ops.sdca_fprint( internal_convert_to_tensor(self._examples['example_ids'])) # pylint: enable=protected-access example_state_data = self._hashtable.lookup(example_ids_hashed) # Solver returns example_state_update, new delta sparse_feature_weights # and delta dense_feature_weights. sparse_weights = [] sparse_indices = [] # If we have partitioned variables, keep a few dictionaries of Tensors # around that we need for the assign_add after the op call to # gen_sdca_ops.sdca_optimizer(). These are keyed because we may have a # mix of partitioned and un-partitioned variables. num_partitions_by_var = {} p_assignments_by_var = {} gather_ids_by_var = {} for v_num, (w, i) in enumerate( zip(self._slots['unshrinked_sparse_features_weights'], sparse_feature_indices)): # Append the sparse_indices (in full-variable space). sparse_idx = math_ops.cast( array_ops.unique(math_ops.cast(i, dtypes.int32))[0], dtypes.int64) sparse_indices.append(sparse_idx) if isinstance(w, list) or isinstance(w, var_ops.PartitionedVariable): num_partitions = len(w) flat_ids = array_ops.reshape(sparse_idx, [-1]) # We use div partitioning, which is easiest to support downstream. # Compute num_total_ids as the sum of dim-0 of w, then assign # to partitions based on a constant number of ids per partition. # Optimize if we already know the full shape statically. dim_0_size = self._get_first_dimension_size_statically( w, num_partitions) if tensor_shape.dimension_value(dim_0_size): num_total_ids = constant_op.constant( tensor_shape.dimension_value(dim_0_size), flat_ids.dtype) else: dim_0_sizes = [] for p in range(num_partitions): if tensor_shape.dimension_value(w[p].shape[0]) is not None: dim_0_sizes.append(tensor_shape.dimension_value(w[p].shape[0])) else: with ops.colocate_with(w[p]): dim_0_sizes.append(array_ops.shape(w[p])[0]) num_total_ids = math_ops.reduce_sum( math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype)) ids_per_partition = num_total_ids // num_partitions extras = num_total_ids % num_partitions p_assignments = math_ops.maximum( flat_ids // (ids_per_partition + 1), (flat_ids - extras) // ids_per_partition) # Emulate a conditional using a boolean indicator tensor new_ids = array_ops.where(p_assignments < extras, flat_ids % (ids_per_partition + 1), (flat_ids - extras) % ids_per_partition) # Cast partition assignments to int32 for use in dynamic_partition. # There really should not be more than 2^32 partitions. p_assignments = math_ops.cast(p_assignments, dtypes.int32) # Partition list of ids based on assignments into num_partitions # separate lists. gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, num_partitions) # Add these into the dictionaries for use in the later update. num_partitions_by_var[v_num] = num_partitions p_assignments_by_var[v_num] = p_assignments gather_ids_by_var[v_num] = gather_ids # Gather the weights from each partition. partition_gathered_weights = [] for p in range(num_partitions): with ops.colocate_with(w[p]): partition_gathered_weights.append( array_ops.gather(w[p], gather_ids[p])) # Stitch the weights back together in the same order they were before # we dynamic_partitioned them. condition_indices = data_flow_ops.dynamic_partition( math_ops.range(array_ops.shape(new_ids)[0]), p_assignments, num_partitions) batch_gathered_weights = data_flow_ops.dynamic_stitch( condition_indices, partition_gathered_weights) else: w_as_tensor = internal_convert_to_tensor(w) with ops.device(w_as_tensor.device): batch_gathered_weights = array_ops.gather( w_as_tensor, sparse_idx) sparse_weights.append(batch_gathered_weights) # pylint: disable=protected-access if compat.forward_compatible(year=2018, month=10, day=30): esu, sfw, dfw = gen_sdca_ops.sdca_optimizer_v2( sparse_example_indices, sparse_feature_indices, sparse_features_values, self._convert_n_to_tensor(self._examples['dense_features']), internal_convert_to_tensor(self._examples['example_weights']), internal_convert_to_tensor(self._examples['example_labels']), sparse_indices, sparse_weights, self._convert_n_to_tensor(self._slots[ 'unshrinked_dense_features_weights']), example_state_data, loss_type=self._options['loss_type'], l1=self._options['symmetric_l1_regularization'], l2=self._symmetric_l2_regularization(), num_loss_partitions=self._num_loss_partitions(), num_inner_iterations=1, adaptive=self._adaptive()) else: esu, sfw, dfw = gen_sdca_ops.sdca_optimizer( sparse_example_indices, sparse_feature_indices, sparse_features_values, self._convert_n_to_tensor(self._examples['dense_features']), internal_convert_to_tensor(self._examples['example_weights']), internal_convert_to_tensor(self._examples['example_labels']), sparse_indices, sparse_weights, self._convert_n_to_tensor(self._slots[ 'unshrinked_dense_features_weights']), example_state_data, loss_type=self._options['loss_type'], l1=self._options['symmetric_l1_regularization'], l2=self._symmetric_l2_regularization(), num_loss_partitions=self._num_loss_partitions(), num_inner_iterations=1, adaptative=self._adaptive()) # pylint: enable=protected-access with ops.control_dependencies([esu]): update_ops = [self._hashtable.insert(example_ids_hashed, esu)] # Update the weights before the proximal step. for v_num, (w, i, u) in enumerate( zip(self._slots['unshrinked_sparse_features_weights'], sparse_indices, sfw)): if (isinstance(w, var_ops.PartitionedVariable) or isinstance(w, list)): update_ops += self._get_partitioned_update_ops( v_num, num_partitions_by_var, p_assignments_by_var, gather_ids_by_var, w, u, p_assignments, num_partitions) else: update_ops.append(state_ops.scatter_add(w, i, u)) for w, u in zip(self._slots['unshrinked_dense_features_weights'], dfw): if (isinstance(w, var_ops.PartitionedVariable) or isinstance(w, list)): split_updates = array_ops.split( u, num_or_size_splits=[v.shape.as_list()[0] for v in w]) for v, split_update in zip(w, split_updates): update_ops.append(state_ops.assign_add(v, split_update)) else: update_ops.append(state_ops.assign_add(w, u)) if not global_step: return control_flow_ops.group(*update_ops) with ops.control_dependencies(update_ops): return state_ops.assign_add(global_step, 1, name=name).op
def embedding_lookup(params, ids, partition_strategy="mod", name=None, validate_indices=True, max_norm=None): """Looks up `ids` in a list of embedding tensors. This function is used to perform parallel lookups on the list of tensors in `params`. It is a generalization of [`tf.gather()`](../../api_docs/python/array_ops.md#gather), where `params` is interpreted as a partitioning of a large embedding tensor. `params` may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a partitioner. If `len(params) > 1`, each element `id` of `ids` is partitioned between the elements of `params` according to the `partition_strategy`. In all strategies, if the id space does not evenly divide the number of partitions, each of the first `(max_id + 1) % len(params)` partitions will be assigned one more id. If `partition_strategy` is `"mod"`, we assign each id to partition `p = id % len(params)`. For instance, 13 ids are split across 5 partitions as: `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]` If `partition_strategy` is `"div"`, we assign ids to partitions in a contiguous manner. In this case, 13 ids are split across 5 partitions as: `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]` The results of the lookup are concatenated into a dense tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. Args: params: A list of tensors with the same type and which can be concatenated along dimension 0. Alternatively, a `PartitionedVariable`, created by partitioning along dimension 0. Each element must be appropriately sized for the given `partition_strategy`. ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked up in `params`. partition_strategy: A string specifying the partitioning strategy, relevant if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default is `"mod"`. name: A name for the operation (optional). validate_indices: Whether or not to validate gather indices. max_norm: If not None, embedding values are l2-normalized to the value of max_norm. Returns: A `Tensor` with the same type as the tensors in `params`. Raises: ValueError: If `params` is empty. """ if params is None or params == []: # pylint: disable=g-explicit-bool-comparison raise ValueError("Need at least one param") if isinstance(params, variables.PartitionedVariable): params = list(params) # Iterate to get the underlying Variables. if not isinstance(params, list): params = [params] def maybe_normalize(x): if max_norm is not None: if x.get_shape().ndims is not None: ndims = x.get_shape().ndims else: ndims = array_ops.size(array_ops.shape(x)) return clip_ops.clip_by_norm(x, max_norm, axes=list(range(1, ndims))) return x with ops.name_scope(name, "embedding_lookup", params + [ids]) as name: np = len(params) # Number of partitions params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params") if np == 1: with ops.colocate_with(params[0]): # TODO(apassos): implement the sharded version as well. if isinstance(params[0], resource_variable_ops.ResourceVariable): ret = params[0].sparse_read(ids, name=name) else: ret = array_ops.gather(params[0], ids, name=name, validate_indices=validate_indices) return maybe_normalize(ret) else: ids = ops.convert_to_tensor(ids, name="ids") flat_ids = array_ops.reshape(ids, [-1]) original_indices = math_ops.range(array_ops.size(flat_ids)) # Create p_assignments and set new_ids depending on the strategy. if partition_strategy == "mod": p_assignments = flat_ids % np new_ids = flat_ids // np elif partition_strategy == "div": # Compute num_total_ids as the sum of dim-0 of params, then assign to # partitions based on a constant number of ids per partition. Optimize # if we already know the full shape statically. dim_0_size = params[0].get_shape()[0] for p in xrange(1, np): dim_0_size += params[p].get_shape()[0] if dim_0_size.value: num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype) else: dim_0_sizes = [] for p in xrange(np): if params[p].get_shape()[0].value is not None: dim_0_sizes.append(params[p].get_shape()[0].value) else: with ops.colocate_with(params[p]): dim_0_sizes.append(array_ops.shape(params[p])[0]) num_total_ids = math_ops.reduce_sum( math_ops.cast(array_ops.pack(dim_0_sizes), flat_ids.dtype)) ids_per_partition = num_total_ids // np extras = num_total_ids % np p_assignments = math_ops.maximum( flat_ids // (ids_per_partition + 1), (flat_ids - extras) // ids_per_partition) # Emulate a conditional using a boolean indicator tensor is_in_first_extras_partitions = math_ops.cast( p_assignments < extras, flat_ids.dtype) new_ids = ( is_in_first_extras_partitions * ( flat_ids % (ids_per_partition + 1)) + (1 - is_in_first_extras_partitions) * ( (flat_ids - extras) % ids_per_partition)) else: raise ValueError("Unrecognized partition strategy: " + partition_strategy) # Cast partition assignments to int32 for use in dynamic_partition. # There really should not be more than 2^32 partitions. p_assignments = math_ops.cast(p_assignments, dtypes.int32) # Partition list of ids based on assignments into np separate lists gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np) # Similarly, partition the original indices. pindices = data_flow_ops.dynamic_partition(original_indices, p_assignments, np) # Do np separate lookups, finding embeddings for plist[p] in params[p] partitioned_result = [] for p in xrange(np): with ops.colocate_with(params[p]): partitioned_result.append(array_ops.gather( params[p], gather_ids[p], validate_indices=validate_indices)) # Stitch these back together ret = data_flow_ops.dynamic_stitch(pindices, partitioned_result, name=name) # Reshape to reverse the flattening of ids. element_shape = params[0].get_shape()[1:] for p in params[1:]: element_shape = element_shape.merge_with(p.get_shape()[1:]) if element_shape.is_fully_defined(): ret = array_ops.reshape(ret, array_ops.concat(0, [ array_ops.shape(ids), element_shape])) else: # It's important that we compute params[0].shape on the right device # to avoid data motion. with ops.colocate_with(params[0]): params_shape = array_ops.shape(params[0]) ret = array_ops.reshape(ret, array_ops.concat(0, [ array_ops.shape(ids), array_ops.slice(params_shape, [1], [-1])])) # output shape = ids.shape + params[*].shape[1:] # Normally the reshape is sufficient, but setting shape explicitly # teaches shape inference that params[1:].get_shape() matters. ret.set_shape(ids.get_shape().concatenate(element_shape)) return maybe_normalize(ret)
def testErrorWrongDimsIndices(self): data = constant_op.constant([[0], [1], [2]]) indices = constant_op.constant([[0], [0]]) with self.assertRaises(ValueError): data_flow_ops.dynamic_partition(data, indices, num_partitions=4)
def embedding_lookup(params, ids, name=None): """Looks up `ids` in a list of embedding tensors. This function is used to perform parallel lookups on the list of tensors in `params`. It is a generalization of [`tf.gather()`](../../api_docs/python/array_ops.md#gather), where `params` is interpreted as a partition of a larger embedding tensor. If `len(params) > 1`, each element `id` of `ids` is partitioned between the elements of `params` by computing `p = id % len(params)`, and is then used to look up the slice `params[p][id // len(params), ...]`. The results of the lookup are then concatenated into a dense tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. Args: params: A list of tensors with the same shape and type. ids: A `Tensor` with type `int32` containing the ids to be looked up in `params`. name: A name for the operation (optional). Returns: A `Tensor` with the same type as the tensors in `params`. Raises: ValueError: If `params` is empty. """ if not isinstance(params, list): params = [params] with ops.op_scope(params + [ids], name, "embedding_lookup") as name: if not params: raise ValueError("Need at least one param") np = len(params) # Number of partitions params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params") if np == 1: with ops.device(params[0].device): return array_ops.gather(params[0], ids, name=name) else: ids = ops.convert_to_tensor(ids, name="ids") flat_ids = array_ops.reshape(ids, [-1]) original_indices = math_ops.range(0, array_ops.size(flat_ids)) # Compute flat_ids % partitions for each id ids_mod_p = flat_ids % np if ids_mod_p.dtype != types.int32: ids_mod_p = math_ops.cast(ids_mod_p, types.int32) # Partition single list of ids based on ids % np into np separate lists plist = data_flow_ops.dynamic_partition(flat_ids, ids_mod_p, np) # Similarly, partition the original indices. pindices = data_flow_ops.dynamic_partition(original_indices, ids_mod_p, np) # Do np separate lookups, finding embeddings for plist[p] in params[p] partitioned_result = [] for p in xrange(np): # TODO(agarwal): handle device allocations here and later in the # colocate code. gather_ids = plist[p] // np with ops.device(params[p].device): partitioned_result.append(array_ops.gather(params[p], gather_ids)) # Stitch these back together ret = data_flow_ops.dynamic_stitch(pindices, partitioned_result, name=name) # Reshape to reverse the flattening of ids. # It's important that we compute params[0].shape on the right device # to avoid data motion. with ops.device(params[0].device): params_shape = array_ops.shape(params[0]) ret = array_ops.reshape(ret, array_ops.concat(0, [ array_ops.shape(ids), array_ops.slice(params_shape, [1], [-1])])) # output shape = ids.shape + params[*].shape[1:] # Normally the reshape is sufficient, but setting shape explicitly # teaches shape inference that params[1:].get_shape() matters. element_shape = params[0].get_shape()[1:] for p in params[1:]: element_shape = element_shape.merge_with(p.get_shape()[1:]) ret.set_shape(ids.get_shape().concatenate(element_shape)) return ret