Ejemplo n.º 1
0
    def _train_policy(self, dataset):
        """
        Train the model-based policy

        implementation details:
            (a) Train for self._training_epochs number of epochs
            (b) The dataset.random_iterator(...)  method will iterate through the dataset once in a random order
            (c) Use self._training_batch_size for iterating through the dataset
            (d) Keep track of the loss values by appending them to the losses array
        """
        timeit.start('train policy')

        losses = []
        for epoch in range(self._training_epochs):
            for states, actions, next_states, _, _ in dataset.random_iterator(
                    self._training_batch_size):
                loss = self._policy.train_step(states, actions, next_states)
                losses.append(loss)

        logger.record_tabular('TrainingLossStart', losses[0])
        logger.record_tabular('TrainingLossFinal', losses[-1])

        timeit.stop('train policy')

        logger.dump_tabular(print_func=logger.info)
Ejemplo n.º 2
0
 def _log(self, dataset):
     timeit.stop('total')
     dataset.log()
     logger.dump_tabular(print_func=logger.info)
     logger.debug('')
     for line in str(timeit).split('\n'):
         logger.debug(line)
     timeit.reset()
     timeit.start('total')
Ejemplo n.º 3
0
    def log(self, write_table_header=False):
        logger.log("Logging data in directory: %s" % logger.get_snapshot_dir())

        logger.record_tabular("Episode", self.num_episodes)

        logger.record_tabular("Accumulated Training Steps",
                              self.num_train_interactions)

        logger.record_tabular("Policy Error", self.logging_policies_error)
        logger.record_tabular("Q-Value Error", self.logging_qvalues_error)
        logger.record_tabular("V-Value Error", self.logging_vvalues_error)

        logger.record_tabular("Alpha", np_ify(self.log_alpha.exp()).item())
        logger.record_tabular("Entropy",
                              np_ify(self.logging_entropy.mean(dim=(0, ))))

        act_mean = np_ify(self.logging_mean.mean(dim=(0, )))
        act_std = np_ify(self.logging_std.mean(dim=(0, )))
        for aa in range(self.action_dim):
            logger.record_tabular("Mean Action %02d" % aa, act_mean[aa])
            logger.record_tabular("Std Action %02d" % aa, act_std[aa])

        # Evaluation Stats to plot
        logger.record_tabular("Test Rewards Mean",
                              np_ify(self.logging_eval_rewards.mean()))
        logger.record_tabular("Test Rewards Std",
                              np_ify(self.logging_eval_rewards.std()))
        logger.record_tabular("Test Returns Mean",
                              np_ify(self.logging_eval_returns.mean()))
        logger.record_tabular("Test Returns Std",
                              np_ify(self.logging_eval_returns.std()))

        # Add the previous times to the logger
        times_itrs = gt.get_times().stamps.itrs
        train_time = times_itrs.get('train', [0])[-1]
        sample_time = times_itrs.get('sample', [0])[-1]
        eval_time = times_itrs.get('eval', [0])[-1]
        epoch_time = train_time + sample_time + eval_time
        total_time = gt.get_times().total
        logger.record_tabular('Train Time (s)', train_time)
        logger.record_tabular('(Previous) Eval Time (s)', eval_time)
        logger.record_tabular('Sample Time (s)', sample_time)
        logger.record_tabular('Epoch Time (s)', epoch_time)
        logger.record_tabular('Total Train Time (s)', total_time)

        # Dump the logger data
        logger.dump_tabular(with_prefix=False,
                            with_timestamp=False,
                            write_header=write_table_header)
        # Save pytorch models
        self.save_training_state()
        logger.log("----")
Ejemplo n.º 4
0
    def _log(self, dataset):
        # stop timing
        timeit.stop('total')

        # print logging information
        dataset.log()
        logger.dump_tabular(print_func=logger.info)
        logger.debug('')
        for line in str(timeit).split('\n'):
            logger.debug(line)

        # reset timing
        timeit.reset()
        timeit.start('total')
Ejemplo n.º 5
0
    # Load buffer
    replay_buffer = utils.ReplayBuffer()
    if args.env_name == 'Multigoal-v0':
        replay_buffer.load_point_mass(buffer_name,
                                      bootstrap_dim=4,
                                      dist_cost_coeff=0.01)
    else:
        replay_buffer.load(buffer_name, bootstrap_dim=4)

    evaluations = []

    episode_num = 0
    done = True

    training_iters = 0
    while training_iters < args.max_timesteps:
        pol_vals = policy.train(replay_buffer, iterations=int(args.eval_freq))

        ret_eval, var_ret, median_ret = evaluate_policy(policy)
        evaluations.append(ret_eval)
        np.save("./results/" + file_name, evaluations)

        training_iters += args.eval_freq
        print("Training iterations: " + str(training_iters))
        logger.record_tabular('Training Epochs',
                              int(training_iters // int(args.eval_freq)))
        logger.record_tabular('AverageReturn', ret_eval)
        logger.record_tabular('VarianceReturn', var_ret)
        logger.record_tabular('MedianReturn', median_ret)
        logger.dump_tabular()
Ejemplo n.º 6
0
def main(args):
    # Setup datasets
    dload_train, dload_train_labeled, dload_valid, dload_test = get_data(args)

    # Model and buffer
    sample_q = get_sample_q(args)
    f, replay_buffer = get_model_and_buffer(args, sample_q)

    # Setup Optimizer
    params = f.class_output.parameters() if args.clf_only else f.parameters()
    if args.optimizer == "adam":
        optim = torch.optim.Adam(params,
                                 lr=args.lr,
                                 betas=[0.9, 0.999],
                                 weight_decay=args.weight_decay)
    else:
        optim = torch.optim.SGD(params,
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=args.weight_decay)

    best_valid_acc = 0.0
    cur_iter = 0
    for epoch in range(args.start_epoch, args.n_epochs):

        # Decay lr
        if epoch in args.decay_epochs:
            for param_group in optim.param_groups:
                new_lr = param_group["lr"] * args.decay_rate
                param_group["lr"] = new_lr

        # Load data
        for i, (x_p_d, _) in tqdm(enumerate(dload_train)):
            # Warmup
            if cur_iter <= args.warmup_iters:
                lr = args.lr * cur_iter / float(args.warmup_iters)
                for param_group in optim.param_groups:
                    param_group["lr"] = lr

            x_p_d = x_p_d.to(device)
            x_lab, y_lab = dload_train_labeled.__next__()
            x_lab, y_lab = x_lab.to(device), y_lab.to(device)

            # Label smoothing
            dist = smooth_one_hot(y_lab, args.n_classes, args.smoothing)

            L = 0.0

            # log p(y|x) cross entropy loss
            if args.pyxce > 0:
                logits = f.classify(x_lab)
                l_pyxce = KHotCrossEntropyLoss()(logits, dist)
                if cur_iter % args.print_every == 0:
                    acc = (logits.max(1)[1] == y_lab).float().mean()
                    print("p(y|x)CE {}:{:>d} loss={:>14.9f}, acc={:>14.9f}".
                          format(epoch, cur_iter, l_pyxce.item(), acc.item()))
                    logger.record_dict({
                        "l_pyxce": l_pyxce.cpu().data.item(),
                        "acc_pyxce": acc.item()
                    })
                L += args.pyxce * l_pyxce

            # log p(x) using sgld
            if args.pxsgld > 0:
                if args.class_cond_p_x_sample:
                    assert not args.uncond, "can only draw class-conditional samples if EBM is class-cond"
                    y_q = torch.randint(0, args.n_classes,
                                        (args.sgld_batch_size, )).to(device)
                    x_q = sample_q(f, replay_buffer, y=y_q)
                else:
                    x_q = sample_q(f, replay_buffer)  # sample from log-sumexp
                fp_all = f(x_p_d)
                fq_all = f(x_q)
                fp = fp_all.mean()
                fq = fq_all.mean()
                l_pxsgld = -(fp - fq)
                if cur_iter % args.print_every == 0:
                    print(
                        "p(x)SGLD | {}:{:>d} loss={:>14.9f} f(x_p_d)={:>14.9f} f(x_q)={:>14.9f}"
                        .format(epoch, i, l_pxsgld, fp, fq))
                    logger.record_dict(
                        {"l_pxsgld": l_pxsgld.cpu().data.item()})
                L += args.pxsgld * l_pxsgld

            # log p(x) using contrastive learning
            if args.pxcontrast > 0:
                # ones like dist to use all indexes
                ones_dist = torch.ones_like(dist).to(device)
                output, target, ce_output, neg_num = f.joint(img=x_lab,
                                                             dist=ones_dist)
                l_pxcontrast = nn.CrossEntropyLoss(reduction="mean")(output,
                                                                     target)
                if cur_iter % args.print_every == 0:
                    acc = (ce_output.max(1)[1] == y_lab).float().mean()
                    print(
                        "p(x)Contrast {}:{:>d} loss={:>14.9f}, acc={:>14.9f}".
                        format(epoch, cur_iter, l_pxcontrast.item(),
                               acc.item()))
                    logger.record_dict({
                        "l_pxcontrast":
                        l_pxcontrast.cpu().data.item(),
                        "acc_pxcontrast":
                        acc.item()
                    })
                L += args.pxycontrast * l_pxcontrast

            # log p(x|y) using sgld
            if args.pxysgld > 0:
                x_q_lab = sample_q(f, replay_buffer, y=y_lab)
                fp, fq = f(x_lab).mean(), f(x_q_lab).mean()
                l_pxysgld = -(fp - fq)
                if cur_iter % args.print_every == 0:
                    print(
                        "p(x|y)SGLD | {}:{:>d} loss={:>14.9f} f(x_p_d)={:>14.9f} f(x_q)={:>14.9f}"
                        .format(epoch, i, l_pxysgld.item(), fp, fq))
                    logger.record_dict(
                        {"l_pxysgld": l_pxysgld.cpu().data.item()})
                L += args.pxsgld * l_pxysgld

            # log p(x|y) using contrastive learning
            if args.pxycontrast > 0:
                output, target, ce_output, neg_num = f.joint(img=x_lab,
                                                             dist=dist)
                l_pxycontrast = nn.CrossEntropyLoss(reduction="mean")(output,
                                                                      target)
                if cur_iter % args.print_every == 0:
                    acc = (ce_output.max(1)[1] == y_lab).float().mean()
                    print(
                        "p(x|y)Contrast {}:{:>d} loss={:>14.9f}, acc={:>14.9f}"
                        .format(epoch, cur_iter, l_pxycontrast.item(),
                                acc.item()))
                    logger.record_dict({
                        "l_pxycontrast":
                        l_pxycontrast.cpu().data.item(),
                        "acc_pxycontrast":
                        acc.item()
                    })
                L += args.pxycontrast * l_pxycontrast

            # SGLD training of log q(x) may diverge
            # break here and record information to restart
            if L.abs().item() > 1e8:
                print("restart epoch: {}".format(epoch))
                print("save dir: {}".format(args.log_dir))
                print("id: {}".format(args.id))
                print("steps: {}".format(args.n_steps))
                print("seed: {}".format(args.seed))
                print("exp prefix: {}".format(args.exp_prefix))
                sys.stdout = sys.__stdout__
                sys.stderr = sys.__stderr__
                print("restart epoch: {}".format(epoch))
                print("save dir: {}".format(args.log_dir))
                print("id: {}".format(args.id))
                print("steps: {}".format(args.n_steps))
                print("seed: {}".format(args.seed))
                print("exp prefix: {}".format(args.exp_prefix))
                assert False, "shit loss explode..."

            optim.zero_grad()
            L.backward()
            optim.step()
            cur_iter += 1

        if epoch % args.plot_every == 0:
            if args.plot_uncond:
                if args.class_cond_p_x_sample:
                    assert not args.uncond, "can only draw class-conditional samples if EBM is class-cond"
                    y_q = torch.randint(0, args.n_classes,
                                        (args.sgld_batch_size, )).to(device)
                    x_q = sample_q(f, replay_buffer, y=y_q)
                    plot(
                        "{}/x_q_{}_{:>06d}.png".format(args.log_dir, epoch, i),
                        x_q)
                    if args.plot_contrast:
                        x_q = sample_q(f, replay_buffer, y=y_q, contrast=True)
                        plot(
                            "{}/contrast_x_q_{}_{:>06d}.png".format(
                                args.log_dir, epoch, i), x_q)
                else:
                    x_q = sample_q(f, replay_buffer)
                    plot(
                        "{}/x_q_{}_{:>06d}.png".format(args.log_dir, epoch, i),
                        x_q)
                    if args.plot_contrast:
                        x_q = sample_q(f, replay_buffer, contrast=True)
                        plot(
                            "{}/contrast_x_q_{}_{:>06d}.png".format(
                                args.log_dir, epoch, i), x_q)
            if args.plot_cond:  # generate class-conditional samples
                y = torch.arange(0, args.n_classes)[None].repeat(
                    args.n_classes,
                    1).transpose(1, 0).contiguous().view(-1).to(device)
                x_q_y = sample_q(f, replay_buffer, y=y)
                plot("{}/x_q_y{}_{:>06d}.png".format(args.log_dir, epoch, i),
                     x_q_y)
                if args.plot_contrast:
                    y = torch.arange(0, args.n_classes)[None].repeat(
                        args.n_classes,
                        1).transpose(1, 0).contiguous().view(-1).to(device)
                    x_q_y = sample_q(f, replay_buffer, y=y, contrast=True)
                    plot(
                        "{}/contrast_x_q_y_{}_{:>06d}.png".format(
                            args.log_dir, epoch, i), x_q_y)

        if args.ckpt_every > 0 and epoch % args.ckpt_every == 0:
            checkpoint(f, replay_buffer, f"ckpt_{epoch}.pt", args)

        if epoch % args.eval_every == 0:
            # Validation set
            correct, val_loss = eval_classification(f, dload_valid)
            if correct > best_valid_acc:
                best_valid_acc = correct
                print("Best Valid!: {}".format(correct))
                checkpoint(f, replay_buffer, "best_valid_ckpt.pt", args)
            # Test set
            correct, test_loss = eval_classification(f, dload_test)
            print("Epoch {}: Valid Loss {}, Valid Acc {}".format(
                epoch, val_loss, correct))
            print("Epoch {}: Test Loss {}, Test Acc {}".format(
                epoch, test_loss, correct))
            f.train()
            logger.record_dict({
                "Epoch":
                epoch,
                "Valid Loss":
                val_loss,
                "Valid Acc":
                correct.detach().cpu().numpy(),
                "Test Loss":
                test_loss,
                "Test Acc":
                correct.detach().cpu().numpy(),
                "Best Valid":
                best_valid_acc.detach().cpu().numpy(),
                "Loss":
                L.cpu().data.item(),
            })
        checkpoint(f, replay_buffer, "last_ckpt.pt", args)

        logger.dump_tabular()