def main(): logger_init() dataset_type = config.DATASET batch_size = config.BATCH_SIZE # Dataset setting logger.info("Initialize the dataset...") val_dataset = InpaintDataset(config.DATA_FLIST[dataset_type][1], \ {mask_type: config.DATA_FLIST[config.MASKDATASET][mask_type][1] for mask_type in ('val',)}, \ resize_shape=tuple(config.IMG_SHAPES), random_bbox_shape=config.RANDOM_BBOX_SHAPE, \ random_bbox_margin=config.RANDOM_BBOX_MARGIN, random_ff_setting=config.RANDOM_FF_SETTING) val_loader = val_dataset.loader(batch_size=1, shuffle=False, num_workers=1) # print(len(val_loader)) ### Generate a new val data logger.info("Finish the dataset initialization.") # Define the Network Structure logger.info("Define the Network Structure and Losses") whole_model_path = 'model_logs/{}'.format(config.MODEL_RESTORE) nets = torch.load(whole_model_path) netG_state_dict, netD_state_dict = nets['netG_state_dict'], nets[ 'netD_state_dict'] if config.NETWORK_TYPE == "l2h_unet": netG = InpaintRUNNet(n_in_channel=config.N_CHANNEL) netG.load_state_dict(netG_state_dict) elif config.NETWORK_TYPE == 'sa_gated': netG = InpaintSANet() load_consistent_state_dict(netG_state_dict, netG) # netG.load_state_dict(netG_state_dict) netD = InpaintSADirciminator() netVGG = vgg16_bn(pretrained=True) # netD.load_state_dict(netD_state_dict) logger.info("Loading pretrained models from {} ...".format( config.MODEL_RESTORE)) # Define loss recon_loss = ReconLoss(*(config.L1_LOSS_ALPHA)) gan_loss = SNGenLoss(config.GAN_LOSS_ALPHA) perc_loss = PerceptualLoss(weight=config.PERC_LOSS_ALPHA, feat_extractors=netVGG.to(cuda1)) style_loss = StyleLoss(weight=config.STYLE_LOSS_ALPHA, feat_extractors=netVGG.to(cuda1)) dis_loss = SNDisLoss() lr, decay = config.LEARNING_RATE, config.WEIGHT_DECAY optG = torch.optim.Adam(netG.parameters(), lr=lr, weight_decay=decay) optD = torch.optim.Adam(netD.parameters(), lr=4 * lr, weight_decay=decay) nets = {"netG": netG, "netD": netD, "vgg": netVGG} losses = { "GANLoss": gan_loss, "ReconLoss": recon_loss, "StyleLoss": style_loss, "DLoss": dis_loss, "PercLoss": perc_loss } opts = { "optG": optG, "optD": optD, } logger.info("Finish Define the Network Structure and Losses") # Start Training logger.info("Start Validation") validate(nets, losses, opts, val_loader, 0, config.NETWORK_TYPE, devices=(cuda0, cuda1))
def main(args): if not os.path.exists(args.logdir): os.makedirs(args.logdir) dataset_type = args.dataset # Dataset setting train_dataset = InpaintDataset(args.train_image_list,\ {'val':args.train_mask_list}, mode='train', img_size=args.img_shape) train_loader = train_dataset.loader(batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True) val_dataset = InpaintDataset(args.val_image_list,\ {'val':args.val_mask_list}, # {'val':args.val_mask_list}, mode='val', img_size=args.img_shape) val_loader = val_dataset.loader(batch_size=1, shuffle=False, num_workers=1) # Define the Network Structure netG = InpaintSANet() netD = InpaintSADirciminator() netG.cuda() netD.cuda() if args.load_weights != '': whole_model_path = args.load_weights nets = torch.load(whole_model_path) netG_state_dict, netD_state_dict = nets['netG_state_dict'], nets[ 'netD_state_dict'] # netG.load_state_dict(netG_state_dict) load_consistent_state_dict(netG_state_dict, netG) netD.load_state_dict(netD_state_dict) # Define loss recon_loss = ReconLoss(*([1.2, 1.2, 1.2, 1.2])) gan_loss = SNGenLoss(0.005) dis_loss = SNDisLoss() lr, decay = args.learning_rate, 0.0 optG = torch.optim.Adam(netG.parameters(), lr=lr, weight_decay=decay) optD = torch.optim.Adam(netD.parameters(), lr=4 * lr, weight_decay=decay) best_score = 0 # Create loss and acc file loss_writer = csv.writer(open(os.path.join(args.logdir, 'loss.csv'), 'w'), delimiter=',') acc_writer = csv.writer(open(os.path.join(args.logdir, 'acc.csv'), 'w'), delimiter=',') # Start Training for i in range(args.epochs): #train data train(netG, netD, gan_loss, recon_loss, dis_loss, optG, optD, train_loader, i + 1, args.img_shape, loss_writer) # validate output_dir = os.path.join(args.result_dir, str(i + 1)) mse, ssim = validate(netG, val_loader, args.img_shape, output_dir, args.gt_dir) score = 1 - mse / 100 + ssim print('MSE: ', mse, ' SSIM:', ssim, ' SCORE:', score) acc_writer.writerow([i + 1, mse, ssim, score]) saved_model = { 'epoch': i + 1, 'netG_state_dict': netG.state_dict(), 'netD_state_dict': netD.state_dict(), # 'optG' : optG.state_dict(), # 'optD' : optD.state_dict() } torch.save(saved_model, '{}/epoch_{}_ckpt.pth.tar'.format(args.logdir, i + 1)) if score > best_score: torch.save(saved_model, '{}/best_ckpt.pth.tar'.format(args.logdir, i + 1)) best_score = score print('New best score at epoch', i + 1)