def __init__(self, args, logger, dataloader, model, loss_all): self.args = args self.logger = logger self.dataloader = dataloader self.model = model self.loss_all = loss_all self.device = torch.device('cpu') if args.cpu else torch.device('cuda') self.vgg19 = Vgg19.Vgg19(requires_grad=False).to(self.device) if ((not self.args.cpu) and (self.args.num_gpu > 1)): self.vgg19 = nn.DataParallel(self.vgg19, list(range(self.args.num_gpu))) self.params = [ {"params": filter(lambda p: p.requires_grad, self.model.MainNet.parameters() if args.num_gpu==1 else self.model.module.MainNet.parameters()), "lr": args.lr_rate }, {"params": filter(lambda p: p.requires_grad, self.model.LTE.parameters() if args.num_gpu==1 else self.model.module.LTE.parameters()), "lr": args.lr_rate_lte } ] self.optimizer = optim.Adam(self.params, betas=(args.beta1, args.beta2), eps=args.eps) self.scheduler = optim.lr_scheduler.StepLR( self.optimizer, step_size=self.args.decay, gamma=self.args.gamma) self.max_psnr = 0. self.max_psnr_epoch = 0 self.max_ssim = 0. self.max_ssim_epoch = 0
def train(config): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Vgg19().to(device).eval() content_image = prepare_img(config['content_img_path'], config['height'], device) style_image = prepare_img(config['style_img_path'], config['height'], device) content_targets, _ = model(content_image) _, style_targets = model(style_image) style_targets = [GramMatrix(s) for s in style_targets] optimizing_img = torch.autograd.Variable(content_image, requires_grad=True) # optimizer = torch.optim.LBFGS((optimizing_img, ), max_iter=1000, line_search_fn='strong_wolfe') optimizer = torch.optim.Adam((optimizing_img, ), lr=1e1) tuning_step = make_tuning_step(content_targets, style_targets, model, optimizer, config['content_weight'], config['style_weight'], config['tv_weight']) for cnt in range(config['iteration']): total_loss, content_loss, style_loss, tv_loss = tuning_step( optimizing_img) print( f'Epochs: {cnt} | total_loss: {total_loss.item():12.4f} | content_loss: {content_loss.item() * config["content_weight"]:12.4f} | style_loss: {style_loss.item() * config["style_weight"]:12.4f}| tv_loss: {tv_loss.item() * config["tv_weight"]:12.4f}' ) save_and_maybe_display(cnt, optimizing_img, config['output_img_path'], config['saving_freq'], config['iteration'], config['img_format'], should_display=True)
def __init__(self, config, outdir, modeldir, data_path, sketch_path, ss_path): self.train_config = config["train"] self.data_config = config["dataset"] model_config = config["model"] self.loss_config = config["loss"] self.outdir = outdir self.modeldir = modeldir self.dataset = IllustDataset( data_path, sketch_path, ss_path, self.data_config["line_method"], self.data_config["extension"], self.data_config["train_size"], self.data_config["valid_size"], self.data_config["color_space"], self.data_config["line_space"]) print(self.dataset) gen = Generator(model_config["generator"]["in_ch"], num_layers=model_config["generator"]["num_layers"], attn_type=model_config["generator"]["attn_type"], guide=model_config["generator"]["guide"]) self.gen, self.gen_opt = self._setting_model_optim( gen, model_config["generator"]) self.guide = model_config["generator"]["guide"] i_dis = Discriminator(model_config["image_dis"]["in_ch"], model_config["image_dis"]["multi"]) self.i_dis, self.i_dis_opt = self._setting_model_optim( i_dis, model_config["image_dis"]) s_dis = Discriminator(model_config["surface_dis"]["in_ch"], model_config["surface_dis"]["multi"]) self.s_dis, self.s_dis_opt = self._setting_model_optim( s_dis, model_config["surface_dis"]) t_dis = Discriminator(model_config["texture_dis"]["in_ch"], model_config["texture_dis"]["multi"]) self.t_dis, self.t_dis_opt = self._setting_model_optim( t_dis, model_config["texture_dis"]) self.guided_filter = GuidedFilter(r=5, eps=2e-1) self.guided_filter.cuda() self.out_guided_filter = GuidedFilter(r=1, eps=1e-2) self.out_guided_filter.cuda() self.vgg = Vgg19(requires_grad=False) self.vgg.cuda() self.vgg.eval() self.lossfunc = WhiteBoxLossCalculator() self.visualizer = Visualizer(self.data_config["color_space"])
def __init__(self, config, outdir, modeldir, data_path, sketch_path, ss_path): self.train_config = config["train"] self.data_config = config["dataset"] model_config = config["model"] self.loss_config = config["loss"] self.outdir = outdir self.modeldir = modeldir self.dataset = IllustDataset( data_path, sketch_path, ss_path, self.data_config["line_method"], self.data_config["extension"], self.data_config["train_size"], self.data_config["valid_size"], self.data_config["color_space"], self.data_config["line_space"]) print(self.dataset) gen = Generator(model_config["generator"]["in_ch"], base=model_config["generator"]["base"], num_layers=model_config["generator"]["num_layers"], up_layers=model_config["generator"]["up_layers"], guide=model_config["generator"]["guide"], resnext=model_config["generator"]["resnext"], encoder_type=model_config["generator"]["encoder_type"]) self.gen, self.gen_opt = self._setting_model_optim( gen, model_config["generator"]) self.guide = model_config["generator"]["guide"] dis = Discriminator(model_config["discriminator"]["in_ch"], model_config["discriminator"]["multi"], base=model_config["discriminator"]["base"], sn=model_config["discriminator"]["sn"], resnext=model_config["discriminator"]["resnext"], patch=model_config["discriminator"]["patch"]) self.dis, self.dis_opt = self._setting_model_optim( dis, model_config["discriminator"]) self.vgg = Vgg19(requires_grad=False, layer="four") self.vgg.cuda() self.vgg.eval() self.out_filter = GuidedFilter(r=1, eps=1e-2) self.out_filter.cuda() self.lossfunc = LossCalculator() self.visualizer = Visualizer(self.data_config["color_space"]) self.scheduler_gen = torch.optim.lr_scheduler.ExponentialLR( self.gen_opt, self.train_config["gamma"]) self.scheduler_dis = torch.optim.lr_scheduler.ExponentialLR( self.dis_opt, self.train_config["gamma"])
def __init__( self, config, outdir, modeldir, data_path, sketch_path, ): self.train_config = config["train"] self.data_config = config["dataset"] model_config = config["model"] self.loss_config = config["loss"] self.outdir = outdir self.modeldir = modeldir self.mask = self.train_config["mask"] self.dataset = IllustDataset(data_path, sketch_path, self.data_config["extension"], self.data_config["train_size"], self.data_config["valid_size"], self.data_config["color_space"], self.data_config["line_space"]) print(self.dataset) if self.mask: in_ch = 6 else: in_ch = 3 loc_gen = LocalEnhancer( in_ch=in_ch, num_layers=model_config["local_enhancer"]["num_layers"]) self.loc_gen, self.loc_gen_opt = self._setting_model_optim( loc_gen, model_config["local_enhancer"]) glo_gen = GlobalGenerator(in_ch=in_ch) self.glo_gen, self.glo_gen_opt = self._setting_model_optim( glo_gen, model_config["global_generator"]) dis = Discriminator(model_config["discriminator"]["in_ch"], model_config["discriminator"]["multi"]) self.dis, self.dis_opt = self._setting_model_optim( dis, model_config["discriminator"]) self.vgg = Vgg19(requires_grad=False) self.vgg.cuda() self.vgg.eval() self.lossfunc = Pix2pixHDCalculator() self.visualizer = Visualizer(self.data_config["color_space"])
def __init__(self, device, lambda_perceptual=10.0): """Perceptual Loss using VGG19 Parameters ---------- device : torch.device gpu or cpu. lambda_perceptual : float weight of perceptual loss. """ super(PerceptualLoss, self).__init__() self.vgg = Vgg19().to(device) self.criterion = nn.L1Loss() self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
def __init__(self, config, outdir, modeldir, data_path, sketch_path, dist_path, pretrained_path): self.train_config = config["train"] self.data_config = config["dataset"] model_config = config["model"] self.loss_config = config["loss"] self.outdir = outdir self.modeldir = modeldir self.dataset = IllustDataset( data_path, sketch_path, dist_path, self.data_config["anime_dir"], self.data_config["extension"], self.data_config["train_size"], self.data_config["valid_size"], self.data_config["scale"], self.data_config["frame_range"]) print(self.dataset) self.ctn = ColorTransformNetwork( layers=model_config["CTN"]["num_layers"]) self.ctn.cuda() self.ctn.eval() weight = torch.load(pretrained_path) self.ctn.load_state_dict(weight) gen = TemporalConstraintNetwork() self.gen, self.gen_opt = self._setting_model_optim( gen, model_config["TCN"]) dis = Discriminator() self.dis, self.dis_opt = self._setting_model_optim( dis, model_config["discriminator"]) t_dis = TemporalDiscriminator() self.t_dis, self.t_dis_opt = self._setting_model_optim( t_dis, model_config["temporal_discriminator"]) self.vgg = Vgg19(requires_grad=False) self.vgg.cuda() self.vgg.eval() self.lossfunc = VideoColorizeLossCalculator() self.visualizer = Visualizer()
def __init__( self, config, outdir, outdir_fix, modeldir, data_path, sketch_path, ): self.train_config = config["train"] self.data_config = config["dataset"] model_config = config["model"] self.loss_config = config["loss"] self.outdir = outdir self.outdir_fix = outdir_fix self.modeldir = modeldir self.dataset = BuildDataset( data_path, sketch_path, self.data_config["line_method"], self.data_config["extension"], self.data_config["train_size"], self.data_config["valid_size"], self.data_config["color_space"], self.data_config["line_space"]) print(self.dataset) gen = Generator(model_config["generator"]["in_ch"], self.train_config["latent_dim"]) self.gen, self.gen_opt = self._setting_model_optim( gen, model_config["generator"]) dis = Discriminator( multi_patterns=model_config["discriminator"]["multi"]) self.dis, self.dis_opt = self._setting_model_optim( dis, model_config["discriminator"]) self.vgg = Vgg19(requires_grad=False) self.vgg.cuda() self.vgg.eval() self.lossfunc = SPADELossCalculator() self.visualizer = Visualizer() self.l_dim = self.train_config["latent_dim"]
def __init__( self, config, outdir, modeldir, data_path, sketch_path, ): self.train_config = config["train"] self.data_config = config["dataset"] model_config = config["model"] self.loss_config = config["loss"] self.outdir = outdir self.modeldir = modeldir self.dataset = IllustDataset( data_path, sketch_path, self.data_config["line_method"], self.data_config["extension"], self.data_config["train_size"], self.data_config["valid_size"], self.data_config["color_space"], self.data_config["line_space"], self.data_config["src_perturbation"], self.data_config["tgt_perturbation"]) print(self.dataset) gen = Generator() self.gen, self.gen_opt = self._setting_model_optim( gen, model_config["generator"]) dis = Discriminator() self.dis, self.dis_opt = self._setting_model_optim( dis, model_config["discriminator"]) self.vgg = Vgg19(requires_grad=False) self.vgg.cuda() self.vgg.eval() self.lossfunc = SCFTLossCalculator() self.visualizer = Visualizer(self.data_config["color_space"])
def main(): os.environ['CUDA_VISIBLE_DEVICES'] = str(cfg.gpu) train_name = get_train_name() print_log('Initializing SRNET', content_color = PrintColor['yellow']) train_data = datagen_srnet(cfg) train_data = DataLoader(dataset = train_data, batch_size = cfg.batch_size, shuffle = False, collate_fn = custom_collate, pin_memory = True) trfms = To_tensor() example_data = example_dataset(transform = trfms) example_loader = DataLoader(dataset = example_data, batch_size = 1, shuffle = False) print_log('training start.', content_color = PrintColor['yellow']) G = Generator(in_channels = 3).cuda() D1 = Discriminator(in_channels = 6).cuda() D2 = Discriminator(in_channels = 6).cuda() vgg_features = Vgg19().cuda() G_solver = torch.optim.Adam(G.parameters(), lr=cfg.learning_rate, betas = (cfg.beta1, cfg.beta2)) D1_solver = torch.optim.Adam(D1.parameters(), lr=cfg.learning_rate, betas = (cfg.beta1, cfg.beta2)) D2_solver = torch.optim.Adam(D2.parameters(), lr=cfg.learning_rate, betas = (cfg.beta1, cfg.beta2)) #g_scheduler = torch.optim.lr_scheduler.MultiStepLR(G_solver, milestones=[30, 200], gamma=0.5) #d1_scheduler = torch.optim.lr_scheduler.MultiStepLR(D1_solver, milestones=[30, 200], gamma=0.5) #d2_scheduler = torch.optim.lr_scheduler.MultiStepLR(D2_solver, milestones=[30, 200], gamma=0.5) try: checkpoint = torch.load(cfg.ckpt_path) G.load_state_dict(checkpoint['generator']) D1.load_state_dict(checkpoint['discriminator1']) D2.load_state_dict(checkpoint['discriminator2']) G_solver.load_state_dict(checkpoint['g_optimizer']) D1_solver.load_state_dict(checkpoint['d1_optimizer']) D2_solver.load_state_dict(checkpoint['d2_optimizer']) ''' g_scheduler.load_state_dict(checkpoint['g_scheduler']) d1_scheduler.load_state_dict(checkpoint['d1_scheduler']) d2_scheduler.load_state_dict(checkpoint['d2_scheduler']) ''' print('Resuming after loading...') except FileNotFoundError: print('checkpoint not found') pass requires_grad(G, False) requires_grad(D1, True) requires_grad(D2, True) disc_loss_val = 0 gen_loss_val = 0 grad_loss_val = 0 trainiter = iter(train_data) example_iter = iter(example_loader) K = torch.nn.ZeroPad2d((0, 1, 1, 0)) for step in tqdm(range(cfg.max_iter)): D1_solver.zero_grad() D2_solver.zero_grad() if ((step+1) % cfg.save_ckpt_interval == 0): torch.save( { 'generator': G.state_dict(), 'discriminator1': D1.state_dict(), 'discriminator2': D2.state_dict(), 'g_optimizer': G_solver.state_dict(), 'd1_optimizer': D1_solver.state_dict(), 'd2_optimizer': D2_solver.state_dict(), #'g_scheduler' : g_scheduler.state_dict(), #'d1_scheduler':d1_scheduler.state_dict(), #'d2_scheduler':d2_scheduler.state_dict(), }, cfg.checkpoint_savedir+f'train_step-{step+1}.model', ) try: i_t, i_s, t_sk, t_t, t_b, t_f, mask_t = trainiter.next() except StopIteration: trainiter = iter(train_data) i_t, i_s, t_sk, t_t, t_b, t_f, mask_t = trainiter.next() i_t = i_t.cuda() i_s = i_s.cuda() t_sk = t_sk.cuda() t_t = t_t.cuda() t_b = t_b.cuda() t_f = t_f.cuda() mask_t = mask_t.cuda() #inputs = [i_t, i_s] labels = [t_sk, t_t, t_b, t_f] o_sk, o_t, o_b, o_f = G(i_t, i_s) o_sk = K(o_sk) o_t = K(o_t) o_b = K(o_b) o_f = K(o_f) #print(o_sk.shape, o_t.shape, o_b.shape, o_f.shape) #print('------') #print(i_s.shape) i_db_true = torch.cat((t_b, i_s), dim = 1) i_db_pred = torch.cat((o_b, i_s), dim = 1) i_df_true = torch.cat((t_f, i_t), dim = 1) i_df_pred = torch.cat((o_f, i_t), dim = 1) o_db_true = D1(i_db_true) o_db_pred = D1(i_db_pred) o_df_true = D2(i_df_true) o_df_pred = D2(i_df_pred) i_vgg = torch.cat((t_f, o_f), dim = 0) out_vgg = vgg_features(i_vgg) db_loss = build_discriminator_loss(o_db_true, o_db_pred) df_loss = build_discriminator_loss(o_df_true, o_df_pred) db_loss.backward() df_loss.backward() D1_solver.step() D2_solver.step() #d1_scheduler.step() #d2_scheduler.step() clip_grad(D1) clip_grad(D2) if ((step+1) % 5 == 0): requires_grad(G, True) requires_grad(D1, False) requires_grad(D2, False) G_solver.zero_grad() o_sk, o_t, o_b, o_f = G(i_t, i_s) o_sk = K(o_sk) o_t = K(o_t) o_b = K(o_b) o_f = K(o_f) #print(o_sk.shape, o_t.shape, o_b.shape, o_f.shape) #print('------') #print(i_s.shape) i_db_true = torch.cat((t_b, i_s), dim = 1) i_db_pred = torch.cat((o_b, i_s), dim = 1) i_df_true = torch.cat((t_f, i_t), dim = 1) i_df_pred = torch.cat((o_f, i_t), dim = 1) o_db_pred = D1(i_db_pred) o_df_pred = D2(i_df_pred) i_vgg = torch.cat((t_f, o_f), dim = 0) out_vgg = vgg_features(i_vgg) out_g = [o_sk, o_t, o_b, o_f, mask_t] out_d = [o_db_pred, o_df_pred] g_loss, detail = build_generator_loss(out_g, out_d, out_vgg, labels) g_loss.backward() G_solver.step() #g_scheduler.step() requires_grad(G, False) requires_grad(D1, True) requires_grad(D2, True) if ((step+1) % cfg.write_log_interval == 0): print('Iter: {}/{} | Gen: {} | D_bg: {} | D_fus: {}'.format(step+1, cfg.max_iter, g_loss.item(), db_loss.item(), df_loss.item())) if ((step+1) % cfg.gen_example_interval == 0): savedir = os.path.join(cfg.example_result_dir, train_name, 'iter-' + str(step+1).zfill(len(str(cfg.max_iter)))) with torch.no_grad(): try: inp = example_iter.next() except StopIteration: example_iter = iter(example_loader) inp = example_iter.next() i_t = inp[0].cuda() i_s = inp[1].cuda() name = str(inp[2][0]) o_sk, o_t, o_b, o_f = G(i_t, i_s) o_sk = o_sk.squeeze(0).to('cpu') o_t = o_t.squeeze(0).to('cpu') o_b = o_b.squeeze(0).to('cpu') o_f = o_f.squeeze(0).to('cpu') if not os.path.exists(savedir): os.makedirs(savedir) o_sk = F.to_pil_image(o_sk) o_t = F.to_pil_image((o_t + 1)/2) o_b = F.to_pil_image((o_b + 1)/2) o_f = F.to_pil_image((o_f + 1)/2) o_f.save(os.path.join(savedir, name + 'o_f.png')) o_sk.save(os.path.join(savedir, name + 'o_sk.png')) o_t.save(os.path.join(savedir, name + 'o_t.png')) o_b.save(os.path.join(savedir, name + 'o_b.png'))
def __init__( self # Training , batch_size: int = 1, loss: str = 'rec' # Model , model: str = 'TTSR', model__args: T.List = [], model__kwargs: T.Dict = {}, optimizer: str = 'Adam', lr: T.Union[float, T.Dict[str, float]] = 1e-3, optimizer__args: T.List = [], optimizer__kwargs: T.Dict = { 'betas': [0.9, 0.99], 'weight_decay': 0 }, lr_scheduler: T.Optional[str] = None, lr_scheduler__args: T.List = [], lr_scheduler__kwargs: T.Dict = {} # Discriminator , use_discriminator: bool = False, discriminator: str = 'Discriminator', discriminator__args: T.List = [], discriminator__kwargs: T.Dict = {}, discriminator__optimizer: str = 'Adam', discriminator__lr: float = 1e-3, discriminator__optimizer__args: T.List = [], discriminator__optimizer__kwargs: T.Dict = { 'betas': [0., 0.9], 'weight_decay': 0 }, discriminator__frequency: T.Optional[int] = None, discriminator__lr_scheduler: T.Optional[str] = None, discriminator__lr_scheduler__args: T.List = [], discriminator__lr_scheduler__kwargs: T.Dict = {}, add_vgg19: bool = False, *args, **kwargs): super().__init__() self.batch_size = batch_size self.loss = self._parse_loss(loss) self.model = self._parse_model(model, *model__args, **model__kwargs) self.optimizer = self._parse_optimizer(optimizer, self.model, lr, *optimizer__args, **optimizer__kwargs) self.lr_scheduler = self._parse_lr_scheduler(lr_scheduler, self.optimizer, *lr_scheduler__args, **lr_scheduler__kwargs) self.use_discriminator = use_discriminator if self.use_discriminator: self.discriminator = self._parse_model(discriminator, *discriminator__args, **discriminator__kwargs) self.discriminator__optimizer = self._parse_optimizer( discriminator__optimizer, self.discriminator, discriminator__lr, *discriminator__optimizer__args, **discriminator__optimizer__kwargs) self.discriminator__frequency = discriminator__frequency self.discriminator__lr_scheduler = self._parse_lr_scheduler( discriminator__lr_scheduler, self.discriminator__optimizer, *discriminator__lr_scheduler__args, **discriminator__lr_scheduler__kwargs) if add_vgg19: self.vgg19 = Vgg19.Vgg19(requires_grad=False) # Basically some meta info to save within logger self.args = args self.kwargs = kwargs self.save_hyperparameters()
config = args.parse_args() num_classes = config.num_classes base_lr = config.lr cuda = config.cuda num_epochs = config.num_epochs print_iter = config.print_iter model_name = config.model_name prediction_file = config.prediction_file test_file = config.test_file batch = config.batch mode = config.mode # create model model = Vgg19(num_classes=num_classes) if mode == 'test': load_model(model_name, model) if cuda: model = model.cuda() if mode == 'train': # define loss function loss_fn = nn.CrossEntropyLoss() if cuda: loss_fn = loss_fn.cuda() # set optimizer optimizer = Adam(
def main(): parser = argparse.ArgumentParser() parser.add_argument('--input_dir', help = 'Directory containing xxx_i_s and xxx_i_t with same prefix', default = cfg.example_data_dir) parser.add_argument('--save_dir', help = 'Directory to save result', default = cfg.predict_result_dir) parser.add_argument('--checkpoint', help = 'ckpt', default = cfg.ckpt_path) args = parser.parse_args() assert args.input_dir is not None assert args.save_dir is not None assert args.checkpoint is not None print_log('model compiling start.', content_color = PrintColor['yellow']) G = Generator(in_channels = 3).to(device) D1 = Discriminator(in_channels = 6).to(device) D2 = Discriminator(in_channels = 6).to(device) vgg_features = Vgg19().to(device) G_solver = torch.optim.Adam(G.parameters(), lr=cfg.learning_rate, betas = (cfg.beta1, cfg.beta2)) D1_solver = torch.optim.Adam(D1.parameters(), lr=cfg.learning_rate, betas = (cfg.beta1, cfg.beta2)) D2_solver = torch.optim.Adam(D2.parameters(), lr=cfg.learning_rate, betas = (cfg.beta1, cfg.beta2)) checkpoint = torch.load(args.checkpoint) G.load_state_dict(checkpoint['generator']) D1.load_state_dict(checkpoint['discriminator1']) D2.load_state_dict(checkpoint['discriminator2']) G_solver.load_state_dict(checkpoint['g_optimizer']) D1_solver.load_state_dict(checkpoint['d1_optimizer']) D2_solver.load_state_dict(checkpoint['d2_optimizer']) trfms = To_tensor() example_data = example_dataset(data_dir= args.input_dir, transform = trfms) example_loader = DataLoader(dataset = example_data, batch_size = 1, shuffle = False) example_iter = iter(example_loader) print_log('Model compiled.', content_color = PrintColor['yellow']) print_log('Predicting', content_color = PrintColor['yellow']) G.eval() D1.eval() D2.eval() with torch.no_grad(): for step in tqdm(range(len(example_data))): try: inp = example_iter.next() except StopIteration: example_iter = iter(example_loader) inp = example_iter.next() i_t = inp[0].to(device) i_s = inp[1].to(device) name = str(inp[2][0]) o_sk, o_t, o_b, o_f = G(i_t, i_s, (i_t.shape[2], i_t.shape[3])) o_sk = o_sk.squeeze(0).detach().to('cpu') o_t = o_t.squeeze(0).detach().to('cpu') o_b = o_b.squeeze(0).detach().to('cpu') o_f = o_f.squeeze(0).detach().to('cpu') if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) o_sk = F.to_pil_image(o_sk) o_t = F.to_pil_image((o_t + 1)/2) o_b = F.to_pil_image((o_b + 1)/2) o_f = F.to_pil_image((o_f + 1)/2) o_f.save(os.path.join(args.save_dir, name + 'o_f.png'))
num_classes = config.num_classes base_lr = config.lr cuda = config.cuda num_epochs = config.num_epochs print_iter = config.print_iter model_name = config.model_name prediction_file = config.prediction_file best_prediction_file = config.best_prediction_file #DBY batch = config.batch mode = config.mode # create model model = SiameseNetwork() #model = SiameseEfficientNet() model = Vgg19() if mode == 'test': load_model(model_name, model) if cuda: model = model.cuda() # Define 'best loss' - DBY best_loss = 0.1 last_loss = 0 if mode == 'train': # define loss function # loss_fn = nn.CrossEntropyLoss() # if cuda: # loss_fn = loss_fn.cuda()
def __init__(self, config, outdir, modeldir, data_path, sketch_path, flat_path, pretrain_path=None): self.train_config = config["train"] self.data_config = config["dataset"] model_config = config["model"] self.loss_config = config["loss"] self.outdir = outdir self.modeldir = modeldir self.train_type = self.train_config["train_type"] if self.train_type == "multi": self.dataset = DanbooruFacesDataset(data_path, sketch_path, self.data_config["line_method"], self.data_config["extension"], self.data_config["train_size"], self.data_config["valid_size"], self.data_config["color_space"], self.data_config["line_space"]) else: self.dataset = IllustDataset(data_path, sketch_path, flat_path, self.data_config["line_method"], self.data_config["extension"], self.data_config["train_size"], self.data_config["valid_size"], self.data_config["color_space"], self.data_config["line_space"]) print(self.dataset) flat_gen = Generator(model_config["flat_generator"]["in_ch"], num_layers=model_config["flat_generator"]["num_layers"], attn_type=model_config["flat_generator"]["attn_type"], ) self.flat_gen, self.flat_gen_opt = self._setting_model_optim(flat_gen, model_config["flat_generator"]) if self.train_type == "multi": weight = torch.load(pretrain_path) self.flat_gen.load_state_dict(weight) f_dis = Discriminator(model_config["flat_dis"]["in_ch"], model_config["flat_dis"]["multi"]) self.f_dis, self.f_dis_opt = self._setting_model_optim(f_dis, model_config["flat_dis"]) if self.train_type == "multi": bicycle_gen = BicycleGAN(model_config["bicycle_gan"]["in_ch"], latent_dim=model_config["bicycle_gan"]["l_dim"], num_layers=model_config["bicycle_gan"]["num_layers"]) self.b_gen, self.b_gen_opt = self._setting_model_optim(bicycle_gen, model_config["bicycle_gan"]) latent_enc = LatentEncoder(model_config["encoder"]["in_ch"], latent_dim=model_config["encoder"]["l_dim"]) self.l_enc, self.l_enc_opt = self._setting_model_optim(latent_enc, model_config["encoder"]) b_dis = Discriminator(model_config["bicycle_dis"]["in_ch"], model_config["bicycle_dis"]["multi"]) self.b_dis, self.b_dis_opt = self._setting_model_optim(b_dis, model_config["bicycle_dis"]) fixer = ColorFixer() self.fix, self.fix_opt = self._setting_model_optim(fixer, model_config["fixer"]) self.vgg = Vgg19(requires_grad=False) self.vgg.cuda() self.vgg.eval() self.out_filter = GuidedFilter(r=1, eps=1e-2) self.out_filter.cuda() self.lossfunc = DecomposeLossCalculator() self.visualizer = Visualizer(self.data_config["color_space"])
def train(epochs, interval, batchsize, validsize, data_path, sketch_path, extension, img_size, outdir, modeldir, gen_learning_rate, dis_learning_rate, beta1, beta2): # Dataset Definition dataset = IllustDataset(data_path, sketch_path, extension) c_valid, l_valid = dataset.valid(validsize) print(dataset) collator = LineCollator(img_size) # Model & Optimizer Definition model = Style2Paint() model.cuda() model.train() gen_opt = torch.optim.Adam(model.parameters(), lr=gen_learning_rate, betas=(beta1, beta2)) discriminator = Discriminator() discriminator.cuda() discriminator.train() dis_opt = torch.optim.Adam(discriminator.parameters(), lr=dis_learning_rate, betas=(beta1, beta2)) vgg = Vgg19(requires_grad=False) vgg.cuda() vgg.eval() # Loss function definition lossfunc = Style2paintsLossCalculator() # Visualizer definition visualizer = Visualizer() iteration = 0 for epoch in range(epochs): dataloader = DataLoader(dataset, batch_size=batchsize, shuffle=True, collate_fn=collator, drop_last=True) progress_bar = tqdm(dataloader) for index, data in enumerate(progress_bar): iteration += 1 jit, war, line = data # Discriminator update y = model(line, war) loss = lossfunc.adversarial_disloss(discriminator, y.detach(), jit) dis_opt.zero_grad() loss.backward() dis_opt.step() # Generator update y = model(line, war) loss = lossfunc.adversarial_genloss(discriminator, y) loss += 10.0 * lossfunc.content_loss(y, jit) loss += lossfunc.style_and_perceptual_loss(vgg, y, jit) gen_opt.zero_grad() loss.backward() gen_opt.step() if iteration % interval == 1: torch.save(model.state_dict(), f"{modeldir}/model_{iteration}.pt") with torch.no_grad(): y = model(l_valid, c_valid) c = c_valid.detach().cpu().numpy() l = l_valid.detach().cpu().numpy() y = y.detach().cpu().numpy() visualizer(l, c, y, outdir, iteration, validsize) print(f"iteration: {iteration} Loss: {loss.data}")