def train_step(self, inputs, state, calc_intrinsic_reward=True): """ Args: inputs (tuple): observation state (tuple): empty tuple () calc_intrinsic_reward (bool): if False, only return the losses Returns: TrainStep: outputs: empty tuple () state: empty tuple () info: RNDInfo """ observation, _ = inputs if self._observation_normalizer is not None: observation = self._observation_normalizer.normalize(observation) pred_embedding, _ = self._predictor_net(observation) target_embedding, _ = self._target_net(observation) loss = 0.5 * tf.reduce_mean( tf.square(pred_embedding - tf.stop_gradient(target_embedding)), axis=-1) intrinsic_reward = () if calc_intrinsic_reward: intrinsic_reward = tf.stop_gradient(loss) intrinsic_reward = self._reward_normalizer.normalize( intrinsic_reward) return AlgorithmStep(outputs=(), state=(), info=RNDInfo(reward=intrinsic_reward, loss=LossInfo(loss=loss)))
def __call__(self, training_info: TrainingInfo, value): """Cacluate actor critic loss The first dimension of all the tensors is time dimension and the second dimesion is the batch dimension. Args: training_info (TrainingInfo): training_info collected by (On/Off)PolicyDriver. All tensors in training_info are time-major value (tf.Tensor): the time-major tensor for the value at each time step final_value (tf.Tensor): the value at one step ahead. Returns: loss_info (LossInfo): with loss_info.extra being ActorCriticLossInfo """ returns, advantages = self._calc_returns_and_advantages( training_info, value) def _summary(): with tf.name_scope('ActorCriticLoss'): tf.summary.scalar("values", tf.reduce_mean(value)) tf.summary.scalar("returns", tf.reduce_mean(returns)) tf.summary.scalar("advantages/mean", tf.reduce_mean(advantages)) tf.summary.histogram("advantages/value", advantages) tf.summary.scalar("explained_variance_of_return_by_value", common.explained_variance(value, returns)) if self._debug_summaries: common.run_if(common.should_record_summaries(), _summary) if self._normalize_advantages: advantages = _normalize_advantages(advantages, axes=(0, 1)) if self._advantage_clip: advantages = tf.clip_by_value(advantages, -self._advantage_clip, self._advantage_clip) pg_loss = self._pg_loss(training_info, tf.stop_gradient(advantages)) td_loss = self._td_error_loss_fn(tf.stop_gradient(returns), value) loss = pg_loss + self._td_loss_weight * td_loss entropy_loss = () if self._entropy_regularization is not None: entropy, entropy_for_gradient = dist_utils.entropy_with_fallback( training_info.action_distribution, self._action_spec) entropy_loss = -entropy loss -= self._entropy_regularization * entropy_for_gradient return LossInfo(loss=loss, extra=ActorCriticLossInfo(td_loss=td_loss, pg_loss=pg_loss, entropy_loss=entropy_loss))
def train_step(self, time_step: ActionTimeStep, state, calc_intrinsic_reward=True): """ Args: time_step (ActionTimeStep): input time_step data for ICM state (Tensor): state for ICM (previous observation) calc_intrinsic_reward (bool): if False, only return the losses Returns: TrainStep: outputs: empty tuple () state: observation info (ICMInfo): """ feature = time_step.observation prev_action = time_step.prev_action if self._encoding_net is not None: feature, _ = self._encoding_net(feature) prev_feature = state prev_action = self._encode_action(prev_action) forward_pred, _ = self._forward_net( inputs=[tf.stop_gradient(prev_feature), prev_action]) forward_loss = 0.5 * tf.reduce_mean( tf.square(tf.stop_gradient(feature) - forward_pred), axis=-1) action_pred, _ = self._inverse_net(inputs=[prev_feature, feature]) if tensor_spec.is_discrete(self._action_spec): inverse_loss = tf.nn.softmax_cross_entropy_with_logits( labels=prev_action, logits=action_pred) else: inverse_loss = 0.5 * tf.reduce_mean( tf.square(prev_action - action_pred), axis=-1) intrinsic_reward = () if calc_intrinsic_reward: intrinsic_reward = tf.stop_gradient(forward_loss) intrinsic_reward = self._reward_normalizer.normalize( intrinsic_reward) return AlgorithmStep( outputs=(), state=feature, info=ICMInfo(reward=intrinsic_reward, loss=LossInfo(loss=forward_loss + inverse_loss, extra=dict(forward_loss=forward_loss, inverse_loss=inverse_loss))))
def gradient_descent_step(loss, variables, stop_grads, allow_grads_to_batch_norm_vars, learning_rate, get_update_ops=True): """Returns the updated vars after one step of gradient descent.""" grads = tf.gradients(loss, variables) if stop_grads: grads = [tf.stop_gradient(dv) for dv in grads] def _apply_grads(variables, grads): """Applies gradients using SGD on a list of variables.""" v_new, update_ops = [], [] for (v, dv) in zip(variables, grads): if (not allow_grads_to_batch_norm_vars and ('offset' in v.name or 'scale' in v.name)): updated_value = v # no update. else: updated_value = v - learning_rate * dv # gradient descent update. if get_update_ops: update_ops.append(tf.assign(v, updated_value)) v_new.append(updated_value) return v_new, update_ops updated_vars, update_ops = _apply_grads(variables, grads) return {'updated_vars': updated_vars, 'update_ops': update_ops}
def _build_train_op(self): """Builds the training op for Rainbow. Returns: train_op: An op performing one step of training. """ replay_action_one_hot = tf.one_hot(self._replay.actions, self.num_actions, 1., 0., name='action_one_hot') replay_chosen_q = tf.reduce_sum(self._replay_qs * replay_action_one_hot, reduction_indices=1, name='replay_chosen_q') target = tf.stop_gradient(self._build_target_q_op()) loss = tf.losses.huber_loss(target, replay_chosen_q, reduction=tf.losses.Reduction.NONE) update_priorities_op = self._replay.tf_set_priority( self._replay.indices, tf.sqrt(loss + 1e-10)) target_priorities = self._replay.tf_get_priority(self._replay.indices) target_priorities = tf.math.add(target_priorities, 1e-10) target_priorities = 1.0 / tf.sqrt(target_priorities) target_priorities /= tf.reduce_max(target_priorities) weighted_loss = target_priorities * loss with tf.control_dependencies([update_priorities_op]): return self.optimizer.minimize( tf.reduce_mean(weighted_loss)), weighted_loss
def safety_critic_loss(tf_agent, safety_critic, time_steps, actions, next_time_steps, safety_rewards, weights=None): """Returns a critic loss with safety.""" next_actions, next_log_pis = tf_agent._actions_and_log_probs( # pylint: disable=protected-access next_time_steps) del next_log_pis target_input = (next_time_steps.observation[0], next_actions[0]) target_q_values, unused_network_state1 = safety_critic( target_input, next_time_steps.step_type[0]) target_q_values = tf.nn.sigmoid(target_q_values) safety_rewards = tf.to_float(safety_rewards) td_targets = tf.stop_gradient(safety_rewards + (1 - safety_rewards) * next_time_steps.discount * target_q_values) td_targets = tf.squeeze(td_targets) pred_input = (time_steps.observation[0], actions[0]) pred_td_targets, unused_network_state1 = safety_critic( pred_input, time_steps.step_type[0]) loss = tf.losses.sigmoid_cross_entropy(td_targets, pred_td_targets) if weights is not None: loss *= tf.to_float(tf.squeeze(weights)) # Take the mean across the batch. loss = tf.reduce_mean(input_tensor=loss) return loss
def _build_train_op(self): """Builds the training op for Rainbow. Returns: train_op: An op performing one step of training. """ target_distribution = tf.stop_gradient(self._build_target_distribution()) # size of indices: batch_size x 1. indices = tf.range(tf.shape(self._replay_logits)[0])[:, None] # size of reshaped_actions: batch_size x 2. reshaped_actions = tf.concat([indices, self._replay.actions[:, None]], 1) # For each element of the batch, fetch the logits for its selected action. chosen_action_logits = tf.gather_nd(self._replay_logits, reshaped_actions) loss = tf.nn.softmax_cross_entropy_with_logits( labels=target_distribution, logits=chosen_action_logits) optimizer = tf.train.AdamOptimizer( learning_rate=self.learning_rate, epsilon=self.optimizer_epsilon) update_priorities_op = self._replay.tf_set_priority( self._replay.indices, tf.sqrt(loss + 1e-10)) target_priorities = self._replay.tf_get_priority(self._replay.indices) target_priorities = tf.math.add(target_priorities, 1e-10) target_priorities = 1.0 / tf.sqrt(target_priorities) target_priorities /= tf.reduce_max(target_priorities) weighted_loss = target_priorities * loss with tf.control_dependencies([update_priorities_op]): return optimizer.minimize(tf.reduce_mean(weighted_loss)), weighted_loss
def actor_loss_fn(dqda, action): if self._dqda_clipping: dqda = tf.clip_by_value(dqda, -self._dqda_clipping, self._dqda_clipping) loss = 0.5 * losses.element_wise_squared_loss( tf.stop_gradient(dqda + action), action) loss = tf.reduce_sum(loss, axis=list(range(1, len(loss.shape)))) return loss
def train_step(self, time_step: ActionTimeStep, state, calc_intrinsic_reward=True): """ Args: time_step (ActionTimeStep): input time_step data state (tuple): empty tuple () calc_intrinsic_reward (bool): if False, only return the losses Returns: TrainStep: outputs: empty tuple () state: empty tuple () info: ICMInfo """ observation = time_step.observation if self._stacked_frames: # Assuming stacking in the last dim, we only keep the last frame. observation = observation[..., -1:] if self._observation_normalizer is not None: observation = self._observation_normalizer.normalize(observation) if self._encoder_net is not None: observation = tf.stop_gradient(self._encoder_net(observation)[0]) pred_embedding, _ = self._predictor_net(observation) target_embedding, _ = self._target_net(observation) loss = tf.reduce_sum( tf.square(pred_embedding - tf.stop_gradient(target_embedding)), axis=-1) intrinsic_reward = () if calc_intrinsic_reward: intrinsic_reward = tf.stop_gradient(loss) if self._reward_normalizer: intrinsic_reward = self._reward_normalizer.normalize( intrinsic_reward, clip_value=self._reward_clip_value) return AlgorithmStep( outputs=(), state=(), info=ICMInfo(reward=intrinsic_reward, loss=LossInfo(loss=loss)))
def _four_layer_convnet(inputs, is_training, scope, weight_decay, reuse=tf.AUTO_REUSE, params=None, moments=None, depth_multiplier=1.0, backprop_through_moments=True, use_bounded_activation=False, keep_spatial_dims=False): """A four-layer-convnet architecture.""" layer = tf.stop_gradient(inputs) model_params_keys, model_params_vars = [], [] moments_keys, moments_vars = [], [] with tf.variable_scope(scope, reuse=reuse): for i in range(4): with tf.variable_scope('layer_{}'.format(i), reuse=reuse): depth = int(64 * depth_multiplier) layer, conv_bn_params, conv_bn_moments = conv_bn( layer, [3, 3], depth, stride=1, weight_decay=weight_decay, params=params, moments=moments, is_training=is_training, backprop_through_moments=backprop_through_moments) model_params_keys.extend(conv_bn_params.keys()) model_params_vars.extend(conv_bn_params.values()) moments_keys.extend(conv_bn_moments.keys()) moments_vars.extend(conv_bn_moments.values()) if use_bounded_activation: layer = tf.nn.relu6(layer) else: layer = tf.nn.relu(layer) layer = tf.layers.max_pooling2d(layer, [2, 2], 2) logging.info('Output of block %d: %s', i, layer.shape) model_params = collections.OrderedDict( zip(model_params_keys, model_params_vars)) moments = collections.OrderedDict(zip(moments_keys, moments_vars)) if not keep_spatial_dims: layer = tf.layers.flatten(layer) return_dict = { 'embeddings': layer, 'params': model_params, 'moments': moments } return return_dict
def train_step(self, inputs, state): """ Args: inputs (tuple): observation and previous action Returns: TrainStep: outputs: intrinsic reward state: info: """ feature, prev_action = inputs if self._encoding_net is not None: feature, _ = self._encoding_net(feature) prev_feature = state prev_action = self._encode_action(prev_action) forward_pred, _ = self._forward_net( inputs=[tf.stop_gradient(prev_feature), prev_action]) forward_loss = 0.5 * tf.reduce_mean( tf.square(tf.stop_gradient(feature) - forward_pred), axis=-1) action_pred, _ = self._inverse_net(inputs=[prev_feature, feature]) if tensor_spec.is_discrete(self._action_spec): inverse_loss = tf.nn.softmax_cross_entropy_with_logits( labels=prev_action, logits=action_pred) else: inverse_loss = 0.5 * tf.reduce_mean( tf.square(prev_action - action_pred), axis=-1) intrinsic_reward = tf.stop_gradient(forward_loss) intrinsic_reward = self._reward_normalizer.normalize(intrinsic_reward) return AlgorithmStep(outputs=intrinsic_reward, state=feature, info=LossInfo(loss=forward_loss + inverse_loss, extra=ICMLossInfo( forward_loss=forward_loss, inverse_loss=inverse_loss)))
def train_step(self, time_step: ActionTimeStep, state, calc_intrinsic_reward=True): """ Args: time_step (ActionTimeStep): input time_step data, where the observation is skill-augmened observation state (Tensor): state for DIAYN (previous skill) calc_intrinsic_reward (bool): if False, only return the losses Returns: TrainStep: outputs: empty tuple () state: skill info (DIAYNInfo): """ observations_aug = time_step.observation step_type = time_step.step_type observation, skill = observations_aug prev_skill = state if self._encoding_net is not None: feature, _ = self._encoding_net(observation) skill_pred, _ = self._discriminator_net(inputs=feature) skill_discriminate_loss = tf.nn.softmax_cross_entropy_with_logits( labels=prev_skill, logits=skill_pred) valid_masks = tf.cast( tf.not_equal(step_type, StepType.FIRST), tf.float32) skill_discriminate_loss = skill_discriminate_loss * valid_masks intrinsic_reward = () if calc_intrinsic_reward: # use negative cross-entropy as reward # neglect neg-prior term as it is constant intrinsic_reward = tf.stop_gradient(-skill_discriminate_loss) intrinsic_reward = self._reward_normalizer.normalize( intrinsic_reward) return AlgorithmStep( outputs=(), state=skill, info=DIAYNInfo( reward=intrinsic_reward, loss=LossInfo( loss=skill_discriminate_loss, extra=dict( skill_discriminate_loss=skill_discriminate_loss))))
def _build_train_op(self): """Builds a training op. Returns: train_op: An op performing one step of training. """ replay_action_one_hot = tf.one_hot( self._replay.actions, self.num_actions, 1., 0., name='action_one_hot') replay_chosen_q = tf.reduce_sum( self._replay_qs * replay_action_one_hot, reduction_indices=1, name='replay_chosen_q') target = tf.stop_gradient(self._build_target_q_op()) loss = tf.losses.huber_loss( target, replay_chosen_q, reduction=tf.losses.Reduction.NONE) return self.optimizer.minimize(tf.reduce_mean(loss))
def _build_train_op(self): """Builds a training op. Returns: An op performing one step of training from replay data. """ # click_indicator: [B, S] # q_values: [B, A] # actions: [B, S] # slate_q_values: [B, S] # replay_click_q: [B] click_indicator = self._replay.rewards[:, :, self._click_response_index] slate_q_values = tf.compat.v1.batch_gather( self._replay_net_outputs.q_values, tf.cast(self._replay.actions, dtype=tf.int32)) # Only get the Q from the clicked document. replay_click_q = tf.reduce_sum(input_tensor=slate_q_values * click_indicator, axis=1, name='replay_click_q') target = tf.stop_gradient(self._build_target_q_op()) clicked = tf.reduce_sum(input_tensor=click_indicator, axis=1) clicked_indices = tf.squeeze(tf.compat.v1.where(tf.equal(clicked, 1)), axis=1) # clicked_indices is a vector and tf.gather selects the batch dimension. q_clicked = tf.gather(replay_click_q, clicked_indices) target_clicked = tf.gather(target, clicked_indices) def get_train_op(): loss = tf.reduce_mean(input_tensor=tf.square(q_clicked - target_clicked)) if self.summary_writer is not None: with tf.compat.v1.variable_scope('Losses'): tf.compat.v1.summary.scalar('Loss', loss) return loss loss = tf.cond(pred=tf.greater(tf.reduce_sum(input_tensor=clicked), 0), true_fn=get_train_op, false_fn=lambda: tf.constant(0.), name='') return self.optimizer.minimize(loss)
def classification_loss_fn(logits, labels, num_valid_voxels=None, weights=1.0): """Semantic segmentation cross entropy loss.""" logits_rank = len(logits.get_shape().as_list()) labels_rank = len(labels.get_shape().as_list()) if logits_rank != labels_rank: raise ValueError('Logits and labels should have the same rank.') if logits_rank != 2 and logits_rank != 3: raise ValueError( 'Logits and labels should have either 2 or 3 dimensions.') if logits_rank == 2: if num_valid_voxels is not None: raise ValueError( '`num_valid_voxels` should be None if not using batched logits.' ) elif logits_rank == 3: if num_valid_voxels is None: raise ValueError( '`num_valid_voxels` cannot be None if using batched logits.') if logits_rank == 3: if (isinstance(weights, tf.Tensor) and len(weights.get_shape().as_list()) == 3): use_weights = True else: use_weights = False batch_size = logits.get_shape().as_list()[0] logits_list = [] labels_list = [] weights_list = [] for i in range(batch_size): num_valid_voxels_i = num_valid_voxels[i] logits_list.append(logits[i, 0:num_valid_voxels_i, :]) labels_list.append(labels[i, 0:num_valid_voxels_i, :]) if use_weights: weights_list.append(weights[i, 0:num_valid_voxels_i, :]) logits = tf.concat(logits_list, axis=0) labels = tf.concat(labels_list, axis=0) if use_weights: weights = tf.concat(weights_list, axis=0) weights = tf.convert_to_tensor(weights, dtype=tf.float32) if labels.get_shape().as_list()[-1] == 1: num_classes = logits.get_shape().as_list()[-1] labels = tf.one_hot(tf.reshape(labels, shape=[-1]), num_classes) losses = tf.nn.softmax_cross_entropy_with_logits( labels=tf.stop_gradient(labels), logits=logits) return tf.reduce_mean(losses * tf.reshape(weights, [-1]))
def policy_loss(self): with tf.name_scope('policy_loss'): log_probs = [ dist.log_prob(self.input_actions[name]) for name, dist in self.model.policy.items() ] log_probs = tf.stack(log_probs, axis=-1) log_probs = log_probs * tf.gather( self.function_args_mask, self.input_actions['function_id']) advantage = self.input_returns - self.model.value policy_loss = -tf.reduce_mean( tf.reduce_sum(log_probs, axis=-1) * tf.stop_gradient(advantage)) * self.policy_factor tf.summary.scalar('policy_loss', policy_loss, family='losses') return policy_loss
def optimizer_update(iterate_collection, iteration_idx, objective_fn, update_fn, get_params_fn, first_order, clip_grad_norm): """Returns the next iterate in the optimization of objective_fn wrt variables. Args: iterate_collection: A (potentially structured) container of tf.Tensors corresponding to the state of the current iterate. iteration_idx: An int Tensor; the iteration number. objective_fn: Callable that takes in variables and produces the value of the objective function. update_fn: Callable that takes in the gradient of the objective function and the current iterate and produces the next iterate. get_params_fn: Callable that takes in the gradient of the objective function and the current iterate and produces the next iterate. first_order: If True, prevent the computation of higher order gradients. clip_grad_norm: If not None, gradient dimensions are independently clipped to lie in the interval [-clip_grad_norm, clip_grad_norm]. """ variables = [get_params_fn(iterate) for iterate in iterate_collection] if tf.executing_eagerly(): with tf.GradientTape(persistent=True) as g: g.watch(variables) loss = objective_fn(variables, iteration_idx) grads = g.gradient(loss, variables) else: loss = objective_fn(variables, iteration_idx) grads = tf.gradients(ys=loss, xs=variables) if clip_grad_norm: grads = [ tf.clip_by_value(grad, -1 * clip_grad_norm, clip_grad_norm) for grad in grads ] if first_order: grads = [tf.stop_gradient(dv) for dv in grads] return [ update_fn(i=iteration_idx, grad=dv, state=s) for (s, dv) in zip(iterate_collection, grads) ]
def _build_train_op(self): """Builds a training op. Returns: train_op: An op performing one step of training from replay data. """ replay_next_target_value = tf.reduce_max( self._replay_next_target_net_outputs.q_values, 1) replay_target_value = tf.reduce_max( self._replay_target_net_outputs.q_values, 1) replay_action_one_hot = tf.one_hot(self._replay.actions, self.num_actions, 1., 0., name='action_one_hot') replay_chosen_q = tf.reduce_sum(self._replay_net_outputs.q_values * replay_action_one_hot, axis=1, name='replay_chosen_q') replay_target_chosen_q = tf.reduce_sum( self._replay_target_net_outputs.q_values * replay_action_one_hot, axis=1, name='replay_chosen_q') augmented_rewards = self._replay.rewards - self.alpha * ( replay_target_value - replay_target_chosen_q) target = (augmented_rewards + self.cumulative_gamma * replay_next_target_value * (1. - tf.cast(self._replay.terminals, tf.float32))) target = tf.stop_gradient(target) loss = tf.losses.huber_loss(target, replay_chosen_q, reduction=tf.losses.Reduction.NONE) if self.summary_writer is not None: with tf.variable_scope('Losses'): tf.summary.scalar('HuberLoss', tf.reduce_mean(loss)) return self.optimizer.minimize(tf.reduce_mean(loss))
def learn_metric(self, verbose=False): """Approximate the bisimulation metric by learning. Args: verbose: bool, whether to print verbose messages. """ summary_writer = tf.summary.FileWriter(self.base_dir) global_step = tf.Variable(0, trainable=False) inc_global_step_op = tf.assign_add(global_step, 1) bisim_horizon = 0.0 bisim_horizon_discount_value = 1.0 if self.use_decayed_learning_rate: learning_rate = tf.train.exponential_decay(self.starting_learning_rate, global_step, self.num_iterations, self.learning_rate_decay, staircase=self.staircase) else: learning_rate = self.starting_learning_rate tf.summary.scalar('Learning/LearningRate', learning_rate) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, epsilon=self.epsilon) train_op = self._build_train_op(optimizer) sync_op = self._build_sync_op() eval_op = tf.stop_gradient(self._build_eval_metric()) eval_states = [] # Build the evaluation tensor. for state in range(self.num_states): row, col = self.inverse_index_states[state] # We make the evaluation states at the center of each grid cell. eval_states.append([row + 0.5, col + 0.5]) eval_states = np.array(eval_states, dtype=np.float64) normalized_bisim_metric = ( self.bisim_metric / np.linalg.norm(self.bisim_metric)) metric_errors = [] average_metric_errors = [] normalized_metric_errors = [] average_normalized_metric_errors = [] saver = tf.train.Saver(max_to_keep=3) with tf.Session() as sess: summary_writer.add_graph(graph=tf.get_default_graph()) sess.run(tf.global_variables_initializer()) merged_summaries = tf.summary.merge_all() for i in range(self.num_iterations): sampled_states = np.random.randint(self.num_states, size=(self.batch_size,)) sampled_actions = np.random.randint(4, size=(self.batch_size,)) if self.add_noise: sampled_noise = np.clip( np.random.normal(0, 0.1, size=(self.batch_size, 2)), -0.3, 0.3) sampled_action_names = [self.actions[x] for x in sampled_actions] next_states = [self.next_states[a][s] for s, a in zip(sampled_states, sampled_action_names)] rewards = np.array([self.rewards[a][s] for s, a in zip(sampled_states, sampled_action_names)]) states = np.array( [self.inverse_index_states[x] for x in sampled_states]) next_states = np.array([self.inverse_index_states[x] for x in next_states]) states = states.astype(np.float64) states += 0.5 # Place points in center of grid. next_states = next_states.astype(np.float64) next_states += 0.5 if self.add_noise: states += sampled_noise next_states += sampled_noise _, summary = sess.run( [train_op, merged_summaries], feed_dict={self.s1_ph: states, self.s2_ph: next_states, self.action_ph: sampled_actions, self.rewards_ph: rewards, self.bisim_horizon_ph: bisim_horizon, self.eval_states_ph: eval_states}) summary_writer.add_summary(summary, i) if self.double_period_halfway and i > self.num_iterations / 2.: self.target_update_period *= 2 self.double_period_halfway = False if i % self.target_update_period == 0: bisim_horizon = 1.0 - bisim_horizon_discount_value bisim_horizon_discount_value *= self.bisim_horizon_discount sess.run(sync_op) # Now compute difference with exact metric. self.learned_distance = sess.run( eval_op, feed_dict={self.eval_states_ph: eval_states}) self.learned_distance = np.reshape(self.learned_distance, (self.num_states, self.num_states)) metric_difference = np.max( abs(self.learned_distance - self.bisim_metric)) average_metric_difference = np.mean( abs(self.learned_distance - self.bisim_metric)) normalized_learned_distance = ( self.learned_distance / np.linalg.norm(self.learned_distance)) normalized_metric_difference = np.max( abs(normalized_learned_distance - normalized_bisim_metric)) average_normalized_metric_difference = np.mean( abs(normalized_learned_distance - normalized_bisim_metric)) error_summary = tf.Summary(value=[ tf.Summary.Value(tag='Approx/Error', simple_value=metric_difference), tf.Summary.Value(tag='Approx/AvgError', simple_value=average_metric_difference), tf.Summary.Value(tag='Approx/NormalizedError', simple_value=normalized_metric_difference), tf.Summary.Value(tag='Approx/AvgNormalizedError', simple_value=average_normalized_metric_difference), ]) summary_writer.add_summary(error_summary, i) sess.run(inc_global_step_op) if i % 100 == 0: # Collect statistics every 100 steps. metric_errors.append(metric_difference) average_metric_errors.append(average_metric_difference) normalized_metric_errors.append(normalized_metric_difference) average_normalized_metric_errors.append( average_normalized_metric_difference) saver.save(sess, os.path.join(self.base_dir, 'tf_ckpt'), global_step=i) if self.debug and i % 100 == 0: self.pretty_print_metric(metric_type='learned') print('Iteration: {}'.format(i)) print('Metric difference: {}'.format(metric_difference)) print('Normalized metric difference: {}'.format( normalized_metric_difference)) if self.add_noise: # Finally, if we have noise, we draw a bunch of samples to get estimates # of the distances between states. sampled_distances = {} for _ in range(self.total_final_samples): eval_states = [] for state in range(self.num_states): row, col = self.inverse_index_states[state] # We make the evaluation states at the center of each grid cell. eval_states.append([row + 0.5, col + 0.5]) eval_states = np.array(eval_states, dtype=np.float64) eval_noise = np.clip( np.random.normal(0, 0.1, size=(self.num_states, 2)), -0.3, 0.3) eval_states += eval_noise distance_samples = sess.run( eval_op, feed_dict={self.eval_states_ph: eval_states}) distance_samples = np.reshape(distance_samples, (self.num_states, self.num_states)) for s1 in range(self.num_states): for s2 in range(self.num_states): sampled_distances[(tuple(eval_states[s1]), tuple(eval_states[s2]))] = ( distance_samples[s1, s2]) else: # Otherwise we just use the last evaluation metric. sampled_distances = self.learned_distance learned_statistics = { 'num_iterations': self.num_iterations, 'metric_errors': metric_errors, 'average_metric_errors': average_metric_errors, 'normalized_metric_errors': normalized_metric_errors, 'average_normalized_metric_errors': average_normalized_metric_errors, 'learned_distances': sampled_distances, } self.statistics['learned'] = learned_statistics if verbose: self.pretty_print_metric(metric_type='learned')
def bn(x, params=None, moments=None, backprop_through_moments=True): """Batch normalization. The usage should be as follows: If x is the support images, moments should be None so that they are computed from the support set examples. On the other hand, if x is the query images, the moments argument should be used in order to pass in the mean and var that were computed from the support set. Args: x: inputs. params: None or a dict containing the values of the offset and scale params. moments: None or a dict containing the values of the mean and var to use for batch normalization. backprop_through_moments: Whether to allow gradients to flow through the given support set moments. Only applies to non-transductive batch norm. Returns: output: The result of applying batch normalization to the input. params: The updated params. moments: The updated moments. """ params_keys, params_vars, moments_keys, moments_vars = [], [], [], [] with tf.variable_scope('batch_norm'): scope_name = tf.get_variable_scope().name if moments is None: # If not provided, compute the mean and var of the current batch. mean, var = tf.nn.moments(x, axes=list(range(len(x.shape) - 1)), keep_dims=True) else: if backprop_through_moments: mean = moments[scope_name + '/mean'] var = moments[scope_name + '/var'] else: # This variant does not yield good resutls. mean = tf.stop_gradient(moments[scope_name + '/mean']) var = tf.stop_gradient(moments[scope_name + '/var']) moments_keys += [scope_name + '/mean'] moments_vars += [mean] moments_keys += [scope_name + '/var'] moments_vars += [var] if params is None: offset = tf.get_variable('offset', shape=mean.get_shape().as_list(), initializer=tf.initializers.zeros()) scale = tf.get_variable('scale', shape=var.get_shape().as_list(), initializer=tf.initializers.ones()) else: offset = params[scope_name + '/offset'] scale = params[scope_name + '/scale'] params_keys += [scope_name + '/offset'] params_vars += [offset] params_keys += [scope_name + '/scale'] params_vars += [scale] output = tf.nn.batch_normalization(x, mean, var, offset, scale, 0.00001) params = collections.OrderedDict(zip(params_keys, params_vars)) moments = collections.OrderedDict(zip(moments_keys, moments_vars)) return output, params, moments
def bn(x, params=None, moments=None, backprop_through_moments=True, use_ema=False, is_training=True, ema_epsilon=.9): """Batch normalization. The usage should be as follows: If x is the support images, moments should be None so that they are computed from the support set examples. On the other hand, if x is the query images, the moments argument should be used in order to pass in the mean and var that were computed from the support set. Args: x: inputs. params: None or a dict containing the values of the offset and scale params. moments: None or a dict containing the values of the mean and var to use for batch normalization. backprop_through_moments: Whether to allow gradients to flow through the given support set moments. Only applies to non-transductive batch norm. use_ema: apply moving averages of batch norm statistics, or update them, depending on whether we are training or testing. Note that passing moments will override this setting, and result in neither updating or using ema statistics. This is important to make sure that episodic learners don't update ema statistics a second time when processing queries. is_training: if use_ema=True, this determines whether to apply the moving averages, or update them. ema_epsilon: if updating moving averages, use this value for the exponential moving averages. Returns: output: The result of applying batch normalization to the input. params: The updated params. moments: The updated moments. """ params_keys, params_vars, moments_keys, moments_vars = [], [], [], [] with tf.variable_scope('batch_norm'): scope_name = tf.get_variable_scope().name if use_ema: ema_shape = [1, 1, 1, x.get_shape().as_list()[-1]] mean_ema = tf.get_variable( 'mean_ema', shape=ema_shape, initializer=tf.initializers.zeros(), trainable=False) var_ema = tf.get_variable( 'var_ema', shape=ema_shape, initializer=tf.initializers.ones(), trainable=False) if moments is not None: if backprop_through_moments: mean = moments[scope_name + '/mean'] var = moments[scope_name + '/var'] else: # This variant does not yield good resutls. mean = tf.stop_gradient(moments[scope_name + '/mean']) var = tf.stop_gradient(moments[scope_name + '/var']) elif use_ema and not is_training: mean = mean_ema var = var_ema else: # If not provided, compute the mean and var of the current batch. replica_ctx = tf.distribute.get_replica_context() if replica_ctx: # from third_party/tensorflow/python/keras/layers/normalization_v2.py axes = list(range(len(x.shape) - 1)) local_sum = tf.reduce_sum(x, axis=axes, keepdims=True) local_squared_sum = tf.reduce_sum( tf.square(x), axis=axes, keepdims=True) batch_size = tf.cast(tf.shape(x)[0], tf.float32) x_sum, x_squared_sum, global_batch_size = ( replica_ctx.all_reduce('sum', [local_sum, local_squared_sum, batch_size])) axes_vals = [(tf.shape(x))[i] for i in range(1, len(axes))] multiplier = tf.cast(tf.reduce_prod(axes_vals), tf.float32) multiplier = multiplier * global_batch_size mean = x_sum / multiplier x_squared_mean = x_squared_sum / multiplier # var = E(x^2) - E(x)^2 var = x_squared_mean - tf.square(mean) else: mean, var = tf.nn.moments( x, axes=list(range(len(x.shape) - 1)), keep_dims=True) # Only update ema's if training and we computed the moments in the current # call. Note: at test time for episodic learners, ema's may be passed # from the support set to the query set, even if it's not really needed. if use_ema and is_training and moments is None: replica_ctx = tf.distribute.get_replica_context() mean_upd = tf.assign(mean_ema, mean_ema * ema_epsilon + mean * (1.0 - ema_epsilon)) var_upd = tf.assign(var_ema, var_ema * ema_epsilon + var * (1.0 - ema_epsilon)) updates = tf.group([mean_upd, var_upd]) if replica_ctx: tf.add_to_collection( tf.GraphKeys.UPDATE_OPS, tf.cond( tf.equal(replica_ctx.replica_id_in_sync_group, 0), lambda: updates, tf.no_op)) else: tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, updates) moments_keys += [scope_name + '/mean'] moments_vars += [mean] moments_keys += [scope_name + '/var'] moments_vars += [var] if params is None: offset = tf.get_variable( 'offset', shape=mean.get_shape().as_list(), initializer=tf.initializers.zeros()) scale = tf.get_variable( 'scale', shape=var.get_shape().as_list(), initializer=tf.initializers.ones()) else: offset = params[scope_name + '/offset'] scale = params[scope_name + '/scale'] params_keys += [scope_name + '/offset'] params_vars += [offset] params_keys += [scope_name + '/scale'] params_vars += [scale] output = tf.nn.batch_normalization(x, mean, var, offset, scale, 0.00001) params = collections.OrderedDict(zip(params_keys, params_vars)) moments = collections.OrderedDict(zip(moments_keys, moments_vars)) return output, params, moments
def relationnet_convnet(inputs, is_training, weight_decay, params=None, moments=None, depth_multiplier=1.0, reuse=tf.AUTO_REUSE, scope='relationnet_convnet', backprop_through_moments=True, use_bounded_activation=False, keep_spatial_dims=False): """A 4-layer-convnet architecture for RelationNet embedding. This is almost like the `four_layer_convnet` embedding function except for the following differences: (1) no padding for the first 3 layers, (2) no maxpool on the last (4th) layer, and (3) no flatten. Paper: https://arxiv.org/abs/1711.06025 Code: https://github.com/floodsung/LearningToCompare_FSL/blob/master/miniimagenet/miniimagenet_train_few_shot.py Args: inputs: Tensors of shape [None, ] + image shape, e.g. [15, 84, 84, 3] is_training: Whether we are in the training phase. weight_decay: float, scaling constant for L2 weight decay on weight variables. params: None will create new params (or reuse from scope), otherwise an ordered dict of convolutional kernels and biases such that params['kernel_0'] stores the kernel of the first convolutional layer, etc. moments: A dict of the means and vars of the different layers to use for batch normalization. If not provided, the mean and var are computed based on the given inputs. depth_multiplier: The depth multiplier for the convnet channels. reuse: Whether to reuse the network's weights. scope: An optional scope for the tf operations. backprop_through_moments: Whether to allow gradients to flow through the given support set moments. Only applies to non-transductive batch norm. use_bounded_activation: Whether to enable bounded activation. This is useful for post-training quantization. keep_spatial_dims: bool, if True the spatial dimensions are kept. Returns: A 2D Tensor, where each row is the embedding of an input in inputs. """ layer = tf.stop_gradient(inputs) model_params_keys, model_params_vars = [], [] moments_keys, moments_vars = [], [] with tf.variable_scope(scope, reuse=reuse): for i in range(4): with tf.variable_scope('layer_{}'.format(i), reuse=reuse): depth = int(64 * depth_multiplier) # The original implementation had VALID padding for the first two layers # that are followed by pooling. The rest (last two) had `SAME` padding. # In our setting, to avoid OOM, we pool (and apply VALID padding) to # the first three layers, and use SAME padding only in the last one. layer, conv_bn_params, conv_bn_moments = conv_bn( layer, [3, 3], depth, stride=1, weight_decay=weight_decay, padding='VALID' if i < 3 else 'SAME', params=params, moments=moments, is_training=is_training, backprop_through_moments=backprop_through_moments) model_params_keys.extend(conv_bn_params.keys()) model_params_vars.extend(conv_bn_params.values()) moments_keys.extend(conv_bn_moments.keys()) moments_vars.extend(conv_bn_moments.values()) layer = relu(layer, use_bounded_activation=use_bounded_activation) if i < 3: layer = tf.layers.max_pooling2d(layer, [2, 2], 2) tf.logging.info('Output of block %d: %s' % (i, layer.shape)) model_params = collections.OrderedDict( zip(model_params_keys, model_params_vars)) moments = collections.OrderedDict(zip(moments_keys, moments_vars)) if not keep_spatial_dims: layer = tf.layers.flatten(layer) return_dict = { 'embeddings': layer, 'params': model_params, 'moments': moments } return return_dict
def safety_critic_loss(time_steps, actions, next_time_steps, safety_rewards, get_action, global_step, critic_network=None, target_network=None, target_safety=None, safety_gamma=0.45, loss_fn='bce', metrics=None, debug_summaries=False): """Computes the critic loss for SAC training. Args: time_steps: A batch of timesteps. actions: A batch of actions. next_time_steps: A batch of next timesteps. safety_rewards: Task-agnostic rewards for safety. 1 is unsafe, 0 is safe. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: safe_critic_loss: A scalar critic loss. """ with tf.name_scope('safety_critic_loss'): next_actions = get_action(next_time_steps) target_input = (next_time_steps.observation, next_actions) target_q_values, _ = target_network(target_input, next_time_steps.step_type) target_q_values = tf.nn.sigmoid(target_q_values) td_targets = tf.stop_gradient(safety_rewards + (1 - safety_rewards) * safety_gamma * next_time_steps.discount * target_q_values) if loss_fn == 'bce' or loss_fn == tf.keras.losses.binary_crossentropy: td_targets = tf.nn.sigmoid(td_targets) pred_input = (time_steps.observation, actions) pred_td_targets, _ = critic_network(pred_input, time_steps.step_type, training=True) pred_td_targets = tf.nn.sigmoid(pred_td_targets) # Loss fns: binary_crossentropy/squared_difference if loss_fn == 'mse': sc_loss = tf.math.squared_difference(td_targets, pred_td_targets) elif loss_fn == 'bce' or loss_fn is None: sc_loss = tf.keras.losses.binary_crossentropy( td_targets, pred_td_targets) elif loss_fn is not None: sc_loss = loss_fn(td_targets, pred_td_targets) if metrics: for metric in metrics: if isinstance(metric, tf.keras.metrics.AUC): metric.update_state(safety_rewards, pred_td_targets) else: rew_pred = tf.greater_equal(pred_td_targets, target_safety) metric.update_state(safety_rewards, rew_pred) if debug_summaries: pred_td_targets = tf.nn.sigmoid(pred_td_targets) td_errors = td_targets - pred_td_targets common.generate_tensor_summaries('safety_td_errors', td_errors, global_step) common.generate_tensor_summaries('safety_td_targets', td_targets, global_step) common.generate_tensor_summaries('safety_pred_td_targets', pred_td_targets, global_step) return sc_loss
def _resnet(x, is_training, weight_decay, scope, reuse=tf.AUTO_REUSE, params=None, moments=None, backprop_through_moments=True, use_bounded_activation=False, blocks=(2, 2, 2, 2), max_stride=None, deeplab_alignment=True, keep_spatial_dims=False): """A ResNet network; ResNet18 by default.""" x = tf.stop_gradient(x) params_keys, params_vars = [], [] moments_keys, moments_vars = [], [] assert max_stride in [None, 4, 8, 16, 32], 'max_stride must be 4, 8, 16, 32, or None' with tf.variable_scope(scope, reuse=reuse): # We use DeepLab feature alignment rule to determine the input size. # Since the image size in the meta-dataset pipeline is a multiplier of 42, # e.g., [42, 84, 168], we align them to the closest sizes that conform to # the alignment rule and at the same time are larger. They are [65, 97, 193] # respectively. The aligned image size for 224 used in the ResNet work is # 225. # # References: # 1. ResNet https://arxiv.org/abs/1512.03385 # 2. DeepLab https://arxiv.org/abs/1606.00915 if deeplab_alignment: size = tf.cast(tf.shape(x)[1], tf.float32) aligned_size = tf.cast(tf.ceil(size / 32.0), tf.int32) * 32 + 1 x = tf.image.resize_bilinear( x, size=[aligned_size, aligned_size], align_corners=True) with tf.variable_scope('conv1'): x, conv_bn_params, conv_bn_moments = conv_bn( x, [7, 7], 64, 2, weight_decay, params=params, moments=moments, is_training=is_training, backprop_through_moments=backprop_through_moments) params_keys.extend(conv_bn_params.keys()) params_vars.extend(conv_bn_params.values()) moments_keys.extend(conv_bn_moments.keys()) moments_vars.extend(conv_bn_moments.values()) x = relu(x, use_bounded_activation=use_bounded_activation) def _bottleneck(x, i, depth, stride, params, moments, net_stride=1, net_rate=1): """Wrapper for bottleneck.""" input_rate = net_rate output_rate = input_rate if i == 0: if max_stride and stride * net_stride > max_stride: output_stride = 1 output_rate *= stride else: output_stride = stride else: output_stride = 1 use_project = True if i == 0 else False x, bottleneck_params, bottleneck_moments = bottleneck( x, (depth, depth), output_stride, weight_decay, params=params, moments=moments, input_rate=input_rate, output_rate=output_rate, use_project=use_project, is_training=is_training, backprop_through_moments=backprop_through_moments) net_stride *= output_stride return x, bottleneck_params, bottleneck_moments, net_stride, output_rate net_stride = 4 net_rate = 1 with tf.variable_scope('conv2_x'): x = tf.nn.max_pool( x, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME') for i in range(blocks[0]): with tf.variable_scope('bottleneck_%d' % i): x, bottleneck_params, bottleneck_moments, net_stride, net_rate = _bottleneck( x, i, 64, 1, params, moments, net_stride, net_rate) params_keys.extend(bottleneck_params.keys()) params_vars.extend(bottleneck_params.values()) moments_keys.extend(bottleneck_moments.keys()) moments_vars.extend(bottleneck_moments.values()) with tf.variable_scope('conv3_x'): for i in range(blocks[1]): with tf.variable_scope('bottleneck_%d' % i): x, bottleneck_params, bottleneck_moments, net_stride, net_rate = _bottleneck( x, i, 128, 2, params, moments, net_stride, net_rate) params_keys.extend(bottleneck_params.keys()) params_vars.extend(bottleneck_params.values()) moments_keys.extend(bottleneck_moments.keys()) moments_vars.extend(bottleneck_moments.values()) with tf.variable_scope('conv4_x'): for i in range(blocks[2]): with tf.variable_scope('bottleneck_%d' % i): x, bottleneck_params, bottleneck_moments, net_stride, net_rate = _bottleneck( x, i, 256, 2, params, moments, net_stride, net_rate) params_keys.extend(bottleneck_params.keys()) params_vars.extend(bottleneck_params.values()) moments_keys.extend(bottleneck_moments.keys()) moments_vars.extend(bottleneck_moments.values()) with tf.variable_scope('conv5_x'): for i in range(blocks[3]): with tf.variable_scope('bottleneck_%d' % i): x, bottleneck_params, bottleneck_moments, net_stride, net_rate = _bottleneck( x, i, 512, 2, params, moments, net_stride, net_rate) params_keys.extend(bottleneck_params.keys()) params_vars.extend(bottleneck_params.values()) moments_keys.extend(bottleneck_moments.keys()) moments_vars.extend(bottleneck_moments.values()) if not keep_spatial_dims: # x.shape: [?, 1, 1, 512] x = tf.reduce_mean(x, axis=[1, 2], keepdims=True) x = tf.reshape(x, [-1, 512]) params = collections.OrderedDict(zip(params_keys, params_vars)) moments = collections.OrderedDict(zip(moments_keys, moments_vars)) return_dict = {'embeddings': x, 'params': params, 'moments': moments} return return_dict
def _resnet(x, is_training, weight_decay, scope, reuse=tf.AUTO_REUSE, params=None, moments=None, backprop_through_moments=True, use_bounded_activation=False, keep_spatial_dims=False): """A ResNet18 network.""" # `is_training` will be used when start to use moving {var, mean} in batch # normalization. This refers to 'meta-training'. del is_training x = tf.stop_gradient(x) params_keys, params_vars = [], [] moments_keys, moments_vars = [], [] with tf.variable_scope(scope, reuse=reuse): # We use DeepLab feature alignment rule to determine the input size. # Since the image size in the meta-dataset pipeline is a multiplier of 42, # e.g., [42, 84, 168], we align them to the closest sizes that conform to # the alignment rule and at the same time are larger. They are [65, 97, 193] # respectively. The aligned image size for 224 used in the ResNet work is # 225. # # References: # 1. ResNet https://arxiv.org/abs/1512.03385 # 2. DeepLab https://arxiv.org/abs/1606.00915 size = tf.cast(tf.shape(x)[1], tf.float32) aligned_size = tf.cast(tf.ceil(size / 32.0), tf.int32) * 32 + 1 x = tf.image.resize_bilinear(x, size=[aligned_size, aligned_size], align_corners=True) with tf.variable_scope('conv1'): x, conv_bn_params, conv_bn_moments = conv_bn( x, [7, 7], 64, 2, weight_decay, params=params, moments=moments, backprop_through_moments=backprop_through_moments) params_keys.extend(conv_bn_params.keys()) params_vars.extend(conv_bn_params.values()) moments_keys.extend(conv_bn_moments.keys()) moments_vars.extend(conv_bn_moments.values()) x = relu(x, use_bounded_activation=use_bounded_activation) def _bottleneck(x, i, depth, stride, params, moments): """Wrapper for bottleneck.""" output_stride = stride if i == 0 else 1 use_project = True if i == 0 else False x, bottleneck_params, bottleneck_moments = bottleneck( x, (depth, depth), output_stride, weight_decay, params=params, moments=moments, use_project=use_project, backprop_through_moments=backprop_through_moments) return x, bottleneck_params, bottleneck_moments with tf.variable_scope('conv2_x'): x = tf.nn.max_pool(x, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME') for i in range(2): with tf.variable_scope('bottleneck_%d' % i): x, bottleneck_params, bottleneck_moments = _bottleneck( x, i, 64, 1, params, moments) params_keys.extend(bottleneck_params.keys()) params_vars.extend(bottleneck_params.values()) moments_keys.extend(bottleneck_moments.keys()) moments_vars.extend(bottleneck_moments.values()) with tf.variable_scope('conv3_x'): for i in range(2): with tf.variable_scope('bottleneck_%d' % i): x, bottleneck_params, bottleneck_moments = _bottleneck( x, i, 128, 2, params, moments) params_keys.extend(bottleneck_params.keys()) params_vars.extend(bottleneck_params.values()) moments_keys.extend(bottleneck_moments.keys()) moments_vars.extend(bottleneck_moments.values()) with tf.variable_scope('conv4_x'): for i in range(2): with tf.variable_scope('bottleneck_%d' % i): x, bottleneck_params, bottleneck_moments = _bottleneck( x, i, 256, 2, params, moments) params_keys.extend(bottleneck_params.keys()) params_vars.extend(bottleneck_params.values()) moments_keys.extend(bottleneck_moments.keys()) moments_vars.extend(bottleneck_moments.values()) with tf.variable_scope('conv5_x'): for i in range(2): with tf.variable_scope('bottleneck_%d' % i): x, bottleneck_params, bottleneck_moments = _bottleneck( x, i, 512, 2, params, moments) params_keys.extend(bottleneck_params.keys()) params_vars.extend(bottleneck_params.values()) moments_keys.extend(bottleneck_moments.keys()) moments_vars.extend(bottleneck_moments.values()) x = tf.reduce_mean(x, axis=[1, 2], keepdims=True) # x.shape: [?, 1, 1, 512] if not keep_spatial_dims: x = tf.reshape(x, [-1, 512]) params = collections.OrderedDict(zip(params_keys, params_vars)) moments = collections.OrderedDict(zip(moments_keys, moments_vars)) return_dict = {'embeddings': x, 'params': params, 'moments': moments} return return_dict
def _build_train_op(self, optimizer): """Build the TensorFlow graph used to learn the bisimulation metric. Args: optimizer: a tf.train optimizer. Returns: A TensorFlow op to minimize the bisimulation loss. """ self.online_network = tf.make_template('Online', self._network_template) self.target_network = tf.make_template('Target', self._network_template) self.s1_ph = tf.placeholder(tf.float64, (self.batch_size, 2), name='s1_ph') self.s2_ph = tf.placeholder(tf.float64, (self.batch_size, 2), name='s2_ph') self.s1_online_distances = self.online_network( self._concat_states(self.s1_ph)) self.s1_target_distances = self.target_network( self._concat_states(self.s1_ph)) self.s2_target_distances = self.target_network( self._concat_states(self.s2_ph)) self.action_ph = tf.placeholder(tf.int32, (self.batch_size,)) self.rewards_ph = tf.placeholder(tf.float64, (self.batch_size,)) # We use an expanding horizon for computing the distances. self.bisim_horizon_ph = tf.placeholder(tf.float64, ()) # bisimulation_target_1 = rew_diff + gamma * next_distance. bisimulation_target_1 = tf.stop_gradient(self._build_bisimulation_target()) # bisimulation_target_2 = curr_distance. bisimulation_target_2 = tf.stop_gradient(self.s1_target_distances) # We slowly taper in the maximum according to the bisim horizon. bisimulation_target = tf.maximum( bisimulation_target_1, bisimulation_target_2 * self.bisim_horizon_ph) # We zero-out diagonal entries, since those are estimating the distance # between a state and itself, which we know to be 0. diagonal_mask = 1.0 - tf.diag(tf.ones(self.batch_size, dtype=tf.float64)) diagonal_mask = tf.reshape(diagonal_mask, (self.batch_size**2, 1)) bisimulation_target *= diagonal_mask bisimulation_estimate = self.s1_online_distances # We start with a mask that includes everything. loss_mask = tf.ones(tf.shape(bisimulation_estimate)) # We have to enforce that states being compared are done only using the same # action. indicators = self.action_ph indicators = tf.cast(indicators, tf.float64) # indicators will initially have shape [batch_size], we first tile it: square_ids = tf.tile([indicators], [self.batch_size, 1]) # We subtract square_ids from its transpose: square_ids = square_ids - tf.transpose(square_ids) # At this point all zero-entries are the ones with equal IDs. # Now we would like to convert the zeros in this matrix to 1s, and make # everything else a 0. We can do this with the following operation: loss_mask = 1 - tf.abs(tf.sign(square_ids)) # Now reshape to match the shapes of the estimate and target. loss_mask = tf.reshape(loss_mask, (self.batch_size**2, 1)) larger_targets = bisimulation_target - bisimulation_estimate larger_targets_count = tf.reduce_sum( tf.cast(larger_targets > 0., tf.float64)) tf.summary.scalar('Learning/LargerTargets', larger_targets_count) tf.summary.scalar('Learning/NumUpdates', tf.count_nonzero(loss_mask)) tf.summary.scalar('Learning/BisimHorizon', self.bisim_horizon_ph) bisimulation_loss = tf.losses.mean_squared_error( bisimulation_target, bisimulation_estimate, weights=loss_mask) tf.summary.scalar('Learning/loss', bisimulation_loss) # Plot average distance between sampled representations. average_distance = tf.reduce_mean(bisimulation_estimate) tf.summary.scalar('Approx/AverageDistance', average_distance) return optimizer.minimize(bisimulation_loss)