def _actor_train_step(self, exp: Experience, state: DdpgActorState): action, actor_state = self._actor_network(exp.observation, exp.step_type, network_state=state.actor) with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(action) q_value, critic_state = self._critic_network( (exp.observation, action), network_state=state.critic) dqda = tape.gradient(q_value, action) def actor_loss_fn(dqda, action): if self._dqda_clipping: dqda = tf.clip_by_value(dqda, -self._dqda_clipping, self._dqda_clipping) loss = 0.5 * losses.element_wise_squared_loss( tf.stop_gradient(dqda + action), action) loss = tf.reduce_sum(loss, axis=list(range(1, len(loss.shape)))) return loss actor_loss = tf.nest.map_structure(actor_loss_fn, dqda, action) state = DdpgActorState(actor=actor_state, critic=critic_state) info = LossInfo(loss=tf.add_n(tf.nest.flatten(actor_loss)), extra=actor_loss) return PolicyStep(action=action, state=state, info=info)
def test_estimated_entropy(self, assume_reparametrization): logging.info("assume_reparametrization=%s" % assume_reparametrization) num_samples = 1000000 seed_stream = tfp.distributions.SeedStream( seed=1, salt='test_estimated_entropy') batch_shape = (2, ) loc = tf.random.normal(shape=batch_shape, seed=seed_stream()) scale = tf.abs(tf.random.normal(shape=batch_shape, seed=seed_stream())) with tf.GradientTape(persistent=True) as tape: tape.watch(scale) dist = tfp.distributions.Normal(loc=loc, scale=scale) analytic_entropy = dist.entropy() est_entropy, est_entropy_for_gradient = dist_utils.estimated_entropy( dist=dist, seed=seed_stream(), assume_reparametrization=assume_reparametrization, num_samples=num_samples) analytic_grad = tape.gradient(analytic_entropy, scale) est_grad = tape.gradient(est_entropy_for_gradient, scale) logging.info("scale=%s" % scale) logging.info("analytic_entropy=%s" % analytic_entropy) logging.info("estimated_entropy=%s" % est_entropy) self.assertArrayAlmostEqual(analytic_entropy, est_entropy, 5e-2) logging.info("analytic_entropy_grad=%s" % analytic_grad) logging.info("estimated_entropy_grad=%s" % est_grad) self.assertArrayAlmostEqual(analytic_grad, est_grad, 5e-2) if not assume_reparametrization: est_grad_wrong = tape.gradient(est_entropy, scale) logging.info("estimated_entropy_grad_wrong=%s", est_grad_wrong) self.assertLess(tf.reduce_max(tf.abs(est_grad_wrong)), 5e-2)
def train_step(tf_agent, safety_critic, batch, safety_rewards, optimizer): """Helper function for creating a train step.""" rb_data, buf_info = batch safe_rew = tf.gather(safety_rewards, buf_info.ids, axis=1) time_steps, actions, next_time_steps = tf_agent._experience_to_transitions( # pylint: disable=protected-access rb_data) time_steps = time_steps._replace(reward=safe_rew[:, :-1]) # pylint: disable=protected-access next_time_steps = next_time_steps._replace(reward=safe_rew[:, 1:]) fail_pct = safety_rewards.sum() / safety_rewards.shape[1] loss_weight = 0.5 / ((next_time_steps.reward) * fail_pct + (1 - next_time_steps.reward) * (1 - fail_pct)) trainable_safety_variables = safety_critic.trainable_variables with tf.GradientTape(watch_accessed_variables=False) as tape: assert trainable_safety_variables, ('No trainable safety critic variables' ' to optimize.') tape.watch(trainable_safety_variables) loss = safety_critic_loss( tf_agent, safety_critic, time_steps, actions, next_time_steps, safety_rewards=next_time_steps.reward, weights=loss_weight) tf.debugging.check_numerics(loss, 'Critic loss is inf or nan.') safety_critic_grads = tape.gradient(loss, trainable_safety_variables) grads_and_vars = list(zip(safety_critic_grads, trainable_safety_variables)) optimizer.apply_gradients(grads_and_vars) return loss
def step(batch_theta, batch_psi): with tf.GradientTape() as tape: z_mean, z_log_var = self.encode(batch_theta) z = self.sample(z_mean, z_log_var, training=True) p_z = self.discriminator(z) x_mean, x_log_var = self.decode(z) loss_theta = self.objective(batch_theta, x_mean, x_log_var, z_mean, z_log_var, p_z) tf.debugging.check_numerics(loss_theta, "loss is invalid") # Discriminator weights are assigned as not trainable in init grad_theta = tape.gradient(loss_theta, self.trainable_variables) optimizer.apply_gradients(zip(grad_theta, self.trainable_variables)) # Updating Discriminator with tf.GradientTape() as tape: z_mean, z_log_var = self.encode(batch_psi) z = self.sample(z_mean, z_log_var, training=True) z_permuted = tf.py_function(self.permute_dims, inp=[z], Tout=tf.float32) z_permuted.set_shape(z.shape) p_permuted = self.discriminator(z_permuted) loss_psi = discriminator_loss(p_z, p_permuted) grad_psi = tape.gradient(loss_psi, self.discriminator_net.variables) optimizer_discriminator.apply_gradients( zip(grad_psi, self.discriminator_net.variables)) logs = {m.name: m.result() for m in self.metrics} logs["loss"] = loss_theta return logs
def _iter(self, time_step, policy_state): """One training iteration.""" counter = tf.zeros((), tf.int32) batch_size = self._env.batch_size def create_ta(s): return tf.TensorArray(dtype=s.dtype, size=self._train_interval, element_shape=tf.TensorShape( [batch_size]).concatenate(s.shape)) training_info_ta = tf.nest.map_structure( create_ta, self._training_info_spec._replace( info=nest_utils.to_distribution_param_spec( self._training_info_spec.info))) with tf.GradientTape(watch_accessed_variables=False, persistent=True) as tape: tape.watch(self._trainable_variables) [counter, next_time_step, next_state, training_info_ta ] = tf.while_loop(cond=lambda *_: True, body=self._train_loop_body, loop_vars=[ counter, time_step, policy_state, training_info_ta ], back_prop=True, parallel_iterations=1, maximum_iterations=self._train_interval, name='iter_loop') training_info = tf.nest.map_structure(lambda ta: ta.stack(), training_info_ta) training_info = nest_utils.params_to_distributions( training_info, self._training_info_spec) loss_info, grads_and_vars = self._algorithm.train_complete( tape, training_info) del tape self._algorithm.summarize_train(training_info, loss_info, grads_and_vars) self._algorithm.summarize_metrics() common.get_global_counter().assign_add(1) return [next_time_step, next_state]
def step(batch): with tf.GradientTape() as tape: z_mean, z_log_var = self.encode(batch) z = self.sample(z_mean, z_log_var, training=True) x_mean, x_log_var = self.decode(z) loss = self.objective(batch, x_mean, x_log_var, z, z_mean, z_log_var) tf.debugging.check_numerics(loss, "Loss is not valid") grad = tape.gradient(loss, self.trainable_variables) optimizer.apply_gradients(zip(grad, self.trainable_variables)) logs = {m.name: m.result() for m in self.metrics} logs["loss"] = loss return logs
def optimizer_update(iterate_collection, iteration_idx, objective_fn, update_fn, get_params_fn, first_order, clip_grad_norm): """Returns the next iterate in the optimization of objective_fn wrt variables. Args: iterate_collection: A (potentially structured) container of tf.Tensors corresponding to the state of the current iterate. iteration_idx: An int Tensor; the iteration number. objective_fn: Callable that takes in variables and produces the value of the objective function. update_fn: Callable that takes in the gradient of the objective function and the current iterate and produces the next iterate. get_params_fn: Callable that takes in the gradient of the objective function and the current iterate and produces the next iterate. first_order: If True, prevent the computation of higher order gradients. clip_grad_norm: If not None, gradient dimensions are independently clipped to lie in the interval [-clip_grad_norm, clip_grad_norm]. """ variables = [get_params_fn(iterate) for iterate in iterate_collection] if tf.executing_eagerly(): with tf.GradientTape(persistent=True) as g: g.watch(variables) loss = objective_fn(variables, iteration_idx) grads = g.gradient(loss, variables) else: loss = objective_fn(variables, iteration_idx) grads = tf.gradients(ys=loss, xs=variables) if clip_grad_norm: grads = [ tf.clip_by_value(grad, -1 * clip_grad_norm, clip_grad_norm) for grad in grads ] if first_order: grads = [tf.stop_gradient(dv) for dv in grads] return [ update_fn(i=iteration_idx, grad=dv, state=s) for (s, dv) in zip(iterate_collection, grads) ]
def _iter(self, time_step, policy_state): """One training iteration.""" counter = tf.zeros((), tf.int32) batch_size = self._env.batch_size def create_ta(s): return tf.TensorArray(dtype=s.dtype, size=self._train_interval + 1, element_shape=tf.TensorShape( [batch_size]).concatenate(s.shape)) training_info_ta = tf.nest.map_structure(create_ta, self._training_info_spec) with tf.GradientTape(watch_accessed_variables=False, persistent=True) as tape: tape.watch(self._trainable_variables) [counter, time_step, policy_state, training_info_ta ] = tf.while_loop(cond=lambda *_: True, body=self._train_loop_body, loop_vars=[ counter, time_step, policy_state, training_info_ta ], back_prop=True, parallel_iterations=1, maximum_iterations=self._train_interval, name='iter_loop') if self._final_step_mode == OnPolicyDriver.FINAL_STEP_SKIP: next_time_step, policy_step, action = self._step( time_step, policy_state) next_state = policy_step.state else: policy_step = common.algorithm_step(self._algorithm.rollout, self._observation_transformer, time_step, policy_state) action = common.sample_action_distribution(policy_step.action) next_time_step = time_step next_state = policy_state action_distribution_param = common.get_distribution_params( policy_step.action) final_training_info = make_training_info( action_distribution=action_distribution_param, action=action, reward=time_step.reward, discount=time_step.discount, step_type=time_step.step_type, info=policy_step.info) with tape: training_info_ta = tf.nest.map_structure( lambda ta, x: ta.write(counter, x), training_info_ta, final_training_info) training_info = tf.nest.map_structure(lambda ta: ta.stack(), training_info_ta) action_distribution = nested_distributions_from_specs( self._algorithm.action_distribution_spec, training_info.action_distribution) training_info = training_info._replace( action_distribution=action_distribution) loss_info, grads_and_vars = self._algorithm.train_complete( tape, training_info) del tape self._training_summary(training_info, loss_info, grads_and_vars) self._train_step_counter.assign_add(1) return next_time_step, next_state
def train_step(exp, safe_rew, tf_agent, sc_net=None, target_sc_net=None, global_step=None, weights=None, target_update=None, metrics=None, optimizer=None, alpha=2., target_safety=None, debug_summaries=False): sc_net = sc_net or tf_agent._safety_critic_network target_sc_net = target_sc_net or tf_agent._target_safety_critic_network target_update = target_update or tf_agent._update_target_safety_critic optimizer = optimizer or tf_agent._safety_critic_optimizer get_action = lambda ts: tf_agent._actions_and_log_probs(ts)[0] time_steps, actions, next_time_steps = experience_to_transitions(exp) # update safety critic trainable_safety_variables = sc_net.trainable_variables with tf.GradientTape(watch_accessed_variables=False) as tape: assert trainable_safety_variables, ( 'No trainable safety critic variables' ' to optimize.') tape.watch(trainable_safety_variables) sc_loss = safety_critic_loss(time_steps, actions, next_time_steps, safe_rew, get_action, global_step, critic_network=sc_net, target_network=target_sc_net, target_safety=target_safety, metrics=metrics, debug_summaries=debug_summaries) sc_loss_raw = tf.reduce_mean(sc_loss) if weights is not None: sc_loss *= weights # Take the mean across the batch. sc_loss = tf.reduce_mean(sc_loss) q_safe = train_utils.eval_safety(sc_net, get_action, time_steps) lam_loss = tf.reduce_mean(q_safe - tf_agent._target_safety) total_loss = sc_loss + alpha * lam_loss tf.debugging.check_numerics(sc_loss, 'Critic loss is inf or nan.') safety_critic_grads = tape.gradient(total_loss, trainable_safety_variables) tf_agent._apply_gradients(safety_critic_grads, trainable_safety_variables, optimizer) # update target safety critic independently of target critic during pretraining target_update() return total_loss, sc_loss_raw, lam_loss
def grad(model, inputs, targets): with tf.GradientTape() as tape: loss_value = loss(model, inputs, targets) return loss_value, tape.gradient(loss_value, model.trainable_variables)
def train(hparams, num_epoch, tuning): log_dir = './results/' test_batch_size = 8 # Load dataset training_set, valid_set = make_dataset(BATCH_SIZE=hparams['HP_BS'], file_name='train_tf_record', split=True) test_set = make_dataset(BATCH_SIZE=test_batch_size, file_name='test_tf_record', split=False) class_names = ['NRDR', 'RDR'] # Model model = ResNet() # set optimizer optimizer = tf.keras.optimizers.Adam(learning_rate=hparams['HP_LR']) # set metrics train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy() valid_accuracy = tf.keras.metrics.Accuracy() valid_con_mat = ConfusionMatrix(num_class=2) test_accuracy = tf.keras.metrics.Accuracy() test_con_mat = ConfusionMatrix(num_class=2) # Save Checkpoint if not tuning: ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer, net=model) manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=5) # Set up summary writers current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") tb_log_dir = log_dir + current_time + '/train' summary_writer = tf.summary.create_file_writer(tb_log_dir) # Restore Checkpoint if not tuning: ckpt.restore(manager.latest_checkpoint) if manager.latest_checkpoint: logging.info('Restored from {}'.format(manager.latest_checkpoint)) else: logging.info('Initializing from scratch.') @tf.function def train_step(train_img, train_label): # Optimize the model loss_value, grads = grad(model, train_img, train_label) optimizer.apply_gradients(zip(grads, model.trainable_variables)) train_pred, _ = model(train_img) train_label = tf.expand_dims(train_label, axis=1) train_accuracy.update_state(train_label, train_pred) for epoch in range(num_epoch): begin = time() # Training loop for train_img, train_label, train_name in training_set: train_img = data_augmentation(train_img) train_step(train_img, train_label) with summary_writer.as_default(): tf.summary.scalar('Train Accuracy', train_accuracy.result(), step=epoch) for valid_img, valid_label, _ in valid_set: valid_img = tf.cast(valid_img, tf.float32) valid_img = valid_img / 255.0 valid_pred, _ = model(valid_img, training=False) valid_pred = tf.cast(tf.argmax(valid_pred, axis=1), dtype=tf.int64) valid_con_mat.update_state(valid_label, valid_pred) valid_accuracy.update_state(valid_label, valid_pred) # Log the confusion matrix as an image summary cm_valid = valid_con_mat.result() figure = plot_confusion_matrix(cm_valid, class_names=class_names) cm_valid_image = plot_to_image(figure) with summary_writer.as_default(): tf.summary.scalar('Valid Accuracy', valid_accuracy.result(), step=epoch) tf.summary.image('Valid ConfusionMatrix', cm_valid_image, step=epoch) end = time() logging.info( "Epoch {:d} Training Accuracy: {:.3%} Validation Accuracy: {:.3%} Time:{:.5}s" .format(epoch + 1, train_accuracy.result(), valid_accuracy.result(), (end - begin))) train_accuracy.reset_states() valid_accuracy.reset_states() valid_con_mat.reset_states() if not tuning: if int(ckpt.step) % 5 == 0: save_path = manager.save() logging.info('Saved checkpoint for epoch {}: {}'.format( int(ckpt.step), save_path)) ckpt.step.assign_add(1) for test_img, test_label, _ in test_set: test_img = tf.cast(test_img, tf.float32) test_img = test_img / 255.0 test_pred, _ = model(test_img, training=False) test_pred = tf.cast(tf.argmax(test_pred, axis=1), dtype=tf.int64) test_accuracy.update_state(test_label, test_pred) test_con_mat.update_state(test_label, test_pred) cm_test = test_con_mat.result() # Log the confusion matrix as an image summary figure = plot_confusion_matrix(cm_test, class_names=class_names) cm_test_image = plot_to_image(figure) with summary_writer.as_default(): tf.summary.scalar('Test Accuracy', test_accuracy.result(), step=epoch) tf.summary.image('Test ConfusionMatrix', cm_test_image, step=epoch) logging.info("Trained finished. Final Accuracy in test set: {:.3%}".format( test_accuracy.result())) # Visualization if not tuning: for vis_img, vis_label, vis_name in test_set: vis_label = vis_label[0] vis_name = vis_name[0] vis_img = tf.cast(vis_img[0], tf.float32) vis_img = tf.expand_dims(vis_img, axis=0) vis_img = vis_img / 255.0 with tf.GradientTape() as tape: vis_pred, conv_output = model(vis_img, training=False) pred_label = tf.argmax(vis_pred, axis=-1) vis_pred = tf.reduce_max(vis_pred, axis=-1) grad_1 = tape.gradient(vis_pred, conv_output) weight = tf.reduce_mean(grad_1, axis=[1, 2]) / grad_1.shape[1] act_map0 = tf.nn.relu( tf.reduce_sum(weight * conv_output, axis=-1)) act_map0 = tf.squeeze(tf.image.resize(tf.expand_dims(act_map0, axis=-1), (256, 256), antialias=True), axis=-1) plot_map(vis_img, act_map0, vis_pred, pred_label, vis_label, vis_name) break return test_accuracy.result()