示例#1
0
    def __init__(self, path: str):
        super(FlickrDataset2, self).__init__()
        self._path = path
        self._class_mapping = {}            # class number -> class label
        self._images = {}
        self.read()
        self._class_mapping_inverted = {v: k for k, v in self._class_mapping.items()}   # class label -> class number

        self._class_groups = self._build_groups()
        self._images_path = list(self._images.keys())
        self.process_image_pipeline = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(256),
            transforms.CenterCrop(224),
            RotationTransform(90),
            ToRGB(),
            # transforms.Normalize(self.mean, self.std),
        ])
if not os.path.isdir(os.path.join(args.name + '_results', 'Reconstruction')):
    os.makedirs(os.path.join(args.name + '_results', 'Reconstruction'))
if not os.path.isdir(os.path.join(args.name + '_results', 'Transfer')):
    os.makedirs(os.path.join(args.name + '_results', 'Transfer'))

# edge-promoting
if not os.path.isdir(os.path.join('data', args.tgt_data, 'pair')):
    print('edge-promoting start!!')
    edge_promoting(os.path.join('data', args.tgt_data, 'train'),
                   os.path.join('data', args.tgt_data, 'pair'))
else:
    print('edge-promoting already done')

# data_loader
src_transform = transforms.Compose([
    ToRGB(),
    transforms.Resize((args.input_size, args.input_size)),
    transforms.ToTensor(),
    RGBToBGR(),
    Zero(),
])

tgt_transform = transforms.Compose([
    ToRGB(),
    transforms.Resize(args.input_size),
    transforms.ToTensor(),
    RGBToBGR(),
    Zero(),
])

src_transform_test = transforms.Compose([
示例#3
0
    torch.backends.cudnn.benchmark = True

G = networks.Transformer()
if torch.cuda.is_available():
    G.load_state_dict(torch.load(args.pre_trained_model))
else:
    # cpu mode
    G.load_state_dict(
        torch.load(args.pre_trained_model,
                   map_location=lambda storage, loc: storage))

G.to(device)
G.eval()

src_transform = transforms.Compose([
    ToRGB(),
    RatioedResize(args.input_size),
    transforms.ToTensor(),
    RGBToBGR(),
    Zero(),
])
# utils.data_load(os.path.join('data', args.src_data), 'test', src_transform, 1, shuffle=True, drop_last=True)
image_src = utils.data_load(os.path.join(args.image_dir),
                            'test',
                            src_transform,
                            1,
                            shuffle=True,
                            drop_last=True)

with torch.no_grad():
    G.eval()