예제 #1
0
def process_components(
    model: RunnerModel,
    criterion: Criterion = None,
    optimizer: Optimizer = None,
    scheduler: Scheduler = None,
    distributed_params: Dict = None,
    device: Device = None,
) -> Tuple[RunnerModel, Criterion, Optimizer, Scheduler, Device]:
    """
    Returns the processed model, criterion, optimizer, scheduler and device.

    Args:
        model: torch model
        criterion: criterion function
        optimizer: optimizer
        scheduler: scheduler
        distributed_params (dict, optional): dict with the parameters
            for distributed and FP16 method
        device (Device, optional): device

    Returns:
        tuple with processed model, criterion, optimizer, scheduler and device.

    Raises:
        ValueError: if device is None and TPU available,
            for using TPU need to manualy move model/optimizer/scheduler
            to a TPU device and pass device to a function.
        NotImplementedError: if model is not nn.Module or dict for multi-gpu,
            nn.ModuleDict for DataParallel not implemented yet
    """
    distributed_params = distributed_params or {}
    distributed_params = copy.deepcopy(distributed_params)
    distributed_params.update(get_distributed_params())

    if device is None and IS_XLA_AVAILABLE:
        raise ValueError(
            "TPU device is available. "
            "Please move model, optimizer and scheduler (if present) "
            "to TPU device manualy and specify a device or "
            "use CPU device.")

    if device is None:
        device = get_device()
    elif isinstance(device, str):
        device = torch.device(device)

    is_apex_enabled = (distributed_params.get("apex", False)
                       and check_apex_available())

    is_amp_enabled = (distributed_params.get("amp", False)
                      and check_amp_available())

    if is_apex_enabled and is_amp_enabled:
        raise ValueError("Both NVidia Apex and Torch.Amp are enabled. "
                         "You must choose only one mixed precision backend")
    model: Model = maybe_recursive_call(model, "to", device=device)

    if check_ddp_wrapped(model):
        pass
    # distributed data parallel run (ddp) (with apex support)
    elif get_rank() >= 0:
        assert isinstance(
            model,
            nn.Module), "Distributed training is not available for KV model"

        local_rank = distributed_params.pop("local_rank", 0) or 0
        device = f"cuda:{local_rank}"
        model = maybe_recursive_call(model, "to", device=device)

        syncbn = distributed_params.pop("syncbn", False)

        if is_apex_enabled:
            import apex

            if syncbn:
                model = apex.parallel.convert_syncbn_model(model)

            model, optimizer = initialize_apex(model, optimizer,
                                               **distributed_params)
            model = apex.parallel.DistributedDataParallel(model)
        else:
            if syncbn:
                model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[local_rank], output_device=local_rank)
    # data parallel run (dp) (with apex support)
    else:
        is_data_parallel = (torch.cuda.device_count() > 1
                            and device.type != "cpu" and device.index is None)

        if is_apex_enabled and not is_data_parallel:
            model, optimizer = initialize_apex(model, optimizer,
                                               **distributed_params)

        elif not is_apex_enabled and is_data_parallel:
            if isinstance(model, nn.Module):
                model = nn.DataParallel(model)
            elif isinstance(model, dict):
                model = {k: nn.DataParallel(v) for k, v in model.items()}
            else:
                raise NotImplementedError()

        elif is_apex_enabled and is_data_parallel:
            model, optimizer = _wrap_into_data_parallel_with_apex(
                model, optimizer, distributed_params)

    model: Model = maybe_recursive_call(model, "to", device=device)

    return model, criterion, optimizer, scheduler, device
def process_components(
    model: Model,
    criterion: Criterion = None,
    optimizer: Optimizer = None,
    scheduler: Scheduler = None,
    distributed_params: Dict = None,
    device: Device = None,
) -> Tuple[Model, Criterion, Optimizer, Scheduler, Device]:
    """
    Returns the processed model, criterion, optimizer, scheduler and device.

    Args:
        model (Model): torch model
        criterion (Criterion): criterion function
        optimizer (Optimizer): optimizer
        scheduler (Scheduler): scheduler
        distributed_params (dict, optional): dict with the parameters
            for distributed and FP16 method
        device (Device, optional): device

    Returns:
        tuple with processed model, criterion, optimizer, scheduler and device.

    Raises:
        NotImplementedError: if model is not nn.Module or dict for multi-gpu,
            nn.ModuleDict for DataParallel not implemented yet
    """
    distributed_params = distributed_params or {}
    distributed_params = copy.deepcopy(distributed_params)
    distributed_params.update(get_distributed_params())

    if device is None:
        device = get_device()
    elif isinstance(device, str):
        device = torch.device(device)

    is_apex_available = (distributed_params.pop("apex", True)
                         and check_apex_available())

    model: Model = maybe_recursive_call(model, "to", device=device)

    if check_ddp_wrapped(model):
        pass
    # distributed data parallel run (ddp) (with apex support)
    elif get_rank() >= 0:
        assert isinstance(
            model,
            nn.Module), "Distributed training is not available for KV model"

        local_rank = distributed_params.pop("local_rank", 0) or 0
        device = f"cuda:{local_rank}"
        model = maybe_recursive_call(model, "to", device=device)

        syncbn = distributed_params.pop("syncbn", False)

        if is_apex_available:
            import apex

            model, optimizer = initialize_apex(model, optimizer,
                                               **distributed_params)
            model = apex.parallel.DistributedDataParallel(model)

            if syncbn:
                model = apex.parallel.convert_syncbn_model(model)
        else:
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[local_rank], output_device=local_rank)
    # data parallel run (dp) (with apex support)
    else:
        # apex issue https://github.com/deepset-ai/FARM/issues/210
        use_apex = (is_apex_available and torch.cuda.device_count() == 1) or (
            is_apex_available and torch.cuda.device_count() > 1
            and distributed_params.get("opt_level", "O0") == "O1")

        if use_apex:
            assert isinstance(
                model,
                nn.Module), "Apex training is not available for KV model"

            model, optimizer = initialize_apex(model, optimizer,
                                               **distributed_params)

        if (torch.cuda.device_count() > 1 and device.type != "cpu"
                and device.index is None):
            if isinstance(model, nn.Module):
                model = nn.DataParallel(model)
            elif isinstance(model, dict):
                model = {k: nn.DataParallel(v) for k, v in model.items()}
            else:
                raise NotImplementedError()

    model: Model = maybe_recursive_call(model, "to", device=device)

    return model, criterion, optimizer, scheduler, device
예제 #3
0
def main(args, _=None):
    """Run the ``catalyst-contrib text2embeddings`` script."""
    batch_size = args.batch_size
    num_workers = args.num_workers
    max_length = args.max_length
    pooling_groups = args.pooling.split(",")
    bert_level = args.bert_level

    if bert_level is not None:
        assert (args.output_hidden_states
                ), "You need hidden states output for level specification"

    set_global_seed(args.seed)
    prepare_cudnn(args.deterministic, args.benchmark)

    if getattr(args, "in_huggingface", False):
        model_config = BertConfig.from_pretrained(args.in_huggingface)
        model_config.output_hidden_states = args.output_hidden_states
        model = BertModel.from_pretrained(args.in_huggingface,
                                          config=model_config)
        tokenizer = BertTokenizer.from_pretrained(args.in_huggingface)
    else:
        model_config = BertConfig.from_pretrained(args.in_config)
        model_config.output_hidden_states = args.output_hidden_states
        model = BertModel(config=model_config)
        tokenizer = BertTokenizer.from_pretrained(args.in_vocab)
    if getattr(args, "in_model", None) is not None:
        checkpoint = load_checkpoint(args.in_model)
        checkpoint = {"model_state_dict": checkpoint}
        unpack_checkpoint(checkpoint=checkpoint, model=model)

    model = model.eval()
    model, _, _, _, device = process_components(model=model)

    df = pd.read_csv(args.in_csv)
    df = df.dropna(subset=[args.txt_col])
    df.to_csv(f"{args.out_prefix}.df.csv", index=False)
    df = df.reset_index().drop("index", axis=1)
    df = list(df.to_dict("index").values())
    num_samples = len(df)

    open_fn = LambdaReader(
        input_key=args.txt_col,
        output_key=None,
        lambda_fn=partial(
            tokenize_text,
            strip=args.strip,
            lowercase=args.lowercase,
            remove_punctuation=args.remove_punctuation,
        ),
        tokenizer=tokenizer,
        max_length=max_length,
    )

    dataloader = get_loader(
        df,
        open_fn,
        batch_size=batch_size,
        num_workers=num_workers,
    )

    features = {}
    dataloader = tqdm(dataloader) if args.verbose else dataloader
    with torch.no_grad():
        for idx, batch_input in enumerate(dataloader):
            batch_input = any2device(batch_input, device)
            batch_output = model(**batch_input)
            mask = (batch_input["attention_mask"].unsqueeze(-1)
                    if args.mask_for_max_length else None)

            if check_ddp_wrapped(model):
                # using several gpu
                hidden_size = model.module.config.hidden_size
                hidden_states = model.module.config.output_hidden_states

            else:
                # using cpu or one gpu
                hidden_size = model.config.hidden_size
                hidden_states = model.config.output_hidden_states

            batch_features = process_bert_output(
                bert_output=batch_output,
                hidden_size=hidden_size,
                output_hidden_states=hidden_states,
                pooling_groups=pooling_groups,
                mask=mask,
            )

            # create storage based on network output
            if idx == 0:
                for layer_name, layer_value in batch_features.items():
                    if bert_level is not None and bert_level != layer_name:
                        continue
                    layer_name = (layer_name if isinstance(layer_name, str)
                                  else f"{layer_name:02d}")
                    _, embedding_size = layer_value.shape
                    features[layer_name] = np.memmap(
                        f"{args.out_prefix}.{layer_name}.npy",
                        dtype=np.float32,
                        mode="w+",
                        shape=(num_samples, embedding_size),
                    )

            indices = np.arange(idx * batch_size,
                                min((idx + 1) * batch_size, num_samples))
            for layer_name2, layer_value2 in batch_features.items():
                if bert_level is not None and bert_level != layer_name2:
                    continue
                layer_name2 = (layer_name2 if isinstance(layer_name2, str) else
                               f"{layer_name2:02d}")
                features[layer_name2][indices] = _detach(layer_value2)

    if args.force_save:
        for key, mmap in features.items():
            mmap.flush()
            np.save(f"{args.out_prefix}.{key}.force.npy",
                    mmap,
                    allow_pickle=False)