Пример #1
0
def main():
    args = get_args()

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    model = Inpaint()
    model = model.to(device)
    optim = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    save, load = bind_nsml(model, optim)
    if args.pause == 1:
        nsml.paused(scope=locals())

    if args.mode == 'train':
        path_train = os.path.join(dir_data_root, 'train')
        path_train_data = os.path.join(dir_data_root, 'train', 'train_data')
        tr_loader, val_loader = data_loader_with_split(path_train, batch_size=args.batch_size)

        postfix = dict()
        total_step = 0
        for epoch in trange(args.num_epochs, disable=use_nsml):
            pbar = tqdm(enumerate(tr_loader), total=len(tr_loader), disable=use_nsml)
            for step, (_, x_input, mask, x_GT) in pbar:
                total_step += 1
                x_GT = x_GT.to(device)
                x_input = x_input.to(device)
                mask = mask.to(device)
                x_mask = torch.cat([x_input, mask], dim=1)

                model.zero_grad()
                x_hat = model(x_mask)
                x_composed = compose(x_input, x_hat, mask)
                loss = l1_loss(x_composed, x_GT)
                loss.backward()
                optim.step()
                postfix['loss'] = loss.item()

                if use_nsml:
                    postfix['epoch'] = epoch
                    postfix['step_'] = step
                    postfix['total_step'] = total_step
                    postfix['steps_per_epoch'] = len(tr_loader)

                if step % args.eval_every == 0:
                    vutils.save_image(x_GT, 'x_GT.png', normalize=True)
                    vutils.save_image(x_input, 'x_input.png', normalize=True)
                    vutils.save_image(x_hat, 'x_hat.png', normalize=True)
                    vutils.save_image(mask, 'mask.png', normalize=True)
                    metric_eval = local_eval(model, val_loader, path_train_data)
                    postfix['metric_eval'] = metric_eval
                if use_nsml:
                    if step % args.print_every == 0:
                        print(postfix)
                    nsml.report(**postfix, scope=locals(), step=total_step)
                else:
                    pbar.set_postfix(postfix)
            if use_nsml:
                nsml.save(epoch)
            else:
                save(epoch)
Пример #2
0
    def _setup_loss_graph(self, s_output_tbi, s_target_tbi, s_step_size):
        """
        Connect a loss function to the graph
        See data.py for explanation of the slicing part
        """
        s_sliced_output_tbi = s_output_tbi[-s_step_size :]
        s_sliced_target_tbi = s_target_tbi[-s_step_size :]

        if self._options['loss_type'] == 'l2':
            return l2_loss(s_sliced_output_tbi, s_sliced_target_tbi)
        if self._options['loss_type'] == 'l1':
            return l1_loss(s_sliced_output_tbi, s_sliced_target_tbi)
        if self._options['loss_type'] == 'huber':
            delta = self._options['huber_delta']
            return huber_loss(s_sliced_output_tbi, s_sliced_target_tbi, delta)
        
        assert False, 'Invalid loss_type option'
        return tt.alloc(np.float32(0.))
Пример #3
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
    """
    variable definition
    """

    X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)
    X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)
    X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)

    style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim)
    """
    loss definitions
    """
    cross_entropy_loss = nn.CrossEntropyLoss()
    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()

        cross_entropy_loss.cuda()

        X_1 = X_1.cuda()
        X_2 = X_2.cuda()
        X_3 = X_3.cuda()

        style_latent_space = style_latent_space.cuda()
    """
    optimizer and scheduler definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))

    reverse_cycle_optimizer = optim.Adam(list(encoder.parameters()),
                                         lr=FLAGS.initial_learning_rate,
                                         betas=(FLAGS.beta_1, FLAGS.beta_2))

    # divide the learning rate by a factor of 10 after 80 epochs
    auto_encoder_scheduler = optim.lr_scheduler.StepLR(auto_encoder_optimizer,
                                                       step_size=80,
                                                       gamma=0.1)
    reverse_cycle_scheduler = optim.lr_scheduler.StepLR(
        reverse_cycle_optimizer, step_size=80, gamma=0.1)
    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write(
                'Epoch\tIteration\tReconstruction_loss\tKL_divergence_loss\tReverse_cycle_loss\n'
            )

    # load data set and create data loader instance
    print('Loading MNIST paired dataset...')
    paired_mnist = MNIST_Paired(root='mnist',
                                download=True,
                                train=True,
                                transform=transform_config)
    loader = cycle(
        DataLoader(paired_mnist,
                   batch_size=FLAGS.batch_size,
                   shuffle=True,
                   num_workers=0,
                   drop_last=True))

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print(
            'Epoch #' + str(epoch) +
            '..........................................................................'
        )

        # update the learning rate scheduler
        auto_encoder_scheduler.step()
        reverse_cycle_scheduler.step()

        for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)):
            # A. run the auto-encoder reconstruction
            image_batch_1, image_batch_2, _ = next(loader)

            auto_encoder_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_mu_1, style_logvar_1, class_latent_space_1 = encoder(
                Variable(X_1))
            style_latent_space_1 = reparameterize(training=True,
                                                  mu=style_mu_1,
                                                  logvar=style_logvar_1)

            kl_divergence_loss_1 = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) -
                                 style_logvar_1.exp()))
            kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels *
                                     FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_1.backward(retain_graph=True)

            style_mu_2, style_logvar_2, class_latent_space_2 = encoder(
                Variable(X_2))
            style_latent_space_2 = reparameterize(training=True,
                                                  mu=style_mu_2,
                                                  logvar=style_logvar_2)

            kl_divergence_loss_2 = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + style_logvar_2 - style_mu_2.pow(2) -
                                 style_logvar_2.exp()))
            kl_divergence_loss_2 /= (FLAGS.batch_size * FLAGS.num_channels *
                                     FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_2.backward(retain_graph=True)

            reconstructed_X_1 = decoder(style_latent_space_1,
                                        class_latent_space_2)
            reconstructed_X_2 = decoder(style_latent_space_2,
                                        class_latent_space_1)

            reconstruction_error_1 = FLAGS.reconstruction_coef * mse_loss(
                reconstructed_X_1, Variable(X_1))
            reconstruction_error_1.backward(retain_graph=True)

            reconstruction_error_2 = FLAGS.reconstruction_coef * mse_loss(
                reconstructed_X_2, Variable(X_2))
            reconstruction_error_2.backward()

            reconstruction_error = (
                reconstruction_error_1 +
                reconstruction_error_2) / FLAGS.reconstruction_coef
            kl_divergence_error = (kl_divergence_loss_1 + kl_divergence_loss_2
                                   ) / FLAGS.kl_divergence_coef

            auto_encoder_optimizer.step()

            # B. reverse cycle
            image_batch_1, _, __ = next(loader)
            image_batch_2, _, __ = next(loader)

            reverse_cycle_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_latent_space.normal_(0., 1.)

            _, __, class_latent_space_1 = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))

            reconstructed_X_1 = decoder(Variable(style_latent_space),
                                        class_latent_space_1.detach())
            reconstructed_X_2 = decoder(Variable(style_latent_space),
                                        class_latent_space_2.detach())

            style_mu_1, style_logvar_1, _ = encoder(reconstructed_X_1)
            style_latent_space_1 = reparameterize(training=False,
                                                  mu=style_mu_1,
                                                  logvar=style_logvar_1)

            style_mu_2, style_logvar_2, _ = encoder(reconstructed_X_2)
            style_latent_space_2 = reparameterize(training=False,
                                                  mu=style_mu_2,
                                                  logvar=style_logvar_2)

            reverse_cycle_loss = FLAGS.reverse_cycle_coef * l1_loss(
                style_latent_space_1, style_latent_space_2)
            reverse_cycle_loss.backward()
            reverse_cycle_loss /= FLAGS.reverse_cycle_coef

            reverse_cycle_optimizer.step()

            if (iteration + 1) % 10 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' +
                      str(reconstruction_error.data.storage().tolist()[0]))
                print('KL-Divergence loss: ' +
                      str(kl_divergence_error.data.storage().tolist()[0]))
                print('Reverse cycle loss: ' +
                      str(reverse_cycle_loss.data.storage().tolist()[0]))

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                    epoch, iteration,
                    reconstruction_error.data.storage().tolist()[0],
                    kl_divergence_error.data.storage().tolist()[0],
                    reverse_cycle_loss.data.storage().tolist()[0]))

            # write to tensorboard
            writer.add_scalar(
                'Reconstruction loss',
                reconstruction_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)
            writer.add_scalar(
                'KL-Divergence loss',
                kl_divergence_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)
            writer.add_scalar(
                'Reverse cycle loss',
                reverse_cycle_loss.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)

        # save model after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.decoder_save))
            """
            save reconstructed images and style swapped image generations to check progress
            """
            image_batch_1, image_batch_2, _ = next(loader)
            image_batch_3, _, __ = next(loader)

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)
            X_3.copy_(image_batch_3)

            style_mu_1, style_logvar_1, _ = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))
            style_mu_3, style_logvar_3, _ = encoder(Variable(X_3))

            style_latent_space_1 = reparameterize(training=False,
                                                  mu=style_mu_1,
                                                  logvar=style_logvar_1)
            style_latent_space_3 = reparameterize(training=False,
                                                  mu=style_mu_3,
                                                  logvar=style_logvar_3)

            reconstructed_X_1_2 = decoder(style_latent_space_1,
                                          class_latent_space_2)
            reconstructed_X_3_2 = decoder(style_latent_space_3,
                                          class_latent_space_2)

            # save input image batch
            image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1))
            image_batch = np.concatenate(
                (image_batch, image_batch, image_batch), axis=3)
            imshow_grid(image_batch, name=str(epoch) + '_original', save=True)

            # save reconstructed batch
            reconstructed_x = np.transpose(
                reconstructed_X_1_2.cpu().data.numpy(), (0, 2, 3, 1))
            reconstructed_x = np.concatenate(
                (reconstructed_x, reconstructed_x, reconstructed_x), axis=3)
            imshow_grid(reconstructed_x,
                        name=str(epoch) + '_target',
                        save=True)

            style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1))
            style_batch = np.concatenate(
                (style_batch, style_batch, style_batch), axis=3)
            imshow_grid(style_batch, name=str(epoch) + '_style', save=True)

            # save style swapped reconstructed batch
            reconstructed_style = np.transpose(
                reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1))
            reconstructed_style = np.concatenate(
                (reconstructed_style, reconstructed_style,
                 reconstructed_style),
                axis=3)
            imshow_grid(reconstructed_style,
                        name=str(epoch) + '_style_target',
                        save=True)
Пример #4
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    discriminator = Discriminator()
    discriminator.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        raise Exception('This is not implemented')
        encoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))

    """
    variable definition
    """

    X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)

    style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim)

    """
    loss definitions
    """
    cross_entropy_loss = nn.CrossEntropyLoss()
    adversarial_loss = nn.BCELoss()

    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()
        discriminator.cuda()

        cross_entropy_loss.cuda()
        adversarial_loss.cuda()

        X_1 = X_1.cuda()
        X_2 = X_2.cuda()
        X_3 = X_3.cuda()

        style_latent_space = style_latent_space.cuda()

    """
    optimizer and scheduler definition
    """
    auto_encoder_optimizer = optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    reverse_cycle_optimizer = optim.Adam(
        list(encoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    generator_optimizer = optim.Adam(
        list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    discriminator_optimizer = optim.Adam(
        list(discriminator.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    # divide the learning rate by a factor of 10 after 80 epochs
    auto_encoder_scheduler = optim.lr_scheduler.StepLR(auto_encoder_optimizer, step_size=80, gamma=0.1)
    reverse_cycle_scheduler = optim.lr_scheduler.StepLR(reverse_cycle_optimizer, step_size=80, gamma=0.1)
    generator_scheduler = optim.lr_scheduler.StepLR(generator_optimizer, step_size=80, gamma=0.1)
    discriminator_scheduler = optim.lr_scheduler.StepLR(discriminator_optimizer, step_size=80, gamma=0.1)

    # Used later to define discriminator ground truths
    Tensor = torch.cuda.FloatTensor if FLAGS.cuda else torch.FloatTensor

    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            headers = ['Epoch', 'Iteration', 'Reconstruction_loss', 'KL_divergence_loss', 'Reverse_cycle_loss']

            if FLAGS.forward_gan:
              headers.extend(['Generator_forward_loss', 'Discriminator_forward_loss'])

            if FLAGS.reverse_gan:
              headers.extend(['Generator_reverse_loss', 'Discriminator_reverse_loss'])

            log.write('\t'.join(headers) + '\n')

    # load data set and create data loader instance
    print('Loading CIFAR paired dataset...')
    paired_cifar = CIFAR_Paired(root='cifar', download=True, train=True, transform=transform_config)
    loader = cycle(DataLoader(paired_cifar, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True))

    # Save a batch of images to use for visualization
    image_sample_1, image_sample_2, _ = next(loader)
    image_sample_3, _, _ = next(loader)

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print('Epoch #' + str(epoch) + '..........................................................................')

        # update the learning rate scheduler
        auto_encoder_scheduler.step()
        reverse_cycle_scheduler.step()
        generator_scheduler.step()
        discriminator_scheduler.step()

        for iteration in range(int(len(paired_cifar) / FLAGS.batch_size)):
            # Adversarial ground truths
            valid = Variable(Tensor(FLAGS.batch_size, 1).fill_(1.0), requires_grad=False)
            fake = Variable(Tensor(FLAGS.batch_size, 1).fill_(0.0), requires_grad=False)

            # A. run the auto-encoder reconstruction
            image_batch_1, image_batch_2, _ = next(loader)

            auto_encoder_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_mu_1, style_logvar_1, class_latent_space_1 = encoder(Variable(X_1))
            style_latent_space_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1)

            kl_divergence_loss_1 = FLAGS.kl_divergence_coef * (
                - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp())
            )
            kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_1.backward(retain_graph=True)

            style_mu_2, style_logvar_2, class_latent_space_2 = encoder(Variable(X_2))
            style_latent_space_2 = reparameterize(training=True, mu=style_mu_2, logvar=style_logvar_2)

            kl_divergence_loss_2 = FLAGS.kl_divergence_coef * (
                - 0.5 * torch.sum(1 + style_logvar_2 - style_mu_2.pow(2) - style_logvar_2.exp())
            )
            kl_divergence_loss_2 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_2.backward(retain_graph=True)

            reconstructed_X_1 = decoder(style_latent_space_1, class_latent_space_2)
            reconstructed_X_2 = decoder(style_latent_space_2, class_latent_space_1)

            reconstruction_error_1 = FLAGS.reconstruction_coef * mse_loss(reconstructed_X_1, Variable(X_1))
            reconstruction_error_1.backward(retain_graph=True)

            reconstruction_error_2 = FLAGS.reconstruction_coef * mse_loss(reconstructed_X_2, Variable(X_2))
            reconstruction_error_2.backward()

            reconstruction_error = (reconstruction_error_1 + reconstruction_error_2) / FLAGS.reconstruction_coef
            kl_divergence_error = (kl_divergence_loss_1 + kl_divergence_loss_2) / FLAGS.kl_divergence_coef

            auto_encoder_optimizer.step()

            # A-1. Discriminator training during forward cycle
            if FLAGS.forward_gan:
              # Training generator
              generator_optimizer.zero_grad()

              g_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), valid)
              g_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), valid)

              gen_f_loss = (g_loss_1 + g_loss_2) / 2.0
              gen_f_loss.backward()

              generator_optimizer.step()

              # Training discriminator
              discriminator_optimizer.zero_grad()

              real_loss_1 = adversarial_loss(discriminator(Variable(X_1)), valid)
              real_loss_2 = adversarial_loss(discriminator(Variable(X_2)), valid)
              fake_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), fake)
              fake_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), fake)

              dis_f_loss = (real_loss_1 + real_loss_2 + fake_loss_1 + fake_loss_2) / 4.0
              dis_f_loss.backward()

              discriminator_optimizer.step()

            # B. reverse cycle
            image_batch_1, _, __ = next(loader)
            image_batch_2, _, __ = next(loader)

            reverse_cycle_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_latent_space.normal_(0., 1.)

            _, __, class_latent_space_1 = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))

            reconstructed_X_1 = decoder(Variable(style_latent_space), class_latent_space_1.detach())
            reconstructed_X_2 = decoder(Variable(style_latent_space), class_latent_space_2.detach())

            style_mu_1, style_logvar_1, _ = encoder(reconstructed_X_1)
            style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1)

            style_mu_2, style_logvar_2, _ = encoder(reconstructed_X_2)
            style_latent_space_2 = reparameterize(training=False, mu=style_mu_2, logvar=style_logvar_2)

            reverse_cycle_loss = FLAGS.reverse_cycle_coef * l1_loss(style_latent_space_1, style_latent_space_2)
            reverse_cycle_loss.backward()
            reverse_cycle_loss /= FLAGS.reverse_cycle_coef

            reverse_cycle_optimizer.step()

            # B-1. Discriminator training during reverse cycle
            if FLAGS.reverse_gan:
              # Training generator
              generator_optimizer.zero_grad()

              g_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), valid)
              g_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), valid)

              gen_r_loss = (g_loss_1 + g_loss_2) / 2.0
              gen_r_loss.backward()

              generator_optimizer.step()

              # Training discriminator
              discriminator_optimizer.zero_grad()

              real_loss_1 = adversarial_loss(discriminator(Variable(X_1)), valid)
              real_loss_2 = adversarial_loss(discriminator(Variable(X_2)), valid)
              fake_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), fake)
              fake_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), fake)

              dis_r_loss = (real_loss_1 + real_loss_2 + fake_loss_1 + fake_loss_2) / 4.0
              dis_r_loss.backward()

              discriminator_optimizer.step()

            if (iteration + 1) % 10 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0]))
                print('KL-Divergence loss: ' + str(kl_divergence_error.data.storage().tolist()[0]))
                print('Reverse cycle loss: ' + str(reverse_cycle_loss.data.storage().tolist()[0]))

                if FLAGS.forward_gan:
                  print('Generator F loss: ' + str(gen_f_loss.data.storage().tolist()[0]))
                  print('Discriminator F loss: ' + str(dis_f_loss.data.storage().tolist()[0]))

                if FLAGS.reverse_gan:
                  print('Generator R loss: ' + str(gen_r_loss.data.storage().tolist()[0]))
                  print('Discriminator R loss: ' + str(dis_r_loss.data.storage().tolist()[0]))

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                row = []

                row.append(epoch)
                row.append(iteration)
                row.append(reconstruction_error.data.storage().tolist()[0])
                row.append(kl_divergence_error.data.storage().tolist()[0])
                row.append(reverse_cycle_loss.data.storage().tolist()[0])

                if FLAGS.forward_gan:
                  row.append(gen_f_loss.data.storage().tolist()[0])
                  row.append(dis_f_loss.data.storage().tolist()[0])

                if FLAGS.reverse_gan:
                  row.append(gen_r_loss.data.storage().tolist()[0])
                  row.append(dis_r_loss.data.storage().tolist()[0])

                row = [str(x) for x in row]
                log.write('\t'.join(row) + '\n')

            # write to tensorboard
            writer.add_scalar('Reconstruction loss', reconstruction_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('KL-Divergence loss', kl_divergence_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('Reverse cycle loss', reverse_cycle_loss.data.storage().tolist()[0],
                              epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)

            if FLAGS.forward_gan:
              writer.add_scalar('Generator F loss', gen_f_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
              writer.add_scalar('Discriminator F loss', dis_f_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)

            if FLAGS.reverse_gan:
              writer.add_scalar('Generator R loss', gen_r_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
              writer.add_scalar('Discriminator R loss', dis_r_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)

        # save model after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save))

            """
            save reconstructed images and style swapped image generations to check progress
            """

            X_1.copy_(image_sample_1)
            X_2.copy_(image_sample_2)
            X_3.copy_(image_sample_3)

            style_mu_1, style_logvar_1, _ = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))
            style_mu_3, style_logvar_3, _ = encoder(Variable(X_3))

            style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1)
            style_latent_space_3 = reparameterize(training=False, mu=style_mu_3, logvar=style_logvar_3)

            reconstructed_X_1_2 = decoder(style_latent_space_1, class_latent_space_2)
            reconstructed_X_3_2 = decoder(style_latent_space_3, class_latent_space_2)

            # save input image batch
            image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              image_batch = np.concatenate((image_batch, image_batch, image_batch), axis=3)
            imshow_grid(image_batch, name=str(epoch) + '_original', save=True)

            # save reconstructed batch
            reconstructed_x = np.transpose(reconstructed_X_1_2.cpu().data.numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              reconstructed_x = np.concatenate((reconstructed_x, reconstructed_x, reconstructed_x), axis=3)
            imshow_grid(reconstructed_x, name=str(epoch) + '_target', save=True)

            style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              style_batch = np.concatenate((style_batch, style_batch, style_batch), axis=3)
            imshow_grid(style_batch, name=str(epoch) + '_style', save=True)

            # save style swapped reconstructed batch
            reconstructed_style = np.transpose(reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              reconstructed_style = np.concatenate((reconstructed_style, reconstructed_style, reconstructed_style), axis=3)
            imshow_grid(reconstructed_style, name=str(epoch) + '_style_target', save=True)
Пример #5
0
    def build_model(self):

        # Placeholders for real training samples
        self.input_A_real = tf.placeholder(tf.float32,
                                           shape=self.input_shape,
                                           name='input_A_real')
        self.input_B_real = tf.placeholder(tf.float32,
                                           shape=self.input_shape,
                                           name='input_B_real')
        # Placeholders for fake generated samples
        self.input_A_fake = tf.placeholder(tf.float32,
                                           shape=self.input_shape,
                                           name='input_A_fake')
        self.input_B_fake = tf.placeholder(tf.float32,
                                           shape=self.input_shape,
                                           name='input_B_fake')
        # Placeholder for test samples
        self.input_A_test = tf.placeholder(tf.float32,
                                           shape=self.input_shape,
                                           name='input_A_test')
        self.input_B_test = tf.placeholder(tf.float32,
                                           shape=self.input_shape,
                                           name='input_B_test')

        self.generation_B = self.generator(inputs=self.input_A_real,
                                           reuse=False,
                                           scope_name='generator_A2B')
        self.cycle_A = self.generator(inputs=self.generation_B,
                                      reuse=False,
                                      scope_name='generator_B2A')

        self.generation_A = self.generator(inputs=self.input_B_real,
                                           reuse=True,
                                           scope_name='generator_B2A')
        self.cycle_B = self.generator(inputs=self.generation_A,
                                      reuse=True,
                                      scope_name='generator_A2B')

        self.generation_A_identity = self.generator(inputs=self.input_A_real,
                                                    reuse=True,
                                                    scope_name='generator_B2A')
        self.generation_B_identity = self.generator(inputs=self.input_B_real,
                                                    reuse=True,
                                                    scope_name='generator_A2B')

        self.discrimination_A_fake = self.discriminator(
            inputs=self.generation_A,
            reuse=False,
            scope_name='discriminator_A')
        self.discrimination_B_fake = self.discriminator(
            inputs=self.generation_B,
            reuse=False,
            scope_name='discriminator_B')

        # Cycle loss
        self.cycle_loss = l1_loss(y=self.input_A_real,
                                  y_hat=self.cycle_A) + l1_loss(
                                      y=self.input_B_real, y_hat=self.cycle_B)

        # Identity loss
        self.identity_loss = l1_loss(
            y=self.input_A_real, y_hat=self.generation_A_identity) + l1_loss(
                y=self.input_B_real, y_hat=self.generation_B_identity)

        # Place holder for lambda_cycle and lambda_identity
        self.lambda_cycle = tf.placeholder(tf.float32,
                                           None,
                                           name='lambda_cycle')
        self.lambda_identity = tf.placeholder(tf.float32,
                                              None,
                                              name='lambda_identity')

        # Generator loss
        # Generator wants to fool discriminator
        self.generator_loss_A2B = l2_loss(y=tf.ones_like(
            self.discrimination_B_fake),
                                          y_hat=self.discrimination_B_fake)
        self.generator_loss_B2A = l2_loss(y=tf.ones_like(
            self.discrimination_A_fake),
                                          y_hat=self.discrimination_A_fake)

        # Merge the two generators and the cycle loss
        self.generator_loss = self.generator_loss_A2B + self.generator_loss_B2A + self.lambda_cycle * self.cycle_loss + self.lambda_identity * self.identity_loss

        # Discriminator loss
        self.discrimination_input_A_real = self.discriminator(
            inputs=self.input_A_real, reuse=True, scope_name='discriminator_A')
        self.discrimination_input_B_real = self.discriminator(
            inputs=self.input_B_real, reuse=True, scope_name='discriminator_B')
        self.discrimination_input_A_fake = self.discriminator(
            inputs=self.input_A_fake, reuse=True, scope_name='discriminator_A')
        self.discrimination_input_B_fake = self.discriminator(
            inputs=self.input_B_fake, reuse=True, scope_name='discriminator_B')

        # Discriminator wants to classify real and fake correctly
        self.discriminator_loss_input_A_real = l2_loss(
            y=tf.ones_like(self.discrimination_input_A_real),
            y_hat=self.discrimination_input_A_real)
        self.discriminator_loss_input_A_fake = l2_loss(
            y=tf.zeros_like(self.discrimination_input_A_fake),
            y_hat=self.discrimination_input_A_fake)
        self.discriminator_loss_A = (self.discriminator_loss_input_A_real +
                                     self.discriminator_loss_input_A_fake) / 2

        self.discriminator_loss_input_B_real = l2_loss(
            y=tf.ones_like(self.discrimination_input_B_real),
            y_hat=self.discrimination_input_B_real)
        self.discriminator_loss_input_B_fake = l2_loss(
            y=tf.zeros_like(self.discrimination_input_B_fake),
            y_hat=self.discrimination_input_B_fake)
        self.discriminator_loss_B = (self.discriminator_loss_input_B_real +
                                     self.discriminator_loss_input_B_fake) / 2

        # Merge the two discriminators into one
        self.discriminator_loss = self.discriminator_loss_A + self.discriminator_loss_B

        # Categorize variables because we have to optimize the two sets of the variables separately
        trainable_variables = tf.trainable_variables()
        self.discriminator_vars = [
            var for var in trainable_variables if 'discriminator' in var.name
        ]
        self.generator_vars = [
            var for var in trainable_variables if 'generator' in var.name
        ]
        #for var in t_vars: print(var.name)

        # Reserved for test
        self.generation_B_test = self.generator(inputs=self.input_A_test,
                                                reuse=True,
                                                scope_name='generator_A2B')
        self.generation_A_test = self.generator(inputs=self.input_B_test,
                                                reuse=True,
                                                scope_name='generator_B2A')
Пример #6
0
def local_eval(model, test_loader, path_GT):
    fnames, x_hats = _infer(model, None, test_loader=test_loader)
    x_GTs = read_prediction_gt(path_GT, fnames)
    loss = float(l1_loss(x_hats, x_GTs))
    print('local_eval', loss)
    return loss
Пример #7
0
    def build_model(self):

        # Placeholders for real training samples
        self.pitch_A_real = tf.placeholder(tf.float32, \
                            shape=self.pitch_shape, name='pitch_A_real')
        self.pitch_B_real = tf.placeholder(tf.float32, \
                            shape=self.pitch_shape, name='pitch_B_real')

        self.mfc_A = tf.placeholder(tf.float32, \
                            shape=self.mfc_shape, name='mfc_A_real')
        self.mfc_B = tf.placeholder(tf.float32, \
                            shape=self.mfc_shape, name='mfc_B_real')

        # Placeholders for fake generated samples
        self.pitch_A_fake = tf.placeholder(tf.float32, \
                            shape=self.pitch_shape, name='pitch_A_fake')
        self.pitch_B_fake = tf.placeholder(tf.float32, \
                            shape=self.pitch_shape, name='pitch_B_fake')

        # Placeholder for test samples
        self.pitch_A_test = tf.placeholder(tf.float32, \
                            shape=self.pitch_shape, name='pitch_A_test')
        self.mfc_A_test = tf.placeholder(tf.float32, \
                            shape=self.mfc_shape, name='mfc_A_test')

        self.pitch_B_test = tf.placeholder(tf.float32, \
                            shape=self.pitch_shape, name='pitch_B_test')
        self.mfc_B_test = tf.placeholder(tf.float32, \
                            shape=self.mfc_shape, name='mfc_B_test')

        # Place holder for lambda_cycle and lambda_identity
        self.lambda_cycle = tf.placeholder(tf.float32, None, \
                                name='lambda_cycle')
        self.lambda_momenta = tf.placeholder(tf.float32, \
                                None, name='lambda_momenta')

        # Create the kernel for lddmm
        self.kernel = tf.expand_dims(tf.constant([6,50], \
                            dtype=tf.float32), axis=0)

        # Generate pitch from A to B
        self.momentum_A2B = self.generator(input_pitch=self.pitch_A_real, \
                            input_mfc=self.mfc_A, \
                            reuse=False, scope_name='generator_A2B')
        self.generation_A2B = forward_tan(x=self.pitch_A_real, \
                            p=self.momentum_A2B, kernel=self.kernel)
        self.momentum_cycle_A2A = self.generator(input_pitch=self.generation_A2B, \
                            input_mfc=self.mfc_B, \
                            reuse=False, scope_name='generator_B2A')
        self.cycle_A2A = forward_tan(x=self.generation_A2B, \
                            p=self.momentum_cycle_A2A, kernel=self.kernel)

        # Generate pitch from B to A
        self.momentum_B2A = self.generator(input_pitch=self.pitch_B_real, \
                            input_mfc=self.mfc_B, \
                            reuse=True, scope_name='generator_B2A')
        self.generation_B2A = forward_tan(x=self.pitch_B_real, \
                            p=self.momentum_B2A, kernel=self.kernel)
        self.momentum_cycle_B2B = self.generator(input_pitch=self.generation_B2A, \
                            input_mfc=self.mfc_A, \
                            reuse=True, scope_name='generator_A2B')
        self.cycle_B2B = forward_tan(x=self.generation_B2A, \
                            p=self.momentum_cycle_B2B, kernel=self.kernel)

        # Generator Discriminator Loss
        self.discrimination_B_fake \
                = self.discriminator(input1=self.pitch_A_real, \
                    input2=self.generation_A2B, reuse=False, \
                    scope_name='discriminator_A')
        self.discrimination_A_fake \
                = self.discriminator(input1=self.pitch_B_real, \
                    input2=self.generation_B2A, reuse=False, \
                    scope_name='discriminator_B')

        # Cycle loss
        self.cycle_loss = (l1_loss(y=self.pitch_A_real, y_hat=self.cycle_A2A) \
                        + l1_loss(y=self.pitch_B_real, y_hat=self.cycle_B2B)) / 2.0

        # Generator loss
        # Generator wants to fool discriminator
        self.generator_loss_A2B \
            = l1_loss(y=tf.ones_like(self.discrimination_B_fake), \
                y_hat=self.discrimination_B_fake)
        self.generator_loss_B2A \
            = l1_loss(y=tf.ones_like(self.discrimination_A_fake), \
                y_hat=self.discrimination_A_fake)
        self.gen_disc_loss = (self.generator_loss_A2B \
                                + self.generator_loss_B2A) / 2.0

        self.momentum_loss_A2B \
            = tf.reduce_sum(tf.square(tf.matmul(self.first_order_diff_mat, \
                tf.reshape(self.momentum_A2B, [-1,1])))) \
                + tf.reduce_sum(tf.square(tf.matmul(self.first_order_diff_mat, \
                tf.reshape(self.momentum_cycle_A2A, [-1,1]))))

        self.momentum_loss_B2A \
            = tf.reduce_sum(tf.square(tf.matmul(self.first_order_diff_mat, \
                tf.reshape(self.momentum_B2A, [-1,1])))) \
                + tf.reduce_sum(tf.square(tf.matmul(self.first_order_diff_mat, \
                tf.reshape(self.momentum_cycle_B2B, [-1,1]))))

        self.momenta_loss = (self.momentum_loss_A2B +
                             self.momentum_loss_B2A) / 2.0

        # Merge the two generators, the cycle loss and vector field regularization
        self.generator_loss = (1 - self.lambda_cycle - self.lambda_momenta) \
                                * self.gen_disc_loss \
                                + self.lambda_cycle * self.cycle_loss \
                                + self.lambda_momenta * self.momenta_loss

        # Compute the discriminator probability for pair of inputs
        self.discrimination_input_A_real_B_fake \
            = self.discriminator(input1=self.pitch_A_real, \
                input2=self.pitch_B_fake, reuse=True, \
                scope_name='discriminator_A')
        self.discrimination_input_A_fake_B_real \
            = self.discriminator(input1=self.pitch_A_fake, \
                input2=self.pitch_B_real, reuse=True, \
                scope_name='discriminator_A')

        self.discrimination_input_B_real_A_fake \
            = self.discriminator(input1=self.pitch_B_real, \
                input2=self.pitch_A_fake, reuse=True, \
                scope_name='discriminator_B')
        self.discrimination_input_B_fake_A_real \
            = self.discriminator(input1=self.pitch_B_fake, \
                input2=self.pitch_A_real, reuse = True, \
                scope_name='discriminator_B')

        # Compute discriminator loss for backprop
        self.discriminator_loss_input_A_real \
            = l1_loss(y=tf.zeros_like(self.discrimination_input_A_real_B_fake), \
                            y_hat=self.discrimination_input_A_real_B_fake)
        self.discriminator_loss_input_A_fake \
            = l1_loss(y=tf.ones_like(self.discrimination_input_A_fake_B_real), \
                            y_hat = self.discrimination_input_A_fake_B_real)
        self.discriminator_loss_A = (self.discriminator_loss_input_A_real \
                                     + self.discriminator_loss_input_A_fake) / 2.0

        self.discriminator_loss_input_B_real \
            = l1_loss(y=tf.zeros_like(self.discrimination_input_B_real_A_fake), \
                            y_hat=self.discrimination_input_B_real_A_fake)
        self.discriminator_loss_input_B_fake \
            = l1_loss(y=tf.ones_like(self.discrimination_input_B_fake_A_real), \
                            y_hat = self.discrimination_input_B_fake_A_real)
        self.discriminator_loss_B = (self.discriminator_loss_input_B_real \
                                     + self.discriminator_loss_input_B_fake) / 2.0

        # Merge the two discriminators into one
        self.discriminator_loss = (self.discriminator_loss_A \
                                    + self.discriminator_loss_B) / 2.0

        # Categorize variables to optimize the two sets separately
        trainable_variables = tf.trainable_variables()
        self.discriminator_vars = [var for var in trainable_variables \
                                   if 'discriminator' in var.name]
        self.generator_vars = [var for var in trainable_variables \
                               if 'generator' in var.name]

        # Reserved for test
        self.momentum_A2B_test = self.generator(input_pitch=self.pitch_A_test, \
                                    input_mfc=self.mfc_A_test, \
                                    reuse=True, scope_name='generator_A2B')
        self.generation_A2B_test = forward_tan(x=self.pitch_A_test, \
                                    p=self.momentum_A2B_test, kernel=self.kernel)

        self.momentum_B2A_test = self.generator(input_pitch=self.pitch_B_test, \
                                    input_mfc=self.mfc_B_test, \
                                    reuse=True, scope_name='generator_B2A')
        self.generation_B2A_test = forward_tan(x=self.pitch_B_test, \
                                    p=self.momentum_B2A_test, kernel=self.kernel)
Пример #8
0
    def build_model(self):

        # Placeholders for real training samples
        self.input_A_real = tf.placeholder(tf.float32,
                                           shape=self.input_shape,
                                           name='input_A_real')
        self.input_B_real = tf.placeholder(tf.float32,
                                           shape=self.input_shape,
                                           name='input_B_real')

        self.input_A_fake = tf.placeholder(tf.float32,
                                           shape=self.input_shape,
                                           name='input_A_fake')  # 이미된거
        self.input_B_fake = tf.placeholder(tf.float32,
                                           shape=self.input_shape,
                                           name='input_B_fake')

        self.input_A_test = tf.placeholder(tf.float32,
                                           shape=self.input_shape,
                                           name='input_A_test')
        self.input_B_test = tf.placeholder(tf.float32,
                                           shape=self.input_shape,
                                           name='input_B_test')

        self.generation_B = self.generator(inputs=self.input_A_real,
                                           reuse=False,
                                           scope_name='generator_A2B')
        self.cycle_A = self.generator(inputs=self.generation_B,
                                      reuse=False,
                                      scope_name='generator_B2A')

        self.generation_A = self.generator(inputs=self.input_B_real,
                                           reuse=True,
                                           scope_name='generator_B2A')
        self.cycle_B = self.generator(inputs=self.generation_A,
                                      reuse=True,
                                      scope_name='generator_A2B')

        self.generation_A_identity = self.generator(inputs=self.input_A_real,
                                                    reuse=True,
                                                    scope_name='generator_B2A')
        self.generation_B_identity = self.generator(inputs=self.input_B_real,
                                                    reuse=True,
                                                    scope_name='generator_A2B')

        self.discrimination_A_fake = self.discriminator(
            inputs=self.generation_A,
            reuse=False,
            scope_name='discriminator_A')
        self.discrimination_B_fake = self.discriminator(
            inputs=self.generation_B,
            reuse=False,
            scope_name='discriminator_B')

        self.cycle_loss = l1_loss(y=self.input_A_real,
                                  y_hat=self.cycle_A) + l1_loss(
                                      y=self.input_B_real, y_hat=self.cycle_B)

        self.identity_loss = l1_loss(
            y=self.input_A_real, y_hat=self.generation_A_identity) + l1_loss(
                y=self.input_B_real, y_hat=self.generation_B_identity)

        self.lambda_cycle = tf.placeholder(tf.float32,
                                           None,
                                           name='lambda_cycle')
        self.lambda_identity = tf.placeholder(tf.float32,
                                              None,
                                              name='lambda_identity')

        self.generator_loss_B2A = l1_loss(y=self.discrimination_A_fake,
                                          y_hat=self.generation_A)
        self.generator_loss_A2B = l1_loss(y=self.discrimination_B_fake,
                                          y_hat=self.generation_B)

        self.generator_loss = self.generator_loss_A2B + self.generator_loss_B2A + self.lambda_cycle * self.cycle_loss + self.lambda_identity * self.identity_loss

        self.discrimination_input_A_real = self.discriminator(
            inputs=self.input_A_real, reuse=True, scope_name='discriminator_A')
        self.discrimination_input_B_real = self.discriminator(
            inputs=self.input_B_real, reuse=True, scope_name='discriminator_B')
        self.discrimination_input_A_fake = self.discriminator(
            inputs=self.generation_A, reuse=True, scope_name='discriminator_A')
        self.discrimination_input_B_fake = self.discriminator(
            inputs=self.generation_B, reuse=True, scope_name='discriminator_B')

        self.k_t_A = tf.placeholder(tf.float32, None, name='k_t_A')
        self.k_t_B = tf.placeholder(tf.float32, None, name='k_t_B')
        self.gamma_A = tf.placeholder(tf.float32, None, name='gamma_A')
        self.gamma_B = tf.placeholder(tf.float32, None, name='gamma_B')
        self.lambda_k_A = tf.placeholder(tf.float32, None, name='lambda_k_A')
        self.lambda_k_B = tf.placeholder(tf.float32, None, name='lambda_k_B')

        # Discriminator wants to classify real and fake correctly
        self.discriminator_loss_input_A_real = l1_loss(
            y=self.discrimination_input_A_real, y_hat=self.input_A_real)
        self.discriminator_loss_input_A_fake = l1_loss(
            y=self.discrimination_input_A_fake, y_hat=self.generation_A)

        self.discriminator_loss_A = self.discriminator_loss_input_A_real - (
            self.k_t_A * self.discriminator_loss_input_A_fake)

        self.discriminator_loss_input_B_real = l1_loss(
            y=self.discrimination_input_B_real, y_hat=self.input_B_real)
        self.discriminator_loss_input_B_fake = l1_loss(
            y=self.discrimination_input_B_fake, y_hat=self.generation_B)

        self.discriminator_loss_B = self.discriminator_loss_input_B_real - (
            self.k_t_B * self.discriminator_loss_input_B_fake)

        # Merge the two discriminators into one
        self.discriminator_loss = self.discriminator_loss_A + self.discriminator_loss_B

        trainable_variables = tf.trainable_variables()
        self.discriminator_vars = [
            var for var in trainable_variables if 'discriminator' in var.name
        ]
        self.generator_vars = [
            var for var in trainable_variables if 'generator' in var.name
        ]

        # Reserved for test
        self.generation_B_test = self.generator(inputs=self.input_A_test,
                                                reuse=True,
                                                scope_name='generator_A2B')
        self.generation_A_test = self.generator(inputs=self.input_B_test,
                                                reuse=True,
                                                scope_name='generator_B2A')
Пример #9
0
    def build_model(self):

        # Placeholders for real training samples
        self.input_A_real = tf.placeholder(tf.float32,
                                           shape=self.input_shape,
                                           name='input_A_real')
        self.input_B_real = tf.placeholder(tf.float32,
                                           shape=self.input_shape,
                                           name='input_B_real')
        # Placeholders for fake generated samples
        self.input_A_fake = tf.placeholder(tf.float32,
                                           shape=self.input_shape,
                                           name='input_A_fake')
        self.input_B_fake = tf.placeholder(tf.float32,
                                           shape=self.input_shape,
                                           name='input_B_fake')
        # Placeholder for test samples
        self.input_A_test = tf.placeholder(tf.float32,
                                           shape=self.input_shape,
                                           name='input_A_test')
        self.input_B_test = tf.placeholder(tf.float32,
                                           shape=self.input_shape,
                                           name='input_B_test')

        self.generation_B = self.generator(inputs=self.input_A_real,
                                           reuse=False,
                                           scope_name='generator_A2B')
        self.cycle_A = self.generator(inputs=self.generation_B,
                                      reuse=False,
                                      scope_name='generator_B2A')

        self.generation_A = self.generator(inputs=self.input_B_real,
                                           reuse=True,
                                           scope_name='generator_B2A')
        self.cycle_B = self.generator(inputs=self.generation_A,
                                      reuse=True,
                                      scope_name='generator_A2B')

        self.generation_A_identity = self.generator(inputs=self.input_A_real,
                                                    reuse=True,
                                                    scope_name='generator_B2A')
        self.generation_B_identity = self.generator(inputs=self.input_B_real,
                                                    reuse=True,
                                                    scope_name='generator_A2B')

        self.discrimination_A_fake = self.discriminator(
            inputs=self.generation_A,
            reuse=False,
            scope_name='discriminator_A')
        self.discrimination_B_fake = self.discriminator(
            inputs=self.generation_B,
            reuse=False,
            scope_name='discriminator_B')

        # Cycle loss
        self.cycle_loss = l1_loss(y=self.input_A_real,
                                  y_hat=self.cycle_A) + l1_loss(
                                      y=self.input_B_real, y_hat=self.cycle_B)

        # Identity loss
        self.identity_loss = l1_loss(
            y=self.input_A_real, y_hat=self.generation_A_identity) + l1_loss(
                y=self.input_B_real, y_hat=self.generation_B_identity)

        # Place holder for lambda_cycle and lambda_identity
        self.lambda_cycle = tf.placeholder(tf.float32,
                                           None,
                                           name='lambda_cycle')
        self.lambda_identity = tf.placeholder(tf.float32,
                                              None,
                                              name='lambda_identity')

        # Generator loss
        # 生成器想要欺骗判别器
        self.generator_loss_A2B = celoss(
            tf.ones_like(self.discrimination_B_fake),
            self.discrimination_B_fake)
        self.generator_loss_B2A = celoss(
            tf.ones_like(self.discrimination_A_fake),
            self.discrimination_A_fake)
        # Discriminator loss
        self.discrimination_input_A_real = self.discriminator(
            inputs=self.input_A_real, reuse=True, scope_name='discriminator_A')
        self.discrimination_input_B_real = self.discriminator(
            inputs=self.input_B_real, reuse=True, scope_name='discriminator_B')
        self.discrimination_input_A_fake = self.discriminator(
            inputs=self.input_A_fake, reuse=True, scope_name='discriminator_A')
        self.discrimination_input_B_fake = self.discriminator(
            inputs=self.input_B_fake, reuse=True, scope_name='discriminator_B')

        # Discriminator wants to classify real and fake correctly
        self.discriminator_loss_input_A_real = celoss(
            tf.ones_like(self.discrimination_input_A_real),
            self.discrimination_input_A_real)
        self.discriminator_loss_input_A_fake = celoss(
            tf.zeros_like(self.discrimination_input_A_fake),
            self.discrimination_input_A_fake)
        self.discriminator_loss_A = self.discriminator_loss_input_A_fake + self.discriminator_loss_input_A_real

        self.discriminator_loss_input_B_real = celoss(
            tf.ones_like(self.discrimination_input_B_real),
            self.discrimination_input_B_real)
        self.discriminator_loss_input_B_fake = celoss(
            tf.zeros_like(self.discrimination_input_B_fake),
            self.discrimination_input_B_fake)
        self.discriminator_loss_B = self.discriminator_loss_input_B_fake + self.discriminator_loss_input_B_real

        #自定义损失
        #生成器总损失
        self.A_loss = tf.reduce_mean(tf.abs(self.cycle_A - self.input_A_real))
        self.B_loss = tf.reduce_mean(tf.abs(self.cycle_B - self.input_B_real))
        self.Ag_loss = celoss(self.discrimination_A_fake,
                              tf.ones_like(self.discrimination_A_fake))
        self.Bg_loss = celoss(self.discrimination_B_fake,
                              tf.ones_like(self.discrimination_B_fake))
        self.generator_loss = self.Ag_loss + 20. * self.B_loss + self.Bg_loss + 20. * self.A_loss + self.identity_loss
        #self.generator_loss = self.generator_loss_A2B + 200.*self.B_loss + self.generator_loss_B2A + 200.*self.A_loss
        #self.generator_loss = self.generator_loss_A2B + 20.*self.B_loss + self.generator_loss_B2A + 20.*self.A_loss
        #判别器总损失
        #if self.train_step %5 ==0:
        self.Ad_loss = self.discriminator_loss_input_A_real + self.discriminator_loss_input_A_fake
        self.Bd_loss = self.discriminator_loss_input_B_real + self.discriminator_loss_input_B_fake
        self.discriminator_loss = self.Ad_loss + self.Bd_loss
        # Merge the two discriminators into one
        #self.discriminator_loss = self.discriminator_loss_A + self.discriminator_loss_B

        # Categorize variables because we have to optimize the two sets of the variables separately
        trainable_variables = tf.trainable_variables()
        self.discriminator_vars = [
            var for var in trainable_variables if 'discriminator' in var.name
        ]
        self.generator_vars = [
            var for var in trainable_variables if 'generator' in var.name
        ]
        #for var in t_vars: print(var.name)

        # Reserved for test
        self.generation_B_test = self.generator(inputs=self.input_A_test,
                                                reuse=True,
                                                scope_name='generator_A2B')
        self.generation_A_test = self.generator(inputs=self.input_B_test,
                                                reuse=True,
                                                scope_name='generator_B2A')
def train_step(inputs):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
        outputs = model(inputs)
        generation_A = outputs[0]
        generation_B = outputs[1]
        cycle_A = outputs[2]
        cycle_B = outputs[3]
        identity_A = outputs[4]
        identity_B = outputs[5]
        discrimination_A_real = outputs[6]
        discrimination_A_fake = outputs[7]
        discrimination_B_real = outputs[8]
        discrimination_B_fake = outputs[9]
        discrimination_A_dot_real = outputs[10]
        discrimination_A_dot_fake = outputs[11]
        discrimination_B_dot_real = outputs[12]
        discrimination_B_dot_fake = outputs[13]

        # Cycle loss.
        cycle_loss = l1_loss(inputs[0], cycle_A) + l1_loss(inputs[1], cycle_B)

        # Identity loss.
        identity_loss = l1_loss(inputs[0], identity_A) + l1_loss(
            inputs[1], identity_B)

        # Generator loss.
        generator_loss_A2B = l2_loss(tf.ones_like(discrimination_B_fake),
                                     discrimination_B_fake)
        generator_loss_B2A = l2_loss(tf.ones_like(discrimination_A_fake),
                                     discrimination_A_fake)

        two_step_generator_loss_A = l2_loss(
            tf.ones_like(discrimination_A_dot_fake), discrimination_A_dot_fake)
        two_step_generator_loss_B = l2_loss(
            tf.ones_like(discrimination_B_dot_fake), discrimination_B_dot_fake)

        generator_loss = generator_loss_A2B + generator_loss_B2A + two_step_generator_loss_A + \
                         two_step_generator_loss_B + hp.lambda_cycle * cycle_loss + hp.lambda_identity * identity_loss

        discriminator_loss_A_real = l2_loss(
            tf.ones_like(discrimination_A_real), discrimination_A_real)
        discriminator_loss_A_fake = l2_loss(
            tf.zeros_like(discrimination_A_fake), discrimination_A_fake)
        discriminator_loss_A = (discriminator_loss_A_real +
                                discriminator_loss_A_fake) / 2

        discriminator_loss_B_real = l2_loss(
            tf.ones_like(discrimination_B_real), discrimination_B_real)
        discriminator_loss_B_fake = l2_loss(
            tf.zeros_like(discrimination_B_fake), discrimination_B_fake)
        discriminator_loss_B = (discriminator_loss_B_real +
                                discriminator_loss_B_fake) / 2

        discriminator_loss_A_dot_real = l2_loss(
            tf.ones_like(discrimination_A_dot_real), discrimination_A_dot_real)
        discriminator_loss_A_dot_fake = l2_loss(
            tf.zeros_like(discrimination_A_dot_fake),
            discrimination_A_dot_fake)
        discriminator_loss_A_dot = (discriminator_loss_A_dot_real +
                                    discriminator_loss_A_dot_fake) / 2

        discriminator_loss_B_dot_real = l2_loss(
            tf.ones_like(discrimination_B_dot_real), discrimination_B_dot_real)
        discriminator_loss_B_dot_fake = l2_loss(
            tf.zeros_like(discrimination_B_dot_fake),
            discrimination_B_dot_fake)
        discriminator_loss_B_dot = (discriminator_loss_B_dot_real +
                                    discriminator_loss_B_dot_fake) / 2

        discriminator_loss = discriminator_loss_A + discriminator_loss_B + discriminator_loss_A_dot + \
                             discriminator_loss_B_dot

    generator_vars = model.generatorA2B.trainable_variables + model.generatorB2A.trainable_variables
    discriminator_vars = model.discriminator_A.trainable_variables + model.discriminator_B.trainable_variables + \
                         model.discriminator_A_dot.trainable_variables + model.discriminator_B_dot.trainable_variables

    grad_gen = gen_tape.gradient(generator_loss, sources=generator_vars)
    grad_dis = dis_tape.gradient(discriminator_loss,
                                 sources=discriminator_vars)
    generator_optimizer.apply_gradients(zip(grad_gen, generator_vars))
    discriminator_optimizer.apply_gradients(zip(grad_dis, discriminator_vars))

    gen_loss(generator_loss)
    disc_loss(discriminator_loss)
    def build_model(self):

        # Placeholders for training samples
        self.input_pitch_A = tf.placeholder(tf.float32,
                                            shape=self.input_pitch_shape,
                                            name='input_pitch_A')
        self.input_pitch_B = tf.placeholder(tf.float32,
                                            shape=self.input_pitch_shape,
                                            name='input_pitch_B')
        self.input_momenta_A2B = tf.placeholder(tf.float32,
                                                shape=self.input_pitch_shape,
                                                name='input_moment_A2B')
        self.input_mfc_A = tf.placeholder(tf.float32,
                                          shape=self.input_mfc_shape,
                                          name='input_mfc_A')
        self.input_mfc_B = tf.placeholder(tf.float32,
                                          shape=self.input_mfc_shape,
                                          name='input_mfc_B')

        # Placeholders for test samples
        self.input_mfc_test = tf.placeholder(tf.float32,
                                             shape=self.input_mfc_shape,
                                             name='input_mfc_test')
        self.input_pitch_test = tf.placeholder(tf.float32,
                                               shape=self.input_pitch_shape,
                                               name='input_pitch_test')

        # Generate momenta and pitch B
        self.generation_momenta_A2B = self.encoder(input_mfc=self.input_mfc_A, \
                                        input_pitch=self.input_pitch_A, reuse=False, \
                                        scope_name='encoder')
        self.generation_pitch_B = self.decoder(input_momenta=self.generation_momenta_A2B, \
                                    input_pitch=self.input_pitch_A, reuse=False, \
                                    scope_name='decoder')
        self.generation_mfc_B = self.generator(input_mfc=self.input_mfc_A, \
                                    input_pitch=self.generation_pitch_B, \
                                    num_mfc=self.num_mfc_features, \
                                    training=True, reuse=False, \
                                    scope_name='generator')

        # Encoder loss
        self.encoder_loss = l1_loss(y=self.input_momenta_A2B,
                                    y_hat=self.generation_momenta_A2B)

        # Decoder loss
        self.decoder_loss = l1_loss(y=self.input_pitch_B,
                                    y_hat=self.generation_pitch_B)

        # Generator loss
        self.generator_loss = l1_loss(y=self.input_mfc_B,
                                      y_hat=self.generation_mfc_B)

        # Place holder for lambda_encoder and lambda_decoder
        self.lambda_encoder = tf.placeholder(tf.float32,
                                             None,
                                             name='lambda_encoder')
        self.lambda_decoder = tf.placeholder(tf.float32,
                                             None,
                                             name='lambda_decoder')
        self.lambda_generator = tf.placeholder(tf.float32,
                                               None,
                                               name='lambda_generator')

        # Merge the encoder-decoder-generator
        self.encoder_decoder_loss = self.lambda_encoder * self.encoder_loss \
                                + self.lambda_decoder * self.decoder_loss \
                                + self.lambda_generator * self.generator_loss

        # Categorize variables because we have to optimize the two sets of the variables separately
        trainable_variables = tf.trainable_variables()
        self.encoder_vars = [
            var for var in trainable_variables if 'encoder' in var.name
        ]
        self.decoder_vars = [
            var for var in trainable_variables if 'decoder' in var.name
        ]
        self.generator_vars = [
            var for var in trainable_variables if 'generator' in var.name
        ]
        #for var in t_vars: print(var.name)

        # Reserved for test
        self.momenta_A2B_test = self.encoder(input_mfc=self.input_mfc_test, \
                                input_pitch=self.input_pitch_test, \
                                reuse=True, scope_name='encoder')
        self.pitch_B_test = self.decoder(input_momenta=self.momenta_A2B_test, \
                                input_pitch=self.input_pitch_test, reuse=True, \
                                scope_name='decoder')
        self.mfc_B_test = self.generator(input_mfc=self.input_mfc_test, \
                                input_pitch=self.pitch_B_test, \
                                num_mfc=self.num_mfc_features, training=False, \
                                reuse=True, scope_name='generator')
Пример #12
0
    def build_model(self):

        # Placeholders for real training samples
        self.input_A_real = tf.placeholder(tf.float32, shape = self.input_shape, name = 'input_A_real')
        self.input_B_real = tf.placeholder(tf.float32, shape = self.input_shape, name = 'input_B_real')
        # Placeholders for fake generated samples
        self.input_A_fake = tf.placeholder(tf.float32, shape = self.input_shape, name = 'input_A_fake')
        self.input_B_fake = tf.placeholder(tf.float32, shape = self.input_shape, name = 'input_B_fake')
        # Placeholder for test samples
        self.input_A_test = tf.placeholder(tf.float32, shape = self.input_shape, name = 'input_A_test')
        self.input_B_test = tf.placeholder(tf.float32, shape = self.input_shape, name = 'input_B_test')

        self.generation_B = self.generator(inputs = self.input_A_real, reuse = False, scope_name = 'generator_A2B')
        self.cycle_A = self.generator(inputs = self.generation_B, reuse = False, scope_name = 'generator_B2A')

        self.generation_A = self.generator(inputs = self.input_B_real, reuse = True, scope_name = 'generator_B2A')
        self.cycle_B = self.generator(inputs = self.generation_A, reuse = True, scope_name = 'generator_A2B')

        self.generation_A_identity = self.generator(inputs = self.input_A_real, reuse = True, scope_name = 'generator_B2A')
        self.generation_B_identity = self.generator(inputs = self.input_B_real, reuse = True, scope_name = 'generator_A2B')

        self.discrimination_A_fake = self.discriminator(inputs = self.generation_A, reuse = False, scope_name = 'discriminator_A')
        self.discrimination_B_fake = self.discriminator(inputs = self.generation_B, reuse = False, scope_name = 'discriminator_B')

        # Cycle loss
        self.cycle_loss = l1_loss(y = self.input_A_real, y_hat = self.cycle_A) + l1_loss(y = self.input_B_real, y_hat = self.cycle_B)

        # Identity loss
        self.identity_loss = l1_loss(y = self.input_A_real, y_hat = self.generation_A_identity) + l1_loss(y = self.input_B_real, y_hat = self.generation_B_identity)

        # Place holder for lambda_cycle and lambda_identity
        self.lambda_cycle = tf.placeholder(tf.float32, None, name = 'lambda_cycle')
        self.lambda_identity = tf.placeholder(tf.float32, None, name = 'lambda_identity')

        # Generator loss
        # Generator wants to fool discriminator
        self.generator_loss_A2B = l2_loss(y = tf.ones_like(self.discrimination_B_fake), y_hat = self.discrimination_B_fake)
        self.generator_loss_B2A = l2_loss(y = tf.ones_like(self.discrimination_A_fake), y_hat = self.discrimination_A_fake)

        # Merge the two generators and the cycle loss
        self.generator_loss = self.generator_loss_A2B + self.generator_loss_B2A + self.lambda_cycle * self.cycle_loss + self.lambda_identity * self.identity_loss

        # Discriminator loss
        self.discrimination_input_A_real = self.discriminator(inputs = self.input_A_real, reuse = True, scope_name = 'discriminator_A')
        self.discrimination_input_B_real = self.discriminator(inputs = self.input_B_real, reuse = True, scope_name = 'discriminator_B')
        self.discrimination_input_A_fake = self.discriminator(inputs = self.input_A_fake, reuse = True, scope_name = 'discriminator_A')
        self.discrimination_input_B_fake = self.discriminator(inputs = self.input_B_fake, reuse = True, scope_name = 'discriminator_B')

        # Discriminator wants to classify real and fake correctly
        self.discriminator_loss_input_A_real = l2_loss(y = tf.ones_like(self.discrimination_input_A_real), y_hat = self.discrimination_input_A_real)
        self.discriminator_loss_input_A_fake = l2_loss(y = tf.zeros_like(self.discrimination_input_A_fake), y_hat = self.discrimination_input_A_fake)
        self.discriminator_loss_A = (self.discriminator_loss_input_A_real + self.discriminator_loss_input_A_fake) / 2

        self.discriminator_loss_input_B_real = l2_loss(y = tf.ones_like(self.discrimination_input_B_real), y_hat = self.discrimination_input_B_real)
        self.discriminator_loss_input_B_fake = l2_loss(y = tf.zeros_like(self.discrimination_input_B_fake), y_hat = self.discrimination_input_B_fake)
        self.discriminator_loss_B = (self.discriminator_loss_input_B_real + self.discriminator_loss_input_B_fake) / 2

        # Merge the two discriminators into one
        self.discriminator_loss = self.discriminator_loss_A + self.discriminator_loss_B

        # Categorize variables because we have to optimize the two sets of the variables separately
        trainable_variables = tf.trainable_variables()
        self.discriminator_vars = [var for var in trainable_variables if 'discriminator' in var.name]
        self.generator_vars = [var for var in trainable_variables if 'generator' in var.name]
        #for var in t_vars: print(var.name)

        # Reserved for test
        self.generation_B_test = self.generator(inputs = self.input_A_test, reuse = True, scope_name = 'generator_A2B')
        self.generation_A_test = self.generator(inputs = self.input_B_test, reuse = True, scope_name = 'generator_B2A')