Beispiel #1
0
    def __init__(self, args):
        args.deterministic = True
        if args.dataset in ["mnist", "dsprites"]:
            content_encoder = models.content_encoder(
                args.g_dims, nc=args.num_channels).cuda()
            position_encoder = models.pose_encoder(
                args.z_dims,
                nc=args.num_channels,
                normalize=args.normalize_position).cuda()
        else:
            content_encoder = vgg_64.encoder(args.g_dims,
                                             nc=args.num_channels).cuda()
            position_encoder = resnet_64.pose_encoder(
                args.z_dims, nc=args.num_channels).cuda()

        if args.dataset == "mpi3d_real":
            decoder = vgg_64.drnet_decoder(args.g_dims,
                                           args.z_dims,
                                           nc=args.num_channels).cuda()
        else:
            decoder = models.decoder(args.g_dims,
                                     args.z_dims,
                                     nc=args.num_channels,
                                     skips=args.skips).cuda()

        self.content_frames = 1
        if args.content_lstm:
            content_encoder = models.content_encoder_lstm(
                args.g_dims, content_encoder, args.batch_size)
            self.content_frames = args.input_frames

        discriminator = models.scene_discriminator(args.z_dims).cuda()
        nets = {
            "content_encoder": content_encoder,
            "position_encoder": position_encoder,
            "decoder": decoder,
            "discriminator": discriminator,
        }

        self.encoder_decoder_parameters = itertools.chain(*[
            content_encoder.parameters(),
            position_encoder.parameters(),
            decoder.parameters(),
        ])

        encoder_decoder_optim = torch.optim.Adam(
            self.encoder_decoder_parameters,
            lr=args.lr,
            betas=(args.beta1, 0.999))

        discriminator_optim = torch.optim.Adam(discriminator.parameters(),
                                               lr=args.lr,
                                               betas=(args.beta1, 0.999))

        optims = {
            "encoder_decoder_optim": encoder_decoder_optim,
            "discriminator_optim": discriminator_optim,
        }

        super().__init__(nets, optims, args)
Beispiel #2
0
    def __init__(self, args):
        args.deterministic = True
        encoder_checkpoint = torch.load(args.encoder_checkpoint)
        if args.dataset in ["mnist", "dsprites"]:
            Ec = models.content_encoder(args.g_dims,
                                        nc=args.num_channels).cuda()
            Ep = models.pose_encoder(args.z_dims, nc=args.num_channels).cuda()
        else:
            Ec = vgg_64.encoder(args.g_dims, nc=args.num_channels).cuda()
            Ep = resnet_64.pose_encoder(args.z_dims,
                                        nc=args.num_channels).cuda()

        if args.dataset == "mpi3d_real":
            D = vgg_64.drnet_decoder(args.g_dims,
                                     args.z_dims,
                                     nc=args.num_channels).cuda()
        else:
            D = models.decoder(args.g_dims,
                               args.z_dims,
                               nc=args.num_channels,
                               skips=args.skips).cuda()

        Ep.load_state_dict(encoder_checkpoint["position_encoder"])
        Ec.load_state_dict(encoder_checkpoint["content_encoder"])
        D.load_state_dict(encoder_checkpoint["decoder"])
        self.Ep = nn.DataParallel(Ep)
        self.Ec = nn.DataParallel(Ec)
        self.D = nn.DataParallel(D)
        self.Ep.train()
        self.Ec.train()
        self.D.train()

        lstm_model = lstm(args.g_dims + args.z_dims, args.z_dims,
                          args.rnn_size, args.rnn_layers,
                          args.batch_size).cuda()
        nets = {"lstm": lstm_model}

        lstm_optim = torch.optim.Adam(lstm_model.parameters(),
                                      lr=args.lr,
                                      betas=(args.beta1, 0.999))

        optims = {"lstm_optim": lstm_optim}

        super().__init__(nets, optims, args)
Beispiel #3
0
def generate(args):
    # Load model
    nn.load_parameters(args.model_load_path)

    # Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    nn.set_default_context(ctx)

    # Input
    b, c, h, w = 1, 3, args.image_size, args.image_size
    x_real_a = nn.Variable([b, c, h, w])
    x_real_b = nn.Variable([b, c, h, w])
    one = nn.Variable.from_numpy_array(np.ones((1, 1, 1, 1)) * 0.5)

    # Model
    maps = args.maps
    # content/style (domain A)
    x_content_a = content_encoder(x_real_a, maps, name="content-encoder-a")
    x_style_a = style_encoder(x_real_a, maps, name="style-encoder-a")
    # content/style (domain B)
    x_content_b = content_encoder(x_real_b, maps, name="content-encoder-b")
    x_style_b = style_encoder(x_real_b, maps, name="style-encoder-b")
    # generate over domains and reconstruction of content and style (domain A)
    z_style_a = F.randn(
        shape=x_style_a.shape) if not args.example_guided else x_style_a
    z_style_a = z_style_a.apply(persistent=True)
    x_fake_a = decoder(x_content_b, z_style_a, name="decoder-a")
    # generate over domains and reconstruction of content and style (domain B)
    z_style_b = F.randn(
        shape=x_style_b.shape) if not args.example_guided else x_style_b
    z_style_b = z_style_b.apply(persistent=True)
    x_fake_b = decoder(x_content_a, z_style_b, name="decoder-b")

    # Monitor
    suffix = "Stochastic" if not args.example_guided else "Example-guided"
    monitor = Monitor(args.monitor_path)
    monitor_image_a = MonitorImage("Fake Image B to A {} Valid".format(suffix),
                                   monitor,
                                   interval=1)
    monitor_image_b = MonitorImage("Fake Image A to B {} Valid".format(suffix),
                                   monitor,
                                   interval=1)

    # DataIterator
    di_a = munit_data_iterator(args.img_path_a, args.batch_size)
    di_b = munit_data_iterator(args.img_path_b, args.batch_size)

    # Generate all
    # generate (A -> B)
    if args.example_guided:
        x_real_b.d = di_b.next()[0]
    for i in range(di_a.size):
        x_real_a.d = di_a.next()[0]
        images = []
        images.append(x_real_a.d.copy())
        for _ in range(args.num_repeats):
            x_fake_b.forward(clear_buffer=True)
            images.append(x_fake_b.d.copy())
        monitor_image_b.add(i, np.concatenate(images, axis=3))

    # generate (B -> A)
    if args.example_guided:
        x_real_a.d = di_a.next()[0]
    for i in range(di_b.size):
        x_real_b.d = di_b.next()[0]
        images = []
        images.append(x_real_b.d.copy())
        for _ in range(args.num_repeats):
            x_fake_a.forward(clear_buffer=True)
            images.append(x_fake_a.d.copy())
        monitor_image_a.add(i, np.concatenate(images, axis=3))
Beispiel #4
0
                         shuffle=True,
                         num_workers=5,
                         drop_last=True)


def get_data_batch():
    while True:
        for seq in data_loader:
            seq[1].transpose_(2, 3).transpose_(1, 2)
            yield seq


data_generator = get_data_batch()

if args.dataset == "dsprites":
    Ec = models.content_encoder(args.g_dims, nc=args.num_channels).cuda()
    Ep = models.pose_encoder(args.z_dims, nc=args.num_channels).cuda()
elif args.dataset == "mpi3d_real":
    Ec = vgg_64.encoder(args.g_dims, nc=args.num_channels).cuda()
    Ep = resnet_64.pose_encoder(args.z_dims, nc=args.num_channels).cuda()

checkpoint = torch.load(args.checkpoint)

Ec.load_state_dict(checkpoint["content_encoder"])
Ep.load_state_dict(checkpoint["position_encoder"])
Ec.eval()
Ep.eval()

latent_c = None
latent_p = None
factors_p = None
Beispiel #5
0
def interpolate(args):
    # Load model
    nn.load_parameters(args.model_load_path)

    # Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    nn.set_default_context(ctx)

    # Input
    b, c, h, w = 1, 3, args.image_size, args.image_size
    x_real_a = nn.Variable([b, c, h, w])
    x_real_b = nn.Variable([b, c, h, w])
    one = nn.Variable.from_numpy_array(np.ones((1, 1, 1, 1)) * 0.5)

    # Model
    maps = args.maps
    # content/style (domain A)
    x_content_a = content_encoder(x_real_a, maps, name="content-encoder-a")
    x_style_a = style_encoder(x_real_a, maps, name="style-encoder-a")
    # content/style (domain B)
    x_content_b = content_encoder(x_real_b, maps, name="content-encoder-b")
    x_style_b = style_encoder(x_real_b, maps, name="style-encoder-b")
    # generate over domains and reconstruction of content and style (domain A)
    z_style_a = nn.Variable(
        x_style_a.shape) if not args.example_guided else x_style_a
    z_style_a = z_style_a.apply(persistent=True)
    x_fake_a = decoder(x_content_b, z_style_a, name="decoder-a")
    # generate over domains and reconstruction of content and style (domain B)
    z_style_b = nn.Variable(
        x_style_b.shape) if not args.example_guided else x_style_b
    z_style_b = z_style_b.apply(persistent=True)
    x_fake_b = decoder(x_content_a, z_style_b, name="decoder-b")

    # Monitor
    def file_names(path):
        return path.split("/")[-1].rstrip("_AB.jpg")

    suffix = "Stochastic" if not args.example_guided else "Example-guided"
    monitor = Monitor(args.monitor_path)
    monitor_image_tile_a = MonitorImageTile(
        "Fake Image Tile {} B to A {} Interpolation".format(
            "-".join([file_names(path) for path in args.img_files_b]), suffix),
        monitor,
        interval=1,
        num_images=len(args.img_files_b))
    monitor_image_tile_b = MonitorImageTile(
        "Fake Image Tile {} A to B {} Interpolation".format(
            "-".join([file_names(path) for path in args.img_files_a]), suffix),
        monitor,
        interval=1,
        num_images=len(args.img_files_a))

    # DataIterator
    di_a = munit_data_iterator(args.img_files_a, b, shuffle=False)
    di_b = munit_data_iterator(args.img_files_b, b, shuffle=False)
    rng = np.random.RandomState(args.seed)

    # Interpolate (A -> B)
    z_data_0 = [rng.randn(*z_style_a.shape) for j in range(di_a.size)]
    z_data_1 = [rng.randn(*z_style_a.shape) for j in range(di_a.size)]
    for i in range(args.num_repeats):
        r = 1.0 * i / args.num_repeats
        images = []
        for j in range(di_a.size):
            x_data_a = di_a.next()[0]
            x_real_a.d = x_data_a
            z_style_b.d = z_data_0[j] * (1.0 - r) + z_data_1[j] * r
            x_fake_b.forward(clear_buffer=True)
            cmp_image = np.concatenate([x_data_a, x_fake_b.d.copy()], axis=3)
            images.append(cmp_image)
        images = np.concatenate(images)
        monitor_image_tile_b.add(i, images)

    # Interpolate (B -> A)
    z_data_0 = [rng.randn(*z_style_b.shape) for j in range(di_b.size)]
    z_data_1 = [rng.randn(*z_style_b.shape) for j in range(di_b.size)]
    for i in range(args.num_repeats):
        r = 1.0 * i / args.num_repeats
        images = []
        for j in range(di_b.size):
            x_data_b = di_b.next()[0]
            x_real_b.d = x_data_b
            z_style_a.d = z_data_0[j] * (1.0 - r) + z_data_1[j] * r
            x_fake_a.forward(clear_buffer=True)
            cmp_image = np.concatenate([x_data_b, x_fake_a.d.copy()], axis=3)
            images.append(cmp_image)
        images = np.concatenate(images)
        monitor_image_tile_a.add(i, images)
Beispiel #6
0
def train(args):
    # Create Communicator and Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    mpi_local_rank = comm.local_rank
    device_id = mpi_local_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    # Input
    b, c, h, w = args.batch_size, 3, args.image_size, args.image_size
    x_real_a = nn.Variable([b, c, h, w])
    x_real_b = nn.Variable([b, c, h, w])

    # Model
    # workaround for starting with the same model among devices.
    np.random.seed(412)
    maps = args.maps
    # within-domain reconstruction (domain A)
    x_content_a = content_encoder(x_real_a, maps, name="content-encoder-a")
    x_style_a = style_encoder(x_real_a, maps, name="style-encoder-a")
    x_recon_a = decoder(x_content_a, x_style_a, name="decoder-a")
    # within-domain reconstruction (domain B)
    x_content_b = content_encoder(x_real_b, maps, name="content-encoder-b")
    x_style_b = style_encoder(x_real_b, maps, name="style-encoder-b")
    x_recon_b = decoder(x_content_b, x_style_b, name="decoder-b")
    # generate over domains and reconstruction of content and style (domain A)
    z_style_a = F.randn(shape=x_style_a.shape)
    x_fake_a = decoder(x_content_b, z_style_a, name="decoder-a")
    x_content_rec_b = content_encoder(x_fake_a, maps, name="content-encoder-a")
    x_style_rec_a = style_encoder(x_fake_a, maps, name="style-encoder-a")
    # generate over domains and reconstruction of content and style (domain B)
    z_style_b = F.randn(shape=x_style_b.shape)
    x_fake_b = decoder(x_content_a, z_style_b, name="decoder-b")
    x_content_rec_a = content_encoder(x_fake_b, maps, name="content-encoder-b")
    x_style_rec_b = style_encoder(x_fake_b, maps, name="style-encoder-b")
    # discriminate (domain A)
    p_x_fake_a_list = discriminators(x_fake_a)
    p_x_real_a_list = discriminators(x_real_a)
    p_x_fake_b_list = discriminators(x_fake_b)
    p_x_real_b_list = discriminators(x_real_b)

    # Loss
    # within-domain reconstruction
    loss_recon_x_a = recon_loss(x_recon_a, x_real_a).apply(persistent=True)
    loss_recon_x_b = recon_loss(x_recon_b, x_real_b).apply(persistent=True)
    # content and style reconstruction
    loss_recon_x_style_a = recon_loss(x_style_rec_a,
                                      z_style_a).apply(persistent=True)
    loss_recon_x_content_b = recon_loss(x_content_rec_b,
                                        x_content_b).apply(persistent=True)
    loss_recon_x_style_b = recon_loss(x_style_rec_b,
                                      z_style_b).apply(persistent=True)
    loss_recon_x_content_a = recon_loss(x_content_rec_a,
                                        x_content_a).apply(persistent=True)

    # adversarial

    def f(x, y):
        return x + y

    loss_gen_a = reduce(f, [lsgan_loss(p_f)
                            for p_f in p_x_fake_a_list]).apply(persistent=True)
    loss_dis_a = reduce(f, [
        lsgan_loss(p_f, p_r)
        for p_f, p_r in zip(p_x_fake_a_list, p_x_real_a_list)
    ]).apply(persistent=True)
    loss_gen_b = reduce(f, [lsgan_loss(p_f)
                            for p_f in p_x_fake_b_list]).apply(persistent=True)
    loss_dis_b = reduce(f, [
        lsgan_loss(p_f, p_r)
        for p_f, p_r in zip(p_x_fake_b_list, p_x_real_b_list)
    ]).apply(persistent=True)
    # loss for generator-related models
    loss_gen = loss_gen_a + loss_gen_b \
        + args.lambda_x * (loss_recon_x_a + loss_recon_x_b) \
        + args.lambda_c * (loss_recon_x_content_a + loss_recon_x_content_b) \
        + args.lambda_s * (loss_recon_x_style_a + loss_recon_x_style_b)
    # loss for discriminators
    loss_dis = loss_dis_a + loss_dis_b

    # Solver
    lr_g, lr_d, beta1, beta2 = args.lr_g, args.lr_d, args.beta1, args.beta2
    # solver for generator-related models
    solver_gen = S.Adam(lr_g, beta1, beta2)
    with nn.parameter_scope("generator"):
        params_gen = nn.get_parameters()
    solver_gen.set_parameters(params_gen)
    # solver for discriminators
    solver_dis = S.Adam(lr_d, beta1, beta2)
    with nn.parameter_scope("discriminators"):
        params_dis = nn.get_parameters()
    solver_dis.set_parameters(params_dis)

    # Monitor
    monitor = Monitor(args.monitor_path)
    # time
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    # reconstruction
    monitor_loss_recon_x_a = MonitorSeries("Recon Loss Image A",
                                           monitor,
                                           interval=10)
    monitor_loss_recon_x_content_b = MonitorSeries("Recon Loss Content B",
                                                   monitor,
                                                   interval=10)
    monitor_loss_recon_x_style_a = MonitorSeries("Recon Loss Style A",
                                                 monitor,
                                                 interval=10)
    monitor_loss_recon_x_b = MonitorSeries("Recon Loss Image B",
                                           monitor,
                                           interval=10)
    monitor_loss_recon_x_content_a = MonitorSeries("Recon Loss Content A",
                                                   monitor,
                                                   interval=10)
    monitor_loss_recon_x_style_b = MonitorSeries("Recon Loss Style B",
                                                 monitor,
                                                 interval=10)
    # adversarial
    monitor_loss_gen_a = MonitorSeries("Gen Loss A", monitor, interval=10)
    monitor_loss_dis_a = MonitorSeries("Dis Loss A", monitor, interval=10)
    monitor_loss_gen_b = MonitorSeries("Gen Loss B", monitor, interval=10)
    monitor_loss_dis_b = MonitorSeries("Dis Loss B", monitor, interval=10)
    monitor_losses = [
        # reconstruction
        (monitor_loss_recon_x_a, loss_recon_x_a),
        (monitor_loss_recon_x_content_b, loss_recon_x_content_b),
        (monitor_loss_recon_x_style_a, loss_recon_x_style_a),
        (monitor_loss_recon_x_b, loss_recon_x_b),
        (monitor_loss_recon_x_content_a, loss_recon_x_content_a),
        (monitor_loss_recon_x_style_b, loss_recon_x_style_b),
        # adaversarial
        (monitor_loss_gen_a, loss_gen_a),
        (monitor_loss_dis_a, loss_dis_a),
        (monitor_loss_gen_b, loss_gen_b),
        (monitor_loss_dis_b, loss_dis_b)
    ]
    # image
    monitor_image_a = MonitorImage("Fake Image B to A Train",
                                   monitor,
                                   interval=1)
    monitor_image_b = MonitorImage("Fake Image A to B Train",
                                   monitor,
                                   interval=1)
    monitor_images = [
        (monitor_image_a, x_fake_a),
        (monitor_image_b, x_fake_b),
    ]

    # DataIterator
    rng_a = np.random.RandomState(device_id)
    rng_b = np.random.RandomState(device_id + n_devices)
    di_a = munit_data_iterator(args.img_path_a, args.batch_size, rng=rng_a)
    di_b = munit_data_iterator(args.img_path_b, args.batch_size, rng=rng_b)

    # Train
    for i in range(args.max_iter // n_devices):
        ii = i * n_devices
        # Train generator-related models
        x_data_a, x_data_b = di_a.next()[0], di_b.next()[0]
        x_real_a.d, x_real_b.d = x_data_a, x_data_b
        solver_gen.zero_grad()
        loss_gen.forward(clear_no_need_grad=True)
        loss_gen.backward(clear_buffer=True)
        comm.all_reduce([w.grad for w in params_gen.values()])
        solver_gen.weight_decay(args.weight_decay_rate)
        solver_gen.update()

        # Train discriminators
        x_data_a, x_data_b = di_a.next()[0], di_b.next()[0]
        x_real_a.d, x_real_b.d = x_data_a, x_data_b
        x_fake_a.need_grad, x_fake_b.need_grad = False, False
        solver_dis.zero_grad()
        loss_dis.forward(clear_no_need_grad=True)
        loss_dis.backward(clear_buffer=True)
        comm.all_reduce([w.grad for w in params_dis.values()])
        solver_dis.weight_decay(args.weight_decay_rate)
        solver_dis.update()
        x_fake_a.need_grad, x_fake_b.need_grad = True, True

        # LR schedule
        if (i + 1) % (args.lr_decay_at_every // n_devices) == 0:
            lr_d = solver_dis.learning_rate() * args.lr_decay_rate
            lr_g = solver_gen.learning_rate() * args.lr_decay_rate
            solver_dis.set_learning_rate(lr_d)
            solver_gen.set_learning_rate(lr_g)

        if mpi_local_rank == 0:
            # Monitor
            monitor_time.add(ii)
            for mon, loss in monitor_losses:
                mon.add(ii, loss.d)
            # Save
            if (i + 1) % (args.model_save_interval // n_devices) == 0:
                for mon, x in monitor_images:
                    mon.add(ii, x.d)
                nn.save_parameters(
                    os.path.join(args.monitor_path,
                                 "param_{:05d}.h5".format(i)))

    if mpi_local_rank == 0:
        # Monitor
        for mon, loss in monitor_losses:
            mon.add(ii, loss.d)
        # Save
        for mon, x in monitor_images:
            mon.add(ii, x.d)
        nn.save_parameters(
            os.path.join(args.monitor_path, "param_{:05d}.h5".format(i)))