예제 #1
0
def load_model():
    model_save_path = constant.load_model_path
    state = torch.load(model_save_path,
                       map_location=lambda storage, location: storage)
    constant.arg = state['config']
    load_settings()

    if constant.model == "LSTM":
        model = LstmModel(vocab=vocab,
                          embedding_size=constant.emb_dim,
                          hidden_size=constant.hidden_dim,
                          num_layers=constant.n_layers,
                          is_bidirectional=constant.bidirec,
                          input_dropout=constant.drop,
                          layer_dropout=constant.drop,
                          attentive=constant.attn)
    elif constant.model == "UTRS":
        model = UTransformer(vocab=vocab,
                             embedding_size=constant.emb_dim,
                             hidden_size=constant.hidden_dim,
                             num_layers=constant.hop,
                             num_heads=constant.heads,
                             total_key_depth=constant.depth,
                             total_value_depth=constant.depth,
                             filter_size=constant.filter,
                             act=constant.act)
    elif constant.model == "ELMO":
        model = ELMoEncoder(C=4)
    else:
        print("Model is not defined")
        exit(0)

    model = model.load_state_dict(state['model'])
    return model
예제 #2
0
def load_model():
    model_load_path = constant.load_model_path
    model_save_path = constant.save_path
    state = torch.load(model_load_path,
                       map_location=lambda storage, location: storage)
    arg = state['config']
    load_settings(arg)

    data_loaders_train, data_loaders_val, data_loaders_test, vocab = prepare_data_loaders(
        num_split=1,
        batch_size=constant.batch_size,
        hier=False,
        elmo=constant.elmo,
        dev_with_label=False,
        include_test=True)

    if constant.model == "LSTM":
        model = LstmModel(vocab=vocab,
                          embedding_size=constant.emb_dim,
                          hidden_size=constant.hidden_dim,
                          num_layers=constant.n_layers,
                          is_bidirectional=constant.bidirec,
                          input_dropout=constant.drop,
                          layer_dropout=constant.drop,
                          attentive=constant.attn)
    elif constant.model == "UTRS":
        model = UTransformer(vocab=vocab,
                             embedding_size=constant.emb_dim,
                             hidden_size=constant.hidden_dim,
                             num_layers=constant.hop,
                             num_heads=constant.heads,
                             total_key_depth=constant.depth,
                             total_value_depth=constant.depth,
                             filter_size=constant.filter,
                             act=constant.act)
    elif constant.model == "ELMO":
        model = ELMoEncoder(C=4)
    else:
        print("Model is not defined")
        exit(0)

    model.load_state_dict(state['model'])
    return model, data_loaders_test, vocab, model_save_path
예제 #3
0
#from models import Zilpzalp
#from models import Hawk
from models import LstmModel
from utils import avg_score, maxwindow_score, get_top5_prediction

app = Flask(__name__)

# Initiate the model
model = LstmModel(time_axis=216, freq_axis=256, no_classes=100)

# Load the state of  model from checkpoint
checkpoint_path = 'model/checkpoint_Lstm_29-03'
checkpoint = torch.load(checkpoint_path, map_location='cpu')
state = checkpoint['state_dict']
model.load_state_dict(state)
model.eval()

# Add the dictionary with the species info
label_dict = {}
reader = csv.DictReader(open('model/top100_codes_translated.csv'))
for row in reader:
    label_dict[int(row['id1'])] = {
        'name': row['english'],
        'img_source': row['img_source'],
        'img_link': row['img_link'],
        'wiki_link': row['wiki_link'],
    }


@app.route("/")