Exemplo n.º 1
0
def main(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    all_data = torch.load(args.data_file)
    x_train, x_val, x_test = all_data
    y_size = 1
    y_train = torch.zeros(x_train.size(0), y_size)
    y_val = torch.zeros(x_val.size(0), y_size)
    y_test = torch.zeros(x_test.size(0), y_size)
    train = torch.utils.data.TensorDataset(x_train, y_train)
    val = torch.utils.data.TensorDataset(x_val, y_val)
    test = torch.utils.data.TensorDataset(x_test, y_test)

    train_loader = torch.utils.data.DataLoader(train,
                                               batch_size=args.batch_size,
                                               shuffle=True)
    val_loader = torch.utils.data.DataLoader(val,
                                             batch_size=args.batch_size,
                                             shuffle=True)
    test_loader = torch.utils.data.DataLoader(test,
                                              batch_size=args.batch_size,
                                              shuffle=True)
    print('Train data: %d batches' % len(train_loader))
    print('Val data: %d batches' % len(val_loader))
    print('Test data: %d batches' % len(test_loader))
    if args.slurm == 0:
        cuda.set_device(args.gpu)
    if args.model == 'autoreg':
        args.latent_feature_map = 0
    if args.train_from == '':
        model = CNNVAE(img_size=args.img_size,
                       latent_dim=args.latent_dim,
                       enc_layers=args.enc_layers,
                       dec_kernel_size=args.dec_kernel_size,
                       dec_layers=args.dec_layers,
                       latent_feature_map=args.latent_feature_map)
    else:
        print('loading model from ' + args.train_from)
        checkpoint = torch.load(args.train_from)
        model = checkpoint['model']
    print("model architecture")
    print(model)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 betas=(0.9, 0.999))

    model.cuda()
    model.train()

    def variational_loss(input, img, model, z=None):
        mean, logvar = input
        z_samples = model._reparameterize(mean, logvar, z)
        preds = model._dec_forward(img, z_samples)
        nll = utils.log_bernoulli_loss(preds, img)
        kl = utils.kl_loss_diag(mean, logvar)
        return nll + args.beta * kl

    update_params = list(model.dec.parameters())
    meta_optimizer = OptimN2N(variational_loss,
                              model,
                              update_params,
                              eps=args.eps,
                              lr=[args.svi_lr1, args.svi_lr2],
                              iters=args.svi_steps,
                              momentum=args.momentum,
                              acc_param_grads=args.train_n2n == 1,
                              max_grad_norm=args.svi_max_grad_norm)
    epoch = 0
    t = 0
    best_val_nll = 1e5
    best_epoch = 0
    loss_stats = []
    if args.warmup == 0:
        args.beta = 1.
    else:
        args.beta = 0.1

    if args.test == 1:
        args.beta = 1
        eval(test_loader, model, meta_optimizer)
        exit()

    while epoch < args.num_epochs:
        start_time = time.time()
        epoch += 1
        print('Starting epoch %d' % epoch)
        train_nll_vae = 0.
        train_nll_autoreg = 0.
        train_kl_vae = 0.
        train_nll_svi = 0.
        train_kl_svi = 0.
        num_examples = 0
        for b, datum in enumerate(train_loader):
            if args.warmup > 0:
                args.beta = min(
                    1, args.beta + 1. / (args.warmup * len(train_loader)))
            img, _ = datum
            img = torch.bernoulli(img)
            batch_size = img.size(0)
            img = Variable(img.cuda())
            t += 1
            optimizer.zero_grad()
            if args.model == 'autoreg':
                preds = model._dec_forward(img, None)
                nll_autoreg = utils.log_bernoulli_loss(preds, img)
                train_nll_autoreg += nll_autoreg.data[0] * batch_size
                nll_autoreg.backward()
            elif args.model == 'svi':
                mean_svi = Variable(
                    0.1 * torch.zeros(batch_size, args.latent_dim).cuda(),
                    requires_grad=True)
                logvar_svi = Variable(
                    0.1 * torch.zeros(batch_size, args.latent_dim).cuda(),
                    requires_grad=True)
                var_params_svi = meta_optimizer.forward(
                    [mean_svi, logvar_svi], img, t % args.print_every == 0)
                mean_svi_final, logvar_svi_final = var_params_svi
                z_samples = model._reparameterize(mean_svi_final.detach(),
                                                  logvar_svi_final.detach())
                preds = model._dec_forward(img, z_samples)
                nll_svi = utils.log_bernoulli_loss(preds, img)
                train_nll_svi += nll_svi.data[0] * batch_size
                kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final)
                train_kl_svi += kl_svi.data[0] * batch_size
                var_loss = nll_svi + args.beta * kl_svi
                var_loss.backward()
            else:
                mean, logvar = model._enc_forward(img)
                z_samples = model._reparameterize(mean, logvar)
                preds = model._dec_forward(img, z_samples)
                nll_vae = utils.log_bernoulli_loss(preds, img)
                train_nll_vae += nll_vae.data[0] * batch_size
                kl_vae = utils.kl_loss_diag(mean, logvar)
                train_kl_vae += kl_vae.data[0] * batch_size
                if args.model == 'vae':
                    vae_loss = nll_vae + args.beta * kl_vae
                    vae_loss.backward(retain_graph=True)

                if args.model == 'savae':
                    var_params = torch.cat([mean, logvar], 1)
                    mean_svi = Variable(mean.data, requires_grad=True)
                    logvar_svi = Variable(logvar.data, requires_grad=True)

                    var_params_svi = meta_optimizer.forward(
                        [mean_svi, logvar_svi], img, t % args.print_every == 0)
                    mean_svi_final, logvar_svi_final = var_params_svi
                    z_samples = model._reparameterize(mean_svi_final,
                                                      logvar_svi_final)
                    preds = model._dec_forward(img, z_samples)
                    nll_svi = utils.log_bernoulli_loss(preds, img)
                    train_nll_svi += nll_svi.data[0] * batch_size
                    kl_svi = utils.kl_loss_diag(mean_svi_final,
                                                logvar_svi_final)
                    train_kl_svi += kl_svi.data[0] * batch_size
                    var_loss = nll_svi + args.beta * kl_svi
                    var_loss.backward(retain_graph=True)
                    if args.train_n2n == 0:
                        if args.train_kl == 1:
                            mean_final = mean_svi_final.detach()
                            logvar_final = logvar_svi_final.detach()
                            kl_init_final = utils.kl_loss(
                                mean, logvar, mean_final, logvar_final)
                            kl_init_final.backward(retain_graph=True)
                        else:
                            vae_loss = nll_vae + args.beta * kl_vae
                            var_param_grads = torch.autograd.grad(
                                vae_loss, [mean, logvar], retain_graph=True)
                            var_param_grads = torch.cat(var_param_grads, 1)
                            var_params.backward(var_param_grads,
                                                retain_graph=True)
                    else:
                        var_param_grads = meta_optimizer.backward(
                            [mean_svi_final.grad, logvar_svi_final.grad],
                            t % args.print_every == 0)
                        var_param_grads = torch.cat(var_param_grads, 1)
                        var_params.backward(var_param_grads)
            if args.max_grad_norm > 0:
                torch.nn.utils.clip_grad_norm(model.parameters(),
                                              args.max_grad_norm)
            optimizer.step()
            num_examples += batch_size
            if t % args.print_every == 0:
                param_norm = sum([p.norm()**2
                                  for p in model.parameters()]).data[0]**0.5
                print(
                    'Iters: %d, Epoch: %d, Batch: %d/%d, LR: %.4f, TrainARNLL: %.2f, TrainVAE_NLL: %.2f, TrainVAE_KL: %.4f, TrainVAE_NLLBnd: %.2f, TrainSVI_NLL: %.2f, TrainSVI_KL: %.4f, TrainSVI_NLLBnd: %.2f, |Param|: %.4f, BestValPerf: %.2f, BestEpoch: %d, Beta: %.3f, Throughput: %.2f examples/sec'
                    %
                    (t, epoch, b + 1, len(train_loader), args.lr,
                     train_nll_autoreg / num_examples,
                     train_nll_vae / num_examples, train_kl_vae / num_examples,
                     (train_nll_vae + train_kl_vae) / num_examples,
                     train_nll_svi / num_examples, train_kl_svi / num_examples,
                     (train_nll_svi + train_kl_svi) / num_examples, param_norm,
                     best_val_nll, best_epoch, args.beta, num_examples /
                     (time.time() - start_time)))
        print('--------------------------------')
        print('Checking validation perf...')
        val_nll = eval(val_loader, model, meta_optimizer)
        loss_stats.append(val_nll)
        if val_nll < best_val_nll:
            best_val_nll = val_nll
            best_epoch = epoch
            checkpoint = {
                'args': args.__dict__,
                'model': model,
                'optimizer': optimizer,
                'loss_stats': loss_stats
            }
            print('Savaeng checkpoint to %s' % args.checkpoint_path)
            torch.save(checkpoint, args.checkpoint_path)
Exemplo n.º 2
0
def main(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_data = Dataset(args.train_file)
    val_data = Dataset(args.val_file)
    test_data = Dataset(args.test_file)
    train_sents = train_data.batch_size.sum()
    vocab_size = int(train_data.vocab_size)
    logger.info('Train data: %d batches' % len(train_data))
    logger.info('Val data: %d batches' % len(val_data))
    logger.info('Test data: %d batches' % len(test_data))
    logger.info('Word vocab size: %d' % vocab_size)

    checkpoint_dir = args.checkpoint_dir
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    suffix = "%s_%s.pt" % (args.model, 'cyc')
    checkpoint_path = os.path.join(checkpoint_dir, suffix)

    if args.slurm == 0:
        cuda.set_device(args.gpu)
    if args.train_from == '':
        model = RNNVAE(vocab_size=vocab_size,
                       enc_word_dim=args.enc_word_dim,
                       enc_h_dim=args.enc_h_dim,
                       enc_num_layers=args.enc_num_layers,
                       dec_word_dim=args.dec_word_dim,
                       dec_h_dim=args.dec_h_dim,
                       dec_num_layers=args.dec_num_layers,
                       dec_dropout=args.dec_dropout,
                       latent_dim=args.latent_dim,
                       mode=args.model)
        for param in model.parameters():
            param.data.uniform_(-0.1, 0.1)
    else:
        logger.info('loading model from ' + args.train_from)
        checkpoint = torch.load(args.train_from)
        model = checkpoint['model']

    logger.info("model architecture")
    print(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    if args.warmup == 0:
        args.beta = 1.
    else:
        args.beta = 0.1

    criterion = nn.NLLLoss()
    model.cuda()
    criterion.cuda()
    model.train()

    def variational_loss(input, sents, model, z=None):
        mean, logvar = input
        z_samples = model._reparameterize(mean, logvar, z)
        preds = model._dec_forward(sents, z_samples)
        nll = sum([
            criterion(preds[:, l], sents[:, l + 1])
            for l in range(preds.size(1))
        ])
        kl = utils.kl_loss_diag(mean, logvar)
        return nll + args.beta * kl

    update_params = list(model.dec.parameters())
    meta_optimizer = OptimN2N(variational_loss,
                              model,
                              update_params,
                              eps=args.eps,
                              lr=[args.svi_lr1, args.svi_lr2],
                              iters=args.svi_steps,
                              momentum=args.momentum,
                              acc_param_grads=args.train_n2n == 1,
                              max_grad_norm=args.svi_max_grad_norm)
    if args.test == 1:
        args.beta = 1
        test_data = Dataset(args.test_file)
        eval(test_data, model, meta_optimizer)
        exit()

    t = 0
    best_val_nll = 1e5
    best_epoch = 0
    val_stats = []
    epoch = 0
    while epoch < args.num_epochs:
        start_time = time.time()
        epoch += 1
        logger.info('Starting epoch %d' % epoch)
        train_nll_vae = 0.
        train_nll_autoreg = 0.
        train_kl_vae = 0.
        train_nll_svi = 0.
        train_kl_svi = 0.
        train_kl_init_final = 0.
        num_sents = 0
        num_words = 0
        b = 0

        tmp = float((epoch - 1) % args.cycle) / args.cycle
        cur_lr = args.lr * 0.5 * (1 + np.cos(tmp * np.pi))
        for param_group in optimizer.param_groups:
            param_group['lr'] = cur_lr

        if (epoch - 1) % args.cycle == 0:
            args.beta = 0.1
            logger.info('KL annealing restart')

        for i in np.random.permutation(len(train_data)):
            if args.warmup > 0:
                args.beta = min(
                    1, args.beta + 1. / (args.warmup * len(train_data)))

            sents, length, batch_size = train_data[i]
            if args.gpu >= 0:
                sents = sents.cuda()
            b += 1

            optimizer.zero_grad()
            if args.model == 'autoreg':
                preds = model._dec_forward(sents, None, True)
                nll_autoreg = sum([
                    criterion(preds[:, l], sents[:, l + 1])
                    for l in range(length)
                ])
                train_nll_autoreg += nll_autoreg.data[0] * batch_size
                nll_autoreg.backward()
            elif args.model == 'svi':
                mean_svi = Variable(
                    0.1 * torch.zeros(batch_size, args.latent_dim).cuda(),
                    requires_grad=True)
                logvar_svi = Variable(
                    0.1 * torch.zeros(batch_size, args.latent_dim).cuda(),
                    requires_grad=True)
                var_params_svi = meta_optimizer.forward(
                    [mean_svi, logvar_svi], sents, b % args.print_every == 0)
                mean_svi_final, logvar_svi_final = var_params_svi
                z_samples = model._reparameterize(mean_svi_final.detach(),
                                                  logvar_svi_final.detach())
                preds = model._dec_forward(sents, z_samples)
                nll_svi = sum([
                    criterion(preds[:, l], sents[:, l + 1])
                    for l in range(length)
                ])
                train_nll_svi += nll_svi.data[0] * batch_size
                kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final)
                train_kl_svi += kl_svi.data[0] * batch_size
                var_loss = nll_svi + args.beta * kl_svi
                var_loss.backward(retain_graph=True)
            else:
                mean, logvar = model._enc_forward(sents)
                z_samples = model._reparameterize(mean, logvar)
                preds = model._dec_forward(sents, z_samples)
                nll_vae = sum([
                    criterion(preds[:, l], sents[:, l + 1])
                    for l in range(length)
                ])
                train_nll_vae += nll_vae.data[0] * batch_size
                kl_vae = utils.kl_loss_diag(mean, logvar)
                train_kl_vae += kl_vae.data[0] * batch_size
                if args.model == 'vae':
                    vae_loss = nll_vae + args.beta * kl_vae
                    vae_loss.backward(retain_graph=True)
                if args.model == 'savae':
                    var_params = torch.cat([mean, logvar], 1)
                    mean_svi = Variable(mean.data, requires_grad=True)
                    logvar_svi = Variable(logvar.data, requires_grad=True)
                    var_params_svi = meta_optimizer.forward(
                        [mean_svi, logvar_svi], sents,
                        b % args.print_every == 0)
                    mean_svi_final, logvar_svi_final = var_params_svi
                    z_samples = model._reparameterize(mean_svi_final,
                                                      logvar_svi_final)
                    preds = model._dec_forward(sents, z_samples)
                    nll_svi = sum([
                        criterion(preds[:, l], sents[:, l + 1])
                        for l in range(length)
                    ])
                    train_nll_svi += nll_svi.data[0] * batch_size
                    kl_svi = utils.kl_loss_diag(mean_svi_final,
                                                logvar_svi_final)
                    train_kl_svi += kl_svi.data[0] * batch_size
                    var_loss = nll_svi + args.beta * kl_svi
                    var_loss.backward(retain_graph=True)
                    if args.train_n2n == 0:
                        if args.train_kl == 1:
                            mean_final = mean_svi_final.detach()
                            logvar_final = logvar_svi_final.detach()
                            kl_init_final = utils.kl_loss(
                                mean, logvar, mean_final, logvar_final)
                            train_kl_init_final += kl_init_final.data[
                                0] * batch_size
                            kl_init_final.backward(retain_graph=True)
                        else:
                            vae_loss = nll_vae + args.beta * kl_vae
                            var_param_grads = torch.autograd.grad(
                                vae_loss, [mean, logvar], retain_graph=True)
                            var_param_grads = torch.cat(var_param_grads, 1)
                            var_params.backward(var_param_grads,
                                                retain_graph=True)
                    else:
                        var_param_grads = meta_optimizer.backward(
                            [mean_svi_final.grad, logvar_svi_final.grad],
                            b % args.print_every == 0)
                        var_param_grads = torch.cat(var_param_grads, 1)
                        var_params.backward(var_param_grads)
            if args.max_grad_norm > 0:
                torch.nn.utils.clip_grad_norm(model.parameters(),
                                              args.max_grad_norm)
            optimizer.step()
            num_sents += batch_size
            num_words += batch_size * length

            if b % args.print_every == 0:
                param_norm = sum([p.norm()**2
                                  for p in model.parameters()]).data[0]**0.5
                logger.info(
                    'Iters: %d, Epoch: %d, Batch: %d/%d, LR: %.4f, TrainARNLL: %.4f, TrainARPPL: %.2f, TrainVAE_NLL: %.4f, TrainVAE_REC: %.4f, TrainVAE_KL: %.4f, TrainVAE_PPL: %.2f, TrainSVI_NLL: %.2f, TrainSVI_REC: %.2f, TrainSVI_KL: %.4f, TrainSVI_PPL: %.2f, KLInitFinal: %.2f, |Param|: %.4f, BestValPerf: %.2f, BestEpoch: %d, Beta: %.4f, Throughput: %.2f examples/sec'
                    % (t, epoch, b + 1, len(train_data), cur_lr,
                       train_nll_autoreg / num_sents,
                       np.exp(train_nll_autoreg / num_words),
                       (train_nll_vae + train_kl_vae) / num_sents,
                       train_nll_vae / num_sents, train_kl_vae / num_sents,
                       np.exp((train_nll_vae + train_kl_vae) / num_words),
                       (train_nll_svi + train_kl_svi) / num_sents,
                       train_nll_svi / num_sents, train_kl_svi / num_sents,
                       np.exp((train_nll_svi + train_kl_svi) / num_words),
                       train_kl_init_final / num_sents, param_norm,
                       best_val_nll, best_epoch, args.beta, num_sents /
                       (time.time() - start_time)))

        epoch_train_time = time.time() - start_time
        logger.info('Time Elapsed: %.1fs' % epoch_train_time)

        logger.info('--------------------------------')
        logger.info('Checking validation perf...')
        logger.record_tabular('Epoch', epoch)
        logger.record_tabular('Mode', 'Val')
        logger.record_tabular('LR', cur_lr)
        logger.record_tabular('Epoch Train Time', epoch_train_time)
        val_nll = eval(val_data, model, meta_optimizer)
        val_stats.append(val_nll)

        logger.info('--------------------------------')
        logger.info('Checking test perf...')
        logger.record_tabular('Epoch', epoch)
        logger.record_tabular('Mode', 'Test')
        logger.record_tabular('LR', cur_lr)
        logger.record_tabular('Epoch Train Time', epoch_train_time)
        test_nll = eval(test_data, model, meta_optimizer)

        if val_nll < best_val_nll:
            best_val_nll = val_nll
            best_epoch = epoch
            model.cpu()
            checkpoint = {
                'args': args.__dict__,
                'model': model,
                'val_stats': val_stats
            }
            logger.info('Save checkpoint to %s' % checkpoint_path)
            torch.save(checkpoint, checkpoint_path)
            model.cuda()
        else:
            if epoch >= args.min_epochs:
                args.decay = 1
Exemplo n.º 3
0
def train(config):
    base_network = network.ResNetFc('ResNet50', use_bottleneck=True, bottleneck_dim=config["bottleneck_dim"], new_cls=True, class_num=config["class_num"])
    ad_net = network.AdversarialNetwork(config["bottleneck_dim"], config["hidden_dim"])

    base_network = base_network.cuda()
    ad_net = ad_net.cuda()

    parameter_list = base_network.get_parameters() + ad_net.get_parameters()

    source_path = ImageList(open(config["s_path"]).readlines(), transform=preprocess.image_train(resize_size=256, crop_size=224))
    target_path = ImageList(open(config["t_path"]).readlines(), transform=preprocess.image_train(resize_size=256, crop_size=224))
    test_path   = ImageList(open(config["t_path"]).readlines(), transform=preprocess.image_test(resize_size=256, crop_size=224))

    source_loader = DataLoader(source_path, batch_size=config["train_bs"], shuffle=True, num_workers=0, drop_last=True)
    target_loader = DataLoader(target_path, batch_size=config["train_bs"], shuffle=True, num_workers=0, drop_last=True)
    test_loader   = DataLoader(test_path, batch_size=config["test_bs"], shuffle=True, num_workers=0, drop_last=True)

    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    gpus = config["gpus"].split(',')
    if len(gpus) > 1:
        ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus])
        base_network = nn.DataParallel(base_network, device_ids=[int(i) for i in gpus])


    len_train_source = len(source_loader)
    len_train_target = len(target_loader)

    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    best_model_path = None

    for i in trange(config["iterations"], leave=False):
        if i % config["test_interval"] == config["test_interval"] - 1:
            base_network.train(False)
            temp_acc = image_classification_test(test_loader, base_network)
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = copy.deepcopy(temp_model)
                best_iter = i
                if best_model_path and osp.exists(best_model_path):
                    try:
                        os.remove(best_model_path)
                    except:
                        pass
                best_model_path = osp.join(config["output_path"], "iter_{:05d}.pth.tar".format(best_iter))
                torch.save(best_model, best_model_path)
            log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
            config["out_file"].write(log_str+"\n")
            config["out_file"].flush()
            # print("cut_loss: ", cut_loss.item())
            print("mix_loss: ", mix_loss.item())
            print(log_str)

        base_network.train(True)
        ad_net.train(True)
        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(source_loader)
        if i % len_train_target == 0:
            iter_target = iter(target_loader)

        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        inputs_source, inputs_target, labels_source = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda()
        labels_src_one_hot = torch.nn.functional.one_hot(labels_source, config["class_num"]).float()

        # inputs_cut, labels_cut = cutmix(base_network, inputs_source, labels_src_one_hot, inputs_target, config["alpha"], config["class_num"])
        inputs_mix, labels_mix = mixup(base_network, inputs_source, labels_src_one_hot, inputs_target, config["alpha"], config["class_num"], config["temperature"])

        features_source, outputs_source = base_network(inputs_source)
        features_target, outputs_target = base_network(inputs_target)
        # features_cut,    outputs_cut    = base_network(inputs_cut)
        features_mix,    outputs_mix    = base_network(inputs_mix)

        features = torch.cat((features_source, features_target), dim=0)
        outputs = torch.cat((outputs_source, outputs_target), dim=0)
        softmax_out = nn.Softmax(dim=1)(outputs)

        if config["method"] == 'DANN':
            transfer_loss = loss.DANN(features, ad_net)
            # cut_loss = utils.kl_loss(outputs_cut, labels_cut.detach())
            mix_loss = utils.kl_loss(outputs_mix, labels_mix.detach())
        else:
            raise ValueError('Method cannot be recognized.')

        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
        total_loss = transfer_loss + classifier_loss + (5*mix_loss)
        total_loss.backward()
        optimizer.step()
    torch.save(best_model, osp.join(config["output_path"], "best_model.pth.tar"))
    print("Training Finished! Best Accuracy: ", best_acc)
    return best_acc
Exemplo n.º 4
0
SAVE_PATH = "trained_models/VAE_jet_L12_BE2.dat"

N_EPOCHS = 10000
optimizer = optim.Adam(net.parameters())
rms_loss = []
kldiv_loss = []

for epoch in range(N_EPOCHS):
    epoch_rms_loss = []
    epoch_kldiv_loss = []
    for minibatch in dataloader:
        inputs, outputs = minibatch
        optimizer.zero_grad()
        pred = net.forward(inputs)
        kl = beta * kl_loss(net.mu, net.log_sigma)
        rms = target_loss(pred, outputs)
        loss = rms + kl
        loss.backward()
        optimizer.step()

        epoch_rms_loss.append(np.mean(rms.data.detach().numpy()))
        epoch_kldiv_loss.append(np.mean(kl.data.detach().numpy()))

    kldiv_loss.append(np.mean(epoch_kldiv_loss))
    rms_loss.append(np.mean(epoch_rms_loss))
    print("Epoch %d -- rms error %f -- kl loss %f" %
          (epoch + 1, rms_loss[-1], kldiv_loss[-1]))

torch.save(net.state_dict(), SAVE_PATH)
print("Model saved to %s" % SAVE_PATH)
Exemplo n.º 5
0
def train(batch_size,
          epochs,
          model,
          dataset,
          valid_size=5000,
          label_smoothing=0.0,
          gpu='cuda:0'):

    device = torch.device(gpu if torch.cuda.is_available() else 'cpu')

    assert model in ['resnet18', 'resnet101', 'densenet121', 'densenet169']
    assert dataset in ['cifar10', 'cifar100']

    print("batch_size =", batch_size)
    print("epochs =", epochs)
    print("model =", model)
    print("data set =", dataset)
    print("label_smoothing =", label_smoothing)

    if dataset == 'cifar100':
        num_classes = 100
        mean = [0.5071, 0.4867, 0.4408]
        std = [0.2675, 0.2565, 0.2761]

        train_set = datasets.CIFAR100(
            '../data',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
                transforms.RandomHorizontalFlip(),
                # transforms.RandomRotation(15),
                transforms.ToTensor(),
                transforms.RandomErasing(p=0.5),
                transforms.Normalize(mean=mean, std=std)
            ]))
        valid_set = datasets.CIFAR100('../data',
                                      train=True,
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),
                                          transforms.Normalize(mean=mean,
                                                               std=std)
                                      ]))
        train_indices = torch.load('./train_indices_cifar100.pth')
        valid_indices = torch.load('./valid_indices_cifar100.pth')
    elif dataset == 'cifar10':  # cifar10
        num_classes = 10
        mean = [0.4914, 0.48216, 0.44653]
        std = [0.2470, 0.2435, 0.26159]

        train_set = datasets.CIFAR10(
            '../data',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
                transforms.RandomHorizontalFlip(),
                # transforms.RandomRotation(15),
                transforms.ToTensor(),
                transforms.RandomErasing(p=0.5),
                transforms.Normalize(mean=mean, std=std)
            ]))
        valid_set = datasets.CIFAR10('../data',
                                     train=True,
                                     transform=transforms.Compose([
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=mean,
                                                              std=std)
                                     ]))
        train_indices = torch.load('./train_indices_cifar10.pth')
        valid_indices = torch.load('./valid_indices_cifar10.pth')

    # indices = torch.randperm(len(train_set))
    # train_indices = indices[:len(indices) - valid_size]
    # valid_indices = indices[len(indices) - valid_size:]
    # torch.save(train_indices, './train_indices_' + dataset + '.pth')
    # torch.save(valid_indices, './valid_indices_' + dataset + '.pth')

    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=batch_size,
        sampler=SubsetRandomSampler(train_indices))
    valid_loader = torch.utils.data.DataLoader(
        valid_set,
        batch_size=batch_size,
        sampler=SubsetRandomSampler(valid_indices))

    net = BayesianNet(num_classes=num_classes, model=model).to(device)
    net.apply(xavier_normal_init)

    # optimizer_net = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-6)
    optimizer_net = optim.AdamW(net.parameters(), lr=0.01)
    lr_scheduler_net = optim.lr_scheduler.ReduceLROnPlateau(optimizer_net,
                                                            patience=10,
                                                            factor=0.1)

    train_losses = []
    train_accuracies = []
    valid_losses = []
    valid_accuracies = []

    for e in range(epochs):
        net.train()

        epoch_train_loss = []
        epoch_train_acc = []
        is_best = False

        print("lr =", optimizer_net.param_groups[0]['lr'])
        for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
            data, target = data.to(device), target.to(device)
            optimizer_net.zero_grad()
            logits = net(data)
            xent = F.cross_entropy(logits, target)
            kll = kl_loss(logits)
            loss = xent + label_smoothing * kll
            loss.backward()
            epoch_train_loss.append(loss.item())
            epoch_train_acc.append(accuracy(logits, target))
            optimizer_net.step()

        epoch_train_loss = np.mean(epoch_train_loss)
        epoch_train_acc = np.mean(epoch_train_acc)
        lr_scheduler_net.step(epoch_train_loss)

        net.eval()
        epoch_valid_loss = []
        epoch_valid_acc = []

        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(tqdm(valid_loader)):
                data, target = data.to(device), target.to(device)
                logits = net(data)
                loss = F.cross_entropy(logits, target)
                epoch_valid_loss.append(loss.item())
                epoch_valid_acc.append(accuracy(logits, target))

        epoch_valid_loss = np.mean(epoch_valid_loss)
        epoch_valid_acc = np.mean(epoch_valid_acc)

        print(
            "Epoch {:d}: loss: {:4f}, acc: {:4f}, val_loss: {:4f}, val_acc: {:4f}"
            .format(
                e,
                epoch_train_loss,
                epoch_train_acc,
                epoch_valid_loss,
                epoch_valid_acc,
            ))

        # save epoch losses
        train_losses.append(epoch_train_loss)
        train_accuracies.append(epoch_train_acc)
        valid_losses.append(epoch_valid_loss)
        valid_accuracies.append(epoch_valid_acc)

        if valid_losses[-1] <= np.min(valid_losses):
            is_best = True

        if is_best:
            filename = f"../snapshots/{model}_best.pth.tar"
            print("Saving best weights so far with val_loss: {:4f}".format(
                valid_losses[-1]))
            torch.save(
                {
                    'epoch': e,
                    'state_dict': net.state_dict(),
                    'optimizer': optimizer_net.state_dict(),
                    'train_losses': train_losses,
                    'train_accs': train_accuracies,
                    'val_losses': valid_losses,
                    'val_accs': valid_accuracies,
                }, filename)

        if e == epochs - 1:
            filename = f"../snapshots/{model}_{e}.pth.tar"
            print("Saving weights at epoch {:d}".format(e))
            torch.save(
                {
                    'epoch': e,
                    'state_dict': net.state_dict(),
                    'optimizer': optimizer_net.state_dict(),
                    'train_losses': train_losses,
                    'train_accs': train_accuracies,
                    'val_losses': valid_losses,
                    'val_accs': valid_accuracies,
                }, filename)
Exemplo n.º 6
0
def main(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_data = Dataset(args.train_file)
    val_data = Dataset(args.val_file)
    train_sents = train_data.batch_size.sum()
    vocab_size = int(train_data.vocab_size)
    print('Train data: %d batches' % len(train_data))
    print('Val data: %d batches' % len(val_data))
    print('Word vocab size: %d' % vocab_size)
    if args.slurm == 0:
        # cuda.set_device(args.gpu)
        gpu_id = 0
        device = torch.device(
            f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
    if args.train_from == '':
        model = RNNVAE(vocab_size=vocab_size,
                       enc_word_dim=args.enc_word_dim,
                       enc_h_dim=args.enc_h_dim,
                       enc_num_layers=args.enc_num_layers,
                       dec_word_dim=args.dec_word_dim,
                       dec_h_dim=args.dec_h_dim,
                       dec_num_layers=args.dec_num_layers,
                       dec_dropout=args.dec_dropout,
                       latent_dim=args.latent_dim,
                       mode=args.model)
        for param in model.parameters():
            param.data.uniform_(-0.1, 0.1)
    else:
        print('loading model from ' + args.train_from)
        checkpoint = torch.load(args.train_from)
        model = checkpoint['model']

    print("model architecture")
    print(model)

    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)

    if args.warmup == 0:
        args.beta = 1.
    else:
        args.beta = args.kl_start

    criterion = nn.NLLLoss(reduce=False)
    # criterion = nn.NLLLoss()
    # model.cuda()
    # criterion.cuda()
    # model = torch.nn.DataParallel(net, device_ids=[0, 1])
    model.to(device)
    criterion.to(device)
    model.train()

    def variational_loss(input, sents, model, z=None):
        mean, logvar = input
        z_samples = model._reparameterize(mean, logvar, z)
        preds = model._dec_forward(sents, z_samples)
        nll = sum([
            criterion(preds[:, l], sents[:, l + 1])
            for l in range(preds.size(1))
        ])
        kl = utils.kl_loss_diag(mean, logvar)
        return nll + args.beta * kl

    update_params = list(model.dec.parameters())
    meta_optimizer = OptimN2N(variational_loss,
                              model,
                              update_params,
                              eps=args.eps,
                              lr=[args.svi_lr1, args.svi_lr2],
                              iters=args.svi_steps,
                              momentum=args.momentum,
                              acc_param_grads=args.train_n2n == 1,
                              max_grad_norm=args.svi_max_grad_norm)
    if args.test == 1:
        args.beta = 1
        test_data = Dataset(args.test_file)
        eval(args, test_data, model, meta_optimizer, device)
        exit()

    t = 0
    best_val_nll = 1e5
    best_epoch = 0
    val_stats = []
    epoch = 0
    while epoch < args.num_epochs:
        start_time = time.time()
        epoch += 1
        print('Starting epoch %d' % epoch)
        train_nll_vae = 0.
        train_nll_autoreg = 0.
        train_kl_vae = 0.
        train_nll_svi = 0.
        train_kl_svi = 0.
        train_kl_init_final = 0.
        num_sents = 0
        num_words = 0
        b = 0

        for i in np.random.permutation(len(train_data)):
            if args.warmup > 0:
                args.beta = min(
                    1, args.beta + 1. / (args.warmup * len(train_data)))

            sents, length, batch_size = train_data[i]
            length = length.item()
            batch_size = batch_size.item()

            if args.gpu >= 0:
                # sents = sents.cuda()
                sents = sents.to(device)
                # batch_size = batch_size.to(device)
            b += 1

            optimizer.zero_grad()
            if args.model == 'autoreg':
                preds = model._dec_forward(sents, None, True)
                tgt = sents[:, 1:].contiguous()
                nll_autoreg = criterion(preds.view(-1, preds.size(2)),
                                        tgt.view(-1)).view(preds.size(0),
                                                           -1).sum(-1).mean(0)
                # nll_autoreg = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)])
                train_nll_autoreg += nll_autoreg.item() * batch_size
                # train_nll_autoreg += nll_autoreg.data[0]*batch_size #old
                nll_autoreg.backward()
            elif args.model == 'svi':
                # mean_svi = Variable(0.1*torch.zeros(batch_size, args.latent_dim).cuda(), requires_grad = True)
                # logvar_svi = Variable(0.1*torch.zeros(batch_size, args.latent_dim).cuda(), requires_grad = True)
                mean_svi = Variable(
                    0.1 * torch.zeros(batch_size, args.latent_dim).to(device),
                    requires_grad=True)
                logvar_svi = Variable(
                    0.1 * torch.zeros(batch_size, args.latent_dim).to(device),
                    requires_grad=True)
                var_params_svi = meta_optimizer.forward(
                    [mean_svi, logvar_svi], sents, b % args.print_every == 0)
                mean_svi_final, logvar_svi_final = var_params_svi
                z_samples = model._reparameterize(mean_svi_final.detach(),
                                                  logvar_svi_final.detach())
                preds = model._dec_forward(sents, z_samples)
                tgt = sents[:, 1:].contiguous()
                nll_svi = criterion(preds.view(-1, preds.size(2)),
                                    tgt.view(-1)).view(preds.size(0),
                                                       -1).sum(-1).mean(0)
                # nll_svi = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)])
                train_nll_svi += nll_svi.data[0] * batch_size
                kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final)
                train_kl_svi += kl_svi.data[0] * batch_size
                var_loss = nll_svi + args.beta * kl_svi
                var_loss.backward(retain_graph=True)
            else:
                mean, logvar = model._enc_forward(sents)
                z_samples = model._reparameterize(mean, logvar)
                preds = model._dec_forward(sents, z_samples)
                tgt = sents[:, 1:].contiguous()
                nll_vae = criterion(preds.view(-1, preds.size(2)),
                                    tgt.view(-1)).view(preds.size(0),
                                                       -1).sum(-1).mean(0)
                # nll_vae = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)])
                # train_nll_vae += nll_vae.data[0]*batch_size#old
                train_nll_vae += nll_vae.item() * batch_size
                kl_vae = utils.kl_loss_diag(mean, logvar)
                # train_kl_vae += kl_vae.data[0]*batch_size#old
                train_kl_vae += kl_vae.item() * batch_size
                if args.model == 'vae':
                    vae_loss = nll_vae + args.beta * kl_vae
                    vae_loss.backward(retain_graph=True)
                if args.model == 'savae':
                    var_params = torch.cat([mean, logvar], 1)
                    mean_svi = Variable(mean.data, requires_grad=True)
                    logvar_svi = Variable(logvar.data, requires_grad=True)
                    var_params_svi = meta_optimizer.forward(
                        [mean_svi, logvar_svi], sents,
                        b % args.print_every == 0)
                    mean_svi_final, logvar_svi_final = var_params_svi
                    z_samples = model._reparameterize(mean_svi_final,
                                                      logvar_svi_final)
                    preds = model._dec_forward(sents, z_samples)
                    tgt = sents[:, 1:].contiguous()
                    nll_svi = criterion(preds.view(-1, preds.size(2)),
                                        tgt.view(-1)).view(preds.size(0),
                                                           -1).sum(-1).mean(0)
                    # nll_svi = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)])
                    train_nll_svi += nll_svi.data[0] * batch_size
                    kl_svi = utils.kl_loss_diag(mean_svi_final,
                                                logvar_svi_final)
                    train_kl_svi += kl_svi.data[0] * batch_size
                    var_loss = nll_svi + args.beta * kl_svi
                    var_loss.backward(retain_graph=True)
                    if args.train_n2n == 0:
                        if args.train_kl == 1:
                            mean_final = mean_svi_final.detach()
                            logvar_final = logvar_svi_final.detach()
                            kl_init_final = utils.kl_loss(
                                mean, logvar, mean_final, logvar_final)
                            train_kl_init_final += kl_init_final.data[
                                0] * batch_size
                            kl_init_final.backward(retain_graph=True)
                        else:
                            vae_loss = nll_vae + args.beta * kl_vae
                            var_param_grads = torch.autograd.grad(
                                vae_loss, [mean, logvar], retain_graph=True)
                            var_param_grads = torch.cat(var_param_grads, 1)
                            var_params.backward(var_param_grads,
                                                retain_graph=True)
                    else:
                        var_param_grads = meta_optimizer.backward(
                            [mean_svi_final.grad, logvar_svi_final.grad],
                            b % args.print_every == 0)
                        var_param_grads = torch.cat(var_param_grads, 1)
                        var_params.backward(var_param_grads)
            if args.max_grad_norm > 0:
                torch.nn.utils.clip_grad_norm(model.parameters(),
                                              args.max_grad_norm)
            optimizer.step()
            num_sents += batch_size
            num_words += batch_size * length
            # num_sents = num_sents.item()
            # num_words = num_words.item()
            if b % args.print_every == 0:
                param_norm = sum([p.norm()**2
                                  for p in model.parameters()]).data[0]**0.5
                print(
                    'Iters: %d, Epoch: %d, Batch: %d/%d, LR: %.4f, TrainARPPL: %.2f, TrainVAE_PPL: %.2f, TrainVAE_KL: %.4f, TrainVAE_PPLBnd: %.2f, TrainSVI_PPL: %.2f, TrainSVI_KL: %.4f, TrainSVI_PPLBnd: %.2f, KLInitFinal: %.2f, |Param|: %.4f, BestValPerf: %.2f, BestEpoch: %d, Beta: %.4f, Throughput: %.2f examples/sec'
                    %
                    (t, epoch, b + 1, len(train_data), args.lr,
                     np.exp(train_nll_autoreg / num_words),
                     np.exp(
                         train_nll_vae / num_words), train_kl_vae / num_sents,
                     np.exp((train_nll_vae + train_kl_vae) / num_words),
                     np.exp(
                         train_nll_svi / num_words), train_kl_svi / num_sents,
                     np.exp((train_nll_svi + train_kl_svi) / num_words),
                     train_kl_init_final / num_sents, param_norm, best_val_nll,
                     best_epoch, args.beta, num_sents /
                     (time.time() - start_time)))

        print('--------------------------------')
        print('Checking validation perf...')
        val_nll = eval(args, val_data, model, meta_optimizer, device)
        val_stats.append(val_nll)

        # if val_elbo > self.best_val_elbo:
        #     self.not_improved = 0
        #     self.best_val_elbo = val_elbo
        # else:
        #     self.not_improved += 1
        #     if self.not_improved % 5 == 0:
        #         self.current_lr = self.current_lr * self.config.options.lr_decay
        #         print(f'New LR {self.current_lr}')
        #         model.optimizer = torch.optim.SGD(model.parameters(), lr=self.current_lr)
        #         model.enc_optimizer = torch.optim.SGD(model.parameters(), lr=self.current_lr)
        #         model.dec_optimizer = torch.optim.SGD(model.parameters(), lr=self.current_lr)

        if val_nll < best_val_nll:
            not_improved = 0
            best_save = '{}_{}.pt'.format(args.checkpoint_path, best_val_nll)
            if os.path.exists(best_save):
                os.remove(best_save)

            best_val_nll = val_nll
            best_epoch = epoch
            model.cpu()
            checkpoint = {
                'args': args.__dict__,
                'model': model,
                'val_stats': val_stats
            }
            print('Savaeng checkpoint to %s' % args.checkpoint_path)
            best_save = '{}_{}.pt'.format(args.checkpoint_path, best_val_nll)
            torch.save(checkpoint, best_save)

            # model.cuda()
            model.to(device)
        else:
            not_improved += 1
            if not_improved % 5 == 0:
                not_improved = 0
                args.lr = args.lr * args.lr_decay
                print(f'New LR: {args.lr}')
                for param_group in optimizer.param_groups:
                    param_group['lr'] = args.lr