def __init__(self, in_planes: int, planes: int, stride: int = 1, num_splits: int = 0): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = get_norm_layer(planes, num_splits) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = get_norm_layer(planes, num_splits) self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), get_norm_layer(self.expansion * planes, num_splits))
def _get_features_and_projections(resnet, num_ftrs, out_dim, num_splits): """Removes classification head from the ResNet and adds a projection head. - Adds a batchnorm layer to the input layer. - Replaces the output layer by a Conv2d followed by adaptive average pool. - Adds a 2-layer mlp projection head. """ # get the number of features from the last channel last_conv_channels = list(resnet.children())[-1].in_features # replace output layer features = nn.Sequential( get_norm_layer(3, num_splits), *list(resnet.children())[:-1], nn.Conv2d(last_conv_channels, num_ftrs, 1), nn.AdaptiveAvgPool2d(1), ) # 2-layer mlp projection head projection_head = nn.Sequential(nn.Linear(num_ftrs, num_ftrs), nn.ReLU(), nn.Linear(num_ftrs, out_dim)) return features, projection_head
def __init__(self, block: nn.Module = BasicBlock, layers: List[int] = [2, 2, 2, 2], num_classes: int = 10, width: float = 1., num_splits: int = 0): super(ResNet, self).__init__() self.in_planes = int(64 * width) self.base = int(64 * width) self.conv1 = nn.Conv2d(3, self.base, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = get_norm_layer(self.base, num_splits) self.layer1 = self._make_layer(block, self.base, layers[0], stride=1, num_splits=num_splits) self.layer2 = self._make_layer(block, self.base * 2, layers[1], stride=2, num_splits=num_splits) self.layer3 = self._make_layer(block, self.base * 4, layers[2], stride=2, num_splits=num_splits) self.layer4 = self._make_layer(block, self.base * 8, layers[3], stride=2, num_splits=num_splits) self.linear = nn.Linear(self.base * 8 * block.expansion, num_classes)
def get_model_from_config(cfg, is_cli_call: bool = False ) -> SelfSupervisedEmbedding: checkpoint = cfg['checkpoint'] if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') if not checkpoint: checkpoint, key = get_ptmodel_from_config(cfg['model']) if not checkpoint: msg = 'Cannot download checkpoint for key {} '.format(key) msg += 'because it does not exist!' raise RuntimeError(msg) state_dict = load_state_dict_from_url( checkpoint, map_location=device)['state_dict'] else: checkpoint = fix_input_path(checkpoint) if is_cli_call else checkpoint state_dict = torch.load(checkpoint, map_location=device)['state_dict'] # load model resnet = ResNetGenerator(cfg['model']['name'], cfg['model']['width']) last_conv_channels = list(resnet.children())[-1].in_features features = nn.Sequential( get_norm_layer(3, 0), *list(resnet.children())[:-1], nn.Conv2d(last_conv_channels, cfg['model']['num_ftrs'], 1), nn.AdaptiveAvgPool2d(1), ) model = _SimCLR(features, num_ftrs=cfg['model']['num_ftrs'], out_dim=cfg['model']['out_dim']).to(device) if state_dict is not None: load_from_state_dict(model, state_dict) encoder = SelfSupervisedEmbedding(model, None, None, None) return encoder
def _train_cli(cfg, is_cli_call=True): input_dir = cfg['input_dir'] if input_dir and is_cli_call: input_dir = fix_input_path(input_dir) if 'seed' in cfg.keys(): seed = cfg['seed'] torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if cfg["trainer"]["weights_summary"] == "None": cfg["trainer"]["weights_summary"] = None if torch.cuda.is_available(): device = 'cuda' elif cfg['trainer'] and cfg['trainer']['gpus']: device = 'cpu' cfg['trainer']['gpus'] = 0 else: device = 'cpu' distributed_strategy = None if cfg['trainer']['gpus'] > 1: distributed_strategy = 'ddp' if cfg['loader']['batch_size'] < 64: msg = 'Training a self-supervised model with a small batch size: {}! ' msg = msg.format(cfg['loader']['batch_size']) msg += 'Small batch size may harm embedding quality. ' msg += 'You can specify the batch size via the loader key-word: ' msg += 'loader.batch_size=BSZ' warnings.warn(msg) # determine the number of available cores if cfg['loader']['num_workers'] < 0: cfg['loader']['num_workers'] = cpu_count() state_dict = None checkpoint = cfg['checkpoint'] if cfg['pre_trained'] and not checkpoint: # if checkpoint wasn't specified explicitly and pre_trained is True # try to load the checkpoint from the model zoo checkpoint, key = get_ptmodel_from_config(cfg['model']) if not checkpoint: msg = 'Cannot download checkpoint for key {} '.format(key) msg += 'because it does not exist! ' msg += 'Model will be trained from scratch.' warnings.warn(msg) elif checkpoint: checkpoint = fix_input_path(checkpoint) if is_cli_call else checkpoint if checkpoint: # load the PyTorch state dictionary and map it to the current device if is_url(checkpoint): state_dict = load_state_dict_from_url( checkpoint, map_location=device)['state_dict'] else: state_dict = torch.load(checkpoint, map_location=device)['state_dict'] # load model resnet = ResNetGenerator(cfg['model']['name'], cfg['model']['width']) last_conv_channels = list(resnet.children())[-1].in_features features = nn.Sequential( get_norm_layer(3, 0), *list(resnet.children())[:-1], nn.Conv2d(last_conv_channels, cfg['model']['num_ftrs'], 1), nn.AdaptiveAvgPool2d(1), ) model = _SimCLR(features, num_ftrs=cfg['model']['num_ftrs'], out_dim=cfg['model']['out_dim']) if state_dict is not None: load_from_state_dict(model, state_dict) criterion = NTXentLoss(**cfg['criterion']) optimizer = torch.optim.SGD(model.parameters(), **cfg['optimizer']) dataset = LightlyDataset(input_dir) cfg['loader']['batch_size'] = min(cfg['loader']['batch_size'], len(dataset)) collate_fn = ImageCollateFunction(**cfg['collate']) dataloader = torch.utils.data.DataLoader(dataset, **cfg['loader'], collate_fn=collate_fn) encoder = SelfSupervisedEmbedding(model, criterion, optimizer, dataloader) encoder.init_checkpoint_callback(**cfg['checkpoint_callback']) encoder.train_embedding(**cfg['trainer'], strategy=distributed_strategy) print( f'Best model is stored at: {bcolors.OKBLUE}{encoder.checkpoint}{bcolors.ENDC}' ) os.environ[cfg['environment_variable_names'] ['lightly_last_checkpoint_path']] = encoder.checkpoint return encoder.checkpoint
def _embed_cli(cfg, is_cli_call=True): checkpoint = cfg['checkpoint'] input_dir = cfg['input_dir'] if input_dir and is_cli_call: input_dir = fix_input_path(input_dir) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') transform = torchvision.transforms.Compose([ torchvision.transforms.Resize((cfg['collate']['input_size'], cfg['collate']['input_size'])), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) dataset = LightlyDataset(input_dir, transform=transform) cfg['loader']['drop_last'] = False cfg['loader']['shuffle'] = False cfg['loader']['batch_size'] = min( cfg['loader']['batch_size'], len(dataset) ) dataloader = torch.utils.data.DataLoader(dataset, **cfg['loader']) # load the PyTorch state dictionary and map it to the current device state_dict = None if not checkpoint: checkpoint, key = get_ptmodel_from_config(cfg['model']) if not checkpoint: msg = 'Cannot download checkpoint for key {} '.format(key) msg += 'because it does not exist!' raise RuntimeError(msg) state_dict = load_state_dict_from_url( checkpoint, map_location=device )['state_dict'] else: checkpoint = fix_input_path(checkpoint) if is_cli_call else checkpoint state_dict = torch.load( checkpoint, map_location=device )['state_dict'] # load model resnet = ResNetGenerator(cfg['model']['name'], cfg['model']['width']) last_conv_channels = list(resnet.children())[-1].in_features features = nn.Sequential( get_norm_layer(3, 0), *list(resnet.children())[:-1], nn.Conv2d(last_conv_channels, cfg['model']['num_ftrs'], 1), nn.AdaptiveAvgPool2d(1), ) model = SimCLR( features, num_ftrs=cfg['model']['num_ftrs'], out_dim=cfg['model']['out_dim'] ).to(device) if state_dict is not None: load_from_state_dict(model, state_dict) encoder = SelfSupervisedEmbedding(model, None, None, None) embeddings, labels, filenames = encoder.embed(dataloader, device=device) if is_cli_call: path = os.path.join(os.getcwd(), 'embeddings.csv') save_embeddings(path, embeddings, labels, filenames) print(f'Embeddings are stored at {bcolors.OKBLUE}{path}{bcolors.ENDC}') return path return embeddings, labels, filenames