예제 #1
0
    def get_loaders(self, stage: str, **kwargs) -> tp.Dict[str, DataLoader]:
        loaders = dict()
        data_params = dict(self.stages_config[stage]["data_params"])
        data_path = (
            Path(os.getenv("DATA_PATH")) / "data_cat_dogs"
        ).as_posix() + "/*"
        tag_file_path = (
            Path(os.getenv("DATA_PATH")) / "cat_dog_labeling.json"
        ).as_posix()
        train_data, valid_data, num_classes = get_cat_dogs_dataset(
            data_path, tag_file_path=tag_file_path
        )

        open_fn = get_reader(num_classes)
        data = [("train", train_data), ("valid", valid_data)]
        for mode, part in data:
            data_transform = self.get_transforms(stage=stage, dataset=mode)
            loaders[mode] = utils.get_loader(
                part,
                open_fn=open_fn,
                dict_transform=data_transform,
                shuffle=(mode == "train"),
                sampler=None,
                drop_last=(mode == "train"),
                **data_params,
            )

        return loaders
예제 #2
0
    def get_loaders(self, stage: str, **kwargs):
        loaders = dict()
        data_params = dict(self.stages_config[stage]["data_params"])
        data_path = Path(os.environ["DATA_PATH"])

        if stage == "stage1":
            for mode in ["train", "valid"]:
                dataset = CIFAR10(
                    root=(data_path / "data_cifar").as_posix(),
                    train=(mode == "train"),
                    download=True,
                    transform=self.get_transforms(stage=stage, dataset=mode),
                )
                loaders[mode] = utils.get_loader(
                    dataset,
                    open_fn=lambda x: x,
                    dict_transform=lambda x: x,
                    shuffle=(mode == "train"),
                    sampler=None,
                    drop_last=(mode == "train"),
                    **data_params,
                )
        elif stage == "stage2":
            data_path = (data_path / "data_cat_dogs").as_posix() + "/*"
            tag_file_path = (data_path / "cat_dog_labeling.json").as_posix()
            train_data, valid_data, num_classes = get_cat_dogs_dataset(
                data_path, tag_file_path=tag_file_path)

            open_fn = get_reader(num_classes)
            data = [("train", train_data), ("valid", valid_data)]
            for mode, part in data:
                data_transform = self.get_transforms(stage=stage, dataset=mode)
                loaders[mode] = utils.get_loader(
                    part,
                    open_fn=open_fn,
                    dict_transform=data_transform,
                    shuffle=(mode == "train"),
                    sampler=None,
                    drop_last=(mode == "train"),
                    **data_params,
                )

        return loaders
예제 #3
0
def main(args, _=None):
    """Run the ``catalyst-data text2embeddings`` script."""
    batch_size = args.batch_size
    num_workers = args.num_workers
    max_length = args.max_length
    pooling_groups = args.pooling.split(",")
    bert_level = args.bert_level

    if bert_level is not None:
        assert (args.output_hidden_states
                ), "You need hidden states output for level specification"

    utils.set_global_seed(args.seed)
    utils.prepare_cudnn(args.deterministic, args.benchmark)

    if getattr(args, "in_huggingface", False):
        model_config = BertConfig.from_pretrained(args.in_huggingface)
        model_config.output_hidden_states = args.output_hidden_states
        model = BertModel.from_pretrained(args.in_huggingface,
                                          config=model_config)
        tokenizer = BertTokenizer.from_pretrained(args.in_huggingface)
    else:
        model_config = BertConfig.from_pretrained(args.in_config)
        model_config.output_hidden_states = args.output_hidden_states
        model = BertModel(config=model_config)
        tokenizer = BertTokenizer.from_pretrained(args.in_vocab)
    if getattr(args, "in_model", None) is not None:
        checkpoint = utils.load_checkpoint(args.in_model)
        checkpoint = {"model_state_dict": checkpoint}
        utils.unpack_checkpoint(checkpoint=checkpoint, model=model)

    model = model.eval()
    model, _, _, _, device = utils.process_components(model=model)

    df = pd.read_csv(args.in_csv)
    df = df.dropna(subset=[args.txt_col])
    df.to_csv(f"{args.out_prefix}.df.csv", index=False)
    df = df.reset_index().drop("index", axis=1)
    df = list(df.to_dict("index").values())
    num_samples = len(df)

    open_fn = LambdaReader(
        input_key=args.txt_col,
        output_key=None,
        lambda_fn=partial(
            tokenize_text,
            strip=args.strip,
            lowercase=args.lowercase,
            remove_punctuation=args.remove_punctuation,
        ),
        tokenizer=tokenizer,
        max_length=max_length,
    )

    dataloader = utils.get_loader(
        df,
        open_fn,
        batch_size=batch_size,
        num_workers=num_workers,
    )

    features = {}
    dataloader = tqdm(dataloader) if args.verbose else dataloader
    with torch.no_grad():
        for idx, batch_input in enumerate(dataloader):
            batch_input = utils.any2device(batch_input, device)
            batch_output = model(**batch_input)
            mask = (batch_input["attention_mask"].unsqueeze(-1)
                    if args.mask_for_max_length else None)

            if utils.check_ddp_wrapped(model):
                # using several gpu
                hidden_size = model.module.config.hidden_size
                hidden_states = model.module.config.output_hidden_states

            else:
                # using cpu or one gpu
                hidden_size = model.config.hidden_size
                hidden_states = model.config.output_hidden_states

            batch_features = process_bert_output(
                bert_output=batch_output,
                hidden_size=hidden_size,
                output_hidden_states=hidden_states,
                pooling_groups=pooling_groups,
                mask=mask,
            )

            # create storage based on network output
            if idx == 0:
                for layer_name, layer_value in batch_features.items():
                    if bert_level is not None and bert_level != layer_name:
                        continue
                    layer_name = (layer_name if isinstance(layer_name, str)
                                  else f"{layer_name:02d}")
                    _, embedding_size = layer_value.shape
                    features[layer_name] = np.memmap(
                        f"{args.out_prefix}.{layer_name}.npy",
                        dtype=np.float32,
                        mode="w+",
                        shape=(num_samples, embedding_size),
                    )

            indices = np.arange(idx * batch_size,
                                min((idx + 1) * batch_size, num_samples))
            for layer_name2, layer_value2 in batch_features.items():
                if bert_level is not None and bert_level != layer_name2:
                    continue
                layer_name2 = (layer_name2 if isinstance(layer_name2, str) else
                               f"{layer_name2:02d}")
                features[layer_name2][indices] = _detach(layer_value2)

    if args.force_save:
        for key, mmap in features.items():
            mmap.flush()
            np.save(f"{args.out_prefix}.{key}.force.npy",
                    mmap,
                    allow_pickle=False)