Example #1
0
def evaluate():
    model = build_model_tabular(args, 1).to(device)
    set_cnf_options(args, model)

    checkpt = torch.load(os.path.join(args.save, 'checkpt.pth'))
    model.load_state_dict(checkpt['state_dict'])
    model.to(device)

    tols = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
    errors = []
    with torch.no_grad():
        for tol in tols:
            args.rtol = tol
            args.atol = tol
            set_cnf_options(args, model)

            xx = torch.linspace(-15, 15, 500000).view(-1, 1).to(device)
            prob_xx = model_density(xx, model).double().view(-1).cpu()
            xx = xx.double().cpu().view(-1)
            dxx = torch.log(xx[1:] - xx[:-1])
            num_integral = torch.logsumexp(prob_xx[:-1] + dxx, 0).exp()
            errors.append(float(torch.abs(num_integral - 1.)))

            print(errors[-1])

    plt.figure(figsize=(5, 3))
    plt.plot(tols, errors, linewidth=3, marker='o', markersize=7)
    # plt.plot([-1, 0.2], [-1, 0.2], '--', color='grey', linewidth=1)
    plt.xscale("log", nonposx='clip')
    # plt.yscale("log", nonposy='clip')
    plt.xlabel('Solver Tolerance', fontsize=17)
    plt.ylabel('$| 1 - \int p(x) |$', fontsize=17)
    plt.tight_layout()
    plt.savefig('ode_solver_error_vs_tol.pdf')
Example #2
0
    def get_ckpt_model_and_data(args):
        # Load checkpoint.
        checkpt = torch.load(args.checkpt,
                             map_location=lambda storage, loc: storage)
        ckpt_args = checkpt['args']
        state_dict = checkpt['state_dict']

        # Construct model and restore checkpoint.
        regularization_fns, regularization_coeffs = create_regularization_fns(
            ckpt_args)
        model = build_model_tabular(ckpt_args, 2,
                                    regularization_fns).to(device)
        if ckpt_args.spectral_norm: add_spectral_norm(model)
        set_cnf_options(ckpt_args, model)

        model.load_state_dict(state_dict)
        model.to(device)

        print(model)
        print("Number of trainable parameters: {}".format(
            count_parameters(model)))

        # Load samples from dataset
        data_samples = toy_data.inf_train_gen(ckpt_args.data, batch_size=2000)

        return model, data_samples
Example #3
0
def visualize_times():
    model = build_model_tabular(args, 1).to(device)
    set_cnf_options(args, model)

    checkpt = torch.load(os.path.join(args.save, 'checkpt.pth'))
    model.load_state_dict(checkpt['state_dict'])
    model.to(device)

    viz_times = torch.linspace(0., args.time_length, args.ntimes)
    errors = []
    with torch.no_grad():
        for i, t in enumerate(tqdm(viz_times[1:])):
            model.eval()
            set_cnf_options(args, model)
            xx = torch.linspace(-10, 10, 10000).view(-1, 1)

            #generated_p = model_density(xx, model)
            generated_p = 0
            for cnf in model.chain:
                xx = xx.to(device)
                z, delta_logp = cnf(xx,
                                    torch.zeros_like(xx),
                                    integration_times=torch.Tensor([0, t]))
                generated_p = standard_normal_logprob(z) - delta_logp

            plt.plot(xx.view(-1).cpu().numpy(),
                     generated_p.view(-1).exp().cpu().numpy(),
                     label='Model')

            utils.makedirs(os.path.join(args.save, 'test_times', 'figs'))
            plt.savefig(
                os.path.join(args.save, 'test_times', 'figs',
                             '{:04d}.jpg'.format(i)))
            plt.close()
    trajectory_to_video(os.path.join(args.save, 'test_times', 'figs'))
Example #4
0
def visualize_evolution():
    model = build_model_tabular(args, 1).to(device)
    set_cnf_options(args, model)

    checkpt = torch.load(os.path.join(args.save, 'checkpt.pth'))
    model.load_state_dict(checkpt['state_dict'])
    model.to(device)

    viz_times = torch.linspace(0., args.time_length, args.ntimes)
    errors = []
    viz_times_np = viz_times[1:].detach().cpu().numpy()
    xx = torch.linspace(-5, 5, args.num_particles).view(-1, 1)
    xx_np = xx.detach().cpu().numpy()
    xs, ys = np.meshgrid(xx, viz_times_np)
    #xx,yy = np.meshgrid(args.num_particles, viz_times_np )
    #all_evolutions = np.zeros((args.ntimes-1,args.num_particles))
    all_evolutions = np.zeros((args.num_particles, args.ntimes - 1))
    with torch.no_grad():
        for i, t in enumerate(tqdm(viz_times[1:])):
            model.eval()
            set_cnf_options(args, model)
            #xx = torch.linspace(-5, 5, args.num_particles).view(-1, 1)

            #generated_p = model_density(xx, model)
            generated_p = 0
            for cnf in model.chain:
                xx = xx.to(device)
                z, delta_logp = cnf(xx,
                                    torch.zeros_like(xx),
                                    integration_times=torch.Tensor([0, t]))
                generated_p = standard_normal_logprob(z) - delta_logp

            generated_p = generated_p.detach()
            #plt.plot(xx.view(-1).cpu().numpy(), generated_p.view(-1).exp().cpu().numpy(), label='Model')
            cur_evolution = generated_p.view(-1).exp().cpu().numpy()

            #all_evolutions[i]= np.array(cur_evolution)
            all_evolutions[:, i] = np.array(cur_evolution)
        #xx = np.array(xx.detach().cpu().numpy())
        #yy = np.array(yy)
        plt.figure(dpi=1200)
        plt.clf()
        all_evolutions = all_evolutions.astype('float32')
        print(xs.shape)
        print(ys.shape)
        print(all_evolutions.shape)
        #plt.pcolormesh(ys, xs, all_evolutions)
        plt.pcolormesh(xs, ys, all_evolutions.transpose())

        utils.makedirs(os.path.join(args.save, 'test_times', 'figs'))
        plt.savefig(
            os.path.join(args.save, 'test_times', 'figs',
                         'evolution.jpg'.format(i)))
        plt.close()
Example #5
0
def visualize_particle_flow():
    model = build_model_tabular(args, 1).to(device)
    set_cnf_options(args, model)

    checkpt = torch.load(os.path.join(args.save, 'checkpt.pth'))
    model.load_state_dict(checkpt['state_dict'])
    model.to(device)

    viz_times = torch.linspace(0., args.time_length, args.ntimes)
    errors = []
    xx = torch.linspace(-5, 5, args.num_particles).view(-1, 1)
    zs = []
    #zs.append(xx.view(-1).cpu().numpy())
    with torch.no_grad():
        for i, t in enumerate(tqdm(viz_times[1:])):
            model.eval()
            set_cnf_options(args, model)

            #generated_p = model_density(xx, model)
            generated_p = 0
            for cnf in model.chain:
                xx = xx.to(device)
                z, delta_logp = cnf(xx,
                                    torch.zeros_like(xx),
                                    integration_times=torch.Tensor([0, t]))
                generated_p = standard_normal_logprob(z) - delta_logp

            zs.append(z.cpu().numpy())

            #plt.plot(xx.view(-1).cpu().numpy(), generated_p.view(-1).exp().cpu().numpy(), label='Model')

            #plt.savefig(os.path.join(args.save,'test_times', 'figs', '{:04d}.jpg'.format(i)))
            #plt.close()

    zs = np.array(zs).reshape(args.ntimes - 1, args.num_particles)
    viz_t = viz_times[1:].numpy()
    #print(zs)
    plt.figure(dpi=1200)

    plt.clf()
    #plt.plot(viz_t , zs[:,0])
    with sns.color_palette("Blues_d"):
        plt.plot(viz_t, zs)
        plt.xlabel("Test Time")
        #plt.tight_layout()
        utils.makedirs(os.path.join(args.save, 'test_times', 'figs'))
        plt.savefig(
            os.path.join(args.save, 'test_times', 'figs',
                         'particle_trajectory.jpg'.format(i)))
        plt.close()
Example #6
0
def create_model(args, data_shape):
    hidden_dims = tuple(map(int, args.dims.split(",")))

    model = odenvp.ODENVP(
        (BATCH_SIZE, *data_shape),
        n_blocks=args.num_blocks,
        intermediate_dims=hidden_dims,
        nonlinearity=args.nonlinearity,
        alpha=args.alpha,
        cnf_kwargs={
            "T": args.time_length,
            "train_T": args.train_T
        },
    )
    if args.spectral_norm: add_spectral_norm(model)
    set_cnf_options(args, model)
    return model
def gen_model(scale=10, fraction=0.5):
    #build normalizing flow model from previous fit
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    args = pkl.load(open('args.pkl', 'rb'))
    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = build_model_tabular(args, 5,
                                regularization_fns).to(device)  #.cuda()
    if args.spectral_norm: add_spectral_norm(model)
    set_cnf_options(args, model)
    model.load_state_dict(torch.load('model_10000.pt'))

    #if torch.cuda.is_available():
    #    model = init_flow_model(
    #        num_inputs=5,
    #        num_cond_inputs=None).cuda() #len(cond_cols)).cuda()
    #else:
    #    model = init_flow_model(
    #        num_inputs=5,
    #        num_cond_inputs=None) #len(cond_cols)).cuda()

    #num_layers = 5
    #base_dist = StandardNormal(shape=(5,))
    #transforms = []
    #for _ in range(num_layers):
    #    transforms.append(ReversePermutation(features=5))
    #    transforms.append(MaskedAffineAutoregressiveTransform(features=5,
    #                                                      hidden_features=4))
    #transform = CompositeTransform(transforms)
    #model = Flow(transform, base_dist).to(device)

    #model.cpu()
    #filename = 'checkpoint11434epochs_cycle.pth'
    #filename = f'gauss_scale{scale}_frac{fraction}/checkpoint200000epochs_cycle_gauss.pth'
    #filename = 'gauss_scale10_frac0.25/checkpoint100000epochs_cycle_gauss.pth'
    #filename = 'checkpoint_epoch{}.pth'.format(95000)
    #data = torch.load(filename, map_location=device)
    #breakpoint()
    #model.load_state_dict(data['model'])
    #if torch.cuda.is_available():
    #    data = torch.load(filename)
    #    model.load_state_dict(data['model'])
    #    model.cuda();
    #else:
    #    data = torch.load(filename, map_location=torch.device('cpu'))
    #    model.load_state_dict(data['model'])
    return model
def main():
    # os.system('shutdown -c')  # cancel previous shutdown command

    if write_log:
        utils.makedirs(args.save)
        logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__))

        logger.info(args)

        args_file_path = os.path.join(args.save, 'args.yaml')
        with open(args_file_path, 'w') as f:
            yaml.dump(vars(args), f, default_flow_style=False)

    if args.distributed:
        if write_log: logger.info('Distributed initializing process group')
        torch.cuda.set_device(args.local_rank)
        distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                       world_size=dist_utils.env_world_size(), rank=env_rank())
        assert (dist_utils.env_world_size() == distributed.get_world_size())
        if write_log: logger.info("Distributed: success (%d/%d)" % (args.local_rank, distributed.get_world_size()))
        device = torch.device("cuda:%d" % torch.cuda.current_device() if torch.cuda.is_available() else "cpu")
    else:
        device = torch.cuda.current_device()  #

    # import pdb; pdb.set_trace()
    cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True)

    # load dataset
    train_loader, test_loader, data_shape = get_dataset(args)

    trainlog = os.path.join(args.save, 'training.csv')
    testlog = os.path.join(args.save, 'test.csv')

    traincolumns = ['itr', 'wall', 'itr_time', 'loss', 'bpd', 'fe', 'total_time', 'grad_norm']
    testcolumns = ['wall', 'epoch', 'eval_time', 'bpd', 'fe', 'total_time', 'transport_cost']

    # build model
    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = create_model(args, data_shape, regularization_fns).cuda()
    if args.distributed: model = dist_utils.DDP(model,
                                                device_ids=[args.local_rank],
                                                output_device=args.local_rank)

    traincolumns = append_regularization_keys_header(traincolumns, regularization_fns)

    if not args.resume and write_log:
        with open(trainlog, 'w') as f:
            csvlogger = csv.DictWriter(f, traincolumns)
            csvlogger.writeheader()
        with open(testlog, 'w') as f:
            csvlogger = csv.DictWriter(f, testcolumns)
            csvlogger.writeheader()

    set_cnf_options(args, model)

    if write_log: logger.info(model)
    if write_log: logger.info("Number of trainable parameters: {}".format(count_parameters(model)))
    if write_log: logger.info('Iters per train epoch: {}'.format(len(train_loader)))
    if write_log: logger.info('Iters per test: {}'.format(len(test_loader)))

    # optimizer
    if args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9,
                              nesterov=False)

    # restore parameters
    # import pdb; pdb.set_trace()
    if args.resume is not None:
        # import pdb; pdb.set_trace()
        print('resume from checkpoint')
        checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda(args.local_rank))
        model.load_state_dict(checkpt["state_dict"])
        if "optim_state_dict" in checkpt.keys():
            optimizer.load_state_dict(checkpt["optim_state_dict"])
            # Manually move optimizer state to device.
            for state in optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = cvt(v)

    # For visualization.
    if write_log: fixed_z = cvt(torch.randn(min(args.test_batch_size, 100), *data_shape))

    if write_log:
        time_meter = utils.RunningAverageMeter(0.97)
        bpd_meter = utils.RunningAverageMeter(0.97)
        loss_meter = utils.RunningAverageMeter(0.97)
        steps_meter = utils.RunningAverageMeter(0.97)
        grad_meter = utils.RunningAverageMeter(0.97)
        tt_meter = utils.RunningAverageMeter(0.97)

    if not args.resume:
        best_loss = float("inf")
        itr = 0
        wall_clock = 0.
        begin_epoch = 1
        chkdir = args.save
        '''
    elif args.resume and args.validate:
        chkdir = os.path.dirname(args.resume)
        wall_clock = 0
        itr = 0
        best_loss = 0.0
        begin_epoch = 0
        '''
    else:
        chkdir = os.path.dirname(args.resume)
        filename = os.path.join(chkdir, 'test.csv')
        print(filename)
        tedf = pd.read_csv(os.path.join(chkdir, 'test.csv'))
        trdf = pd.read_csv(os.path.join(chkdir, 'training.csv'))
        # import pdb; pdb.set_trace()
        wall_clock = trdf['wall'].to_numpy()[-1]
        itr = trdf['itr'].to_numpy()[-1]
        best_loss = tedf['bpd'].min()
        begin_epoch = int(tedf['epoch'].to_numpy()[-1] + 1)  # not exactly correct

    if args.distributed:
        if write_log: logger.info('Syncing machines before training')
        dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda())

    for epoch in range(begin_epoch, begin_epoch + 1):
        # compute test loss
        print('Evaluating')
        model.eval()
        if args.local_rank == 0:
            utils.makedirs(args.save)
            # import pdb; pdb.set_trace()
            if hasattr(model, 'module'):
                _state = model.module.state_dict()
            else:
                _state = model.state_dict()
            torch.save({
                "args": args,
                "state_dict": _state,  # model.module.state_dict() if torch.cuda.is_available() else model.state_dict(),
                "optim_state_dict": optimizer.state_dict(),
                "fixed_z": fixed_z.cpu()
            }, os.path.join(args.save, "checkpt_%d.pth" % epoch))

        # save real and generate with different temperatures
        fig_num = 64
        if True:  # args.save_real:
            for i, (x, y) in enumerate(test_loader):
                if i < 100:
                    pass
                elif i == 100:
                    real = x.size(0)
                else:
                    break
            if x.shape[0] > fig_num:
                x = x[:fig_num, ...]
            # import pdb; pdb.set_trace()
            fig_filename = os.path.join(chkdir, "real.jpg")
            save_image(x.float() / 255.0, fig_filename, nrow=8)

        if True:  # args.generate:
            print('\nGenerating images... ')
            fixed_z = cvt(torch.randn(fig_num, *data_shape))
            nb = int(np.ceil(np.sqrt(float(fixed_z.size(0)))))
            for t in [ 1.0, 0.99, 0.98, 0.97,0.96,0.95,0.93,0.92,0.90,0.85,0.8,0.75,0.7,0.65,0.6]:
                # visualize samples and density
                fig_filename = os.path.join(chkdir, "generated-T%g.jpg" % t)
                utils.makedirs(os.path.dirname(fig_filename))
                generated_samples = model(t * fixed_z, reverse=True)
                x = unshift(generated_samples[0].view(-1, *data_shape), 8)
                save_image(x, fig_filename, nrow=nb)
Example #9
0
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if args.use_cpu:
        device = torch.device("cpu")
    cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True)

    # load dataset
    test_loader = get_dataset(args, args.test_batch_size)

    # build model
    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    aug_model = build_augmented_model_tabular(
        args,
        args.aug_size + args.effective_shape,
        regularization_fns=regularization_fns,
    )
    set_cnf_options(args, aug_model)
    logger.info(aug_model)

    # restore parameters
    itr = 0
    if args.resume is not None:
        checkpt = torch.load(args.resume,
                             map_location=lambda storage, loc: storage)
        aug_model.load_state_dict(checkpt["state_dict"])

    if torch.cuda.is_available() and not args.use_cpu:
        aug_model = torch.nn.DataParallel(aug_model).cuda()

    best_loss = float("inf")
    aug_model.eval()
    with torch.no_grad():
Example #10
0
        best_loss = float('inf')
        itr = 0
        n_vals_without_improvement = 0
        end = time.time()
        model.train()
        while True:
            if args.early_stopping > 0 and n_vals_without_improvement > args.early_stopping:
                break

            for x in batch_iter(data.trn.x, shuffle=True):
                if args.early_stopping > 0 and n_vals_without_improvement > args.early_stopping:
                    break

                atol, rtol = update_tolerances(args, itr, decay_factors)
                set_cnf_options(args, atol, rtol, model)
                print(atol)
                print(rtol)

                optimizer.zero_grad()

                x = cvt(x)
                loss = compute_loss(x, model)
                loss_meter.update(loss.item())

                if len(regularization_coeffs) > 0:
                    reg_states = get_regularization(model,
                                                    regularization_coeffs)
                    reg_loss = sum(reg_state * coeff
                                   for reg_state, coeff in zip(
                                       reg_states, regularization_coeffs)
Example #11
0
def train():

    model = build_model_tabular(args, 1).to(device)
    set_cnf_options(args, model)

    logger.info(model)
    logger.info("Number of trainable parameters: {}".format(
        count_parameters(model)))

    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)

    time_meter = utils.RunningAverageMeter(0.93)
    loss_meter = utils.RunningAverageMeter(0.93)
    nfef_meter = utils.RunningAverageMeter(0.93)
    nfeb_meter = utils.RunningAverageMeter(0.93)
    tt_meter = utils.RunningAverageMeter(0.93)

    end = time.time()
    best_loss = float('inf')
    model.train()
    for itr in range(1, args.niters + 1):
        optimizer.zero_grad()

        loss = compute_loss(args, model)
        loss_meter.update(loss.item())

        total_time = count_total_time(model)
        nfe_forward = count_nfe(model)

        loss.backward()
        optimizer.step()

        nfe_total = count_nfe(model)
        nfe_backward = nfe_total - nfe_forward
        nfef_meter.update(nfe_forward)
        nfeb_meter.update(nfe_backward)

        time_meter.update(time.time() - end)
        tt_meter.update(total_time)

        log_message = (
            'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | NFE Forward {:.0f}({:.1f})'
            ' | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format(
                itr, time_meter.val, time_meter.avg, loss_meter.val,
                loss_meter.avg, nfef_meter.val, nfef_meter.avg, nfeb_meter.val,
                nfeb_meter.avg, tt_meter.val, tt_meter.avg))
        logger.info(log_message)

        if itr % args.val_freq == 0 or itr == args.niters:
            with torch.no_grad():
                model.eval()
                test_loss = compute_loss(args,
                                         model,
                                         batch_size=args.test_batch_size)
                test_nfe = count_nfe(model)
                log_message = '[TEST] Iter {:04d} | Test Loss {:.6f} | NFE {:.0f}'.format(
                    itr, test_loss, test_nfe)
                logger.info(log_message)

                if test_loss.item() < best_loss:
                    best_loss = test_loss.item()
                    utils.makedirs(args.save)
                    torch.save(
                        {
                            'args': args,
                            'state_dict': model.state_dict(),
                        }, os.path.join(args.save, 'checkpt.pth'))
                model.train()

        if itr % args.viz_freq == 0:
            with torch.no_grad():
                model.eval()

                xx = torch.linspace(-10, 10, 10000).view(-1, 1)
                true_p = data_density(xx)
                plt.plot(xx.view(-1).cpu().numpy(),
                         true_p.view(-1).exp().cpu().numpy(),
                         label='True')

                true_p = model_density(xx, model)
                plt.plot(xx.view(-1).cpu().numpy(),
                         true_p.view(-1).exp().cpu().numpy(),
                         label='Model')

                utils.makedirs(os.path.join(args.save, 'figs'))
                plt.savefig(
                    os.path.join(args.save, 'figs', '{:06d}.jpg'.format(itr)))
                plt.close()

                model.train()

        end = time.time()

    logger.info('Training has finished.')
Example #12
0
def main(args):
    # logger
    print(args.no_display_loss)
    utils.makedirs(args.save)
    logger = utils.get_logger(
        logpath=os.path.join(args.save, "logs"),
        filepath=os.path.abspath(__file__),
        displaying=~args.no_display_loss,
    )

    if args.layer_type == "blend":
        logger.info("!! Setting time_scale from None to 1.0 for Blend layers.")
        args.time_scale = 1.0

    logger.info(args)

    device = torch.device(
        "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu"
    )
    if args.use_cpu:
        device = torch.device("cpu")

    args.data = dataset.SCData.factory(args.dataset, args.max_dim)

    args.timepoints = args.data.get_unique_times()
    # Use maximum timepoint to establish integration_times
    # as some timepoints may be left out for validation etc.
    args.int_tps = (np.arange(max(args.timepoints) + 1) + 1.0) * args.time_scale

    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = build_model_tabular(args, args.data.get_shape()[0], regularization_fns).to(
        device
    )
    if args.use_growth:
        if args.leaveout_timepoint == -1:
            growth_model_path = (
                "../data/externel/growth_model_v2.ckpt"
            )
        elif args.leaveout_timepoint in [1, 2, 3]:
            assert args.max_dim == 5
            growth_model_path = (
                "../data/growth/model_%d"
                % args.leaveout_timepoint
            )
        else:
            print("WARNING: Cannot use growth with this timepoint")

    growth_model = torch.load(growth_model_path, map_location=device)
    if args.spectral_norm:
        add_spectral_norm(model)
    set_cnf_options(args, model)

    if args.test:
        state_dict = torch.load(args.save + "/checkpt.pth", map_location=device)
        model.load_state_dict(state_dict["state_dict"])
        # if "growth_state_dict" not in state_dict:
        #    print("error growth model note in save")
        #    growth_model = None
        # else:
        #    checkpt = torch.load(args.save + "/checkpt.pth", map_location=device)
        #    growth_model.load_state_dict(checkpt["growth_state_dict"])
        # TODO can we load the arguments from the save?
        # eval_utils.generate_samples(
        #    device, args, model, growth_model, timepoint=args.leaveout_timepoint
        # )
        # with torch.no_grad():
        #    evaluate(device, args, model, growth_model)
    #    exit()
    else:
        logger.info(model)
        n_param = count_parameters(model)
        logger.info("Number of trainable parameters: {}".format(n_param))

        train(
            device,
            args,
            model,
            growth_model,
            regularization_coeffs,
            regularization_fns,
            logger,
        )

    if args.data.data.shape[1] == 2:
        plot_output(device, args, model)
Example #13
0
def run(args, kwargs):
    # ==================================================================================================================
    # SNAPSHOTS
    # ==================================================================================================================
    args.model_signature = str(datetime.datetime.now())[0:19].replace(' ', '_')
    args.model_signature = args.model_signature.replace(':', '_')

    if args.automatic_saving == True:
        path = '{}/{}/{}/{}/{}/{}/{}/{}/{}/'.format(args.solver, args.dataset,
                                                    args.layer_type, args.atol,
                                                    args.rtol, args.atol_start,
                                                    args.rtol_start,
                                                    args.warmup_steps,
                                                    args.manual_seed)
    else:
        path = 'test/'

    args.snap_dir = os.path.join(args.out_dir, path)

    if not os.path.exists(args.snap_dir):
        os.makedirs(args.snap_dir)

    # logger
    utils.makedirs(args.snap_dir)
    logger = utils.get_logger(logpath=os.path.join(args.snap_dir, 'logs'),
                              filepath=os.path.abspath(__file__))

    logger.info(args)

    # SAVING
    torch.save(args, args.snap_dir + 'config.config')

    # ==================================================================================================================
    # LOAD DATA
    # ==================================================================================================================
    train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs)

    if not args.evaluate:

        nfef_meter = utils.AverageMeter()
        nfeb_meter = utils.AverageMeter()

        # ==============================================================================================================
        # SELECT MODEL
        # ==============================================================================================================
        # flow parameters and architecture choice are passed on to model through args

        if args.flow == 'no_flow':
            model = VAE.VAE(args)
        elif args.flow == 'planar':
            model = VAE.PlanarVAE(args)
        elif args.flow == 'iaf':
            model = VAE.IAFVAE(args)
        elif args.flow == 'orthogonal':
            model = VAE.OrthogonalSylvesterVAE(args)
        elif args.flow == 'householder':
            model = VAE.HouseholderSylvesterVAE(args)
        elif args.flow == 'triangular':
            model = VAE.TriangularSylvesterVAE(args)
        elif args.flow == 'cnf':
            model = CNFVAE.CNFVAE(args)
        elif args.flow == 'cnf_bias':
            model = CNFVAE.AmortizedBiasCNFVAE(args)
        elif args.flow == 'cnf_hyper':
            model = CNFVAE.HypernetCNFVAE(args)
        elif args.flow == 'cnf_lyper':
            model = CNFVAE.LypernetCNFVAE(args)
        elif args.flow == 'cnf_rank':
            model = CNFVAE.AmortizedLowRankCNFVAE(args)
        else:
            raise ValueError('Invalid flow choice')

        if args.retrain_encoder:
            logger.info(f"Initializing decoder from {args.model_path}")
            dec_model = torch.load(args.model_path)
            dec_sd = {}
            for k, v in dec_model.state_dict().items():
                if 'p_x' in k:
                    dec_sd[k] = v
            model.load_state_dict(dec_sd, strict=False)

        if args.cuda:
            logger.info("Model on GPU")
            model.cuda()

        logger.info(model)
        logger.info("Number of trainable parameters: {}".format(
            count_parameters(model)))

        if args.retrain_encoder:
            parameters = []
            logger.info('Optimizing over:')
            for name, param in model.named_parameters():
                if 'p_x' not in name:
                    logger.info(name)
                    parameters.append(param)
        else:
            parameters = model.parameters()

        optimizer = optim.Adamax(parameters, lr=args.learning_rate, eps=1.e-7)

        # ==================================================================================================================
        # TRAINING
        # ==================================================================================================================
        train_loss = []
        val_loss = []

        # for early stopping
        best_loss = np.inf
        best_bpd = np.inf
        e = 0
        epoch = 0

        train_times = []

        for epoch in range(1, args.epochs + 1):
            atol, rtol = update_tolerances(args, epoch, decay_factors)
            print(atol)
            set_cnf_options(args, atol, rtol, model)

            t_start = time.time()

            if 'cnf' not in args.flow:
                tr_loss = train(epoch, train_loader, model, optimizer, args,
                                logger)
            else:
                tr_loss, nfef_meter, nfeb_meter = train(
                    epoch, train_loader, model, optimizer, args, logger,
                    nfef_meter, nfeb_meter)

            train_loss.append(tr_loss)
            train_times.append(time.time() - t_start)
            logger.info('One training epoch took %.2f seconds' %
                        (time.time() - t_start))

            v_loss, v_bpd = evaluate(val_loader,
                                     model,
                                     args,
                                     logger,
                                     epoch=epoch)

            val_loss.append(v_loss)

            # early-stopping
            if v_loss < best_loss:
                e = 0
                best_loss = v_loss
                if args.input_type != 'binary':
                    best_bpd = v_bpd
                logger.info('->model saved<-')
                torch.save(model, args.snap_dir + 'model.model')
                # torch.save(model, snap_dir + args.flow + '_' + args.architecture + '.model')

            elif (args.early_stopping_epochs > 0) and (epoch >= args.warmup):
                e += 1
                if e > args.early_stopping_epochs:
                    break

            if args.input_type == 'binary':
                logger.info(
                    '--> Early stopping: {}/{} (BEST: loss {:.4f})\n'.format(
                        e, args.early_stopping_epochs, best_loss))

            else:
                logger.info(
                    '--> Early stopping: {}/{} (BEST: loss {:.4f}, bpd {:.4f})\n'
                    .format(e, args.early_stopping_epochs, best_loss,
                            best_bpd))

            if math.isnan(v_loss):
                raise ValueError('NaN encountered!')

        train_loss = np.hstack(train_loss)
        val_loss = np.array(val_loss)

        plot_training_curve(train_loss,
                            val_loss,
                            fname=args.snap_dir + '/training_curve.pdf')

        # training time per epoch
        train_times = np.array(train_times)
        mean_train_time = np.mean(train_times)
        std_train_time = np.std(train_times, ddof=1)
        logger.info('Average train time per epoch: %.2f +/- %.2f' %
                    (mean_train_time, std_train_time))

        # ==================================================================================================================
        # EVALUATION
        # ==================================================================================================================

        logger.info(args)
        logger.info('Stopped after %d epochs' % epoch)
        logger.info('Average train time per epoch: %.2f +/- %.2f' %
                    (mean_train_time, std_train_time))

        final_model = torch.load(args.snap_dir + 'model.model')
        validation_loss, validation_bpd = evaluate(val_loader, final_model,
                                                   args, logger)

    else:
        validation_loss = "N/A"
        validation_bpd = "N/A"
        logger.info(f"Loading model from {args.model_path}")
        final_model = torch.load(args.model_path)

    test_loss, test_bpd = evaluate(test_loader,
                                   final_model,
                                   args,
                                   logger,
                                   testing=True)

    logger.info(
        'FINAL EVALUATION ON VALIDATION SET. ELBO (VAL): {:.4f}'.format(
            validation_loss))
Example #14
0
def main():
    global best_acc

    if not os.path.isdir(args.out):
        mkdir_p(args.out)

    # Data
    print(f'==> Preparing cifar10')
    transform_train = transforms.Compose([
        dataset.RandomPadandCrop(32),
        dataset.RandomFlip(),
        dataset.ToTensor(),
    ])

    transform_val = transforms.Compose([
        dataset.ToTensor(),
    ])

    train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10(
        '/home/fengchan/stor/dataset/original-data/cifar10',
        args.n_labeled,
        transform_train=transform_train,
        transform_val=transform_val)
    labeled_trainloader = data.DataLoader(train_labeled_set,
                                          batch_size=args.batch_size,
                                          shuffle=True,
                                          num_workers=0,
                                          drop_last=True)
    unlabeled_trainloader = data.DataLoader(train_unlabeled_set,
                                            batch_size=args.batch_size,
                                            shuffle=True,
                                            num_workers=0,
                                            drop_last=True)
    val_loader = data.DataLoader(val_set,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=0)
    test_loader = data.DataLoader(test_set,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=0)

    # Model
    print("==> creating WRN-28-2")

    def create_model(ema=False):
        model = models.WideResNet(num_classes=num_classes)
        model = model.cuda()

        if ema:
            for param in model.parameters():
                param.detach_()

        return model

    data_shape = [3, 32, 32]

    regularization_fns, regularization_coeffs = create_regularization_fns(args)

    def create_cnf():
        # generate cnf
        # cnf = create_cnf_model_1(args, data_shape, regularization_fns=None)
        # cnf = create_cnf_model(args, data_shape, regularization_fns=regularization_fns)
        cnf = create_nf_model(args, data_shape, regularization_fns=None)
        cnf = cnf.cuda() if use_cuda else cnf
        return cnf

    model = create_model()
    ema_model = create_model(ema=True)
    cnf = create_cnf()

    if args.spectral_norm:
        add_spectral_norm(cnf, logger)
        set_cnf_options(args, cnf)

    cudnn.benchmark = True
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    train_criterion = SemiLoss()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    #CNF
    cnf_optimizer = optim.Adam(cnf.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)

    ema_optimizer = WeightEMA(model, ema_model, alpha=args.ema_decay)
    start_epoch = 0

    # Resume
    #generate prior
    means = generate_gaussian_means(num_classes, data_shape, seed=num_classes)
    title = 'noisy-cifar-10'
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        args.out = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        ema_model.load_state_dict(checkpoint['ema_state_dict'])
        cnf.load_state_dict(checkpoint['cnf_state_dict'])
        means = checkpoint['means']
        cnf_optimizer.load_state_dict(checkpoint['cnf_optimizer'])
        optimizer.load_state_dict(checkpoint['optimizer'])

        logger = Logger(os.path.join(args.out, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        logger = Logger(os.path.join(args.out, 'log.txt'), title=title)
        logger.set_names([
            'Train Loss', 'Train Loss X', 'Train Loss U', 'Train loss NLL X',
            'Train loss NLL U', 'Train loss mixed X', 'Valid Loss',
            'Valid Acc.', 'Test Loss', 'Test Acc.'
        ])

    means = means.cuda() if use_cuda else means
    prior = SSLGaussMixture(means, device='cuda' if use_cuda else 'cpu')

    writer = SummaryWriter(args.out)
    step = 0
    test_accs = []
    # Train and val
    for epoch in range(start_epoch, args.epochs):

        print('\nEpoch: [%d | %d] LR: %f' %
              (epoch + 1, args.epochs, state['lr']))

        train_loss, train_loss_x, train_loss_u, train_loss_nll_x, train_loss_nll_u, train_loss_mixed_x = train(
            labeled_trainloader, unlabeled_trainloader, model, cnf, prior,
            cnf_optimizer, optimizer, ema_optimizer, train_criterion, epoch,
            use_cuda)
        _, train_acc = validate(labeled_trainloader,
                                ema_model,
                                criterion,
                                epoch,
                                use_cuda,
                                mode='Train Stats')
        val_loss, val_acc = validate(val_loader,
                                     ema_model,
                                     criterion,
                                     epoch,
                                     use_cuda,
                                     mode='Valid Stats')
        test_loss, test_acc = validate(test_loader,
                                       ema_model,
                                       criterion,
                                       epoch,
                                       use_cuda,
                                       mode='Test Stats ')

        step = args.train_iteration * (epoch + 1)

        writer.add_scalar('losses/train_loss', train_loss, step)
        writer.add_scalar('losses/train_loss_nll_x', train_loss_nll_x, step)
        writer.add_scalar('losses/train_loss_nll_u', train_loss_nll_u, step)
        writer.add_scalar('losses/train_loss_mixed_x', train_loss_mixed_x,
                          step)
        writer.add_scalar('losses/train_loss_nll_x', train_loss_nll_x, step)
        writer.add_scalar('losses/valid_loss', val_loss, step)
        writer.add_scalar('losses/test_loss', test_loss, step)

        writer.add_scalar('accuracy/train_acc', train_acc, step)
        writer.add_scalar('accuracy/val_acc', val_acc, step)
        writer.add_scalar('accuracy/test_acc', test_acc, step)

        # append logger file
        logger.append([
            train_loss, train_loss_x, train_loss_u, train_loss_nll_x,
            train_loss_nll_u, train_loss_mixed_x, val_loss, val_acc, test_loss,
            test_acc
        ])

        # save model
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'cnf_state_dict': cnf.state_dict(),
                'means': means,
                'ema_state_dict': ema_model.state_dict(),
                'acc': val_acc,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
                'cnf_optimizer': cnf_optimizer.state_dict(),
            }, is_best)
        test_accs.append(test_acc)
    logger.close()
    writer.close()

    print('Best acc:')
    print(best_acc)

    print('Mean acc:')
    print(np.mean(test_accs[-20:]))
def build_augmented_model_tabular(args, dims, regularization_fns=None):
    """
    The function used for creating conditional Continuous Normlizing Flow
    with augmented neural ODE

    Parameters:
        args: arguments used to create conditional CNF. Check args parser for details.
        dims: dimension of the input. Currently only allow 1-d input.
        regularization_fns: regularizations applied to the ODE function

    Returns:
        a ctfp model based on augmened neural ode
    """
    hidden_dims = tuple(map(int, args.dims.split(",")))
    if args.aug_hidden_dims is not None:
        aug_hidden_dims = tuple(map(int, args.aug_hidden_dims.split(",")))
    else:
        aug_hidden_dims = None

    def build_cnf():
        diffeq = layers.AugODEnet(
            hidden_dims=hidden_dims,
            input_shape=(dims, ),
            effective_shape=args.effective_shape,
            strides=None,
            conv=False,
            layer_type=args.layer_type,
            nonlinearity=args.nonlinearity,
            aug_dim=args.aug_dim,
            aug_mapping=args.aug_mapping,
            aug_hidden_dims=args.aug_hidden_dims,
        )
        odefunc = layers.AugODEfunc(
            diffeq=diffeq,
            divergence_fn=args.divergence_fn,
            residual=args.residual,
            rademacher=args.rademacher,
            effective_shape=args.effective_shape,
        )
        cnf = layers.CNF(
            odefunc=odefunc,
            T=args.time_length,
            train_T=args.train_T,
            regularization_fns=regularization_fns,
            solver=args.solver,
            rtol=args.rtol,
            atol=args.atol,
        )
        return cnf

    chain = [build_cnf() for _ in range(args.num_blocks)]
    if args.batch_norm:
        bn_layers = [
            layers.MovingBatchNorm1d(dims,
                                     bn_lag=args.bn_lag,
                                     effective_shape=args.effective_shape)
            for _ in range(args.num_blocks)
        ]
        bn_chain = [
            layers.MovingBatchNorm1d(dims,
                                     bn_lag=args.bn_lag,
                                     effective_shape=args.effective_shape)
        ]
        for a, b in zip(chain, bn_layers):
            bn_chain.append(a)
            bn_chain.append(b)
        chain = bn_chain
    model = layers.SequentialFlow(chain)
    set_cnf_options(args, model)

    return model
Example #16
0
def main():
    #os.system('shutdown -c')  # cancel previous shutdown command

    if write_log:
        utils.makedirs(args.save)
        logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'),
                                  filepath=os.path.abspath(__file__))

        logger.info(args)

        args_file_path = os.path.join(args.save, 'args.yaml')
        with open(args_file_path, 'w') as f:
            yaml.dump(vars(args), f, default_flow_style=False)

    if args.distributed:
        if write_log: logger.info('Distributed initializing process group')
        torch.cuda.set_device(args.local_rank)
        distributed.init_process_group(backend=args.dist_backend,
                                       init_method=args.dist_url,
                                       world_size=dist_utils.env_world_size(),
                                       rank=env_rank())
        assert (dist_utils.env_world_size() == distributed.get_world_size())
        if write_log:
            logger.info("Distributed: success (%d/%d)" %
                        (args.local_rank, distributed.get_world_size()))

    # get deivce
    # device = torch.device("cuda:%d"%torch.cuda.current_device() if torch.cuda.is_available() else "cpu")
    device = "cpu"
    cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True)

    # load dataset
    train_loader, test_loader, data_shape = get_dataset(args)

    trainlog = os.path.join(args.save, 'training.csv')
    testlog = os.path.join(args.save, 'test.csv')

    traincolumns = [
        'itr', 'wall', 'itr_time', 'loss', 'bpd', 'fe', 'total_time',
        'grad_norm'
    ]
    testcolumns = [
        'wall', 'epoch', 'eval_time', 'bpd', 'fe', 'total_time',
        'transport_cost'
    ]

    # build model
    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = create_model(args, data_shape, regularization_fns)
    # model = model.cuda()
    if args.distributed:
        model = dist_utils.DDP(model,
                               device_ids=[args.local_rank],
                               output_device=args.local_rank)

    traincolumns = append_regularization_keys_header(traincolumns,
                                                     regularization_fns)

    if not args.resume and write_log:
        with open(trainlog, 'w') as f:
            csvlogger = csv.DictWriter(f, traincolumns)
            csvlogger.writeheader()
        with open(testlog, 'w') as f:
            csvlogger = csv.DictWriter(f, testcolumns)
            csvlogger.writeheader()

    set_cnf_options(args, model)

    if write_log: logger.info(model)
    if write_log:
        logger.info("Number of trainable parameters: {}".format(
            count_parameters(model)))
    if write_log:
        logger.info('Iters per train epoch: {}'.format(len(train_loader)))
    if write_log: logger.info('Iters per test: {}'.format(len(test_loader)))

    # optimizer
    if args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)
    elif args.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              weight_decay=args.weight_decay,
                              momentum=0.9,
                              nesterov=False)

    # restore parameters
    if args.resume is not None:
        checkpt = torch.load(
            args.resume,
            map_location=lambda storage, loc: storage.cuda(args.local_rank))
        model.load_state_dict(checkpt["state_dict"])
        if "optim_state_dict" in checkpt.keys():
            optimizer.load_state_dict(checkpt["optim_state_dict"])
            # Manually move optimizer state to device.
            for state in optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = cvt(v)

    # For visualization.
    if write_log:
        fixed_z = cvt(torch.randn(min(args.test_batch_size, 100), *data_shape))

    if write_log:
        time_meter = utils.RunningAverageMeter(0.97)
        bpd_meter = utils.RunningAverageMeter(0.97)
        loss_meter = utils.RunningAverageMeter(0.97)
        steps_meter = utils.RunningAverageMeter(0.97)
        grad_meter = utils.RunningAverageMeter(0.97)
        tt_meter = utils.RunningAverageMeter(0.97)

    if not args.resume:
        best_loss = float("inf")
        itr = 0
        wall_clock = 0.
        begin_epoch = 1
    else:
        chkdir = os.path.dirname(args.resume)
        tedf = pd.read_csv(os.path.join(chkdir, 'test.csv'))
        trdf = pd.read_csv(os.path.join(chkdir, 'training.csv'))
        wall_clock = trdf['wall'].to_numpy()[-1]
        itr = trdf['itr'].to_numpy()[-1]
        best_loss = tedf['bpd'].min()
        begin_epoch = int(tedf['epoch'].to_numpy()[-1] +
                          1)  # not exactly correct

    if args.distributed:
        if write_log: logger.info('Syncing machines before training')
        dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda())

    for epoch in range(begin_epoch, args.num_epochs + 1):
        if not args.validate:
            model.train()

            with open(trainlog, 'a') as f:
                if write_log: csvlogger = csv.DictWriter(f, traincolumns)

                for _, (x, y) in enumerate(train_loader):
                    start = time.time()
                    update_lr(optimizer, itr)
                    optimizer.zero_grad()

                    # cast data and move to device
                    x = add_noise(cvt(x), nbits=args.nbits)
                    #x = x.clamp_(min=0, max=1)
                    # compute loss
                    bpd, (x, z), reg_states = compute_bits_per_dim(x, model)
                    if np.isnan(bpd.data.item()):
                        raise ValueError('model returned nan during training')
                    elif np.isinf(bpd.data.item()):
                        raise ValueError('model returned inf during training')

                    loss = bpd
                    if regularization_coeffs:
                        reg_loss = sum(reg_state * coeff
                                       for reg_state, coeff in zip(
                                           reg_states, regularization_coeffs)
                                       if coeff != 0)
                        loss = loss + reg_loss
                    total_time = count_total_time(model)

                    loss.backward()
                    nfe_opt = count_nfe(model)
                    if write_log: steps_meter.update(nfe_opt)
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(), args.max_grad_norm)

                    optimizer.step()

                    itr_time = time.time() - start
                    wall_clock += itr_time

                    batch_size = x.size(0)
                    metrics = torch.tensor([
                        1., batch_size,
                        loss.item(),
                        bpd.item(), nfe_opt, grad_norm, *reg_states
                    ]).float()

                    rv = tuple(torch.tensor(0.) for r in reg_states)

                    total_gpus, batch_total, r_loss, r_bpd, r_nfe, r_grad_norm, *rv = dist_utils.sum_tensor(
                        metrics).cpu().numpy()

                    if write_log:
                        time_meter.update(itr_time)
                        bpd_meter.update(r_bpd / total_gpus)
                        loss_meter.update(r_loss / total_gpus)
                        grad_meter.update(r_grad_norm / total_gpus)
                        tt_meter.update(total_time)

                        fmt = '{:.4f}'
                        logdict = {
                            'itr': itr,
                            'wall': fmt.format(wall_clock),
                            'itr_time': fmt.format(itr_time),
                            'loss': fmt.format(r_loss / total_gpus),
                            'bpd': fmt.format(r_bpd / total_gpus),
                            'total_time': fmt.format(total_time),
                            'fe': r_nfe / total_gpus,
                            'grad_norm': fmt.format(r_grad_norm / total_gpus),
                        }
                        if regularization_coeffs:
                            rv = tuple(v_ / total_gpus for v_ in rv)
                            logdict = append_regularization_csv_dict(
                                logdict, regularization_fns, rv)
                        csvlogger.writerow(logdict)

                        if itr % args.log_freq == 0:
                            log_message = (
                                "Itr {:06d} | Wall {:.3e}({:.2f}) | "
                                "Time/Itr {:.2f}({:.2f}) | BPD {:.2f}({:.2f}) | "
                                "Loss {:.2f}({:.2f}) | "
                                "FE {:.0f}({:.0f}) | Grad Norm {:.3e}({:.3e}) | "
                                "TT {:.2f}({:.2f})".format(
                                    itr, wall_clock, wall_clock / (itr + 1),
                                    time_meter.val, time_meter.avg,
                                    bpd_meter.val, bpd_meter.avg,
                                    loss_meter.val, loss_meter.avg,
                                    steps_meter.val, steps_meter.avg,
                                    grad_meter.val, grad_meter.avg,
                                    tt_meter.val, tt_meter.avg))
                            if regularization_coeffs:
                                log_message = append_regularization_to_log(
                                    log_message, regularization_fns, rv)
                            logger.info(log_message)

                    itr += 1

        # compute test loss
        model.eval()
        if args.local_rank == 0:
            utils.makedirs(args.save)
            torch.save(
                {
                    "args":
                    args,
                    "state_dict":
                    model.module.state_dict()
                    if torch.cuda.is_available() else model.state_dict(),
                    "optim_state_dict":
                    optimizer.state_dict(),
                    "fixed_z":
                    fixed_z.cpu()
                }, os.path.join(args.save, "checkpt.pth"))
        if epoch % args.val_freq == 0 or args.validate:
            with open(testlog, 'a') as f:
                if write_log: csvlogger = csv.DictWriter(f, testcolumns)
                with torch.no_grad():
                    start = time.time()
                    if write_log: logger.info("validating...")

                    lossmean = 0.
                    meandist = 0.
                    steps = 0
                    tt = 0.
                    for i, (x, y) in enumerate(test_loader):
                        sh = x.shape
                        x = shift(cvt(x), nbits=args.nbits)
                        loss, (x, z), _ = compute_bits_per_dim(x, model)
                        dist = (x.view(x.size(0), -1) -
                                z).pow(2).mean(dim=-1).mean()
                        meandist = i / (i + 1) * dist + meandist / (i + 1)
                        lossmean = i / (i + 1) * lossmean + loss / (i + 1)

                        tt = i / (i + 1) * tt + count_total_time(model) / (i +
                                                                           1)
                        steps = i / (i + 1) * steps + count_nfe(model) / (i +
                                                                          1)

                    loss = lossmean.item()
                    metrics = torch.tensor([1., loss, meandist, steps]).float()

                    total_gpus, r_bpd, r_mdist, r_steps = dist_utils.sum_tensor(
                        metrics).cpu().numpy()
                    eval_time = time.time() - start

                    if write_log:
                        fmt = '{:.4f}'
                        logdict = {
                            'epoch': epoch,
                            'eval_time': fmt.format(eval_time),
                            'bpd': fmt.format(r_bpd / total_gpus),
                            'wall': fmt.format(wall_clock),
                            'total_time': fmt.format(tt),
                            'transport_cost': fmt.format(r_mdist / total_gpus),
                            'fe': '{:.2f}'.format(r_steps / total_gpus)
                        }

                        csvlogger.writerow(logdict)

                        logger.info(
                            "Epoch {:04d} | Time {:.4f}, Bit/dim {:.4f}, Steps {:.4f}, TT {:.2f}, Transport Cost {:.2e}"
                            .format(epoch, eval_time, r_bpd / total_gpus,
                                    r_steps / total_gpus, tt,
                                    r_mdist / total_gpus))

                    loss = r_bpd / total_gpus

                    if loss < best_loss and args.local_rank == 0:
                        best_loss = loss
                        shutil.copyfile(os.path.join(args.save, "checkpt.pth"),
                                        os.path.join(args.save, "best.pth"))

            # visualize samples and density
            if write_log:
                with torch.no_grad():
                    fig_filename = os.path.join(args.save, "figs",
                                                "{:04d}.jpg".format(epoch))
                    utils.makedirs(os.path.dirname(fig_filename))
                    generated_samples, _, _ = model(fixed_z, reverse=True)
                    generated_samples = generated_samples.view(-1, *data_shape)
                    nb = int(np.ceil(np.sqrt(float(fixed_z.size(0)))))
                    save_image(unshift(generated_samples, nbits=args.nbits),
                               fig_filename,
                               nrow=nb)
            if args.validate:
                break
Example #17
0
    x = toy_data.inf_train_gen(args.data, batch_size=batch_size)
    x = torch.from_numpy(x).type(torch.float32).to(device)
    zero = torch.zeros(x.shape[0], 1).to(x)
    z, change = model(x, zero)

    logpx = standard_normal_logprob(z).sum(1, keepdim=True) - change
    loss = -torch.mean(logpx)
    return loss


if __name__ == '__main__':

    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = build_model_tabular(args, 2, regularization_fns).to(device)
    if args.spectral_norm: add_spectral_norm(model)
    set_cnf_options(args, model)

    logger.info(model)
    logger.info("Number of trainable parameters: {}".format(count_parameters(model)))

    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    time_meter = utils.RunningAverageMeter(0.93)
    loss_meter = utils.RunningAverageMeter(0.93)
    nfef_meter = utils.RunningAverageMeter(0.93)
    nfeb_meter = utils.RunningAverageMeter(0.93)
    tt_meter = utils.RunningAverageMeter(0.93)

    end = time.time()
    best_loss = float('inf')
    model.train()
Example #18
0
def main(args):
    device = torch.device(
        "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
    if args.use_cpu:
        device = torch.device("cpu")

    data = dataset.SCData.factory(args.dataset, args)

    args.timepoints = data.get_unique_times()

    # Use maximum timepoint to establish integration_times
    # as some timepoints may be left out for validation etc.
    args.int_tps = (np.arange(max(args.timepoints) + 1) +
                    1.0) * args.time_scale

    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = build_model_tabular(args,
                                data.get_shape()[0],
                                regularization_fns).to(device)
    if args.use_growth:
        growth_model_path = data.get_growth_net_path()
        #growth_model_path = "/home/atong/TrajectoryNet/data/externel/growth_model_v2.ckpt"
        growth_model = torch.load(growth_model_path, map_location=device)
    if args.spectral_norm:
        add_spectral_norm(model)
    set_cnf_options(args, model)

    state_dict = torch.load(args.save + "/checkpt.pth", map_location=device)
    model.load_state_dict(state_dict["state_dict"])

    #plot_output(device, args, model, data)
    #exit()
    # get_trajectory_samples(device, model, data)

    args.data = data
    args.timepoints = args.data.get_unique_times()
    args.int_tps = (np.arange(max(args.timepoints) + 1) +
                    1.0) * args.time_scale

    print('integrating backwards')
    #end_time_data = data.data_dict[args.embedding_name]
    end_time_data = data.get_data()[args.data.get_times() == np.max(
        args.data.get_times())]
    #np.random.permutation(end_time_data)
    #rand_idx = np.random.randint(end_time_data.shape[0], size=5000)
    #end_time_data = end_time_data[rand_idx,:]
    integrate_backwards(end_time_data,
                        model,
                        args.save,
                        ntimes=100,
                        device=device)
    exit()
    losses_list = []
    #for factor in np.linspace(0.05, 0.95, 19):
    #for factor in np.linspace(0.91, 0.99, 9):
    if args.dataset == 'CHAFFER':  # Do timepoint adjustment
        print('adjusting_timepoints')
        lt = args.leaveout_timepoint
        if lt == 1:
            factor = 0.6799872494335812
            factor = 0.95
        elif lt == 2:
            factor = 0.2905983814032348
            factor = 0.01
        else:
            raise RuntimeError('Unknown timepoint %d' %
                               args.leaveout_timepoint)
        args.int_tps[lt] = (
            1 - factor) * args.int_tps[lt - 1] + factor * args.int_tps[lt + 1]
    losses = eval_utils.evaluate_kantorovich_v2(device, args, model)
    losses_list.append(losses)
    print(np.array(losses_list))
    np.save(os.path.join(args.save, 'emd_list'), np.array(losses_list))