Пример #1
0
def main():
    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    hyperparams = Hyperparameters(args.snapshot_path)
    hyperparams.print()

    num_bins_x = 2.0**hyperparams.num_bits_x

    encoder = Glow(hyperparams, hdf5_path=args.snapshot_path)
    if using_gpu:
        encoder.to_gpu()

    with chainer.no_backprop_mode() and encoder.reverse() as decoder:
        while True:
            z = xp.random.normal(0,
                                 args.temperature,
                                 size=(
                                     1,
                                     3,
                                 ) + hyperparams.image_size).astype("float32")

            x, _ = decoder.reverse_step(z)
            x_img = make_uint8(x.data[0], num_bins_x)
            plt.imshow(x_img, interpolation="none")
            plt.pause(.01)
Пример #2
0
def get_model(path, using_gpu):
    print(path)
    hyperparams = Hyperparameters(path)
    hyperparams.print()

    num_bins_x = 2.0**hyperparams.num_bits_x

    encoder = Glow(hyperparams, hdf5_path=path)
    if using_gpu:
        encoder.to_gpu()

    return encoder, num_bins_x, hyperparams
Пример #3
0
    def forward_and_reverse_output_shape(self,
                                         in_channel,
                                         data,
                                         levels=3,
                                         depth=4):
        glow = Glow(in_channel, levels, depth)
        z, logdet, eps = glow(data)
        height, width = data.shape[2], data.shape[3]
        """
            cifar example:
            Level = 3
            initial shape -> [4, 3, 32, 32]
            iter 1 -> z: [4, 12, 16, 16] because of squeeze from outside the loop
            iter 2 -> z: [4, 24, 8, 8] because of squeeze + split
            iter 3 -> z: [4, 48, 4, 4] because of squeeze + split
        """
        assert list(z.shape) == [4, in_channel * 4 * 2**(levels - 1), 4, 4]
        assert list(logdet.shape) == [4]  # because batch_size = 4
        assert len(
            eps
        ) == levels - 1  # because L = 3 and split is executed whenever < L, i.e 2 times in total

        factor = 1
        for e in eps:
            factor *= 2
            # example: first eps -> from iter 1 take z shape and divide channel by 2: [4, 12/2, 16, 16]
            assert list(e.shape) == [
                4, in_channel * factor, height / factor, width / factor
            ]
        """
            In total depth * levels = 4 * 3 = 12, so we got 12 instances of actnorm, inconv and affinecoupling
            Actnorm = 2 trainable parameters
            Invconv = 3 trainable parameter
            Affinecoupling = 6 trainable parameters (got 3 conv layers, each layer has weight + bias, so for all layers combined we get 6 in total)
            Zeroconv = 4 (2 conv layers, each with weight + bias)
            
            12 * (2+3+6) + 4= 136
        """
        assert len(list(
            glow.parameters())) == (levels * depth) * (2 + 3 + 6) + 4
        for param in glow.parameters():
            assert param.requires_grad

        # reverse
        # For cifar we expect z with level=3 to be of shape [4,48,4,4]
        z = glow.reverse(z, eps)

        assert list(z.shape) == [4, 3, 32, 32]
Пример #4
0
 def test_generate_sample(self):
     in_channel_cifar = 3
     levels = 4
     depth = 8
     glow = Glow(in_channel_cifar, levels, depth)
     x = torch.randn((2, 3, 32, 32))
     x = generate(glow, 2, 'cpu', x.shape, levels)
     # print(x)
     print(x.shape)
Пример #5
0
def main():
    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    hyperparams = Hyperparameters(args.snapshot_path)
    hyperparams.print()

    num_bins_x = 2.0**hyperparams.num_bits_x

    encoder = Glow(hyperparams, hdf5_path=args.snapshot_path)
    if using_gpu:
        encoder.to_gpu()

    temperatures = [0.0, 0.25, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    total = len(temperatures)
    fig = plt.figure(figsize=(total * 4, 4))
    subplots = []
    for n in range(total):
        subplot = fig.add_subplot(1, total, n + 1)
        subplots.append(subplot)

    with chainer.no_backprop_mode() and encoder.reverse() as decoder:
        while True:
            z_batch = []
            for temperature in temperatures:
                z = np.random.normal(0,
                                     temperature,
                                     size=(3, ) +
                                     hyperparams.image_size).astype("float32")
                z_batch.append(z)
            z_batch = np.asanyarray(z_batch)
            if using_gpu:
                z_batch = cuda.to_gpu(z_batch)
            x, _ = decoder.reverse_step(z_batch)
            for n, (temperature,
                    subplot) in enumerate(zip(temperatures, subplots)):
                x_img = make_uint8(x.data[n], num_bins_x)
                # x_img = np.broadcast_to(x_img, (28, 28, 3))
                subplot.imshow(x_img, interpolation="none")
                subplot.set_title("temperature={}".format(temperature))
            plt.pause(.01)
Пример #6
0
def main(args,kwargs):
    output_folder = args.output_folder
    model_name = args.model_name

    with open(os.path.join(output_folder,'hparams.json')) as json_file:  
        hparams = json.load(json_file)
      
    image_shape, num_classes, _, test_mnist = get_MNIST(False, hparams['dataroot'], hparams['download'])
    test_loader = data.DataLoader(test_mnist, batch_size=32,
                                      shuffle=False, num_workers=6,
                                      drop_last=False)
    x, y = test_loader.__iter__().__next__()
    x = x.to(device)

    model = Glow(image_shape, hparams['hidden_channels'], hparams['K'], hparams['L'], hparams['actnorm_scale'],
                 hparams['flow_permutation'], hparams['flow_coupling'], hparams['LU_decomposed'], num_classes,
                 hparams['learn_top'], hparams['y_condition'], False if 'logittransform' not in hparams else hparams['logittransform'],False if 'sn' not in hparams else hparams['sn'])

    model.load_state_dict(torch.load(os.path.join(output_folder, model_name)))
    model.set_actnorm_init()

    model = model.to(device)
    model = model.eval()



    with torch.no_grad():
        # ipdb.set_trace()
        images = model(y_onehot=None, temperature=1, batch_size=32, reverse=True).cpu()  
        better_dup_images = model(y_onehot=None, temperature=1, z=model._last_z, reverse=True, use_last_split=True).cpu()   
        dup_images = model(y_onehot=None, temperature=1, z=model._last_z, reverse=True).cpu()   
        worse_dup_images = model(y_onehot=None, temperature=1, z=model._last_z, reverse=True).cpu()   

    l2_err =  torch.pow((images - dup_images).view(images.shape[0], -1), 2).sum(-1).mean()
    better_l2_err =  torch.pow((images - better_dup_images).view(images.shape[0], -1), 2).sum(-1).mean()
    worse_l2_err =  torch.pow((images - worse_dup_images).view(images.shape[0], -1), 2).sum(-1).mean()
    print(l2_err, better_l2_err, worse_l2_err)
    plot_imgs([images, dup_images, better_dup_images, worse_dup_images], '_recons')

    # 
    with torch.no_grad():
        # ipdb.set_trace()
        z, nll, y_logits = model(x, None)
        better_dup_images = model(y_onehot=None, temperature=1, z=z, reverse=True, use_last_split=True).cpu()   

    plot_imgs([x, better_dup_images], '_data_recons2')

    fpath = os.path.join(output_folder, '_recon_evoluation.png')
    pad = run_recon_evolution(model, x, fpath)
Пример #7
0
def main(args):
    # we're probably only be using 1 GPU, so this should be fine
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"running on {device}")
    # set random seed for all
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    global best_loss
    if(args.generate_samples):
        print("generating samples")
    # load data
    # example for CIFAR-10 training:
    train_set, test_set = get_dataloader(args.dataset, args.batch_size)

    input_channels = channels_from_dataset(args.dataset)
    print(f"amount of  input channels: {input_channels}")
    # instantiate model
    # # baby network to make sure training script works
    net = Glow(in_channels=input_channels,
               depth=args.amt_flow_steps, levels=args.amt_levels, use_normalization=args.norm_method)

    # code for rosalinty model
    # net = RosGlow(input_channels, args.amt_flow_steps, args.amt_levels)

    net = net.to(device)

    print(f"training for {args.num_epochs} epochs.")

    start_epoch = 0
    # TODO: add functionality for loading checkpoints here
    if args.resume:
        print(f"resuming from checkpoint found in checkpoints/best_{args.dataset.lower()}.pth.tar.")
        # raise error if no checkpoint directory is found
        assert os.path.isdir("new_checkpoints")
        checkpoint = torch.load(f"new_checkpoints/best_{args.dataset.lower()}.pth.tar")
        net.load_state_dict(checkpoint["model"])
        global best_loss
        best_loss = checkpoint["test_loss"]
        start_epoch = checkpoint["epoch"]

    loss_function = FlowNLL().to(device)
    optimizer = optim.Adam(net.parameters(), lr=float(args.lr))
    # scheduler found in code, no mention in paper
    # scheduler = sched.LambdaLR(
    #     optimizer, lambda s: min(1., s / args.warmup_iters))

    # should we add a resume function here?

    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        print(f"training epoch {epoch}")
        train(net, train_set, device, optimizer, loss_function, epoch)
        # how often do we want to test?
        if (epoch % 10 == 0):  # revert this to 10 once we know that this works
            print(f"testing epoch {epoch}")
            test(net, test_set, device, loss_function, epoch, args.generate_samples,
                 args.amt_levels, args.dataset, args.n_samples)
Пример #8
0
def main():
    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    hyperparams = Hyperparameters(args.snapshot_path)
    hyperparams.print()

    num_bins_x = 2.0**hyperparams.num_bits_x
    image_size = (28, 28)

    images = chainer.datasets.mnist.get_mnist(withlabel=False)[0]
    images = 255.0 * np.asarray(images).reshape((-1, ) + image_size + (1, ))
    if hyperparams.num_image_channels != 1:
        images = np.broadcast_to(images, (images.shape[0], ) + image_size +
                                 (hyperparams.num_image_channels, ))
    images = preprocess(images, hyperparams.num_bits_x)

    dataset = glow.dataset.Dataset(images)
    iterator = glow.dataset.Iterator(dataset, batch_size=1)

    print(tabulate([["#image", len(dataset)]]))

    encoder = Glow(hyperparams, hdf5_path=args.snapshot_path)
    if using_gpu:
        encoder.to_gpu()

    fig = plt.figure(figsize=(8, 4))
    left = fig.add_subplot(1, 2, 1)
    right = fig.add_subplot(1, 2, 2)

    with chainer.no_backprop_mode() and encoder.reverse() as decoder:
        while True:
            for data_indices in iterator:
                x = to_gpu(dataset[data_indices])
                x += xp.random.uniform(0, 1.0 / num_bins_x, size=x.shape)
                factorized_z_distribution, _ = encoder.forward_step(x)

                factorized_z = []
                for (zi, mean, ln_var) in factorized_z_distribution:
                    factorized_z.append(zi)

                # for zi in factorized_z:
                #     noise = xp.random.normal(
                #         0, 0.2, size=zi.shape).astype("float32")
                #     zi.data += noise
                rev_x, _ = decoder.reverse_step(factorized_z)

                x_img = make_uint8(x[0], num_bins_x)
                rev_x_img = make_uint8(rev_x.data[0], num_bins_x)

                left.imshow(x_img, interpolation="none")
                right.imshow(rev_x_img, interpolation="none")

                plt.pause(.01)
Пример #9
0
def main():
    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    hyperparams = Hyperparameters(args.snapshot_path)
    hyperparams.print()

    num_bins_x = 2.0**hyperparams.num_bits_x

    assert args.dataset_format in ["png", "npy"]

    files = Path(args.dataset_path).glob("*.{}".format(args.dataset_format))
    if args.dataset_format == "png":
        images = []
        for filepath in files:
            image = np.array(Image.open(filepath)).astype("float32")
            image = preprocess(image, hyperparams.num_bits_x)
            images.append(image)
        assert len(images) > 0
        images = np.asanyarray(images)
    elif args.dataset_format == "npy":
        images = []
        for filepath in files:
            array = np.load(filepath).astype("float32")
            array = preprocess(array, hyperparams.num_bits_x)
            images.append(array)
        assert len(images) > 0
        num_files = len(images)
        images = np.asanyarray(images)
        images = images.reshape((num_files * images.shape[1], ) +
                                images.shape[2:])
    else:
        raise NotImplementedError

    dataset = glow.dataset.Dataset(images)
    iterator = glow.dataset.Iterator(dataset, batch_size=1)

    print(tabulate([["#image", len(dataset)]]))

    encoder = Glow(hyperparams, hdf5_path=args.snapshot_path)
    if using_gpu:
        encoder.to_gpu()

    with chainer.no_backprop_mode() and encoder.reverse() as decoder:
        for data_indices in iterator:
            print("data:", data_indices)
            x = to_gpu(dataset[data_indices])
            x += xp.random.uniform(0, 1.0 / num_bins_x, size=x.shape)
            factorized_z_distribution, _ = encoder.forward_step(x)

            for (_, mean, ln_var) in factorized_z_distribution:
                print(xp.mean(mean.data), xp.mean(xp.exp(ln_var.data)))
Пример #10
0
def main():
    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    hyperparams = Hyperparameters(args.snapshot_path)
    hyperparams.print()

    num_bins_x = 2.0**hyperparams.num_bits_x

    encoder = Glow(hyperparams, hdf5_path=args.snapshot_path)
    if using_gpu:
        encoder.to_gpu()

    total = hyperparams.levels + 1
    fig = plt.figure(figsize=(4 * total, 4))
    subplots = []
    for n in range(total):
        subplot = fig.add_subplot(1, total, n + 1)
        subplots.append(subplot)

    def reverse_step(z, sampling=True):
        if isinstance(z, list):
            factorized_z = z
        else:
            factorized_z = encoder.factor_z(z)

        assert len(factorized_z) == len(encoder.blocks)

        out = None
        sum_logdet = 0

        for block, zi in zip(encoder.blocks[::-1], factorized_z[::-1]):
            out, logdet = block.reverse_step(
                out,
                gaussian_eps=zi,
                squeeze_factor=encoder.hyperparams.squeeze_factor,
                sampling=sampling)
            sum_logdet += logdet

        return out, sum_logdet

    with chainer.no_backprop_mode() and encoder.reverse() as decoder:
        while True:
            base_z = xp.random.normal(0,
                                      args.temperature,
                                      size=(
                                          1,
                                          3,
                                      ) + hyperparams.image_size,
                                      dtype="float32")
            factorized_z = encoder.factor_z(base_z)

            rev_x, _ = decoder.reverse_step(factorized_z)
            rev_x_img = make_uint8(rev_x.data[0], num_bins_x)
            subplots[0].imshow(rev_x_img, interpolation="none")

            z = xp.copy(base_z)
            factorized_z = encoder.factor_z(z)
            for n in range(hyperparams.levels - 1):
                factorized_z[n] = xp.random.normal(0,
                                                   args.temperature,
                                                   size=factorized_z[n].shape,
                                                   dtype="float32")
            rev_x, _ = decoder.reverse_step(factorized_z)
            rev_x_img = make_uint8(rev_x.data[0], num_bins_x)
            subplots[1].imshow(rev_x_img, interpolation="none")

            # for n in range(hyperparams.levels):
            #     z = xp.copy(base_z)
            #     factorized_z = encoder.factor_z(z)
            #     for m in range(n + 1):
            #         factorized_z[m] = xp.random.normal(
            #             0,
            #             args.temperature,
            #             size=factorized_z[m].shape,
            #             dtype="float32")
            #         # factorized_z[m] = xp.zeros_like(factorized_z[m])
            #     out = None
            #     for k, (block, zi) in enumerate(
            #             zip(encoder.blocks[::-1], factorized_z[::-1])):
            #         sampling = False
            #         out, _ = block.reverse_step(
            #             out,
            #             gaussian_eps=zi,
            #             squeeze_factor=encoder.hyperparams.squeeze_factor,
            #             sampling=sampling)
            #     rev_x = out

            #     rev_x_img = make_uint8(rev_x.data[0], num_bins_x)
            #     subplots[n + 1].imshow(rev_x_img, interpolation="none")

            for n in range(hyperparams.levels):
                z = xp.copy(base_z)
                factorized_z = encoder.factor_z(z)
                factorized_z[n] = xp.random.normal(0,
                                                   args.temperature,
                                                   size=factorized_z[n].shape,
                                                   dtype="float32")
                factorized_z[n] = xp.zeros_like(factorized_z[n])
                out = None
                for k, (block, zi) in enumerate(
                        zip(encoder.blocks[::-1], factorized_z[::-1])):
                    sampling = False if k == hyperparams.levels - n - 1 else True
                    out, _ = block.reverse_step(
                        out,
                        gaussian_eps=zi,
                        squeeze_factor=encoder.hyperparams.squeeze_factor,
                        sampling=sampling)
                rev_x = out

                rev_x_img = make_uint8(rev_x.data[0], num_bins_x)
                subplots[n + 1].imshow(rev_x_img, interpolation="none")
            plt.pause(.01)
Пример #11
0
                    f'Loss: {loss:.5f}; logP: {log_p:.5f}; logdet: {log_det:.5f}; lr: {warmup_lr:.7f}'
                )
                check_save(model_single, optimizer, args, z_sample, i)
    except (KeyboardInterrupt, SystemExit):
        check_save(model_single, optimizer, args, z_sample, i, save=True)
        raise


if __name__ == '__main__':
    args = parser.parse_args()
    if len(args.load_path) > 0:
        args.startiter = int(args.load_path[:-3].split('_')[-1])
    print(args)

    model_single = Glow(
        1, args.n_flow, args.n_block, affine=args.affine, conv_lu=not args.no_lu
    ).cpu()
    if len(args.load_path) > 0:
        model_single.load_state_dict(torch.load(args.load_path,  map_location=lambda storage, loc: storage))
        
        model_single.initialize()
        gc.collect()
        torch.cuda.empty_cache()
    model = model_single
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    if len(args.load_path) > 0:
        optim_path = '/'.join(args.load_path.split('/')[:-1])
        optimizer.load_state_dict(torch.load(os.path.join(optim_path, 'optimizer.pth'),  map_location=lambda storage, loc: storage))
        gc.collect()
    model_path = args.model_path

    checkpoint_path = os.path.join(model_path, args.checkpoint)
    params_path = os.path.join(model_path, 'hparams.json')

    with open(os.path.join(model_path, 'hparams.json')) as json_file:
        hparams = json.load(json_file)

    ds = check_dataset(args.dataset, args.dataroot, True, args.download)
    ds2 = check_dataset(args.dataset2, args.dataroot, True, args.download)
    image_shape, num_classes, train_dataset, test_dataset = ds
    image_shape2, num_classes2, train_dataset_2, test_dataset_2 = ds2

    model = Glow(image_shape, hparams['hidden_channels'], hparams['K'],
                 hparams['L'], hparams['actnorm_scale'],
                 hparams['flow_permutation'], hparams['flow_coupling'],
                 hparams['LU_decomposed'], num_classes, hparams['learn_top'],
                 hparams['y_condition'])

    dic = torch.load(checkpoint_path)
    if 'model' in dic.keys():
        model.load_state_dict(dic["model"])
    else:
        model.load_state_dict(dic)
    model.set_actnorm_init()

    model = model.to(device)
    model = model.eval()

    if args.optim_type == "ADAM":
        optim_default = partial(optim.Adam, lr=args.lr_test)
Пример #13
0
def main(args):
    # torch.manual_seed(args.seed)

    # Test loading and sampling
    output_folder = os.path.join('results', args.name)

    with open(os.path.join(output_folder, 'hparams.json')) as json_file:
        hparams = json.load(json_file)

    device = "cpu" if not torch.cuda.is_available() else "cuda:0"
    image_shape = (hparams['patch_size'], hparams['patch_size'],
                   args.n_modalities)
    num_classes = 1

    print('Loading model...')
    model = Glow(image_shape, hparams['hidden_channels'], hparams['K'],
                 hparams['L'], hparams['actnorm_scale'],
                 hparams['flow_permutation'], hparams['flow_coupling'],
                 hparams['LU_decomposed'], num_classes, hparams['learn_top'],
                 hparams['y_condition'])

    model_chkpt = torch.load(
        os.path.join(output_folder, 'checkpoints', args.model))
    model.load_state_dict(model_chkpt['model'])
    model.set_actnorm_init()
    model = model.to(device)

    # Build images
    model.eval()
    temperature = args.temperature

    if args.steps is None:  # automatically calculate step size if no step size

        fig_dir = os.path.join(output_folder, 'stepnum_results')
        if not os.path.exists(fig_dir):
            os.mkdir(fig_dir)

        print('No step size entered')

        # Create sample of images to estimate chord length
        with torch.no_grad():
            mean, logs = model.prior(None, None)
            z = gaussian_sample(mean, logs, temperature)
            images_raw = model(z=z, temperature=temperature, reverse=True)
        images_raw[torch.isnan(images_raw)] = 0.5
        images_raw[torch.isinf(images_raw)] = 0.5
        images_raw = torch.clamp(images_raw, -0.5, 0.5)

        images_out = np.transpose(
            np.squeeze(images_raw[:, args.step_modality, :, :].cpu().numpy()),
            (1, 0, 2))

        # Threshold images and compute covariances
        if args.binary_data:
            thresh = 0
        else:
            thresh = threshold_otsu(images_out)
        images_bin = np.greater(images_out, thresh)
        x_cov = two_point_correlation(images_bin, 0)
        y_cov = two_point_correlation(images_bin, 1)

        # Compute chord length
        cov_avg = np.mean(np.mean(np.concatenate((x_cov, y_cov), axis=2),
                                  axis=0),
                          axis=0)
        N = 5
        S20, _ = curve_fit(straight_line_at_origin(cov_avg[0]), range(0, N),
                           cov_avg[0:N])
        l_pore = np.abs(cov_avg[0] / S20)
        steps = int(l_pore)
        print('Calculated step size: {}'.format(steps))

    else:
        print('Using user-entered step size {}...'.format(args.steps))
        steps = args.steps

    # Build desired number of volumes
    for iter_vol in range(args.iter):
        if args.iter == 1:
            stack_dir = os.path.join(output_folder, 'image_stacks',
                                     args.save_name)
            print('Sampling images, saving to {}...'.format(args.save_name))
        else:
            stack_dir = os.path.join(
                output_folder, 'image_stacks',
                args.save_name + '_' + str(iter_vol).zfill(3))
            print('Sampling images, saving to {}_'.format(args.save_name) +
                  str(iter_vol).zfill(3) + '...')
        if not os.path.exists(stack_dir):
            os.makedirs(stack_dir)

        with torch.no_grad():
            mean, logs = model.prior(None, None)
            alpha = 1 - torch.reshape(torch.linspace(0, 1, steps=steps),
                                      (-1, 1, 1, 1))
            alpha = alpha.to(device)

            num_imgs = int(np.ceil(hparams['patch_size'] / steps) + 1)
            z = gaussian_sample(mean, logs, temperature)[:num_imgs, ...]
            z = torch.cat([
                alpha * z[i, ...] + (1 - alpha) * z[i + 1, ...]
                for i in range(num_imgs - 1)
            ])
            z = z[:hparams['patch_size'], ...]

            images_raw = model(z=z, temperature=temperature, reverse=True)

        images_raw[torch.isnan(images_raw)] = 0.5
        images_raw[torch.isinf(images_raw)] = 0.5
        images_raw = torch.clamp(images_raw, -0.5, 0.5)

        # apply median filter to output
        if args.med_filt is not None or args.binary_data:
            for m in range(args.n_modalities):
                if args.binary_data:
                    SE = ball(1)
                else:
                    SE = ball(args.med_filt)
                images_np = np.squeeze(images_raw[:, m, :, :].cpu().numpy())
                images_filt = median_filter(images_np, footprint=SE)

                # Erode binary images
                if args.binary_data:
                    images_filt = np.greater(images_filt, 0)
                    SE = ball(1)
                    images_filt = 1.0 * binary_erosion(images_filt,
                                                       selem=SE) - 0.5

                images_raw[:, m, :, :] = torch.tensor(images_filt,
                                                      device=device)

        images1 = postprocess(images_raw).cpu()
        images2 = postprocess(torch.transpose(images_raw, 0, 2)).cpu()
        images3 = postprocess(torch.transpose(images_raw, 0, 3)).cpu()

        # apply Otsu thresholding to output
        if args.save_binary and not args.binary_data:
            thresh = threshold_otsu(images1.numpy())
            images1[images1 < thresh] = 0
            images1[images1 > thresh] = 255
            images2[images2 < thresh] = 0
            images2[images2 > thresh] = 255
            images3[images3 < thresh] = 0
            images3[images3 > thresh] = 255

        # # erode binary images by 1 px to correct for training image transformation
        # if args.binary_data:
        #     images1 = np.greater(images1.numpy(), 127)
        #     images2 = np.greater(images2.numpy(), 127)
        #     images3 = np.greater(images3.numpy(), 127)

        #     images1 = 255*torch.tensor(1.0*np.expand_dims(binary_erosion(np.squeeze(images1), selem=np.ones((1,2,2))), 1))
        #     images2 = 255*torch.tensor(1.0*np.expand_dims(binary_erosion(np.squeeze(images2), selem=np.ones((2,1,2))), 1))
        #     images3 = 255*torch.tensor(1.0*np.expand_dims(binary_erosion(np.squeeze(images3), selem=np.ones((2,2,1))), 1))

        # save video for each modality
        for m in range(args.n_modalities):
            if args.n_modalities > 1:
                save_dir = os.path.join(stack_dir, 'modality{}'.format(m))
            else:
                save_dir = stack_dir

            if not os.path.exists(save_dir):
                os.makedirs(save_dir)

            write_video(images1[:, m, :, :], 'xy', hparams, save_dir)
            write_video(images2[:, m, :, :], 'xz', hparams, save_dir)
            write_video(images3[:, m, :, :], 'yz', hparams, save_dir)

    print('Finished!')
Пример #14
0
def main():
    try:
        os.mkdir(args.snapshot_path)
    except:
        pass

    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    num_bins_x = 2**args.num_bits_x

    assert args.dataset_format in ["png", "npy"]

    # Get datasets:
    if True:
        files = Path(args.dataset_path).glob("*.{}".format(
            args.dataset_format))
        if args.dataset_format == "png":
            images = []
            for filepath in files:
                image = np.array(Image.open(filepath)).astype("float32")
                image = preprocess(image, args.num_bits_x)
                images.append(image)
            assert len(images) > 0
            images = np.asanyarray(images)
        elif args.dataset_format == "npy":
            images = []
            for filepath in files:
                array = np.load(filepath).astype("float32")
                # TODO: Preprocess here
                array = preprocess(array, args.num_bits_x)
                images.append(array)
            assert len(images) > 0
            num_files = len(images)
            images = np.asanyarray(images)
            images = images.reshape((num_files * images.shape[1], ) +
                                    images.shape[2:])
        else:
            raise NotImplementedError

    # Print dataset information
    if True:
        x_mean = np.mean(images)
        x_var = np.var(images)

        dataset = glow.dataset.Dataset(images)
        iterator = glow.dataset.Iterator(dataset, batch_size=args.batch_size)

        print(
            tabulate([
                ["#", len(dataset)],
                ["mean", x_mean],
                ["var", x_var],
            ]))

    # Hyperparameters' info
    if True:
        hyperparams = Hyperparameters()
        hyperparams.levels = args.levels
        hyperparams.depth_per_level = args.depth_per_level
        hyperparams.nn_hidden_channels = args.nn_hidden_channels
        hyperparams.image_size = images.shape[2:]
        hyperparams.num_bits_x = args.num_bits_x
        hyperparams.lu_decomposition = args.lu_decomposition
        hyperparams.squeeze_factor = args.squeeze_factor
        hyperparams.save(args.snapshot_path)
        hyperparams.print()

    encoder = Glow(hyperparams, hdf5_path=args.snapshot_path)
    if using_gpu:
        encoder.to_gpu()

    optimizer = Optimizer(encoder)

    # Data dependent initialization
    if encoder.need_initialize:
        for batch_index, data_indices in enumerate(iterator):
            x = to_gpu(dataset[data_indices])
            encoder.initialize_actnorm_weights(x)
            break

    current_training_step = 0
    num_pixels = 3 * hyperparams.image_size[0] * hyperparams.image_size[1]

    # Training loop
    for iteration in range(args.total_iteration):
        sum_loss = 0
        sum_nll = 0
        sum_kld = 0
        start_time = time.time()

        for batch_index, data_indices in enumerate(iterator):
            x = to_gpu(dataset[data_indices])
            x += xp.random.uniform(0, 1.0 / num_bins_x, size=x.shape)

            denom = math.log(2.0) * num_pixels

            factorized_z_distribution, logdet = encoder.forward_step(x)

            logdet -= math.log(num_bins_x) * num_pixels

            kld = 0
            negative_log_likelihood = 0
            factor_z = []
            for (zi, mean, ln_var) in factorized_z_distribution:
                negative_log_likelihood += cf.gaussian_nll(zi, mean, ln_var)
                if args.regularize_z:
                    kld += cf.gaussian_kl_divergence(mean, ln_var)
                factor_z.append(zi.data.reshape(zi.shape[0], -1))
            factor_z = xp.concatenate(factor_z, axis=1)
            negative_log_likelihood += cf.gaussian_nll(
                factor_z, xp.zeros(factor_z.shape, dtype='float32'),
                xp.zeros(factor_z.shape, dtype='float32'))
            loss = (negative_log_likelihood + kld) / args.batch_size - logdet
            loss = loss / denom

            encoder.cleargrads()
            loss.backward()
            optimizer.update(current_training_step)

            current_training_step += 1

            sum_loss += _float(loss)
            sum_nll += _float(negative_log_likelihood) / args.batch_size
            sum_kld += _float(kld) / args.batch_size
            printr(
                "Iteration {}: Batch {} / {} - loss: {:.8f} - nll: {:.8f} - kld: {:.8f} - log_det: {:.8f}\n"
                .format(
                    iteration + 1, batch_index + 1, len(iterator),
                    _float(loss),
                    _float(negative_log_likelihood) / args.batch_size / denom,
                    _float(kld) / args.batch_size,
                    _float(logdet) / denom))

            if (batch_index + 1) % 100 == 0:
                encoder.save(args.snapshot_path)

        mean_log_likelihood = -sum_nll / len(iterator)
        mean_kld = sum_kld / len(iterator)
        elapsed_time = time.time() - start_time
        print(
            "\033[2KIteration {} - loss: {:.5f} - log_likelihood: {:.5f} - kld: {:.5f} - elapsed_time: {:.3f} min\n"
            .format(iteration + 1, sum_loss / len(iterator),
                    mean_log_likelihood, mean_kld, elapsed_time / 60))
        encoder.save(args.snapshot_path)
Пример #15
0
def main():
    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    hyperparams = Hyperparameters(args.snapshot_path)
    hyperparams.print()

    num_bins_x = 2.0**hyperparams.num_bits_x
    image_size = (28, 28)

    images = chainer.datasets.mnist.get_mnist(withlabel=False)[0]
    images = 255.0 * np.asarray(images).reshape((-1, ) + image_size + (1, ))
    if hyperparams.num_image_channels != 1:
        images = np.broadcast_to(images, (images.shape[0], ) + image_size +
                                 (hyperparams.num_image_channels, ))
    images = preprocess(images, hyperparams.num_bits_x)

    dataset = glow.dataset.Dataset(images)
    iterator = glow.dataset.Iterator(dataset, batch_size=2)

    print(tabulate([["#image", len(dataset)]]))

    encoder = Glow(hyperparams, hdf5_path=args.snapshot_path)
    if using_gpu:
        encoder.to_gpu()

    total = args.num_steps + 2
    fig = plt.figure(figsize=(4 * total, 4))
    subplots = []
    for n in range(total):
        subplot = fig.add_subplot(1, total, n + 1)
        subplots.append(subplot)

    with chainer.no_backprop_mode() and encoder.reverse() as decoder:
        while True:
            for data_indices in iterator:
                x = to_gpu(dataset[data_indices])
                x += xp.random.uniform(0, 1.0 / num_bins_x, size=x.shape)
                factorized_z_distribution, _ = encoder.forward_step(x)

                factorized_z = []
                for (zi, mean, ln_var) in factorized_z_distribution:
                    factorized_z.append(zi)

                z = encoder.merge_factorized_z(factorized_z)
                z_start = z[0]
                z_end = z[1]

                z_batch = [z_start]
                for n in range(args.num_steps):
                    ratio = n / (args.num_steps - 1)
                    z_interp = ratio * z_end + (1.0 - ratio) * z_start
                    z_batch.append(args.temperature * z_interp)
                z_batch.append(z_end)
                z_batch = xp.stack(z_batch)

                rev_x_batch, _ = decoder.reverse_step(z_batch)
                for n in range(args.num_steps):
                    rev_x_img = make_uint8(rev_x_batch.data[n + 1], num_bins_x)
                    subplots[n + 1].imshow(rev_x_img, interpolation="none")

                x_start_img = make_uint8(x[0], num_bins_x)
                subplots[0].imshow(x_start_img, interpolation="none")

                x_end_img = make_uint8(x[-1], num_bins_x)
                subplots[-1].imshow(x_end_img, interpolation="none")

                plt.pause(.01)
Пример #16
0
def main():
    try:
        os.mkdir(args.ckpt)
    except:
        pass

    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    hyperparams = Hyperparameters(args.snapshot_path)
    hyperparams.print()

    num_bins_x = 2.0**hyperparams.num_bits_x
    num_pixels = 3 * hyperparams.image_size[0] * hyperparams.image_size[1]

    encoder = Glow(hyperparams, hdf5_path=args.snapshot_path)
    if using_gpu:
        encoder.to_gpu()

    # Load picture
    x = np.array(Image.open(args.img)).astype('float32')
    x = preprocess(x, hyperparams.num_bits_x)

    x = to_gpu(xp.expand_dims(x, axis=0))
    x += xp.random.uniform(0, 1.0 / num_bins_x, size=x.shape)

    if True:
        # Print this image info:
        b = xp.zeros((1, 3, 128, 128))
        z, fw_ldt = encoder.forward_step(x, b)
        fw_ldt -= math.log(num_bins_x) * num_pixels

        logpZ = 0
        ez = []
        factor_z = []
        for (zi, mean, ln_var) in z:
            factor_z.append(zi.data)
            logpZ += cf.gaussian_nll(zi, mean, ln_var)
            ez.append(zi.data.reshape(-1, ))

        ez = np.concatenate(ez)
        logpZ2 = cf.gaussian_nll(ez, xp.zeros(ez.shape),
                                 xp.zeros(ez.shape)).data

        print(fw_ldt, logpZ, logpZ2)
        with encoder.reverse() as decoder:
            rx, _ = decoder.reverse_step(factor_z)
            rx_img = make_uint8(rx.data[0], num_bins_x)
            rx_img = Image.fromarray(rx_img)
            rx_img.save(args.t + 'ori_revx.png')

        np.save(args.t + 'ori_z.npy', ez.get())

    # Construct epsilon
    class eps(chainer.Chain):
        def __init__(self, shape, glow_encoder):
            super().__init__()
            self.encoder = glow_encoder

            with self.init_scope():
                self.b = chainer.Parameter(initializers.Zero(),
                                           (1, 3, 128, 128))
                self.m = chainer.Parameter(initializers.One(), (3, 8, 8))

        def forward(self, x):
            # b = cf.tanh(self.b) * 0.5
            b = self.b

            # Not sure if implementation is wrong
            m = cf.softplus(self.m)
            # m = cf.repeat(m, 8, axis=2)
            # m = cf.repeat(m, 8, axis=1)
            # m = cf.repeat(m, 16, axis=2)
            # m = cf.repeat(m, 16, axis=1)

            # b = b * m
            # cur_x = cf.add(x, b)
            # cur_x = cf.clip(cur_x, -0.5,0.5)

            z = []
            zs, logdet = self.encoder.forward_step(x, b)
            for (zi, mean, ln_var) in zs:
                z.append(zi)

            z = merge_factorized_z(z)

            # return z, zs, logdet, cf.batch_l2_norm_squared(b), xp.tanh(self.b.data*1), cur_x, m
            return z, zs, logdet, xp.sum(xp.abs(b.data)), self.b.data * 1, m, x

        def save(self, path):
            filename = 'loss_model.hdf5'
            self.save_parameter(path, filename, self)

        def save_parameter(self, path, filename, params):
            tmp_filename = str(uuid.uuid4())
            tmp_filepath = os.path.join(path, tmp_filename)
            save_hdf5(tmp_filepath, params)
            os.rename(tmp_filepath, os.path.join(path, filename))

    epsilon = eps(x.shape, encoder)
    if using_gpu:
        epsilon.to_gpu()

    # optimizer = Optimizer(epsilon)
    optimizer = optimizers.Adam().setup(epsilon)
    # optimizer = optimizers.SGD().setup(epsilon)
    epsilon.b.update_rule.hyperparam.lr = 0.01
    epsilon.m.update_rule.hyperparam.lr = 0.1
    print('init finish')

    training_step = 0

    z_s = []
    b_s = []
    loss_s = []
    logpZ_s = []
    logDet_s = []
    m_s = []
    j = 0

    for iteration in range(args.total_iteration):
        epsilon.cleargrads()
        z, zs, fw_ldt, b_norm, b, m, cur_x = epsilon.forward(x)

        fw_ldt -= math.log(num_bins_x) * num_pixels

        logpZ1 = 0
        factor_z = []
        for (zi, mean, ln_var) in zs:
            factor_z.append(zi.data)
            logpZ1 += cf.gaussian_nll(zi, mean, ln_var)

        logpZ2 = cf.gaussian_nll(z, xp.zeros(z.shape), xp.zeros(z.shape)).data
        # logpZ2 = cf.gaussian_nll(z, np.mean(z), np.log(np.var(z))).data

        logpZ = (logpZ2 + logpZ1) / 2
        loss = b_norm + (logpZ - fw_ldt)

        loss.backward()
        optimizer.update()
        training_step += 1

        z_s.append(z.get())
        b_s.append(cupy.asnumpy(b))
        m_s.append(cupy.asnumpy(m.data))
        loss_s.append(_float(loss))
        logpZ_s.append(_float(logpZ))
        logDet_s.append(_float(fw_ldt))

        printr(
            "Iteration {}: loss: {:.6f} - b_norm: {:.6f} - logpZ: {:.6f} - logpZ1: {:.6f} - logpZ2: {:.6f} - log_det: {:.6f} - logpX: {:.6f}\n"
            .format(iteration + 1, _float(loss), _float(b_norm), _float(logpZ),
                    _float(logpZ1), _float(logpZ2), _float(fw_ldt),
                    _float(logpZ) - _float(fw_ldt)))

        if iteration % 100 == 99:
            np.save(args.ckpt + '/' + str(j) + 'z.npy', z_s)
            np.save(args.ckpt + '/' + str(j) + 'b.npy', b_s)
            np.save(args.ckpt + '/' + str(j) + 'loss.npy', loss_s)
            np.save(args.ckpt + '/' + str(j) + 'logpZ.npy', logpZ_s)
            np.save(args.ckpt + '/' + str(j) + 'logDet.npy', logDet_s)
            # cur_x = make_uint8(cur_x[0].data, num_bins_x)
            # np.save(args.ckpt + '/'+str(j)+'image.npy', cur_x)
            np.save(args.ckpt + '/' + str(j) + 'm.npy', m_s)

            with encoder.reverse() as decoder:
                rx, _ = decoder.reverse_step(factor_z)
                rx_img = make_uint8(rx.data[0], num_bins_x)
                np.save(args.ckpt + '/' + str(j) + 'res.npy', rx_img)
            z_s = []
            b_s = []
            loss_s = []
            logpZ_s = []
            logDet_s = []
            m_s = []
            j += 1
            epsilon.save(args.ckpt)
Пример #17
0
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size,
         epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale,
         flow_permutation, flow_coupling, LU_decomposed, learn_top,
         y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers,
         cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer,
         warmup, fresh, logittransform, gan, disc_lr, sn, flowgan, eval_every,
         ld_on_samples, weight_gan, weight_prior, weight_logdet,
         jac_reg_lambda, affine_eps, no_warm_up, optim_name, clamp, svd_every,
         eval_only, no_actnorm, affine_scale_eps, actnorm_max_scale,
         no_conv_actnorm, affine_max_scale, actnorm_eps, init_sample, no_split,
         disc_arch, weight_entropy_reg, db):

    check_manual_seed(seed)

    ds = check_dataset(dataset, dataroot, augment, download)
    image_shape, num_classes, train_dataset, test_dataset = ds

    # Note: unsupported for now
    multi_class = False

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=n_workers,
                                   drop_last=True)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=eval_batch_size,
                                  shuffle=False,
                                  num_workers=n_workers,
                                  drop_last=False)
    model = Glow(image_shape, hidden_channels, K, L, actnorm_scale,
                 flow_permutation, flow_coupling, LU_decomposed, num_classes,
                 learn_top, y_condition, logittransform, sn, affine_eps,
                 no_actnorm, affine_scale_eps, actnorm_max_scale,
                 no_conv_actnorm, affine_max_scale, actnorm_eps, no_split)

    model = model.to(device)

    if disc_arch == 'mine':
        discriminator = mine.Discriminator(image_shape[-1])
    elif disc_arch == 'biggan':
        discriminator = cgan_models.Discriminator(
            image_channels=image_shape[-1], conditional_D=False)
    elif disc_arch == 'dcgan':
        discriminator = DCGANDiscriminator(image_shape[0], 64, image_shape[-1])
    elif disc_arch == 'inv':
        discriminator = InvDiscriminator(
            image_shape, hidden_channels, K, L, actnorm_scale,
            flow_permutation, flow_coupling, LU_decomposed, num_classes,
            learn_top, y_condition, logittransform, sn, affine_eps, no_actnorm,
            affine_scale_eps, actnorm_max_scale, no_conv_actnorm,
            affine_max_scale, actnorm_eps, no_split)

    discriminator = discriminator.to(device)
    D_optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                    discriminator.parameters()),
                             lr=disc_lr,
                             betas=(.5, .99),
                             weight_decay=0)
    if optim_name == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=lr,
                               betas=(.5, .99),
                               weight_decay=0)
    elif optim_name == 'adamax':
        optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    if not no_warm_up:
        lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      lr_lambda=lr_lambda)

    iteration_fieldnames = [
        'global_iteration', 'fid', 'sample_pad', 'train_bpd', 'eval_bpd',
        'pad', 'batch_real_acc', 'batch_fake_acc', 'batch_acc'
    ]
    iteration_logger = CSVLogger(fieldnames=iteration_fieldnames,
                                 filename=os.path.join(output_dir,
                                                       'iteration_log.csv'))
    iteration_fieldnames = [
        'global_iteration', 'condition_num', 'max_sv', 'min_sv',
        'inverse_condition_num', 'inverse_max_sv', 'inverse_min_sv'
    ]
    svd_logger = CSVLogger(fieldnames=iteration_fieldnames,
                           filename=os.path.join(output_dir, 'svd_log.csv'))

    #
    test_iter = test_loader.__iter__()
    N_inception = 1000
    x_real_inception = torch.cat([
        test_iter.__next__()[0].to(device)
        for _ in range(N_inception // args.batch_size + 1)
    ], 0)[:N_inception]
    x_real_inception = x_real_inception + .5
    x_for_recon = test_iter.__next__()[0].to(device)

    def gan_step(engine, batch):
        assert not y_condition
        if 'iter_ind' in dir(engine):
            engine.iter_ind += 1
        else:
            engine.iter_ind = -1
        losses = {}
        model.train()
        discriminator.train()

        x, y = batch
        x = x.to(device)

        def run_noised_disc(discriminator, x):
            x = uniform_binning_correction(x)[0]
            return discriminator(x)

        real_acc = fake_acc = acc = 0
        if weight_gan > 0:
            fake = generate_from_noise(model, x.size(0), clamp=clamp)

            D_real_scores = run_noised_disc(discriminator, x.detach())
            D_fake_scores = run_noised_disc(discriminator, fake.detach())

            ones_target = torch.ones((x.size(0), 1), device=x.device)
            zeros_target = torch.zeros((x.size(0), 1), device=x.device)

            D_real_accuracy = torch.sum(
                torch.round(F.sigmoid(D_real_scores)) ==
                ones_target).float() / ones_target.size(0)
            D_fake_accuracy = torch.sum(
                torch.round(F.sigmoid(D_fake_scores)) ==
                zeros_target).float() / zeros_target.size(0)

            D_real_loss = F.binary_cross_entropy_with_logits(
                D_real_scores, ones_target)
            D_fake_loss = F.binary_cross_entropy_with_logits(
                D_fake_scores, zeros_target)

            D_loss = (D_real_loss + D_fake_loss) / 2
            gp = gradient_penalty(
                x.detach(), fake.detach(),
                lambda _x: run_noised_disc(discriminator, _x))
            D_loss_plus_gp = D_loss + 10 * gp
            D_optimizer.zero_grad()
            D_loss_plus_gp.backward()
            D_optimizer.step()

            # Train generator
            fake = generate_from_noise(model,
                                       x.size(0),
                                       clamp=clamp,
                                       guard_nans=False)
            G_loss = F.binary_cross_entropy_with_logits(
                run_noised_disc(discriminator, fake),
                torch.ones((x.size(0), 1), device=x.device))

            # Trace
            real_acc = D_real_accuracy.item()
            fake_acc = D_fake_accuracy.item()
            acc = .5 * (D_fake_accuracy.item() + D_real_accuracy.item())

        z, nll, y_logits, (prior, logdet) = model.forward(x,
                                                          None,
                                                          return_details=True)
        train_bpd = nll.mean().item()

        loss = 0
        if weight_gan > 0:
            loss = loss + weight_gan * G_loss
        if weight_prior > 0:
            loss = loss + weight_prior * -prior.mean()
        if weight_logdet > 0:
            loss = loss + weight_logdet * -logdet.mean()

        if weight_entropy_reg > 0:
            _, _, _, (sample_prior,
                      sample_logdet) = model.forward(fake,
                                                     None,
                                                     return_details=True)
            # notice this is actually "decreasing" sample likelihood.
            loss = loss + weight_entropy_reg * (sample_prior.mean() +
                                                sample_logdet.mean())
        # Jac Reg
        if jac_reg_lambda > 0:
            # Sample
            x_samples = generate_from_noise(model,
                                            args.batch_size,
                                            clamp=clamp).detach()
            x_samples.requires_grad_()
            z = model.forward(x_samples, None, return_details=True)[0]
            other_zs = torch.cat([
                split._last_z2.view(x.size(0), -1)
                for split in model.flow.splits
            ], -1)
            all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1)
            sample_foward_jac = compute_jacobian_regularizer(x_samples,
                                                             all_z,
                                                             n_proj=1)
            _, c2, h, w = model.prior_h.shape
            c = c2 // 2
            zshape = (batch_size, c, h, w)
            randz = torch.randn(zshape).to(device)
            randz = torch.autograd.Variable(randz, requires_grad=True)
            images = model(z=randz,
                           y_onehot=None,
                           temperature=1,
                           reverse=True,
                           batch_size=0)
            other_zs = [split._last_z2 for split in model.flow.splits]
            all_z = [randz] + other_zs
            sample_inverse_jac = compute_jacobian_regularizer_manyinputs(
                all_z, images, n_proj=1)

            # Data
            x.requires_grad_()
            z = model.forward(x, None, return_details=True)[0]
            other_zs = torch.cat([
                split._last_z2.view(x.size(0), -1)
                for split in model.flow.splits
            ], -1)
            all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1)
            data_foward_jac = compute_jacobian_regularizer(x, all_z, n_proj=1)
            _, c2, h, w = model.prior_h.shape
            c = c2 // 2
            zshape = (batch_size, c, h, w)
            z.requires_grad_()
            images = model(z=z,
                           y_onehot=None,
                           temperature=1,
                           reverse=True,
                           batch_size=0)
            other_zs = [split._last_z2 for split in model.flow.splits]
            all_z = [z] + other_zs
            data_inverse_jac = compute_jacobian_regularizer_manyinputs(
                all_z, images, n_proj=1)

            # loss = loss + jac_reg_lambda * (sample_foward_jac + sample_inverse_jac )
            loss = loss + jac_reg_lambda * (sample_foward_jac +
                                            sample_inverse_jac +
                                            data_foward_jac + data_inverse_jac)

        if not eval_only:
            optimizer.zero_grad()
            loss.backward()
            if not db:
                assert max_grad_clip == max_grad_norm == 0
            if max_grad_clip > 0:
                torch.nn.utils.clip_grad_value_(model.parameters(),
                                                max_grad_clip)
            if max_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               max_grad_norm)

            # Replace NaN gradient with 0
            for p in model.parameters():
                if p.requires_grad and p.grad is not None:
                    g = p.grad.data
                    g[g != g] = 0

            optimizer.step()

        if engine.iter_ind % 100 == 0:
            with torch.no_grad():
                fake = generate_from_noise(model, x.size(0), clamp=clamp)
                z = model.forward(fake, None, return_details=True)[0]
            print("Z max min")
            print(z.max().item(), z.min().item())
            if (fake != fake).float().sum() > 0:
                title = 'NaNs'
            else:
                title = "Good"
            grid = make_grid((postprocess(fake.detach().cpu(), dataset)[:30]),
                             nrow=6).permute(1, 2, 0)
            plt.figure(figsize=(10, 10))
            plt.imshow(grid)
            plt.axis('off')
            plt.title(title)
            plt.savefig(
                os.path.join(output_dir, f'sample_{engine.iter_ind}.png'))

        if engine.iter_ind % eval_every == 0:

            def check_all_zero_except_leading(x):
                return x % 10**np.floor(np.log10(x)) == 0

            if engine.iter_ind == 0 or check_all_zero_except_leading(
                    engine.iter_ind):
                torch.save(
                    model.state_dict(),
                    os.path.join(output_dir, f'ckpt_sd_{engine.iter_ind}.pt'))

            model.eval()

            with torch.no_grad():
                # Plot recon
                fpath = os.path.join(output_dir, '_recon',
                                     f'recon_{engine.iter_ind}.png')
                sample_pad = run_recon_evolution(
                    model,
                    generate_from_noise(model, args.batch_size,
                                        clamp=clamp).detach(), fpath)
                print(
                    f"Iter: {engine.iter_ind}, Recon Sample PAD: {sample_pad}")

                pad = run_recon_evolution(model, x_for_recon, fpath)
                print(f"Iter: {engine.iter_ind}, Recon PAD: {pad}")
                pad = pad.item()
                sample_pad = sample_pad.item()

                # Inception score
                sample = torch.cat([
                    generate_from_noise(model, args.batch_size, clamp=clamp)
                    for _ in range(N_inception // args.batch_size + 1)
                ], 0)[:N_inception]
                sample = sample + .5

                if (sample != sample).float().sum() > 0:
                    print("Sample NaNs")
                    raise
                else:
                    fid = run_fid(x_real_inception.clamp_(0, 1),
                                  sample.clamp_(0, 1))
                    print(f'fid: {fid}, global_iter: {engine.iter_ind}')

                # Eval BPD
                eval_bpd = np.mean([
                    model.forward(x.to(device), None,
                                  return_details=True)[1].mean().item()
                    for x, _ in test_loader
                ])

                stats_dict = {
                    'global_iteration': engine.iter_ind,
                    'fid': fid,
                    'train_bpd': train_bpd,
                    'pad': pad,
                    'eval_bpd': eval_bpd,
                    'sample_pad': sample_pad,
                    'batch_real_acc': real_acc,
                    'batch_fake_acc': fake_acc,
                    'batch_acc': acc
                }
                iteration_logger.writerow(stats_dict)
                plot_csv(iteration_logger.filename)
            model.train()

        if engine.iter_ind + 2 % svd_every == 0:
            model.eval()
            svd_dict = {}
            ret = utils.computeSVDjacobian(x_for_recon, model)
            D_for, D_inv = ret['D_for'], ret['D_inv']
            cn = float(D_for.max() / D_for.min())
            cn_inv = float(D_inv.max() / D_inv.min())
            svd_dict['global_iteration'] = engine.iter_ind
            svd_dict['condition_num'] = cn
            svd_dict['max_sv'] = float(D_for.max())
            svd_dict['min_sv'] = float(D_for.min())
            svd_dict['inverse_condition_num'] = cn_inv
            svd_dict['inverse_max_sv'] = float(D_inv.max())
            svd_dict['inverse_min_sv'] = float(D_inv.min())
            svd_logger.writerow(svd_dict)
            # plot_utils.plot_stability_stats(output_dir)
            # plot_utils.plot_individual_figures(output_dir, 'svd_log.csv')
            model.train()
            if eval_only:
                sys.exit()

        # Dummy
        losses['total_loss'] = torch.mean(nll).item()
        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(nll,
                                        y_logits,
                                        y_weight,
                                        y,
                                        multi_class,
                                        reduction='none')
            else:
                z, nll, y_logits = model(x, None)
                losses = compute_loss(nll, reduction='none')

        return losses

    trainer = Engine(gan_step)
    # else:
    #     trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         'glow',
                                         save_interval=5,
                                         n_saved=1,
                                         require_empty=False)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {
        'model': model,
        'optimizer': optimizer
    })

    monitoring_metrics = ['total_loss']
    RunningAverage(output_transform=lambda x: x['total_loss']).attach(
        trainer, 'total_loss')

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(lambda x, y: torch.mean(x),
         output_transform=lambda x:
         (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach(
             evaluator, 'total_loss')

    if y_condition:
        monitoring_metrics.extend(['nll'])
        RunningAverage(output_transform=lambda x: x['nll']).attach(
            trainer, 'nll')

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(lambda x, y: torch.mean(x),
             output_transform=lambda x:
             (x['nll'], torch.empty(x['nll'].shape[0]))).attach(
                 evaluator, 'nll')

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        print("Loading...")
        print(saved_model)
        loaded = torch.load(saved_model)
        # if 'Glow' in str(type(loaded)):
        #     model  = loaded
        # else:
        #     raise
        # # if 'Glow' in str(type(loaded)):
        # #     loaded  = loaded.state_dict()
        model.load_state_dict(loaded)
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer))

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split('_')[-1])

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(
                engine.state.dataloader)

    @trainer.on(Events.STARTED)
    def init(engine):
        if saved_model:
            return
        model.train()
        print("Initializing Actnorm...")
        init_batches = []
        init_targets = []

        if n_init_batches == 0:
            model.set_actnorm_init()
            return
        with torch.no_grad():
            if init_sample:
                generate_from_noise(model,
                                    args.batch_size * args.n_init_batches)
            else:
                for batch, target in islice(train_loader, None,
                                            n_init_batches):
                    init_batches.append(batch)
                    init_targets.append(target)

                init_batches = torch.cat(init_batches).to(device)

                assert init_batches.shape[0] == n_init_batches * batch_size

                if y_condition:
                    init_targets = torch.cat(init_targets).to(device)
                else:
                    init_targets = None

                model(init_batches, init_targets)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):
        evaluator.run(test_loader)
        if not no_warm_up:
            scheduler.step()
        metrics = evaluator.state.metrics

        losses = ', '.join(
            [f"{key}: {value:.2f}" for key, value in metrics.items()])

        print(f'Validation Results - Epoch: {engine.state.epoch} {losses}')

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]'
        )
        timer.reset()

    trainer.run(train_loader, epochs)
Пример #18
0
    )
    parser.add_argument("--affine",
                        action="store_true",
                        help="use affine coupling instead of additive")
    parser.add_argument("--n_flow",
                        default=32,
                        type=int,
                        help="number of flows in each block")
    parser.add_argument("--n_bits", default=5, type=int, help="number of bits")
    parser.add_argument("--img_size", default=64, type=int)
    parser.add_argument("--batch_size", default=18, type=int)
    args = parser.parse_args()

    model_single = Glow(3,
                        args.n_flow,
                        args.n_block,
                        affine=args.affine,
                        conv_lu=not args.no_lu)
    model = torch.nn.DataParallel(model_single)
    model = model.to(device)
    ckt = torch.load('savemodel/checkpoint_18.tar')
    model.load_state_dict(ckt['net'])

    dataset = CelebALoader('./', args.img_size)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True)

    with_attr_cnt = 0
    wo_attr_cnt = 0
    with_attr_z = None
Пример #19
0
def main():
    try:
        os.mkdir(args.ckpt)
    except:
        pass

    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    hyperparams = Hyperparameters(args.snapshot_path)
    hyperparams.print()

    num_bins_x = 2.0**hyperparams.num_bits_x
    num_pixels = 3 * hyperparams.image_size[0] * hyperparams.image_size[1]

    encoder = Glow(hyperparams, hdf5_path=args.snapshot_path)
    if using_gpu:
        encoder.to_gpu()

    # Load picture
    x = np.array(Image.open(args.img)).astype('float32')
    x = preprocess(x, hyperparams.num_bits_x)

    x = to_gpu(xp.expand_dims(x, axis=0))
    x += xp.random.uniform(0, 1.0 / num_bins_x, size=x.shape)

    # Construct epsilon
    class eps(chainer.Chain):
        def __init__(self, shape, glow_encoder):
            super().__init__()
            self.encoder = glow_encoder

            with self.init_scope():
                self.b = chainer.Parameter(initializers.Zero(), shape)
                self.m = chainer.Parameter(initializers.One(), shape)

        def modify_mask(self):
            mask = self.m.data
            for i_idx in range(8):
                for j_idx in range(8):
                    mean = xp.mean((xp.sum(mask[:, :, i_idx * 8:i_idx * 8 + 8,
                                                j_idx * 8:j_idx * 8 + 8])))
                    mask[:, :, i_idx * 8:i_idx * 8 + 8,
                         j_idx * 8:j_idx * 8 + 8] = mean

            mask = xp.abs(mask)
            print(type(mask), type(self.m), type(self.m.data))

            self.m.data = mask

        def forward(self, x):
            # b_ = cf.tanh(self.b)
            b_ = self.b
            # Not sure if implementation is wrong
            self.modify_mask()
            # m = cf.repeat(m, 8, axis=2)
            # m = cf.repeat(m, 8, axis=1)
            # m = cf.repeat(m, 16, axis=2)
            # m = cf.repeat(m, 16, axis=1)
            # b = b * m
            x_ = cf.add(x, b_)
            x_ = cf.clip(x_, -0.5, 0.5)

            z = []
            zs, logdet = self.encoder.forward_step(x_)
            for (zi, mean, ln_var) in zs:
                z.append(zi)

            z = merge_factorized_z(z)

            # return z, zs, logdet, cf.batch_l2_norm_squared(b), xp.tanh(self.b.data*1), cur_x, m
            return z, zs, logdet, xp.sum(xp.abs(b_.data)), xp.tanh(
                self.b.data * 1), self.m, x_

        def save(self, path):
            filename = 'loss_model.hdf5'
            self.save_parameter(path, filename, self)

        def save_parameter(self, path, filename, params):
            tmp_filename = str(uuid.uuid4())
            tmp_filepath = os.path.join(path, tmp_filename)
            save_hdf5(tmp_filepath, params)
            os.rename(tmp_filepath, os.path.join(path, filename))

    epsilon = eps(x.shape, encoder)
    if using_gpu:
        epsilon.to_gpu()

    # optimizer = Optimizer(epsilon)
    # optimizer = optimizers.Adam(alpha=0.0005).setup(epsilon)
    optimizer = optimizers.SGD().setup(epsilon)
    epsilon.b.update_rule.hyperparam.lr = 0.0001
    epsilon.m.update_rule.hyperparam.lr = 0.1
    print('init finish')

    training_step = 0

    z_s = []
    b_s = []
    loss_s = []
    logpZ_s = []
    logDet_s = []
    m_s = []
    j = 0

    for iteration in range(args.total_iteration):

        # z, zs, logdet, xp.sum(xp.abs(b_.data)), xp.tanh(self.b.data * 1), self.m, x_

        z, zs, fw_ldt, b_norm, b, m, cur_x = epsilon.forward(x)

        epsilon.cleargrads()
        fw_ldt -= math.log(num_bins_x) * num_pixels

        logpZ1 = 0
        factor_z = []
        for (zi, mean, ln_var) in zs:
            factor_z.append(zi.data.reshape(zi.shape[0], -1))
            logpZ1 += cf.gaussian_nll(zi, mean, ln_var)
        factor_z = xp.concatenate(factor_z, axis=1)
        logpZ2 = cf.gaussian_nll(z, xp.zeros(z.shape), xp.zeros(z.shape)).data
        # logpZ2 = cf.gaussian_nll(z, np.mean(z), np.log(np.var(z))).data

        logpZ = (logpZ2 * 1 + logpZ1 * 1)
        loss = 10 * b_norm + (logpZ - fw_ldt)

        loss.backward()
        optimizer.update()
        training_step += 1

        z_s.append(z.get())
        b_s.append(cupy.asnumpy(b))
        m_s.append(cupy.asnumpy(m.data))
        loss_s.append(_float(loss))
        logpZ_s.append(_float(logpZ))
        logDet_s.append(_float(fw_ldt))

        printr(
            "Iteration {}: loss: {:.6f} - b_norm: {:.6f} - logpZ: {:.6f} - logpZ1: {:.6f} - logpZ2: {:.6f} - log_det: {:.6f} - logpX: {:.6f}\n"
            .format(iteration + 1, _float(loss), _float(b_norm), _float(logpZ),
                    _float(logpZ1), _float(logpZ2), _float(fw_ldt),
                    _float(logpZ) - _float(fw_ldt)))

        if iteration % 100 == 99:
            print(cur_x.shape)
            np.save(args.ckpt + '/' + str(j) + 'z.npy', z_s)
            np.save(args.ckpt + '/' + str(j) + 'b.npy', b_s)
            np.save(args.ckpt + '/' + str(j) + 'loss.npy', loss_s)
            np.save(args.ckpt + '/' + str(j) + 'logpZ.npy', logpZ_s)
            np.save(args.ckpt + '/' + str(j) + 'logDet.npy', logDet_s)

            cur_x = make_uint8(cur_x[0].data, num_bins_x)
            np.save(args.ckpt + '/' + str(j) + 'image.npy', cur_x)

            x_PIL = Image.fromarray(cur_x)
            x_PIL.save("./mask_imgs/trained.jpg")
            np.save(args.ckpt + '/' + str(j) + 'm.npy', m_s)

            # with encoder.reverse() as decoder:
            #     rx, _ = decoder.reverse_step(factor_z)
            #     rx_img = make_uint8(rx.data[0], num_bins_x)
            #     np.save(args.ckpt + '/'+str(j)+'res.npy', rx_img)
            z_s = []
            b_s = []
            loss_s = []
            logpZ_s = []
            logDet_s = []
            m_s = []
            j += 1
            epsilon.save(args.ckpt)
Пример #20
0
def main(
    dataset,
    dataset2,
    dataroot,
    download,
    augment,
    batch_size,
    eval_batch_size,
    nlls_batch_size,
    epochs,
    nb_step,
    saved_model,
    seed,
    hidden_channels,
    K,
    L,
    actnorm_scale,
    flow_permutation,
    flow_coupling,
    LU_decomposed,
    learn_top,
    y_condition,
    y_weight,
    max_grad_clip,
    max_grad_norm,
    lr,
    lr_test,
    n_workers,
    cuda,
    n_init_batches,
    output_dir,
    saved_optimizer,
    warmup,
    every_epoch,
):

    device = "cpu" if (not torch.cuda.is_available() or not cuda) else "cuda:0"

    check_manual_seed(seed)

    ds = check_dataset(dataset, dataroot, augment, download)
    ds2 = check_dataset(dataset2, dataroot, augment, download)
    image_shape, num_classes, train_dataset, test_dataset = ds
    image_shape2, num_classes2, train_dataset_2, test_dataset_2 = ds2

    assert(image_shape == image_shape2)
    data1 = []
    data2 = []
    for k in range(nlls_batch_size):
        dataaux, targetaux = test_dataset[k]
        data1.append(dataaux)
        dataaux, targetaux = test_dataset_2[k]
        data2.append(dataaux)


    # Note: unsupported for now
    multi_class = False

    train_loader = data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=n_workers,
        drop_last=True,
    )
    test_loader = data.DataLoader(
        test_dataset,
        batch_size=eval_batch_size,
        shuffle=False,
        num_workers=n_workers,
        drop_last=False,
    )

    model = Glow(
        image_shape,
        hidden_channels,
        K,
        L,
        actnorm_scale,
        flow_permutation,
        flow_coupling,
        LU_decomposed,
        num_classes,
        learn_top,
        y_condition,
    )

    model = model.to(device)
    optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup)  # noqa
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch
        x = x.to(device)

        if y_condition:
            y = y.to(device)
            z, nll, y_logits = model(x, y)
            losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class)
        else:
            z, nll, y_logits = model(x, None)
            losses = compute_loss(nll)

        losses["total_loss"].backward()

        if max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(
                    nll, y_logits, y_weight, y, multi_class, reduction="none"
                )
            else:
                z, nll, y_logits = model(x, None)
                losses = compute_loss(nll, reduction="none")

        return losses

    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(
        output_dir, "glow", n_saved=2, require_empty=False
    )

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        checkpoint_handler,
        {"model": model, "optimizer": optimizer},
    )

    monitoring_metrics = ["total_loss"]
    RunningAverage(output_transform=lambda x: x["total_loss"]).attach(
        trainer, "total_loss"
    )

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(
        lambda x, y: torch.mean(x),
        output_transform=lambda x: (
            x["total_loss"],
            torch.empty(x["total_loss"].shape[0]),
        ),
    ).attach(evaluator, "total_loss")

    if y_condition:
        monitoring_metrics.extend(["nll"])
        RunningAverage(output_transform=lambda x: x["nll"]).attach(trainer, "nll")

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(
            lambda x, y: torch.mean(x),
            output_transform=lambda x: (x["nll"], torch.empty(x["nll"].shape[0])),
        ).attach(evaluator, "nll")

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        model.load_state_dict(torch.load(saved_model)['model'])
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer)['opt'])

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split("_")[-1])/1e3

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(engine.state.dataloader)

    @trainer.on(Events.STARTED)
    def init(engine):
        model.train()

        init_batches = []
        init_targets = []

        with torch.no_grad():
            print(train_loader)
            for batch, target in islice(train_loader, None, n_init_batches):
                init_batches.append(batch)
                init_targets.append(target)

            init_batches = torch.cat(init_batches).to(device)

            assert init_batches.shape[0] == n_init_batches * batch_size

            if y_condition:
                init_targets = torch.cat(init_targets).to(device)
            else:
                init_targets = None

            model(init_batches, init_targets)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):
        evaluator.run(test_loader)

        scheduler.step()
        metrics = evaluator.state.metrics

        losses = ", ".join([f"{key}: {value:.2f}" for key, value in metrics.items()])

        print(f"Validation Results - Epoch: {engine.state.epoch} {losses}")

    timer = Timer(average=True)
    timer.attach(
        trainer,
        start=Events.EPOCH_STARTED,
        resume=Events.ITERATION_STARTED,
        pause=Events.ITERATION_COMPLETED,
        step=Events.ITERATION_COMPLETED,
    )

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f"Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]"
        )
        timer.reset()

    # @trainer.on(Events.EPOCH_COMPLETED)
    # def eval_likelihood(engine):
    #     global_nlls(output_dir, engine.state.epoch, data1, data2, model, dataset1_name = dataset, dataset2_name = dataset2, nb_step = nb_step, every_epoch = every_epoch, optim_default = partial(optim.SGD, lr=1e-5, momentum = 0.))


    trainer.run(train_loader, epochs)
Пример #21
0
def main(args):
    seed_list = [int(item) for item in args.seed.split(',')]

    for seed in seed_list:

        device = torch.device("cuda")

        experiment_folder = args.experiment_folder + '/' + str(seed) + '/'
        print(experiment_folder)
        #model_name = 'glow_checkpoint_'+ str(args.chk)+'.pth'
        for thing in os.listdir(experiment_folder):
            if 'best' in thing:
                model_name = thing
        print(model_name)

        random.seed(seed)
        torch.manual_seed(seed)
        np.random.seed(seed)
        torch.random.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

        with open(experiment_folder + 'hparams.json') as json_file:
            hparams = json.load(json_file)

        image_shape = (32, 32, 3)
        if hparams['y_condition']:
            num_classes = 2
            num_domains = 0
        elif hparams['d_condition']:
            num_classes = 10
            num_domains = 0
        elif hparams['yd_condition']:
            num_classes = 2
            num_domains = 10
        else:
            num_classes = 2
            num_domains = 0

        model = Glow(image_shape, hparams['hidden_channels'], hparams['K'],
                     hparams['L'], hparams['actnorm_scale'],
                     hparams['flow_permutation'], hparams['flow_coupling'],
                     hparams['LU_decomposed'], num_classes, num_domains,
                     hparams['learn_top'], hparams['y_condition'],
                     hparams['extra_condition'], hparams['sp_condition'],
                     hparams['d_condition'], hparams['yd_condition'])
        print('loading model')
        model.load_state_dict(torch.load(experiment_folder + model_name))
        model.set_actnorm_init()

        model = model.to(device)

        model = model.eval()

        if hparams['y_condition']:
            print('y_condition')

            def sample(model, temp=args.temperature):
                with torch.no_grad():
                    if hparams['y_condition']:
                        print("extra", hparams['extra_condition'])
                        y = torch.eye(num_classes)
                        y = torch.cat(1000 * [y])
                        print(y.size())
                        y_0 = y[::2, :].to(
                            device)  # number hardcoded in model for now
                        y_1 = y[1::2, :].to(device)
                        print(y_0.size())
                        print(y_0)
                        print(y_1)
                        print(y_1.size())
                        images0 = model(z=None,
                                        y_onehot=y_0,
                                        temperature=temp,
                                        reverse=True,
                                        batch_size=1000)
                        images1 = model(z=None,
                                        y_onehot=y_1,
                                        temperature=temp,
                                        reverse=True,
                                        batch_size=1000)
                return images0, images1

            images0, images1 = sample(model)

            os.makedirs(experiment_folder + 'generations/Uninfected',
                        exist_ok=True)
            os.makedirs(experiment_folder + 'generations/Parasitized',
                        exist_ok=True)
            for i in range(images0.size(0)):
                torchvision.utils.save_image(
                    images0[i, :, :, :], experiment_folder +
                    'generations/Uninfected/sample_{}.png'.format(i))
                torchvision.utils.save_image(
                    images1[i, :, :, :], experiment_folder +
                    'generations/Parasitized/sample_{}.png'.format(i))
            images_concat0 = torchvision.utils.make_grid(images0[:64, :, :, :],
                                                         nrow=int(64**0.5),
                                                         padding=2,
                                                         pad_value=255)
            torchvision.utils.save_image(images_concat0,
                                         experiment_folder + '/uninfected.png')
            images_concat1 = torchvision.utils.make_grid(images1[:64, :, :, :],
                                                         nrow=int(64**0.5),
                                                         padding=2,
                                                         pad_value=255)
            torchvision.utils.save_image(
                images_concat1, experiment_folder + '/parasitized.png')

        elif hparams['d_condition']:
            print('d_cond')

            def sample_d(model, idx, batch_size=1000, temp=args.temperature):
                with torch.no_grad():
                    if hparams['d_condition']:

                        y_0 = torch.zeros([batch_size, 10], device='cuda:0')
                        y_0[:, idx] = torch.ones(batch_size)
                        y_0.to(device)
                        print(y_0)

                        # y_1 = torch.zeros([batch_size, 201], device='cuda:0')
                        # y_1[:, 157] = torch.ones(batch_size)
                        # y_1.to(device)
                        # y = torch.eye(num_classes)
                        # y = torch.cat(1000 * [y])
                        # print(y.size())
                        # y_0 = y[::2, :].to(device)  # number hardcoded in model for now
                        # y_1 = y[1::2, :].to(device)
                        # print(y_0.size())
                        # print(y_0)
                        # print(y_1)
                        # print(y_1.size())

                        images0 = model(z=None,
                                        y_onehot=y_0,
                                        temperature=temp,
                                        reverse=True,
                                        batch_size=1000)
                        # images1 = model(z=None, y_onehot=y_1, temperature=1.0, reverse=True, batch_size=1000)
                return images0

            for idx, dom in enumerate(["C116P77ThinF", "C132P93ThinF", "C137P98ThinF", "C180P141NThinF", "C182P143NThinF", \
                             "C184P145ThinF", "C39P4thinF", 'C59P20thinF', "C68P29N", "C99P60ThinF"]):

                images0 = sample_d(model, idx)

                os.makedirs(experiment_folder + 'generations/' + dom +
                            '/Uninfected/',
                            exist_ok=True)
                os.makedirs(experiment_folder + 'generations/' + dom +
                            '/Parasitized/',
                            exist_ok=True)
                # os.makedirs(experiment_folder + 'generations/C59P20thinF/Uninfected/', exist_ok=True)
                # os.makedirs(experiment_folder + 'generations/C59P20thinF/Parasitized/', exist_ok=True)
                for i in range(images0.size(0)):
                    torchvision.utils.save_image(
                        images0[i, :, :, :],
                        experiment_folder + 'generations/' + dom +
                        '/Uninfected/sample_{}.png'.format(i))
                    #torchvision.utils.save_image(images1[i, :, :, :], experiment_folder + 'generations/C59P20thinF/Parasitized/sample_{}.png'.format(i))
                images_concat0 = torchvision.utils.make_grid(
                    images0[:25, :, :, :],
                    nrow=int(25**0.5),
                    padding=2,
                    pad_value=255)
                torchvision.utils.save_image(images_concat0,
                                             experiment_folder + dom + '.png')
                # images_concat1 = torchvision.utils.make_grid(images1[:64,:,:,:], nrow=int(64 ** 0.5), padding=2, pad_value=255)
                # torchvision.utils.save_image(images_concat1, experiment_folder + 'C59P20thinF.png')

        elif hparams['yd_condition']:

            def sample_YD(model, idx, batch_size=1000, temp=args.temperature):
                with torch.no_grad():
                    if hparams['yd_condition']:
                        y_0 = torch.zeros([batch_size, 12], device='cuda:0')
                        y_0[:, 0] = torch.ones(batch_size)
                        y_0[:, idx + 2] = torch.ones(batch_size)
                        y_0.to(device)
                        print(y_0)

                        y_1 = torch.zeros([batch_size, 12], device='cuda:0')
                        y_1[:, 1] = torch.ones(batch_size)
                        y_1[:, idx + 2] = torch.ones(batch_size)
                        y_1.to(device)
                        print(y_1)

                        images0 = model(z=None,
                                        y_onehot=y_0,
                                        temperature=temp,
                                        reverse=True,
                                        batch_size=1000)
                        images1 = model(z=None,
                                        y_onehot=y_1,
                                        temperature=temp,
                                        reverse=True,
                                        batch_size=1000)
                return images0, images1

            def sample_DD(model, idx, batch_size=1000, temp=args.temperature):
                with torch.no_grad():
                    if hparams['yd_condition']:
                        y_1 = torch.zeros([batch_size, 20], device='cuda:0')
                        y_1[:, idx] = torch.ones(batch_size)
                        y_1.to(device)
                        print(y_1)

                        y_0 = torch.zeros([batch_size, 20], device='cuda:0')
                        y_0[:, idx + 10] = torch.ones(batch_size)
                        y_0.to(device)
                        print(y_0)

                        images0 = model(z=None,
                                        y_onehot=y_0,
                                        temperature=temp,
                                        reverse=True,
                                        batch_size=1000)
                        images1 = model(z=None,
                                        y_onehot=y_1,
                                        temperature=temp,
                                        reverse=True,
                                        batch_size=1000)
                return images0, images1

            for idx, dom in enumerate(
                    ["C116P77ThinF", "C132P93ThinF", "C137P98ThinF", "C180P141NThinF", "C182P143NThinF", \
                     "C184P145ThinF", "C39P4thinF", 'C59P20thinF', "C68P29N", "C99P60ThinF"]):

                images0, images1 = sample_YD(model, idx)

                os.makedirs(experiment_folder + 'generations/' + dom +
                            '/Uninfected/',
                            exist_ok=True)
                os.makedirs(experiment_folder + 'generations/' + dom +
                            '/Parasitized/',
                            exist_ok=True)
                for i in range(images0.size(0)):
                    torchvision.utils.save_image(
                        images0[i, :, :, :],
                        experiment_folder + 'generations/' + dom +
                        '/Uninfected/sample_{}.png'.format(i))
                    torchvision.utils.save_image(
                        images1[i, :, :, :],
                        experiment_folder + 'generations/' + dom +
                        '/Parasitized/sample_{}.png'.format(i))
                images_concat0 = torchvision.utils.make_grid(
                    images0[:64, :, :, :],
                    nrow=int(64**0.5),
                    padding=2,
                    pad_value=255)
                torchvision.utils.save_image(
                    images_concat0, experiment_folder + dom +
                    str(args.temperature) + '_uninfected.png')
                images_concat1 = torchvision.utils.make_grid(
                    images1[:64, :, :, :],
                    nrow=int(64**0.5),
                    padding=2,
                    pad_value=255)
                torchvision.utils.save_image(
                    images_concat1, experiment_folder + dom +
                    str(args.temperature) + '_parasitized.png')

        else:

            def sample(model, temp=args.temperature):
                with torch.no_grad():
                    images = model(z=None,
                                   y_onehot=None,
                                   temperature=temp,
                                   reverse=True,
                                   batch_size=1000)

                return images

            images = sample(model)

            os.makedirs('unconditioned/' + str(seed) + '/generations/' +
                        experiment_folder[:-3],
                        exist_ok=True)
            for i in range(images.size(0)):
                torchvision.utils.save_image(
                    images[i, :, :, :],
                    'unconditioned/' + str(seed) + '/generations/' +
                    experiment_folder[:-2] + 'sample_{}.png'.format(i))

            images_concat = torchvision.utils.make_grid(images[:64, :, :, :],
                                                        nrow=int(64**0.5),
                                                        padding=2,
                                                        pad_value=255)
            torchvision.utils.save_image(
                images_concat, 'unconditioned/' + str(seed) + '/' +
                experiment_folder[:-3] + '.png')
Пример #22
0
    # ipdb.set_trace()
    with torch.no_grad():
        iss, _, _, acts_real = inception_score(x_is, cuda=True, batch_size=32, resize=True, splits=10, return_preds=True)
    print(iss)
    # Model samples
    coupling = 'affine'
    sn_ = 0
    loss_ = 'mle'
    # for (coupling, sn_) in product(['affine','additive'], [0,1]):
    exp_dir = f'/scratch/gobi2/wangkuan/glow/rebuttal-guess/{loss_}-128-32-{coupling}-{sn_}'
    with open(os.path.join(exp_dir, 'hparams.json'), 'r') as f:
        params = json.load(f)
    locals().update(params)
    
    model = Glow(image_shape, hidden_channels, K, L, actnorm_scale,
                 flow_permutation, flow_coupling, LU_decomposed, num_classes,
                 learn_top, y_condition,logittransform,sn)
    model = model.to(device)
    # ipdb.set_trace()

    def generate_from_noise(batch_size):
        _, c2, h, w  = model.prior_h.shape
        c = c2 // 2
        zshape = (batch_size, c, h, w)
        randz  = torch.randn(zshape).to(device)
        randz  = torch.autograd.Variable(randz, requires_grad=True)
        images = model(z= randz, y_onehot=None, temperature=1, reverse=True,batch_size=batch_size)   
        return images

    # ipdb.set_trace()
    iteration_fieldnames = ['global_iteration', 'fid']
Пример #23
0
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size,
         epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale,
         flow_permutation, flow_coupling, LU_decomposed, learn_top,
         y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers,
         cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer,
         fresh):

    device = 'cpu' if (not torch.cuda.is_available() or not cuda) else 'cuda:0'

    check_manual_seed(seed)

    ds = check_dataset(dataset, dataroot, augment, download)
    image_shape, num_classes, train_dataset, test_dataset = ds

    # Note: unsupported for now
    multi_class = False

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=n_workers,
                                   drop_last=True)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=eval_batch_size,
                                  shuffle=False,
                                  num_workers=n_workers,
                                  drop_last=False)

    model = Glow(image_shape, hidden_channels, K, L, actnorm_scale,
                 flow_permutation, flow_coupling, LU_decomposed, num_classes,
                 learn_top, y_condition)

    model = model.to(device)
    optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch
        x = x.to(device)

        if y_condition:
            y = y.to(device)
            z, nll, y_logits = model(x, y)
            losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class)
        else:
            z, nll, y_logits = model(x, None)
            losses = compute_loss(nll)

        losses['total_loss'].backward()

        if max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(nll,
                                        y_logits,
                                        y_weight,
                                        y,
                                        multi_class,
                                        reduction='none')
            else:
                z, nll, y_logits = model(x, None)
                losses = compute_loss(nll, reduction='none')

        return losses

    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         'glow',
                                         save_interval=1,
                                         n_saved=2,
                                         require_empty=False)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {
        'model': model,
        'optimizer': optimizer
    })

    monitoring_metrics = ['total_loss']
    RunningAverage(output_transform=lambda x: x['total_loss']).attach(
        trainer, 'total_loss')

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(lambda x, y: torch.mean(x),
         output_transform=lambda x:
         (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach(
             evaluator, 'total_loss')

    if y_condition:
        monitoring_metrics.extend(['nll'])
        RunningAverage(output_transform=lambda x: x['nll']).attach(
            trainer, 'nll')

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(lambda x, y: torch.mean(x),
             output_transform=lambda x:
             (x['nll'], torch.empty(x['nll'].shape[0]))).attach(
                 evaluator, 'nll')

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        model.load_state_dict(torch.load(saved_model))
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer))

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split('_')[-1])

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(
                engine.state.dataloader)

    @trainer.on(Events.STARTED)
    def init(engine):
        model.train()

        init_batches = []
        init_targets = []

        with torch.no_grad():
            for batch, target in islice(train_loader, None, n_init_batches):
                init_batches.append(batch)
                init_targets.append(target)

            init_batches = torch.cat(init_batches).to(device)

            assert init_batches.shape[0] == n_init_batches * batch_size

            if y_condition:
                init_targets = torch.cat(init_targets).to(device)
            else:
                init_targets = None

            model(init_batches, init_targets)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):
        evaluator.run(test_loader)
        metrics = evaluator.state.metrics

        losses = ', '.join(
            [f"{key}: {value:.2f}" for key, value in metrics.items()])

        print(f'Validation Results - Epoch: {engine.state.epoch} {losses}')

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]'
        )
        timer.reset()

    trainer.run(train_loader, epochs)
Пример #24
0
def main():
    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    hyperparams = Hyperparameters(args.snapshot_path)
    hyperparams.print()

    num_bins_x = 2.0**hyperparams.num_bits_x
    num_pixels = 3 * hyperparams.image_size[0] * hyperparams.image_size[1]

    # Get Dataset:
    if True:
        assert args.dataset_format in ["png", "npy"]

        files = Path(args.dataset_path).glob("*.{}".format(
            args.dataset_format))
        if args.dataset_format == "png":
            images = []
            for filepath in files:
                image = np.array(Image.open(filepath)).astype("float32")
                image = preprocess(image, hyperparams.num_bits_x)
                images.append(image)
            assert len(images) > 0
            images = np.asanyarray(images)
        elif args.dataset_format == "npy":
            images = []
            for filepath in files:
                array = np.load(filepath).astype("float32")
                array = preprocess(array, hyperparams.num_bits_x)
                images.append(array)
            assert len(images) > 0
            num_files = len(images)
            images = np.asanyarray(images)
            images = images.reshape((num_files * images.shape[1], ) +
                                    images.shape[2:])
        else:
            raise NotImplementedError

    dataset = glow.dataset.Dataset(images)
    iterator = glow.dataset.Iterator(dataset, batch_size=1)

    print(tabulate([["#image", len(dataset)]]))

    encoder = Glow(hyperparams, hdf5_path=args.snapshot_path)
    if using_gpu:
        encoder.to_gpu()

    ori_x = []
    enc_z = []
    rev_x = []
    fw_logdet = []
    logpZ = []
    logpZ2 = []
    i = 0

    with chainer.no_backprop_mode() and encoder.reverse() as decoder:
        for data_indices in iterator:
            i += 1

            x = to_gpu(dataset[data_indices])  # 1x3x64x64
            x += xp.random.uniform(0, 1.0 / num_bins_x, size=x.shape)
            x_img = make_uint8(x[0], num_bins_x)
            ori_x.append(x_img)

            factorized_z_distribution, fw_ldt = encoder.forward_step(x)
            fw_ldt -= math.log(num_bins_x) * num_pixels
            fw_logdet.append(cupy.asnumpy(fw_ldt.data))

            factor_z = []
            ez = []
            nll = 0
            for (zi, mean, ln_var) in factorized_z_distribution:
                nll += cf.gaussian_nll(zi, mean, ln_var)
                factor_z.append(zi.data)
                ez.append(zi.data.reshape(-1, ))

            ez = np.concatenate(ez)
            enc_z.append(ez.get())
            logpZ.append(cupy.asnumpy(nll.data))
            logpZ2.append(
                cupy.asnumpy(
                    cf.gaussian_nll(ez, np.mean(ez), np.log(np.var(ez))).data))

            rx, _ = decoder.reverse_step(factor_z)
            rx_img = make_uint8(rx.data[0], num_bins_x)
            rev_x.append(rx_img)

            if i % 100 == 0:
                np.save(str(i) + '/ori_x.npy', ori_x)
                fw_logdet = np.array(fw_logdet)
                np.save(str(i) + '/fw_logdet.npy', fw_logdet)
                np.save(str(i) + '/enc_z.npy', enc_z)
                logpZ = np.array(logpZ)
                np.save(str(i) + '/logpZ.npy', logpZ)
                logpZ2 = np.array(logpZ2)
                np.save(str(i) + '/logpZ2.npy', logpZ2)
                np.save(str(i) + '/rev_x.npy', rev_x)

                ori_x = []
                enc_z = []
                rev_x = []
                fw_logdet = []
                logpZ = []
                logpZ2 = []
                return
Пример #25
0
                print(
                    ind,
                    args.delta,
                    log_p[i].item(),
                    logdet[i].item(),
                    train_labels[ind].item(),
                    file=f,
                )
            if ind >= 9999:
                break
    f.close()


if __name__ == "__main__":
    args = parser.parse_args()
    print(string_args(args))
    device = args.device

    model_single = Glow(
        args.n_channels,
        args.n_flow,
        args.n_block,
        affine=args.affine,
        conv_lu=not args.no_lu,
    )
    model = model_single
    model.load_state_dict(torch.load(args.model_path))
    model = model.to(device)

    test(args, model)
Пример #26
0
                        model_single.reverse(z_sample).cpu().data,
                        f'sample/{str(i + 1).zfill(6)}.png',
                        normalize=True,
                        nrow=10,
                        range=(-0.5, 0.5),
                    )

            if i % 10000 == 0:
                torch.save(model.state_dict(),
                           f'checkpoint/model_{str(i + 1).zfill(6)}.pt')
                torch.save(optimizer.state_dict(),
                           f'checkpoint/optim_{str(i + 1).zfill(6)}.pt')


if __name__ == '__main__':
    args = parser.parse_args()
    print(args)

    model_single = Glow(3,
                        args.n_flow,
                        args.n_block,
                        affine=args.affine,
                        conv_lu=not args.no_lu)
    model = nn.DataParallel(model_single)
    # model = model_single
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    train(args, model, optimizer)
    with open(output_folder + 'hparams.json') as json_file:
        hparams = json.load(json_file)

    image_shape = (64, 64, 3)
    num_classes = 40
    Batch_Size = 4
    dataset_test = CelebALoader(
        root_folder=hparams['dataroot']
    )  #'/home/yellow/deep-learning-and-practice/hw7/dataset/task_2/'
    test_loader = DataLoader(dataset_test,
                             batch_size=Batch_Size,
                             shuffle=False,
                             drop_last=True)
    model = Glow(image_shape, hparams['hidden_channels'], hparams['K'],
                 hparams['L'], hparams['actnorm_scale'],
                 hparams['flow_permutation'], hparams['flow_coupling'],
                 hparams['LU_decomposed'], num_classes, hparams['learn_top'],
                 hparams['y_condition'])

    model.load_state_dict(
        torch.load(output_folder + model_name, map_location="cpu")['model'])
    model.set_actnorm_init()
    model = model.to(device)
    model = model.eval()

    # attribute_list = [8] # Black_Hair
    attribute_list = [20, 31, 33]  # Male, Smiling, Wavy_Hair, 24z    No_Beard
    # attribute_list = [11, 26, 31, 8, 6, 7] # Brown_Hair, Pale_Skin, Smiling, Black_Hair, Big_Lips, Big_Nose
    # attribute_list = [i for i in range(40)]
    N = 8
    z_pos_list = [torch.Tensor([]).cuda() for i in range(len(attribute_list))]
Пример #28
0
def main():
    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    hyperparams = Hyperparameters(args.snapshot_path)
    hyperparams.print()

    num_bins_x = 2.0**hyperparams.num_bits_x

    encoder = Glow(hyperparams, hdf5_path=args.snapshot_path)
    if using_gpu:
        encoder.to_gpu()

    total = hyperparams.levels
    fig = plt.figure(figsize=(4 * total, 4))
    subplots = []
    for n in range(total):
        subplot = fig.add_subplot(1, total, n + 1)
        subplots.append(subplot)

    with chainer.no_backprop_mode() and encoder.reverse() as decoder:
        while True:
            seed = int(time.time())

            for level in range(1, hyperparams.levels):
                xp.random.seed(seed)
                z = xp.random.normal(0,
                                     args.temperature,
                                     size=(
                                         1,
                                         3,
                                     ) + hyperparams.image_size,
                                     dtype="float32")
                factorized_z = glow.nn.functions.factor_z(
                    z, level + 1, squeeze_factor=hyperparams.squeeze_factor)

                out = glow.nn.functions.unsqueeze(
                    factorized_z.pop(-1),
                    factor=hyperparams.squeeze_factor,
                    module=xp)
                for n, zi in enumerate(factorized_z[::-1]):
                    block = encoder.blocks[level - n - 1]
                    out, _ = block.reverse_step(
                        out,
                        gaussian_eps=zi,
                        squeeze_factor=hyperparams.squeeze_factor)
                rev_x = out
                rev_x_img = make_uint8(rev_x.data[0], num_bins_x)
                subplot = subplots[level - 1]
                subplot.imshow(rev_x_img, interpolation="none")
                subplot.set_title("level = {}".format(level))

            # original #levels
            xp.random.seed(seed)
            z = xp.random.normal(0,
                                 args.temperature,
                                 size=(
                                     1,
                                     3,
                                 ) + hyperparams.image_size,
                                 dtype="float32")
            factorized_z = encoder.factor_z(z)
            rev_x, _ = decoder.reverse_step(factorized_z)
            rev_x_img = make_uint8(rev_x.data[0], num_bins_x)
            subplot = subplots[-1]
            subplot.imshow(rev_x_img, interpolation="none")
            subplot.set_title("level = {}".format(hyperparams.levels))

            plt.pause(.01)
Пример #29
0
                            ldtv,
                    ) in zip(log_p_val, logdet_val, log_p_train_val,
                             logdet_train_val):
                        print(
                            args.delta,
                            lpv.item(),
                            ldv.item(),
                            lptv.item(),
                            ldtv.item(),
                            file=f_ll,
                        )
                f_ll.close()
    f_train_loss.close()
    f_test_loss.close()


if __name__ == "__main__":
    args = parser.parse_args()
    print(string_args(args))
    device = args.device
    model = Glow(
        args.n_channels,
        args.n_flow,
        args.n_block,
        affine=args.affine,
        conv_lu=not args.no_lu,
    )
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    train(args, model, optimizer)
Пример #30
0
def main():
    try:
        os.mkdir(args.snapshot_path)
    except:
        pass

    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    num_bins_x = 2**args.num_bits_x

    image_size = (28, 28)

    images = chainer.datasets.mnist.get_mnist(withlabel=False)[0]
    images = 255.0 * np.asarray(images).reshape((-1, ) + image_size + (1, ))
    if args.num_channels != 1:
        images = np.broadcast_to(
            images, (images.shape[0], ) + image_size + (args.num_channels, ))
    images = preprocess(images, args.num_bits_x)

    x_mean = np.mean(images)
    x_var = np.var(images)

    dataset = glow.dataset.Dataset(images)
    iterator = glow.dataset.Iterator(dataset, batch_size=args.batch_size)

    print(tabulate([
        ["#", len(dataset)],
        ["mean", x_mean],
        ["var", x_var],
    ]))

    hyperparams = Hyperparameters(args.snapshot_path)
    hyperparams.levels = args.levels
    hyperparams.depth_per_level = args.depth_per_level
    hyperparams.nn_hidden_channels = args.nn_hidden_channels
    hyperparams.image_size = image_size
    hyperparams.num_bits_x = args.num_bits_x
    hyperparams.lu_decomposition = args.lu_decomposition
    hyperparams.num_image_channels = args.num_channels
    hyperparams.save(args.snapshot_path)
    hyperparams.print()

    encoder = Glow(hyperparams, hdf5_path=args.snapshot_path)
    if using_gpu:
        encoder.to_gpu()

    optimizer = Optimizer(encoder)

    # Data dependent initialization
    if encoder.need_initialize:
        for batch_index, data_indices in enumerate(iterator):
            x = to_gpu(dataset[data_indices])
            encoder.initialize_actnorm_weights(
                x, reduce_memory=args.reduce_memory)
            break

    current_training_step = 0
    num_pixels = args.num_channels * hyperparams.image_size[0] * hyperparams.image_size[1]

    # Training loop
    for iteration in range(args.total_iteration):
        sum_loss = 0
        sum_nll = 0
        start_time = time.time()

        for batch_index, data_indices in enumerate(iterator):
            x = to_gpu(dataset[data_indices])
            x += xp.random.uniform(0, 1.0 / num_bins_x, size=x.shape)

            denom = math.log(2.0) * num_pixels

            factorized_z_distribution, logdet = encoder.forward_step(
                x, reduce_memory=args.reduce_memory)

            logdet -= math.log(num_bins_x) * num_pixels

            negative_log_likelihood = 0
            for (zi, mean, ln_var) in factorized_z_distribution:
                negative_log_likelihood += cf.gaussian_nll(zi, mean, ln_var)

            loss = (negative_log_likelihood / args.batch_size - logdet) / denom

            encoder.cleargrads()
            loss.backward()
            optimizer.update(current_training_step)
            current_training_step += 1

            sum_loss += float(loss.data)
            sum_nll += float(negative_log_likelihood.data) / args.batch_size
            printr(
                "Iteration {}: Batch {} / {} - loss: {:.8f} - nll: {:.8f} - log_det: {:.8f}".
                format(
                    iteration + 1, batch_index + 1, len(iterator),
                    float(loss.data),
                    float(negative_log_likelihood.data) / args.batch_size /
                    denom,
                    float(logdet.data) / denom))

        log_likelihood = -sum_nll / len(iterator)
        elapsed_time = time.time() - start_time
        print(
            "\033[2KIteration {} - loss: {:.5f} - log_likelihood: {:.5f} - elapsed_time: {:.3f} min".
            format(iteration + 1, sum_loss / len(iterator), log_likelihood,
                elapsed_time / 60))
        encoder.save(args.snapshot_path)