Exemplo n.º 1
0
    def _actor_train_step(self, exp: Experience, state: DdpgActorState):
        action, actor_state = self._actor_network(exp.observation,
                                                  exp.step_type,
                                                  network_state=state.actor)

        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(action)
            q_value, critic_state = self._critic_network(
                (exp.observation, action), network_state=state.critic)

        dqda = tape.gradient(q_value, action)

        def actor_loss_fn(dqda, action):
            if self._dqda_clipping:
                dqda = tf.clip_by_value(dqda, -self._dqda_clipping,
                                        self._dqda_clipping)
            loss = 0.5 * losses.element_wise_squared_loss(
                tf.stop_gradient(dqda + action), action)
            loss = tf.reduce_sum(loss, axis=list(range(1, len(loss.shape))))
            return loss

        actor_loss = tf.nest.map_structure(actor_loss_fn, dqda, action)
        state = DdpgActorState(actor=actor_state, critic=critic_state)
        info = LossInfo(loss=tf.add_n(tf.nest.flatten(actor_loss)),
                        extra=actor_loss)
        return PolicyStep(action=action, state=state, info=info)
Exemplo n.º 2
0
    def test_estimated_entropy(self, assume_reparametrization):
        logging.info("assume_reparametrization=%s" % assume_reparametrization)
        num_samples = 1000000
        seed_stream = tfp.distributions.SeedStream(
            seed=1, salt='test_estimated_entropy')
        batch_shape = (2, )
        loc = tf.random.normal(shape=batch_shape, seed=seed_stream())
        scale = tf.abs(tf.random.normal(shape=batch_shape, seed=seed_stream()))

        with tf.GradientTape(persistent=True) as tape:
            tape.watch(scale)
            dist = tfp.distributions.Normal(loc=loc, scale=scale)
            analytic_entropy = dist.entropy()
            est_entropy, est_entropy_for_gradient = dist_utils.estimated_entropy(
                dist=dist,
                seed=seed_stream(),
                assume_reparametrization=assume_reparametrization,
                num_samples=num_samples)

        analytic_grad = tape.gradient(analytic_entropy, scale)
        est_grad = tape.gradient(est_entropy_for_gradient, scale)
        logging.info("scale=%s" % scale)
        logging.info("analytic_entropy=%s" % analytic_entropy)
        logging.info("estimated_entropy=%s" % est_entropy)
        self.assertArrayAlmostEqual(analytic_entropy, est_entropy, 5e-2)

        logging.info("analytic_entropy_grad=%s" % analytic_grad)
        logging.info("estimated_entropy_grad=%s" % est_grad)
        self.assertArrayAlmostEqual(analytic_grad, est_grad, 5e-2)
        if not assume_reparametrization:
            est_grad_wrong = tape.gradient(est_entropy, scale)
            logging.info("estimated_entropy_grad_wrong=%s", est_grad_wrong)
            self.assertLess(tf.reduce_max(tf.abs(est_grad_wrong)), 5e-2)
Exemplo n.º 3
0
def train_step(tf_agent, safety_critic, batch, safety_rewards, optimizer):
  """Helper function for creating a train step."""
  rb_data, buf_info = batch
  safe_rew = tf.gather(safety_rewards, buf_info.ids, axis=1)

  time_steps, actions, next_time_steps = tf_agent._experience_to_transitions(  # pylint: disable=protected-access
      rb_data)
  time_steps = time_steps._replace(reward=safe_rew[:, :-1])  # pylint: disable=protected-access
  next_time_steps = next_time_steps._replace(reward=safe_rew[:, 1:])
  fail_pct = safety_rewards.sum() / safety_rewards.shape[1]
  loss_weight = 0.5 / ((next_time_steps.reward) * fail_pct +
                       (1 - next_time_steps.reward) * (1 - fail_pct))
  trainable_safety_variables = safety_critic.trainable_variables
  with tf.GradientTape(watch_accessed_variables=False) as tape:
    assert trainable_safety_variables, ('No trainable safety critic variables'
                                        ' to optimize.')
    tape.watch(trainable_safety_variables)
    loss = safety_critic_loss(
        tf_agent,
        safety_critic,
        time_steps,
        actions,
        next_time_steps,
        safety_rewards=next_time_steps.reward,
        weights=loss_weight)

    tf.debugging.check_numerics(loss, 'Critic loss is inf or nan.')
    safety_critic_grads = tape.gradient(loss, trainable_safety_variables)
    grads_and_vars = list(zip(safety_critic_grads, trainable_safety_variables))
    optimizer.apply_gradients(grads_and_vars)
  return loss
        def step(batch_theta, batch_psi):
            with tf.GradientTape() as tape:
                z_mean, z_log_var = self.encode(batch_theta)
                z = self.sample(z_mean, z_log_var, training=True)

                p_z = self.discriminator(z)

                x_mean, x_log_var = self.decode(z)

                loss_theta = self.objective(batch_theta, x_mean, x_log_var,
                                            z_mean, z_log_var, p_z)
                tf.debugging.check_numerics(loss_theta, "loss is invalid")

            # Discriminator weights are assigned as not trainable in init
            grad_theta = tape.gradient(loss_theta, self.trainable_variables)
            optimizer.apply_gradients(zip(grad_theta,
                                          self.trainable_variables))

            # Updating Discriminator
            with tf.GradientTape() as tape:
                z_mean, z_log_var = self.encode(batch_psi)
                z = self.sample(z_mean, z_log_var, training=True)

                z_permuted = tf.py_function(self.permute_dims,
                                            inp=[z],
                                            Tout=tf.float32)
                z_permuted.set_shape(z.shape)

                p_permuted = self.discriminator(z_permuted)

                loss_psi = discriminator_loss(p_z, p_permuted)

            grad_psi = tape.gradient(loss_psi,
                                     self.discriminator_net.variables)
            optimizer_discriminator.apply_gradients(
                zip(grad_psi, self.discriminator_net.variables))

            logs = {m.name: m.result() for m in self.metrics}
            logs["loss"] = loss_theta

            return logs
Exemplo n.º 5
0
    def _iter(self, time_step, policy_state):
        """One training iteration."""
        counter = tf.zeros((), tf.int32)
        batch_size = self._env.batch_size

        def create_ta(s):
            return tf.TensorArray(dtype=s.dtype,
                                  size=self._train_interval,
                                  element_shape=tf.TensorShape(
                                      [batch_size]).concatenate(s.shape))

        training_info_ta = tf.nest.map_structure(
            create_ta,
            self._training_info_spec._replace(
                info=nest_utils.to_distribution_param_spec(
                    self._training_info_spec.info)))

        with tf.GradientTape(watch_accessed_variables=False,
                             persistent=True) as tape:
            tape.watch(self._trainable_variables)
            [counter, next_time_step, next_state, training_info_ta
             ] = tf.while_loop(cond=lambda *_: True,
                               body=self._train_loop_body,
                               loop_vars=[
                                   counter, time_step, policy_state,
                                   training_info_ta
                               ],
                               back_prop=True,
                               parallel_iterations=1,
                               maximum_iterations=self._train_interval,
                               name='iter_loop')

            training_info = tf.nest.map_structure(lambda ta: ta.stack(),
                                                  training_info_ta)

            training_info = nest_utils.params_to_distributions(
                training_info, self._training_info_spec)

        loss_info, grads_and_vars = self._algorithm.train_complete(
            tape, training_info)

        del tape

        self._algorithm.summarize_train(training_info, loss_info,
                                        grads_and_vars)
        self._algorithm.summarize_metrics()

        common.get_global_counter().assign_add(1)

        return [next_time_step, next_state]
Exemplo n.º 6
0
        def step(batch):
            with tf.GradientTape() as tape:
                z_mean, z_log_var = self.encode(batch)
                z = self.sample(z_mean, z_log_var, training=True)

                x_mean, x_log_var = self.decode(z)

                loss = self.objective(batch, x_mean, x_log_var, z, z_mean,
                                      z_log_var)

            tf.debugging.check_numerics(loss, "Loss is not valid")

            grad = tape.gradient(loss, self.trainable_variables)
            optimizer.apply_gradients(zip(grad, self.trainable_variables))
            logs = {m.name: m.result() for m in self.metrics}
            logs["loss"] = loss

            return logs
Exemplo n.º 7
0
def optimizer_update(iterate_collection, iteration_idx, objective_fn,
                     update_fn, get_params_fn, first_order, clip_grad_norm):
    """Returns the next iterate in the optimization of objective_fn wrt variables.

  Args:
    iterate_collection: A (potentially structured) container of tf.Tensors
      corresponding to the state of the current iterate.
    iteration_idx: An int Tensor; the iteration number.
    objective_fn: Callable that takes in variables and produces the value of the
      objective function.
    update_fn: Callable that takes in the gradient of the objective function and
      the current iterate and produces the next iterate.
    get_params_fn: Callable that takes in the gradient of the objective function
      and the current iterate and produces the next iterate.
    first_order: If True, prevent the computation of higher order gradients.
    clip_grad_norm: If not None, gradient dimensions are independently clipped
      to lie in the interval [-clip_grad_norm, clip_grad_norm].
  """
    variables = [get_params_fn(iterate) for iterate in iterate_collection]

    if tf.executing_eagerly():
        with tf.GradientTape(persistent=True) as g:
            g.watch(variables)
            loss = objective_fn(variables, iteration_idx)
        grads = g.gradient(loss, variables)
    else:
        loss = objective_fn(variables, iteration_idx)
        grads = tf.gradients(ys=loss, xs=variables)

    if clip_grad_norm:
        grads = [
            tf.clip_by_value(grad, -1 * clip_grad_norm, clip_grad_norm)
            for grad in grads
        ]

    if first_order:
        grads = [tf.stop_gradient(dv) for dv in grads]

    return [
        update_fn(i=iteration_idx, grad=dv, state=s)
        for (s, dv) in zip(iterate_collection, grads)
    ]
Exemplo n.º 8
0
    def _iter(self, time_step, policy_state):
        """One training iteration."""
        counter = tf.zeros((), tf.int32)
        batch_size = self._env.batch_size

        def create_ta(s):
            return tf.TensorArray(dtype=s.dtype,
                                  size=self._train_interval + 1,
                                  element_shape=tf.TensorShape(
                                      [batch_size]).concatenate(s.shape))

        training_info_ta = tf.nest.map_structure(create_ta,
                                                 self._training_info_spec)

        with tf.GradientTape(watch_accessed_variables=False,
                             persistent=True) as tape:
            tape.watch(self._trainable_variables)
            [counter, time_step, policy_state, training_info_ta
             ] = tf.while_loop(cond=lambda *_: True,
                               body=self._train_loop_body,
                               loop_vars=[
                                   counter, time_step, policy_state,
                                   training_info_ta
                               ],
                               back_prop=True,
                               parallel_iterations=1,
                               maximum_iterations=self._train_interval,
                               name='iter_loop')

        if self._final_step_mode == OnPolicyDriver.FINAL_STEP_SKIP:
            next_time_step, policy_step, action = self._step(
                time_step, policy_state)
            next_state = policy_step.state
        else:
            policy_step = common.algorithm_step(self._algorithm.rollout,
                                                self._observation_transformer,
                                                time_step, policy_state)
            action = common.sample_action_distribution(policy_step.action)
            next_time_step = time_step
            next_state = policy_state

        action_distribution_param = common.get_distribution_params(
            policy_step.action)

        final_training_info = make_training_info(
            action_distribution=action_distribution_param,
            action=action,
            reward=time_step.reward,
            discount=time_step.discount,
            step_type=time_step.step_type,
            info=policy_step.info)

        with tape:
            training_info_ta = tf.nest.map_structure(
                lambda ta, x: ta.write(counter, x), training_info_ta,
                final_training_info)
            training_info = tf.nest.map_structure(lambda ta: ta.stack(),
                                                  training_info_ta)

            action_distribution = nested_distributions_from_specs(
                self._algorithm.action_distribution_spec,
                training_info.action_distribution)

            training_info = training_info._replace(
                action_distribution=action_distribution)

        loss_info, grads_and_vars = self._algorithm.train_complete(
            tape, training_info)

        del tape

        self._training_summary(training_info, loss_info, grads_and_vars)

        self._train_step_counter.assign_add(1)

        return next_time_step, next_state
Exemplo n.º 9
0
def train_step(exp,
               safe_rew,
               tf_agent,
               sc_net=None,
               target_sc_net=None,
               global_step=None,
               weights=None,
               target_update=None,
               metrics=None,
               optimizer=None,
               alpha=2.,
               target_safety=None,
               debug_summaries=False):
    sc_net = sc_net or tf_agent._safety_critic_network
    target_sc_net = target_sc_net or tf_agent._target_safety_critic_network
    target_update = target_update or tf_agent._update_target_safety_critic
    optimizer = optimizer or tf_agent._safety_critic_optimizer
    get_action = lambda ts: tf_agent._actions_and_log_probs(ts)[0]

    time_steps, actions, next_time_steps = experience_to_transitions(exp)

    # update safety critic
    trainable_safety_variables = sc_net.trainable_variables
    with tf.GradientTape(watch_accessed_variables=False) as tape:
        assert trainable_safety_variables, (
            'No trainable safety critic variables'
            ' to optimize.')
        tape.watch(trainable_safety_variables)
        sc_loss = safety_critic_loss(time_steps,
                                     actions,
                                     next_time_steps,
                                     safe_rew,
                                     get_action,
                                     global_step,
                                     critic_network=sc_net,
                                     target_network=target_sc_net,
                                     target_safety=target_safety,
                                     metrics=metrics,
                                     debug_summaries=debug_summaries)

        sc_loss_raw = tf.reduce_mean(sc_loss)

        if weights is not None:
            sc_loss *= weights

        # Take the mean across the batch.
        sc_loss = tf.reduce_mean(sc_loss)

        q_safe = train_utils.eval_safety(sc_net, get_action, time_steps)
        lam_loss = tf.reduce_mean(q_safe - tf_agent._target_safety)
        total_loss = sc_loss + alpha * lam_loss

        tf.debugging.check_numerics(sc_loss, 'Critic loss is inf or nan.')
        safety_critic_grads = tape.gradient(total_loss,
                                            trainable_safety_variables)
        tf_agent._apply_gradients(safety_critic_grads,
                                  trainable_safety_variables, optimizer)

    # update target safety critic independently of target critic during pretraining
    target_update()

    return total_loss, sc_loss_raw, lam_loss
def grad(model, inputs, targets):
    with tf.GradientTape() as tape:
        loss_value = loss(model, inputs, targets)
    return loss_value, tape.gradient(loss_value, model.trainable_variables)
def train(hparams, num_epoch, tuning):

    log_dir = './results/'
    test_batch_size = 8
    # Load dataset
    training_set, valid_set = make_dataset(BATCH_SIZE=hparams['HP_BS'],
                                           file_name='train_tf_record',
                                           split=True)
    test_set = make_dataset(BATCH_SIZE=test_batch_size,
                            file_name='test_tf_record',
                            split=False)
    class_names = ['NRDR', 'RDR']

    # Model
    model = ResNet()

    # set optimizer
    optimizer = tf.keras.optimizers.Adam(learning_rate=hparams['HP_LR'])
    # set metrics
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
    valid_accuracy = tf.keras.metrics.Accuracy()
    valid_con_mat = ConfusionMatrix(num_class=2)
    test_accuracy = tf.keras.metrics.Accuracy()
    test_con_mat = ConfusionMatrix(num_class=2)

    # Save Checkpoint
    if not tuning:
        ckpt = tf.train.Checkpoint(step=tf.Variable(1),
                                   optimizer=optimizer,
                                   net=model)
        manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=5)

    # Set up summary writers
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    tb_log_dir = log_dir + current_time + '/train'
    summary_writer = tf.summary.create_file_writer(tb_log_dir)

    # Restore Checkpoint
    if not tuning:
        ckpt.restore(manager.latest_checkpoint)
        if manager.latest_checkpoint:
            logging.info('Restored from {}'.format(manager.latest_checkpoint))
        else:
            logging.info('Initializing from scratch.')

    @tf.function
    def train_step(train_img, train_label):
        # Optimize the model
        loss_value, grads = grad(model, train_img, train_label)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        train_pred, _ = model(train_img)
        train_label = tf.expand_dims(train_label, axis=1)
        train_accuracy.update_state(train_label, train_pred)

    for epoch in range(num_epoch):

        begin = time()

        # Training loop
        for train_img, train_label, train_name in training_set:
            train_img = data_augmentation(train_img)
            train_step(train_img, train_label)

        with summary_writer.as_default():
            tf.summary.scalar('Train Accuracy',
                              train_accuracy.result(),
                              step=epoch)

        for valid_img, valid_label, _ in valid_set:
            valid_img = tf.cast(valid_img, tf.float32)
            valid_img = valid_img / 255.0
            valid_pred, _ = model(valid_img, training=False)
            valid_pred = tf.cast(tf.argmax(valid_pred, axis=1), dtype=tf.int64)
            valid_con_mat.update_state(valid_label, valid_pred)
            valid_accuracy.update_state(valid_label, valid_pred)

        # Log the confusion matrix as an image summary
        cm_valid = valid_con_mat.result()
        figure = plot_confusion_matrix(cm_valid, class_names=class_names)
        cm_valid_image = plot_to_image(figure)

        with summary_writer.as_default():
            tf.summary.scalar('Valid Accuracy',
                              valid_accuracy.result(),
                              step=epoch)
            tf.summary.image('Valid ConfusionMatrix',
                             cm_valid_image,
                             step=epoch)

        end = time()
        logging.info(
            "Epoch {:d} Training Accuracy: {:.3%} Validation Accuracy: {:.3%} Time:{:.5}s"
            .format(epoch + 1, train_accuracy.result(),
                    valid_accuracy.result(), (end - begin)))
        train_accuracy.reset_states()
        valid_accuracy.reset_states()
        valid_con_mat.reset_states()
        if not tuning:
            if int(ckpt.step) % 5 == 0:
                save_path = manager.save()
                logging.info('Saved checkpoint for epoch {}: {}'.format(
                    int(ckpt.step), save_path))
            ckpt.step.assign_add(1)

    for test_img, test_label, _ in test_set:
        test_img = tf.cast(test_img, tf.float32)
        test_img = test_img / 255.0
        test_pred, _ = model(test_img, training=False)
        test_pred = tf.cast(tf.argmax(test_pred, axis=1), dtype=tf.int64)
        test_accuracy.update_state(test_label, test_pred)
        test_con_mat.update_state(test_label, test_pred)

    cm_test = test_con_mat.result()
    # Log the confusion matrix as an image summary
    figure = plot_confusion_matrix(cm_test, class_names=class_names)
    cm_test_image = plot_to_image(figure)
    with summary_writer.as_default():
        tf.summary.scalar('Test Accuracy', test_accuracy.result(), step=epoch)
        tf.summary.image('Test ConfusionMatrix', cm_test_image, step=epoch)

    logging.info("Trained finished. Final Accuracy in test set: {:.3%}".format(
        test_accuracy.result()))

    # Visualization
    if not tuning:
        for vis_img, vis_label, vis_name in test_set:
            vis_label = vis_label[0]
            vis_name = vis_name[0]
            vis_img = tf.cast(vis_img[0], tf.float32)
            vis_img = tf.expand_dims(vis_img, axis=0)
            vis_img = vis_img / 255.0
            with tf.GradientTape() as tape:
                vis_pred, conv_output = model(vis_img, training=False)
                pred_label = tf.argmax(vis_pred, axis=-1)
                vis_pred = tf.reduce_max(vis_pred, axis=-1)
                grad_1 = tape.gradient(vis_pred, conv_output)
                weight = tf.reduce_mean(grad_1, axis=[1, 2]) / grad_1.shape[1]
                act_map0 = tf.nn.relu(
                    tf.reduce_sum(weight * conv_output, axis=-1))
                act_map0 = tf.squeeze(tf.image.resize(tf.expand_dims(act_map0,
                                                                     axis=-1),
                                                      (256, 256),
                                                      antialias=True),
                                      axis=-1)
                plot_map(vis_img, act_map0, vis_pred, pred_label, vis_label,
                         vis_name)
            break

    return test_accuracy.result()