Пример #1
0
    def __init__(self, args):
        super(CNFVAE, self).__init__(args)

        # CNF model
        self.cnf = build_model_tabular(args, args.z_size)

        if args.cuda:
            self.cuda()
Пример #2
0
def visualize_evolution():
    model = build_model_tabular(args, 1).to(device)
    set_cnf_options(args, model)

    checkpt = torch.load(os.path.join(args.save, 'checkpt.pth'))
    model.load_state_dict(checkpt['state_dict'])
    model.to(device)

    viz_times = torch.linspace(0., args.time_length, args.ntimes)
    errors = []
    viz_times_np = viz_times[1:].detach().cpu().numpy()
    xx = torch.linspace(-5, 5, args.num_particles).view(-1, 1)
    xx_np = xx.detach().cpu().numpy()
    xs, ys = np.meshgrid(xx, viz_times_np)
    #xx,yy = np.meshgrid(args.num_particles, viz_times_np )
    #all_evolutions = np.zeros((args.ntimes-1,args.num_particles))
    all_evolutions = np.zeros((args.num_particles, args.ntimes - 1))
    with torch.no_grad():
        for i, t in enumerate(tqdm(viz_times[1:])):
            model.eval()
            set_cnf_options(args, model)
            #xx = torch.linspace(-5, 5, args.num_particles).view(-1, 1)

            #generated_p = model_density(xx, model)
            generated_p = 0
            for cnf in model.chain:
                xx = xx.to(device)
                z, delta_logp = cnf(xx,
                                    torch.zeros_like(xx),
                                    integration_times=torch.Tensor([0, t]))
                generated_p = standard_normal_logprob(z) - delta_logp

            generated_p = generated_p.detach()
            #plt.plot(xx.view(-1).cpu().numpy(), generated_p.view(-1).exp().cpu().numpy(), label='Model')
            cur_evolution = generated_p.view(-1).exp().cpu().numpy()

            #all_evolutions[i]= np.array(cur_evolution)
            all_evolutions[:, i] = np.array(cur_evolution)
        #xx = np.array(xx.detach().cpu().numpy())
        #yy = np.array(yy)
        plt.figure(dpi=1200)
        plt.clf()
        all_evolutions = all_evolutions.astype('float32')
        print(xs.shape)
        print(ys.shape)
        print(all_evolutions.shape)
        #plt.pcolormesh(ys, xs, all_evolutions)
        plt.pcolormesh(xs, ys, all_evolutions.transpose())

        utils.makedirs(os.path.join(args.save, 'test_times', 'figs'))
        plt.savefig(
            os.path.join(args.save, 'test_times', 'figs',
                         'evolution.jpg'.format(i)))
        plt.close()
Пример #3
0
def visualize_particle_flow():
    model = build_model_tabular(args, 1).to(device)
    set_cnf_options(args, model)

    checkpt = torch.load(os.path.join(args.save, 'checkpt.pth'))
    model.load_state_dict(checkpt['state_dict'])
    model.to(device)

    viz_times = torch.linspace(0., args.time_length, args.ntimes)
    errors = []
    xx = torch.linspace(-5, 5, args.num_particles).view(-1, 1)
    zs = []
    #zs.append(xx.view(-1).cpu().numpy())
    with torch.no_grad():
        for i, t in enumerate(tqdm(viz_times[1:])):
            model.eval()
            set_cnf_options(args, model)

            #generated_p = model_density(xx, model)
            generated_p = 0
            for cnf in model.chain:
                xx = xx.to(device)
                z, delta_logp = cnf(xx,
                                    torch.zeros_like(xx),
                                    integration_times=torch.Tensor([0, t]))
                generated_p = standard_normal_logprob(z) - delta_logp

            zs.append(z.cpu().numpy())

            #plt.plot(xx.view(-1).cpu().numpy(), generated_p.view(-1).exp().cpu().numpy(), label='Model')

            #plt.savefig(os.path.join(args.save,'test_times', 'figs', '{:04d}.jpg'.format(i)))
            #plt.close()

    zs = np.array(zs).reshape(args.ntimes - 1, args.num_particles)
    viz_t = viz_times[1:].numpy()
    #print(zs)
    plt.figure(dpi=1200)

    plt.clf()
    #plt.plot(viz_t , zs[:,0])
    with sns.color_palette("Blues_d"):
        plt.plot(viz_t, zs)
        plt.xlabel("Test Time")
        #plt.tight_layout()
        utils.makedirs(os.path.join(args.save, 'test_times', 'figs'))
        plt.savefig(
            os.path.join(args.save, 'test_times', 'figs',
                         'particle_trajectory.jpg'.format(i)))
        plt.close()
def gen_model(scale=10, fraction=0.5):
    #build normalizing flow model from previous fit
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    args = pkl.load(open('args.pkl', 'rb'))
    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = build_model_tabular(args, 5,
                                regularization_fns).to(device)  #.cuda()
    if args.spectral_norm: add_spectral_norm(model)
    set_cnf_options(args, model)
    model.load_state_dict(torch.load('model_10000.pt'))

    #if torch.cuda.is_available():
    #    model = init_flow_model(
    #        num_inputs=5,
    #        num_cond_inputs=None).cuda() #len(cond_cols)).cuda()
    #else:
    #    model = init_flow_model(
    #        num_inputs=5,
    #        num_cond_inputs=None) #len(cond_cols)).cuda()

    #num_layers = 5
    #base_dist = StandardNormal(shape=(5,))
    #transforms = []
    #for _ in range(num_layers):
    #    transforms.append(ReversePermutation(features=5))
    #    transforms.append(MaskedAffineAutoregressiveTransform(features=5,
    #                                                      hidden_features=4))
    #transform = CompositeTransform(transforms)
    #model = Flow(transform, base_dist).to(device)

    #model.cpu()
    #filename = 'checkpoint11434epochs_cycle.pth'
    #filename = f'gauss_scale{scale}_frac{fraction}/checkpoint200000epochs_cycle_gauss.pth'
    #filename = 'gauss_scale10_frac0.25/checkpoint100000epochs_cycle_gauss.pth'
    #filename = 'checkpoint_epoch{}.pth'.format(95000)
    #data = torch.load(filename, map_location=device)
    #breakpoint()
    #model.load_state_dict(data['model'])
    #if torch.cuda.is_available():
    #    data = torch.load(filename)
    #    model.load_state_dict(data['model'])
    #    model.cuda();
    #else:
    #    data = torch.load(filename, map_location=torch.device('cpu'))
    #    model.load_state_dict(data['model'])
    return model
Пример #5
0
    def get_ckpt_model_and_data(args):
        # Load checkpoint.
        checkpt = torch.load(args.checkpt, map_location=lambda storage, loc: storage)
        ckpt_args = checkpt['args']
        state_dict = checkpt['state_dict']

        # Construct model and restore checkpoint.
        regularization_fns, regularization_coeffs = create_regularization_fns(ckpt_args)
        model = build_model_tabular(ckpt_args, 2, regularization_fns).to(device)
        if ckpt_args.spectral_norm: add_spectral_norm(model)
        set_cnf_options(ckpt_args, model)

        model.load_state_dict(state_dict)
        model.to(device)

        print(model)
        print("Number of trainable parameters: {}".format(count_parameters(model)))

        # Load samples from dataset
        data_samples = toy_data.inf_train_gen(ckpt_args.data, batch_size=2000)

        return model, data_samples
Пример #6
0
def compute_loss(args, model, batch_size=args.batch_size):

    x = toy_data.inf_train_gen(args.data, batch_size=batch_size)
    x = torch.from_numpy(x).type(torch.float32).to(device)
    zero = torch.zeros(x.shape[0], 1).to(x)
    z, change = model(x, zero)

    logpx = standard_normal_logprob(z).sum(1, keepdim=True) - change
    loss = -torch.mean(logpx)
    return loss


if __name__ == '__main__':

    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = build_model_tabular(args, 2, regularization_fns).to(device)
    if args.spectral_norm: add_spectral_norm(model)
    set_cnf_options(args, model)

    logger.info(model)
    logger.info("Number of trainable parameters: {}".format(count_parameters(model)))

    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    time_meter = utils.RunningAverageMeter(0.93)
    loss_meter = utils.RunningAverageMeter(0.93)
    nfef_meter = utils.RunningAverageMeter(0.93)
    nfeb_meter = utils.RunningAverageMeter(0.93)
    tt_meter = utils.RunningAverageMeter(0.93)

    end = time.time()
Пример #7
0
if __name__ == '__main__':

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True)

    logger.info('Using {} GPUs.'.format(torch.cuda.device_count()))

    data = load_data(args.data)
    data.trn.x = torch.from_numpy(data.trn.x)
    data.val.x = torch.from_numpy(data.val.x)
    data.tst.x = torch.from_numpy(data.tst.x)

    args.dims = '-'.join([str(args.hdim_factor * data.n_dims)] * args.nhidden)

    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = build_model_tabular(args, data.n_dims,
                                regularization_fns).to(device)
    set_cnf_options(args, model)

    for k in model.state_dict().keys():
        logger.info(k)

    if args.resume is not None:
        checkpt = torch.load(args.resume)

        # Backwards compatibility with an older version of the code.
        # TODO: remove upon release.
        filtered_state_dict = {}
        for k, v in checkpt['state_dict'].items():
            if 'diffeq.diffeq' not in k:
                filtered_state_dict[k.replace('module.', '')] = v
        model.load_state_dict(filtered_state_dict)
Пример #8
0

if __name__ == '__main__':
    # only a single block of diffeq is supported now
    assert args.num_blocks == 1
    centers = DEFAULT_CENTERS.to(device)
    dim = centers.shape[1]
    convection = lambda x: -gaussian_mixture_score(x, centers)

    writer = SummaryWriter('out/wgf/gaussian')

    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    regularization_fns = None
    model = build_model_tabular(args=args,
                                dims=dim,
                                convection=convection,
                                regularization_fns=regularization_fns,
                                exp_decay=args.exp_decay).to(device)

    model_validate_dvp = build_model_compare_DVP(
        args=args,
        convection=convection,
        mollifier=IsometricGaussianMollifier(args.mollifier_sigma_square),
        diffeq=model.chain[0].odefunc.diffeq).to(device)

    if args.spectral_norm: add_spectral_norm(model)
    set_cnf_options(args, model)

    logger.info(model)
    logger.info("Number of trainable parameters: {}".format(
        count_parameters(model)))
Пример #9
0
    cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True)

    # logger.info('Using {} GPUs.'.format(torch.cuda.device_count()))

    data = load_data(args.data)
    data.trn.x = torch.from_numpy(data.trn.x)
    data.val.x = torch.from_numpy(data.val.x)
    data.tst.x = torch.from_numpy(data.tst.x)

    args.dims = '-'.join([str(args.hdim_factor * data.n_dims)] * args.nhidden)

    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model, cnfs = build_model_tabular(
        args,
        data.n_dims,
        regularization_fns,
        return_intermediate_points=args.return_inter_points)
    model = model.to(device)
    set_cnf_options(args, model)

    # for k in model.state_dict().keys():
    #     logger.info(k)

    if args.resume is not None:
        checkpt = torch.load(args.resume)

        # Backwards compatibility with an older version of the code.
        # TODO: remove upon release.
        filtered_state_dict = {}
        for k, v in checkpt['state_dict'].items():
Пример #10
0
def train():

    model = build_model_tabular(args, 1).to(device)
    set_cnf_options(args, model)

    logger.info(model)
    logger.info("Number of trainable parameters: {}".format(
        count_parameters(model)))

    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)

    time_meter = utils.RunningAverageMeter(0.93)
    loss_meter = utils.RunningAverageMeter(0.93)
    nfef_meter = utils.RunningAverageMeter(0.93)
    nfeb_meter = utils.RunningAverageMeter(0.93)
    tt_meter = utils.RunningAverageMeter(0.93)

    end = time.time()
    best_loss = float('inf')
    model.train()
    for itr in range(1, args.niters + 1):
        optimizer.zero_grad()

        loss = compute_loss(args, model)
        loss_meter.update(loss.item())

        total_time = count_total_time(model)
        nfe_forward = count_nfe(model)

        loss.backward()
        optimizer.step()

        nfe_total = count_nfe(model)
        nfe_backward = nfe_total - nfe_forward
        nfef_meter.update(nfe_forward)
        nfeb_meter.update(nfe_backward)

        time_meter.update(time.time() - end)
        tt_meter.update(total_time)

        log_message = (
            'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | NFE Forward {:.0f}({:.1f})'
            ' | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format(
                itr, time_meter.val, time_meter.avg, loss_meter.val,
                loss_meter.avg, nfef_meter.val, nfef_meter.avg, nfeb_meter.val,
                nfeb_meter.avg, tt_meter.val, tt_meter.avg))
        logger.info(log_message)

        if itr % args.val_freq == 0 or itr == args.niters:
            with torch.no_grad():
                model.eval()
                test_loss = compute_loss(args,
                                         model,
                                         batch_size=args.test_batch_size)
                test_nfe = count_nfe(model)
                log_message = '[TEST] Iter {:04d} | Test Loss {:.6f} | NFE {:.0f}'.format(
                    itr, test_loss, test_nfe)
                logger.info(log_message)

                if test_loss.item() < best_loss:
                    best_loss = test_loss.item()
                    utils.makedirs(args.save)
                    torch.save(
                        {
                            'args': args,
                            'state_dict': model.state_dict(),
                        }, os.path.join(args.save, 'checkpt.pth'))
                model.train()

        if itr % args.viz_freq == 0:
            with torch.no_grad():
                model.eval()

                xx = torch.linspace(-10, 10, 10000).view(-1, 1)
                true_p = data_density(xx)
                plt.plot(xx.view(-1).cpu().numpy(),
                         true_p.view(-1).exp().cpu().numpy(),
                         label='True')

                true_p = model_density(xx, model)
                plt.plot(xx.view(-1).cpu().numpy(),
                         true_p.view(-1).exp().cpu().numpy(),
                         label='Model')

                utils.makedirs(os.path.join(args.save, 'figs'))
                plt.savefig(
                    os.path.join(args.save, 'figs', '{:06d}.jpg'.format(itr)))
                plt.close()

                model.train()

        end = time.time()

    logger.info('Training has finished.')
Пример #11
0
    def density_fn(x, logpx=None):
        if logpx is not None:
            return model(x, logpx, reverse=False)
        else:
            return model(x, reverse=False)

    return sample_fn, density_fn


if __name__ == '__main__':

    if args.discrete:
        model = construct_discrete_model().to(device)
        model.load_state_dict(torch.load(args.checkpt)['state_dict'])
    else:
        model = build_model_tabular(args, 2).to(device)

        sd = torch.load(args.checkpt)['state_dict']
        fixed_sd = {}
        for k, v in sd.items():
            fixed_sd[k.replace('odefunc.odefunc', 'odefunc')] = v
        model.load_state_dict(fixed_sd)

    print(model)
    print("Number of trainable parameters: {}".format(count_parameters(model)))

    model.eval()
    p_samples = toy_data.inf_train_gen(args.data, batch_size=800**2)

    with torch.no_grad():
        sample_fn, density_fn = get_transforms(model)
Пример #12
0
    # transform to z
    z, delta_logp = model(x, zero)

    # compute log q(z)
    logpz = standard_normal_logprob(z).sum(1, keepdim=True)

    logpx = logpz - delta_logp
    loss = -torch.mean(logpx)
    return loss


if __name__ == '__main__':

    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    input_dim = 2 if args.data != 'HDline' else 16
    model = build_model_tabular(args, input_dim, regularization_fns).to(device)
    if args.spectral_norm: add_spectral_norm(model)
    set_cnf_options(args, model)

    logger.info(model)
    logger.info("Number of trainable parameters: {}".format(
        count_parameters(model)))

    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)

    time_meter = utils.RunningAverageMeter(0.93)
    loss_meter = utils.RunningAverageMeter(0.93)
    nfef_meter = utils.RunningAverageMeter(0.93)
    nfeb_meter = utils.RunningAverageMeter(0.93)
Пример #13
0
def main(args):
    # logger
    print(args.no_display_loss)
    utils.makedirs(args.save)
    logger = utils.get_logger(
        logpath=os.path.join(args.save, "logs"),
        filepath=os.path.abspath(__file__),
        displaying=~args.no_display_loss,
    )

    if args.layer_type == "blend":
        logger.info("!! Setting time_scale from None to 1.0 for Blend layers.")
        args.time_scale = 1.0

    logger.info(args)

    device = torch.device(
        "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu"
    )
    if args.use_cpu:
        device = torch.device("cpu")

    args.data = dataset.SCData.factory(args.dataset, args.max_dim)

    args.timepoints = args.data.get_unique_times()
    # Use maximum timepoint to establish integration_times
    # as some timepoints may be left out for validation etc.
    args.int_tps = (np.arange(max(args.timepoints) + 1) + 1.0) * args.time_scale

    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = build_model_tabular(args, args.data.get_shape()[0], regularization_fns).to(
        device
    )
    if args.use_growth:
        if args.leaveout_timepoint == -1:
            growth_model_path = (
                "../data/externel/growth_model_v2.ckpt"
            )
        elif args.leaveout_timepoint in [1, 2, 3]:
            assert args.max_dim == 5
            growth_model_path = (
                "../data/growth/model_%d"
                % args.leaveout_timepoint
            )
        else:
            print("WARNING: Cannot use growth with this timepoint")

    growth_model = torch.load(growth_model_path, map_location=device)
    if args.spectral_norm:
        add_spectral_norm(model)
    set_cnf_options(args, model)

    if args.test:
        state_dict = torch.load(args.save + "/checkpt.pth", map_location=device)
        model.load_state_dict(state_dict["state_dict"])
        # if "growth_state_dict" not in state_dict:
        #    print("error growth model note in save")
        #    growth_model = None
        # else:
        #    checkpt = torch.load(args.save + "/checkpt.pth", map_location=device)
        #    growth_model.load_state_dict(checkpt["growth_state_dict"])
        # TODO can we load the arguments from the save?
        # eval_utils.generate_samples(
        #    device, args, model, growth_model, timepoint=args.leaveout_timepoint
        # )
        # with torch.no_grad():
        #    evaluate(device, args, model, growth_model)
    #    exit()
    else:
        logger.info(model)
        n_param = count_parameters(model)
        logger.info("Number of trainable parameters: {}".format(n_param))

        train(
            device,
            args,
            model,
            growth_model,
            regularization_coeffs,
            regularization_fns,
            logger,
        )

    if args.data.data.shape[1] == 2:
        plot_output(device, args, model)
Пример #14
0
def main(args):
    device = torch.device(
        "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
    if args.use_cpu:
        device = torch.device("cpu")

    data = dataset.SCData.factory(args.dataset, args)

    args.timepoints = data.get_unique_times()

    # Use maximum timepoint to establish integration_times
    # as some timepoints may be left out for validation etc.
    args.int_tps = (np.arange(max(args.timepoints) + 1) +
                    1.0) * args.time_scale

    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = build_model_tabular(args,
                                data.get_shape()[0],
                                regularization_fns).to(device)
    if args.use_growth:
        growth_model_path = data.get_growth_net_path()
        #growth_model_path = "/home/atong/TrajectoryNet/data/externel/growth_model_v2.ckpt"
        growth_model = torch.load(growth_model_path, map_location=device)
    if args.spectral_norm:
        add_spectral_norm(model)
    set_cnf_options(args, model)

    state_dict = torch.load(args.save + "/checkpt.pth", map_location=device)
    model.load_state_dict(state_dict["state_dict"])

    #plot_output(device, args, model, data)
    #exit()
    # get_trajectory_samples(device, model, data)

    args.data = data
    args.timepoints = args.data.get_unique_times()
    args.int_tps = (np.arange(max(args.timepoints) + 1) +
                    1.0) * args.time_scale

    print('integrating backwards')
    #end_time_data = data.data_dict[args.embedding_name]
    end_time_data = data.get_data()[args.data.get_times() == np.max(
        args.data.get_times())]
    #np.random.permutation(end_time_data)
    #rand_idx = np.random.randint(end_time_data.shape[0], size=5000)
    #end_time_data = end_time_data[rand_idx,:]
    integrate_backwards(end_time_data,
                        model,
                        args.save,
                        ntimes=100,
                        device=device)
    exit()
    losses_list = []
    #for factor in np.linspace(0.05, 0.95, 19):
    #for factor in np.linspace(0.91, 0.99, 9):
    if args.dataset == 'CHAFFER':  # Do timepoint adjustment
        print('adjusting_timepoints')
        lt = args.leaveout_timepoint
        if lt == 1:
            factor = 0.6799872494335812
            factor = 0.95
        elif lt == 2:
            factor = 0.2905983814032348
            factor = 0.01
        else:
            raise RuntimeError('Unknown timepoint %d' %
                               args.leaveout_timepoint)
        args.int_tps[lt] = (
            1 - factor) * args.int_tps[lt - 1] + factor * args.int_tps[lt + 1]
    losses = eval_utils.evaluate_kantorovich_v2(device, args, model)
    losses_list.append(losses)
    print(np.array(losses_list))
    np.save(os.path.join(args.save, 'emd_list'), np.array(losses_list))