config.TASK = args.task config.NAME = args.name config.N_EPOCH = 150 config.DATA_SET = 512 config.keep_training = args.keep_training # model if args.model == 'deep_guided_filter': config.model = DeepGuidedFilter().cuda() if config.keep_training == True: config.model.init_lr( os.path.join('checkpoints', config.TASK, config.NAME, 'snapshots/net_epoch_54.pth')) elif args.model == 'deep_guided_filter_advanced': config.model = DeepGuidedFilterAdvanced() elif args.model == 'deep_conv_guided_filter': config.model = DeepGuidedFilterConvGF() elif args.model == 'deep_conv_guided_filter_adv': config.model = DeepGuidedFilterGuidedMapConvGF() else: print('Not a valid model!') exit(-1) def forward(imgs, config): x_hr, gt_hr, x_lr = imgs[:3] if config.GPU >= 0: with torch.cuda.device(config.GPU): x_hr, gt_hr, x_lr = x_hr.cuda(), gt_hr.cuda(), x_lr.cuda() return config.model(Variable(x_lr), Variable(x_hr)), gt_hr
parser.add_argument('--iter_size', type=int, default= 100, help='TOTAL_ITER') parser.add_argument('--model_id', type=int, default= 0, help='MODEL_ID') args = parser.parse_args() SAVE_FOLDER = 'time' GPU = args.gpu LOW_SIZE = args.low_size FULL_SIZE = args.full_size TOTAL_ITER = args.iter_size MODEL_ID = args.model_id # model - forward model_forward = [ ('deep_guided_filter', (DeepGuidedFilter(), lambda model, imgs: model(imgs[0], imgs[1]))), ('deep_guided_filter_layer', (FastGuidedFilter(1, 1e-8), lambda model, imgs: model(imgs[0], imgs[0], imgs[1]))), ('deep_guided_filter_advanced', (DeepGuidedFilterAdvanced(), lambda model, imgs: model(imgs[0], imgs[1]))), ('deep_conv_guided_filter_layer', (ConvGuidedFilter(1, AdaptiveNorm), lambda model, imgs: model(imgs[0], imgs[0], imgs[1]))), ('deep_conv_guided_filter', (DeepGuidedFilterConvGF(), lambda model, imgs: model(imgs[0], imgs[1]))), ('deep_conv_guided_filter_adv', (DeepGuidedFilterGuidedMapConvGF(), lambda model, imgs: model(imgs[0], imgs[1]))) ] # mkdir if not os.path.isdir(SAVE_FOLDER): os.makedirs(SAVE_FOLDER) # prepare img imgs = [torch.rand((1, 3, LOW_SIZE, LOW_SIZE)), torch.rand((1, 3, FULL_SIZE, FULL_SIZE))] if GPU >= 0: with torch.cuda.device(GPU): imgs = [img.cuda() for img in imgs] imgs = [Variable(img, requires_grad=False) for img in imgs]