Ejemplo n.º 1
0
def make_tsv(
    config: ConfigSchema,
    checkpoint: str,
    entities_by_type: Dict[str, List[str]],
    relation_types: List[str],
    entities_tf: TextIO,
    relation_types_tf: TextIO,
) -> None:
    print("Initializing model...")
    model = make_model(config)

    print("Loading model check point...")
    checkpoint_manager = CheckpointManager(checkpoint)
    state_dict, _ = checkpoint_manager.read_model()
    if state_dict is not None:
        model.load_state_dict(state_dict, strict=False)

    make_tsv_for_entities(
        model,
        checkpoint_manager,
        entities_by_type,
        entities_tf,
    )
    make_tsv_for_relation_types(
        model,
        relation_types,
        relation_types_tf,
    )
Ejemplo n.º 2
0
def make_tsv(
    config: ConfigSchema,
    entities_tf: TextIO,
    relation_types_tf: TextIO,
) -> None:
    print("Loading relation types and entities...")
    entity_storage = ENTITY_STORAGES.make_instance(config.entity_path)
    relation_type_storage = RELATION_TYPE_STORAGES.make_instance(
        config.entity_path)

    print("Initializing model...")
    model = make_model(config)

    print("Loading model check point...")
    checkpoint_manager = CheckpointManager(config.checkpoint_path)
    state_dict, _ = checkpoint_manager.read_model()
    if state_dict is not None:
        model.load_state_dict(state_dict, strict=False)

    make_tsv_for_entities(
        model,
        checkpoint_manager,
        entity_storage,
        entities_tf,
    )
    make_tsv_for_relation_types(
        model,
        relation_type_storage,
        relation_types_tf,
    )
Ejemplo n.º 3
0
def make_tsv(config: ConfigSchema, entities_tf: TextIO,
             relation_types_tf: TextIO) -> None:
    logging.info("Loading relation types and entities...")
    entity_storage = ENTITY_STORAGES.make_instance(config.entity_path)
    relation_type_storage = RELATION_TYPE_STORAGES.make_instance(
        config.entity_path)

    logging.info("Initializing model...")
    model = make_model(config)

    logging.info("Loading model check point...")
    checkpoint_manager = CheckpointManager(config.checkpoint_path)
    state_dict, _ = checkpoint_manager.read_model()
    if state_dict is not None:
        model.load_state_dict(state_dict, strict=False)

    make_tsv_for_entities(model, checkpoint_manager, entity_storage,
                          entities_tf)
    if config.relations[
            0].operator != 'linear':  # when the operator is linear, the relations type will raise an error
        make_tsv_for_relation_types(model, relation_type_storage,
                                    relation_types_tf)