Пример #1
0
class Solver(object):
    def __init__(self, face_data_loader, config):
        # Data loader
        self.face_data_loader = face_data_loader

        # Model parameters
        self.y_dim = config.y_dim
        self.num_layers = config.num_layers
        self.im_size = config.im_size
        self.g_first_dim = config.g_first_dim
        self.d_first_dim = config.d_first_dim
        self.enc_repeat_num = config.enc_repeat_num
        self.d_repeat_num = config.d_repeat_num
        self.d_train_repeat = config.d_train_repeat

        # Hyper-parameteres
        self.lambda_cls = config.lambda_cls
        self.lambda_id = config.lambda_id
        self.lambda_bi = config.lambda_bi
        self.lambda_gp = config.lambda_gp
        self.enc_lr = config.enc_lr
        self.dec_lr = config.dec_lr
        self.d_lr = config.d_lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2

        # Training settings
        self.num_epochs = config.num_epochs
        self.num_epochs_decay = config.num_epochs_decay
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.batch_size = config.batch_size
        self.trained_model = config.trained_model

        # Test settings
        self.test_model = config.test_model

        # Path
        self.log_path = config.log_path
        self.sample_path = config.sample_path
        self.model_path = config.model_path
        self.test_path = config.test_path

        # Step size
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step

        # Set tensorboard
        self.build_model()
        self.use_tensorboard()

        # Start with trained model
        if self.trained_model:
            self.load_trained_model()

    def build_model(self):
        # Define encoder-decoder (generator) and a discriminator
        self.Enc = Encoder(self.g_first_dim, self.enc_repeat_num)
        self.Dec = Decoder(self.g_first_dim)
        self.D = Discriminator(self.im_size, self.d_first_dim,
                               self.d_repeat_num)

        # Optimizers
        self.enc_optimizer = torch.optim.Adam(self.Enc.parameters(),
                                              self.enc_lr,
                                              [self.beta1, self.beta2])
        self.dec_optimizer = torch.optim.Adam(self.Dec.parameters(),
                                              self.dec_lr,
                                              [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr,
                                            [self.beta1, self.beta2])

        if torch.cuda.is_available():
            self.Enc.cuda()
            self.Dec.cuda()
            self.D.cuda()

    def load_trained_model(self):

        self.Enc.load_state_dict(
            torch.load(
                os.path.join(self.model_path,
                             '{}_Enc.pth'.format(self.trained_model))))
        self.Dec.load_state_dict(
            torch.load(
                os.path.join(self.model_path,
                             '{}_Dec.pth'.format(self.trained_model))))
        self.D.load_state_dict(
            torch.load(
                os.path.join(self.model_path,
                             '{}_D.pth'.format(self.trained_model))))
        print('loaded models (step: {})..!'.format(self.trained_model))

    def use_tensorboard(self):
        from tensorboard_logger import Logger
        self.logger = Logger(self.log_path)

    def update_lr(self, enc_lr, dec_lr, d_lr):
        for param_group in self.enc_optimizer.param_groups:
            param_group['lr'] = enc_lr
        for param_group in self.dec_optimizer.param_groups:
            param_group['lr'] = dec_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr

    def reset(self):
        self.enc_optimizer.zero_grad()
        self.dec_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def to_var(self, x, volatile=False):
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x, volatile=volatile)

    def calculate_accuracy(self, x, y):
        _, predicted = torch.max(x, dim=1)
        correct = (predicted == y).float()
        accuracy = torch.mean(correct) * 100.0
        return accuracy

    def denorm(self, x):
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def one_hot(self, labels, dim):
        """Convert label indices to one-hot vector"""
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)
        out[np.arange(batch_size), labels.long()] = 1
        return out

    def train(self):
        """Train attribute-guided face image synthesis model"""
        self.data_loader = self.face_data_loader
        # The number of iterations for each epoch
        iters_per_epoch = len(self.data_loader)

        sample_x = []
        sample_l = []
        real_y = []
        for i, (images, landmark) in enumerate(self.data_loader):
            labels = images[1]
            sample_x.append(images[0])
            sample_l.append(landmark[0])
            real_y.append(labels)
            if i == 2:
                break

        # Sample inputs and desired domain labels for testing
        sample_x = torch.cat(sample_x, dim=0)
        sample_x = self.to_var(sample_x, volatile=True)
        sample_l = torch.cat(sample_l, dim=0)
        sample_l = self.to_var(sample_l, volatile=True)
        real_y = torch.cat(real_y, dim=0)

        sample_y_list = []
        for i in range(self.y_dim):
            sample_y = self.one_hot(
                torch.ones(sample_x.size(0)) * i, self.y_dim)
            sample_y_list.append(self.to_var(sample_y, volatile=True))

        # Learning rate for decaying
        d_lr = self.d_lr
        enc_lr = self.enc_lr
        dec_lr = self.dec_lr

        # Start with trained model
        if self.trained_model:
            start = int(self.trained_model.split('_')[0])
        else:
            start = 0

        # Start training
        start_time = time.time()
        for e in range(start, self.num_epochs):
            for i, (real_image, real_landmark) in enumerate(self.data_loader):
                #real_x: real image and real_l: conditional side image (landmark heatmap)
                real_x = real_image[0]
                real_label = real_image[1]
                real_l = real_landmark[0]

                # Sample fake labels randomly
                rand_idx = torch.randperm(real_label.size(0))
                fake_label = real_label[rand_idx]

                real_y = self.one_hot(real_label, self.y_dim)
                fake_y = self.one_hot(fake_label, self.y_dim)

                # Convert tensor to variable
                real_x = self.to_var(real_x)
                real_l = self.to_var(real_l)
                real_y = self.to_var(real_y)
                fake_y = self.to_var(fake_y)
                real_label = self.to_var(real_label)
                fake_label = self.to_var(fake_label)

                #================== Train Discriminator ================== #
                # Input images (original image+side images) are concatenated
                src_output, cls_output = self.D(torch.cat([real_x, real_l], 1))
                d_loss_real = -torch.mean(src_output)
                d_loss_cls = F.cross_entropy(cls_output, real_label)

                # Compute expression recognition accuracy on synthetic images
                if (i + 1) % self.log_step == 0:
                    accuracies = self.calculate_accuracy(
                        cls_output, real_label)
                    log = [
                        "{:.2f}".format(acc)
                        for acc in accuracies.data.cpu().numpy()
                    ]
                    print('Recognition Acc: ')
                    print(log)

                # Generate outputs and compute loss with fake generated images
                enc_feat = self.Enc(torch.cat([real_x, real_l], 1))
                fake_x, fake_l = self.Dec(enc_feat, fake_y)
                fake_x = Variable(fake_x.data)
                fake_l = Variable(fake_l.data)

                src_output, cls_output = self.D(torch.cat([fake_x, fake_l], 1))
                d_loss_fake = torch.mean(src_output)

                # Discriminator losses
                d_loss = self.lambda_cls * d_loss_cls + d_loss_real + d_loss_fake
                self.reset()
                d_loss.backward()
                self.d_optimizer.step()

                # Compute gradient penalty loss
                real = torch.cat([real_x, real_l], 1)
                fake = torch.cat([fake_x, fake_l], 1)
                alpha = torch.rand(real_x.size(0), 1, 1,
                                   1).cuda().expand_as(real)
                interpolated = Variable(alpha * real.data +
                                        (1 - alpha) * fake.data,
                                        requires_grad=True)
                output, cls_output = self.D(interpolated)

                grad = torch.autograd.grad(outputs=output,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(
                                               output.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Gradient penalty loss
                d_loss = self.lambda_gp * d_loss_gp
                self.reset()
                d_loss.backward()
                self.d_optimizer.step()

                # Logging
                loss = {}
                loss['D/loss_real'] = d_loss_real.data[0]
                loss['D/loss_fake'] = d_loss_fake.data[0]
                loss['D/loss_cls'] = d_loss_cls.data[0]
                loss['D/loss_gp'] = d_loss_gp.data[0]

                # ================== Train Encoder-Decoder networks ================== #
                if (i + 1) % self.d_train_repeat == 0:

                    # Original-to-target and target-to-original domain
                    enc_feat = self.Enc(torch.cat([real_x, real_l], 1))
                    fake_x, fake_l = self.Dec(enc_feat, fake_y)
                    src_output, cls_output = self.D(
                        torch.cat([fake_x, fake_l], 1))
                    g_loss_fake = -torch.mean(src_output)

                    #rec_feat = self.Enc(fake_x)
                    rec_feat = self.Enc(torch.cat([fake_x, fake_l], 1))
                    rec_x, rec_l = self.Dec(rec_feat, real_y)

                    # bidirectional loss of the images
                    g_loss_rec_x = torch.mean(torch.abs(real_x - rec_x))
                    g_loss_rec_l = torch.mean(torch.abs(real_l - rec_l))

                    #bidirectional loss of the latent feature
                    g_loss_feature = torch.mean(torch.abs(enc_feat - rec_feat))

                    #identity loss of the images
                    g_loss_identity_x = torch.mean(torch.abs(real_x - fake_x))
                    g_loss_identity_l = torch.mean(torch.abs(real_l - fake_l))

                    # attribute classification loss for the fake generated images
                    g_loss_cls = F.cross_entropy(cls_output, fake_label)

                    # Backward + Optimize (generator (encoder-decoder) losses), we update decoder two times for each encoder update
                    g_loss = g_loss_fake + self.lambda_bi * g_loss_rec_x + self.lambda_bi * g_loss_rec_l + self.lambda_bi * g_loss_feature + self.lambda_id * g_loss_identity_x + self.lambda_id * g_loss_identity_l + self.lambda_cls * g_loss_cls
                    self.reset()
                    g_loss.backward()
                    self.enc_optimizer.step()
                    self.dec_optimizer.step()
                    self.dec_optimizer.step()

                    # Logging Generator losses
                    loss['G/loss_feature'] = g_loss_feature.data[0]
                    loss['G/loss_identity_x'] = g_loss_identity_x.data[0]
                    loss['G/loss_identity_l'] = g_loss_identity_l.data[0]
                    loss['G/loss_rec_x'] = g_loss_rec_x.data[0]
                    loss['G/loss_rec_l'] = g_loss_rec_l.data[0]
                    loss['G/loss_fake'] = g_loss_fake.data[0]
                    loss['G/loss_cls'] = g_loss_cls.data[0]

                # Print out log
                if (i + 1) % self.log_step == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))

                    log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format(
                        elapsed, e + 1, self.num_epochs, i + 1,
                        iters_per_epoch)

                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)

                    for tag, value in loss.items():
                        self.logger.scalar_summary(tag, value,
                                                   e * iters_per_epoch + i + 1)

                # Synthesize images
                if (i + 1) % self.sample_step == 0:
                    fake_image_list = [sample_x]
                    for sample_y in sample_y_list:
                        enc_feat = self.Enc(torch.cat([sample_x, sample_l], 1))
                        sample_result, sample_landmark = self.Dec(
                            enc_feat, sample_y)
                        fake_image_list.append(sample_result)
                    fake_images = torch.cat(fake_image_list, dim=3)
                    save_image(self.denorm(fake_images.data),
                               os.path.join(
                                   self.sample_path,
                                   '{}_{}_fake.png'.format(e + 1, i + 1)),
                               nrow=1,
                               padding=0)
                    print('Generated images and saved into {}..!'.format(
                        self.sample_path))

                # Save checkpoints
                if (i + 1) % self.model_save_step == 0:
                    torch.save(
                        self.Enc.state_dict(),
                        os.path.join(self.model_path,
                                     '{}_{}_Enc.pth'.format(e + 1, i + 1)))
                    torch.save(
                        self.Dec.state_dict(),
                        os.path.join(self.model_path,
                                     '{}_{}_Dec.pth'.format(e + 1, i + 1)))
                    torch.save(
                        self.D.state_dict(),
                        os.path.join(self.model_path,
                                     '{}_{}_D.pth'.format(e + 1, i + 1)))

            # Decay learning rate
            if (e + 1) > (self.num_epochs - self.num_epochs_decay):
                d_lr -= (self.d_lr / float(self.num_epochs_decay))
                enc_lr -= (self.enc_lr / float(self.num_epochs_decay))
                dec_lr -= (self.dec_lr / float(self.num_epochs_decay))
                self.update_lr(enc_lr, dec_lr, d_lr)
                print('Decay learning rate to enc_lr: {}, d_lr: {}.'.format(
                    enc_lr, d_lr))

    def test(self):
        """Generating face images owning target attributes (desired expressions) """
        # Load trained models
        Enc_path = os.path.join(self.model_path,
                                '{}_Enc.pth'.format(self.test_model))
        Dec_path = os.path.join(self.model_path,
                                '{}_Dec.pth'.format(self.test_model))
        self.Enc.load_state_dict(torch.load(Enc_path))
        self.Dec.load_state_dict(torch.load(Dec_path))
        self.Enc.eval()
        self.Dec.eval()

        data_loader = self.face_data_loader

        for i, (real_image, real_landmark) in enumerate(data_loader):
            org_c = real_image[1]
            real_x = real_image[0]
            real_l = real_landmark[0]
            real_x = self.to_var(real_x, volatile=True)
            real_l = self.to_var(real_l, volatile=True)

            target_y_list = []
            for j in range(self.y_dim):
                target_y = self.one_hot(
                    torch.ones(real_x.size(0)) * j, self.y_dim)
                target_y_list.append(self.to_var(target_y, volatile=True))

            # Target image generation
            fake_image_list = [real_x]
            for target_y in target_y_list:
                enc_feat = self.Enc(torch.cat([real_x, real_l], 1))
                sample_result, sample_landmark = self.Dec(enc_feat, target_y)
                fake_image_list.append(sample_result)
            fake_images = torch.cat(fake_image_list, dim=3)
            save_path = os.path.join(self.test_path,
                                     '{}_fake.png'.format(i + 1))
            save_image(self.denorm(fake_images.data),
                       save_path,
                       nrow=1,
                       padding=0)
            print('Generated images and saved into "{}"..!'.format(save_path))
Пример #2
0
        lr_scheduler.step()
        for i_batch, sampled_batch in enumerate(loader):
            data, target = sampled_batch

            if torch.cuda.is_available():
                data, target = Variable(data).cuda(), Variable(target).cuda()
            else:
                data, target = Variable(data), Variable(target)

            optimizer.zero_grad()
            pred = net(data)
            loss = loss_fn(pred, target.float())
            loss.backward()
            optimizer.step()
            logger.info('[epoch: {}, batch: {}] Training loss: {}'.format(
                epoch, i_batch, loss.data[0]))
            tb_logger.scalar_summary('loss', loss.data[0],
                                     epoch * niter_per_epoch + i_batch + 1)

        # (2) Log values and gradients of the parameters (histogram)
        for tag, value in net.named_parameters():
            tag = tag.replace('.', '/')
            tb_logger.histo_summary(tag, value.data.cpu().numpy(), epoch + 1)
            tb_logger.histo_summary(tag + '/grad',
                                    value.grad.data.cpu().numpy(), epoch + 1)

        if (epoch + 1) % 10 == 0:
            cp_path = opj(CHECKPOINTS_PATH, cur_time, 'model_%s' % epoch)
            mkdir_r(dirname(cp_path))
            torch.save(net.state_dict(), cp_path)
Пример #3
0
# Define loss function
criterion = nn.NLLLoss()

# Keep track of time elapsed and running averages
start = time.time()

# Set configuration for using Tensorboard
logger = Logger('graphs')

for step in range(step, final_steps + 1):

    # Get training data for this cycle
    inputs, targets, len_inputs, len_targets = train_corpus.next_batch()
    input_variable = Variable(torch.LongTensor(inputs), requires_grad=False)
    target_variable = Variable(torch.LongTensor(targets), requires_grad=False)

    if Config.use_cuda:
        input_variable = input_variable.cuda()
        target_variable = target_variable.cuda()

    # Run the train function
    loss = train(input_variable, len_inputs, target_variable, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)

    # Keep track of loss
    logger.scalar_summary('loss', loss, step)

    if step % print_every == 0:
        print('%s: %s (%d %d%%)' % (step, time_since(start, 1. * step / final_steps), step, step / final_steps * 100))

    if step % save_every == 0:
        save_state(encoder, decoder, encoder_optimizer, decoder_optimizer, step)
Пример #4
0
    # Compute accuracy
    _, argmax = torch.max(outputs, 1)
    accuracy = (labels == argmax.squeeze()).float().mean()

    if (step + 1) % 100 == 0:
        print('Step [{}/{}], Loss: {:.4f}, Acc: {:.2f}'.format(
            step + 1, total_step, loss.item(), accuracy.item()))

        # ================================================================== #
        #                        Tensorboard Logging                         #
        # ================================================================== #

        # 1. Log scalar values (scalar summary)
        info = {'loss': loss.item(), 'accuracy': accuracy.item()}

        for tag, value in info.items():
            logger.scalar_summary(tag, value, step + 1)

        # 2. Log values and gradients of the parameters (histogram summary)
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            logger.histo_summary(tag, value.data.cpu().numpy(), step + 1)
            logger.histo_summary(tag + '/grad',
                                 value.grad.data.cpu().numpy(), step + 1)

        # 3. Log training images (image summary)
        info = {'images': images.view(-1, 28, 28)[:10].cpu().numpy()}
        # [:10]:取前 9 张?

        for tag, images in info.items():
            logger.image_summary(tag, images, step + 1)
Пример #5
0
class Solver(object):
    def __init__(self, data_loaders, config):
        # Data loader
        self.data_loaders = data_loaders
        self.attrs = config.attrs

        # Model hyper-parameters
        self.c_dim = len(data_loaders['train'].dataset.class_names)
        self.c2_dim = config.c2_dim
        self.image_size = config.image_size
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.g_repeat_num = config.g_repeat_num
        self.d_repeat_num = config.d_repeat_num
        self.d_train_repeat = config.d_train_repeat

        # Hyper-parameteres
        self.lambda_cls = config.lambda_cls
        self.lambda_rec = config.lambda_rec
        self.lambda_gp = config.lambda_gp
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2

        # Training settings
        self.num_epochs = config.num_epochs
        self.num_epochs_decay = config.num_epochs_decay
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.batch_size = config.batch_size
        self.use_tensorboard = config.use_tensorboard
        self.pretrained_model = config.pretrained_model
        self.pretrained_model_path = config.pretrained_model_path

        # Test settings
        self.test_model = config.test_model

        # Path
        self.log_path = os.path.join(config.output_path, 'logs',
                                     config.output_name)
        self.sample_path = os.path.join(config.output_path, 'samples',
                                        config.output_name)
        self.model_save_path = os.path.join(config.output_path, 'models',
                                            config.output_name)
        self.result_path = os.path.join(config.output_path, 'results',
                                        config.output_name)

        # Step size
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step

        self.num_val_imgs = config.num_val_imgs

        # Build tensorboard if use
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()

        # Start with trained model
        if self.pretrained_model:
            self.load_pretrained_model()

    def build_model(self):

        # self.G = UnetGenerator(3+self.c_dim)
        self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num,
                           self.image_size)
        self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim,
                               self.d_repeat_num)

        # Optimizers
        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr,
                                            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr,
                                            [self.beta1, self.beta2])

        # Print networks
        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')

        if torch.cuda.is_available():
            self.G.cuda()
            self.D.cuda()

    def print_network(self, model, name):
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print_logger.info('{} - {} - Number of parameters: {}'.format(
            name, model, num_params))

    def load_pretrained_model(self):
        self.G.load_state_dict(
            torch.load(
                os.path.join(self.pretrained_model_path,
                             '{}_G.pth'.format(self.pretrained_model))))
        self.D.load_state_dict(
            torch.load(
                os.path.join(self.pretrained_model_path,
                             '{}_D.pth'.format(self.pretrained_model))))
        print_logger.info('loaded trained models (step: {})..!'.format(
            self.pretrained_model))

    def build_tensorboard(self):
        from tensorboard_logger import Logger
        self.logger = Logger(self.log_path)

    def update_lr(self, g_lr, d_lr):
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr

    def reset_grad(self):
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def to_var(self, x, grad=True):
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x, requires_grad=grad)

    def denorm(self, x):
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def threshold(self, x):
        x = x.clone()
        x = (x >= 0.5).float()
        return x

    def compute_accuracy(self, x, y):
        x = F.sigmoid(x)
        predicted = self.threshold(x)
        correct = (predicted == y).float()
        accuracy = torch.mean(correct, dim=0) * 100.0
        return accuracy

    def one_hot(self, labels, dim):
        """Convert label indices to one-hot vector"""
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)
        out[np.arange(batch_size), labels.long()] = 1
        return out

    def make_data_labels(self, real_c):
        """Generate domain labels for dataset for debugging/testing.
        """

        y = []
        for dim in range(self.c_dim):
            t = [0] * self.c_dim
            t[dim] = 1
            y.append(torch.FloatTensor(t))

        fixed_c_list = []

        for i in range(self.c_dim):
            fixed_c = real_c.clone()
            for c in fixed_c:
                c[:self.c_dim] = y[i]

            fixed_c_list.append(self.to_var(fixed_c, grad=False))

        return fixed_c_list

    def train(self):
        """Train StarGAN within a single dataset."""

        # The number of iterations per epoch
        data_loader = self.data_loaders['train']

        iters_per_epoch = len(data_loader)

        fixed_x = []
        real_c = []

        num_fixed_imgs = self.num_val_imgs
        for i in range(num_fixed_imgs):
            images, labels = self.data_loaders['val'].dataset.__getitem__(i)
            fixed_x.append(images.unsqueeze(0))
            real_c.append(labels.unsqueeze(0))

        # Fixed inputs and target domain labels for debugging
        fixed_x = torch.cat(fixed_x, dim=0)
        fixed_x = self.to_var(fixed_x, grad=False)
        real_c = torch.cat(real_c, dim=0)

        fixed_c_list = self.make_data_labels(real_c)

        # lr cache for decaying
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start with trained model if exists
        if self.pretrained_model:
            start = int(self.pretrained_model.split('_')[0])
        else:
            start = 0

        # Start training
        start_time = time.time()
        for e in range(start, self.num_epochs):
            for i, (real_x, real_label) in enumerate(data_loader):

                # Generat fake labels randomly (target domain labels)
                rand_idx = torch.randperm(real_label.size(0))
                fake_label = real_label[rand_idx]

                real_c = real_label.clone()
                fake_c = fake_label.clone()

                # Convert tensor to variable
                real_x = self.to_var(real_x)
                real_c = self.to_var(real_c)  # input for the generator
                fake_c = self.to_var(fake_c)
                real_label = self.to_var(
                    real_label
                )  # this is same as real_c if dataset == 'CelebA'
                fake_label = self.to_var(fake_label)

                # ================== Train D ================== #

                # Compute loss with real images
                out_src, out_cls = self.D(real_x)
                d_loss_real = -torch.mean(out_src)

                d_loss_cls = F.binary_cross_entropy_with_logits(
                    out_cls, real_label, size_average=False) / real_x.size(0)

                # Compute classification accuracy of the discriminator
                if (i + 1) % (self.log_step * 10) == 0:
                    accuracies = self.compute_accuracy(out_cls, real_label)
                    log = [
                        "{}: {:.2f}".format(attr, acc)
                        for (attr, acc) in zip(data_loader.dataset.class_names,
                                               accuracies.data.cpu().numpy())
                    ]
                    print_logger.info('Discriminator Accuracy: {}'.format(log))

                # Compute loss with fake images
                fake_x = self.G(real_x, fake_c)
                fake_x = Variable(fake_x.data)
                out_src, out_cls = self.D(fake_x)
                d_loss_fake = torch.mean(out_src)

                # Backward + Optimize
                d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                loss = {}
                loss['D/loss_real'] = d_loss_real.data.item()
                loss['D/loss_fake'] = d_loss_fake.data.item()
                loss['D/loss_cls'] = d_loss_cls.data.item()
                loss['D/loss'] = d_loss.data.item()

                # Compute gradient penalty
                alpha = torch.rand(real_x.size(0), 1, 1,
                                   1).cuda().expand_as(real_x)
                interpolated = Variable(alpha * real_x.data +
                                        (1 - alpha) * fake_x.data,
                                        requires_grad=True)
                out, out_cls = self.D(interpolated)

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(
                                               out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Logging
                loss['D/loss_gp'] = d_loss_gp.data.item()

                # ================== Train G ================== #
                if (i + 1) % self.d_train_repeat == 0:

                    # Original-to-target and target-to-original domain
                    fake_x = self.G(real_x, fake_c)
                    rec_x = self.G(fake_x, real_c)

                    # Compute losses
                    out_src, out_cls = self.D(fake_x)
                    g_loss_fake = -torch.mean(out_src)
                    g_loss_rec = torch.mean(torch.abs(real_x - rec_x))

                    g_loss_l1 = torch.mean(torch.abs(real_x - fake_x))

                    g_loss_cls = F.binary_cross_entropy_with_logits(
                        out_cls, fake_label,
                        size_average=False) / fake_x.size(0)

                    # Backward + Optimize
                    g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls  #+ g_loss_l1 * self.lambda_rec
                    self.reset_grad()
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Logging
                    loss['G/loss_fake'] = g_loss_fake.data.item()
                    loss['G/loss_rec'] = g_loss_rec.data.item()
                    loss['G/loss_cls'] = g_loss_cls.data.item()
                    # loss['G/loss_l1'] = g_loss_l1.data.item()
                    loss['G/loss'] = g_loss.data.item()

                # Print out log info
                if (i + 1) % self.log_step == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))

                    log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format(
                        elapsed, e + 1, self.num_epochs, i + 1,
                        iters_per_epoch)

                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print_logger.info(log)

                    if self.use_tensorboard:
                        for tag, value in loss.items():
                            self.logger.scalar_summary(
                                tag, value, e * iters_per_epoch + i + 1)

                # Translate fixed images for debugging
                if (i + 1) % self.sample_step == 0:
                    fake_image_list = [fixed_x]

                    for fixed_c in fixed_c_list:
                        gen_imgs = self.G(fixed_x, fixed_c)
                        fake_image_list.append(gen_imgs)

                    # fake_images = torch.cat(fake_image_list, dim=3)
                    # save_image(self.denorm(fake_images.data.cpu()),
                    #     os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0)
                    # print_logger.info('Translated images and saved into {}..!'.format(self.sample_path))

                    if self.use_tensorboard:
                        tb_imgs = [t.unsqueeze(0) for t in fake_image_list]
                        tb_imgs = torch.cat(tb_imgs)
                        tb_imgs = tb_imgs.permute(1, 0, 2, 3, 4)
                        tb_imgs_list = torch.unbind(tb_imgs, dim=0)
                        tb_imgs_list = [
                            torch.cat(torch.unbind(t, dim=0), dim=2)
                            for t in tb_imgs_list
                        ]

                        self.logger.image_summary('fixed_imgs', tb_imgs_list,
                                                  e * iters_per_epoch + i + 1)

                # Save model checkpoints
                if (i + 1) % self.model_save_step == 0:
                    torch.save(
                        self.G.state_dict(),
                        os.path.join(self.model_save_path,
                                     '{}_{}_G.pth'.format(e + 1, i + 1)))
                    torch.save(
                        self.D.state_dict(),
                        os.path.join(self.model_save_path,
                                     '{}_{}_D.pth'.format(e + 1, i + 1)))

            # Decay learning rate
            if (e + 1) > (self.num_epochs - self.num_epochs_decay):
                g_lr -= (self.g_lr / float(self.num_epochs_decay))
                d_lr -= (self.d_lr / float(self.num_epochs_decay))
                self.update_lr(g_lr, d_lr)
                print_logger.info(
                    'Decay learning rate to g_lr: {}, d_lr: {}.'.format(
                        g_lr, d_lr))

    def test(self):
        """Facial attribute transfer on CelebA or facial expression synthesis on RaFD."""
        # Load trained parameters
        G_path = os.path.join(self.model_save_path,
                              '{}_G.pth'.format(self.test_model))
        self.G.load_state_dict(torch.load(G_path))
        self.G.eval()

        data_loader = self.data_loaders['test']

        for i, (real_x, org_c) in enumerate(data_loader):
            real_x = self.to_var(real_x, grad=False)
            target_c_list = self.make_data_labels(org_c)

            # Start translations
            fake_image_list = [real_x]
            for target_c in target_c_list:
                fake_image_list.append(self.G(real_x, target_c))
            fake_images = torch.cat(fake_image_list, dim=3)
            save_path = os.path.join(self.result_path,
                                     '{}_fake.png'.format(i + 1))
            save_image(self.denorm(fake_images.data),
                       save_path,
                       nrow=1,
                       padding=0)
            print_logger.info(
                'Translated test images and saved into "{}"..!'.format(
                    save_path))
Пример #6
0
def main():
    global opt, best_studentprec1
    cudnn.benchmark = True

    opt = parser.parse_args()
    opt.logdir = opt.logdir + '/' + opt.name
    logger = Logger(opt.logdir)

    print(opt)
    best_studentprec1 = 0.0

    print('Loading models...')
    teacher = init.load_model(opt, 'teacher')
    student = init.load_model(opt, 'student')
    discriminator = init.load_model(opt, 'discriminator')
    teacher = init.setup(teacher, opt, 'teacher')
    student = init.setup(student, opt, 'student')
    discriminator = init.setup(discriminator, opt, 'discriminator')

    #Write the code to classify it in the 11th class
    print(teacher)
    print(student)
    print(discriminator)

    advCriterion = nn.BCELoss().cuda()
    similarityCriterion = nn.L1Loss().cuda()
    derivativeCriterion = nn.SmoothL1Loss().cuda()
    discclassifyCriterion = nn.CrossEntropyLoss(size_average=True).cuda()

    studOptim = getOptim(opt, student, 'student')
    discrecOptim = getOptim(opt, discriminator, 'discriminator')

    trainer = train.Trainer(student, teacher, discriminator,
                            discclassifyCriterion, advCriterion,
                            similarityCriterion, derivativeCriterion,
                            studOptim, discrecOptim, opt, logger)
    validator = train.Validator(student, teacher, discriminator, opt, logger)

    #To update. Does not work as of now
    if opt.resume:
        if os.path.isfile(opt.resume):
            model, optimizer, opt, best_prec1 = init.resumer(
                opt, model, optimizer)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    dataloader = init_data.load_data(opt)
    train_loader = dataloader.train_loader
    val_loader = dataloader.val_loader

    for epoch in range(opt.start_epoch, opt.epochs):
        utils.adjust_learning_rate(opt, studOptim, epoch)
        utils.adjust_learning_rate(opt, discrecOptim, epoch)
        print("Starting epoch number:", epoch + 1, "Learning rate:",
              studOptim.param_groups[0]["lr"])

        if opt.testOnly == False:
            trainer.train(train_loader, epoch, opt)
        if opt.tensorboard:
            logger.scalar_summary('learning_rate', opt.lr, epoch)

        student_prec1 = validator.validate(val_loader, epoch, opt)
        best_studentprec1 = max(student_prec1, best_studentprec1)
        init.save_checkpoint(opt, teacher, student, discriminator, studOptim,
                             discrecOptim, student_prec1, epoch)

        print('Best accuracy: [{0:.3f}]\t'.format(best_studentprec1))