def _download_cli(cfg, is_cli_call=True): tag_name = cfg['tag_name'] dataset_id = cfg['dataset_id'] token = cfg['token'] if not tag_name: print('Please specify a tag name') print('For help, try: lightly-download --help') return if not token or not dataset_id: print('Please specify your access token and dataset id') print('For help, try: lightly-download --help') return api_workflow_client = ApiWorkflowClient(token=token, dataset_id=dataset_id) # get tag id tag_name_id_dict = dict([tag.name, tag.id] for tag in api_workflow_client._get_all_tags()) tag_id = tag_name_id_dict.get(tag_name, None) if tag_id is None: print(f'The specified tag {tag_name} does not exist.') return # get tag data tag_data = api_workflow_client.tags_api.get_tag_by_tag_id( dataset_id=dataset_id, tag_id=tag_id) # get samples chosen_samples_ids = BitMask.from_hex(tag_data.bit_mask_data).to_indices() samples = [ api_workflow_client.filenames_on_server[i] for i in chosen_samples_ids ] # store sample names in a .txt file with open(cfg['tag_name'] + '.txt', 'w') as f: for item in samples: f.write("%s\n" % item) msg = 'The list of files in tag {} is stored at: '.format(cfg['tag_name']) msg += os.path.join(os.getcwd(), cfg['tag_name'] + '.txt') print(msg, flush=True) if not cfg['input_dir'] and cfg['output_dir']: # download full images from api output_dir = fix_input_path(cfg['output_dir']) api_workflow_client.download_dataset(output_dir, tag_name=tag_name) elif cfg['input_dir'] and cfg['output_dir']: input_dir = fix_input_path(cfg['input_dir']) output_dir = fix_input_path(cfg['output_dir']) print(f'Copying files from {input_dir} to {output_dir}.') # create a dataset from the input directory dataset = data.LightlyDataset(input_dir=input_dir) # dump the dataset in the output directory dataset.dump(output_dir, samples)
def _upload_cli(cfg, is_cli_call=True): input_dir = cfg['input_dir'] if input_dir and is_cli_call: input_dir = fix_input_path(input_dir) path_to_embeddings = cfg['embeddings'] if path_to_embeddings and is_cli_call: path_to_embeddings = fix_input_path(path_to_embeddings) dataset_id = cfg['dataset_id'] token = cfg['token'] new_dataset_name = cfg['new_dataset_name'] cli_api_args_wrong = False if not token: print_as_warning('Please specify your access token.') cli_api_args_wrong = True dataset_id_ok = dataset_id and len(dataset_id) > 0 new_dataset_name_ok = new_dataset_name and len(new_dataset_name) > 0 if new_dataset_name_ok and not dataset_id_ok: api_workflow_client = ApiWorkflowClient(token=token) api_workflow_client.create_dataset(dataset_name=new_dataset_name) elif dataset_id_ok and not new_dataset_name_ok: api_workflow_client = ApiWorkflowClient(token=token, dataset_id=dataset_id) else: print_as_warning('Please specify either the dataset_id of an existing dataset or a new_dataset_name.') cli_api_args_wrong = True if cli_api_args_wrong: print_as_warning('For help, try: lightly-upload --help') return size = cfg['resize'] if not isinstance(size, int): size = tuple(size) transform = None if isinstance(size, tuple) or size > 0: transform = torchvision.transforms.Resize(size) if input_dir: mode = cfg['upload'] dataset = LightlyDataset(input_dir=input_dir, transform=transform) api_workflow_client.upload_dataset( input=dataset, mode=mode, max_workers=cfg['loader']['num_workers'] ) print(f"Finished the upload of the dataset.") if path_to_embeddings: name = cfg['embedding_name'] print("Starting upload of embeddings.") api_workflow_client.upload_embeddings( path_to_embeddings_csv=path_to_embeddings, name=name ) print("Finished upload of embeddings.") if new_dataset_name_ok: print(f'The dataset_id of the newly created dataset is ' f'{bcolors.OKBLUE}{api_workflow_client.dataset_id}{bcolors.ENDC}')
def _download_cli(cfg, is_cli_call=True): tag_name = cfg['tag_name'] dataset_id = cfg['dataset_id'] token = cfg['token'] if not tag_name: print('Please specify a tag name') print('For help, try: lightly-download --help') return if not token or not dataset_id: print('Please specify your access token and dataset id') print('For help, try: lightly-download --help') return # get all samples in the queried tag samples = get_samples_by_tag(tag_name, dataset_id, token, mode='list', filenames=None) # store sample names in a .txt file with open(cfg['tag_name'] + '.txt', 'w') as f: for item in samples: f.write("%s\n" % item) msg = 'The list of files in tag {} is stored at: '.format(cfg['tag_name']) msg += os.path.join(os.getcwd(), cfg['tag_name'] + '.txt') print(msg) if cfg['input_dir'] and cfg['output_dir']: # "name.jpg" -> "/name.jpg" to prevent bugs like this: # "path/to/1234.jpg" ends with both "234.jpg" and "1234.jpg" samples = [os.path.join(' ', s)[1:] for s in samples] # copy all images from one folder to the other input_dir = fix_input_path(cfg['input_dir']) output_dir = fix_input_path(cfg['output_dir']) dataset = data.LightlyDataset(from_folder=input_dir) basenames = dataset.get_filenames() source_names = [os.path.join(input_dir, f) for f in basenames] target_names = [os.path.join(output_dir, f) for f in basenames] # only copy files which are in the tag indices = [ i for i in range(len(source_names)) if any([source_names[i].endswith(s) for s in samples]) ] print(f'Copying files from {input_dir} to {output_dir}.') for i in tqdm(indices): dirname = os.path.dirname(target_names[i]) os.makedirs(dirname, exist_ok=True) shutil.copy(source_names[i], target_names[i])
def _upload_cli(cfg, is_cli_call=True): input_dir = cfg['input_dir'] if input_dir and is_cli_call: input_dir = fix_input_path(input_dir) path_to_embeddings = cfg['embeddings'] if path_to_embeddings and is_cli_call: path_to_embeddings = fix_input_path(path_to_embeddings) dataset_id = cfg['dataset_id'] token = cfg['token'] new_dataset_name = cfg['new_dataset_name'] if not token: warnings.warn('Please specify your access token. For help, try: lightly-upload --help') return dataset_id_ok = dataset_id and len(dataset_id) > 0 new_dataset_name_ok = new_dataset_name and len(new_dataset_name) > 0 if new_dataset_name_ok and not dataset_id_ok: api_workflow_client = ApiWorkflowClient(token=token) api_workflow_client.create_dataset(dataset_name=new_dataset_name) elif dataset_id_ok and not new_dataset_name_ok: api_workflow_client = ApiWorkflowClient(token=token, dataset_id=dataset_id) else: warnings.warn('Please specify either the dataset_id of an existing dataset or a new_dataset_name. ' 'For help, try: lightly-upload --help') return size = cfg['resize'] if not isinstance(size, int): size = tuple(size) transform = None if isinstance(size, tuple) or size > 0: transform = torchvision.transforms.Resize(size) if input_dir: mode = cfg['upload'] dataset = LightlyDataset(input_dir=input_dir, transform=transform) api_workflow_client.upload_dataset( input=dataset, mode=mode, max_workers=cfg['loader']['num_workers'] ) if path_to_embeddings: name = cfg['embedding_name'] api_workflow_client.upload_embeddings( path_to_embeddings_csv=path_to_embeddings, name=name )
def _download_cli(cfg, is_cli_call=True): tag_name = str(cfg['tag_name']) dataset_id = str(cfg['dataset_id']) token = str(cfg['token']) if not tag_name or not token or not dataset_id: print_as_warning('Please specify all of the parameters tag_name, token and dataset_id') print_as_warning('For help, try: lightly-download --help') return api_workflow_client = ApiWorkflowClient( token=token, dataset_id=dataset_id ) # get tag id tag_data = api_workflow_client.get_tag_by_name(tag_name) filenames_tag = api_workflow_client.get_filenames_in_tag( tag_data, exclude_parent_tag=cfg['exclude_parent_tag'], ) # store sample names in a .txt file filename = tag_name + '.txt' with open(filename, 'w') as f: for item in filenames_tag: f.write("%s\n" % item) filepath = os.path.join(os.getcwd(), filename) msg = f'The list of files in tag {cfg["tag_name"]} is stored at: {bcolors.OKBLUE}{filepath}{bcolors.ENDC}' print(msg, flush=True) if not cfg['input_dir'] and cfg['output_dir']: # download full images from api output_dir = fix_input_path(cfg['output_dir']) api_workflow_client.download_dataset(output_dir, tag_name=tag_name) elif cfg['input_dir'] and cfg['output_dir']: input_dir = fix_input_path(cfg['input_dir']) output_dir = fix_input_path(cfg['output_dir']) print(f'Copying files from {input_dir} to {bcolors.OKBLUE}{output_dir}{bcolors.ENDC}.') # create a dataset from the input directory dataset = data.LightlyDataset(input_dir=input_dir) # dump the dataset in the output directory dataset.dump(output_dir, filenames_tag)
def _download_cli(cfg, is_cli_call=True): tag_name = cfg['tag_name'] dataset_id = cfg['dataset_id'] token = cfg['token'] if not tag_name: print('Please specify a tag name') print('For help, try: lightly-download --help') return if not token or not dataset_id: print('Please specify your access token and dataset id') print('For help, try: lightly-download --help') return # get all samples in the queried tag samples = get_samples_by_tag(tag_name, dataset_id, token, mode='list', filenames=None) # store sample names in a .txt file with open(cfg['tag_name'] + '.txt', 'w') as f: for item in samples: f.write("%s\n" % item) msg = 'The list of files in tag {} is stored at: '.format(cfg['tag_name']) msg += os.path.join(os.getcwd(), cfg['tag_name'] + '.txt') print(msg, flush=True) if cfg['input_dir'] and cfg['output_dir']: input_dir = fix_input_path(cfg['input_dir']) output_dir = fix_input_path(cfg['output_dir']) print(f'Copying files from {input_dir} to {output_dir}.') # create a dataset from the input directory dataset = data.LightlyDataset(from_folder=input_dir) # dump the dataset in the output directory dataset.dump(output_dir, samples)
def _crop_cli(cfg, is_cli_call=True): input_dir = cfg['input_dir'] if input_dir and is_cli_call: input_dir = fix_input_path(input_dir) output_dir = cfg['output_dir'] if output_dir and is_cli_call: output_dir = fix_input_path(output_dir) label_dir = cfg['label_dir'] if label_dir and is_cli_call: label_dir = fix_input_path(label_dir) label_names_file = cfg['label_names_file'] if label_names_file and len(label_names_file) > 0: if is_cli_call: label_names_file = fix_input_path(label_names_file) with open(label_names_file, 'r') as file: label_names_file_dict = yaml.full_load(file) class_names = label_names_file_dict['names'] else: class_names = None dataset = LightlyDataset(input_dir) class_indices_list_list: List[List[int]] = [] bounding_boxes_list_list: List[List[BoundingBox]] = [] # YOLO-Specific for filename_image in dataset.get_filenames(): filepath_image_base, image_extension = os.path.splitext(filename_image) filepath_label = os.path.join(label_dir, filename_image).replace( image_extension, '.txt') class_indices, bounding_boxes = read_yolo_label_file( filepath_label, float(cfg['crop_padding'])) class_indices_list_list.append(class_indices) bounding_boxes_list_list.append(bounding_boxes) cropped_images_list_list = \ crop_dataset_by_bounding_boxes_and_save(dataset, output_dir, bounding_boxes_list_list, class_indices_list_list, class_names) print( f'Cropped images are stored at: {bcolors.OKBLUE}{output_dir}{bcolors.ENDC}' ) return cropped_images_list_list
def _upload_cli(cfg, is_cli_call=True): input_dir = cfg['input_dir'] if input_dir and is_cli_call: input_dir = fix_input_path(input_dir) path_to_embeddings = cfg['embeddings'] if path_to_embeddings and is_cli_call: path_to_embeddings = fix_input_path(path_to_embeddings) dataset_id = cfg['dataset_id'] token = cfg['token'] size = cfg['resize'] if not isinstance(size, int): size = tuple(size) if not token or not dataset_id: print('Please specify your access token and dataset id.') print('For help, try: lightly-upload --help') return if input_dir: mode = cfg['upload'] try: upload_images_from_folder(input_dir, dataset_id, token, mode=mode, size=size) except (ValueError, ConnectionRefusedError) as error: msg = f'Error: {error}' print(msg) exit(0) if path_to_embeddings: max_upload = cfg['emb_upload_bsz'] upload_embeddings_from_csv(path_to_embeddings, dataset_id, token, max_upload=max_upload, embedding_name=cfg['embedding_name'])
def _upload_cli(cfg, is_cli_call=True): input_dir = cfg['input_dir'] if input_dir and is_cli_call: input_dir = fix_input_path(input_dir) path_to_embeddings = cfg['embeddings'] if path_to_embeddings and is_cli_call: path_to_embeddings = fix_input_path(path_to_embeddings) dataset_id = cfg['dataset_id'] token = cfg['token'] size = cfg['resize'] if not isinstance(size, int): size = tuple(size) transform = None if isinstance(size, tuple) or size > 0: transform = torchvision.transforms.Resize(size) if not token or not dataset_id: print('Please specify your access token and dataset id.') print('For help, try: lightly-upload --help') return api_workflow_client = ApiWorkflowClient(token=token, dataset_id=dataset_id) if input_dir: mode = cfg['upload'] dataset = LightlyDataset(input_dir=input_dir, transform=transform) api_workflow_client.upload_dataset(input=dataset, mode=mode) if path_to_embeddings: name = cfg['embedding_name'] api_workflow_client.upload_embeddings( path_to_embeddings_csv=path_to_embeddings, name=name)
def _embed_cli(cfg, is_cli_call=True): input_dir = cfg['input_dir'] if input_dir and is_cli_call: input_dir = fix_input_path(input_dir) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') transform = torchvision.transforms.Compose([ torchvision.transforms.Resize( (cfg['collate']['input_size'], cfg['collate']['input_size'])), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) dataset = LightlyDataset(input_dir, transform=transform) # disable drop_last and shuffle cfg['loader']['drop_last'] = False cfg['loader']['shuffle'] = False cfg['loader']['batch_size'] = min(cfg['loader']['batch_size'], len(dataset)) # determine the number of available cores if cfg['loader']['num_workers'] < 0: cfg['loader']['num_workers'] = cpu_count() dataloader = torch.utils.data.DataLoader(dataset, **cfg['loader']) encoder = get_model_from_config(cfg, is_cli_call) embeddings, labels, filenames = encoder.embed(dataloader, device=device) if is_cli_call: path = os.path.join(os.getcwd(), 'embeddings.csv') save_embeddings(path, embeddings, labels, filenames) print(f'Embeddings are stored at {bcolors.OKBLUE}{path}{bcolors.ENDC}') os.environ[cfg['environment_variable_names'] ['lightly_last_embedding_path']] = path return path return embeddings, labels, filenames
def _upload_cli(cfg, is_cli_call=True): input_dir = cfg['input_dir'] if input_dir and is_cli_call: input_dir = fix_input_path(input_dir) path_to_embeddings = cfg['embeddings'] if path_to_embeddings and is_cli_call: path_to_embeddings = fix_input_path(path_to_embeddings) dataset_id = cfg['dataset_id'] token = cfg['token'] new_dataset_name = cfg['new_dataset_name'] cli_api_args_wrong = False if not token: print_as_warning('Please specify your access token.') cli_api_args_wrong = True if dataset_id: if new_dataset_name: print_as_warning( 'Please specify either the dataset_id of an existing dataset ' 'or a new_dataset_name, but not both.') cli_api_args_wrong = True else: api_workflow_client = \ ApiWorkflowClient(token=token, dataset_id=dataset_id) else: if new_dataset_name: api_workflow_client = ApiWorkflowClient(token=token) api_workflow_client.create_dataset(dataset_name=new_dataset_name) else: print_as_warning( 'Please specify either the dataset_id of an existing dataset ' 'or a new_dataset_name.') cli_api_args_wrong = True # delete the dataset_id as it might be an empty string # Use api_workflow_client.dataset_id instead del dataset_id if cli_api_args_wrong: print_as_warning('For help, try: lightly-upload --help') return # potentially load custom metadata custom_metadata = None if cfg['custom_metadata']: path_to_custom_metadata = fix_input_path(cfg['custom_metadata']) print('Loading custom metadata from ' f'{bcolors.OKBLUE}{path_to_custom_metadata}{bcolors.ENDC}') with open(path_to_custom_metadata, 'r') as f: custom_metadata = json.load(f) # set the number of workers if unset if cfg['loader']['num_workers'] < 0: # set the number of workers to the number of CPUs available, # but minimum of 8 num_workers = max(8, cpu_count()) num_workers = min(32, num_workers) cfg['loader']['num_workers'] = num_workers size = cfg['resize'] if not isinstance(size, int): size = tuple(size) transform = None if isinstance(size, tuple) or size > 0: transform = torchvision.transforms.Resize(size) if input_dir: mode = cfg['upload'] dataset = LightlyDataset(input_dir=input_dir, transform=transform) api_workflow_client.upload_dataset( input=dataset, mode=mode, max_workers=cfg['loader']['num_workers'], custom_metadata=custom_metadata, ) print('Finished the upload of the dataset.') if path_to_embeddings: name = cfg['embedding_name'] print('Starting upload of embeddings.') api_workflow_client.upload_embeddings( path_to_embeddings_csv=path_to_embeddings, name=name) print('Finished upload of embeddings.') if custom_metadata is not None and not input_dir: # upload custom metadata separately api_workflow_client.upload_custom_metadata( custom_metadata, verbose=True, max_workers=cfg['loader']['num_workers'], ) if new_dataset_name: print( f'The dataset_id of the newly created dataset is ' f'{bcolors.OKBLUE}{api_workflow_client.dataset_id}{bcolors.ENDC}') os.environ[cfg['environment_variable_names'] ['lightly_last_dataset_id']] = api_workflow_client.dataset_id
def _download_cli(cfg, is_cli_call=True): tag_name = cfg['tag_name'] dataset_id = cfg['dataset_id'] token = cfg['token'] if not tag_name or not token or not dataset_id: print_as_warning( 'Please specify all of the parameters tag_name, token and dataset_id' ) print_as_warning('For help, try: lightly-download --help') return api_workflow_client = ApiWorkflowClient(token=token, dataset_id=dataset_id) # get tag id tag_name_id_dict = dict([tag.name, tag.id] for tag in api_workflow_client._get_all_tags()) tag_id = tag_name_id_dict.get(tag_name, None) if tag_id is None: warnings.warn(f'The specified tag {tag_name} does not exist.') return # get tag data tag_data: TagData = api_workflow_client.tags_api.get_tag_by_tag_id( dataset_id=dataset_id, tag_id=tag_id) if cfg["exclude_parent_tag"]: parent_tag_id = tag_data.prev_tag_id tag_arithmetics_request = TagArithmeticsRequest( tag_id1=tag_data.id, tag_id2=parent_tag_id, operation=TagArithmeticsOperation.DIFFERENCE) bit_mask_response: TagBitMaskResponse \ = api_workflow_client.tags_api.perform_tag_arithmetics(body=tag_arithmetics_request, dataset_id=dataset_id) bit_mask_data = bit_mask_response.bit_mask_data else: bit_mask_data = tag_data.bit_mask_data # get samples chosen_samples_ids = BitMask.from_hex(bit_mask_data).to_indices() samples = [ api_workflow_client.filenames_on_server[i] for i in chosen_samples_ids ] # store sample names in a .txt file filename = cfg['tag_name'] + '.txt' with open(filename, 'w') as f: for item in samples: f.write("%s\n" % item) filepath = os.path.join(os.getcwd(), filename) msg = f'The list of files in tag {cfg["tag_name"]} is stored at: {bcolors.OKBLUE}{filepath}{bcolors.ENDC}' print(msg, flush=True) if not cfg['input_dir'] and cfg['output_dir']: # download full images from api output_dir = fix_input_path(cfg['output_dir']) api_workflow_client.download_dataset(output_dir, tag_name=tag_name) elif cfg['input_dir'] and cfg['output_dir']: input_dir = fix_input_path(cfg['input_dir']) output_dir = fix_input_path(cfg['output_dir']) print( f'Copying files from {input_dir} to {bcolors.OKBLUE}{output_dir}{bcolors.ENDC}.' ) # create a dataset from the input directory dataset = data.LightlyDataset(input_dir=input_dir) # dump the dataset in the output directory dataset.dump(output_dir, samples)
def _train_cli(cfg, is_cli_call=True): input_dir = cfg['input_dir'] if input_dir and is_cli_call: input_dir = fix_input_path(input_dir) if 'seed' in cfg.keys(): seed = cfg['seed'] torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if cfg["trainer"]["weights_summary"] == "None": cfg["trainer"]["weights_summary"] = None if torch.cuda.is_available(): device = 'cuda' elif cfg['trainer'] and cfg['trainer']['gpus']: device = 'cpu' cfg['trainer']['gpus'] = 0 else: device = 'cpu' distributed_strategy = None if cfg['trainer']['gpus'] > 1: distributed_strategy = 'ddp' if cfg['loader']['batch_size'] < 64: msg = 'Training a self-supervised model with a small batch size: {}! ' msg = msg.format(cfg['loader']['batch_size']) msg += 'Small batch size may harm embedding quality. ' msg += 'You can specify the batch size via the loader key-word: ' msg += 'loader.batch_size=BSZ' warnings.warn(msg) # determine the number of available cores if cfg['loader']['num_workers'] < 0: cfg['loader']['num_workers'] = cpu_count() state_dict = None checkpoint = cfg['checkpoint'] if cfg['pre_trained'] and not checkpoint: # if checkpoint wasn't specified explicitly and pre_trained is True # try to load the checkpoint from the model zoo checkpoint, key = get_ptmodel_from_config(cfg['model']) if not checkpoint: msg = 'Cannot download checkpoint for key {} '.format(key) msg += 'because it does not exist! ' msg += 'Model will be trained from scratch.' warnings.warn(msg) elif checkpoint: checkpoint = fix_input_path(checkpoint) if is_cli_call else checkpoint if checkpoint: # load the PyTorch state dictionary and map it to the current device if is_url(checkpoint): state_dict = load_state_dict_from_url( checkpoint, map_location=device)['state_dict'] else: state_dict = torch.load(checkpoint, map_location=device)['state_dict'] # load model resnet = ResNetGenerator(cfg['model']['name'], cfg['model']['width']) last_conv_channels = list(resnet.children())[-1].in_features features = nn.Sequential( get_norm_layer(3, 0), *list(resnet.children())[:-1], nn.Conv2d(last_conv_channels, cfg['model']['num_ftrs'], 1), nn.AdaptiveAvgPool2d(1), ) model = _SimCLR(features, num_ftrs=cfg['model']['num_ftrs'], out_dim=cfg['model']['out_dim']) if state_dict is not None: load_from_state_dict(model, state_dict) criterion = NTXentLoss(**cfg['criterion']) optimizer = torch.optim.SGD(model.parameters(), **cfg['optimizer']) dataset = LightlyDataset(input_dir) cfg['loader']['batch_size'] = min(cfg['loader']['batch_size'], len(dataset)) collate_fn = ImageCollateFunction(**cfg['collate']) dataloader = torch.utils.data.DataLoader(dataset, **cfg['loader'], collate_fn=collate_fn) encoder = SelfSupervisedEmbedding(model, criterion, optimizer, dataloader) encoder.init_checkpoint_callback(**cfg['checkpoint_callback']) encoder.train_embedding(**cfg['trainer'], strategy=distributed_strategy) print( f'Best model is stored at: {bcolors.OKBLUE}{encoder.checkpoint}{bcolors.ENDC}' ) os.environ[cfg['environment_variable_names'] ['lightly_last_checkpoint_path']] = encoder.checkpoint return encoder.checkpoint
def _train_cli(cfg, is_cli_call=True): data = cfg['data'] download = cfg['download'] root = cfg['root'] if root and is_cli_call: root = fix_input_path(root) input_dir = cfg['input_dir'] if input_dir and is_cli_call: input_dir = fix_input_path(input_dir) if 'seed' in cfg.keys(): seed = cfg['seed'] torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if torch.cuda.is_available(): device = 'cuda' elif cfg['trainer'] and cfg['trainer']['gpus']: device = 'cpu' cfg['trainer']['gpus'] = 0 if cfg['loader']['batch_size'] < 64: msg = 'Training a self-supervised model with a small batch size: {}! ' msg = msg.format(cfg['loader']['batch_size']) msg += 'Small batch size may harm embedding quality. ' msg += 'You can specify the batch size via the loader key-word: ' msg += 'loader.batch_size=BSZ' warnings.warn(msg) state_dict = None checkpoint = cfg['checkpoint'] if cfg['pre_trained'] and not checkpoint: # if checkpoint wasn't specified explicitly and pre_trained is True # try to load the checkpoint from the model zoo checkpoint, key = get_ptmodel_from_config(cfg['model']) if not checkpoint: msg = 'Cannot download checkpoint for key {} '.format(key) msg += 'because it does not exist! ' msg += 'Model will be trained from scratch.' warnings.warn(msg) elif checkpoint: checkpoint = fix_input_path(checkpoint) if is_cli_call else checkpoint if checkpoint: # load the PyTorch state dictionary and map it to the current device if is_url(checkpoint): state_dict = load_state_dict_from_url( checkpoint, map_location=device )['state_dict'] else: state_dict = torch.load( checkpoint, map_location=device )['state_dict'] # load model model = ResNetSimCLR(**cfg['model']) if state_dict is not None: model.load_from_state_dict(state_dict) criterion = NTXentLoss(**cfg['criterion']) optimizer = torch.optim.SGD(model.parameters(), **cfg['optimizer']) dataset = LightlyDataset(root, name=data, train=True, download=download, from_folder=input_dir) cfg['loader']['batch_size'] = min( cfg['loader']['batch_size'], len(dataset) ) collate_fn = ImageCollateFunction(**cfg['collate']) dataloader = torch.utils.data.DataLoader(dataset, **cfg['loader'], collate_fn=collate_fn) encoder = SelfSupervisedEmbedding(model, criterion, optimizer, dataloader) encoder.init_checkpoint_callback(**cfg['checkpoint_callback']) encoder = encoder.train_embedding(**cfg['trainer']) print('Best model is stored at: %s' % (encoder.checkpoint)) return encoder.checkpoint
def _embed_cli(cfg, is_cli_call=True): data = cfg['data'] train = cfg.get('train', True) checkpoint = cfg['checkpoint'] download = cfg['download'] root = cfg['root'] if root and is_cli_call: root = fix_input_path(root) input_dir = cfg['input_dir'] if input_dir and is_cli_call: input_dir = fix_input_path(input_dir) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') transform = torchvision.transforms.Compose([ torchvision.transforms.Resize( (cfg['collate']['input_size'], cfg['collate']['input_size'])), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) dataset = LightlyDataset(root, name=data, train=train, download=download, from_folder=input_dir, transform=transform) cfg['loader']['drop_last'] = False cfg['loader']['shuffle'] = False cfg['loader']['batch_size'] = min(cfg['loader']['batch_size'], len(dataset)) dataloader = torch.utils.data.DataLoader(dataset, **cfg['loader']) # load the PyTorch state dictionary and map it to the current device state_dict = None if not checkpoint: checkpoint, key = get_ptmodel_from_config(cfg['model']) if not checkpoint: msg = 'Cannot download checkpoint for key {} '.format(key) msg += 'because it does not exist!' raise RuntimeError(msg) state_dict = load_state_dict_from_url( checkpoint, map_location=device)['state_dict'] else: checkpoint = fix_input_path(checkpoint) if is_cli_call else checkpoint state_dict = torch.load(checkpoint, map_location=device)['state_dict'] model = ResNetSimCLR(**cfg['model']).to(device) if state_dict is not None: model.load_from_state_dict(state_dict) encoder = SelfSupervisedEmbedding(model, None, None, None) embeddings, labels, filenames = encoder.embed(dataloader, device=device) if is_cli_call: path = os.path.join(os.getcwd(), 'embeddings.csv') save_embeddings(path, embeddings, labels, filenames) print('Embeddings are stored at %s' % (path)) return path return embeddings, labels, filenames
def _embed_cli(cfg, is_cli_call=True): checkpoint = cfg['checkpoint'] input_dir = cfg['input_dir'] if input_dir and is_cli_call: input_dir = fix_input_path(input_dir) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') transform = torchvision.transforms.Compose([ torchvision.transforms.Resize((cfg['collate']['input_size'], cfg['collate']['input_size'])), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) dataset = LightlyDataset(input_dir, transform=transform) cfg['loader']['drop_last'] = False cfg['loader']['shuffle'] = False cfg['loader']['batch_size'] = min( cfg['loader']['batch_size'], len(dataset) ) dataloader = torch.utils.data.DataLoader(dataset, **cfg['loader']) # load the PyTorch state dictionary and map it to the current device state_dict = None if not checkpoint: checkpoint, key = get_ptmodel_from_config(cfg['model']) if not checkpoint: msg = 'Cannot download checkpoint for key {} '.format(key) msg += 'because it does not exist!' raise RuntimeError(msg) state_dict = load_state_dict_from_url( checkpoint, map_location=device )['state_dict'] else: checkpoint = fix_input_path(checkpoint) if is_cli_call else checkpoint state_dict = torch.load( checkpoint, map_location=device )['state_dict'] # load model resnet = ResNetGenerator(cfg['model']['name'], cfg['model']['width']) last_conv_channels = list(resnet.children())[-1].in_features features = nn.Sequential( get_norm_layer(3, 0), *list(resnet.children())[:-1], nn.Conv2d(last_conv_channels, cfg['model']['num_ftrs'], 1), nn.AdaptiveAvgPool2d(1), ) model = SimCLR( features, num_ftrs=cfg['model']['num_ftrs'], out_dim=cfg['model']['out_dim'] ).to(device) if state_dict is not None: load_from_state_dict(model, state_dict) encoder = SelfSupervisedEmbedding(model, None, None, None) embeddings, labels, filenames = encoder.embed(dataloader, device=device) if is_cli_call: path = os.path.join(os.getcwd(), 'embeddings.csv') save_embeddings(path, embeddings, labels, filenames) print(f'Embeddings are stored at {bcolors.OKBLUE}{path}{bcolors.ENDC}') return path return embeddings, labels, filenames