Пример #1
0
def _download_cli(cfg, is_cli_call=True):

    tag_name = cfg['tag_name']
    dataset_id = cfg['dataset_id']
    token = cfg['token']

    if not tag_name:
        print('Please specify a tag name')
        print('For help, try: lightly-download --help')
        return

    if not token or not dataset_id:
        print('Please specify your access token and dataset id')
        print('For help, try: lightly-download --help')
        return

    api_workflow_client = ApiWorkflowClient(token=token, dataset_id=dataset_id)

    # get tag id
    tag_name_id_dict = dict([tag.name, tag.id]
                            for tag in api_workflow_client._get_all_tags())
    tag_id = tag_name_id_dict.get(tag_name, None)
    if tag_id is None:
        print(f'The specified tag {tag_name} does not exist.')
        return

    # get tag data
    tag_data = api_workflow_client.tags_api.get_tag_by_tag_id(
        dataset_id=dataset_id, tag_id=tag_id)

    # get samples
    chosen_samples_ids = BitMask.from_hex(tag_data.bit_mask_data).to_indices()
    samples = [
        api_workflow_client.filenames_on_server[i] for i in chosen_samples_ids
    ]

    # store sample names in a .txt file
    with open(cfg['tag_name'] + '.txt', 'w') as f:
        for item in samples:
            f.write("%s\n" % item)

    msg = 'The list of files in tag {} is stored at: '.format(cfg['tag_name'])
    msg += os.path.join(os.getcwd(), cfg['tag_name'] + '.txt')
    print(msg, flush=True)

    if not cfg['input_dir'] and cfg['output_dir']:
        # download full images from api
        output_dir = fix_input_path(cfg['output_dir'])
        api_workflow_client.download_dataset(output_dir, tag_name=tag_name)

    elif cfg['input_dir'] and cfg['output_dir']:
        input_dir = fix_input_path(cfg['input_dir'])
        output_dir = fix_input_path(cfg['output_dir'])
        print(f'Copying files from {input_dir} to {output_dir}.')

        # create a dataset from the input directory
        dataset = data.LightlyDataset(input_dir=input_dir)

        # dump the dataset in the output directory
        dataset.dump(output_dir, samples)
Пример #2
0
def _upload_cli(cfg, is_cli_call=True):
    input_dir = cfg['input_dir']
    if input_dir and is_cli_call:
        input_dir = fix_input_path(input_dir)

    path_to_embeddings = cfg['embeddings']
    if path_to_embeddings and is_cli_call:
        path_to_embeddings = fix_input_path(path_to_embeddings)

    dataset_id = cfg['dataset_id']
    token = cfg['token']
    new_dataset_name = cfg['new_dataset_name']

    cli_api_args_wrong = False
    if not token:
        print_as_warning('Please specify your access token.')
        cli_api_args_wrong = True

    dataset_id_ok = dataset_id and len(dataset_id) > 0
    new_dataset_name_ok = new_dataset_name and len(new_dataset_name) > 0
    if new_dataset_name_ok and not dataset_id_ok:
        api_workflow_client = ApiWorkflowClient(token=token)
        api_workflow_client.create_dataset(dataset_name=new_dataset_name)
    elif dataset_id_ok and not new_dataset_name_ok:
        api_workflow_client = ApiWorkflowClient(token=token, dataset_id=dataset_id)
    else:
        print_as_warning('Please specify either the dataset_id of an existing dataset or a new_dataset_name.')
        cli_api_args_wrong = True

    if cli_api_args_wrong:
        print_as_warning('For help, try: lightly-upload --help')
        return

    size = cfg['resize']
    if not isinstance(size, int):
        size = tuple(size)
    transform = None
    if isinstance(size, tuple) or size > 0:
        transform = torchvision.transforms.Resize(size)

    if input_dir:
        mode = cfg['upload']
        dataset = LightlyDataset(input_dir=input_dir, transform=transform)
        api_workflow_client.upload_dataset(
            input=dataset, mode=mode, max_workers=cfg['loader']['num_workers']
        )
        print(f"Finished the upload of the dataset.")

    if path_to_embeddings:
        name = cfg['embedding_name']
        print("Starting upload of embeddings.")
        api_workflow_client.upload_embeddings(
            path_to_embeddings_csv=path_to_embeddings, name=name
        )
        print("Finished upload of embeddings.")

    if new_dataset_name_ok:
        print(f'The dataset_id of the newly created dataset is '
              f'{bcolors.OKBLUE}{api_workflow_client.dataset_id}{bcolors.ENDC}')
Пример #3
0
def _download_cli(cfg, is_cli_call=True):

    tag_name = cfg['tag_name']
    dataset_id = cfg['dataset_id']
    token = cfg['token']

    if not tag_name:
        print('Please specify a tag name')
        print('For help, try: lightly-download --help')
        return

    if not token or not dataset_id:
        print('Please specify your access token and dataset id')
        print('For help, try: lightly-download --help')
        return

    # get all samples in the queried tag
    samples = get_samples_by_tag(tag_name,
                                 dataset_id,
                                 token,
                                 mode='list',
                                 filenames=None)

    # store sample names in a .txt file
    with open(cfg['tag_name'] + '.txt', 'w') as f:
        for item in samples:
            f.write("%s\n" % item)
    msg = 'The list of files in tag {} is stored at: '.format(cfg['tag_name'])
    msg += os.path.join(os.getcwd(), cfg['tag_name'] + '.txt')
    print(msg)

    if cfg['input_dir'] and cfg['output_dir']:
        # "name.jpg" -> "/name.jpg" to prevent bugs like this:
        # "path/to/1234.jpg" ends with both "234.jpg" and "1234.jpg"
        samples = [os.path.join(' ', s)[1:] for s in samples]

        # copy all images from one folder to the other
        input_dir = fix_input_path(cfg['input_dir'])
        output_dir = fix_input_path(cfg['output_dir'])

        dataset = data.LightlyDataset(from_folder=input_dir)
        basenames = dataset.get_filenames()

        source_names = [os.path.join(input_dir, f) for f in basenames]
        target_names = [os.path.join(output_dir, f) for f in basenames]

        # only copy files which are in the tag
        indices = [
            i for i in range(len(source_names))
            if any([source_names[i].endswith(s) for s in samples])
        ]

        print(f'Copying files from {input_dir} to {output_dir}.')
        for i in tqdm(indices):
            dirname = os.path.dirname(target_names[i])
            os.makedirs(dirname, exist_ok=True)
            shutil.copy(source_names[i], target_names[i])
Пример #4
0
def _upload_cli(cfg, is_cli_call=True):
    input_dir = cfg['input_dir']
    if input_dir and is_cli_call:
        input_dir = fix_input_path(input_dir)

    path_to_embeddings = cfg['embeddings']
    if path_to_embeddings and is_cli_call:
        path_to_embeddings = fix_input_path(path_to_embeddings)

    dataset_id = cfg['dataset_id']
    token = cfg['token']
    new_dataset_name = cfg['new_dataset_name']

    if not token:
        warnings.warn('Please specify your access token. For help, try: lightly-upload --help')
        return

    dataset_id_ok = dataset_id and len(dataset_id) > 0
    new_dataset_name_ok = new_dataset_name and len(new_dataset_name) > 0
    if new_dataset_name_ok and not dataset_id_ok:
        api_workflow_client = ApiWorkflowClient(token=token)
        api_workflow_client.create_dataset(dataset_name=new_dataset_name)
    elif dataset_id_ok and not new_dataset_name_ok:
        api_workflow_client = ApiWorkflowClient(token=token, dataset_id=dataset_id)
    else:
        warnings.warn('Please specify either the dataset_id of an existing dataset or a new_dataset_name. '
                      'For help, try: lightly-upload --help')
        return

    size = cfg['resize']
    if not isinstance(size, int):
        size = tuple(size)
    transform = None
    if isinstance(size, tuple) or size > 0:
        transform = torchvision.transforms.Resize(size)

    if input_dir:
        mode = cfg['upload']
        dataset = LightlyDataset(input_dir=input_dir, transform=transform)
        api_workflow_client.upload_dataset(
            input=dataset, mode=mode, max_workers=cfg['loader']['num_workers']
        )

    if path_to_embeddings:
        name = cfg['embedding_name']
        api_workflow_client.upload_embeddings(
            path_to_embeddings_csv=path_to_embeddings, name=name
        )
Пример #5
0
def _download_cli(cfg, is_cli_call=True):

    tag_name = str(cfg['tag_name'])
    dataset_id = str(cfg['dataset_id'])
    token = str(cfg['token'])

    if not tag_name or not token or not dataset_id:
        print_as_warning('Please specify all of the parameters tag_name, token and dataset_id')
        print_as_warning('For help, try: lightly-download --help')
        return

    api_workflow_client = ApiWorkflowClient(
        token=token, dataset_id=dataset_id
    )

    # get tag id
    tag_data = api_workflow_client.get_tag_by_name(tag_name)
    filenames_tag = api_workflow_client.get_filenames_in_tag(
        tag_data,
        exclude_parent_tag=cfg['exclude_parent_tag'],
    )

    # store sample names in a .txt file
    filename = tag_name + '.txt'
    with open(filename, 'w') as f:
        for item in filenames_tag:
            f.write("%s\n" % item)

    filepath = os.path.join(os.getcwd(), filename)
    msg = f'The list of files in tag {cfg["tag_name"]} is stored at: {bcolors.OKBLUE}{filepath}{bcolors.ENDC}'
    print(msg, flush=True)

    if not cfg['input_dir'] and cfg['output_dir']:
        # download full images from api
        output_dir = fix_input_path(cfg['output_dir'])
        api_workflow_client.download_dataset(output_dir, tag_name=tag_name)

    elif cfg['input_dir'] and cfg['output_dir']:
        input_dir = fix_input_path(cfg['input_dir'])
        output_dir = fix_input_path(cfg['output_dir'])
        print(f'Copying files from {input_dir} to {bcolors.OKBLUE}{output_dir}{bcolors.ENDC}.')

        # create a dataset from the input directory
        dataset = data.LightlyDataset(input_dir=input_dir)

        # dump the dataset in the output directory
        dataset.dump(output_dir, filenames_tag)
Пример #6
0
def _download_cli(cfg, is_cli_call=True):

    tag_name = cfg['tag_name']
    dataset_id = cfg['dataset_id']
    token = cfg['token']

    if not tag_name:
        print('Please specify a tag name')
        print('For help, try: lightly-download --help')
        return

    if not token or not dataset_id:
        print('Please specify your access token and dataset id')
        print('For help, try: lightly-download --help')
        return

    # get all samples in the queried tag
    samples = get_samples_by_tag(tag_name,
                                 dataset_id,
                                 token,
                                 mode='list',
                                 filenames=None)

    # store sample names in a .txt file
    with open(cfg['tag_name'] + '.txt', 'w') as f:
        for item in samples:
            f.write("%s\n" % item)

    msg = 'The list of files in tag {} is stored at: '.format(cfg['tag_name'])
    msg += os.path.join(os.getcwd(), cfg['tag_name'] + '.txt')
    print(msg, flush=True)

    if cfg['input_dir'] and cfg['output_dir']:

        input_dir = fix_input_path(cfg['input_dir'])
        output_dir = fix_input_path(cfg['output_dir'])
        print(f'Copying files from {input_dir} to {output_dir}.')

        # create a dataset from the input directory
        dataset = data.LightlyDataset(from_folder=input_dir)

        # dump the dataset in the output directory
        dataset.dump(output_dir, samples)
Пример #7
0
def _crop_cli(cfg, is_cli_call=True):
    input_dir = cfg['input_dir']
    if input_dir and is_cli_call:
        input_dir = fix_input_path(input_dir)
    output_dir = cfg['output_dir']
    if output_dir and is_cli_call:
        output_dir = fix_input_path(output_dir)
    label_dir = cfg['label_dir']
    if label_dir and is_cli_call:
        label_dir = fix_input_path(label_dir)
    label_names_file = cfg['label_names_file']
    if label_names_file and len(label_names_file) > 0:
        if is_cli_call:
            label_names_file = fix_input_path(label_names_file)
        with open(label_names_file, 'r') as file:
            label_names_file_dict = yaml.full_load(file)
        class_names = label_names_file_dict['names']
    else:
        class_names = None

    dataset = LightlyDataset(input_dir)

    class_indices_list_list: List[List[int]] = []
    bounding_boxes_list_list: List[List[BoundingBox]] = []

    # YOLO-Specific
    for filename_image in dataset.get_filenames():
        filepath_image_base, image_extension = os.path.splitext(filename_image)
        filepath_label = os.path.join(label_dir, filename_image).replace(
            image_extension, '.txt')
        class_indices, bounding_boxes = read_yolo_label_file(
            filepath_label, float(cfg['crop_padding']))
        class_indices_list_list.append(class_indices)
        bounding_boxes_list_list.append(bounding_boxes)

    cropped_images_list_list = \
        crop_dataset_by_bounding_boxes_and_save(dataset, output_dir, bounding_boxes_list_list, class_indices_list_list, class_names)

    print(
        f'Cropped images are stored at: {bcolors.OKBLUE}{output_dir}{bcolors.ENDC}'
    )
    return cropped_images_list_list
Пример #8
0
def _upload_cli(cfg, is_cli_call=True):

    input_dir = cfg['input_dir']
    if input_dir and is_cli_call:
        input_dir = fix_input_path(input_dir)

    path_to_embeddings = cfg['embeddings']
    if path_to_embeddings and is_cli_call:
        path_to_embeddings = fix_input_path(path_to_embeddings)

    dataset_id = cfg['dataset_id']
    token = cfg['token']

    size = cfg['resize']
    if not isinstance(size, int):
        size = tuple(size)

    if not token or not dataset_id:
        print('Please specify your access token and dataset id.')
        print('For help, try: lightly-upload --help')
        return

    if input_dir:
        mode = cfg['upload']
        try:
            upload_images_from_folder(input_dir,
                                      dataset_id,
                                      token,
                                      mode=mode,
                                      size=size)
        except (ValueError, ConnectionRefusedError) as error:
            msg = f'Error: {error}'
            print(msg)
            exit(0)

    if path_to_embeddings:
        max_upload = cfg['emb_upload_bsz']
        upload_embeddings_from_csv(path_to_embeddings,
                                   dataset_id,
                                   token,
                                   max_upload=max_upload,
                                   embedding_name=cfg['embedding_name'])
Пример #9
0
def _upload_cli(cfg, is_cli_call=True):

    input_dir = cfg['input_dir']
    if input_dir and is_cli_call:
        input_dir = fix_input_path(input_dir)

    path_to_embeddings = cfg['embeddings']
    if path_to_embeddings and is_cli_call:
        path_to_embeddings = fix_input_path(path_to_embeddings)

    dataset_id = cfg['dataset_id']
    token = cfg['token']

    size = cfg['resize']
    if not isinstance(size, int):
        size = tuple(size)
    transform = None
    if isinstance(size, tuple) or size > 0:
        transform = torchvision.transforms.Resize(size)

    if not token or not dataset_id:
        print('Please specify your access token and dataset id.')
        print('For help, try: lightly-upload --help')
        return

    api_workflow_client = ApiWorkflowClient(token=token, dataset_id=dataset_id)

    if input_dir:
        mode = cfg['upload']
        dataset = LightlyDataset(input_dir=input_dir, transform=transform)
        api_workflow_client.upload_dataset(input=dataset, mode=mode)

    if path_to_embeddings:
        name = cfg['embedding_name']
        api_workflow_client.upload_embeddings(
            path_to_embeddings_csv=path_to_embeddings, name=name)
Пример #10
0
def _embed_cli(cfg, is_cli_call=True):
    input_dir = cfg['input_dir']
    if input_dir and is_cli_call:
        input_dir = fix_input_path(input_dir)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize(
            (cfg['collate']['input_size'], cfg['collate']['input_size'])),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225]),
    ])

    dataset = LightlyDataset(input_dir, transform=transform)

    # disable drop_last and shuffle
    cfg['loader']['drop_last'] = False
    cfg['loader']['shuffle'] = False
    cfg['loader']['batch_size'] = min(cfg['loader']['batch_size'],
                                      len(dataset))

    # determine the number of available cores
    if cfg['loader']['num_workers'] < 0:
        cfg['loader']['num_workers'] = cpu_count()

    dataloader = torch.utils.data.DataLoader(dataset, **cfg['loader'])

    encoder = get_model_from_config(cfg, is_cli_call)

    embeddings, labels, filenames = encoder.embed(dataloader, device=device)

    if is_cli_call:
        path = os.path.join(os.getcwd(), 'embeddings.csv')
        save_embeddings(path, embeddings, labels, filenames)
        print(f'Embeddings are stored at {bcolors.OKBLUE}{path}{bcolors.ENDC}')
        os.environ[cfg['environment_variable_names']
                   ['lightly_last_embedding_path']] = path
        return path

    return embeddings, labels, filenames
Пример #11
0
def _upload_cli(cfg, is_cli_call=True):
    input_dir = cfg['input_dir']
    if input_dir and is_cli_call:
        input_dir = fix_input_path(input_dir)

    path_to_embeddings = cfg['embeddings']
    if path_to_embeddings and is_cli_call:
        path_to_embeddings = fix_input_path(path_to_embeddings)

    dataset_id = cfg['dataset_id']
    token = cfg['token']
    new_dataset_name = cfg['new_dataset_name']

    cli_api_args_wrong = False
    if not token:
        print_as_warning('Please specify your access token.')
        cli_api_args_wrong = True

    if dataset_id:
        if new_dataset_name:
            print_as_warning(
                'Please specify either the dataset_id of an existing dataset '
                'or a new_dataset_name, but not both.')
            cli_api_args_wrong = True
        else:
            api_workflow_client = \
                ApiWorkflowClient(token=token, dataset_id=dataset_id)
    else:
        if new_dataset_name:
            api_workflow_client = ApiWorkflowClient(token=token)
            api_workflow_client.create_dataset(dataset_name=new_dataset_name)
        else:
            print_as_warning(
                'Please specify either the dataset_id of an existing dataset '
                'or a new_dataset_name.')
            cli_api_args_wrong = True
    # delete the dataset_id as it might be an empty string
    # Use api_workflow_client.dataset_id instead
    del dataset_id

    if cli_api_args_wrong:
        print_as_warning('For help, try: lightly-upload --help')
        return

    # potentially load custom metadata
    custom_metadata = None
    if cfg['custom_metadata']:
        path_to_custom_metadata = fix_input_path(cfg['custom_metadata'])
        print('Loading custom metadata from '
              f'{bcolors.OKBLUE}{path_to_custom_metadata}{bcolors.ENDC}')
        with open(path_to_custom_metadata, 'r') as f:
            custom_metadata = json.load(f)

    # set the number of workers if unset
    if cfg['loader']['num_workers'] < 0:
        # set the number of workers to the number of CPUs available,
        # but minimum of 8
        num_workers = max(8, cpu_count())
        num_workers = min(32, num_workers)
        cfg['loader']['num_workers'] = num_workers

    size = cfg['resize']
    if not isinstance(size, int):
        size = tuple(size)
    transform = None
    if isinstance(size, tuple) or size > 0:
        transform = torchvision.transforms.Resize(size)

    if input_dir:
        mode = cfg['upload']
        dataset = LightlyDataset(input_dir=input_dir, transform=transform)
        api_workflow_client.upload_dataset(
            input=dataset,
            mode=mode,
            max_workers=cfg['loader']['num_workers'],
            custom_metadata=custom_metadata,
        )
        print('Finished the upload of the dataset.')

    if path_to_embeddings:
        name = cfg['embedding_name']
        print('Starting upload of embeddings.')
        api_workflow_client.upload_embeddings(
            path_to_embeddings_csv=path_to_embeddings, name=name)
        print('Finished upload of embeddings.')

    if custom_metadata is not None and not input_dir:
        # upload custom metadata separately
        api_workflow_client.upload_custom_metadata(
            custom_metadata,
            verbose=True,
            max_workers=cfg['loader']['num_workers'],
        )

    if new_dataset_name:
        print(
            f'The dataset_id of the newly created dataset is '
            f'{bcolors.OKBLUE}{api_workflow_client.dataset_id}{bcolors.ENDC}')

    os.environ[cfg['environment_variable_names']
               ['lightly_last_dataset_id']] = api_workflow_client.dataset_id
Пример #12
0
def _download_cli(cfg, is_cli_call=True):
    tag_name = cfg['tag_name']
    dataset_id = cfg['dataset_id']
    token = cfg['token']

    if not tag_name or not token or not dataset_id:
        print_as_warning(
            'Please specify all of the parameters tag_name, token and dataset_id'
        )
        print_as_warning('For help, try: lightly-download --help')
        return

    api_workflow_client = ApiWorkflowClient(token=token, dataset_id=dataset_id)

    # get tag id
    tag_name_id_dict = dict([tag.name, tag.id]
                            for tag in api_workflow_client._get_all_tags())
    tag_id = tag_name_id_dict.get(tag_name, None)
    if tag_id is None:
        warnings.warn(f'The specified tag {tag_name} does not exist.')
        return

    # get tag data
    tag_data: TagData = api_workflow_client.tags_api.get_tag_by_tag_id(
        dataset_id=dataset_id, tag_id=tag_id)

    if cfg["exclude_parent_tag"]:
        parent_tag_id = tag_data.prev_tag_id
        tag_arithmetics_request = TagArithmeticsRequest(
            tag_id1=tag_data.id,
            tag_id2=parent_tag_id,
            operation=TagArithmeticsOperation.DIFFERENCE)
        bit_mask_response: TagBitMaskResponse \
            = api_workflow_client.tags_api.perform_tag_arithmetics(body=tag_arithmetics_request, dataset_id=dataset_id)
        bit_mask_data = bit_mask_response.bit_mask_data
    else:
        bit_mask_data = tag_data.bit_mask_data

    # get samples
    chosen_samples_ids = BitMask.from_hex(bit_mask_data).to_indices()
    samples = [
        api_workflow_client.filenames_on_server[i] for i in chosen_samples_ids
    ]

    # store sample names in a .txt file
    filename = cfg['tag_name'] + '.txt'
    with open(filename, 'w') as f:
        for item in samples:
            f.write("%s\n" % item)

    filepath = os.path.join(os.getcwd(), filename)
    msg = f'The list of files in tag {cfg["tag_name"]} is stored at: {bcolors.OKBLUE}{filepath}{bcolors.ENDC}'
    print(msg, flush=True)

    if not cfg['input_dir'] and cfg['output_dir']:
        # download full images from api
        output_dir = fix_input_path(cfg['output_dir'])
        api_workflow_client.download_dataset(output_dir, tag_name=tag_name)

    elif cfg['input_dir'] and cfg['output_dir']:
        input_dir = fix_input_path(cfg['input_dir'])
        output_dir = fix_input_path(cfg['output_dir'])
        print(
            f'Copying files from {input_dir} to {bcolors.OKBLUE}{output_dir}{bcolors.ENDC}.'
        )

        # create a dataset from the input directory
        dataset = data.LightlyDataset(input_dir=input_dir)

        # dump the dataset in the output directory
        dataset.dump(output_dir, samples)
Пример #13
0
def _train_cli(cfg, is_cli_call=True):

    input_dir = cfg['input_dir']
    if input_dir and is_cli_call:
        input_dir = fix_input_path(input_dir)

    if 'seed' in cfg.keys():
        seed = cfg['seed']
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    if cfg["trainer"]["weights_summary"] == "None":
        cfg["trainer"]["weights_summary"] = None

    if torch.cuda.is_available():
        device = 'cuda'
    elif cfg['trainer'] and cfg['trainer']['gpus']:
        device = 'cpu'
        cfg['trainer']['gpus'] = 0
    else:
        device = 'cpu'

    distributed_strategy = None
    if cfg['trainer']['gpus'] > 1:
        distributed_strategy = 'ddp'

    if cfg['loader']['batch_size'] < 64:
        msg = 'Training a self-supervised model with a small batch size: {}! '
        msg = msg.format(cfg['loader']['batch_size'])
        msg += 'Small batch size may harm embedding quality. '
        msg += 'You can specify the batch size via the loader key-word: '
        msg += 'loader.batch_size=BSZ'
        warnings.warn(msg)

    # determine the number of available cores
    if cfg['loader']['num_workers'] < 0:
        cfg['loader']['num_workers'] = cpu_count()

    state_dict = None
    checkpoint = cfg['checkpoint']
    if cfg['pre_trained'] and not checkpoint:
        # if checkpoint wasn't specified explicitly and pre_trained is True
        # try to load the checkpoint from the model zoo
        checkpoint, key = get_ptmodel_from_config(cfg['model'])
        if not checkpoint:
            msg = 'Cannot download checkpoint for key {} '.format(key)
            msg += 'because it does not exist! '
            msg += 'Model will be trained from scratch.'
            warnings.warn(msg)
    elif checkpoint:
        checkpoint = fix_input_path(checkpoint) if is_cli_call else checkpoint

    if checkpoint:
        # load the PyTorch state dictionary and map it to the current device
        if is_url(checkpoint):
            state_dict = load_state_dict_from_url(
                checkpoint, map_location=device)['state_dict']
        else:
            state_dict = torch.load(checkpoint,
                                    map_location=device)['state_dict']

    # load model
    resnet = ResNetGenerator(cfg['model']['name'], cfg['model']['width'])
    last_conv_channels = list(resnet.children())[-1].in_features
    features = nn.Sequential(
        get_norm_layer(3, 0),
        *list(resnet.children())[:-1],
        nn.Conv2d(last_conv_channels, cfg['model']['num_ftrs'], 1),
        nn.AdaptiveAvgPool2d(1),
    )

    model = _SimCLR(features,
                    num_ftrs=cfg['model']['num_ftrs'],
                    out_dim=cfg['model']['out_dim'])
    if state_dict is not None:
        load_from_state_dict(model, state_dict)

    criterion = NTXentLoss(**cfg['criterion'])
    optimizer = torch.optim.SGD(model.parameters(), **cfg['optimizer'])

    dataset = LightlyDataset(input_dir)

    cfg['loader']['batch_size'] = min(cfg['loader']['batch_size'],
                                      len(dataset))

    collate_fn = ImageCollateFunction(**cfg['collate'])
    dataloader = torch.utils.data.DataLoader(dataset,
                                             **cfg['loader'],
                                             collate_fn=collate_fn)

    encoder = SelfSupervisedEmbedding(model, criterion, optimizer, dataloader)
    encoder.init_checkpoint_callback(**cfg['checkpoint_callback'])
    encoder.train_embedding(**cfg['trainer'], strategy=distributed_strategy)

    print(
        f'Best model is stored at: {bcolors.OKBLUE}{encoder.checkpoint}{bcolors.ENDC}'
    )
    os.environ[cfg['environment_variable_names']
               ['lightly_last_checkpoint_path']] = encoder.checkpoint
    return encoder.checkpoint
Пример #14
0
def _train_cli(cfg, is_cli_call=True):

    data = cfg['data']
    download = cfg['download']

    root = cfg['root']
    if root and is_cli_call:
        root = fix_input_path(root)

    input_dir = cfg['input_dir']
    if input_dir and is_cli_call:
        input_dir = fix_input_path(input_dir)

    if 'seed' in cfg.keys():
        seed = cfg['seed']
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    if torch.cuda.is_available():
        device = 'cuda'
    elif cfg['trainer'] and cfg['trainer']['gpus']:
        device = 'cpu'
        cfg['trainer']['gpus'] = 0

    if cfg['loader']['batch_size'] < 64:
        msg = 'Training a self-supervised model with a small batch size: {}! '
        msg = msg.format(cfg['loader']['batch_size'])
        msg += 'Small batch size may harm embedding quality. '
        msg += 'You can specify the batch size via the loader key-word: '
        msg += 'loader.batch_size=BSZ'
        warnings.warn(msg)

    state_dict = None
    checkpoint = cfg['checkpoint']
    if cfg['pre_trained'] and not checkpoint:
        # if checkpoint wasn't specified explicitly and pre_trained is True
        # try to load the checkpoint from the model zoo
        checkpoint, key = get_ptmodel_from_config(cfg['model'])
        if not checkpoint:
            msg = 'Cannot download checkpoint for key {} '.format(key)
            msg += 'because it does not exist! '
            msg += 'Model will be trained from scratch.'
            warnings.warn(msg)
    elif checkpoint:
        checkpoint = fix_input_path(checkpoint) if is_cli_call else checkpoint
    
    if checkpoint:
        # load the PyTorch state dictionary and map it to the current device
        if is_url(checkpoint):
            state_dict = load_state_dict_from_url(
                checkpoint, map_location=device
            )['state_dict']
        else:
            state_dict = torch.load(
                checkpoint, map_location=device
            )['state_dict']

    # load model
    model = ResNetSimCLR(**cfg['model'])
    if state_dict is not None:
        model.load_from_state_dict(state_dict)

    criterion = NTXentLoss(**cfg['criterion'])
    optimizer = torch.optim.SGD(model.parameters(), **cfg['optimizer'])

    dataset = LightlyDataset(root,
                           name=data, train=True, download=download,
                           from_folder=input_dir)

    cfg['loader']['batch_size'] = min(
        cfg['loader']['batch_size'],
        len(dataset)
    )

    collate_fn = ImageCollateFunction(**cfg['collate'])
    dataloader = torch.utils.data.DataLoader(dataset,
                                             **cfg['loader'],
                                             collate_fn=collate_fn)

    encoder = SelfSupervisedEmbedding(model, criterion, optimizer, dataloader)
    encoder.init_checkpoint_callback(**cfg['checkpoint_callback'])
    encoder = encoder.train_embedding(**cfg['trainer'])

    print('Best model is stored at: %s' % (encoder.checkpoint))
    return encoder.checkpoint
Пример #15
0
def _embed_cli(cfg, is_cli_call=True):

    data = cfg['data']
    train = cfg.get('train', True)
    checkpoint = cfg['checkpoint']
    download = cfg['download']

    root = cfg['root']
    if root and is_cli_call:
        root = fix_input_path(root)

    input_dir = cfg['input_dir']
    if input_dir and is_cli_call:
        input_dir = fix_input_path(input_dir)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize(
            (cfg['collate']['input_size'], cfg['collate']['input_size'])),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
    ])

    dataset = LightlyDataset(root,
                             name=data,
                             train=train,
                             download=download,
                             from_folder=input_dir,
                             transform=transform)

    cfg['loader']['drop_last'] = False
    cfg['loader']['shuffle'] = False
    cfg['loader']['batch_size'] = min(cfg['loader']['batch_size'],
                                      len(dataset))
    dataloader = torch.utils.data.DataLoader(dataset, **cfg['loader'])

    # load the PyTorch state dictionary and map it to the current device
    state_dict = None
    if not checkpoint:
        checkpoint, key = get_ptmodel_from_config(cfg['model'])
        if not checkpoint:
            msg = 'Cannot download checkpoint for key {} '.format(key)
            msg += 'because it does not exist!'
            raise RuntimeError(msg)
        state_dict = load_state_dict_from_url(
            checkpoint, map_location=device)['state_dict']
    else:
        checkpoint = fix_input_path(checkpoint) if is_cli_call else checkpoint
        state_dict = torch.load(checkpoint, map_location=device)['state_dict']

    model = ResNetSimCLR(**cfg['model']).to(device)
    if state_dict is not None:
        model.load_from_state_dict(state_dict)

    encoder = SelfSupervisedEmbedding(model, None, None, None)
    embeddings, labels, filenames = encoder.embed(dataloader, device=device)

    if is_cli_call:
        path = os.path.join(os.getcwd(), 'embeddings.csv')
        save_embeddings(path, embeddings, labels, filenames)
        print('Embeddings are stored at %s' % (path))
        return path

    return embeddings, labels, filenames
Пример #16
0
def _embed_cli(cfg, is_cli_call=True):

    checkpoint = cfg['checkpoint']

    input_dir = cfg['input_dir']
    if input_dir and is_cli_call:
        input_dir = fix_input_path(input_dir)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((cfg['collate']['input_size'],
                                       cfg['collate']['input_size'])),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])
    ])

    dataset = LightlyDataset(input_dir, transform=transform)

    cfg['loader']['drop_last'] = False
    cfg['loader']['shuffle'] = False
    cfg['loader']['batch_size'] = min(
        cfg['loader']['batch_size'],
        len(dataset)
    )
    dataloader = torch.utils.data.DataLoader(dataset, **cfg['loader'])

    # load the PyTorch state dictionary and map it to the current device    
    state_dict = None
    if not checkpoint:
        checkpoint, key = get_ptmodel_from_config(cfg['model'])
        if not checkpoint:
            msg = 'Cannot download checkpoint for key {} '.format(key)
            msg += 'because it does not exist!'
            raise RuntimeError(msg)
        state_dict = load_state_dict_from_url(
            checkpoint, map_location=device
        )['state_dict']
    else:
        checkpoint = fix_input_path(checkpoint) if is_cli_call else checkpoint
        state_dict = torch.load(
            checkpoint, map_location=device
        )['state_dict']

    # load model
    resnet = ResNetGenerator(cfg['model']['name'], cfg['model']['width'])
    last_conv_channels = list(resnet.children())[-1].in_features
    features = nn.Sequential(
        get_norm_layer(3, 0),
        *list(resnet.children())[:-1],
        nn.Conv2d(last_conv_channels, cfg['model']['num_ftrs'], 1),
        nn.AdaptiveAvgPool2d(1),
    )

    model = SimCLR(
        features,
        num_ftrs=cfg['model']['num_ftrs'],
        out_dim=cfg['model']['out_dim']
    ).to(device)

    if state_dict is not None:
        load_from_state_dict(model, state_dict)

    encoder = SelfSupervisedEmbedding(model, None, None, None)
    embeddings, labels, filenames = encoder.embed(dataloader, device=device)

    if is_cli_call:
        path = os.path.join(os.getcwd(), 'embeddings.csv')
        save_embeddings(path, embeddings, labels, filenames)
        print(f'Embeddings are stored at {bcolors.OKBLUE}{path}{bcolors.ENDC}')
        return path

    return embeddings, labels, filenames