Пример #1
0
def load_pretrained_model(serialization_dir: str) -> models.Model:
    """
    Given serialization directory, returns: model loaded with the pretrained weights.
    """

    # Load Config
    config_path = os.path.join(serialization_dir, "config.json")
    model_path = os.path.join(serialization_dir, "model.ckpt.index")

    model_files_present = all(
        [os.path.exists(path) for path in [config_path, model_path]])
    if not model_files_present:
        raise Exception(
            f"Model files in serialization_dir ({serialization_dir}) "
            f" are missing. Cannot load_the_model.")

    model_path = model_path.replace(".index", "")

    with open(config_path, "r") as file:
        config = json.load(file)

    # Load Model
    model_name = config.pop("type")
    if model_name == "basic":
        from model import MyBasicAttentiveBiGRU  # To prevent circular imports
        model = MyBasicAttentiveBiGRU(**config)
    elif model_name == "advanced":
        from model import MyAdvancedModel  # To prevent circular imports
        model = MyAdvancedModel(**config)
    else:
        raise Exception(f"model_name: {model_name} is not supported.")
    model.load_weights(model_path)

    return model
Пример #2
0
    train_instances = read_instances(args.data_file, MAX_TOKENS)
    print(f"\nReading Val Instances")
    val_instances = read_instances(args.val_file, MAX_TOKENS)

    with open(GLOVE_COMMON_WORDS_PATH) as file:
        glove_common_words = [line.strip() for line in file.readlines() if line.strip()]

    vocab_token_to_id, vocab_id_to_token = build_vocabulary(train_instances, VOCAB_SIZE,
                                                            glove_common_words)

    train_instances = index_instances(train_instances, vocab_token_to_id)
    val_instances = index_instances(val_instances, vocab_token_to_id)

    vocab_size = len(vocab_token_to_id)
    config = {'vocab_size': vocab_size, 'embed_dim': args.embed_dim, 'training': True, 'hidden_size': args.hidden_size}
    model = MyBasicAttentiveBiGRU(**config)
    config['type'] = 'basic'

    optimizer = optimizers.Adam()

    embeddings = load_glove_embeddings(args.embed_file, args.embed_dim, vocab_id_to_token)
    model.embeddings.assign(tf.convert_to_tensor(embeddings))

    save_serialization_dir = os.path.join('serialization_dirs', 'basic')
    if not os.path.exists(save_serialization_dir):
        os.makedirs(save_serialization_dir)

    train_output = train(model, optimizer, train_instances, val_instances,
                         args.epochs, args.batch_size, save_serialization_dir)

    config_path = os.path.join(save_serialization_dir, "config.json")