def __init__(self,
                 device='cuda:0',
                 log_dir='logs',
                 gpu_ids=0,
                 lr=0.0002,
                 beta1=0.5,
                 lambda_idt=5,
                 lambda_A=10.0,
                 lambda_B=10.0):
        self.lr = lr
        self.beta1 = beta1
        self.device = device

        self.netG_A = Generator().to(self.device)
        self.netG_B = Generator().to(self.device)
        self.netD_A = Discriminator().to(self.device)
        self.netD_B = Discriminator().to(self.device)

        print(torch.cuda.is_available())

        # multi-GPUs
        self.netG_A = torch.nn.DataParallel(self.netG_A, gpu_ids)
        self.netG_B = torch.nn.DataParallel(self.netG_B, gpu_ids)
        self.netD_A = torch.nn.DataParallel(self.netD_A, gpu_ids)
        self.netD_B = torch.nn.DataParallel(self.netD_B, gpu_ids)
        print('will use gpus: {}'.format(gpu_ids))

        self.fake_A_pool = ImagePool(50)
        self.fake_B_pool = ImagePool(50)

        # set losses
        self.criterionGAN = GANLoss(self.device)
        self.criterionCycle = torch.nn.L1Loss()
        self.criterionIdt = torch.nn.L1Loss()

        # weights of loss function
        self.lambda_idt = lambda_idt
        self.lambda_A = lambda_A
        self.lambda_B = lambda_B

        # optimization
        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.netG_A.parameters(), self.netG_B.parameters()),
                                            lr=self.lr,
                                            betas=(self.beta1, 0.999))
        self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                              lr=self.lr,
                                              betas=(self.beta1, 0.999))
        self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                              lr=self.lr,
                                              betas=(self.beta1, 0.999))
        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D_A)
        self.optimizers.append(self.optimizer_D_B)

        self.log_dir = log_dir
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)
Exemple #2
0
    def __init__(self, config):
        self.generator = Generator()
        self.discriminator = Discriminator()

        print(self.generator)
        print(self.discriminator)
        self.bce_loss_fn = nn.BCELoss()
        self.mse_loss_fn = nn.MSELoss()

        self.opt_g = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                             self.generator.parameters()),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))
        self.opt_d = torch.optim.Adam(self.discriminator.parameters(),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))

        if config.dataset == 'grid':
            self.dataset = VaganDataset(config.dataset_dir,
                                        train=config.is_train)
        elif config.dataset == 'lrw':
            self.dataset = LRWdataset(config.dataset_dir,
                                      train=config.is_train)

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=config.batch_size,
                                      num_workers=config.num_thread,
                                      shuffle=True,
                                      drop_last=True)
        data_iter = iter(self.data_loader)
        data_iter.next()
        self.ones = Variable(torch.ones(config.batch_size),
                             requires_grad=False)
        self.zeros = Variable(torch.zeros(config.batch_size),
                              requires_grad=False)

        if config.cuda:
            device_ids = [int(i) for i in config.device_ids.split(',')]
            self.generator = nn.DataParallel(self.generator.cuda(),
                                             device_ids=device_ids)
            self.discriminator = nn.DataParallel(self.discriminator.cuda(),
                                                 device_ids=device_ids)
            self.bce_loss_fn = self.bce_loss_fn.cuda()
            self.mse_loss_fn = self.mse_loss_fn.cuda()
            self.ones = self.ones.cuda()
            self.zeros = self.zeros.cuda()

        self.config = config
        self.start_epoch = 0

        if config.load_model:
            self.start_epoch = config.start_epoch
            self.load(config.pretrained_dir, config.pretrained_epoch)
class CycleGAN(object):

    def __init__(self, device='cuda:0', log_dir='logs', gpu_ids=0, lr=0.0002, beta1=0.5,
                 lambda_idt=5, lambda_A=10.0, lambda_B=10.0, lambda_mask=10.0):
        self.lr = lr
        self.beta1 = beta1
        self.device = device

        self.netG_A = Generator().to(self.device)
        self.netG_B = Generator().to(self.device)
        self.netD_A = Discriminator().to(self.device)
        self.netD_B = Discriminator().to(self.device)

        print(torch.cuda.is_available())

        # multi-GPUs
        self.netG_A = torch.nn.DataParallel(self.netG_A, gpu_ids)
        self.netG_B = torch.nn.DataParallel(self.netG_B, gpu_ids)
        self.netD_A = torch.nn.DataParallel(self.netD_A, gpu_ids)
        self.netD_B = torch.nn.DataParallel(self.netD_B, gpu_ids)
        print('will use gpus: {}'.format(gpu_ids))

        self.fake_A_pool = ImagePool(50)
        self.fake_B_pool = ImagePool(50)

        # set losses
        self.criterionGAN = GANLoss(self.device)
        self.criterionCycle = torch.nn.L1Loss()
        self.criterionIdt = torch.nn.L1Loss()
        self.criterionMask = MASKLoss(self.device)

        # weights of loss function
        self.lambda_idt = lambda_idt
        self.lambda_A = lambda_A
        self.lambda_B = lambda_B
        self.lambda_mask = lambda_mask

        # optimization
        self.optimizer_G = torch.optim.Adam(
            itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
            lr=self.lr,
            betas=(self.beta1, 0.999))
        self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=self.lr, betas=(self.beta1, 0.999))
        self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=self.lr, betas=(self.beta1, 0.999))
        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D_A)
        self.optimizers.append(self.optimizer_D_B)

        self.log_dir = log_dir
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

    def set_input(self, input):
        self.real_A = input['A'].to(self.device)
        self.real_B = input['B'].to(self.device)
        self.real_A_mask = input['A_mask'].to(self.device)

    def backward_G(self, real_A, real_B, real_A_mask):

        idt_A = self.netG_A(real_B)
        loss_idt_A = self.criterionIdt(idt_A, real_B) * self.lambda_idt

        idt_B = self.netG_B(real_A)
        loss_idt_B = self.criterionIdt(idt_B, real_A) * self.lambda_idt

        # GAN loss D_A(G_A(A))
        # G_A tries to fool D_A as real
        fake_B = self.netG_A(real_A)
        pred_fake_B = self.netD_A(fake_B)
        loss_G_A = self.criterionGAN(pred_fake_B, True)

        # GAN loss D_B(G_B(B))
        # G_B tries to fool D_B as real
        fake_A = self.netG_B(real_B)
        pred_fake_A = self.netD_B(fake_A)
        loss_G_B = self.criterionGAN(pred_fake_A, True)

        # forward cycle loss
        # real_A => fake_B => rec_A is close to real_A
        rec_A = self.netG_B(fake_B)
        loss_cycle_A = self.criterionCycle(rec_A, real_A) * self.lambda_A

        # backward cycle loss
        # real_B => fake_A => rec_B is close to real_B
        rec_B = self.netG_A(fake_A)
        loss_cycle_B = self.criterionCycle(rec_B, real_B) * self.lambda_B

        # mse for mase as a new loss function
        if self.lambda_mask == 0:
            loss_mask = torch.tensor(0).to(self.device)
        else:
            loss_mask = self.criterionMask(real_A, fake_B, real_A_mask) * self.lambda_mask

        # combined loss
        loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B + loss_mask
        loss_G.backward()

        return loss_G_A.data, loss_G_B.data, loss_cycle_A.data, loss_cycle_B.data, \
               loss_idt_A.data, loss_idt_B.data, loss_mask.data, fake_A.data, fake_B.data

    def backward_D_A(self, real_B, fake_B):
        # work on fake_B from domain A

        # use image pool
        fake_B = self.fake_B_pool.query(fake_B)

        # real image is real
        pred_real = self.netD_A(real_B)
        loss_D_real = self.criterionGAN(pred_real, True)

        # fake image is fake
        # detach()
        pred_fake = self.netD_A(fake_B.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)

        # combined loss
        loss_D_A = (loss_D_real + loss_D_fake) * 0.5
        loss_D_A.backward()

        return loss_D_A.data

    def backward_D_B(self, real_A, fake_A):
        # work on fake_A from domain B

        fake_A = self.fake_A_pool.query(fake_A)

        # real image is real
        pred_real = self.netD_B(real_A)
        loss_D_real = self.criterionGAN(pred_real, True)

        # fake image is fake
        # detach()
        pred_fake = self.netD_B(fake_A.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)

        # combined loss
        loss_D_B = (loss_D_real + loss_D_fake) * 0.5
        loss_D_B.backward()

        return loss_D_B.data

    def optimize(self):

        # update Generator (G_A and G_B)
        self.optimizer_G.zero_grad()
        loss_G_A, loss_G_B, loss_cycle_A, loss_cycle_B, loss_idt_A, loss_idt_B, loss_mask, fake_A, fake_B \
            = self.backward_G(self.real_A, self.real_B, self.real_A_mask)
        self.optimizer_G.step()

        # update D_A
        self.optimizer_D_A.zero_grad()
        loss_D_A = self.backward_D_A(self.real_B, fake_B)
        self.optimizer_D_A.step()

        # update D_B
        self.optimizer_D_B.zero_grad()
        loss_D_B = self.backward_D_B(self.real_A, fake_A)
        self.optimizer_D_B.step()

        ret_loss = [
            loss_G_A, loss_D_A,
            loss_G_B, loss_D_B,
            loss_cycle_A, loss_cycle_B,
            loss_idt_A, loss_idt_B,
            loss_mask
        ]

        return np.array(ret_loss)

    def train(self, data_loader):
        running_loss = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
        time_list = []
        for batch_idx, data in enumerate(data_loader):

            t1 = time.perf_counter()
            self.set_input(data)
            losses = self.optimize()
            losses = losses.astype(np.float32)
            running_loss += losses

            t2 = time.perf_counter()
            get_processing_time = t2 - t1
            time_list.append(get_processing_time)

            if batch_idx % 50 == 0:
                print('batch: {} / {}, elapsed_time: {} sec'.format(batch_idx, len(data_loader), sum(time_list)))
                time_list = []

        running_loss /= len(data_loader)
        return running_loss

    def save_network(self, network, network_label, epoch_label):
        save_filename = '{}_net_{}.pth'.format(epoch_label, network_label)
        save_path = os.path.join(self.log_dir, save_filename)

        torch.save(network.cpu().state_dict(), save_path)
        network.to(self.device)

    def load_network(self, network, network_label, epoch_label):
        load_filename = '{}_net_{}.pth'.format(epoch_label, network_label)
        load_path = os.path.join(self.log_dir, load_filename)
        network.load_state_dict(torch.load(load_path))

    def save(self, label):
        self.save_network(self.netG_A, 'G_A', label)
        self.save_network(self.netD_A, 'D_A', label)
        self.save_network(self.netG_B, 'G_B', label)
        self.save_network(self.netD_B, 'D_B', label)

    def load(self, label):
        self.load_network(self.netG_A, 'G_A', label)
        self.load_network(self.netD_A, 'D_A', label)
        self.load_network(self.netG_B, 'G_B', label)
        self.load_network(self.netD_B, 'D_B', label)

    def save_imgs(self, imgs, name_imgs, batch_size, epoch_label):
        img_table_name = '{}_'.format(epoch_label) + name_imgs + '.png'
        save_path = os.path.join(self.log_dir, img_table_name)
        if batch_size <= 16:
            utils.save_image(
                imgs,
                save_path,
                nrow=int(batch_size ** 0.5),
                normalize=True,
                range=(-1, 1)
            )
        else:
            utils.save_image(
                imgs,
                save_path,
                nrow=int(16 ** 0.5),
                normalize=True,
                range=(-1, 1)
            )

    def generate_imgs(self, epoch_label, batch_size):
        real_A = self.real_A
        real_B = self.real_B
        fake_B = self.netG_A(real_A)
        fake_A = self.netG_B(real_B)

        self.save_imgs(real_A, 'real_A', batch_size, epoch_label)
        self.save_imgs(real_B, 'real_B', batch_size, epoch_label)
        self.save_imgs(fake_B, 'fake_B', batch_size, epoch_label)
        self.save_imgs(fake_A, 'fake_A', batch_size, epoch_label)
    def __init__(self, config):
        self.generator = Generator()
        self.discriminator = Discriminator()
        self.encoder = Encoder()
        self.encoder.load_state_dict(
            torch.load(
                '/mnt/disk1/dat/lchen63/grid/model/model_embedding/encoder_6.pth'
            ))
        for param in self.encoder.parameters():
            param.requires_grad = False

        print(self.generator)
        print(self.discriminator)

        self.l1_loss_fn = nn.L1Loss()
        self.bce_loss_fn = nn.BCELoss()
        self.mse_loss_fn = nn.MSELoss()

        self.opt_g = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                             self.generator.parameters()),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))
        self.opt_d = torch.optim.Adam(self.discriminator.parameters(),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))

        if config.dataset == 'grid':
            self.dataset = VaganDataset(config.dataset_dir,
                                        train=config.is_train)
        elif config.dataset == 'lrw':
            self.dataset = LRWdataset(config.dataset_dir,
                                      train=config.is_train)

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=config.batch_size,
                                      num_workers=4,
                                      shuffle=True,
                                      drop_last=True)
        data_iter = iter(self.data_loader)
        data_iter.next()
        self.ones = Variable(torch.ones(config.batch_size),
                             requires_grad=False)
        self.zeros = Variable(torch.zeros(config.batch_size),
                              requires_grad=False)

        if config.cuda:
            device_ids = [int(i) for i in config.device_ids.split(',')]
            self.generator = nn.DataParallel(self.generator.cuda(),
                                             device_ids=device_ids)
            self.discriminator = nn.DataParallel(self.discriminator.cuda(),
                                                 device_ids=device_ids)
            self.encoder = nn.DataParallel(self.encoder.cuda(),
                                           device_ids=device_ids)
            self.bce_loss_fn = self.bce_loss_fn.cuda()
            self.mse_loss_fn = self.mse_loss_fn.cuda()
            self.l1_loss_fn = self.l1_loss_fn.cuda()
            self.ones = self.ones.cuda()
            self.zeros = self.zeros.cuda()

        self.config = config
        self.start_epoch = 0
class Trainer():
    def __init__(self, config):
        self.generator = Generator()
        self.discriminator = Discriminator()
        self.encoder = Encoder()
        self.encoder.load_state_dict(
            torch.load(
                '/mnt/disk1/dat/lchen63/grid/model/model_embedding/encoder_6.pth'
            ))
        for param in self.encoder.parameters():
            param.requires_grad = False

        print(self.generator)
        print(self.discriminator)

        self.l1_loss_fn = nn.L1Loss()
        self.bce_loss_fn = nn.BCELoss()
        self.mse_loss_fn = nn.MSELoss()

        self.opt_g = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                             self.generator.parameters()),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))
        self.opt_d = torch.optim.Adam(self.discriminator.parameters(),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))

        if config.dataset == 'grid':
            self.dataset = VaganDataset(config.dataset_dir,
                                        train=config.is_train)
        elif config.dataset == 'lrw':
            self.dataset = LRWdataset(config.dataset_dir,
                                      train=config.is_train)

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=config.batch_size,
                                      num_workers=4,
                                      shuffle=True,
                                      drop_last=True)
        data_iter = iter(self.data_loader)
        data_iter.next()
        self.ones = Variable(torch.ones(config.batch_size),
                             requires_grad=False)
        self.zeros = Variable(torch.zeros(config.batch_size),
                              requires_grad=False)

        if config.cuda:
            device_ids = [int(i) for i in config.device_ids.split(',')]
            self.generator = nn.DataParallel(self.generator.cuda(),
                                             device_ids=device_ids)
            self.discriminator = nn.DataParallel(self.discriminator.cuda(),
                                                 device_ids=device_ids)
            self.encoder = nn.DataParallel(self.encoder.cuda(),
                                           device_ids=device_ids)
            self.bce_loss_fn = self.bce_loss_fn.cuda()
            self.mse_loss_fn = self.mse_loss_fn.cuda()
            self.l1_loss_fn = self.l1_loss_fn.cuda()
            self.ones = self.ones.cuda()
            self.zeros = self.zeros.cuda()

        self.config = config
        self.start_epoch = 0
        # self.load(config.model_dir)

    def fit(self):
        config = self.config
        configure("{}/".format(config.log_dir), flush_secs=5)

        num_steps_per_epoch = len(self.data_loader)
        cc = 0

        for epoch in range(self.start_epoch, config.max_epochs):
            for step, (example, real_im, landmarks, right_audio,
                       wrong_audio) in enumerate(self.data_loader):

                t1 = time.time()

                if config.cuda:
                    example = Variable(example).cuda()
                    landmarks = Variable(landmarks).cuda()
                    real_im = Variable(real_im).cuda()
                    right_audio = Variable(right_audio).cuda()
                    wrong_audio = Variable(wrong_audio).cuda()
                else:
                    example = Variable(example)
                    landmarks = Variable(landmarks)
                    real_im = Variable(real_im)
                    right_audio = Variable(right_audio)
                    wrong_audio = Variable(wrong_audio)

                fake_im = self.generator(example, right_audio)

                # Train the discriminator
                D_real = self.discriminator(real_im, right_audio)
                D_wrong = self.discriminator(real_im, wrong_audio)
                D_fake = self.discriminator(fake_im.detach(), right_audio)

                loss_real = self.bce_loss_fn(D_real, self.ones)
                loss_wrong = self.bce_loss_fn(D_wrong, self.zeros)
                loss_fake = self.bce_loss_fn(D_fake, self.zeros)

                loss_disc = loss_real + 0.5 * (loss_fake + loss_wrong)
                loss_disc.backward()
                self.opt_d.step()
                self._reset_gradients()

                # Train the generator
                # noise = Variable(torch.randn(config.batch_size, config.noise_size))
                # noise = noise.cuda() if config.cuda else noise

                fake_im = self.generator(example, right_audio)
                fea_r = self.encoder(real_im)[1]
                fea_f = self.encoder(fake_im)[1]

                D_fake = self.discriminator(fake_im, right_audio)

                ############gan loss###################
                loss_gen1 = self.bce_loss_fn(D_fake, self.ones)

                #######gradient loss##############
                # f_gra_x = torch.abs(fake_im[:,:,:,:-1,:] -  fake_im[:,:,:,1:,:])
                # f_gra_y =  torch.abs(fake_im[:,:,:,:,:-1] -  fake_im[:,:,:,:,1:])
                # r_gra_x = torch.abs(real_im[:,:,:,:-1,:] -  real_im[:,:,:,1:,:])
                # r_gra_y =  torch.abs(real_im[:,:,:,:,:-1] -  real_im[:,:,:,:,1:])
                # loss_grad_x = self.l1_loss_fn(f_gra_x,r_gra_x)
                # loss_grad_y = self.l1_loss_fn(f_gra_y, r_gra_y)

                ######perceptual loss ##############

                loss_perceptual = self.mse_loss_fn(fea_f, fea_r)

                loss_gen = loss_gen1 + loss_perceptual
                loss_gen.backward()
                self.opt_g.step()
                self._reset_gradients()

                t2 = time.time()

                if (step + 1) % 1 == 0 or (step + 1) == num_steps_per_epoch:
                    steps_remain = num_steps_per_epoch-step+1 + \
                        (config.max_epochs-epoch+1)*num_steps_per_epoch
                    eta = int((t2 - t1) * steps_remain)

                    print(
                        "[{}/{}][{}/{}] Loss_D: {:.4f}  Loss_G: {:.4f}, loss_perceptual: {: .4f},  ETA: {} second"
                        .format(epoch + 1, config.max_epochs, step + 1,
                                num_steps_per_epoch, loss_disc.data[0],
                                loss_gen1.data[0], loss_perceptual.data[0],
                                eta))
                    log_value('discriminator_loss', loss_disc.data[0],
                              step + num_steps_per_epoch * epoch)
                    log_value('generator_loss', loss_gen1.data[0],
                              step + num_steps_per_epoch * epoch)
                    log_value('perceptual_loss', 0.5 * loss_perceptual.data[0],
                              step + num_steps_per_epoch * epoch)
                if (step) % (num_steps_per_epoch / 3) == 0:
                    fake_store = fake_im.data.permute(
                        0, 2, 1, 3,
                        4).contiguous().view(config.batch_size * 16, 3, 64, 64)
                    torchvision.utils.save_image(fake_store,
                                                 "{}fake_{}.png".format(
                                                     config.sample_dir, cc),
                                                 nrow=16,
                                                 normalize=True)
                    real_store = real_im.data.permute(
                        0, 2, 1, 3,
                        4).contiguous().view(config.batch_size * 16, 3, 64, 64)
                    torchvision.utils.save_image(real_store,
                                                 "{}real_{}.png".format(
                                                     config.sample_dir, cc),
                                                 nrow=16,
                                                 normalize=True)
                    cc += 1
            if epoch % 1 == 0:
                torch.save(
                    self.generator.state_dict(),
                    "{}/generator_{}.pth".format(config.model_dir, epoch))
                torch.save(
                    self.discriminator.state_dict(),
                    "{}/discriminator_{}.pth".format(config.model_dir, epoch))

    def load(self, directory):
        paths = glob.glob(os.path.join(directory, "*.pth"))
        gen_path = [path for path in paths if "generator" in path][0]
        disc_path = [path for path in paths if "discriminator" in path][0]
        # gen_state_dict = torch.load(gen_path)
        # new_gen_state_dict = OrderedDict()
        # for k, v in gen_state_dict.items():
        #     name = 'model.' + k
        #     new_gen_state_dict[name] = v
        # # load params
        # self.generator.load_state_dict(new_gen_state_dict)

        # disc_state_dict = torch.load(disc_path)
        # new_disc_state_dict = OrderedDict()
        # for k, v in disc_state_dict.items():
        #     name = 'model.' + k
        #     new_disc_state_dict[name] = v
        # # load params
        # self.discriminator.load_state_dict(new_disc_state_dict)

        self.generator.load_state_dict(torch.load(gen_path))
        self.discriminator.load_state_dict(torch.load(disc_path))

        self.start_epoch = int(gen_path.split(".")[0].split("_")[-1])
        print("Load pretrained [{}, {}]".format(gen_path, disc_path))

    def _reset_gradients(self):
        self.generator.zero_grad()
        self.discriminator.zero_grad()
        self.encoder.zero_grad()
Exemple #6
0
class Trainer():
    def __init__(self, config):
        self.generator = Generator()
        self.discriminator = Discriminator()

        print(self.generator)
        print(self.discriminator)
        self.bce_loss_fn = nn.BCELoss()
        self.mse_loss_fn = nn.MSELoss()

        self.opt_g = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                             self.generator.parameters()),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))
        self.opt_d = torch.optim.Adam(self.discriminator.parameters(),
                                      lr=config.lr,
                                      betas=(config.beta1, config.beta2))

        if config.dataset == 'grid':
            self.dataset = VaganDataset(config.dataset_dir,
                                        train=config.is_train)
        elif config.dataset == 'lrw':
            self.dataset = LRWdataset(config.dataset_dir,
                                      train=config.is_train)

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=config.batch_size,
                                      num_workers=config.num_thread,
                                      shuffle=True,
                                      drop_last=True)
        data_iter = iter(self.data_loader)
        data_iter.next()
        self.ones = Variable(torch.ones(config.batch_size),
                             requires_grad=False)
        self.zeros = Variable(torch.zeros(config.batch_size),
                              requires_grad=False)

        if config.cuda:
            device_ids = [int(i) for i in config.device_ids.split(',')]
            self.generator = nn.DataParallel(self.generator.cuda(),
                                             device_ids=device_ids)
            self.discriminator = nn.DataParallel(self.discriminator.cuda(),
                                                 device_ids=device_ids)
            self.bce_loss_fn = self.bce_loss_fn.cuda()
            self.mse_loss_fn = self.mse_loss_fn.cuda()
            self.ones = self.ones.cuda()
            self.zeros = self.zeros.cuda()

        self.config = config
        self.start_epoch = 0

        if config.load_model:
            self.start_epoch = config.start_epoch
            self.load(config.pretrained_dir, config.pretrained_epoch)

    def fit(self):
        config = self.config
        configure("{}".format(config.log_dir), flush_secs=5)

        num_steps_per_epoch = len(self.data_loader)
        cc = 0

        for epoch in range(self.start_epoch, config.max_epochs):
            for step, (example, real_im, landmarks, right_audio,
                       wrong_audio) in enumerate(self.data_loader):
                t1 = time.time()

                if config.cuda:
                    example = Variable(example).cuda()
                    real_im = Variable(real_im).cuda()
                    right_audio = Variable(right_audio).cuda()
                    wrong_audio = Variable(wrong_audio).cuda()
                else:
                    example = Variable(example)
                    real_im = Variable(real_im)
                    right_audio = Variable(right_audio)
                    wrong_audio = Variable(wrong_audio)

                fake_im = self.generator(example, right_audio)

                # Train the discriminator
                D_real = self.discriminator(real_im, right_audio)
                D_wrong = self.discriminator(real_im, wrong_audio)
                D_fake = self.discriminator(fake_im.detach(), right_audio)

                loss_real = self.bce_loss_fn(D_real, self.ones)
                loss_wrong = self.bce_loss_fn(D_wrong, self.zeros)
                loss_fake = self.bce_loss_fn(D_fake, self.zeros)

                loss_disc = loss_real + 0.5 * (loss_fake + loss_wrong)
                loss_disc.backward()
                self.opt_d.step()
                self._reset_gradients()

                # Train the generator
                fake_im = self.generator(example, right_audio)
                D_fake = self.discriminator(fake_im, right_audio)
                loss_gen = self.bce_loss_fn(D_fake, self.ones)

                loss_gen.backward()
                self.opt_g.step()
                self._reset_gradients()

                t2 = time.time()

                if (step + 1) % 1 == 0 or (step + 1) == num_steps_per_epoch:
                    steps_remain = num_steps_per_epoch-step+1 + \
                        (config.max_epochs-epoch+1)*num_steps_per_epoch
                    eta = int((t2 - t1) * steps_remain)

                    print(
                        "[{}/{}][{}/{}] Loss_D: {:.4f}  Loss_G: {:.4f},  ETA: {} second"
                        .format(epoch + 1, config.max_epochs, step + 1,
                                num_steps_per_epoch, loss_disc.data[0],
                                loss_gen.data[0], eta))
                    log_value('discriminator_loss', loss_disc.data[0],
                              step + num_steps_per_epoch * epoch)
                    log_value('generator_loss', loss_gen.data[0],
                              step + num_steps_per_epoch * epoch)
                if (step) % (num_steps_per_epoch / 10) == 0:
                    fake_store = fake_im.data.permute(
                        0, 2, 1, 3,
                        4).contiguous().view(config.batch_size * 16, 3, 64, 64)
                    torchvision.utils.save_image(fake_store,
                                                 "{}fake_{}.png".format(
                                                     config.sample_dir, cc),
                                                 nrow=16,
                                                 normalize=True)
                    real_store = real_im.data.permute(
                        0, 2, 1, 3,
                        4).contiguous().view(config.batch_size * 16, 3, 64, 64)
                    torchvision.utils.save_image(real_store,
                                                 "{}real_{}.png".format(
                                                     config.sample_dir, cc),
                                                 nrow=16,
                                                 normalize=True)
                    cc += 1
            if epoch % 1 == 0:
                torch.save(
                    self.generator.state_dict(),
                    "{}/generator_{}.pth".format(config.model_dir, epoch))
                torch.save(
                    self.discriminator.state_dict(),
                    "{}/discriminator_{}.pth".format(config.model_dir, epoch))

    def load(self, directory, epoch):
        gen_path = os.path.join(directory, 'generator_{}.pth'.format(epoch))
        disc_path = os.path.join(directory,
                                 'discriminator_{}.pth'.format(epoch))

        self.generator.load_state_dict(torch.load(gen_path))
        self.discriminator.load_state_dict(torch.load(disc_path))

        print("Load pretrained [{}, {}]".format(gen_path, disc_path))

    def _reset_gradients(self):
        self.generator.zero_grad()
        self.discriminator.zero_grad()