コード例 #1
0
 def __init__(self, config: Config):
     config = config.overwrite(config[config.dataset])
     config = config.overwrite(Config.read_from_cli())
     config.exp_name = f'zsl_{config.dataset}_{config.hp.compute_hash()}_{config.random_seed}'
     if not config.get('silent'):
         print(config.hp)
     self.random = np.random.RandomState(config.random_seed)
     super().__init__(config)
コード例 #2
0
def get_transform(config: Config) -> Callable:
    if config.name == 'mnist':
        return transforms.Compose([
            transforms.Resize(config.target_img_size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
    elif config.name in {'cifar10', 'single_image'}:
        return transforms.Compose([
            transforms.Resize(
                (config.target_img_size, config.target_img_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])
    elif config.name.startswith('lsun_') or config.name in {
            'ffhq_thumbs', 'celeba_thumbs', 'ffhq_256', 'ffhq_1024'
    }:
        if config.get('concat_patches.enabled'):
            return transforms.Compose([
                CenterCropToMin(),
                transforms.RandomHorizontalFlip(),
                PatchConcatAndResize(config.target_img_size,
                                     config.concat_patches.ratio),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ])
        else:
            return transforms.Compose([
                CenterCropToMin(),
                transforms.RandomHorizontalFlip(),
                transforms.Resize(config.target_img_size),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ])
    elif config.name == 'imagenet_vs':
        return transforms.Compose([
            PadToSquare(),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
                                 inplace=True),
        ])
    else:
        raise NotImplementedError(f'Unknown dataset: {config.name}')
コード例 #3
0
    def __init__(self, config: Config):
        super(ConvModel, self).__init__()

        conv_sizes = config.get('conv_sizes', [1, 8, 32, 64])
        dense_sizes = config.get('dense_sizes', [576, 128])
        use_bn = config.get('use_bn', False)
        use_dropout = config.get('use_dropout', False)
        use_maxpool = config.get('use_maxpool', False)
        use_skip_connection = config.get('use_skip_connection', False)
        activation = config.get('activation', 'relu')
        adaptive_pool_size = config.get('adaptive_pool_size', (4, 4))

        if activation == 'relu':
            self.activation = lambda: nn.ReLU(inplace=True)
        elif activation == 'selu':
            self.activation = lambda: nn.SELU(inplace=True)
        elif activation == 'tanh':
            self.activation = lambda: nn.Tanh()
        elif activation == 'sigmoid':
            self.activation = lambda: nn.Sigmoid()
        else:
            raise NotImplementedError(
                f'Unknown activation function: {activation}')

        conv_body = nn.Sequential(*[
            ConvBlock(conv_sizes[i], conv_sizes[i+1], use_bn, use_skip_connection, \
                      use_maxpool, self.activation) for i in range(len(conv_sizes) - 1)])
        dense_head = nn.Sequential(*[
            self._create_dense_block(dense_sizes[i], dense_sizes[
                i + 1], use_dropout) for i in range(len(dense_sizes) - 1)
        ])

        self.nn = nn.Sequential(conv_body,
                                nn.AdaptiveAvgPool2d(adaptive_pool_size),
                                Flatten(), dense_head,
                                nn.Linear(dense_sizes[-1], 10))
コード例 #4
0
def load_data(
    config: Config,
    img_target_shape: Tuple[int, int] = None
) -> Tuple[ImageDataset, ImageDataset, np.ndarray]:
    if config.name == 'CUB':
        ds_train = cub.load_dataset(config.dir,
                                    split='train',
                                    target_shape=img_target_shape,
                                    in_memory=config.get('in_memory', False))
        ds_test = cub.load_dataset(config.dir,
                                   split='test',
                                   target_shape=img_target_shape,
                                   in_memory=config.get('in_memory', False))
        class_attributes = cub.load_class_attributes(config.dir).astype(
            np.float32)
    elif config.name == 'CUB_EMBEDDINGS':
        ds_train = feats.load_dataset(config.dir,
                                      config.input_type,
                                      split='train')
        ds_test = feats.load_dataset(config.dir,
                                     config.input_type,
                                     split='test')
        class_attributes = cub.load_class_attributes(config.dir).astype(
            np.float32)
    elif config.name == 'AWA':
        ds_train = awa.load_dataset(config.dir,
                                    split='train',
                                    target_shape=img_target_shape)
        ds_test = awa.load_dataset(config.dir,
                                   split='test',
                                   target_shape=img_target_shape)
        class_attributes = awa.load_class_attributes(config.dir).astype(
            np.float32)
    elif config.name == 'SUN':
        ds_train = sun.load_dataset(config.dir,
                                    split='train',
                                    target_shape=img_target_shape)
        ds_test = sun.load_dataset(config.dir,
                                   split='val',
                                   target_shape=img_target_shape)
        class_attributes = sun.load_class_attributes(config.dir).astype(
            np.float32)
    elif config.name == 'TinyImageNet':
        ds_train = tiny_imagenet.load_dataset(config.dir,
                                              split='train',
                                              target_shape=img_target_shape)
        ds_test = tiny_imagenet.load_dataset(config.dir,
                                             split='val',
                                             target_shape=img_target_shape)
        class_attributes = None
    elif config.name in SIMPLE_LOADERS.keys():
        ds_train = SIMPLE_LOADERS[config.name](config.dir, split='train')
        ds_test = SIMPLE_LOADERS[config.name](config.dir, split='test')
        class_attributes = None
    elif config.name.endswith('EMBEDDINGS'):
        ds_train = feats.load_dataset(config.dir,
                                      config.input_type,
                                      split='train')
        ds_test = feats.load_dataset(config.dir,
                                     config.input_type,
                                     split='test')
        class_attributes = None
    else:
        raise NotImplementedError(f'Unkown dataset: {config.name}')

    # if embed_data:
    #     ds_train = extract_resnet_features_for_dataset(ds_train, input_type=18)
    #     ds_test = extract_resnet_features_for_dataset(ds_test, input_type=18)

    # np.save(f'/tmp/{config.name}_train', ds_train)
    # np.save(f'/tmp/{config.name}_test', ds_test)
    # ds_train = np.load(f'/tmp/{config.name}_train.npy', allow_pickle=True)
    # ds_test = np.load(f'/tmp/{config.name}_test.npy', allow_pickle=True)

    return ds_train, ds_test, class_attributes