Exemple #1
0
    def __init__(self, args):

        self.z_dim = args.z_dim
        self.decay_rate = args.decay_rate
        self.learning_rate = args.learning_rate
        self.model_name = args.model_name
        self.batch_size = args.batch_size

        #initialize networks
        self.Generator = Generator(self.z_dim).cuda()
        self.Encoder = Encoder(self.z_dim).cuda()
        self.Discriminator = Discriminator().cuda()

        #set optimizers for all networks
        self.optimizer_G_E = torch.optim.Adam(
            list(self.Generator.parameters()) +
            list(self.Encoder.parameters()),
            lr=self.learning_rate,
            betas=(0.5, 0.999))

        self.optimizer_D = torch.optim.Adam(self.Discriminator.parameters(),
                                            lr=self.learning_rate,
                                            betas=(0.5, 0.999))

        #initialize network weights
        self.Generator.apply(weights_init)
        self.Encoder.apply(weights_init)
        self.Discriminator.apply(weights_init)
Exemple #2
0
def build_models(args, device='cuda'):
    models = {}

    models['encodercnn'] = CNN(input_shape=RGB_INPUT_SHAPE,
                               model_name=args.encoder_cnn_model).to(device)

    models['encoder'] = Encoder(input_shape=models['encodercnn'].out_size,
                                encoder_block='convbilstm',
                                hidden_size=args.encoder_hid_size).to(device)

    models['crossviewdecodercnn'] = CNN(input_shape=DEPTH_INPUT_SHAPE,
                                        model_name=args.encoder_cnn_model,
                                        input_channel=1).to(device)

    crossviewdecoder_in_size = list(models['crossviewdecodercnn'].out_size)
    crossviewdecoder_in_size[0] = crossviewdecoder_in_size[0] * 3
    crossviewdecoder_in_size = torch.Size(crossviewdecoder_in_size)
    models['crossviewdecoder'] = CrossViewDecoder(
        input_shape=crossviewdecoder_in_size).to(device)

    models['reconstructiondecoder'] = ReconstructionDecoder(
        input_shape=models['encoder'].out_size[1:]).to(device)

    models['viewclassifier'] = ViewClassifier(
        input_size=reduce(operator.mul, models['encoder'].out_size[1:]),
        num_classes=5,
        reverse=(not args.disable_grl)).to(device)

    return models
Exemple #3
0
    def __init__(self,
                 in_size: int,
                 ts_size: int = 100,
                 latent_dim: int = 20,
                 lr: float = 0.0005,
                 weight_decay: float = 1e-6,
                 iterations_critic: int = 5,
                 gamma: float = 10,
                 weighted: bool = True,
                 use_gru=False):
        super(TadGAN, self).__init__()
        self.in_size = in_size
        self.latent_dim = latent_dim
        self.lr = lr
        self.weight_decay = weight_decay
        self.iterations_critic = iterations_critic
        self.gamma = gamma
        self.weighted = weighted

        self.hparams = {
            'lr': self.lr,
            'weight_decay': self.weight_decay,
            'iterations_critic': self.iterations_critic,
            'gamma': self.gamma
        }

        self.encoder = Encoder(in_size,
                               ts_size=ts_size,
                               out_size=self.latent_dim,
                               batch_first=True,
                               use_gru=use_gru)
        self.generator = Generator(use_gru=use_gru)
        self.critic_x = CriticX(in_size=in_size)
        self.critic_z = CriticZ()

        self.encoder.apply(init_weights)
        self.generator.apply(init_weights)
        self.critic_x.apply(init_weights)
        self.critic_z.apply(init_weights)

        if self.logger is not None:
            self.logger.log_hyperparams(self.hparams)

        self.y_hat = []
        self.index = []
        self.critic = []
Exemple #4
0
    def __init__(self,
                 z_dim=50,
                 hidden_dim=400,
                 enc_kernel1=5,
                 enc_kernel2=5,
                 use_cuda=False):
        super(VAE, self).__init__()
        # create the encoder and decoder networks
        self.encoder = Encoder(z_dim, hidden_dim, enc_kernel1, enc_kernel2)
        self.decoder = Decoder(z_dim, hidden_dim)

        if use_cuda:
            # calling cuda() here will put all the parameters of
            # the encoder and decoder networks into gpu memory
            self.cuda()
        self.use_cuda = use_cuda
        self.z_dim = z_dim
Exemple #5
0
 def __init__(self,
              latents_sizes,
              latents_names,
              img_dim=4096,
              label_dim=114,
              latent_dim=200,
              use_CUDA=False):
     super(VAE, self).__init__()
     #creating networks
     self.encoder = Encoder(img_dim, label_dim, latent_dim)
     self.decoder = Decoder(img_dim, label_dim, latent_dim)
     self.img_dim = img_dim
     self.label_dim = label_dim
     self.latent_dim = latent_dim
     self.latents_sizes = latents_sizes
     self.latents_names = latents_names
     if use_CUDA:
         self.cuda()
     self.use_CUDA = use_CUDA
Exemple #6
0
    def __init__(self, alpha=1., beta=1., gamma=0.1):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma

        with self.init_scope():
            self.encoder = Encoder()
            self.local_disc = LocalDiscriminator()
            self.global_disc = GlobalDiscriminator()
            self.prior_disc = PriorDiscriminator()
Exemple #7
0
    def __init__(self, embedding, output_size=2):
        super(Decoder, self).__init__()
        self.output_size = output_size

        self.positions = nn.Linear(config.hidden_size * 2, output_size)
        self.encoder = Encoder(embedding, config.batch_size,
                               config.hidden_size, config.num_encoder_layers,
                               config.encoder_bidirectional)
        self.softmax = nn.Softmax(dim=1)
        self.qn_linear = nn.Linear(config.hidden_size, config.hidden_size)
        self.criterion = nn.MSELoss()
Exemple #8
0
    def __init__(self, input_type='image', representation_type='image', output_type=['image'], s_type='classes', input_dim=104, \
            representation_dim=8, output_dim=[1], s_dim=1, problem='privacy', beta=1.0, gamma=1.0, prior_type='Gaussian'):
        super(VPAF, self).__init__()

        self.problem = problem
        self.param = gamma if self.problem == 'privacy' else beta
        self.input_type = input_type
        self.representation_type = representation_type
        self.output_type = output_type
        self.output_dim = output_dim
        self.s_type = s_type
        self.prior_type = prior_type

        self.encoder = Encoder(input_type, representation_type, input_dim,
                               representation_dim)
        self.decoder = Decoder(representation_type,
                               output_type,
                               representation_dim,
                               output_dim,
                               s_dim=s_dim)
    def __init__(self, env, tau=0.1, gamma=0.9, epsilon=1.0):
        self.env = env
        self.tau = tau
        self.gamma = gamma
        self.embedding_size = 30
        self.hidden_size = 30
        self.obs_shape = self.env.get_obs().shape
        self.action_shape = 40 // 5
        if args.encoding == "onehot":
            self.encoder = OneHot(
                args.bins,
                self.env.all_questions + self.env.held_out_questions,
                self.hidden_size).to(DEVICE)
        else:
            self.encoder = Encoder(self.embedding_size,
                                   self.hidden_size).to(DEVICE)

        self.model = DQN(self.obs_shape, self.action_shape,
                         self.encoder).to(DEVICE)
        self.target_model = DQN(self.obs_shape, self.action_shape,
                                self.encoder).to(DEVICE)
        self.optimizer = torch.optim.Adam(self.model.parameters())
        self.epsilon = epsilon
        if os.path.exists(MODEL_FILE):
            checkpoint = torch.load(MODEL_FILE)
            self.encoder.load_state_dict(checkpoint['encoder_state_dict'])
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.target_model.load_state_dict(
                checkpoint['target_model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.epsilon = checkpoint['epsilon']

        # hard copy model parameters to target model parameters
        for target_param, param in zip(self.model.parameters(),
                                       self.target_model.parameters()):
            target_param.data.copy_(param)
Exemple #10
0
def main(FLAGS):
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))

    device = 'cuda:0'

    decoder.to(device)
    encoder.to(device)

    tsne = TSNE(2)

    mnist = DataLoader(
        datasets.MNIST(root='mnist',
                       download=True,
                       train=False,
                       transform=transform_config))
    s_dict = {}
    with torch.no_grad():
        for i, (image, label) in enumerate(mnist):
            label = int(label)
            print(i, label)
            style_mu_1, style_logvar_1, class_latent_space_1 = encoder(
                image.to(device))
            s_dict.setdefault(label, []).append(class_latent_space_1)

    s_all = []
    for label in range(10):
        s_all.extend(s_dict[label])

    s_all = torch.cat(s_all)
    s_all = s_all.view(s_all.shape[0], -1).cpu()

    s_2d = tsne.fit_transform(s_all)

    np.savez('s_2d.npz', s_2d=s_2d)
Exemple #11
0
    def __init__(self, env, tau=0.05, gamma=0.9, epsilon=1.0):
        super().__init__()
        self.env = env
        self.tau = tau
        self.gamma = gamma
        self.embedding_size = 64
        self.hidden_size = 64
        self.obs_shape = self.env.get_obs().shape
        self.action_shape = 40 // 5
        self.encoder = Encoder(self.embedding_size,
                               self.hidden_size).to(DEVICE)

        self.model = DQN(self.obs_shape, self.action_shape,
                         self.encoder).to(DEVICE)
        self.target_model = DQN(self.obs_shape, self.action_shape,
                                self.encoder).to(DEVICE)
        self.target_model.eval()
        self.optimizer = torch.optim.Adam(self.model.parameters())
        self.epsilon = epsilon

        # hard copy model parameters to target model parameters
        for target_param, param in zip(self.model.parameters(),
                                       self.target_model.parameters()):
            target_param.data.copy_(param)
Exemple #12
0
class VPAF(torch.nn.Module):

    def __init__(self, input_type='image', representation_type='image', output_type=['image'], s_type='classes', input_dim=104, \
            representation_dim=8, output_dim=[1], s_dim=1, problem='privacy', beta=1.0, gamma=1.0, prior_type='Gaussian'):
        super(VPAF, self).__init__()

        self.problem = problem
        self.param = gamma if self.problem == 'privacy' else beta
        self.input_type = input_type
        self.representation_type = representation_type
        self.output_type = output_type
        self.output_dim = output_dim
        self.s_type = s_type
        self.prior_type = prior_type

        self.encoder = Encoder(input_type, representation_type, input_dim,
                               representation_dim)
        self.decoder = Decoder(representation_type,
                               output_type,
                               representation_dim,
                               output_dim,
                               s_dim=s_dim)

    def get_IXY_ub(self, y_mean, mode='Gaussian'):

        if mode == 'Gaussian':
            Dkl = -0.5 * torch.sum(1.0 + self.encoder.y_logvar_theta -
                                   torch.pow(y_mean, 2) -
                                   torch.exp(self.encoder.y_logvar_theta))
            IXY_ub = Dkl / math.log(2)  # in bits
        else:  # MoG
            IXY_ub = KDE_IXY_estimation(self.encoder.y_logvar_theta, y_mean)
            IXY_ub /= math.log(2)  # in bits

        return IXY_ub

    def get_H_output_given_SY_ub(self, decoder_output, t):

        if len(t.shape) == 1:
            t = t.view(-1, 1)

        H_output_given_SY_ub = 0
        dim_start_out = 0
        dim_start_t = 0
        reg_start = 0

        for output_type_, output_dim_ in zip(self.output_type,
                                             self.output_dim):

            if output_type_ == 'classes':
                so = dim_start_out
                eo = dim_start_out + output_dim_
                st = dim_start_t
                et = dim_start_t + 1
                CE = torch.nn.functional.cross_entropy(
                    decoder_output[:, so:eo],
                    t[:, st:et].long().view(-1),
                    reduction='sum')
            elif output_type_ == 'binary':
                so = dim_start_out
                eo = dim_start_out + 1
                st = dim_start_t
                et = dim_start_t + 1
                CE = torch.nn.functional.binary_cross_entropy_with_logits(
                    decoder_output[:, so:eo].view(-1),
                    t[:, st:et].view(-1),
                    reduction='sum')
            elif output_type_ == 'image':
                eo = et = 0
                CE = torch.nn.functional.binary_cross_entropy(decoder_output,
                                                              t,
                                                              reduction='sum')
            else:  # regression
                so = dim_start_out
                eo = dim_start_out + output_dim_
                st = dim_start_t
                et = dim_start_t + output_dim_
                sr = reg_start
                er = reg_start + output_dim_
                reg_start = er
                CE = 0.5 * torch.sum(
                    math.log(2*math.pi) + self.decoder.out_logvar_phi[sr:er] + \
                        torch.pow(decoder_output[:,so:eo] - t[:,st:et], 2) / (torch.exp(self.decoder.out_logvar_phi[sr:er]) + 1e-10)
                )

            H_output_given_SY_ub += CE / math.log(2)  # in bits

            dim_start_out = eo
            dim_start_t = et

        return H_output_given_SY_ub

    def evaluate_privacy(self, dataloader, device, N, batch_size, figs_dir,
                         verbose):

        IXY_ub = 0
        H_X_given_SY_ub = 0

        with torch.no_grad():
            for it, (x, t, s) in enumerate(dataloader):

                x = x.to(device).float()
                t = t.to(device).float()
                s = s.to(device).float()

                y, y_mean = self.encoder(x)
                output = self.decoder(y, s)

                if self.input_type == 'image' and self.representation_type == 'image' and 'image' in self.output_type and it == 1:
                    torchvision.utils.save_image(x[:12 * 8],
                                                 os.path.join(
                                                     figs_dir, 'x.eps'),
                                                 nrow=12)
                    torchvision.utils.save_image(y_mean[:12 * 8],
                                                 os.path.join(
                                                     figs_dir, 'y.eps'),
                                                 nrow=12)
                    torchvision.utils.save_image(output[:12 * 8],
                                                 os.path.join(
                                                     figs_dir, 'x_hat.eps'),
                                                 nrow=12)

                IXY_ub += self.get_IXY_ub(y_mean, self.prior_type)
                H_X_given_SY_ub += self.get_H_output_given_SY_ub(output, t)
                if self.representation_type == 'image':
                    if it == 0 and self.s_type == 'classes':
                        reducer_y = umap.UMAP(random_state=0)
                        reducer_y.fit(y_mean.cpu().view(batch_size, -1),
                                      y=s.cpu())
                        reducer_x = umap.UMAP(random_state=0)
                        reducer_x.fit(x.cpu().view(batch_size, -1), y=s.cpu())
                    if it == 1:
                        if self.s_type == 'classes':
                            embedding_s_y = reducer_y.transform(
                                y_mean.cpu().view(batch_size, -1))
                            embedding_s_x = reducer_x.transform(x.cpu().view(
                                batch_size, -1))
                        reducer_y = umap.UMAP(random_state=0)
                        reducer_x = umap.UMAP(random_state=0)
                        embedding_y = reducer_y.fit_transform(
                            y_mean.cpu().view(batch_size, -1))
                        embedding_x = reducer_x.fit_transform(x.cpu().view(
                            batch_size, -1))
                        if self.s_type == 'classes':
                            plot_embeddings(embedding_y, embedding_s_y,
                                            s.cpu().view(batch_size).long(),
                                            figs_dir, 'y')
                            plot_embeddings(embedding_x, embedding_s_x,
                                            s.cpu().view(batch_size).long(),
                                            figs_dir, 'x')
                        else:
                            plot_embeddings(embedding_y, embedding_y, -1,
                                            figs_dir, 'y')
                            plot_embeddings(embedding_y, embedding_y, -1,
                                            figs_dir, 'x')

        IXY_ub /= N
        H_X_given_SY_ub /= N
        print(f'IXY: {IXY_ub.item()}') if verbose else 0
        print(f'HX_given_SY: {H_X_given_SY_ub.item()}') if verbose else 0
        return IXY_ub, H_X_given_SY_ub

    def evaluate_fairness(self, dataloader, device, N, target_vals,
                          H_T_given_S, verbose):

        IXY_ub = 0
        H_T_given_SY_ub = 0
        accuracy = 0

        with torch.no_grad():
            for it, (x, t, s) in enumerate(dataloader):

                x = x.to(device).float()
                t = t.to(device).float()
                s = s.to(device).float()

                y, y_mean = self.encoder(x)
                output = self.decoder(y, s)

                IXY_ub += self.get_IXY_ub(y_mean, self.prior_type)
                H_T_given_SY_ub += self.get_H_output_given_SY_ub(output, t)
                accuracy += metrics.get_accuracy(output, t,
                                                 target_vals) * len(x)

        IXY_ub /= N
        H_T_given_SY_ub /= N
        print(H_T_given_SY_ub)
        accuracy /= N
        IYT_given_S_lb = H_T_given_S - H_T_given_SY_ub.item()
        print(f'I(X;Y) = {IXY_ub.item()}') if verbose else 0
        print(f'I(Y;T|S) = {IYT_given_S_lb}') if verbose else 0
        print(f'Accuracy (network): {accuracy}') if verbose else 0
        return IXY_ub, IYT_given_S_lb

    def evaluate(self, dataset, verbose, figs_dir):

        device = 'cuda' if next(self.encoder.parameters()).is_cuda else 'cpu'
        batch_size = 2048
        if len(dataset) < 2048:
            batch_size = len(dataset)
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=batch_size,
                                                 shuffle=True)
        if self.problem == 'privacy':
            IXY_ub, H_X_given_SY_ub = self.evaluate_privacy(
                dataloader, device, len(dataset), batch_size, figs_dir,
                verbose)
            return IXY_ub, H_X_given_SY_ub
        else:  # fairness
            H_T_given_S = get_conditional_entropy(dataset.targets,
                                                  dataset.hidden,
                                                  dataset.target_vals,
                                                  dataset.hidden_vals)
            IXY_ub, IYT_given_S_lb = self.evaluate_fairness(
                dataloader, device, len(dataset), dataset.target_vals,
                H_T_given_S, verbose)
            return IXY_ub, IYT_given_S_lb

    def train_step(self, batch_size, learning_rate, dataloader, optimizer,
                   verbose):

        device = 'cuda' if next(self.encoder.parameters()).is_cuda else 'cpu'

        for x, t, s in progressbar(dataloader):

            x = x.to(device).float()
            t = t.to(device).float()
            s = s.to(device).float()

            optimizer.zero_grad()
            y, y_mean = self.encoder(x)
            output = self.decoder(y, s)
            IXY_ub = self.get_IXY_ub(y_mean, self.prior_type)
            H_output_given_SY_ub = self.get_H_output_given_SY_ub(output, t)
            loss = IXY_ub + self.param * H_output_given_SY_ub

            loss.backward()
            optimizer.step()


    def fit(self, dataset_train, dataset_val, epochs=1000, learning_rate=1e-3, batch_size=1024, eval_rate=15, \
        verbose=True, logs_dir='../results/logs/', figs_dir='../results/images/'):

        dataloader = torch.utils.data.DataLoader(dataset_train,
                                                 batch_size=batch_size,
                                                 shuffle=True)

        params = list(self.encoder.parameters()) + list(
            self.decoder.parameters())
        optimizer = torch.optim.Adam(params, lr=learning_rate)

        for epoch in range(epochs):
            print(f'Epoch # {epoch+1}')
            self.train_step(batch_size, learning_rate, dataloader, optimizer,
                            verbose)

            if epoch % eval_rate == eval_rate - 1:
                print(f'Evaluating TRAIN') if verbose else 0
                if self.problem == 'privacy':
                    IXY_ub, H_X_given_SY_ub = self.evaluate(
                        dataset_train, verbose, figs_dir)
                else:  # fairness
                    self.evaluate(dataset_train, verbose, figs_dir)
                    print(f'Evaluating VALIDATION/TEST') if verbose else 0
                    XY_ub, IYT_given_S_lb = self.evaluate(
                        dataset_val, verbose, figs_dir)
Exemple #13
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
    """
    variable definition
    """
    X = torch.FloatTensor(FLAGS.batch_size, 1, FLAGS.image_size,
                          FLAGS.image_size)
    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()

        X = X.cuda()
    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))
    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write(
                'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n'
            )

    # load data set and create data loader instance
    print('Loading MNIST dataset...')
    mnist = datasets.MNIST(root='mnist',
                           download=True,
                           train=True,
                           transform=transform_config)
    loader = cycle(
        DataLoader(mnist,
                   batch_size=FLAGS.batch_size,
                   shuffle=True,
                   num_workers=0,
                   drop_last=True))

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print(
            'Epoch #' + str(epoch) +
            '..........................................................................'
        )

        for iteration in range(int(len(mnist) / FLAGS.batch_size)):
            # load a mini-batch
            image_batch, labels_batch = next(loader)

            # set zero_grad for the optimizer
            auto_encoder_optimizer.zero_grad()

            X.copy_(image_batch)

            style_mu, style_logvar, class_mu, class_logvar = encoder(
                Variable(X))
            grouped_mu, grouped_logvar = accumulate_group_evidence(
                class_mu.data, class_logvar.data, labels_batch, FLAGS.cuda)

            # kl-divergence error for style latent space
            style_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) -
                                 style_logvar.exp()))
            style_kl_divergence_loss /= (FLAGS.batch_size *
                                         FLAGS.num_channels *
                                         FLAGS.image_size * FLAGS.image_size)
            style_kl_divergence_loss.backward(retain_graph=True)

            # kl-divergence error for class latent space
            class_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + grouped_logvar - grouped_mu.pow(2) -
                                 grouped_logvar.exp()))
            class_kl_divergence_loss /= (FLAGS.batch_size *
                                         FLAGS.num_channels *
                                         FLAGS.image_size * FLAGS.image_size)
            class_kl_divergence_loss.backward(retain_graph=True)

            # reconstruct samples
            """
            sampling from group mu and logvar for each image in mini-batch differently makes
            the decoder consider class latent embeddings as random noise and ignore them 
            """
            style_latent_embeddings = reparameterize(training=True,
                                                     mu=style_mu,
                                                     logvar=style_logvar)
            class_latent_embeddings = group_wise_reparameterize(
                training=True,
                mu=grouped_mu,
                logvar=grouped_logvar,
                labels_batch=labels_batch,
                cuda=FLAGS.cuda)

            reconstructed_images = decoder(style_latent_embeddings,
                                           class_latent_embeddings)

            reconstruction_error = FLAGS.reconstruction_coef * mse_loss(
                reconstructed_images, Variable(X))
            reconstruction_error.backward()

            auto_encoder_optimizer.step()

            if (iteration + 1) % 50 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' +
                      str(reconstruction_error.data.storage().tolist()[0]))
                print('Style KL-Divergence loss: ' +
                      str(style_kl_divergence_loss.data.storage().tolist()[0]))
                print('Class KL-Divergence loss: ' +
                      str(class_kl_divergence_loss.data.storage().tolist()[0]))

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                    epoch, iteration,
                    reconstruction_error.data.storage().tolist()[0],
                    style_kl_divergence_loss.data.storage().tolist()[0],
                    class_kl_divergence_loss.data.storage().tolist()[0]))

            # write to tensorboard
            writer.add_scalar(
                'Reconstruction loss',
                reconstruction_error.data.storage().tolist()[0],
                epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar(
                'Style KL-Divergence loss',
                style_kl_divergence_loss.data.storage().tolist()[0],
                epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar(
                'Class KL-Divergence loss',
                class_kl_divergence_loss.data.storage().tolist()[0],
                epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration)

        # save checkpoints after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.decoder_save))
Exemple #14
0
import sys, os
sys.path.append("..")
sys.path.extend([
    os.path.join(root, name) for root, dirs, _ in os.walk("../")
    for name in dirs
])

from _config import NNConfig
from networks import CNN, LSTM, Encoder, Decoder

nnconfig = NNConfig()
nnconfig.show()

cnn = CNN("cnn_layer1")
lstm = LSTM("lstm_layer1")
cnn.show()
lstm.show()

encoder = Encoder(cnn)
decoder = Decoder(lstm)
Exemple #15
0
    # sigma_p_inv: (n_dim, n_frames, n_frames), det_p: (d)
    # sigma_q: (batch_size, n_dim, n_frames, n_frames), mu_q: (batch_size, d, nlen)

    l1 = torch.einsum('kij,mkji->mk', sigma_p_inv,
                      sigma_q)  # tr(sigma_p_inv sigma_q)
    l2 = torch.einsum('mki,mki->mk', mu_p - mu_q,
                      torch.einsum('kij,mkj->mki', sigma_p_inv,
                                   mu_p - mu_q))  # <mu_q, sigma_p_inv, mu_q>
    loss = torch.sum(l1 + l2 + torch.log(det_p) - torch.log(det_q), dim=1)
    return loss


if (__name__ == '__main__'):

    # model definition
    encoder = Encoder()
    encoder.apply(weights_init)

    decoder = Decoder()
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if LOAD_SAVED:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', ENCODER_SAVE)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', DECODER_SAVE)))

    # loss definition
    mse_loss = nn.MSELoss()
Exemple #16
0
class CGAN(ConditionalGenerativeModel):
    def __init__(
        self,
        x_dim,
        y_dim,
        z_dim,
        gen_architecture,
        adversarial_architecture,
        folder="./CGAN",
        append_y_at_every_layer=None,
        is_patchgan=False,
        is_wasserstein=False,
        aux_architecture=None,
    ):
        architectures = [gen_architecture, adversarial_architecture]
        self._is_cycle_consistent = False
        if aux_architecture is not None:
            architectures.append(aux_architecture)
            self._is_cycle_consistent = True
        super(CGAN,
              self).__init__(x_dim=x_dim,
                             y_dim=y_dim,
                             z_dim=z_dim,
                             architectures=architectures,
                             folder=folder,
                             append_y_at_every_layer=append_y_at_every_layer)

        self._gen_architecture = self._architectures[0]
        self._adversarial_architecture = self._architectures[1]
        self._is_patchgan = is_patchgan
        self._is_wasserstein = is_wasserstein
        self._is_feature_matching = False

        ################# Define architecture
        if self._is_patchgan:
            f_xy = self._adversarial_architecture[-1][-1]["filters"]
            assert f_xy == 1, "If is PatchGAN, last layer of adversarial_XY needs 1 filter. Given: {}.".format(
                f_xy)

            a_xy = self._adversarial_architecture[-1][-1]["activation"]
            if self._is_wasserstein:
                assert a_xy == tf.identity, "If is PatchGAN, last layer of adversarial needs tf.identity. Given: {}.".format(
                    a_xy)
            else:
                assert a_xy == tf.nn.sigmoid, "If is PatchGAN, last layer of adversarial needs tf.nn.sigmoid. Given: {}.".format(
                    a_xy)
        else:
            self._adversarial_architecture.append(
                [tf.layers.flatten, {
                    "name": "Flatten"
                }])
            if self._is_wasserstein:
                self._adversarial_architecture.append([
                    logged_dense, {
                        "units": 1,
                        "activation": tf.identity,
                        "name": "Output"
                    }
                ])
            else:
                self._adversarial_architecture.append([
                    logged_dense, {
                        "units": 1,
                        "activation": tf.nn.sigmoid,
                        "name": "Output"
                    }
                ])
        self._gen_architecture[-1][1]["name"] = "Output"

        self._generator = ConditionalGenerator(self._gen_architecture,
                                               name="Generator")
        self._adversarial = Critic(self._adversarial_architecture,
                                   name="Adversarial")

        self._nets = [self._generator, self._adversarial]

        ################# Connect inputs and networks
        self._output_gen = self._generator.generate_net(
            self._mod_Z_input,
            append_elements_at_every_layer=self._append_at_every_layer,
            tf_trainflag=self._is_training)

        with tf.name_scope("InputsAdversarial"):
            if len(self._x_dim) == 1:
                self._input_real = tf.concat(
                    axis=1, values=[self._X_input, self._Y_input], name="real")
                self._input_fake = tf.concat(
                    axis=1,
                    values=[self._output_gen, self._Y_input],
                    name="fake")
            else:
                self._input_real = image_condition_concat(
                    inputs=self._X_input, condition=self._Y_input, name="real")
                self._input_fake = image_condition_concat(
                    inputs=self._output_gen,
                    condition=self._Y_input,
                    name="fake")

        self._output_adversarial_real = self._adversarial.generate_net(
            self._input_real, tf_trainflag=self._is_training)
        self._output_adversarial_fake = self._adversarial.generate_net(
            self._input_fake, tf_trainflag=self._is_training)

        assert self._output_gen.get_shape()[1:] == x_dim, (
            "Output of generator is {}, but x_dim is {}.".format(
                self._output_gen.get_shape(), x_dim))

        ################# Auxiliary network for cycle consistency
        if self._is_cycle_consistent:
            self._auxiliary = Encoder(self._architectures[2], name="Auxiliary")
            self._output_auxiliary = self._auxiliary.generate_net(
                self._output_gen, tf_trainflag=self._is_training)
            assert self._output_auxiliary.get_shape().as_list(
            ) == self._mod_Z_input.get_shape().as_list(), (
                "Wrong shape for auxiliary vs. mod Z: {} vs {}.".format(
                    self._output_auxiliary.get_shape(),
                    self._mod_Z_input.get_shape()))
            self._nets.append(self._auxiliary)

        ################# Finalize
        self._init_folders()
        self._verify_init()

        if self._is_patchgan:
            print("PATCHGAN chosen with output: {}.".format(
                self._output_adversarial_real.shape))

    def compile(self,
                loss,
                logged_images=None,
                logged_labels=None,
                learning_rate=0.0005,
                learning_rate_gen=None,
                learning_rate_adversarial=None,
                optimizer=tf.train.RMSPropOptimizer,
                feature_matching=False,
                label_smoothing=1):
        if self._is_wasserstein and loss != "wasserstein":
            raise ValueError(
                "If is_wasserstein is true in Constructor, loss needs to be wasserstein."
            )
        if not self._is_wasserstein and loss == "wasserstein":
            raise ValueError(
                "If loss is wasserstein, is_wasserstein needs to be true in constructor."
            )

        if learning_rate_gen is None:
            learning_rate_gen = learning_rate
        if learning_rate_adversarial is None:
            learning_rate_adversarial = learning_rate
        self._define_loss(loss, feature_matching, label_smoothing)
        with tf.name_scope("Optimizer"):
            gen_optimizer = optimizer(learning_rate=learning_rate_gen)
            self._gen_optimizer = gen_optimizer.minimize(
                self._gen_loss,
                var_list=self._get_vars("Generator"),
                name="Generator")
            adversarial_optimizer = optimizer(
                learning_rate=learning_rate_adversarial)
            self._adversarial_optimizer = adversarial_optimizer.minimize(
                self._adversarial_loss,
                var_list=self._get_vars("Adversarial"),
                name="Adversarial")

            if self._is_cycle_consistent:
                aux_optimizer = optimizer(learning_rate=learning_rate_gen)
                self._aux_optimizer = aux_optimizer.minimize(
                    self._aux_loss,
                    var_list=self._get_vars(scope="Generator") +
                    self._get_vars(scope="Auxiliary"),
                    name="Auxiliary")

            self._gen_grads_and_vars = gen_optimizer.compute_gradients(
                self._gen_loss)
            self._adversarial_grads_and_vars = adversarial_optimizer.compute_gradients(
                self._adversarial_loss)
        self._summarise(logged_images=logged_images,
                        logged_labels=logged_labels)

    def _define_loss(self, loss, feature_matching, label_smoothing):
        possible_losses = ["cross-entropy", "L1", "L2", "wasserstein", "KL"]

        def get_labels_one(tensor):
            return tf.ones_like(tensor) * label_smoothing

        eps = 1e-7
        if loss == "cross-entropy":
            self._logits_real = tf.math.log(
                self._output_adversarial_real /
                (1 + eps - self._output_adversarial_real) + eps)
            self._logits_fake = tf.math.log(
                self._output_adversarial_fake /
                (1 + eps - self._output_adversarial_fake) + eps)

            self._gen_loss = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=get_labels_one(self._logits_fake),
                    logits=self._logits_fake))
            self._adversarial_loss = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=get_labels_one(self._logits_real),
                    logits=self._logits_real) +
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.zeros_like(self._logits_fake),
                    logits=self._logits_fake))

        elif loss == "L1":
            self._gen_loss = tf.reduce_mean(
                tf.abs(self._output_adversarial_fake -
                       get_labels_one(self._output_adversarial_fake)))
            self._adversarial_loss = (tf.reduce_mean(
                tf.abs(self._output_adversarial_real -
                       get_labels_one(self._output_adversarial_real)) +
                tf.abs(self._output_adversarial_fake))) / 2.0

        elif loss == "L2":
            self._gen_loss = tf.reduce_mean(
                tf.square(self._output_adversarial_fake -
                          get_labels_one(self._output_adversarial_fake)))
            self._adversarial_loss = (tf.reduce_mean(
                tf.square(self._output_adversarial_real -
                          get_labels_one(self._output_adversarial_real)) +
                tf.square(self._output_adversarial_fake))) / 2.0
        elif loss == "wasserstein":
            self._gen_loss = -tf.reduce_mean(self._output_adversarial_fake)
            self._adversarial_loss = (
                -(tf.reduce_mean(self._output_adversarial_real) -
                  tf.reduce_mean(self._output_adversarial_fake)) +
                10 * self._define_gradient_penalty())
        elif loss == "KL":
            self._logits_real = tf.math.log(
                self._output_adversarial_real /
                (1 + eps - self._output_adversarial_real) + eps)
            self._logits_fake = tf.math.log(
                self._output_adversarial_fake /
                (1 + eps - self._output_adversarial_fake) + eps)

            self._gen_loss = -tf.reduce_mean(self._logits_fake)
            self._adversarial_loss = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=get_labels_one(self._logits_real),
                    logits=self._logits_real) +
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.zeros_like(self._logits_fake),
                    logits=self._logits_fake))
        else:
            raise ValueError(
                "Loss not implemented. Choose from {}. Given: {}.".format(
                    possible_losses, loss))

        if feature_matching:
            self._is_feature_matching = True
            otp_adv_real = self._adversarial.generate_net(
                self._input_real,
                tf_trainflag=self._is_training,
                return_idx=-2)
            otp_adv_fake = self._adversarial.generate_net(
                self._input_fake,
                tf_trainflag=self._is_training,
                return_idx=-2)
            self._gen_loss = tf.reduce_mean(
                tf.square(otp_adv_real - otp_adv_fake))

        if self._is_cycle_consistent:
            self._aux_loss = tf.reduce_mean(
                tf.abs(self._mod_Z_input - self._output_auxiliary))
            self._gen_loss += self._aux_loss

        with tf.name_scope("Loss") as scope:
            tf.summary.scalar("Generator_Loss", self._gen_loss)
            tf.summary.scalar("Adversarial_Loss", self._adversarial_loss)
            if self._is_cycle_consistent:
                tf.summary.scalar("Auxiliary_Loss", self._aux_loss)

    def _define_gradient_penalty(self):
        alpha = tf.random_uniform(shape=tf.shape(self._input_real),
                                  minval=0.,
                                  maxval=1.)
        differences = self._input_fake - self._input_real
        interpolates = self._input_real + (alpha * differences)
        gradients = tf.gradients(self._adversarial.generate_net(interpolates),
                                 [interpolates])[0]
        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients)))
        with tf.name_scope("Loss") as scope:
            self._gradient_penalty = tf.reduce_mean((slopes - 1.)**2)
            tf.summary.scalar("Gradient_penalty", self._gradient_penalty)
        return self._gradient_penalty

    def train(self,
              x_train,
              y_train,
              x_test=None,
              y_test=None,
              epochs=100,
              batch_size=64,
              gen_steps=1,
              adversarial_steps=5,
              log_step=3,
              batch_log_step=None,
              steps=None,
              gpu_options=None):
        if steps is not None:
            gen_steps = 1
            adversarial_steps = steps
        self._set_up_training(log_step=log_step, gpu_options=gpu_options)
        self._set_up_test_train_sample(x_train, y_train, x_test, y_test)
        self._log_results(epoch=0, epoch_time=0)
        nr_batches = np.floor(len(x_train) / batch_size)

        self._dominating_adversarial = 0
        self._gen_out_zero = 0
        for epoch in range(epochs):
            batch_nr = 0
            adversarial_loss_epoch = 0
            gen_loss_epoch = 0
            aux_loss_epoch = 0
            start = time.clock()
            trained_examples = 0
            ii = 0

            while trained_examples < len(x_train):
                adversarial_loss_batch, gen_loss_batch, aux_loss_batch = self._optimize(
                    self._trainset, batch_size, adversarial_steps, gen_steps)
                trained_examples += batch_size

                if np.isnan(adversarial_loss_batch) or np.isnan(
                        gen_loss_batch):
                    print("adversarialLoss / GenLoss: ",
                          adversarial_loss_batch, gen_loss_batch)
                    oar, oaf = self._sess.run(
                        [
                            self._output_adversarial_real,
                            self._output_adversarial_fake
                        ],
                        feed_dict={
                            self._X_input: self.current_batch_x,
                            self._Y_input: self.current_batch_y,
                            self._Z_input: self._Z_noise,
                            self._is_training: True
                        })
                    print(oar)
                    print(oaf)
                    print(np.max(oar))
                    print(np.max(oaf))

                    # self._check_tf_variables(ii, nr_batches)
                    raise GeneratorExit("Nan found.")

                if (batch_log_step is not None) and (ii % batch_log_step == 0):
                    batch_train_time = (time.clock() - start) / 60
                    self._log(int(epoch * nr_batches + ii), batch_train_time)

                adversarial_loss_epoch += adversarial_loss_batch
                gen_loss_epoch += gen_loss_batch
                aux_loss_epoch += aux_loss_batch
                ii += 1

            epoch_train_time = (time.clock() - start) / 60
            adversarial_loss_epoch = np.round(adversarial_loss_epoch, 2)
            gen_loss_epoch = np.round(gen_loss_epoch, 2)

            print("Epoch {}: Adversarial: {}.".format(epoch + 1,
                                                      adversarial_loss_epoch))
            print("\t\t\tGenerator: {}.".format(gen_loss_epoch))
            print("\t\t\tEncoder: {}.".format(aux_loss_epoch))

            if self._log_step is not None:
                self._log(epoch + 1, epoch_train_time)

            # self._check_tf_variables(epoch, epochs)

    def _optimize(self, dataset, batch_size, adversarial_steps, gen_steps):
        for i in range(adversarial_steps):
            current_batch_x, current_batch_y = dataset.get_next_batch(
                batch_size)
            # self.current_batch_x, self.current_batch_y = current_batch_x, current_batch_y
            self._Z_noise = self.sample_noise(n=len(current_batch_x))
            _, adversarial_loss_batch = self._sess.run(
                [self._adversarial_optimizer, self._adversarial_loss],
                feed_dict={
                    self._X_input: current_batch_x,
                    self._Y_input: current_batch_y,
                    self._Z_input: self._Z_noise,
                    self._is_training: True
                })

        aux_loss_batch = 0
        for _ in range(gen_steps):
            Z_noise = self._generator.sample_noise(n=len(current_batch_x))
            if not self._is_feature_matching:
                _, gen_loss_batch = self._sess.run(
                    [self._gen_optimizer, self._gen_loss],
                    feed_dict={
                        self._Z_input: Z_noise,
                        self._Y_input: current_batch_y,
                        self._is_training: True
                    })
            else:
                _, gen_loss_batch = self._sess.run(
                    [self._gen_optimizer, self._gen_loss],
                    feed_dict={
                        self._X_input: current_batch_x,
                        self._Y_input: current_batch_y,
                        self._Z_input: self._Z_noise,
                        self._is_training: True
                    })
            if self._is_cycle_consistent:
                _, aux_loss_batch = self._sess.run(
                    [self._aux_optimizer, self._aux_loss],
                    feed_dict={
                        self._Z_input: Z_noise,
                        self._Y_input: current_batch_y,
                        self._is_training: True
                    })

        return adversarial_loss_batch, gen_loss_batch, aux_loss_batch

    def predict(self, inpt_x, inpt_y):
        inpt = self._sess.run(self._input_real,
                              feed_dict={
                                  self._X_input: inpt_x,
                                  self._Y_input: inpt_y,
                                  self._is_training: True
                              })
        return self._adversarial.predict(inpt, self._sess)

    def _check_tf_variables(self, batch_nr, nr_batches):
        Z_noise = self._generator.sample_noise(n=len(self._x_test))
        gen_grads = [
            self._sess.run(gen_gv[0],
                           feed_dict={
                               self._X_input: self._x_test,
                               self._Y_input: self._y_test,
                               self._Z_input: Z_noise,
                               self._is_training: False
                           }) for gen_gv in self._gen_grads_and_vars
        ]
        adversarial_grads = [
            self._sess.run(adversarial_gv[0],
                           feed_dict={
                               self._X_input: self._x_test,
                               self._Y_input: self._y_test,
                               self._Z_input: Z_noise,
                               self._is_training: False
                           })
            for adversarial_gv in self._adversarial_grads_and_vars
        ]
        gen_grads_maxis = [np.max(gv) for gv in gen_grads]
        gen_grads_means = [np.mean(gv) for gv in gen_grads]
        gen_grads_minis = [np.min(gv) for gv in gen_grads]
        adversarial_grads_maxis = [np.max(dv) for dv in adversarial_grads]
        adversarial_grads_means = [np.mean(dv) for dv in adversarial_grads]
        adversarial_grads_minis = [np.min(dv) for dv in adversarial_grads]

        real_logits, fake_logits, gen_out = self._sess.run(
            [
                self._output_adversarial_real, self._output_adversarial_fake,
                self._output_gen
            ],
            feed_dict={
                self._X_input: self._x_test,
                self._Y_input: self._y_test,
                self._Z_input: Z_noise,
                self._is_training: False
            })
        real_logits = np.mean(real_logits)
        fake_logits = np.mean(fake_logits)

        gen_varsis = np.array([
            x.eval(session=self._sess)
            for x in self._generator.get_network_params()
        ])
        adversarial_varsis = np.array([
            x.eval(session=self._sess)
            for x in self._adversarial.get_network_params()
        ])
        gen_maxis = np.array([np.max(x) for x in gen_varsis])
        adversarial_maxis = np.array([np.max(x) for x in adversarial_varsis])
        gen_means = np.array([np.mean(x) for x in gen_varsis])
        adversarial_means = np.array([np.mean(x) for x in adversarial_varsis])
        gen_minis = np.array([np.min(x) for x in gen_varsis])
        adversarial_minis = np.array([np.min(x) for x in adversarial_varsis])

        print(batch_nr, "/", nr_batches, ":")
        print("adversarialReal / adversarialFake: ", real_logits, fake_logits)
        print("GenWeight Max / Mean / Min: ", np.max(gen_maxis),
              np.mean(gen_means), np.min(gen_minis))
        print("GenGrads Max / Mean / Min: ", np.max(gen_grads_maxis),
              np.mean(gen_grads_means), np.min(gen_grads_minis))
        print("adversarialWeight Max / Mean / Min: ",
              np.max(adversarial_maxis), np.mean(adversarial_means),
              np.min(adversarial_minis))
        print("adversarialGrads Max / Mean / Min: ",
              np.max(adversarial_grads_maxis),
              np.mean(adversarial_grads_means),
              np.min(adversarial_grads_minis))
        print("GenOut Max / Mean / Min: ", np.max(gen_out), np.mean(gen_out),
              np.min(gen_out))
        print("\n")

        if real_logits > 0.99 and fake_logits < 0.01:
            self._dominating_adversarial += 1
            if self._dominating_adversarial == 5:
                raise GeneratorExit("Dominating adversarialriminator!")
        else:
            self._dominating_adversarial = 0

        print(np.max(gen_out))
        print(np.max(gen_out) < 0.05)
        if np.max(gen_out) < 0.05:
            self._gen_out_zero += 1
            print(self._gen_out_zero)
            if self._gen_out_zero == 50:
                raise GeneratorExit("Generator outputs zeros")
        else:
            self._gen_out_zero = 0
        print(self._gen_out_zero)
Exemple #17
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    discriminator = Discriminator()
    discriminator.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        raise Exception('This is not implemented')
        encoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))

    """
    variable definition
    """

    X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)

    style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim)

    """
    loss definitions
    """
    cross_entropy_loss = nn.CrossEntropyLoss()
    adversarial_loss = nn.BCELoss()

    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()
        discriminator.cuda()

        cross_entropy_loss.cuda()
        adversarial_loss.cuda()

        X_1 = X_1.cuda()
        X_2 = X_2.cuda()
        X_3 = X_3.cuda()

        style_latent_space = style_latent_space.cuda()

    """
    optimizer and scheduler definition
    """
    auto_encoder_optimizer = optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    reverse_cycle_optimizer = optim.Adam(
        list(encoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    generator_optimizer = optim.Adam(
        list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    discriminator_optimizer = optim.Adam(
        list(discriminator.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    # divide the learning rate by a factor of 10 after 80 epochs
    auto_encoder_scheduler = optim.lr_scheduler.StepLR(auto_encoder_optimizer, step_size=80, gamma=0.1)
    reverse_cycle_scheduler = optim.lr_scheduler.StepLR(reverse_cycle_optimizer, step_size=80, gamma=0.1)
    generator_scheduler = optim.lr_scheduler.StepLR(generator_optimizer, step_size=80, gamma=0.1)
    discriminator_scheduler = optim.lr_scheduler.StepLR(discriminator_optimizer, step_size=80, gamma=0.1)

    # Used later to define discriminator ground truths
    Tensor = torch.cuda.FloatTensor if FLAGS.cuda else torch.FloatTensor

    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            headers = ['Epoch', 'Iteration', 'Reconstruction_loss', 'KL_divergence_loss', 'Reverse_cycle_loss']

            if FLAGS.forward_gan:
              headers.extend(['Generator_forward_loss', 'Discriminator_forward_loss'])

            if FLAGS.reverse_gan:
              headers.extend(['Generator_reverse_loss', 'Discriminator_reverse_loss'])

            log.write('\t'.join(headers) + '\n')

    # load data set and create data loader instance
    print('Loading CIFAR paired dataset...')
    paired_cifar = CIFAR_Paired(root='cifar', download=True, train=True, transform=transform_config)
    loader = cycle(DataLoader(paired_cifar, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True))

    # Save a batch of images to use for visualization
    image_sample_1, image_sample_2, _ = next(loader)
    image_sample_3, _, _ = next(loader)

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print('Epoch #' + str(epoch) + '..........................................................................')

        # update the learning rate scheduler
        auto_encoder_scheduler.step()
        reverse_cycle_scheduler.step()
        generator_scheduler.step()
        discriminator_scheduler.step()

        for iteration in range(int(len(paired_cifar) / FLAGS.batch_size)):
            # Adversarial ground truths
            valid = Variable(Tensor(FLAGS.batch_size, 1).fill_(1.0), requires_grad=False)
            fake = Variable(Tensor(FLAGS.batch_size, 1).fill_(0.0), requires_grad=False)

            # A. run the auto-encoder reconstruction
            image_batch_1, image_batch_2, _ = next(loader)

            auto_encoder_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_mu_1, style_logvar_1, class_latent_space_1 = encoder(Variable(X_1))
            style_latent_space_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1)

            kl_divergence_loss_1 = FLAGS.kl_divergence_coef * (
                - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp())
            )
            kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_1.backward(retain_graph=True)

            style_mu_2, style_logvar_2, class_latent_space_2 = encoder(Variable(X_2))
            style_latent_space_2 = reparameterize(training=True, mu=style_mu_2, logvar=style_logvar_2)

            kl_divergence_loss_2 = FLAGS.kl_divergence_coef * (
                - 0.5 * torch.sum(1 + style_logvar_2 - style_mu_2.pow(2) - style_logvar_2.exp())
            )
            kl_divergence_loss_2 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_2.backward(retain_graph=True)

            reconstructed_X_1 = decoder(style_latent_space_1, class_latent_space_2)
            reconstructed_X_2 = decoder(style_latent_space_2, class_latent_space_1)

            reconstruction_error_1 = FLAGS.reconstruction_coef * mse_loss(reconstructed_X_1, Variable(X_1))
            reconstruction_error_1.backward(retain_graph=True)

            reconstruction_error_2 = FLAGS.reconstruction_coef * mse_loss(reconstructed_X_2, Variable(X_2))
            reconstruction_error_2.backward()

            reconstruction_error = (reconstruction_error_1 + reconstruction_error_2) / FLAGS.reconstruction_coef
            kl_divergence_error = (kl_divergence_loss_1 + kl_divergence_loss_2) / FLAGS.kl_divergence_coef

            auto_encoder_optimizer.step()

            # A-1. Discriminator training during forward cycle
            if FLAGS.forward_gan:
              # Training generator
              generator_optimizer.zero_grad()

              g_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), valid)
              g_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), valid)

              gen_f_loss = (g_loss_1 + g_loss_2) / 2.0
              gen_f_loss.backward()

              generator_optimizer.step()

              # Training discriminator
              discriminator_optimizer.zero_grad()

              real_loss_1 = adversarial_loss(discriminator(Variable(X_1)), valid)
              real_loss_2 = adversarial_loss(discriminator(Variable(X_2)), valid)
              fake_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), fake)
              fake_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), fake)

              dis_f_loss = (real_loss_1 + real_loss_2 + fake_loss_1 + fake_loss_2) / 4.0
              dis_f_loss.backward()

              discriminator_optimizer.step()

            # B. reverse cycle
            image_batch_1, _, __ = next(loader)
            image_batch_2, _, __ = next(loader)

            reverse_cycle_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_latent_space.normal_(0., 1.)

            _, __, class_latent_space_1 = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))

            reconstructed_X_1 = decoder(Variable(style_latent_space), class_latent_space_1.detach())
            reconstructed_X_2 = decoder(Variable(style_latent_space), class_latent_space_2.detach())

            style_mu_1, style_logvar_1, _ = encoder(reconstructed_X_1)
            style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1)

            style_mu_2, style_logvar_2, _ = encoder(reconstructed_X_2)
            style_latent_space_2 = reparameterize(training=False, mu=style_mu_2, logvar=style_logvar_2)

            reverse_cycle_loss = FLAGS.reverse_cycle_coef * l1_loss(style_latent_space_1, style_latent_space_2)
            reverse_cycle_loss.backward()
            reverse_cycle_loss /= FLAGS.reverse_cycle_coef

            reverse_cycle_optimizer.step()

            # B-1. Discriminator training during reverse cycle
            if FLAGS.reverse_gan:
              # Training generator
              generator_optimizer.zero_grad()

              g_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), valid)
              g_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), valid)

              gen_r_loss = (g_loss_1 + g_loss_2) / 2.0
              gen_r_loss.backward()

              generator_optimizer.step()

              # Training discriminator
              discriminator_optimizer.zero_grad()

              real_loss_1 = adversarial_loss(discriminator(Variable(X_1)), valid)
              real_loss_2 = adversarial_loss(discriminator(Variable(X_2)), valid)
              fake_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), fake)
              fake_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), fake)

              dis_r_loss = (real_loss_1 + real_loss_2 + fake_loss_1 + fake_loss_2) / 4.0
              dis_r_loss.backward()

              discriminator_optimizer.step()

            if (iteration + 1) % 10 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0]))
                print('KL-Divergence loss: ' + str(kl_divergence_error.data.storage().tolist()[0]))
                print('Reverse cycle loss: ' + str(reverse_cycle_loss.data.storage().tolist()[0]))

                if FLAGS.forward_gan:
                  print('Generator F loss: ' + str(gen_f_loss.data.storage().tolist()[0]))
                  print('Discriminator F loss: ' + str(dis_f_loss.data.storage().tolist()[0]))

                if FLAGS.reverse_gan:
                  print('Generator R loss: ' + str(gen_r_loss.data.storage().tolist()[0]))
                  print('Discriminator R loss: ' + str(dis_r_loss.data.storage().tolist()[0]))

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                row = []

                row.append(epoch)
                row.append(iteration)
                row.append(reconstruction_error.data.storage().tolist()[0])
                row.append(kl_divergence_error.data.storage().tolist()[0])
                row.append(reverse_cycle_loss.data.storage().tolist()[0])

                if FLAGS.forward_gan:
                  row.append(gen_f_loss.data.storage().tolist()[0])
                  row.append(dis_f_loss.data.storage().tolist()[0])

                if FLAGS.reverse_gan:
                  row.append(gen_r_loss.data.storage().tolist()[0])
                  row.append(dis_r_loss.data.storage().tolist()[0])

                row = [str(x) for x in row]
                log.write('\t'.join(row) + '\n')

            # write to tensorboard
            writer.add_scalar('Reconstruction loss', reconstruction_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('KL-Divergence loss', kl_divergence_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('Reverse cycle loss', reverse_cycle_loss.data.storage().tolist()[0],
                              epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)

            if FLAGS.forward_gan:
              writer.add_scalar('Generator F loss', gen_f_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
              writer.add_scalar('Discriminator F loss', dis_f_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)

            if FLAGS.reverse_gan:
              writer.add_scalar('Generator R loss', gen_r_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
              writer.add_scalar('Discriminator R loss', dis_r_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)

        # save model after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save))

            """
            save reconstructed images and style swapped image generations to check progress
            """

            X_1.copy_(image_sample_1)
            X_2.copy_(image_sample_2)
            X_3.copy_(image_sample_3)

            style_mu_1, style_logvar_1, _ = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))
            style_mu_3, style_logvar_3, _ = encoder(Variable(X_3))

            style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1)
            style_latent_space_3 = reparameterize(training=False, mu=style_mu_3, logvar=style_logvar_3)

            reconstructed_X_1_2 = decoder(style_latent_space_1, class_latent_space_2)
            reconstructed_X_3_2 = decoder(style_latent_space_3, class_latent_space_2)

            # save input image batch
            image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              image_batch = np.concatenate((image_batch, image_batch, image_batch), axis=3)
            imshow_grid(image_batch, name=str(epoch) + '_original', save=True)

            # save reconstructed batch
            reconstructed_x = np.transpose(reconstructed_X_1_2.cpu().data.numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              reconstructed_x = np.concatenate((reconstructed_x, reconstructed_x, reconstructed_x), axis=3)
            imshow_grid(reconstructed_x, name=str(epoch) + '_target', save=True)

            style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              style_batch = np.concatenate((style_batch, style_batch, style_batch), axis=3)
            imshow_grid(style_batch, name=str(epoch) + '_style', save=True)

            # save style swapped reconstructed batch
            reconstructed_style = np.transpose(reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              reconstructed_style = np.concatenate((reconstructed_style, reconstructed_style, reconstructed_style), axis=3)
            imshow_grid(reconstructed_style, name=str(epoch) + '_style_target', save=True)
Exemple #18
0
class VAEGAN(GenerativeModel):
    def __init__(self,
                 x_dim,
                 z_dim,
                 enc_architecture,
                 gen_architecture,
                 disc_architecture,
                 folder="./VAEGAN"):
        super(VAEGAN, self).__init__(
            x_dim, z_dim,
            [enc_architecture, gen_architecture, disc_architecture], folder)

        self._enc_architecture = self._architectures[0]
        self._gen_architecture = self._architectures[1]
        self._disc_architecture = self._architectures[2]

        ################# Define architecture
        last_layer_mean = [
            logged_dense, {
                "units": z_dim,
                "activation": tf.identity,
                "name": "Mean"
            }
        ]
        self._encoder_mean = Encoder(self._enc_architecture +
                                     [last_layer_mean],
                                     name="Encoder")
        last_layer_std = [
            logged_dense, {
                "units": z_dim,
                "activation": tf.identity,
                "name": "Std"
            }
        ]
        self._encoder_std = Encoder(self._enc_architecture + [last_layer_std],
                                    name="Encoder")

        self._gen_architecture[-1][1]["name"] = "Output"
        self._generator = Generator(self._gen_architecture, name="Generator")

        self._disc_architecture.append(
            [tf.layers.flatten, {
                "name": "Flatten"
            }])
        self._disc_architecture.append([
            logged_dense, {
                "units": 1,
                "activation": tf.nn.sigmoid,
                "name": "Output"
            }
        ])
        self._discriminator = Discriminator(self._disc_architecture,
                                            name="Discriminator")

        self._nets = [self._encoder_mean, self._generator, self._discriminator]

        ################# Connect inputs and networks
        self._mean_layer = self._encoder_mean.generate_net(self._X_input)
        self._std_layer = self._encoder_std.generate_net(self._X_input)

        self._output_enc_with_noise = self._mean_layer + tf.exp(
            0.5 * self._std_layer) * self._Z_input

        self._output_gen = self._generator.generate_net(
            self._output_enc_with_noise)
        self._output_gen_from_encoding = self._generator.generate_net(
            self._Z_input)

        assert self._output_gen.get_shape()[1:] == x_dim, (
            "Generator output must have shape of x_dim. Given: {}. Expected: {}."
            .format(self._output_gen.get_shape(), x_dim))

        self._output_disc_real = self._discriminator.generate_net(
            self._X_input)
        self._output_disc_fake_from_real = self._discriminator.generate_net(
            self._output_gen)
        self._output_disc_fake_from_latent = self._discriminator.generate_net(
            self._output_gen_from_encoding)

        ################# Finalize
        self._init_folders()
        self._verify_init()

    def compile(self,
                learning_rate=0.0001,
                optimizer=tf.train.AdamOptimizer,
                label_smoothing=1,
                gamma=1):
        self._define_loss(label_smoothing=label_smoothing, gamma=gamma)
        with tf.name_scope("Optimizer"):
            enc_optimizer = optimizer(learning_rate=learning_rate)
            self._enc_optimizer = enc_optimizer.minimize(
                self._enc_loss,
                var_list=self._get_vars("Encoder"),
                name="Encoder")
            gen_optimizer = optimizer(learning_rate=learning_rate)
            self._gen_optimizer = gen_optimizer.minimize(
                self._gen_loss,
                var_list=self._get_vars("Generator"),
                name="Generator")
            disc_optimizer = optimizer(learning_rate=learning_rate)
            self._disc_optimizer = disc_optimizer.minimize(
                self._disc_loss,
                var_list=self._get_vars("Discriminator"),
                name="Discriminator")
        self._summarise()

    def _define_loss(self, label_smoothing, gamma):
        def get_labels_one(tensor):
            return tf.ones_like(tensor) * label_smoothing

        eps = 1e-7
        ## Kullback-Leibler divergence
        self._KLdiv = 0.5 * (tf.square(self._mean_layer) +
                             tf.exp(self._std_layer) - self._std_layer - 1)
        self._KLdiv = tf.reduce_mean(self._KLdiv)

        ## Feature matching loss
        otp_disc_real = self._discriminator.generate_net(
            self._X_input, tf_trainflag=self._is_training, return_idx=-2)
        otp_disc_fake = self._discriminator.generate_net(
            self._output_gen, tf_trainflag=self._is_training, return_idx=-2)
        self._feature_loss = tf.reduce_mean(
            tf.square(otp_disc_real - otp_disc_fake))

        ## Discriminator loss
        self._logits_real = tf.math.log(self._output_disc_real /
                                        (1 + eps - self._output_disc_real) +
                                        eps)
        self._logits_fake_from_real = tf.math.log(
            self._output_disc_fake_from_real /
            (1 + eps - self._output_disc_fake_from_real) + eps)
        self._logits_fake_from_latent = tf.math.log(
            self._output_disc_fake_from_latent /
            (1 + eps - self._output_disc_fake_from_latent) + eps)
        self._generator_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=get_labels_one(self._logits_fake_from_real),
                logits=self._logits_fake_from_real) +
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=get_labels_one(self._logits_fake_from_latent),
                logits=self._logits_fake_from_latent))
        self._discriminator_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(labels=get_labels_one(
                self._logits_real),
                                                    logits=self._logits_real) +
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.zeros_like(self._logits_fake_from_real),
                logits=self._logits_fake_from_real) +
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.zeros_like(self._logits_fake_from_latent),
                logits=self._logits_fake_from_latent))

        with tf.name_scope("Loss") as scope:

            self._enc_loss = self._KLdiv + self._feature_loss
            self._gen_loss = self._feature_loss + self._generator_loss
            self._disc_loss = self._discriminator_loss

            tf.summary.scalar("Encoder", self._enc_loss)
            tf.summary.scalar("Generator", self._gen_loss)
            tf.summary.scalar("Discriminator", self._disc_loss)

    def train(self,
              x_train,
              x_test,
              epochs=100,
              batch_size=64,
              disc_steps=5,
              gen_steps=1,
              log_step=3):
        self._set_up_training(log_step=log_step)
        self._set_up_test_train_sample(x_train, x_test)
        for epoch in range(epochs):
            batch_nr = 0
            disc_loss_epoch = 0
            gen_loss_epoch = 0
            enc_loss_epoch = 0
            start = time.clock()
            trained_examples = 0
            while trained_examples < len(x_train):
                disc_loss_batch, gen_loss_batch, enc_loss_batch = self._optimize(
                    self._trainset, batch_size, disc_steps, gen_steps)
                trained_examples += batch_size
                disc_loss_epoch += disc_loss_batch
                gen_loss_epoch += gen_loss_batch
                enc_loss_epoch += enc_loss_batch

            epoch_train_time = (time.clock() - start) / 60
            disc_loss_epoch = np.round(disc_loss_epoch, 2)
            gen_loss_epoch = np.round(gen_loss_epoch, 2)
            enc_loss_epoch = np.round(enc_loss_epoch, 2)

            print("Epoch {}: D: {}; G: {}; E: {}.".format(
                epoch, disc_loss_epoch, gen_loss_epoch, enc_loss_epoch))

            if log_step is not None:
                self._log(epoch, epoch_train_time)

    def _optimize(self, dataset, batch_size, disc_steps, gen_steps):
        for i in range(disc_steps):
            current_batch_x = dataset.get_next_batch(batch_size)
            Z_noise = self._generator.sample_noise(n=len(current_batch_x))
            _, disc_loss_batch = self._sess.run(
                [self._disc_optimizer, self._disc_loss],
                feed_dict={
                    self._X_input: current_batch_x,
                    self._Z_input: Z_noise
                })

        for i in range(gen_steps):
            Z_noise = self._generator.sample_noise(n=len(current_batch_x))
            _, gen_loss_batch = self._sess.run(
                [self._gen_optimizer, self._gen_loss],
                feed_dict={
                    self._X_input: current_batch_x,
                    self._Z_input: Z_noise
                })
            _, enc_loss_batch = self._sess.run(
                [self._enc_optimizer, self._enc_loss],
                feed_dict={
                    self._X_input: current_batch_x,
                    self._Z_input: Z_noise
                })

        return disc_loss_batch, gen_loss_batch, enc_loss_batch
Exemple #19
0
    def __init__(self, hyperparameters):
        super(LSGANs_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.encoder = Encoder(hyperparameters['input_dim_a'],
                               hyperparameters['gen'])
        self.decoder = Decoder(hyperparameters['input_dim_a'],
                               hyperparameters['gen'])
        self.dis_a = Discriminator()
        self.dis_b = Discriminator()
        self.interp_net_ab = Interpolator()
        self.interp_net_ba = Interpolator()
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        enc_params = list(self.encoder.parameters())
        dec_params = list(self.decoder.parameters())
        dis_a_params = list(self.dis_a.parameters())
        dis_b_params = list(self.dis_b.parameters())
        interperlator_ab_params = list(self.interp_net_ab.parameters())
        interperlator_ba_params = list(self.interp_net_ba.parameters())

        self.enc_opt = torch.optim.Adam(
            [p for p in enc_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dec_opt = torch.optim.Adam(
            [p for p in dec_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_a_opt = torch.optim.Adam(
            [p for p in dis_a_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_b_opt = torch.optim.Adam(
            [p for p in dis_b_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.interp_ab_opt = torch.optim.Adam(
            [p for p in interperlator_ab_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.interp_ba_opt = torch.optim.Adam(
            [p for p in interperlator_ba_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])

        self.enc_scheduler = get_scheduler(self.enc_opt, hyperparameters)
        self.dec_scheduler = get_scheduler(self.dec_opt, hyperparameters)
        self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters)
        self.dis_b_scheduler = get_scheduler(self.dis_b_opt, hyperparameters)
        self.interp_ab_scheduler = get_scheduler(self.interp_ab_opt,
                                                 hyperparameters)
        self.interp_ba_scheduler = get_scheduler(self.interp_ba_opt,
                                                 hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        self.total_loss = 0
        self.best_iter = 0
        self.perceptural_loss = Perceptural_loss()
Exemple #20
0
class BiGAN(object):
    def __init__(self, args):

        self.z_dim = args.z_dim
        self.decay_rate = args.decay_rate
        self.learning_rate = args.learning_rate
        self.model_name = args.model_name
        self.batch_size = args.batch_size

        #initialize networks
        self.Generator = Generator(self.z_dim).cuda()
        self.Encoder = Encoder(self.z_dim).cuda()
        self.Discriminator = Discriminator().cuda()

        #set optimizers for all networks
        self.optimizer_G_E = torch.optim.Adam(
            list(self.Generator.parameters()) +
            list(self.Encoder.parameters()),
            lr=self.learning_rate,
            betas=(0.5, 0.999))

        self.optimizer_D = torch.optim.Adam(self.Discriminator.parameters(),
                                            lr=self.learning_rate,
                                            betas=(0.5, 0.999))

        #initialize network weights
        self.Generator.apply(weights_init)
        self.Encoder.apply(weights_init)
        self.Discriminator.apply(weights_init)

    def train(self, data):

        self.Generator.train()
        self.Encoder.train()
        self.Discriminator.train()

        self.optimizer_G_E.zero_grad()
        self.optimizer_D.zero_grad()

        #get fake z_data for generator
        self.z_fake = torch.randn((self.batch_size, self.z_dim))

        #send fake z_data through generator to get fake x_data
        self.x_fake = self.Generator(self.z_fake.detach())

        #send real data through encoder to get real z_data
        self.z_real = self.Encoder(data)

        #send real x and z data into discriminator
        self.out_real = self.Discriminator(data, z_real.detach())

        #send fake x and z data into discriminator
        self.out_fake = self.Discriminator(x_fake.detach(), z_fake.detach())

        #compute discriminator loss
        self.D_loss = nn.BCELoss()

        #compute generator/encoder loss
        self.G_E_loss = nn.BCELoss()

        #compute discriminator gradiants and backpropogate
        self.D_loss.backward()
        self.optimizer_D.step()

        #compute generator/encoder gradiants and backpropogate
        self.G_E_loss.backward()
        self.optimizer_G_E.step()
Exemple #21
0
	def __init__(self, args, lr=0.1, latent_dim=8, lambda_latent=0.5,
					lambda_kl= 0.001, lambda_recon= 10, is_train = True,  ):
		## Parameters 
		self.batch_size = args.batch_size
		self.latent_dim = latent_dim
		self.image_size = args.img_size
		self.lambda_kl = lambda_kl 
		self.lambda_recon = lambda_recon
		self.lambda_latent = lambda_latent
		self.is_train = tf.placeholder(tf.bool, name= 'is_training')
		self.lr = tf.placeholder(tf.float32, name='learning_rate')
		self.A = tf.placeholder(tf.float32, [self.batch_size, self.image_size,
								self.image_size, 3], name= 'A') 
		self.B = tf.placeholder(tf.float32, [self.batch_size, self.image_size,
								self.image_size, 3], name= 'B')
		self.z = tf.placeholder(tf.float32, [self.batch_size, self.latent_dim], 
								name= 'z')

		## Augmentation
		def aug_img(image):
			aug_strength = 30
			aug_size = self.image_size + aug_strength
			image_resized = tf.image.resize_images(image, [aug_size, aug_size])
			image_cropped = tf.random_crop(image_resized, [self.batch_size, self.image_size,
								self.image_size, 3])
			## work-around as tf-flip doesn't support 4D-batch
			image_flipped = tf.map_fn(lambda image_iter: tf.image.random_flip_left_right(image_iter), image_cropped)
			return image_flipped
		A = tf.cond(self.is_train,
					 lambda: aug_img(self.A), lambda: self.A)
		B = tf.cond(self.is_train, 
					lambda: aug_img(self.B), lambda: self.B)
		## Generator
		with tf.variable_scope('generator'):
			Gen = Generator(self.image_size, self.is_train)

		## Discriminator
		with tf.variable_scope('discriminator'):
			Disc = Discriminator(self.image_size, self.is_train)

		## Encoder
		with tf.variable_scope('encoder'):
			Enc = Encoder(self.image_size, self.is_train, self.latent_dim)

		## cVAE-GAN
		with tf.variable_scope('encoder'):
			z_enc, z_enc_mu, z_enc_log_sigma = Enc(B)
		
		with tf.variable_scope('generator'):
			self.B_hat_enc = Gen(A, z_enc)

		## cLR-GAN 
		with tf.variable_scope('generator', reuse=True):
			self.B_hat = Gen(A, self.z)
		with tf.variable_scope('encoder', reuse= True):
			z_hat, z_hat_mu, z_hat_log_sigma = Enc(self.B_hat)

		## Disc
		with tf.variable_scope('discriminator'):
			self.real = Disc(B)
		with tf.variable_scope('discriminator', reuse=True):
			self.fake = Disc(self.B_hat)
			self.fake_enc = Disc(self.B_hat_enc)

		## losses
		self.vae_gan_cost = tf.reduce_mean(tf.squared_difference(self.real, 0.9)) + \
						tf.reduce_mean(tf.square(self.fake_enc))
		self.recon_img_cost = tf.reduce_mean(tf.abs(B - self.B_hat_enc))
		self.gan_cost = tf.reduce_mean(tf.squared_difference(self.real, 0.9)) + \
					tf.reduce_mean(tf.square(self.fake))
		self.recon_latent_cost = tf.reduce_mean(tf.abs(self.z-z_hat))
		self.kl_div_cost =  -0.5*tf.reduce_mean(1 + 2*z_enc_log_sigma - z_enc_mu**2 -\
							tf.exp(2* z_enc_log_sigma))
		self.vec_cost = [self.vae_gan_cost, self.recon_img_cost, self.gan_cost, self.recon_latent_cost, 
						self.kl_div_cost]
		weight_vec = [1, -self.lambda_recon, 1, -self.lambda_latent, self.lambda_kl]

		self.cost = tf.reduce_sum([self.vec_cost[i]* weight_vec[i] for i in range(len(self.vec_cost)) ])

		## Optimizers
		update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
		self.optim_gen = tf.train.AdamOptimizer(self.lr, beta1=0.5)
		self.optim_disc = tf.train.AdamOptimizer(self.lr, beta1=0.5)
		self.optim_enc = tf.train.AdamOptimizer(self.lr, beta1=0.5)
		
		## Collecting the trainalbe variables
		gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='bc_gan/generator')
		disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='bc_gan/discriminator')
		enc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='bc_gan/encoder')

		## Defining the training operation 
		with tf.control_dependencies(update_ops):
			self.train_op_gen = self.optim_gen.minimize(-self.cost, var_list=gen_vars)
			self.train_op_disc = self.optim_disc.minimize(self.cost, var_list= disc_vars)
			self.train_op_enc = self.optim_enc.minimize(-self.cost, var_list=enc_vars)

		## Joing the training ops 
		self.train_ops = [self.train_op_gen, self.train_op_disc, self.train_op_enc]
		## Summary Create
		def summary_create(self):
			## Image summaries
			tf.summary.image('A', self.A[0:1])
			tf.summary.image('B', self.B[0:1])
			tf.summary.image('B^', self.B_hat[0:1])
			tf.summary.image('B^-enc', self.B_hat_enc[0:1])
			## GEN - DISC summaries - min max game  
			tf.summary.scalar('fake', tf.reduce_mean(self.fake))
			tf.summary.scalar('fake_enc', tf.reduce_mean(self.fake_enc))
			tf.summary.scalar('real', tf.reduce_mean(self.real))
			tf.summary.scalar('learning_rate', self.lr)
			## cost summaries		
			tf.summary.scalar('cost_vae_gan', self.vae_gan_cost)
			tf.summary.scalar('cost_recon_img', self.recon_img_cost)
			tf.summary.scalar('cost_gan_cost', self.gan_cost)
			tf.summary.scalar('cost_recon_latent', self.recon_latent_cost)
			tf.summary.scalar('cost_kl_div', self.kl_div_cost)
			tf.summary.scalar('cost_final', self.cost)
			## Merge Summaries
			self.merge_op = tf.summary.merge_all()

		summary_create(self)
Exemple #22
0
    def __init__(self,
                 x_dim,
                 y_dim,
                 z_dim,
                 enc_architecture,
                 gen_architecture,
                 adversarial_architecture,
                 folder="./CVAEGAN",
                 is_patchgan=False,
                 is_wasserstein=False):
        super(CVAEGAN, self).__init__(
            x_dim, y_dim,
            [enc_architecture, gen_architecture, adversarial_architecture],
            folder)

        self._z_dim = z_dim
        with tf.name_scope("Inputs"):
            self._Z_input = tf.placeholder(tf.float32,
                                           shape=[None, z_dim],
                                           name="z")

        self._enc_architecture = self._architectures[0]
        self._gen_architecture = self._architectures[1]
        self._adv_architecture = self._architectures[2]

        self._is_patchgan = is_patchgan
        self._is_wasserstein = is_wasserstein
        self._is_feature_matching = False

        ################# Define architecture
        if self._is_patchgan:
            f_xy = self._adv_architecture[-1][-1]["filters"]
            assert f_xy == 1, "If is PatchGAN, last layer of adversarial_XY needs 1 filter. Given: {}.".format(
                f_xy)

            a_xy = self._adv_architecture[-1][-1]["activation"]
            if self._is_wasserstein:
                assert a_xy == tf.identity, "If is PatchGAN, last layer of adversarial needs tf.identity. Given: {}.".format(
                    a_xy)
            else:
                assert a_xy == tf.nn.sigmoid, "If is PatchGAN, last layer of adversarial needs tf.nn.sigmoid. Given: {}.".format(
                    a_xy)
        else:
            self._adv_architecture.append(
                [tf.layers.flatten, {
                    "name": "Flatten"
                }])
            if self._is_wasserstein:
                self._adv_architecture.append([
                    logged_dense, {
                        "units": 1,
                        "activation": tf.identity,
                        "name": "Output"
                    }
                ])
            else:
                self._adv_architecture.append([
                    logged_dense, {
                        "units": 1,
                        "activation": tf.nn.sigmoid,
                        "name": "Output"
                    }
                ])

        last_layers_mean = [[tf.layers.flatten, {
            "name": "flatten"
        }],
                            [
                                logged_dense, {
                                    "units": z_dim,
                                    "activation": tf.identity,
                                    "name": "Mean"
                                }
                            ]]
        self._encoder_mean = Encoder(self._enc_architecture + last_layers_mean,
                                     name="Encoder")
        last_layers_std = [[tf.layers.flatten, {
            "name": "flatten"
        }],
                           [
                               logged_dense, {
                                   "units": z_dim,
                                   "activation": tf.identity,
                                   "name": "Std"
                               }
                           ]]
        self._encoder_std = Encoder(self._enc_architecture + last_layers_std,
                                    name="Encoder")

        self._gen_architecture[-1][1]["name"] = "Output"
        self._generator = Generator(self._gen_architecture, name="Generator")

        self._adversarial = Discriminator(self._adv_architecture,
                                          name="Adversarial")

        self._nets = [self._encoder_mean, self._generator, self._adversarial]

        ################# Connect inputs and networks
        self._mean_layer = self._encoder_mean.generate_net(self._Y_input)
        self._std_layer = self._encoder_std.generate_net(self._Y_input)

        self._output_enc_with_noise = self._mean_layer + tf.exp(
            0.5 * self._std_layer) * self._Z_input
        with tf.name_scope("Inputs"):
            self._gen_input = image_condition_concat(
                inputs=self._X_input,
                condition=self._output_enc_with_noise,
                name="mod_z_real")
            self._gen_input_from_encoding = image_condition_concat(
                inputs=self._X_input, condition=self._Z_input, name="mod_z")
        self._output_gen = self._generator.generate_net(self._gen_input)
        self._output_gen_from_encoding = self._generator.generate_net(
            self._gen_input_from_encoding)
        self._generator._input_dim = z_dim

        assert self._output_gen.get_shape()[1:] == y_dim, (
            "Generator output must have shape of y_dim. Given: {}. Expected: {}."
            .format(self._output_gen.get_shape(), x_dim))

        with tf.name_scope("InputsAdversarial"):
            self._input_real = tf.concat(values=[self._Y_input, self._X_input],
                                         axis=3)
            self._input_fake_from_real = tf.concat(
                values=[self._output_gen, self._X_input], axis=3)
            self._input_fake_from_latent = tf.concat(
                values=[self._output_gen_from_encoding, self._X_input], axis=3)

        self._output_adv_real = self._adversarial.generate_net(
            self._input_real)
        self._output_adv_fake_from_real = self._adversarial.generate_net(
            self._input_fake_from_real)
        self._output_adv_fake_from_latent = self._adversarial.generate_net(
            self._input_fake_from_latent)

        ################# Finalize
        self._init_folders()
        self._verify_init()

        self._output_label_real = tf.placeholder(
            tf.float32, shape=self._output_adv_real.shape, name="label_real")
        self._output_label_fake = tf.placeholder(
            tf.float32,
            shape=self._output_adv_fake_from_real.shape,
            name="label_fake")

        if self._is_patchgan:
            print("PATCHGAN chosen with output: {}.".format(
                self._output_adv_real.shape))
Exemple #23
0
class CVAEGAN(Image2ImageGenerativeModel):
    def __init__(self,
                 x_dim,
                 y_dim,
                 z_dim,
                 enc_architecture,
                 gen_architecture,
                 adversarial_architecture,
                 folder="./CVAEGAN",
                 is_patchgan=False,
                 is_wasserstein=False):
        super(CVAEGAN, self).__init__(
            x_dim, y_dim,
            [enc_architecture, gen_architecture, adversarial_architecture],
            folder)

        self._z_dim = z_dim
        with tf.name_scope("Inputs"):
            self._Z_input = tf.placeholder(tf.float32,
                                           shape=[None, z_dim],
                                           name="z")

        self._enc_architecture = self._architectures[0]
        self._gen_architecture = self._architectures[1]
        self._adv_architecture = self._architectures[2]

        self._is_patchgan = is_patchgan
        self._is_wasserstein = is_wasserstein
        self._is_feature_matching = False

        ################# Define architecture
        if self._is_patchgan:
            f_xy = self._adv_architecture[-1][-1]["filters"]
            assert f_xy == 1, "If is PatchGAN, last layer of adversarial_XY needs 1 filter. Given: {}.".format(
                f_xy)

            a_xy = self._adv_architecture[-1][-1]["activation"]
            if self._is_wasserstein:
                assert a_xy == tf.identity, "If is PatchGAN, last layer of adversarial needs tf.identity. Given: {}.".format(
                    a_xy)
            else:
                assert a_xy == tf.nn.sigmoid, "If is PatchGAN, last layer of adversarial needs tf.nn.sigmoid. Given: {}.".format(
                    a_xy)
        else:
            self._adv_architecture.append(
                [tf.layers.flatten, {
                    "name": "Flatten"
                }])
            if self._is_wasserstein:
                self._adv_architecture.append([
                    logged_dense, {
                        "units": 1,
                        "activation": tf.identity,
                        "name": "Output"
                    }
                ])
            else:
                self._adv_architecture.append([
                    logged_dense, {
                        "units": 1,
                        "activation": tf.nn.sigmoid,
                        "name": "Output"
                    }
                ])

        last_layers_mean = [[tf.layers.flatten, {
            "name": "flatten"
        }],
                            [
                                logged_dense, {
                                    "units": z_dim,
                                    "activation": tf.identity,
                                    "name": "Mean"
                                }
                            ]]
        self._encoder_mean = Encoder(self._enc_architecture + last_layers_mean,
                                     name="Encoder")
        last_layers_std = [[tf.layers.flatten, {
            "name": "flatten"
        }],
                           [
                               logged_dense, {
                                   "units": z_dim,
                                   "activation": tf.identity,
                                   "name": "Std"
                               }
                           ]]
        self._encoder_std = Encoder(self._enc_architecture + last_layers_std,
                                    name="Encoder")

        self._gen_architecture[-1][1]["name"] = "Output"
        self._generator = Generator(self._gen_architecture, name="Generator")

        self._adversarial = Discriminator(self._adv_architecture,
                                          name="Adversarial")

        self._nets = [self._encoder_mean, self._generator, self._adversarial]

        ################# Connect inputs and networks
        self._mean_layer = self._encoder_mean.generate_net(self._Y_input)
        self._std_layer = self._encoder_std.generate_net(self._Y_input)

        self._output_enc_with_noise = self._mean_layer + tf.exp(
            0.5 * self._std_layer) * self._Z_input
        with tf.name_scope("Inputs"):
            self._gen_input = image_condition_concat(
                inputs=self._X_input,
                condition=self._output_enc_with_noise,
                name="mod_z_real")
            self._gen_input_from_encoding = image_condition_concat(
                inputs=self._X_input, condition=self._Z_input, name="mod_z")
        self._output_gen = self._generator.generate_net(self._gen_input)
        self._output_gen_from_encoding = self._generator.generate_net(
            self._gen_input_from_encoding)
        self._generator._input_dim = z_dim

        assert self._output_gen.get_shape()[1:] == y_dim, (
            "Generator output must have shape of y_dim. Given: {}. Expected: {}."
            .format(self._output_gen.get_shape(), x_dim))

        with tf.name_scope("InputsAdversarial"):
            self._input_real = tf.concat(values=[self._Y_input, self._X_input],
                                         axis=3)
            self._input_fake_from_real = tf.concat(
                values=[self._output_gen, self._X_input], axis=3)
            self._input_fake_from_latent = tf.concat(
                values=[self._output_gen_from_encoding, self._X_input], axis=3)

        self._output_adv_real = self._adversarial.generate_net(
            self._input_real)
        self._output_adv_fake_from_real = self._adversarial.generate_net(
            self._input_fake_from_real)
        self._output_adv_fake_from_latent = self._adversarial.generate_net(
            self._input_fake_from_latent)

        ################# Finalize
        self._init_folders()
        self._verify_init()

        self._output_label_real = tf.placeholder(
            tf.float32, shape=self._output_adv_real.shape, name="label_real")
        self._output_label_fake = tf.placeholder(
            tf.float32,
            shape=self._output_adv_fake_from_real.shape,
            name="label_fake")

        if self._is_patchgan:
            print("PATCHGAN chosen with output: {}.".format(
                self._output_adv_real.shape))

    def compile(self,
                loss,
                optimizer,
                learning_rate=None,
                learning_rate_enc=None,
                learning_rate_gen=None,
                learning_rate_adv=None,
                label_smoothing=1,
                lmbda_kl=0.1,
                lmbda_y=1,
                feature_matching=False,
                random_labeling=0):

        if self._is_wasserstein and loss != "wasserstein":
            raise ValueError(
                "If is_wasserstein is true in Constructor, loss needs to be wasserstein."
            )
        if not self._is_wasserstein and loss == "wasserstein":
            raise ValueError(
                "If loss is wasserstein, is_wasserstein needs to be true in constructor."
            )

        if np.all([
                lr is None for lr in [
                    learning_rate, learning_rate_enc, learning_rate_gen,
                    learning_rate_adv
                ]
        ]):
            raise ValueError("Need learning_rate.")
        if learning_rate is not None and learning_rate_enc is None:
            learning_rate_enc = learning_rate
        if learning_rate is not None and learning_rate_gen is None:
            learning_rate_gen = learning_rate
        if learning_rate is not None and learning_rate_adv is None:
            learning_rate_adv = learning_rate

        self._define_loss(loss=loss,
                          label_smoothing=label_smoothing,
                          lmbda_kl=lmbda_kl,
                          lmbda_y=lmbda_y,
                          feature_matching=feature_matching,
                          random_labeling=random_labeling)
        with tf.name_scope("Optimizer"):
            self._enc_optimizer = optimizer(learning_rate=learning_rate_enc)
            self._enc_optimizer_op = self._enc_optimizer.minimize(
                self._enc_loss,
                var_list=self._get_vars("Encoder"),
                name="Encoder")
            self._gen_optimizer = optimizer(learning_rate=learning_rate_gen)
            self._gen_oprimizer_op = self._gen_optimizer.minimize(
                self._gen_loss,
                var_list=self._get_vars("Generator"),
                name="Generator")
            self._adv_optimizer = optimizer(learning_rate=learning_rate_adv)
            self._adv_optimizer_op = self._adv_optimizer.minimize(
                self._adv_loss,
                var_list=self._get_vars("Adversarial"),
                name="Adversarial")
        self._summarise()

    def _define_loss(self, loss, label_smoothing, lmbda_kl, lmbda_y,
                     feature_matching, random_labeling):
        possible_losses = ["cross-entropy", "L2", "wasserstein", "KL"]

        def get_labels_one():
            return tf.math.multiply(self._output_label_real, label_smoothing)

        def get_labels_zero():
            return self._output_label_fake

        eps = 1e-6
        self._label_smoothing = label_smoothing
        self._random_labeling = random_labeling
        ## Kullback-Leibler divergence
        self._KLdiv = 0.5 * (tf.square(self._mean_layer) +
                             tf.exp(self._std_layer) - self._std_layer - 1)
        self._KLdiv = lmbda_kl * tf.reduce_mean(self._KLdiv)

        ## L1 loss
        self._recon_loss = lmbda_y * tf.reduce_mean(
            tf.abs(self._Y_input - self._output_gen))

        ## Adversarial loss
        if loss == "cross-entropy":
            self._logits_real = tf.math.log(self._output_adv_real /
                                            (1 + eps - self._output_adv_real) +
                                            eps)
            self._logits_fake_from_real = tf.math.log(
                self._output_adv_fake_from_real /
                (1 + eps - self._output_adv_fake_from_real) + eps)
            self._logits_fake_from_latent = tf.math.log(
                self._output_adv_fake_from_latent /
                (1 + eps - self._output_adv_fake_from_latent) + eps)
            self._generator_loss = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.ones_like(self._logits_fake_from_real),
                    logits=self._logits_fake_from_real) +
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.ones_like(self._logits_fake_from_latent),
                    logits=self._logits_fake_from_latent))
            self._adversarial_loss = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=get_labels_one(), logits=self._logits_real) +
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=get_labels_zero(),
                    logits=self._logits_fake_from_real) +
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=get_labels_zero(),
                    logits=self._logits_fake_from_latent))
        elif loss == "L2":
            self._generator_loss = tf.reduce_mean(
                tf.square(self._output_adv_fake_from_real -
                          tf.ones_like(self._output_adv_fake_from_real)) +
                tf.square(self._output_adv_fake_from_latent -
                          tf.ones_like(self._output_adv_fake_from_latent))) / 2
            self._adversarial_loss = (tf.reduce_mean(
                tf.square(self._output_adv_real - get_labels_one()) +
                tf.square(self._output_adv_fake_from_real -
                          get_labels_zero()) +
                tf.square(self._output_adv_fake_from_latent -
                          get_labels_zero()))) / 3.0
        elif loss == "wasserstein":
            self._generator_loss = -tf.reduce_mean(
                self._output_adv_fake_from_real) - tf.reduce_mean(
                    self._output_adv_fake_from_latent)
            self._adversarial_loss = (
                -(tf.reduce_mean(self._output_adv_real) -
                  tf.reduce_mean(self._output_adv_fake_from_real) -
                  tf.reduce_mean(self._output_adv_fake_from_latent)) +
                10 * self._define_gradient_penalty())
        elif loss == "KL":
            self._logits_real = tf.math.log(self._output_adv_real /
                                            (1 + eps - self._output_adv_real) +
                                            eps)
            self._logits_fake_from_real = tf.math.log(
                self._output_adv_fake_from_real /
                (1 + eps - self._output_adv_fake_from_real) + eps)
            self._logits_fake_from_latent = tf.math.log(
                self._output_adv_fake_from_latent /
                (1 + eps - self._output_adv_fake_from_latent) + eps)
            self._generator_loss = (
                -tf.reduce_mean(self._logits_fake_from_real) -
                tf.reduce_mean(self._logits_fake_from_latent)) / 2
            self._adversarial_loss = tf.reduce_mean(
                0.5 * tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=get_labels_one(), logits=self._logits_real) +
                0.25 * tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=get_labels_zero(),
                    logits=self._logits_fake_from_real) +
                0.25 * tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=get_labels_zero(),
                    logits=self._logits_fake_from_latent))
        else:
            raise ValueError(
                "Loss not implemented. Choose from {}. Given: {}.".format(
                    possible_losses, loss))

        if feature_matching:
            self._is_feature_matching = True
            otp_adv_real = self._adversarial.generate_net(
                self._input_real,
                tf_trainflag=self._is_training,
                return_idx=-2)
            otp_adv_fake = self._adversarial.generate_net(
                self._input_fake_from_real,
                tf_trainflag=self._is_training,
                return_idx=-2)
            self._generator_loss = tf.reduce_mean(
                tf.square(otp_adv_real - otp_adv_fake))

        with tf.name_scope("Loss") as scope:

            self._enc_loss = self._KLdiv + self._recon_loss + self._generator_loss
            self._gen_loss = self._recon_loss + self._generator_loss
            self._adv_loss = self._adversarial_loss

            tf.summary.scalar("Kullback-Leibler", self._KLdiv)
            tf.summary.scalar("Reconstruction", self._recon_loss)
            tf.summary.scalar("Vanilla_Generator", self._generator_loss)

            tf.summary.scalar("Encoder", self._enc_loss)
            tf.summary.scalar("Generator", self._gen_loss)
            tf.summary.scalar("Adversarial", self._adv_loss)

    def _define_gradient_penalty(self):
        alpha = tf.random_uniform(shape=tf.shape(self._input_real),
                                  minval=0.,
                                  maxval=1.)
        differences = self._input_fake_from_real - self._input_real
        interpolates = self._input_real + (alpha * differences)
        gradients = tf.gradients(self._adversarial.generate_net(interpolates),
                                 [interpolates])[0]
        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients)))
        with tf.name_scope("Loss") as scope:
            self._gradient_penalty = tf.reduce_mean((slopes - 1.)**2)
            tf.summary.scalar("Gradient_penalty", self._gradient_penalty)
        return self._gradient_penalty

    def train(self,
              x_train,
              y_train,
              x_test,
              y_test,
              epochs=100,
              batch_size=64,
              adv_steps=5,
              gen_steps=1,
              log_step=3,
              gpu_options=None,
              batch_log_step=None):
        self._set_up_training(log_step=log_step, gpu_options=gpu_options)
        self._set_up_test_train_sample(x_train=x_train,
                                       y_train=y_train,
                                       x_test=x_test,
                                       y_test=y_test)
        self._z_test = self._generator.sample_noise(n=len(self._x_test))
        nr_batches = np.floor(len(x_train) / batch_size)
        self.batch_size = batch_size
        self._prepare_monitoring()
        self._log_results(epoch=0, epoch_time=0)

        for epoch in range(epochs):
            adv_loss_epoch = 0
            gen_loss_epoch = 0
            enc_loss_epoch = 0
            start = time.clock()
            trained_examples = 0
            batch_nr = 0

            while trained_examples < len(x_train):
                batch_train_start = time.clock()
                adv_loss_batch, gen_loss_batch, enc_loss_batch = self._optimize(
                    self._trainset, adv_steps, gen_steps)
                trained_examples += self.batch_size
                adv_loss_epoch += adv_loss_batch
                gen_loss_epoch += gen_loss_batch
                enc_loss_epoch += enc_loss_batch
                self._total_train_time += (time.clock() - batch_train_start)

                if (batch_log_step is not None) and (batch_nr % batch_log_step
                                                     == 0):
                    self._count_batches += batch_log_step
                    batch_train_time = (time.clock() - start) / 60
                    self._log(self._count_batches, batch_train_time)
                batch_nr += 1

            epoch_train_time = (time.clock() - start) / 60
            adv_loss_epoch = np.round(adv_loss_epoch, 2)
            gen_loss_epoch = np.round(gen_loss_epoch, 2)
            enc_loss_epoch = np.round(enc_loss_epoch, 2)

            print("\nEpoch {}: D: {}; G: {}; E: {}.".format(
                epoch, adv_loss_epoch, gen_loss_epoch, enc_loss_epoch))

            if batch_log_step is None and (log_step
                                           is not None) and (epoch % log_step
                                                             == 0):
                self._log(epoch + 1, epoch_train_time)

    def _optimize(self, dataset, adv_steps, gen_steps):
        for i in range(adv_steps):
            current_batch_x, current_batch_y = dataset.get_next_batch(
                self.batch_size)
            Z_noise = self._generator.sample_noise(n=len(current_batch_x))
            _, adv_loss_batch = self._sess.run(
                [self._adv_optimizer_op, self._adv_loss],
                feed_dict={
                    self._X_input: current_batch_x,
                    self._Y_input: current_batch_y,
                    self._Z_input: Z_noise,
                    self._is_training: True,
                    self._output_label_real:
                    self.get_random_label(is_real=True),
                    self._output_label_fake:
                    self.get_random_label(is_real=False)
                })

        for i in range(gen_steps):
            Z_noise = self._generator.sample_noise(n=len(current_batch_x))
            _, gen_loss_batch = self._sess.run(
                [self._gen_oprimizer_op, self._gen_loss],
                feed_dict={
                    self._X_input: current_batch_x,
                    self._Y_input: current_batch_y,
                    self._Z_input: Z_noise,
                    self._is_training: True
                })
            _, enc_loss_batch = self._sess.run(
                [self._enc_optimizer_op, self._enc_loss],
                feed_dict={
                    self._X_input: current_batch_x,
                    self._Y_input: current_batch_y,
                    self._Z_input: Z_noise,
                    self._is_training: True
                })

        return adv_loss_batch, gen_loss_batch, enc_loss_batch

    def _log_results(self, epoch, epoch_time):
        summary = self._sess.run(self._merged_summaries,
                                 feed_dict={
                                     self._X_input:
                                     self._x_test,
                                     self._Y_input:
                                     self._y_test,
                                     self._Z_input:
                                     self._z_test,
                                     self._epoch_time:
                                     epoch_time,
                                     self._is_training:
                                     False,
                                     self._epoch_nr:
                                     epoch,
                                     self._output_label_real:
                                     self.get_random_label(is_real=True,
                                                           size=self._nr_test),
                                     self._output_label_fake:
                                     self.get_random_label(is_real=False,
                                                           size=self._nr_test)
                                 })
        self._writer1.add_summary(summary, epoch)
        nr_test = len(self._x_test)
        summary = self._sess.run(self._merged_summaries,
                                 feed_dict={
                                     self._X_input:
                                     self._trainset.get_xdata()[:nr_test],
                                     self._Z_input:
                                     self._z_test,
                                     self._Y_input:
                                     self._trainset.get_ydata()[:nr_test],
                                     self._epoch_time:
                                     epoch_time,
                                     self._is_training:
                                     False,
                                     self._epoch_nr:
                                     epoch,
                                     self._output_label_real:
                                     self.get_random_label(is_real=True,
                                                           size=self._nr_test),
                                     self._output_label_fake:
                                     self.get_random_label(is_real=False,
                                                           size=self._nr_test)
                                 })
        self._writer2.add_summary(summary, epoch)
        if self._image_shape is not None:
            self.plot_samples(inpt_x=self._x_test[:10],
                              inpt_y=self._y_test[:10],
                              sess=self._sess,
                              image_shape=self._image_shape,
                              epoch=epoch,
                              path="{}/GeneratedSamples/result_{}.png".format(
                                  self._folder, epoch))
        self.save_model(epoch)
        additional_log = getattr(self, "evaluate", None)
        if callable(additional_log):
            self.evaluate(true=self._x_test,
                          condition=self._y_test,
                          epoch=epoch)
        print("Logged.")

    def plot_samples(self, inpt_x, inpt_y, sess, image_shape, epoch, path):
        outpt_xy = sess.run(self._output_gen_from_encoding,
                            feed_dict={
                                self._X_input: inpt_x,
                                self._Z_input: self._z_test[:len(inpt_x)],
                                self._is_training: False
                            })

        image_matrix = np.array([[
            x.reshape(self._x_dim[0], self._x_dim[1]),
            y.reshape(self._y_dim[0], self._y_dim[1]),
            np.zeros(shape=(self._x_dim[0], self._x_dim[1])),
            xy.reshape(self._y_dim[0], self._y_dim[1])
        ] for x, y, xy in zip(inpt_x, inpt_y, outpt_xy)])
        self._generator.build_generated_samples(
            image_matrix,
            column_titles=["True X", "True Y", "", "Gen_XY"],
            epoch=epoch,
            path=path)

    def _prepare_monitoring(self):
        self._total_train_time = 0
        self._total_log_time = 0
        self._count_batches = 0
        self._batches = []

        self._max_allowed_failed_checks = 20
        self._enc_grads_and_vars = self._enc_optimizer.compute_gradients(
            self._enc_loss, var_list=self._get_vars("Encoder"))
        self._gen_grads_and_vars = self._gen_optimizer.compute_gradients(
            self._gen_loss, var_list=self._get_vars("Generator"))
        self._adv_grads_and_vars = self._adv_optimizer.compute_gradients(
            self._adv_loss, var_list=self._get_vars("Adversarial"))

        self._monitor_dict = {
            "Gradients": [[
                self._enc_grads_and_vars, self._gen_grads_and_vars,
                self._adv_grads_and_vars
            ], ["Encoder", "Generator", "Adversarial"],
                          [[] for i in range(9)]],
            "Losses":
            [[
                self._enc_loss, self._gen_loss, self._adversarial_loss,
                self._generator_loss, self._recon_loss, self._KLdiv
            ],
             [
                 "Encoder (V+R+K)", "Generator (V+R)", "Adversarial",
                 "Vanilla_Generator", "Reconstruction", "Kullback-Leibler"
             ], [[] for i in range(6)]],
            "Output Adversarial": [[
                self._output_adv_fake_from_real,
                self._output_adv_fake_from_latent, self._output_adv_real
            ], ["Fake_from_real", "Fake_from_latent", "Real"],
                                   [[] for i in range(3)], [np.mean]]
        }

        self._check_dict = {
            "Dominating Discriminator": {
                "Tensors":
                [self._output_adv_real, self._output_adv_fake_from_real],
                "OPonTensors": [np.mean, np.mean],
                "Relation": [">", "<"],
                "Threshold": [
                    self._label_smoothing * 0.95,
                    (1 - self._label_smoothing) * 1.05
                ],
                "TensorRelation":
                np.logical_and
            },
            "Generator outputs zeros": {
                "Tensors": [
                    self._output_gen_from_encoding,
                    self._output_gen_from_encoding
                ],
                "OPonTensors": [np.max, np.min],
                "Relation": ["<", ">"],
                "Threshold": [0.05, 0.95],
                "TensorRelation":
                np.logical_or
            }
        }
        self._check_count = [0 for key in self._check_dict]

        if not os.path.exists(self._folder + "/Evaluation"):
            pos.mkdir(self._folder + "/Evaluation")
        os.mkdir(self._folder + "/Evaluation/Cells")
        os.mkdir(self._folder + "/Evaluation/CenterOfMassX")
        os.mkdir(self._folder + "/Evaluation/CenterOfMassY")
        os.mkdir(self._folder + "/Evaluation/Energy")
        os.mkdir(self._folder + "/Evaluation/MaxEnergy")
        os.mkdir(self._folder + "/Evaluation/StdEnergy")

    def evaluate(self, true, condition, epoch):
        print("Batch ", epoch)
        log_start = time.clock()
        self._batches.append(epoch)

        fake = self._sess.run(self._output_gen_from_encoding,
                              feed_dict={
                                  self._X_input: self._x_test,
                                  self._Z_input: self._z_test,
                                  self._is_training: False
                              })
        true = self._y_test.reshape(
            [-1, self._image_shape[0], self._image_shape[1]])
        fake = fake.reshape([-1, self._image_shape[0], self._image_shape[1]])
        build_histogram(true=true,
                        fake=fake,
                        function=get_energies,
                        name="Energy",
                        epoch=epoch,
                        folder=self._folder)
        build_histogram(true=true,
                        fake=fake,
                        function=get_number_of_activated_cells,
                        name="Cells",
                        epoch=epoch,
                        folder=self._folder,
                        threshold=6 / 6120)
        build_histogram(true=true,
                        fake=fake,
                        function=get_max_energy,
                        name="MaxEnergy",
                        epoch=epoch,
                        folder=self._folder)
        build_histogram(true=true,
                        fake=fake,
                        function=get_center_of_mass_x,
                        name="CenterOfMassX",
                        epoch=epoch,
                        folder=self._folder,
                        image_shape=self._image_shape)
        build_histogram(true=true,
                        fake=fake,
                        function=get_center_of_mass_y,
                        name="CenterOfMassY",
                        epoch=epoch,
                        folder=self._folder,
                        image_shape=self._image_shape)
        build_histogram(true=true,
                        fake=fake,
                        function=get_std_energy,
                        name="StdEnergy",
                        epoch=epoch,
                        folder=self._folder)

        fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(18, 10))
        axs = np.ravel(axs)
        if "Gradients" in self._monitor_dict:
            colors = ["green", "blue", "red"]
            axy_min = np.inf
            axy_max = -np.inf
            for go, gradient_ops in enumerate(
                    self._monitor_dict["Gradients"][0]):
                grads = [
                    self._sess.run(
                        gv[0],
                        feed_dict={
                            self._X_input:
                            self._x_test,
                            self._Y_input:
                            self._y_test,
                            self._Z_input:
                            self._z_test,
                            self._is_training:
                            False,
                            self._output_label_real:
                            self.get_random_label(is_real=True,
                                                  size=self._nr_test),
                            self._output_label_fake:
                            self.get_random_label(is_real=False,
                                                  size=self._nr_test)
                        }) for gv in gradient_ops
                ]

                for op_idx, op in enumerate([np.mean, np.max, np.min]):
                    self._monitor_dict["Gradients"][2][go * 3 + op_idx].append(
                        op([op(grad) for grad in grads]))
                    vals = self._monitor_dict["Gradients"][2][go * 3 + op_idx]
                    if op_idx == 0:
                        axs[0].plot(
                            self._batches,
                            vals,
                            label=self._monitor_dict["Gradients"][1][go],
                            color=colors[go])
                    else:
                        axs[0].plot(self._batches,
                                    vals,
                                    linewidth=0.5,
                                    linestyle="--",
                                    color=colors[go])
                        upper = np.mean(vals)
                        lower = np.mean(vals)
                        if upper > axy_max:
                            axy_max = upper
                        if lower < axy_min:
                            axy_min = lower
        axs[0].set_title("Gradients")
        axs[0].legend()
        axs[0].set_ylim([axy_min, axy_max])

        current_batch_x, current_batch_y = self._trainset.get_next_batch(
            self.batch_size)
        Z_noise = self._generator.sample_noise(n=len(current_batch_x))

        colors = [
            "green", "blue", "red", "orange", "purple", "brown", "gray",
            "pink", "cyan", "olive"
        ]
        for k, key in enumerate(self._monitor_dict):
            if key == "Gradients":
                continue
            key_results = self._sess.run(
                self._monitor_dict[key][0],
                feed_dict={
                    self._X_input: current_batch_x,
                    self._Y_input: current_batch_y,
                    self._Z_input: Z_noise,
                    self._is_training: True,
                    self._output_label_real:
                    self.get_random_label(is_real=True),
                    self._output_label_fake:
                    self.get_random_label(is_real=False)
                })
            for kr, key_result in enumerate(key_results):
                try:
                    self._monitor_dict[key][2][kr].append(
                        self._monitor_dict[key][3][0](key_result))
                except IndexError:
                    self._monitor_dict[key][2][kr].append(key_result)
                axs[k].plot(self._batches,
                            self._monitor_dict[key][2][kr],
                            label=self._monitor_dict[key][1][kr],
                            color=colors[kr])
            axs[k].legend()
            axs[k].set_title(key)
            print("; ".join([
                "{}: {}".format(name, round(float(val[-1]), 5)) for name, val
                in zip(self._monitor_dict[key][1], self._monitor_dict[key][2])
            ]))

        gen_samples = self._sess.run(
            [self._output_gen_from_encoding],
            feed_dict={
                self._X_input: current_batch_x,
                self._Z_input: Z_noise,
                self._is_training: False
            })
        axs[-1].hist([np.ravel(gen_samples),
                      np.ravel(current_batch_y)],
                     label=["Generated", "True"])
        axs[-1].set_title("Pixel distribution")
        axs[-1].legend()

        for check_idx, check_key in enumerate(self._check_dict):
            result_bools_of_check = []
            check = self._check_dict[check_key]
            for tensor_idx in range(len(check["Tensors"])):
                tensor_ = self._sess.run(check["Tensors"][tensor_idx],
                                         feed_dict={
                                             self._X_input: self._x_test,
                                             self._Y_input: self._y_test,
                                             self._Z_input: self._z_test,
                                             self._is_training: False
                                         })
                tensor_op = check["OPonTensors"][tensor_idx](tensor_)
                if eval(
                        str(tensor_op) + check["Relation"][tensor_idx] +
                        str(check["Threshold"][tensor_idx])):
                    result_bools_of_check.append(True)
                else:
                    result_bools_of_check.append(False)
            if (tensor_idx > 0 and check["TensorRelation"](
                    *result_bools_of_check)) or (result_bools_of_check[0]):
                self._check_count[check_idx] += 1
                if self._check_count[
                        check_idx] == self._max_allowed_failed_checks:
                    raise GeneratorExit(check_key)
            else:
                self._check_count[check_idx] = 0

        self._total_log_time += (time.clock() - log_start)
        fig.suptitle("Train {} / Log {} / Fails {}".format(
            np.round(self._total_train_time, 2),
            np.round(self._total_log_time, 2), self._check_count))

        plt.savefig(self._folder + "/TrainStatistics.png")
        plt.close("all")

    def get_random_label(self, is_real, size=None):
        if size is None:
            size = self.batch_size
        labels_shape = [size, *self._output_adv_real.shape.as_list()[1:]]
        labels = np.ones(shape=labels_shape)
        if self._random_labeling > 0:
            relabel_mask = np.random.binomial(n=1,
                                              p=self._random_labeling,
                                              size=labels_shape) == 1
            labels[relabel_mask] = 0
        if not is_real:
            labels = 1 - labels
        return labels
Exemple #24
0
    def __init__(
        self,
        x_dim,
        y_dim,
        z_dim,
        gen_architecture,
        adversarial_architecture,
        folder="./CGAN",
        append_y_at_every_layer=None,
        is_patchgan=False,
        is_wasserstein=False,
        aux_architecture=None,
    ):
        architectures = [gen_architecture, adversarial_architecture]
        self._is_cycle_consistent = False
        if aux_architecture is not None:
            architectures.append(aux_architecture)
            self._is_cycle_consistent = True
        super(CGAN,
              self).__init__(x_dim=x_dim,
                             y_dim=y_dim,
                             z_dim=z_dim,
                             architectures=architectures,
                             folder=folder,
                             append_y_at_every_layer=append_y_at_every_layer)

        self._gen_architecture = self._architectures[0]
        self._adversarial_architecture = self._architectures[1]
        self._is_patchgan = is_patchgan
        self._is_wasserstein = is_wasserstein
        self._is_feature_matching = False

        ################# Define architecture
        if self._is_patchgan:
            f_xy = self._adversarial_architecture[-1][-1]["filters"]
            assert f_xy == 1, "If is PatchGAN, last layer of adversarial_XY needs 1 filter. Given: {}.".format(
                f_xy)

            a_xy = self._adversarial_architecture[-1][-1]["activation"]
            if self._is_wasserstein:
                assert a_xy == tf.identity, "If is PatchGAN, last layer of adversarial needs tf.identity. Given: {}.".format(
                    a_xy)
            else:
                assert a_xy == tf.nn.sigmoid, "If is PatchGAN, last layer of adversarial needs tf.nn.sigmoid. Given: {}.".format(
                    a_xy)
        else:
            self._adversarial_architecture.append(
                [tf.layers.flatten, {
                    "name": "Flatten"
                }])
            if self._is_wasserstein:
                self._adversarial_architecture.append([
                    logged_dense, {
                        "units": 1,
                        "activation": tf.identity,
                        "name": "Output"
                    }
                ])
            else:
                self._adversarial_architecture.append([
                    logged_dense, {
                        "units": 1,
                        "activation": tf.nn.sigmoid,
                        "name": "Output"
                    }
                ])
        self._gen_architecture[-1][1]["name"] = "Output"

        self._generator = ConditionalGenerator(self._gen_architecture,
                                               name="Generator")
        self._adversarial = Critic(self._adversarial_architecture,
                                   name="Adversarial")

        self._nets = [self._generator, self._adversarial]

        ################# Connect inputs and networks
        self._output_gen = self._generator.generate_net(
            self._mod_Z_input,
            append_elements_at_every_layer=self._append_at_every_layer,
            tf_trainflag=self._is_training)

        with tf.name_scope("InputsAdversarial"):
            if len(self._x_dim) == 1:
                self._input_real = tf.concat(
                    axis=1, values=[self._X_input, self._Y_input], name="real")
                self._input_fake = tf.concat(
                    axis=1,
                    values=[self._output_gen, self._Y_input],
                    name="fake")
            else:
                self._input_real = image_condition_concat(
                    inputs=self._X_input, condition=self._Y_input, name="real")
                self._input_fake = image_condition_concat(
                    inputs=self._output_gen,
                    condition=self._Y_input,
                    name="fake")

        self._output_adversarial_real = self._adversarial.generate_net(
            self._input_real, tf_trainflag=self._is_training)
        self._output_adversarial_fake = self._adversarial.generate_net(
            self._input_fake, tf_trainflag=self._is_training)

        assert self._output_gen.get_shape()[1:] == x_dim, (
            "Output of generator is {}, but x_dim is {}.".format(
                self._output_gen.get_shape(), x_dim))

        ################# Auxiliary network for cycle consistency
        if self._is_cycle_consistent:
            self._auxiliary = Encoder(self._architectures[2], name="Auxiliary")
            self._output_auxiliary = self._auxiliary.generate_net(
                self._output_gen, tf_trainflag=self._is_training)
            assert self._output_auxiliary.get_shape().as_list(
            ) == self._mod_Z_input.get_shape().as_list(), (
                "Wrong shape for auxiliary vs. mod Z: {} vs {}.".format(
                    self._output_auxiliary.get_shape(),
                    self._mod_Z_input.get_shape()))
            self._nets.append(self._auxiliary)

        ################# Finalize
        self._init_folders()
        self._verify_init()

        if self._is_patchgan:
            print("PATCHGAN chosen with output: {}.".format(
                self._output_adversarial_real.shape))
Exemple #25
0
class LSGANs_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(LSGANs_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.encoder = Encoder(hyperparameters['input_dim_a'],
                               hyperparameters['gen'])
        self.decoder = Decoder(hyperparameters['input_dim_a'],
                               hyperparameters['gen'])
        self.dis_a = Discriminator()
        self.dis_b = Discriminator()
        self.interp_net_ab = Interpolator()
        self.interp_net_ba = Interpolator()
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        enc_params = list(self.encoder.parameters())
        dec_params = list(self.decoder.parameters())
        dis_a_params = list(self.dis_a.parameters())
        dis_b_params = list(self.dis_b.parameters())
        interperlator_ab_params = list(self.interp_net_ab.parameters())
        interperlator_ba_params = list(self.interp_net_ba.parameters())

        self.enc_opt = torch.optim.Adam(
            [p for p in enc_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dec_opt = torch.optim.Adam(
            [p for p in dec_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_a_opt = torch.optim.Adam(
            [p for p in dis_a_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_b_opt = torch.optim.Adam(
            [p for p in dis_b_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.interp_ab_opt = torch.optim.Adam(
            [p for p in interperlator_ab_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.interp_ba_opt = torch.optim.Adam(
            [p for p in interperlator_ba_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])

        self.enc_scheduler = get_scheduler(self.enc_opt, hyperparameters)
        self.dec_scheduler = get_scheduler(self.dec_opt, hyperparameters)
        self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters)
        self.dis_b_scheduler = get_scheduler(self.dis_b_opt, hyperparameters)
        self.interp_ab_scheduler = get_scheduler(self.interp_ab_opt,
                                                 hyperparameters)
        self.interp_ba_scheduler = get_scheduler(self.interp_ba_opt,
                                                 hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        self.total_loss = 0
        self.best_iter = 0
        self.perceptural_loss = Perceptural_loss()

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()

        c_a, s_a_fake = self.encoder(x_a)
        c_b, s_b_fake = self.encoder(x_b)

        # decode (cross domain)
        s_ab_interp = self.interp_net_ab(s_a_fake, s_b_fake, self.v)
        s_ba_interp = self.interp_net_ba(s_b_fake, s_a_fake, self.v)
        x_ba = self.decoder(c_b, s_a_interp)
        x_ab = selfdecoder(c_a, s_b_interp)
        self.train()
        return x_ab, x_ba

    def zero_grad(self):
        self.dis_a_opt.zero_grad()
        self.dis_b_opt.zero_grad()
        self.dec_opt.zero_grad()
        self.enc_opt.zero_grad()
        self.interp_ab_opt.zero_grad()
        self.interp_ba_opt.zero_grad()

    def dis_update(self, x_a, x_b, hyperparameters):
        self.zero_grad()

        # encode
        c_a, s_a = self.encoder(x_a)
        c_b, s_b = self.encoder(x_b)

        # decode (cross domain)
        self.v = torch.ones(s_a.size())
        s_a_interp = self.interp_net_ba(s_b, s_a, self.v)
        s_b_interp = self.interp_net_ab(s_a, s_b, self.v)
        x_ba = self.decoder(c_b, s_a_interp)
        x_ab = self.decoder(c_a, s_b_interp)

        x_a_feature = self.dis_a(x_a)
        x_ba_feature = self.dis_a(x_ba)
        x_b_feature = self.dis_b(x_b)
        x_ab_feature = self.dis_b(x_ab)
        self.loss_dis_a = (x_ba_feature - x_a_feature).mean()
        self.loss_dis_b = (x_ab_feature - x_b_feature).mean()

        # gradient penality
        self.loss_dis_a_gp = self.dis_a.calculate_gradient_penalty(x_ba, x_a)
        self.loss_dis_b_gp = self.dis_b.calculate_gradient_penalty(x_ab, x_b)


        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + \
                              hyperparameters['gan_w'] * self.loss_dis_b + \
                              hyperparameters['gan_w'] * self.loss_dis_a_gp + \
                              hyperparameters['gan_w'] * self.loss_dis_b_gp

        self.loss_dis_total.backward()
        self.total_loss += self.loss_dis_total.item()
        self.dis_a_opt.step()
        self.dis_b_opt.step()

    def gen_update(self, x_a, x_b, hyperparameters):
        self.zero_grad()

        # encode
        c_a, s_a = self.encoder(x_a)
        c_b, s_b = self.encoder(x_b)

        # decode (within domain)
        x_a_recon = self.decoder(c_a, s_a)
        x_b_recon = self.decoder(c_b, s_b)

        # decode (cross domain)
        self.v = torch.ones(s_a.size())
        s_a_interp = self.interp_net_ba(s_b, s_a, self.v)
        s_b_interp = self.interp_net_ab(s_a, s_b, self.v)
        x_ba = self.decoder(c_b, s_a_interp)
        x_ab = self.decoder(c_a, s_b_interp)

        # encode again
        c_b_recon, s_a_recon = self.encoder(x_ba)
        c_a_recon, s_b_recon = self.encoder(x_ab)

        # decode again
        x_aa = self.decoder(
            c_a_recon, s_a) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bb = self.decoder(
            c_b_recon, s_b) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aa, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bb, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0

        # perceptual loss
        self.loss_gen_vgg_a = self.perceptural_loss(
            x_a_recon, x_a) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.perceptural_loss(
            x_b_recon, x_b) if hyperparameters['vgg_w'] > 0 else 0

        self.loss_gen_vgg_aa = self.perceptural_loss(
            x_aa, x_a) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_bb = self.perceptural_loss(
            x_bb, x_b) if hyperparameters['vgg_w'] > 0 else 0

        # GAN loss
        x_ba_feature = self.dis_a(x_ba)
        x_ab_feature = self.dis_b(x_ab)
        self.loss_gen_adv_a = -x_ba_feature.mean()
        self.loss_gen_adv_b = -x_ab_feature.mean()

        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_aa + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_bb + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b

        self.loss_gen_total.backward()
        self.total_loss += self.loss_gen_total.item()
        self.dec_opt.step()
        self.enc_opt.step()
        self.interp_ab_opt.step()
        self.interp_ba_opt.step()

    def sample(self, x_a, x_b):
        self.eval()
        x_a_recon, x_b_recon, x_ab, x_ba, x_aa, x_bb = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            c_a, s_a = self.encoder(x_a[i].unsqueeze(0))
            c_b, s_b = self.encoder(x_b[i].unsqueeze(0))
            x_a_recon.append(self.decoder(c_a, s_a))
            x_b_recon.append(self.decoder(c_b, s_b))

            self.v = torch.ones(s_a.size())
            s_a_interp = self.interp_net_ba(s_b, s_a, self.v)
            s_b_interp = self.interp_net_ab(s_a, s_b, self.v)

            x_ab_i = self.decoder(c_a, s_b_interp)
            x_ba_i = self.decoder(c_b, s_a_interp)

            c_a_recon, s_b_recon = self.encoder(x_ab_i)
            c_b_recon, s_a_recon = self.encoder(x_ba_i)

            x_ab.append(self.decoder(c_a, s_b_interp.unsqueeze(0)))
            x_ba.append(self.decoder(c_b, s_a_interp.unsqueeze(0)))
            x_aa.append(self.decoder(c_a_recon, s_a.unsqueeze(0)))
            x_bb.append(self.decoder(c_b_recon, s_b.unsqueeze(0)))

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ab, x_aa = torch.cat(x_ab), torch.cat(x_aa)
        x_ba, x_bb = torch.cat(x_ba), torch.cat(x_bb)

        self.train()

        return x_a, x_a_recon, x_ab, x_aa, x_b, x_b_recon, x_ba, x_bb

    def update_learning_rate(self):
        if self.dis_a_scheduler is not None:
            self.dis_a_scheduler.step()
        if self.dis_b_scheduler is not None:
            self.dis_b_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()
        if self.enc_scheduler is not None:
            self.enc_scheduler.step()
        if self.dec_scheduler is not None:
            self.dec_scheduler.step()
        if self.interpo_ab_scheduler is not None:
            self.interpo_ab_scheduler.step()
        if self.interpo_ba_scheduler is not None:
            self.interpo_ba_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load encode
        model_name = get_model(checkpoint_dir, "encoder")
        state_dict = torch.load(model_name)
        self.encoder.load_state_dict(state_dict)

        # Load decode
        model_name = get_model(checkpoint_dir, "decoder")
        state_dict = torch.load(model_name)
        self.decoder.load_state_dict(state_dict)

        # Load discriminator a
        model_name = get_model(checkpoint_dir, "dis_a")
        state_dict = torch.load(model_name)
        self.dis_a.load_state_dict(state_dict)

        # Load discriminator a
        model_name = get_model(checkpoint_dir, "dis_b")
        state_dict = torch.load(model_name)
        self.dis_b.load_state_dict(state_dict)

        # Load interperlator ab
        model_name = get_model(checkpoint_dir, "interp_ab")
        state_dict = torch.load(model_name)
        self.interp_net_ab.load_state_dict(state_dict)

        # Load interperlator ba
        model_name = get_model(checkpoint_dir, "interp_ba")
        state_dict = torch.load(model_name)
        self.interp_net_ba.load_state_dict(state_dict)

        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.enc_opt.load_state_dict(state_dict['enc_opt'])
        self.dec_opt.load_state_dict(state_dict['dec_opt'])
        self.dis_a_opt.load_state_dict(state_dict['dis_a_opt'])
        self.dis_b_opt.load_state_dict(state_dict['dis_b_opt'])
        self.interp_ab_opt.load_state_dict(state_dict['interp_ab_opt'])
        self.interp_ba_opt.load_state_dict(state_dict['interp_ba_opt'])

        self.best_iter = state_dict['best_iter']
        self.total_loss = state_dict['total_loss']

        # Reinitilize schedulers
        self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters,
                                             self.best_iter)
        self.dis_b_scheduler = get_scheduler(self.dis_b_opt, hyperparameters,
                                             self.best_iter)
        self.enc_scheduler = get_scheduler(self.enc_opt, hyperparameters,
                                           self.best_iter)
        self.dec_scheduler = get_scheduler(self.dec_opt, hyperparameters,
                                           self.best_iter)
        self.interpo_ab_scheduler = get_scheduler(self.interp_ab_opt,
                                                  hyperparameters,
                                                  self.best_iter)
        self.interpo_ba_scheduler = get_scheduler(self.interp_ba_opt,
                                                  hyperparameters,
                                                  self.best_iter)
        print('Resume from iteration %d' % self.best_iter)
        return self.best_iter, self.total_loss

    def resume_iter(self, checkpoint_dir, surfix, hyperparameters):
        # Load encode
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'encoder' + surfix + '.pt'))
        self.encoder.load_state_dict(state_dict)

        # Load decode
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'decoder' + surfix + '.pt'))
        self.decoder.load_state_dict(state_dict)

        # Load discriminator a
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'dis_a' + surfix + '.pt'))
        self.dis_a.load_state_dict(state_dict)

        # # Load discriminator b
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'dis_b' + surfix + '.pt'))
        self.dis_b.load_state_dict(state_dict)

        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'interp' + surfix + '.pt'))
        # print(state_dict)
        self.interp_net_ab.load_state_dict(state_dict['ab'])
        self.interp_net_ba.load_state_dict(state_dict['ba'])

        # Load interperlator ab
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'interp_ab' + surfix + '.pt'))
        self.interp_net_ab.load_state_dict(state_dict)

        # # Load interperlator ba
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'interp_ba' + surfix + '.pt'))
        self.interp_net_ba.load_state_dict(state_dict)

        # Load optimizers
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'optimizer' + surfix + '.pt'))
        self.enc_opt.load_state_dict(state_dict['enc_opt'])
        self.dec_opt.load_state_dict(state_dict['dec_opt'])
        self.dis_a_opt.load_state_dict(state_dict['dis_a_opt'])
        self.dis_b_opt.load_state_dict(state_dict['dis_b_opt'])
        self.interp_ab_opt.load_state_dict(state_dict['interp_ab_opt'])
        self.interp_ba_opt.load_state_dict(state_dict['interp_ba_opt'])

        self.best_iter = state_dict['best_iter']
        self.total_loss = state_dict['total_loss']

        # Reinitilize schedulers
        self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters,
                                             self.best_iter)
        self.dis_b_scheduler = get_scheduler(self.dis_b_opt, hyperparameters,
                                             self.best_iter)
        self.enc_scheduler = get_scheduler(self.enc_opt, hyperparameters,
                                           self.best_iter)
        self.dec_scheduler = get_scheduler(self.dec_opt, hyperparameters,
                                           self.best_iter)
        self.interpo_ab_scheduler = get_scheduler(self.interp_ab_opt,
                                                  hyperparameters,
                                                  self.best_iter)
        self.interpo_ba_scheduler = get_scheduler(self.interp_ba_opt,
                                                  hyperparameters,
                                                  self.best_iter)
        print('Resume from iteration %d' % self.best_iter)
        return self.best_iter, self.total_loss

    def save_better_model(self, snapshot_dir):
        # remove sub_optimal models
        files = glob.glob(snapshot_dir + '/*')
        for f in files:
            os.remove(f)
        # Save encoder, decoder, interpolator, discriminators, and optimizers
        encoder_name = os.path.join(snapshot_dir,
                                    'encoder_%.4f.pt' % (self.total_loss))
        decoder_name = os.path.join(snapshot_dir,
                                    'decoder_%.4f.pt' % (self.total_loss))
        interp_ab_name = os.path.join(snapshot_dir,
                                      'interp_ab_%.4f.pt' % (self.total_loss))
        interp_ba_name = os.path.join(snapshot_dir,
                                      'interp_ba_%.4f.pt' % (self.total_loss))
        dis_a_name = os.path.join(snapshot_dir,
                                  'dis_a_%.4f.pt' % (self.total_loss))
        dis_b_name = os.path.join(snapshot_dir,
                                  'dis_b_%.4f.pt' % (self.total_loss))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')

        torch.save(self.encoder.state_dict(), encoder_name)
        torch.save(self.decoder.state_dict(), decoder_name)
        torch.save(self.interp_net_ab.state_dict(), interp_ab_name)
        torch.save(self.interp_net_ba.state_dict(), interp_ba_name)
        torch.save(self.dis_a_opt.state_dict(), dis_a_name)
        torch.save(self.dis_b_opt.state_dict(), dis_b_name)
        torch.save(
            {
                'enc_opt': self.enc_opt.state_dict(),
                'dec_opt': self.dec_opt.state_dict(),
                'dis_a_opt': self.dis_a_opt.state_dict(),
                'dis_b_opt': self.dis_b_opt.state_dict(),
                'interp_ab_opt': self.interp_ab_opt.state_dict(),
                'interp_ba_opt': self.interp_ba_opt.state_dict(),
                'best_iter': self.best_iter,
                'total_loss': self.total_loss
            }, opt_name)

    def save_at_iter(self, snapshot_dir, iterations):

        encoder_name = os.path.join(snapshot_dir,
                                    'encoder_%08d.pt' % (iterations + 1))
        decoder_name = os.path.join(snapshot_dir,
                                    'decoder_%08d.pt' % (iterations + 1))
        interp_ab_name = os.path.join(snapshot_dir,
                                      'interp_ab_%08d.pt' % (iterations + 1))
        interp_ba_name = os.path.join(snapshot_dir,
                                      'interp_ba_%08d.pt' % (iterations + 1))
        dis_a_name = os.path.join(snapshot_dir,
                                  'dis_a_%08d.pt' % (iterations + 1))
        dis_b_name = os.path.join(snapshot_dir,
                                  'dis_b_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir,
                                'optimizer_%08d.pt' % (iterations + 1))

        torch.save(self.encoder.state_dict(), encoder_name)
        torch.save(self.decoder.state_dict(), decoder_name)
        torch.save(self.interp_net_ab.state_dict(), interp_ab_name)
        torch.save(self.interp_net_ba.state_dict(), interp_ba_name)
        torch.save(self.dis_a_opt.state_dict(), dis_a_name)
        torch.save(self.dis_b_opt.state_dict(), dis_b_name)
        torch.save(
            {
                'enc_opt': self.enc_opt.state_dict(),
                'dec_opt': self.dec_opt.state_dict(),
                'dis_a_opt': self.dis_a_opt.state_dict(),
                'dis_b_opt': self.dis_b_opt.state_dict(),
                'interp_ab_opt': self.interp_ab_opt.state_dict(),
                'interp_ba_opt': self.interp_ba_opt.state_dict(),
                'best_iter': self.best_iter,
                'total_loss': self.total_loss
            }, opt_name)
parser.add_argument("--print_interval", type=int, default=100, help="interval of loss printing")
parser.add_argument("--dataroot", default="", help="path to dataset")
parser.add_argument("--dataset", default="cifar10", help="folder | cifar10 | mnist")
parser.add_argument("--abnormal_class", default="airplane", help="Anomaly class idx for mnist and cifar datasets")
parser.add_argument("--out", default="ckpts", help="checkpoint directory")
parser.add_argument("--device", default="cuda", help="device: cuda | cpu")
parser.add_argument("--G_path", default="ckpts/G_epoch19.pt", help="path to trained state dict of generator")
parser.add_argument("--D_path", default="ckpts/D_epoch19.pt", help="path to trained state dict of discriminator")
opt = parser.parse_args()
print(opt)

img_shape = (opt.channels, opt.img_size, opt.img_size)

generator = Generator(dim = 64, zdim=opt.latent_dim, nc=opt.channels)
discriminator = Discriminator(dim = 64, zdim=opt.latent_dim, nc=opt.channels,out_feat=True)
encoder = Encoder(dim = 64, zdim=opt.latent_dim, nc=opt.channels)

generator.load_state_dict(torch.load(opt.G_path))
discriminator.load_state_dict(torch.load(opt.D_path))
generator.to(opt.device)
encoder.to(opt.device)
discriminator.to(opt.device)

encoder.train()
discriminator.train()

dataloader = load_data(opt)

generator.eval()

Tensor = torch.cuda.FloatTensor if  opt.device == 'cuda' else torch.FloatTensor
Exemple #27
0
    def __init__(self,
                 x_dim,
                 z_dim,
                 enc_architecture,
                 gen_architecture,
                 disc_architecture,
                 folder="./VAEGAN"):
        super(VAEGAN, self).__init__(
            x_dim, z_dim,
            [enc_architecture, gen_architecture, disc_architecture], folder)

        self._enc_architecture = self._architectures[0]
        self._gen_architecture = self._architectures[1]
        self._disc_architecture = self._architectures[2]

        ################# Define architecture
        last_layer_mean = [
            logged_dense, {
                "units": z_dim,
                "activation": tf.identity,
                "name": "Mean"
            }
        ]
        self._encoder_mean = Encoder(self._enc_architecture +
                                     [last_layer_mean],
                                     name="Encoder")
        last_layer_std = [
            logged_dense, {
                "units": z_dim,
                "activation": tf.identity,
                "name": "Std"
            }
        ]
        self._encoder_std = Encoder(self._enc_architecture + [last_layer_std],
                                    name="Encoder")

        self._gen_architecture[-1][1]["name"] = "Output"
        self._generator = Generator(self._gen_architecture, name="Generator")

        self._disc_architecture.append(
            [tf.layers.flatten, {
                "name": "Flatten"
            }])
        self._disc_architecture.append([
            logged_dense, {
                "units": 1,
                "activation": tf.nn.sigmoid,
                "name": "Output"
            }
        ])
        self._discriminator = Discriminator(self._disc_architecture,
                                            name="Discriminator")

        self._nets = [self._encoder_mean, self._generator, self._discriminator]

        ################# Connect inputs and networks
        self._mean_layer = self._encoder_mean.generate_net(self._X_input)
        self._std_layer = self._encoder_std.generate_net(self._X_input)

        self._output_enc_with_noise = self._mean_layer + tf.exp(
            0.5 * self._std_layer) * self._Z_input

        self._output_gen = self._generator.generate_net(
            self._output_enc_with_noise)
        self._output_gen_from_encoding = self._generator.generate_net(
            self._Z_input)

        assert self._output_gen.get_shape()[1:] == x_dim, (
            "Generator output must have shape of x_dim. Given: {}. Expected: {}."
            .format(self._output_gen.get_shape(), x_dim))

        self._output_disc_real = self._discriminator.generate_net(
            self._X_input)
        self._output_disc_fake_from_real = self._discriminator.generate_net(
            self._output_gen)
        self._output_disc_fake_from_latent = self._discriminator.generate_net(
            self._output_gen_from_encoding)

        ################# Finalize
        self._init_folders()
        self._verify_init()
    summand = px_cond * p_s * p_g / q_s / q_g
    print(summand)

    likelihood = torch.sum(summand) / summand.size(0)
    



FLAGS = parser.parse_args()

if __name__ == '__main__':
    """
    model definitions
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)

    encoder.load_state_dict(
        torch.load(os.path.join('checkpoints', FLAGS.encoder_save), map_location=lambda storage, loc: storage))
    decoder.load_state_dict(
        torch.load(os.path.join('checkpoints', FLAGS.decoder_save), map_location=lambda storage, loc: storage))

    encoder.cuda()
    decoder.cuda()

    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')

    # load data set and create data loader instance
    '''
Exemple #29
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
    """
    variable definition
    """

    X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)
    X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)
    X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)

    style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim)
    """
    loss definitions
    """
    cross_entropy_loss = nn.CrossEntropyLoss()
    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()

        cross_entropy_loss.cuda()

        X_1 = X_1.cuda()
        X_2 = X_2.cuda()
        X_3 = X_3.cuda()

        style_latent_space = style_latent_space.cuda()
    """
    optimizer and scheduler definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))

    reverse_cycle_optimizer = optim.Adam(list(encoder.parameters()),
                                         lr=FLAGS.initial_learning_rate,
                                         betas=(FLAGS.beta_1, FLAGS.beta_2))

    # divide the learning rate by a factor of 10 after 80 epochs
    auto_encoder_scheduler = optim.lr_scheduler.StepLR(auto_encoder_optimizer,
                                                       step_size=80,
                                                       gamma=0.1)
    reverse_cycle_scheduler = optim.lr_scheduler.StepLR(
        reverse_cycle_optimizer, step_size=80, gamma=0.1)
    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write(
                'Epoch\tIteration\tReconstruction_loss\tKL_divergence_loss\tReverse_cycle_loss\n'
            )

    # load data set and create data loader instance
    print('Loading MNIST paired dataset...')
    paired_mnist = MNIST_Paired(root='mnist',
                                download=True,
                                train=True,
                                transform=transform_config)
    loader = cycle(
        DataLoader(paired_mnist,
                   batch_size=FLAGS.batch_size,
                   shuffle=True,
                   num_workers=0,
                   drop_last=True))

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print(
            'Epoch #' + str(epoch) +
            '..........................................................................'
        )

        # update the learning rate scheduler
        auto_encoder_scheduler.step()
        reverse_cycle_scheduler.step()

        for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)):
            # A. run the auto-encoder reconstruction
            image_batch_1, image_batch_2, _ = next(loader)

            auto_encoder_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_mu_1, style_logvar_1, class_latent_space_1 = encoder(
                Variable(X_1))
            style_latent_space_1 = reparameterize(training=True,
                                                  mu=style_mu_1,
                                                  logvar=style_logvar_1)

            kl_divergence_loss_1 = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) -
                                 style_logvar_1.exp()))
            kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels *
                                     FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_1.backward(retain_graph=True)

            style_mu_2, style_logvar_2, class_latent_space_2 = encoder(
                Variable(X_2))
            style_latent_space_2 = reparameterize(training=True,
                                                  mu=style_mu_2,
                                                  logvar=style_logvar_2)

            kl_divergence_loss_2 = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + style_logvar_2 - style_mu_2.pow(2) -
                                 style_logvar_2.exp()))
            kl_divergence_loss_2 /= (FLAGS.batch_size * FLAGS.num_channels *
                                     FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_2.backward(retain_graph=True)

            reconstructed_X_1 = decoder(style_latent_space_1,
                                        class_latent_space_2)
            reconstructed_X_2 = decoder(style_latent_space_2,
                                        class_latent_space_1)

            reconstruction_error_1 = FLAGS.reconstruction_coef * mse_loss(
                reconstructed_X_1, Variable(X_1))
            reconstruction_error_1.backward(retain_graph=True)

            reconstruction_error_2 = FLAGS.reconstruction_coef * mse_loss(
                reconstructed_X_2, Variable(X_2))
            reconstruction_error_2.backward()

            reconstruction_error = (
                reconstruction_error_1 +
                reconstruction_error_2) / FLAGS.reconstruction_coef
            kl_divergence_error = (kl_divergence_loss_1 + kl_divergence_loss_2
                                   ) / FLAGS.kl_divergence_coef

            auto_encoder_optimizer.step()

            # B. reverse cycle
            image_batch_1, _, __ = next(loader)
            image_batch_2, _, __ = next(loader)

            reverse_cycle_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_latent_space.normal_(0., 1.)

            _, __, class_latent_space_1 = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))

            reconstructed_X_1 = decoder(Variable(style_latent_space),
                                        class_latent_space_1.detach())
            reconstructed_X_2 = decoder(Variable(style_latent_space),
                                        class_latent_space_2.detach())

            style_mu_1, style_logvar_1, _ = encoder(reconstructed_X_1)
            style_latent_space_1 = reparameterize(training=False,
                                                  mu=style_mu_1,
                                                  logvar=style_logvar_1)

            style_mu_2, style_logvar_2, _ = encoder(reconstructed_X_2)
            style_latent_space_2 = reparameterize(training=False,
                                                  mu=style_mu_2,
                                                  logvar=style_logvar_2)

            reverse_cycle_loss = FLAGS.reverse_cycle_coef * l1_loss(
                style_latent_space_1, style_latent_space_2)
            reverse_cycle_loss.backward()
            reverse_cycle_loss /= FLAGS.reverse_cycle_coef

            reverse_cycle_optimizer.step()

            if (iteration + 1) % 10 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' +
                      str(reconstruction_error.data.storage().tolist()[0]))
                print('KL-Divergence loss: ' +
                      str(kl_divergence_error.data.storage().tolist()[0]))
                print('Reverse cycle loss: ' +
                      str(reverse_cycle_loss.data.storage().tolist()[0]))

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                    epoch, iteration,
                    reconstruction_error.data.storage().tolist()[0],
                    kl_divergence_error.data.storage().tolist()[0],
                    reverse_cycle_loss.data.storage().tolist()[0]))

            # write to tensorboard
            writer.add_scalar(
                'Reconstruction loss',
                reconstruction_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)
            writer.add_scalar(
                'KL-Divergence loss',
                kl_divergence_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)
            writer.add_scalar(
                'Reverse cycle loss',
                reverse_cycle_loss.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)

        # save model after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.decoder_save))
            """
            save reconstructed images and style swapped image generations to check progress
            """
            image_batch_1, image_batch_2, _ = next(loader)
            image_batch_3, _, __ = next(loader)

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)
            X_3.copy_(image_batch_3)

            style_mu_1, style_logvar_1, _ = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))
            style_mu_3, style_logvar_3, _ = encoder(Variable(X_3))

            style_latent_space_1 = reparameterize(training=False,
                                                  mu=style_mu_1,
                                                  logvar=style_logvar_1)
            style_latent_space_3 = reparameterize(training=False,
                                                  mu=style_mu_3,
                                                  logvar=style_logvar_3)

            reconstructed_X_1_2 = decoder(style_latent_space_1,
                                          class_latent_space_2)
            reconstructed_X_3_2 = decoder(style_latent_space_3,
                                          class_latent_space_2)

            # save input image batch
            image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1))
            image_batch = np.concatenate(
                (image_batch, image_batch, image_batch), axis=3)
            imshow_grid(image_batch, name=str(epoch) + '_original', save=True)

            # save reconstructed batch
            reconstructed_x = np.transpose(
                reconstructed_X_1_2.cpu().data.numpy(), (0, 2, 3, 1))
            reconstructed_x = np.concatenate(
                (reconstructed_x, reconstructed_x, reconstructed_x), axis=3)
            imshow_grid(reconstructed_x,
                        name=str(epoch) + '_target',
                        save=True)

            style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1))
            style_batch = np.concatenate(
                (style_batch, style_batch, style_batch), axis=3)
            imshow_grid(style_batch, name=str(epoch) + '_style', save=True)

            # save style swapped reconstructed batch
            reconstructed_style = np.transpose(
                reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1))
            reconstructed_style = np.concatenate(
                (reconstructed_style, reconstructed_style,
                 reconstructed_style),
                axis=3)
            imshow_grid(reconstructed_style,
                        name=str(epoch) + '_style_target',
                        save=True)
Exemple #30
0
    def __init__(self, hparams):
        super().__init__()

        # output path
        self.output_path = "outputs/"

        # self.logger.log_hyperparams(hparams)  # log hyperparameters
        self.hparams = hparams

        self.files = ["../data.csv"]
        self.n = len(self.files)

        # load data
        self.data = [torch.from_numpy(np.genfromtxt(file, delimiter=',').transpose()[1:, 1:]).float() for file in self.files]

        self.datasets = [torch.utils.data.TensorDataset(data) for data in self.data]

        self.test_size = 60

        self.train_dataset, self.test_dataset = zip(*(
            torch.utils.data.random_split(dataset, (len(dataset) - self.test_size, self.test_size))
            for dataset in self.datasets))

        input_dim = 3000

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # distance matrix
        print("computing distance matrices")
        self.distance_matrix_train = self.distance_matrix(torch.stack(list(self.UnTensorDataset(self.train_dataset[0]))))
        if not self.hparams.correlation_distance_loss:
            self.distance_total_train = self.distance_matrix_train.sum(1)
            # self.distance_matrix_train_norm = self.distance_matrix_train / self.distance_total_train.unsqueeze(1)
            self.distance_matrix_train_norm = self.distance_matrix_train / self.distance_total_train.sum()
        print("done")

        # define VAEs
        self.E = [Encoder(input_dim, self.hparams.latent_dim, self.hparams.hypersphere).to(device) for _ in range(self.n)]  # hparams available for activation and dropout

        self.G = [Generator(self.hparams.latent_dim, input_dim).to(device) for _ in range(self.n)]

        # share weights
        if self.hparams.share_weights:
            for E in self.E[1:]:
                E.s1 = self.E[0].s1
                E.s2m = self.E[0].s2m
                E.s2v = self.E[0].s2v
            for G in self.G[1:]:
                G.s1 = self.G[0].s1
                G.s2 = self.G[0].s2

        # define discriminators
        if self.hparams.separate_D:
            self.D = [[Discriminator(input_dim).to(device) if i != j else None
                       for j in range(self.n)] for i in range(self.n)]
        else:
            self.D = [[Discriminator(input_dim).to(device) for _ in range(self.n)]] * self.n

        # named modules to make pytorch lightning happy
        self.E0 = self.E[0]
        self.G0 = self.G[0]

        # hyperspherical distribution
        self.p_z = HypersphericalUniform(self.hparams.latent_dim - 1, device=device) \
            if self.hparams.hypersphere else None

        # cache
        self.prev_g_loss = None
        self.current_z = self.forward(torch.stack(list(self.UnTensorDataset(self.train_dataset[0]))).to(device), first=True).z_a.detach()