Ejemplo n.º 1
0
def create_discriminator_criterion(args):
    d = discriminator.Discriminator(outputs_size=1000, K=8).cuda()
    d = torch.nn.DataParallel(d)
    update_parameters = {'params': d.parameters(), "lr": args.d_lr}
    discriminators_criterion = discriminatorLoss(d).cuda()
    if len(args.gpus) > 1:
        discriminators_criterion = torch.nn.DataParallel(discriminators_criterion, device_ids=args.gpus)
    return discriminators_criterion, update_parameters
Ejemplo n.º 2
0
    def train(self, epoch, saveModel=False):
        """
        Training logic for an epoch
        """
        self.model_G.train()
        self.model_D.train()
        a = time.time()
        if epoch > 10:
            self.G_iteration = 2

        for batch_idx, real in enumerate(self.data_loader):
            if len(real) != self.batch_size:
                break

            for _ in range(self.D_iteration):
                self.optimizer_D.zero_grad()
                if self.with_cuda:
                    real = real.cuda()
                if _ == 0:
                    real = Variable(real)
                real_logits = self.model_D(real)

                noise = Variable(
                    self.random_generator((self.batch_size, 100), True))
                fake = self.model_G(noise).detach()  # prevent G from updating
                fake_logits = self.model_D(fake)

                d_loss = discriminatorLoss(real_logits, fake_logits)
                d_loss.backward()
                #clip_grad_norm(self.model_D.parameters(), 15)
                self.optimizer_D.step()

            for _ in range(self.G_iteration):
                self.optimizer_G.zero_grad()

                noise = Variable(
                    self.random_generator((self.batch_size, 100), True))
                fake = self.model_G(noise)
                fake_logits = self.model_D(fake)

                g_loss = generatorLoss(fake_logits)
                g_loss.backward()
                #clip_grad_norm(self.model_G.parameters(), 30)
                self.optimizer_G.step()

            if batch_idx % self.log_step == 0:
                info = self.get_training_info(
                    epoch=epoch,
                    batch_id=batch_idx,
                    batch_size=self.batch_size,
                    total_data_size=len(self.data_loader.dataset),
                    n_batch=len(self.data_loader),
                    d_loss=d_loss.data[0],
                    g_loss=g_loss.data[0],
                    g_normD=self.get_gradient_norm(self.model_D),
                    g_normG=self.get_gradient_norm(self.model_G))
                print('\r', info, end='')  # original: end='\r'
        if (saveModel):
            print()
            model_dir = "saved"
            print('Saving model', "{}/epoch{}.pt".format(model_dir, epoch))
            torch.save(self.model_G.state_dict(),
                       "{}/epoch{}_G.pt".format(model_dir, epoch))
            #torch.save(self.model_D.state_dict(), "{}/epoch{}_D.pt".format(model_dir, epoch))
        print()
        print("Training time: ", int(time.time() - a), 'seconds/epoch')
Ejemplo n.º 3
0
                                  None, X_val_padded.shape[1],
                                  X_val_padded.shape[2], X_val_padded.shape[3],
                                  X_val_padded.shape[4]
                              ],
                              name='X_val')
is_train = tf.placeholder(tf.bool, name='is_train')

# Networks
Y_generated = network.getGenerator(X_tensor)
Y_val_generated = network.getGenerator(X_val_tensor, True)

D_logits_real = network.getDiscriminator(X_tensor, Y_tensor)
D_logits_fake = network.getDiscriminator(X_tensor, Y_generated, True)

# Losses and optimizer
D_loss = loss.discriminatorLoss(D_logits_real, D_logits_fake, labelSmoothing)
G_loss, G_gan, G_L1 = loss.generatorLoss(D_logits_fake, Y_generated, Y_tensor,
                                         L1_Weight)
optimizer = loss.getOptimizer(lr, beta1, D_loss, G_loss)

# Tensorboard
halfVolume = int(X_tensor.shape[2]) // 2
train_summaries = [tf.summary.scalar('D_loss', D_loss), tf.summary.scalar('G_loss', G_gan), tf.summary.scalar('L1_loss', G_L1), \
                    tf.summary.image('in', X_tensor[:, :, :, halfVolume], max_outputs=1), tf.summary.image('out', Y_generated[:, :, :, halfVolume], max_outputs=1), \
                    tf.summary.image('label', Y_tensor[:, :, :, halfVolume], max_outputs=1)]
train_merged_summaries = tf.summary.merge(train_summaries)

rmseTensor = tf.placeholder(tf.float32, shape=())
ddrmseTensor = tf.placeholder(tf.float32, shape=())
val_summaries = [
    tf.summary.scalar('rmse', rmseTensor),