Exemple #1
0
    def __init__(self,
                 in_planes: int,
                 planes: int,
                 stride: int = 1,
                 num_splits: int = 0):

        super(BasicBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_planes,
                               planes,
                               kernel_size=3,
                               stride=stride,
                               padding=1,
                               bias=False)
        self.bn1 = get_norm_layer(planes, num_splits)

        self.conv2 = nn.Conv2d(planes,
                               planes,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=False)
        self.bn2 = get_norm_layer(planes, num_splits)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes,
                          self.expansion * planes,
                          kernel_size=1,
                          stride=stride,
                          bias=False),
                get_norm_layer(self.expansion * planes, num_splits))
Exemple #2
0
def _get_features_and_projections(resnet, num_ftrs, out_dim, num_splits):
    """Removes classification head from the ResNet and adds a projection head.

    - Adds a batchnorm layer to the input layer.
    - Replaces the output layer by a Conv2d followed by adaptive average pool.
    - Adds a 2-layer mlp projection head.

    """

    # get the number of features from the last channel
    last_conv_channels = list(resnet.children())[-1].in_features

    # replace output layer
    features = nn.Sequential(
        get_norm_layer(3, num_splits),
        *list(resnet.children())[:-1],
        nn.Conv2d(last_conv_channels, num_ftrs, 1),
        nn.AdaptiveAvgPool2d(1),
    )

    # 2-layer mlp projection head
    projection_head = nn.Sequential(nn.Linear(num_ftrs, num_ftrs), nn.ReLU(),
                                    nn.Linear(num_ftrs, out_dim))

    return features, projection_head
Exemple #3
0
    def __init__(self,
                 block: nn.Module = BasicBlock,
                 layers: List[int] = [2, 2, 2, 2],
                 num_classes: int = 10,
                 width: float = 1.,
                 num_splits: int = 0):

        super(ResNet, self).__init__()
        self.in_planes = int(64 * width)

        self.base = int(64 * width)

        self.conv1 = nn.Conv2d(3,
                               self.base,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=False)
        self.bn1 = get_norm_layer(self.base, num_splits)
        self.layer1 = self._make_layer(block,
                                       self.base,
                                       layers[0],
                                       stride=1,
                                       num_splits=num_splits)
        self.layer2 = self._make_layer(block,
                                       self.base * 2,
                                       layers[1],
                                       stride=2,
                                       num_splits=num_splits)
        self.layer3 = self._make_layer(block,
                                       self.base * 4,
                                       layers[2],
                                       stride=2,
                                       num_splits=num_splits)
        self.layer4 = self._make_layer(block,
                                       self.base * 8,
                                       layers[3],
                                       stride=2,
                                       num_splits=num_splits)
        self.linear = nn.Linear(self.base * 8 * block.expansion, num_classes)
Exemple #4
0
def get_model_from_config(cfg,
                          is_cli_call: bool = False
                          ) -> SelfSupervisedEmbedding:
    checkpoint = cfg['checkpoint']
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    if not checkpoint:
        checkpoint, key = get_ptmodel_from_config(cfg['model'])
        if not checkpoint:
            msg = 'Cannot download checkpoint for key {} '.format(key)
            msg += 'because it does not exist!'
            raise RuntimeError(msg)
        state_dict = load_state_dict_from_url(
            checkpoint, map_location=device)['state_dict']
    else:
        checkpoint = fix_input_path(checkpoint) if is_cli_call else checkpoint
        state_dict = torch.load(checkpoint, map_location=device)['state_dict']

    # load model
    resnet = ResNetGenerator(cfg['model']['name'], cfg['model']['width'])
    last_conv_channels = list(resnet.children())[-1].in_features
    features = nn.Sequential(
        get_norm_layer(3, 0),
        *list(resnet.children())[:-1],
        nn.Conv2d(last_conv_channels, cfg['model']['num_ftrs'], 1),
        nn.AdaptiveAvgPool2d(1),
    )

    model = _SimCLR(features,
                    num_ftrs=cfg['model']['num_ftrs'],
                    out_dim=cfg['model']['out_dim']).to(device)

    if state_dict is not None:
        load_from_state_dict(model, state_dict)

    encoder = SelfSupervisedEmbedding(model, None, None, None)
    return encoder
Exemple #5
0
def _train_cli(cfg, is_cli_call=True):

    input_dir = cfg['input_dir']
    if input_dir and is_cli_call:
        input_dir = fix_input_path(input_dir)

    if 'seed' in cfg.keys():
        seed = cfg['seed']
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    if cfg["trainer"]["weights_summary"] == "None":
        cfg["trainer"]["weights_summary"] = None

    if torch.cuda.is_available():
        device = 'cuda'
    elif cfg['trainer'] and cfg['trainer']['gpus']:
        device = 'cpu'
        cfg['trainer']['gpus'] = 0
    else:
        device = 'cpu'

    distributed_strategy = None
    if cfg['trainer']['gpus'] > 1:
        distributed_strategy = 'ddp'

    if cfg['loader']['batch_size'] < 64:
        msg = 'Training a self-supervised model with a small batch size: {}! '
        msg = msg.format(cfg['loader']['batch_size'])
        msg += 'Small batch size may harm embedding quality. '
        msg += 'You can specify the batch size via the loader key-word: '
        msg += 'loader.batch_size=BSZ'
        warnings.warn(msg)

    # determine the number of available cores
    if cfg['loader']['num_workers'] < 0:
        cfg['loader']['num_workers'] = cpu_count()

    state_dict = None
    checkpoint = cfg['checkpoint']
    if cfg['pre_trained'] and not checkpoint:
        # if checkpoint wasn't specified explicitly and pre_trained is True
        # try to load the checkpoint from the model zoo
        checkpoint, key = get_ptmodel_from_config(cfg['model'])
        if not checkpoint:
            msg = 'Cannot download checkpoint for key {} '.format(key)
            msg += 'because it does not exist! '
            msg += 'Model will be trained from scratch.'
            warnings.warn(msg)
    elif checkpoint:
        checkpoint = fix_input_path(checkpoint) if is_cli_call else checkpoint

    if checkpoint:
        # load the PyTorch state dictionary and map it to the current device
        if is_url(checkpoint):
            state_dict = load_state_dict_from_url(
                checkpoint, map_location=device)['state_dict']
        else:
            state_dict = torch.load(checkpoint,
                                    map_location=device)['state_dict']

    # load model
    resnet = ResNetGenerator(cfg['model']['name'], cfg['model']['width'])
    last_conv_channels = list(resnet.children())[-1].in_features
    features = nn.Sequential(
        get_norm_layer(3, 0),
        *list(resnet.children())[:-1],
        nn.Conv2d(last_conv_channels, cfg['model']['num_ftrs'], 1),
        nn.AdaptiveAvgPool2d(1),
    )

    model = _SimCLR(features,
                    num_ftrs=cfg['model']['num_ftrs'],
                    out_dim=cfg['model']['out_dim'])
    if state_dict is not None:
        load_from_state_dict(model, state_dict)

    criterion = NTXentLoss(**cfg['criterion'])
    optimizer = torch.optim.SGD(model.parameters(), **cfg['optimizer'])

    dataset = LightlyDataset(input_dir)

    cfg['loader']['batch_size'] = min(cfg['loader']['batch_size'],
                                      len(dataset))

    collate_fn = ImageCollateFunction(**cfg['collate'])
    dataloader = torch.utils.data.DataLoader(dataset,
                                             **cfg['loader'],
                                             collate_fn=collate_fn)

    encoder = SelfSupervisedEmbedding(model, criterion, optimizer, dataloader)
    encoder.init_checkpoint_callback(**cfg['checkpoint_callback'])
    encoder.train_embedding(**cfg['trainer'], strategy=distributed_strategy)

    print(
        f'Best model is stored at: {bcolors.OKBLUE}{encoder.checkpoint}{bcolors.ENDC}'
    )
    os.environ[cfg['environment_variable_names']
               ['lightly_last_checkpoint_path']] = encoder.checkpoint
    return encoder.checkpoint
Exemple #6
0
def _embed_cli(cfg, is_cli_call=True):

    checkpoint = cfg['checkpoint']

    input_dir = cfg['input_dir']
    if input_dir and is_cli_call:
        input_dir = fix_input_path(input_dir)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((cfg['collate']['input_size'],
                                       cfg['collate']['input_size'])),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])
    ])

    dataset = LightlyDataset(input_dir, transform=transform)

    cfg['loader']['drop_last'] = False
    cfg['loader']['shuffle'] = False
    cfg['loader']['batch_size'] = min(
        cfg['loader']['batch_size'],
        len(dataset)
    )
    dataloader = torch.utils.data.DataLoader(dataset, **cfg['loader'])

    # load the PyTorch state dictionary and map it to the current device    
    state_dict = None
    if not checkpoint:
        checkpoint, key = get_ptmodel_from_config(cfg['model'])
        if not checkpoint:
            msg = 'Cannot download checkpoint for key {} '.format(key)
            msg += 'because it does not exist!'
            raise RuntimeError(msg)
        state_dict = load_state_dict_from_url(
            checkpoint, map_location=device
        )['state_dict']
    else:
        checkpoint = fix_input_path(checkpoint) if is_cli_call else checkpoint
        state_dict = torch.load(
            checkpoint, map_location=device
        )['state_dict']

    # load model
    resnet = ResNetGenerator(cfg['model']['name'], cfg['model']['width'])
    last_conv_channels = list(resnet.children())[-1].in_features
    features = nn.Sequential(
        get_norm_layer(3, 0),
        *list(resnet.children())[:-1],
        nn.Conv2d(last_conv_channels, cfg['model']['num_ftrs'], 1),
        nn.AdaptiveAvgPool2d(1),
    )

    model = SimCLR(
        features,
        num_ftrs=cfg['model']['num_ftrs'],
        out_dim=cfg['model']['out_dim']
    ).to(device)

    if state_dict is not None:
        load_from_state_dict(model, state_dict)

    encoder = SelfSupervisedEmbedding(model, None, None, None)
    embeddings, labels, filenames = encoder.embed(dataloader, device=device)

    if is_cli_call:
        path = os.path.join(os.getcwd(), 'embeddings.csv')
        save_embeddings(path, embeddings, labels, filenames)
        print(f'Embeddings are stored at {bcolors.OKBLUE}{path}{bcolors.ENDC}')
        return path

    return embeddings, labels, filenames