Example #1
0
def create_model(embeddings, d_model, d_ff, num_heads, num_layers, rpr_k, d_k,
                 activation, checkpoint_name, device):
    if len(rpr_k) == 0 or rpr_k[0] < 1:
        rpr_k = [None]
    else:
        rpr_k = listify(rpr_k)
    logger.info("Creating tied encoder decoder model")
    hps = {
        "dsz": d_model,
        "hsz": d_model,
        "d_ff": d_ff,
        "dropout": 0.0,
        "num_heads": num_heads,
        "layers": num_layers,
        "encoder_type": "transformer",
        "decoder_type": "transformer",
        "src_lengths_key": "x_lengths",
        "d_k": d_k,
        "activation": activation,
        "rpr_k": rpr_k
    }
    model = TiedEmbeddingsSeq2SeqModel({'x': embeddings}, None, **hps)
    if checkpoint_name.endswith('npz'):
        load_transformer_seq2seq_npz(model, checkpoint_name)
    else:
        model.load_state_dict(
            torch.load(checkpoint_name, map_location=torch.device(device)))
    print(model)
    return model
Example #2
0
 def create(cls, src_embeddings, tgt_embedding, **kwargs):
     model = cls(src_embeddings, tgt_embedding, **kwargs)
     checkpoint_name = kwargs.get('checkpoint')
     if checkpoint_name is not None:
         if checkpoint_name.endswith('npz'):
             load_transformer_seq2seq_npz(model, checkpoint_name)
         else:
             model.load_state_dict(torch.load(checkpoint_name))
     logger.info(model)
     return model
def reload_from_checkpoint(model_type, restart_from, restart_tick_type, model,
                           steps_per_epoch):
    if os.path.isdir(restart_from):
        restart_from, _ = find_latest_checkpoint(restart_from)
        print(f'Latest checkpoint: {restart_from}')
    vec = restart_from.split("-")
    try:
        step_num = int(vec[-1].split(".")[0])
    except:
        step_num = 0
    start_epoch = 0
    if restart_tick_type:
        tick_type = restart_tick_type
    else:
        tick_type = vec[-2]
    if restart_from.endswith('.npz'):
        # If its a seq2seq load either from a seq2seq or from a TLM encoder
        if model_type == 'encoder-decoder':
            try:
                load_transformer_seq2seq_npz(model, restart_from)
            except:
                print(
                    'Model file not recognized as seq2seq model, attempting to load as LM for encoder, reset step'
                )
                load_seq2seq_enc_from_tlm_npz(model, restart_from)
                step_num = 0
                tick_type = 'ignore'
        else:
            try:
                load_transformer_de_npz(model, restart_from)
            # If its a dual-encoder, assuming we have model.transformer and model.embeddings, we can load directly
            # from a Transformer Language Model
            except:
                print(
                    'Model file not recognized as a dual encoder model, attempting to load as LM for encoder, reset step'
                )
                load_tlm_npz(model, restart_from)
                step_num = 0
                tick_type = 'ignore'

    else:
        model.load_state_dict(torch.load(restart_from))
    if tick_type == 'epoch':
        start_epoch = step_num
        step_num = start_epoch * steps_per_epoch

    elif tick_type == 'step':
        start_epoch = step_num // steps_per_epoch
    else:
        logger.warning(
            f"The previous tick was {step_num} but command-line specifies to ignore, setting to 0"
        )
        step_num = 0
        start_epoch = 0
    return step_num, start_epoch