예제 #1
0
    def get_ckpt_model_and_data(args):
        # Load checkpoint.
        checkpt = torch.load(args.checkpt,
                             map_location=lambda storage, loc: storage)
        ckpt_args = checkpt['args']
        state_dict = checkpt['state_dict']

        # Construct model and restore checkpoint.
        regularization_fns, regularization_coeffs = create_regularization_fns(
            ckpt_args)
        model = build_model_tabular(ckpt_args, 2,
                                    regularization_fns).to(device)
        if ckpt_args.spectral_norm: add_spectral_norm(model)
        set_cnf_options(ckpt_args, model)

        model.load_state_dict(state_dict)
        model.to(device)

        print(model)
        print("Number of trainable parameters: {}".format(
            count_parameters(model)))

        # Load samples from dataset
        data_samples = toy_data.inf_train_gen(ckpt_args.data, batch_size=2000)

        return model, data_samples
def gen_model(scale=10, fraction=0.5):
    #build normalizing flow model from previous fit
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    args = pkl.load(open('args.pkl', 'rb'))
    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = build_model_tabular(args, 5,
                                regularization_fns).to(device)  #.cuda()
    if args.spectral_norm: add_spectral_norm(model)
    set_cnf_options(args, model)
    model.load_state_dict(torch.load('model_10000.pt'))

    #if torch.cuda.is_available():
    #    model = init_flow_model(
    #        num_inputs=5,
    #        num_cond_inputs=None).cuda() #len(cond_cols)).cuda()
    #else:
    #    model = init_flow_model(
    #        num_inputs=5,
    #        num_cond_inputs=None) #len(cond_cols)).cuda()

    #num_layers = 5
    #base_dist = StandardNormal(shape=(5,))
    #transforms = []
    #for _ in range(num_layers):
    #    transforms.append(ReversePermutation(features=5))
    #    transforms.append(MaskedAffineAutoregressiveTransform(features=5,
    #                                                      hidden_features=4))
    #transform = CompositeTransform(transforms)
    #model = Flow(transform, base_dist).to(device)

    #model.cpu()
    #filename = 'checkpoint11434epochs_cycle.pth'
    #filename = f'gauss_scale{scale}_frac{fraction}/checkpoint200000epochs_cycle_gauss.pth'
    #filename = 'gauss_scale10_frac0.25/checkpoint100000epochs_cycle_gauss.pth'
    #filename = 'checkpoint_epoch{}.pth'.format(95000)
    #data = torch.load(filename, map_location=device)
    #breakpoint()
    #model.load_state_dict(data['model'])
    #if torch.cuda.is_available():
    #    data = torch.load(filename)
    #    model.load_state_dict(data['model'])
    #    model.cuda();
    #else:
    #    data = torch.load(filename, map_location=torch.device('cpu'))
    #    model.load_state_dict(data['model'])
    return model
예제 #3
0
def create_model(args, data_shape):
    hidden_dims = tuple(map(int, args.dims.split(",")))

    model = odenvp.ODENVP(
        (BATCH_SIZE, *data_shape),
        n_blocks=args.num_blocks,
        intermediate_dims=hidden_dims,
        nonlinearity=args.nonlinearity,
        alpha=args.alpha,
        cnf_kwargs={
            "T": args.time_length,
            "train_T": args.train_T
        },
    )
    if args.spectral_norm: add_spectral_norm(model)
    set_cnf_options(args, model)
    return model
예제 #4
0
파일: train_toy.py 프로젝트: jwubz123/5470-
    x = toy_data.inf_train_gen(args.data, batch_size=batch_size)
    x = torch.from_numpy(x).type(torch.float32).to(device)
    zero = torch.zeros(x.shape[0], 1).to(x)
    z, change = model(x, zero)

    logpx = standard_normal_logprob(z).sum(1, keepdim=True) - change
    loss = -torch.mean(logpx)
    return loss


if __name__ == '__main__':

    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = build_model_tabular(args, 2, regularization_fns).to(device)
    if args.spectral_norm: add_spectral_norm(model)
    set_cnf_options(args, model)

    logger.info(model)
    logger.info("Number of trainable parameters: {}".format(count_parameters(model)))

    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    time_meter = utils.RunningAverageMeter(0.93)
    loss_meter = utils.RunningAverageMeter(0.93)
    nfef_meter = utils.RunningAverageMeter(0.93)
    nfeb_meter = utils.RunningAverageMeter(0.93)
    tt_meter = utils.RunningAverageMeter(0.93)

    end = time.time()
    best_loss = float('inf')
예제 #5
0

if __name__ == "__main__":

    # get deivce
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True)

    # load dataset
    train_set, test_loader, data_shape = get_dataset(args)

    # build model
    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = create_model(args, data_shape, regularization_fns)

    if args.spectral_norm: add_spectral_norm(model, logger)
    set_cnf_options(args, model)

    logger.info(model)
    logger.info("Number of trainable parameters: {}".format(
        count_parameters(model)))

    # optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)

    # restore parameters
    if args.resume is not None:
        checkpt = torch.load(args.resume,
                             map_location=lambda storage, loc: storage)
예제 #6
0
def main():
    global best_acc

    if not os.path.isdir(args.out):
        mkdir_p(args.out)

    # Data
    print(f'==> Preparing cifar10')
    transform_train = transforms.Compose([
        dataset.RandomPadandCrop(32),
        dataset.RandomFlip(),
        dataset.ToTensor(),
    ])

    transform_val = transforms.Compose([
        dataset.ToTensor(),
    ])

    train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10(
        '/home/fengchan/stor/dataset/original-data/cifar10',
        args.n_labeled,
        transform_train=transform_train,
        transform_val=transform_val)
    labeled_trainloader = data.DataLoader(train_labeled_set,
                                          batch_size=args.batch_size,
                                          shuffle=True,
                                          num_workers=0,
                                          drop_last=True)
    unlabeled_trainloader = data.DataLoader(train_unlabeled_set,
                                            batch_size=args.batch_size,
                                            shuffle=True,
                                            num_workers=0,
                                            drop_last=True)
    val_loader = data.DataLoader(val_set,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=0)
    test_loader = data.DataLoader(test_set,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=0)

    # Model
    print("==> creating WRN-28-2")

    def create_model(ema=False):
        model = models.WideResNet(num_classes=num_classes)
        model = model.cuda()

        if ema:
            for param in model.parameters():
                param.detach_()

        return model

    data_shape = [3, 32, 32]

    regularization_fns, regularization_coeffs = create_regularization_fns(args)

    def create_cnf():
        # generate cnf
        # cnf = create_cnf_model_1(args, data_shape, regularization_fns=None)
        # cnf = create_cnf_model(args, data_shape, regularization_fns=regularization_fns)
        cnf = create_nf_model(args, data_shape, regularization_fns=None)
        cnf = cnf.cuda() if use_cuda else cnf
        return cnf

    model = create_model()
    ema_model = create_model(ema=True)
    cnf = create_cnf()

    if args.spectral_norm:
        add_spectral_norm(cnf, logger)
        set_cnf_options(args, cnf)

    cudnn.benchmark = True
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    train_criterion = SemiLoss()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    #CNF
    cnf_optimizer = optim.Adam(cnf.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)

    ema_optimizer = WeightEMA(model, ema_model, alpha=args.ema_decay)
    start_epoch = 0

    # Resume
    #generate prior
    means = generate_gaussian_means(num_classes, data_shape, seed=num_classes)
    title = 'noisy-cifar-10'
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        args.out = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        ema_model.load_state_dict(checkpoint['ema_state_dict'])
        cnf.load_state_dict(checkpoint['cnf_state_dict'])
        means = checkpoint['means']
        cnf_optimizer.load_state_dict(checkpoint['cnf_optimizer'])
        optimizer.load_state_dict(checkpoint['optimizer'])

        logger = Logger(os.path.join(args.out, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        logger = Logger(os.path.join(args.out, 'log.txt'), title=title)
        logger.set_names([
            'Train Loss', 'Train Loss X', 'Train Loss U', 'Train loss NLL X',
            'Train loss NLL U', 'Train loss mixed X', 'Valid Loss',
            'Valid Acc.', 'Test Loss', 'Test Acc.'
        ])

    means = means.cuda() if use_cuda else means
    prior = SSLGaussMixture(means, device='cuda' if use_cuda else 'cpu')

    writer = SummaryWriter(args.out)
    step = 0
    test_accs = []
    # Train and val
    for epoch in range(start_epoch, args.epochs):

        print('\nEpoch: [%d | %d] LR: %f' %
              (epoch + 1, args.epochs, state['lr']))

        train_loss, train_loss_x, train_loss_u, train_loss_nll_x, train_loss_nll_u, train_loss_mixed_x = train(
            labeled_trainloader, unlabeled_trainloader, model, cnf, prior,
            cnf_optimizer, optimizer, ema_optimizer, train_criterion, epoch,
            use_cuda)
        _, train_acc = validate(labeled_trainloader,
                                ema_model,
                                criterion,
                                epoch,
                                use_cuda,
                                mode='Train Stats')
        val_loss, val_acc = validate(val_loader,
                                     ema_model,
                                     criterion,
                                     epoch,
                                     use_cuda,
                                     mode='Valid Stats')
        test_loss, test_acc = validate(test_loader,
                                       ema_model,
                                       criterion,
                                       epoch,
                                       use_cuda,
                                       mode='Test Stats ')

        step = args.train_iteration * (epoch + 1)

        writer.add_scalar('losses/train_loss', train_loss, step)
        writer.add_scalar('losses/train_loss_nll_x', train_loss_nll_x, step)
        writer.add_scalar('losses/train_loss_nll_u', train_loss_nll_u, step)
        writer.add_scalar('losses/train_loss_mixed_x', train_loss_mixed_x,
                          step)
        writer.add_scalar('losses/train_loss_nll_x', train_loss_nll_x, step)
        writer.add_scalar('losses/valid_loss', val_loss, step)
        writer.add_scalar('losses/test_loss', test_loss, step)

        writer.add_scalar('accuracy/train_acc', train_acc, step)
        writer.add_scalar('accuracy/val_acc', val_acc, step)
        writer.add_scalar('accuracy/test_acc', test_acc, step)

        # append logger file
        logger.append([
            train_loss, train_loss_x, train_loss_u, train_loss_nll_x,
            train_loss_nll_u, train_loss_mixed_x, val_loss, val_acc, test_loss,
            test_acc
        ])

        # save model
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'cnf_state_dict': cnf.state_dict(),
                'means': means,
                'ema_state_dict': ema_model.state_dict(),
                'acc': val_acc,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
                'cnf_optimizer': cnf_optimizer.state_dict(),
            }, is_best)
        test_accs.append(test_acc)
    logger.close()
    writer.close()

    print('Best acc:')
    print(best_acc)

    print('Mean acc:')
    print(np.mean(test_accs[-20:]))
예제 #7
0
def main(args):
    # logger
    print(args.no_display_loss)
    utils.makedirs(args.save)
    logger = utils.get_logger(
        logpath=os.path.join(args.save, "logs"),
        filepath=os.path.abspath(__file__),
        displaying=~args.no_display_loss,
    )

    if args.layer_type == "blend":
        logger.info("!! Setting time_scale from None to 1.0 for Blend layers.")
        args.time_scale = 1.0

    logger.info(args)

    device = torch.device(
        "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu"
    )
    if args.use_cpu:
        device = torch.device("cpu")

    args.data = dataset.SCData.factory(args.dataset, args.max_dim)

    args.timepoints = args.data.get_unique_times()
    # Use maximum timepoint to establish integration_times
    # as some timepoints may be left out for validation etc.
    args.int_tps = (np.arange(max(args.timepoints) + 1) + 1.0) * args.time_scale

    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = build_model_tabular(args, args.data.get_shape()[0], regularization_fns).to(
        device
    )
    if args.use_growth:
        if args.leaveout_timepoint == -1:
            growth_model_path = (
                "../data/externel/growth_model_v2.ckpt"
            )
        elif args.leaveout_timepoint in [1, 2, 3]:
            assert args.max_dim == 5
            growth_model_path = (
                "../data/growth/model_%d"
                % args.leaveout_timepoint
            )
        else:
            print("WARNING: Cannot use growth with this timepoint")

    growth_model = torch.load(growth_model_path, map_location=device)
    if args.spectral_norm:
        add_spectral_norm(model)
    set_cnf_options(args, model)

    if args.test:
        state_dict = torch.load(args.save + "/checkpt.pth", map_location=device)
        model.load_state_dict(state_dict["state_dict"])
        # if "growth_state_dict" not in state_dict:
        #    print("error growth model note in save")
        #    growth_model = None
        # else:
        #    checkpt = torch.load(args.save + "/checkpt.pth", map_location=device)
        #    growth_model.load_state_dict(checkpt["growth_state_dict"])
        # TODO can we load the arguments from the save?
        # eval_utils.generate_samples(
        #    device, args, model, growth_model, timepoint=args.leaveout_timepoint
        # )
        # with torch.no_grad():
        #    evaluate(device, args, model, growth_model)
    #    exit()
    else:
        logger.info(model)
        n_param = count_parameters(model)
        logger.info("Number of trainable parameters: {}".format(n_param))

        train(
            device,
            args,
            model,
            growth_model,
            regularization_coeffs,
            regularization_fns,
            logger,
        )

    if args.data.data.shape[1] == 2:
        plot_output(device, args, model)
예제 #8
0
def main(args):
    device = torch.device(
        "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
    if args.use_cpu:
        device = torch.device("cpu")

    data = dataset.SCData.factory(args.dataset, args)

    args.timepoints = data.get_unique_times()

    # Use maximum timepoint to establish integration_times
    # as some timepoints may be left out for validation etc.
    args.int_tps = (np.arange(max(args.timepoints) + 1) +
                    1.0) * args.time_scale

    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = build_model_tabular(args,
                                data.get_shape()[0],
                                regularization_fns).to(device)
    if args.use_growth:
        growth_model_path = data.get_growth_net_path()
        #growth_model_path = "/home/atong/TrajectoryNet/data/externel/growth_model_v2.ckpt"
        growth_model = torch.load(growth_model_path, map_location=device)
    if args.spectral_norm:
        add_spectral_norm(model)
    set_cnf_options(args, model)

    state_dict = torch.load(args.save + "/checkpt.pth", map_location=device)
    model.load_state_dict(state_dict["state_dict"])

    #plot_output(device, args, model, data)
    #exit()
    # get_trajectory_samples(device, model, data)

    args.data = data
    args.timepoints = args.data.get_unique_times()
    args.int_tps = (np.arange(max(args.timepoints) + 1) +
                    1.0) * args.time_scale

    print('integrating backwards')
    #end_time_data = data.data_dict[args.embedding_name]
    end_time_data = data.get_data()[args.data.get_times() == np.max(
        args.data.get_times())]
    #np.random.permutation(end_time_data)
    #rand_idx = np.random.randint(end_time_data.shape[0], size=5000)
    #end_time_data = end_time_data[rand_idx,:]
    integrate_backwards(end_time_data,
                        model,
                        args.save,
                        ntimes=100,
                        device=device)
    exit()
    losses_list = []
    #for factor in np.linspace(0.05, 0.95, 19):
    #for factor in np.linspace(0.91, 0.99, 9):
    if args.dataset == 'CHAFFER':  # Do timepoint adjustment
        print('adjusting_timepoints')
        lt = args.leaveout_timepoint
        if lt == 1:
            factor = 0.6799872494335812
            factor = 0.95
        elif lt == 2:
            factor = 0.2905983814032348
            factor = 0.01
        else:
            raise RuntimeError('Unknown timepoint %d' %
                               args.leaveout_timepoint)
        args.int_tps[lt] = (
            1 - factor) * args.int_tps[lt - 1] + factor * args.int_tps[lt + 1]
    losses = eval_utils.evaluate_kantorovich_v2(device, args, model)
    losses_list.append(losses)
    print(np.array(losses_list))
    np.save(os.path.join(args.save, 'emd_list'), np.array(losses_list))