Ejemplo n.º 1
0
class RFRNetModel():
    def __init__(self):
        self.G = None
        self.lossNet = None
        self.iter = None
        self.optm_G = None
        self.device = None
        self.real_A = None
        self.real_B = None
        self.fake_B = None
        self.comp_B = None
        self.l1_loss_val = 0.0
        self.visual_names = ['real_A', 'real_B', 'mask', 'fake_B', 'comp_B']
        self.model_names = ['G']
        self.visualizer = Visualizer()

    def initialize_model(self, path=None, train=True):
        self.G = RFRNet()
        self.optm_G = optim.Adam(self.G.parameters(), lr=2e-4)
        self.print_networks(False)
        if train:
            self.lossNet = VGG16FeatureExtractor()
        try:
            start_iter = load_ckpt(path, [('generator', self.G)],
                                   [('optimizer_G', self.optm_G)])
            if train:
                self.optm_G = optim.Adam(self.G.parameters(), lr=2e-4)
                print('Model Initialized, iter: ', start_iter)
                self.iter = start_iter
        except:
            print('No trained model, from start')
            self.iter = 0

    def cuda(self):
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            print("Model moved to cuda")
            self.G.cuda()
            if self.lossNet is not None:
                self.lossNet.cuda()
        else:
            self.device = torch.device("cpu")

    def train(self, train_loader, save_path, finetune=False, iters=450000):
        #    writer = SummaryWriter(log_dir="log_info")
        self.G.train(finetune=finetune)
        if finetune:
            self.optm_G = optim.Adam(filter(lambda p: p.requires_grad,
                                            self.G.parameters()),
                                     lr=5e-5)
        print("Starting training from iteration:{:d}".format(self.iter))
        s_time = time.time()
        while self.iter < iters:
            for items in train_loader:
                gt_images, masks = self.__cuda__(*items)
                masked_images = gt_images * masks
                self.forward(masked_images, masks, gt_images)
                self.update_parameters()
                self.iter += 1

                if self.iter % 50 == 0:
                    e_time = time.time()
                    int_time = e_time - s_time
                    print("Iteration:%d, l1_loss:%.4f, time_taken:%.2f" %
                          (self.iter, self.l1_loss_val / 50, int_time))
                    s_time = time.time()
                    self.l1_loss_val = 0.0

                if self.iter % 40000 == 0:
                    if not os.path.exists('{:s}'.format(save_path)):
                        os.makedirs('{:s}'.format(save_path))
                    save_ckpt('{:s}/g_{:d}.pth'.format(save_path, self.iter),
                              [('generator', self.G)],
                              [('optimizer_G', self.optm_G)], self.iter)

                if self.iter % 200 == 0:
                    self.visualizer.display_current_results(
                        self.get_current_visuals(), self.iter, False)

        if not os.path.exists('{:s}'.format(save_path)):
            os.makedirs('{:s}'.format(save_path))
            save_ckpt('{:s}/g_{:s}.pth'.format(save_path, "final"),
                      [('generator', self.G)], [('optimizer_G', self.optm_G)],
                      self.iter)

    def test(self, test_loader, result_save_path):
        self.G.eval()
        for para in self.G.parameters():
            para.requires_grad = False
        count = 0
        for items in test_loader:
            gt_images, masks = self.__cuda__(*items)
            masked_images = gt_images * masks
            masks = torch.cat([masks] * 3, dim=1)
            fake_B, mask = self.G(masked_images, masks)
            comp_B = fake_B * (1 - masks) + gt_images * masks
            if not os.path.exists('{:s}/results'.format(result_save_path)):
                os.makedirs('{:s}/results'.format(result_save_path))
            for k in range(comp_B.size(0)):
                count += 1
                grid = make_grid(comp_B[k:k + 1])
                file_path = '{:s}/results/img_{:d}.png'.format(
                    result_save_path, count)
                save_image(grid, file_path)

                grid = make_grid(masked_images[k:k + 1] + 1 - masks[k:k + 1])
                file_path = '{:s}/results/masked_img_{:d}.png'.format(
                    result_save_path, count)
                save_image(grid, file_path)

    def forward(self, masked_image, mask, gt_image):
        self.real_A = masked_image
        self.real_B = gt_image
        self.mask = mask
        fake_B, _ = self.G(masked_image, mask)
        self.fake_B = fake_B
        self.comp_B = self.fake_B * (1 - mask) + self.real_B * mask

    def update_parameters(self):
        self.update_G()
        self.update_D()

    def update_G(self):
        self.optm_G.zero_grad()
        loss_G = self.get_g_loss()
        loss_G.backward()
        self.optm_G.step()

    def update_D(self):
        return

    def get_g_loss(self):
        real_B = self.real_B
        fake_B = self.fake_B
        comp_B = self.comp_B

        real_B_feats = self.lossNet(real_B)
        fake_B_feats = self.lossNet(fake_B)
        comp_B_feats = self.lossNet(comp_B)

        tv_loss = self.TV_loss(comp_B * (1 - self.mask))
        style_loss = self.style_loss(real_B_feats,
                                     fake_B_feats) + self.style_loss(
                                         real_B_feats, comp_B_feats)
        preceptual_loss = self.preceptual_loss(
            real_B_feats, fake_B_feats) + self.preceptual_loss(
                real_B_feats, comp_B_feats)
        valid_loss = self.l1_loss(real_B, fake_B, self.mask)
        hole_loss = self.l1_loss(real_B, fake_B, (1 - self.mask))

        loss_G = (tv_loss * 0.1 + style_loss * 120 + preceptual_loss * 0.05 +
                  valid_loss * 1 + hole_loss * 6)

        self.l1_loss_val += valid_loss.detach() + hole_loss.detach()
        return loss_G

    def l1_loss(self, f1, f2, mask=1):
        return torch.mean(torch.abs(f1 - f2) * mask)

    def style_loss(self, A_feats, B_feats):
        assert len(A_feats) == len(
            B_feats
        ), "the length of two input feature maps lists should be the same"
        loss_value = 0.0
        for i in range(len(A_feats)):
            A_feat = A_feats[i]
            B_feat = B_feats[i]
            _, c, w, h = A_feat.size()
            A_feat = A_feat.view(A_feat.size(0), A_feat.size(1),
                                 A_feat.size(2) * A_feat.size(3))
            B_feat = B_feat.view(B_feat.size(0), B_feat.size(1),
                                 B_feat.size(2) * B_feat.size(3))
            A_style = torch.matmul(A_feat, A_feat.transpose(2, 1))
            B_style = torch.matmul(B_feat, B_feat.transpose(2, 1))
            loss_value += torch.mean(
                torch.abs(A_style - B_style) / (c * w * h))
        return loss_value

    def TV_loss(self, x):
        h_x = x.size(2)
        w_x = x.size(3)
        h_tv = torch.mean(torch.abs(x[:, :, 1:, :] - x[:, :, :h_x - 1, :]))
        w_tv = torch.mean(torch.abs(x[:, :, :, 1:] - x[:, :, :, :w_x - 1]))
        return h_tv + w_tv

    def preceptual_loss(self, A_feats, B_feats):
        assert len(A_feats) == len(
            B_feats
        ), "the length of two input feature maps lists should be the same"
        loss_value = 0.0
        for i in range(len(A_feats)):
            A_feat = A_feats[i]
            B_feat = B_feats[i]
            loss_value += torch.mean(torch.abs(A_feat - B_feat))
        return loss_value

    def __cuda__(self, *args):
        return (item.to(self.device) for item in args)

    def get_current_visuals(self):
        """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
        visual_ret = OrderedDict()
        for name in self.visual_names:
            if isinstance(name, str):
                attr = getattr(self, name)
                if isinstance(attr, list):
                    for i in range(len(attr)):
                        visual_ret[name + ':' + str(i)] = attr[i]
                else:
                    visual_ret[name] = attr
        return visual_ret

    def print_networks(self, verbose):
        """Print the total number of parameters in the network and (if verbose) network architecture

        Parameters:
            verbose (bool) -- if verbose: print the network architecture
        """
        print('---------- Networks initialized -------------')
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, name)
                num_params = 0
                for param in net.parameters():
                    num_params += param.numel()
                if verbose:
                    print(net)
                print('[Network %s] Total number of parameters : %.3f M' %
                      (name, num_params / 1e6))
        print('-----------------------------------------------')
Ejemplo n.º 2
0
class RFRNetModel():
    def __init__(self):
        self.loss_weights = {
            "tv": 0.1,
            "style": 120,
            "perceptual": 0.05,
            "valid": 1,
            "hole": 6
        }
        self.learning_rates = {"train": 2e-4, "finetune": 5e-5}
        self.save_freq = 10000

        self.G = None
        self.lossNet = None
        self.optm_G = None
        self.device = None

        self.iter = None
        self.real_A = None
        self.real_B = None
        self.fake_B = None
        self.comp_B = None
        self.l1_loss_val = 0.0

    def initialize_model(self, path=None, train=True):
        self.G = RFRNet()
        self.optm_G = optim.Adam(self.G.parameters(),
                                 lr=self.learning_rates["train"])
        if train:
            self.lossNet = VGG16FeatureExtractor()

        try:
            start_iter = load_ckpt(path, [('generator', self.G)],
                                   [('optimizer_G', self.optm_G)])

            if train:
                self.optm_G = optim.Adam(self.G.parameters(),
                                         lr=self.learning_rates["train"])
                print('Model Initialized, iter: ', start_iter)
                self.iter = start_iter
        except:
            print('No trained model, from start')
            self.iter = 0

        return self

    def cuda(self):
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            print("Model moved to cuda")

            self.G.cuda()
            if self.lossNet is not None:
                self.lossNet.cuda()
        else:
            self.device = torch.device("cpu")
        return self

    def multi_gpu(self):
        print(f'Multi-GPU training with {torch.cuda.device_count()} GPUs')
        self.G = nn.DataParallel(self.G)
        if self.lossNet is not None:
            self.lossNet = nn.DataParallel(self.lossNet)
        return self

    def train(self,
              train_loader,
              save_path,
              finetune=False,
              iters=450000,
              fp16=False,
              multi_gpu=True):
        writer = SummaryWriter()

        self.G.train(finetune=finetune)
        # Overwrite optimizer with a lower lr
        if finetune:
            self.optm_G = optim.Adam(filter(lambda p: p.requires_grad,
                                            self.G.parameters()),
                                     lr=self.learning_rates["finetune"])

        self.fp16 = fp16 and GOT_AMP
        if self.fp16:
            self.G, self.optm_G = amp.initialize(self.G,
                                                 self.optm_G,
                                                 opt_level="O1")
            if self.lossNet is not None:
                self.lossNet = amp.initialize(self.lossNet, opt_level="O1")

        if multi_gpu:
            self.multi_gpu()

        print("Starting training from iteration: {:d}, finetuning: {}".format(
            self.iter, finetune))
        s_time = time.time()
        while self.iter < iters:
            for items in train_loader:
                gt_images, masks = self.__cuda__(*items)
                masked_images = gt_images * masks

                self.forward(masked_images, masks, gt_images)
                self.update_parameters()

                for k, v in self.metrics["lossG"].items():
                    writer.add_scalar(f"lossG/{k}", v, global_step=self.iter)

                self.iter += 1

                if self.iter % 200 == 0:
                    e_time = time.time()
                    int_time = e_time - s_time
                    print("Iteration:%d, l1_loss:%.4f, time_taken:%.2f" %
                          (self.iter, self.l1_loss_val / 50, int_time))

                    writer.add_images("real_A",
                                      self.real_A,
                                      global_step=self.iter)
                    writer.add_images("mask", self.mask, global_step=self.iter)
                    writer.add_images("real_B",
                                      self.real_B,
                                      global_step=self.iter)
                    writer.add_images("fake_B",
                                      self.fake_B,
                                      global_step=self.iter)
                    writer.add_images("comp_B",
                                      self.comp_B,
                                      global_step=self.iter)

                    # Reset
                    s_time = time.time()
                    self.l1_loss_val = 0.0

                if self.iter % self.save_freq == 0:
                    if not os.path.exists('{:s}'.format(save_path)):
                        os.makedirs('{:s}'.format(save_path))
                    save_ckpt(
                        '{:s}/g_{:d}{}.pth'.format(
                            save_path, self.iter,
                            "_finetune" if finetune else ""),
                        [('generator', self.G)],
                        [('optimizer_G', self.optm_G)], self.iter)

        if not os.path.exists('{:s}'.format(save_path)):
            os.makedirs('{:s}'.format(save_path))
            save_ckpt(
                '{:s}/g_{:s}{}.pth'.format(save_path, "final",
                                           "_finetune" if finetune else ""),
                [('generator', self.G)], [('optimizer_G', self.optm_G)],
                self.iter)

    def test(self, test_loader, result_save_path):
        self.G.eval()

        for para in self.G.parameters():
            para.requires_grad = False

        count = 0
        for items in test_loader:
            gt_images, masks = self.__cuda__(*items)
            masked_images = gt_images * masks
            # print(f">>> masks.shape: {masks.shape}")
            if masks.size(1) == 1:
                masks = torch.cat([masks] * 3, dim=1)

            fake_B, mask = self.G(masked_images, masks)
            comp_B = fake_B * (1 - masks) + gt_images * masks

            if not os.path.exists('{:s}/results'.format(result_save_path)):
                os.makedirs('{:s}/results'.format(result_save_path))

            for k in range(comp_B.size(0)):
                count += 1
                grid = make_grid(comp_B[k:k + 1])
                file_path = '{:s}/results/img_{:d}.png'.format(
                    result_save_path, count)
                save_image(grid, file_path)

                grid = make_grid(masked_images[k:k + 1] + 1 - masks[k:k + 1])
                file_path = '{:s}/results/masked_img_{:d}.png'.format(
                    result_save_path, count)
                save_image(grid, file_path)

    def forward(self, masked_image, mask, gt_image):
        self.real_A = masked_image
        self.real_B = gt_image
        self.mask = mask
        fake_B, _ = self.G(masked_image, mask)
        self.fake_B = fake_B
        self.comp_B = self.fake_B * (1 - mask) + self.real_B * mask

    def update_parameters(self):
        self.update_G()
        self.update_D()

    def update_G(self):
        self.optm_G.zero_grad()
        loss_G = self.get_g_loss()
        if self.fp16:
            with amp.scale_loss(loss_G, self.optm_G) as scaled_loss:
                scaled_loss.backward()
        else:
            loss_G.backward()
        self.optm_G.step()

    def update_D(self):
        return

    def get_g_loss(self):
        real_B = self.real_B
        fake_B = self.fake_B
        comp_B = self.comp_B

        real_B_feats = self.lossNet(real_B)
        fake_B_feats = self.lossNet(fake_B)
        comp_B_feats = self.lossNet(comp_B)

        tv_loss = self.TV_loss(comp_B * (1 - self.mask))
        style_loss = self.style_loss(real_B_feats, fake_B_feats) \
            + self.style_loss(real_B_feats, comp_B_feats)
        perceptual_loss = self.perceptual_loss(real_B_feats, fake_B_feats) \
            + self.perceptual_loss(real_B_feats, comp_B_feats)
        valid_loss = self.l1_loss(real_B, fake_B, self.mask)
        hole_loss = self.l1_loss(real_B, fake_B, (1 - self.mask))

        loss_G = (tv_loss * self.loss_weights["tv"] +
                  style_loss * self.loss_weights["style"] +
                  perceptual_loss * self.loss_weights["perceptual"] +
                  valid_loss * self.loss_weights["valid"] +
                  hole_loss * self.loss_weights["hole"])

        self.l1_loss_val += valid_loss.detach() + hole_loss.detach()

        self.metrics = {
            "lossG": {
                "sum":
                loss_G.item(),
                "tv":
                tv_loss.item() * self.loss_weights["tv"],
                "style":
                style_loss.item() * self.loss_weights["style"],
                "perceptual":
                perceptual_loss.item() * self.loss_weights["perceptual"],
                "valid":
                valid_loss.item() * self.loss_weights["valid"],
                "hole":
                hole_loss.item() * self.loss_weights["hole"],
            }
        }
        print(f"#{self.iter:08d} - lossG: {self.metrics['lossG']}")

        return loss_G

    @staticmethod
    def l1_loss(f1, f2, mask=1):
        return torch.mean(torch.abs(f1 - f2) * mask)

    @staticmethod
    def style_loss(A_feats, B_feats):
        assert len(A_feats) == len(B_feats), \
            "the length of two input feature maps lists should be the same"

        loss_value = 0.0
        for i in range(len(A_feats)):
            A_feat = A_feats[i]
            B_feat = B_feats[i]

            # _, c, w, h = A_feat.size()
            # A_feat = A_feat.view(A_feat.size(0),
            #                      A_feat.size(1),
            #                      A_feat.size(2) * A_feat.size(3))
            # B_feat = B_feat.view(B_feat.size(0),
            #                      B_feat.size(1),
            #                      B_feat.size(2) * B_feat.size(3))
            # A_style = torch.matmul(A_feat, A_feat.transpose(2, 1))
            # B_style = torch.matmul(B_feat, B_feat.transpose(2, 1))
            # loss_value += torch.mean(torch.abs(A_style - B_style)/(c * w * h))

            # Avoid underflow when using mixed precision training
            gram_A = gram_matrix(A_feat)
            gram_B = gram_matrix(B_feat)
            loss_value += torch.mean(torch.abs(gram_A - gram_B))

        return loss_value

    @staticmethod
    def TV_loss(x):
        h_tv = torch.mean(torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]))
        w_tv = torch.mean(torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]))
        return h_tv + w_tv

    @staticmethod
    def perceptual_loss(A_feats, B_feats):
        assert len(A_feats) == len(B_feats), \
            "the length of two input feature maps lists should be the same"

        loss_value = 0.0
        for i in range(len(A_feats)):
            A_feat = A_feats[i]
            B_feat = B_feats[i]
            loss_value += torch.mean(torch.abs(A_feat - B_feat))

        return loss_value

    def __cuda__(self, *args):
        return (item.to(self.device) for item in args)
Ejemplo n.º 3
0
class RFRNetModel():
    def __init__(self):
        self.G = None
        self.lossNet = None
        self.iter = None
        self.optm_G = None
        self.device = None
        self.real_A = None
        self.real_B = None
        self.fake_B = None
        self.comp_B = None
        self.l1_loss_val = 0.0

    def initialize_model(self, path=None, train=True):
        self.G = RFRNet()
        self.optm_G = optim.Adam(self.G.parameters(), lr=2e-4)
        if train:
            self.lossNet = VGG16FeatureExtractor()
        try:
            start_iter = load_ckpt(path, [('generator', self.G)],
                                   [('optimizer_G', self.optm_G)])
            if train:
                self.optm_G = optim.Adam(self.G.parameters(), lr=2e-4)
                print('Model Initialized, iter: ', start_iter)
                self.iter = start_iter
        except:
            print('No trained model, from start')
            self.iter = 0

    def cuda(self):
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            print("Model moved to cuda")
            self.G.cuda()
            if self.lossNet is not None:
                self.lossNet.cuda()
        else:
            self.device = torch.device("cpu")

    def train(self,
              train_loader,
              save_path,
              finetune=False,
              iters=450000,
              batch_size=6,
              batch_preload_count=1):
        #    writer = SummaryWriter(log_dir="log_info")
        self.G.train(finetune=finetune)
        if finetune:
            self.optm_G = optim.Adam(filter(lambda p: p.requires_grad,
                                            self.G.parameters()),
                                     lr=5e-5)
        print("Starting training from iteration:{:d}".format(self.iter))
        s_time = time.time()

        while self.iter < iters:
            for items in train_loader:
                gt_image_batch, mask_batch, masked_image_batch = self.__cuda__(
                    *items)
                # print("New batch of %s elements" %(items[0].size()[0]))

                for batch_idx in range(0, batch_preload_count):
                    left = batch_idx * batch_size
                    right = left + min(batch_size, gt_image_batch.size()[0])
                    gt_image = gt_image_batch[left:right]
                    mask = mask_batch[left:right]
                    masked_image = masked_image_batch[left:right]

                    if gt_image.size()[0] == 0:
                        break

                    # print(len(train_loader), batch_idx, left, right, gt_image, mask, masked_image)
                    self.forward(masked_image, mask, gt_image)
                    self.update_parameters()
                    self.iter += 1

                    if self.iter % 50 == 0:
                        e_time = time.time()
                        int_time = e_time - s_time
                        print("Iteration:%d, l1_loss:%.4f, time_taken:%.2f" %
                              (self.iter, self.l1_loss_val / 50, int_time))
                        s_time = time.time()
                        self.l1_loss_val = 0.0

                    if self.iter % 40000 == 0:
                        if not os.path.exists('{:s}'.format(save_path)):
                            os.makedirs('{:s}'.format(save_path))
                        save_ckpt(
                            '{:s}/g_{:d}.pth'.format(save_path, self.iter),
                            [('generator', self.G)],
                            [('optimizer_G', self.optm_G)], self.iter)

                    if self.iter >= iters:
                        break

                if self.iter >= iters:
                    break

        print("Finished training iter %d. Saving model." % (self.iter))

        if not os.path.exists('{:s}'.format(save_path)):
            os.makedirs('{:s}'.format(save_path))
        # Save final checkpoint
        save_ckpt('{:s}/g_{:s}.pth'.format(save_path,
                                           "final"), [('generator', self.G)],
                  [('optimizer_G', self.optm_G)], self.iter)
        save_ckpt('{:s}/g_{:s}_{:d}.pth'.format(save_path, "final", self.iter),
                  [('generator', self.G)], [('optimizer_G', self.optm_G)],
                  self.iter)

    def test(self, test_loader, result_save_path):
        self.G.eval()
        for para in self.G.parameters():
            para.requires_grad = False
        count = 0
        for items in test_loader:
            gt_images, masks, masked_images = self.__cuda__(*items)
            masks = torch.cat([masks] * 3, dim=1)
            fake_B, mask = self.G(masked_images, masks)
            comp_B = fake_B * (1 - masks) + gt_images * masks
            if not os.path.exists('{:s}/results'.format(result_save_path)):
                os.makedirs('{:s}/results'.format(result_save_path))
            for k in range(comp_B.size(0)):
                count += 1
                grid = make_grid(comp_B[k:k + 1])
                file_path = '{:s}/results/img_{:d}.png'.format(
                    result_save_path, count)
                save_image(grid, file_path)

                grid = make_grid(masked_images[k:k + 1] + 1 - masks[k:k + 1])
                file_path = '{:s}/results/masked_img_{:d}.png'.format(
                    result_save_path, count)
                save_image(grid, file_path)

    def forward(self, masked_image, mask, gt_image):
        self.real_A = masked_image
        self.real_B = gt_image
        self.mask = mask
        fake_B, _ = self.G(masked_image, mask)
        self.fake_B = fake_B
        self.comp_B = self.fake_B * (1 - mask) + self.real_B * mask

    def update_parameters(self):
        self.update_G()
        self.update_D()

    def update_G(self):
        self.optm_G.zero_grad()
        loss_G = self.get_g_loss()
        loss_G.backward()
        self.optm_G.step()

    def update_D(self):
        return

    def get_g_loss(self):
        real_B = self.real_B
        fake_B = self.fake_B
        comp_B = self.comp_B

        real_B_feats = self.lossNet(real_B)
        fake_B_feats = self.lossNet(fake_B)
        comp_B_feats = self.lossNet(comp_B)

        tv_loss = self.TV_loss(comp_B * (1 - self.mask))
        style_loss = self.style_loss(real_B_feats,
                                     fake_B_feats) + self.style_loss(
                                         real_B_feats, comp_B_feats)
        preceptual_loss = self.preceptual_loss(
            real_B_feats, fake_B_feats) + self.preceptual_loss(
                real_B_feats, comp_B_feats)
        valid_loss = self.l1_loss(real_B, fake_B, self.mask)
        hole_loss = self.l1_loss(real_B, fake_B, (1 - self.mask))

        loss_G = (tv_loss * 0.1 + style_loss * 120 + preceptual_loss * 0.05 +
                  valid_loss * 1 + hole_loss * 6)

        self.l1_loss_val += valid_loss.detach() + hole_loss.detach()
        return loss_G

    def l1_loss(self, f1, f2, mask=1):
        return torch.mean(torch.abs(f1 - f2) * mask)

    def style_loss(self, A_feats, B_feats):
        assert len(A_feats) == len(
            B_feats
        ), "the length of two input feature maps lists should be the same"
        loss_value = 0.0
        for i in range(len(A_feats)):
            A_feat = A_feats[i]
            B_feat = B_feats[i]
            _, c, w, h = A_feat.size()
            A_feat = A_feat.view(A_feat.size(0), A_feat.size(1),
                                 A_feat.size(2) * A_feat.size(3))
            B_feat = B_feat.view(B_feat.size(0), B_feat.size(1),
                                 B_feat.size(2) * B_feat.size(3))
            A_style = torch.matmul(A_feat, A_feat.transpose(2, 1))
            B_style = torch.matmul(B_feat, B_feat.transpose(2, 1))
            loss_value += torch.mean(
                torch.abs(A_style - B_style) / (c * w * h))
        return loss_value

    def TV_loss(self, x):
        h_x = x.size(2)
        w_x = x.size(3)
        h_tv = torch.mean(torch.abs(x[:, :, 1:, :] - x[:, :, :h_x - 1, :]))
        w_tv = torch.mean(torch.abs(x[:, :, :, 1:] - x[:, :, :, :w_x - 1]))
        return h_tv + w_tv

    def preceptual_loss(self, A_feats, B_feats):
        assert len(A_feats) == len(
            B_feats
        ), "the length of two input feature maps lists should be the same"
        loss_value = 0.0
        for i in range(len(A_feats)):
            A_feat = A_feats[i]
            B_feat = B_feats[i]
            loss_value += torch.mean(torch.abs(A_feat - B_feat))
        return loss_value

    def __cuda__(self, *args):
        return (item.to(self.device) for item in args)