def test_tuple_input(self): device = 'cuda' if torch.cuda.is_available() else 'cpu' resnet = ResNetGenerator('resnet-18') model = SimCLR(get_backbone(resnet, num_ftrs=32), out_dim=128).to(device) x0 = torch.rand((self.batch_size, 3, 64, 64)).to(device) x1 = torch.rand((self.batch_size, 3, 64, 64)).to(device) out = model(x0) self.assertEqual(out.shape, (self.batch_size, 128)) out, features = model(x0, return_features=True) self.assertEqual(out.shape, (self.batch_size, 128)) self.assertEqual(features.shape, (self.batch_size, 32)) out0, out1 = model(x0, x1) self.assertEqual(out0.shape, (self.batch_size, 128)) self.assertEqual(out1.shape, (self.batch_size, 128)) (out0, f0), (out1, f1) = model(x0, x1, return_features=True) self.assertEqual(out0.shape, (self.batch_size, 128)) self.assertEqual(out1.shape, (self.batch_size, 128)) self.assertEqual(f0.shape, (self.batch_size, 32)) self.assertEqual(f1.shape, (self.batch_size, 32))
def test_create_variations_gpu(self): device = 'cuda' if torch.cuda.is_available() else 'cpu' if device == 'cuda': for model_name in self.resnet_variants: resnet = ResNetGenerator(model_name) model = SimCLR(get_backbone(resnet)).to(device) self.assertIsNotNone(model) else: pass
def test_feature_dim_configurable(self): device = 'cuda' if torch.cuda.is_available() else 'cpu' for model_name in self.resnet_variants: for num_ftrs, out_dim in zip([16, 64], [64, 256]): resnet = ResNetGenerator(model_name) model = SimCLR(get_backbone(resnet, num_ftrs=num_ftrs), num_ftrs=num_ftrs, out_dim=out_dim).to(device) # check that feature vector has correct dimension with torch.no_grad(): out_features = model.backbone(self.input_tensor.to(device)) self.assertEqual(out_features.shape[1], num_ftrs) # check that projection head output has right dimension with torch.no_grad(): out_projection = model.projection_head( out_features.squeeze()) self.assertEqual(out_projection.shape[1], out_dim) self.assertIsNotNone(model)
def test_variations_input_dimension(self): device = 'cuda' if torch.cuda.is_available() else 'cpu' for model_name in self.resnet_variants: for input_width, input_height in zip([32, 64], [64, 64]): resnet = ResNetGenerator(model_name) model = SimCLR(get_backbone(resnet, num_ftrs=32)).to(device) input_tensor = torch.rand( (self.batch_size, 3, input_height, input_width)) with torch.no_grad(): out = model(input_tensor.to(device)) self.assertIsNotNone(model) self.assertIsNotNone(out)
def _train_cli(cfg, is_cli_call=True): input_dir = cfg['input_dir'] if input_dir and is_cli_call: input_dir = fix_input_path(input_dir) if 'seed' in cfg.keys(): seed = cfg['seed'] torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if torch.cuda.is_available(): device = 'cuda' elif cfg['trainer'] and cfg['trainer']['gpus']: device = 'cpu' cfg['trainer']['gpus'] = 0 if cfg['loader']['batch_size'] < 64: msg = 'Training a self-supervised model with a small batch size: {}! ' msg = msg.format(cfg['loader']['batch_size']) msg += 'Small batch size may harm embedding quality. ' msg += 'You can specify the batch size via the loader key-word: ' msg += 'loader.batch_size=BSZ' warnings.warn(msg) state_dict = None checkpoint = cfg['checkpoint'] if cfg['pre_trained'] and not checkpoint: # if checkpoint wasn't specified explicitly and pre_trained is True # try to load the checkpoint from the model zoo checkpoint, key = get_ptmodel_from_config(cfg['model']) if not checkpoint: msg = 'Cannot download checkpoint for key {} '.format(key) msg += 'because it does not exist! ' msg += 'Model will be trained from scratch.' warnings.warn(msg) elif checkpoint: checkpoint = fix_input_path(checkpoint) if is_cli_call else checkpoint if checkpoint: # load the PyTorch state dictionary and map it to the current device if is_url(checkpoint): state_dict = load_state_dict_from_url( checkpoint, map_location=device)['state_dict'] else: state_dict = torch.load(checkpoint, map_location=device)['state_dict'] # load model resnet = ResNetGenerator(cfg['model']['name'], cfg['model']['width']) last_conv_channels = list(resnet.children())[-1].in_features features = nn.Sequential( get_norm_layer(3, 0), *list(resnet.children())[:-1], nn.Conv2d(last_conv_channels, cfg['model']['num_ftrs'], 1), nn.AdaptiveAvgPool2d(1), ) model = SimCLR(features, num_ftrs=cfg['model']['num_ftrs'], out_dim=cfg['model']['out_dim']) if state_dict is not None: load_from_state_dict(model, state_dict) criterion = NTXentLoss(**cfg['criterion']) optimizer = torch.optim.SGD(model.parameters(), **cfg['optimizer']) dataset = LightlyDataset(input_dir) cfg['loader']['batch_size'] = min(cfg['loader']['batch_size'], len(dataset)) collate_fn = ImageCollateFunction(**cfg['collate']) dataloader = torch.utils.data.DataLoader(dataset, **cfg['loader'], collate_fn=collate_fn) encoder = SelfSupervisedEmbedding(model, criterion, optimizer, dataloader) encoder.init_checkpoint_callback(**cfg['checkpoint_callback']) encoder.train_embedding(**cfg['trainer']) print('Best model is stored at: %s' % (encoder.checkpoint)) return encoder.checkpoint
def test_create_variations_cpu(self): for model_name in self.resnet_variants: resnet = ResNetGenerator(model_name) model = SimCLR(get_backbone(resnet)) self.assertIsNotNone(model)
def _embed_cli(cfg, is_cli_call=True): checkpoint = cfg['checkpoint'] input_dir = cfg['input_dir'] if input_dir and is_cli_call: input_dir = fix_input_path(input_dir) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') transform = torchvision.transforms.Compose([ torchvision.transforms.Resize((cfg['collate']['input_size'], cfg['collate']['input_size'])), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) dataset = LightlyDataset(input_dir, transform=transform) cfg['loader']['drop_last'] = False cfg['loader']['shuffle'] = False cfg['loader']['batch_size'] = min( cfg['loader']['batch_size'], len(dataset) ) dataloader = torch.utils.data.DataLoader(dataset, **cfg['loader']) # load the PyTorch state dictionary and map it to the current device state_dict = None if not checkpoint: checkpoint, key = get_ptmodel_from_config(cfg['model']) if not checkpoint: msg = 'Cannot download checkpoint for key {} '.format(key) msg += 'because it does not exist!' raise RuntimeError(msg) state_dict = load_state_dict_from_url( checkpoint, map_location=device )['state_dict'] else: checkpoint = fix_input_path(checkpoint) if is_cli_call else checkpoint state_dict = torch.load( checkpoint, map_location=device )['state_dict'] # load model resnet = ResNetGenerator(cfg['model']['name'], cfg['model']['width']) last_conv_channels = list(resnet.children())[-1].in_features features = nn.Sequential( get_norm_layer(3, 0), *list(resnet.children())[:-1], nn.Conv2d(last_conv_channels, cfg['model']['num_ftrs'], 1), nn.AdaptiveAvgPool2d(1), ) model = SimCLR( features, num_ftrs=cfg['model']['num_ftrs'], out_dim=cfg['model']['out_dim'] ).to(device) if state_dict is not None: load_from_state_dict(model, state_dict) encoder = SelfSupervisedEmbedding(model, None, None, None) embeddings, labels, filenames = encoder.embed(dataloader, device=device) if is_cli_call: path = os.path.join(os.getcwd(), 'embeddings.csv') save_embeddings(path, embeddings, labels, filenames) print(f'Embeddings are stored at {bcolors.OKBLUE}{path}{bcolors.ENDC}') return path return embeddings, labels, filenames