Exemplo n.º 1
0
def _train_cli(cfg, is_cli_call=True):

    input_dir = cfg['input_dir']
    if input_dir and is_cli_call:
        input_dir = fix_input_path(input_dir)

    if 'seed' in cfg.keys():
        seed = cfg['seed']
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    if cfg["trainer"]["weights_summary"] == "None":
        cfg["trainer"]["weights_summary"] = None

    if torch.cuda.is_available():
        device = 'cuda'
    elif cfg['trainer'] and cfg['trainer']['gpus']:
        device = 'cpu'
        cfg['trainer']['gpus'] = 0
    else:
        device = 'cpu'

    distributed_strategy = None
    if cfg['trainer']['gpus'] > 1:
        distributed_strategy = 'ddp'

    if cfg['loader']['batch_size'] < 64:
        msg = 'Training a self-supervised model with a small batch size: {}! '
        msg = msg.format(cfg['loader']['batch_size'])
        msg += 'Small batch size may harm embedding quality. '
        msg += 'You can specify the batch size via the loader key-word: '
        msg += 'loader.batch_size=BSZ'
        warnings.warn(msg)

    # determine the number of available cores
    if cfg['loader']['num_workers'] < 0:
        cfg['loader']['num_workers'] = cpu_count()

    state_dict = None
    checkpoint = cfg['checkpoint']
    if cfg['pre_trained'] and not checkpoint:
        # if checkpoint wasn't specified explicitly and pre_trained is True
        # try to load the checkpoint from the model zoo
        checkpoint, key = get_ptmodel_from_config(cfg['model'])
        if not checkpoint:
            msg = 'Cannot download checkpoint for key {} '.format(key)
            msg += 'because it does not exist! '
            msg += 'Model will be trained from scratch.'
            warnings.warn(msg)
    elif checkpoint:
        checkpoint = fix_input_path(checkpoint) if is_cli_call else checkpoint

    if checkpoint:
        # load the PyTorch state dictionary and map it to the current device
        if is_url(checkpoint):
            state_dict = load_state_dict_from_url(
                checkpoint, map_location=device)['state_dict']
        else:
            state_dict = torch.load(checkpoint,
                                    map_location=device)['state_dict']

    # load model
    resnet = ResNetGenerator(cfg['model']['name'], cfg['model']['width'])
    last_conv_channels = list(resnet.children())[-1].in_features
    features = nn.Sequential(
        get_norm_layer(3, 0),
        *list(resnet.children())[:-1],
        nn.Conv2d(last_conv_channels, cfg['model']['num_ftrs'], 1),
        nn.AdaptiveAvgPool2d(1),
    )

    model = _SimCLR(features,
                    num_ftrs=cfg['model']['num_ftrs'],
                    out_dim=cfg['model']['out_dim'])
    if state_dict is not None:
        load_from_state_dict(model, state_dict)

    criterion = NTXentLoss(**cfg['criterion'])
    optimizer = torch.optim.SGD(model.parameters(), **cfg['optimizer'])

    dataset = LightlyDataset(input_dir)

    cfg['loader']['batch_size'] = min(cfg['loader']['batch_size'],
                                      len(dataset))

    collate_fn = ImageCollateFunction(**cfg['collate'])
    dataloader = torch.utils.data.DataLoader(dataset,
                                             **cfg['loader'],
                                             collate_fn=collate_fn)

    encoder = SelfSupervisedEmbedding(model, criterion, optimizer, dataloader)
    encoder.init_checkpoint_callback(**cfg['checkpoint_callback'])
    encoder.train_embedding(**cfg['trainer'], strategy=distributed_strategy)

    print(
        f'Best model is stored at: {bcolors.OKBLUE}{encoder.checkpoint}{bcolors.ENDC}'
    )
    os.environ[cfg['environment_variable_names']
               ['lightly_last_checkpoint_path']] = encoder.checkpoint
    return encoder.checkpoint
Exemplo n.º 2
0
def _train_cli(cfg, is_cli_call=True):

    data = cfg['data']
    download = cfg['download']

    root = cfg['root']
    if root and is_cli_call:
        root = fix_input_path(root)

    input_dir = cfg['input_dir']
    if input_dir and is_cli_call:
        input_dir = fix_input_path(input_dir)

    if 'seed' in cfg.keys():
        seed = cfg['seed']
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    if torch.cuda.is_available():
        device = 'cuda'
    elif cfg['trainer'] and cfg['trainer']['gpus']:
        device = 'cpu'
        cfg['trainer']['gpus'] = 0

    if cfg['loader']['batch_size'] < 64:
        msg = 'Training a self-supervised model with a small batch size: {}! '
        msg = msg.format(cfg['loader']['batch_size'])
        msg += 'Small batch size may harm embedding quality. '
        msg += 'You can specify the batch size via the loader key-word: '
        msg += 'loader.batch_size=BSZ'
        warnings.warn(msg)

    state_dict = None
    checkpoint = cfg['checkpoint']
    if cfg['pre_trained'] and not checkpoint:
        # if checkpoint wasn't specified explicitly and pre_trained is True
        # try to load the checkpoint from the model zoo
        checkpoint, key = get_ptmodel_from_config(cfg['model'])
        if not checkpoint:
            msg = 'Cannot download checkpoint for key {} '.format(key)
            msg += 'because it does not exist! '
            msg += 'Model will be trained from scratch.'
            warnings.warn(msg)
    elif checkpoint:
        checkpoint = fix_input_path(checkpoint) if is_cli_call else checkpoint
    
    if checkpoint:
        # load the PyTorch state dictionary and map it to the current device
        if is_url(checkpoint):
            state_dict = load_state_dict_from_url(
                checkpoint, map_location=device
            )['state_dict']
        else:
            state_dict = torch.load(
                checkpoint, map_location=device
            )['state_dict']

    # load model
    model = ResNetSimCLR(**cfg['model'])
    if state_dict is not None:
        model.load_from_state_dict(state_dict)

    criterion = NTXentLoss(**cfg['criterion'])
    optimizer = torch.optim.SGD(model.parameters(), **cfg['optimizer'])

    dataset = LightlyDataset(root,
                           name=data, train=True, download=download,
                           from_folder=input_dir)

    cfg['loader']['batch_size'] = min(
        cfg['loader']['batch_size'],
        len(dataset)
    )

    collate_fn = ImageCollateFunction(**cfg['collate'])
    dataloader = torch.utils.data.DataLoader(dataset,
                                             **cfg['loader'],
                                             collate_fn=collate_fn)

    encoder = SelfSupervisedEmbedding(model, criterion, optimizer, dataloader)
    encoder.init_checkpoint_callback(**cfg['checkpoint_callback'])
    encoder = encoder.train_embedding(**cfg['trainer'])

    print('Best model is stored at: %s' % (encoder.checkpoint))
    return encoder.checkpoint