Пример #1
0
def get_db_config(cred: Optional[str]) -> Dict:
    """

    :param cred:
    :return:
    """
    if cred:
        cred_params = read_config_stream(cred)['odin_db']
        if 'host' not in cred_params:
            cred_params['host'] = cred_params['dbhost']
        if 'port' not in cred_params:
            cred_params['port'] = cred_params['dbport']
    else:
        cred_params = {}
        cred_params['backend'] = os.environ.get("ODIN_JOBS_BACKEND",
                                                "postgres")
        cred_params['host'] = os.environ.get("SQL_HOST", "127.0.0.1")
        cred_params['port'] = os.environ.get("DB_PORT", 5432)
        cred_params['user'] = os.environ.get("DB_USER")
        cred_params['passwd'] = os.environ.get("DB_PASS")
        cred_params['odin_root_user'] = os.environ.get("ODIN_ROOT_USER")
        cred_params['odin_root_passwd'] = os.environ.get("ODIN_ROOT_PASS")
    cred_params['db'] = ODIN_DB
    LOGGER.warning('%s %s %s %s', cred_params['user'], cred_params['db'],
                   cred_params['host'], cred_params['port'])
    return cred_params
Пример #2
0
    def __init__(self, **kwargs):
        super().__init__()
        # You dont actually have to pass this if you are using the `load_bert_vocab` call from your
        # tokenizer.  In this case, a singleton variable will contain the vocab and it will be returned
        # by `load_bert_vocab`
        # If you trained your model with MEAD/Baseline, you will have a `*.json` file which would want to
        # reference here
        vocab_file = kwargs.get('vocab_file')
        if vocab_file and os.path.exists(vocab_file):
            if vocab_file.endswith('.json'):
                self.vocab = read_config_stream(kwargs.get('vocab_file'))
            else:
                self.vocab = load_bert_vocab(kwargs.get('vocab_file'))
        else:
            self.vocab = kwargs.get('vocab', kwargs.get('known_vocab'))
            if self.vocab is None or isinstance(self.vocab,
                                                collections.Counter):
                self.vocab = load_bert_vocab(None)
        # When we reload, allows skipping restoration of these embeddings
        # If the embedding wasnt trained with token types, this allows us to add them later
        self.skippable = set(listify(kwargs.get('skip_restore_embeddings',
                                                [])))

        self.cls_index = self.vocab.get('[CLS]', self.vocab.get('<s>'))
        self.vsz = max(self.vocab.values()) + 1
        self.d_model = int(kwargs.get('dsz', kwargs.get('d_model', 768)))
        self.init_embed(**kwargs)
        self.proj_to_dsz = pytorch_linear(
            self.dsz, self.d_model) if self.dsz != self.d_model else _identity
        self.init_transformer(**kwargs)
        self.return_mask = kwargs.get('return_mask', False)
Пример #3
0
def create_transformer_lm(
        config_url: str,
        model_type: str) -> Tuple[TransformerMaskedLanguageModel, int]:
    config = read_config_stream(config_url)
    pdrop = config['attention_probs_dropout_prob']
    activation = config['hidden_act']
    d_model = config['hidden_size']
    d_ff = config['intermediate_size']
    layer_norm_eps = float(config['layer_norm_eps'])
    mxlen = config['max_position_embeddings']
    num_heads = config['num_attention_heads']
    num_layers = config['num_hidden_layers']
    pad = config['pad_token_id']
    if pad != 0 and pad != 1:
        raise Exception(f"Unexpected pad value {pad}")
    tt_vsz = config['type_vocab_size']
    vsz = config['vocab_size']

    if model_type == "bert" or model_type == "roberta":
        embeddings_type = "sum-layer-norm"
        transformer_type = "post-layer-norm"
    else:
        raise Exception(f"We dont support model type {model_type}")

    embeddings = {
        'x':
        LearnedPositionalLookupTableEmbeddings(vsz=vsz,
                                               dsz=d_model,
                                               mxlen=mxlen),
        'tt':
        LookupTableEmbeddings(vsz=tt_vsz, dsz=d_model)
    }
    if model_type == "bert":
        embeddings['tt']: LookupTableEmbeddings(vsz=tt_vsz, dsz=d_model)

    model = TransformerMaskedLanguageModel.create(
        embeddings,
        d_model=d_model,
        d_ff=d_ff,
        num_heads=num_heads,
        tgt_key='x',
        num_layers=num_layers,
        embeddings_dropout=pdrop,
        dropout=pdrop,
        activation=activation,
        transformer_type=transformer_type,
        embeddings_reduction=embeddings_type)
    return model, num_layers
Пример #4
0
def get_db_config(cred: Optional[str]) -> Dict:
    """

    :param cred:
    :return:
    """
    if cred:
        cred_params = read_config_stream(cred)['jobs_db']

    else:
        cred_params = {}
        cred_params['backend'] = os.environ.get("ODIN_JOBS_BACKEND",
                                                "postgres")
        cred_params['host'] = os.environ.get("SQL_HOST", "127.0.0.1")
        cred_params['port'] = os.environ.get("DB_PORT", 5432)
        cred_params['user'] = os.environ.get("DB_USER")
        cred_params['passwd'] = os.environ.get("DB_PASS")
    cred_params['db'] = os.environ.get("DB_NAME", "jobs_db")
    return cred_params
Пример #5
0
 def __init__(self, **kwargs):
     super().__init__()
     # You dont actually have to pass this if you are using the `load_bert_vocab` call from your
     # tokenizer.  In this case, a singleton variable will contain the vocab and it will be returned
     # by `load_bert_vocab`
     # If you trained your model with MEAD/Baseline, you will have a `*.json` file which would want to
     # reference here
     vocab_file = kwargs.get('vocab_file')
     if vocab_file and vocab_file.endswith('.json'):
         self.vocab = read_config_stream(kwargs.get('vocab_file'))
     else:
         self.vocab = load_bert_vocab(kwargs.get('vocab_file'))
     self.cls_index = self.vocab['[CLS]']
     self.vsz = max(self.vocab.values()) + 1
     self.d_model = int(kwargs.get('dsz', kwargs.get('d_model', 768)))
     self.init_embed(**kwargs)
     self.proj_to_dsz = pytorch_linear(
         self.dsz, self.d_model) if self.dsz != self.d_model else _identity
     self.init_transformer(**kwargs)
Пример #6
0
def create_transformer_lm(
        config_url: str) -> Tuple[TransformerMaskedLanguageModel, int]:
    config = read_config_stream(config_url)
    pdrop = config['attention_probs_dropout_prob']
    activation = config['hidden_act']
    d_model = config['hidden_size']
    d_ff = config['intermediate_size']
    layer_norm_eps = float(config['layer_norm_eps'])
    mxlen = config['max_position_embeddings']
    num_heads = config['num_attention_heads']
    num_layers = config['num_hidden_layers']
    pad = config['pad_token_id']
    if pad != 0:
        raise Exception(f"Unexpected pad value {pad}")
    if layer_norm_eps != 1e-12:
        raise Exception(
            f"Expected layer norm to be 1e-12, received {layer_norm_eps}")

    tt_vsz = config['type_vocab_size']
    vsz = config['vocab_size']
    embeddings = {
        'x':
        LearnedPositionalLookupTableEmbeddings(vsz=vsz,
                                               dsz=d_model,
                                               mxlen=mxlen),
        'tt':
        LookupTableEmbeddings(vsz=tt_vsz, dsz=d_model)
    }
    model = TransformerMaskedLanguageModel.create(
        embeddings,
        d_model=d_model,
        d_ff=d_ff,
        num_heads=num_heads,
        tgt_key='x',
        num_layers=num_layers,
        embeddings_dropout=pdrop,
        dropout=pdrop,
        activation=activation,
        layer_norms_after=True,
        embeddings_reduction='sum-layer-norm')
    return model, num_layers
Пример #7
0
def create_transformer_lm_gpt2(
        config_url: str,
        model_type: str) -> Tuple[TransformerLanguageModel, int]:
    config = read_config_stream(config_url)
    pdrop = config['attn_pdrop']
    activation = config['activation_function']
    d_model = config['n_embd']
    d_ff = 4 * d_model
    layer_norm_eps = float(config['layer_norm_epsilon'])
    mxlen = config['n_ctx']
    num_heads = config['n_head']
    num_layers = config['n_layer']
    pad = 0
    if pad != 0 and pad != 1:
        raise Exception(f"Unexpected pad value {pad}")
    vsz = config['vocab_size']
    embeddings_type = "sum"
    transformer_type = "pre-layer-norm"
    embeddings = {
        'x':
        LearnedPositionalLookupTableEmbeddings(vsz=vsz,
                                               dsz=d_model,
                                               mxlen=mxlen)
    }
    model = TransformerLanguageModel.create(
        embeddings,
        d_model=d_model,
        d_ff=d_ff,
        num_heads=num_heads,
        tgt_key='x',
        num_layers=num_layers,
        embeddings_dropout=pdrop,
        dropout=pdrop,
        activation=activation,
        transformer_type=transformer_type,
        embeddings_reduction=embeddings_type,
        layer_norms_after=False,
        layer_norm_eps=layer_norm_eps,
        tie_weights=True)
    return model, num_layers
Пример #8
0
def test_read_config_stream_str(gold_data):
    input_ = json.dumps(gold_data)
    data = read_config_stream(input_)
    assert data == gold_data
Пример #9
0
def test_read_config_stream_env(env, gold_data):
    data = read_config_stream(env)
    assert data == gold_data
Пример #10
0
def test_read_config_stream_file():
    file_name = os.path.join(data_loc, "test_json.json")
    with mock.patch("eight_mile.utils.read_config_file") as read_patch:
        read_config_stream(file_name)
    read_patch.assert_called_once_with(file_name)
Пример #11
0
def main():
    parser = argparse.ArgumentParser(description='Train a text classifier')
    parser.add_argument(
        '--config',
        help=
        'JSON/YML Configuration for an experiment: local file or remote URL',
        type=convert_path,
        default="$MEAD_CONFIG")
    parser.add_argument('--settings',
                        help='JSON/YML Configuration for mead',
                        default=DEFAULT_SETTINGS_LOC,
                        type=convert_path)
    parser.add_argument('--task_modules',
                        help='tasks to load, must be local',
                        default=[],
                        nargs='+',
                        required=False)
    parser.add_argument(
        '--datasets',
        help=
        'index of dataset labels: local file, remote URL or mead-ml/hub ref',
        type=convert_path)
    parser.add_argument(
        '--modules',
        help='modules to load: local files, remote URLs or mead-ml/hub refs',
        default=[],
        nargs='+',
        required=False)
    parser.add_argument('--mod_train_file', help='override the training set')
    parser.add_argument('--mod_valid_file', help='override the validation set')
    parser.add_argument('--mod_test_file', help='override the test set')
    parser.add_argument('--fit_func', help='override the fit function')
    parser.add_argument(
        '--embeddings',
        help='index of embeddings: local file, remote URL or mead-ml/hub ref',
        type=convert_path)
    parser.add_argument(
        '--vecs',
        help='index of vectorizers: local file, remote URL or hub mead-ml/ref',
        type=convert_path)
    parser.add_argument('--logging',
                        help='json file for logging',
                        default=DEFAULT_LOGGING_LOC,
                        type=convert_path)
    parser.add_argument('--task',
                        help='task to run',
                        choices=['classify', 'tagger', 'seq2seq', 'lm'])
    parser.add_argument('--gpus',
                        help='Number of GPUs (defaults to number available)',
                        type=int,
                        default=-1)
    parser.add_argument(
        '--basedir',
        help='Override the base directory where models are stored',
        type=str)
    parser.add_argument('--reporting', help='reporting hooks', nargs='+')
    parser.add_argument('--backend', help='The deep learning backend to use')
    parser.add_argument('--checkpoint',
                        help='Restart training from this checkpoint')
    parser.add_argument(
        '--prefer_eager',
        help="If running in TensorFlow, should we prefer eager model",
        type=str2bool)
    args, overrides = parser.parse_known_args()
    config_params = read_config_stream(args.config)
    config_params = parse_and_merge_overrides(config_params,
                                              overrides,
                                              pre='x')
    if args.basedir is not None:
        config_params['basedir'] = args.basedir

    # task_module overrides are not allowed via hub or HTTP, must be defined locally
    for task in args.task_modules:
        import_user_module(task)

    task_name = config_params.get(
        'task', 'classify') if args.task is None else args.task
    args.logging = read_config_stream(args.logging)
    configure_logger(args.logging,
                     config_params.get('basedir', './{}'.format(task_name)))

    try:
        args.settings = read_config_stream(args.settings)
    except:
        logger.warning(
            'Warning: no mead-settings file was found at [{}]'.format(
                args.settings))
        args.settings = {}

    args.datasets = args.settings.get(
        'datasets', convert_path(
            DEFAULT_DATASETS_LOC)) if args.datasets is None else args.datasets
    args.datasets = read_config_stream(args.datasets)
    if args.mod_train_file or args.mod_valid_file or args.mod_test_file:
        logging.warning(
            'Warning: overriding the training/valid/test data with user-specified files'
            ' different from what was specified in the dataset index.  Creating a new key for this entry'
        )
        update_datasets(args.datasets, config_params, args.mod_train_file,
                        args.mod_valid_file, args.mod_test_file)

    args.embeddings = args.settings.get(
        'embeddings', convert_path(DEFAULT_EMBEDDINGS_LOC)
    ) if args.embeddings is None else args.embeddings
    args.embeddings = read_config_stream(args.embeddings)

    args.vecs = args.settings.get('vecs', convert_path(
        DEFAULT_VECTORIZERS_LOC)) if args.vecs is None else args.vecs
    args.vecs = read_config_stream(args.vecs)

    if args.gpus:
        # why does it go to model and not to train?
        config_params['train']['gpus'] = args.gpus
    if args.fit_func:
        config_params['train']['fit_func'] = args.fit_func
    if args.backend:
        config_params['backend'] = normalize_backend(args.backend)

    config_params['modules'] = list(
        set(chain(config_params.get('modules', []), args.modules)))

    cmd_hooks = args.reporting if args.reporting is not None else []
    config_hooks = config_params.get('reporting') if config_params.get(
        'reporting') is not None else []
    reporting = parse_extra_args(set(chain(cmd_hooks, config_hooks)),
                                 overrides)
    config_params['reporting'] = reporting

    logger.info('Task: [{}]'.format(task_name))

    task = mead.Task.get_task_specific(task_name, args.settings)

    task.read_config(config_params,
                     args.datasets,
                     args.vecs,
                     reporting_args=overrides,
                     prefer_eager=args.prefer_eager)
    task.initialize(args.embeddings)
    task.train(args.checkpoint)
Пример #12
0
def main():
    parser = argparse.ArgumentParser(
        description='Encode a sentence as an embedding')
    parser.add_argument('--subword_model_file', help='Subword model file')
    parser.add_argument('--nctx', default=256, type=int)
    parser.add_argument('--batchsz', default=20, type=int)
    parser.add_argument('--vec_id',
                        default='bert-base-uncased',
                        help='Reference to a specific embedding type')
    parser.add_argument('--embed_id',
                        default='bert-base-uncased',
                        help='What type of embeddings to use')
    parser.add_argument('--file', required=True)
    parser.add_argument('--column', type=str)
    parser.add_argument('--output', default='embeddings.npz')
    parser.add_argument(
        '--pool',
        help=
        'Should a reduction be applied on the embeddings?  Only use if your embeddings arent already pooled',
        type=str)
    parser.add_argument(
        '--embeddings',
        help='index of embeddings: local file, remote URL or mead-ml/hub ref',
        type=convert_path)
    parser.add_argument(
        '--vecs',
        help='index of vectorizers: local file, remote URL or hub mead-ml/ref',
        type=convert_path)
    parser.add_argument('--cuda', type=baseline.str2bool, default=True)
    parser.add_argument('--has_header', action="store_true")
    parser.add_argument(
        "--tokenizer_type",
        type=str,
        help="Optional tokenizer, default is to use string split")
    parser.add_argument(
        '--faiss_index',
        help="If provided, we will build a FAISS index and store it here")
    parser.add_argument(
        '--quoting',
        default=3,
        help='0=QUOTE_MINIMAL 1=QUOTE_ALL 2=QUOTE_NONNUMERIC 3=QUOTE_NONE',
        type=int)
    parser.add_argument('--sep', default='\t')
    parser.add_argument('--add_columns', nargs='+', default=[])

    args = parser.parse_args()

    if not args.has_header:
        if not args.column:
            args.column = 0
        if args.add_columns:
            args.add_columns = [int(c) for c in args.add_columns]
        column = int(args.column)

    else:
        column = args.column

    args.embeddings = convert_path(
        DEFAULT_EMBEDDINGS_LOC) if args.embeddings is None else args.embeddings
    args.embeddings = read_config_stream(args.embeddings)

    args.vecs = convert_path(
        DEFAULT_VECTORIZERS_LOC) if args.vecs is None else args.vecs

    vecs_index = read_config_stream(args.vecs)
    vecs_set = index_by_label(vecs_index)
    vec_params = vecs_set[args.vec_id]
    vec_params['mxlen'] = args.nctx

    if 'transform' in vec_params:
        vec_params['transform_fn'] = vec_params['transform']

    if 'transform_fn' in vec_params and isinstance(vec_params['transform_fn'],
                                                   str):
        vec_params['transform_fn'] = eval(vec_params['transform_fn'])
    tokenizer = create_tokenizer(args.tokenizer_type)
    vectorizer = create_vectorizer(**vec_params)
    if not isinstance(vectorizer, HasPredefinedVocab):
        raise Exception(
            "We currently require a vectorizer with a pre-defined vocab to run this script"
        )
    embeddings_index = read_config_stream(args.embeddings)
    embeddings_set = index_by_label(embeddings_index)
    embeddings_params = embeddings_set[args.embed_id]
    # If they dont want CUDA try and get the embedding loader to use CPU
    embeddings_params['cpu_placement'] = not args.cuda
    embeddings = load_embeddings_overlay(embeddings_set, embeddings_params,
                                         vectorizer.vocab)

    vocabs = {'x': embeddings['vocab']}
    embedder = embeddings['embeddings'].cpu()
    embedder.eval()
    if args.cuda:
        embedder = embedder.cuda()

    def _mean_pool(inputs, embeddings):
        mask = (inputs != 0)
        seq_lengths = mask.sum(1).unsqueeze(-1)
        return embeddings.sum(1) / seq_lengths

    def _zero_tok_pool(_, embeddings):
        pooled = embeddings[:, 0]
        return pooled

    def _max_pool(inputs, embeddings):
        mask = (inputs != 0)
        embeddings = embeddings.masked_fill(mask.unsqueeze(-1) == False, -1e8)
        return torch.max(embeddings, 1, False)[0]

    if args.pool:
        if args.pool == 'max':
            pool = _max_pool
        elif args.pool == 'zero' or args.pool == 'cls':
            pool = _zero_tok_pool
        else:
            pool = _mean_pool
    else:
        pool = lambda x, y: y

    def chunks(lst, n):
        """Yield successive n-sized chunks from lst."""
        for i in range(0, len(lst), n):
            yield lst[i:i + n]

    df = pd.read_csv(args.file,
                     header='infer' if args.has_header else None,
                     sep=args.sep)
    col = df[column]
    batches = []
    as_list = col.tolist()
    extra_col_map = {}
    for extra_col in args.add_columns:
        if isinstance(extra_col, int):
            key = f'column_{extra_col}'
        else:
            key = extra_col
        extra_col_map[key] = df[extra_col].tolist()
    num_batches = math.ceil(len(as_list) / args.batchsz)
    pg = baseline.create_progress_bar(num_batches, name='tqdm')
    for i, batch in enumerate(chunks(as_list, args.batchsz)):
        pg.update()
        with torch.no_grad():
            vecs = []
            for line in batch:
                tokenized = tokenizer(line)
                vec, l = vectorizer.run(tokenized, vocabs['x'])
                vecs.append(vec)
            vecs = torch.tensor(np.stack(vecs))
            if args.cuda:
                vecs = vecs.cuda()
            embedding = embedder(vecs)
            pooled_batch = pool(vecs, embedding).cpu().numpy()
            batches += [x for x in pooled_batch]

    np.savez(args.output, embeddings=batches, text=as_list, **extra_col_map)
    if args.faiss_index:
        import faiss
        index = faiss.IndexFlatIP(batches[0].shape[-1])
        batches = np.stack(batches)
        faiss.normalize_L2(batches)
        index.add(batches)
        faiss.write_index(index, args.faiss_index)
Пример #13
0
    '--embeddings',
    help='index of embeddings: local file, remote URL or mead-ml/hub ref',
    type=convert_path)
parser.add_argument(
    '--vecs',
    help='index of vectorizers: local file, remote URL or hub mead-ml/ref',
    type=convert_path)
parser.add_argument('--cuda', type=baseline.str2bool, default=True)
parser.add_argument('--has_header', type=baseline.str2bool, default=True)
parser.add_argument('--sep', default='\t')

args = parser.parse_args()

args.embeddings = convert_path(
    DEFAULT_EMBEDDINGS_LOC) if args.embeddings is None else args.embeddings
args.embeddings = read_config_stream(args.embeddings)

args.vecs = convert_path(
    DEFAULT_VECTORIZERS_LOC) if args.vecs is None else args.vecs

vecs_index = read_config_stream(args.vecs)
vecs_set = index_by_label(vecs_index)
vec_params = vecs_set[args.vec_id]
vec_params['mxlen'] = args.nctx

if 'transform' in vec_params:
    vec_params['transform_fn'] = vec_params['transform']

if 'transform_fn' in vec_params and isinstance(vec_params['transform_fn'],
                                               str):
    vec_params['transform_fn'] = eval(vec_params['transform_fn'])
Пример #14
0
def main():
    parser = argparse.ArgumentParser(description='Run senteval harness')
    parser.add_argument('--nctx', default=512, type=int)
    parser.add_argument("--module", default=None, help="Module containing custom tokenizers")
    parser.add_argument('--tasks', nargs="+", default=['sts', 'class', 'probe'])
    parser.add_argument('--batchsz', default=20, type=int)
    parser.add_argument('--tok', help='Optional tokenizer, e.g. "gpt2" or "basic". These can be defined in extra module')
    parser.add_argument('--pool', help='Should a reduction be applied on the embeddings?  Only use if your embeddings arent already pooled', type=str)
    parser.add_argument('--vec_id', help='Reference to a specific embedding type')
    parser.add_argument('--embed_id', help='What type of embeddings to use')
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)")
    parser.add_argument('--max_len1d', type=int, default=100)
    parser.add_argument('--embeddings', help='index of embeddings: local file, remote URL or mead-ml/hub ref', type=convert_path)
    parser.add_argument('--vecs', help='index of vectorizers: local file, remote URL or hub mead-ml/ref', type=convert_path)
    parser.add_argument('--fast', help="Run fast, but not necessarily as accurate", action='store_true')
    parser.add_argument('--data', help="Path to senteval data",
                        default=os.path.expanduser("~/dev/work/SentEval/data"))
    args = parser.parse_args()

    if args.module:
        logger.warning("Loading custom user module %s for masking rules and tokenizers", args.module)
        baseline.import_user_module(args.module)


    tokenizer = create_tokenizer(args.tok) if args.tok else None

    args.embeddings = convert_path(DEFAULT_EMBEDDINGS_LOC) if args.embeddings is None else args.embeddings
    args.embeddings = read_config_stream(args.embeddings)

    args.vecs = convert_path(DEFAULT_VECTORIZERS_LOC) if args.vecs is None else args.vecs

    vecs_index = read_config_stream(args.vecs)
    vecs_set = index_by_label(vecs_index)
    vec_params = vecs_set[args.vec_id]
    vec_params['mxlen'] = args.nctx

    if 'transform' in vec_params:
        vec_params['transform_fn'] = vec_params['transform']

    if 'transform_fn' in vec_params and isinstance(vec_params['transform_fn'], str):
        vec_params['transform_fn'] = eval(vec_params['transform_fn'])

    vectorizer = create_vectorizer(**vec_params)
    if not isinstance(vectorizer, HasPredefinedVocab):
        raise Exception("We currently require a vectorizer with a pre-defined vocab to run this script")
    embeddings_index = read_config_stream(args.embeddings)
    embeddings_set = index_by_label(embeddings_index)
    embeddings_params = embeddings_set[args.embed_id]
    embeddings = load_embeddings_overlay(embeddings_set, embeddings_params, vectorizer.vocab)

    embedder = embeddings['embeddings']
    embedder.to(args.device).eval()

    def _mean_pool(inputs, embeddings):
        mask = (inputs != 0)
        seq_lengths = mask.sum(1).unsqueeze(-1)
        return embeddings.sum(1)/seq_lengths

    def _zero_tok_pool(_, embeddings):
        pooled = embeddings[:, 0]
        return pooled

    def _max_pool(inputs, embeddings):
        mask = (inputs != 0)
        embeddings = embeddings.masked_fill(mask.unsqueeze(-1) == False, -1e8)
        return torch.max(embeddings, 1, False)[0]

    if args.pool:
        if args.pool == 'max':
            pool = _max_pool
        elif args.pool == 'zero' or args.pool == 'cls':
            pool = _zero_tok_pool
        else:
            pool = _mean_pool
    else:
        pool = lambda x, y: y

    params_senteval = {'task_path': args.data, 'usepytorch': True, 'kfold': 10}
    params_senteval['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64,
                                     'tenacity': 5, 'epoch_size': 4}
    if args.fast:
        logging.info("Setting fast params")
        params_senteval['kfold'] = 5
        params_senteval['classifier']['epoch_size'] = 2
        params_senteval['classifier']['tenacity'] = 3
        params_senteval['classifier']['batch_size'] = 128

    # SentEval prepare and batcher
    def prepare(params, samples):
        max_sample = max(len(s) for s in samples)
        vectorizer.mxlen = min(args.nctx, max_sample + SUBWORD_EXTRA)
        logging.info('num_samples %d, mxlen set to %d', max_sample, vectorizer.mxlen)

    def batcher(params, batch):
        if not tokenizer:
            batch = [sent if sent != [] else ['.'] for sent in batch]
        else:
            batch = [tokenizer(' '.join(sent)) for sent in batch]

        vs = []
        for sent in batch:
            v, l = vectorizer.run(sent, vectorizer.vocab)
            vs.append(v)
        vs = np.stack(vs)
        with torch.no_grad():
            inputs = torch.tensor(vs, device=args.device)
            encoding = embedder(inputs)
            encoding = pool(inputs, encoding)
            encoding = encoding.cpu().numpy()
        return encoding

    se = senteval.engine.SE(params_senteval, batcher, prepare)
    transfer_tasks = []
    if 'sts' in args.tasks:
        transfer_tasks += ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'SICKRelatedness', 'STSBenchmark']
    if 'class' in args.tasks:
        transfer_tasks += ['MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC',
                           'SICKEntailment']
    if 'probe' in args.tasks:
        transfer_tasks += ['Length', 'WordContent', 'Depth', 'TopConstituents',
                           'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
                           'OddManOut', 'CoordinationInversion']

    results = se.eval(transfer_tasks)
    print(results)