def create_model(embeddings, d_model, d_ff, num_heads, num_layers, rpr_k, rpr_value_on, d_k, checkpoint_name, activation): rpr_k = listify(rpr_k) if len(rpr_k) == 0 or rpr_k[0] < 1: rpr_k = None elif len(rpr_k) == 1: rpr_k = rpr_k[0] logger.info("Creating tied encoder decoder model") model = TransformerLanguageModel.create({'x': embeddings}, hsz=d_model, d_ff=d_ff, tie_weights=True, dropout=0, gpu=False, num_heads=num_heads, layers=num_layers, rpr_k=rpr_k, rpr_value_on=rpr_value_on, d_k=d_k, activation=activation, src_keys=['x'], tgt_key='x') if checkpoint_name.endswith('npz'): load_tlm_npz(model, checkpoint_name) else: tlm_load_state_dict(model, checkpoint_name) model.eval() print(model) return model
def load(cls, embeddings, **kwargs): c = cls("tlm-words-embed-pooled-output", **kwargs) if embeddings.endswith('.bin'): # HuggingFace checkpoint, convert on the fly from eight_mile.pytorch.serialize import load_tlm_transformers_bin, BERT_HF_FT_LAYER_MAP unmatch = load_tlm_transformers_bin( c, embeddings, replace_layers=BERT_HF_FT_LAYER_MAP) if unmatch['missing'] or unmatch['unexpected']: raise Exception("Unable to load the HuggingFace checkpoint") if mime_type(embeddings ) == 'application/zip' and not embeddings.endswith("pth"): keys_to_restore = set(list(c.embeddings.keys())) filtered_keys = keys_to_restore.difference(c.skippable) if not keys_to_restore: raise Exception("No keys to restore!") if len(filtered_keys) < len(keys_to_restore): logger.warning("Restoring only key [%s]", ' '.join(filtered_keys)) load_tlm_output_npz(c, embeddings, filtered_keys) else: map_location = 'cpu' if kwargs.get('cpu_placement') else None tlm_load_state_dict(c, embeddings, str_map={ 'model.embeddings.embeddings.0.': '', 'model.output_layer': 'output_layer' }, map_location=map_location) return c
def load(cls, embeddings, **kwargs): c = cls("tlm-words-embed", **kwargs) if embeddings.endswith('.bin'): # HuggingFace checkpoint, convert on the fly from eight_mile.pytorch.serialize import load_tlm_transformers_bin, BERT_HF_FT_LAYER_MAP unmatch = load_tlm_transformers_bin( c, embeddings, replace_layers=BERT_HF_FT_LAYER_MAP) if unmatch['missing'] or unmatch['unexpected']: raise Exception("Unable to load the HuggingFace checkpoint") if mime_type(embeddings) == 'application/zip': load_tlm_npz(c, embeddings) else: tlm_load_state_dict(c, embeddings) return c