Beispiel #1
0
class tag2pix(object):
    def __init__(self, args):
        if args.model == 'tag2pix':
            from network import Generator
        elif args.model == 'senet':
            from model.GD_senet import Generator
        elif args.model == 'resnext':
            from model.GD_resnext import Generator
        elif args.model == 'catconv':
            from model.GD_cat_conv import Generator
        elif args.model == 'catall':
            from model.GD_cat_all import Generator
        elif args.model == 'adain':
            from model.GD_adain import Generator
        elif args.model == 'seadain':
            from model.GD_seadain import Generator
        else:
            raise Exception('invalid model name: {}'.format(args.model))

        self.args = args
        self.epoch = args.epoch
        self.batch_size = args.batch_size

        self.gpu_mode = not args.cpu
        self.input_size = args.input_size
        self.color_revert = ColorSpace2RGB(args.color_space)
        self.layers = args.layers
        [self.cit_weight, self.cvt_weight] = args.cit_cvt_weight

        self.load_dump = (args.load is not "")

        self.load_path = Path(args.load)

        self.l1_lambda = args.l1_lambda
        self.guide_beta = args.guide_beta
        self.adv_lambda = args.adv_lambda
        self.save_freq = args.save_freq

        self.two_step_epoch = args.two_step_epoch
        self.brightness_epoch = args.brightness_epoch
        self.save_all_epoch = args.save_all_epoch

        self.iv_dict, self.cv_dict, self.id_to_name = get_tag_dict(
            args.tag_dump)

        cvt_class_num = len(self.cv_dict.keys())
        cit_class_num = len(self.iv_dict.keys())
        self.class_num = cvt_class_num + cit_class_num

        self.start_epoch = 1

        #### load dataset
        if not args.test:
            self.train_data_loader, self.test_data_loader = get_dataset(args)
            self.result_path = Path(args.result_dir) / time.strftime(
                '%y%m%d-%H%M%S', time.localtime())

            if not self.result_path.exists():
                self.result_path.mkdir()

            self.test_images = self.get_test_data(self.test_data_loader,
                                                  args.test_image_count)
        else:
            self.test_data_loader = get_dataset(args)
            self.result_path = Path(args.result_dir)

        ##### initialize network
        self.net_opt = {
            'guide': not args.no_guide,
            'relu': args.use_relu,
            'bn': not args.no_bn,
            'cit': not args.no_cit
        }

        if self.net_opt['cit']:
            self.Pretrain_ResNeXT = se_resnext_half(
                dump_path=args.pretrain_dump,
                num_classes=cit_class_num,
                input_channels=1)
        else:
            self.Pretrain_ResNeXT = nn.Sequential()

        self.G = Generator(input_size=args.input_size,
                           layers=args.layers,
                           cv_class_num=cvt_class_num,
                           iv_class_num=cit_class_num,
                           net_opt=self.net_opt)
        self.D = Discriminator(input_dim=3,
                               output_dim=1,
                               input_size=self.input_size,
                               cv_class_num=cvt_class_num,
                               iv_class_num=cit_class_num)

        for param in self.Pretrain_ResNeXT.parameters():
            param.requires_grad = False
        if args.test:
            for param in self.G.parameters():
                param.requires_grad = False
            for param in self.D.parameters():
                param.requires_grad = False

        self.Pretrain_ResNeXT = nn.DataParallel(self.Pretrain_ResNeXT)
        self.G = nn.DataParallel(self.G)
        self.D = nn.DataParallel(self.D)

        self.G_optimizer = optim.Adam(self.G.parameters(),
                                      lr=args.lrG,
                                      betas=(args.beta1, args.beta2))
        self.D_optimizer = optim.Adam(self.D.parameters(),
                                      lr=args.lrD,
                                      betas=(args.beta1, args.beta2))

        self.BCE_loss = nn.BCELoss()
        self.CE_loss = nn.CrossEntropyLoss()
        self.L1Loss = nn.L1Loss()

        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        print("gpu mode: ", self.gpu_mode)
        print("device: ", self.device)
        print(torch.cuda.device_count(), "GPUS!")

        if self.gpu_mode:
            self.Pretrain_ResNeXT.to(self.device)
            self.G.to(self.device)
            self.D.to(self.device)
            self.BCE_loss.to(self.device)
            self.CE_loss.to(self.device)
            self.L1Loss.to(self.device)

    def train(self):
        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []
        self.train_hist['per_epoch_time'] = []
        self.train_hist['total_time'] = []

        self.y_real_, self.y_fake_ = torch.ones(self.batch_size,
                                                1), torch.zeros(
                                                    self.batch_size, 1)

        if self.gpu_mode:
            self.y_real_, self.y_fake_ = self.y_real_.to(
                self.device), self.y_fake_.to(self.device)

        if self.load_dump:
            self.load(self.load_path)
            print("continue training!!!!")
        else:
            self.end_epoch = self.epoch

        self.print_params()

        self.D.train()
        print('training start!!')
        start_time = time.time()

        for epoch in range(self.start_epoch, self.end_epoch + 1):
            print("EPOCH: {}".format(epoch))

            self.G.train()
            epoch_start_time = time.time()

            if epoch == self.brightness_epoch:
                print('changing brightness ...')
                self.train_data_loader.dataset.enhance_brightness(
                    self.input_size)

            max_iter = self.train_data_loader.dataset.__len__(
            ) // self.batch_size

            for iter, (original_, sketch_, iv_tag_, cv_tag_) in enumerate(
                    tqdm(self.train_data_loader, ncols=80)):
                if iter >= max_iter:
                    break

                if self.gpu_mode:
                    sketch_, original_, iv_tag_, cv_tag_ = sketch_.to(
                        self.device), original_.to(self.device), iv_tag_.to(
                            self.device), cv_tag_.to(self.device)

                # update D network
                self.D_optimizer.zero_grad()

                with torch.no_grad():
                    feature_tensor = self.Pretrain_ResNeXT(sketch_)
                if self.gpu_mode:
                    feature_tensor = feature_tensor.to(self.device)

                D_real, CIT_real, CVT_real = self.D(original_)
                D_real_loss = self.BCE_loss(D_real, self.y_real_)

                G_f, _ = self.G(sketch_, feature_tensor, cv_tag_)
                if self.gpu_mode:
                    G_f = G_f.to(self.device)

                D_f_fake, CIT_f_fake, CVT_f_fake = self.D(G_f)
                D_f_fake_loss = self.BCE_loss(D_f_fake, self.y_fake_)

                if self.two_step_epoch == 0 or epoch >= self.two_step_epoch:
                    CIT_real_loss = self.BCE_loss(
                        CIT_real, iv_tag_) if self.net_opt['cit'] else 0
                    CVT_real_loss = self.BCE_loss(CVT_real, cv_tag_)

                    C_real_loss = self.cvt_weight * CVT_real_loss + self.cit_weight * CIT_real_loss

                    CIT_f_fake_loss = self.BCE_loss(
                        CIT_f_fake, iv_tag_) if self.net_opt['cit'] else 0
                    CVT_f_fake_loss = self.BCE_loss(CVT_f_fake, cv_tag_)

                    C_f_fake_loss = self.cvt_weight * CVT_f_fake_loss + self.cit_weight * CIT_f_fake_loss
                else:
                    C_real_loss = 0
                    C_f_fake_loss = 0

                D_loss = self.adv_lambda * (D_real_loss + D_f_fake_loss) + (
                    C_real_loss + C_f_fake_loss)

                self.train_hist['D_loss'].append(D_loss.item())

                D_loss.backward()
                self.D_optimizer.step()

                # update G network
                self.G_optimizer.zero_grad()

                G_f, G_g = self.G(sketch_, feature_tensor, cv_tag_)

                if self.gpu_mode:
                    G_f, G_g = G_f.to(self.device), G_g.to(self.device)

                D_f_fake, CIT_f_fake, CVT_f_fake = self.D(G_f)

                D_f_fake_loss = self.BCE_loss(D_f_fake, self.y_real_)

                if self.two_step_epoch == 0 or epoch >= self.two_step_epoch:
                    CIT_f_fake_loss = self.BCE_loss(
                        CIT_f_fake, iv_tag_) if self.net_opt['cit'] else 0
                    CVT_f_fake_loss = self.BCE_loss(CVT_f_fake, cv_tag_)

                    C_f_fake_loss = self.cvt_weight * CVT_f_fake_loss + self.cit_weight * CIT_f_fake_loss
                else:
                    C_f_fake_loss = 0

                L1_D_f_fake_loss = self.L1Loss(G_f, original_)
                L1_D_g_fake_loss = self.L1Loss(
                    G_g, original_) if self.net_opt['guide'] else 0

                G_loss = (D_f_fake_loss + C_f_fake_loss) + \
                         (L1_D_f_fake_loss + L1_D_g_fake_loss * self.guide_beta) * self.l1_lambda

                self.train_hist['G_loss'].append(G_loss.item())

                G_loss.backward()
                self.G_optimizer.step()

                if ((iter + 1) % 100) == 0:
                    print(
                        "Epoch: [{:2d}] [{:4d}/{:4d}] D_loss: {:.8f}, G_loss: {:.8f}"
                        .format(epoch, (iter + 1), max_iter, D_loss.item(),
                                G_loss.item()))

            self.train_hist['per_epoch_time'].append(time.time() -
                                                     epoch_start_time)

            with torch.no_grad():
                self.visualize_results(epoch)
                utils.loss_plot(self.train_hist, self.result_path, epoch)

            if epoch >= self.save_all_epoch > 0:
                self.save(epoch)
            elif self.save_freq > 0 and epoch % self.save_freq == 0:
                self.save(epoch)

        print("Training finish!... save training results")

        if self.save_freq == 0 or epoch % self.save_freq != 0:
            if self.save_all_epoch <= 0 or epoch < self.save_all_epoch:
                self.save(epoch)

        self.train_hist['total_time'].append(time.time() - start_time)
        print(
            "Avg one epoch time: {:.2f}, total {} epochs time: {:.2f}".format(
                np.mean(self.train_hist['per_epoch_time']), self.epoch,
                self.train_hist['total_time'][0]))

    def test(self):
        self.load_test(self.args.load)

        self.D.eval()
        self.G.eval()

        load_path = self.load_path
        result_path = self.result_path / load_path.stem

        if not result_path.exists():
            result_path.mkdir()

        with torch.no_grad():
            for sketch_, index_, _, cv_tag_ in tqdm(self.test_data_loader,
                                                    ncols=80):
                if self.gpu_mode:
                    sketch_, cv_tag_ = sketch_.to(self.device), cv_tag_.to(
                        self.device)

                with torch.no_grad():
                    feature_tensor = self.Pretrain_ResNeXT(sketch_)

                if self.gpu_mode:
                    feature_tensor = feature_tensor.to(self.device)

                # D_real, CIT_real, CVT_real = self.D(original_)
                G_f, _ = self.G(sketch_, feature_tensor, cv_tag_)
                G_f = self.color_revert(G_f.cpu())

                for ind, result in zip(index_.cpu().numpy(), G_f):
                    save_path = result_path / f'{ind}.png'
                    if save_path.exists():
                        for i in range(100):
                            save_path = result_path / f'{ind}_{i}.png'
                            if not save_path.exists():
                                break
                    img = Image.fromarray(result)
                    img.save(save_path)

    def visualize_results(self, epoch, fix=True):
        if not self.result_path.exists():
            self.result_path.mkdir()

        self.G.eval()

        # test_data_loader
        original_, sketch_, iv_tag_, cv_tag_ = self.test_images
        image_frame_dim = int(np.ceil(np.sqrt(len(original_))))

        # iv_tag_ to feature tensor 16 * 16 * 256 by pre-reained Sketch.
        with torch.no_grad():
            feature_tensor = self.Pretrain_ResNeXT(sketch_)

            if self.gpu_mode:
                original_, sketch_, iv_tag_, cv_tag_, feature_tensor = original_.to(
                    self.device), sketch_.to(self.device), iv_tag_.to(
                        self.device), cv_tag_.to(
                            self.device), feature_tensor.to(self.device)

            G_f, G_g = self.G(sketch_, feature_tensor, cv_tag_)

            if self.gpu_mode:
                G_f = G_f.cpu()
                G_g = G_g.cpu()

            G_f = self.color_revert(G_f)
            G_g = self.color_revert(G_g)

        utils.save_images(
            G_f[:image_frame_dim * image_frame_dim, :, :, :],
            [image_frame_dim, image_frame_dim],
            self.result_path / 'tag2pix_epoch{:03d}_G_f.png'.format(epoch))
        utils.save_images(
            G_g[:image_frame_dim * image_frame_dim, :, :, :],
            [image_frame_dim, image_frame_dim],
            self.result_path / 'tag2pix_epoch{:03d}_G_g.png'.format(epoch))

    def save(self, save_epoch):
        if not self.result_path.exists():
            self.result_path.mkdir()

        with (self.result_path / 'arguments.txt').open('w') as f:
            f.write(pprint.pformat(self.args.__dict__))

        save_dir = self.result_path

        torch.save(
            {
                'G': self.G.state_dict(),
                'D': self.D.state_dict(),
                'G_optimizer': self.G_optimizer.state_dict(),
                'D_optimizer': self.D_optimizer.state_dict(),
                'finish_epoch': save_epoch,
                'result_path': str(save_dir)
            }, str(save_dir / 'tag2pix_{}_epoch.pkl'.format(save_epoch)))

        with (save_dir /
              'tag2pix_{}_history.pkl'.format(save_epoch)).open('wb') as f:
            pickle.dump(self.train_hist, f)

        print("============= save success =============")
        print("epoch from {} to {}".format(self.start_epoch, save_epoch))
        print("save result path is {}".format(str(self.result_path)))

    def load_test(self, checkpoint_path):
        checkpoint = torch.load(str(checkpoint_path))
        self.G.load_state_dict(checkpoint['G'])

    def load(self, checkpoint_path):
        checkpoint = torch.load(str(checkpoint_path))
        self.G.load_state_dict(checkpoint['G'])
        self.D.load_state_dict(checkpoint['D'])
        self.G_optimizer.load_state_dict(checkpoint['G_optimizer'])
        self.D_optimizer.load_state_dict(checkpoint['D_optimizer'])
        self.start_epoch = checkpoint['finish_epoch'] + 1

        self.finish_epoch = self.args.epoch + self.start_epoch - 1

        print("============= load success =============")
        print("epoch start from {} to {}".format(self.start_epoch,
                                                 self.finish_epoch))
        print("previous result path is {}".format(checkpoint['result_path']))

    def get_test_data(self, test_data_loader, count):
        test_count = 0
        original_, sketch_, iv_tag_, cv_tag_ = [], [], [], []
        for orig, sket, ivt, cvt in test_data_loader:
            original_.append(orig)
            sketch_.append(sket)
            iv_tag_.append(ivt)
            cv_tag_.append(cvt)

            test_count += len(orig)
            if test_count >= count:
                break

        original_ = torch.cat(original_, 0)
        sketch_ = torch.cat(sketch_, 0)
        iv_tag_ = torch.cat(iv_tag_, 0)
        cv_tag_ = torch.cat(cv_tag_, 0)

        self.save_tag_tensor_name(iv_tag_, cv_tag_,
                                  self.result_path / "test_image_tags.txt")

        image_frame_dim = int(np.ceil(np.sqrt(len(original_))))

        if self.gpu_mode:
            original_ = original_.cpu()
        sketch_np = sketch_.data.numpy().transpose(0, 2, 3, 1)
        original_np = self.color_revert(original_)

        utils.save_images(
            original_np[:image_frame_dim * image_frame_dim, :, :, :],
            [image_frame_dim, image_frame_dim],
            self.result_path / 'tag2pix_original.png')
        utils.save_images(
            sketch_np[:image_frame_dim * image_frame_dim, :, :, :],
            [image_frame_dim, image_frame_dim],
            self.result_path / 'tag2pix_sketch.png')

        return original_, sketch_, iv_tag_, cv_tag_

    def save_tag_tensor_name(self, iv_tensor, cv_tensor, save_file_path):
        '''iv_tensor, cv_tensor: batched one-hot tag tensors'''
        iv_dict_inverse = {
            tag_index: tag_id
            for (tag_id, tag_index) in self.iv_dict.items()
        }
        cv_dict_inverse = {
            tag_index: tag_id
            for (tag_id, tag_index) in self.cv_dict.items()
        }

        with open(save_file_path, 'w') as f:
            f.write("CIT tags\n")

            for tensor_i, batch_unit in enumerate(iv_tensor):
                tag_list = []
                f.write(f'{tensor_i} : ')

                for i, is_tag in enumerate(batch_unit):
                    if is_tag:
                        tag_name = self.id_to_name[iv_dict_inverse[i]]
                        tag_list.append(tag_name)
                        f.write(f"{tag_name}, ")
                f.write("\n")

            f.write("\nCVT tags\n")

            for tensor_i, batch_unit in enumerate(cv_tensor):
                tag_list = []
                f.write(f'{tensor_i} : ')

                for i, is_tag in enumerate(batch_unit):
                    if is_tag:
                        tag_name = self.id_to_name[cv_dict_inverse[i]]
                        tag_list.append(self.id_to_name[cv_dict_inverse[i]])
                        f.write(f"{tag_name}, ")
                f.write("\n")

    def print_params(self):
        params_cnt = [0, 0, 0]
        for param in self.G.parameters():
            params_cnt[0] += param.numel()
        for param in self.D.parameters():
            params_cnt[1] += param.numel()
        for param in self.Pretrain_ResNeXT.parameters():
            params_cnt[2] += param.numel()
        print(
            f'Parameter #: G - {params_cnt[0]} / D - {params_cnt[1]} / Pretrain - {params_cnt[2]}'
        )
			Disc_a = discriminator(data)

			optimizer.zero_grad()
			loss_classification = torch.FloatTensor([0])
			for cls in range(len(label)):
				loss_classification += F.binary_cross_entropy(torch.squeeze(Disc_a)[cls], label[cls].float())
			#loss_classification = criterion(Disc_a, label)
			loss = loss_classification
			loss.backward()
			optimizer.step()

			num_batches += 1
			total_clas_loss += loss_classification.data.item()
		avg_clas_loss = total_clas_loss / num_batches
		loss_classifier_list.append(avg_clas_loss)
	plot_clas_loss(loss_classifier_list, 'clas_loss.png')
	discriminator.eval()
	models.append(discriminator.state_dict())

	Disc_b = discriminator(torch.from_numpy(X_test).float())
	pred_b = torch.from_numpy(np.array([1 if i > 0.5 else 0 for i in Disc_b]))
	#pred_b = torch.max(F.softmax(Disc_b), 1)[1]
	test_label = torch.from_numpy(y_test)
	num_correct_b = 0
	num_correct_b += torch.eq(pred_b, test_label).sum().float().item()
	Acc_b = num_correct_b/len(test_label)
	scoreA.append(Acc_b)

print(np.mean(scoreA))

Beispiel #3
0
def train(args):
    # set the logger
    logger = Logger('./logs')

    # GPU enabling
    if (args.gpu != None):
        use_cuda = True
        dtype = torch.cuda.FloatTensor
        torch.cuda.set_device(args.gpu)
        print("Current device: %s" % torch.cuda.get_device_name(args.gpu))

    # define networks
    g_AtoB = Generator().type(dtype)
    g_BtoA = Generator().type(dtype)
    d_A = Discriminator().type(dtype)
    d_B = Discriminator().type(dtype)

    # optimizers
    optimizer_generators = Adam(
        list(g_AtoB.parameters()) + list(g_BtoA.parameters()), INITIAL_LR)
    optimizer_d_A = Adam(d_A.parameters(), INITIAL_LR)
    optimizer_d_B = Adam(d_B.parameters(), INITIAL_LR)

    # loss criterion
    criterion_mse = torch.nn.MSELoss()
    criterion_l1 = torch.nn.L1Loss()

    # get training data
    dataset_transform = transforms.Compose([
        transforms.Resize(int(IMAGE_SIZE * 1),
                          Image.BICUBIC),  # scale shortest side to image_size
        transforms.RandomCrop(
            (IMAGE_SIZE, IMAGE_SIZE)),  # random center image_size out
        transforms.ToTensor(),  # turn image from [0-255] to [0-1]
        transforms.Normalize(mean=(0.5, 0.5, 0.5),
                             std=(0.5, 0.5, 0.5))  # normalize
    ])
    dataloader = DataLoader(ImgPairDataset(args.dataroot, dataset_transform,
                                           'train'),
                            batch_size=BATCH_SIZE,
                            shuffle=True)

    # get some test data to display periodically
    test_data_A = torch.tensor([]).type(dtype)
    test_data_B = torch.tensor([]).type(dtype)
    for i in range(NUM_TEST_SAMPLES):
        imgA = ImgPairDataset(args.dataroot, dataset_transform,
                              'test')[i]['A'].type(dtype).unsqueeze(0)
        imgB = ImgPairDataset(args.dataroot, dataset_transform,
                              'test')[i]['B'].type(dtype).unsqueeze(0)
        test_data_A = torch.cat((test_data_A, imgA), dim=0)
        test_data_B = torch.cat((test_data_B, imgB), dim=0)

        fileStrA = 'visualization/test_%d/%s/' % (i, 'B_inStyleofA')
        fileStrB = 'visualization/test_%d/%s/' % (i, 'A_inStyleofB')
        if not os.path.exists(fileStrA):
            os.makedirs(fileStrA)
        if not os.path.exists(fileStrB):
            os.makedirs(fileStrB)

        fileStrA = 'visualization/test_original_%s_%04d.png' % ('A', i)
        fileStrB = 'visualization/test_original_%s_%04d.png' % ('B', i)
        utils.save_image(
            fileStrA,
            ImgPairDataset(args.dataroot, dataset_transform,
                           'test')[i]['A'].data)
        utils.save_image(
            fileStrB,
            ImgPairDataset(args.dataroot, dataset_transform,
                           'test')[i]['B'].data)

    # replay buffers
    replayBufferA = utils.ReplayBuffer(50)
    replayBufferB = utils.ReplayBuffer(50)

    # training loop
    step = 0
    for e in range(EPOCHS):
        startTime = time.time()
        for idx, batch in enumerate(dataloader):
            real_A = batch['A'].type(dtype)
            real_B = batch['B'].type(dtype)

            # some examples seem to have only 1 color channel instead of 3
            if (real_A.shape[1] != 3):
                continue
            if (real_B.shape[1] != 3):
                continue

            # -----------------
            #  train generators
            # -----------------
            optimizer_generators.zero_grad()
            utils.learning_rate_decay(INITIAL_LR, e, EPOCHS,
                                      optimizer_generators)

            # GAN loss
            fake_A = g_BtoA(real_B)
            disc_fake_A = d_A(fake_A)
            fake_B = g_AtoB(real_A)
            disc_fake_B = d_B(fake_B)

            replayBufferA.push(torch.tensor(fake_A.data))
            replayBufferB.push(torch.tensor(fake_B.data))

            target_real = Variable(torch.ones_like(disc_fake_A)).type(dtype)
            target_fake = Variable(torch.zeros_like(disc_fake_A)).type(dtype)

            loss_gan_AtoB = criterion_mse(disc_fake_B, target_real)
            loss_gan_BtoA = criterion_mse(disc_fake_A, target_real)
            loss_gan = loss_gan_AtoB + loss_gan_BtoA

            # cyclic reconstruction loss
            cyclic_A = g_BtoA(fake_B)
            cyclic_B = g_AtoB(fake_A)
            loss_cyclic_AtoBtoA = criterion_l1(cyclic_A,
                                               real_A) * CYCLIC_WEIGHT
            loss_cyclic_BtoAtoB = criterion_l1(cyclic_B,
                                               real_B) * CYCLIC_WEIGHT
            loss_cyclic = loss_cyclic_AtoBtoA + loss_cyclic_BtoAtoB

            # identity loss
            loss_identity = 0
            loss_identity_A = 0
            loss_identity_B = 0
            if (args.use_identity == True):
                identity_A = g_BtoA(real_A)
                identity_B = g_AtoB(real_B)
                loss_identity_A = criterion_l1(identity_A,
                                               real_A) * 0.5 * CYCLIC_WEIGHT
                loss_identity_B = criterion_l1(identity_B,
                                               real_B) * 0.5 * CYCLIC_WEIGHT
                loss_identity = loss_identity_A + loss_identity_B

            loss_generators = loss_gan + loss_cyclic + loss_identity
            loss_generators.backward()
            optimizer_generators.step()

            # -----------------
            #  train discriminators
            # -----------------
            optimizer_d_A.zero_grad()
            utils.learning_rate_decay(INITIAL_LR, e, EPOCHS, optimizer_d_A)

            fake_A = replayBufferA.sample(1).detach()
            disc_fake_A = d_A(fake_A)
            disc_real_A = d_A(real_A)
            loss_d_A = 0.5 * (criterion_mse(disc_real_A, target_real) +
                              criterion_mse(disc_fake_A, target_fake))

            loss_d_A.backward()
            optimizer_d_A.step()

            optimizer_d_B.zero_grad()
            utils.learning_rate_decay(INITIAL_LR, e, EPOCHS, optimizer_d_B)

            fake_B = replayBufferB.sample(1).detach()
            disc_fake_B = d_B(fake_B)
            disc_real_B = d_B(real_B)
            loss_d_B = 0.5 * (criterion_mse(disc_real_B, target_real) +
                              criterion_mse(disc_fake_B, target_fake))

            loss_d_B.backward()
            optimizer_d_B.step()

            #log info and save sample images
            if ((idx % 250) == 0):
                # eval on some sample images
                g_AtoB.eval()
                g_BtoA.eval()

                test_B_hat = g_AtoB(test_data_A).cpu()
                test_A_hat = g_BtoA(test_data_B).cpu()

                fileBaseStr = 'test_%d_%d' % (e, idx)
                for i in range(NUM_TEST_SAMPLES):
                    fileStrA = 'visualization/test_%d/%s/%03d_%04d.png' % (
                        i, 'B_inStyleofA', e, idx)
                    fileStrB = 'visualization/test_%d/%s/%03d_%04d.png' % (
                        i, 'A_inStyleofB', e, idx)
                    utils.save_image(fileStrA, test_A_hat[i].data)
                    utils.save_image(fileStrB, test_B_hat[i].data)

                g_AtoB.train()
                g_BtoA.train()

                endTime = time.time()
                timeForIntervalIterations = endTime - startTime
                startTime = endTime

                print(
                    'Epoch [{:3d}/{:3d}], Training [{:4d}/{:4d}], Time Spent (s): [{:4.4f}], Losses: [G_GAN: {:4.4f}][G_CYC: {:4.4f}][G_IDT: {:4.4f}][D_A: {:4.4f}][D_B: {:4.4f}]'
                    .format(e, EPOCHS, idx, len(dataloader),
                            timeForIntervalIterations, loss_gan, loss_cyclic,
                            loss_identity, loss_d_A, loss_d_B))

                # tensorboard logging
                info = {
                    'loss_generators':
                    loss_generators.item(),
                    'loss_gan_AtoB':
                    loss_gan_AtoB.item(),
                    'loss_gan_BtoA':
                    loss_gan_BtoA.item(),
                    'loss_cyclic_AtoBtoA':
                    loss_cyclic_AtoBtoA.item(),
                    'loss_cyclic_BtoAtoB':
                    loss_cyclic_BtoAtoB.item(),
                    'loss_cyclic':
                    loss_cyclic.item(),
                    'loss_d_A':
                    loss_d_A.item(),
                    'loss_d_B':
                    loss_d_B.item(),
                    'lr_optimizer_generators':
                    optimizer_generators.param_groups[0]['lr'],
                    'lr_optimizer_d_A':
                    optimizer_d_A.param_groups[0]['lr'],
                    'lr_optimizer_d_B':
                    optimizer_d_B.param_groups[0]['lr'],
                }
                if (args.use_identity):
                    info['loss_identity_A'] = loss_identity_A.item()
                    info['loss_identity_B'] = loss_identity_B.item()
                for tag, value in info.items():
                    logger.scalar_summary(tag, value, step)

                info = {
                    'test_A_hat':
                    test_A_hat.data.numpy().transpose(0, 2, 3, 1),
                    'test_B_hat':
                    test_B_hat.data.numpy().transpose(0, 2, 3, 1),
                }
                for tag, images in info.items():
                    logger.image_summary(tag, images, step)

            step += 1

        # save after every epoch
        g_AtoB.eval()
        g_BtoA.eval()
        d_A.eval()
        d_B.eval()

        if use_cuda:
            g_AtoB.cpu()
            g_BtoA.cpu()
            d_A.cpu()
            d_B.cpu()

        if not os.path.exists("models"):
            os.makedirs("models")
        filename_gAtoB = "models/" + str('g_AtoB') + "_epoch_" + str(
            e) + ".model"
        filename_gBtoA = "models/" + str('g_BtoA') + "_epoch_" + str(
            e) + ".model"
        filename_dA = "models/" + str('d_A') + "_epoch_" + str(e) + ".model"
        filename_dB = "models/" + str('d_B') + "_epoch_" + str(e) + ".model"
        torch.save(g_AtoB.state_dict(), filename_gAtoB)
        torch.save(g_BtoA.state_dict(), filename_gBtoA)
        torch.save(d_A.state_dict(), filename_dA)
        torch.save(d_B.state_dict(), filename_dB)

        if use_cuda:
            g_AtoB.cuda()
            g_BtoA.cuda()
            d_A.cuda()
            d_B.cuda()