def train(epoch, vae_dir):
    """ One training epoch """
    model.train()
    model_p.train()
    controller.train()
    train_loss = []
    for batch_idx, [data, action, pre] in enumerate(train_loader):
        #torch.autograd.set_detect_anomaly(True)
        data = data.cuda()
        action = action.cuda()
        pre = pre.cuda()
        optimizer.zero_grad()
        optimizer_p.zero_grad()
        optimizer_a.zero_grad()
        recon_c, mu_c, logvar_c = model(data)
        loss_c = loss_function(recon_c, data, mu_c, logvar_c)
        recon_f, mu_f, logvar_f = model(pre)
        loss_f = loss_function(recon_f, pre, mu_f, logvar_f)
        recon_p, mu_p, logvar_p = model_p(torch.cat([data, action], dim=1))
        loss_p = loss_function(recon_p, pre, mu_p, logvar_p)
        mu, logsigma = mu_c.detach().cuda(), logvar_c.detach().cuda()
        sigma = torch.exp(logsigma / 2.0)
        epsilon = torch.randn_like(sigma)
        z = mu + sigma * epsilon
        z = z.cuda().view(data.shape[0], -1).detach()
        action_p = controller(z)
        #print(action[:,:,0,0])
        loss_a = F.mse_loss(action_p, action[:, :3, 11, 11], reduction='mean')
        loss = loss_c + loss_f + loss_p + loss_a
        loss.backward()

        #print(loss.item())
        train_loss.append(loss.item())
        optimizer.step()
        optimizer_p.step()
        optimizer_a.step()
        ground = data[0, ...].data.cpu().numpy().astype('float32')
        ground = np.reshape(ground, [3, 64, 64])
        vis.image(
            ground,
            opts=dict(title='ground!', caption='ground.'),
            win=current_window,
        )
        image = recon_c[0, ...].data.cpu().numpy().astype('float32')
        image = np.reshape(image, [3, 64, 64])
        vis.image(
            image,
            opts=dict(title='Reconstruction!', caption='Reconstruction.'),
            win=recon_window,
        )
        image = np.sum(ground, axis=0)
        image = (image < np.mean(image)).astype('float32')
        vis.image(
            image,
            opts=dict(title='Reconstruction!', caption='Reconstruction.'),
            win=mask_window,
        )
        ground = pre[0, ...].data.cpu().numpy().astype('float32')
        ground = np.reshape(ground, [3, 64, 64])
        vis.image(
            ground,
            opts=dict(title='future!', caption='ground.'),
            win=future_window,
        )
        image = recon_p[0, ...].data.cpu().numpy().astype('float32')
        image = np.reshape(image, [3, 64, 64])
        vis.image(
            image,
            opts=dict(title='prediction!', caption='prediction.'),
            win=pre_window,
        )
        vis.line(X=torch.ones(1).cpu() * batch_idx + torch.ones(1).cpu() *
                 (epoch - trained - 1) * args.batch_size,
                 Y=loss.item() * torch.ones(1).cpu(),
                 win=loss_window,
                 update='append')
        vis.line(X=torch.ones(1).cpu() * batch_idx + torch.ones(1).cpu() *
                 (epoch - trained - 1) * args.batch_size,
                 Y=loss_c.item() * torch.ones(1).cpu(),
                 win=lossc_window,
                 update='append')
        vis.line(X=torch.ones(1).cpu() * batch_idx + torch.ones(1).cpu() *
                 (epoch - trained - 1) * args.batch_size,
                 Y=loss_a.item() * torch.ones(1).cpu(),
                 win=lossa_window,
                 update='append')
        vis.line(X=torch.ones(1).cpu() * batch_idx + torch.ones(1).cpu() *
                 (epoch - trained - 1) * args.batch_size,
                 Y=loss_p.item() * torch.ones(1).cpu(),
                 win=lossp_window,
                 update='append')
        if batch_idx % 1 == 0:
            print(
                'Train Epoch: {} [{}/{} ({:.0f}%)]  Loss_c: {:.4f}  Loss_f: {:.4f}  Loss_p: {:.4f}  Loss_a: {:.4f}'
                .format(epoch, batch_idx * len(data),
                        len(train_loader.dataset),
                        len(data) * batch_idx / len(train_loader) / 10,
                        loss_c.item(), loss_f.item(), loss_p.item(),
                        loss_a.item()))
        if batch_idx % 1000 == 0:
            best_filename = join(vae_dir, 'best.pkl')
            filename_vae = join(vae_dir,
                                'vae_checkpoint_' + str(epoch) + '.pkl')
            filename_pre = join(vae_dir,
                                'pre_checkpoint_' + str(epoch) + '.pkl')
            filename_control = join(
                vae_dir, 'contorl_checkpoint_' + str(epoch) + '.pkl')
            # is_best = not cur_best or test_loss < cur_best
            # if is_best:
            #     cur_best = test_loss
            is_best = False
            save_checkpoint(
                {
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, is_best, filename_vae, best_filename)
            save_checkpoint(
                {
                    'epoch': epoch,
                    'state_dict': model_p.state_dict(),
                    'optimizer': optimizer_p.state_dict(),
                }, is_best, filename_pre, best_filename)
            save_checkpoint(
                {
                    'epoch': epoch,
                    'state_dict': controller.state_dict(),
                    'optimizer': optimizer_a.state_dict(),
                }, is_best, filename_control, best_filename)
Exemplo n.º 2
0

train = partial(data_pass, train=True, include_reward=args.include_reward)
test = partial(data_pass, train=False, include_reward=args.include_reward)

cur_best = None
for e in range(epochs):
    train(e)
    test_loss = test(e)
    scheduler.step(test_loss)
    earlystopping.step(test_loss)

    is_best = not cur_best or test_loss < cur_best
    if is_best:
        cur_best = test_loss
    checkpoint_fname = join(rnn_dir, 'checkpoint.tar')
    save_checkpoint(
        {
            "state_dict": mdrnn.state_dict(),
            "optimizer": optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'earlystopping': earlystopping.state_dict(),
            "precision": test_loss,
            "epoch": e
        }, is_best, checkpoint_fname, rnn_file)

    if earlystopping.stop:
        print(
            "End of Training because of early stopping at epoch {}".format(e))
        break
Exemplo n.º 3
0
    test_loss = test()
    scheduler.step(test_loss)
    earlystopping.step(test_loss)

    # checkpointing
    best_filename = join(vae_dir, 'best.tar')
    filename = join(vae_dir, 'checkpoint.tar')
    is_best = not cur_best or test_loss < cur_best
    if is_best:
        cur_best = test_loss

    save_checkpoint(
        {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'precision': test_loss,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'earlystopping': earlystopping.state_dict()
        }, is_best, filename, best_filename)

    if not args.nosamples:
        with torch.no_grad():
            sample = torch.randn(RED_SIZE, LSIZE).to(device)
            sample = model.decoder(sample).cpu()
            ch1 = sample.view(64, 3, RED_SIZE, RED_SIZE)[:, 0, :, :]
            ch2 = sample.view(64, 3, RED_SIZE, RED_SIZE)[:, 1, :, :]
            ch3 = sample.view(64, 3, RED_SIZE, RED_SIZE)[:, 2, :, :]
            ch1 = torch.stack([ch1, ch1, ch1], 1)
            ch2 = torch.stack([ch2, ch2, ch2], 1)
            ch3 = torch.stack([ch3, ch3, ch3], 1)
    # scheduler.step(test_loss)
    # earlystopping.step(test_loss)

    # checkpointing
    best_filename = join(vae_dir, 'best.pkl')
    filename_vae = join(vae_dir, 'vae_checkpoint_' + str(epoch) + '.pkl')
    filename_pre = join(vae_dir, 'pre_checkpoint_' + str(epoch) + '.pkl')
    filename_control = join(vae_dir,
                            'contorl_checkpoint_' + str(epoch) + '.pkl')
    # is_best = not cur_best or test_loss < cur_best
    # if is_best:
    #     cur_best = test_loss
    is_best = False
    save_checkpoint(
        {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, is_best, filename_vae, best_filename)
    save_checkpoint(
        {
            'epoch': epoch,
            'state_dict': model_p.state_dict(),
            'optimizer': optimizer_p.state_dict(),
        }, is_best, filename_pre, best_filename)
    save_checkpoint(
        {
            'epoch': epoch,
            'state_dict': controller.state_dict(),
            'optimizer': optimizer_a.state_dict(),
        }, is_best, filename_control, best_filename)
EvalAttack = config.create_evaluation_attack_method(DEVICE)

now_epoch = 0

if args.auto_continue:
    args.resume = os.path.join(config.model_dir, 'last.checkpoint')
if args.resume is not None and os.path.isfile(args.resume):
    now_epoch = load_checkpoint(args.resume, net, optimizer,lr_scheduler)

while True:
    if now_epoch > config.num_epochs:
        break
    now_epoch = now_epoch + 1

    descrip_str = 'Training epoch:{}/{} -- lr:{}'.format(now_epoch, config.num_epochs,
                                                                       lr_scheduler.get_lr()[0])
    acc, yofoacc = train_one_epoch(net, ds_train, optimizer, criterion, LayerOneTrainer, config.K,
                    DEVICE, descrip_str)
    tb_train_dic = {'Acc':acc, 'YofoAcc':yofoacc}
    print(tb_train_dic)
    writer.add_scalars('Train', tb_train_dic, now_epoch)
    if config.val_interval > 0 and now_epoch % config.val_interval == 0:
        acc, advacc = eval_one_epoch(net, ds_val, DEVICE, EvalAttack)
        tb_val_dic = {'Acc': acc, 'AdvAcc': advacc}
        writer.add_scalars('Val', tb_val_dic, now_epoch)

    lr_scheduler.step()
    lyaer_one_optimizer_lr_scheduler.step()
    save_checkpoint(now_epoch, net, optimizer, lr_scheduler,
                    file_name = os.path.join(config.model_dir, 'epoch.checkpoint'))
Exemplo n.º 6
0
                dim=1,
                keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    val_acc = 100. * correct / len(val_loader.dataset)
    is_best = val_acc > best_val_acc
    if is_best:
        best_val_acc = val_acc

    summary_writer.add_scalar('Validation/Acc_Trans', val_acc, epoch)
    tqdm.write('[%d/%d] Validation Accuracy %.3f' %
               (epoch + 1, args.epochs, val_acc))

    save_checkpoint(
        {
            'epoch': epoch,
            'iteration': iteration,
            'best_val_acc': best_val_acc,
            'shared_latent': shared_latent.state_dict(),
            'encoder_1': encoder_1.state_dict(),
            'encoder_2': encoder_2.state_dict(),
            'decoder_1': decoder_1.state_dict(),
            'decoder_2': decoder_2.state_dict(),
            'discriminator_1': discriminator_1.state_dict(),
            'discriminator_2': discriminator_2.state_dict(),
            'optimizerD': optimizerD.state_dict(),
            'optimizerG': optimizerG.state_dict(),
        },
        is_best=is_best,
        filename=os.path.join(args.save, 'checkpoint.pth'))
Exemplo n.º 7
0
def m_model_train_proc(rnn_dir,
                       model,
                       v_model,
                       dataset_train,
                       dataset_test,
                       optimizer,
                       scheduler,
                       earlystopping,
                       skip_train=False,
                       max_train_epochs=50):
    step_log('3-3. m_model_train_proc START!!')
    train_loader = torch.utils.data.DataLoader(dataset_train,
                                               batch_size=M_BATCH_SIZE,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset_test,
                                              batch_size=M_BATCH_SIZE)

    # check rnn dir exists, if not, create it
    if not os.path.exists(rnn_dir):
        os.mkdir(rnn_dir)

    rnn_file = os.path.join(rnn_dir, 'best.tar')
    if os.path.exists(rnn_file):
        state = torch.load(rnn_file)
        print("Reloading model at epoch {}"
              ", with test error {}".format(state['epoch'],
                                            state['precision']))
        model.load_state_dict(state['state_dict'])
        optimizer.load_state_dict(state['optimizer'])
        scheduler.load_state_dict(state['scheduler'])
        earlystopping.load_state_dict(state['earlystopping'])

    if skip_train:
        return  # pipaek : 트레이닝을 통한 모델 개선을 skip하고 싶을 때..

    def data_pass(epoch, train):  # pylint: disable=too-many-locals
        """ One pass through the data """
        if train:
            model.train()
            loader = train_loader
        else:
            model.eval()
            loader = test_loader

        loader.dataset.load_next_buffer()

        cum_loss = 0
        cum_gmm = 0
        cum_bce = 0
        cum_mse = 0

        pbar = tqdm(total=len(loader.dataset), desc="Epoch {}".format(epoch))
        for i, data in enumerate(loader):
            obs, action, reward, terminal, next_obs = [
                arr.to(device) for arr in data
            ]

            # transform obs
            latent_obs, latent_next_obs = to_latent(obs, next_obs)

            if train:
                losses = get_loss(latent_obs, action, reward, terminal,
                                  latent_next_obs)

                optimizer.zero_grad()
                losses['loss'].backward()
                optimizer.step()
            else:
                with torch.no_grad():
                    losses = get_loss(latent_obs, action, reward, terminal,
                                      latent_next_obs)

            cum_loss += losses['loss'].item()
            cum_gmm += losses['gmm'].item()
            cum_bce += losses['bce'].item()
            cum_mse += losses['mse'].item()

            pbar.set_postfix_str("loss={loss:10.6f} bce={bce:10.6f} "
                                 "gmm={gmm:10.6f} mse={mse:10.6f}".format(
                                     loss=cum_loss / (i + 1),
                                     bce=cum_bce / (i + 1),
                                     gmm=cum_gmm / LSIZE / (i + 1),
                                     mse=cum_mse / (i + 1)))
            pbar.update(M_BATCH_SIZE)
        pbar.close()
        return cum_loss * M_BATCH_SIZE / len(loader.dataset)

    def to_latent(obs, next_obs):
        """ Transform observations to latent space.

        :args obs: 5D torch tensor (BSIZE, SEQ_LEN, ASIZE, SIZE, SIZE)
        :args next_obs: 5D torch tensor (BSIZE, SEQ_LEN, ASIZE, SIZE, SIZE)

        :returns: (latent_obs, latent_next_obs)
            - latent_obs: 4D torch tensor (BSIZE, SEQ_LEN, LSIZE)
            - next_latent_obs: 4D torch tensor (BSIZE, SEQ_LEN, LSIZE)
        """
        with torch.no_grad():
            obs, next_obs = [
                F.upsample(x.view(-1, 3, SIZE, SIZE),
                           size=RED_SIZE,
                           mode='bilinear',
                           align_corners=True) for x in (obs, next_obs)
            ]

            (obs_mu, obs_logsigma), (next_obs_mu, next_obs_logsigma) = [
                v_model(x)[1:] for x in (obs, next_obs)
            ]

            latent_obs, latent_next_obs = [
                (x_mu + x_logsigma.exp() * torch.randn_like(x_mu)).view(
                    M_BATCH_SIZE, M_SEQ_LEN, LSIZE)
                for x_mu, x_logsigma in [(
                    obs_mu, obs_logsigma), (next_obs_mu, next_obs_logsigma)]
            ]
        return latent_obs, latent_next_obs

    def get_loss(latent_obs, action, reward, terminal, latent_next_obs):
        """ Compute losses.

        The loss that is computed is:
        (GMMLoss(latent_next_obs, GMMPredicted) + MSE(reward, predicted_reward) +
             BCE(terminal, logit_terminal)) / (LSIZE + 2)
        The LSIZE + 2 factor is here to counteract the fact that the GMMLoss scales
        approximately linearily with LSIZE. All losses are averaged both on the
        batch and the sequence dimensions (the two first dimensions).

        :args latent_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor
        :args action: (BSIZE, SEQ_LEN, ASIZE) torch tensor
        :args reward: (BSIZE, SEQ_LEN) torch tensor
        :args latent_next_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor

        :returns: dictionary of losses, containing the gmm, the mse, the bce and
            the averaged loss.
        """
        latent_obs, action, \
        reward, terminal, \
        latent_next_obs = [arr.transpose(1, 0)
                           for arr in [latent_obs, action,
                                       reward, terminal,
                                       latent_next_obs]]
        mus, sigmas, logpi, rs, ds = model(action, latent_obs)
        gmm = gmm_loss(latent_next_obs, mus, sigmas, logpi)
        bce = F.binary_cross_entropy_with_logits(ds, terminal)
        mse = F.mse_loss(rs, reward)
        loss = (gmm + bce + mse) / (LSIZE + 2)
        return dict(gmm=gmm, bce=bce, mse=mse, loss=loss)

    def gmm_loss(batch, mus, sigmas, logpi, reduce=True):  # pylint: disable=too-many-arguments
        """ Computes the gmm loss.

        Compute minus the log probability of batch under the GMM model described
        by mus, sigmas, pi. Precisely, with bs1, bs2, ... the sizes of the batch
        dimensions (several batch dimension are useful when you have both a batch
        axis and a time step axis), gs the number of mixtures and fs the number of
        features.

        :args batch: (bs1, bs2, *, fs) torch tensor
        :args mus: (bs1, bs2, *, gs, fs) torch tensor
        :args sigmas: (bs1, bs2, *, gs, fs) torch tensor
        :args logpi: (bs1, bs2, *, gs) torch tensor
        :args reduce: if not reduce, the mean in the following formula is ommited

        :returns:
        loss(batch) = - mean_{i1=0..bs1, i2=0..bs2, ...} log(
            sum_{k=1..gs} pi[i1, i2, ..., k] * N(
                batch[i1, i2, ..., :] | mus[i1, i2, ..., k, :], sigmas[i1, i2, ..., k, :]))

        NOTE: The loss is not reduced along the feature dimension (i.e. it should scale ~linearily
        with fs).
        """
        batch = batch.unsqueeze(-2)
        normal_dist = Normal(mus, sigmas)
        g_log_probs = normal_dist.log_prob(batch)
        g_log_probs = logpi + torch.sum(g_log_probs, dim=-1)
        max_log_probs = torch.max(g_log_probs, dim=-1, keepdim=True)[0]
        g_log_probs = g_log_probs - max_log_probs

        g_probs = torch.exp(g_log_probs)
        probs = torch.sum(g_probs, dim=-1)

        log_prob = max_log_probs.squeeze() + torch.log(probs)
        if reduce:
            return -torch.mean(log_prob)
        return -log_prob

    train = partial(data_pass, train=True)
    test = partial(data_pass, train=False)

    cur_best = None  # pipaek : 이쪽으로 옮겨야 할듯?
    for e in range(max_train_epochs):
        #cur_best = None   # pipaek : 이건 버그 아닌가??
        train(e)
        test_loss = test(e)
        scheduler.step(test_loss)
        earlystopping.step(test_loss)

        is_best = not cur_best or test_loss < cur_best
        if is_best:
            cur_best = test_loss
        checkpoint_fname = os.path.join(rnn_dir, 'checkpoint.tar')
        save_checkpoint(
            {
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'earlystopping': earlystopping.state_dict(),
                "precision": test_loss,
                "epoch": e
            }, is_best, checkpoint_fname, rnn_file)

        if earlystopping.stop:
            print(
                "End of Training because of early stopping at epoch {}".format(
                    e))
            break
Exemplo n.º 8
0
def train(model, train_loader, valid_loader, args):
    """Trainer function for PointNet
    """
    if args.webhook != '':
        slack_message("Training started in %s (%s)" %
                      (args.logdir, misc.gethostname()),
                      url=args.webhook)

    # Set device
    assert args.cuda < 0 or torch.cuda.is_available()
    device_tag = "cpu" if args.cuda == -1 else "cuda:%d" % args.cuda
    device = torch.device(device_tag)

    # Set model and label weights
    model = model.to(device)
    model.labelweights = torch.tensor(train_loader.dataset.labelweights,
                                      device=device,
                                      requires_grad=False)

    # Set optimizer (default SGD with momentum)
    if args.use_adam:
        optimizer = optim.AdamW(model.parameters(),
                                lr=args.lr,
                                weight_decay=args.weight_decay)
    else:
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              weight_decay=args.weight_decay,
                              momentum=args.momentum,
                              nesterov=True)

    # Get current state
    model_object = model.module if isinstance(model,
                                              torch.nn.DataParallel) else model
    module_file = misc.sys.modules[model_object.__class__.__module__].__file__
    state = misc.persistence(args, module_file=module_file, main_file=__file__)
    init_epoch = state["epoch"]

    if state["model_state_dict"]:
        logging.info("Loading pre-trained model from %s" % args.model_path)
        model.load_state_dict(state["model_state_dict"])

    if state["optimizer_state_dict"]:
        optimizer.load_state_dict(state["optimizer_state_dict"])

    lr_scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer,
        milestones=np.arange(args.decay_step, args.epochs,
                             args.decay_step).tolist(),
        gamma=args.lr_decay,
        last_epoch=init_epoch - 1)

    def train_one_epoch():
        iterations = tqdm(train_loader,
                          unit='batch',
                          leave=False,
                          disable=args.headless)
        ep_sum = run_one_epoch(model,
                               iterations,
                               "train",
                               optimizer=optimizer,
                               loss_update_interval=10)

        summary = {"Loss": np.mean(ep_sum["losses"])}
        return summary

    def eval_one_epoch():
        iterations = tqdm(valid_loader,
                          unit='batch',
                          leave=False,
                          desc="Validation",
                          disable=args.headless)
        ep_sum = run_one_epoch(model,
                               iterations,
                               "test",
                               get_locals=True,
                               loss_update_interval=-1)

        # preds = ep_sum["logits"].argmax(axis=-2)
        preds = [l[0].argmax(axis=0) for l in ep_sum["logits"]]
        labels = [l[0] for l in ep_sum["labels"]]
        summary = get_segmentation_metrics(labels, preds)
        summary["Loss"] = float(np.mean(ep_sum["losses"]))
        return summary

    # Train for multiple epochs
    tensorboard = SummaryWriter(log_dir=misc.join_path(args.logdir, "logs"))
    tqdm_epochs = tqdm(range(init_epoch, args.epochs),
                       total=args.epochs,
                       initial=init_epoch,
                       unit='epoch',
                       desc="Progress",
                       disable=args.headless)
    logging.info("Training started.")
    for e in tqdm_epochs:
        train_summary = train_one_epoch()
        valid_summary = eval_one_epoch()
        # valid_summary={"Loss/validation":0}
        train_summary["Learning Rate"] = lr_scheduler.get_last_lr()[-1]

        train_summary = {
            f"Train/{k}": train_summary.pop(k)
            for k in list(train_summary.keys())
        }
        valid_summary = {
            f"Validation/{k}": valid_summary.pop(k)
            for k in list(valid_summary.keys())
        }
        summary = {**train_summary, **valid_summary}

        if args.print_summary:
            tqdm_epochs.clear()
            logging.info("Epoch %d summary:\n%s\n" %
                         (e + 1, misc.json.dumps((summary), indent=2)))

        # Update learning rate and save checkpoint
        lr_scheduler.step()
        misc.save_checkpoint(
            args.logdir, {
                "epoch": e + 1,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": summary["Validation/Loss"],
                "summary": summary
            })

        # Write summary
        for name, val in summary.items():
            if "IoU per Class" in name: continue
            tensorboard.add_scalar(name, val, global_step=e + 1)

    if args.webhook != '':
        slack_message("Training finished in %s (%s)" %
                      (args.logdir, misc.gethostname()),
                      url=args.webhook)
Exemplo n.º 9
0
def train_mdrnn(logdir, traindir, epochs=10, testdir=None):
    BSIZE = 80 # maybe should change this back to their initial one of 16
    noreload = False #Best model is not reloaded if specified
    SEQ_LEN = 32
    epochs = int(epochs)

    testdir = testdir if testdir else traindir
    cuda = torch.cuda.is_available()

    torch.manual_seed(123)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


    # Loading VAE
    vae_file = join(logdir, 'vae', 'best.tar')
    assert exists(vae_file), "No trained VAE in the logdir..."
    state = torch.load(vae_file)
    print("Loading VAE at epoch {} "
          "with test error {}".format(
              state['epoch'], state['precision']))

    vae = VAE(3, LSIZE).to(device)
    vae.load_state_dict(state['state_dict'])

    # Loading model
    rnn_dir = join(logdir, 'mdrnn')
    rnn_file = join(rnn_dir, 'best.tar')

    if not exists(rnn_dir):
        mkdir(rnn_dir)

    mdrnn = MDRNN(LSIZE, ASIZE, RSIZE, 5)
    mdrnn.to(device)
    optimizer = torch.optim.RMSprop(mdrnn.parameters(), lr=1e-3, alpha=.9)
    scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
    earlystopping = EarlyStopping('min', patience=30)


    if exists(rnn_file) and not noreload:
        rnn_state = torch.load(rnn_file)
        print("Loading MDRNN at epoch {} "
              "with test error {}".format(
                  rnn_state["epoch"], rnn_state["precision"]))
        mdrnn.load_state_dict(rnn_state["state_dict"])
        optimizer.load_state_dict(rnn_state["optimizer"])
        scheduler.load_state_dict(state['scheduler'])
        earlystopping.load_state_dict(state['earlystopping'])


    # Data Loading
    transform = transforms.Lambda(
        lambda x: np.transpose(x, (0, 3, 1, 2)) / 255)
    train_loader = DataLoader(
        RolloutSequenceDataset(traindir, SEQ_LEN, transform, buffer_size=30),
        batch_size=BSIZE, num_workers=8, shuffle=True)
    test_loader = DataLoader(
        RolloutSequenceDataset(testdir, SEQ_LEN, transform, train=False, buffer_size=10),
        batch_size=BSIZE, num_workers=8)

    def to_latent(obs, next_obs):
        """ Transform observations to latent space.

        :args obs: 5D torch tensor (BSIZE, SEQ_LEN, ASIZE, SIZE, SIZE)
        :args next_obs: 5D torch tensor (BSIZE, SEQ_LEN, ASIZE, SIZE, SIZE)

        :returns: (latent_obs, latent_next_obs)
            - latent_obs: 4D torch tensor (BSIZE, SEQ_LEN, LSIZE)
            - next_latent_obs: 4D torch tensor (BSIZE, SEQ_LEN, LSIZE)
        """
        with torch.no_grad():
            obs, next_obs = [
                f.upsample(x.view(-1, 3, SIZE, SIZE), size=RED_SIZE,
                           mode='bilinear', align_corners=True)
                for x in (obs, next_obs)]

            (obs_mu, obs_logsigma), (next_obs_mu, next_obs_logsigma) = [
                vae(x)[1:] for x in (obs, next_obs)]

            latent_obs, latent_next_obs = [
                (x_mu + x_logsigma.exp() * torch.randn_like(x_mu)).view(BSIZE, SEQ_LEN, LSIZE)
                for x_mu, x_logsigma in
                [(obs_mu, obs_logsigma), (next_obs_mu, next_obs_logsigma)]]
        return latent_obs, latent_next_obs

    def get_loss(latent_obs, action, reward, terminal, latent_next_obs):
        """ Compute losses.

        The loss that is computed is:
        (GMMLoss(latent_next_obs, GMMPredicted) + MSE(reward, predicted_reward) +
             BCE(terminal, logit_terminal)) / (LSIZE + 2)
        The LSIZE + 2 factor is here to counteract the fact that the GMMLoss scales
        approximately linearily with LSIZE. All losses are averaged both on the
        batch and the sequence dimensions (the two first dimensions).

        :args latent_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor
        :args action: (BSIZE, SEQ_LEN, ASIZE) torch tensor
        :args reward: (BSIZE, SEQ_LEN) torch tensor
        :args latent_next_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor

        :returns: dictionary of losses, containing the gmm, the mse, the bce and
            the averaged loss.
        """
        latent_obs, action,\
            reward, terminal,\
            latent_next_obs = [arr.transpose(1, 0)
                               for arr in [latent_obs, action,
                                           reward, terminal,
                                           latent_next_obs]]
        mus, sigmas, logpi, rs, ds = mdrnn(action, latent_obs)
        gmm = gmm_loss(latent_next_obs, mus, sigmas, logpi)
        bce = f.binary_cross_entropy_with_logits(ds, terminal)
        mse = f.mse_loss(rs, reward)
        loss = (gmm + bce + mse) / (LSIZE + 2)
        return dict(gmm=gmm, bce=bce, mse=mse, loss=loss)


    def data_pass(epoch, train): # pylint: disable=too-many-locals
        """ One pass through the data """
        if train:
            mdrnn.train()
            loader = train_loader
        else:
            mdrnn.eval()
            loader = test_loader

        loader.dataset.load_next_buffer()

        cum_loss = 0
        cum_gmm = 0
        cum_bce = 0
        cum_mse = 0

        pbar = tqdm(total=len(loader.dataset), desc="Epoch {}".format(epoch))
        for i, data in enumerate(loader):
            obs, action, reward, terminal, next_obs = [arr.to(device) for arr in data]

            # transform obs
            latent_obs, latent_next_obs = to_latent(obs, next_obs)

            if train:
                losses = get_loss(latent_obs, action, reward,
                                  terminal, latent_next_obs)

                optimizer.zero_grad()
                losses['loss'].backward()
                optimizer.step()
            else:
                with torch.no_grad():
                    losses = get_loss(latent_obs, action, reward,
                                      terminal, latent_next_obs)

            cum_loss += losses['loss'].item()
            cum_gmm += losses['gmm'].item()
            cum_bce += losses['bce'].item()
            cum_mse += losses['mse'].item()

            pbar.set_postfix_str("loss={loss:10.6f} bce={bce:10.6f} "
                                 "gmm={gmm:10.6f} mse={mse:10.6f}".format(
                                     loss=cum_loss / (i + 1), bce=cum_bce / (i + 1),
                                     gmm=cum_gmm / LSIZE / (i + 1), mse=cum_mse / (i + 1)))
            pbar.update(BSIZE)
        pbar.close()
        return cum_loss * BSIZE / len(loader.dataset)

    train = partial(data_pass, train=True)
    test = partial(data_pass, train=False)

    for e in range(epochs):
        cur_best = None
        train(e)
        test_loss = test(e)
        scheduler.step(test_loss)
        earlystopping.step(test_loss)

        is_best = not cur_best or test_loss < cur_best
        if is_best:
            cur_best = test_loss
        checkpoint_fname = join(rnn_dir, 'checkpoint.tar')
        save_checkpoint({
            "state_dict": mdrnn.state_dict(),
            "optimizer": optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'earlystopping': earlystopping.state_dict(),
            "precision": test_loss,
            "epoch": e}, is_best, checkpoint_fname,
                        rnn_file)

        if earlystopping.stop:
            print("End of Training because of early stopping at epoch {}".format(e))
            break
Exemplo n.º 10
0
def v_model_train_proc(vae_dir,
                       model,
                       dataset_train,
                       dataset_test,
                       optimizer,
                       scheduler,
                       earlystopping,
                       skip_train=False,
                       max_train_epochs=1000):
    step_log('2-3. v_model_train_proc START!!')
    train_loader = torch.utils.data.DataLoader(dataset_train,
                                               batch_size=V_BATCH_SIZE,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset_test,
                                              batch_size=V_BATCH_SIZE,
                                              shuffle=True)

    # check vae dir exists, if not, create it
    if not os.path.exists(vae_dir):
        os.mkdir(vae_dir)

    # sample image dir for each epoch
    sample_dir = os.path.join(vae_dir, 'samples')
    if not os.path.exists(sample_dir):
        os.mkdir(sample_dir)

    reload_file = os.path.join(vae_dir, 'best.tar')
    if os.path.exists(reload_file):
        state = torch.load(reload_file)
        print("Reloading model at epoch {}"
              ", with test error {}".format(state['epoch'],
                                            state['precision']))
        v_model.load_state_dict(state['state_dict'])
        optimizer.load_state_dict(state['optimizer'])
        scheduler.load_state_dict(state['scheduler'])
        earlystopping.load_state_dict(state['earlystopping'])

    if skip_train:
        return  # pipaek : 트레이닝을 통한 모델 개선을 skip하고 싶을 때..

    cur_best = None

    for epoch in range(1, max_train_epochs + 1):
        vae_train(epoch, model, dataset_train, train_loader, optimizer)
        test_loss = vae_test(model, dataset_test, test_loader)
        scheduler.step(test_loss)
        earlystopping.step(test_loss)

        # checkpointing
        best_filename = os.path.join(vae_dir, 'best.tar')
        filename = os.path.join(vae_dir, 'checkpoint.tar')
        is_best = not cur_best or test_loss < cur_best
        if is_best:
            cur_best = test_loss

        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'precision': test_loss,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'earlystopping': earlystopping.state_dict()
            }, is_best, filename, best_filename)

        #if not args.nosamples:
        with torch.no_grad():
            sample = torch.randn(RED_SIZE, LSIZE).to(device)
            sample = model.decoder(sample).cpu()
            save_image(
                sample.view(64, 3, RED_SIZE, RED_SIZE),
                os.path.join(sample_dir, 'sample_' + str(epoch) + '.png'))

        if earlystopping.stop:
            print(
                "End of Training because of early stopping at epoch {}".format(
                    epoch))
            break
Exemplo n.º 11
0
def train(model,
          dataloader,
          dl_type,
          optimizer,
          loss_fn,
          params,
          model_dir,
          reshape,
          restore_file=None):
    """
    Train the model on `num_steps` batches

    Parameters
    ----------
    model
    dataloader
    dl_type
    optimizer
    loss_fn
    params
    model_dir
    reshape
    restore_file

    Returns
    -------

    """

    random_vector_for_generation = torch.randn(
        torch.Size([params.num_examples_to_generate,
                    params.latent_dim])).cuda()

    logging.info("\nTraining started.\n")
    # Add tensorboardX SummeryWriter to log training, logs will be save in model_dir directory
    run = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    log_dir = create_log_dir(model_dir, run)
    img_dir = os.path.join(log_dir, 'generated_images')
    create_dir(img_dir)
    with SummaryWriter(log_dir) as writer:
        # set model to training mode
        model.train()

        # reload weights from restore_file if specified
        if restore_file is not None:
            restore_path = os.path.join(model_dir,
                                        args.restore_file + '.pth.tar')
            logging.info("Restoring parameters from {}".format(restore_path))
            load_checkpoint(restore_path, model, optimizer)

        # number of iterations
        iterations = 0
        # Use tqdm progress bar for number of epochs
        for epoch in tqdm(range(params.num_epochs),
                          desc="Epochs: ",
                          leave=True):
            # Track the progress of the training batches
            training_progressor = trange(len(dataloader), desc="Loss")
            for i in training_progressor:
                iterations += 1
                # Fetch next batch of training samples
                if dl_type == 'orl_face':
                    train_batch = next(iter(dataloader))
                else:
                    train_batch, _ = next(iter(dataloader))

                # move to GPU if available
                if params.cuda:
                    train_batch = train_batch.cuda()

                # compute model output and loss
                if reshape:
                    train_batch = train_batch.view(-1, params.input_dim)

                X_reconstructed, mu, logvar, z = model(train_batch)
                losses = loss_fn(train_batch, X_reconstructed, mu, logvar, z,
                                 params.use_mse)
                loss = losses['loss']

                # clear previous gradients, compute gradients of all variables wrt loss
                optimizer.zero_grad()
                loss.backward()

                # performs updates using calculated gradients
                optimizer.step()

                # Evaluate model parameters only once in a while
                if (i + 1) % params.save_summary_steps == 0:
                    # Log values and gradients of the model parameters (histogram summary)
                    for tag, value in model.named_parameters():
                        tag = tag.replace('.', '/')
                        writer.add_histogram(tag,
                                             value.cpu().data.numpy(),
                                             iterations)
                        writer.add_histogram(tag + '/grad',
                                             value.grad.cpu().data.numpy(),
                                             iterations)

                # Compute the loss for each iteration
                summary_batch = losses
                # log loss and/or other metrics to the writer
                for tag, value in summary_batch.items():
                    writer.add_scalar(tag, value.item(), iterations)

                # update the average loss
                training_progressor.set_description("VAE (Loss=%g)" %
                                                    round(loss.item(), 4))

            # generate images for gif
            if epoch % 1 == 0:
                generate_and_save_images(model, epoch,
                                         random_vector_for_generation, img_dir)

            # Save weights
            if epoch % 10 == 0:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optim_dict': optimizer.state_dict()
                    },
                    is_best=True,
                    checkpoint=log_dir,
                    datetime=run)

        logging.info("\n\nTraining Completed.\n\n")

        logging.info(
            "Creating gif of images generated with gaussian latent vectors.\n")
        generate_gif(img_dir, writer)
Exemplo n.º 12
0
def trainFAN(args):
    global best_acc
    global best_auc

    if not os.path.exists(args.checkpoint):
        os.makedirs(args.checkpoint)

    print(
        "==> Creating model '{}-{}', stacks={} modules={}, hgdepths={} feats={} classes={}"
        .format(args.netType, args.pointType, args.nStack, args.nModules,
                args.nHgDepth, args.nFeats, args.nclasses))

    print("=> Models will be saved at: {}".format(args.checkpoint))

    # nStack=4, nModules=1, nHgDepth=4, num_feats=256, num_classes=68
    model = models.__dict__[args.netType](nStack=args.nStack,
                                          nModules=args.nModules,
                                          nHgDepth=args.nHgDepth,
                                          num_feats=args.nFeats,
                                          num_classes=args.nclasses)

    model = torch.nn.DataParallel(model).cuda()

    criterion = torch.nn.MSELoss(size_average=True).cuda()

    optimizer = torch.optim.RMSprop(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    title = args.checkpoint.split('/')[-1] + ' on ' + args.data.split('/')[-1]

    Loader = get_loader(args.data)

    val_loader = torch.utils.data.DataLoader(Loader(args, 'A'),
                                             batch_size=args.val_batch,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                            title=title,
                            resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                            title=title)
            logger.set_names([
                'Epoch', 'LR', 'Train Loss', 'Valid Loss', 'Train Acc',
                'Val Acc', 'AUC'
            ])
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names([
            'Epoch', 'LR', 'Train Loss', 'Valid Loss', 'Train Acc', 'Val Acc',
            'AUC'
        ])

    cudnn.benchmark = True
    print('=> Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / (1024. * 1024)))

    if args.evaluation:
        print('=> Evaluation only')
        D = args.data.split('/')[-1]
        save_dir = os.path.join(args.checkpoint, D)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        loss, acc, predictions, auc = validate(val_loader, model, criterion,
                                               args.netType, args.debug,
                                               args.flip)
        save_pred(predictions, checkpoint=save_dir)
        return

    train_loader = torch.utils.data.DataLoader(Loader(args, 'train'),
                                               batch_size=args.train_batch,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    lr = args.lr
    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule,
                                  args.gamma)
        print('=> Epoch: %d | LR %.8f' % (epoch + 1, lr))
        sys.stdout.flush()

        train_loss, train_acc = train(train_loader, model, criterion,
                                      optimizer, args.netType, args.debug,
                                      args.flip)
        # do not save predictions in model file
        valid_loss, valid_acc, predictions, valid_auc = validate(
            val_loader, model, criterion, args.netType, args.debug, args.flip)

        logger.append([
            int(epoch + 1), lr, train_loss, valid_loss, train_acc, valid_acc,
            valid_auc
        ])

        is_best = valid_auc >= best_auc
        best_auc = max(valid_auc, best_auc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'netType': args.netType,
                'state_dict': model.state_dict(),
                'best_acc': best_auc,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            predictions,
            checkpoint=args.checkpoint,
            filename='checkpoint.pth.tar',
            snapshot=args.snapshot)

    logger.close()
    logger.plot(['AUC'])
    savefig(os.path.join(args.checkpoint, 'log.eps'))
Exemplo n.º 13
0
def train(configs):
    train_dataset = LianjiaCornerDataset(
        data_dir=
        '/local-scratch/cjc/Lianjia-inverse-cad/FloorPlotter/data/Lianjia_corner',
        phase='train',
        augmentation=configs.augmentation)  # use rotation for augmentation

    train_loader = DataLoader(train_dataset,
                              batch_size=configs.batch_size,
                              shuffle=True)

    model = CornerEdgeNet(num_input_channel=5,
                          base_pretrained=False,
                          bin_size=36,
                          im_size=256,
                          configs=configs)
    model.double()

    criterion = nn.BCEWithLogitsLoss(reduce=False)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=configs.lr,
                                 weight_decay=configs.decay_rate)
    scheduler = StepLR(optimizer, step_size=configs.lr_step, gamma=0.1)

    num_parameters = count_parameters(model)
    print('total number of trainable parameters is: {}'.format(num_parameters))

    trainer = CornerTrainer(model=model,
                            train_loader=train_loader,
                            val_loader=None,
                            criterion=criterion,
                            optimizer=optimizer,
                            configs=configs)
    start_epoch = 0

    if configs.resume:
        if os.path.isfile(configs.model_path):
            print("=> loading checkpoint '{}'".format(configs.model_path))
            checkpoint = torch.load(configs.model_path)
            model.load_state_dict(checkpoint['state_dict'])
            start_epoch = checkpoint['epoch']
            optimizer.load_state_dict(checkpoint['optimizer'])
            if configs.use_cuda:
                transfer_optimizer_to_gpu(optimizer)
            print('=> loaded checkpoint {} (epoch {})'.format(
                configs.model_path, start_epoch))
        else:
            print('no checkpoint found at {}'.format(configs.model_path))

    if configs.use_cuda:
        model.cuda()
        criterion.cuda()

    ckpt_save_path = os.path.join(configs.exp_dir, 'checkpoints')
    if not os.path.exists(ckpt_save_path):
        os.makedirs(ckpt_save_path)

    for epoch_num in range(start_epoch, start_epoch + configs.num_epochs):
        scheduler.step()
        trainer.train_epoch(epoch_num)

        if epoch_num % configs.val_interval == 0:
            # valid_loss, valid_acc, _, _, _ = trainer.validate()

            # is_best = valid_acc > best_acc
            # best_acc = max(valid_acc, best_acc)
            save_checkpoint(
                {
                    'epoch': epoch_num + 1,
                    # 'best_acc': best_acc,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                },
                is_best=False,
                checkpoint=ckpt_save_path,
                filename='checkpoint_corner_edge_{}.pth.tar'.format(epoch_num))
Exemplo n.º 14
0
def main(args):
    global best_acc
    global best_auc

    if not os.path.exists(args.checkpoint):
        os.makedirs(args.checkpoint)

    print("==> Creating model '{}-{}', stacks={}, blocks={}, feats={}".format(
        args.netType, args.pointType, args.nStacks, args.nModules, args.nFeats))

    print("=> Models will be saved at: {}".format(args.checkpoint))

    model = models.__dict__[args.netType](
        num_stacks=args.nStacks,
        num_blocks=args.nModules,
        num_feats=args.nFeats,
        use_se=args.use_se,
        use_attention=args.use_attention,
        num_classes=68)

    model = torch.nn.DataParallel(model).cuda()

    criterion = torch.nn.MSELoss(size_average=True).cuda()

    optimizer = torch.optim.RMSprop(
        model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    title = args.checkpoint.split('/')[-1] + ' on ' + args.data.split('/')[-1]

    Loader = get_loader(args.data)

    val_loader = torch.utils.data.DataLoader(
        Loader(args, 'A'),
        batch_size=args.val_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> Loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
            logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(['Epoch', 'LR', 'Train Loss', 'Valid Loss', 'Train Acc', 'Val Acc', 'AUC'])

    cudnn.benchmark = True
    print('=> Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / (1024. * 1024)))

    if args.evaluation:
        print('=> Evaluation only')
        D = args.data.split('/')[-1]
        save_dir = os.path.join(args.checkpoint, D)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        loss, acc, predictions, auc = validate(val_loader, model, criterion, args.netType,
                                                        args.debug, args.flip)
        save_pred(predictions, checkpoint=save_dir)
        return

    train_loader = torch.utils.data.DataLoader(
        Loader(args, 'train'),
        batch_size=args.train_batch,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True)
    lr = args.lr
    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma)
        print('=> Epoch: %d | LR %.8f' % (epoch + 1, lr))

        train_loss, train_acc = train(train_loader, model, criterion, optimizer, args.netType,
                                      args.debug, args.flip)
        # do not save predictions in model file
        valid_loss, valid_acc, predictions, valid_auc = validate(val_loader, model, criterion, args.netType,
                                                      args.debug, args.flip)

        logger.append([int(epoch + 1), lr, train_loss, valid_loss, train_acc, valid_acc, valid_auc])

        is_best = valid_auc >= best_auc
        best_auc = max(valid_auc, best_auc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'netType': args.netType,
                'state_dict': model.state_dict(),
                'best_acc': best_auc,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            predictions,
            checkpoint=args.checkpoint)

    logger.close()
    logger.plot(['AUC'])
    savefig(os.path.join(args.checkpoint, 'log.eps'))
Exemplo n.º 15
0
def main():

    args = parser.parse_args()
    torch.manual_seed(args.seed)

    if not osp.exists(osp.join('checkpoints', 'Tracking')):
        os.makedirs(osp.join('checkpoints', 'Tracking'))
    vis = Dashboard(server='http://localhost', port=8097, env='Tracking')

    model = tracking_v1.TrackModel(pretrained=True).cuda()
    # checkpoint = torch.load('checkpoints/Tracking/latest.pth.tar')
    # model.load_state_dict(checkpoint['state_dict'])
    # model = tracking.TrackModel(pretrained=True)
    model = model.cuda()
    # grad = viz.create_viz('main', model, env = 'Tracking')
    # grad.regis_weight_ratio_plot('critic.fc2', 'weight', 'g/w')

    # feature_extractor network, the same learning rate
    # optimizer = torch.optim.Adam([{'params': model.fc.parameters()},
    optimizer = torch.optim.Adam([
        {'params': model.feature_extractor.parameters()},
        {'params': model.actor.parameters()},
        {'params': model.critic.parameters()},
        {'params': model.rnn.parameters()}],
        lr=args.lr)
    best_loss = 100
    train_loss1 = {}
    train_loss2 = {}
    train_loss = {}
    val_loss = {}
    val_loss1 = {}
    val_loss2 = {}
    train_reward = {}
    val_reward = {}

    model.critic.fc2.register_backward_hook(lambda module,
                                            grad_input,
                                            grad_output: grad_output)
    model.actor.fc1.weight.register_hook(hook_print)
    for epoch in range(args.num_epochs):

        reward_train, loss_train, loss1_train, loss2_train = train(args=args,
                                                                   model=model,
                                                                   optimizer=optimizer,
                                                                   video_train=video_train)
        reward_val, loss_val, loss1_val, loss2_val = test(args=args,
                                                          model=model,
                                                          video_val=video_val)

        adjust_learning_rate(optimizer=optimizer, lr=args.lr, epoch=epoch)

        train_loss[epoch] = loss_train[0, 0]
        train_loss1[epoch] = loss1_train[0, 0]
        train_loss2[epoch] = loss2_train[0, 0]
        train_reward[epoch] = reward_train
        val_loss[epoch] = loss_val[0, 0]
        val_loss1[epoch] = loss1_val[0, 0]
        val_loss2[epoch] = loss2_val[0, 0]
        val_reward[epoch] = reward_val

        # for visualization
        vis.draw(train_data=train_loss, val_data=val_loss, datatype='loss')
        vis.draw(train_data=train_loss1, val_data=val_loss1, datatype='Loss1')
        vis.draw(train_data=train_loss2, val_data=val_loss2, datatype='Loss2')
        vis.draw(train_data=train_reward, val_data=val_reward, datatype='rewards')

        print ('Train', 'epoch:', epoch, 'rewards:{%.6f}' % reward_train,
               'loss:{%.6f}' % loss_train),
        print ('validation', 'epoch:', epoch, 'rewards:{%.6f}' % reward_val,
               'loss:{%.6f}' % loss_val),

        if best_loss > loss_val[0, 0]:
            best_loss = loss_val
            is_best = True
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_loss1': best_loss,
                'train_loss': train_loss,
                'val_loss': val_loss,
                'train_rewards': train_reward,
                'val_rewards': val_reward,
            }, is_best,
                filename='epoch_{}.pth.tar'.format(epoch + 1),
                dir=os.path.join('checkpoints', 'Tracking'), epoch=epoch)
Exemplo n.º 16
0
        pbar.update(BSIZE)
    pbar.close()
    return cum_loss * BSIZE / len(loader.dataset)

train = partial(data_pass, train=True)
test = partial(data_pass, train=False)

for e in range(epochs):
    cur_best = None
    train(e)
    test_loss = test(e)
    scheduler.step(test_loss)
    earlystopping.step(test_loss)

    is_best = not cur_best or test_loss < cur_best
    if is_best:
        cur_best = test_loss
    checkpoint_fname = join(rnn_dir, 'checkpoint.tar')
    save_checkpoint({
        "state_dict": mdrnn.state_dict(),
        "optimizer": optimizer.state_dict(),
        'scheduler': scheduler.state_dict(),
        'earlystopping': earlystopping.state_dict(),
        "precision": test_loss,
        "epoch": e}, is_best, checkpoint_fname,
                    rnn_file)

    if earlystopping.stop:
        print("End of Training because of early stopping at epoch {}".format(e))
        break
Exemplo n.º 17
0
def train(
    model_id,
    sequences_per_img=5,
    batch_size=10,
    resnet_conv_feature_size=2048,
    start_from=None,
    input_json_file_name=None,
    input_label_h5_file_name=None,
    label_smoothing=0,
    structure_loss_weight=1,
    train_sample_method="sample",
    train_beam_size=1,
    struc_use_logsoftmax=True,
    train_sample_n=5,
    structure_loss_type="seqnll",
    optimizer_type=NOAM,
    noamopt_factor=1,
    noamopt_warmup=20000,
    core_optimizer="sgd",
    learning_rate=0.0005,
    optimizer_alpha=0.9,
    optimizer_beta=0.999,
    optimizer_epsilon=1e-8,
    weight_decay=0,
    load_best_score=True,
    max_epochs=50,
    scheduled_sampling_start=-1,
    scheduled_sampling_increase_every=5,
    scheduled_sampling_increase_prob=0.05,
    scheduled_sampling_max_prob=0.25,
    self_critical_after=-1,
    structure_after=-1,
    cached_tokens="coco-train-idxs",
    grad_clip_value=0.1,
    grad_clip_mode=CLIP_VALUE,
    log_loss_iterations=25,
    save_every_epoch=True,
    save_checkpoint_iterations=3000,
    save_history_ckpt=True,
    eval_language_model=True,
):

    #
    # File names
    info_file_name = (
        join(start_from, "infos_" + model_id + ".pkl") if start_from is not None else ""
    )
    history_file_name = (
        join(start_from, "histories_" + model_id + ".pkl")
        if start_from is not None
        else ""
    )
    model_file_name = join(start_from, "model.pth") if start_from is not None else ""
    optimizer_file_name = (
        join(start_from, "optimizer.pth") if start_from is not None else ""
    )

    #
    # Load data
    loader = DataLoader(
        sequences_per_img,
        batch_size=batch_size,
        use_fc=True,
        use_att=True,
        use_box=0,
        norm_att_feat=0,
        norm_box_feat=0,
        input_json_file_name=input_json_file_name,
        input_label_h5_file_name=input_label_h5_file_name,
    )
    vocab_size = loader.vocab_size
    seq_length = loader.seq_length

    #
    # Initialize training info
    infos = {
        "iter": 0,
        "epoch": 0,
        "loader_state_dict": None,
        "vocab": loader.get_vocab(),
    }

    #
    # Load existing state training information, if there is any
    if start_from is not None and isfile(info_file_name):
        #
        with open(info_file_name, "rb") as f:
            assert True

    #
    # Create data logger
    histories = defaultdict(dict)
    if start_from is not None and isfile(history_file_name):
        with open(history_file_name, "rb") as f:
            histories.update(pickle_load(f))

    # tensorboard logger
    tb_summary_writer = SummaryWriter(checkpoint_path)

    #
    # Create our model
    vocab = loader.get_vocab()
    model = Transformer(
        vocab_size, resnet_conv_feature_size=resnet_conv_feature_size
    ).cuda()

    #
    # Load pretrained weights:
    if start_from is not None and isfile(model_file_name):
        model.load_state_dict(torch_load(model_file_name))

    #
    # Wrap generation model with loss function(used for training)
    # This allows loss function computed separately on each machine
    lw_model = LossWrapper(
        model,
        label_smoothing=label_smoothing,
        structure_loss_weight=structure_loss_weight,
        train_sample_method=train_sample_method,
        train_beam_size=train_beam_size,
        struc_use_logsoftmax=struc_use_logsoftmax,
        train_sample_n=train_sample_n,
        structure_loss_type=structure_loss_type,
    )

    #
    # Wrap with dataparallel
    dp_model = DataParallel(model)
    dp_lw_model = DataParallel(lw_model)

    #
    #  Build optimizer
    if optimizer_type == NOAM:
        optimizer = get_std_opt(model, factor=noamopt_factor, warmup=noamopt_warmup)
    elif optimizer_type == REDUCE_LR:
        optimizer = build_optimizer(
            model.parameters(),
            core_optimizer=core_optimizer,
            learning_rate=learning_rate,
            optimizer_alpha=optimizer_alpha,
            optimizer_beta=optimizer_beta,
            optimizer_epsilon=optimizer_epsilon,
            weight_decay=weight_decay,
        )
        optimizer = ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        raise (
            Exception("Only supports NoamOpt and ReduceLROnPlateau optimization types")
        )

    #
    # # Load the optimizer
    if start_from is not None and isfile(optimizer_file_name):
        optimizer.load_state_dict(torch_load(optimizer_file_name))

    #
    # Prepare for training
    iteration = infos["iter"]
    epoch = infos["epoch"]
    #
    # For back compatibility
    if "iterators" in infos:
        infos["loader_state_dict"] = {
            split: {
                "index_list": infos["split_ix"][split],
                "iter_counter": infos["iterators"][split],
            }
            for split in ["train", "val", "test"]
        }
    loader.load_state_dict(infos["loader_state_dict"])
    if load_best_score == 1:
        best_val_score = infos.get("best_val_score", None)
    if optimizer_type == NOAM:
        optimizer._step = iteration
    #
    # Assure in training mode
    dp_lw_model.train()
    epoch_done = True

    #
    # Start training
    try:
        while True:
            #
            # Check max epochs
            if epoch >= max_epochs and max_epochs != -1:
                break

            #
            # Update end of epoch data
            if epoch_done:
                #
                # Assign the scheduled sampling prob
                if epoch > scheduled_sampling_start and scheduled_sampling_start >= 0:
                    frac = (
                        epoch - scheduled_sampling_start
                    ) // scheduled_sampling_increase_every
                    ss_prob = min(
                        scheduled_sampling_increase_prob * frac,
                        scheduled_sampling_max_prob,
                    )
                    model.ss_prob = ss_prob

                #
                # If start self critical training
                if self_critical_after != -1 and epoch >= self_critical_after:
                    sc_flag = True
                    init_scorer(cached_tokens)
                else:
                    sc_flag = False

                #
                # If start structure loss training
                if structure_after != -1 and epoch >= structure_after:
                    struc_flag = True
                    init_scorer(cached_tokens)
                else:
                    struc_flag = False

                #
                # End epoch update
                epoch_done = False
            #
            # Compute time to load data
            start = time.time()
            data = loader.get_batch("train")
            load_data_time = time.time() - start
            print(f"Time to load data: {load_data_time} seconds")

            ########################
            # SYNC
            ########################
            synchronize()

            #
            # Compute time to complete epoch
            start = time.time()

            #
            # Make sure data is in GPU memory
            tmp = [
                data["fc_feats"],
                data["att_feats"],
                data["labels"],
                data["masks"],
                data["att_masks"],
            ]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp

            #
            # Reset gradient
            optimizer.zero_grad()

            #
            print("MADE IT TO THE MODEL EVALUATION")
            #
            # Evaluate model
            model_out = dp_lw_model(
                fc_feats,
                att_feats,
                labels,
                masks,
                att_masks,
                data["gts"],
                torch_arange(0, len(data["gts"])),
                sc_flag,
                struc_flag,
            )

            #
            # Average loss over training batch
            loss = model_out["loss"].mean()

            #
            # Compute gradient
            loss.backward()

            #
            # Clip gradient
            if grad_clip_value != 0:
                gradient_clipping_functions[grad_clip_mode](
                    model.parameters(), grad_clip_value
                )
            #
            # Update
            optimizer.step()
            train_loss = loss.item()
            end = time.time()

            ########################
            # SYNC
            ########################
            synchronize()

            #
            # Output status
            if struc_flag:
                print(
                    "iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}".format(
                        iteration,
                        epoch,
                        train_loss,
                        model_out["lm_loss"].mean().item(),
                        model_out["struc_loss"].mean().item(),
                        end - start,
                    )
                )
            elif not sc_flag:
                print(
                    "iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}".format(
                        iteration, epoch, train_loss, end - start
                    )
                )
            else:
                print(
                    "iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}".format(
                        iteration, epoch, model_out["reward"].mean(), end - start
                    )
                )

            #
            # Update the iteration and epoch
            iteration += 1
            if data["bounds"]["wrapped"]:
                epoch += 1
                epoch_done = True

            #
            # Write the training loss summary
            if iteration % log_loss_iterations == 0:

                tb_summary_writer.add_scalar("train_loss", train_loss, iteration)

                if optimizer_type == NOAM:
                    current_lr = optimizer.rate()
                elif optimizer_type == REDUCE_LR:
                    current_lr = optimizer.current_lr

                tb_summary_writer.add_scalar("learning_rate", current_lr, iteration)
                tb_summary_writer.add_scalar(
                    "scheduled_sampling_prob", model.ss_prob, iteration
                )

                if sc_flag:
                    tb_summary_writer.add_scalar(
                        "avg_reward", model_out["reward"].mean(), iteration
                    )
                elif struc_flag:
                    tb_summary_writer.add_scalar(
                        "lm_loss", model_out["lm_loss"].mean().item(), iteration
                    )
                    tb_summary_writer.add_scalar(
                        "struc_loss", model_out["struc_loss"].mean().item(), iteration
                    )
                    tb_summary_writer.add_scalar(
                        "reward", model_out["reward"].mean().item(), iteration
                    )
                    tb_summary_writer.add_scalar(
                        "reward_var", model_out["reward"].var(1).mean(), iteration
                    )

                histories["loss_history"][iteration] = (
                    train_loss if not sc_flag else model_out["reward"].mean()
                )
                histories["lr_history"][iteration] = current_lr
                histories["ss_prob_history"][iteration] = model.ss_prob

            #
            # Update infos
            infos["iter"] = iteration
            infos["epoch"] = epoch
            infos["loader_state_dict"] = loader.state_dict()

            #
            # Make evaluation on validation set, and save model
            if (
                iteration % save_checkpoint_iterations == 0 and not save_every_epoch
            ) or (epoch_done and save_every_epoch):
                #
                # Evaluate model on Validation set of COCO
                eval_kwargs = {"split": "val", "dataset": input_json_file_name}
                val_loss, predictions, lang_stats = eval_split(
                    dp_model,
                    lw_model.crit,
                    loader,
                    verbose=True,
                    verbose_beam=False,
                    verbose_loss=True,
                    num_images=-1,
                    split="val",
                    lang_eval=False,
                    dataset="coco",
                    beam_size=1,
                    sample_n=1,
                    remove_bad_endings=False,
                    dump_path=False,
                    dump_images=False,
                    job_id="FUN_TIME",
                )

                #
                # Reduces learning rate if no improvement in objective
                if optimizer_type == REDUCE_LR:
                    if "CIDEr" in lang_stats:
                        optimizer.scheduler_step(-lang_stats["CIDEr"])
                    else:
                        optimizer.scheduler_step(val_loss)

                #
                # Write validation result into summary
                tb_summary_writer.add_scalar("validation loss", val_loss, iteration)
                if lang_stats is not None:
                    for k, v in lang_stats.items():
                        tb_summary_writer.add_scalar(k, v, iteration)

                histories["val_result_history"][iteration] = {
                    "loss": val_loss,
                    "lang_stats": lang_stats,
                    "predictions": predictions,
                }

                #
                # Save model if is improving on validation result
                if eval_language_model:
                    current_score = lang_stats["CIDEr"]
                else:
                    current_score = -val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                #
                # Dump miscalleous informations
                infos["best_val_score"] = best_val_score

                #
                # Save checkpoints...seems only most recent one keep histories,
                # and it's overwritten each time
                save_checkpoint(
                    model,
                    infos,
                    optimizer,
                    checkpoint_dir=checkpoint_path,
                    histories=histories,
                    append="RECENT",
                )
                if save_history_ckpt:
                    save_checkpoint(
                        model,
                        infos,
                        optimizer,
                        checkpoint_dir=checkpoint_path,
                        append=str(epoch) if save_every_epoch else str(iteration),
                    )
                if best_flag:
                    save_checkpoint(
                        model,
                        infos,
                        optimizer,
                        checkpoint_dir=checkpoint_path,
                        append="BEST",
                    )

    except (RuntimeError, KeyboardInterrupt):
        print(f'{BAR("=", 20)}Save checkpoint on exception...')
        save_checkpoint(
            model, infos, optimizer, checkpoint_dir=checkpoint_path, append="EXCEPTION"
        )
        print(f'...checkpoint saved.{BAR("=", 20)}')
        stack_trace = format_exc()
        print(stack_trace)