Exemplo n.º 1
0
    def test_tuple_input(self):
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        resnet = ResNetGenerator('resnet-18')
        model = SimCLR(get_backbone(resnet, num_ftrs=32),
                       out_dim=128).to(device)

        x0 = torch.rand((self.batch_size, 3, 64, 64)).to(device)
        x1 = torch.rand((self.batch_size, 3, 64, 64)).to(device)

        out = model(x0)
        self.assertEqual(out.shape, (self.batch_size, 128))

        out, features = model(x0, return_features=True)
        self.assertEqual(out.shape, (self.batch_size, 128))
        self.assertEqual(features.shape, (self.batch_size, 32))

        out0, out1 = model(x0, x1)
        self.assertEqual(out0.shape, (self.batch_size, 128))
        self.assertEqual(out1.shape, (self.batch_size, 128))

        (out0, f0), (out1, f1) = model(x0, x1, return_features=True)
        self.assertEqual(out0.shape, (self.batch_size, 128))
        self.assertEqual(out1.shape, (self.batch_size, 128))
        self.assertEqual(f0.shape, (self.batch_size, 32))
        self.assertEqual(f1.shape, (self.batch_size, 32))
Exemplo n.º 2
0
 def test_create_variations_gpu(self):
     device = 'cuda' if torch.cuda.is_available() else 'cpu'
     if device == 'cuda':
         for model_name in self.resnet_variants:
             resnet = ResNetGenerator(model_name)
             model = SimCLR(get_backbone(resnet)).to(device)
             self.assertIsNotNone(model)
     else:
         pass
Exemplo n.º 3
0
    def test_feature_dim_configurable(self):
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        for model_name in self.resnet_variants:
            for num_ftrs, out_dim in zip([16, 64], [64, 256]):
                resnet = ResNetGenerator(model_name)
                model = SimCLR(get_backbone(resnet, num_ftrs=num_ftrs),
                               num_ftrs=num_ftrs,
                               out_dim=out_dim).to(device)

                # check that feature vector has correct dimension
                with torch.no_grad():
                    out_features = model.backbone(self.input_tensor.to(device))
                self.assertEqual(out_features.shape[1], num_ftrs)

                # check that projection head output has right dimension
                with torch.no_grad():
                    out_projection = model.projection_head(
                        out_features.squeeze())
                self.assertEqual(out_projection.shape[1], out_dim)
                self.assertIsNotNone(model)
Exemplo n.º 4
0
    def test_variations_input_dimension(self):
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        for model_name in self.resnet_variants:
            for input_width, input_height in zip([32, 64], [64, 64]):
                resnet = ResNetGenerator(model_name)
                model = SimCLR(get_backbone(resnet, num_ftrs=32)).to(device)

                input_tensor = torch.rand(
                    (self.batch_size, 3, input_height, input_width))
                with torch.no_grad():
                    out = model(input_tensor.to(device))

                self.assertIsNotNone(model)
                self.assertIsNotNone(out)
Exemplo n.º 5
0
def _train_cli(cfg, is_cli_call=True):

    input_dir = cfg['input_dir']
    if input_dir and is_cli_call:
        input_dir = fix_input_path(input_dir)

    if 'seed' in cfg.keys():
        seed = cfg['seed']
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    if torch.cuda.is_available():
        device = 'cuda'
    elif cfg['trainer'] and cfg['trainer']['gpus']:
        device = 'cpu'
        cfg['trainer']['gpus'] = 0

    if cfg['loader']['batch_size'] < 64:
        msg = 'Training a self-supervised model with a small batch size: {}! '
        msg = msg.format(cfg['loader']['batch_size'])
        msg += 'Small batch size may harm embedding quality. '
        msg += 'You can specify the batch size via the loader key-word: '
        msg += 'loader.batch_size=BSZ'
        warnings.warn(msg)

    state_dict = None
    checkpoint = cfg['checkpoint']
    if cfg['pre_trained'] and not checkpoint:
        # if checkpoint wasn't specified explicitly and pre_trained is True
        # try to load the checkpoint from the model zoo
        checkpoint, key = get_ptmodel_from_config(cfg['model'])
        if not checkpoint:
            msg = 'Cannot download checkpoint for key {} '.format(key)
            msg += 'because it does not exist! '
            msg += 'Model will be trained from scratch.'
            warnings.warn(msg)
    elif checkpoint:
        checkpoint = fix_input_path(checkpoint) if is_cli_call else checkpoint

    if checkpoint:
        # load the PyTorch state dictionary and map it to the current device
        if is_url(checkpoint):
            state_dict = load_state_dict_from_url(
                checkpoint, map_location=device)['state_dict']
        else:
            state_dict = torch.load(checkpoint,
                                    map_location=device)['state_dict']

    # load model
    resnet = ResNetGenerator(cfg['model']['name'], cfg['model']['width'])
    last_conv_channels = list(resnet.children())[-1].in_features
    features = nn.Sequential(
        get_norm_layer(3, 0),
        *list(resnet.children())[:-1],
        nn.Conv2d(last_conv_channels, cfg['model']['num_ftrs'], 1),
        nn.AdaptiveAvgPool2d(1),
    )

    model = SimCLR(features,
                   num_ftrs=cfg['model']['num_ftrs'],
                   out_dim=cfg['model']['out_dim'])
    if state_dict is not None:
        load_from_state_dict(model, state_dict)

    criterion = NTXentLoss(**cfg['criterion'])
    optimizer = torch.optim.SGD(model.parameters(), **cfg['optimizer'])

    dataset = LightlyDataset(input_dir)

    cfg['loader']['batch_size'] = min(cfg['loader']['batch_size'],
                                      len(dataset))

    collate_fn = ImageCollateFunction(**cfg['collate'])
    dataloader = torch.utils.data.DataLoader(dataset,
                                             **cfg['loader'],
                                             collate_fn=collate_fn)

    encoder = SelfSupervisedEmbedding(model, criterion, optimizer, dataloader)
    encoder.init_checkpoint_callback(**cfg['checkpoint_callback'])
    encoder.train_embedding(**cfg['trainer'])

    print('Best model is stored at: %s' % (encoder.checkpoint))
    return encoder.checkpoint
Exemplo n.º 6
0
 def test_create_variations_cpu(self):
     for model_name in self.resnet_variants:
         resnet = ResNetGenerator(model_name)
         model = SimCLR(get_backbone(resnet))
         self.assertIsNotNone(model)
Exemplo n.º 7
0
def _embed_cli(cfg, is_cli_call=True):

    checkpoint = cfg['checkpoint']

    input_dir = cfg['input_dir']
    if input_dir and is_cli_call:
        input_dir = fix_input_path(input_dir)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((cfg['collate']['input_size'],
                                       cfg['collate']['input_size'])),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])
    ])

    dataset = LightlyDataset(input_dir, transform=transform)

    cfg['loader']['drop_last'] = False
    cfg['loader']['shuffle'] = False
    cfg['loader']['batch_size'] = min(
        cfg['loader']['batch_size'],
        len(dataset)
    )
    dataloader = torch.utils.data.DataLoader(dataset, **cfg['loader'])

    # load the PyTorch state dictionary and map it to the current device    
    state_dict = None
    if not checkpoint:
        checkpoint, key = get_ptmodel_from_config(cfg['model'])
        if not checkpoint:
            msg = 'Cannot download checkpoint for key {} '.format(key)
            msg += 'because it does not exist!'
            raise RuntimeError(msg)
        state_dict = load_state_dict_from_url(
            checkpoint, map_location=device
        )['state_dict']
    else:
        checkpoint = fix_input_path(checkpoint) if is_cli_call else checkpoint
        state_dict = torch.load(
            checkpoint, map_location=device
        )['state_dict']

    # load model
    resnet = ResNetGenerator(cfg['model']['name'], cfg['model']['width'])
    last_conv_channels = list(resnet.children())[-1].in_features
    features = nn.Sequential(
        get_norm_layer(3, 0),
        *list(resnet.children())[:-1],
        nn.Conv2d(last_conv_channels, cfg['model']['num_ftrs'], 1),
        nn.AdaptiveAvgPool2d(1),
    )

    model = SimCLR(
        features,
        num_ftrs=cfg['model']['num_ftrs'],
        out_dim=cfg['model']['out_dim']
    ).to(device)

    if state_dict is not None:
        load_from_state_dict(model, state_dict)

    encoder = SelfSupervisedEmbedding(model, None, None, None)
    embeddings, labels, filenames = encoder.embed(dataloader, device=device)

    if is_cli_call:
        path = os.path.join(os.getcwd(), 'embeddings.csv')
        save_embeddings(path, embeddings, labels, filenames)
        print(f'Embeddings are stored at {bcolors.OKBLUE}{path}{bcolors.ENDC}')
        return path

    return embeddings, labels, filenames