コード例 #1
0
def main(args):
    # we're probably only be using 1 GPU, so this should be fine
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"running on {device}")
    # set random seed for all
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    global best_loss
    if(args.generate_samples):
        print("generating samples")
    # load data
    # example for CIFAR-10 training:
    train_set, test_set = get_dataloader(args.dataset, args.batch_size)

    input_channels = channels_from_dataset(args.dataset)
    print(f"amount of  input channels: {input_channels}")
    # instantiate model
    # # baby network to make sure training script works
    net = Glow(in_channels=input_channels,
               depth=args.amt_flow_steps, levels=args.amt_levels, use_normalization=args.norm_method)

    # code for rosalinty model
    # net = RosGlow(input_channels, args.amt_flow_steps, args.amt_levels)

    net = net.to(device)

    print(f"training for {args.num_epochs} epochs.")

    start_epoch = 0
    # TODO: add functionality for loading checkpoints here
    if args.resume:
        print(f"resuming from checkpoint found in checkpoints/best_{args.dataset.lower()}.pth.tar.")
        # raise error if no checkpoint directory is found
        assert os.path.isdir("new_checkpoints")
        checkpoint = torch.load(f"new_checkpoints/best_{args.dataset.lower()}.pth.tar")
        net.load_state_dict(checkpoint["model"])
        global best_loss
        best_loss = checkpoint["test_loss"]
        start_epoch = checkpoint["epoch"]

    loss_function = FlowNLL().to(device)
    optimizer = optim.Adam(net.parameters(), lr=float(args.lr))
    # scheduler found in code, no mention in paper
    # scheduler = sched.LambdaLR(
    #     optimizer, lambda s: min(1., s / args.warmup_iters))

    # should we add a resume function here?

    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        print(f"training epoch {epoch}")
        train(net, train_set, device, optimizer, loss_function, epoch)
        # how often do we want to test?
        if (epoch % 10 == 0):  # revert this to 10 once we know that this works
            print(f"testing epoch {epoch}")
            test(net, test_set, device, loss_function, epoch, args.generate_samples,
                 args.amt_levels, args.dataset, args.n_samples)
コード例 #2
0
def main(args,kwargs):
    output_folder = args.output_folder
    model_name = args.model_name

    with open(os.path.join(output_folder,'hparams.json')) as json_file:  
        hparams = json.load(json_file)
      
    image_shape, num_classes, _, test_mnist = get_MNIST(False, hparams['dataroot'], hparams['download'])
    test_loader = data.DataLoader(test_mnist, batch_size=32,
                                      shuffle=False, num_workers=6,
                                      drop_last=False)
    x, y = test_loader.__iter__().__next__()
    x = x.to(device)

    model = Glow(image_shape, hparams['hidden_channels'], hparams['K'], hparams['L'], hparams['actnorm_scale'],
                 hparams['flow_permutation'], hparams['flow_coupling'], hparams['LU_decomposed'], num_classes,
                 hparams['learn_top'], hparams['y_condition'], False if 'logittransform' not in hparams else hparams['logittransform'],False if 'sn' not in hparams else hparams['sn'])

    model.load_state_dict(torch.load(os.path.join(output_folder, model_name)))
    model.set_actnorm_init()

    model = model.to(device)
    model = model.eval()



    with torch.no_grad():
        # ipdb.set_trace()
        images = model(y_onehot=None, temperature=1, batch_size=32, reverse=True).cpu()  
        better_dup_images = model(y_onehot=None, temperature=1, z=model._last_z, reverse=True, use_last_split=True).cpu()   
        dup_images = model(y_onehot=None, temperature=1, z=model._last_z, reverse=True).cpu()   
        worse_dup_images = model(y_onehot=None, temperature=1, z=model._last_z, reverse=True).cpu()   

    l2_err =  torch.pow((images - dup_images).view(images.shape[0], -1), 2).sum(-1).mean()
    better_l2_err =  torch.pow((images - better_dup_images).view(images.shape[0], -1), 2).sum(-1).mean()
    worse_l2_err =  torch.pow((images - worse_dup_images).view(images.shape[0], -1), 2).sum(-1).mean()
    print(l2_err, better_l2_err, worse_l2_err)
    plot_imgs([images, dup_images, better_dup_images, worse_dup_images], '_recons')

    # 
    with torch.no_grad():
        # ipdb.set_trace()
        z, nll, y_logits = model(x, None)
        better_dup_images = model(y_onehot=None, temperature=1, z=z, reverse=True, use_last_split=True).cpu()   

    plot_imgs([x, better_dup_images], '_data_recons2')

    fpath = os.path.join(output_folder, '_recon_evoluation.png')
    pad = run_recon_evolution(model, x, fpath)
コード例 #3
0
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size,
         epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale,
         flow_permutation, flow_coupling, LU_decomposed, learn_top,
         y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers,
         cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer,
         fresh):

    device = 'cpu' if (not torch.cuda.is_available() or not cuda) else 'cuda:0'

    check_manual_seed(seed)

    ds = check_dataset(dataset, dataroot, augment, download)
    image_shape, num_classes, train_dataset, test_dataset = ds

    # Note: unsupported for now
    multi_class = False

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=n_workers,
                                   drop_last=True)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=eval_batch_size,
                                  shuffle=False,
                                  num_workers=n_workers,
                                  drop_last=False)

    model = Glow(image_shape, hidden_channels, K, L, actnorm_scale,
                 flow_permutation, flow_coupling, LU_decomposed, num_classes,
                 learn_top, y_condition)

    model = model.to(device)
    optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch
        x = x.to(device)

        if y_condition:
            y = y.to(device)
            z, nll, y_logits = model(x, y)
            losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class)
        else:
            z, nll, y_logits = model(x, None)
            losses = compute_loss(nll)

        losses['total_loss'].backward()

        if max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(nll,
                                        y_logits,
                                        y_weight,
                                        y,
                                        multi_class,
                                        reduction='none')
            else:
                z, nll, y_logits = model(x, None)
                losses = compute_loss(nll, reduction='none')

        return losses

    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         'glow',
                                         save_interval=1,
                                         n_saved=2,
                                         require_empty=False)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {
        'model': model,
        'optimizer': optimizer
    })

    monitoring_metrics = ['total_loss']
    RunningAverage(output_transform=lambda x: x['total_loss']).attach(
        trainer, 'total_loss')

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(lambda x, y: torch.mean(x),
         output_transform=lambda x:
         (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach(
             evaluator, 'total_loss')

    if y_condition:
        monitoring_metrics.extend(['nll'])
        RunningAverage(output_transform=lambda x: x['nll']).attach(
            trainer, 'nll')

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(lambda x, y: torch.mean(x),
             output_transform=lambda x:
             (x['nll'], torch.empty(x['nll'].shape[0]))).attach(
                 evaluator, 'nll')

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        model.load_state_dict(torch.load(saved_model))
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer))

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split('_')[-1])

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(
                engine.state.dataloader)

    @trainer.on(Events.STARTED)
    def init(engine):
        model.train()

        init_batches = []
        init_targets = []

        with torch.no_grad():
            for batch, target in islice(train_loader, None, n_init_batches):
                init_batches.append(batch)
                init_targets.append(target)

            init_batches = torch.cat(init_batches).to(device)

            assert init_batches.shape[0] == n_init_batches * batch_size

            if y_condition:
                init_targets = torch.cat(init_targets).to(device)
            else:
                init_targets = None

            model(init_batches, init_targets)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):
        evaluator.run(test_loader)
        metrics = evaluator.state.metrics

        losses = ', '.join(
            [f"{key}: {value:.2f}" for key, value in metrics.items()])

        print(f'Validation Results - Epoch: {engine.state.epoch} {losses}')

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]'
        )
        timer.reset()

    trainer.run(train_loader, epochs)
コード例 #4
0
    except (KeyboardInterrupt, SystemExit):
        check_save(model_single, optimizer, args, z_sample, i, save=True)
        raise


if __name__ == '__main__':
    args = parser.parse_args()
    if len(args.load_path) > 0:
        args.startiter = int(args.load_path[:-3].split('_')[-1])
    print(args)

    model_single = Glow(
        1, args.n_flow, args.n_block, affine=args.affine, conv_lu=not args.no_lu
    ).cpu()
    if len(args.load_path) > 0:
        model_single.load_state_dict(torch.load(args.load_path,  map_location=lambda storage, loc: storage))
        
        model_single.initialize()
        gc.collect()
        torch.cuda.empty_cache()
    model = model_single
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    if len(args.load_path) > 0:
        optim_path = '/'.join(args.load_path.split('/')[:-1])
        optimizer.load_state_dict(torch.load(os.path.join(optim_path, 'optimizer.pth'),  map_location=lambda storage, loc: storage))
        gc.collect()
        torch.cuda.empty_cache()
    
    train(args, model, optimizer)
コード例 #5
0
        hparams = json.load(json_file)

    ds = check_dataset(args.dataset, args.dataroot, True, args.download)
    ds2 = check_dataset(args.dataset2, args.dataroot, True, args.download)
    image_shape, num_classes, train_dataset, test_dataset = ds
    image_shape2, num_classes2, train_dataset_2, test_dataset_2 = ds2

    model = Glow(image_shape, hparams['hidden_channels'], hparams['K'],
                 hparams['L'], hparams['actnorm_scale'],
                 hparams['flow_permutation'], hparams['flow_coupling'],
                 hparams['LU_decomposed'], num_classes, hparams['learn_top'],
                 hparams['y_condition'])

    dic = torch.load(checkpoint_path)
    if 'model' in dic.keys():
        model.load_state_dict(dic["model"])
    else:
        model.load_state_dict(dic)
    model.set_actnorm_init()

    model = model.to(device)
    model = model.eval()

    if args.optim_type == "ADAM":
        optim_default = partial(optim.Adam, lr=args.lr_test)
    elif args.optim_type == "SGD":
        optim_default = partial(optim.SGD,
                                lr=args.lr_test,
                                momentum=args.momentum)

    if args.limited_data is not None:
コード例 #6
0
ファイル: train.py プロジェクト: HugoSenetaire/Glow-PyTorch-1
def main(
    dataset,
    dataset2,
    dataroot,
    download,
    augment,
    batch_size,
    eval_batch_size,
    nlls_batch_size,
    epochs,
    nb_step,
    saved_model,
    seed,
    hidden_channels,
    K,
    L,
    actnorm_scale,
    flow_permutation,
    flow_coupling,
    LU_decomposed,
    learn_top,
    y_condition,
    y_weight,
    max_grad_clip,
    max_grad_norm,
    lr,
    lr_test,
    n_workers,
    cuda,
    n_init_batches,
    output_dir,
    saved_optimizer,
    warmup,
    every_epoch,
):

    device = "cpu" if (not torch.cuda.is_available() or not cuda) else "cuda:0"

    check_manual_seed(seed)

    ds = check_dataset(dataset, dataroot, augment, download)
    ds2 = check_dataset(dataset2, dataroot, augment, download)
    image_shape, num_classes, train_dataset, test_dataset = ds
    image_shape2, num_classes2, train_dataset_2, test_dataset_2 = ds2

    assert(image_shape == image_shape2)
    data1 = []
    data2 = []
    for k in range(nlls_batch_size):
        dataaux, targetaux = test_dataset[k]
        data1.append(dataaux)
        dataaux, targetaux = test_dataset_2[k]
        data2.append(dataaux)


    # Note: unsupported for now
    multi_class = False

    train_loader = data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=n_workers,
        drop_last=True,
    )
    test_loader = data.DataLoader(
        test_dataset,
        batch_size=eval_batch_size,
        shuffle=False,
        num_workers=n_workers,
        drop_last=False,
    )

    model = Glow(
        image_shape,
        hidden_channels,
        K,
        L,
        actnorm_scale,
        flow_permutation,
        flow_coupling,
        LU_decomposed,
        num_classes,
        learn_top,
        y_condition,
    )

    model = model.to(device)
    optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup)  # noqa
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch
        x = x.to(device)

        if y_condition:
            y = y.to(device)
            z, nll, y_logits = model(x, y)
            losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class)
        else:
            z, nll, y_logits = model(x, None)
            losses = compute_loss(nll)

        losses["total_loss"].backward()

        if max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(
                    nll, y_logits, y_weight, y, multi_class, reduction="none"
                )
            else:
                z, nll, y_logits = model(x, None)
                losses = compute_loss(nll, reduction="none")

        return losses

    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(
        output_dir, "glow", n_saved=2, require_empty=False
    )

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        checkpoint_handler,
        {"model": model, "optimizer": optimizer},
    )

    monitoring_metrics = ["total_loss"]
    RunningAverage(output_transform=lambda x: x["total_loss"]).attach(
        trainer, "total_loss"
    )

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(
        lambda x, y: torch.mean(x),
        output_transform=lambda x: (
            x["total_loss"],
            torch.empty(x["total_loss"].shape[0]),
        ),
    ).attach(evaluator, "total_loss")

    if y_condition:
        monitoring_metrics.extend(["nll"])
        RunningAverage(output_transform=lambda x: x["nll"]).attach(trainer, "nll")

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(
            lambda x, y: torch.mean(x),
            output_transform=lambda x: (x["nll"], torch.empty(x["nll"].shape[0])),
        ).attach(evaluator, "nll")

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        model.load_state_dict(torch.load(saved_model)['model'])
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer)['opt'])

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split("_")[-1])/1e3

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(engine.state.dataloader)

    @trainer.on(Events.STARTED)
    def init(engine):
        model.train()

        init_batches = []
        init_targets = []

        with torch.no_grad():
            print(train_loader)
            for batch, target in islice(train_loader, None, n_init_batches):
                init_batches.append(batch)
                init_targets.append(target)

            init_batches = torch.cat(init_batches).to(device)

            assert init_batches.shape[0] == n_init_batches * batch_size

            if y_condition:
                init_targets = torch.cat(init_targets).to(device)
            else:
                init_targets = None

            model(init_batches, init_targets)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):
        evaluator.run(test_loader)

        scheduler.step()
        metrics = evaluator.state.metrics

        losses = ", ".join([f"{key}: {value:.2f}" for key, value in metrics.items()])

        print(f"Validation Results - Epoch: {engine.state.epoch} {losses}")

    timer = Timer(average=True)
    timer.attach(
        trainer,
        start=Events.EPOCH_STARTED,
        resume=Events.ITERATION_STARTED,
        pause=Events.ITERATION_COMPLETED,
        step=Events.ITERATION_COMPLETED,
    )

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f"Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]"
        )
        timer.reset()

    # @trainer.on(Events.EPOCH_COMPLETED)
    # def eval_likelihood(engine):
    #     global_nlls(output_dir, engine.state.epoch, data1, data2, model, dataset1_name = dataset, dataset2_name = dataset2, nb_step = nb_step, every_epoch = every_epoch, optim_default = partial(optim.SGD, lr=1e-5, momentum = 0.))


    trainer.run(train_loader, epochs)
コード例 #7
0
def main(args):
    seed_list = [int(item) for item in args.seed.split(',')]

    for seed in seed_list:

        device = torch.device("cuda")

        experiment_folder = args.experiment_folder + '/' + str(seed) + '/'
        print(experiment_folder)
        #model_name = 'glow_checkpoint_'+ str(args.chk)+'.pth'
        for thing in os.listdir(experiment_folder):
            if 'best' in thing:
                model_name = thing
        print(model_name)

        random.seed(seed)
        torch.manual_seed(seed)
        np.random.seed(seed)
        torch.random.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

        with open(experiment_folder + 'hparams.json') as json_file:
            hparams = json.load(json_file)

        image_shape = (32, 32, 3)
        if hparams['y_condition']:
            num_classes = 2
            num_domains = 0
        elif hparams['d_condition']:
            num_classes = 10
            num_domains = 0
        elif hparams['yd_condition']:
            num_classes = 2
            num_domains = 10
        else:
            num_classes = 2
            num_domains = 0

        model = Glow(image_shape, hparams['hidden_channels'], hparams['K'],
                     hparams['L'], hparams['actnorm_scale'],
                     hparams['flow_permutation'], hparams['flow_coupling'],
                     hparams['LU_decomposed'], num_classes, num_domains,
                     hparams['learn_top'], hparams['y_condition'],
                     hparams['extra_condition'], hparams['sp_condition'],
                     hparams['d_condition'], hparams['yd_condition'])
        print('loading model')
        model.load_state_dict(torch.load(experiment_folder + model_name))
        model.set_actnorm_init()

        model = model.to(device)

        model = model.eval()

        if hparams['y_condition']:
            print('y_condition')

            def sample(model, temp=args.temperature):
                with torch.no_grad():
                    if hparams['y_condition']:
                        print("extra", hparams['extra_condition'])
                        y = torch.eye(num_classes)
                        y = torch.cat(1000 * [y])
                        print(y.size())
                        y_0 = y[::2, :].to(
                            device)  # number hardcoded in model for now
                        y_1 = y[1::2, :].to(device)
                        print(y_0.size())
                        print(y_0)
                        print(y_1)
                        print(y_1.size())
                        images0 = model(z=None,
                                        y_onehot=y_0,
                                        temperature=temp,
                                        reverse=True,
                                        batch_size=1000)
                        images1 = model(z=None,
                                        y_onehot=y_1,
                                        temperature=temp,
                                        reverse=True,
                                        batch_size=1000)
                return images0, images1

            images0, images1 = sample(model)

            os.makedirs(experiment_folder + 'generations/Uninfected',
                        exist_ok=True)
            os.makedirs(experiment_folder + 'generations/Parasitized',
                        exist_ok=True)
            for i in range(images0.size(0)):
                torchvision.utils.save_image(
                    images0[i, :, :, :], experiment_folder +
                    'generations/Uninfected/sample_{}.png'.format(i))
                torchvision.utils.save_image(
                    images1[i, :, :, :], experiment_folder +
                    'generations/Parasitized/sample_{}.png'.format(i))
            images_concat0 = torchvision.utils.make_grid(images0[:64, :, :, :],
                                                         nrow=int(64**0.5),
                                                         padding=2,
                                                         pad_value=255)
            torchvision.utils.save_image(images_concat0,
                                         experiment_folder + '/uninfected.png')
            images_concat1 = torchvision.utils.make_grid(images1[:64, :, :, :],
                                                         nrow=int(64**0.5),
                                                         padding=2,
                                                         pad_value=255)
            torchvision.utils.save_image(
                images_concat1, experiment_folder + '/parasitized.png')

        elif hparams['d_condition']:
            print('d_cond')

            def sample_d(model, idx, batch_size=1000, temp=args.temperature):
                with torch.no_grad():
                    if hparams['d_condition']:

                        y_0 = torch.zeros([batch_size, 10], device='cuda:0')
                        y_0[:, idx] = torch.ones(batch_size)
                        y_0.to(device)
                        print(y_0)

                        # y_1 = torch.zeros([batch_size, 201], device='cuda:0')
                        # y_1[:, 157] = torch.ones(batch_size)
                        # y_1.to(device)
                        # y = torch.eye(num_classes)
                        # y = torch.cat(1000 * [y])
                        # print(y.size())
                        # y_0 = y[::2, :].to(device)  # number hardcoded in model for now
                        # y_1 = y[1::2, :].to(device)
                        # print(y_0.size())
                        # print(y_0)
                        # print(y_1)
                        # print(y_1.size())

                        images0 = model(z=None,
                                        y_onehot=y_0,
                                        temperature=temp,
                                        reverse=True,
                                        batch_size=1000)
                        # images1 = model(z=None, y_onehot=y_1, temperature=1.0, reverse=True, batch_size=1000)
                return images0

            for idx, dom in enumerate(["C116P77ThinF", "C132P93ThinF", "C137P98ThinF", "C180P141NThinF", "C182P143NThinF", \
                             "C184P145ThinF", "C39P4thinF", 'C59P20thinF', "C68P29N", "C99P60ThinF"]):

                images0 = sample_d(model, idx)

                os.makedirs(experiment_folder + 'generations/' + dom +
                            '/Uninfected/',
                            exist_ok=True)
                os.makedirs(experiment_folder + 'generations/' + dom +
                            '/Parasitized/',
                            exist_ok=True)
                # os.makedirs(experiment_folder + 'generations/C59P20thinF/Uninfected/', exist_ok=True)
                # os.makedirs(experiment_folder + 'generations/C59P20thinF/Parasitized/', exist_ok=True)
                for i in range(images0.size(0)):
                    torchvision.utils.save_image(
                        images0[i, :, :, :],
                        experiment_folder + 'generations/' + dom +
                        '/Uninfected/sample_{}.png'.format(i))
                    #torchvision.utils.save_image(images1[i, :, :, :], experiment_folder + 'generations/C59P20thinF/Parasitized/sample_{}.png'.format(i))
                images_concat0 = torchvision.utils.make_grid(
                    images0[:25, :, :, :],
                    nrow=int(25**0.5),
                    padding=2,
                    pad_value=255)
                torchvision.utils.save_image(images_concat0,
                                             experiment_folder + dom + '.png')
                # images_concat1 = torchvision.utils.make_grid(images1[:64,:,:,:], nrow=int(64 ** 0.5), padding=2, pad_value=255)
                # torchvision.utils.save_image(images_concat1, experiment_folder + 'C59P20thinF.png')

        elif hparams['yd_condition']:

            def sample_YD(model, idx, batch_size=1000, temp=args.temperature):
                with torch.no_grad():
                    if hparams['yd_condition']:
                        y_0 = torch.zeros([batch_size, 12], device='cuda:0')
                        y_0[:, 0] = torch.ones(batch_size)
                        y_0[:, idx + 2] = torch.ones(batch_size)
                        y_0.to(device)
                        print(y_0)

                        y_1 = torch.zeros([batch_size, 12], device='cuda:0')
                        y_1[:, 1] = torch.ones(batch_size)
                        y_1[:, idx + 2] = torch.ones(batch_size)
                        y_1.to(device)
                        print(y_1)

                        images0 = model(z=None,
                                        y_onehot=y_0,
                                        temperature=temp,
                                        reverse=True,
                                        batch_size=1000)
                        images1 = model(z=None,
                                        y_onehot=y_1,
                                        temperature=temp,
                                        reverse=True,
                                        batch_size=1000)
                return images0, images1

            def sample_DD(model, idx, batch_size=1000, temp=args.temperature):
                with torch.no_grad():
                    if hparams['yd_condition']:
                        y_1 = torch.zeros([batch_size, 20], device='cuda:0')
                        y_1[:, idx] = torch.ones(batch_size)
                        y_1.to(device)
                        print(y_1)

                        y_0 = torch.zeros([batch_size, 20], device='cuda:0')
                        y_0[:, idx + 10] = torch.ones(batch_size)
                        y_0.to(device)
                        print(y_0)

                        images0 = model(z=None,
                                        y_onehot=y_0,
                                        temperature=temp,
                                        reverse=True,
                                        batch_size=1000)
                        images1 = model(z=None,
                                        y_onehot=y_1,
                                        temperature=temp,
                                        reverse=True,
                                        batch_size=1000)
                return images0, images1

            for idx, dom in enumerate(
                    ["C116P77ThinF", "C132P93ThinF", "C137P98ThinF", "C180P141NThinF", "C182P143NThinF", \
                     "C184P145ThinF", "C39P4thinF", 'C59P20thinF', "C68P29N", "C99P60ThinF"]):

                images0, images1 = sample_YD(model, idx)

                os.makedirs(experiment_folder + 'generations/' + dom +
                            '/Uninfected/',
                            exist_ok=True)
                os.makedirs(experiment_folder + 'generations/' + dom +
                            '/Parasitized/',
                            exist_ok=True)
                for i in range(images0.size(0)):
                    torchvision.utils.save_image(
                        images0[i, :, :, :],
                        experiment_folder + 'generations/' + dom +
                        '/Uninfected/sample_{}.png'.format(i))
                    torchvision.utils.save_image(
                        images1[i, :, :, :],
                        experiment_folder + 'generations/' + dom +
                        '/Parasitized/sample_{}.png'.format(i))
                images_concat0 = torchvision.utils.make_grid(
                    images0[:64, :, :, :],
                    nrow=int(64**0.5),
                    padding=2,
                    pad_value=255)
                torchvision.utils.save_image(
                    images_concat0, experiment_folder + dom +
                    str(args.temperature) + '_uninfected.png')
                images_concat1 = torchvision.utils.make_grid(
                    images1[:64, :, :, :],
                    nrow=int(64**0.5),
                    padding=2,
                    pad_value=255)
                torchvision.utils.save_image(
                    images_concat1, experiment_folder + dom +
                    str(args.temperature) + '_parasitized.png')

        else:

            def sample(model, temp=args.temperature):
                with torch.no_grad():
                    images = model(z=None,
                                   y_onehot=None,
                                   temperature=temp,
                                   reverse=True,
                                   batch_size=1000)

                return images

            images = sample(model)

            os.makedirs('unconditioned/' + str(seed) + '/generations/' +
                        experiment_folder[:-3],
                        exist_ok=True)
            for i in range(images.size(0)):
                torchvision.utils.save_image(
                    images[i, :, :, :],
                    'unconditioned/' + str(seed) + '/generations/' +
                    experiment_folder[:-2] + 'sample_{}.png'.format(i))

            images_concat = torchvision.utils.make_grid(images[:64, :, :, :],
                                                        nrow=int(64**0.5),
                                                        padding=2,
                                                        pad_value=255)
            torchvision.utils.save_image(
                images_concat, 'unconditioned/' + str(seed) + '/' +
                experiment_folder[:-3] + '.png')
コード例 #8
0
ファイル: sample_mnist.py プロジェクト: wangkua1/Glow-PyTorch
model_name = 'glow_model_1.pth'

with open(os.path.join(output_folder, 'hparams.json')) as json_file:
    hparams = json.load(json_file)

image_shape, num_classes, _, test_mnist = get_MNIST(False, hparams['dataroot'],
                                                    hparams['download'])

model = Glow(
    image_shape, hparams['hidden_channels'], hparams['K'], hparams['L'],
    hparams['actnorm_scale'], hparams['flow_permutation'],
    hparams['flow_coupling'], hparams['LU_decomposed'], num_classes,
    hparams['learn_top'], hparams['y_condition'],
    False if 'logittransform' not in hparams else hparams['logittransform'])

model.load_state_dict(torch.load(os.path.join(output_folder, model_name)))
model.set_actnorm_init()

model = model.to(device)

model = model.eval()


def sample(model):
    with torch.no_grad():
        assert not hparams['y_condition']
        y = None
        images = model(y_onehot=y, temperature=1, reverse=True, batch_size=32)
        # images = postprocess(model(y_onehot=y, temperature=1, reverse=True))

    return images.cpu()
コード例 #9
0
hparams['dataroot'] = '../mutual-information'

image_shape, num_classes, _, test_cifar = get_CIFAR10(hparams['augment'],
                                                      hparams['dataroot'],
                                                      hparams['download'])
image_shape, num_classes, _, test_svhn = get_SVHN(hparams['augment'],
                                                  hparams['dataroot'],
                                                  hparams['download'])

model = Glow(image_shape, hparams['hidden_channels'], hparams['K'],
             hparams['L'], hparams['actnorm_scale'],
             hparams['flow_permutation'], hparams['flow_coupling'],
             hparams['LU_decomposed'], num_classes, hparams['learn_top'],
             hparams['y_condition'])

model.load_state_dict(torch.load(output_folder + model_name))
model.set_actnorm_init()

model = model.to(device)

model = model.eval()


def sample(model):
    with torch.no_grad():
        if hparams['y_condition']:
            y = torch.eye(num_classes)
            y = y.repeat(batch_size // num_classes + 1)
            y = y[:32, :].to(device)  # number hardcoded in model for now
        else:
            y = None
コード例 #10
0
ファイル: gen3d.py プロジェクト: supri-a/RockFlow
def main(args):
    # torch.manual_seed(args.seed)

    # Test loading and sampling
    output_folder = os.path.join('results', args.name)

    with open(os.path.join(output_folder, 'hparams.json')) as json_file:
        hparams = json.load(json_file)

    device = "cpu" if not torch.cuda.is_available() else "cuda:0"
    image_shape = (hparams['patch_size'], hparams['patch_size'],
                   args.n_modalities)
    num_classes = 1

    print('Loading model...')
    model = Glow(image_shape, hparams['hidden_channels'], hparams['K'],
                 hparams['L'], hparams['actnorm_scale'],
                 hparams['flow_permutation'], hparams['flow_coupling'],
                 hparams['LU_decomposed'], num_classes, hparams['learn_top'],
                 hparams['y_condition'])

    model_chkpt = torch.load(
        os.path.join(output_folder, 'checkpoints', args.model))
    model.load_state_dict(model_chkpt['model'])
    model.set_actnorm_init()
    model = model.to(device)

    # Build images
    model.eval()
    temperature = args.temperature

    if args.steps is None:  # automatically calculate step size if no step size

        fig_dir = os.path.join(output_folder, 'stepnum_results')
        if not os.path.exists(fig_dir):
            os.mkdir(fig_dir)

        print('No step size entered')

        # Create sample of images to estimate chord length
        with torch.no_grad():
            mean, logs = model.prior(None, None)
            z = gaussian_sample(mean, logs, temperature)
            images_raw = model(z=z, temperature=temperature, reverse=True)
        images_raw[torch.isnan(images_raw)] = 0.5
        images_raw[torch.isinf(images_raw)] = 0.5
        images_raw = torch.clamp(images_raw, -0.5, 0.5)

        images_out = np.transpose(
            np.squeeze(images_raw[:, args.step_modality, :, :].cpu().numpy()),
            (1, 0, 2))

        # Threshold images and compute covariances
        if args.binary_data:
            thresh = 0
        else:
            thresh = threshold_otsu(images_out)
        images_bin = np.greater(images_out, thresh)
        x_cov = two_point_correlation(images_bin, 0)
        y_cov = two_point_correlation(images_bin, 1)

        # Compute chord length
        cov_avg = np.mean(np.mean(np.concatenate((x_cov, y_cov), axis=2),
                                  axis=0),
                          axis=0)
        N = 5
        S20, _ = curve_fit(straight_line_at_origin(cov_avg[0]), range(0, N),
                           cov_avg[0:N])
        l_pore = np.abs(cov_avg[0] / S20)
        steps = int(l_pore)
        print('Calculated step size: {}'.format(steps))

    else:
        print('Using user-entered step size {}...'.format(args.steps))
        steps = args.steps

    # Build desired number of volumes
    for iter_vol in range(args.iter):
        if args.iter == 1:
            stack_dir = os.path.join(output_folder, 'image_stacks',
                                     args.save_name)
            print('Sampling images, saving to {}...'.format(args.save_name))
        else:
            stack_dir = os.path.join(
                output_folder, 'image_stacks',
                args.save_name + '_' + str(iter_vol).zfill(3))
            print('Sampling images, saving to {}_'.format(args.save_name) +
                  str(iter_vol).zfill(3) + '...')
        if not os.path.exists(stack_dir):
            os.makedirs(stack_dir)

        with torch.no_grad():
            mean, logs = model.prior(None, None)
            alpha = 1 - torch.reshape(torch.linspace(0, 1, steps=steps),
                                      (-1, 1, 1, 1))
            alpha = alpha.to(device)

            num_imgs = int(np.ceil(hparams['patch_size'] / steps) + 1)
            z = gaussian_sample(mean, logs, temperature)[:num_imgs, ...]
            z = torch.cat([
                alpha * z[i, ...] + (1 - alpha) * z[i + 1, ...]
                for i in range(num_imgs - 1)
            ])
            z = z[:hparams['patch_size'], ...]

            images_raw = model(z=z, temperature=temperature, reverse=True)

        images_raw[torch.isnan(images_raw)] = 0.5
        images_raw[torch.isinf(images_raw)] = 0.5
        images_raw = torch.clamp(images_raw, -0.5, 0.5)

        # apply median filter to output
        if args.med_filt is not None or args.binary_data:
            for m in range(args.n_modalities):
                if args.binary_data:
                    SE = ball(1)
                else:
                    SE = ball(args.med_filt)
                images_np = np.squeeze(images_raw[:, m, :, :].cpu().numpy())
                images_filt = median_filter(images_np, footprint=SE)

                # Erode binary images
                if args.binary_data:
                    images_filt = np.greater(images_filt, 0)
                    SE = ball(1)
                    images_filt = 1.0 * binary_erosion(images_filt,
                                                       selem=SE) - 0.5

                images_raw[:, m, :, :] = torch.tensor(images_filt,
                                                      device=device)

        images1 = postprocess(images_raw).cpu()
        images2 = postprocess(torch.transpose(images_raw, 0, 2)).cpu()
        images3 = postprocess(torch.transpose(images_raw, 0, 3)).cpu()

        # apply Otsu thresholding to output
        if args.save_binary and not args.binary_data:
            thresh = threshold_otsu(images1.numpy())
            images1[images1 < thresh] = 0
            images1[images1 > thresh] = 255
            images2[images2 < thresh] = 0
            images2[images2 > thresh] = 255
            images3[images3 < thresh] = 0
            images3[images3 > thresh] = 255

        # # erode binary images by 1 px to correct for training image transformation
        # if args.binary_data:
        #     images1 = np.greater(images1.numpy(), 127)
        #     images2 = np.greater(images2.numpy(), 127)
        #     images3 = np.greater(images3.numpy(), 127)

        #     images1 = 255*torch.tensor(1.0*np.expand_dims(binary_erosion(np.squeeze(images1), selem=np.ones((1,2,2))), 1))
        #     images2 = 255*torch.tensor(1.0*np.expand_dims(binary_erosion(np.squeeze(images2), selem=np.ones((2,1,2))), 1))
        #     images3 = 255*torch.tensor(1.0*np.expand_dims(binary_erosion(np.squeeze(images3), selem=np.ones((2,2,1))), 1))

        # save video for each modality
        for m in range(args.n_modalities):
            if args.n_modalities > 1:
                save_dir = os.path.join(stack_dir, 'modality{}'.format(m))
            else:
                save_dir = stack_dir

            if not os.path.exists(save_dir):
                os.makedirs(save_dir)

            write_video(images1[:, m, :, :], 'xy', hparams, save_dir)
            write_video(images2[:, m, :, :], 'xz', hparams, save_dir)
            write_video(images3[:, m, :, :], 'yz', hparams, save_dir)

    print('Finished!')
コード例 #11
0
def main(
    dataset,
    dataroot,
    download,
    augment,
    batch_size,
    eval_batch_size,
    epochs,
    saved_model,
    seed,
    hidden_channels,
    K,
    L,
    actnorm_scale,
    flow_permutation,
    flow_coupling,
    LU_decomposed,
    learn_top,
    y_condition,
    y_weight,
    max_grad_clip,
    max_grad_norm,
    lr,
    n_workers,
    cuda,
    n_init_batches,
    output_dir,
    saved_optimizer,
    warmup,
    classifier_weight
):

    device = "cpu" if (not torch.cuda.is_available() or not cuda) else "cuda:0"
    wandb.init(project=args.dataset)

    check_manual_seed(seed)

    image_shape = (64,64,3)
    # if args.dataset == "task1": num_classes = 24
    # else : num_classes = 40

    num_classes = 40

    # Note: unsupported for now
    multi_class = True #It's True but this variable doesn't be used now


    # if args.dataset == "task1":
    #     dataset_train = CLEVRDataset(root_folder=args.dataroot,img_folder=args.dataroot+'images/')
    #     train_loader = DataLoader(dataset_train,batch_size=args.batch_size,shuffle=True,drop_last=True)
    # else :
    #     dataset_train = CelebALoader(root_folder=args.dataroot) #'/home/arg/courses/machine_learning/homework/deep_learning_and_practice/Lab7/dataset/task_2/'
    #     train_loader = DataLoader(dataset_train,batch_size=args.batch_size,shuffle=True,drop_last=True)

    dataset_train = CelebALoader(root_folder=args.dataroot) #'/home/arg/courses/machine_learning/homework/deep_learning_and_practice/Lab7/dataset/task_2/'
    train_loader = DataLoader(dataset_train,batch_size=args.batch_size,shuffle=True,drop_last=True)    


    model = Glow(
        image_shape,
        hidden_channels,
        K,
        L,
        actnorm_scale,
        flow_permutation,
        flow_coupling,
        LU_decomposed,
        num_classes,
        learn_top,
        y_condition,
    )

    model = model.to(device)
    optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup)  # noqa
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    wandb.watch(model)

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch
        x = x.to(device)
        if y_condition:
            y = y.to(device)
            z, nll, y_logits = model(x, y)
            ### x: torch.Size([batchsize, 3, 64, 64]); y: torch.Size([batchsize, 24]); z: torch.Size([batchsize, 48, 8, 8])
            losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class)
        else:
            z, nll, y_logits = model(x, None)
            losses = compute_loss(nll)

        losses["total_loss"].backward()

        if max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()


        return losses


    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(
        output_dir, "glow", n_saved=None, require_empty=False
    )
    ### n_saved (Optional[int]) – Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept.

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        checkpoint_handler,
        {"model": model, "optimizer": optimizer},
    )

    monitoring_metrics = ["total_loss"]
    RunningAverage(output_transform=lambda x: x["total_loss"]).attach(
        trainer, "total_loss"
    )


    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)


    if saved_model:
        model.load_state_dict(torch.load(saved_model, map_location="cpu")['model'])
        model.set_actnorm_init()

    @trainer.on(Events.STARTED)
    def init(engine):
        model.train()

        init_batches = []
        init_targets = []

        with torch.no_grad():
            for batch, target in islice(train_loader, None, n_init_batches):
                init_batches.append(batch)
                init_targets.append(target)

            init_batches = torch.cat(init_batches).to(device)

            assert init_batches.shape[0] == n_init_batches * batch_size

            if y_condition:
                init_targets = torch.cat(init_targets).to(device)
            else:
                init_targets = None

            model(init_batches, init_targets)




    # evaluator = evaluation_model(args.classifier_weight)
    # @trainer.on(Events.EPOCH_COMPLETED)
    # def evaluate(engine):
    #     if args.dataset == "task1":
    #         model.eval()
    #         with torch.no_grad():
    #             test_conditions = get_test_conditions(args.dataroot).cuda()
    #             predict_x = postprocess(model(y_onehot=test_conditions, temperature=1, reverse=True)).float()
    #             score = evaluator.eval(predict_x, test_conditions)
    #             save_image(predict_x.float(), args.output_dir+f"/Epoch{engine.state.epoch}_score{score:.3f}.png", normalize=True)

    #             test_conditions = get_new_test_conditions(args.dataroot).cuda()
    #             predict_x = postprocess(model(y_onehot=test_conditions, temperature=1, reverse=True)).float()
    #             newscore = evaluator.eval(predict_x.float(), test_conditions)
    #             save_image(predict_x.float(), args.output_dir+f"/Epoch{engine.state.epoch}_newscore{newscore:.3f}.png", normalize=True)

    #             print(f"Iter: {engine.state.iteration}  score:{score:.3f} newscore:{newscore:.3f} ")
    #             wandb.log({"score": score, "new_score": newscore})




    trainer.run(train_loader, epochs)
コード例 #12
0
ファイル: analyze.py プロジェクト: wangkua1/Glow-PyTorch
def main(dataset, dataroot, download, augment, n_workers, eval_batch_size, output_dir,db, glow_path,ckpt_name):

    
    (image_shape, num_classes, train_dataset, test_dataset) = check_dataset(dataset, dataroot, augment, download)

    test_loader = data.DataLoader(test_dataset, batch_size=eval_batch_size,
                                  shuffle=False, num_workers=n_workers,
                                  drop_last=False)

    x = test_loader.__iter__().__next__()[0].to(device)

    # OOD data
    ood_distributions = ['gaussian']
    # ood_distributions = ['gaussian', 'rademacher', 'texture3', 'svhn','tinyimagenet','lsun']
    tr = transforms.Compose([])
    tr.transforms.append(transforms.ToPILImage()) 
    tr.transforms.append(transforms.Resize((32,32)))
    tr.transforms.append(transforms.ToTensor())
    tr.transforms.append(one_to_three_channels)
    tr.transforms.append(preprocess)
    ood_tensors = [(out_name, torch.stack([tr(x) for x in load_ood_data({
                                  'name': out_name,
                                  'ood_scale': 1,
                                  'n_anom': eval_batch_size,
                                })]).to(device)
                        ) for out_name in ood_distributions]
    if 'sd' in glow_path:
        with open(os.path.join(os.path.dirname(glow_path), 'hparams.json'), 'r') as f:
            model_kwargs = json.load(f)
        model = Glow(
                (32, 32, 3), 
                model_kwargs['hidden_channels'], 
                model_kwargs['K'], 
                model_kwargs['L'], 
                model_kwargs['actnorm_scale'],
                model_kwargs['flow_permutation'], 
                model_kwargs['flow_coupling'], 
                model_kwargs['LU_decomposed'], 
                10,
                model_kwargs['learn_top'], 
                model_kwargs['y_condition'],
                model_kwargs['logittransform'],
                model_kwargs['sn'],
                model_kwargs['affine_eps'],
                model_kwargs['no_actnorm'],
                model_kwargs['affine_scale_eps'], 
                model_kwargs['actnorm_max_scale'], 
                model_kwargs['no_conv_actnorm'],
                model_kwargs['affine_max_scale'],
                model_kwargs['actnorm_eps'],
                model_kwargs['no_split']
            )
        model.load_state_dict(torch.load(glow_path))
        model.set_actnorm_init()
    else:
        model = torch.load(glow_path)
    model = model.to(device)
    model.eval()

    with torch.no_grad():
        samples = generate_from_noise(model, eval_batch_size,clamp=False, guard_nans=False)
    stats = OrderedDict()
    for name, x in [('data',x), ('samples',samples)] + ood_tensors:
        p_pxs, p_ims, cn, dlogdet, bpd, pad = run_analysis(x, model, os.path.join(output_dir, f'recon_{ckpt_name}_{name}.jpeg'))
        
        stats[f"{name}-percent-pixels-nans"] =  p_pxs
        stats[f"{name}-percent-imgs-nans"] =  p_ims
        stats[f"{name}-cn"] =  cn
        stats[f"{name}-dlogdet"] =  dlogdet
        stats[f"{name}-bpd"] =  bpd
        stats[f"{name}-recon-err"] =  pad
        
        with open(os.path.join(output_dir, f'results_{ckpt_name}.json'), 'w') as fp:
            json.dump(stats, fp, indent=4)
コード例 #13
0
def main(
    dataset,
    augment,
    batch_size,
    eval_batch_size,
    epochs,
    saved_model,
    seed,
    hidden_channels,
    K,
    L,
    actnorm_scale,
    flow_permutation,
    flow_coupling,
    LU_decomposed,
    learn_top,
    y_condition,
    extra_condition,
    sp_condition,
    d_condition,
    yd_condition,
    y_weight,
    d_weight,
    max_grad_clip,
    max_grad_norm,
    lr,
    n_workers,
    cuda,
    n_init_batches,
    output_dir,
    missing,
    saved_optimizer,
    warmup,
):

    print(output_dir)
    device = "cpu" if (not torch.cuda.is_available() or not cuda) else "cuda:0"
    print(device)
    check_manual_seed(seed)
    print("augmenting?", augment)
    train_dataset, test_dataset = check_dataset(dataset, augment, missing)
    image_shape = (32, 32, 3)

    multi_class = False

    if yd_condition:
        num_classes = 2
        num_domains = 10
        #num_classes = 10+2
        #multi_class=True
    elif d_condition:
        num_classes = 10
        num_domains = 0
    else:
        num_classes = 2
        num_domains = 0
    #print("num classes", num_classes)

    train_loader = data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=n_workers,
        drop_last=True,
    )
    test_loader = data.DataLoader(
        test_dataset,
        batch_size=eval_batch_size,
        shuffle=False,
        num_workers=n_workers,
        drop_last=False,
    )

    model = Glow(image_shape, hidden_channels, K, L, actnorm_scale,
                 flow_permutation, flow_coupling, LU_decomposed, num_classes,
                 num_domains, learn_top, y_condition, extra_condition,
                 sp_condition, d_condition, yd_condition)

    model = model.to(device)
    optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup)  # noqa
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                  lr_lambda=lr_lambda)

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y, d, yd = batch
        x = x.to(device)

        if y_condition:
            y = y.to(device)
            z, nll, y_logits, spare = model(x, y)
            losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class)
        elif d_condition:
            d = d.to(device)
            z, nll, d_logits, spare = model(x, d)

            losses = compute_loss_y(nll, d_logits, d_weight, d, multi_class)
        elif yd_condition:
            y, d, yd = y.to(device), d.to(device), yd.to(device)
            z, nll, y_logits, d_logits = model(x, yd)
            losses = compute_loss_yd(nll, y_logits, y_weight, y, d_logits,
                                     d_weight, d)
        else:
            print("none")
            z, nll, y_logits, spare = model(x, None)
            losses = compute_loss(nll)

        losses["total_loss"].backward()

        if max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y, d, yd = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits, none_logits = model(x, y)
                losses = compute_loss_y(nll,
                                        y_logits,
                                        y_weight,
                                        y,
                                        multi_class,
                                        reduction="none")
            elif d_condition:
                d = d.to(device)
                z, nll, d_logits, non_logits = model(x, d)
                losses = compute_loss_y(nll,
                                        d_logits,
                                        d_weight,
                                        d,
                                        multi_class,
                                        reduction="none")
            elif yd_condition:
                y, d, yd = y.to(device), d.to(device), yd.to(device)
                z, nll, y_logits, d_logits = model(x, yd)
                losses = compute_loss_yd(nll,
                                         y_logits,
                                         y_weight,
                                         y,
                                         d_logits,
                                         d_weight,
                                         d,
                                         reduction="none")
            else:

                z, nll, y_logits, d_logits = model(x, None)
                losses = compute_loss(nll, reduction="none")
        #print(losses, "losssssess")
        return losses

    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         "glow",
                                         save_interval=1,
                                         n_saved=2,
                                         require_empty=False)

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        checkpoint_handler,
        {
            "model": model,
            "optimizer": optimizer
        },
    )

    monitoring_metrics = ["total_loss"]
    RunningAverage(output_transform=lambda x: x["total_loss"]).attach(
        trainer, "total_loss")

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(
        lambda x, y: torch.mean(x),
        output_transform=lambda x: (
            x["total_loss"],
            torch.empty(x["total_loss"].shape[0]),
        ),
    ).attach(evaluator, "total_loss")

    if y_condition or d_condition or yd_condition:
        monitoring_metrics.extend(["nll"])
        RunningAverage(output_transform=lambda x: x["nll"]).attach(
            trainer, "nll")

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(
            lambda x, y: torch.mean(x),
            output_transform=lambda x:
            (x["nll"], torch.empty(x["nll"].shape[0])),
        ).attach(evaluator, "nll")

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        model.load_state_dict(torch.load(saved_model))
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer))

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split("_")[-1])

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(
                engine.state.dataloader)

    @trainer.on(Events.STARTED)
    def init(engine):
        model.train()

        init_batches = []
        init_targets = []
        init_domains = []
        init_yds = []

        with torch.no_grad():
            for batch, target, domain, yd in islice(train_loader, None,
                                                    n_init_batches):
                init_batches.append(batch)
                init_targets.append(target)
                init_domains.append(domain)
                init_yds.append(yd)

            init_batches = torch.cat(init_batches).to(device)

            assert init_batches.shape[0] == n_init_batches * batch_size

            if y_condition:
                init_targets = torch.cat(init_targets).to(device)
                model(init_batches, init_targets)
            elif d_condition:
                init_domains = torch.cat(init_domains).to(device)
                model(init_batches, init_domains)
            elif yd_condition:
                init_yds = torch.cat(init_yds).to(device)
                model(init_batches, init_yds)
            else:
                init_targets = None
                model(init_batches, init_targets)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):
        evaluator.run(test_loader)
        #print("done")
        scheduler.step()
        metrics = evaluator.state.metrics

        losses = ", ".join(
            [f"{key}: {value:.8f}" for key, value in metrics.items()])

        print(f"Validation Results - Epoch: {engine.state.epoch} {losses}")

    def score_function(engine):
        val_loss = engine.state.metrics['total_loss']

        return -val_loss

    name = "best_"

    val_handler = ModelCheckpoint(output_dir,
                                  name,
                                  score_function=score_function,
                                  score_name="val_loss",
                                  n_saved=1,
                                  require_empty=False)

    evaluator.add_event_handler(
        Events.EPOCH_COMPLETED,
        val_handler,
        {"model": model},
    )

    timer = Timer(average=True)
    timer.attach(
        trainer,
        start=Events.EPOCH_STARTED,
        resume=Events.ITERATION_STARTED,
        pause=Events.ITERATION_COMPLETED,
        step=Events.ITERATION_COMPLETED,
    )

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f"Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]"
        )
        timer.reset()

    trainer.run(train_loader, epochs)
コード例 #14
0
ファイル: train_gan.py プロジェクト: wangkua1/Glow-PyTorch
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size,
         epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale,
         flow_permutation, flow_coupling, LU_decomposed, learn_top,
         y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers,
         cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer,
         warmup, fresh, logittransform, gan, disc_lr):

    device = 'cpu' if (not torch.cuda.is_available() or not cuda) else 'cuda:0'

    check_manual_seed(seed)

    ds = check_dataset(dataset, dataroot, augment, download)
    image_shape, num_classes, train_dataset, test_dataset = ds

    # Note: unsupported for now
    multi_class = False

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=n_workers,
                                   drop_last=True)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=eval_batch_size,
                                  shuffle=False,
                                  num_workers=n_workers,
                                  drop_last=False)

    model = Glow(image_shape, hidden_channels, K, L, actnorm_scale,
                 flow_permutation, flow_coupling, LU_decomposed, num_classes,
                 learn_top, y_condition, logittransform)

    model = model.to(device)

    if gan:
        # Debug
        model = mine.Generator(32, 1).to(device)

        optimizer = optim.Adam(model.parameters(),
                               lr=lr,
                               betas=(.5, .99),
                               weight_decay=0)
        discriminator = mine.Discriminator(image_shape[-1])
        discriminator = discriminator.to(device)
        D_optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                        discriminator.parameters()),
                                 lr=disc_lr,
                                 betas=(.5, .99),
                                 weight_decay=0)
    else:
        optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    # lr_lambda = lambda epoch: lr * min(1., epoch+1 / warmup)
    # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    i = 0

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch
        x = x.to(device)

        if y_condition:
            y = y.to(device)
            z, nll, y_logits = model(x, y)
            losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class)
        else:
            z, nll, y_logits = model(x, None)
            losses = compute_loss(nll)

        losses['total_loss'].backward()

        if max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        return losses

    def gan_step(engine, batch):
        assert not y_condition
        if 'iter_ind' in dir(engine):
            engine.iter_ind += 1
        else:
            engine.iter_ind = -1
        losses = {}
        model.train()
        discriminator.train()

        x, y = batch
        x = x.to(device)

        # def generate_from_noise(batch_size):
        #     _, c2, h, w  = model.prior_h.shape
        #     c = c2 // 2
        #     zshape = (batch_size, c, h, w)
        #     randz  = torch.autograd.Variable(torch.randn(zshape), requires_grad=True).to(device)
        #     images = model(z= randz, y_onehot=None, temperature=1, reverse=True,batch_size=batch_size)
        #     return images

        def generate_from_noise(batch_size):

            zshape = (batch_size, 32, 1, 1)
            randz = torch.randn(zshape).to(device)
            images = model(randz)
            return images / 2

        def run_noised_disc(discriminator, x):
            x = uniform_binning_correction(x)[0]
            return discriminator(x)

        # Train Disc
        fake = generate_from_noise(x.size(0))

        D_real_scores = run_noised_disc(discriminator, x.detach())
        D_fake_scores = run_noised_disc(discriminator, fake.detach())

        ones_target = torch.ones((x.size(0), 1), device=x.device)
        zeros_target = torch.zeros((x.size(0), 1), device=x.device)

        # D_real_accuracy = torch.sum(torch.round(F.sigmoid(D_real_scores)) == ones_target).float() / ones_target.size(0)
        # D_fake_accuracy = torch.sum(torch.round(F.sigmoid(D_fake_scores)) == zeros_target).float() / zeros_target.size(0)

        D_real_loss = F.binary_cross_entropy_with_logits(
            D_real_scores, ones_target)
        D_fake_loss = F.binary_cross_entropy_with_logits(
            D_fake_scores, zeros_target)

        D_loss = (D_real_loss + D_fake_loss) / 2
        gp = gradient_penalty(x.detach(), fake.detach(),
                              lambda _x: run_noised_disc(discriminator, _x))
        D_loss_plus_gp = D_loss + 10 * gp
        D_optimizer.zero_grad()
        D_loss_plus_gp.backward()
        D_optimizer.step()

        # Train generator
        fake = generate_from_noise(x.size(0))
        G_loss = F.binary_cross_entropy_with_logits(
            run_noised_disc(discriminator, fake),
            torch.ones((x.size(0), 1), device=x.device))
        losses['total_loss'] = G_loss

        # G-step
        optimizer.zero_grad()
        losses['total_loss'].backward()
        params = list(model.parameters())
        gnorm = [p.grad.norm() for p in params]
        optimizer.step()
        # if max_grad_clip > 0:
        #     torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        # if max_grad_norm > 0:
        #     torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        if engine.iter_ind % 50 == 0:
            grid = make_grid((postprocess(fake.detach().cpu())[:30]),
                             nrow=6).permute(1, 2, 0)
            plt.figure(figsize=(10, 10))
            plt.imshow(grid)
            plt.axis('off')
            plt.savefig(
                os.path.join(output_dir, f'sample_{engine.iter_ind}.png'))

            grid = make_grid(
                (postprocess(uniform_binning_correction(x)[0].cpu())[:30]),
                nrow=6).permute(1, 2, 0)
            plt.figure(figsize=(10, 10))
            plt.imshow(grid)
            plt.axis('off')
            plt.savefig(os.path.join(output_dir,
                                     f'data_{engine.iter_ind}.png'))

        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(nll,
                                        y_logits,
                                        y_weight,
                                        y,
                                        multi_class,
                                        reduction='none')
            else:
                z, nll, y_logits = model(x, None)
                losses = compute_loss(nll, reduction='none')

        return losses

    if gan:
        trainer = Engine(gan_step)
    else:
        trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         'glow',
                                         save_interval=1,
                                         n_saved=2,
                                         require_empty=False)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {
        'model': model,
        'optimizer': optimizer
    })

    monitoring_metrics = ['total_loss']
    RunningAverage(output_transform=lambda x: x['total_loss']).attach(
        trainer, 'total_loss')

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(lambda x, y: torch.mean(x),
         output_transform=lambda x:
         (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach(
             evaluator, 'total_loss')

    if y_condition:
        monitoring_metrics.extend(['nll'])
        RunningAverage(output_transform=lambda x: x['nll']).attach(
            trainer, 'nll')

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(lambda x, y: torch.mean(x),
             output_transform=lambda x:
             (x['nll'], torch.empty(x['nll'].shape[0]))).attach(
                 evaluator, 'nll')

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        model.load_state_dict(torch.load(saved_model))
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer))

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split('_')[-1])

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(
                engine.state.dataloader)

    # @trainer.on(Events.STARTED)
    # def init(engine):
    #     model.train()

    #     init_batches = []
    #     init_targets = []

    #     with torch.no_grad():
    #         for batch, target in islice(train_loader, None,
    #                                     n_init_batches):
    #             init_batches.append(batch)
    #             init_targets.append(target)

    #         init_batches = torch.cat(init_batches).to(device)

    #         assert init_batches.shape[0] == n_init_batches * batch_size

    #         if y_condition:
    #             init_targets = torch.cat(init_targets).to(device)
    #         else:
    #             init_targets = None

    #         model(init_batches, init_targets)

    # @trainer.on(Events.EPOCH_COMPLETED)
    # def evaluate(engine):
    #     evaluator.run(test_loader)

    #     # scheduler.step()
    #     metrics = evaluator.state.metrics

    #     losses = ', '.join([f"{key}: {value:.2f}" for key, value in metrics.items()])

    #     myprint(f'Validation Results - Epoch: {engine.state.epoch} {losses}')

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]'
        )
        timer.reset()

    trainer.run(train_loader, epochs)
コード例 #15
0
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size,
         epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale,
         flow_permutation, flow_coupling, LU_decomposed, learn_top,
         y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers,
         cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer,
         warmup, fresh, logittransform, gan, disc_lr, sn, flowgan, eval_every,
         ld_on_samples, weight_gan, weight_prior, weight_logdet,
         jac_reg_lambda, affine_eps, no_warm_up, optim_name, clamp, svd_every,
         eval_only, no_actnorm, affine_scale_eps, actnorm_max_scale,
         no_conv_actnorm, affine_max_scale, actnorm_eps, init_sample, no_split,
         disc_arch, weight_entropy_reg, db):

    check_manual_seed(seed)

    ds = check_dataset(dataset, dataroot, augment, download)
    image_shape, num_classes, train_dataset, test_dataset = ds

    # Note: unsupported for now
    multi_class = False

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=n_workers,
                                   drop_last=True)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=eval_batch_size,
                                  shuffle=False,
                                  num_workers=n_workers,
                                  drop_last=False)
    model = Glow(image_shape, hidden_channels, K, L, actnorm_scale,
                 flow_permutation, flow_coupling, LU_decomposed, num_classes,
                 learn_top, y_condition, logittransform, sn, affine_eps,
                 no_actnorm, affine_scale_eps, actnorm_max_scale,
                 no_conv_actnorm, affine_max_scale, actnorm_eps, no_split)

    model = model.to(device)

    if disc_arch == 'mine':
        discriminator = mine.Discriminator(image_shape[-1])
    elif disc_arch == 'biggan':
        discriminator = cgan_models.Discriminator(
            image_channels=image_shape[-1], conditional_D=False)
    elif disc_arch == 'dcgan':
        discriminator = DCGANDiscriminator(image_shape[0], 64, image_shape[-1])
    elif disc_arch == 'inv':
        discriminator = InvDiscriminator(
            image_shape, hidden_channels, K, L, actnorm_scale,
            flow_permutation, flow_coupling, LU_decomposed, num_classes,
            learn_top, y_condition, logittransform, sn, affine_eps, no_actnorm,
            affine_scale_eps, actnorm_max_scale, no_conv_actnorm,
            affine_max_scale, actnorm_eps, no_split)

    discriminator = discriminator.to(device)
    D_optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                    discriminator.parameters()),
                             lr=disc_lr,
                             betas=(.5, .99),
                             weight_decay=0)
    if optim_name == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=lr,
                               betas=(.5, .99),
                               weight_decay=0)
    elif optim_name == 'adamax':
        optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    if not no_warm_up:
        lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      lr_lambda=lr_lambda)

    iteration_fieldnames = [
        'global_iteration', 'fid', 'sample_pad', 'train_bpd', 'eval_bpd',
        'pad', 'batch_real_acc', 'batch_fake_acc', 'batch_acc'
    ]
    iteration_logger = CSVLogger(fieldnames=iteration_fieldnames,
                                 filename=os.path.join(output_dir,
                                                       'iteration_log.csv'))
    iteration_fieldnames = [
        'global_iteration', 'condition_num', 'max_sv', 'min_sv',
        'inverse_condition_num', 'inverse_max_sv', 'inverse_min_sv'
    ]
    svd_logger = CSVLogger(fieldnames=iteration_fieldnames,
                           filename=os.path.join(output_dir, 'svd_log.csv'))

    #
    test_iter = test_loader.__iter__()
    N_inception = 1000
    x_real_inception = torch.cat([
        test_iter.__next__()[0].to(device)
        for _ in range(N_inception // args.batch_size + 1)
    ], 0)[:N_inception]
    x_real_inception = x_real_inception + .5
    x_for_recon = test_iter.__next__()[0].to(device)

    def gan_step(engine, batch):
        assert not y_condition
        if 'iter_ind' in dir(engine):
            engine.iter_ind += 1
        else:
            engine.iter_ind = -1
        losses = {}
        model.train()
        discriminator.train()

        x, y = batch
        x = x.to(device)

        def run_noised_disc(discriminator, x):
            x = uniform_binning_correction(x)[0]
            return discriminator(x)

        real_acc = fake_acc = acc = 0
        if weight_gan > 0:
            fake = generate_from_noise(model, x.size(0), clamp=clamp)

            D_real_scores = run_noised_disc(discriminator, x.detach())
            D_fake_scores = run_noised_disc(discriminator, fake.detach())

            ones_target = torch.ones((x.size(0), 1), device=x.device)
            zeros_target = torch.zeros((x.size(0), 1), device=x.device)

            D_real_accuracy = torch.sum(
                torch.round(F.sigmoid(D_real_scores)) ==
                ones_target).float() / ones_target.size(0)
            D_fake_accuracy = torch.sum(
                torch.round(F.sigmoid(D_fake_scores)) ==
                zeros_target).float() / zeros_target.size(0)

            D_real_loss = F.binary_cross_entropy_with_logits(
                D_real_scores, ones_target)
            D_fake_loss = F.binary_cross_entropy_with_logits(
                D_fake_scores, zeros_target)

            D_loss = (D_real_loss + D_fake_loss) / 2
            gp = gradient_penalty(
                x.detach(), fake.detach(),
                lambda _x: run_noised_disc(discriminator, _x))
            D_loss_plus_gp = D_loss + 10 * gp
            D_optimizer.zero_grad()
            D_loss_plus_gp.backward()
            D_optimizer.step()

            # Train generator
            fake = generate_from_noise(model,
                                       x.size(0),
                                       clamp=clamp,
                                       guard_nans=False)
            G_loss = F.binary_cross_entropy_with_logits(
                run_noised_disc(discriminator, fake),
                torch.ones((x.size(0), 1), device=x.device))

            # Trace
            real_acc = D_real_accuracy.item()
            fake_acc = D_fake_accuracy.item()
            acc = .5 * (D_fake_accuracy.item() + D_real_accuracy.item())

        z, nll, y_logits, (prior, logdet) = model.forward(x,
                                                          None,
                                                          return_details=True)
        train_bpd = nll.mean().item()

        loss = 0
        if weight_gan > 0:
            loss = loss + weight_gan * G_loss
        if weight_prior > 0:
            loss = loss + weight_prior * -prior.mean()
        if weight_logdet > 0:
            loss = loss + weight_logdet * -logdet.mean()

        if weight_entropy_reg > 0:
            _, _, _, (sample_prior,
                      sample_logdet) = model.forward(fake,
                                                     None,
                                                     return_details=True)
            # notice this is actually "decreasing" sample likelihood.
            loss = loss + weight_entropy_reg * (sample_prior.mean() +
                                                sample_logdet.mean())
        # Jac Reg
        if jac_reg_lambda > 0:
            # Sample
            x_samples = generate_from_noise(model,
                                            args.batch_size,
                                            clamp=clamp).detach()
            x_samples.requires_grad_()
            z = model.forward(x_samples, None, return_details=True)[0]
            other_zs = torch.cat([
                split._last_z2.view(x.size(0), -1)
                for split in model.flow.splits
            ], -1)
            all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1)
            sample_foward_jac = compute_jacobian_regularizer(x_samples,
                                                             all_z,
                                                             n_proj=1)
            _, c2, h, w = model.prior_h.shape
            c = c2 // 2
            zshape = (batch_size, c, h, w)
            randz = torch.randn(zshape).to(device)
            randz = torch.autograd.Variable(randz, requires_grad=True)
            images = model(z=randz,
                           y_onehot=None,
                           temperature=1,
                           reverse=True,
                           batch_size=0)
            other_zs = [split._last_z2 for split in model.flow.splits]
            all_z = [randz] + other_zs
            sample_inverse_jac = compute_jacobian_regularizer_manyinputs(
                all_z, images, n_proj=1)

            # Data
            x.requires_grad_()
            z = model.forward(x, None, return_details=True)[0]
            other_zs = torch.cat([
                split._last_z2.view(x.size(0), -1)
                for split in model.flow.splits
            ], -1)
            all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1)
            data_foward_jac = compute_jacobian_regularizer(x, all_z, n_proj=1)
            _, c2, h, w = model.prior_h.shape
            c = c2 // 2
            zshape = (batch_size, c, h, w)
            z.requires_grad_()
            images = model(z=z,
                           y_onehot=None,
                           temperature=1,
                           reverse=True,
                           batch_size=0)
            other_zs = [split._last_z2 for split in model.flow.splits]
            all_z = [z] + other_zs
            data_inverse_jac = compute_jacobian_regularizer_manyinputs(
                all_z, images, n_proj=1)

            # loss = loss + jac_reg_lambda * (sample_foward_jac + sample_inverse_jac )
            loss = loss + jac_reg_lambda * (sample_foward_jac +
                                            sample_inverse_jac +
                                            data_foward_jac + data_inverse_jac)

        if not eval_only:
            optimizer.zero_grad()
            loss.backward()
            if not db:
                assert max_grad_clip == max_grad_norm == 0
            if max_grad_clip > 0:
                torch.nn.utils.clip_grad_value_(model.parameters(),
                                                max_grad_clip)
            if max_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               max_grad_norm)

            # Replace NaN gradient with 0
            for p in model.parameters():
                if p.requires_grad and p.grad is not None:
                    g = p.grad.data
                    g[g != g] = 0

            optimizer.step()

        if engine.iter_ind % 100 == 0:
            with torch.no_grad():
                fake = generate_from_noise(model, x.size(0), clamp=clamp)
                z = model.forward(fake, None, return_details=True)[0]
            print("Z max min")
            print(z.max().item(), z.min().item())
            if (fake != fake).float().sum() > 0:
                title = 'NaNs'
            else:
                title = "Good"
            grid = make_grid((postprocess(fake.detach().cpu(), dataset)[:30]),
                             nrow=6).permute(1, 2, 0)
            plt.figure(figsize=(10, 10))
            plt.imshow(grid)
            plt.axis('off')
            plt.title(title)
            plt.savefig(
                os.path.join(output_dir, f'sample_{engine.iter_ind}.png'))

        if engine.iter_ind % eval_every == 0:

            def check_all_zero_except_leading(x):
                return x % 10**np.floor(np.log10(x)) == 0

            if engine.iter_ind == 0 or check_all_zero_except_leading(
                    engine.iter_ind):
                torch.save(
                    model.state_dict(),
                    os.path.join(output_dir, f'ckpt_sd_{engine.iter_ind}.pt'))

            model.eval()

            with torch.no_grad():
                # Plot recon
                fpath = os.path.join(output_dir, '_recon',
                                     f'recon_{engine.iter_ind}.png')
                sample_pad = run_recon_evolution(
                    model,
                    generate_from_noise(model, args.batch_size,
                                        clamp=clamp).detach(), fpath)
                print(
                    f"Iter: {engine.iter_ind}, Recon Sample PAD: {sample_pad}")

                pad = run_recon_evolution(model, x_for_recon, fpath)
                print(f"Iter: {engine.iter_ind}, Recon PAD: {pad}")
                pad = pad.item()
                sample_pad = sample_pad.item()

                # Inception score
                sample = torch.cat([
                    generate_from_noise(model, args.batch_size, clamp=clamp)
                    for _ in range(N_inception // args.batch_size + 1)
                ], 0)[:N_inception]
                sample = sample + .5

                if (sample != sample).float().sum() > 0:
                    print("Sample NaNs")
                    raise
                else:
                    fid = run_fid(x_real_inception.clamp_(0, 1),
                                  sample.clamp_(0, 1))
                    print(f'fid: {fid}, global_iter: {engine.iter_ind}')

                # Eval BPD
                eval_bpd = np.mean([
                    model.forward(x.to(device), None,
                                  return_details=True)[1].mean().item()
                    for x, _ in test_loader
                ])

                stats_dict = {
                    'global_iteration': engine.iter_ind,
                    'fid': fid,
                    'train_bpd': train_bpd,
                    'pad': pad,
                    'eval_bpd': eval_bpd,
                    'sample_pad': sample_pad,
                    'batch_real_acc': real_acc,
                    'batch_fake_acc': fake_acc,
                    'batch_acc': acc
                }
                iteration_logger.writerow(stats_dict)
                plot_csv(iteration_logger.filename)
            model.train()

        if engine.iter_ind + 2 % svd_every == 0:
            model.eval()
            svd_dict = {}
            ret = utils.computeSVDjacobian(x_for_recon, model)
            D_for, D_inv = ret['D_for'], ret['D_inv']
            cn = float(D_for.max() / D_for.min())
            cn_inv = float(D_inv.max() / D_inv.min())
            svd_dict['global_iteration'] = engine.iter_ind
            svd_dict['condition_num'] = cn
            svd_dict['max_sv'] = float(D_for.max())
            svd_dict['min_sv'] = float(D_for.min())
            svd_dict['inverse_condition_num'] = cn_inv
            svd_dict['inverse_max_sv'] = float(D_inv.max())
            svd_dict['inverse_min_sv'] = float(D_inv.min())
            svd_logger.writerow(svd_dict)
            # plot_utils.plot_stability_stats(output_dir)
            # plot_utils.plot_individual_figures(output_dir, 'svd_log.csv')
            model.train()
            if eval_only:
                sys.exit()

        # Dummy
        losses['total_loss'] = torch.mean(nll).item()
        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(nll,
                                        y_logits,
                                        y_weight,
                                        y,
                                        multi_class,
                                        reduction='none')
            else:
                z, nll, y_logits = model(x, None)
                losses = compute_loss(nll, reduction='none')

        return losses

    trainer = Engine(gan_step)
    # else:
    #     trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         'glow',
                                         save_interval=5,
                                         n_saved=1,
                                         require_empty=False)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {
        'model': model,
        'optimizer': optimizer
    })

    monitoring_metrics = ['total_loss']
    RunningAverage(output_transform=lambda x: x['total_loss']).attach(
        trainer, 'total_loss')

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(lambda x, y: torch.mean(x),
         output_transform=lambda x:
         (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach(
             evaluator, 'total_loss')

    if y_condition:
        monitoring_metrics.extend(['nll'])
        RunningAverage(output_transform=lambda x: x['nll']).attach(
            trainer, 'nll')

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(lambda x, y: torch.mean(x),
             output_transform=lambda x:
             (x['nll'], torch.empty(x['nll'].shape[0]))).attach(
                 evaluator, 'nll')

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        print("Loading...")
        print(saved_model)
        loaded = torch.load(saved_model)
        # if 'Glow' in str(type(loaded)):
        #     model  = loaded
        # else:
        #     raise
        # # if 'Glow' in str(type(loaded)):
        # #     loaded  = loaded.state_dict()
        model.load_state_dict(loaded)
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer))

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split('_')[-1])

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(
                engine.state.dataloader)

    @trainer.on(Events.STARTED)
    def init(engine):
        if saved_model:
            return
        model.train()
        print("Initializing Actnorm...")
        init_batches = []
        init_targets = []

        if n_init_batches == 0:
            model.set_actnorm_init()
            return
        with torch.no_grad():
            if init_sample:
                generate_from_noise(model,
                                    args.batch_size * args.n_init_batches)
            else:
                for batch, target in islice(train_loader, None,
                                            n_init_batches):
                    init_batches.append(batch)
                    init_targets.append(target)

                init_batches = torch.cat(init_batches).to(device)

                assert init_batches.shape[0] == n_init_batches * batch_size

                if y_condition:
                    init_targets = torch.cat(init_targets).to(device)
                else:
                    init_targets = None

                model(init_batches, init_targets)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):
        evaluator.run(test_loader)
        if not no_warm_up:
            scheduler.step()
        metrics = evaluator.state.metrics

        losses = ', '.join(
            [f"{key}: {value:.2f}" for key, value in metrics.items()])

        print(f'Validation Results - Epoch: {engine.state.epoch} {losses}')

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]'
        )
        timer.reset()

    trainer.run(train_loader, epochs)
コード例 #16
0
        images = model(z= randz, y_onehot=None, temperature=1, reverse=True,batch_size=batch_size)   
        return images

    # ipdb.set_trace()
    iteration_fieldnames = ['global_iteration', 'fid']
    iteration_logger = CSVLogger(fieldnames=iteration_fieldnames,
                             filename=os.path.join(output_dir, 'eval_log.csv'))



    # for idx in tqdm(np.arange(10,100,10)):
    for _ in range(1):
        idx = 100
        model(xs[:10].cuda(), None) # this is to initialize the u,v buffer in SpectralNorm blah...otherwise state_dicts  don't match... 
        saved_model = os.path.join(exp_dir, f"glow_model_{idx}.pth")
        model.load_state_dict(torch.load(saved_model))
        model.set_actnorm_init()
        # Sample
        with torch.no_grad():
            fake = torch.cat([generate_from_noise(100) for _ in range(20)],0 )
        x_is = 2*fake
        x_is = x_is.repeat(1,3,1,1).detach()
        # I have no clue why samples can contain nan....but it does...
        def _replace_nan_with_k_inplace(x, k):
            mask = x != x
            x[mask] = k
        _replace_nan_with_k_inplace(x_is, -1)
        with torch.no_grad():
            issf, _, _, acts_fake = inception_score(x_is, cuda=True, batch_size=32, resize=True, splits=10, return_preds=True)
        idxs_ = np.argsort(np.abs(acts_fake).sum(-1))[:1800] # filter the ones with super large values
        acts_fake = acts_fake[idxs_]
    num_classes = 40
    Batch_Size = 4
    dataset_test = CelebALoader(
        root_folder=hparams['dataroot']
    )  #'/home/yellow/deep-learning-and-practice/hw7/dataset/task_2/'
    test_loader = DataLoader(dataset_test,
                             batch_size=Batch_Size,
                             shuffle=False,
                             drop_last=True)
    model = Glow(image_shape, hparams['hidden_channels'], hparams['K'],
                 hparams['L'], hparams['actnorm_scale'],
                 hparams['flow_permutation'], hparams['flow_coupling'],
                 hparams['LU_decomposed'], num_classes, hparams['learn_top'],
                 hparams['y_condition'])

    model.load_state_dict(
        torch.load(output_folder + model_name, map_location="cpu")['model'])
    model.set_actnorm_init()
    model = model.to(device)
    model = model.eval()

    # attribute_list = [8] # Black_Hair
    attribute_list = [20, 31, 33]  # Male, Smiling, Wavy_Hair, 24z    No_Beard
    # attribute_list = [11, 26, 31, 8, 6, 7] # Brown_Hair, Pale_Skin, Smiling, Black_Hair, Big_Lips, Big_Nose
    # attribute_list = [i for i in range(40)]
    N = 8
    z_pos_list = [torch.Tensor([]).cuda() for i in range(len(attribute_list))]
    z_neg_list = [torch.Tensor([]).cuda() for i in range(len(attribute_list))]

    z_input_img = None
    with torch.no_grad():
        for i, (x, y) in enumerate(test_loader):
コード例 #18
0
with open(output_folder + 'hparams.json') as json_file:
    hparams = json.load(json_file)

test_mnist = train.MyMNIST(train=False, download=False)
image_shape = (32, 32, 1)
num_classes = 10
batch_size = 512

model = Glow(image_shape, hparams['hidden_channels'], hparams['K'],
             hparams['L'], hparams['actnorm_scale'],
             hparams['flow_permutation'], hparams['flow_coupling'],
             hparams['LU_decomposed'], num_classes, hparams['learn_top'],
             hparams['y_condition'])

model.load_state_dict(torch.load(latest_model_path))
model.set_actnorm_init()

model = model.to(device)

model = model.eval()


def sample(model):
    with torch.no_grad():
        if hparams['y_condition']:
            y = torch.eye(num_classes)
            y = y.repeat(batch_size // num_classes + 1)
            y = y[:32, :].to(device)  # number hardcoded in model for now
        else:
            y = None