Exemplo n.º 1
0
    def __init__(self, transfer, descriptor, optimizer, wh, ctx, logger):
        self._init_metric()
        self.transfer = transfer
        self.descriptor = descriptor
        self.optimizer = optimizer
        self.gram_targets = []
        self.wh = wh
        # losses
        self.content_loss = CustomMSE()
        self.style_loss = CustomSSE()
        self.tv_loss = TVLoss()
        self.frame_coherent_loss = LuminanceLoss()
        self.fm_coherent_loss = FeatureLoss()
        self.ctx = ctx
        self.logger = logger

        self.train_tic = timer()
        return
Exemplo n.º 2
0
def train(args):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    transform = transforms.Compose(
        [crop(args.scale, args.patch_size),
         augmentation()])
    dataset = mydata(GT_path=args.GT_path,
                     LR_path=args.LR_path,
                     in_memory=args.in_memory,
                     transform=transform)
    loader = DataLoader(dataset,
                        batch_size=args.batch_size,
                        shuffle=True,
                        num_workers=args.num_workers)

    generator = Generator(img_feat=3,
                          n_feats=64,
                          kernel_size=3,
                          num_block=args.res_num,
                          scale=args.scale)

    if args.fine_tuning:
        generator.load_state_dict(torch.load(args.generator_path))
        print("pre-trained model is loaded")
        print("path : %s" % (args.generator_path))

    generator = generator.to(device)
    generator.train()

    l2_loss = nn.MSELoss()
    g_optim = optim.Adam(generator.parameters(), lr=1e-4)

    pre_epoch = 0
    fine_epoch = 0

    #### Train using L2_loss
    while pre_epoch < args.pre_train_epoch:
        for i, tr_data in enumerate(loader):
            gt = tr_data['GT'].to(device)
            lr = tr_data['LR'].to(device)

            output, _ = generator(lr)
            loss = l2_loss(gt, output)

            g_optim.zero_grad()
            loss.backward()
            g_optim.step()

        pre_epoch += 1

        if pre_epoch % 2 == 0:
            print(pre_epoch)
            print(loss.item())
            print('=========')

        if pre_epoch % 800 == 0:
            torch.save(generator.state_dict(),
                       './model/pre_trained_model_%03d.pt' % pre_epoch)

    #### Train using perceptual & adversarial loss
    vgg_net = vgg19().to(device)
    vgg_net = vgg_net.eval()

    discriminator = Discriminator(patch_size=args.patch_size * args.scale)
    discriminator = discriminator.to(device)
    discriminator.train()

    d_optim = optim.Adam(discriminator.parameters(), lr=1e-4)
    scheduler = optim.lr_scheduler.StepLR(g_optim, step_size=2000, gamma=0.1)

    VGG_loss = perceptual_loss(vgg_net)
    cross_ent = nn.BCELoss()
    tv_loss = TVLoss()
    real_label = torch.ones((args.batch_size, 1)).to(device)
    fake_label = torch.zeros((args.batch_size, 1)).to(device)

    while fine_epoch < args.fine_train_epoch:

        scheduler.step()

        for i, tr_data in enumerate(loader):
            gt = tr_data['GT'].to(device)
            lr = tr_data['LR'].to(device)

            ## Training Discriminator
            output, _ = generator(lr)
            fake_prob = discriminator(output)
            real_prob = discriminator(gt)

            d_loss_real = cross_ent(real_prob, real_label)
            d_loss_fake = cross_ent(fake_prob, fake_label)

            d_loss = d_loss_real + d_loss_fake

            g_optim.zero_grad()
            d_optim.zero_grad()
            d_loss.backward()
            d_optim.step()

            ## Training Generator
            output, _ = generator(lr)
            fake_prob = discriminator(output)

            _percep_loss, hr_feat, sr_feat = VGG_loss((gt + 1.0) / 2.0,
                                                      (output + 1.0) / 2.0,
                                                      layer=args.feat_layer)

            L2_loss = l2_loss(output, gt)
            percep_loss = args.vgg_rescale_coeff * _percep_loss
            adversarial_loss = args.adv_coeff * cross_ent(
                fake_prob, real_label)
            total_variance_loss = args.tv_loss_coeff * tv_loss(
                args.vgg_rescale_coeff * (hr_feat - sr_feat)**2)

            g_loss = percep_loss + adversarial_loss + total_variance_loss + L2_loss

            g_optim.zero_grad()
            d_optim.zero_grad()
            g_loss.backward()
            g_optim.step()

        fine_epoch += 1

        if fine_epoch % 2 == 0:
            print(fine_epoch)
            print(g_loss.item())
            print(d_loss.item())
            print('=========')

        if fine_epoch % 500 == 0:
            torch.save(generator.state_dict(),
                       './model/SRGAN_gene_%03d.pt' % fine_epoch)
            torch.save(discriminator.state_dict(),
                       './model/SRGAN_discrim_%03d.pt' % fine_epoch)
Exemplo n.º 3
0
    def get_style_model_and_losses(self, model_dict, style_img, content_img):
        """
        get the losses.
        """
        self.cnn = copy.deepcopy(self.cnn)
        self.cnn = self.cnn.to(self.device)

        c_idx = 0
        r_idx = 0
        p_idx = 0
        # do some normalization!

        # list of losses in layers:
        content_losses = []
        style_losses = []
        tv_losses = []

        model = nn.Sequential()
        i = 0
        tv_mod = TVLoss(1e-3)
        model.add_module(str(len(model)), tv_mod)
        tv_losses.append(tv_mod)

        for layer in self.cnn.children():

            if isinstance(layer, nn.Conv2d):
                i += 1
                name = model_dict['conv'][c_idx]
                c_idx += 1

            elif isinstance(layer, nn.ReLU):
                name = model_dict['relu'][r_idx]
                r_idx += 1
                layer = nn.ReLU(inplace=True)

            elif isinstance(layer, nn.MaxPool2d):
                name = model_dict['pool'][p_idx]
                # layer = nn.AvgPool2d(kernel_size=2, stride=2)
                p_idx += 1

            elif isinstance(layer, nn.BatchNorm2d):
                name = f'bn_{i}'

            else:
                layer_name = layer.__class__.__name__
                raise RuntimeError(f'Unrecognized Layer: {layer_name}')

            model.add_module(name, layer)

            if name in self.content_layers_default:
                content_loss = ContentLoss()
                model.add_module(f'content_loss_{i}', content_loss)
                content_losses.append(content_loss)

            if name in self.style_layers_default:
                # get feature maps from style
                style_loss = StyleLoss()
                model.add_module(f'style_loss_{i}', style_loss)
                style_losses.append(style_loss)

        # now we trim off the layers after the last content and style losses
        # if there is extra non-needed layers.
        for i in range(len(model) - 1, -1, -1):
            if isinstance(model[i], ContentLoss) or isinstance(
                    model[i], StyleLoss):
                break

        model = model[:(i + 1)]
        # rip through the model, getting the REAL feature maps for style and content.
        for module in style_losses:
            module.mode = "capture"

        model(style_img)

        for module in style_losses:
            module.mode = "none"

        for module in content_losses:
            module.mode = "capture"

        model(content_img)

        for module in style_losses:
            module.mode = "loss"

        for module in content_losses:
            module.mode = "loss"

        return model, style_losses, content_losses, tv_losses