Exemple #1
0
 def predict_batch(self, batch):
     batch = any2device(batch, self.device)
     if len(batch) == 4:
         (
             city_id_tensor,
             booker_country_tensor,
             device_class_tensor,
             affiliate_id_tensor,
         ) = batch
     elif len(batch) == 5:
         (
             city_id_tensor,
             booker_country_tensor,
             device_class_tensor,
             affiliate_id_tensor,
             y,
         ) = batch
     else:
         raise RuntimeError
     out = self.model(
         city_id_tensor,
         booker_country_tensor,
         device_class_tensor,
         affiliate_id_tensor,
     )
     return out
Exemple #2
0
 def predict_batch(self, batch):
     batch = any2device(batch, self.device)
     if len(batch) == 12:
         (
             city_id_tensor,
             # booker_country_tensor,
             device_class_tensor,
             affiliate_id_tensor,
             month_checkin_tensor,
             num_checkin_tensor,
             days_stay_tensor,
             days_move_tensor,
             hotel_country_tensor,
             num_visit_drop_duplicates_tensor,
             num_visit_tensor,
             num_visit_same_city_tensor,
             num_stay_consecutively_tensor,
         ) = batch
     elif len(batch) == 14:
         (
             city_id_tensor,
             # booker_country_tensor,
             device_class_tensor,
             affiliate_id_tensor,
             month_checkin_tensor,
             num_checkin_tensor,
             days_stay_tensor,
             days_move_tensor,
             hotel_country_tensor,
             num_visit_drop_duplicates_tensor,
             num_visit_tensor,
             num_visit_same_city_tensor,
             num_stay_consecutively_tensor,
             y_s,
             y_h,
         ) = batch
     else:
         raise RuntimeError
     out_s, out_h = self.model(
         city_id_tensor,
         # booker_country_tensor,
         device_class_tensor,
         affiliate_id_tensor,
         month_checkin_tensor,
         num_checkin_tensor,
         days_stay_tensor,
         days_move_tensor,
         hotel_country_tensor,
         num_visit_drop_duplicates_tensor,
         num_visit_tensor,
         num_visit_same_city_tensor,
         num_stay_consecutively_tensor,
     )
     return out_s, out_h
    def predict_batch(self, batch):
        batch = any2device(batch, self.device)
        if len(batch) == 14:
            input_tensors = batch
        elif len(batch) == 15:
            input_tensors = batch[:-1]
        else:
            raise RuntimeError

        out = self.model(*input_tensors)
        return out
Exemple #4
0
    def _get_optimizer(self, stage: str, model: Union[Model, Dict[str, Model]],
                       **params) -> Optimizer:
        # @TODO 1: refactoring; this method is too long
        # @TODO 2: load state dicts for schedulers & criterion
        layerwise_params = params.pop("layerwise_params", OrderedDict())
        no_bias_weight_decay = params.pop("no_bias_weight_decay", True)

        # linear scaling rule from https://arxiv.org/pdf/1706.02677.pdf
        lr_scaling_params = params.pop("lr_linear_scaling", None)
        if lr_scaling_params:
            data_params = dict(self.stages_config[stage]["data_params"])
            batch_size = data_params.get("batch_size")
            per_gpu_scaling = data_params.get("per_gpu_scaling", False)
            distributed_rank = utils.get_rank()
            distributed = distributed_rank > -1
            if per_gpu_scaling and not distributed:
                num_gpus = max(1, torch.cuda.device_count())
                batch_size *= num_gpus

            base_lr = lr_scaling_params.get("lr")
            base_batch_size = lr_scaling_params.get("base_batch_size", 256)
            lr_scaling = batch_size / base_batch_size
            params["lr"] = base_lr * lr_scaling  # scale default lr
        else:
            lr_scaling = 1.0

        # getting model parameters
        model_key = params.pop("_model", None)
        if model_key is None:
            assert isinstance(
                model, nn.Module
            ), "model is key-value, but optimizer has no specified model"
            model_params = utils.process_model_params(model, layerwise_params,
                                                      no_bias_weight_decay,
                                                      lr_scaling)
        elif isinstance(model_key, str):
            model_params = utils.process_model_params(
                model[model_key],
                layerwise_params,
                no_bias_weight_decay,
                lr_scaling,
            )
        elif isinstance(model_key, (list, tuple)):
            model_params = []
            for model_key_ in model_key:
                model_params_ = utils.process_model_params(
                    model[model_key_],
                    layerwise_params,
                    no_bias_weight_decay,
                    lr_scaling,
                )
                model_params.extend(model_params_)
        else:
            raise ValueError("unknown type of model_params")

        load_from_previous_stage = params.pop("load_from_previous_stage",
                                              False)
        optimizer_key = params.pop("optimizer_key", None)
        optimizer = OPTIMIZERS.get_from_params(**params, params=model_params)

        if load_from_previous_stage and self.stages.index(stage) != 0:
            checkpoint_path = f"{self.logdir}/checkpoints/best_full.pth"
            checkpoint = utils.load_checkpoint(checkpoint_path)

            dict2load = optimizer
            if optimizer_key is not None:
                dict2load = {optimizer_key: optimizer}
            utils.unpack_checkpoint(checkpoint, optimizer=dict2load)

            # move optimizer to device
            device = utils.get_device()
            for param in model_params:
                param = param["params"][0]
                state = optimizer.state[param]
                for key, value in state.items():
                    state[key] = utils.any2device(value, device)

            # update optimizer params
            for key, value in params.items():
                for pg in optimizer.param_groups:
                    pg[key] = value

        return optimizer
Exemple #5
0
 def _batch2device(self, batch: Mapping[str, Any], device):
     res = utils.any2device(batch, device)
     return res
Exemple #6
0
def main(args, _=None):
    """Run the ``catalyst-data text2embeddings`` script."""
    batch_size = args.batch_size
    num_workers = args.num_workers
    max_length = args.max_length
    pooling_groups = args.pooling.split(",")

    utils.set_global_seed(args.seed)
    utils.prepare_cudnn(args.deterministic, args.benchmark)

    if hasattr(args, "in_huggingface"):
        model_config = BertConfig.from_pretrained(args.in_huggingface)
        model_config.output_hidden_states = args.output_hidden_states
        model = BertModel.from_pretrained(args.in_huggingface,
                                          config=model_config)
        tokenizer = BertTokenizer.from_pretrained(args.in_huggingface)
    else:
        model_config = BertConfig.from_pretrained(args.in_config)
        model_config.output_hidden_states = args.output_hidden_states
        model = BertModel(config=model_config)
        tokenizer = BertTokenizer.from_pretrained(args.in_vocab)
    if hasattr(args, "in_model"):
        checkpoint = utils.load_checkpoint(args.in_model)
        checkpoint = {"model_state_dict": checkpoint}
        utils.unpack_checkpoint(checkpoint=checkpoint, model=model)

    model = model.eval()
    model, _, _, _, device = utils.process_components(model=model)

    df = pd.read_csv(args.in_csv)
    df = df.dropna(subset=[args.txt_col])
    df.to_csv(f"{args.out_prefix}.df.csv", index=False)
    df = df.reset_index().drop("index", axis=1)
    df = list(df.to_dict("index").values())
    num_samples = len(df)

    open_fn = LambdaReader(
        input_key=args.txt_col,
        output_key=None,
        lambda_fn=partial(
            tokenize_text,
            strip=args.strip,
            lowercase=args.lowercase,
            remove_punctuation=args.remove_punctuation,
        ),
        tokenizer=tokenizer,
        max_length=max_length,
    )

    dataloader = utils.get_loader(
        df,
        open_fn,
        batch_size=batch_size,
        num_workers=num_workers,
    )

    features = {}
    dataloader = tqdm(dataloader) if args.verbose else dataloader
    with torch.no_grad():
        for idx, batch in enumerate(dataloader):
            batch = utils.any2device(batch, device)
            bert_output = model(**batch)
            mask = (batch["attention_mask"].unsqueeze(-1)
                    if args.mask_for_max_length else None)

            if utils.check_ddp_wrapped(model):
                # using several gpu
                hidden_size = model.module.config.hidden_size
                hidden_states = model.module.config.output_hidden_states

            else:
                # using cpu or one gpu
                hidden_size = model.config.hidden_size
                hidden_states = model.config.output_hidden_states

            features_ = process_bert_output(
                bert_output=bert_output,
                hidden_size=hidden_size,
                output_hidden_states=hidden_states,
                pooling_groups=pooling_groups,
                mask=mask,
            )

            # create storage based on network output
            if idx == 0:
                for key, value in features_.items():
                    name_ = key if isinstance(key, str) else f"{key:02d}"
                    _, embedding_size = value.shape
                    features[name_] = np.memmap(
                        f"{args.out_prefix}.{name_}.npy",
                        dtype=np.float32,
                        mode="w+",
                        shape=(num_samples, embedding_size),
                    )

            indices = np.arange(idx * batch_size,
                                min((idx + 1) * batch_size, num_samples))
            for key, value in features_.items():
                name_ = key if isinstance(key, str) else f"{key:02d}"
                features[name_][indices] = _detach(value)
Exemple #7
0
def main(args, _=None):
    batch_size = args.batch_size
    num_workers = args.num_workers
    max_length = args.max_length
    pooling_groups = args.pooling.split(",")

    utils.set_global_seed(args.seed)
    utils.prepare_cudnn(args.deterministic, args.benchmark)

    model_config = BertConfig.from_pretrained(args.in_config)
    model_config.output_hidden_states = args.output_hidden_states
    model = BertModel(config=model_config)

    checkpoint = utils.load_checkpoint(args.in_model)
    checkpoint = {"model_state_dict": checkpoint}
    utils.unpack_checkpoint(checkpoint=checkpoint, model=model)

    model = model.eval()
    model, _, _, _, device = utils.process_components(model=model)

    tokenizer = BertTokenizer.from_pretrained(args.in_vocab)

    df = pd.read_csv(args.in_csv)
    df = df.dropna(subset=[args.txt_col])
    df.to_csv(f"{args.out_prefix}.df.csv", index=False)
    df = df.reset_index().drop("index", axis=1)
    df = list(df.to_dict("index").values())
    num_samples = len(df)

    open_fn = LambdaReader(
        input_key=args.txt_col,
        output_key=None,
        lambda_fn=get_features,
        tokenizer=tokenizer,
        max_length=max_length,
    )

    dataloader = utils.get_loader(
        df,
        open_fn,
        batch_size=batch_size,
        num_workers=num_workers,
    )

    features = {}
    poolings = {}
    dataloader = tqdm(dataloader) if args.verbose else dataloader
    with torch.no_grad():
        for idx, batch in enumerate(dataloader):
            batch = utils.any2device(batch, device)
            features_ = model(**batch)

            # create storage based on network output
            if idx == 0:
                # class
                _, embedding_size = features_[1].shape
                features["class"] = np.memmap(
                    f"{args.out_prefix}.class.npy",
                    dtype=np.float32,
                    mode="w+",
                    shape=(num_samples, embedding_size),
                )
                if args.output_hidden_states:
                    # all embeddings
                    for i, feature_ in enumerate(features_[2]):
                        name_ = f"embeddings_{i + 1:02d}"
                        _, _, embedding_size = feature_.shape
                        poolings[name_] = LamaPooling(
                            features_in=embedding_size,
                            groups=pooling_groups,
                        )
                        features[name_] = np.memmap(
                            f"{args.out_prefix}.{name_}.npy",
                            dtype=np.float32,
                            mode="w+",
                            shape=(num_samples, embedding_size),
                        )
                else:
                    # last
                    _, _, embedding_size = features_[0].shape
                    poolings["last"] = LamaPooling(
                        features_in=embedding_size,
                        groups=pooling_groups,
                    )
                    features["last"] = np.memmap(
                        f"{args.out_prefix}.last.npy",
                        dtype=np.float32,
                        mode="w+",
                        shape=(num_samples, embedding_size),
                    )

            indices = np.arange(idx * batch_size,
                                min((idx + 1) * batch_size, num_samples))
            features["class"][indices] = _detach(features_[1])
            if args.output_hidden_states:
                # all embeddings
                for i, feature_ in enumerate(features_[2]):
                    name_ = f"embeddings_{i + 1:02d}"
                    feature_ = poolings[name_](feature_)
                    features[name_][indices] = _detach(feature_)
            else:
                feature_ = poolings[name_](features_[0])
                features["last"][indices] = _detach(feature_)