def output(self, Q, input_state_1, input_state_2, **kwargs): with tf.name_scope("Q-output"): # Number of samples depend on the state's batch size. # Each iteration we can try to predict direction from # multiple different starting points at the same time. input_shape = tf.shape(input_state_1) n_states = input_shape[1] Q_shape = tf.shape(Q) indeces = tf.stack( [ # Numer of repetitions depends on the size of # the state batch tf_utils.repeat(tf.range(Q_shape[0]), n_states), # Each state is a coordinate (x and y) # that point to some place on a grid. tf.cast(tf_utils.flatten(input_state_1), tf.int32), tf.cast(tf_utils.flatten(input_state_2), tf.int32), ], axis=1) # Output is a matrix that has n_samples * n_states rows # and n_filters (which is Q.shape[1]) columns. return tf.gather_nd(Q, indeces)
def test_repeat(self): matrix = np.array([ [1, 2], [3, 4], ]) actual = self.eval(tf_utils.repeat(matrix, (2, 3))) expected = np.array([ [1, 1, 1, 2, 2, 2], [1, 1, 1, 2, 2, 2], [3, 3, 3, 4, 4, 4], [3, 3, 3, 4, 4, 4], ]) np.testing.assert_array_equal(actual, expected)
def output(self, input_value, **kwargs): input_value = tf.convert_to_tensor(input_value, dtype=tf.float32) self.fail_if_shape_invalid(input_value.shape) return tf_utils.repeat(input_value, as_tuple(1, self.scale, 1))