Exemplo n.º 1
0
def create_model(embeddings, d_model, d_ff, num_heads, num_layers, rpr_k, rpr_value_on, d_k, checkpoint_name, activation):
    rpr_k = listify(rpr_k)

    if len(rpr_k) == 0 or rpr_k[0] < 1:
        rpr_k = None
    elif len(rpr_k) == 1:
        rpr_k = rpr_k[0]

    logger.info("Creating tied encoder decoder model")
    model = TransformerLanguageModel.create({'x': embeddings},
                                            hsz=d_model,
                                            d_ff=d_ff,
                                            tie_weights=True,
                                            dropout=0,
                                            gpu=False,
                                            num_heads=num_heads,
                                            layers=num_layers,
                                            rpr_k=rpr_k,
                                            rpr_value_on=rpr_value_on,
                                            d_k=d_k,
                                            activation=activation,
                                            src_keys=['x'], tgt_key='x')
    if checkpoint_name.endswith('npz'):
        load_tlm_npz(model, checkpoint_name)
    else:
        tlm_load_state_dict(model, checkpoint_name)
    model.eval()
    print(model)
    return model
Exemplo n.º 2
0
    def load(cls, embeddings, **kwargs):
        c = cls("tlm-words-embed-pooled-output", **kwargs)

        if embeddings.endswith('.bin'):
            # HuggingFace checkpoint, convert on the fly
            from eight_mile.pytorch.serialize import load_tlm_transformers_bin, BERT_HF_FT_LAYER_MAP
            unmatch = load_tlm_transformers_bin(
                c, embeddings, replace_layers=BERT_HF_FT_LAYER_MAP)
            if unmatch['missing'] or unmatch['unexpected']:
                raise Exception("Unable to load the HuggingFace checkpoint")
        if mime_type(embeddings
                     ) == 'application/zip' and not embeddings.endswith("pth"):
            keys_to_restore = set(list(c.embeddings.keys()))
            filtered_keys = keys_to_restore.difference(c.skippable)
            if not keys_to_restore:
                raise Exception("No keys to restore!")
            if len(filtered_keys) < len(keys_to_restore):
                logger.warning("Restoring only key [%s]",
                               ' '.join(filtered_keys))
            load_tlm_output_npz(c, embeddings, filtered_keys)
        else:
            map_location = 'cpu' if kwargs.get('cpu_placement') else None
            tlm_load_state_dict(c,
                                embeddings,
                                str_map={
                                    'model.embeddings.embeddings.0.': '',
                                    'model.output_layer': 'output_layer'
                                },
                                map_location=map_location)
        return c
Exemplo n.º 3
0
    def load(cls, embeddings, **kwargs):
        c = cls("tlm-words-embed", **kwargs)

        if embeddings.endswith('.bin'):
            # HuggingFace checkpoint, convert on the fly
            from eight_mile.pytorch.serialize import load_tlm_transformers_bin, BERT_HF_FT_LAYER_MAP
            unmatch = load_tlm_transformers_bin(
                c, embeddings, replace_layers=BERT_HF_FT_LAYER_MAP)
            if unmatch['missing'] or unmatch['unexpected']:
                raise Exception("Unable to load the HuggingFace checkpoint")
        if mime_type(embeddings) == 'application/zip':
            load_tlm_npz(c, embeddings)
        else:
            tlm_load_state_dict(c, embeddings)
        return c