Beispiel #1
0
    config.TASK = args.task
    config.NAME = args.name
    config.N_EPOCH = 150
    config.DATA_SET = 512
    config.keep_training = args.keep_training

    # model
    if args.model == 'deep_guided_filter':
        config.model = DeepGuidedFilter().cuda()
        if config.keep_training == True:
            config.model.init_lr(
                os.path.join('checkpoints', config.TASK, config.NAME,
                             'snapshots/net_epoch_54.pth'))
    elif args.model == 'deep_guided_filter_advanced':
        config.model = DeepGuidedFilterAdvanced()
    elif args.model == 'deep_conv_guided_filter':
        config.model = DeepGuidedFilterConvGF()
    elif args.model == 'deep_conv_guided_filter_adv':
        config.model = DeepGuidedFilterGuidedMapConvGF()
    else:
        print('Not a valid model!')
        exit(-1)

    def forward(imgs, config):
        x_hr, gt_hr, x_lr = imgs[:3]
        if config.GPU >= 0:
            with torch.cuda.device(config.GPU):
                x_hr, gt_hr, x_lr = x_hr.cuda(), gt_hr.cuda(), x_lr.cuda()

        return config.model(Variable(x_lr), Variable(x_hr)), gt_hr
parser.add_argument('--iter_size', type=int, default= 100, help='TOTAL_ITER')
parser.add_argument('--model_id',  type=int, default=   0, help='MODEL_ID')
args = parser.parse_args()

SAVE_FOLDER = 'time'
GPU = args.gpu
LOW_SIZE = args.low_size
FULL_SIZE = args.full_size
TOTAL_ITER = args.iter_size
MODEL_ID = args.model_id

# model - forward
model_forward = [
    ('deep_guided_filter', (DeepGuidedFilter(), lambda model, imgs: model(imgs[0], imgs[1]))),
    ('deep_guided_filter_layer', (FastGuidedFilter(1, 1e-8), lambda model, imgs: model(imgs[0], imgs[0], imgs[1]))),
    ('deep_guided_filter_advanced', (DeepGuidedFilterAdvanced(), lambda model, imgs: model(imgs[0], imgs[1]))),
    ('deep_conv_guided_filter_layer', (ConvGuidedFilter(1, AdaptiveNorm), lambda model, imgs: model(imgs[0], imgs[0], imgs[1]))),
    ('deep_conv_guided_filter', (DeepGuidedFilterConvGF(), lambda model, imgs: model(imgs[0], imgs[1]))),
    ('deep_conv_guided_filter_adv', (DeepGuidedFilterGuidedMapConvGF(), lambda model, imgs: model(imgs[0], imgs[1])))
]

# mkdir
if not os.path.isdir(SAVE_FOLDER):
    os.makedirs(SAVE_FOLDER)

# prepare img
imgs = [torch.rand((1, 3, LOW_SIZE, LOW_SIZE)), torch.rand((1, 3, FULL_SIZE, FULL_SIZE))]
if GPU >= 0:
    with torch.cuda.device(GPU):
        imgs = [img.cuda() for img in imgs]
imgs = [Variable(img, requires_grad=False) for img in imgs]