def train(model, trainD, evalD, checkpt=None):
    global ndecs
    optim = torch.optim.Adam(model.parameters(),
                             lr=args.lr,
                             betas=(0.9, 0.99),
                             weight_decay=args.wd)
    #  sch = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.nepochs * trainD.N)

    if checkpt is not None:
        optim.load_state_dict(checkpt['optim'])
        ndecs = checkpt['ndecs']

    batch_time = utils.RunningAverageMeter(0.98)
    cg_meter = utils.RunningAverageMeter(0.98)
    gnorm_meter = utils.RunningAverageMeter(0.98)
    train_est_meter = utils.RunningAverageMeter(0.98**args.train_est_freq)

    best_logp = -float('inf')
    itr = 0 if checkpt is None else checkpt['iters']
    n_vals_without_improvement = 0
    model.train()
    while True:
        if itr >= args.nepochs * math.ceil(trainD.N / args.batch_size):
            break
        if 0 < args.early_stopping < n_vals_without_improvement:
            break
        for x in batch_iter(trainD.x, shuffle=True):
            if 0 < args.early_stopping < n_vals_without_improvement:
                break
            end = time.time()
            optim.zero_grad()

            x = cvt(x)
            train_est = [0] if itr % args.train_est_freq == 0 else None
            loss = -model.logp(x, extra=train_est).mean()
            if train_est is not None:
                train_est = train_est[0].mean().detach().item()

            if loss != loss:
                raise ValueError('NaN encountered @ training logp!')

            loss.backward()

            if args.clip_grad == 0:
                parameters = [
                    p for p in model.parameters() if p.grad is not None
                ]
                grad_norm = torch.norm(
                    torch.stack([
                        torch.norm(p.grad.detach(), 2.0) for p in parameters
                    ]), 2.0)
            else:
                grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_(
                    model.parameters(), args.clip_grad)

            optim.step()
            #  sch.step()

            gnorm_meter.update(float(grad_norm))
            cg_meter.update(sum(flows.CG_ITERS_TRACER))
            flows.CG_ITERS_TRACER.clear()
            batch_time.update(time.time() - end)
            if train_est is not None:
                train_est_meter.update(train_est)

            del loss
            gc.collect()
            torch.clear_autocast_cache()

            if itr % args.log_freq == 0:
                log_message = (
                    'Iter {:06d} | Epoch {:.2f} | Time {batch_time.val:.3f} | '
                    'GradNorm {gnorm_meter.avg:.2f} | CG iters {cg_meter.val} ({cg_meter.avg:.2f}) | '
                    'Train logp {train_logp.val:.6f} ({train_logp.avg:.6f})'.
                    format(itr,
                           float(itr) / (trainD.N / float(args.batch_size)),
                           batch_time=batch_time,
                           gnorm_meter=gnorm_meter,
                           cg_meter=cg_meter,
                           train_logp=train_est_meter))
                logger.info(log_message)

            # Validation loop.
            if itr % args.val_freq == 0:
                with eval_ctx(model, bruteforce=args.brute_val):
                    val_logp = utils.AverageMeter()
                    with tqdm(total=evalD.N) as pbar:
                        # noinspection PyAssignmentToLoopOrWithParameter
                        for x in batch_iter(evalD.x,
                                            batch_size=args.val_batch_size):
                            x = cvt(x)
                            val_logp.update(
                                model.logp(x).mean().item(), x.size(0))
                            pbar.update(x.size(0))
                    if val_logp.avg > best_logp:
                        best_logp = val_logp.avg
                        utils.makedirs(args.save)
                        torch.save(
                            {
                                'args': args,
                                'model': model.state_dict(),
                                'optim': optim.state_dict(),
                                'iters': itr + 1,
                                'ndecs': ndecs,
                            }, save_path)
                        n_vals_without_improvement = 0
                    else:
                        n_vals_without_improvement += 1
                        update_lr(optim, n_vals_without_improvement)

                    log_message = ('[VAL] Iter {:06d} | Val logp {:.6f} | '
                                   'NoImproveEpochs {:02d}/{:02d}'.format(
                                       itr, val_logp.avg,
                                       n_vals_without_improvement,
                                       args.early_stopping))
                    logger.info(log_message)

            itr += 1

    logger.info('Training has finished, yielding the best model...')
    best_checkpt = torch.load(save_path)
    model.load_state_dict(best_checkpt['model'])
    return model
Beispiel #2
0
def train():

    model = build_model_tabular(args, 1).to(device)
    set_cnf_options(args, model)

    logger.info(model)
    logger.info("Number of trainable parameters: {}".format(
        count_parameters(model)))

    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)

    time_meter = utils.RunningAverageMeter(0.93)
    loss_meter = utils.RunningAverageMeter(0.93)
    nfef_meter = utils.RunningAverageMeter(0.93)
    nfeb_meter = utils.RunningAverageMeter(0.93)
    tt_meter = utils.RunningAverageMeter(0.93)

    end = time.time()
    best_loss = float('inf')
    model.train()
    for itr in range(1, args.niters + 1):
        optimizer.zero_grad()

        loss = compute_loss(args, model)
        loss_meter.update(loss.item())

        total_time = count_total_time(model)
        nfe_forward = count_nfe(model)

        loss.backward()
        optimizer.step()

        nfe_total = count_nfe(model)
        nfe_backward = nfe_total - nfe_forward
        nfef_meter.update(nfe_forward)
        nfeb_meter.update(nfe_backward)

        time_meter.update(time.time() - end)
        tt_meter.update(total_time)

        log_message = (
            'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | NFE Forward {:.0f}({:.1f})'
            ' | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format(
                itr, time_meter.val, time_meter.avg, loss_meter.val,
                loss_meter.avg, nfef_meter.val, nfef_meter.avg, nfeb_meter.val,
                nfeb_meter.avg, tt_meter.val, tt_meter.avg))
        logger.info(log_message)

        if itr % args.val_freq == 0 or itr == args.niters:
            with torch.no_grad():
                model.eval()
                test_loss = compute_loss(args,
                                         model,
                                         batch_size=args.test_batch_size)
                test_nfe = count_nfe(model)
                log_message = '[TEST] Iter {:04d} | Test Loss {:.6f} | NFE {:.0f}'.format(
                    itr, test_loss, test_nfe)
                logger.info(log_message)

                if test_loss.item() < best_loss:
                    best_loss = test_loss.item()
                    utils.makedirs(args.save)
                    torch.save(
                        {
                            'args': args,
                            'state_dict': model.state_dict(),
                        }, os.path.join(args.save, 'checkpt.pth'))
                model.train()

        if itr % args.viz_freq == 0:
            with torch.no_grad():
                model.eval()

                xx = torch.linspace(-10, 10, 10000).view(-1, 1)
                true_p = data_density(xx)
                plt.plot(xx.view(-1).cpu().numpy(),
                         true_p.view(-1).exp().cpu().numpy(),
                         label='True')

                true_p = model_density(xx, model)
                plt.plot(xx.view(-1).cpu().numpy(),
                         true_p.view(-1).exp().cpu().numpy(),
                         label='Model')

                utils.makedirs(os.path.join(args.save, 'figs'))
                plt.savefig(
                    os.path.join(args.save, 'figs', '{:06d}.jpg'.format(itr)))
                plt.close()

                model.train()

        end = time.time()

    logger.info('Training has finished.')
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    testset = torchvision.datasets.MNIST(
        root='data',
        train=False,
        download=True,
        transform=transforms.Compose(transform_mnist + transform_test +
                                     transform_mnist2))
    test_loader = torch.utils.data.DataLoader(testset,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)
    args.num_classes = 10

batch_time = utils.RunningAverageMeter(0.97)
loss_meter = utils.RunningAverageMeter(0.97)


def update_lipschitz(model):
    with torch.no_grad():
        for m in model.modules():
            if isinstance(m, base_layers.SpectralNormConv2d) or isinstance(
                    m, base_layers.SpectralNormLinear):
                m.compute_weight(update=True)
            if isinstance(m, base_layers.InducedNormConv2d) or isinstance(
                    m, base_layers.InducedNormLinear):
                m.compute_weight(update=True)


def train(args, model, device, train_loader, optimizer, epoch, ema):
    #genGen = Generator(16)
    #genGen = Generator(128)

    genGen = Generator(8)

    #optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    #optimizerGen = optim.SGD(genGen.parameters(), lr=0.1, momentum=0.9)

    #optimizerGen = optim.SGD(genGen.parameters(), lr=0.1, momentum=0.9)

    #optimizerGen = optim.SGD(genGen.parameters(), lr=0.1, momentum=0.9)
    optimizerGen = optim.Adam(genGen.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    time_meter = utils.RunningAverageMeter(0.93)
    loss_meter = utils.RunningAverageMeter(0.93)

    logpz_meter = utils.RunningAverageMeter(0.93)
    delta_logp_meter = utils.RunningAverageMeter(0.93)

    end = time.time()
    best_loss = float('inf')

    model.train()

    #for itr in range(1, args.niters2 + 1):
    #for itr in range(1, args.niters + 1):

    #for itr in range(1, args.niters + 1):
    for itr in range(1, 2):
Beispiel #5
0
def train(
    device, args, model, growth_model, regularization_coeffs, regularization_fns, logger
):
    optimizer = optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
    )

    time_meter = utils.RunningAverageMeter(0.93)
    loss_meter = utils.RunningAverageMeter(0.93)
    nfef_meter = utils.RunningAverageMeter(0.93)
    nfeb_meter = utils.RunningAverageMeter(0.93)
    tt_meter = utils.RunningAverageMeter(0.93)

    full_data = (
        torch.from_numpy(
            args.data.get_data()[args.data.get_times() != args.leaveout_timepoint]
        )
        .type(torch.float32)
        .to(device)
    )

    best_loss = float("inf")
    growth_model.eval()
    end = time.time()
    for itr in range(1, args.niters + 1):
        model.train()
        optimizer.zero_grad()

        # Train
        if args.spectral_norm:
            spectral_norm_power_iteration(model, 1)

        loss = compute_loss(device, args, model, growth_model, logger, full_data)
        loss_meter.update(loss.item())

        if len(regularization_coeffs) > 0:
            # Only regularize on the last timepoint
            reg_states = get_regularization(model, regularization_coeffs)
            reg_loss = sum(
                reg_state * coeff
                for reg_state, coeff in zip(reg_states, regularization_coeffs)
                if coeff != 0
            )
            loss = loss + reg_loss
        total_time = count_total_time(model)
        nfe_forward = count_nfe(model)

        loss.backward()
        optimizer.step()

        # Eval
        nfe_total = count_nfe(model)
        nfe_backward = nfe_total - nfe_forward
        nfef_meter.update(nfe_forward)
        nfeb_meter.update(nfe_backward)
        time_meter.update(time.time() - end)
        tt_meter.update(total_time)

        log_message = (
            "Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) |"
            " NFE Forward {:.0f}({:.1f})"
            " | NFE Backward {:.0f}({:.1f})".format(
                itr,
                time_meter.val,
                time_meter.avg,
                loss_meter.val,
                loss_meter.avg,
                nfef_meter.val,
                nfef_meter.avg,
                nfeb_meter.val,
                nfeb_meter.avg,
            )
        )
        if len(regularization_coeffs) > 0:
            log_message = append_regularization_to_log(
                log_message, regularization_fns, reg_states
            )
        logger.info(log_message)

        if itr % args.val_freq == 0 or itr == args.niters:
            with torch.no_grad():
                train_eval(
                    device, args, model, growth_model, itr, best_loss, logger, full_data
                )

        if itr % args.viz_freq == 0:
            if args.data.get_shape()[0] > 2:
                logger.warning("Skipping vis as data dimension is >2")
            else:
                with torch.no_grad():
                    visualize(device, args, model, itr)
        if itr % args.save_freq == 0:
            utils.save_checkpoint(
                {
                    # 'args': args,
                    "state_dict": model.state_dict(),
                    "growth_state_dict": growth_model.state_dict(),
                },
                args.save,
                epoch=itr,
            )
        end = time.time()
    logger.info("Training has finished.")
Beispiel #6
0
        model.load_state_dict(checkpt["state_dict"])
        if "optim_state_dict" in checkpt.keys():
            optimizer.load_state_dict(checkpt["optim_state_dict"])
            # Manually move optimizer state to device.
            for state in optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = cvt(v)

    if torch.cuda.is_available():
        model = torch.nn.DataParallel(model).cuda()

    # For visualization.
    fixed_z = cvt(torch.randn(100, *data_shape))

    time_meter = utils.RunningAverageMeter(0.97)
    loss_meter = utils.RunningAverageMeter(0.97)
    steps_meter = utils.RunningAverageMeter(0.97)
    grad_meter = utils.RunningAverageMeter(0.97)
    tt_meter = utils.RunningAverageMeter(0.97)

    if args.spectral_norm and not args.resume:
        spectral_norm_power_iteration(model, 500)

    best_loss = float("inf")
    itr = 0
    for epoch in range(args.begin_epoch, args.num_epochs + 1):
        model.train()
        train_loader = get_train_loader(train_set, epoch)
        for _, (x, y) in enumerate(train_loader):
            start = time.time()
def main():
    # os.system('shutdown -c')  # cancel previous shutdown command

    if write_log:
        utils.makedirs(args.save)
        logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__))

        logger.info(args)

        args_file_path = os.path.join(args.save, 'args.yaml')
        with open(args_file_path, 'w') as f:
            yaml.dump(vars(args), f, default_flow_style=False)

    if args.distributed:
        if write_log: logger.info('Distributed initializing process group')
        torch.cuda.set_device(args.local_rank)
        distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                       world_size=dist_utils.env_world_size(), rank=env_rank())
        assert (dist_utils.env_world_size() == distributed.get_world_size())
        if write_log: logger.info("Distributed: success (%d/%d)" % (args.local_rank, distributed.get_world_size()))
        device = torch.device("cuda:%d" % torch.cuda.current_device() if torch.cuda.is_available() else "cpu")
    else:
        device = torch.cuda.current_device()  #

    # import pdb; pdb.set_trace()
    cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True)

    # load dataset
    train_loader, test_loader, data_shape = get_dataset(args)

    trainlog = os.path.join(args.save, 'training.csv')
    testlog = os.path.join(args.save, 'test.csv')

    traincolumns = ['itr', 'wall', 'itr_time', 'loss', 'bpd', 'fe', 'total_time', 'grad_norm']
    testcolumns = ['wall', 'epoch', 'eval_time', 'bpd', 'fe', 'total_time', 'transport_cost']

    # build model
    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = create_model(args, data_shape, regularization_fns).cuda()
    if args.distributed: model = dist_utils.DDP(model,
                                                device_ids=[args.local_rank],
                                                output_device=args.local_rank)

    traincolumns = append_regularization_keys_header(traincolumns, regularization_fns)

    if not args.resume and write_log:
        with open(trainlog, 'w') as f:
            csvlogger = csv.DictWriter(f, traincolumns)
            csvlogger.writeheader()
        with open(testlog, 'w') as f:
            csvlogger = csv.DictWriter(f, testcolumns)
            csvlogger.writeheader()

    set_cnf_options(args, model)

    if write_log: logger.info(model)
    if write_log: logger.info("Number of trainable parameters: {}".format(count_parameters(model)))
    if write_log: logger.info('Iters per train epoch: {}'.format(len(train_loader)))
    if write_log: logger.info('Iters per test: {}'.format(len(test_loader)))

    # optimizer
    if args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9,
                              nesterov=False)

    # restore parameters
    # import pdb; pdb.set_trace()
    if args.resume is not None:
        # import pdb; pdb.set_trace()
        print('resume from checkpoint')
        checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda(args.local_rank))
        model.load_state_dict(checkpt["state_dict"])
        if "optim_state_dict" in checkpt.keys():
            optimizer.load_state_dict(checkpt["optim_state_dict"])
            # Manually move optimizer state to device.
            for state in optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = cvt(v)

    # For visualization.
    if write_log: fixed_z = cvt(torch.randn(min(args.test_batch_size, 100), *data_shape))

    if write_log:
        time_meter = utils.RunningAverageMeter(0.97)
        bpd_meter = utils.RunningAverageMeter(0.97)
        loss_meter = utils.RunningAverageMeter(0.97)
        steps_meter = utils.RunningAverageMeter(0.97)
        grad_meter = utils.RunningAverageMeter(0.97)
        tt_meter = utils.RunningAverageMeter(0.97)

    if not args.resume:
        best_loss = float("inf")
        itr = 0
        wall_clock = 0.
        begin_epoch = 1
        chkdir = args.save
        '''
    elif args.resume and args.validate:
        chkdir = os.path.dirname(args.resume)
        wall_clock = 0
        itr = 0
        best_loss = 0.0
        begin_epoch = 0
        '''
    else:
        chkdir = os.path.dirname(args.resume)
        filename = os.path.join(chkdir, 'test.csv')
        print(filename)
        tedf = pd.read_csv(os.path.join(chkdir, 'test.csv'))
        trdf = pd.read_csv(os.path.join(chkdir, 'training.csv'))
        # import pdb; pdb.set_trace()
        wall_clock = trdf['wall'].to_numpy()[-1]
        itr = trdf['itr'].to_numpy()[-1]
        best_loss = tedf['bpd'].min()
        begin_epoch = int(tedf['epoch'].to_numpy()[-1] + 1)  # not exactly correct

    if args.distributed:
        if write_log: logger.info('Syncing machines before training')
        dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda())

    for epoch in range(begin_epoch, begin_epoch + 1):
        # compute test loss
        print('Evaluating')
        model.eval()
        if args.local_rank == 0:
            utils.makedirs(args.save)
            # import pdb; pdb.set_trace()
            if hasattr(model, 'module'):
                _state = model.module.state_dict()
            else:
                _state = model.state_dict()
            torch.save({
                "args": args,
                "state_dict": _state,  # model.module.state_dict() if torch.cuda.is_available() else model.state_dict(),
                "optim_state_dict": optimizer.state_dict(),
                "fixed_z": fixed_z.cpu()
            }, os.path.join(args.save, "checkpt_%d.pth" % epoch))

        # save real and generate with different temperatures
        fig_num = 64
        if True:  # args.save_real:
            for i, (x, y) in enumerate(test_loader):
                if i < 100:
                    pass
                elif i == 100:
                    real = x.size(0)
                else:
                    break
            if x.shape[0] > fig_num:
                x = x[:fig_num, ...]
            # import pdb; pdb.set_trace()
            fig_filename = os.path.join(chkdir, "real.jpg")
            save_image(x.float() / 255.0, fig_filename, nrow=8)

        if True:  # args.generate:
            print('\nGenerating images... ')
            fixed_z = cvt(torch.randn(fig_num, *data_shape))
            nb = int(np.ceil(np.sqrt(float(fixed_z.size(0)))))
            for t in [ 1.0, 0.99, 0.98, 0.97,0.96,0.95,0.93,0.92,0.90,0.85,0.8,0.75,0.7,0.65,0.6]:
                # visualize samples and density
                fig_filename = os.path.join(chkdir, "generated-T%g.jpg" % t)
                utils.makedirs(os.path.dirname(fig_filename))
                generated_samples = model(t * fixed_z, reverse=True)
                x = unshift(generated_samples[0].view(-1, *data_shape), 8)
                save_image(x, fig_filename, nrow=nb)
Beispiel #8
0
def main():
    lipschitz_constants = []
    start_time = time.time()
    # entropy_avg_meter = AverageValueMeter()
    # latent_nats_avg_meter = AverageValueMeter()
    point_nats_avg_meter = AverageValueMeter()

    time_meter = utils.RunningAverageMeter(0.93)
    loss_meter = utils.RunningAverageMeter(0.93)
    logpz_meter = utils.RunningAverageMeter(0.93)
    delta_logp_meter = utils.RunningAverageMeter(0.93)
    end = time.time()

    for epoch in range(args.begin_epoch, args.epochs):
        train_loss, train_count = 0, 0
        train_logpz, train_delta_logp = 0, 0
        for bidx, data in enumerate(train_loader):
            idx_batch, tr_batch, te_batch = data['idx'], data[
                'train_points'], data['test_points']
            step = bidx + len(train_loader) * epoch
            model.train()
            if args.random_rotate:
                tr_batch, _, _ = apply_random_rotation(
                    tr_batch, rot_axis=train_loader.dataset.gravity_axis)
            # use toy model
            # inputs = tr_batch.view(-1, args.tr_max_sample_points*3).to(device)
            # zero = torch.zeros(inputs.shape[0], 1).to(inputs)
            # # transform to z
            # z, delta_logp = model(inputs, zero)
            #
            # #compute log p(z)
            # logpz = standard_normal_logprob(z).sum(1, keepdim=True)
            # logpx = logpz - delta_logp
            # loss = -torch.mean(logpx)/args.tr_max_sample_points/3/np.log(2)
            #
            inputs = tr_batch.to(device)
            loss, logpz, delta_logp = model(inputs, step, writer)

            loss_meter.update(loss.item())
            logpz_meter.update(torch.mean(logpz).item())
            delta_logp_meter.update(torch.mean(-delta_logp).item())

            train_count += 1
            train_loss += loss.item()
            train_logpz += torch.mean(logpz).item()
            train_delta_logp += torch.mean(-delta_logp).item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #update_lipschitz(model, 5)
            update_lipschitz(model)

        #scheduler.step()
        lipschitz_constants.append(get_lipschitz_constants(model))
        logger.info('Lipsh: {}'.format(pretty_repr(lipschitz_constants[-1])))
        writer.add_scalar('avg_train_loss', train_loss / train_count, epoch)
        writer.add_scalar('avg_lopz', train_logpz / train_count, epoch)
        writer.add_scalar('avg_neg_deltalogz', train_delta_logp / train_count,
                          epoch)

        # print("Epoch %d Time [%3.2fs]  Likelihood Loss  %2.5f"
        #               % (epoch, time.time() - start_time, train_loss / train_count))

        time_meter.update(time.time() - end)
        logger.info(
            'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f})'
            ' | Logp(z) {:.6f}({:.6f}) | DeltaLogp {:.6f}({:.6f})'.format(
                epoch, time_meter.val, time_meter.avg, loss_meter.val,
                loss_meter.avg, logpz_meter.val, logpz_meter.avg,
                delta_logp_meter.val, delta_logp_meter.avg))

        # generate samples
        if epoch % args.val_freq == 0:
            print('Start testing the model at epoch {}'.format(epoch))
            model.eval()
            with torch.no_grad():
                _, samples = model.sample(args.val_batchsize,
                                          args.tr_max_sample_points,
                                          truncate_std=None,
                                          gpu=device)
                #samples = model.inverse(torch.randn(args.val_batchsize, args.tr_max_sample_points*3).to(device))
            test_path = os.path.join('checkpoints', args.save, 'test_results/')
            if not os.path.isdir(test_path):
                os.mkdir(test_path)
            elif epoch == 0:
                files = glob.glob(test_path + '*.npy')
                for f in files:
                    os.remove(f)
            np.save(os.path.join(test_path, 'samples_' + str(epoch) + '.npy'),
                    samples.detach().cpu().numpy())

            # save the recent model (should save the best one)
            torch.save(
                model.state_dict(),
                os.path.join('checkpoints', args.save, 'models/model.t7'))
Beispiel #9
0
            # Manually move optimizer state to device.
            for state in optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = cvt(v)
        if "iter" in checkpt.keys():
            itr = checkpt["iter"]
        if "last_epoch" in checkpt.keys():
            args.begin_epoch = checkpt["last_epoch"] + 1

    if torch.cuda.is_available() and not args.use_cpu:
        aug_model = torch.nn.DataParallel(aug_model).cuda()

    # For visualization.

    time_meter = utils.RunningAverageMeter(RUNNINGAVE_PARAM)
    loss_meter = utils.RunningAverageMeter(RUNNINGAVE_PARAM)
    steps_meter = utils.RunningAverageMeter(RUNNINGAVE_PARAM)
    grad_meter = utils.RunningAverageMeter(RUNNINGAVE_PARAM)
    tt_meter = utils.RunningAverageMeter(RUNNINGAVE_PARAM)

    best_loss = float("inf")
    for epoch in range(args.begin_epoch, args.num_epochs + 1):
        aug_model.train()
        for temp_idx, x in enumerate(train_loader):
            ## x is a tuple of (values, times, stdv, masks)
            start = time.time()
            optimizer.zero_grad()

            # cast data and move to device
            x = map(cvt, x)
Beispiel #10
0
def main():
    # parse command line arguments
    parser = argparse.ArgumentParser(description="parse args")
    parser.add_argument('-d', '--dataset', default='celeba', type=str, help='dataset name',
        choices=['celeba'])
    parser.add_argument('-dist', default='normal', type=str, choices=['normal', 'laplace', 'flow'])
    parser.add_argument('-n', '--num-epochs', default=50, type=int, help='number of training epochs')
    parser.add_argument('-b', '--batch-size', default=2048, type=int, help='batch size')
    parser.add_argument('-l', '--learning-rate', default=1e-3, type=float, help='learning rate')
    parser.add_argument('-z', '--latent-dim', default=100, type=int, help='size of latent dimension')
    parser.add_argument('--beta', default=1, type=float, help='ELBO penalty term')
    parser.add_argument('--beta_sens', default=20, type=float, help='Relative importance of predicting sensitive attributes')
    #parser.add_argument('--sens_idx', default=[13, 15, 20], type=list, help='Relative importance of predicting sensitive attributes')
    parser.add_argument('--tcvae', action='store_true')
    parser.add_argument('--exclude-mutinfo', action='store_true')
    parser.add_argument('--beta-anneal', action='store_true')
    parser.add_argument('--lambda-anneal', action='store_true')
    parser.add_argument('--mss', action='store_true', help='use the improved minibatch estimator')
    parser.add_argument('--conv', action='store_true')
    parser.add_argument('--clf_samps', action='store_true')
    parser.add_argument('--clf_means', action='store_false', dest='clf_samps')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--visdom', action='store_true', help='whether plotting in visdom is desired')
    parser.add_argument('--save', default='betatcvae-celeba')
    parser.add_argument('--log_freq', default=200, type=int, help='num iterations per log')
    parser.add_argument('--audit', action='store_true',
            help='after each epoch, audit the repr wrt fair clf task')
    args = parser.parse_args()
    print(args)
    
    if not os.path.exists(args.save):
        os.makedirs(args.save)

    writer = SummaryWriter(args.save)
    writer.add_text('args', json.dumps(vars(args), sort_keys=True, indent=4))

    log_file = os.path.join(args.save, 'train.log')
    if os.path.exists(log_file):
        os.remove(log_file)

    print(vars(args))
    print(vars(args), file=open(log_file, 'w'))

    torch.cuda.set_device(args.gpu)

    # data loader
    loaders = setup_data_loaders(args, use_cuda=True)

    # setup the VAE
    if args.dist == 'normal':
        prior_dist = dist.Normal()
        q_dist = dist.Normal()
    elif args.dist == 'laplace':
        prior_dist = dist.Laplace()
        q_dist = dist.Laplace()
    elif args.dist == 'flow':
        prior_dist = FactorialNormalizingFlow(dim=args.latent_dim, nsteps=32)
        q_dist = dist.Normal()

    x_dist = dist.Normal() if args.dataset == 'celeba' else dist.Bernoulli()
    a_dist = dist.Bernoulli()
    vae = SensVAE(z_dim=args.latent_dim, use_cuda=True, prior_dist=prior_dist, 
            q_dist=q_dist, include_mutinfo=not args.exclude_mutinfo, 
            tcvae=args.tcvae, conv=args.conv, mss=args.mss, 
            n_chan=3 if args.dataset == 'celeba' else 1, sens_idx=SENS_IDX,
            x_dist=x_dist, a_dist=a_dist, clf_samps=args.clf_samps)

    if args.audit:
        audit_label_fn = get_label_fn(
                dict(data=dict(name='celeba', label_fn='H'))
                )
        audit_repr_fns = dict()
        audit_attr_fns = dict()
        audit_models = dict()
        audit_train_metrics = dict()
        audit_validation_metrics = dict()
        for attr_fn_name in CELEBA_SENS_IDX.keys():
            model = MLPClassifier(args.latent_dim, 1000, 2)
            model.cuda()
            audit_models[attr_fn_name] = model
            audit_repr_fns[attr_fn_name] = get_repr_fn(
                dict(data=dict(
                    name='celeba', repr_fn='remove_all', attr_fn=attr_fn_name))
                )
            audit_attr_fns[attr_fn_name] = get_attr_fn(
                dict(data=dict(name='celeba', attr_fn=attr_fn_name))
                )

    # setup the optimizer
    optimizer = optim.Adam(vae.parameters(), lr=args.learning_rate)
    if args.audit:
        Adam = optim.Adam
        audit_optimizers = dict()
        for k, v in audit_models.items():
            audit_optimizers[k] = Adam(v.parameters(), lr=args.learning_rate)


    # setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom(env=args.save, port=3776)

    train_elbo = []
    train_tc = []

    # training loop
    dataset_size = len(loaders['train'].dataset)
    num_iterations = len(loaders['train']) * args.num_epochs
    iteration = 0
    # initialize loss accumulator
    elbo_running_mean = utils.RunningAverageMeter()
    tc_running_mean = utils.RunningAverageMeter()
    clf_acc_meters = {'clf_acc{}'.format(s): utils.RunningAverageMeter() for s in vae.sens_idx}

    val_elbo_running_mean = utils.RunningAverageMeter()
    val_tc_running_mean = utils.RunningAverageMeter()
    val_clf_acc_meters = {'val_clf_acc{}'.format(s): utils.RunningAverageMeter() for s in vae.sens_idx}


    while iteration < num_iterations:
        bar = tqdm(range(len(loaders['train'])))
        for i, (x, a) in enumerate(loaders['train']):
            bar.update()
            iteration += 1
            batch_time = time.time()
            vae.train()
            #anneal_kl(args, vae, iteration)  # TODO try annealing beta/beta_sens
            vae.beta = args.beta
            vae.beta_sens = args.beta_sens
            optimizer.zero_grad()
            # transfer to GPU
            x = x.cuda(async=True)
            a = a.float()
            a = a.cuda(async=True)
            # wrap the mini-batch in a PyTorch Variable
            x = Variable(x)
            a = Variable(a)
            # do ELBO gradient and accumulate loss
            obj, elbo, metrics = vae.elbo(x, a, dataset_size)
            if utils.isnan(obj).any():
                raise ValueError('NaN spotted in objective.')
            obj.mean().mul(-1).backward()
            elbo_running_mean.update(elbo.mean().data.item())
            tc_running_mean.update(metrics['tc'])
            for (s, meter), (_, acc) in zip(clf_acc_meters.items(), metrics.items()):
                clf_acc_meters[s].update(acc.data.item())
            optimizer.step()

            if args.audit:
                for model in audit_models.values():
                    model.train()
                # now re-encode x and take a step to train each audit classifier
                for opt in audit_optimizers.values():
                    opt.zero_grad()
                with torch.no_grad():
                    zs, z_params = vae.encode(x)
                    if args.clf_samps:
                        z = zs
                    else:
                        z_mu = z_params.select(-1, 0)
                        z = z_mu
                    a_all = a
                for subgroup, model in audit_models.items():
                    # noise out sensitive dims of latent code
                    z_ = z.clone()
                    a_all_ = a_all.clone()
                    # subsample to just sens attr of interest for this subgroup
                    a_ = audit_attr_fns[subgroup](a_all_)
                    # noise out sensitive dims for this subgroup
                    z_ = audit_repr_fns[subgroup](z_, None, None)
                    y_ = audit_label_fn(a_all_).long()

                    loss, _, metrics = model(z_, y_, a_)
                    loss.backward()
                    audit_optimizers[subgroup].step()
                    metrics_dict = {}
                    metrics_dict.update(loss=loss.detach().item())
                    for k, v in metrics.items():
                        if v.numel() > 1:
                            k += '-avg'
                            v = v.float().mean()
                        metrics_dict.update({k:v.detach().item()})
                    audit_train_metrics[subgroup] = metrics_dict

            # report training diagnostics
            if iteration % args.log_freq == 0:
                if args.audit:
                    for subgroup, metrics in audit_train_metrics.items():
                        for metric_name, metric_value in metrics.items():
                            writer.add_scalar(
                                    '{}/{}'.format(subgroup, metric_name),
                                    metric_value, iteration)

                train_elbo.append(elbo_running_mean.avg)
                writer.add_scalar('train_elbo', elbo_running_mean.avg, iteration)
                train_tc.append(tc_running_mean.avg)
                writer.add_scalar('train_tc', tc_running_mean.avg, iteration)
                msg = '[iteration %03d] time: %.2f \tbeta %.2f \tlambda %.2f training ELBO: %.4f (%.4f) training TC %.4f (%.4f)' % (
                    iteration, time.time() - batch_time, vae.beta, vae.lamb,
                    elbo_running_mean.val, elbo_running_mean.avg,
                    tc_running_mean.val, tc_running_mean.avg)
                for k, v in clf_acc_meters.items():
                    msg += ' {}: {:.2f}'.format(k, v.avg)
                    writer.add_scalar(k, v.avg, iteration)
                print(msg)
                print(msg, file=open(log_file, 'a'))

                vae.eval()
                ################################################################
                # evaluate validation metrics on vae and auditors
                for x, a in loaders['validation']:
                    # transfer to GPU
                    x = x.cuda(async=True)
                    a = a.float()
                    a = a.cuda(async=True)
                    # wrap the mini-batch in a PyTorch Variable
                    x = Variable(x)
                    a = Variable(a)
                    # do ELBO gradient and accumulate loss
                    obj, elbo, metrics = vae.elbo(x, a, dataset_size)
                    if utils.isnan(obj).any():
                        raise ValueError('NaN spotted in objective.')
                    #
                    val_elbo_running_mean.update(elbo.mean().data.item())
                    val_tc_running_mean.update(metrics['tc'])
                    for (s, meter), (_, acc) in zip(
                            val_clf_acc_meters.items(), metrics.items()):
                        val_clf_acc_meters[s].update(acc.data.item())

                if args.audit:
                    for model in audit_models.values():
                        model.eval()
                    with torch.no_grad():
                        zs, z_params = vae.encode(x)
                        if args.clf_samps:
                            z = zs
                        else:
                            z_mu = z_params.select(-1, 0)
                            z = z_mu
                        a_all = a
                    for subgroup, model in audit_models.items():
                        # noise out sensitive dims of latent code
                        z_ = z.clone()
                        a_all_ = a_all.clone()
                        # subsample to just sens attr of interest for this subgroup
                        a_ = audit_attr_fns[subgroup](a_all_)
                        # noise out sensitive dims for this subgroup
                        z_ = audit_repr_fns[subgroup](z_, None, None)
                        y_ = audit_label_fn(a_all_).long()

                        loss, _, metrics = model(z_, y_, a_)
                        loss.backward()
                        audit_optimizers[subgroup].step()
                        metrics_dict = {}
                        metrics_dict.update(val_loss=loss.detach().item())
                        for k, v in metrics.items():
                            k = 'val_' + k  # denote a validation metric
                            if v.numel() > 1:
                                k += '-avg'
                                v = v.float().mean()
                            metrics_dict.update({k:v.detach().item()})
                        audit_validation_metrics[subgroup] = metrics_dict

                # after iterating through validation set, write summaries
                for subgroup, metrics in audit_validation_metrics.items():
                    for metric_name, metric_value in metrics.items():
                        writer.add_scalar(
                                '{}/{}'.format(subgroup, metric_name),
                                metric_value, iteration)
                writer.add_scalar('val_elbo', val_elbo_running_mean.avg, iteration)
                writer.add_scalar('val_tc', val_tc_running_mean.avg, iteration)
                for k, v in val_clf_acc_meters.items():
                    writer.add_scalar(k, v.avg, iteration)

                ################################################################
                # finally, plot training and test ELBOs
                if args.visdom:
                    display_samples(vae, x, vis)
                    plot_elbo(train_elbo, vis)
                    plot_tc(train_tc, vis)

                utils.save_checkpoint({
                    'state_dict': vae.state_dict(),
                    'args': args}, args.save, iteration // len(loaders['train']))
                eval('plot_vs_gt_' + args.dataset)(vae, loaders['train'].dataset,
                    os.path.join(args.save, 'gt_vs_latent_{:05d}.png'.format(iteration)))

    # Report statistics after training
    vae.eval()
    utils.save_checkpoint({
        'state_dict': vae.state_dict(),
        'args': args}, args.save, 0)
    dataset_loader = DataLoader(loaders['train'].dataset, batch_size=1000, num_workers=1, shuffle=False)
    if False:
        logpx, dependence, information, dimwise_kl, analytical_cond_kl, marginal_entropies, joint_entropy = \
            elbo_decomposition(vae, dataset_loader)
        torch.save({
            'logpx': logpx,
            'dependence': dependence,
            'information': information,
            'dimwise_kl': dimwise_kl,
            'analytical_cond_kl': analytical_cond_kl,
            'marginal_entropies': marginal_entropies,
            'joint_entropy': joint_entropy
        }, os.path.join(args.save, 'elbo_decomposition.pth'))
    eval('plot_vs_gt_' + args.dataset)(vae, dataset_loader.dataset, os.path.join(args.save, 'gt_vs_latent.png'))

    for file in [open(os.path.join(args.save, 'done'), 'w'), sys.stdout]:
        print('done', file=file)

    return vae
Beispiel #11
0
def train(args, model, growth_model):
    logger.info(model)
    logger.info("Number of trainable parameters: {}".format(count_parameters(model)))

    #optimizer = optim.Adam(set(model.parameters()) | set(growth_model.parameters()), 
    optimizer = optim.Adam(model.parameters(), 
                           lr=args.lr, weight_decay=args.weight_decay)
    #growth_optimizer = optim.Adam(growth_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    time_meter = utils.RunningAverageMeter(0.93)
    loss_meter = utils.RunningAverageMeter(0.93)
    nfef_meter = utils.RunningAverageMeter(0.93)
    nfeb_meter = utils.RunningAverageMeter(0.93)
    tt_meter = utils.RunningAverageMeter(0.93)

    end = time.time()
    best_loss = float('inf')
    model.train()
    growth_model.eval()
    for itr in range(1, args.niters + 1):
        optimizer.zero_grad()
        #growth_optimizer.zero_grad()

        ### Train
        if args.spectral_norm: spectral_norm_power_iteration(model, 1)
        #if args.spectral_norm: spectral_norm_power_iteration(growth_model, 1)

        loss = compute_loss(args, model, growth_model)
        loss_meter.update(loss.item())

        if len(regularization_coeffs) > 0:
            # Only regularize on the last timepoint
            reg_states = get_regularization(model, regularization_coeffs)
            reg_loss = sum(
                reg_state * coeff for reg_state, coeff in zip(reg_states, regularization_coeffs) if coeff != 0
            )
            loss = loss + reg_loss

        #if len(growth_regularization_coeffs) > 0:
        #    growth_reg_states = get_regularization(growth_model, growth_regularization_coeffs)
        #    reg_loss = sum(
        #        reg_state * coeff for reg_state, coeff in zip(growth_reg_states, growth_regularization_coeffs) if coeff != 0
        #    )
        #    loss2 = loss2 + reg_loss

        total_time = count_total_time(model)
        nfe_forward = count_nfe(model)

        loss.backward()
        #loss2.backward()
        optimizer.step()
        #growth_optimizer.step()

        ### Eval
        nfe_total = count_nfe(model)
        nfe_backward = nfe_total - nfe_forward
        nfef_meter.update(nfe_forward)
        nfeb_meter.update(nfe_backward)
        time_meter.update(time.time() - end)
        tt_meter.update(total_time)

        log_message = (
            'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | NFE Forward {:.0f}({:.1f})'
            ' | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format(
                itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, nfef_meter.val, nfef_meter.avg,
                nfeb_meter.val, nfeb_meter.avg, tt_meter.val, tt_meter.avg
            )
        )
        if len(regularization_coeffs) > 0:
            log_message = append_regularization_to_log(log_message, regularization_fns, reg_states)

        logger.info(log_message)

        if itr % args.val_freq == 0 or itr == args.niters:
            with torch.no_grad():
                model.eval()
                growth_model.eval()
                test_loss = compute_loss(args, model, growth_model)
                test_nfe = count_nfe(model)
                log_message = '[TEST] Iter {:04d} | Test Loss {:.6f} | NFE {:.0f}'.format(itr, test_loss, test_nfe)
                logger.info(log_message)

                if test_loss.item() < best_loss:
                    best_loss = test_loss.item()
                    utils.makedirs(args.save)
                    torch.save({
                        'args': args,
                        'state_dict': model.state_dict(),
                        'growth_state_dict': growth_model.state_dict(),
                    }, os.path.join(args.save, 'checkpt.pth'))
                model.train()

        if itr % args.viz_freq == 0:
            with torch.no_grad():
                model.eval()
                for i, tp in enumerate(timepoints):
                    p_samples = viz_sampler(tp)
                    sample_fn, density_fn = get_transforms(model, int_tps[:i+1])
                    #growth_sample_fn, growth_density_fn = get_transforms(growth_model, int_tps[:i+1])
                    plt.figure(figsize=(9, 3))
                    visualize_transform(
                        p_samples, torch.randn, standard_normal_logprob, transform=sample_fn, inverse_transform=density_fn,
                        samples=True, npts=100, device=device
                    )
                    fig_filename = os.path.join(args.save, 'figs', '{:04d}_{:01d}.jpg'.format(itr, i))
                    utils.makedirs(os.path.dirname(fig_filename))
                    plt.savefig(fig_filename)
                    plt.close()

                    #visualize_transform(
                    #    p_samples, torch.rand, uniform_logprob, transform=growth_sample_fn, 
                    #    inverse_transform=growth_density_fn,
                    #    samples=True, npts=800, device=device
                    #)

                    #fig_filename = os.path.join(args.save, 'growth_figs', '{:04d}_{:01d}.jpg'.format(itr, i))
                    #utils.makedirs(os.path.dirname(fig_filename))
                    #plt.savefig(fig_filename)
                    #plt.close()
                model.train()

        """
        if itr % args.viz_freq_growth == 0:
            with torch.no_grad():
                growth_model.eval()
                # Visualize growth transform
                growth_filename = os.path.join(args.save, 'growth', '{:04d}.jpg'.format(itr))
                utils.makedirs(os.path.dirname(growth_filename))
                visualize_growth(growth_model, data, labels, npts=200, device=device)
                plt.savefig(growth_filename)
                plt.close()
                growth_model.train()
        """

        end = time.time()
    logger.info('Training has finished.')
Beispiel #12
0
    return loss


if __name__ == '__main__':

    model = construct_model().to(device)

    logger.info(model)
    logger.info("Number of trainable parameters: {}".format(
        count_parameters(model)))

    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)

    time_meter = utils.RunningAverageMeter(0.98)
    loss_meter = utils.RunningAverageMeter(0.98)

    end = time.time()
    best_loss = float('inf')
    model.train()
    for itr in range(1, args.niters + 1):
        optimizer.zero_grad()

        loss = compute_loss(args, model)
        loss_meter.update(loss.item())

        loss.backward()
        optimizer.step()

        time_meter.update(time.time() - end)
Beispiel #13
0
def main():
    # parse command line arguments
    parser = argparse.ArgumentParser(description="parse args")
    parser.add_argument('-d',
                        '--dataset',
                        default='faces',
                        type=str,
                        help='dataset name',
                        choices=['shapes', 'faces'])
    parser.add_argument('-dist',
                        default='normal',
                        type=str,
                        choices=['normal', 'laplace', 'flow'])
    parser.add_argument('-x_dist',
                        default='normal',
                        type=str,
                        choices=['normal', 'bernoulli'])
    parser.add_argument('-n',
                        '--num-epochs',
                        default=50,
                        type=int,
                        help='number of training epochs')
    parser.add_argument('-b',
                        '--batch-size',
                        default=2048,
                        type=int,
                        help='batch size')
    parser.add_argument('-l',
                        '--learning-rate',
                        default=1e-3,
                        type=float,
                        help='learning rate')
    parser.add_argument('-z',
                        '--latent-dim',
                        default=10,
                        type=int,
                        help='size of latent dimension')
    parser.add_argument('--beta',
                        default=1,
                        type=float,
                        help='ELBO penalty term')
    parser.add_argument('--tcvae', action='store_true')
    parser.add_argument('--exclude-mutinfo', action='store_false')
    parser.add_argument('--beta-anneal', action='store_true')
    parser.add_argument('--lambda-anneal', action='store_true')
    parser.add_argument('--mss',
                        action='store_true',
                        help='use the improved minibatch estimator')
    parser.add_argument('--conv', action='store_true')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--visdom',
                        action='store_true',
                        help='whether plotting in visdom is desired')
    parser.add_argument('--save', default='test2')
    parser.add_argument('--log_freq',
                        default=50,
                        type=int,
                        help='num iterations per log')
    parser.add_argument(
        '-problem',
        default='Climate_ORNL',
        type=str,
        choices=['HEP_SL', 'Climate_ORNL', 'Climate_C', 'Nuclear_Physics'])
    parser.add_argument('--VIB', action='store_true', help='VIB regression')
    parser.add_argument('--UQ',
                        action='store_true',
                        help='Uncertainty Quantification - likelihood')
    parser.add_argument('-name_S',
                        '--name_save',
                        default=[],
                        type=str,
                        help='name to save file')
    parser.add_argument('--classification', action='store_true')
    parser.add_argument('--Func_reg', action='store_true')

    args = parser.parse_args()

    torch.cuda.set_device(args.gpu)

    # data loader
    train_loader = setup_data_loaders(args, use_cuda=True)

    # setup the VAE
    if args.dist == 'normal':
        prior_dist = dist.Normal()
        q_dist = dist.Normal()
    elif args.dist == 'laplace':
        prior_dist = dist.Laplace()
        q_dist = dist.Laplace()
    elif args.dist == 'flow':
        prior_dist = FactorialNormalizingFlow(dim=args.latent_dim, nsteps=32)
        q_dist = dist.Normal()

    # setup the likelihood distribution
    if args.x_dist == 'normal':
        x_dist = dist.Normal()
    elif args.x_dist == 'bernoulli':
        x_dist = dist.Bernoulli()
    else:
        raise ValueError('x_dist can be Normal or Bernoulli')

    vae = VAE(z_dim=args.latent_dim,
              beta=args.beta,
              use_cuda=True,
              prior_dist=prior_dist,
              q_dist=q_dist,
              x_dist=x_dist,
              x_dist_name=args.x_dist,
              include_mutinfo=not args.exclude_mutinfo,
              tcvae=args.tcvae,
              conv=args.conv,
              mss=args.mss,
              problem=args.problem,
              VIB=args.VIB,
              UQ=args.UQ,
              classification=args.classification)

    if (args.Func_reg):
        args.latent_dim2 = 4
        args.beta2 = 0.0
        prior_dist2 = dist.Normal()
        q_dist2 = dist.Normal()
        x_dist2 = dist.Normal()
        args.x_dist2 = dist.Normal()
        args.tcvae2 = False
        args.conv2 = False
        args.problem2 = 'Climate_ORNL'
        args.VIB2 = True
        args.UQ2 = False
        args.classification2 = False

        vae2 = VAE(z_dim=args.latent_dim2,
                   beta=args.beta2,
                   use_cuda=True,
                   prior_dist=prior_dist2,
                   q_dist=q_dist2,
                   x_dist=x_dist2,
                   x_dist_name=args.x_dist2,
                   include_mutinfo=not args.exclude_mutinfo,
                   tcvae=args.tcvae2,
                   conv=args.conv2,
                   mss=args.mss,
                   problem=args.problem2,
                   VIB=args.VIB2,
                   UQ=args.UQ2,
                   classification=args.classification2)

    # setup the optimizer
    #optimizer = optim.Adam(vae.parameters(), lr=args.learning_rate)
    if (args.Func_reg):
        params = list(vae.parameters()) + list(vae2.parameters())
        optimizer = optim.RMSprop(params, lr=args.learning_rate)
    else:
        optimizer = optim.RMSprop(vae.parameters(), lr=args.learning_rate)
    # setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom(env=args.save, port=4500)

    train_elbo = []
    train_rmse = []
    train_mae = []
    train_elbo1 = []
    train_elbo2 = []
    train_elbo3 = []
    train_elbo4 = []
    train_rmse2 = []
    train_mae2 = []
    # training loop
    dataset_size = len(train_loader.dataset)
    num_iterations = len(train_loader) * args.num_epochs
    print("num_iteration", len(train_loader), args.num_epochs)
    iteration = 0
    print("likelihood function", args.x_dist, x_dist)

    train_iter = iter(train_loader)
    images = train_iter.next()

    img_max = train_loader.dataset.__getmax__()

    # initialize loss accumulator
    elbo_running_mean = utils.RunningAverageMeter()
    elbo_running_rmse = utils.RunningAverageMeter()
    elbo_running_mae = utils.RunningAverageMeter()
    elbo_running_mean1 = utils.RunningAverageMeter()
    elbo_running_mean2 = utils.RunningAverageMeter()
    elbo_running_mean3 = utils.RunningAverageMeter()
    elbo_running_mean4 = utils.RunningAverageMeter()
    elbo_running_rmse2 = utils.RunningAverageMeter()
    elbo_running_mae2 = utils.RunningAverageMeter()
    #plot the data to visualize

    x_test = train_loader.dataset.imgs_test
    x_train = train_loader.dataset.imgs

    def count_parameters(model):
        trainable = sum(p.numel() for p in model.parameters()
                        if p.requires_grad)
        total = sum(p.numel() for p in model.parameters())
        return (trainable, total)

    while iteration < num_iterations:
        for i, xy in enumerate(train_loader):
            iteration += 1
            batch_time = time.time()
            vae.train()
            #anneal_kl(args, vae, iteration)
            optimizer.zero_grad()
            # transfer to GPU
            if (args.problem == 'HEP_SL'):
                x = xy[0]
                x = x.float()
                x = x.cuda()
                x = Variable(x)

                y = xy[1]
                y = y.cuda()
                y = Variable(y)

                label = xy[2]
                label = label.cuda()
                label = Variable(label)

            # Get the Training Objective
            obj, elbo, x_mean_pred, z_params1, _, _ = vae.elbo(
                x, y, label, dataset_size)
            if utils.isnan(obj).any():
                raise ValueError('NaN spotted in objective.')

            obj.mean().mul(-1).backward()
            elbo_running_mean.update(elbo.mean().data)  #[0])
            optimizer.step()

            # report training diagnostics
            if iteration % args.log_freq == 0:
                train_elbo.append(elbo_running_mean.avg)

                if (args.VIB):
                    if not args.classification:
                        if (args.UQ):
                            A = x_mean_pred.cpu().data.numpy()[:, :, 0]
                        else:
                            A = x_mean_pred.cpu().data.numpy()
                        B = y.cpu().data.numpy()
                    else:
                        A = x_mean_pred.cpu().data.numpy()
                        B = label.cpu().data.numpy()
                else:
                    A = x_mean_pred.cpu().data.numpy()
                    B = x.cpu().data.numpy()

                rmse = np.sqrt((np.square(A - B)).mean(axis=None))
                mae = np.abs(A - B).mean(axis=None)

                elbo_running_rmse.update(rmse)
                elbo_running_mae.update(mae)

                train_rmse.append(elbo_running_rmse.avg)
                train_mae.append(elbo_running_mae.avg)

                print(
                    '[iteration %03d] time: %.2f \tbeta %.2f \tlambda %.2f training ELBO: %.4f (%.4f) RMSE: %.4f (%.4f) MAE: %.4f (%.4f)'
                    % (iteration, time.time() - batch_time, vae.beta, vae.lamb,
                       elbo_running_mean.val, elbo_running_mean.avg,
                       elbo_running_rmse.val, elbo_running_rmse.avg,
                       elbo_running_mae.val, elbo_running_mae.avg))

                utils.save_checkpoint(
                    {
                        'state_dict': vae.state_dict(),
                        'args': args
                    }, args.save, 0)

                print("max pred:", np.max(A), "max input:", np.max(B),
                      "min pred:", np.min(A), "min input:", np.min(B))

    if (args.problem == 'HEP_SL'):
        x_test = train_loader.dataset.imgs_test
        x_test = x_test.cuda()
        y_test = train_loader.dataset.lens_p_test
        y_test = y_test.cuda()
        label_test = train_loader.dataset.label_test
        label_test = label_test.cuda()

    utils.save_checkpoint({
        'state_dict': vae.state_dict(),
        'args': args
    }, args.save, 0)
    name_save = args.name_save

    Viz_plot.Convergence_plot(train_elbo, train_rmse, train_mae, name_save,
                              args.save)
    Viz_plot.display_samples_pred_mlp(vae, x_test, y_test, label_test,
                                      args.problem, args.VIB, name_save,
                                      args.UQ, args.classification, args.save,
                                      img_max)

    # Report statistics after training
    vae.eval()
    return vae
Beispiel #14
0
    wandb.config.checkpoint_id = checkpoint_id
    wandb.config.update(args)

if args.manual_seed is None:
    args.manual_seed = random.randint(1, 100000)
random.seed(args.manual_seed)
torch.manual_seed(args.manual_seed)
np.random.seed(args.manual_seed)

if args.cuda:
    # gpu device number
    torch.cuda.set_device(args.gpu_num)

kwargs = {'num_workers': 0, 'pin_memory': True} if args.cuda else {}

nfef_meter = utils.RunningAverageMeter(0.93)
nfeb_meter = utils.RunningAverageMeter(0.93)


def run(args, kwargs):
    # ==================================================================================================================
    # SNAPSHOTS
    # ==================================================================================================================
    args.model_signature = str(datetime.datetime.now())[0:19].replace(' ', '_')
    args.model_signature = args.model_signature.replace(':', '_')

    snapshots_path = os.path.join(args.out_dir, 'vae_' + args.dataset + '_')
    snap_dir = snapshots_path + args.flow

    if args.flow != 'no_flow':
        snap_dir += '_' + 'num_flows_' + str(args.num_flows)
Beispiel #15
0
def main():
    # parse command line arguments
    parser = argparse.ArgumentParser(description="parse args")
    parser.add_argument(
        '-d',
        '--dataset',
        default='shapes',
        type=str,
        help='dataset name',
        choices=['shapes', 'faces', 'celeba', 'cars3d', '3dchairs'])
    parser.add_argument('-dist',
                        default='normal',
                        type=str,
                        choices=['normal', 'lpnorm', 'lpnested'])
    parser.add_argument('-n',
                        '--num-epochs',
                        default=50,
                        type=int,
                        help='number of training epochs')
    parser.add_argument(
        '--num-iterations',
        default=0,
        type=int,
        help='number of iterations (overrides number of epochs if >0)')
    parser.add_argument('-b',
                        '--batch-size',
                        default=2048,
                        type=int,
                        help='batch size')
    parser.add_argument('-l',
                        '--learning-rate',
                        default=1e-3,
                        type=float,
                        help='learning rate')
    parser.add_argument('-z',
                        '--latent-dim',
                        default=10,
                        type=int,
                        help='size of latent dimension')
    parser.add_argument('-p',
                        '--pnorm',
                        default=4.0 / 3.0,
                        type=float,
                        help='p value of the Lp-norm')
    parser.add_argument(
        '--pnested',
        default='',
        type=str,
        help=
        'nested list representation of the Lp-nested prior, e.g. [2.1, [ [2.2, [ [1.0], [1.0], [1.0], [1.0] ] ], [2.2, [ [1.0], [1.0], [1.0], [1.0] ] ], [2.2, [ [1.0], [1.0], [1.0], [1.0] ] ] ] ]'
    )
    parser.add_argument(
        '--isa',
        default='',
        type=str,
        help=
        'shorthand notation of ISA Lp-nested norm, e.g. [2.1, [(2.2, 4), (2.2, 4), (2.2, 4)]]'
    )
    parser.add_argument('--p0', default=2.0, type=float, help='p0 of ISA')
    parser.add_argument('--p1', default=2.1, type=float, help='p1 of ISA')
    parser.add_argument('--n1', default=6, type=int, help='n1 of ISA')
    parser.add_argument('--p2', default=2.1, type=float, help='p2 of ISA')
    parser.add_argument('--n2', default=6, type=int, help='n2 of ISA')
    parser.add_argument('--p3', default=2.1, type=float, help='p3 of ISA')
    parser.add_argument('--n3', default=6, type=int, help='n3 of ISA')
    parser.add_argument('--scale',
                        default=1.0,
                        type=float,
                        help='scale of LpNested distribution')
    parser.add_argument('--q-dist',
                        default='normal',
                        type=str,
                        choices=['normal', 'laplace'])
    parser.add_argument('--x-dist',
                        default='bernoulli',
                        type=str,
                        choices=['normal', 'bernoulli'])
    parser.add_argument('--beta',
                        default=1,
                        type=float,
                        help='ELBO penalty term')
    parser.add_argument('--tcvae', action='store_true')
    parser.add_argument('--exclude-mutinfo', action='store_true')
    parser.add_argument('--beta-anneal', action='store_true')
    parser.add_argument('--lambda-anneal', action='store_true')
    parser.add_argument('--mss',
                        action='store_true',
                        help='use the improved minibatch estimator')
    parser.add_argument('--conv', action='store_true')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--visdom',
                        action='store_true',
                        help='whether plotting in visdom is desired')
    parser.add_argument('--save', default='test1')
    parser.add_argument('--id', default='1')
    parser.add_argument(
        '--seed',
        default=-1,
        type=int,
        help=
        'seed for pytorch and numpy random number generator to allow reproducibility (default/-1: use random seed)'
    )
    parser.add_argument('--log_freq',
                        default=200,
                        type=int,
                        help='num iterations per log')
    parser.add_argument('--use-mse-loss', action='store_true')
    parser.add_argument('--mse-sigma',
                        default=0.01,
                        type=float,
                        help='sigma of mean squared error loss')
    parser.add_argument('--dip', action='store_true', help='use DIP-VAE')
    parser.add_argument('--dip-type',
                        default=1,
                        type=int,
                        help='DIP type (1 or 2)')
    parser.add_argument('--lambda-od',
                        default=2.0,
                        type=float,
                        help='DIP: lambda weight off-diagonal')
    parser.add_argument('--clip',
                        default=0.0,
                        type=float,
                        help='Gradient clipping (0 disabled)')
    parser.add_argument('--test', action='store_true', help='run test')
    parser.add_argument(
        '--trainingsetsize',
        default=0,
        type=int,
        help='Subsample the trainingset (0 use original training data)')
    args = parser.parse_args()

    # initialize seeds for reproducibility
    if not args.seed == -1:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.gpu != -1:
        print('Using CUDA device {}'.format(args.gpu))
        torch.cuda.set_device(args.gpu)
        use_cuda = True
    else:
        print('CUDA disabled')
        use_cuda = False

    # data loader
    train_loader = setup_data_loaders(args.dataset,
                                      args.batch_size,
                                      use_cuda=use_cuda,
                                      len_subset=args.trainingsetsize)

    # setup the VAE
    if args.dist == 'normal':
        prior_dist = dist.Normal()
    elif args.dist == 'laplace':
        prior_dist = dist.Laplace()
    elif args.dist == 'lpnested':
        if not args.isa == '':
            pnested = parseISA(ast.literal_eval(args.isa))
        elif not args.pnested == '':
            pnested = ast.literal_eval(args.pnested)
        else:
            pnested = parseISA([
                args.p0,
                [(args.p1, args.n1), (args.p2, args.n2), (args.p3, args.n3)]
            ])

        print('using Lp-nested prior, pnested = ({}) {}'.format(
            type(pnested), pnested))
        prior_dist = LpNestedAdapter(p=pnested, scale=args.scale)
        args.latent_dim = prior_dist.dimz()
        print('using Lp-nested prior, changed latent dimension to {}'.format(
            args.latent_dim))
    elif args.dist == 'lpnorm':
        prior_dist = LpNestedAdapter(p=[args.pnorm, [[1.0]] * args.latent_dim],
                                     scale=args.scale)

    if args.q_dist == 'normal':
        q_dist = dist.Normal()
    elif args.q_dist == 'laplace':
        q_dist = dist.Laplace()

    if args.x_dist == 'normal':
        x_dist = dist.Normal(sigma=args.mse_sigma)
    elif args.x_dist == 'bernoulli':
        x_dist = dist.Bernoulli()

    if args.dip_type == 1:
        lambda_d = 10.0 * args.lambda_od
    else:
        lambda_d = args.lambda_od

    vae = VAE(z_dim=args.latent_dim,
              use_cuda=use_cuda,
              prior_dist=prior_dist,
              q_dist=q_dist,
              x_dist=x_dist,
              include_mutinfo=not args.exclude_mutinfo,
              tcvae=args.tcvae,
              conv=args.conv,
              mss=args.mss,
              dataset=args.dataset,
              mse_sigma=args.mse_sigma,
              DIP=args.dip,
              DIP_type=args.dip_type,
              lambda_od=args.lambda_od,
              lambda_d=lambda_d)

    # setup the optimizer
    optimizer = optim.Adam([{
        'params': vae.parameters()
    }, {
        'params': prior_dist.parameters(),
        'lr': 5e-4
    }],
                           lr=args.learning_rate)

    # setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom(env=args.save, port=4500)

    train_elbo = []

    # training loop
    dataset_size = len(train_loader.dataset)
    if args.num_iterations == 0:
        num_iterations = len(train_loader) * args.num_epochs
    else:
        num_iterations = args.num_iterations
    iteration = 0
    obj_best_snapshot = float('-inf')
    best_checkpoint_updated = False

    trainingcurve_filename = os.path.join(args.save, 'trainingcurve.csv')
    if not os.path.exists(trainingcurve_filename):
        with open(trainingcurve_filename, 'w') as fd:
            fd.write(
                'iteration,num_iterations,time,elbo_running_mean_val,elbo_running_mean_avg\n'
            )

    # initialize loss accumulator
    elbo_running_mean = utils.RunningAverageMeter()
    nan_detected = False
    while iteration < num_iterations and not nan_detected:
        for i, x in enumerate(train_loader):
            iteration += 1
            batch_time = time.time()
            vae.train()
            anneal_kl(args, vae, iteration)
            optimizer.zero_grad()
            # transfer to GPU
            if use_cuda:
                x = x.cuda()  # async=True)
            # wrap the mini-batch in a PyTorch Variable
            x = Variable(x)
            # do ELBO gradient and accumulate loss
            #with autograd.detect_anomaly():
            obj, elbo, logpx = vae.elbo(prior_dist,
                                        x,
                                        dataset_size,
                                        use_mse_loss=args.use_mse_loss,
                                        mse_sigma=args.mse_sigma)
            if utils.isnan(obj).any():
                print('NaN spotted in objective.')
                print('lpnested: {}'.format(prior_dist.prior.p))
                print("gradient abs max {}".format(
                    max([g.abs().max() for g in gradients])))
                #raise ValueError('NaN spotted in objective.')
                nan_detected = True
                break
            elbo_running_mean.update(elbo.mean().item())

            # save checkpoint of best ELBO
            if obj.mean().item() > obj_best_snapshot:
                obj_best_snapshot = obj.mean().item()
                best_checkpoint = {
                    'state_dict': vae.state_dict(),
                    'state_dict_prior_dist': prior_dist.state_dict(),
                    'args': args,
                    'iteration': iteration,
                    'obj': obj_best_snapshot,
                    'elbo': elbo.mean().item(),
                    'logpx': logpx.mean().item()
                }
                best_checkpoint_updated = True

            #with autograd.detect_anomaly():
            obj.mean().mul(-1).backward()

            gradients = list(
                filter(lambda p: p.grad is not None, vae.parameters()))

            if args.clip > 0:
                torch.nn.utils.clip_grad_norm_(vae.parameters(), args.clip)

            optimizer.step()

            # report training diagnostics
            if iteration % args.log_freq == 0:
                train_elbo.append(elbo_running_mean.avg)
                time_ = time.time() - batch_time
                print(
                    '[iteration %03d/%03d] time: %.2f \tbeta %.2f \tlambda %.2f \tobj %.4f \tlogpx %.4f training ELBO: %.4f (%.4f)'
                    % (iteration, num_iterations, time_, vae.beta, vae.lamb,
                       obj.mean().item(), logpx.mean().item(),
                       elbo_running_mean.val, elbo_running_mean.avg))

                p0, p1list = backwardsParseISA(prior_dist.prior.p)
                print('lpnested: {}, {}'.format(p0, p1list))
                print("gradient abs max {}".format(
                    max([g.abs().max() for g in gradients])))

                with open(os.path.join(args.save, 'trainingcurve.csv'),
                          'a') as fd:
                    fd.write('{},{},{},{},{}\n'.format(iteration,
                                                       num_iterations, time_,
                                                       elbo_running_mean.val,
                                                       elbo_running_mean.avg))

                if best_checkpoint_updated:
                    print(
                        'Update best checkpoint [iteration %03d] training ELBO: %.4f'
                        % (best_checkpoint['iteration'],
                           best_checkpoint['elbo']))
                    utils.save_checkpoint(best_checkpoint, args.save, 0)
                    best_checkpoint_updated = False

                vae.eval()
                prior_dist.eval()

                # plot training and test ELBOs
                if args.visdom:
                    if args.dataset == 'celeba':
                        num_channels = 3
                    else:
                        num_channels = 1
                    display_samples(vae, prior_dist, x, vis, num_channels)
                    plot_elbo(train_elbo, vis)

                if iteration % (10 * args.log_freq) == 0:
                    utils.save_checkpoint(
                        {
                            'state_dict': vae.state_dict(),
                            'state_dict_prior_dist': prior_dist.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'args': args,
                            'iteration': iteration,
                            'obj': obj.mean().item(),
                            'torch_random_state': torch.get_rng_state(),
                            'numpy_random_state': np.random.get_state()
                        },
                        args.save,
                        prefix='latest-optimizer-model-')
                    if not args.dataset == 'celeba' and not args.dataset == '3dchairs':
                        eval('plot_vs_gt_' + args.dataset)(
                            vae, train_loader.dataset,
                            os.path.join(
                                args.save,
                                'gt_vs_latent_{:05d}.png'.format(iteration)))

    # Report statistics of best snapshot after training
    vae.load_state_dict(best_checkpoint['state_dict'])
    prior_dist.load_state_dict(best_checkpoint['state_dict_prior_dist'])

    vae.eval()
    prior_dist.eval()

    if args.dataset == 'shapes':
        data_set = dset.Shapes()
    elif args.dataset == 'faces':
        data_set = dset.Faces()
    elif args.dataset == 'cars3d':
        data_set = dset.Cars3d()
    elif args.dataset == 'celeba':
        data_set = dset.CelebA()
    elif args.dataset == '3dchairs':
        data_set = dset.Chairs()
    else:
        raise ValueError('Unknown dataset ' + str(args.dataset))

    print("loaded dataset {} of size {}".format(args.dataset, len(data_set)))

    dataset_loader = DataLoader(data_set,
                                batch_size=1000,
                                num_workers=0,
                                shuffle=False)

    logpx, dependence, information, dimwise_kl, analytical_cond_kl, elbo_marginal_entropies, elbo_joint_entropy = \
        elbo_decomposition(vae, prior_dist, dataset_loader)
    torch.save(
        {
            'args': args,
            'logpx': logpx,
            'dependence': dependence,
            'information': information,
            'dimwise_kl': dimwise_kl,
            'analytical_cond_kl': analytical_cond_kl,
            'marginal_entropies': elbo_marginal_entropies,
            'joint_entropy': elbo_joint_entropy
        }, os.path.join(args.save, 'elbo_decomposition.pth'))
    print('logpx: {:.2f}'.format(logpx))
    if not args.dataset == 'celeba' and not args.dataset == '3dchairs':
        eval('plot_vs_gt_' + args.dataset)(vae, dataset_loader.dataset,
                                           os.path.join(
                                               args.save, 'gt_vs_latent.png'))

        metric, metric_marginal_entropies, metric_cond_entropies = eval(
            'disentanglement_metrics.mutual_info_metric_' + args.dataset)(
                vae, dataset_loader.dataset)
        torch.save(
            {
                'args': args,
                'metric': metric,
                'marginal_entropies': metric_marginal_entropies,
                'cond_entropies': metric_cond_entropies,
            }, os.path.join(args.save, 'disentanglement_metric.pth'))
        print('MIG: {:.2f}'.format(metric))

        if args.dist == 'lpnested':
            p0, p1list = backwardsParseISA(prior_dist.prior.p)
            print('p0: {}'.format(p0))
            print('p1: {}'.format(p1list))
            torch.save(
                {
                    'args': args,
                    'logpx': logpx,
                    'dependence': dependence,
                    'information': information,
                    'dimwise_kl': dimwise_kl,
                    'analytical_cond_kl': analytical_cond_kl,
                    'elbo_marginal_entropies': elbo_marginal_entropies,
                    'elbo_joint_entropy': elbo_joint_entropy,
                    'metric': metric,
                    'metric_marginal_entropies': metric_marginal_entropies,
                    'metric_cond_entropies': metric_cond_entropies,
                    'p0': p0,
                    'p1': p1list
                }, os.path.join(args.save, 'combined_data.pth'))
        else:
            torch.save(
                {
                    'args': args,
                    'logpx': logpx,
                    'dependence': dependence,
                    'information': information,
                    'dimwise_kl': dimwise_kl,
                    'analytical_cond_kl': analytical_cond_kl,
                    'elbo_marginal_entropies': elbo_marginal_entropies,
                    'elbo_joint_entropy': elbo_joint_entropy,
                    'metric': metric,
                    'metric_marginal_entropies': metric_marginal_entropies,
                    'metric_cond_entropies': metric_cond_entropies,
                }, os.path.join(args.save, 'combined_data.pth'))

        if args.dist == 'lpnested':
            if args.dataset == 'shapes':
                eval('plot_vs_gt_' + args.dataset)(
                    vae,
                    dataset_loader.dataset,
                    os.path.join(args.save, 'gt_vs_grouped_latent.png'),
                    eval_subspaces=True)

                metric_subspaces, metric_marginal_entropies_subspaces, metric_cond_entropies_subspaces = eval(
                    'disentanglement_metrics.mutual_info_metric_' +
                    args.dataset)(vae,
                                  dataset_loader.dataset,
                                  eval_subspaces=True)
                torch.save(
                    {
                        'args': args,
                        'metric': metric_subspaces,
                        'marginal_entropies':
                        metric_marginal_entropies_subspaces,
                        'cond_entropies': metric_cond_entropies_subspaces,
                    },
                    os.path.join(args.save,
                                 'disentanglement_metric_subspaces.pth'))
                print('MIG grouped by subspaces: {:.2f}'.format(
                    metric_subspaces))

                torch.save(
                    {
                        'args': args,
                        'logpx': logpx,
                        'dependence': dependence,
                        'information': information,
                        'dimwise_kl': dimwise_kl,
                        'analytical_cond_kl': analytical_cond_kl,
                        'elbo_marginal_entropies': elbo_marginal_entropies,
                        'elbo_joint_entropy': elbo_joint_entropy,
                        'metric': metric,
                        'metric_marginal_entropies': metric_marginal_entropies,
                        'metric_cond_entropies': metric_cond_entropies,
                        'metric_subspaces': metric_subspaces,
                        'metric_marginal_entropies_subspaces':
                        metric_marginal_entropies_subspaces,
                        'metric_cond_entropies_subspaces':
                        metric_cond_entropies_subspaces,
                        'p0': p0,
                        'p1': p1list
                    }, os.path.join(args.save, 'combined_data.pth'))

    return vae
Beispiel #16
0
def main(rank, world_size, args):
    setup(rank, world_size, args.port)

    # setup logger
    if rank == 0:
        utils.makedirs(args.save)
        logger = utils.get_logger(os.path.join(args.save, "logs"))

    def mprint(msg):
        if rank == 0:
            logger.info(msg)

    mprint(args)

    device = torch.device(
        f'cuda:{rank}' if torch.cuda.is_available() else 'cpu')

    if device.type == 'cuda':
        mprint('Found {} CUDA devices.'.format(torch.cuda.device_count()))
        for i in range(torch.cuda.device_count()):
            props = torch.cuda.get_device_properties(i)
            mprint('{} \t Memory: {:.2f}GB'.format(
                props.name, props.total_memory / (1024**3)))
    else:
        mprint('WARNING: Using device {}'.format(device))

    np.random.seed(args.seed + rank)
    torch.manual_seed(args.seed + rank)
    if device.type == 'cuda':
        torch.cuda.manual_seed(args.seed + rank)

    mprint('Loading dataset {}'.format(args.data))
    # Dataset and hyperparameters
    if args.data == 'cifar10':
        im_dim = 3

        transform_train = transforms.Compose([
            transforms.Resize(args.imagesize),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            add_noise if args.add_noise else identity,
        ])
        transform_test = transforms.Compose([
            transforms.Resize(args.imagesize),
            transforms.ToTensor(),
            add_noise if args.add_noise else identity,
        ])

        init_layer = flows.LogitTransform(0.05)
        train_set = vdsets.SVHN(args.dataroot,
                                download=True,
                                split="train",
                                transform=transform_train)
        sampler = torch.utils.data.distributed.DistributedSampler(train_set)
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=args.batchsize,
            sampler=sampler,
        )
        test_loader = torch.utils.data.DataLoader(
            vdsets.SVHN(args.dataroot,
                        download=True,
                        split="test",
                        transform=transform_test),
            batch_size=args.val_batchsize,
            shuffle=False,
        )

    elif args.data == 'mnist':
        im_dim = 1
        init_layer = flows.LogitTransform(1e-6)
        train_set = datasets.MNIST(
            args.dataroot,
            train=True,
            transform=transforms.Compose([
                transforms.Resize(args.imagesize),
                transforms.ToTensor(),
                add_noise if args.add_noise else identity,
            ]))
        sampler = torch.utils.data.distributed.DistributedSampler(train_set)
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=args.batchsize,
            sampler=sampler,
        )
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(args.dataroot,
                           train=False,
                           transform=transforms.Compose([
                               transforms.Resize(args.imagesize),
                               transforms.ToTensor(),
                               add_noise if args.add_noise else identity,
                           ])),
            batch_size=args.val_batchsize,
            shuffle=False,
        )
    else:
        raise Exception(f'dataset not one of mnist / cifar10, got {args.data}')

    mprint('Dataset loaded.')
    mprint('Creating model.')

    input_size = (args.batchsize, im_dim, args.imagesize, args.imagesize)

    model = MultiscaleFlow(
        input_size,
        block_fn=partial(cpflow_block_fn,
                         block_type=args.block_type,
                         dimh=args.dimh,
                         num_hidden_layers=args.num_hidden_layers,
                         icnn_version=args.icnn,
                         num_pooling=args.num_pooling),
        n_blocks=list(map(int, args.nblocks.split('-'))),
        factor_out=args.factor_out,
        init_layer=init_layer,
        actnorm=args.actnorm,
        fc_end=args.fc_end,
        glow=args.glow,
    )
    model.to(device)

    model = DDP(model, device_ids=[rank], find_unused_parameters=True)
    ema = utils.ExponentialMovingAverage(model)

    mprint(model)
    mprint('EMA: {}'.format(ema))

    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           betas=(0.9, 0.99),
                           weight_decay=args.wd)

    # Saving and resuming
    best_test_bpd = math.inf
    begin_epoch = 0

    most_recent_path = os.path.join(args.save, 'models', 'most_recent.pth')
    checkpt_exists = os.path.exists(most_recent_path)
    if checkpt_exists:
        mprint(f"Resuming from {most_recent_path}")

        # deal with data-dependent initialization like actnorm.
        with torch.no_grad():
            x = torch.rand(8, *input_size[1:]).to(device)
            model(x)

        checkpt = torch.load(most_recent_path)
        begin_epoch = checkpt["epoch"] + 1

        model.module.load_state_dict(checkpt["state_dict"])
        ema.set(checkpt['ema'])
        optimizer.load_state_dict(checkpt["opt_state_dict"])
    elif args.resume:
        mprint(f"Resuming from {args.resume}")

        # deal with data-dependent initialization like actnorm.
        with torch.no_grad():
            x = torch.rand(8, *input_size[1:]).to(device)
            model(x)

        checkpt = torch.load(args.resume)
        begin_epoch = checkpt["epoch"] + 1

        model.module.load_state_dict(checkpt["state_dict"])
        ema.set(checkpt['ema'])
        optimizer.load_state_dict(checkpt["opt_state_dict"])

    mprint(optimizer)

    batch_time = utils.RunningAverageMeter(0.97)
    bpd_meter = utils.RunningAverageMeter(0.97)
    gnorm_meter = utils.RunningAverageMeter(0.97)
    cg_meter = utils.RunningAverageMeter(0.97)
    hnorm_meter = utils.RunningAverageMeter(0.97)

    update_lr(optimizer, 0, args)

    # for visualization
    fixed_x = next(iter(train_loader))[0][:8].to(device)
    fixed_z = torch.randn(8,
                          im_dim * args.imagesize * args.imagesize).to(fixed_x)
    if rank == 0:
        utils.makedirs(os.path.join(args.save, 'figs'))
        # visualize(model, fixed_x, fixed_z, os.path.join(args.save, 'figs', 'init.png'))
    for epoch in range(begin_epoch, args.nepochs):
        sampler.set_epoch(epoch)
        flows.CG_ITERS_TRACER.clear()
        flows.HESS_NORM_TRACER.clear()
        mprint('Current LR {}'.format(optimizer.param_groups[0]['lr']))
        train(epoch, train_loader, model, optimizer, bpd_meter, gnorm_meter,
              cg_meter, hnorm_meter, batch_time, ema, device, mprint,
              world_size, args)
        val_time, test_bpd = validate(epoch, model, test_loader, ema, device)
        mprint(
            'Epoch: [{0}]\tTime {1:.2f} | Test bits/dim {test_bpd:.4f}'.format(
                epoch, val_time, test_bpd=test_bpd))

        if rank == 0:
            utils.makedirs(os.path.join(args.save, 'figs'))
            visualize(model, fixed_x, fixed_z,
                      os.path.join(args.save, 'figs', f'{epoch}.png'))

            utils.makedirs(os.path.join(args.save, "models"))
            if test_bpd < best_test_bpd:
                best_test_bpd = test_bpd
                torch.save(
                    {
                        'epoch': epoch,
                        'state_dict': model.module.state_dict(),
                        'opt_state_dict': optimizer.state_dict(),
                        'args': args,
                        'ema': ema,
                        'test_bpd': test_bpd,
                    }, os.path.join(args.save, 'models', 'best_model.pth'))

        if rank == 0:
            torch.save(
                {
                    'epoch': epoch,
                    'state_dict': model.module.state_dict(),
                    'opt_state_dict': optimizer.state_dict(),
                    'args': args,
                    'ema': ema,
                    'test_bpd': test_bpd,
                }, os.path.join(args.save, 'models', 'most_recent.pth'))

    cleanup()
Beispiel #17
0
def main():
    args = get_config()

    # prepare path
    utils.clear_dir(args.tmp_path)
    utils.clear_dir(args.save_path)
    print('[*] Clear path')

    # choose device
    if args.gpu >= 0:
        device = torch.device('cuda:%d' % args.gpu)
        print('[*] Choose cuda:%d as device' % args.gpu)
    else:
        device = torch.device('cpu')
        print('[*] Choose cpu as device')

    # load training set and evaluation set
    train_set = PolyMRDataset(size=args.image_size, type=args.dataset)
    train_loader = DataLoader(dataset=train_set,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=0)
    valid_set = PolyMRDataset(size=args.image_size,
                              type=args.dataset,
                              set='val')
    valid_loader = DataLoader(dataset=valid_set,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=0)
    print('[*] Load datasets')
    img_size = train_set.size
    dataset_size_train = len(train_set)
    dataset_size_valid = len(valid_set)

    # setup the model
    vae = ReasonNet(args, device)
    vae = vae.to(device)
    print('[*] Load model')

    # setup the optimizer
    optimizer = optim.Adam(vae.parameters(), lr=args.learning_rate)

    # training loop
    iteration = 0
    # initialize loss accumulator
    elbo_running_mean = utils.RunningAverageMeter()
    # record best elbo and epoch
    best_elbo = -1e6
    best_epoch = -1
    for epoch in range(args.num_epochs):
        vae.train()
        for i, (x, t) in enumerate(train_loader):
            optimizer.zero_grad()
            x = x.view(-1, 1, img_size, img_size).to(device)
            t = t.view(-1, 1, img_size, img_size).to(device)
            obj, recon = vae(x, t, dataset_size_train)
            obj.mul(-1).backward()
            torch.nn.utils.clip_grad_norm_(vae.parameters(), 10.0)
            optimizer.step()
            iteration += 1

        vae.eval()
        elbo_running_mean.reset()
        with torch.no_grad():
            for i, (x, t) in enumerate(valid_loader):
                x = x.view(-1, 1, img_size, img_size).to(device)
                t = t.view(-1, 1, img_size, img_size).to(device)
                obj, elbo, recon = vae.evaluate(x, t, dataset_size_valid)
                elbo_running_mean.update(elbo)

            avg_elbo = elbo_running_mean.get_avg()['elbo']
            if avg_elbo > best_elbo:
                best_epoch = epoch
                utils.save_checkpoint(
                    {
                        'state_dict': vae.state_dict(),
                        'args': args,
                        'epoch': epoch
                    }, args.save_path)
                best_elbo = avg_elbo

            if epoch % args.log_freq == 0:
                elbo_running_mean.log(epoch, args.num_epochs, best_epoch)
Beispiel #18
0
def main():
    #os.system('shutdown -c')  # cancel previous shutdown command

    if write_log:
        utils.makedirs(args.save)
        logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'),
                                  filepath=os.path.abspath(__file__))

        logger.info(args)

        args_file_path = os.path.join(args.save, 'args.yaml')
        with open(args_file_path, 'w') as f:
            yaml.dump(vars(args), f, default_flow_style=False)

    if args.distributed:
        if write_log: logger.info('Distributed initializing process group')
        torch.cuda.set_device(args.local_rank)
        distributed.init_process_group(backend=args.dist_backend,
                                       init_method=args.dist_url,
                                       world_size=dist_utils.env_world_size(),
                                       rank=env_rank())
        assert (dist_utils.env_world_size() == distributed.get_world_size())
        if write_log:
            logger.info("Distributed: success (%d/%d)" %
                        (args.local_rank, distributed.get_world_size()))

    # get deivce
    # device = torch.device("cuda:%d"%torch.cuda.current_device() if torch.cuda.is_available() else "cpu")
    device = "cpu"
    cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True)

    # load dataset
    train_loader, test_loader, data_shape = get_dataset(args)

    trainlog = os.path.join(args.save, 'training.csv')
    testlog = os.path.join(args.save, 'test.csv')

    traincolumns = [
        'itr', 'wall', 'itr_time', 'loss', 'bpd', 'fe', 'total_time',
        'grad_norm'
    ]
    testcolumns = [
        'wall', 'epoch', 'eval_time', 'bpd', 'fe', 'total_time',
        'transport_cost'
    ]

    # build model
    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = create_model(args, data_shape, regularization_fns)
    # model = model.cuda()
    if args.distributed:
        model = dist_utils.DDP(model,
                               device_ids=[args.local_rank],
                               output_device=args.local_rank)

    traincolumns = append_regularization_keys_header(traincolumns,
                                                     regularization_fns)

    if not args.resume and write_log:
        with open(trainlog, 'w') as f:
            csvlogger = csv.DictWriter(f, traincolumns)
            csvlogger.writeheader()
        with open(testlog, 'w') as f:
            csvlogger = csv.DictWriter(f, testcolumns)
            csvlogger.writeheader()

    set_cnf_options(args, model)

    if write_log: logger.info(model)
    if write_log:
        logger.info("Number of trainable parameters: {}".format(
            count_parameters(model)))
    if write_log:
        logger.info('Iters per train epoch: {}'.format(len(train_loader)))
    if write_log: logger.info('Iters per test: {}'.format(len(test_loader)))

    # optimizer
    if args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)
    elif args.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              weight_decay=args.weight_decay,
                              momentum=0.9,
                              nesterov=False)

    # restore parameters
    if args.resume is not None:
        checkpt = torch.load(
            args.resume,
            map_location=lambda storage, loc: storage.cuda(args.local_rank))
        model.load_state_dict(checkpt["state_dict"])
        if "optim_state_dict" in checkpt.keys():
            optimizer.load_state_dict(checkpt["optim_state_dict"])
            # Manually move optimizer state to device.
            for state in optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = cvt(v)

    # For visualization.
    if write_log:
        fixed_z = cvt(torch.randn(min(args.test_batch_size, 100), *data_shape))

    if write_log:
        time_meter = utils.RunningAverageMeter(0.97)
        bpd_meter = utils.RunningAverageMeter(0.97)
        loss_meter = utils.RunningAverageMeter(0.97)
        steps_meter = utils.RunningAverageMeter(0.97)
        grad_meter = utils.RunningAverageMeter(0.97)
        tt_meter = utils.RunningAverageMeter(0.97)

    if not args.resume:
        best_loss = float("inf")
        itr = 0
        wall_clock = 0.
        begin_epoch = 1
    else:
        chkdir = os.path.dirname(args.resume)
        tedf = pd.read_csv(os.path.join(chkdir, 'test.csv'))
        trdf = pd.read_csv(os.path.join(chkdir, 'training.csv'))
        wall_clock = trdf['wall'].to_numpy()[-1]
        itr = trdf['itr'].to_numpy()[-1]
        best_loss = tedf['bpd'].min()
        begin_epoch = int(tedf['epoch'].to_numpy()[-1] +
                          1)  # not exactly correct

    if args.distributed:
        if write_log: logger.info('Syncing machines before training')
        dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda())

    for epoch in range(begin_epoch, args.num_epochs + 1):
        if not args.validate:
            model.train()

            with open(trainlog, 'a') as f:
                if write_log: csvlogger = csv.DictWriter(f, traincolumns)

                for _, (x, y) in enumerate(train_loader):
                    start = time.time()
                    update_lr(optimizer, itr)
                    optimizer.zero_grad()

                    # cast data and move to device
                    x = add_noise(cvt(x), nbits=args.nbits)
                    #x = x.clamp_(min=0, max=1)
                    # compute loss
                    bpd, (x, z), reg_states = compute_bits_per_dim(x, model)
                    if np.isnan(bpd.data.item()):
                        raise ValueError('model returned nan during training')
                    elif np.isinf(bpd.data.item()):
                        raise ValueError('model returned inf during training')

                    loss = bpd
                    if regularization_coeffs:
                        reg_loss = sum(reg_state * coeff
                                       for reg_state, coeff in zip(
                                           reg_states, regularization_coeffs)
                                       if coeff != 0)
                        loss = loss + reg_loss
                    total_time = count_total_time(model)

                    loss.backward()
                    nfe_opt = count_nfe(model)
                    if write_log: steps_meter.update(nfe_opt)
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(), args.max_grad_norm)

                    optimizer.step()

                    itr_time = time.time() - start
                    wall_clock += itr_time

                    batch_size = x.size(0)
                    metrics = torch.tensor([
                        1., batch_size,
                        loss.item(),
                        bpd.item(), nfe_opt, grad_norm, *reg_states
                    ]).float()

                    rv = tuple(torch.tensor(0.) for r in reg_states)

                    total_gpus, batch_total, r_loss, r_bpd, r_nfe, r_grad_norm, *rv = dist_utils.sum_tensor(
                        metrics).cpu().numpy()

                    if write_log:
                        time_meter.update(itr_time)
                        bpd_meter.update(r_bpd / total_gpus)
                        loss_meter.update(r_loss / total_gpus)
                        grad_meter.update(r_grad_norm / total_gpus)
                        tt_meter.update(total_time)

                        fmt = '{:.4f}'
                        logdict = {
                            'itr': itr,
                            'wall': fmt.format(wall_clock),
                            'itr_time': fmt.format(itr_time),
                            'loss': fmt.format(r_loss / total_gpus),
                            'bpd': fmt.format(r_bpd / total_gpus),
                            'total_time': fmt.format(total_time),
                            'fe': r_nfe / total_gpus,
                            'grad_norm': fmt.format(r_grad_norm / total_gpus),
                        }
                        if regularization_coeffs:
                            rv = tuple(v_ / total_gpus for v_ in rv)
                            logdict = append_regularization_csv_dict(
                                logdict, regularization_fns, rv)
                        csvlogger.writerow(logdict)

                        if itr % args.log_freq == 0:
                            log_message = (
                                "Itr {:06d} | Wall {:.3e}({:.2f}) | "
                                "Time/Itr {:.2f}({:.2f}) | BPD {:.2f}({:.2f}) | "
                                "Loss {:.2f}({:.2f}) | "
                                "FE {:.0f}({:.0f}) | Grad Norm {:.3e}({:.3e}) | "
                                "TT {:.2f}({:.2f})".format(
                                    itr, wall_clock, wall_clock / (itr + 1),
                                    time_meter.val, time_meter.avg,
                                    bpd_meter.val, bpd_meter.avg,
                                    loss_meter.val, loss_meter.avg,
                                    steps_meter.val, steps_meter.avg,
                                    grad_meter.val, grad_meter.avg,
                                    tt_meter.val, tt_meter.avg))
                            if regularization_coeffs:
                                log_message = append_regularization_to_log(
                                    log_message, regularization_fns, rv)
                            logger.info(log_message)

                    itr += 1

        # compute test loss
        model.eval()
        if args.local_rank == 0:
            utils.makedirs(args.save)
            torch.save(
                {
                    "args":
                    args,
                    "state_dict":
                    model.module.state_dict()
                    if torch.cuda.is_available() else model.state_dict(),
                    "optim_state_dict":
                    optimizer.state_dict(),
                    "fixed_z":
                    fixed_z.cpu()
                }, os.path.join(args.save, "checkpt.pth"))
        if epoch % args.val_freq == 0 or args.validate:
            with open(testlog, 'a') as f:
                if write_log: csvlogger = csv.DictWriter(f, testcolumns)
                with torch.no_grad():
                    start = time.time()
                    if write_log: logger.info("validating...")

                    lossmean = 0.
                    meandist = 0.
                    steps = 0
                    tt = 0.
                    for i, (x, y) in enumerate(test_loader):
                        sh = x.shape
                        x = shift(cvt(x), nbits=args.nbits)
                        loss, (x, z), _ = compute_bits_per_dim(x, model)
                        dist = (x.view(x.size(0), -1) -
                                z).pow(2).mean(dim=-1).mean()
                        meandist = i / (i + 1) * dist + meandist / (i + 1)
                        lossmean = i / (i + 1) * lossmean + loss / (i + 1)

                        tt = i / (i + 1) * tt + count_total_time(model) / (i +
                                                                           1)
                        steps = i / (i + 1) * steps + count_nfe(model) / (i +
                                                                          1)

                    loss = lossmean.item()
                    metrics = torch.tensor([1., loss, meandist, steps]).float()

                    total_gpus, r_bpd, r_mdist, r_steps = dist_utils.sum_tensor(
                        metrics).cpu().numpy()
                    eval_time = time.time() - start

                    if write_log:
                        fmt = '{:.4f}'
                        logdict = {
                            'epoch': epoch,
                            'eval_time': fmt.format(eval_time),
                            'bpd': fmt.format(r_bpd / total_gpus),
                            'wall': fmt.format(wall_clock),
                            'total_time': fmt.format(tt),
                            'transport_cost': fmt.format(r_mdist / total_gpus),
                            'fe': '{:.2f}'.format(r_steps / total_gpus)
                        }

                        csvlogger.writerow(logdict)

                        logger.info(
                            "Epoch {:04d} | Time {:.4f}, Bit/dim {:.4f}, Steps {:.4f}, TT {:.2f}, Transport Cost {:.2e}"
                            .format(epoch, eval_time, r_bpd / total_gpus,
                                    r_steps / total_gpus, tt,
                                    r_mdist / total_gpus))

                    loss = r_bpd / total_gpus

                    if loss < best_loss and args.local_rank == 0:
                        best_loss = loss
                        shutil.copyfile(os.path.join(args.save, "checkpt.pth"),
                                        os.path.join(args.save, "best.pth"))

            # visualize samples and density
            if write_log:
                with torch.no_grad():
                    fig_filename = os.path.join(args.save, "figs",
                                                "{:04d}.jpg".format(epoch))
                    utils.makedirs(os.path.dirname(fig_filename))
                    generated_samples, _, _ = model(fixed_z, reverse=True)
                    generated_samples = generated_samples.view(-1, *data_shape)
                    nb = int(np.ceil(np.sqrt(float(fixed_z.size(0)))))
                    save_image(unshift(generated_samples, nbits=args.nbits),
                               fig_filename,
                               nrow=nb)
            if args.validate:
                break
Beispiel #19
0
def main():
    # parse command line arguments
    parser = argparse.ArgumentParser(description="parse args")
    parser.add_argument('-d', '--dataset', default='shapes', type=str, help='dataset name',
        choices=['shapes', 'faces'])
    parser.add_argument('-dist', default='normal', type=str, choices=['normal', 'laplace', 'flow'])
    parser.add_argument('-n', '--num-epochs', default=50, type=int, help='number of training epochs')
    parser.add_argument('-b', '--batch-size', default=2048, type=int, help='batch size')
    parser.add_argument('-l', '--learning-rate', default=1e-3, type=float, help='learning rate')
    parser.add_argument('-z', '--latent-dim', default=10, type=int, help='size of latent dimension')
    parser.add_argument('--beta', default=1, type=float, help='ELBO penalty term')
    parser.add_argument('--tcvae', action='store_true')
    parser.add_argument('--exclude-mutinfo', action='store_true')
    parser.add_argument('--beta-anneal', action='store_true')
    parser.add_argument('--lambda-anneal', action='store_true')
    parser.add_argument('--mss', action='store_true', help='use the improved minibatch estimator')
    parser.add_argument('--conv', action='store_true')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--visdom', action='store_true', help='whether plotting in visdom is desired')
    parser.add_argument('--save', default='test1')
    parser.add_argument('--log_freq', default=200, type=int, help='num iterations per log')
    args = parser.parse_args()

    # torch.cuda.set_device(args.gpu)

    # data loader
    train_loader = setup_data_loaders(args, use_cuda=True)

    # setup the VAE
    if args.dist == 'normal':
        prior_dist = dist.Normal()
        q_dist = dist.Normal()
    elif args.dist == 'laplace':
        prior_dist = dist.Laplace()
        q_dist = dist.Laplace()
    elif args.dist == 'flow':
        prior_dist = FactorialNormalizingFlow(dim=args.latent_dim, nsteps=32)
        q_dist = dist.Normal()

    vae = VAE(z_dim=args.latent_dim, use_cuda=True, prior_dist=prior_dist, q_dist=q_dist,
        include_mutinfo=not args.exclude_mutinfo, tcvae=args.tcvae, conv=args.conv, mss=args.mss)

    # setup the optimizer
    optimizer = optim.Adam(vae.parameters(), lr=args.learning_rate)

    # setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom(env=args.save, port=4500)

    train_elbo = []

    # training loop
    dataset_size = len(train_loader.dataset)
    num_iterations = len(train_loader) * args.num_epochs
    iteration = 0
    # initialize loss accumulator
    elbo_running_mean = utils.RunningAverageMeter()
    while iteration < num_iterations:
        for i, x in enumerate(train_loader):
            iteration += 1
            batch_time = time.time()
            vae.train()
            anneal_kl(args, vae, iteration)
            optimizer.zero_grad()
            # transfer to GPU
            x = x.cuda(async=True)
            # wrap the mini-batch in a PyTorch Variable
            x = Variable(x)
            # do ELBO gradient and accumulate loss
            obj, elbo = vae.elbo(x, dataset_size)
            if utils.isnan(obj).any():
                raise ValueError('NaN spotted in objective.')
            obj.mean().mul(-1).backward()
            print("obj value: ", obj.mean().mul(-1).cpu())
            elbo_running_mean.update(elbo.mean().item())
            optimizer.step()

            # report training diagnostics
            if iteration % args.log_freq == 0:
                train_elbo.append(elbo_running_mean.avg)
                print('[iteration %03d] time: %.2f \tbeta %.2f \tlambda %.2f training ELBO: %.4f (%.4f)' % (
                    iteration, time.time() - batch_time, vae.beta, vae.lamb,
                    elbo_running_mean.val, elbo_running_mean.avg))

                vae.eval()

                # plot training and test ELBOs
                if args.visdom:
                    display_samples(vae, x, vis)
                    plot_elbo(train_elbo, vis)

                utils.save_checkpoint({
                    'state_dict': vae.state_dict(),
                    'args': args}, args.save, 0)
                eval('plot_vs_gt_' + args.dataset)(vae, train_loader.dataset,
                    os.path.join(args.save, 'gt_vs_latent_{:05d}.png'.format(iteration)))

    # Report statistics after training
    vae.eval()
    utils.save_checkpoint({
        'state_dict': vae.state_dict(),
        'args': args}, args.save, 0)
    dataset_loader = DataLoader(train_loader.dataset, batch_size=10, num_workers=1, shuffle=False)
    logpx, dependence, information, dimwise_kl, analytical_cond_kl, marginal_entropies, joint_entropy = \
        elbo_decomposition(vae, dataset_loader)
    torch.save({
        'logpx': logpx,
        'dependence': dependence,
        'information': information,
        'dimwise_kl': dimwise_kl,
        'analytical_cond_kl': analytical_cond_kl,
        'marginal_entropies': marginal_entropies,
        'joint_entropy': joint_entropy
    }, os.path.join(args.save, 'elbo_decomposition.pth'))
    eval('plot_vs_gt_' + args.dataset)(vae, dataset_loader.dataset, os.path.join(args.save, 'gt_vs_latent.png'))
    return vae
Beispiel #20
0
def run(args, logger, train_loader, validation_loader, data_shape):

    # get device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True)

    # build model
    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = create_model(args, data_shape, regularization_fns)

    if args.spectral_norm: add_spectral_norm(model, logger)
    set_cnf_options(args, model)

    logger.info(model)
    logger.info("Number of trainable parameters: {}".format(
        count_parameters(model)))

    # optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)

    start_epoch = 0

    # restore parameters
    if args.resume is not None:
        checkpt = torch.load(args.resume,
                             map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpt["state_dict"])
        if "optim_state_dict" in checkpt.keys():
            optimizer.load_state_dict(checkpt["optim_state_dict"])
            # Manually move optimizer state to device.
            for state in optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = cvt(v)
        args = checkpt["args"]
        start_epoch = checkpt["epoch"] + 1
        logger.info("Resuming at epoch {} with args {}.".format(
            start_epoch, args))

    if torch.cuda.is_available():
        model = torch.nn.DataParallel(model).cuda()

    time_meter = utils.RunningAverageMeter(0.97)
    loss_meter = utils.RunningAverageMeter(0.97)
    steps_meter = utils.RunningAverageMeter(0.97)
    grad_meter = utils.RunningAverageMeter(0.97)
    tt_meter = utils.RunningAverageMeter(0.97)

    if args.spectral_norm and not args.resume:
        spectral_norm_power_iteration(model, 500)

    best_loss = float("inf")

    itr = 0
    train_loader_break = 10000
    validation_loader_break = 5000
    break_train = int(train_loader_break / args.batch_size)
    break_validation = int(validation_loader_break / args.batch_size)

    for epoch in range(start_epoch, args.num_epochs):
        logger.info("Epoch {}/{}".format(epoch, args.num_epochs))

        model.train()

        for idx_count, (data) in enumerate(train_loader):
            if idx_count > break_train:
                break
#
            if args.heterogen:
                x = extract(data, args)
            else:
                x, y = extract(data, args)

            start = time.time()
            update_lr(args, optimizer, itr)
            optimizer.zero_grad()

            if not args.conv:
                x = x.view(x.shape[0], -1)

            # cast data and move to device
            x = cvt(x)

            # compute loss
            loss = compute_bits_per_dim(x, model)
            if regularization_coeffs:
                reg_states = get_regularization(model, regularization_coeffs)
                reg_loss = sum(reg_state * coeff for reg_state, coeff in zip(
                    reg_states, regularization_coeffs) if coeff != 0)
                loss = loss + reg_loss
            total_time = count_total_time(model)
            loss = loss + total_time * args.time_penalty

            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)

            optimizer.step()

            if args.spectral_norm:
                spectral_norm_power_iteration(model, args.spectral_norm_niter)

            time_meter.update(time.time() - start)
            loss_meter.update(loss.item())
            steps_meter.update(count_nfe(model))
            grad_meter.update(grad_norm)
            tt_meter.update(total_time)

            if itr % args.log_freq == 0:
                write_bits_dim(args, epoch, loss_meter.avg, time_meter.avg,
                               steps_meter.avg)

                log_message = (
                    "Iter {:04d} | Time {:.4f}({:.4f}) | Bit/dim {:.4f}({:.4f}) | "
                    "Steps {:.0f}({:.2f}) | Grad Norm {:.4f}({:.4f}) | Total Time {:.2f}({:.2f})"
                    .format(itr, time_meter.val, time_meter.avg,
                            loss_meter.val, loss_meter.avg, steps_meter.val,
                            steps_meter.avg, grad_meter.val, grad_meter.avg,
                            tt_meter.val, tt_meter.avg))
                if regularization_coeffs:
                    log_message = append_regularization_to_log(
                        log_message, regularization_fns, reg_states)
                logger.info(log_message)

            itr += 1

        # Evaluate and save model
        if args.evaluate:
            if epoch % args.val_freq == 0:
                model.eval()
                with torch.no_grad():
                    start = time.time()
                    logger.info("validating...")

                    losses = []
                    losses_vec_recon_images = []
                    losses_vec_images_recon_images = []

                    for _, (data) in enumerate(validation_loader):
                        if _ > break_validation:
                            break
#
                        if args.heterogen:
                            x = extract(data, args)
                        else:
                            x, y = extract(data, args)

                        if not args.conv:
                            x = x.view(x.shape[0], -1)

                        x = cvt(x)

                        zero = torch.zeros(x.shape[0], 1).to(x)
                        z, delta_logp = model(x, zero)  # run model forward

                        recon_images = model(z, reverse=True)
                        loss = compute_bits_per_dim(x, model)
                        losses.append(loss.item())

                    #      if args.data == "piv" and args.heterogen == False:
                    #          loss_vec_recon_images, loss_vec_images_recon_images = resnet_pretrained.run(args, logger,
                    #                  recon_images, x, y, data_shape)
                    #          losses_vec_recon_images.append(loss_vec_recon_images.item())
                    #          losses_vec_images_recon_images.append(loss_vec_images_recon_images.item())
                    #
                    #
                    #  if args.data == "piv" and args.heterogen == False:
                    #      logger.info("Loss vector reconstructed images {}, Loss vector images reconstructed images {}".format(np.mean(losses_vec_recon_images),
                    #  np.mean(losses_vec_images_recon_images)))

                    loss = np.mean(losses)
                    logger.info(
                        "Epoch {:04d} | Time {:.4f}, Bit/dim {:.4f}".format(
                            epoch,
                            time.time() - start, loss))
                    if loss < best_loss:
                        best_loss = loss
                        torch.save(
                            {
                                "args":
                                args,
                                "epoch":
                                epoch,
                                "state_dict":
                                model.module.state_dict()
                                if torch.cuda.is_available() else
                                model.state_dict(),
                                "optim_state_dict":
                                optimizer.state_dict(),
                            }, os.path.join(args.save, "checkpt.pth"))
                        logger.info("Saving model at epoch {}.".format(epoch))

            # visualize samples and density
            evaluation.save_recon_images(args, model, validation_loader,
                                         data_shape, logger)

            evaluation.save_fixed_z_image(args, model, data_shape, logger)
            #  evaluation.tsne(args, model, data_shape, logger)

            if args.data == "piv":
                evaluation.tsne_x(args, model, validation_loader, data_shape,
                                  logger)

                evaluation.save_2D_manifold(args, model, data_shape,
                                            validation_loader)
Beispiel #21
0
    return loss


if __name__ == '__main__':

    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = build_model_tabular(args, 2, regularization_fns).to(device)
    if args.spectral_norm: add_spectral_norm(model)
    set_cnf_options(args, model)

    logger.info(model)
    logger.info("Number of trainable parameters: {}".format(count_parameters(model)))

    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    time_meter = utils.RunningAverageMeter(0.93)
    loss_meter = utils.RunningAverageMeter(0.93)
    nfef_meter = utils.RunningAverageMeter(0.93)
    nfeb_meter = utils.RunningAverageMeter(0.93)
    tt_meter = utils.RunningAverageMeter(0.93)

    end = time.time()
    best_loss = float('inf')
    model.train()
    for itr in range(1, args.niters + 1):
        optimizer.zero_grad()
        if args.spectral_norm: spectral_norm_power_iteration(model, 1)

        loss = compute_loss(args, model)
        loss_meter.update(loss.item())
Beispiel #22
0
        if "optim_state_dict" in checkpt.keys():
            optimizer.load_state_dict(checkpt["optim_state_dict"])
            # Manually move optimizer state to device.
            for state in optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = cvt(v)

    if torch.cuda.is_available():
        #model = torch.nn.DataParallel(model).cuda()
        model = model.cuda()

    # For visualization.
    fixed_z = cvt(torch.randn(100, *data_shape))

    time_meter = utils.RunningAverageMeter(0.97)
    loss_meter = utils.RunningAverageMeter(0.97)
    nfef_meter = utils.RunningAverageMeter(0.97)
    nfeb_meter = utils.RunningAverageMeter(0.97)
    grad_meter = utils.RunningAverageMeter(0.97)
    tt_meter = utils.RunningAverageMeter(0.97)

    if args.spectral_norm and not args.resume: spectral_norm_power_iteration(model, 500)

    best_loss = float("inf")
    itr = 0
    train_time = 0
    for epoch in range(args.begin_epoch, args.num_epochs + 1):
        model.train()
        train_loader = get_train_loader(train_set, epoch)
        if not args.evaluate:
Beispiel #23
0
def compute_p_grads(model):
    scales = 0.
    nlayers = 0
    for m in model.modules():
        if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear):
            scales = scales + m.compute_one_iter()
            nlayers += 1
    scales.mul(1 / nlayers).backward()
    for m in model.modules():
        if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear):
            if m.domain.grad is not None and torch.isnan(m.domain.grad):
                m.domain.grad = None


batch_time = utils.RunningAverageMeter(0.97)
bpd_meter = utils.RunningAverageMeter(0.97)
logpz_meter = utils.RunningAverageMeter(0.97)
deltalogp_meter = utils.RunningAverageMeter(0.97)
firmom_meter = utils.RunningAverageMeter(0.97)
secmom_meter = utils.RunningAverageMeter(0.97)
gnorm_meter = utils.RunningAverageMeter(0.97)
ce_meter = utils.RunningAverageMeter(0.97)


def train(epoch, model):

    model = parallelize(model)
    model.train()

    total = 0
Beispiel #24
0
def run(args, logger, train_loader, validation_loader, data_shape):

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model = module.HouseholderSylvesterVAE(args, data_shape)
    #  model = module.OrthogonalSylvesterVAE(args, data_shape)

    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           'min',
                                                           factor=0.2,
                                                           patience=5,
                                                           min_lr=1e-8)

    start_epoch = 0

    # restore parameters
    if args.resume is not None:
        checkpt = torch.load(args.resume,
                             map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpt["state_dict"])
        optimizer.load_state_dict(checkpt["optim_state_dict"])
        args = checkpt["args"]
        start_epoch = checkpt["epoch"] + 1
        logger.info("Resuming at epoch {} with args {}.".format(
            start_epoch, args))

    time_meter = utils.RunningAverageMeter(0.97)

    beta = args.beta
    train_loader_break = 500000
    break_train = int(train_loader_break / args.batch_size)
    break_training = 50

    best_loss = float("inf")
    itr = 0
    for epoch in range(start_epoch, args.num_epochs):
        logger.info('Epoch: {}/{} \tBeta: {}'.format(epoch, args.num_epochs,
                                                     beta))

        model.train()
        num_data = 0
        end = time.time()

        for idx_count, data in enumerate(train_loader):
            #  if idx_count > break_training:
            #  break
            if args.data == 'piv':
                x_, y_ = data['ComImages'], data['AllGenDetails']

                if args.heterogen:
                    x = torch.zeros([x_.size(0), 4, 32, 32])
                    x[:, :2, :, :] = x_
                    for idx in range(x_.size(0)):
                        u_vector = torch.zeros([1, 32, 32])
                        u_vector.fill_(y_[idx][0] / 20 * 0.5 + 0.5)

                        v_vector = torch.zeros([1, 32, 32])
                        v_vector.fill_(y_[idx][1] / 20 * 0.5 + 0.5)

                        x[idx, 2, :, :] = u_vector
                        x[idx, 3, :, :] = v_vector

                else:
                    x = x_
                    y = y_

            elif args.data == 'mnist' and args.heterogen:
                x_, y_ = data

                x = torch.zeros([x_.size(0), 2, 28, 28])
                x[:, :1, :, :] = x_
                for idx in range(x_.size(0)):
                    labels = torch.zeros([1, 28, 28])
                    labels.fill_(y_[idx] / 10)

                    x[idx, 1, :, :] = labels

            elif args.data == 'cifar10' and args.heterogen:
                x_, y_ = data

                x = torch.zeros([x_.size(0), 4, 32, 32])
                x[:, :3, :, :] = x_
                for idx in range(x_.size(0)):
                    labels = torch.zeros([1, 32, 32])
                    labels.fill_(y_[idx])

                    x[idx, 3, :, :] = labels

            else:
                x, y = data

            x = x.to(device)

            start = time.time()
            optimizer.zero_grad()

            recon_images, z_mu, z_var, ldj, z0, z_k = model(x)

            loss, rec, kl = loss_function.binary_loss_function(
                recon_images, x, z_mu, z_var, z0, z_k, ldj, beta)

            loss.backward()

            optimizer.step()

            rec = rec.item()
            kl = kl.item()
            num_data += len(x)

            batch_time = time.time() - end
            end = time.time()

            if itr % args.log_freq == 0:
                log_message = (
                    "Epoch {:03d} |  [{:5d}/{:5d} ({:2.0f}%)] | Time {:.3f} | Loss: {:11.6f} |"
                    "rec:{:11.6f} | kl: {:11.6f}".format(
                        epoch, num_data, len(train_loader.sampler),
                        100. * idx_count / len(train_loader), batch_time,
                        loss.item(), rec, kl))
                logger.info(log_message)

            itr += 1

        scheduler.step(loss.item())

        # Evaluate and save model
        if args.evaluate:
            if epoch % args.val_freq == 0:
                model.eval()
                with torch.no_grad():
                    start = time.time()
                    logger.info("validating...")

                    losses_vec_recon_images = []
                    losses_vec_images_recon_images = []
                    losses = []

                    for _, (data) in enumerate(validation_loader):

                        if _ > break_training:
                            break

                        if args.data == 'piv':
                            x_, y_ = data['ComImages'], data['AllGenDetails']

                            if args.heterogen:
                                x = torch.zeros([x_.size(0), 4, 32, 32])
                                x[:, :2, :, :] = x_
                                for idx in range(x_.size(0)):
                                    u_vector = torch.zeros([1, 32, 32])
                                    u_vector.fill_(y_[idx][0] / 20 * 0.5 + 0.5)

                                    v_vector = torch.zeros([1, 32, 32])
                                    v_vector.fill_(y_[idx][1] / 20 * 0.5 + 0.5)

                                    x[idx, 2, :, :] = u_vector
                                    x[idx, 3, :, :] = v_vector

                            else:
                                x = x_
                                y = y_

                        elif args.data == 'mnist' and args.heterogen:
                            x_, y_ = data

                            x = torch.zeros([x_.size(0), 2, 28, 28])
                            x[:, :1, :, :] = x_
                            for idx in range(x_.size(0)):
                                labels = torch.zeros([1, 28, 28])
                                labels.fill_(y_[idx] / 10)

                                x[idx, 1, :, :] = labels

                        elif args.data == 'cifar10' and args.heterogen:
                            x_, y_ = data

                            x = torch.zeros([x_.size(0), 4, 32, 32])
                            x[:, :3, :, :] = x_
                            for idx in range(x_.size(0)):
                                labels = torch.zeros([1, 32, 32])
                                labels.fill_(y_[idx])

                                x[idx, 3, :, :] = labels
                        else:
                            x, y = data

                        x = x.to(device)

                        recon_images, z_mu, z_var, ldj, z0, z_k = model(x)
                        loss, rec, kl = loss_function.binary_loss_function(
                            recon_images, x, z_mu, z_var, z0, z_k, ldj, beta)
                        losses.append(loss.item())

                        if args.data == "piv" and args.heterogen == False:
                            loss_vec_recon_images, loss_vec_images_recon_images = resnet_pretrained.run(
                                args, logger, recon_images, x, y, data_shape)
                            losses_vec_recon_images.append(
                                loss_vec_recon_images.item())
                            losses_vec_images_recon_images.append(
                                loss_vec_images_recon_images.item())

                    if args.data == "piv" and args.heterogen == False:
                        logger.info(
                            "Loss vector reconstructed images {}, Loss vector images reconstructed images {}"
                            .format(np.mean(losses_vec_recon_images),
                                    np.mean(losses_vec_images_recon_images)))

                    loss = np.mean(losses)
                    logger.info(
                        "Epoch {:04d} | Time {:.4f} | Loss {:.4f}".format(
                            epoch,
                            time.time() - start, loss))
                    if loss < best_loss:
                        best_loss = loss
                        utils.makedirs(args.save)
                        torch.save(
                            {
                                "args": args,
                                "epoch": epoch,
                                "state_dict": model.state_dict(),
                                "optim_state_dict": optimizer.state_dict(),
                            }, os.path.join(args.save, "checkpt.pth"))
                        logger.info("Saving model at epoch {}.".format(epoch))

            if beta < 1:
                beta += 0.01

            # Evaluation
            evaluation.save_recon_images(args, model, validation_loader,
                                         data_shape, logger)
            evaluation.save_fixed_z_image(args, model, data_shape, logger)