def output(self, Q, input_state_1, input_state_2, **kwargs): with tf.name_scope("Q-output"): # Number of samples depend on the state's batch size. # Each iteration we can try to predict direction from # multiple different starting points at the same time. input_shape = tf.shape(input_state_1) n_states = input_shape[1] Q_shape = tf.shape(Q) indeces = tf.stack( [ # Numer of repetitions depends on the size of # the state batch tf_utils.repeat(tf.range(Q_shape[0]), n_states), # Each state is a coordinate (x and y) # that point to some place on a grid. tf.cast(tf_utils.flatten(input_state_1), tf.int32), tf.cast(tf_utils.flatten(input_state_2), tf.int32), ], axis=1) # Output is a matrix that has n_samples * n_states rows # and n_filters (which is Q.shape[1]) columns. return tf.gather_nd(Q, indeces)
def loss_function(expected, predicted): epsilon = 1e-7 # for 32-bit float predicted = tf.clip_by_value(predicted, epsilon, 1.0 - epsilon) expected = tf.cast(tf_utils.flatten(expected), tf.int32) log_predicted = tf.log(predicted) indeces = tf.stack([tf.range(tf.size(expected)), expected], axis=1) errors = tf.gather_nd(log_predicted, indeces) return -tf.reduce_mean(errors)
def test_flatten(in_shape, out_shape): X = np.random.random(in_shape) Y = tf_utils.tensorflow_eval(tf_utils.flatten(X)) assert Y.shape == out_shape