示例#1
0
    def __init__(self, args, logger, dataloader, model, loss_all):
        self.args = args
        self.logger = logger
        self.dataloader = dataloader
        self.model = model
        self.loss_all = loss_all
        self.device = torch.device('cpu') if args.cpu else torch.device('cuda')
        self.vgg19 = Vgg19.Vgg19(requires_grad=False).to(self.device)
        if ((not self.args.cpu) and (self.args.num_gpu > 1)):
            self.vgg19 = nn.DataParallel(self.vgg19, list(range(self.args.num_gpu)))

        self.params = [
            {"params": filter(lambda p: p.requires_grad, self.model.MainNet.parameters() if 
             args.num_gpu==1 else self.model.module.MainNet.parameters()),
             "lr": args.lr_rate
            },
            {"params": filter(lambda p: p.requires_grad, self.model.LTE.parameters() if 
             args.num_gpu==1 else self.model.module.LTE.parameters()), 
             "lr": args.lr_rate_lte
            }
        ]
        self.optimizer = optim.Adam(self.params, betas=(args.beta1, args.beta2), eps=args.eps)
        self.scheduler = optim.lr_scheduler.StepLR(
            self.optimizer, step_size=self.args.decay, gamma=self.args.gamma)
        self.max_psnr = 0.
        self.max_psnr_epoch = 0
        self.max_ssim = 0.
        self.max_ssim_epoch = 0
示例#2
0
def train(config):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = Vgg19().to(device).eval()

    content_image = prepare_img(config['content_img_path'], config['height'],
                                device)
    style_image = prepare_img(config['style_img_path'], config['height'],
                              device)

    content_targets, _ = model(content_image)
    _, style_targets = model(style_image)
    style_targets = [GramMatrix(s) for s in style_targets]

    optimizing_img = torch.autograd.Variable(content_image, requires_grad=True)

    # optimizer = torch.optim.LBFGS((optimizing_img, ), max_iter=1000, line_search_fn='strong_wolfe')
    optimizer = torch.optim.Adam((optimizing_img, ), lr=1e1)

    tuning_step = make_tuning_step(content_targets, style_targets, model,
                                   optimizer, config['content_weight'],
                                   config['style_weight'], config['tv_weight'])

    for cnt in range(config['iteration']):
        total_loss, content_loss, style_loss, tv_loss = tuning_step(
            optimizing_img)
        print(
            f'Epochs: {cnt} | total_loss: {total_loss.item():12.4f} | content_loss: {content_loss.item() * config["content_weight"]:12.4f} | style_loss: {style_loss.item() * config["style_weight"]:12.4f}| tv_loss: {tv_loss.item() * config["tv_weight"]:12.4f}'
        )
        save_and_maybe_display(cnt,
                               optimizing_img,
                               config['output_img_path'],
                               config['saving_freq'],
                               config['iteration'],
                               config['img_format'],
                               should_display=True)
示例#3
0
    def __init__(self, config, outdir, modeldir, data_path, sketch_path,
                 ss_path):

        self.train_config = config["train"]
        self.data_config = config["dataset"]
        model_config = config["model"]
        self.loss_config = config["loss"]

        self.outdir = outdir
        self.modeldir = modeldir

        self.dataset = IllustDataset(
            data_path, sketch_path, ss_path, self.data_config["line_method"],
            self.data_config["extension"], self.data_config["train_size"],
            self.data_config["valid_size"], self.data_config["color_space"],
            self.data_config["line_space"])
        print(self.dataset)

        gen = Generator(model_config["generator"]["in_ch"],
                        num_layers=model_config["generator"]["num_layers"],
                        attn_type=model_config["generator"]["attn_type"],
                        guide=model_config["generator"]["guide"])
        self.gen, self.gen_opt = self._setting_model_optim(
            gen, model_config["generator"])
        self.guide = model_config["generator"]["guide"]

        i_dis = Discriminator(model_config["image_dis"]["in_ch"],
                              model_config["image_dis"]["multi"])
        self.i_dis, self.i_dis_opt = self._setting_model_optim(
            i_dis, model_config["image_dis"])

        s_dis = Discriminator(model_config["surface_dis"]["in_ch"],
                              model_config["surface_dis"]["multi"])
        self.s_dis, self.s_dis_opt = self._setting_model_optim(
            s_dis, model_config["surface_dis"])

        t_dis = Discriminator(model_config["texture_dis"]["in_ch"],
                              model_config["texture_dis"]["multi"])
        self.t_dis, self.t_dis_opt = self._setting_model_optim(
            t_dis, model_config["texture_dis"])

        self.guided_filter = GuidedFilter(r=5, eps=2e-1)
        self.guided_filter.cuda()

        self.out_guided_filter = GuidedFilter(r=1, eps=1e-2)
        self.out_guided_filter.cuda()

        self.vgg = Vgg19(requires_grad=False)
        self.vgg.cuda()
        self.vgg.eval()

        self.lossfunc = WhiteBoxLossCalculator()
        self.visualizer = Visualizer(self.data_config["color_space"])
示例#4
0
    def __init__(self, config, outdir, modeldir, data_path, sketch_path,
                 ss_path):

        self.train_config = config["train"]
        self.data_config = config["dataset"]
        model_config = config["model"]
        self.loss_config = config["loss"]

        self.outdir = outdir
        self.modeldir = modeldir

        self.dataset = IllustDataset(
            data_path, sketch_path, ss_path, self.data_config["line_method"],
            self.data_config["extension"], self.data_config["train_size"],
            self.data_config["valid_size"], self.data_config["color_space"],
            self.data_config["line_space"])
        print(self.dataset)

        gen = Generator(model_config["generator"]["in_ch"],
                        base=model_config["generator"]["base"],
                        num_layers=model_config["generator"]["num_layers"],
                        up_layers=model_config["generator"]["up_layers"],
                        guide=model_config["generator"]["guide"],
                        resnext=model_config["generator"]["resnext"],
                        encoder_type=model_config["generator"]["encoder_type"])
        self.gen, self.gen_opt = self._setting_model_optim(
            gen, model_config["generator"])
        self.guide = model_config["generator"]["guide"]

        dis = Discriminator(model_config["discriminator"]["in_ch"],
                            model_config["discriminator"]["multi"],
                            base=model_config["discriminator"]["base"],
                            sn=model_config["discriminator"]["sn"],
                            resnext=model_config["discriminator"]["resnext"],
                            patch=model_config["discriminator"]["patch"])
        self.dis, self.dis_opt = self._setting_model_optim(
            dis, model_config["discriminator"])

        self.vgg = Vgg19(requires_grad=False, layer="four")
        self.vgg.cuda()
        self.vgg.eval()

        self.out_filter = GuidedFilter(r=1, eps=1e-2)
        self.out_filter.cuda()

        self.lossfunc = LossCalculator()
        self.visualizer = Visualizer(self.data_config["color_space"])

        self.scheduler_gen = torch.optim.lr_scheduler.ExponentialLR(
            self.gen_opt, self.train_config["gamma"])
        self.scheduler_dis = torch.optim.lr_scheduler.ExponentialLR(
            self.dis_opt, self.train_config["gamma"])
示例#5
0
    def __init__(
        self,
        config,
        outdir,
        modeldir,
        data_path,
        sketch_path,
    ):

        self.train_config = config["train"]
        self.data_config = config["dataset"]
        model_config = config["model"]
        self.loss_config = config["loss"]

        self.outdir = outdir
        self.modeldir = modeldir
        self.mask = self.train_config["mask"]

        self.dataset = IllustDataset(data_path, sketch_path,
                                     self.data_config["extension"],
                                     self.data_config["train_size"],
                                     self.data_config["valid_size"],
                                     self.data_config["color_space"],
                                     self.data_config["line_space"])
        print(self.dataset)

        if self.mask:
            in_ch = 6
        else:
            in_ch = 3

        loc_gen = LocalEnhancer(
            in_ch=in_ch,
            num_layers=model_config["local_enhancer"]["num_layers"])
        self.loc_gen, self.loc_gen_opt = self._setting_model_optim(
            loc_gen, model_config["local_enhancer"])

        glo_gen = GlobalGenerator(in_ch=in_ch)
        self.glo_gen, self.glo_gen_opt = self._setting_model_optim(
            glo_gen, model_config["global_generator"])

        dis = Discriminator(model_config["discriminator"]["in_ch"],
                            model_config["discriminator"]["multi"])
        self.dis, self.dis_opt = self._setting_model_optim(
            dis, model_config["discriminator"])

        self.vgg = Vgg19(requires_grad=False)
        self.vgg.cuda()
        self.vgg.eval()

        self.lossfunc = Pix2pixHDCalculator()
        self.visualizer = Visualizer(self.data_config["color_space"])
示例#6
0
    def __init__(self, device, lambda_perceptual=10.0):
        """Perceptual Loss using VGG19

        Parameters
        ----------
        device : torch.device
            gpu or cpu.

        lambda_perceptual : float
            weight of perceptual loss.

        """

        super(PerceptualLoss, self).__init__()
        self.vgg = Vgg19().to(device)
        self.criterion = nn.L1Loss()
        self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
示例#7
0
    def __init__(self, config, outdir, modeldir, data_path, sketch_path,
                 dist_path, pretrained_path):

        self.train_config = config["train"]
        self.data_config = config["dataset"]
        model_config = config["model"]
        self.loss_config = config["loss"]

        self.outdir = outdir
        self.modeldir = modeldir

        self.dataset = IllustDataset(
            data_path, sketch_path, dist_path, self.data_config["anime_dir"],
            self.data_config["extension"], self.data_config["train_size"],
            self.data_config["valid_size"], self.data_config["scale"],
            self.data_config["frame_range"])
        print(self.dataset)

        self.ctn = ColorTransformNetwork(
            layers=model_config["CTN"]["num_layers"])
        self.ctn.cuda()
        self.ctn.eval()

        weight = torch.load(pretrained_path)
        self.ctn.load_state_dict(weight)

        gen = TemporalConstraintNetwork()
        self.gen, self.gen_opt = self._setting_model_optim(
            gen, model_config["TCN"])

        dis = Discriminator()
        self.dis, self.dis_opt = self._setting_model_optim(
            dis, model_config["discriminator"])

        t_dis = TemporalDiscriminator()
        self.t_dis, self.t_dis_opt = self._setting_model_optim(
            t_dis, model_config["temporal_discriminator"])

        self.vgg = Vgg19(requires_grad=False)
        self.vgg.cuda()
        self.vgg.eval()

        self.lossfunc = VideoColorizeLossCalculator()
        self.visualizer = Visualizer()
示例#8
0
    def __init__(
        self,
        config,
        outdir,
        outdir_fix,
        modeldir,
        data_path,
        sketch_path,
    ):

        self.train_config = config["train"]
        self.data_config = config["dataset"]
        model_config = config["model"]
        self.loss_config = config["loss"]

        self.outdir = outdir
        self.outdir_fix = outdir_fix
        self.modeldir = modeldir

        self.dataset = BuildDataset(
            data_path, sketch_path, self.data_config["line_method"],
            self.data_config["extension"], self.data_config["train_size"],
            self.data_config["valid_size"], self.data_config["color_space"],
            self.data_config["line_space"])
        print(self.dataset)

        gen = Generator(model_config["generator"]["in_ch"],
                        self.train_config["latent_dim"])
        self.gen, self.gen_opt = self._setting_model_optim(
            gen, model_config["generator"])

        dis = Discriminator(
            multi_patterns=model_config["discriminator"]["multi"])
        self.dis, self.dis_opt = self._setting_model_optim(
            dis, model_config["discriminator"])

        self.vgg = Vgg19(requires_grad=False)
        self.vgg.cuda()
        self.vgg.eval()

        self.lossfunc = SPADELossCalculator()
        self.visualizer = Visualizer()
        self.l_dim = self.train_config["latent_dim"]
示例#9
0
    def __init__(
        self,
        config,
        outdir,
        modeldir,
        data_path,
        sketch_path,
    ):

        self.train_config = config["train"]
        self.data_config = config["dataset"]
        model_config = config["model"]
        self.loss_config = config["loss"]

        self.outdir = outdir
        self.modeldir = modeldir

        self.dataset = IllustDataset(
            data_path, sketch_path, self.data_config["line_method"],
            self.data_config["extension"], self.data_config["train_size"],
            self.data_config["valid_size"], self.data_config["color_space"],
            self.data_config["line_space"],
            self.data_config["src_perturbation"],
            self.data_config["tgt_perturbation"])
        print(self.dataset)

        gen = Generator()
        self.gen, self.gen_opt = self._setting_model_optim(
            gen, model_config["generator"])

        dis = Discriminator()
        self.dis, self.dis_opt = self._setting_model_optim(
            dis, model_config["discriminator"])

        self.vgg = Vgg19(requires_grad=False)
        self.vgg.cuda()
        self.vgg.eval()

        self.lossfunc = SCFTLossCalculator()
        self.visualizer = Visualizer(self.data_config["color_space"])
示例#10
0
def main():
    
    os.environ['CUDA_VISIBLE_DEVICES'] = str(cfg.gpu)
    
    train_name = get_train_name()
    
    print_log('Initializing SRNET', content_color = PrintColor['yellow'])
    
    train_data = datagen_srnet(cfg)
    
    train_data = DataLoader(dataset = train_data, batch_size = cfg.batch_size, shuffle = False, collate_fn = custom_collate,  pin_memory = True)
    
    trfms = To_tensor()
    example_data = example_dataset(transform = trfms)
        
    example_loader = DataLoader(dataset = example_data, batch_size = 1, shuffle = False)
    
    print_log('training start.', content_color = PrintColor['yellow'])
        
    G = Generator(in_channels = 3).cuda()
    
    D1 = Discriminator(in_channels = 6).cuda()
    
    D2 = Discriminator(in_channels = 6).cuda()
        
    vgg_features = Vgg19().cuda()    
        
    G_solver = torch.optim.Adam(G.parameters(), lr=cfg.learning_rate, betas = (cfg.beta1, cfg.beta2))
    D1_solver = torch.optim.Adam(D1.parameters(), lr=cfg.learning_rate, betas = (cfg.beta1, cfg.beta2))
    D2_solver = torch.optim.Adam(D2.parameters(), lr=cfg.learning_rate, betas = (cfg.beta1, cfg.beta2))

    #g_scheduler = torch.optim.lr_scheduler.MultiStepLR(G_solver, milestones=[30, 200], gamma=0.5)
    
    #d1_scheduler = torch.optim.lr_scheduler.MultiStepLR(D1_solver, milestones=[30, 200], gamma=0.5)
    
    #d2_scheduler = torch.optim.lr_scheduler.MultiStepLR(D2_solver, milestones=[30, 200], gamma=0.5)

    try:
    
      checkpoint = torch.load(cfg.ckpt_path)
      G.load_state_dict(checkpoint['generator'])
      D1.load_state_dict(checkpoint['discriminator1'])
      D2.load_state_dict(checkpoint['discriminator2'])
      G_solver.load_state_dict(checkpoint['g_optimizer'])
      D1_solver.load_state_dict(checkpoint['d1_optimizer'])
      D2_solver.load_state_dict(checkpoint['d2_optimizer'])
      
      '''
      g_scheduler.load_state_dict(checkpoint['g_scheduler'])
      d1_scheduler.load_state_dict(checkpoint['d1_scheduler'])
      d2_scheduler.load_state_dict(checkpoint['d2_scheduler'])
      '''

      print('Resuming after loading...')

    except FileNotFoundError:

      print('checkpoint not found')
      pass  

    requires_grad(G, False)

    requires_grad(D1, True)
    requires_grad(D2, True)


    disc_loss_val = 0
    gen_loss_val = 0
    grad_loss_val = 0
        
    
    trainiter = iter(train_data)
    example_iter = iter(example_loader)
    
    K = torch.nn.ZeroPad2d((0, 1, 1, 0))

    for step in tqdm(range(cfg.max_iter)):
        
        D1_solver.zero_grad()
        D2_solver.zero_grad()
        
        if ((step+1) % cfg.save_ckpt_interval == 0):
            
            torch.save(
                {
                    'generator': G.state_dict(),
                    'discriminator1': D1.state_dict(),
                    'discriminator2': D2.state_dict(),
                    'g_optimizer': G_solver.state_dict(),
                    'd1_optimizer': D1_solver.state_dict(),
                    'd2_optimizer': D2_solver.state_dict(),
                    #'g_scheduler' : g_scheduler.state_dict(),
                    #'d1_scheduler':d1_scheduler.state_dict(),
                    #'d2_scheduler':d2_scheduler.state_dict(),
                },
                cfg.checkpoint_savedir+f'train_step-{step+1}.model',
            )
        
        try:

          i_t, i_s, t_sk, t_t, t_b, t_f, mask_t = trainiter.next()

        except StopIteration:

          trainiter = iter(train_data)
          i_t, i_s, t_sk, t_t, t_b, t_f, mask_t = trainiter.next()
                
        i_t = i_t.cuda()
        i_s = i_s.cuda()
        t_sk = t_sk.cuda()
        t_t = t_t.cuda()
        t_b = t_b.cuda()
        t_f = t_f.cuda()
        mask_t = mask_t.cuda()
                
        #inputs = [i_t, i_s]
        labels = [t_sk, t_t, t_b, t_f]
        
        o_sk, o_t, o_b, o_f = G(i_t, i_s)
        
        o_sk = K(o_sk)
        o_t = K(o_t)
        o_b = K(o_b)
        o_f = K(o_f)
                
        #print(o_sk.shape, o_t.shape, o_b.shape, o_f.shape)
        #print('------')
        #print(i_s.shape)
        
        i_db_true = torch.cat((t_b, i_s), dim = 1)
        i_db_pred = torch.cat((o_b, i_s), dim = 1)
        
        i_df_true = torch.cat((t_f, i_t), dim = 1)
        i_df_pred = torch.cat((o_f, i_t), dim = 1)
        
        o_db_true = D1(i_db_true)
        o_db_pred = D1(i_db_pred)
        
        o_df_true = D2(i_df_true)
        o_df_pred = D2(i_df_pred)
        
        i_vgg = torch.cat((t_f, o_f), dim = 0)
        
        out_vgg = vgg_features(i_vgg)
        
        db_loss = build_discriminator_loss(o_db_true,  o_db_pred)
        
        df_loss = build_discriminator_loss(o_df_true, o_df_pred)
                
        db_loss.backward()
        df_loss.backward()
        
        D1_solver.step()
        D2_solver.step()
        
        #d1_scheduler.step()
        #d2_scheduler.step()
        
        clip_grad(D1)
        clip_grad(D2)
        
        
        if ((step+1) % 5 == 0):
            
            requires_grad(G, True)

            requires_grad(D1, False)
            requires_grad(D2, False)
            
            G_solver.zero_grad()
            
            o_sk, o_t, o_b, o_f = G(i_t, i_s)
            
            o_sk = K(o_sk)
            o_t = K(o_t)
            o_b = K(o_b)
            o_f = K(o_f)

            #print(o_sk.shape, o_t.shape, o_b.shape, o_f.shape)
            #print('------')
            #print(i_s.shape)

            i_db_true = torch.cat((t_b, i_s), dim = 1)
            i_db_pred = torch.cat((o_b, i_s), dim = 1)

            i_df_true = torch.cat((t_f, i_t), dim = 1)
            i_df_pred = torch.cat((o_f, i_t), dim = 1)

            o_db_pred = D1(i_db_pred)

            o_df_pred = D2(i_df_pred)

            i_vgg = torch.cat((t_f, o_f), dim = 0)

            out_vgg = vgg_features(i_vgg)
            
            out_g = [o_sk, o_t, o_b, o_f, mask_t]
        
            out_d = [o_db_pred, o_df_pred]
        
            g_loss, detail = build_generator_loss(out_g, out_d, out_vgg, labels)    
                
            g_loss.backward()
            
            G_solver.step()
            
            #g_scheduler.step()
                        
            requires_grad(G, False)

            requires_grad(D1, True)
            requires_grad(D2, True)
            
        if ((step+1) % cfg.write_log_interval == 0):
            
            print('Iter: {}/{} | Gen: {} | D_bg: {} | D_fus: {}'.format(step+1, cfg.max_iter, g_loss.item(), db_loss.item(), df_loss.item()))
            
        if ((step+1) % cfg.gen_example_interval == 0):
            
            savedir = os.path.join(cfg.example_result_dir, train_name, 'iter-' + str(step+1).zfill(len(str(cfg.max_iter))))
            
            with torch.no_grad():

                try:

                  inp = example_iter.next()
                
                except StopIteration:

                  example_iter = iter(example_loader)
                  inp = example_iter.next()
                
                i_t = inp[0].cuda()
                i_s = inp[1].cuda()
                name = str(inp[2][0])
                
                o_sk, o_t, o_b, o_f = G(i_t, i_s)

                o_sk = o_sk.squeeze(0).to('cpu')
                o_t = o_t.squeeze(0).to('cpu')
                o_b = o_b.squeeze(0).to('cpu')
                o_f = o_f.squeeze(0).to('cpu')
                
                if not os.path.exists(savedir):
                    os.makedirs(savedir)
                
                o_sk = F.to_pil_image(o_sk)
                o_t = F.to_pil_image((o_t + 1)/2)
                o_b = F.to_pil_image((o_b + 1)/2)
                o_f = F.to_pil_image((o_f + 1)/2)
                               
                o_f.save(os.path.join(savedir, name + 'o_f.png'))
                
                o_sk.save(os.path.join(savedir, name + 'o_sk.png'))
                o_t.save(os.path.join(savedir, name + 'o_t.png'))
                o_b.save(os.path.join(savedir, name + 'o_b.png'))
示例#11
0
    def __init__(
            self

        # Training
        ,
            batch_size: int = 1,
            loss: str = 'rec'

        # Model
        ,
            model: str = 'TTSR',
            model__args: T.List = [],
            model__kwargs: T.Dict = {},
            optimizer: str = 'Adam',
            lr: T.Union[float, T.Dict[str, float]] = 1e-3,
            optimizer__args: T.List = [],
            optimizer__kwargs: T.Dict = {
                'betas': [0.9, 0.99],
                'weight_decay': 0
            },
            lr_scheduler: T.Optional[str] = None,
            lr_scheduler__args: T.List = [],
            lr_scheduler__kwargs: T.Dict = {}

        # Discriminator
        ,
            use_discriminator: bool = False,
            discriminator: str = 'Discriminator',
            discriminator__args: T.List = [],
            discriminator__kwargs: T.Dict = {},
            discriminator__optimizer: str = 'Adam',
            discriminator__lr: float = 1e-3,
            discriminator__optimizer__args: T.List = [],
            discriminator__optimizer__kwargs: T.Dict = {
                'betas': [0., 0.9],
                'weight_decay': 0
            },
            discriminator__frequency: T.Optional[int] = None,
            discriminator__lr_scheduler: T.Optional[str] = None,
            discriminator__lr_scheduler__args: T.List = [],
            discriminator__lr_scheduler__kwargs: T.Dict = {},
            add_vgg19: bool = False,
            *args,
            **kwargs):
        super().__init__()

        self.batch_size = batch_size
        self.loss = self._parse_loss(loss)

        self.model = self._parse_model(model, *model__args, **model__kwargs)
        self.optimizer = self._parse_optimizer(optimizer, self.model, lr,
                                               *optimizer__args,
                                               **optimizer__kwargs)
        self.lr_scheduler = self._parse_lr_scheduler(lr_scheduler,
                                                     self.optimizer,
                                                     *lr_scheduler__args,
                                                     **lr_scheduler__kwargs)

        self.use_discriminator = use_discriminator
        if self.use_discriminator:
            self.discriminator = self._parse_model(discriminator,
                                                   *discriminator__args,
                                                   **discriminator__kwargs)
            self.discriminator__optimizer = self._parse_optimizer(
                discriminator__optimizer, self.discriminator,
                discriminator__lr, *discriminator__optimizer__args,
                **discriminator__optimizer__kwargs)
            self.discriminator__frequency = discriminator__frequency
            self.discriminator__lr_scheduler = self._parse_lr_scheduler(
                discriminator__lr_scheduler, self.discriminator__optimizer,
                *discriminator__lr_scheduler__args,
                **discriminator__lr_scheduler__kwargs)

        if add_vgg19:
            self.vgg19 = Vgg19.Vgg19(requires_grad=False)

        # Basically some meta info to save within logger
        self.args = args
        self.kwargs = kwargs

        self.save_hyperparameters()
示例#12
0
    config = args.parse_args()

    num_classes = config.num_classes
    base_lr = config.lr
    cuda = config.cuda
    num_epochs = config.num_epochs
    print_iter = config.print_iter
    model_name = config.model_name
    prediction_file = config.prediction_file
    test_file = config.test_file
    batch = config.batch
    mode = config.mode

    # create model
    model = Vgg19(num_classes=num_classes)

    if mode == 'test':
        load_model(model_name, model)

    if cuda:
        model = model.cuda()

    if mode == 'train':
        # define loss function
        loss_fn = nn.CrossEntropyLoss()
        if cuda:
            loss_fn = loss_fn.cuda()

        # set optimizer
        optimizer = Adam(
示例#13
0
def main():
    
    parser = argparse.ArgumentParser() 
    parser.add_argument('--input_dir', help = 'Directory containing xxx_i_s and xxx_i_t with same prefix',
                        default = cfg.example_data_dir)
    parser.add_argument('--save_dir', help = 'Directory to save result', default = cfg.predict_result_dir)
    parser.add_argument('--checkpoint', help = 'ckpt', default = cfg.ckpt_path)
    args = parser.parse_args()

    assert args.input_dir is not None
    assert args.save_dir is not None
    assert args.checkpoint is not None

    print_log('model compiling start.', content_color = PrintColor['yellow'])

    G = Generator(in_channels = 3).to(device)
    D1 = Discriminator(in_channels = 6).to(device)
    D2 = Discriminator(in_channels = 6).to(device)  
    vgg_features = Vgg19().to(device)   
      
    G_solver = torch.optim.Adam(G.parameters(), lr=cfg.learning_rate, betas = (cfg.beta1, cfg.beta2))
    D1_solver = torch.optim.Adam(D1.parameters(), lr=cfg.learning_rate, betas = (cfg.beta1, cfg.beta2))
    D2_solver = torch.optim.Adam(D2.parameters(), lr=cfg.learning_rate, betas = (cfg.beta1, cfg.beta2))

    checkpoint = torch.load(args.checkpoint)
    G.load_state_dict(checkpoint['generator'])
    D1.load_state_dict(checkpoint['discriminator1'])
    D2.load_state_dict(checkpoint['discriminator2'])
    G_solver.load_state_dict(checkpoint['g_optimizer'])
    D1_solver.load_state_dict(checkpoint['d1_optimizer'])
    D2_solver.load_state_dict(checkpoint['d2_optimizer'])

    trfms = To_tensor()
    example_data = example_dataset(data_dir= args.input_dir, transform = trfms)
    example_loader = DataLoader(dataset = example_data, batch_size = 1, shuffle = False)
    example_iter = iter(example_loader)

    print_log('Model compiled.', content_color = PrintColor['yellow'])

    print_log('Predicting', content_color = PrintColor['yellow'])

    G.eval()
    D1.eval()
    D2.eval()

    with torch.no_grad():

      for step in tqdm(range(len(example_data))):

        try:

          inp = example_iter.next()

        except StopIteration:

          example_iter = iter(example_loader)
          inp = example_iter.next()

        i_t = inp[0].to(device)
        i_s = inp[1].to(device)
        name = str(inp[2][0])

        o_sk, o_t, o_b, o_f = G(i_t, i_s, (i_t.shape[2], i_t.shape[3]))

        o_sk = o_sk.squeeze(0).detach().to('cpu')
        o_t = o_t.squeeze(0).detach().to('cpu')
        o_b = o_b.squeeze(0).detach().to('cpu')
        o_f = o_f.squeeze(0).detach().to('cpu')

        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)

        o_sk = F.to_pil_image(o_sk)
        o_t = F.to_pil_image((o_t + 1)/2)
        o_b = F.to_pil_image((o_b + 1)/2)
        o_f = F.to_pil_image((o_f + 1)/2)
                        
        o_f.save(os.path.join(args.save_dir, name + 'o_f.png'))
示例#14
0
    num_classes = config.num_classes
    base_lr = config.lr
    cuda = config.cuda
    num_epochs = config.num_epochs
    print_iter = config.print_iter
    model_name = config.model_name
    prediction_file = config.prediction_file
    best_prediction_file = config.best_prediction_file  #DBY
    batch = config.batch
    mode = config.mode

    # create model
    model = SiameseNetwork()
    #model = SiameseEfficientNet()
    model = Vgg19()

    if mode == 'test':
        load_model(model_name, model)

    if cuda:
        model = model.cuda()

    # Define 'best loss' - DBY
    best_loss = 0.1
    last_loss = 0
    if mode == 'train':
        # define loss function
        # loss_fn = nn.CrossEntropyLoss()
        # if cuda:
        #     loss_fn = loss_fn.cuda()
示例#15
0
    def __init__(self,
                 config,
                 outdir,
                 modeldir,
                 data_path,
                 sketch_path,
                 flat_path,
                 pretrain_path=None):

        self.train_config = config["train"]
        self.data_config = config["dataset"]
        model_config = config["model"]
        self.loss_config = config["loss"]

        self.outdir = outdir
        self.modeldir = modeldir

        self.train_type = self.train_config["train_type"]

        if self.train_type == "multi":
            self.dataset = DanbooruFacesDataset(data_path,
                                                sketch_path,
                                                self.data_config["line_method"],
                                                self.data_config["extension"],
                                                self.data_config["train_size"],
                                                self.data_config["valid_size"],
                                                self.data_config["color_space"],
                                                self.data_config["line_space"])

        else:
            self.dataset = IllustDataset(data_path,
                                         sketch_path,
                                         flat_path,
                                         self.data_config["line_method"],
                                         self.data_config["extension"],
                                         self.data_config["train_size"],
                                         self.data_config["valid_size"],
                                         self.data_config["color_space"],
                                         self.data_config["line_space"])
        print(self.dataset)

        flat_gen = Generator(model_config["flat_generator"]["in_ch"],
                             num_layers=model_config["flat_generator"]["num_layers"],
                             attn_type=model_config["flat_generator"]["attn_type"],
                             )
        self.flat_gen, self.flat_gen_opt = self._setting_model_optim(flat_gen,
                                                                     model_config["flat_generator"])

        if self.train_type == "multi":
            weight = torch.load(pretrain_path)
            self.flat_gen.load_state_dict(weight)

        f_dis = Discriminator(model_config["flat_dis"]["in_ch"],
                              model_config["flat_dis"]["multi"])
        self.f_dis, self.f_dis_opt = self._setting_model_optim(f_dis,
                                                               model_config["flat_dis"])

        if self.train_type == "multi":
            bicycle_gen = BicycleGAN(model_config["bicycle_gan"]["in_ch"],
                                     latent_dim=model_config["bicycle_gan"]["l_dim"],
                                     num_layers=model_config["bicycle_gan"]["num_layers"])
            self.b_gen, self.b_gen_opt = self._setting_model_optim(bicycle_gen,
                                                                   model_config["bicycle_gan"])

            latent_enc = LatentEncoder(model_config["encoder"]["in_ch"],
                                       latent_dim=model_config["encoder"]["l_dim"])
            self.l_enc, self.l_enc_opt = self._setting_model_optim(latent_enc,
                                                                   model_config["encoder"])

            b_dis = Discriminator(model_config["bicycle_dis"]["in_ch"],
                                  model_config["bicycle_dis"]["multi"])
            self.b_dis, self.b_dis_opt = self._setting_model_optim(b_dis,
                                                                   model_config["bicycle_dis"])

            fixer = ColorFixer()
            self.fix, self.fix_opt = self._setting_model_optim(fixer,
                                                               model_config["fixer"])

        self.vgg = Vgg19(requires_grad=False)
        self.vgg.cuda()
        self.vgg.eval()

        self.out_filter = GuidedFilter(r=1, eps=1e-2)
        self.out_filter.cuda()

        self.lossfunc = DecomposeLossCalculator()
        self.visualizer = Visualizer(self.data_config["color_space"])
示例#16
0
def train(epochs, interval, batchsize, validsize, data_path, sketch_path,
          extension, img_size, outdir, modeldir, gen_learning_rate,
          dis_learning_rate, beta1, beta2):

    # Dataset Definition
    dataset = IllustDataset(data_path, sketch_path, extension)
    c_valid, l_valid = dataset.valid(validsize)
    print(dataset)
    collator = LineCollator(img_size)

    # Model & Optimizer Definition
    model = Style2Paint()
    model.cuda()
    model.train()
    gen_opt = torch.optim.Adam(model.parameters(),
                               lr=gen_learning_rate,
                               betas=(beta1, beta2))

    discriminator = Discriminator()
    discriminator.cuda()
    discriminator.train()
    dis_opt = torch.optim.Adam(discriminator.parameters(),
                               lr=dis_learning_rate,
                               betas=(beta1, beta2))

    vgg = Vgg19(requires_grad=False)
    vgg.cuda()
    vgg.eval()

    # Loss function definition
    lossfunc = Style2paintsLossCalculator()

    # Visualizer definition
    visualizer = Visualizer()

    iteration = 0

    for epoch in range(epochs):
        dataloader = DataLoader(dataset,
                                batch_size=batchsize,
                                shuffle=True,
                                collate_fn=collator,
                                drop_last=True)
        progress_bar = tqdm(dataloader)

        for index, data in enumerate(progress_bar):
            iteration += 1
            jit, war, line = data

            # Discriminator update
            y = model(line, war)
            loss = lossfunc.adversarial_disloss(discriminator, y.detach(), jit)

            dis_opt.zero_grad()
            loss.backward()
            dis_opt.step()

            # Generator update
            y = model(line, war)
            loss = lossfunc.adversarial_genloss(discriminator, y)
            loss += 10.0 * lossfunc.content_loss(y, jit)
            loss += lossfunc.style_and_perceptual_loss(vgg, y, jit)

            gen_opt.zero_grad()
            loss.backward()
            gen_opt.step()

            if iteration % interval == 1:
                torch.save(model.state_dict(),
                           f"{modeldir}/model_{iteration}.pt")

                with torch.no_grad():
                    y = model(l_valid, c_valid)

                c = c_valid.detach().cpu().numpy()
                l = l_valid.detach().cpu().numpy()
                y = y.detach().cpu().numpy()

                visualizer(l, c, y, outdir, iteration, validsize)

            print(f"iteration: {iteration} Loss: {loss.data}")