Пример #1
0
def train_eval(device, args, model, growth_model, itr, best_loss, logger, full_data):
    model.eval()
    test_loss = compute_loss(device, args, model, growth_model, logger, full_data)
    test_nfe = count_nfe(model)
    log_message = "[TEST] Iter {:04d} | Test Loss {:.6f} |" " NFE {:.0f}".format(
        itr, test_loss, test_nfe
    )
    logger.info(log_message)
    utils.makedirs(args.save)
    with open(os.path.join(args.save, "train_eval.csv"), "a") as f:
        import csv

        writer = csv.writer(f)
        writer.writerow((itr, test_loss))

    if test_loss.item() < best_loss:
        best_loss = test_loss.item()
        torch.save(
            {
                # 'args': args,
                "state_dict": model.state_dict(),
                "growth_state_dict": growth_model.state_dict(),
            },
            os.path.join(args.save, "checkpt.pth"),
        )
Пример #2
0
def compare_with_DV_particle_method(args, model, dim, batch_size=None):
    if batch_size is None: batch_size = args.batch_size

    x = torch.randn([batch_size, dim], dtype=torch.float32, device=device)
    diff_0 = torch.zeros(1, dtype=torch.float32, device=device)

    x_t, diff_t = model(x, diff_0, integration_times=args.time_length)

    nfe = count_nfe(model)
    torch.save(x_t, 'output/DVP_output_gaussian_mixture.pt')

    return diff_t[0] / nfe
Пример #3
0
def compute_loss_wgf(args, model, dim, batch_size=None):
    if batch_size is None: batch_size = args.batch_size

    z = torch.randn(batch_size, dim, dtype=torch.float32, device=device)
    logp_z = standard_normal_logprob(z).sum(1, keepdim=True).to(z)
    score_z = standard_normal_score(z).to(z)
    wgf_reg_0 = torch.tensor(0, device=device)
    # mu_0 = torch.zeros(2, dtype=torch.float32, device=device)
    # sigma_half_0 = torch.eye(2, dtype=torch.float32, device=device)
    # score_error_0 = torch.zeros(1, dtype=torch.float32, device=device)
    x, logp_x, score_x, wgf_reg = model(z,
                                        logpz=logp_z,
                                        score=score_z,
                                        wgf_reg=wgf_reg_0)

    nfe = count_nfe(model)

    return wgf_reg / nfe
Пример #4
0
def score_error_wgf(args, model, batch_size=None):
    if batch_size is None: batch_size = args.batch_size

    # TODO: should have an input specifying the data dimension. Now it is fixed to 2
    z = torch.randn(batch_size, 2, dtype=torch.float32, device=device)
    logp_z = standard_normal_logprob(z).sum(1, keepdim=True).to(z)
    score_z = standard_normal_score(z).to(z)
    wgf_reg_0 = torch.tensor(0, device=device)
    mu_0 = torch.zeros(2, dtype=torch.float32, device=device)
    sigma_half_0 = torch.eye(2, dtype=torch.float32, device=device)
    score_error_0 = torch.zeros(1, dtype=torch.float32, device=device)
    # x, logp_x, score_x, wgf_reg = model(z, logp_z, score_z, wgf_reg_0)
    x, logp_x, score_x, wgf_reg, mu, sigma_half, score_error = \
        model(z, logpz=logp_z, score=score_z, wgf_reg=wgf_reg_0, mu_0=mu_0, sigma_half_0=sigma_half_0,
              score_error_0=score_error_0)

    nfe = count_nfe(model)

    return score_error / nfe
Пример #5
0
def compute_likelihood(args, model, batch_size=None):
    if batch_size is None: batch_size = args.batch_size

    # TODO: should have an input specifying the data dimension. Now it is fixed to 2
    z = torch.randn(batch_size, 2, dtype=torch.float32, device=device)
    logp_z = standard_normal_logprob(z).sum(1, keepdim=True).to(z)
    score_z = standard_normal_score(z).to(z)
    wgf_reg_0 = torch.tensor(0, device=device)
    # x, logp_x, score_x, wgf_reg = model(z, logp_z, score_z, wgf_reg_0)
    x, logp_x, score_x, wgf_reg = model(z,
                                        logpz=logp_z,
                                        score=score_z,
                                        wgf_reg=wgf_reg_0)

    nfe = count_nfe(model)
    logp_true_x = gaussian_logprob(x).sum(1, keepdim=True).to(z)
    # logp_true_x = gaussian_mixture_logprob(x)
    # print(torch.mean(x, 0))
    return -torch.mean(logp_true_x)
Пример #6
0
def compare_with_Gaussian(args, model, dim, batch_size=None):
    if batch_size is None: batch_size = args.batch_size

    x = torch.randn([batch_size, dim], dtype=torch.float32, device=device)
    mu_0 = torch.zeros(dim, dtype=torch.float32, device=device)
    sigma_half_0 = torch.eye(dim, dtype=torch.float32, device=device)
    diff_0 = torch.zeros(1, dtype=torch.float32, device=device)

    x_t, mu_t, sigma_half_t, diff_t = model(x,
                                            mu_0,
                                            sigma_half_0,
                                            diff_0,
                                            integration_times=args.time_length)

    nfe = count_nfe(model)
    # print(torch.mean(x_t, dim=0))
    # print(torch.mean(y_t, dim=0))
    # print(torch.norm(score_t[0])**2)
    return diff_t[0] / nfe
Пример #7
0
    logger.info("Number of trainable parameters: {}".format(nWeights))
    logger.info('Evaluating model on test set.')
    model.eval()

    override_divergence_fn(model, "brute_force")

    bInverse = True  # check one batch for inverse error, for speed

    with torch.no_grad():
        test_loss = utils.AverageMeter()
        test_nfe = utils.AverageMeter()
        for itr, x in enumerate(
                batch_iter(data.tst.x, batch_size=test_batch_size)):

            x = cvt(x)
            test_loss.update(compute_loss(x, model).item(), x.shape[0])
            test_nfe.update(count_nfe(model))

            if bInverse:  # check the ivnerse error
                z = model(x, reverse=False)  # push forward
                xpred = model(z, reverse=True)  # inverse
                logger.info('inverse norm for first batch: ')
                logger.info(torch.norm(xpred - x).item() / x.shape[0])
                bInverse = False

            logger.info('Progress: {:.2f}%'.format(
                100. * itr / (data.tst.x.shape[0] / test_batch_size)))
        log_message = '[TEST] Iter {:06d} | Test Loss {:.6f} | NFE {:.0f}'.format(
            itr, test_loss.avg, test_nfe.avg)
        logger.info(log_message)
Пример #8
0
def train(epoch, train_loader, model, opt, args, logger):

    model.train()
    train_loss = np.zeros(len(train_loader))
    train_bpd = np.zeros(len(train_loader))

    num_data = 0

    # set warmup coefficient
    beta = min([(epoch * 1.) / max([args.warmup, 1.]), args.max_beta])
    logger.info('beta = {:5.4f}'.format(beta))
    end = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):
        if args.cuda:
            data = data.cuda()
            target = target.cuda()

        if args.dynamic_binarization:
            data = torch.bernoulli(data)

        data = data.view(-1, *args.input_size)

        opt.zero_grad()

        if args.conditional:
            x_mean, z_mu, z_var, ldj, z0, zk = model(data, target)
        else:
            x_mean, z_mu, z_var, ldj, z0, zk = model(data)

        # if batch_idx == len(train_loader)-1:
        #     print('-'*10 ,)
        # for i in range(len(x_mean)):
        #     print(x_mean[i].data[0].item(), x_mean[i].data[1].item(), data[i].data[0].item(), data[i].data[1].item())
        if 'cnf' in args.flow:
            f_nfe = count_nfe(model)

        loss, rec, kl, bpd = calculate_loss(x_mean,
                                            data,
                                            z_mu,
                                            z_var,
                                            z0,
                                            zk,
                                            ldj,
                                            args,
                                            beta=beta)

        loss.backward()

        if 'cnf' in args.flow:
            t_nfe = count_nfe(model)
            b_nfe = t_nfe - f_nfe

        train_loss[batch_idx] = loss.item()
        train_bpd[batch_idx] = bpd

        opt.step()

        rec = rec.item()
        kl = kl.item()

        num_data += len(data)

        batch_time = time.time() - end
        end = time.time()

        if batch_idx % args.log_interval == 0:
            if args.input_type == 'binary':
                perc = 100. * batch_idx / len(train_loader)
                log_msg = (
                    'Epoch {:3d} [{:5d}/{:5d} ({:2.0f}%)] | Time {:.3f} | Loss {:11.6f} | '
                    'Rec {:11.6f} | KL {:11.6f}'.format(
                        epoch, num_data, len(train_loader.sampler), perc,
                        batch_time, loss.item(), rec, kl))
            else:
                perc = 100. * batch_idx / len(train_loader)
                tmp = 'Epoch {:3d} [{:5d}/{:5d} ({:2.0f}%)] | Time {:.3f} | Loss {:11.6f} | Bits/dim {:8.6f}'
                log_msg = tmp.format(
                    epoch, num_data, len(train_loader.sampler), perc,
                    batch_time, loss.item(),
                    bpd), '\trec: {:11.3f}\tkl: {:11.6f}\tvar: {}'.format(
                        rec, kl, torch.mean(torch.mean(z_var, dim=0)))
                log_msg = "".join(log_msg)
            if 'cnf' in args.flow:
                log_msg += ' | NFE Forward {} | NFE Backward {}'.format(
                    f_nfe, b_nfe)
            logger.info(log_msg)

    if args.input_type == 'binary':
        logger.info('====> Epoch: {:3d} Average train loss: {:.4f}'.format(
            epoch,
            train_loss.sum() / len(train_loader)))
    else:
        logger.info(
            '====> Epoch: {:3d} Average train loss: {:.4f}, average bpd: {:.4f}'
            .format(epoch,
                    train_loss.sum() / len(train_loader),
                    train_bpd.sum() / len(train_loader)))

    return train_loss
Пример #9
0
                loss = loss + reg_loss
            total_time = count_total_time(model)
            loss = loss + total_time * args.time_penalty

            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)

            optimizer.step()

            if args.spectral_norm:
                spectral_norm_power_iteration(model, args.spectral_norm_niter)

            time_meter.update(time.time() - start)
            loss_meter.update(loss.item())
            steps_meter.update(count_nfe(model))
            grad_meter.update(grad_norm)
            tt_meter.update(total_time)

            if itr % args.log_freq == 0:
                log_message = (
                    "Iter {:04d} | Time {:.4f}({:.4f}) | Bit/dim {:.4f}({:.4f}) | "
                    "Steps {:.0f}({:.2f}) | Grad Norm {:.4f}({:.4f}) | Total Time {:.2f}({:.2f})"
                    .format(itr, time_meter.val, time_meter.avg,
                            loss_meter.val, loss_meter.avg, steps_meter.val,
                            steps_meter.avg, grad_meter.val, grad_meter.avg,
                            tt_meter.val, tt_meter.avg))
                if regularization_coeffs:
                    log_message = append_regularization_to_log(
                        log_message, regularization_fns, reg_states)
                logger.info(log_message)
Пример #10
0
def train():

    model = build_model_tabular(args, 1).to(device)
    set_cnf_options(args, model)

    logger.info(model)
    logger.info("Number of trainable parameters: {}".format(
        count_parameters(model)))

    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)

    time_meter = utils.RunningAverageMeter(0.93)
    loss_meter = utils.RunningAverageMeter(0.93)
    nfef_meter = utils.RunningAverageMeter(0.93)
    nfeb_meter = utils.RunningAverageMeter(0.93)
    tt_meter = utils.RunningAverageMeter(0.93)

    end = time.time()
    best_loss = float('inf')
    model.train()
    for itr in range(1, args.niters + 1):
        optimizer.zero_grad()

        loss = compute_loss(args, model)
        loss_meter.update(loss.item())

        total_time = count_total_time(model)
        nfe_forward = count_nfe(model)

        loss.backward()
        optimizer.step()

        nfe_total = count_nfe(model)
        nfe_backward = nfe_total - nfe_forward
        nfef_meter.update(nfe_forward)
        nfeb_meter.update(nfe_backward)

        time_meter.update(time.time() - end)
        tt_meter.update(total_time)

        log_message = (
            'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | NFE Forward {:.0f}({:.1f})'
            ' | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format(
                itr, time_meter.val, time_meter.avg, loss_meter.val,
                loss_meter.avg, nfef_meter.val, nfef_meter.avg, nfeb_meter.val,
                nfeb_meter.avg, tt_meter.val, tt_meter.avg))
        logger.info(log_message)

        if itr % args.val_freq == 0 or itr == args.niters:
            with torch.no_grad():
                model.eval()
                test_loss = compute_loss(args,
                                         model,
                                         batch_size=args.test_batch_size)
                test_nfe = count_nfe(model)
                log_message = '[TEST] Iter {:04d} | Test Loss {:.6f} | NFE {:.0f}'.format(
                    itr, test_loss, test_nfe)
                logger.info(log_message)

                if test_loss.item() < best_loss:
                    best_loss = test_loss.item()
                    utils.makedirs(args.save)
                    torch.save(
                        {
                            'args': args,
                            'state_dict': model.state_dict(),
                        }, os.path.join(args.save, 'checkpt.pth'))
                model.train()

        if itr % args.viz_freq == 0:
            with torch.no_grad():
                model.eval()

                xx = torch.linspace(-10, 10, 10000).view(-1, 1)
                true_p = data_density(xx)
                plt.plot(xx.view(-1).cpu().numpy(),
                         true_p.view(-1).exp().cpu().numpy(),
                         label='True')

                true_p = model_density(xx, model)
                plt.plot(xx.view(-1).cpu().numpy(),
                         true_p.view(-1).exp().cpu().numpy(),
                         label='Model')

                utils.makedirs(os.path.join(args.save, 'figs'))
                plt.savefig(
                    os.path.join(args.save, 'figs', '{:06d}.jpg'.format(itr)))
                plt.close()

                model.train()

        end = time.time()

    logger.info('Training has finished.')
Пример #11
0
def main():
    #os.system('shutdown -c')  # cancel previous shutdown command

    if write_log:
        utils.makedirs(args.save)
        logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'),
                                  filepath=os.path.abspath(__file__))

        logger.info(args)

        args_file_path = os.path.join(args.save, 'args.yaml')
        with open(args_file_path, 'w') as f:
            yaml.dump(vars(args), f, default_flow_style=False)

    if args.distributed:
        if write_log: logger.info('Distributed initializing process group')
        torch.cuda.set_device(args.local_rank)
        distributed.init_process_group(backend=args.dist_backend,
                                       init_method=args.dist_url,
                                       world_size=dist_utils.env_world_size(),
                                       rank=env_rank())
        assert (dist_utils.env_world_size() == distributed.get_world_size())
        if write_log:
            logger.info("Distributed: success (%d/%d)" %
                        (args.local_rank, distributed.get_world_size()))

    # get deivce
    # device = torch.device("cuda:%d"%torch.cuda.current_device() if torch.cuda.is_available() else "cpu")
    device = "cpu"
    cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True)

    # load dataset
    train_loader, test_loader, data_shape = get_dataset(args)

    trainlog = os.path.join(args.save, 'training.csv')
    testlog = os.path.join(args.save, 'test.csv')

    traincolumns = [
        'itr', 'wall', 'itr_time', 'loss', 'bpd', 'fe', 'total_time',
        'grad_norm'
    ]
    testcolumns = [
        'wall', 'epoch', 'eval_time', 'bpd', 'fe', 'total_time',
        'transport_cost'
    ]

    # build model
    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = create_model(args, data_shape, regularization_fns)
    # model = model.cuda()
    if args.distributed:
        model = dist_utils.DDP(model,
                               device_ids=[args.local_rank],
                               output_device=args.local_rank)

    traincolumns = append_regularization_keys_header(traincolumns,
                                                     regularization_fns)

    if not args.resume and write_log:
        with open(trainlog, 'w') as f:
            csvlogger = csv.DictWriter(f, traincolumns)
            csvlogger.writeheader()
        with open(testlog, 'w') as f:
            csvlogger = csv.DictWriter(f, testcolumns)
            csvlogger.writeheader()

    set_cnf_options(args, model)

    if write_log: logger.info(model)
    if write_log:
        logger.info("Number of trainable parameters: {}".format(
            count_parameters(model)))
    if write_log:
        logger.info('Iters per train epoch: {}'.format(len(train_loader)))
    if write_log: logger.info('Iters per test: {}'.format(len(test_loader)))

    # optimizer
    if args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)
    elif args.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              weight_decay=args.weight_decay,
                              momentum=0.9,
                              nesterov=False)

    # restore parameters
    if args.resume is not None:
        checkpt = torch.load(
            args.resume,
            map_location=lambda storage, loc: storage.cuda(args.local_rank))
        model.load_state_dict(checkpt["state_dict"])
        if "optim_state_dict" in checkpt.keys():
            optimizer.load_state_dict(checkpt["optim_state_dict"])
            # Manually move optimizer state to device.
            for state in optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = cvt(v)

    # For visualization.
    if write_log:
        fixed_z = cvt(torch.randn(min(args.test_batch_size, 100), *data_shape))

    if write_log:
        time_meter = utils.RunningAverageMeter(0.97)
        bpd_meter = utils.RunningAverageMeter(0.97)
        loss_meter = utils.RunningAverageMeter(0.97)
        steps_meter = utils.RunningAverageMeter(0.97)
        grad_meter = utils.RunningAverageMeter(0.97)
        tt_meter = utils.RunningAverageMeter(0.97)

    if not args.resume:
        best_loss = float("inf")
        itr = 0
        wall_clock = 0.
        begin_epoch = 1
    else:
        chkdir = os.path.dirname(args.resume)
        tedf = pd.read_csv(os.path.join(chkdir, 'test.csv'))
        trdf = pd.read_csv(os.path.join(chkdir, 'training.csv'))
        wall_clock = trdf['wall'].to_numpy()[-1]
        itr = trdf['itr'].to_numpy()[-1]
        best_loss = tedf['bpd'].min()
        begin_epoch = int(tedf['epoch'].to_numpy()[-1] +
                          1)  # not exactly correct

    if args.distributed:
        if write_log: logger.info('Syncing machines before training')
        dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda())

    for epoch in range(begin_epoch, args.num_epochs + 1):
        if not args.validate:
            model.train()

            with open(trainlog, 'a') as f:
                if write_log: csvlogger = csv.DictWriter(f, traincolumns)

                for _, (x, y) in enumerate(train_loader):
                    start = time.time()
                    update_lr(optimizer, itr)
                    optimizer.zero_grad()

                    # cast data and move to device
                    x = add_noise(cvt(x), nbits=args.nbits)
                    #x = x.clamp_(min=0, max=1)
                    # compute loss
                    bpd, (x, z), reg_states = compute_bits_per_dim(x, model)
                    if np.isnan(bpd.data.item()):
                        raise ValueError('model returned nan during training')
                    elif np.isinf(bpd.data.item()):
                        raise ValueError('model returned inf during training')

                    loss = bpd
                    if regularization_coeffs:
                        reg_loss = sum(reg_state * coeff
                                       for reg_state, coeff in zip(
                                           reg_states, regularization_coeffs)
                                       if coeff != 0)
                        loss = loss + reg_loss
                    total_time = count_total_time(model)

                    loss.backward()
                    nfe_opt = count_nfe(model)
                    if write_log: steps_meter.update(nfe_opt)
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(), args.max_grad_norm)

                    optimizer.step()

                    itr_time = time.time() - start
                    wall_clock += itr_time

                    batch_size = x.size(0)
                    metrics = torch.tensor([
                        1., batch_size,
                        loss.item(),
                        bpd.item(), nfe_opt, grad_norm, *reg_states
                    ]).float()

                    rv = tuple(torch.tensor(0.) for r in reg_states)

                    total_gpus, batch_total, r_loss, r_bpd, r_nfe, r_grad_norm, *rv = dist_utils.sum_tensor(
                        metrics).cpu().numpy()

                    if write_log:
                        time_meter.update(itr_time)
                        bpd_meter.update(r_bpd / total_gpus)
                        loss_meter.update(r_loss / total_gpus)
                        grad_meter.update(r_grad_norm / total_gpus)
                        tt_meter.update(total_time)

                        fmt = '{:.4f}'
                        logdict = {
                            'itr': itr,
                            'wall': fmt.format(wall_clock),
                            'itr_time': fmt.format(itr_time),
                            'loss': fmt.format(r_loss / total_gpus),
                            'bpd': fmt.format(r_bpd / total_gpus),
                            'total_time': fmt.format(total_time),
                            'fe': r_nfe / total_gpus,
                            'grad_norm': fmt.format(r_grad_norm / total_gpus),
                        }
                        if regularization_coeffs:
                            rv = tuple(v_ / total_gpus for v_ in rv)
                            logdict = append_regularization_csv_dict(
                                logdict, regularization_fns, rv)
                        csvlogger.writerow(logdict)

                        if itr % args.log_freq == 0:
                            log_message = (
                                "Itr {:06d} | Wall {:.3e}({:.2f}) | "
                                "Time/Itr {:.2f}({:.2f}) | BPD {:.2f}({:.2f}) | "
                                "Loss {:.2f}({:.2f}) | "
                                "FE {:.0f}({:.0f}) | Grad Norm {:.3e}({:.3e}) | "
                                "TT {:.2f}({:.2f})".format(
                                    itr, wall_clock, wall_clock / (itr + 1),
                                    time_meter.val, time_meter.avg,
                                    bpd_meter.val, bpd_meter.avg,
                                    loss_meter.val, loss_meter.avg,
                                    steps_meter.val, steps_meter.avg,
                                    grad_meter.val, grad_meter.avg,
                                    tt_meter.val, tt_meter.avg))
                            if regularization_coeffs:
                                log_message = append_regularization_to_log(
                                    log_message, regularization_fns, rv)
                            logger.info(log_message)

                    itr += 1

        # compute test loss
        model.eval()
        if args.local_rank == 0:
            utils.makedirs(args.save)
            torch.save(
                {
                    "args":
                    args,
                    "state_dict":
                    model.module.state_dict()
                    if torch.cuda.is_available() else model.state_dict(),
                    "optim_state_dict":
                    optimizer.state_dict(),
                    "fixed_z":
                    fixed_z.cpu()
                }, os.path.join(args.save, "checkpt.pth"))
        if epoch % args.val_freq == 0 or args.validate:
            with open(testlog, 'a') as f:
                if write_log: csvlogger = csv.DictWriter(f, testcolumns)
                with torch.no_grad():
                    start = time.time()
                    if write_log: logger.info("validating...")

                    lossmean = 0.
                    meandist = 0.
                    steps = 0
                    tt = 0.
                    for i, (x, y) in enumerate(test_loader):
                        sh = x.shape
                        x = shift(cvt(x), nbits=args.nbits)
                        loss, (x, z), _ = compute_bits_per_dim(x, model)
                        dist = (x.view(x.size(0), -1) -
                                z).pow(2).mean(dim=-1).mean()
                        meandist = i / (i + 1) * dist + meandist / (i + 1)
                        lossmean = i / (i + 1) * lossmean + loss / (i + 1)

                        tt = i / (i + 1) * tt + count_total_time(model) / (i +
                                                                           1)
                        steps = i / (i + 1) * steps + count_nfe(model) / (i +
                                                                          1)

                    loss = lossmean.item()
                    metrics = torch.tensor([1., loss, meandist, steps]).float()

                    total_gpus, r_bpd, r_mdist, r_steps = dist_utils.sum_tensor(
                        metrics).cpu().numpy()
                    eval_time = time.time() - start

                    if write_log:
                        fmt = '{:.4f}'
                        logdict = {
                            'epoch': epoch,
                            'eval_time': fmt.format(eval_time),
                            'bpd': fmt.format(r_bpd / total_gpus),
                            'wall': fmt.format(wall_clock),
                            'total_time': fmt.format(tt),
                            'transport_cost': fmt.format(r_mdist / total_gpus),
                            'fe': '{:.2f}'.format(r_steps / total_gpus)
                        }

                        csvlogger.writerow(logdict)

                        logger.info(
                            "Epoch {:04d} | Time {:.4f}, Bit/dim {:.4f}, Steps {:.4f}, TT {:.2f}, Transport Cost {:.2e}"
                            .format(epoch, eval_time, r_bpd / total_gpus,
                                    r_steps / total_gpus, tt,
                                    r_mdist / total_gpus))

                    loss = r_bpd / total_gpus

                    if loss < best_loss and args.local_rank == 0:
                        best_loss = loss
                        shutil.copyfile(os.path.join(args.save, "checkpt.pth"),
                                        os.path.join(args.save, "best.pth"))

            # visualize samples and density
            if write_log:
                with torch.no_grad():
                    fig_filename = os.path.join(args.save, "figs",
                                                "{:04d}.jpg".format(epoch))
                    utils.makedirs(os.path.dirname(fig_filename))
                    generated_samples, _, _ = model(fixed_z, reverse=True)
                    generated_samples = generated_samples.view(-1, *data_shape)
                    nb = int(np.ceil(np.sqrt(float(fixed_z.size(0)))))
                    save_image(unshift(generated_samples, nbits=args.nbits),
                               fig_filename,
                               nrow=nb)
            if args.validate:
                break
Пример #12
0
def train(epoch,
          train_loader,
          model,
          opt,
          args,
          logger,
          nfef_meter=None,
          nfeb_meter=None):

    model.train()
    train_loss = np.zeros(len(train_loader))
    train_bpd = np.zeros(len(train_loader))

    num_data = 0

    # set warmup coefficient
    beta = min([(epoch * 1.) / max([args.warmup, 1.]), args.max_beta])
    logger.info('beta = {:5.4f}'.format(beta))
    end = time.time()
    for batch_idx, (data, _) in enumerate(train_loader):
        if args.cuda:
            data = data.cuda()

        if args.dynamic_binarization:
            data = torch.bernoulli(data)

        data = data.view(-1, *args.input_size)

        opt.zero_grad()
        x_mean, z_mu, z_var, ldj, z0, zk = model(data,
                                                 is_eval=False,
                                                 epoch=epoch)

        if 'cnf' in args.flow:
            f_nfe = count_nfe(model)

        loss, rec, kl, bpd = calculate_loss(x_mean,
                                            data,
                                            z_mu,
                                            z_var,
                                            z0,
                                            zk,
                                            ldj,
                                            args,
                                            beta=beta)

        loss.backward()

        if 'cnf' in args.flow:
            t_nfe = count_nfe(model)
            b_nfe = t_nfe - f_nfe

            nfef_meter.update(f_nfe)
            nfeb_meter.update(b_nfe)

        train_loss[batch_idx] = loss.item()
        train_bpd[batch_idx] = bpd

        opt.step()

        rec = rec.item()
        kl = kl.item()

        num_data += len(data)

        batch_time = time.time() - end
        end = time.time()

        if batch_idx % args.log_interval == 0:
            if args.input_type == 'binary':
                perc = 100. * batch_idx / len(train_loader)
                log_msg = (
                    'Epoch {:3d} [{:5d}/{:5d} ({:2.0f}%)] | Time {:.3f} | Loss {:11.6f} | '
                    'Rec {:11.6f} | KL {:11.6f}'.format(
                        epoch, num_data, len(train_loader.sampler), perc,
                        batch_time, loss.item(), rec, kl))
            else:
                perc = 100. * batch_idx / len(train_loader)
                tmp = 'Epoch {:3d} [{:5d}/{:5d} ({:2.0f}%)] | Time {:.3f} | Loss {:11.6f} | Bits/dim {:8.6f}'
                log_msg = tmp.format(
                    epoch, num_data, len(train_loader.sampler), perc,
                    batch_time, loss.item(),
                    bpd), '\trec: {:11.3f}\tkl: {:11.6f}'.format(rec, kl)
                log_msg = "".join(log_msg)
            if 'cnf' in args.flow:
                log_msg += ' | NFE Forward {:.0f}({:.1f}) | NFE Backward {:.0f}({:.1f})'.format(
                    f_nfe, nfef_meter.avg, b_nfe, nfeb_meter.avg)
            logger.info(log_msg)

    if args.input_type == 'binary':
        logger.info('====> Epoch: {:3d} Average train loss: {:.4f}'.format(
            epoch,
            train_loss.sum() / len(train_loader)))
    else:
        logger.info(
            '====> Epoch: {:3d} Average train loss: {:.4f}, average bpd: {:.4f}'
            .format(epoch,
                    train_loss.sum() / len(train_loader),
                    train_bpd.sum() / len(train_loader)))

    if 'cnf' not in args.flow:
        return train_loss

    else:
        return train_loss, nfef_meter, nfeb_meter
Пример #13
0
def train(args, model, growth_model):
    logger.info(model)
    logger.info("Number of trainable parameters: {}".format(count_parameters(model)))

    #optimizer = optim.Adam(set(model.parameters()) | set(growth_model.parameters()), 
    optimizer = optim.Adam(model.parameters(), 
                           lr=args.lr, weight_decay=args.weight_decay)
    #growth_optimizer = optim.Adam(growth_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    time_meter = utils.RunningAverageMeter(0.93)
    loss_meter = utils.RunningAverageMeter(0.93)
    nfef_meter = utils.RunningAverageMeter(0.93)
    nfeb_meter = utils.RunningAverageMeter(0.93)
    tt_meter = utils.RunningAverageMeter(0.93)

    end = time.time()
    best_loss = float('inf')
    model.train()
    growth_model.eval()
    for itr in range(1, args.niters + 1):
        optimizer.zero_grad()
        #growth_optimizer.zero_grad()

        ### Train
        if args.spectral_norm: spectral_norm_power_iteration(model, 1)
        #if args.spectral_norm: spectral_norm_power_iteration(growth_model, 1)

        loss = compute_loss(args, model, growth_model)
        loss_meter.update(loss.item())

        if len(regularization_coeffs) > 0:
            # Only regularize on the last timepoint
            reg_states = get_regularization(model, regularization_coeffs)
            reg_loss = sum(
                reg_state * coeff for reg_state, coeff in zip(reg_states, regularization_coeffs) if coeff != 0
            )
            loss = loss + reg_loss

        #if len(growth_regularization_coeffs) > 0:
        #    growth_reg_states = get_regularization(growth_model, growth_regularization_coeffs)
        #    reg_loss = sum(
        #        reg_state * coeff for reg_state, coeff in zip(growth_reg_states, growth_regularization_coeffs) if coeff != 0
        #    )
        #    loss2 = loss2 + reg_loss

        total_time = count_total_time(model)
        nfe_forward = count_nfe(model)

        loss.backward()
        #loss2.backward()
        optimizer.step()
        #growth_optimizer.step()

        ### Eval
        nfe_total = count_nfe(model)
        nfe_backward = nfe_total - nfe_forward
        nfef_meter.update(nfe_forward)
        nfeb_meter.update(nfe_backward)
        time_meter.update(time.time() - end)
        tt_meter.update(total_time)

        log_message = (
            'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | NFE Forward {:.0f}({:.1f})'
            ' | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format(
                itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, nfef_meter.val, nfef_meter.avg,
                nfeb_meter.val, nfeb_meter.avg, tt_meter.val, tt_meter.avg
            )
        )
        if len(regularization_coeffs) > 0:
            log_message = append_regularization_to_log(log_message, regularization_fns, reg_states)

        logger.info(log_message)

        if itr % args.val_freq == 0 or itr == args.niters:
            with torch.no_grad():
                model.eval()
                growth_model.eval()
                test_loss = compute_loss(args, model, growth_model)
                test_nfe = count_nfe(model)
                log_message = '[TEST] Iter {:04d} | Test Loss {:.6f} | NFE {:.0f}'.format(itr, test_loss, test_nfe)
                logger.info(log_message)

                if test_loss.item() < best_loss:
                    best_loss = test_loss.item()
                    utils.makedirs(args.save)
                    torch.save({
                        'args': args,
                        'state_dict': model.state_dict(),
                        'growth_state_dict': growth_model.state_dict(),
                    }, os.path.join(args.save, 'checkpt.pth'))
                model.train()

        if itr % args.viz_freq == 0:
            with torch.no_grad():
                model.eval()
                for i, tp in enumerate(timepoints):
                    p_samples = viz_sampler(tp)
                    sample_fn, density_fn = get_transforms(model, int_tps[:i+1])
                    #growth_sample_fn, growth_density_fn = get_transforms(growth_model, int_tps[:i+1])
                    plt.figure(figsize=(9, 3))
                    visualize_transform(
                        p_samples, torch.randn, standard_normal_logprob, transform=sample_fn, inverse_transform=density_fn,
                        samples=True, npts=100, device=device
                    )
                    fig_filename = os.path.join(args.save, 'figs', '{:04d}_{:01d}.jpg'.format(itr, i))
                    utils.makedirs(os.path.dirname(fig_filename))
                    plt.savefig(fig_filename)
                    plt.close()

                    #visualize_transform(
                    #    p_samples, torch.rand, uniform_logprob, transform=growth_sample_fn, 
                    #    inverse_transform=growth_density_fn,
                    #    samples=True, npts=800, device=device
                    #)

                    #fig_filename = os.path.join(args.save, 'growth_figs', '{:04d}_{:01d}.jpg'.format(itr, i))
                    #utils.makedirs(os.path.dirname(fig_filename))
                    #plt.savefig(fig_filename)
                    #plt.close()
                model.train()

        """
        if itr % args.viz_freq_growth == 0:
            with torch.no_grad():
                growth_model.eval()
                # Visualize growth transform
                growth_filename = os.path.join(args.save, 'growth', '{:04d}.jpg'.format(itr))
                utils.makedirs(os.path.dirname(growth_filename))
                visualize_growth(growth_model, data, labels, npts=200, device=device)
                plt.savefig(growth_filename)
                plt.close()
                growth_model.train()
        """

        end = time.time()
    logger.info('Training has finished.')
Пример #14
0
                x = cvt(x)
                loss = compute_loss(x, model)
                loss_meter.update(loss.item())

                if len(regularization_coeffs) > 0:
                    reg_states = get_regularization(model,
                                                    regularization_coeffs)
                    reg_loss = sum(reg_state * coeff
                                   for reg_state, coeff in zip(
                                       reg_states, regularization_coeffs)
                                   if coeff != 0)
                    loss = loss + reg_loss

                total_time = count_total_time(model)
                nfe_forward = count_nfe(model)

                loss.backward()
                optimizer.step()

                nfe_total = count_nfe(model)
                nfe_backward = nfe_total - nfe_forward
                nfef_meter.update(nfe_forward)
                nfeb_meter.update(nfe_backward)

                time_meter.update(time.time() - end)
                tt_meter.update(total_time)

                if itr % args.log_freq == 0:
                    log_message = (
                        'Iter {:06d} | Epoch {:.2f} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | '
Пример #15
0
def train(epoch, train_loader, model, opt, args, wandb):

    model.train()
    train_loss = np.zeros(len(train_loader))
    train_bpd = np.zeros(len(train_loader))

    num_data = 0

    # set warmup coefficient
    beta = min([(epoch * 1.) / max([args.warmup, 1.]), args.max_beta])
    # logger.info('beta = {:5.4f}'.format(beta))
    end = time.time()
    for batch_idx, (data, _) in enumerate(train_loader):
        if args.cuda:
            data = data.cuda()

        if args.dynamic_binarization:
            data = torch.bernoulli(data)

        data = data.view(-1, *args.input_size)

        opt.zero_grad()
        x_mean, z_mu, z_var, ldj, z0, zk = model(data)

        if 'cnf' in args.flow:
            f_nfe = count_nfe(model)

        loss, rec, kl, bpd = calculate_loss(x_mean,
                                            data,
                                            z_mu,
                                            z_var,
                                            z0,
                                            zk,
                                            ldj,
                                            args,
                                            beta=beta)

        loss.backward()

        if 'cnf' in args.flow:
            t_nfe = count_nfe(model)
            b_nfe = t_nfe - f_nfe

        train_loss[batch_idx] = loss.item()
        train_bpd[batch_idx] = bpd
        wandb.log({
            'train_loss': loss.item(),
            'train_bpd': bpd,
            'nfe': t_nfe,
            'nbe': b_nfe
        })

        opt.step()

        rec = rec.item()
        kl = kl.item()

        num_data += len(data)

        batch_time = time.time() - end
        end = time.time()

        # if batch_idx % args.log_interval == 0:
        #     if args.input_type == 'binary':
        #         perc = 100. * batch_idx / len(train_loader)
        #         log_msg = (
        #             'Epoch {:3d} [{:5d}/{:5d} ({:2.0f}%)] | Time {:.3f} | Loss {:11.6f} | '
        #             'Rec {:11.6f} | KL {:11.6f}'.format(
        #                 epoch, num_data, len(train_loader.sampler), perc, batch_time, loss.item(), rec, kl
        #             )
        #         )
        #     else:
        #         perc = 100. * batch_idx / len(train_loader)
        #         tmp = 'Epoch {:3d} [{:5d}/{:5d} ({:2.0f}%)] | Time {:.3f} | Loss {:11.6f} | Bits/dim {:8.6f}'
        #         log_msg = tmp.format(epoch, num_data, len(train_loader.sampler), perc, batch_time, loss.item(),
        #                              bpd), '\trec: {:11.3f}\tkl: {:11.6f}'.format(rec, kl)
        #         log_msg = "".join(log_msg)
        #     if 'cnf' in args.flow:
        #         log_msg += ' | NFE Forward {} | NFE Backward {}'.format(f_nfe, b_nfe)
        #     logger.info(log_msg)

    # if args.input_type == 'binary':
    #     logger.info('====> Epoch: {:3d} Average train loss: {:.4f}'.format(epoch, train_loss.sum() / len(train_loader)))
    # else:
    #     logger.info(
    #         '====> Epoch: {:3d} Average train loss: {:.4f}, average bpd: {:.4f}'.
    #         format(epoch, train_loss.sum() / len(train_loader), train_bpd.sum() / len(train_loader))
    #     )

    return train_loss.sum() / len(train_loader)
Пример #16
0
    for itr in range(1, args.niters + 1):
        optimizer.zero_grad()
        if args.spectral_norm: spectral_norm_power_iteration(model, 1)

        loss = compute_loss(args, model)
        loss_meter.update(loss.item())

        if len(regularization_coeffs) > 0:
            reg_states = get_regularization(model, regularization_coeffs)
            reg_loss = sum(
                reg_state * coeff for reg_state, coeff in zip(reg_states, regularization_coeffs) if coeff != 0
            )
            loss = loss + reg_loss

        total_time = count_total_time(model)
        nfe_forward = count_nfe(model)

        loss.backward()
        optimizer.step()

        nfe_total = count_nfe(model)
        nfe_backward = nfe_total - nfe_forward
        nfef_meter.update(nfe_forward)
        nfeb_meter.update(nfe_backward)

        time_meter.update(time.time() - end)
        tt_meter.update(total_time)

        log_message = (
            'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | NFE Forward {:.0f}({:.1f})'
            ' | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format(
Пример #17
0
def train(
    device, args, model, growth_model, regularization_coeffs, regularization_fns, logger
):
    optimizer = optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
    )

    time_meter = utils.RunningAverageMeter(0.93)
    loss_meter = utils.RunningAverageMeter(0.93)
    nfef_meter = utils.RunningAverageMeter(0.93)
    nfeb_meter = utils.RunningAverageMeter(0.93)
    tt_meter = utils.RunningAverageMeter(0.93)

    full_data = (
        torch.from_numpy(
            args.data.get_data()[args.data.get_times() != args.leaveout_timepoint]
        )
        .type(torch.float32)
        .to(device)
    )

    best_loss = float("inf")
    growth_model.eval()
    end = time.time()
    for itr in range(1, args.niters + 1):
        model.train()
        optimizer.zero_grad()

        # Train
        if args.spectral_norm:
            spectral_norm_power_iteration(model, 1)

        loss = compute_loss(device, args, model, growth_model, logger, full_data)
        loss_meter.update(loss.item())

        if len(regularization_coeffs) > 0:
            # Only regularize on the last timepoint
            reg_states = get_regularization(model, regularization_coeffs)
            reg_loss = sum(
                reg_state * coeff
                for reg_state, coeff in zip(reg_states, regularization_coeffs)
                if coeff != 0
            )
            loss = loss + reg_loss
        total_time = count_total_time(model)
        nfe_forward = count_nfe(model)

        loss.backward()
        optimizer.step()

        # Eval
        nfe_total = count_nfe(model)
        nfe_backward = nfe_total - nfe_forward
        nfef_meter.update(nfe_forward)
        nfeb_meter.update(nfe_backward)
        time_meter.update(time.time() - end)
        tt_meter.update(total_time)

        log_message = (
            "Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) |"
            " NFE Forward {:.0f}({:.1f})"
            " | NFE Backward {:.0f}({:.1f})".format(
                itr,
                time_meter.val,
                time_meter.avg,
                loss_meter.val,
                loss_meter.avg,
                nfef_meter.val,
                nfef_meter.avg,
                nfeb_meter.val,
                nfeb_meter.avg,
            )
        )
        if len(regularization_coeffs) > 0:
            log_message = append_regularization_to_log(
                log_message, regularization_fns, reg_states
            )
        logger.info(log_message)

        if itr % args.val_freq == 0 or itr == args.niters:
            with torch.no_grad():
                train_eval(
                    device, args, model, growth_model, itr, best_loss, logger, full_data
                )

        if itr % args.viz_freq == 0:
            if args.data.get_shape()[0] > 2:
                logger.warning("Skipping vis as data dimension is >2")
            else:
                with torch.no_grad():
                    visualize(device, args, model, itr)
        if itr % args.save_freq == 0:
            utils.save_checkpoint(
                {
                    # 'args': args,
                    "state_dict": model.state_dict(),
                    "growth_state_dict": growth_model.state_dict(),
                },
                args.save,
                epoch=itr,
            )
        end = time.time()
    logger.info("Training has finished.")