示例#1
0
args = parser.parse_args()


def forward(imgs, gt, config):
    x_hr = imgs
    gt_hr = gt
    return config.model(x_hr, x_hr)


dataset_path = args.path_imgs[0]
dataset_smooth_path = args.path_smooth[0]

# List of the name of all the images in the dataset_path
image_list = [
    f for f in listdir(dataset_path) if isfile(join(dataset_path, f))
]
NumImg = len(image_list)

# Configuration
config = copy.deepcopy(default_config)
config.N_EPOCH = 100
# model
config.model = DeepGuidedFilter()
config.forward = forward
config.clip = 0.01

# Run the attack for each image
for idx in tqdm(range(NumImg)):
    run(config, dataset_path, dataset_smooth_path, image_list, idx,
        args.adv_model)
示例#2
0
parser.add_argument('--low_size',  type=int, default=  64, help='LOW_SIZE')
parser.add_argument('--full_size', type=int, default=2048, help='FULL_SIZE')
parser.add_argument('--iter_size', type=int, default= 100, help='TOTAL_ITER')
parser.add_argument('--model_id',  type=int, default=   0, help='MODEL_ID')
args = parser.parse_args()

SAVE_FOLDER = 'time'
GPU = args.gpu
LOW_SIZE = args.low_size
FULL_SIZE = args.full_size
TOTAL_ITER = args.iter_size
MODEL_ID = args.model_id

# model - forward
model_forward = [
    ('deep_guided_filter', (DeepGuidedFilter(), lambda model, imgs: model(imgs[0], imgs[1]))),
    ('deep_guided_filter_layer', (FastGuidedFilter(1, 1e-8), lambda model, imgs: model(imgs[0], imgs[0], imgs[1]))),
    ('deep_guided_filter_advanced', (DeepGuidedFilterAdvanced(), lambda model, imgs: model(imgs[0], imgs[1]))),
    ('deep_conv_guided_filter_layer', (ConvGuidedFilter(1, AdaptiveNorm), lambda model, imgs: model(imgs[0], imgs[0], imgs[1]))),
    ('deep_conv_guided_filter', (DeepGuidedFilterConvGF(), lambda model, imgs: model(imgs[0], imgs[1]))),
    ('deep_conv_guided_filter_adv', (DeepGuidedFilterGuidedMapConvGF(), lambda model, imgs: model(imgs[0], imgs[1])))
]

# mkdir
if not os.path.isdir(SAVE_FOLDER):
    os.makedirs(SAVE_FOLDER)

# prepare img
imgs = [torch.rand((1, 3, LOW_SIZE, LOW_SIZE)), torch.rand((1, 3, FULL_SIZE, FULL_SIZE))]
if GPU >= 0:
    with torch.cuda.device(GPU):
示例#3
0
img_list = []
if args.img_path is not None:
    img_list.append(args.img_path)
if args.img_list is not None:
    with open(args.img_list) as f:
        for line in f:
            img_list.append(line.strip())
assert len(img_list) > 0

# Save Folder
if not os.path.isdir(args.save_folder):
    os.makedirs(args.save_folder)

# Model
if args.model in ['guided_filter', 'deep_guided_filter']:
    model = DeepGuidedFilter()
elif args.model == 'deep_guided_filter_advanced':
    model = DeepGuidedFilterAdvanced()
else:
    print('Not a valid model!')
    exit(-1)

model2name = {
    'guided_filter': 'lr',
    'deep_guided_filter': 'hr',
    'deep_guided_filter_advanced': 'hr_ad'
}
model_path = os.path.join('models', args.task,
                          '{}_net_latest.pth'.format(model2name[args.model]))

if args.model in ['deep_guided_filter', 'deep_guided_filter_advanced']:
示例#4
0
parser = argparse.ArgumentParser(description='Evaluate Deep Guided Filtering Networks')
parser.add_argument('--task',  type=str, default='l0_smooth',          help='TASK')
parser.add_argument('--name',  type=str, default='HR',                 help='NAME')
parser.add_argument('--model', type=str, default='deep_guided_filter', help='model')
args = parser.parse_args()

config = copy.deepcopy(default_config)

config.TASK = args.task
config.NAME = args.name
config.SET_NAME = 1024

# model
if args.model in ['guided_filter', 'deep_guided_filter']:
    model = DeepGuidedFilter()
elif args.model == 'deep_guided_filter_advanced':
    model = DeepGuidedFilterAdvanced()
elif args.model == 'deep_conv_guided_filter':
    model = DeepGuidedFilterConvGF()
elif args.model == 'deep_conv_guided_filter_adv':
    model = DeepGuidedFilterGuidedMapConvGF()
else:
    print('Not a valid model!')
    exit(-1)

if args.model == 'guided_filter':
    model.init_lr(os.path.join(config.MODEL_PATH, config.TASK, 'LR', 'snapshots', 'net_latest.pth'))
else:
    model.load_state_dict(
        torch.load(