Ejemplo n.º 1
0
def plot_cat_traversal(model: infogan.InfoGAN, nrow, cat_mapping=None):
    cat_dim = model.cat_dim
    idx = np.argsort(cat_mapping) if cat_mapping is not None else np.arange(
        cat_dim)
    latent = model.sample_latent(nrow).repeat(cat_dim, 1)
    latent[:, model.cat_idx] = 0
    for d in range(cat_dim):
        latent[d * nrow:(d + 1) * nrow, model.cat_idx[idx[d]]] = 1
    samples = model.gen(latent).detach()
    fig, axs = plot_grid(samples,
                         nrow=nrow,
                         figsize=(cat_dim, nrow),
                         gridspec_kw=dict(wspace=0, hspace=0))
    # plt.suptitle(f"$c_1$: Categorical ({cat_dim})")
    for i in [0, -1]:
        _prep_ax(axs[i, 0])
    axs[0, 0].set_xlabel('$(1)$',
                         ha='center',
                         va='bottom',
                         size=_TICK_LABEL_SIZE)
    axs[-1, 0].set_xlabel(f'$({model.cat_dim})$',
                          ha='center',
                          va='bottom',
                          size=_TICK_LABEL_SIZE)

    ypos = axs[0, 0].get_position().y1

    fig.text(.5, ypos, '$c_1$', ha='center', va='bottom', size=_VAR_LABEL_SIZE)
Ejemplo n.º 2
0
def plot_cont_traversal(model: infogan.InfoGAN, c, nrow, nstep=9):
    values = torch.linspace(-2, 2, nstep).to(model.device)
    latent = model.sample_latent(nrow).repeat(nstep, 1)
    for r in range(nrow):
        latent[r::nrow, model.cont_idx[c]] = values
    samples = model.gen(latent).detach()
    fig, axs = plot_grid(samples,
                         nrow=nrow,
                         figsize=(nstep, nrow),
                         gridspec_kw=dict(wspace=0, hspace=0))
    # plt.suptitle(f"$c_{{{c + 2}}}$: Continuous (-2 to 2)")

    for i in [0, -1]:
        _prep_ax(axs[i, 0])
    axs[0, 0].set_xlabel(f'${values[ 0]:+g}$',
                         ha='center',
                         va='bottom',
                         size=_TICK_LABEL_SIZE)
    axs[-1, 0].set_xlabel(f'${values[-1]:+g}$',
                          ha='center',
                          va='bottom',
                          size=_TICK_LABEL_SIZE)

    ypos = axs[0, 0].get_position().y1

    fig.text(.5,
             ypos,
             f'$c_{{{c + 2}}}$',
             ha='center',
             va='bottom',
             size=_VAR_LABEL_SIZE)
Ejemplo n.º 3
0
def get_cat_mapping(model: infogan.InfoGAN, data_loader: DataLoader):
    eye = torch.eye(10)
    confusion = torch.zeros(10, 10)
    for data, labels in data_loader:
        real_data = data.to(model.device).unsqueeze(1).float() / 255.
        cat_logits = model.rec(model.dis(real_data)[1])[0]
        confusion += eye[labels.long()].t() @ eye[cat_logits.cpu().argmax(1)]
    return confusion.argmax(0).numpy()
Ejemplo n.º 4
0
def get_cat_mapping(model: infogan.InfoGAN, data_loader: DataLoader):
    eye = torch.eye(10)
    confusion = torch.zeros(10, 10)
    for data, labels in data_loader:
        real_data = data.to(model.device).unsqueeze(1).float() / 255.
        cat_logits = model.rec(model.dis(real_data)[1])[0]
        confusion += eye[labels.long()].t() @ eye[cat_logits.cpu().argmax(1)]
    return confusion.argmax(0).numpy()
Ejemplo n.º 5
0
def load_gan(spec):
    _, latent_dims, dataset_names = spec_util.parse_setup_spec(spec)
    checkpoint_dir = os.path.join(CHECKPOINT_ROOT, spec)
    device = torch.device('cuda')
    gan = InfoGAN(*latent_dims)
    trainer = Trainer(gan).to(device)
    load_checkpoint(trainer, checkpoint_dir)
    gan.eval()
    return gan
Ejemplo n.º 6
0
def load_gan(spec):
    _, latent_dims, dataset_names = spec_util.parse_setup_spec(spec)
    checkpoint_dir = os.path.join(CHECKPOINT_ROOT, spec)
    device = torch.device('cuda')
    gan = InfoGAN(*latent_dims)
    trainer = Trainer(gan).to(device)
    load_checkpoint(trainer, checkpoint_dir)
    gan.eval()
    return gan
Ejemplo n.º 7
0
def plot_bin_traversal(model: infogan.InfoGAN, nrow, ncol=5):
    latent = model.sample_latent(nrow * ncol).view(ncol, 1, nrow, -1).repeat(1, 2, 1, 1)
    bin_code = latent[..., model.bin_idx].clone()
    for b in range(model.bin_dim):
        latent[..., model.bin_idx] = bin_code
        latent[:, 0, :, model.bin_idx[b]] = 0
        latent[:, 1, :, model.bin_idx[b]] = 1
        samples = model.gen(latent.view(int(np.prod(latent.shape[:-1])), -1)).detach()
        plot_grid(samples, nrow=nrow, figsize=(2 * ncol, nrow),
                  gridspec_kw=dict(wspace=0, hspace=0))
        plt.suptitle(f"$c_{{{model.cont_dim + b + 2}}}$: Binary (columns: 0, 1)")
Ejemplo n.º 8
0
def plot_cont_cont_traversal(model: infogan.InfoGAN, c1, c2, nstep=9):
    values = torch.linspace(-1.5, 1.5, nstep).to(model.device)
    latent = model.sample_latent(1).repeat(nstep**2, 1)
    for s in range(nstep):
        latent[s::nstep, model.cont_idx[c2]] = values
        latent[s * nstep:(s + 1) * nstep, model.cont_idx[c1]] = values
    samples = model.gen(latent).detach()
    fig, axs = plot_grid(samples,
                         nrow=nstep,
                         figsize=(nstep, nstep),
                         gridspec_kw=dict(wspace=0, hspace=0))
    # plt.suptitle(rf"$c_{{{c1 + 2}}} \times c_{{{c2 + 2}}}$: Continuous (-2 to 2)")

    for i in [(0, 0), (0, -1), (-1, 0)]:
        _prep_ax(axs[i])
    axs[0, 0].set_xlabel(f'${values[ 0]:+g}$',
                         ha='center',
                         va='bottom',
                         size=_TICK_LABEL_SIZE)
    axs[-1, 0].set_xlabel(f'${values[-1]:+g}$',
                          ha='center',
                          va='bottom',
                          size=_TICK_LABEL_SIZE)
    axs[0, 0].set_ylabel(f'${values[ 0]:+g}$',
                         ha='right',
                         va='center',
                         rotation=0,
                         size=_TICK_LABEL_SIZE)
    axs[0, -1].set_ylabel(f'${values[-1]:+g}$',
                          ha='right',
                          va='center',
                          rotation=0,
                          size=_TICK_LABEL_SIZE)

    xpos = axs[0, 0].get_position().x0
    ypos = axs[0, 0].get_position().y1

    fig.text(.5,
             ypos,
             f'$c_{{{c1 + 2}}}$',
             ha='center',
             va='bottom',
             size=_VAR_LABEL_SIZE)
    fig.text(xpos,
             .5,
             f'$c_{{{c2 + 2}}}$',
             ha='right',
             va='center',
             size=_VAR_LABEL_SIZE)
Ejemplo n.º 9
0
def plot_bin_traversal(model: infogan.InfoGAN, nrow, ncol=5):
    latent = model.sample_latent(nrow * ncol).view(ncol, 1, nrow,
                                                   -1).repeat(1, 2, 1, 1)
    bin_code = latent[..., model.bin_idx].clone()
    for b in range(model.bin_dim):
        latent[..., model.bin_idx] = bin_code
        latent[:, 0, :, model.bin_idx[b]] = 0
        latent[:, 1, :, model.bin_idx[b]] = 1
        samples = model.gen(latent.view(int(np.prod(latent.shape[:-1])),
                                        -1)).detach()
        plot_grid(samples,
                  nrow=nrow,
                  figsize=(2 * ncol, nrow),
                  gridspec_kw=dict(wspace=0, hspace=0))
        plt.suptitle(
            f"$c_{{{model.cont_dim + b + 2}}}$: Binary (columns: 0, 1)")
Ejemplo n.º 10
0
def test(model: infogan.InfoGAN, cat_mapping=None):
    model.eval()

    fake_data = model(64)
    plot_grid(fake_data.detach(), figsize=(8, 8), gridspec_kw=dict(wspace=.1, hspace=.1))
    plt.show()

    nrow = 5
    if model.cat_dim > 0:
        infogan_util.plot_cat_traversal(model, nrow, cat_mapping)
        plt.show()
    if model.cont_dim > 0:
        for c in range(model.cont_dim):
            infogan_util.plot_cont_traversal(model, c, nrow)
            plt.show()
    if model.bin_dim > 0:
        infogan_util.plot_bin_traversal(model, nrow)
        plt.show()
Ejemplo n.º 11
0
def plot_cont_traversal(model: infogan.InfoGAN, c, nrow, nstep=9):
    values = torch.linspace(-2, 2, nstep).to(model.device)
    latent = model.sample_latent(nrow).repeat(nstep, 1)
    for r in range(nrow):
        latent[r::nrow, model.cont_idx[c]] = values
    samples = model.gen(latent).detach()
    fig, axs = plot_grid(samples, nrow=nrow, figsize=(nstep, nrow),
                         gridspec_kw=dict(wspace=0, hspace=0))
    # plt.suptitle(f"$c_{{{c + 2}}}$: Continuous (-2 to 2)")

    for i in [0, -1]:
        _prep_ax(axs[i, 0])
    axs[0,  0].set_xlabel(f'${values[ 0]:+g}$', ha='center', va='bottom', size=_TICK_LABEL_SIZE)
    axs[-1, 0].set_xlabel(f'${values[-1]:+g}$', ha='center', va='bottom', size=_TICK_LABEL_SIZE)

    ypos = axs[0, 0].get_position().y1

    fig.text(.5, ypos, f'$c_{{{c + 2}}}$', ha='center', va='bottom', size=_VAR_LABEL_SIZE)
Ejemplo n.º 12
0
def plot_cat_traversal(model: infogan.InfoGAN, nrow, cat_mapping=None):
    cat_dim = model.cat_dim
    idx = np.argsort(cat_mapping) if cat_mapping is not None else np.arange(cat_dim)
    latent = model.sample_latent(nrow).repeat(cat_dim, 1)
    latent[:, model.cat_idx] = 0
    for d in range(cat_dim):
        latent[d * nrow: (d + 1) * nrow, model.cat_idx[idx[d]]] = 1
    samples = model.gen(latent).detach()
    fig, axs = plot_grid(samples, nrow=nrow, figsize=(cat_dim, nrow),
                         gridspec_kw=dict(wspace=0, hspace=0))
    # plt.suptitle(f"$c_1$: Categorical ({cat_dim})")
    for i in [0, -1]:
        _prep_ax(axs[i, 0])
    axs[0,  0].set_xlabel('$(1)$', ha='center', va='bottom', size=_TICK_LABEL_SIZE)
    axs[-1, 0].set_xlabel(f'$({model.cat_dim})$', ha='center', va='bottom', size=_TICK_LABEL_SIZE)

    ypos = axs[0, 0].get_position().y1

    fig.text(.5, ypos, '$c_1$', ha='center', va='bottom', size=_VAR_LABEL_SIZE)
Ejemplo n.º 13
0
def test(model: infogan.InfoGAN, cat_mapping=None):
    model.eval()

    fake_data = model(64)
    plot_grid(fake_data.detach(),
              figsize=(8, 8),
              gridspec_kw=dict(wspace=.1, hspace=.1))
    plt.show()

    nrow = 5
    if model.cat_dim > 0:
        infogan_util.plot_cat_traversal(model, nrow, cat_mapping)
        plt.show()
    if model.cont_dim > 0:
        for c in range(model.cont_dim):
            infogan_util.plot_cont_traversal(model, c, nrow)
            plt.show()
    if model.bin_dim > 0:
        infogan_util.plot_bin_traversal(model, nrow)
        plt.show()
Ejemplo n.º 14
0
def plot_cont_cont_traversal(model: infogan.InfoGAN, c1, c2, nstep=9):
    values = torch.linspace(-1.5, 1.5, nstep).to(model.device)
    latent = model.sample_latent(1).repeat(nstep ** 2, 1)
    for s in range(nstep):
        latent[s::nstep, model.cont_idx[c2]] = values
        latent[s * nstep:(s + 1) * nstep, model.cont_idx[c1]] = values
    samples = model.gen(latent).detach()
    fig, axs = plot_grid(samples, nrow=nstep, figsize=(nstep, nstep),
                         gridspec_kw=dict(wspace=0, hspace=0))
    # plt.suptitle(rf"$c_{{{c1 + 2}}} \times c_{{{c2 + 2}}}$: Continuous (-2 to 2)")

    for i in [(0, 0), (0, -1), (-1, 0)]:
        _prep_ax(axs[i])
    axs[ 0, 0].set_xlabel(f'${values[ 0]:+g}$', ha='center', va='bottom', size=_TICK_LABEL_SIZE)
    axs[-1, 0].set_xlabel(f'${values[-1]:+g}$', ha='center', va='bottom', size=_TICK_LABEL_SIZE)
    axs[ 0, 0].set_ylabel(f'${values[ 0]:+g}$', ha='right', va='center', rotation=0, size=_TICK_LABEL_SIZE)
    axs[ 0,-1].set_ylabel(f'${values[-1]:+g}$', ha='right', va='center', rotation=0, size=_TICK_LABEL_SIZE)

    xpos = axs[ 0, 0].get_position().x0
    ypos = axs[ 0, 0].get_position().y1

    fig.text(.5, ypos, f'$c_{{{c1 + 2}}}$', ha='center', va='bottom', size=_VAR_LABEL_SIZE)
    fig.text(xpos, .5, f'$c_{{{c2 + 2}}}$', ha='right', va='center', size=_VAR_LABEL_SIZE)
Ejemplo n.º 15
0
def main(args=None):
    # Parse arguments
    if args is None:
        args = sys.argv[1:]
    args = parse_args(args)

    # Check if a GPU Id was set
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    keras.backend.tensorflow_backend.set_session(get_session())

    # Load appropriate model:
    if args.type == 'DCGAN':  # Deep Convolutional GAN
        model = DCGAN(args)
    elif (args.type == 'WGAN'):  # Wasserstein GAN
        model = WGAN(args)
    elif (args.type == 'CGAN'):  # Conditional GAN
        model = CGAN(args)
    elif (args.type == 'InfoGAN'):  # InfoGAN
        model = InfoGAN(args)

    # Load pre-trained weights
    if args.model:
        model.load_weights(args.model)
    elif not args.train:
        raise Exception('Please specify path to pretrained model')

    # Load MNIST Data, pre-train D for a couple of iterations and train model
    if args.train:
        print('traun')
        X_train, y_train, _, _, N = import_mnist()
        model.pre_train(X_train, y_train)
        model.train(X_train,
                    bs=args.batch_size,
                    nb_epoch=args.nb_epochs,
                    nb_iter=X_train.shape[0] // args.batch_size,
                    y_train=y_train,
                    save_path=args.save_path)

    # (Optional) Visualize results
    if args.visualize:

        model.visualize()
Ejemplo n.º 16
0
def main(args=None):
    # Parse arguments
    if args is None:
        args = sys.argv[1:]
    args = parse_args(args)

    # Check if a GPU Id was set
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    keras.backend.tensorflow_backend.set_session(get_session())

    # Load appropriate model:
    if args.type == 'DCGAN':  # Deep Convolutional GAN
        model = DCGAN(args)
    elif (args.type == 'WGAN'):  # Wasserstein GAN
        model = WGAN(args)
    elif (args.type == 'CGAN'):  # Conditional GAN
        model = CGAN(args)
    elif (args.type == 'InfoGAN'):  # InfoGAN
        model = InfoGAN(args)

    # Load pre-trained weights
    if args.model:
        model.load_weights(args.model)
    elif not args.train:
        raise Exception('Please specify path to pretrained model')

    # # Load MNIST Data, pre-train D for a couple of iterations and train model
    # if args.train:
    #     X_train, y_train, _, _, N = import_mnist()
    #     model.pre_train(X_train, y_train)
    #     model.train(X_train,
    #         bs=args.batch_size,
    #         nb_epoch=args.nb_epochs,
    #         nb_iter=2,
    #         y_train=y_train,
    #         save_path=args.save_path)

    # (Optional) Visualize results
    # if args.visualize:
    #     model.visualize()

    # X_train, y_train, _, _, N = import_mnist()
    layers = [0, 4, 7, 9, 11]
    old_params = []
    A = []
    B = []

    n_x = 50
    n_y = 50
    n_samples = 500

    for l in layers:
        W, b = model.G.layers[l].get_weights()
        old_params.append((W, b))

        A_W = np.random.randn(*W.shape)
        A_W /= np.linalg.norm(A_W.reshape(-1, A_W.shape[-1]), axis=0)
        A_W *= np.linalg.norm(W.reshape(-1, W.shape[-1]), axis=0)
        A.append(A_W)

        B_W = np.random.randn(*W.shape)
        B_W /= np.linalg.norm(B_W.reshape(-1, B_W.shape[-1]), axis=0)
        B_W *= np.linalg.norm(W.reshape(-1, W.shape[-1]), axis=0)
        B.append(B_W)

    xs = np.linspace(-3, 3, n_x)
    ys = np.linspace(-3, 3, n_y)
    loss = np.zeros((n_x, n_y))

    for i, x in enumerate(ys):
        for j, y in enumerate(ys):
            for A_W, B_W, (W, b), l in zip(A, B, old_params, layers):
                model.G.layers[l].set_weights((W + x * A_W + y * B_W, b))
            print((i, j))
            loss[i, j] = model.eval_gen_loss(n_samples)

    print('Done!')

    np.save(
        '{}_loss_n_samples_{}_xlarge_epoch_100'.format(args.type, n_samples),
        loss)
    xx, yy = np.meshgrid(xs, ys)
    plt.contour(xx, yy, loss)
    # plt.show()
    plt.savefig(
        'figures/{}_landscape_n_samples_{}_xlarge_epoch_100.png'.format(
            args.type, n_samples))
Ejemplo n.º 17
0
def encode(gan: infogan.InfoGAN, x):
    with torch.no_grad():
        _, hidden = gan.dis(x)
        cat_logits, cont_mean, cont_logvar, bin_logit = gan.rec(hidden)
    return cat_logits, cont_mean, cont_logvar, bin_logit