Пример #1
0
    def forward_and_reverse_output_shape(self,
                                         in_channel,
                                         data,
                                         levels=3,
                                         depth=4):
        glow = Glow(in_channel, levels, depth)
        z, logdet, eps = glow(data)
        height, width = data.shape[2], data.shape[3]
        """
            cifar example:
            Level = 3
            initial shape -> [4, 3, 32, 32]
            iter 1 -> z: [4, 12, 16, 16] because of squeeze from outside the loop
            iter 2 -> z: [4, 24, 8, 8] because of squeeze + split
            iter 3 -> z: [4, 48, 4, 4] because of squeeze + split
        """
        assert list(z.shape) == [4, in_channel * 4 * 2**(levels - 1), 4, 4]
        assert list(logdet.shape) == [4]  # because batch_size = 4
        assert len(
            eps
        ) == levels - 1  # because L = 3 and split is executed whenever < L, i.e 2 times in total

        factor = 1
        for e in eps:
            factor *= 2
            # example: first eps -> from iter 1 take z shape and divide channel by 2: [4, 12/2, 16, 16]
            assert list(e.shape) == [
                4, in_channel * factor, height / factor, width / factor
            ]
        """
            In total depth * levels = 4 * 3 = 12, so we got 12 instances of actnorm, inconv and affinecoupling
            Actnorm = 2 trainable parameters
            Invconv = 3 trainable parameter
            Affinecoupling = 6 trainable parameters (got 3 conv layers, each layer has weight + bias, so for all layers combined we get 6 in total)
            Zeroconv = 4 (2 conv layers, each with weight + bias)
            
            12 * (2+3+6) + 4= 136
        """
        assert len(list(
            glow.parameters())) == (levels * depth) * (2 + 3 + 6) + 4
        for param in glow.parameters():
            assert param.requires_grad

        # reverse
        # For cifar we expect z with level=3 to be of shape [4,48,4,4]
        z = glow.reverse(z, eps)

        assert list(z.shape) == [4, 3, 32, 32]
Пример #2
0
def main(args):
    # we're probably only be using 1 GPU, so this should be fine
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"running on {device}")
    # set random seed for all
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    global best_loss
    if(args.generate_samples):
        print("generating samples")
    # load data
    # example for CIFAR-10 training:
    train_set, test_set = get_dataloader(args.dataset, args.batch_size)

    input_channels = channels_from_dataset(args.dataset)
    print(f"amount of  input channels: {input_channels}")
    # instantiate model
    # # baby network to make sure training script works
    net = Glow(in_channels=input_channels,
               depth=args.amt_flow_steps, levels=args.amt_levels, use_normalization=args.norm_method)

    # code for rosalinty model
    # net = RosGlow(input_channels, args.amt_flow_steps, args.amt_levels)

    net = net.to(device)

    print(f"training for {args.num_epochs} epochs.")

    start_epoch = 0
    # TODO: add functionality for loading checkpoints here
    if args.resume:
        print(f"resuming from checkpoint found in checkpoints/best_{args.dataset.lower()}.pth.tar.")
        # raise error if no checkpoint directory is found
        assert os.path.isdir("new_checkpoints")
        checkpoint = torch.load(f"new_checkpoints/best_{args.dataset.lower()}.pth.tar")
        net.load_state_dict(checkpoint["model"])
        global best_loss
        best_loss = checkpoint["test_loss"]
        start_epoch = checkpoint["epoch"]

    loss_function = FlowNLL().to(device)
    optimizer = optim.Adam(net.parameters(), lr=float(args.lr))
    # scheduler found in code, no mention in paper
    # scheduler = sched.LambdaLR(
    #     optimizer, lambda s: min(1., s / args.warmup_iters))

    # should we add a resume function here?

    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        print(f"training epoch {epoch}")
        train(net, train_set, device, optimizer, loss_function, epoch)
        # how often do we want to test?
        if (epoch % 10 == 0):  # revert this to 10 once we know that this works
            print(f"testing epoch {epoch}")
            test(net, test_set, device, loss_function, epoch, args.generate_samples,
                 args.amt_levels, args.dataset, args.n_samples)
Пример #3
0
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size,
         epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale,
         flow_permutation, flow_coupling, LU_decomposed, learn_top,
         y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers,
         cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer,
         warmup, fresh, logittransform, gan, disc_lr, sn, flowgan, eval_every,
         ld_on_samples, weight_gan, weight_prior, weight_logdet,
         jac_reg_lambda, affine_eps, no_warm_up, optim_name, clamp, svd_every,
         eval_only, no_actnorm, affine_scale_eps, actnorm_max_scale,
         no_conv_actnorm, affine_max_scale, actnorm_eps, init_sample, no_split,
         disc_arch, weight_entropy_reg, db):

    check_manual_seed(seed)

    ds = check_dataset(dataset, dataroot, augment, download)
    image_shape, num_classes, train_dataset, test_dataset = ds

    # Note: unsupported for now
    multi_class = False

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=n_workers,
                                   drop_last=True)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=eval_batch_size,
                                  shuffle=False,
                                  num_workers=n_workers,
                                  drop_last=False)
    model = Glow(image_shape, hidden_channels, K, L, actnorm_scale,
                 flow_permutation, flow_coupling, LU_decomposed, num_classes,
                 learn_top, y_condition, logittransform, sn, affine_eps,
                 no_actnorm, affine_scale_eps, actnorm_max_scale,
                 no_conv_actnorm, affine_max_scale, actnorm_eps, no_split)

    model = model.to(device)

    if disc_arch == 'mine':
        discriminator = mine.Discriminator(image_shape[-1])
    elif disc_arch == 'biggan':
        discriminator = cgan_models.Discriminator(
            image_channels=image_shape[-1], conditional_D=False)
    elif disc_arch == 'dcgan':
        discriminator = DCGANDiscriminator(image_shape[0], 64, image_shape[-1])
    elif disc_arch == 'inv':
        discriminator = InvDiscriminator(
            image_shape, hidden_channels, K, L, actnorm_scale,
            flow_permutation, flow_coupling, LU_decomposed, num_classes,
            learn_top, y_condition, logittransform, sn, affine_eps, no_actnorm,
            affine_scale_eps, actnorm_max_scale, no_conv_actnorm,
            affine_max_scale, actnorm_eps, no_split)

    discriminator = discriminator.to(device)
    D_optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                    discriminator.parameters()),
                             lr=disc_lr,
                             betas=(.5, .99),
                             weight_decay=0)
    if optim_name == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=lr,
                               betas=(.5, .99),
                               weight_decay=0)
    elif optim_name == 'adamax':
        optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    if not no_warm_up:
        lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      lr_lambda=lr_lambda)

    iteration_fieldnames = [
        'global_iteration', 'fid', 'sample_pad', 'train_bpd', 'eval_bpd',
        'pad', 'batch_real_acc', 'batch_fake_acc', 'batch_acc'
    ]
    iteration_logger = CSVLogger(fieldnames=iteration_fieldnames,
                                 filename=os.path.join(output_dir,
                                                       'iteration_log.csv'))
    iteration_fieldnames = [
        'global_iteration', 'condition_num', 'max_sv', 'min_sv',
        'inverse_condition_num', 'inverse_max_sv', 'inverse_min_sv'
    ]
    svd_logger = CSVLogger(fieldnames=iteration_fieldnames,
                           filename=os.path.join(output_dir, 'svd_log.csv'))

    #
    test_iter = test_loader.__iter__()
    N_inception = 1000
    x_real_inception = torch.cat([
        test_iter.__next__()[0].to(device)
        for _ in range(N_inception // args.batch_size + 1)
    ], 0)[:N_inception]
    x_real_inception = x_real_inception + .5
    x_for_recon = test_iter.__next__()[0].to(device)

    def gan_step(engine, batch):
        assert not y_condition
        if 'iter_ind' in dir(engine):
            engine.iter_ind += 1
        else:
            engine.iter_ind = -1
        losses = {}
        model.train()
        discriminator.train()

        x, y = batch
        x = x.to(device)

        def run_noised_disc(discriminator, x):
            x = uniform_binning_correction(x)[0]
            return discriminator(x)

        real_acc = fake_acc = acc = 0
        if weight_gan > 0:
            fake = generate_from_noise(model, x.size(0), clamp=clamp)

            D_real_scores = run_noised_disc(discriminator, x.detach())
            D_fake_scores = run_noised_disc(discriminator, fake.detach())

            ones_target = torch.ones((x.size(0), 1), device=x.device)
            zeros_target = torch.zeros((x.size(0), 1), device=x.device)

            D_real_accuracy = torch.sum(
                torch.round(F.sigmoid(D_real_scores)) ==
                ones_target).float() / ones_target.size(0)
            D_fake_accuracy = torch.sum(
                torch.round(F.sigmoid(D_fake_scores)) ==
                zeros_target).float() / zeros_target.size(0)

            D_real_loss = F.binary_cross_entropy_with_logits(
                D_real_scores, ones_target)
            D_fake_loss = F.binary_cross_entropy_with_logits(
                D_fake_scores, zeros_target)

            D_loss = (D_real_loss + D_fake_loss) / 2
            gp = gradient_penalty(
                x.detach(), fake.detach(),
                lambda _x: run_noised_disc(discriminator, _x))
            D_loss_plus_gp = D_loss + 10 * gp
            D_optimizer.zero_grad()
            D_loss_plus_gp.backward()
            D_optimizer.step()

            # Train generator
            fake = generate_from_noise(model,
                                       x.size(0),
                                       clamp=clamp,
                                       guard_nans=False)
            G_loss = F.binary_cross_entropy_with_logits(
                run_noised_disc(discriminator, fake),
                torch.ones((x.size(0), 1), device=x.device))

            # Trace
            real_acc = D_real_accuracy.item()
            fake_acc = D_fake_accuracy.item()
            acc = .5 * (D_fake_accuracy.item() + D_real_accuracy.item())

        z, nll, y_logits, (prior, logdet) = model.forward(x,
                                                          None,
                                                          return_details=True)
        train_bpd = nll.mean().item()

        loss = 0
        if weight_gan > 0:
            loss = loss + weight_gan * G_loss
        if weight_prior > 0:
            loss = loss + weight_prior * -prior.mean()
        if weight_logdet > 0:
            loss = loss + weight_logdet * -logdet.mean()

        if weight_entropy_reg > 0:
            _, _, _, (sample_prior,
                      sample_logdet) = model.forward(fake,
                                                     None,
                                                     return_details=True)
            # notice this is actually "decreasing" sample likelihood.
            loss = loss + weight_entropy_reg * (sample_prior.mean() +
                                                sample_logdet.mean())
        # Jac Reg
        if jac_reg_lambda > 0:
            # Sample
            x_samples = generate_from_noise(model,
                                            args.batch_size,
                                            clamp=clamp).detach()
            x_samples.requires_grad_()
            z = model.forward(x_samples, None, return_details=True)[0]
            other_zs = torch.cat([
                split._last_z2.view(x.size(0), -1)
                for split in model.flow.splits
            ], -1)
            all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1)
            sample_foward_jac = compute_jacobian_regularizer(x_samples,
                                                             all_z,
                                                             n_proj=1)
            _, c2, h, w = model.prior_h.shape
            c = c2 // 2
            zshape = (batch_size, c, h, w)
            randz = torch.randn(zshape).to(device)
            randz = torch.autograd.Variable(randz, requires_grad=True)
            images = model(z=randz,
                           y_onehot=None,
                           temperature=1,
                           reverse=True,
                           batch_size=0)
            other_zs = [split._last_z2 for split in model.flow.splits]
            all_z = [randz] + other_zs
            sample_inverse_jac = compute_jacobian_regularizer_manyinputs(
                all_z, images, n_proj=1)

            # Data
            x.requires_grad_()
            z = model.forward(x, None, return_details=True)[0]
            other_zs = torch.cat([
                split._last_z2.view(x.size(0), -1)
                for split in model.flow.splits
            ], -1)
            all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1)
            data_foward_jac = compute_jacobian_regularizer(x, all_z, n_proj=1)
            _, c2, h, w = model.prior_h.shape
            c = c2 // 2
            zshape = (batch_size, c, h, w)
            z.requires_grad_()
            images = model(z=z,
                           y_onehot=None,
                           temperature=1,
                           reverse=True,
                           batch_size=0)
            other_zs = [split._last_z2 for split in model.flow.splits]
            all_z = [z] + other_zs
            data_inverse_jac = compute_jacobian_regularizer_manyinputs(
                all_z, images, n_proj=1)

            # loss = loss + jac_reg_lambda * (sample_foward_jac + sample_inverse_jac )
            loss = loss + jac_reg_lambda * (sample_foward_jac +
                                            sample_inverse_jac +
                                            data_foward_jac + data_inverse_jac)

        if not eval_only:
            optimizer.zero_grad()
            loss.backward()
            if not db:
                assert max_grad_clip == max_grad_norm == 0
            if max_grad_clip > 0:
                torch.nn.utils.clip_grad_value_(model.parameters(),
                                                max_grad_clip)
            if max_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               max_grad_norm)

            # Replace NaN gradient with 0
            for p in model.parameters():
                if p.requires_grad and p.grad is not None:
                    g = p.grad.data
                    g[g != g] = 0

            optimizer.step()

        if engine.iter_ind % 100 == 0:
            with torch.no_grad():
                fake = generate_from_noise(model, x.size(0), clamp=clamp)
                z = model.forward(fake, None, return_details=True)[0]
            print("Z max min")
            print(z.max().item(), z.min().item())
            if (fake != fake).float().sum() > 0:
                title = 'NaNs'
            else:
                title = "Good"
            grid = make_grid((postprocess(fake.detach().cpu(), dataset)[:30]),
                             nrow=6).permute(1, 2, 0)
            plt.figure(figsize=(10, 10))
            plt.imshow(grid)
            plt.axis('off')
            plt.title(title)
            plt.savefig(
                os.path.join(output_dir, f'sample_{engine.iter_ind}.png'))

        if engine.iter_ind % eval_every == 0:

            def check_all_zero_except_leading(x):
                return x % 10**np.floor(np.log10(x)) == 0

            if engine.iter_ind == 0 or check_all_zero_except_leading(
                    engine.iter_ind):
                torch.save(
                    model.state_dict(),
                    os.path.join(output_dir, f'ckpt_sd_{engine.iter_ind}.pt'))

            model.eval()

            with torch.no_grad():
                # Plot recon
                fpath = os.path.join(output_dir, '_recon',
                                     f'recon_{engine.iter_ind}.png')
                sample_pad = run_recon_evolution(
                    model,
                    generate_from_noise(model, args.batch_size,
                                        clamp=clamp).detach(), fpath)
                print(
                    f"Iter: {engine.iter_ind}, Recon Sample PAD: {sample_pad}")

                pad = run_recon_evolution(model, x_for_recon, fpath)
                print(f"Iter: {engine.iter_ind}, Recon PAD: {pad}")
                pad = pad.item()
                sample_pad = sample_pad.item()

                # Inception score
                sample = torch.cat([
                    generate_from_noise(model, args.batch_size, clamp=clamp)
                    for _ in range(N_inception // args.batch_size + 1)
                ], 0)[:N_inception]
                sample = sample + .5

                if (sample != sample).float().sum() > 0:
                    print("Sample NaNs")
                    raise
                else:
                    fid = run_fid(x_real_inception.clamp_(0, 1),
                                  sample.clamp_(0, 1))
                    print(f'fid: {fid}, global_iter: {engine.iter_ind}')

                # Eval BPD
                eval_bpd = np.mean([
                    model.forward(x.to(device), None,
                                  return_details=True)[1].mean().item()
                    for x, _ in test_loader
                ])

                stats_dict = {
                    'global_iteration': engine.iter_ind,
                    'fid': fid,
                    'train_bpd': train_bpd,
                    'pad': pad,
                    'eval_bpd': eval_bpd,
                    'sample_pad': sample_pad,
                    'batch_real_acc': real_acc,
                    'batch_fake_acc': fake_acc,
                    'batch_acc': acc
                }
                iteration_logger.writerow(stats_dict)
                plot_csv(iteration_logger.filename)
            model.train()

        if engine.iter_ind + 2 % svd_every == 0:
            model.eval()
            svd_dict = {}
            ret = utils.computeSVDjacobian(x_for_recon, model)
            D_for, D_inv = ret['D_for'], ret['D_inv']
            cn = float(D_for.max() / D_for.min())
            cn_inv = float(D_inv.max() / D_inv.min())
            svd_dict['global_iteration'] = engine.iter_ind
            svd_dict['condition_num'] = cn
            svd_dict['max_sv'] = float(D_for.max())
            svd_dict['min_sv'] = float(D_for.min())
            svd_dict['inverse_condition_num'] = cn_inv
            svd_dict['inverse_max_sv'] = float(D_inv.max())
            svd_dict['inverse_min_sv'] = float(D_inv.min())
            svd_logger.writerow(svd_dict)
            # plot_utils.plot_stability_stats(output_dir)
            # plot_utils.plot_individual_figures(output_dir, 'svd_log.csv')
            model.train()
            if eval_only:
                sys.exit()

        # Dummy
        losses['total_loss'] = torch.mean(nll).item()
        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(nll,
                                        y_logits,
                                        y_weight,
                                        y,
                                        multi_class,
                                        reduction='none')
            else:
                z, nll, y_logits = model(x, None)
                losses = compute_loss(nll, reduction='none')

        return losses

    trainer = Engine(gan_step)
    # else:
    #     trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         'glow',
                                         save_interval=5,
                                         n_saved=1,
                                         require_empty=False)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {
        'model': model,
        'optimizer': optimizer
    })

    monitoring_metrics = ['total_loss']
    RunningAverage(output_transform=lambda x: x['total_loss']).attach(
        trainer, 'total_loss')

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(lambda x, y: torch.mean(x),
         output_transform=lambda x:
         (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach(
             evaluator, 'total_loss')

    if y_condition:
        monitoring_metrics.extend(['nll'])
        RunningAverage(output_transform=lambda x: x['nll']).attach(
            trainer, 'nll')

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(lambda x, y: torch.mean(x),
             output_transform=lambda x:
             (x['nll'], torch.empty(x['nll'].shape[0]))).attach(
                 evaluator, 'nll')

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        print("Loading...")
        print(saved_model)
        loaded = torch.load(saved_model)
        # if 'Glow' in str(type(loaded)):
        #     model  = loaded
        # else:
        #     raise
        # # if 'Glow' in str(type(loaded)):
        # #     loaded  = loaded.state_dict()
        model.load_state_dict(loaded)
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer))

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split('_')[-1])

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(
                engine.state.dataloader)

    @trainer.on(Events.STARTED)
    def init(engine):
        if saved_model:
            return
        model.train()
        print("Initializing Actnorm...")
        init_batches = []
        init_targets = []

        if n_init_batches == 0:
            model.set_actnorm_init()
            return
        with torch.no_grad():
            if init_sample:
                generate_from_noise(model,
                                    args.batch_size * args.n_init_batches)
            else:
                for batch, target in islice(train_loader, None,
                                            n_init_batches):
                    init_batches.append(batch)
                    init_targets.append(target)

                init_batches = torch.cat(init_batches).to(device)

                assert init_batches.shape[0] == n_init_batches * batch_size

                if y_condition:
                    init_targets = torch.cat(init_targets).to(device)
                else:
                    init_targets = None

                model(init_batches, init_targets)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):
        evaluator.run(test_loader)
        if not no_warm_up:
            scheduler.step()
        metrics = evaluator.state.metrics

        losses = ', '.join(
            [f"{key}: {value:.2f}" for key, value in metrics.items()])

        print(f'Validation Results - Epoch: {engine.state.epoch} {losses}')

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]'
        )
        timer.reset()

    trainer.run(train_loader, epochs)
Пример #4
0
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size,
         epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale,
         flow_permutation, flow_coupling, LU_decomposed, learn_top,
         y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers,
         cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer,
         fresh):

    device = 'cpu' if (not torch.cuda.is_available() or not cuda) else 'cuda:0'

    check_manual_seed(seed)

    ds = check_dataset(dataset, dataroot, augment, download)
    image_shape, num_classes, train_dataset, test_dataset = ds

    # Note: unsupported for now
    multi_class = False

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=n_workers,
                                   drop_last=True)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=eval_batch_size,
                                  shuffle=False,
                                  num_workers=n_workers,
                                  drop_last=False)

    model = Glow(image_shape, hidden_channels, K, L, actnorm_scale,
                 flow_permutation, flow_coupling, LU_decomposed, num_classes,
                 learn_top, y_condition)

    model = model.to(device)
    optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch
        x = x.to(device)

        if y_condition:
            y = y.to(device)
            z, nll, y_logits = model(x, y)
            losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class)
        else:
            z, nll, y_logits = model(x, None)
            losses = compute_loss(nll)

        losses['total_loss'].backward()

        if max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(nll,
                                        y_logits,
                                        y_weight,
                                        y,
                                        multi_class,
                                        reduction='none')
            else:
                z, nll, y_logits = model(x, None)
                losses = compute_loss(nll, reduction='none')

        return losses

    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         'glow',
                                         save_interval=1,
                                         n_saved=2,
                                         require_empty=False)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {
        'model': model,
        'optimizer': optimizer
    })

    monitoring_metrics = ['total_loss']
    RunningAverage(output_transform=lambda x: x['total_loss']).attach(
        trainer, 'total_loss')

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(lambda x, y: torch.mean(x),
         output_transform=lambda x:
         (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach(
             evaluator, 'total_loss')

    if y_condition:
        monitoring_metrics.extend(['nll'])
        RunningAverage(output_transform=lambda x: x['nll']).attach(
            trainer, 'nll')

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(lambda x, y: torch.mean(x),
             output_transform=lambda x:
             (x['nll'], torch.empty(x['nll'].shape[0]))).attach(
                 evaluator, 'nll')

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        model.load_state_dict(torch.load(saved_model))
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer))

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split('_')[-1])

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(
                engine.state.dataloader)

    @trainer.on(Events.STARTED)
    def init(engine):
        model.train()

        init_batches = []
        init_targets = []

        with torch.no_grad():
            for batch, target in islice(train_loader, None, n_init_batches):
                init_batches.append(batch)
                init_targets.append(target)

            init_batches = torch.cat(init_batches).to(device)

            assert init_batches.shape[0] == n_init_batches * batch_size

            if y_condition:
                init_targets = torch.cat(init_targets).to(device)
            else:
                init_targets = None

            model(init_batches, init_targets)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):
        evaluator.run(test_loader)
        metrics = evaluator.state.metrics

        losses = ', '.join(
            [f"{key}: {value:.2f}" for key, value in metrics.items()])

        print(f'Validation Results - Epoch: {engine.state.epoch} {losses}')

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]'
        )
        timer.reset()

    trainer.run(train_loader, epochs)
                            ldtv,
                    ) in zip(log_p_val, logdet_val, log_p_train_val,
                             logdet_train_val):
                        print(
                            args.delta,
                            lpv.item(),
                            ldv.item(),
                            lptv.item(),
                            ldtv.item(),
                            file=f_ll,
                        )
                f_ll.close()
    f_train_loss.close()
    f_test_loss.close()


if __name__ == "__main__":
    args = parser.parse_args()
    print(string_args(args))
    device = args.device
    model = Glow(
        args.n_channels,
        args.n_flow,
        args.n_block,
        affine=args.affine,
        conv_lu=not args.no_lu,
    )
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    train(args, model, optimizer)
Пример #6
0
def main(
    dataset,
    dataset2,
    dataroot,
    download,
    augment,
    batch_size,
    eval_batch_size,
    nlls_batch_size,
    epochs,
    nb_step,
    saved_model,
    seed,
    hidden_channels,
    K,
    L,
    actnorm_scale,
    flow_permutation,
    flow_coupling,
    LU_decomposed,
    learn_top,
    y_condition,
    y_weight,
    max_grad_clip,
    max_grad_norm,
    lr,
    lr_test,
    n_workers,
    cuda,
    n_init_batches,
    output_dir,
    saved_optimizer,
    warmup,
    every_epoch,
):

    device = "cpu" if (not torch.cuda.is_available() or not cuda) else "cuda:0"

    check_manual_seed(seed)

    ds = check_dataset(dataset, dataroot, augment, download)
    ds2 = check_dataset(dataset2, dataroot, augment, download)
    image_shape, num_classes, train_dataset, test_dataset = ds
    image_shape2, num_classes2, train_dataset_2, test_dataset_2 = ds2

    assert(image_shape == image_shape2)
    data1 = []
    data2 = []
    for k in range(nlls_batch_size):
        dataaux, targetaux = test_dataset[k]
        data1.append(dataaux)
        dataaux, targetaux = test_dataset_2[k]
        data2.append(dataaux)


    # Note: unsupported for now
    multi_class = False

    train_loader = data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=n_workers,
        drop_last=True,
    )
    test_loader = data.DataLoader(
        test_dataset,
        batch_size=eval_batch_size,
        shuffle=False,
        num_workers=n_workers,
        drop_last=False,
    )

    model = Glow(
        image_shape,
        hidden_channels,
        K,
        L,
        actnorm_scale,
        flow_permutation,
        flow_coupling,
        LU_decomposed,
        num_classes,
        learn_top,
        y_condition,
    )

    model = model.to(device)
    optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup)  # noqa
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch
        x = x.to(device)

        if y_condition:
            y = y.to(device)
            z, nll, y_logits = model(x, y)
            losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class)
        else:
            z, nll, y_logits = model(x, None)
            losses = compute_loss(nll)

        losses["total_loss"].backward()

        if max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(
                    nll, y_logits, y_weight, y, multi_class, reduction="none"
                )
            else:
                z, nll, y_logits = model(x, None)
                losses = compute_loss(nll, reduction="none")

        return losses

    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(
        output_dir, "glow", n_saved=2, require_empty=False
    )

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        checkpoint_handler,
        {"model": model, "optimizer": optimizer},
    )

    monitoring_metrics = ["total_loss"]
    RunningAverage(output_transform=lambda x: x["total_loss"]).attach(
        trainer, "total_loss"
    )

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(
        lambda x, y: torch.mean(x),
        output_transform=lambda x: (
            x["total_loss"],
            torch.empty(x["total_loss"].shape[0]),
        ),
    ).attach(evaluator, "total_loss")

    if y_condition:
        monitoring_metrics.extend(["nll"])
        RunningAverage(output_transform=lambda x: x["nll"]).attach(trainer, "nll")

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(
            lambda x, y: torch.mean(x),
            output_transform=lambda x: (x["nll"], torch.empty(x["nll"].shape[0])),
        ).attach(evaluator, "nll")

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        model.load_state_dict(torch.load(saved_model)['model'])
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer)['opt'])

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split("_")[-1])/1e3

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(engine.state.dataloader)

    @trainer.on(Events.STARTED)
    def init(engine):
        model.train()

        init_batches = []
        init_targets = []

        with torch.no_grad():
            print(train_loader)
            for batch, target in islice(train_loader, None, n_init_batches):
                init_batches.append(batch)
                init_targets.append(target)

            init_batches = torch.cat(init_batches).to(device)

            assert init_batches.shape[0] == n_init_batches * batch_size

            if y_condition:
                init_targets = torch.cat(init_targets).to(device)
            else:
                init_targets = None

            model(init_batches, init_targets)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):
        evaluator.run(test_loader)

        scheduler.step()
        metrics = evaluator.state.metrics

        losses = ", ".join([f"{key}: {value:.2f}" for key, value in metrics.items()])

        print(f"Validation Results - Epoch: {engine.state.epoch} {losses}")

    timer = Timer(average=True)
    timer.attach(
        trainer,
        start=Events.EPOCH_STARTED,
        resume=Events.ITERATION_STARTED,
        pause=Events.ITERATION_COMPLETED,
        step=Events.ITERATION_COMPLETED,
    )

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f"Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]"
        )
        timer.reset()

    # @trainer.on(Events.EPOCH_COMPLETED)
    # def eval_likelihood(engine):
    #     global_nlls(output_dir, engine.state.epoch, data1, data2, model, dataset1_name = dataset, dataset2_name = dataset2, nb_step = nb_step, every_epoch = every_epoch, optim_default = partial(optim.SGD, lr=1e-5, momentum = 0.))


    trainer.run(train_loader, epochs)
Пример #7
0
def main(
    dataset,
    dataroot,
    download,
    augment,
    batch_size,
    eval_batch_size,
    epochs,
    saved_model,
    seed,
    hidden_channels,
    K,
    L,
    actnorm_scale,
    flow_permutation,
    flow_coupling,
    LU_decomposed,
    learn_top,
    y_condition,
    y_weight,
    max_grad_clip,
    max_grad_norm,
    lr,
    n_workers,
    cuda,
    n_init_batches,
    output_dir,
    saved_optimizer,
    warmup,
    classifier_weight
):

    device = "cpu" if (not torch.cuda.is_available() or not cuda) else "cuda:0"
    wandb.init(project=args.dataset)

    check_manual_seed(seed)

    image_shape = (64,64,3)
    # if args.dataset == "task1": num_classes = 24
    # else : num_classes = 40

    num_classes = 40

    # Note: unsupported for now
    multi_class = True #It's True but this variable doesn't be used now


    # if args.dataset == "task1":
    #     dataset_train = CLEVRDataset(root_folder=args.dataroot,img_folder=args.dataroot+'images/')
    #     train_loader = DataLoader(dataset_train,batch_size=args.batch_size,shuffle=True,drop_last=True)
    # else :
    #     dataset_train = CelebALoader(root_folder=args.dataroot) #'/home/arg/courses/machine_learning/homework/deep_learning_and_practice/Lab7/dataset/task_2/'
    #     train_loader = DataLoader(dataset_train,batch_size=args.batch_size,shuffle=True,drop_last=True)

    dataset_train = CelebALoader(root_folder=args.dataroot) #'/home/arg/courses/machine_learning/homework/deep_learning_and_practice/Lab7/dataset/task_2/'
    train_loader = DataLoader(dataset_train,batch_size=args.batch_size,shuffle=True,drop_last=True)    


    model = Glow(
        image_shape,
        hidden_channels,
        K,
        L,
        actnorm_scale,
        flow_permutation,
        flow_coupling,
        LU_decomposed,
        num_classes,
        learn_top,
        y_condition,
    )

    model = model.to(device)
    optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup)  # noqa
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    wandb.watch(model)

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch
        x = x.to(device)
        if y_condition:
            y = y.to(device)
            z, nll, y_logits = model(x, y)
            ### x: torch.Size([batchsize, 3, 64, 64]); y: torch.Size([batchsize, 24]); z: torch.Size([batchsize, 48, 8, 8])
            losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class)
        else:
            z, nll, y_logits = model(x, None)
            losses = compute_loss(nll)

        losses["total_loss"].backward()

        if max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()


        return losses


    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(
        output_dir, "glow", n_saved=None, require_empty=False
    )
    ### n_saved (Optional[int]) – Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept.

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        checkpoint_handler,
        {"model": model, "optimizer": optimizer},
    )

    monitoring_metrics = ["total_loss"]
    RunningAverage(output_transform=lambda x: x["total_loss"]).attach(
        trainer, "total_loss"
    )


    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)


    if saved_model:
        model.load_state_dict(torch.load(saved_model, map_location="cpu")['model'])
        model.set_actnorm_init()

    @trainer.on(Events.STARTED)
    def init(engine):
        model.train()

        init_batches = []
        init_targets = []

        with torch.no_grad():
            for batch, target in islice(train_loader, None, n_init_batches):
                init_batches.append(batch)
                init_targets.append(target)

            init_batches = torch.cat(init_batches).to(device)

            assert init_batches.shape[0] == n_init_batches * batch_size

            if y_condition:
                init_targets = torch.cat(init_targets).to(device)
            else:
                init_targets = None

            model(init_batches, init_targets)




    # evaluator = evaluation_model(args.classifier_weight)
    # @trainer.on(Events.EPOCH_COMPLETED)
    # def evaluate(engine):
    #     if args.dataset == "task1":
    #         model.eval()
    #         with torch.no_grad():
    #             test_conditions = get_test_conditions(args.dataroot).cuda()
    #             predict_x = postprocess(model(y_onehot=test_conditions, temperature=1, reverse=True)).float()
    #             score = evaluator.eval(predict_x, test_conditions)
    #             save_image(predict_x.float(), args.output_dir+f"/Epoch{engine.state.epoch}_score{score:.3f}.png", normalize=True)

    #             test_conditions = get_new_test_conditions(args.dataroot).cuda()
    #             predict_x = postprocess(model(y_onehot=test_conditions, temperature=1, reverse=True)).float()
    #             newscore = evaluator.eval(predict_x.float(), test_conditions)
    #             save_image(predict_x.float(), args.output_dir+f"/Epoch{engine.state.epoch}_newscore{newscore:.3f}.png", normalize=True)

    #             print(f"Iter: {engine.state.iteration}  score:{score:.3f} newscore:{newscore:.3f} ")
    #             wandb.log({"score": score, "new_score": newscore})




    trainer.run(train_loader, epochs)
Пример #8
0
                                          batch_size=args.batch_size,
                                          shuffle=True)
image_size = train_dataset[0][0].size()
print('size of train data: %d' % len(train_dataset))
print('size of test data: %d' % len(test_dataset))
print('image size: %s' % str(image_size))

# Model
print('==> Model')
model = Glow(image_size,
             args.channels_h,
             args.K,
             args.L,
             save_memory=args.save_memory).to(device)
#print(model)
optimizer = torch.optim.Adam(model.parameters(),
                             lr=args.lr,
                             weight_decay=args.weight_decay)


def train(epoch):
    # warmup
    lr = min(args.lr * epoch / 10, args.lr)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    model.train()
    sum_loss = 0
    count = 0
    for iteration, batch in enumerate(train_loader, 1):
        batch = batch[0].to(device)
Пример #9
0
n_sample = 4
temp = 0.7
n_bits = 5
n_bins = 2.**n_bits
img_channels = 1

model = Glow(img_channels, n_flow, n_block, affine=affine)

z_sample = []
z_shapes = calc_z_shapes(img_channels, img_size, n_flow, n_block)
for z in z_shapes:
    z_new = torch.randn(n_sample, *z) * temp
    z_sample.append(z_new.to('cuda'))

model.to('cuda')
optimizer = Adam(model.parameters(), lr=1e-4)

plot = False
i = 0
total_loss = []

for i in range(100):
    for image, _ in tqdm(train_loader):
        optimizer.zero_grad()
        image = image.to('cuda')
        log_p, logdet, out = model(image + torch.rand_like(image) / n_bins)
        loss, log_p, log_det = calc_loss(log_p, logdet, img_size, img_channels,
                                         n_bins)
        loss.backward()
        optimizer.step()
        writer.add_scalar('loss', loss.cpu().item(), i)
Пример #10
0
def main(
    dataset,
    augment,
    batch_size,
    eval_batch_size,
    epochs,
    saved_model,
    seed,
    hidden_channels,
    K,
    L,
    actnorm_scale,
    flow_permutation,
    flow_coupling,
    LU_decomposed,
    learn_top,
    y_condition,
    extra_condition,
    sp_condition,
    d_condition,
    yd_condition,
    y_weight,
    d_weight,
    max_grad_clip,
    max_grad_norm,
    lr,
    n_workers,
    cuda,
    n_init_batches,
    output_dir,
    missing,
    saved_optimizer,
    warmup,
):

    print(output_dir)
    device = "cpu" if (not torch.cuda.is_available() or not cuda) else "cuda:0"
    print(device)
    check_manual_seed(seed)
    print("augmenting?", augment)
    train_dataset, test_dataset = check_dataset(dataset, augment, missing)
    image_shape = (32, 32, 3)

    multi_class = False

    if yd_condition:
        num_classes = 2
        num_domains = 10
        #num_classes = 10+2
        #multi_class=True
    elif d_condition:
        num_classes = 10
        num_domains = 0
    else:
        num_classes = 2
        num_domains = 0
    #print("num classes", num_classes)

    train_loader = data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=n_workers,
        drop_last=True,
    )
    test_loader = data.DataLoader(
        test_dataset,
        batch_size=eval_batch_size,
        shuffle=False,
        num_workers=n_workers,
        drop_last=False,
    )

    model = Glow(image_shape, hidden_channels, K, L, actnorm_scale,
                 flow_permutation, flow_coupling, LU_decomposed, num_classes,
                 num_domains, learn_top, y_condition, extra_condition,
                 sp_condition, d_condition, yd_condition)

    model = model.to(device)
    optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup)  # noqa
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                  lr_lambda=lr_lambda)

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y, d, yd = batch
        x = x.to(device)

        if y_condition:
            y = y.to(device)
            z, nll, y_logits, spare = model(x, y)
            losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class)
        elif d_condition:
            d = d.to(device)
            z, nll, d_logits, spare = model(x, d)

            losses = compute_loss_y(nll, d_logits, d_weight, d, multi_class)
        elif yd_condition:
            y, d, yd = y.to(device), d.to(device), yd.to(device)
            z, nll, y_logits, d_logits = model(x, yd)
            losses = compute_loss_yd(nll, y_logits, y_weight, y, d_logits,
                                     d_weight, d)
        else:
            print("none")
            z, nll, y_logits, spare = model(x, None)
            losses = compute_loss(nll)

        losses["total_loss"].backward()

        if max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y, d, yd = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits, none_logits = model(x, y)
                losses = compute_loss_y(nll,
                                        y_logits,
                                        y_weight,
                                        y,
                                        multi_class,
                                        reduction="none")
            elif d_condition:
                d = d.to(device)
                z, nll, d_logits, non_logits = model(x, d)
                losses = compute_loss_y(nll,
                                        d_logits,
                                        d_weight,
                                        d,
                                        multi_class,
                                        reduction="none")
            elif yd_condition:
                y, d, yd = y.to(device), d.to(device), yd.to(device)
                z, nll, y_logits, d_logits = model(x, yd)
                losses = compute_loss_yd(nll,
                                         y_logits,
                                         y_weight,
                                         y,
                                         d_logits,
                                         d_weight,
                                         d,
                                         reduction="none")
            else:

                z, nll, y_logits, d_logits = model(x, None)
                losses = compute_loss(nll, reduction="none")
        #print(losses, "losssssess")
        return losses

    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         "glow",
                                         save_interval=1,
                                         n_saved=2,
                                         require_empty=False)

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        checkpoint_handler,
        {
            "model": model,
            "optimizer": optimizer
        },
    )

    monitoring_metrics = ["total_loss"]
    RunningAverage(output_transform=lambda x: x["total_loss"]).attach(
        trainer, "total_loss")

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(
        lambda x, y: torch.mean(x),
        output_transform=lambda x: (
            x["total_loss"],
            torch.empty(x["total_loss"].shape[0]),
        ),
    ).attach(evaluator, "total_loss")

    if y_condition or d_condition or yd_condition:
        monitoring_metrics.extend(["nll"])
        RunningAverage(output_transform=lambda x: x["nll"]).attach(
            trainer, "nll")

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(
            lambda x, y: torch.mean(x),
            output_transform=lambda x:
            (x["nll"], torch.empty(x["nll"].shape[0])),
        ).attach(evaluator, "nll")

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        model.load_state_dict(torch.load(saved_model))
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer))

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split("_")[-1])

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(
                engine.state.dataloader)

    @trainer.on(Events.STARTED)
    def init(engine):
        model.train()

        init_batches = []
        init_targets = []
        init_domains = []
        init_yds = []

        with torch.no_grad():
            for batch, target, domain, yd in islice(train_loader, None,
                                                    n_init_batches):
                init_batches.append(batch)
                init_targets.append(target)
                init_domains.append(domain)
                init_yds.append(yd)

            init_batches = torch.cat(init_batches).to(device)

            assert init_batches.shape[0] == n_init_batches * batch_size

            if y_condition:
                init_targets = torch.cat(init_targets).to(device)
                model(init_batches, init_targets)
            elif d_condition:
                init_domains = torch.cat(init_domains).to(device)
                model(init_batches, init_domains)
            elif yd_condition:
                init_yds = torch.cat(init_yds).to(device)
                model(init_batches, init_yds)
            else:
                init_targets = None
                model(init_batches, init_targets)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):
        evaluator.run(test_loader)
        #print("done")
        scheduler.step()
        metrics = evaluator.state.metrics

        losses = ", ".join(
            [f"{key}: {value:.8f}" for key, value in metrics.items()])

        print(f"Validation Results - Epoch: {engine.state.epoch} {losses}")

    def score_function(engine):
        val_loss = engine.state.metrics['total_loss']

        return -val_loss

    name = "best_"

    val_handler = ModelCheckpoint(output_dir,
                                  name,
                                  score_function=score_function,
                                  score_name="val_loss",
                                  n_saved=1,
                                  require_empty=False)

    evaluator.add_event_handler(
        Events.EPOCH_COMPLETED,
        val_handler,
        {"model": model},
    )

    timer = Timer(average=True)
    timer.attach(
        trainer,
        start=Events.EPOCH_STARTED,
        resume=Events.ITERATION_STARTED,
        pause=Events.ITERATION_COMPLETED,
        step=Events.ITERATION_COMPLETED,
    )

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f"Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]"
        )
        timer.reset()

    trainer.run(train_loader, epochs)
Пример #11
0
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size,
         epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale,
         flow_permutation, flow_coupling, LU_decomposed, learn_top,
         y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers,
         cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer,
         warmup, fresh, logittransform, gan, disc_lr):

    device = 'cpu' if (not torch.cuda.is_available() or not cuda) else 'cuda:0'

    check_manual_seed(seed)

    ds = check_dataset(dataset, dataroot, augment, download)
    image_shape, num_classes, train_dataset, test_dataset = ds

    # Note: unsupported for now
    multi_class = False

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=n_workers,
                                   drop_last=True)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=eval_batch_size,
                                  shuffle=False,
                                  num_workers=n_workers,
                                  drop_last=False)

    model = Glow(image_shape, hidden_channels, K, L, actnorm_scale,
                 flow_permutation, flow_coupling, LU_decomposed, num_classes,
                 learn_top, y_condition, logittransform)

    model = model.to(device)

    if gan:
        # Debug
        model = mine.Generator(32, 1).to(device)

        optimizer = optim.Adam(model.parameters(),
                               lr=lr,
                               betas=(.5, .99),
                               weight_decay=0)
        discriminator = mine.Discriminator(image_shape[-1])
        discriminator = discriminator.to(device)
        D_optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                        discriminator.parameters()),
                                 lr=disc_lr,
                                 betas=(.5, .99),
                                 weight_decay=0)
    else:
        optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    # lr_lambda = lambda epoch: lr * min(1., epoch+1 / warmup)
    # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    i = 0

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch
        x = x.to(device)

        if y_condition:
            y = y.to(device)
            z, nll, y_logits = model(x, y)
            losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class)
        else:
            z, nll, y_logits = model(x, None)
            losses = compute_loss(nll)

        losses['total_loss'].backward()

        if max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        return losses

    def gan_step(engine, batch):
        assert not y_condition
        if 'iter_ind' in dir(engine):
            engine.iter_ind += 1
        else:
            engine.iter_ind = -1
        losses = {}
        model.train()
        discriminator.train()

        x, y = batch
        x = x.to(device)

        # def generate_from_noise(batch_size):
        #     _, c2, h, w  = model.prior_h.shape
        #     c = c2 // 2
        #     zshape = (batch_size, c, h, w)
        #     randz  = torch.autograd.Variable(torch.randn(zshape), requires_grad=True).to(device)
        #     images = model(z= randz, y_onehot=None, temperature=1, reverse=True,batch_size=batch_size)
        #     return images

        def generate_from_noise(batch_size):

            zshape = (batch_size, 32, 1, 1)
            randz = torch.randn(zshape).to(device)
            images = model(randz)
            return images / 2

        def run_noised_disc(discriminator, x):
            x = uniform_binning_correction(x)[0]
            return discriminator(x)

        # Train Disc
        fake = generate_from_noise(x.size(0))

        D_real_scores = run_noised_disc(discriminator, x.detach())
        D_fake_scores = run_noised_disc(discriminator, fake.detach())

        ones_target = torch.ones((x.size(0), 1), device=x.device)
        zeros_target = torch.zeros((x.size(0), 1), device=x.device)

        # D_real_accuracy = torch.sum(torch.round(F.sigmoid(D_real_scores)) == ones_target).float() / ones_target.size(0)
        # D_fake_accuracy = torch.sum(torch.round(F.sigmoid(D_fake_scores)) == zeros_target).float() / zeros_target.size(0)

        D_real_loss = F.binary_cross_entropy_with_logits(
            D_real_scores, ones_target)
        D_fake_loss = F.binary_cross_entropy_with_logits(
            D_fake_scores, zeros_target)

        D_loss = (D_real_loss + D_fake_loss) / 2
        gp = gradient_penalty(x.detach(), fake.detach(),
                              lambda _x: run_noised_disc(discriminator, _x))
        D_loss_plus_gp = D_loss + 10 * gp
        D_optimizer.zero_grad()
        D_loss_plus_gp.backward()
        D_optimizer.step()

        # Train generator
        fake = generate_from_noise(x.size(0))
        G_loss = F.binary_cross_entropy_with_logits(
            run_noised_disc(discriminator, fake),
            torch.ones((x.size(0), 1), device=x.device))
        losses['total_loss'] = G_loss

        # G-step
        optimizer.zero_grad()
        losses['total_loss'].backward()
        params = list(model.parameters())
        gnorm = [p.grad.norm() for p in params]
        optimizer.step()
        # if max_grad_clip > 0:
        #     torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        # if max_grad_norm > 0:
        #     torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        if engine.iter_ind % 50 == 0:
            grid = make_grid((postprocess(fake.detach().cpu())[:30]),
                             nrow=6).permute(1, 2, 0)
            plt.figure(figsize=(10, 10))
            plt.imshow(grid)
            plt.axis('off')
            plt.savefig(
                os.path.join(output_dir, f'sample_{engine.iter_ind}.png'))

            grid = make_grid(
                (postprocess(uniform_binning_correction(x)[0].cpu())[:30]),
                nrow=6).permute(1, 2, 0)
            plt.figure(figsize=(10, 10))
            plt.imshow(grid)
            plt.axis('off')
            plt.savefig(os.path.join(output_dir,
                                     f'data_{engine.iter_ind}.png'))

        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(nll,
                                        y_logits,
                                        y_weight,
                                        y,
                                        multi_class,
                                        reduction='none')
            else:
                z, nll, y_logits = model(x, None)
                losses = compute_loss(nll, reduction='none')

        return losses

    if gan:
        trainer = Engine(gan_step)
    else:
        trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         'glow',
                                         save_interval=1,
                                         n_saved=2,
                                         require_empty=False)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {
        'model': model,
        'optimizer': optimizer
    })

    monitoring_metrics = ['total_loss']
    RunningAverage(output_transform=lambda x: x['total_loss']).attach(
        trainer, 'total_loss')

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(lambda x, y: torch.mean(x),
         output_transform=lambda x:
         (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach(
             evaluator, 'total_loss')

    if y_condition:
        monitoring_metrics.extend(['nll'])
        RunningAverage(output_transform=lambda x: x['nll']).attach(
            trainer, 'nll')

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(lambda x, y: torch.mean(x),
             output_transform=lambda x:
             (x['nll'], torch.empty(x['nll'].shape[0]))).attach(
                 evaluator, 'nll')

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        model.load_state_dict(torch.load(saved_model))
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer))

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split('_')[-1])

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(
                engine.state.dataloader)

    # @trainer.on(Events.STARTED)
    # def init(engine):
    #     model.train()

    #     init_batches = []
    #     init_targets = []

    #     with torch.no_grad():
    #         for batch, target in islice(train_loader, None,
    #                                     n_init_batches):
    #             init_batches.append(batch)
    #             init_targets.append(target)

    #         init_batches = torch.cat(init_batches).to(device)

    #         assert init_batches.shape[0] == n_init_batches * batch_size

    #         if y_condition:
    #             init_targets = torch.cat(init_targets).to(device)
    #         else:
    #             init_targets = None

    #         model(init_batches, init_targets)

    # @trainer.on(Events.EPOCH_COMPLETED)
    # def evaluate(engine):
    #     evaluator.run(test_loader)

    #     # scheduler.step()
    #     metrics = evaluator.state.metrics

    #     losses = ', '.join([f"{key}: {value:.2f}" for key, value in metrics.items()])

    #     myprint(f'Validation Results - Epoch: {engine.state.epoch} {losses}')

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]'
        )
        timer.reset()

    trainer.run(train_loader, epochs)