示例#1
0
 def loss(self, KL=True):
     probs = self.prior_prob(self.input)
     if KL:
         self.loss = generator_loss(self.output, self.input, self.target_function(self.output), self.target_integral, probs)
     else:
         self.loss = tf.losses.mean_squared_error(self.target_function(self.input), self.output)
     return self.loss
示例#2
0
    def validate(self, netG, netsD, text_encoder, image_encoder):
        batch_size = self.batch_size
        nz = self.opts.GAN.Z_DIM
        real_labels, fake_labels, match_labels = self.prepare_labels()

        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))

        noise, fixed_noise = noise.to(self.device), fixed_noise.to(self.device)

        val_batches = len(self.val_loader)
        netG.eval()
        for i in range(len(netsD)):
            netsD[i].eval()

        inception_scorer = InceptionScore(val_batches, batch_size, val_batches)
        total_loss = []
        with torch.no_grad():
            for step, data in enumerate(self.val_loader):
                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                imgs, captions, class_ids, input_mask = prepare_data(
                    data, self.device)

                words_embs, sent_emb = self.text_encoder_forward(
                    text_encoder, captions, input_mask)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)
                errG_total, G_logs = generator_loss(netsD, image_encoder,
                                                    fake_imgs, real_labels,
                                                    words_embs, sent_emb,
                                                    match_labels, class_ids,
                                                    self.opts)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                total_loss.append(errG_total.data.item())
                inception_scorer.predict(fake_imgs[-1], step)

        netG.train()
        for i in range(len(netsD)):
            netsD[i].train()

        m, s = inception_scorer.get_ic_score()
        return m, s, sum(total_loss) / val_batches
示例#3
0
    def train(x, y):
        G.optim.zero_grad()
        D.optim.zero_grad()
        # How many chunks to split x and y into?
        x = torch.split(x, config.batch_size)

        y = torch.split(y, config.batch_size)
        counter = 0

        # Optionally toggle D and G's "require_grad"

        toggle_grad(D, True)
        toggle_grad(G, False)

        for step_index in range(config.num_D_steps):
            z_.sample_()
            y_.sample_()
            D_fake, D_real = GD(z_[:config.batch_size],
                                y_[:config.batch_size],
                                x[counter],
                                y[counter],
                                train_G=False)

            D_loss_real, D_loss_fake = losses.discriminator_loss(
                D_fake, D_real)
            D_loss = (D_loss_real + D_loss_fake)
            D_loss.backward()
            # counter += 1
            D.optim.step()

        # Optionally toggle "requires_grad"
        toggle_grad(D, False)
        toggle_grad(G, True)

        # Zero G's gradients by default before training G, for safety
        G.optim.zero_grad()

        z_.sample_()
        y_.sample_()
        D_fake = GD(z_, y_, train_G=True)
        G_loss = losses.generator_loss(D_fake)

        G_loss.backward()
        G.optim.step()

        out = {
            'G_loss': float(G_loss.item()),
            'D_loss_real': float(D_loss_real.item()),
            'D_loss_fake': float(D_loss_fake.item())
        }
        # Return G's loss and the components of D's loss.
        return out
示例#4
0
    diagonals = np.reshape(np.repeat(np.arange(0.0, 1.0, 0.01), dist_dim),
                           (-1, dist_dim))
    plt.plot(np.arange(0.0, 1.0, 0.01), np.exp(h.nn.predict(diagonals)), '.r',
             np.arange(0.0, 1.0, 0.01), tf.keras.backend.eval(f(diagonals)),
             '.g')
    plt.show()

    h_G_z = hG_out  #sess.run(reg_out, {reg_in: z})
    integral_f = tf.Variable(sess.run(tf_integrate(tf.exp(h.output), 1.0),
                                      {h.input: xs}),
                             trainable=False)
    print(tf.keras.backend.eval(integral_f))
    p_z = tf.reduce_prod(normal.prob(gen_in), (-1), keepdims=True)

    gen_loss = generator_loss(gen_out, gen_in, h_G_z, integral_f, p_z)
    gen_opt = tf.train.AdamOptimizer(0.0005, 0.9)
    #gen_opt = tf.train.GradientDescentOptimizer(0.001)
    grads_and_vars = gen_opt.compute_gradients(gen_loss, var_list=[gen_vars])
    grads = [x[0] for x in grads_and_vars]
    vars = [x[1] for x in grads_and_vars]
    grads, grad_norm = tf.clip_by_global_norm(grads, 2)
    gen_train = gen_opt.apply_gradients(zip(grads, vars))
    sess.run(tf.global_variables_initializer())
    h.nn.load_weights(reg_file)

    for e in range(1, 5):

        batches = zip(np.reshape(zs, (-1, 1 * 512, dist_dim)),
                      np.reshape(probs, (-1, 1 * 512, dist_dim)))
示例#5
0
    def optimize(self, data, current_step):
        if config.use_apex:
            from apex import amp
        losses_dict = OrderedDict()
        for param in self.discriminator.parameters():
            param.requires_grad = True
        for param in self.generator_params:
            param.requires_grad = True

        pseudo_labels_a, embeddings_a = self.encoder(data['image_a'])
        pseudo_labels_b, embeddings_b = self.encoder(data['image_b'])

        if config.use_mixing:
            num_0d_units = 1 if config.size_0d_unit > 0 else 0
            random = np.random.randint(
                2,
                size=[
                    num_0d_units + config.num_1d_units + config.num_2d_units,
                    config.batch_size, 1, 1
                ]).tolist()
            random_tensor = torch.tensor(random,
                                         dtype=torch.float,
                                         requires_grad=False).to("cuda")

            normalized_embeddings_from_a_mix = self.rotate(embeddings_a,
                                                           pseudo_labels_a,
                                                           random_tensor,
                                                           inverse=True)
            embeddings_a_to_mix = self.rotate(normalized_embeddings_from_a_mix,
                                              pseudo_labels_b, random_tensor)
            image_mix_hat = self.decoder(embeddings_a_to_mix)

            pseudo_labels_mix_hat, embeddings_mix_hat = self.encoder(
                image_mix_hat)

        # a -> b
        normalized_embeddings_from_a = self.rotate(embeddings_a,
                                                   pseudo_labels_a,
                                                   inverse=True)
        embeddings_a_to_b = self.rotate(normalized_embeddings_from_a,
                                        pseudo_labels_b)
        image_b_hat = self.decoder(embeddings_a_to_b)

        # optimize discriminator
        real = self.discriminator(data['image_b'])
        fake = self.discriminator(image_b_hat.detach())

        losses_dict['discriminator'] = losses.discriminator_loss(real=real,
                                                                 fake=fake)
        losses_dict['generator'] = losses.generator_loss(fake=fake)
        discriminator_loss = losses_dict[
            'discriminator'] * config.coeff_discriminator_loss
        # Warm up period for generator losses
        losses_dict['discrim_coeff'] = torch.tensor(
            max(min(1.0, current_step / 20000.0), 0.0))

        self.discriminator_optimizer.zero_grad()
        if config.use_apex:
            with amp.scale_loss(discriminator_loss,
                                self.discriminator_optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            discriminator_loss.backward()
        self.discriminator_optimizer.step()

        for param in self.discriminator.parameters():
            param.requires_grad = False

        # for generator update

        losses_dict['l1'] = losses.reconstruction_l1_loss(x=data['image_b'],
                                                          x_hat=image_b_hat)
        total_loss = losses_dict['l1'] * config.coeff_l1_loss
        if not config.semi_supervised:
            losses_dict['gaze_a'] = (losses.gaze_angular_loss(
                y=data['gaze_a'],
                y_hat=pseudo_labels_a[-1]) + losses.gaze_angular_loss(
                    y=data['gaze_b'], y_hat=pseudo_labels_b[-1])) / 2
            losses_dict['head_a'] = (losses.gaze_angular_loss(
                y=data['head_a'],
                y_hat=pseudo_labels_a[-2]) + losses.gaze_angular_loss(
                    y=data['head_b'], y_hat=pseudo_labels_b[-2])) / 2
        else:
            losses_dict['gaze_a'] = losses.gaze_angular_loss(
                y=data['gaze_a'], y_hat=pseudo_labels_a[-1])
            losses_dict['head_a'] = losses.gaze_angular_loss(
                y=data['head_a'], y_hat=pseudo_labels_a[-2])

            losses_dict['gaze_a_unlabeled'] = losses.gaze_angular_loss(
                y=data['gaze_b'], y_hat=pseudo_labels_b[-1])
            losses_dict['head_a_unlabeled'] = losses.gaze_angular_loss(
                y=data['head_b'], y_hat=pseudo_labels_b[-2])

        total_loss += (losses_dict['gaze_a'] +
                       losses_dict['head_a']) * config.coeff_gaze_loss

        fake = self.discriminator(image_b_hat)
        generator_loss = losses.generator_loss(fake=fake)
        total_loss += generator_loss * config.coeff_discriminator_loss * losses_dict[
            'discrim_coeff']

        if config.coeff_embedding_consistency_loss != 0:
            normalized_embeddings_from_a = self.rotate(embeddings_a,
                                                       pseudo_labels_a,
                                                       inverse=True)
            normalized_embeddings_from_b = self.rotate(embeddings_b,
                                                       pseudo_labels_b,
                                                       inverse=True)
            flattened_normalized_embeddings_from_a = torch.cat([
                e.reshape(e.shape[0], -1) for e in normalized_embeddings_from_a
            ],
                                                               dim=1)
            flattened_normalized_embeddings_from_b = torch.cat([
                e.reshape(e.shape[0], -1) for e in normalized_embeddings_from_b
            ],
                                                               dim=1)
            losses_dict['embedding_consistency'] = (1.0 - torch.mean(
                F.cosine_similarity(flattened_normalized_embeddings_from_a,
                                    flattened_normalized_embeddings_from_b,
                                    dim=-1)))
            total_loss += losses_dict[
                'embedding_consistency'] * config.coeff_embedding_consistency_loss

        if config.coeff_disentangle_embedding_loss != 0:
            assert config.use_mixing is True
            flattened_before_c = torch.cat(
                [e.reshape(e.shape[0], -1) for e in embeddings_a_to_mix],
                dim=1)
            flattened_after_c = torch.cat(
                [e.reshape(e.shape[0], -1) for e in embeddings_mix_hat], dim=1)
            losses_dict['embedding_disentangle'] = (1.0 - torch.mean(
                F.cosine_similarity(
                    flattened_before_c, flattened_after_c, dim=-1)))
            total_loss += losses_dict[
                'embedding_disentangle'] * config.coeff_disentangle_embedding_loss
        if config.coeff_disentangle_pseudo_label_loss != 0:
            assert config.use_mixing is True
            losses_dict['label_disentangle'] = 0
            pseudo_labels_a_b_mix = []
            for i in range(len(pseudo_labels_a)):  # pseudo code
                if pseudo_labels_b[i] is not None:
                    pseudo_labels_a_b_mix.append(
                        pseudo_labels_b[i] * random_tensor[i].squeeze(-1) +
                        pseudo_labels_a[i] *
                        (1 - random_tensor[i].squeeze(-1)))
                else:
                    pseudo_labels_a_b_mix.append(None)

            for y, y_hat in zip(pseudo_labels_a_b_mix[-2:],
                                pseudo_labels_mix_hat[-2:]):
                if y is not None:
                    losses_dict[
                        'label_disentangle'] += losses.gaze_angular_loss(
                            y, y_hat)
            total_loss += losses_dict[
                'label_disentangle'] * config.coeff_disentangle_pseudo_label_loss

        feature_h, gaze_h, head_h = self.GazeHeadNet_train(image_b_hat, True)
        feature_t, gaze_t, head_t = self.GazeHeadNet_train(
            data['image_b'], True)
        losses_dict['redirection_feature_loss'] = 0
        for i in range(len(feature_h)):
            losses_dict['redirection_feature_loss'] += nn.functional.mse_loss(
                feature_h[i], feature_t[i].detach())
        total_loss += losses_dict[
            'redirection_feature_loss'] * config.coeff_redirection_feature_loss
        losses_dict['gaze_redirection'] = losses.gaze_angular_loss(
            y=gaze_t.detach(), y_hat=gaze_h)
        total_loss += losses_dict[
            'gaze_redirection'] * config.coeff_redirection_gaze_loss
        losses_dict['head_redirection'] = losses.gaze_angular_loss(
            y=head_t.detach(), y_hat=head_h)
        total_loss += losses_dict[
            'head_redirection'] * config.coeff_redirection_gaze_loss
        self.generator_optimizer.zero_grad()
        if config.use_apex:
            with amp.scale_loss(total_loss,
                                [self.generator_optimizer]) as scaled_loss:
                scaled_loss.backward()
        else:
            total_loss.backward()
        self.generator_optimizer.step()

        return losses_dict, image_b_hat
示例#6
0
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models(
        )
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = self.opts.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))

        noise, fixed_noise = noise.to(self.device), fixed_noise.to(self.device)

        gen_iterations = 0

        lr_schedulers = []
        if self.use_lr_scheduler:
            for i in range(len(optimizersD)):
                lr_scheduler = LambdaLR(optimizersD[i],
                                        lr_lambda=lambda epoch: 0.998**epoch)

                for m in range(start_epoch):
                    lr_scheduler.step()
                lr_schedulers.append(lr_scheduler)

        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.train_loader)
            step = 0

            for i in range(len(lr_schedulers)):
                lr_schedulers[i].step()

            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = next(data_iter)
                imgs, captions, class_ids, captions_mask = prepare_data(
                    data, self.device)

                words_embs, sent_emb = self.text_encoder_forward(
                    text_encoder, captions, captions_mask)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels,
                                              fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.data.item())

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, class_ids, self.opts)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.data.item()
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 10 == 0:
                    print("Epoch: " + str(epoch) + " Step: " + str(step) +
                          " " + D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 300 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG,
                                          fixed_noise,
                                          sent_emb,
                                          words_embs,
                                          mask,
                                          image_encoder,
                                          captions,
                                          epoch,
                                          step,
                                          name='average')
                    load_params(netG, backup_para)

            is_mean, is_std, error_G_val = self.validate(
                netG, netsD, text_encoder, image_encoder)
            self.val_logger.write("{} {} {}\n".format(epoch, is_mean, is_std))
            self.val_logger.flush()

            self.losses_logger.write("{} {} {}\n".format(
                epoch, errG_total.data.item(), error_G_val))
            self.losses_logger.flush()

            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches,
                   errD_total.data.item(), errG_total.data.item(),
                   end_t - start_t))

            print("IS: {} {}".format(is_mean, is_std))
            if epoch % self.opts.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
示例#7
0
    def train(x, y):
        G.optim.zero_grad()
        D.optim.zero_grad()
        # How many chunks to split x and y into?
        x = torch.split(x, config['batch_size'])
        y = torch.split(y, config['batch_size'])
        counter = 0
        lambda_D = config['lambda_D']
        lambda_G = config['lambda_G']

        # Optionally toggle D and G's "require_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, True)
            utils.toggle_grad(G, False)

        for step_index in range(config['num_D_steps']):
            # If accumulating gradients, loop multiple times before an optimizer step
            D.optim.zero_grad()
            for accumulation_index in range(config['num_D_accumulations']):
                z_.sample_()
                y_.sample_()
                D_scores, D_scores_rotate90, D_scores_rotate180, D_scores_rotate270, \
                D_scores_croptl, D_scores_croptr, D_scores_cropbl, D_scores_cropbr, \
                D_scores_translation, D_scores_cutout = GD(z_[:config['batch_size']], y_[:config['batch_size']],
                              x[counter], y[counter], train_G=False, policy=config['DiffAugment'],
                              CR=config['CR'] > 0, CR_augment=config['CR_augment'])

                D_loss_CR = 0
                if config['CR'] > 0:
                    D_fake, D_real, D_real_aug = D_scores
                    D_loss_CR = torch.mean(
                        (D_real_aug - D_real)**2) * config['CR']
                else:
                    D_fake, D_real = D_scores
                    # rotation
                    D_fake_rotate90, D_real_rotate90 = D_scores_rotate90
                    D_fake_rotate180, D_real_rotate180 = D_scores_rotate180
                    D_fake_rotate270, D_real_rotate270 = D_scores_rotate270
                    # cropping
                    D_fake_croptl, D_real_croptl = D_scores_croptl
                    D_fake_croptr, D_real_croptr = D_scores_croptr
                    D_fake_cropbl, D_real_cropbl = D_scores_cropbl
                    D_fake_cropbr, D_real_cropbr = D_scores_cropbr
                    # translation & cutout
                    D_fake_translation, D_real_translation = D_scores_translation
                    D_fake_cutout, D_real_cutout = D_scores_cutout

                # Compute components of D's loss, average them, and divide by
                # the number of gradient accumulations
                D_loss_real, D_loss_fake = losses.discriminator_loss(
                    D_fake, D_real)
                # rotation
                D_loss_real_rotate90, D_loss_fake_rotate90 = losses.discriminator_loss(
                    D_fake_rotate90, D_real_rotate90)
                D_loss_real_rotate180, D_loss_fake_rotate180 = losses.discriminator_loss(
                    D_fake_rotate180, D_real_rotate180)
                D_loss_real_rotate270, D_loss_fake_rotate270 = losses.discriminator_loss(
                    D_fake_rotate270, D_real_rotate270)
                # croping
                D_loss_real_croptl, D_loss_fake_croptl = losses.discriminator_loss(
                    D_fake_croptl, D_real_croptl)
                D_loss_real_croptr, D_loss_fake_croptr = losses.discriminator_loss(
                    D_fake_croptr, D_real_croptr)
                D_loss_real_cropbl, D_loss_fake_cropbl = losses.discriminator_loss(
                    D_fake_cropbl, D_real_cropbl)
                D_loss_real_cropbr, D_loss_fake_cropbr = losses.discriminator_loss(
                    D_fake_cropbr, D_real_cropbr)
                # translation and cutout
                D_loss_real_translation, D_loss_fake_translation = losses.discriminator_loss(
                    D_fake_translation, D_real_translation)
                D_loss_real_cutout, D_loss_fake_cutout = losses.discriminator_loss(
                    D_fake_cutout, D_real_cutout)

                D_loss = D_loss_real + D_loss_fake + D_loss_CR
                # rotation
                D_loss_rotate90 = D_loss_real_rotate90 + D_loss_fake_rotate90
                D_loss_rotate180 = D_loss_real_rotate180 + D_loss_fake_rotate180
                D_loss_rotate270 = D_loss_real_rotate270 + D_loss_fake_rotate270
                # cropping
                D_loss_croptl = D_loss_real_croptl + D_loss_fake_croptl
                D_loss_croptr = D_loss_real_croptr + D_loss_fake_croptr
                D_loss_cropbl = D_loss_real_cropbl + D_loss_fake_cropbl
                D_loss_cropbr = D_loss_real_cropbr + D_loss_fake_cropbr
                # translation and cutout
                D_loss_translation = D_loss_real_translation + D_loss_fake_translation
                D_loss_cutout = D_loss_real_cutout + D_loss_fake_cutout

                D_loss = D_loss + lambda_D/4*(D_loss + D_loss_rotate90 + D_loss_rotate180 + D_loss_rotate270) \
                                + lambda_D/5*(D_loss + D_loss_croptl + D_loss_croptr + D_loss_cropbl + D_loss_cropbr) \
                                + lambda_D/2*(D_loss + D_loss_translation) \
                                + lambda_D/2*(D_loss + D_loss_cutout)

                D_loss = D_loss / float(config['num_D_accumulations'])
                D_loss.backward(retain_graph=True)
                counter += 1

            # Optionally apply ortho reg in D
            if config['D_ortho'] > 0.0:
                # Debug print to indicate we're using ortho reg in D.
                print('using modified ortho reg in D')
                utils.ortho(D, config['D_ortho'])

            D.optim.step()

        # Optionally toggle "requires_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, False)
            utils.toggle_grad(G, True)

        # Zero G's gradients by default before training G, for safety
        G.optim.zero_grad()

        if not config['fix_G']:
            # If accumulating gradients, loop multiple times
            for accumulation_index in range(config['num_G_accumulations']):
                z_.sample_()
                y_.sample_()
                D_fake, D_fake_rotate90, D_fake_rotate180, D_fake_rotate270, \
                        D_fake_croptl, D_fake_croptr, D_fake_cropbl, D_fake_cropbr, D_fake_translation, D_fake_cutout = GD(z_, y_, train_G=True, policy=config['DiffAugment'])

                G_loss_rotate0 = losses.generator_loss(D_fake) / float(
                    config['num_G_accumulations'])
                # rotation
                G_loss_rotate90 = losses.generator_loss(
                    D_fake_rotate90) / float(config['num_G_accumulations'])
                G_loss_rotate180 = losses.generator_loss(
                    D_fake_rotate180) / float(config['num_G_accumulations'])
                G_loss_rotate270 = losses.generator_loss(
                    D_fake_rotate270) / float(config['num_G_accumulations'])
                # cropping
                G_loss_croptl = losses.generator_loss(D_fake_croptl) / float(
                    config['num_G_accumulations'])
                G_loss_croptr = losses.generator_loss(D_fake_croptr) / float(
                    config['num_G_accumulations'])
                G_loss_cropbl = losses.generator_loss(D_fake_cropbl) / float(
                    config['num_G_accumulations'])
                G_loss_cropbr = losses.generator_loss(D_fake_cropbr) / float(
                    config['num_G_accumulations'])
                # translation and cutout
                G_loss_translation = losses.generator_loss(
                    D_fake_translation) / float(config['num_G_accumulations'])
                G_loss_cutout = losses.generator_loss(D_fake_cutout) / float(
                    config['num_G_accumulations'])

                G_loss = G_loss_rotate0 + lambda_G/4.*(G_loss_rotate0 + G_loss_rotate90 + G_loss_rotate180 + G_loss_rotate270) \
                                        + lambda_G/5.*(G_loss_rotate0 + G_loss_croptl + G_loss_croptr + G_loss_cropbl + G_loss_cropbr) \
                                        + lambda_G/2.*(G_loss_rotate0 + G_loss_translation) \
                                        + lambda_G/2.*(G_loss_rotate0 + G_loss_cutout)

                G_loss.backward()

            # Optionally apply modified ortho reg in G
            if config['G_ortho'] > 0.0:
                # Debug print to indicate we're using ortho reg in G
                print('using modified ortho reg in G')
                # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
                utils.ortho(
                    G,
                    config['G_ortho'],
                    blacklist=[param for param in G.shared.parameters()])
            G.optim.step()

            # If we have an ema, update it, regardless of if we test with it or not
            if config['ema']:
                ema.update(state_dict['itr'])

        out = {
            'G_loss': float(G_loss.item()) if not config['fix_G'] else 0,
            'D_loss_real': float(D_loss_real.item()),
            'D_loss_fake': float(D_loss_fake.item()),
        }
        if config['CR'] > 0:
            out['D_loss_CR'] = float(D_loss_CR.item())
        # Return G's loss and the components of D's loss.
        return out
def train(data, epochs, batch_size=1, gen_lr=5e-6, disc_lr=5e-7, epoch_offset=0):
    generator = Generator(input_shape=[None,None,2])
    discriminator = Discriminator(input_shape=[None,None,1])

    generator_optimizer = tf.keras.optimizers.Adam(gen_lr)
    discriminator_optimizer = tf.keras.optimizers.Adam(disc_lr)

    model_name = data['training'].origin+'_2_any'
    checkpoint_prefix = os.path.join(CHECKPOINT_DIR, model_name)
    if(not os.path.isdir(checkpoint_prefix)):
        os.makedirs(checkpoint_prefix)
    else:
        if(os.path.isfile(os.path.join(checkpoint_prefix, 'generator.h5'))):
            generator.load_weights(os.path.join(checkpoint_prefix, 'generator.h5'), by_name=True)
            print('Generator weights restorred from ' + checkpoint_prefix)

        if(os.path.isfile(os.path.join(checkpoint_prefix, 'discriminator.h5'))):
            discriminator.load_weights(os.path.join(checkpoint_prefix, 'discriminator.h5'), by_name=True)
            print('Discriminator weights restorred from ' + checkpoint_prefix)

    # Get the number of batches in the training set
    epoch_size = data['training'].__len__()

    print()
    print("Started training with the following parameters: ")
    print("\tCheckpoints: \t", checkpoint_prefix)
    print("\tEpochs: \t", epochs)
    print("\tgen_lr: \t", gen_lr)
    print("\tdisc_lr: \t", disc_lr)
    print("\tBatchSize: \t", batch_size)
    print("\tnBatches: \t", epoch_size)
    print()

    # Precompute the test input and target for validation
    audio_input = load_audio(os.path.join(TEST_AUDIOS_PATH, data['training'].origin+'.wav'))
    mag_input, phase = forward_transform(audio_input)
    mag_input = amplitude_to_db(mag_input)
    test_input = slice_magnitude(mag_input, mag_input.shape[0])
    test_input = (test_input * 2) - 1

    test_inputs = []
    test_targets = []

    for t in data['training'].target:
        audio_target = load_audio(os.path.join(TEST_AUDIOS_PATH, t+'.wav'))
        mag_target, _ = forward_transform(audio_target)
        mag_target = amplitude_to_db(mag_target)
        test_target = slice_magnitude(mag_target, mag_target.shape[0])
        test_target = (test_target * 2) - 1

        test_target_perm = test_target[np.random.permutation(test_target.shape[0]),:,:,:]
        test_inputs.append(np.concatenate([test_input, test_target_perm], axis=3))
        test_targets.append(test_target)

    gen_mae_list, gen_mae_val_list  = [], []
    gen_loss_list, gen_loss_val_list  = [], []
    disc_loss_list, disc_loss_val_list  = [], []
    for epoch in range(epochs):
        gen_mae_total, gen_mae_val_total = 0, 0
        gen_loss_total, gen_loss_val_total = 0, 0
        disc_loss_total, disc_loss_val_total = 0, 0

        print('Epoch {}/{}'.format((epoch+1)+epoch_offset, epochs+epoch_offset))
        progbar = tf.keras.utils.Progbar(epoch_size)
        for i in range(epoch_size):
            # Get the data from the DataGenerator
            input_image, target = data['training'].__getitem__(i) 
            with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
                # Generate a fake image
                gen_output = generator(input_image, training=True)
                
                # Train the discriminator
                disc_real_output = discriminator([input_image[:,:,:,0:1], target], training=True)
                disc_generated_output = discriminator([input_image[:,:,:,0:1], gen_output], training=True)
                
                # Compute the losses
                gen_mae = l1_loss(target, gen_output)
                gen_loss = generator_loss(disc_generated_output, gen_mae)
                disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
                
                # Compute the gradients
                generator_gradients = gen_tape.gradient(gen_loss,generator.trainable_variables)
                discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
                
                # Apply the gradients
                generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
                discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))

                # Update the progress bar
                gen_mae = gen_mae.numpy()
                gen_loss = gen_loss.numpy()
                disc_loss = disc_loss.numpy()
                
                gen_mae_total += gen_mae
                gen_loss_total += gen_loss
                disc_loss_total += disc_loss

                progbar.add(1, values=[
                                        ("gen_mae", gen_mae), 
                                        ("gen_loss", gen_loss), 
                                        ("disc_loss", disc_loss)
                                    ])

        gen_mae_list.append(gen_mae_total/epoch_size)
        gen_mae_val_list.append(gen_mae_val_total/epoch_size)
        gen_loss_list.append(gen_loss_total/epoch_size)
        gen_loss_val_list.append(gen_loss_val_total/epoch_size)
        disc_loss_list.append(disc_loss_total/epoch_size)
        disc_loss_val_list.append(disc_loss_val_total/epoch_size)

        history = pd.DataFrame({
                                    'gen_mae': gen_mae_list, 
                                    'gen_mae_val': gen_mae_val_list, 
                                    'gen_loss': gen_loss_list,
                                    'gen_loss_val': gen_loss_val_list,
                                    'disc_loss': disc_loss_list,
                                    'disc_loss_val': disc_loss_val_list
                                })
        write_csv(history, os.path.join(checkpoint_prefix, 'history.csv'))

        epoch_output = os.path.join(OUTPUT_PATH, model_name, str((epoch+1)+epoch_offset).zfill(3))
        init_directory(epoch_output)

        # Generate audios and save spectrograms for the entire audios
        for j in range(len(data['training'].target)):
            prediction = generator(test_inputs[j], training=False)
            prediction = (prediction + 1) / 2
            generate_images(prediction, (test_inputs[j] + 1) / 2, (test_targets[j] + 1) / 2, os.path.join(epoch_output, 'spectrogram_'+data['training'].target[j]))
            generate_audio(prediction, phase, os.path.join(epoch_output, 'audio_'+data['training'].target[j]+'.wav'))
        print('Epoch outputs saved in ' + epoch_output)

        # Save the weights
        generator.save_weights(os.path.join(checkpoint_prefix, 'generator.h5'))
        discriminator.save_weights(os.path.join(checkpoint_prefix, 'discriminator.h5'))
        print('Weights saved in ' + checkpoint_prefix)

        # Callback at the end of the epoch for the DataGenerator
        data['training'].on_epoch_end()
示例#9
0
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels, fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG, fixed_noise, sent_emb,
                                          words_embs, mask, image_encoder,
                                          captions, cap_lens, epoch, name='average')
                    load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
                  % (epoch, self.max_epoch, self.num_batches,
                     errD_total.item(), errG_total.item(),
                     end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
示例#10
0
    def train(x, y, stage):
        G.optim.zero_grad()
        D.optim.zero_grad()
        M.optim.zero_grad()  # yaxing # How many chunks to split x and y into?
        x = torch.split(x, config['batch_size'])
        y = torch.split(y, config['batch_size'])
        counter = 0

        # Optionally toggle D and G's "require_grad"
        if config['toggle_grads']:  # yaxing: hert it is True
            utils.toggle_grad(D, True)
            utils.toggle_grad(G, False)
            utils.toggle_grad(M, False)  # yaxing

        for step_index in range(config['num_D_steps']):
            # If accumulating gradients, loop multiple times before an optimizer step
            D.optim.zero_grad()
            for accumulation_index in range(config['num_D_accumulations']):
                z_.sample_()
                y_.sample_()
                # yaxing: set gy and dy is equal 0, since we donot know label
                D_fake, D_real = GD(z_[:config['batch_size']],
                                    y_[:config['batch_size']],
                                    x[counter],
                                    y[counter],
                                    train_G=False,
                                    split_D=config['split_D'])

                # Compute components of D's loss, average them, and divide by
                # the number of gradient accumulations
                D_loss_real, D_loss_fake = losses.discriminator_loss(
                    D_fake, D_real)
                D_loss = (D_loss_real + D_loss_fake) / float(
                    config['num_D_accumulations'])
                D_loss.backward()
                counter += 1

            # Optionally apply ortho reg in D
            if config['D_ortho'] > 0.0:  # yaxing: hert it is 0.0
                # Debug print to indicate we're using ortho reg in D.
                print('using modified ortho reg in D')
                utils.ortho(D, config['D_ortho'])

            D.optim.step()

        # Optionally toggle "requires_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, False)
            if stage == 1:
                utils.toggle_grad(G, False)  # yaxing
            else:
                utils.toggle_grad(G, True)  # yaxing
            utils.toggle_grad(M, True)  # yaxing

        # Zero G's gradients by default before training G, for safety
        G.optim.zero_grad()
        M.optim.zero_grad()  # yaxing

        # If accumulating gradients, loop multiple times
        for accumulation_index in range(
                config['num_G_accumulations']):  # yaxing: hert it is 1
            z_.sample_()
            y_.sample_()
            #D_fake = GD(z_, y_, train_G=True, split_D=config['split_D'])
            # yaxing: set gy and dy is equal 0, since we donot know label
            D_fake, M_regu = GD(z_,
                                y_,
                                train_G=True,
                                split_D=config['split_D'],
                                train_M=True,
                                M_regu=True)
            #G_loss = losses.generator_loss(D_fake) / float(config['num_G_accumulations'])
            M_loss = losses.generator_loss(D_fake, M_regu) / float(
                config['num_G_accumulations'])
            #pdb.set_trace()
            #G_loss.backward()
            M_loss.backward()

        # Optionally apply modified ortho reg in G
        if config['G_ortho'] > 0.0:  # yaxing: hert it is 0.0
            print('using modified ortho reg in G'
                  )  # Debug print to indicate we're using ortho reg in G
            # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
            utils.ortho(G,
                        config['G_ortho'],
                        blacklist=[param for param in G.shared.parameters()])
        if stage == 2:
            G.optim.step()
        M.optim.step()

        # If we have an ema, update it, regardless of if we test with it or not
        if config['ema']:
            ema.update(state_dict['itr'])

        #out = {'G_loss': float(G_loss.item()),
        out = {
            'G_loss': float(M_loss.item()),
            'D_loss_real': float(D_loss_real.item()),
            'D_loss_fake': float(D_loss_fake.item())
        }
        # Return G's loss and the components of D's loss.
        return out
示例#11
0
  def train(x, y):
    G.optim.zero_grad()
    D.optim.zero_grad()
    # How many chunks to split x and y into?
    x = torch.split(x, config['batch_size'])
    y = torch.split(y, config['batch_size'])
    counter = 0
    
    # Optionally toggle D and G's "require_grad"

    utils.toggle_grad(D, True)
    utils.toggle_grad(G, False)
      
    for step_index in range(config['num_D_steps']):
      # If accumulating gradients, loop multiple times before an optimizer step
      for accumulation_index in range(config['num_D_accumulations']):
        z_.sample_()
        y_.sample_()
        D_fake, D_real, mi, c_cls = GD(z_[:config['batch_size']], y_[:config['batch_size']],
                            x[counter], y[counter], train_G=False, 
                            split_D=config['split_D'])
         
        # Compute components of D's loss, average them, and divide by 
        # the number of gradient accumulations
        D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real)
        C_loss = 0
        if config['loss_type'] == 'Twin_AC':
            C_loss += F.cross_entropy(c_cls[D_fake.shape[0]:] ,y[counter]) + F.cross_entropy(mi[:D_fake.shape[0]] ,y_)
        if config['loss_type'] == 'Twin_AC_M':
            C_loss += hinge_multi(c_cls[D_fake.shape[0]:], y[counter]) + hinge_multi(mi[:D_fake.shape[0]], y_)
        if config['loss_type'] == 'AC':
            C_loss += F.cross_entropy(c_cls[D_fake.shape[0]:] ,y[counter])
        D_loss = (D_loss_real + D_loss_fake + C_loss*config['AC_weight']) / float(config['num_D_accumulations'])
        D_loss.backward()
        counter += 1
        
      # Optionally apply ortho reg in D
      if config['D_ortho'] > 0.0:
        # Debug print to indicate we're using ortho reg in D.
        print('using modified ortho reg in D')
        utils.ortho(D, config['D_ortho'])
      
      D.optim.step()
    
    # Optionally toggle "requires_grad"
    utils.toggle_grad(D, False)
    utils.toggle_grad(G, True)
      
    # Zero G's gradients by default before training G, for safety
    G.optim.zero_grad()
    for step_index in range(config['num_G_steps']):
        for accumulation_index in range(config['num_G_accumulations']):
            z_.sample_()
            y_.sample_()
            D_fake, G_z, mi, c_cls = GD(z_, y_, train_G=True, split_D=config['split_D'], return_G_z=True)
            C_loss = 0
            MI_loss = 0
            if config['loss_type'] == 'AC' or config['loss_type'] == 'Twin_AC':
                C_loss = F.cross_entropy(c_cls, y_)
                if config['loss_type'] == 'Twin_AC':
                    MI_loss = F.cross_entropy(mi, y_)
            if config['loss_type'] == 'Twin_AC_M':
                C_loss = hinge_multi(c_cls, y_,hinge=False)
                MI_loss = hinge_multi(mi, y_, hinge=False)

            G_loss = losses.generator_loss(D_fake) / float(config['num_G_accumulations'])
            C_loss = C_loss / float(config['num_G_accumulations'])
            MI_loss = MI_loss / float(config['num_G_accumulations'])
            (G_loss + (C_loss - MI_loss)*config['AC_weight']).backward()

        # Optionally apply modified ortho reg in G
        if config['G_ortho'] > 0.0:
            print('using modified ortho reg in G')  # Debug print to indicate we're using ortho reg in G
            # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
            utils.ortho(G, config['G_ortho'],
                        blacklist=[param for param in G.shared.parameters()])
        G.optim.step()
    
    # If we have an ema, update it, regardless of if we test with it or not
    if config['ema']:
      ema.update(state_dict['itr'])
    
    out = {'G_loss': float(G_loss.item()), 
            'D_loss_real': float(D_loss_real.item()),
            'D_loss_fake': float(D_loss_fake.item()),
            'C_loss': C_loss,
            'MI_loss': MI_loss}
    # Return G's loss and the components of D's loss.
    return out
示例#12
0
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
  net_g, net_d = nets
  optim_g, optim_d = optims
  scheduler_g, scheduler_d = schedulers
  train_loader, eval_loader = loaders
  if writers is not None:
    writer, writer_eval = writers

  train_loader.batch_sampler.set_epoch(epoch)
  global global_step

  net_g.train()
  net_d.train()
  for batch_idx, (spec, spec_lengths, y, y_lengths) in enumerate(train_loader):
    spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True)
    y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True)


    with autocast(enabled=hps.train.fp16_run):
      mel = spec_to_mel_torch(
          spec, 
          hps.data.filter_length, 
          hps.data.n_mel_channels, 
          hps.data.sampling_rate,
          hps.data.mel_fmin, 
          hps.data.mel_fmax)
#       print('check',mel.shape)/
      y_hat, ids_slice, x_mask, z_mask,\
      (z, z_p, m_p, logs_p, m_q, logs_q) = net_g(mel, spec_lengths, spec, spec_lengths)
#       print('check',log_det_j_sum.shape, m_p.shape)

      y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
      y_hat_mel = mel_spectrogram_torch(
          y_hat.squeeze(1), 
          hps.data.filter_length, 
          hps.data.n_mel_channels, 
          hps.data.sampling_rate, 
          hps.data.hop_length, 
          hps.data.win_length, 
          hps.data.mel_fmin, 
          hps.data.mel_fmax
      )

      y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice 
      
      # NDA is effective?
      batch_size=y.size(0)
      y_jig1 = y.view(batch_size,4,-1)
      rand_idx = torch.randperm(4)
      y_jig2 = y_jig1[:,rand_idx,:]
      y_jigsaw = y_jig2.view(batch_size,1,-1)
#             print(rand_idx)
      check_idx = torch.tensor([0,1,2,3])
      if (rand_idx ==check_idx).sum()==4:
          y_jigsaw = y_hat
      else:
          y_jigsaw = y_jigsaw
    
      y_negative = 0.75*y_hat + 0.25*y_jigsaw
    
    
      # Discriminator
      y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_negative.detach())
    
    
    
      with autocast(enabled=False):
        loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
        loss_disc_all = loss_disc
    optim_d.zero_grad()
    scaler.scale(loss_disc_all).backward()
    scaler.unscale_(optim_d)
    grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
    scaler.step(optim_d)

    with autocast(enabled=hps.train.fp16_run):
      # Generator
      y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
      with autocast(enabled=False):
        loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
        loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl

        loss_fm = feature_loss(fmap_r, fmap_g)
        loss_gen, losses_gen = generator_loss(y_d_hat_g)
        loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
    optim_g.zero_grad()
    scaler.scale(loss_gen_all).backward()
    scaler.unscale_(optim_g)
    grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
    scaler.step(optim_g)
    scaler.update()

    if rank==0:
      if global_step % hps.train.log_interval == 0:
        lr = optim_g.param_groups[0]['lr']
        losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_kl]
        logger.info('Train Epoch: {} [{:.0f}%]'.format(
          epoch,
          100. * batch_idx / len(train_loader)))
        logger.info([x.item() for x in losses] + [global_step, lr])
        
        scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g}
        scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl})

        scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
        scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
        scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
        image_dict = { 
            "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
            "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), 
            "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
        }
        utils.summarize(
          writer=writer,
          global_step=global_step, 
          images=image_dict,
          scalars=scalar_dict)

      if global_step % hps.train.eval_interval == 0:
        evaluate(hps, net_g, eval_loader, writer_eval)
        utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step)))
        utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step)))
    global_step += 1
  
  if rank == 0:
    logger.info('====> Epoch: {}'.format(epoch))
示例#13
0
    def train(x, y):
        G.optim.zero_grad()
        D.optim.zero_grad()
        # How many chunks to split x and y into?
        x = torch.split(x, config['batch_size'])
        y = torch.split(y, config['batch_size'])
        counter = 0

        # Optionally toggle D and G's "require_grad"

        utils.toggle_grad(D, True)
        utils.toggle_grad(G, False)

        for step_index in range(config['num_D_steps']):
            # If accumulating gradients, loop multiple times before an optimizer step
            for accumulation_index in range(config['num_D_accumulations']):
                z_.sample_()
                y_.sample_()
                D_fake, D_real, mi, c_cls, G_z = GD(z_[:config['batch_size']],
                                                    y_[:config['batch_size']],
                                                    x[counter],
                                                    y[counter],
                                                    train_G=False,
                                                    split_D=config['split_D'],
                                                    return_G_z=True)

                D_loss_real, D_loss_fake = losses.discriminator_loss(
                    D_fake, D_real)

                # Compute components of D's loss, average them, and divide by
                # the number of gradient accumulations
                C_loss = 0
                if config['loss_type'] == 'Twin_AC':
                    C_loss += F.cross_entropy(c_cls[D_fake.shape[0]:],
                                              y[counter]) + F.cross_entropy(
                                                  mi[:D_fake.shape[0]], y_)
                if config['loss_type'] == 'AC':
                    C_loss += F.cross_entropy(c_cls[D_fake.shape[0]:],
                                              y[counter])
                if config['Pac']:
                    T_img = x[counter].view(-1, 4 * x[counter].size()[1],
                                            x[counter].size()[2],
                                            x[counter].size()[3])
                    F_img = G_z.view(-1, 4 * G_z.size()[1],
                                     G_z.size()[2],
                                     G_z.size()[3])
                    pack_img = torch.cat([T_img, F_img], dim=0)
                    pack_out, _, _ = D(pack_img, pack=True)
                    D_real_pac = pack_out[:T_img.size()[0]]
                    D_fake_pac = pack_out[T_img.size()[0]:]
                    D_loss_real_pac, D_loss_fake_pac = losses.discriminator_loss(
                        D_fake_pac, D_real_pac)
                    D_loss_real += D_loss_real_pac
                    D_loss_fake += D_loss_fake_pac
                D_loss = (D_loss_real + D_loss_fake +
                          C_loss * config['AC_weight']) / float(
                              config['num_D_accumulations'])
                D_loss.backward()
                counter += 1

            D.optim.step()

        # Optionally toggle "requires_grad"
        utils.toggle_grad(D, False)
        utils.toggle_grad(G, True)

        # Zero G's gradients by default before training G, for safety
        G.optim.zero_grad()
        for step_index in range(config['num_G_steps']):
            for accumulation_index in range(config['num_G_accumulations']):
                z_.sample_()
                y_.sample_()
                D_fake, G_z, mi, c_cls = GD(z_,
                                            y_,
                                            train_G=True,
                                            split_D=config['split_D'],
                                            return_G_z=True)
                C_loss = 0
                MI_loss = 0
                G_loss = losses.generator_loss(D_fake)
                if config['loss_type'] == 'AC' or config[
                        'loss_type'] == 'Twin_AC':
                    C_loss = F.cross_entropy(c_cls, y_)
                    if config['loss_type'] == 'Twin_AC':
                        MI_loss = F.cross_entropy(mi, y_)

                if config['Pac']:
                    F_img = G_z.view(-1, 4 * G_z.size()[1],
                                     G_z.size()[2],
                                     G_z.size()[3])
                    D_fake_pac, _, _ = D(F_img, pack=True)
                    G_loss_pac = losses.generator_loss(D_fake_pac)
                    G_loss += G_loss_pac

                G_loss = G_loss / float(config['num_G_accumulations'])
                C_loss = C_loss / float(config['num_G_accumulations'])
                MI_loss = MI_loss / float(config['num_G_accumulations'])
                (G_loss + (C_loss - MI_loss) * config['AC_weight']).backward()

            G.optim.step()

        # If we have an ema, update it, regardless of if we test with it or not
        if config['ema']:
            ema.update(state_dict['itr'])

        out = {
            'G_loss': float(G_loss.item()),
            'D_loss_real': float(D_loss_real.item()),
            'D_loss_fake': float(D_loss_fake.item()),
            'C_loss': C_loss,
            'MI_loss': MI_loss
        }
        # Return G's loss and the components of D's loss.
        return out
示例#14
0
    def build_model(self):
        self.build_data_loader()

        x_lr, x_hr = self.inputs

        g_fake = self.generator(x_lr)

        # PatchGAN-wise
        d_fake = self.discriminator(g_fake)
        d_real = self.discriminator(x_hr, reuse=True)

        # losses
        self.d_adv_loss = discriminator_loss(self.gan_type, d_real, d_fake, use_ra=self.use_ra)
        self.g_adv_loss = generator_loss(self.gan_type, d_real, d_fake, use_ra=self.use_ra)

        gp = 0.
        if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan':
            gp = self.gradient_penalty(real=x_hr, fake=g_fake)

        self.d_loss = self.d_adv_loss + gp
        self.rec_loss = tf.reduce_mean(tf.abs(g_fake - x_hr))
        self.g_loss = self.weight_adv_loss * self.g_adv_loss + self.weight_rec_loss * self.rec_loss

        if self.use_perceptual_loss:
            x_real = tf.image.resize_images(x_hr, size=(224, 224), align_corners=False)
            x_fake = tf.image.resize_images(g_fake, size=(224, 224), align_corners=False)

            vgg19_real = self.build_vgg19_model(x_real)
            vgg19_fake = self.build_vgg19_model(x_fake, reuse=True)

            self.perceptual_loss = tf.reduce_mean(tf.square(vgg19_real - vgg19_fake))

            self.g_loss += self.weight_perceptual_loss * self.perceptual_loss

        t_vars = tf.trainable_variables()
        d_vars = [var for var in t_vars if "discriminator" in var.name]
        g_vars = [var for var in t_vars if "generator" in var.name]

        with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
            d_opt = tf.train.AdamOptimizer(self.d_lr, beta1=self.beta1, beta2=self.beta2)
            d_grads, d_vars = zip(*d_opt.compute_gradients(self.d_loss, var_list=d_vars))
            d_grads = [tf.clip_by_norm(grad, self.grad_clip_norm) for grad in d_grads]
            self.d_opt = d_opt.apply_gradients(zip(d_grads, d_vars))

            g_opt = tf.train.AdamOptimizer(self.g_lr, beta1=self.beta1, beta2=self.beta2)
            g_grads, g_vars = zip(*g_opt.compute_gradients(self.g_loss, var_list=g_vars))
            g_grads = [tf.clip_by_norm(grad, self.grad_clip_norm) for grad in g_grads]
            self.g_opt = g_opt.apply_gradients(zip(g_grads, g_vars))

            g_rec_opt = tf.train.AdamOptimizer(self.g_lr, beta1=self.beta1, beta2=self.beta2)
            g_rec_grads, g_rec_vars = zip(*g_rec_opt.compute_gradients(self.rec_loss, var_list=g_vars))
            g_rec_grads = [tf.clip_by_norm(grad, self.grad_clip_norm) for grad in g_rec_grads]
            self.g_rec_opt = g_rec_opt.apply_gradients(zip(g_rec_grads, g_rec_vars))

        # summaries
        tf.summary.scalar("loss/d_adv_loss", self.d_adv_loss)
        tf.summary.scalar("loss/g_adv_loss", self.g_adv_loss)
        tf.summary.scalar("loss/rec_loss", self.rec_loss)
        tf.summary.scalar("loss/g_loss", self.g_loss)
        if self.use_perceptual_loss:
            tf.summary.scalar("loss/perceptual_loss", self.perceptual_loss)

        tf.summary.image("real/x_lr", x_lr, max_outputs=1)
        tf.summary.image("real/x_hr", x_hr, max_outputs=1)
        tf.summary.image("fake/gen", g_fake, max_outputs=1)

        self.merged = tf.summary.merge_all()
        self.saver = tf.train.Saver(max_to_keep=5)
        self.best_saver = tf.train.Saver(max_to_keep=1)
        self.writer = tf.summary.FileWriter(self.checkpoint_dir, self.sess.graph)
示例#15
0
    def train(x, y, tensor_writer=None, iteration=None):
        print('Summation will be taken', config['D_hinge_loss_sum'],
              'D hinge loss')
        G.optim.zero_grad()
        D.optim.zero_grad()
        if config['no_Dv'] == False:
            Dv.optim.zero_grad()

        if tensor_writer != None and iteration % config[
                'log_results_every'] == 0:
            tensor_writer.add_video('Loaded Data', (x + 1) / 2, iteration)
            mean_pixel_val = torch.mean((x + 1) / 2, dim=[0, 1, 3, 4])
            tensor_writer.add_scalar(
                'Pixel vals/Mean Red Pixel values, real data',
                float(mean_pixel_val[0].item()), iteration)
            tensor_writer.add_scalar(
                'Pixel vals/Mean Green Pixel values, real data',
                float(mean_pixel_val[1].item()), iteration)
            tensor_writer.add_scalar(
                'Pixel vals/Mean Blue Pixel values, real data',
                float(mean_pixel_val[2].item()), iteration)

            y_text = []
            for yi in y:
                y_text.append(idx_to_classes[yi.item()])
            tensor_writer.add_text('Loaded Labels', ' | '.join(y_text),
                                   iteration)
        #Added by Xiaodan: prepare for avg pixel loss
        if config['no_avg_pixel_loss'] == False:
            mean_pixel_val_real = torch.mean((x + 1) / 2)
        # print('Range of loaded data:',x.min(),'--',x.max())
        # How many chunks to split x and y into?
        x = torch.split(x, config['batch_size'])
        y = torch.split(y, config['batch_size'])
        counter = 0

        # Optionally toggle D and G's "require_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, True)
            if config['no_Dv'] == False:
                utils.toggle_grad(Dv, True)
            utils.toggle_grad(G, False)

        for step_index in range(config['num_D_steps']):
            # If accumulating gradients, loop multiple times before an optimizer step
            D.optim.zero_grad()
            if config['no_Dv'] == False:
                Dv.optim.zero_grad()
            for accumulation_index in range(config['num_D_accumulations']):
                z_.sample_()
                y_.sample_()
                # print('z_ size in GAN tranining func:',z_.shape)
                # print('y_ size in GAN tranining func:',y_.shape)
                #xiaodan: D_fake, D_real [B*8,1]
                # print('hier and G_shared:',config['hier'],config['G_shared'])
                # print('Shape of z_[:config[batch_size]]:',z_[:config['batch_size']].shape)
                # print('config[batch_size]',config['batch_size'])
                if config['no_Dv'] == False:
                    D_fake, D_real, Dv_fake, Dv_real, G_z = GD(
                        z_[:config['batch_size']],
                        y_[:config['batch_size']],
                        x[counter],
                        y[counter],
                        train_G=False,
                        split_D=config['split_D'],
                        tensor_writer=tensor_writer,
                        iteration=iteration)
                else:
                    D_fake, D_real, G_z = GD(z_[:config['batch_size']],
                                             y_[:config['batch_size']],
                                             x[counter],
                                             y[counter],
                                             train_G=False,
                                             split_D=config['split_D'],
                                             tensor_writer=tensor_writer,
                                             iteration=iteration)
                # print('GD.k in train_fns line 49',GD.module.k) #GD.module because GD is now dataparallel class
                # D_fake & D_real shapes: [Bk,1], [Bk,1]
                # xiaodan: Make scores back to [B,k,1] for easier summation in discriminator_loss
                D_fake = D_fake.contiguous().view(-1, GD.module.k,
                                                  *D_fake.shape[1:])  #[B,k,1]
                D_real = D_real.contiguous().view(-1, GD.module.k,
                                                  *D_real.shape[1:])  #[B,k,1]
                if config['D_hinge_loss_sum'] == 'before':
                    D_fake = torch.sum(
                        D_fake, 1
                    )  #xiaodan: add k scores before doing hinge loss, according to the paper
                    D_real = torch.sum(D_real, 1)  #[B,1]
                # Compute components of D's loss, average them, and divide by
                # the number of gradient accumulations
                D_loss_real, D_loss_fake = losses.discriminator_loss(
                    D_fake, D_real, config['D_hinge_loss_sum'])

                # Dv_fake & Dv_real shapes: [BT*,1], [BT*,1] if T_into_B; [B,1], [B,1] if False
                if config['no_Dv'] == False:
                    # print('Dv_fake shape',Dv_fake.shape)
                    if config['T_into_B'] == True:
                        Dv_fake = Dv_fake.contiguous().view(
                            D_fake.shape[0], -1, *Dv_fake.shape[1:])  #[B,T*,1]
                        Dv_real = Dv_real.contiguous().view(
                            D_real.shape[0], -1, *Dv_real.shape[1:])  #[B,T*,1]
                        if config['Dv_hinge_loss_sum'] == 'before':
                            Dv_fake = torch.sum(
                                Dv_fake, 1
                            )  #xiaodan: add T* scores before doing hinge loss
                            Dv_real = torch.sum(Dv_real, 1)  #[B,1]
                        Dv_loss_real, Dv_loss_fake = losses.discriminator_loss(
                            Dv_fake, Dv_real, config['Dv_hinge_loss_sum'])
                    else:
                        #Xiaodan: If T_into_B is False, must use "before" for hinge loss.
                        Dv_loss_real, Dv_loss_fake = losses.discriminator_loss(
                            Dv_fake, Dv_real, 'before')
                    D_loss = (D_loss_real + D_loss_fake + Dv_loss_fake +
                              Dv_loss_real) / float(
                                  config['num_D_accumulations'])
                else:
                    D_loss = (D_loss_real + D_loss_fake) / float(
                        config['num_D_accumulations'])
                D_loss.backward()
                counter += 1

            # Optionally apply ortho reg in D
            if config['D_ortho'] > 0.0:
                # Debug print to indicate we're using ortho reg in D.
                if config['no_Dv'] == False:
                    print('using modified ortho reg in D and Dv')
                    utils.ortho(Dv, config['D_ortho'])
                else:
                    print('using modified ortho reg in D')
                    utils.ortho(D, config['D_ortho'])

            D.optim.step()
            if config['no_Dv'] == False:
                Dv.optim.step()

        # Optionally toggle "requires_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, False)
            if config['no_Dv'] == False:
                utils.toggle_grad(Dv, False)
            utils.toggle_grad(G, True)

        # Zero G's gradients by default before training G, for safety
        G.optim.zero_grad()

        # If accumulating gradients, loop multiple times
        for accumulation_index in range(config['num_G_accumulations']):
            z_.sample_()
            y_.sample_()
            # print('z_,y_ shapes before pass into GD:',z_.shape,y_.shape)
            if config['no_Dv'] == False:
                D_fake, Dv_fake, G_z = GD(z_,
                                          y_,
                                          train_G=True,
                                          split_D=config['split_D'],
                                          tensor_writer=tensor_writer,
                                          iteration=iteration)
            else:
                D_fake, G_z = GD(z_,
                                 y_,
                                 train_G=True,
                                 split_D=config['split_D'],
                                 tensor_writer=tensor_writer,
                                 iteration=iteration)

            D_fake = D_fake.contiguous().view(-1, GD.module.k,
                                              *D_fake.shape[1:])  #[B, k, 1]
            D_fake = torch.mean(
                D_fake,
                1)  # [B,1]  xiaodan: average k scores before doing hinge loss

            G_loss = config['D_loss_weight'] * losses.generator_loss(
                D_fake) / float(config['num_G_accumulations'])
            if config['no_Dv'] == False:
                if config['T_into_B'] == True:
                    Dv_fake = Dv_fake.contiguous().view(
                        D_fake.shape[0], -1, *Dv_fake.shape[1:])  #[B,T*,1]
                    Dv_fake = torch.mean(Dv_fake, 1)  # [B,1]
                G_loss += losses.generator_loss(Dv_fake) / float(
                    config['num_G_accumulations'])
            #Added by Xiaodan to take avg. pixel value into account as an additional losses
            # print(type(G_loss))
            if config['no_avg_pixel_loss'] == False:
                mean_pixel_val_fake = torch.mean((G_z + 1) / 2)
                mean_pixel_val_diff = abs(
                    float(mean_pixel_val_fake.item()) -
                    float(mean_pixel_val_real.item()))
                mean_pixel_loss = losses.avg_pixel_loss(
                    mean_pixel_val_diff,
                    config['avg_pixel_loss_weight']) / float(
                        config['num_G_accumulations'])
                if iteration >= config['pixel_loss_kicksin']:
                    G_loss += mean_pixel_loss
                else:
                    mean_pixel_loss = 0
            G_loss.backward()

        # Optionally apply modified ortho reg in G
        if config['G_ortho'] > 0.0:
            print('using modified ortho reg in G'
                  )  # Debug print to indicate we're using ortho reg in G
            # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
            utils.ortho(G,
                        config['G_ortho'],
                        blacklist=[param for param in G.shared.parameters()])
        if config['no_convgru'] == False:
            G_grad_gates = G.convgru.convgru.cell_list[
                0].conv_gates.weight.grad.abs().sum()
            G_grad_can = G.convgru.convgru.cell_list[
                0].conv_can.weight.grad.abs().sum()
            G_grad_first_layer = G.blocks[0][0].conv1.weight.grad.abs().sum()
            G_weight_gates = G.convgru.convgru.cell_list[
                0].conv_gates.weight.abs().mean()
            G_weight_can = G.convgru.convgru.cell_list[0].conv_can.weight.abs(
            ).mean()
            G_weight_first_layer = G.blocks[0][0].conv1.weight.abs().mean()
        G.optim.step()

        # If we have an ema, update it, regardless of if we test with it or not
        if config['ema']:
            ema.update(state_dict['itr'])
        if config['no_Dv'] == False:
            out = {
                'G_loss': float(G_loss.item()),
                'D_loss_real': float(D_loss_real.item()),
                'D_loss_fake': float(D_loss_fake.item()),
                'Dv_loss_real': float(Dv_loss_real.item()),
                'Dv_loss_fake': float(Dv_loss_fake.item())
            }
        else:
            out = {
                'G_loss': float(G_loss.item()),
                'D_loss_real': float(D_loss_real.item()),
                'D_loss_fake': float(D_loss_fake.item())
            }
        if tensor_writer != None and iteration % config[
                'log_results_every'] == 0:
            tensor_writer.add_video('Video Results', (G_z + 1) / 2, iteration)
            mean_pixel_val = torch.mean((G_z + 1) / 2, dim=[0, 1, 3, 4])
            tensor_writer.add_scalar(
                'Pixel vals/Mean Red Pixel values, fake data',
                float(mean_pixel_val[0].item()), iteration)
            tensor_writer.add_scalar(
                'Pixel vals/Mean Green Pixel values, fake data',
                float(mean_pixel_val[1].item()), iteration)
            tensor_writer.add_scalar(
                'Pixel vals/Mean Blue Pixel values, fake data',
                float(mean_pixel_val[2].item()), iteration)
            y_Gz_text = []
            for yi in y_:
                y_Gz_text.append(idx_to_classes[yi.item()])
            tensor_writer.add_text('Generated Labels', ' | '.join(y_Gz_text),
                                   iteration)

            # Return G's loss and the components of D's loss.
            if config['no_avg_pixel_loss'] == False:
                tensor_writer.add_scalar('Loss/avg_pixel_loss',
                                         mean_pixel_loss, iteration)
            tensor_writer.add_scalar('Loss/G_loss', out['G_loss'], iteration)
            tensor_writer.add_scalar('Loss/D_loss_real', out['D_loss_real'],
                                     iteration)
            tensor_writer.add_scalar('Loss/D_loss_fake', out['D_loss_fake'],
                                     iteration)
            if config['no_Dv'] == False:
                tensor_writer.add_scalar('Loss/Dv_loss_fake',
                                         out['Dv_loss_fake'], iteration)
                tensor_writer.add_scalar('Loss/Dv_loss_real',
                                         out['Dv_loss_real'], iteration)
            if config['no_convgru'] == False:
                tensor_writer.add_scalar('Gradient/G_grad_gates', G_grad_gates,
                                         iteration)
                tensor_writer.add_scalar('Gradient/G_grad_can', G_grad_can,
                                         iteration)
                tensor_writer.add_scalar('Gradient/G_grad_first_layer',
                                         G_grad_first_layer, iteration)

                tensor_writer.add_scalar('Weight/G_weight_gates',
                                         G_weight_gates, iteration)
                tensor_writer.add_scalar('Weight/G_weight_can', G_weight_can,
                                         iteration)
                tensor_writer.add_scalar('Weight/G_weight_first_layer',
                                         G_weight_first_layer, iteration)
        return out
示例#16
0
    def train(x_s, y, yd):
        G.optim.zero_grad()
        D.optim.zero_grad()
        # How many chunks to split x and y into?
        y = y.long()
        yd = yd.long()
        x_s = torch.split(x_s, config['batch_size'])
        y = torch.split(y, config['batch_size'])
        yd = torch.split(yd, config['batch_size'])
        counter = 0

        # Optionally toggle D and G's "require_grad"

        utils.toggle_grad(D, True)
        utils.toggle_grad(G, False)

        for step_index in range(config['num_D_steps']):
            # If accumulating gradients, loop multiple times before an optimizer step
            for accumulation_index in range(config['num_D_accumulations']):
                z_.sample_()
                y_.sample_()
                yd_.sample_()

                D_fake, D_real, mi, c_cls, mid, c_clsd, G_z = GD(
                    z_,
                    y_,
                    yd_,
                    x_s[counter],
                    y[counter],
                    yd[counter],
                    train_G=False,
                    split_D=config['split_D'],
                    return_G_z=True)

                D_loss_real, D_loss_fake = losses.discriminator_loss(
                    D_fake, D_real)

                C_loss = 0

                if config['AC']:
                    fake_mi = mi[:D_fake.shape[0]]
                    fake_cls = c_cls[:D_fake.shape[0]]
                    c_cls_rs = c_cls[D_fake.shape[0]:]

                    fake_mid = mid[:D_fake.shape[0]]
                    c_clsd = c_clsd[D_fake.shape[0]:]
                    # print(yd)
                    # print(yd_)

                    if config['loss_type'] == 'Twin_AC':
                        C_loss += F.cross_entropy(c_clsd, yd[counter]) + F.cross_entropy(fake_mid, yd_) + \
                                  0.5*F.cross_entropy(c_cls_rs[yd[counter]!=0], y[counter][yd[counter]!=0]) + 0.5*F.cross_entropy(fake_cls, y_) + 1.0*F.cross_entropy(fake_mi, y_)
                        # if state_dict['itr'] > 0000:
                        #     C_loss += 0.2*F.cross_entropy(c_cls_ft, y_[yd_!=0]) + 0.2*F.cross_entropy(fake_mi_t[yd_!=0], y_[yd_!=0])#F.cross_entropy(fake_mi[yd_ == 0], y_[yd_ == 0])

                    if config['loss_type'] == 'AC':
                        C_loss += F.cross_entropy(
                            c_cls_fs, y_f_s) + F.cross_entropy(c_clsd, yd)

                # Compute components of D's loss, average them, and divide by
                # the number of gradient accumulations

                if config['Pac']:
                    x_pack = torch.cat([x_s[counter], x_t[counter]], dim=0)
                    T_img = x_pack.view(-1, 4 * x_pack.size()[1],
                                        x_pack.size()[2],
                                        x_pack.size()[3])
                    F_img = G_z.view(-1, 4 * G_z.size()[1],
                                     G_z.size()[2],
                                     G_z.size()[3])
                    pack_img = torch.cat([T_img, F_img], dim=0)
                    pack_out, _, _ = D(pack_img, pack=True)
                    D_real_pac = pack_out[:T_img.size()[0]]
                    D_fake_pac = pack_out[T_img.size()[0]:]
                    D_loss_real_pac, D_loss_fake_pac = losses.discriminator_loss(
                        D_fake_pac, D_real_pac)
                    D_loss_real += D_loss_real_pac
                    D_loss_fake += D_loss_fake_pac

                D_loss = (D_loss_real + D_loss_fake +
                          C_loss * config['AC_weight']) / float(
                              config['num_D_accumulations'])
                D_loss.backward()
                counter += 1

            # Optionally apply ortho reg in D
            if config['D_ortho'] > 0.0:
                # Debug print to indicate we're using ortho reg in D.
                print('using modified ortho reg in D')
                utils.ortho(D, config['D_ortho'])

            D.optim.step()

        # Optionally toggle "requires_grad"
        utils.toggle_grad(D, False)
        utils.toggle_grad(G, True)

        # Zero G's gradients by default before training G, for safety
        G.optim.zero_grad()
        for step_index in range(config['num_G_steps']):
            for accumulation_index in range(config['num_G_accumulations']):
                z_.sample_()
                y_.sample_()
                yd_.sample_()
                D_fake, mi, cls, mid, clsd, G_z = GD(z_,
                                                     y_,
                                                     yd_,
                                                     train_G=True,
                                                     split_D=config['split_D'],
                                                     return_G_z=True)

                C_loss = 0
                MI_loss = 0
                CD_loss = 0
                MID_loss = 0
                G_loss = losses.generator_loss(D_fake)
                if config['loss_type'] == 'AC' or config[
                        'loss_type'] == 'Twin_AC':
                    C_loss = 1.0 * F.cross_entropy(
                        cls,
                        y_)  #+ 0.5*F.cross_entropy(cls[yd_!=0], y_[yd_!=0])
                    CD_loss = F.cross_entropy(clsd, yd_)
                    if config['loss_type'] == 'Twin_AC':
                        MI_loss = 1.0 * F.cross_entropy(mi, y_)
                        # if state_dict['itr'] > 0000:
                        #     MI_loss += 0.5*F.cross_entropy(mi_t[yd_!=0], y_[yd_!=0])
                        MID_loss = F.cross_entropy(mid, yd_)

                if config['Pac']:
                    F_img = G_z.view(-1, 4 * G_z.size()[1],
                                     G_z.size()[2],
                                     G_z.size()[3])
                    D_fake_pac, _, _ = D(F_img, pack=True)
                    G_loss_pac = losses.generator_loss(D_fake_pac)
                    G_loss += G_loss_pac

                G_loss = G_loss / float(config['num_G_accumulations'])
                C_loss = C_loss / float(config['num_G_accumulations'])
                MI_loss = MI_loss / float(config['num_G_accumulations'])
                CD_loss = CD_loss / float(config['num_G_accumulations'])
                MID_loss = MID_loss / float(config['num_G_accumulations'])
                (G_loss + (C_loss - MI_loss + CD_loss - MID_loss) *
                 config['AC_weight']).backward()

            # Optionally apply modified ortho reg in G
            if config['G_ortho'] > 0.0:
                print('using modified ortho reg in G'
                      )  # Debug print to indicate we're using ortho reg in G
                # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
                utils.ortho(
                    G,
                    config['G_ortho'],
                    blacklist=[param for param in G.shared.parameters()])
            G.optim.step()

        # If we have an ema, update it, regardless of if we test with it or not
        if config['ema']:
            ema.update(state_dict['itr'])

        out = {
            'G_loss': float(G_loss.item()),
            'D_loss_real': float(D_loss_real.item()),
            'D_loss_fake': float(D_loss_fake.item()),
            'C_loss': C_loss,
            'MI_loss': MI_loss,
            'CD_loss': CD_loss,
            'MID_loss': MID_loss
        }
        # Return G's loss and the components of D's loss.
        return out
    def train(x, y):
        G.optim.zero_grad()
        D.optim.zero_grad()
        # How many chunks to split x and y into?
        x = torch.split(x, config['batch_size'])
        y = torch.split(y, config['batch_size'])
        counter = 0

        # Optionally toggle D and G's "require_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, True)
            utils.toggle_grad(G, False)

        for step_index in range(config['num_D_steps']):
            # If accumulating gradients, loop multiple times before an optimizer step
            D.optim.zero_grad()
            for accumulation_index in range(config['num_D_accumulations']):
                z_.sample_()
                y_.sample_()
                D_scores = GD(z_[:config['batch_size']],
                              y_[:config['batch_size']],
                              x[counter],
                              y[counter],
                              train_G=False,
                              policy=config['DiffAugment'],
                              CR=config['CR'] > 0,
                              CR_augment=config['CR_augment'])

                D_loss_CR = 0
                if config['CR'] > 0:

                    # to do
                    continue

                else:
                    D_fake, D_real = D_scores

                # Compute components of D's loss, average them, and divide by
                # the number of gradient accumulations
                D_loss_real, D_loss_fake = losses.discriminator_loss(
                    D_fake, D_real)
                D_loss = D_loss_real + D_loss_fake + D_loss_CR
                D_loss = D_loss / float(config['num_D_accumulations'])
                D_loss.backward()
                counter += 1

            # Optionally apply ortho reg in D
            if config['D_ortho'] > 0.0:
                # Debug print to indicate we're using ortho reg in D.
                print('using modified ortho reg in D')
                utils.ortho(D, config['D_ortho'])

            D.optim.step()

        # Optionally toggle "requires_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, False)
            utils.toggle_grad(G, True)

        # Zero G's gradients by default before training G, for safety
        G.optim.zero_grad()

        if not config['fix_G']:
            # If accumulating gradients, loop multiple times
            for accumulation_index in range(config['num_G_accumulations']):
                z_.sample_()
                y_.sample_()
                D_fake = GD(z_, y_, train_G=True, policy=config['DiffAugment'])
                G_loss = losses.generator_loss(D_fake) / float(
                    config['num_G_accumulations'])
                G_loss.backward()

            # Optionally apply modified ortho reg in G
            if config['G_ortho'] > 0.0:
                # Debug print to indicate we're using ortho reg in G
                print('using modified ortho reg in G')
                # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
                utils.ortho(
                    G,
                    config['G_ortho'],
                    blacklist=[param for param in G.shared.parameters()])
            G.optim.step()

            # If we have an ema, update it, regardless of if we test with it or not
            if config['ema']:
                ema.update(state_dict['itr'])

        out = {
            'G_loss': float(G_loss.item()) if not config['fix_G'] else 0,
            'D_loss_real': float(D_loss_real.item()),
            'D_loss_fake': float(D_loss_fake.item()),
        }
        if config['CR'] > 0:
            out['D_loss_CR'] = float(D_loss_CR.item())
        # Return G's loss and the components of D's loss.
        return out
示例#18
0
    def train(x, y):
        G.module.optim.zero_grad()
        D.module.optim.zero_grad()
        # How many chunks to split x and y into?
        x = torch.split(x, config.batch_size)
        y = torch.split(y, config.batch_size)
        counter = 0

        # Optionally toggle D and G's "require_grad"

        toggle_grad(D, True)
        toggle_grad(G, False)

        for step_index in range(config.num_D_steps):
            z_.sample_()
            y_.sample_()
            D_fake, D_real, mi, c_cls = GD(z_[:config.batch_size],
                                           y_[:config.batch_size],
                                           x[counter],
                                           y[counter],
                                           train_G=False)

            D_loss_real, D_loss_fake = losses.discriminator_loss(
                D_fake, D_real)
            if config.loss_type == 'Twin_AC':
                D_loss = (D_loss_real + D_loss_fake) + config.C_w * (
                    F.cross_entropy(c_cls[D_fake.shape[0]:], y[counter]) +
                    F.cross_entropy(mi[:D_fake.shape[0]], y_))
            elif config.loss_type == 'AC':
                D_loss = (D_loss_real +
                          D_loss_fake) + config.C_w * F.cross_entropy(
                              c_cls[D_fake.shape[0]:], y[counter])
            else:
                D_loss = (D_loss_real + D_loss_fake)
            (D_loss).backward()
            counter += 1
            D.module.optim.step()

        # Optionally toggle "requires_grad"
        toggle_grad(D, False)
        toggle_grad(G, True)

        # Zero G's gradients by default before training G, for safety
        G.module.optim.zero_grad()

        for step_index in range(config.num_G_steps):
            z_.sample_()
            y_.sample_()
            D_fake, mi, c_cls = GD(z_[:config.batch_size],
                                   y_[:config.batch_size],
                                   train_G=True)  # D(fake_img, y_)
            G_loss = losses.generator_loss(D_fake)

            C_loss = 0
            MI_loss = 0
            if config.loss_type == 'Twin_AC':

                MI_loss = F.cross_entropy(mi, y_)
                C_loss = F.cross_entropy(c_cls, y_)

                ((G_loss - config.C_w * MI_loss +
                  config.C_w * C_loss)).backward()
            elif config.loss_type == 'AC':

                C_loss = F.cross_entropy(c_cls, y_)

                ((G_loss + config.C_w * C_loss)).backward()
            else:
                (G_loss).backward()

        G.module.optim.step()

        out = {
            'G_loss': G_loss,
            'D_loss_real': D_loss_real,
            'D_loss_fake': D_loss_fake,
            'C_loss': C_loss,
            'MI_loss': MI_loss
        }
        # Return G's loss and the components of D's loss.
        return out
示例#19
0
    def train(x, y):
        G.optim.zero_grad()
        D.optim.zero_grad()
        # How many chunks to split x and y into?
        x = torch.split(x, config['batch_size'])
        y = torch.split(y, config['batch_size'])
        counter = 0

        # Optionally toggle D and G's "require_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, True)
            utils.toggle_grad(G, False)

        for step_index in range(config['num_D_steps']):
            # If accumulating gradients, loop multiple times before an optimizer step
            D.optim.zero_grad()
            # The fake class label
            lossy = torch.LongTensor(config['batch_size'])
            lossy = lossy.cuda()
            lossy.data.fill_(
                config['n_classes'])  # index for fake just for loss
            for accumulation_index in range(config['num_D_accumulations']):
                z_.sample_()
                y_.sample_()

                D_fake, D_real = GD(z_[:config['batch_size']],
                                    y_[:config['batch_size']],
                                    x[counter],
                                    y[counter],
                                    train_G=False,
                                    split_D=config['split_D'])

                # Compute components of D's loss, average them, and divide by
                # the number of gradient accumulations
                if config['mh_csc_loss'] or config['mh_loss']:
                    D_loss_real = losses.crammer_singer_criterion(
                        D_real, y[counter])
                    D_loss_fake = losses.crammer_singer_criterion(
                        D_fake, lossy[:config['batch_size']])
                else:
                    D_loss_real, D_loss_fake = losses.discriminator_loss(
                        D_fake, D_real)
                D_loss = (D_loss_real + D_loss_fake) / float(
                    config['num_D_accumulations'])
                D_loss.backward()
                counter += 1

            # Optionally apply ortho reg in D
            if config['D_ortho'] > 0.0:
                # Debug print to indicate we're using ortho reg in D.
                print('using modified ortho reg in D')
                utils.ortho(D, config['D_ortho'])

            D.optim.step()

        # Optionally toggle "requires_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, False)
            utils.toggle_grad(G, True)

        # Zero G's gradients by default before training G, for safety
        G.optim.zero_grad()

        # If accumulating gradients, loop multiple times
        for accumulation_index in range(config['num_G_accumulations']):
            # reusing the same noise for CIFAR ...
            if config['resampling'] or (accumulation_index > 0):
                z_.sample_()
                y_.sample_()

            if config['fm_loss']:
                D_feat_fake, D_feat_real = GD(z_,
                                              y_,
                                              x[-1],
                                              None,
                                              train_G=True,
                                              split_D=config['split_D'],
                                              feat=True)
                fm_loss = torch.mean(
                    torch.abs(
                        torch.mean(D_feat_fake, 0) -
                        torch.mean(D_feat_real, 0)))
                G_loss = fm_loss
            else:
                D_fake = GD(z_, y_, train_G=True, split_D=config['split_D'])
                if config['mh_csc_loss']:
                    G_loss = losses.crammer_singer_complement_criterion(
                        D_fake, lossy[:config['batch_size']]) / float(
                            config['num_G_accumulations'])
                elif config['mh_loss']:
                    D_feat_fake, D_feat_real = GD(z_,
                                                  y_,
                                                  x[-1],
                                                  None,
                                                  train_G=True,
                                                  split_D=config['split_D'],
                                                  feat=True)
                    fm_loss = torch.mean(
                        torch.abs(
                            torch.mean(D_feat_fake, 0) -
                            torch.mean(D_feat_real, 0)))
                    oth_loss = losses.mh_loss(D_fake,
                                              y_[:config['batch_size']])
                    G_loss = (config['mh_fmloss_weight'] * fm_loss +
                              config['mh_loss_weight'] * oth_loss) / float(
                                  config['num_G_accumulations'])
                else:
                    G_loss = losses.generator_loss(D_fake) / float(
                        config['num_G_accumulations'])
            G_loss.backward()

        # Optionally apply modified ortho reg in G
        if config['G_ortho'] > 0.0:
            print('using modified ortho reg in G'
                  )  # Debug print to indicate we're using ortho reg in G
            # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
            utils.ortho(G,
                        config['G_ortho'],
                        blacklist=[param for param in G.shared.parameters()])
        G.optim.step()

        # If we have an ema, update it, regardless of if we test with it or not
        if config['ema']:
            ema.update(state_dict['itr'])

        out = {
            'G_loss': float(G_loss.item()),
            'D_loss_real': float(D_loss_real.item()),
            'D_loss_fake': float(D_loss_fake.item())
        }
        # Return G's loss and the components of D's loss.
        return out
示例#20
0
    def create_train_op(input, labels, params):
        assert labels is None
        reals, reals_class_id = input['reals']
        pp(['input', input])
        pp(['reals', reals])
        pp(['reals_class_id', reals_class_id])
        pp(['params', params])
        mdl = BigGAN.GAN()
        BigGAN.instance = mdl
        dim_z = mdl.gan.generator.dim_z
        nclasses = mdl.gan.discriminator.n_class
        N, H, W, C = reals.shape.as_list()
        fakes_z, fakes_class_id = utils.prepare_z_y(G_batch_size=N,
                                                    dim_z=dim_z,
                                                    nclasses=nclasses)
        reals_y = tf.one_hot(reals_class_id, nclasses)
        fakes_y = tf.one_hot(fakes_class_id, nclasses)
        fakes = mdl.gan.generator(fakes_z, fakes_y)
        reals_D = mdl.gan.discriminator(reals, reals_y)
        fakes_D = mdl.gan.discriminator(fakes, fakes_y)
        global_step = tflex.get_or_create_global_step()
        #inc_global_step = global_step.assign_add(1, read_value=False, name="inc_global_step")
        # G_vars = []
        # D_vars = []
        # for variable in tf.trainable_variables():
        #   if variable.name.startswith('Generator/'):
        #     G_vars.append(variable)
        #   elif variable.name.startswith('Discriminator/'):
        #     D_vars.append(variable)
        #   elif variable.name.startswith('linear/w'):
        #     G_vars.append(variable)
        #     D_vars.append(variable)
        #   else:
        #     import pdb; pdb.set_trace()
        #     assert False, "Unexpected trainable variable"
        T_vars = tf.trainable_variables()
        G_vars = [
            x for x in T_vars if x.name.startswith('Generator/')
            or x.name.startswith('linear/w:')
        ]
        D_vars = [
            x for x in T_vars if x.name.startswith('Discriminator/')
            or x.name.startswith('linear/w:')
        ]
        leftover_vars = [
            x for x in T_vars if x not in G_vars and x not in D_vars
        ]
        if len(leftover_vars) > 0:
            import pdb
            pdb.set_trace()
            raise ValueError("Unexpected trainable variables")
        # pp({
        #   "G_vars": G_vars,
        #   "D_vars": D_vars,
        #   "leftover_vars": leftover_vars,
        #   })
        if True:

            def should_train_variable(v):
                return True

            train_vars = [
                v for v in tf.trainable_variables() if should_train_variable(v)
            ]
            non_train_vars = [
                v for v in tf.trainable_variables()
                if not should_train_variable(v)
            ]
            other_vars = [
                v for v in tf.global_variables()
                if v not in train_vars and v not in non_train_vars
            ]
            local_vars = [v for v in tf.local_variables()]

            paramcount = lambda vs: sum(
                [np.prod(v.shape.as_list()) for v in vs])

            def logvars(variables, label, print_variables=False):
                if print_variables:
                    tf.logging.info("%s (%s parameters): %s", label,
                                    paramcount(variables), pps(variables))
                else:
                    tf.logging.info("%s (%s parameters)", label,
                                    paramcount(variables))
                return variables

            tf.logging.info(
                "Training %d parameters (%.2fM) out of %d parameters (%.2fM)" %
                (
                    paramcount(train_vars),
                    paramcount(train_vars) / (1024.0 * 1024.0),
                    paramcount(tf.trainable_variables()),
                    paramcount(tf.trainable_variables()) / (1024.0 * 1024.0),
                ))

            tf.logging.info("---------")
            tf.logging.info("Variable details:")
            logvars(train_vars, "trainable variables", print_variables=True)
            logvars(non_train_vars,
                    "non-trainable variables",
                    print_variables=True)
            logvars(other_vars, "other global variables", print_variables=True)
            logvars(local_vars, "other local variables", print_variables=True)

            tf.logging.info("---------")
            tf.logging.info("Variable summary:")
            logvars(train_vars, "trainable variables")
            logvars(non_train_vars, "non-trainable variables")
            logvars(other_vars, "other global variables")
            logvars(local_vars, "other local variables")

        G_loss = losses.generator_loss(fakes_D)
        D_loss_real, D_loss_fake = losses.discriminator_loss(reals_D, fakes_D)
        D_loss = D_loss_real + D_loss_fake
        #loss = tf.constant(0.0)
        loss = G_loss + D_loss
        optimizer = tf.train.AdamOptimizer()
        if params['use_tpu']:
            optimizer = tf.tpu.CrossShardOptimizer(optimizer)
        #import pdb; pdb.set_trace()
        update_ops = tf.get_collection(
            tf.GraphKeys.UPDATE_OPS)  # To update batchnorm, if present
        pp(['tf.GraphKeys.UPDATE_OPS', update_ops])
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss,
                                          var_list=T_vars,
                                          global_step=global_step)
            return train_op, loss  #D_loss_real
示例#21
0
    def train(x, y):
        G.optim.zero_grad()
        D.optim.zero_grad()
        inner_iter_count = 0
        partial_test_input = 0
        # How many chunks to split x and y into?
        #x = torch.split(x, config['batch_size'])
        #y = torch.split(y, config['batch_size'])
        #print('x len{}'.format(len(x)))
        #print('y len{}'.format(len(y)))
        #assert len(x) == config['num_D_accumulations'] == len(y)
        #D_fake, D_real, G_fake, gy = None, None, None, None
        # Optionally toggle D and G's "require_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, True)
            utils.toggle_grad(G, False)

        for step_index in range(config['num_D_steps']):
            # If accumulating gradients, loop multiple times before an optimizer step
            D.optim.zero_grad()

            d_reals = None#[None for _ in x]
            g_fakes = None#[None for _ in x]
            #gys = [None for _ in x]
            #zs = [None for _ in x]
            #zs_.sample_()
            #ys_.sample_()
            #gy = ys_[:config['batch_size']]
            #z = zs_[:config['batch_size']].view(zs_.size(0), 9, 8, 8)[:, :5]
            if state_dict['epoch'] < 0:
                #for accumulation_index in range(config['num_D_accumulations']):  # doesn't mean anything right now
                # for fb_iter in range(config['num_feedback_iter']):
                # if fb_iter == 0:
                # z_ = zs_[:config['batch_size']]
                # gy = ys_[:config['batch_size']]
                # print('z_ shape {}'.format(z_.shape))
                # z_ = z_.view(zs_.size(0), 9, 8, 8)[:, :5]
                zs_.sample_()
                z_ = zs_[:config['batch_size']].view(zs_.size(0), 24, 8, 8)[:,20]  # [:, :5]
                #z_ = z_.view(z_.size(0), -1)

                # zs[accumulation_index] = z
                # z_ = torch.cat([z, torch.zeros(zs_.size(0), 4, 8, 8).cuda()], 1)

                ys_.sample_()
                gy = ys_[:config['batch_size']]
                # gys[accumulation_index] = gy.detach()
                # else:
                # D_real = D_real#.repeat(1,3,1,1)# * g_fakes[accumulation_index]
                # print('zs_ shape 0 {}'.format(zs_.shape))
                # print('\n\n\n\n')
                # print('r shape {}'.format(r.shape))
                # print('g fake shape {}'.format(g_fakes[accumulation_index].shape))
                # print('\n\n\n\n')
                # z_ = zs_[:config['batch_size']].view(zs_.size(0), 9, 8, 8)[:, :8]
                # G_fake = nn.AvgPool2d(4)(g_fakes[accumulation_index])
                # print('z shape 5 {}'.format(z_.shape))
                # z_=z_[:,:3]
                # print('z shape 10 {}'.format(z_.shape))

                # z_ = torch.cat([d_reals[accumulation_index], G_fake, zs[accumulation_index]], 1)
                # print('z shape 15 {}'.format(z_.shape))
                # gy = gys[accumulation_index]
                D_fake, D_real, G_fake = GD(z_,
                                            gy,
                                            x=x,#[accumulation_index],
                                            dy=y,#[accumulation_index],
                                            train_G=False,
                                            split_D=config['split_D'])
                #print('D shape {}'.format(D_fake.shape))
                #print('G fake shape {}'.format(nn.AvgPool2d(4)(G_fake).shape))
                #print('D real shape {}'.format(D_real.shape))
                #print('z shape {}'.format(z_.shape))

                if state_dict['itr'] % 1000 == 0: ##and accumulation_index == 6:
                    print('saving img')
                    torchvision.utils.save_image(x.float().cpu(),#[accumulation_index].float().cpu(),
                                                 '/ubc/cs/research/shield/projects/cshen001/BigGAN-original/BigGAN-PyTorch/samples_new/{}_it{}_pre_xreal.jpg'.format(
                                                     time, state_dict['itr']),
                                                 nrow=int(D_fake.shape[0] ** 0.5), normalize=True)
                    torchvision.utils.save_image(D_fake.float().cpu(),
                                                 '/ubc/cs/research/shield/projects/cshen001/BigGAN-original/BigGAN-PyTorch/samples_new/{}_it{}_pre_dfake.jpg'.format(
                                                     time, state_dict['itr']),
                                                 nrow=int(D_fake.shape[0] ** 0.5), normalize=True)
                    torchvision.utils.save_image(D_real.float().cpu(),
                                                 '/ubc/cs/research/shield/projects/cshen001/BigGAN-original/BigGAN-PyTorch/samples_new/{}_it{}_pre_dreal.jpg'.format(
                                                     time, state_dict['itr']),
                                                 nrow=int(D_fake.shape[0] ** 0.5), normalize=True)

                # d_reals[accumulation_index] = D_real.detach()
                # g_fakes[accumulation_index] = G_fake.detach()

                # Compute components of D's loss, average them, and divide by
                # the number of gradient accumulations
                D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real)
                D_loss = (D_loss_real + D_loss_fake)# / float(config['num_D_accumulations'])
                D_loss.backward()
                # counter += 1

                    # Optionally apply ortho reg in D
                if config['D_ortho'] > 0.0:
                    # Debug print to indicate we're using ortho reg in D.
                    print('using modified ortho reg in D')
                    utils.ortho(D, config['D_ortho'])

                D.optim.step()
                # D.optim.zero_grad()
                # Optionally toggle "requires_grad"
            else:
                for fb_iter in range(config['num_feedback_iter_D']):
                    #for accumulation_index in range(config['num_D_accumulations']): #doesn't mean anything right now
                    #for fb_iter in range(config['num_feedback_iter']):
                    zs_.sample_()
                    z_ = zs_[:config['batch_size']].view(zs_.size(0), 24, 32, 32)[:, :20]
                    ys_.sample_()
                    gy = ys_[:config['batch_size']]

                    if fb_iter == 0:
                        # z_ = zs_[:config['batch_size']]
                        # gy = ys_[:config['batch_size']]
                        #print('z_ shape {}'.format(z_.shape))
                        #z_ = z_.view(zs_.size(0), 9, 8, 8)[:, :5]

                        #zs_.sample_()
                        #z_ = zs_[:config['batch_size']].view(zs_.size(0), 24, 8, 8)[:, :20]
                        #zs[accumulation_index] = z_
                        #print('three channel x input train D shape before {}'.format(x[:, :3].shape))
                        #init_x = nn.AvgPool2d(4)(x[:, :3])
                        init_x = x[:, :3]

                        z_ = torch.cat([z_, init_x, torch.ones(zs_.size(0), 1, 32, 32).cuda()], 1)
                        #print('three channel x input train D shape after {}'.format(nn.AvgPool2d(4)(x[:, :3]).shape))

                        #ys_.sample_()
                        #gy = ys_[:config['batch_size']]
                        #gys[accumulation_index] = gy.detach()
                    else:
                        #D_real = D_real#.repeat(1,3,1,1)# * g_fakes[accumulation_index]
                        #print('zs_ shape 0 {}'.format(zs_.shape))
                        #print('\n\n\n\n')
                        #print('r shape {}'.format(r.shape))
                        #print('g fake shape {}'.format(g_fakes[accumulation_index].shape))
                        #print('\n\n\n\n')
                        #z_ = zs_[:config['batch_size']].view(zs_.size(0), 9, 8, 8)[:, :8]
                        g_fake = 0.1 * g_fake + 0.9 * init_x#[accumulation_index]
                        #print('z shape 5 {}'.format(z_.shape))
                        #z_=z_[:,:3]
                        # print('z shape 10 {}'.format(z_.shape))
                        # print('g fake shape 10 {}'.format(G_fake.shape))
                        # print('d real shape 10 {}'.format(d_reals.shape))
                        #z_ = torch.cat([zs[accumulation_index],d_reals[accumulation_index], G_fake,], 1)
                        z_ = torch.cat([z_, g_fake, nn.functional.interpolate(d_reals, 32, mode='bilinear')#[accumulation_index]
                                           ,], 1)
                    #z_ = z_.view(z_.size(0),-1)
                        #print('z shape 15 {}'.format(z_.shape))
                        #gy = gys[accumulation_index]
                    # if state_dict['itr'] % 42 == 0:
                    #     partial_test_input = partial_test_input + torch.cat([g_fakes, d_fakes])
                    D_fake, D_real, G_fake = GD(z_,
                                        gy,
                                        x=x,#[accumulation_index],
                                        dy=y,#[accumulation_index],
                                        train_G=False,

                                        split_D=config['split_D'])
                    #print('D shape {}'.format(D_fake.shape))
                    if state_dict['itr'] % 1000 == 0:# and accumulation_index == 6:
                        print('saving img')
                        torchvision.utils.save_image(x.float().cpu(),#[accumulation_index].float().cpu(),
                                                     '/ubc/cs/research/shield/projects/cshen001/BigGAN-original/BigGAN-PyTorch/samples_new/{}_it{}_fb{}_xreal.jpg'.format(
                                                         time, state_dict['itr'], fb_iter),
                                                     nrow=int(D_fake.shape[0] ** 0.5), normalize=True)
                        torchvision.utils.save_image(G_fake.float().cpu(),
                        '/ubc/cs/research/shield/projects/cshen001/BigGAN-original/BigGAN-PyTorch/samples_new/{}_it{}_fb{}_Gfake_d.jpg'.format(
                            time,state_dict['itr'],fb_iter),nrow=int(D_fake.shape[0] ** 0.5),normalize=True)
                        if fb_iter > 1:
                            torchvision.utils.save_image(g_fake.float().cpu(),
                            '/ubc/cs/research/shield/projects/cshen001/BigGAN-original/BigGAN-PyTorch/samples_new/{}_it{}_fb{}_gfake_d.jpg'.format(
                                time,state_dict['itr'],fb_iter),nrow=int(D_fake.shape[0] ** 0.5),normalize=True)


                    D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real)
                    if not fb_iter == 0:
                        # d_real_enforcement = losses.loss_enforcing(d_reals#[accumulation_index]
                        #                                            , D_real)
                        # g_fakes_enforcement = losses.loss_enforcing(g_fakes #[accumulation_index]
                        #                                             , nn.AvgPool2d(4)(G_fake))
                        D_loss = (D_loss_real + D_loss_fake)# + 0.1 * d_real_enforcement)# / float(config['num_D_accumulations'])
                    else:
                        D_loss = (D_loss_real + D_loss_fake)# / float(config['num_D_accumulations'])

                    #d_reals[accumulation_index] = D_real.detach()
                    d_reals = D_real.detach()

                    #g_fakes[accumulation_index] = nn.AvgPool2d(4)(G_fake).detach()
                    g_fake = G_fake.detach()
                    #g_fakes = G_fake.detach()

                    # Compute components of D's loss, average them, and divide by
                    # the number of gradient accumulations

                    # D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real)
                    # if not fb_iter == 0:
                    #     D_loss = (D_loss_real + D_loss_fake + d_real_enforcement + g_fakes_enforcement) / float(config['num_D_accumulations'])
                    # else:
                    #     D_loss = (D_loss_real + D_loss_fake) / float(config['num_D_accumulations'])

                    D_loss.backward()

                    #counter += 1

                    # Optionally apply ortho reg in D
                    if config['D_ortho'] > 0.0:
                        # Debug print to indicate we're using ortho reg in D.
                        # print('using modified ortho reg in D')
                        utils.ortho(D, config['D_ortho'])

                    D.optim.step()
                        #D.optim.zero_grad()

            # Optionally toggle "requires_grad"

        if config['toggle_grads']:
            utils.toggle_grad(D, False)
            utils.toggle_grad(G, True)

        # Zero G's gradients by default before training G, for safety
        G.optim.zero_grad()

        #d_fakes = [None for _ in range(config['num_G_accumulations'])]
        #g_fakes = [None for _ in range(config['num_G_accumulations'])]
        #gys = [None for _ in range(config['num_G_accumulations'])]
        #for fb_iter in range(config['num_feedback_iter']):
        # If accumulating gradients, loop multiple times
        d_fakes = None#[None for _ in x]
        g_fakes = None#[None for _ in x]
        #gys = [None for _ in x]
        #zs = [None for _ in x]
        if state_dict['epoch'] < 0:
            #for accumulation_index in range(config['num_G_accumulations']):  # doesn't mean anything right now
            zs_.sample_()
            z_ = zs_[:config['batch_size']].view(zs_.size(0), 24, 32, 32)[:, :20]
            #zs[accumulation_index] = z_[:, :5]
            # z_ = torch.cat([z, torch.zeros(zs_.size(0), 4, 8, 8).cuda()],1)
            ys_.sample_()
            gy = ys_
            #gys[accumulation_index] = gy.detach()

            # D_fake = D_fake.repeat(1,3,1,1)
            # z_ = zs_[:config['batch_size']].view(zs_.size(0), 9, 8, 8)[:, :5]
            #G_fake = nn.AvgPool2d(4)(g_fakes[accumulation_index])
            #z_ = torch.cat([d_fakes[accumulation_index], G_fake, zs[accumulation_index]], 1)
             #   gy = gys[accumulation_index]
            z_ = z_.view(z_.size(0), -1)
            D_fake, G_z = GD(z=z_, gy=gy, train_G=True, split_D=config['split_D'], return_G_z=True)
            G_loss = losses.generator_loss(D_fake)# / float(config['num_G_accumulations'])
            G_loss.backward()

            if state_dict['itr'] % 1000 == 0:# and accumulation_index == 6:
                print('saving img')
                torchvision.utils.save_image(D_fake.float().cpu(),
                                             '/ubc/cs/research/shield/projects/cshen001/BigGAN-original/BigGAN-PyTorch/samples_new/{}_it{}_pre_dfake.jpg'.format(
                                                 time,
                                                 state_dict['itr'],),
                                             nrow=int(D_fake.shape[0] ** 0.5),
                                             normalize=True)
                torchvision.utils.save_image(G_z.float().cpu(),
                                             '/ubc/cs/research/shield/projects/cshen001/BigGAN-original/BigGAN-PyTorch/samples_new/{}_it{}_pre_G_z.jpg'.format(
                                                 time,
                                                 state_dict['itr'],),
                                             nrow=int(D_fake.shape[0] ** 0.5),
                                             normalize=True)

            #g_fakes[accumulation_index] = G_z.detach()
            #d_fakes[accumulation_index] = D_fake.detach()
            # Optionally apply modified ortho reg in G
            if config['G_ortho'] > 0.0:
                print('using modified ortho reg in G')  # Debug print to indicate we're using ortho reg in G
                # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
                utils.ortho(G, config['G_ortho'],
                            blacklist=[param for param in G.shared.parameters()])
            G.optim.step()
            # G.optim.zero_grad()
        else:
            for fb_iter in range(config['num_feedback_iter']):
                #for accumulation_index in range(config['num_G_accumulations']): #doesn't mean anything right now
                zs_.sample_()
                z_ = zs_[:config['batch_size']].view(zs_.size(0), 24, 32, 32)[:, :20]
                ys_.sample_()
                gy = ys_

                if fb_iter <= 1:
                    #zs_.sample_()
                    #z_ = zs_[:config['batch_size']].view(zs_.size(0), 24, 8, 8)[:, :20]

                    #zs[accumulation_index] = z_
                    #print('three channel x input train G shape before {}'.format(x.shape))
                    #init_x = nn.AvgPool2d(4)(x[:, :3])
                    init_x = x[:, :3]
                    z_ = torch.cat([z_, init_x, torch.ones(zs_.size(0), 1, 32, 32).cuda()], 1)
                    #print('three channel x input train G shape after {}'.format(nn.AvgPool2d(4)(x[:, :3]).shape))
                    #ys_.sample_()
                    #gy = ys_
                    #gys[accumulation_index] = gy.detach()
                else:
                    #D_fake = D_fake.repeat(1,3,1,1)
                    #z_ = zs_[:config['batch_size']].view(zs_.size(0), 9, 8, 8)[:, :5]
                    #G_fake = g_fakes#[accumulation_index]
                    g_fake = 0.05 * g_fakes + 0.95 * init_x  # [accumulation_index]
                    d_fakes = nn.functional.interpolate(d_fakes, 32, mode='bilinear')#[accumulation_index]
                    #z_ = torch.cat([zs[accumulation_index], d_fakes[accumulation_index], G_fake, ], 1)
                    z_ = torch.cat([z_, g_fake, d_fakes #[accumulation_index]
                                       ,], 1)
                    if ((not (state_dict['itr'] % config['save_every'])) or (not (state_dict['itr'] % config['test_every']))):
                        partial_test_input = partial_test_input + torch.cat([g_fake, d_fakes], 1)
                        inner_iter_count = inner_iter_count + 1
                    #gy = gys[accumulation_index]
                #z_ = z_.view(z_.size(0), -1)
                D_fake, G_z = GD(z=z_, gy=gy, train_G=True, split_D=config['split_D'], return_G_z=True)

                if not fb_iter == 0:
                    #g_fakes_enforcement = losses.loss_enforcing(g_fakes#[accumulation_index]
                                                                #, G_z)
                    # d_fakes_enforcement = losses.loss_enforcing(d_fakes#[accumulation_index]
                    #                                             , D_fake)
                    G_loss = (losses.generator_loss(D_fake))# + 0.1 * g_fakes_enforcement) #/ float(config['num_G_accumulations'])
                else:
                    G_loss = (losses.generator_loss(D_fake))# / float(config['num_G_accumulations'])

                G_loss.backward()

                if state_dict['itr'] % 1000 == 0:# and accumulation_index == 6:
                    print('saving img')
                    # torchvision.utils.save_image(D_fake.float().cpu(),
                    #                            '/ubc/cs/research/shield/projects/cshen001/BigGAN-original/BigGAN-PyTorch/samples_new/{}_it{}_fb{}_dfake.jpg'.format(time,
                    #                                state_dict['itr'], fb_iter),
                    #                            nrow=int(D_fake.shape[0] ** 0.5),
                    #                            normalize=True)
                    torchvision.utils.save_image(G_z.float().cpu(),
                                               '/ubc/cs/research/shield/projects/cshen001/BigGAN-original/BigGAN-PyTorch/samples_new/{}_it{}_fb{}_G_z.jpg'.format(time,
                                                   state_dict['itr'], fb_iter),
                                               nrow=int(D_fake.shape[0] ** 0.5),
                                               normalize=True)
                    if fb_iter > 1:
                        torchvision.utils.save_image(g_fake.float().cpu(),
                                               '/ubc/cs/research/shield/projects/cshen001/BigGAN-original/BigGAN-PyTorch/samples_new/{}_it{}_fb{}_G_z_input.jpg'.format(time,
                                                   state_dict['itr'], fb_iter),
                                               nrow=int(D_fake.shape[0] ** 0.5),
                                               normalize=True)

                #g_fakes[accumulation_index] = nn.AvgPool2d(4)(G_z).detach()
                g_fakes = G_z.detach()
                #g_fakes = G_z.detach()
                #d_fakes[accumulation_index] = D_fake.detach()

                d_fakes = D_fake.detach()

                # Optionally apply modified ortho reg in G
                if config['G_ortho'] > 0.0:
                    print('using modified ortho reg in G') # Debug print to indicate we're using ortho reg in G
                    # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
                    utils.ortho(G, config['G_ortho'],
                                      blacklist=[param for param in G.shared.parameters()])
                G.optim.step()
                    #G.optim.zero_grad()

        # If we have an ema, update it, regardless of if we test with it or not
        if config['ema']:
          ema.update(state_dict['itr'])

        out = {'G_loss': float(G_loss.item()),
                'D_loss_real': float(D_loss_real.item()),
                'D_loss_fake': float(D_loss_fake.item())}
        # Return G's loss and the components of D's loss.

        partial_test_input = partial_test_input / (inner_iter_count + 1e-9)
        return out, partial_test_input
示例#22
0
    def train(x, y):
        G.optim.zero_grad()
        D.optim.zero_grad()
        E.optim.zero_grad()

        # How many chunks to split x and y into?
        x = torch.split(x, config['batch_size'])
        y = torch.split(y, config['batch_size'])
        # print("inside fns", x)
        print("split - x {}".format(len(x)))
        counter = 0

        # Optionally toggle D and G's "require_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, True)
            utils.toggle_grad(G, False)
            utils.toggle_grad(E, False)
        # print("inside train fns: config['num_D_steps']", config['num_D_steps'])
        for step_index in range(config['num_D_steps']):
            # If accumulating gradients, loop multiple times before an optimizer step
            D.optim.zero_grad()
            # print("---------------------- counter {} ---------------".format(counter))
            # print("x[counter] {}; y[counter] {}".format(x[counter].shape, y[counter].shape))
            for accumulation_index in range(config['num_D_accumulations']):
                # Cornner case for the last batch
                if counter >= len(x):
                    break
                D_fake, D_real = GDE(x[counter], y[counter], config, state_dict['itr'], img_pool, train_G=False,
                                    split_D=config['split_D'])
                # Compute components of D's loss, average them, and divide by
                # the number of gradient accumulations
                D_loss_real, D_loss_fake = losses.discriminator_loss( \
                    D_fake, D_real, config['clip'])
                D_loss = (D_loss_real + D_loss_fake) / \
                    float(config['num_D_accumulations'])
                print("D_loss: {}; D_fake {}, D_real {}".format(D_loss.item(), D_loss_fake.item(), D_loss_real.item()))
                D_loss.backward()
                counter += 1

            # Optionally apply ortho reg in D
            if config['D_ortho'] > 0.0:
                # Debug print to indicate we're using ortho reg in D.
                print('using modified ortho reg in D')
                utils.ortho(D, config['D_ortho'])
            
            # stop gradient for testing purpose
            if config['stop_gradient']:
                print("!!! D is not optimized since you turn on `stop_gradient`!!!!!!")
            else:
                D.optim.step()

        # Optionally toggle "requires_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, False)
            utils.toggle_grad(G, True)
            utils.toggle_grad(E, True)

        # Zero G/E's gradients by default before training G, for safety
        G.optim.zero_grad()
        E.optim.zero_grad()
        # If accumulating gradients, loop multiple times
        counter = 0 # reset counter for data split
        for accumulation_index in range(config['num_G_accumulations']):
            if counter >= len(x):
                    break
            # print("---------------------- counter {} ---------------".format(counter))
            output = GDE(x[counter], y[counter], config, state_dict['itr'], img_pool, train_G=True, split_D=config['split_D'], return_G_z=True)
            D_fake = output[0]
            G_z = output[2]
            mu, log_var = output[3], output[4]
            if len(output) == 6:
                G_additional = output[5]
            # print("checkpoint==========================")
            G_loss = losses.generator_loss(
                D_fake) / float(config['num_G_accumulations'])
            VAE_recon_loss = losses.vae_recon_loss(G_z, x[counter])
            VAE_kld_loss = losses.vae_kld_loss(mu, log_var, config['clip'])
            GE_loss = G_loss + VAE_recon_loss * config['lambda_vae_recon'] + VAE_kld_loss * config['lambda_vae_kld']
                            # weights_TTs.mean() * config['lambda_spatial_transform_weights']
                            
            # log_loss_str = f"GE_loss {GE_loss.item()}; VAE_recon_loss {VAE_recon_loss.item()}; VAE_kld_loss {VAE_kld_loss.item()}; weights_TTs {weights_TTs.mean().item()}; "
            log_loss_str = f"GE_loss {GE_loss.item()}; VAE_recon_loss {VAE_recon_loss.item()}; VAE_kld_loss {VAE_kld_loss.item()} "

            # add G_additional loss
            if len(output) == 6:
                G_additional_loss = config['lambda_g_additional'] * G_additional.sum()
                GE_loss += G_additional_loss
                log_loss_str += f"G_additional {G_additional_loss.item()}"
            
            # print out loss
            print(log_loss_str)
            
            # optimization
            GE_loss.backward()
            counter += 1


        # Optionally apply modified ortho reg in G
        if config['G_ortho'] > 0.0:
            # Debug print to indicate we're using ortho reg in G
            print('using modified ortho reg in G')
            # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
            utils.ortho(G, config['G_ortho'],
                        blacklist=[param for param in G.shared.parameters()])
        
        # stop gradient for testing purpose
        if config['stop_gradient']:
            print("!!! G and E is not optimized since you turn on `stop_gradient`!!!!!!")
        else:
            G.optim.step()
            E.optim.step()

        # If we have an ema, update it, regardless of if we test with it or not
        if config['ema']:
            ema.update(state_dict['itr'])

        out = {'G_loss': float(G_loss.item()),
               'D_loss_real': float(D_loss_real.item()),
               'D_loss_fake': float(D_loss_fake.item()),
               'VAE_recon_loss': float(VAE_recon_loss.item()),
               'VAE_KLD_loss': float(VAE_recon_loss.item())}
        # Return G's loss and the components of D's loss.
        return out
def find_lr(data, batch_size=1, start_lr=1e-9, end_lr=1):
    generator = Generator(input_shape=[None,None,2])
    discriminator = Discriminator(input_shape=[None,None,1])

    generator_optimizer = tf.keras.optimizers.Adam(lr=start_lr)
    discriminator_optimizer = tf.keras.optimizers.Adam(lr=start_lr)

    model_name = data['training'].origin+'_2_any'
    checkpoint_prefix = os.path.join(CHECKPOINT_DIR, model_name)
    if(not os.path.isdir(checkpoint_prefix)):
        os.makedirs(checkpoint_prefix)

    epoch_size = data['training'].__len__()
    lr_mult = (end_lr / start_lr) ** (1 / epoch_size)

    lrs = []
    losses = {
        'gen_mae': [],
        'gen_loss': [],
        'disc_loss': []
    }
    best_losses = {
        'gen_mae': 1e9,
        'gen_loss': 1e9,
        'disc_loss': 1e9
    }

    print()
    print("Finding the optimal LR with the following parameters: ")
    print("\tCheckpoints: \t", checkpoint_prefix)
    print("\tEpochs: \t", 1)
    print("\tBatchSize: \t", batch_size)
    print("\tnBatches: \t", epoch_size)
    print()    

    print('Epoch {}/{}'.format(1, 1))
    progbar = tf.keras.utils.Progbar(epoch_size)
    for i in range(epoch_size):
        # Get the data from the DataGenerator
        input_image, target = data['training'].__getitem__(i) 
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            # Generate a fake image
            gen_output = generator(input_image, training=True)
            
            # Train the discriminator
            disc_real_output = discriminator([input_image[:,:,:,0:1], target], training=True)
            disc_generated_output = discriminator([input_image[:,:,:,0:1], gen_output], training=True)
            
            # Compute the losses
            gen_mae = l1_loss(target, gen_output)
            gen_loss = generator_loss(disc_generated_output, gen_mae)
            disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

            # Compute the gradients
            generator_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables)
            discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
            
            # Apply the gradients
            generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
            discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))
            
            # Convert losses to numpy 
            gen_mae = gen_mae.numpy()
            gen_loss = gen_loss.numpy()
            disc_loss = disc_loss.numpy()
            
            # Update the progress bar
            progbar.add(1, values=[
                                    ("gen_mae", gen_mae), 
                                    ("gen_loss", gen_loss), 
                                    ("disc_loss", disc_loss)
                                ])
            
            # On batch end
            lr = tf.keras.backend.get_value(generator_optimizer.lr)
            lrs.append(lr)
            
            # Update the lr
            lr *= lr_mult
            tf.keras.backend.set_value(generator_optimizer.lr, lr)
            tf.keras.backend.set_value(discriminator_optimizer.lr, lr)
            
            # Update the losses
            losses['gen_mae'].append(gen_mae)
            losses['gen_loss'].append(gen_loss)
            losses['disc_loss'].append(disc_loss)
            
            # Update the best losses
            if(best_losses['gen_mae'] > gen_mae):
                best_losses['gen_mae'] = gen_mae
            if(best_losses['gen_loss'] > gen_loss):
                best_losses['gen_loss'] = gen_loss
            if(best_losses['disc_loss'] > disc_loss):
                best_losses['disc_loss'] = disc_loss
            if(gen_mae >= 100*best_losses['gen_mae'] or gen_loss >= 100*best_losses['gen_loss'] or disc_loss >= 100*best_losses['disc_loss']):
                break

    plot_loss_findlr(losses['gen_mae'], lrs, os.path.join(checkpoint_prefix, 'LRFinder_gen_mae.tiff'))
    plot_loss_findlr(losses['gen_loss'], lrs, os.path.join(checkpoint_prefix, 'LRFinder_gen_loss.tiff'))
    plot_loss_findlr(losses['disc_loss'], lrs, os.path.join(checkpoint_prefix, 'LRFinder_disc_loss.tiff'))

    print('Best losses:')
    print('gen_mae =', best_losses['gen_mae'])
    print('gen_loss =', best_losses['gen_loss'])
    print('disc_loss =', best_losses['disc_loss'])
示例#24
0
    def generator_train_step(self, gen_prev_images, gen_next_images,
                             gen_prev_images_gt, gen_next_images_gt,
                             gen_images, gen_event_volume):
        gen_model = self.models_dict['gen']
        dis_model = self.models_dict['dis']
        if self.options.cycle_recons:
            e2i_model = self.models_dict['e2i']
        if self.options.cycle_flow:
            e2f_model = self.models_dict['e2f']

        losses = {}
        outputs = {}
        g_loss = 0.
        # Train generator.
        # Generator output.
        gen_fake_volume = gen_model(gen_images)

        if not self.options.no_train_gan:
            # Get discriminator prediction.
            classification = dis_model(gen_fake_volume[::-1], gen_images)
            # Compute GAN loss.
            g_loss += generator_loss("hinge", classification)
            losses['generator'] = g_loss

        cycle_loss = 0.
        # cycle consistency loss.
        if self.options.cycle_recons:
            e2i_input = torch.sum(gen_fake_volume[-1], dim=1, keepdim=True)
            e2i_input = torch.cat([e2i_input, gen_prev_images], dim=1)
            recons_image_list = e2i_model(e2i_input)

            reconstruction_loss = [self.image_loss(F.interpolate(r_img, gen_next_images.shape[2:]),
                                                   gen_next_images) \
                                   for r_img in recons_image_list]

            reconstruction_loss = torch.sum(torch.stack(
                [loss * 2. ** (i - len(reconstruction_loss) + 1) \
                 for i, loss in enumerate(reconstruction_loss)]))
            #    recons_image,
            cycle_loss += self.options.cycle_recons_weight * reconstruction_loss
            losses['cycle_reconstruction_loss'] = reconstruction_loss
            outputs['reconstructed_image'] = (recons_image_list[-1] + 1.) / 2.
        if self.options.cycle_flow:
            flow_mask = torch.sum(gen_event_volume, 1) > 0
            e2f_input = gen_fake_volume[-1]
            flow_output = e2f_model(e2f_input)
            photo_loss, smooth_loss, _ = multi_scale_flow_loss(
                gen_prev_images_gt,
                gen_next_images_gt,
                flow_output,
                flow_mask,
                second_order_smooth=False)
            flow_loss = photo_loss
            if not self.options.no_flow_smoothness:
                flow_loss += smooth_loss * self.options.smooth_weight
            cycle_loss += self.options.cycle_flow_weight * flow_loss
            losses['cycle_flow_loss'] = flow_loss
            outputs['cycle_flow'] = flow2rgb(flow_output[-1])
        g_loss += cycle_loss

        # Other outputs for visualization.
        outputs.update(
            gen_event_images(gen_fake_volume[-1], 'gen', self.device,
                             self.options.normalize_events))
        outputs.update(
            gen_event_images(gen_event_volume, 'raw', self.device,
                             self.options.normalize_events))
        outputs['gt_gray'] = (gen_next_images + 1.) / 2.
        outputs['gen_event_hist'] = gen_fake_volume[-1]
        outputs['raw_event_hist'] = gen_event_volume
        return g_loss, losses, outputs
示例#25
0
    def train(x, y):
        G.optim.zero_grad()
        D.optim.zero_grad()
        # How many chunks to split x and y into?
        x = torch.split(x, config['batch_size'])
        y = torch.split(y, config['batch_size'])
        counter = 0

        # Optionally toggle D and G's "require_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, True)
            utils.toggle_grad(G, False)

        for step_index in range(config['num_D_steps']):
            # If accumulating gradients, loop multiple times before an
            # optimizer step
            D.optim.zero_grad()

            for accumulation_index in range(config['num_D_accumulations']):
                z_, y_ = sample()
                D_fake, D_real = GD(z_[:config['batch_size']], y_[:config['batch_size']],
                                    x[counter], y[counter], train_G=False,
                                    split_D=config['split_D'])
                # Compute components of D's loss, average them, and divide by
                # the number of gradient accumulations
                D_loss_real, D_loss_fake = losses.discriminator_loss(
                    D_fake, D_real)
                D_loss = (D_loss_real + D_loss_fake) / \
                    float(config['num_D_accumulations'])
                D_loss.backward()
                counter += 1

            # Optionally apply ortho reg in D
            if config['D_ortho'] > 0.0:
                # Debug print to indicate we're using ortho reg in D.
                xm.master_print('using modified ortho reg in D')
                utils.ortho(D, config['D_ortho'])

            xm.optimizer_step(D.optim)

        # Optionally toggle "requires_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, False)
            utils.toggle_grad(G, True)

        # Zero G's gradients by default before training G, for safety
        G.optim.zero_grad()

        # If accumulating gradients, loop multiple times
        for accumulation_index in range(config['num_G_accumulations']):

            z_, y_ = sample()
            D_fake = GD(z_, y_, train_G=True, split_D=config['split_D'])
            G_loss = losses.generator_loss(
                D_fake) / float(config['num_G_accumulations'])
            G_loss.backward()

        # Optionally apply modified ortho reg in G
        if config['G_ortho'] > 0.0:
            # Debug print to indicate we're using ortho reg in G
            print('using modified ortho reg in G')
            # Don't ortho reg shared, it makes no sense. Really we should
            # blacklist any embeddings for this
  
            utils.ortho(G, config['G_ortho'],
                        blacklist=[param for param in G.shared.parameters()])
        xm.optimizer_step(G.optim)

        # If we have an ema, update it, regardless of if we test with it or not
        if config['ema']:
            ema.update(state_dict['itr'])

        out = {'G_loss': G_loss,
               'D_loss_real': D_loss_real,
               'D_loss_fake': D_loss_fake}
        # Return G's loss and the components of D's loss.
        return out
示例#26
0
def training_step_gen(generator_zebras, generator_horses, discriminator_zebras,
                      discriminator_horses, images_zebras, images_horses,
                      optimizer_zebras, optimizer_horses, lambda_factor):

    #clarification: generator_zebras generates zebra images from horses
    #Calculate the loss for both generators and update the weights
    with tf.GradientTape() as tape_horse, tf.GradientTape() as tape_zebra:

        #feed original images to generators
        fake_images_zebras = generator_zebras(images_horses)
        fake_images_horses = generator_horses(images_zebras)

        #get the assigned predicition from the discriminators
        fake_image_predictions_zebras = discriminator_zebras(
            fake_images_zebras)
        fake_image_predictions_horses = discriminator_horses(
            fake_images_horses)

        #calculate the adversarial generatorloss:
        #did the discriminator recognize the images as generated?
        gen_loss_zebras = losses.generator_loss(fake_image_predictions_zebras)
        gen_loss_horses = losses.generator_loss(fake_image_predictions_horses)

        #pass the generetaed zebra images of generator_zebras to generator_horses
        #(to see if it produces horse images close to the original image)
        recreated_images_horses = generator_horses(fake_images_zebras)
        recreated_images_zebras = generator_zebras(fake_images_horses)

        #calculate cycle loss: the weighting factor lambda is set to 10
        #how much does the original image differ from the the cycled image
        cycle_loss_forward = losses.calc_cycle_loss(images_zebras,
                                                    recreated_images_zebras,
                                                    lambda_factor)
        cycle_loss_backward = losses.calc_cycle_loss(images_horses,
                                                     recreated_images_horses,
                                                     lambda_factor)
        total_cycle_loss = cycle_loss_forward + cycle_loss_backward

        #give images from their target domain to the generators
        # e.g. give zebra images to a zebra generator and then see if the output
        #images are close to original images -> identity loss
        same_images_reconstructed_zebras = generator_zebras(images_zebras)
        same_images_reconstructed_horses = generator_horses(images_horses)

        identity_loss_horses = losses.identity_loss(
            images_horses, same_images_reconstructed_horses, lambda_factor)
        identity_loss_zebras = losses.identity_loss(
            images_zebras, same_images_reconstructed_zebras, lambda_factor)

        # sum up the losses for each generator
        # this means the respective generator and identity loss (for their domain)
        # but also the complete cycle consistency loss!
        total_loss_zebras = gen_loss_zebras + total_cycle_loss + identity_loss_zebras
        total_loss_horses = gen_loss_horses + total_cycle_loss + identity_loss_horses

        #update weights (by calculating gradients) of the currently trained generator
        gradients_zebras = tape_zebra.gradient(
            total_loss_zebras, generator_zebras.trainable_variables)
        gradients_horses = tape_horse.gradient(
            total_loss_horses, generator_horses.trainable_variables)

    #update weights
    optimizer_zebras.apply_gradients(
        zip(gradients_zebras, generator_zebras.trainable_variables))
    optimizer_horses.apply_gradients(
        zip(gradients_horses, generator_horses.trainable_variables))

    #return loss and generated images for the buffer
    return total_loss_zebras, total_loss_horses, fake_images_zebras, fake_images_horses