def main(args): # Dataset setting val_dataset = InpaintDataset( args.source_dir, # {'val':args.mask_list}, mode='val') val_loader = val_dataset.loader(batch_size=1, shuffle=False, num_workers=1) ### Generate a new val data # Define the Network Structure whole_model_path = args.model_path nets = torch.load(whole_model_path) # netG_state_dict = nets['netG_state_dict'] netG_state_dict, netD_state_dict = nets['netG_state_dict'], nets[ 'netD_state_dict'] netG = InpaintSANet() # load_consistent_state_dict(netG_state_dict, netG) netG.load_state_dict(netG_state_dict) # mse, ssim = validate(netG, val_loader, args.img_shape, args.result_dir, args.gt_dir)
def main(): logger_init() dataset_type = config.DATASET batch_size = config.BATCH_SIZE # Dataset setting logger.info("Initialize the dataset...") train_dataset = InpaintDataset(config.DATA_FLIST[dataset_type][0], \ {mask_type: config.DATA_FLIST[config.MASKDATASET][mask_type][0] for mask_type in config.MASK_TYPES}, \ resize_shape=tuple(config.IMG_SHAPES), random_bbox_shape=config.RANDOM_BBOX_SHAPE, \ random_bbox_margin=config.RANDOM_BBOX_MARGIN, random_ff_setting=config.RANDOM_FF_SETTING) train_loader = train_dataset.loader(batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True) val_dataset = InpaintDataset(config.DATA_FLIST[dataset_type][1], \ {mask_type: config.DATA_FLIST[config.MASKDATASET][mask_type][1] for mask_type in ('val',)}, \ resize_shape=tuple(config.IMG_SHAPES), random_bbox_shape=config.RANDOM_BBOX_SHAPE, \ random_bbox_margin=config.RANDOM_BBOX_MARGIN, random_ff_setting=config.RANDOM_FF_SETTING) val_loader = val_dataset.loader(batch_size=1, shuffle=False, num_workers=1) # print(len(val_loader)) ### Generate a new val data val_datas = [] j = 0 for i, data in enumerate(val_loader): if j < config.STATIC_VIEW_SIZE: imgs = data[0] if imgs.size(1) == 3: val_datas.append(data) j += 1 else: break # val_datas = [(imgs, masks) for imgs, masks in val_loader] val_loader = val_dataset.loader(batch_size=1, shuffle=False, num_workers=1) logger.info("Finish the dataset initialization.") # Define the Network Structure logger.info("Define the Network Structure and Losses") netG = InpaintSANet() netD = InpaintSADirciminator() if config.MODEL_RESTORE != '': whole_model_path = 'model_logs/{}'.format(config.MODEL_RESTORE) nets = torch.load(whole_model_path) netG_state_dict, netD_state_dict = nets['netG_state_dict'], nets[ 'netD_state_dict'] netG.load_state_dict(netG_state_dict) netD.load_state_dict(netD_state_dict) logger.info("Loading pretrained models from {} ...".format( config.MODEL_RESTORE)) # Define loss recon_loss = ReconLoss(*(config.L1_LOSS_ALPHA)) gan_loss = SNGenLoss(config.GAN_LOSS_ALPHA) dis_loss = SNDisLoss() lr, decay = config.LEARNING_RATE, config.WEIGHT_DECAY optG = torch.optim.Adam(netG.parameters(), lr=lr, weight_decay=decay) optD = torch.optim.Adam(netD.parameters(), lr=4 * lr, weight_decay=decay) logger.info("Finish Define the Network Structure and Losses") # Start Training logger.info("Start Training...") epoch = 50 for i in range(epoch): # validate(netG, netD, gan_loss, recon_loss, dis_loss, optG, optD, val_loader, i, device=cuda0) # train data train(netG, netD, gan_loss, recon_loss, dis_loss, optG, optD, train_loader, i, device=cuda0, val_datas=val_datas) # validate validate(netG, netD, gan_loss, recon_loss, dis_loss, optG, optD, val_datas, i, device=cuda0) saved_model = { 'epoch': i + 1, 'netG_state_dict': netG.to(cpu0).state_dict(), 'netD_state_dict': netD.to(cpu0).state_dict(), # 'optG' : optG.state_dict(), # 'optD' : optD.state_dict() } torch.save(saved_model, '{}/epoch_{}_ckpt.pth.tar'.format(log_dir, i + 1)) torch.save(saved_model, '{}/latest_ckpt.pth.tar'.format(log_dir, i + 1))