def test_uses_named_inputs(self): inputs = { "question": "What kind of test succeeded on its first attempt?", "passage": "One time I was writing a unit test, and it succeeded on the first attempt." } archive = load_archive(self.FIXTURES_ROOT / 'bidaf' / 'serialization' / 'model.tar.gz') predictor = Predictor.from_archive(archive, 'machine-comprehension') result = predictor.predict_json(inputs) best_span = result.get("best_span") assert best_span is not None assert isinstance(best_span, list) assert len(best_span) == 2 assert all(isinstance(x, int) for x in best_span) assert best_span[0] <= best_span[1] best_span_str = result.get("best_span_str") assert isinstance(best_span_str, str) assert best_span_str != "" for probs_key in ("span_start_probs", "span_end_probs"): probs = result.get(probs_key) assert probs is not None assert all(isinstance(x, float) for x in probs) assert sum(probs) == approx(1.0)
def test_uses_named_inputs(self): inputs = { "document": "This is a single string document about a test. Sometimes it " "contains coreferent parts." } archive = load_archive(self.FIXTURES_ROOT / 'coref' / 'serialization' / 'model.tar.gz') predictor = Predictor.from_archive(archive, 'coreference-resolution') result = predictor.predict_json(inputs) document = result["document"] assert document == [ 'This', 'is', 'a', 'single', 'string', 'document', 'about', 'a', 'test', '.', 'Sometimes', 'it', 'contains', 'coreferent', 'parts', '.' ] clusters = result["clusters"] assert isinstance(clusters, list) for cluster in clusters: assert isinstance(cluster, list) for mention in cluster: # Spans should be integer indices. assert isinstance(mention[0], int) assert isinstance(mention[1], int) # Spans should be inside document. assert 0 < mention[0] <= len(document) assert 0 < mention[1] <= len(document)
def test_prediction_with_no_verbs(self): input1 = {"sentence": "Blah no verb sentence."} archive = load_archive(self.FIXTURES_ROOT / 'srl' / 'serialization' / 'model.tar.gz') predictor = Predictor.from_archive(archive, 'semantic-role-labeling') result = predictor.predict_json(input1) assert result == { 'words': ['Blah', 'no', 'verb', 'sentence', '.'], 'verbs': [] } input2 = {"sentence": "This sentence has a verb."} results = predictor.predict_batch_json([input1, input2]) assert results[0] == { 'words': ['Blah', 'no', 'verb', 'sentence', '.'], 'verbs': [] } assert results[1] == { 'words': ['This', 'sentence', 'has', 'a', 'verb', '.'], 'verbs': [{ 'verb': 'has', 'description': 'This sentence has a verb .', 'tags': ['O', 'O', 'O', 'O', 'O', 'O'] }] }
def test_uses_named_inputs(self): inputs = { "sentence": "The squirrel wrote a unit test to make sure its nuts worked as designed." } archive = load_archive(self.FIXTURES_ROOT / 'srl' / 'serialization' / 'model.tar.gz') predictor = Predictor.from_archive(archive, 'semantic-role-labeling') result = predictor.predict_json(inputs) print(result) words = result.get("words") assert words == [ "The", "squirrel", "wrote", "a", "unit", "test", "to", "make", "sure", "its", "nuts", "worked", "as", "designed", "." ] num_words = len(words) verbs = result.get("verbs") assert verbs is not None assert isinstance(verbs, list) assert any(v["verb"] == "wrote" for v in verbs) assert any(v["verb"] == "make" for v in verbs) assert any(v["verb"] == "worked" for v in verbs) for verb in verbs: tags = verb.get("tags") assert tags is not None assert isinstance(tags, list) assert all(isinstance(tag, str) for tag in tags) assert len(tags) == num_words
def test_extra_files(self): serialization_dir = self.TEST_DIR / 'serialization' # Train a model train_model(self.params, serialization_dir=serialization_dir) # Archive model, and also archive the training data files_to_archive = { "train_data_path": str(self.FIXTURES_ROOT / 'data' / 'sequence_tagging.tsv') } archive_model(serialization_dir=serialization_dir, files_to_archive=files_to_archive) archive = load_archive(serialization_dir / 'model.tar.gz') params = archive.config # The param in the data should have been replaced with a temporary path # (which we don't know, but we know what it ends with). assert params.get('train_data_path').endswith('/fta/train_data_path') # The validation data path should be the same though. assert params.get('validation_data_path') == str( self.FIXTURES_ROOT / 'data' / 'sequence_tagging.tsv')
def setUp(self): super().setUp() archive = load_archive(self.FIXTURES_ROOT / 'bidaf' / 'serialization' / 'model.tar.gz') self.bidaf_predictor = Predictor.from_archive(archive, 'machine-comprehension')
def test_archiving(self): # copy params, since they'll get consumed during training params_copy = copy.deepcopy(self.params.as_dict()) # `train_model` should create an archive serialization_dir = self.TEST_DIR / 'archive_test' model = train_model(self.params, serialization_dir=serialization_dir) archive_path = serialization_dir / "model.tar.gz" # load from the archive archive = load_archive(archive_path) model2 = archive.model # check that model weights are the same keys = set(model.state_dict().keys()) keys2 = set(model2.state_dict().keys()) assert keys == keys2 for key in keys: assert torch.equal(model.state_dict()[key], model2.state_dict()[key]) # check that vocabularies are the same vocab = model.vocab vocab2 = model2.vocab assert vocab._token_to_index == vocab2._token_to_index # pylint: disable=protected-access assert vocab._index_to_token == vocab2._index_to_token # pylint: disable=protected-access # check that params are the same params2 = archive.config assert params2.as_dict() == params_copy
def test_from_archive_does_not_consume_params(self): archive = load_archive(self.FIXTURES_ROOT / 'bidaf' / 'serialization' / 'model.tar.gz') Predictor.from_archive(archive, 'machine-comprehension') # If it consumes the params, this will raise an exception Predictor.from_archive(archive, 'machine-comprehension')
def test_batch_prediction(self): inputs = [{ "sentence": "What kind of test succeeded on its first attempt?", }, { "sentence": "What kind of test succeeded on its first attempt at batch processing?", }] archive = load_archive(self.FIXTURES_ROOT / 'biaffine_dependency_parser' / 'serialization' / 'model.tar.gz') predictor = Predictor.from_archive(archive, 'biaffine-dependency-parser') results = predictor.predict_batch_json(inputs) assert len(results) == 2 for result in results: sequence_length = len(result.get("words")) predicted_heads = result.get("predicted_heads") assert len(predicted_heads) == sequence_length predicted_dependencies = result.get("predicted_dependencies") assert len(predicted_dependencies) == sequence_length assert isinstance(predicted_dependencies, list) assert all(isinstance(x, str) for x in predicted_dependencies)
def test_uses_named_inputs(self): inputs = { "premise": "I always write unit tests for my code.", "hypothesis": "One time I didn't write any unit tests for my code." } archive = load_archive(self.FIXTURES_ROOT / 'decomposable_attention' / 'serialization' / 'model.tar.gz') predictor = Predictor.from_archive(archive, 'textual-entailment') result = predictor.predict_json(inputs) # Label probs should be 3 floats that sum to one label_probs = result.get("label_probs") assert label_probs is not None assert isinstance(label_probs, list) assert len(label_probs) == 3 assert all(isinstance(x, float) for x in label_probs) assert all(x >= 0 for x in label_probs) assert sum(label_probs) == approx(1.0) # Logits should be 3 floats that softmax to label_probs label_logits = result.get("label_logits") assert label_logits is not None assert isinstance(label_logits, list) assert len(label_logits) == 3 assert all(isinstance(x, float) for x in label_logits) exps = [math.exp(x) for x in label_logits] sumexps = sum(exps) for e, p in zip(exps, label_probs): assert e / sumexps == approx(p)
def _get_predictor(args: argparse.Namespace) -> Predictor: check_for_gpu(args.cuda_device) archive = load_archive(args.archive_file, weights_file=args.weights_file, cuda_device=args.cuda_device, overrides=args.overrides) return Predictor.from_archive(archive, args.predictor)
def __init__(self, vocab: Vocabulary, sentence_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, attention: Attention, beam_size: int, max_decoding_steps: int, max_num_finished_states: int = None, dropout: float = 0.0, normalize_beam_score_by_length: bool = False, checklist_cost_weight: float = 0.6, dynamic_cost_weight: Dict[str, Union[int, float]] = None, penalize_non_agenda_actions: bool = False, initial_mml_model_file: str = None) -> None: super(NlvrCoverageSemanticParser, self).__init__(vocab=vocab, sentence_embedder=sentence_embedder, action_embedding_dim=action_embedding_dim, encoder=encoder, dropout=dropout) self._agenda_coverage = Average() self._decoder_trainer: DecoderTrainer[Callable[[NlvrDecoderState], torch.Tensor]] = \ ExpectedRiskMinimization(beam_size=beam_size, normalize_by_length=normalize_beam_score_by_length, max_decoding_steps=max_decoding_steps, max_num_finished_states=max_num_finished_states) # Instantiating an empty NlvrWorld just to get the number of terminals. self._terminal_productions = set(NlvrWorld([]).terminal_productions.values()) self._decoder_step = NlvrDecoderStep(encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, input_attention=attention, dropout=dropout, use_coverage=True) self._checklist_cost_weight = checklist_cost_weight self._dynamic_cost_wait_epochs = None self._dynamic_cost_rate = None if dynamic_cost_weight: self._dynamic_cost_wait_epochs = dynamic_cost_weight["wait_num_epochs"] self._dynamic_cost_rate = dynamic_cost_weight["rate"] self._penalize_non_agenda_actions = penalize_non_agenda_actions self._last_epoch_in_forward: int = None # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've # copied a trained ERM model from a different machine and the original MML model that was # used to initialize it does not exist on the current machine. This may not be the best # solution for the problem. if initial_mml_model_file is not None: if os.path.isfile(initial_mml_model_file): archive = load_archive(initial_mml_model_file) self._initialize_weights_from_archive(archive) else: # A model file is passed, but it does not exist. This is expected to happen when # you're using a trained ERM model to decode. But it may also happen if the path to # the file is really just incorrect. So throwing a warning. logger.warning("MML model file for initializing weights is passed, but does not exist." " This is fine if you're just decoding.")
def main(args): # Executing this file with no extra options runs the simple service with the bidaf test fixture # and the machine-comprehension predictor. There's no good reason you'd want # to do this, except possibly to test changes to the stock HTML). parser = argparse.ArgumentParser(description='Serve up a simple model') parser.add_argument('--archive-path', type=str, required=True, help='path to trained archive file') parser.add_argument('--predictor', type=str, required=True, help='name of predictor') parser.add_argument('--static-dir', type=str, help='serve index.html from this directory') parser.add_argument('--title', type=str, help='change the default page title', default="AllenNLP Demo") parser.add_argument('--field-name', type=str, action='append', help='field names to include in the demo') parser.add_argument('--port', type=int, default=8000, help='port to serve the demo on') parser.add_argument('--include-package', type=str, action='append', default=[], help='additional packages to include') args = parser.parse_args(args) # Load modules for package_name in args.include_package: import_submodules(package_name) archive = load_archive(args.archive_path) predictor = Predictor.from_archive(archive, args.predictor) field_names = args.field_name app = make_app(predictor=predictor, field_names=field_names, static_dir=args.static_dir, title=args.title) CORS(app) http_server = WSGIServer(('0.0.0.0', args.port), app) print(f"Model loaded, serving demo on port {args.port}") http_server.serve_forever()
def test_batch_prediction(self): inputs = { "sentence": "The squirrel wrote a unit test to make sure its nuts worked as designed." } archive = load_archive(self.FIXTURES_ROOT / 'srl' / 'serialization' / 'model.tar.gz') predictor = Predictor.from_archive(archive, 'semantic-role-labeling') result = predictor.predict_batch_json([inputs, inputs]) assert result[0] == result[1]
def test_predictor_with_direct_parser(self): archive_dir = self.FIXTURES_ROOT / 'semantic_parsing' / 'nlvr_direct_semantic_parser' / 'serialization' archive = load_archive(os.path.join(archive_dir, 'model.tar.gz')) predictor = Predictor.from_archive(archive, 'nlvr-parser') result = predictor.predict_json(self.inputs) assert 'logical_form' in result assert 'denotations' in result # result['denotations'] is a list corresponding to k-best logical forms, where k is 1 by # default. assert len(result['denotations'] [0]) == 2 # Because there are two worlds in the input.
def from_params(cls, vocab: Vocabulary, params: Params) -> 'BidafEnsemble': # type: ignore # pylint: disable=arguments-differ if vocab: raise ConfigurationError("vocab should be None") submodels = [] paths = params.pop("submodels") for path in paths: submodels.append(load_archive(path).model) return cls(submodels=submodels)
def test_answer_present_with_batch_predict(self): inputs = [{ "question": "Who is 18 years old?", "table": "Name\tAge\nShallan\t16\nKaladin\t18" }] archive_path = self.FIXTURES_ROOT / 'semantic_parsing' / 'wikitables' / 'serialization' / 'model.tar.gz' archive = load_archive(archive_path) predictor = Predictor.from_archive(archive, 'wikitables-parser') result = predictor.predict_batch_json(inputs) answer = result[0].get("answer") assert answer is not None
def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]: # Disable some of the more verbose logging statements logging.getLogger('allennlp.common.params').disabled = True logging.getLogger('allennlp.nn.initializers').disabled = True logging.getLogger('allennlp.modules.token_embedders.embedding').setLevel( logging.INFO) # Load from archive archive = load_archive(args.archive_file, args.cuda_device, args.overrides, args.weights_file) config = archive.config prepare_environment(config) model = archive.model model.eval() # Load the evaluation data # Try to use the validation dataset reader if there is one - otherwise fall back # to the default dataset_reader used for both training and validation. validation_dataset_reader_params = config.pop('validation_dataset_reader', None) if validation_dataset_reader_params is not None: dataset_reader = DatasetReader.from_params( validation_dataset_reader_params) else: dataset_reader = DatasetReader.from_params( config.pop('dataset_reader')) evaluation_data_path = args.input_file logger.info("Reading evaluation data from %s", evaluation_data_path) instances = dataset_reader.read(evaluation_data_path) iterator_params = config.pop("validation_iterator", None) if iterator_params is None: iterator_params = config.pop("iterator") iterator = DataIterator.from_params(iterator_params) iterator.index_with(model.vocab) metrics = evaluate(model, instances, iterator, args.cuda_device) logger.info("Finished evaluating.") logger.info("Metrics:") for key, metric in metrics.items(): logger.info("%s: %s", key, metric) output_file = args.output_file if output_file: with open(output_file, "w") as file: json.dump(metrics, file, indent=4) return metrics
def from_path(cls, archive_path: str, predictor_name: str = None) -> 'Predictor': """ Instantiate a :class:`Predictor` from an archive path. If you need more detailed configuration options, such as running the predictor on the GPU, please use `from_archive`. Parameters ---------- archive_path The path to the archive. Returns ------- A Predictor instance. """ return Predictor.from_archive(load_archive(archive_path), predictor_name)
def test_uses_named_inputs(self): inputs = { "source": "What kind of test succeeded on its first attempt?", } archive = load_archive(self.FIXTURES_ROOT / 'encoder_decoder' / 'simple_seq2seq' / 'serialization' / 'model.tar.gz') predictor = Predictor.from_archive(archive, 'simple_seq2seq') result = predictor.predict_json(inputs) predicted_tokens = result.get("predicted_tokens") assert predicted_tokens is not None assert isinstance(predicted_tokens, list) assert all(isinstance(x, str) for x in predicted_tokens)
def test_get_vocab_index_mapping(self): # pylint: disable=line-too-long mml_model_archive_file = (self.FIXTURES_ROOT / "semantic_parsing" / "nlvr_direct_semantic_parser" / "serialization" / "model.tar.gz") archive = load_archive(mml_model_archive_file) mapping = self.model._get_vocab_index_mapping(archive.model.vocab) expected_mapping = [(i, i) for i in range(16)] assert mapping == expected_mapping new_vocab = Vocabulary() def copy_token_at_index(i): token = self.vocab.get_token_from_index(i, "tokens") new_vocab.add_token_to_namespace(token, "tokens") copy_token_at_index(5) copy_token_at_index(7) copy_token_at_index(10) mapping = self.model._get_vocab_index_mapping(new_vocab) # Mapping of indices from model vocabulary to new vocabulary. 0 and 1 are padding and unk # tokens. assert mapping == [(0, 0), (1, 1), (5, 2), (7, 3), (10, 4)]
def test_predictor_uses_dataset_reader_to_determine_pos_set(self): # pylint: disable=protected-access archive = load_archive(self.FIXTURES_ROOT / 'biaffine_dependency_parser' / 'serialization' / 'model.tar.gz') predictor = Predictor.from_archive(archive, 'biaffine-dependency-parser') inputs = { "sentence": "Dogs eat cats.", } instance_with_ud_pos = predictor._json_to_instance(inputs) tags = instance_with_ud_pos.fields["pos_tags"].labels assert tags == ['NOUN', 'VERB', 'NOUN', 'PUNCT'] predictor._dataset_reader.use_language_specific_pos = True instance_with_ptb_pos = predictor._json_to_instance(inputs) tags = instance_with_ptb_pos.fields["pos_tags"].labels assert tags == ['NNS', 'VBP', 'NNS', '.']
def test_uses_named_inputs(self): inputs = { "sentence": "What a great test sentence.", } archive = load_archive(self.FIXTURES_ROOT / 'constituency_parser' / 'serialization' / 'model.tar.gz') predictor = Predictor.from_archive(archive, 'constituency-parser') result = predictor.predict_json(inputs) assert len(result["spans"] ) == 21 # number of possible substrings of the sentence. assert len(result["class_probabilities"]) == 21 assert result["tokens"] == [ "What", "a", "great", "test", "sentence", "." ] assert isinstance(result["trees"], str) for class_distribution in result["class_probabilities"]: self.assertAlmostEqual(sum(class_distribution), 1.0, places=4)
def test_initialize_weights_from_archive(self): original_model_parameters = self.model.named_parameters() original_model_weights = {name: parameter.data.clone().numpy() for name, parameter in original_model_parameters} # pylint: disable=line-too-long mml_model_archive_file = (self.FIXTURES_ROOT / "semantic_parsing" / "nlvr_direct_semantic_parser" / "serialization" / "model.tar.gz") archive = load_archive(mml_model_archive_file) archived_model_parameters = archive.model.named_parameters() self.model._initialize_weights_from_archive(archive) changed_model_parameters = dict(self.model.named_parameters()) for name, archived_parameter in archived_model_parameters: archived_weight = archived_parameter.data.numpy() original_weight = original_model_weights[name] changed_weight = changed_model_parameters[name].data.numpy() # We want to make sure that the weights in the original model have indeed been changed # after a call to ``_initialize_weights_from_archive``. with self.assertRaises(AssertionError, msg=f"{name} has not changed"): assert_almost_equal(original_weight, changed_weight) # This also includes the sentence token embedder. Those weights will be the same # because the two models have the same vocabulary. assert_almost_equal(archived_weight, changed_weight)
def test_batch_prediction(self): batch_inputs = [ { "premise": "I always write unit tests for my code.", "hypothesis": "One time I didn't write any unit tests for my code." }, { "premise": "I also write batched unit tests for throughput!", "hypothesis": "Batch tests are slower." }, ] archive = load_archive(self.FIXTURES_ROOT / 'decomposable_attention' / 'serialization' / 'model.tar.gz') predictor = Predictor.from_archive(archive, 'textual-entailment') results = predictor.predict_batch_json(batch_inputs) print(results) assert len(results) == 2 for result in results: # Logits should be 3 floats that softmax to label_probs label_logits = result.get("label_logits") # Label probs should be 3 floats that sum to one label_probs = result.get("label_probs") assert label_probs is not None assert isinstance(label_probs, list) assert len(label_probs) == 3 assert all(isinstance(x, float) for x in label_probs) assert all(x >= 0 for x in label_probs) assert sum(label_probs) == approx(1.0) assert label_logits is not None assert isinstance(label_logits, list) assert len(label_logits) == 3 assert all(isinstance(x, float) for x in label_logits) exps = [math.exp(x) for x in label_logits] sumexps = sum(exps) for e, p in zip(exps, label_probs): assert e / sumexps == approx(p)
def test_batch_prediction(self): inputs = [{ "question": "What kind of test succeeded on its first attempt?", "passage": "One time I was writing a unit test, and it succeeded on the first attempt." }, { "question": "What kind of test succeeded on its first attempt at batch processing?", "passage": "One time I was writing a unit test, and it always failed!" }] archive = load_archive(self.FIXTURES_ROOT / 'bidaf' / 'serialization' / 'model.tar.gz') predictor = Predictor.from_archive(archive, 'machine-comprehension') results = predictor.predict_batch_json(inputs) assert len(results) == 2 for result in results: best_span = result.get("best_span") best_span_str = result.get("best_span_str") start_probs = result.get("span_start_probs") end_probs = result.get("span_end_probs") assert best_span is not None assert isinstance(best_span, list) assert len(best_span) == 2 assert all(isinstance(x, int) for x in best_span) assert best_span[0] <= best_span[1] assert isinstance(best_span_str, str) assert best_span_str != "" for probs in (start_probs, end_probs): assert probs is not None assert all(isinstance(x, float) for x in probs) assert sum(probs) == approx(1.0)
def test_uses_named_inputs(self): inputs = { "question": "names", "table": "name\tdate\nmatt\t2017\npradeep\t2018" } archive_path = self.FIXTURES_ROOT / 'semantic_parsing' / 'wikitables' / 'serialization' / 'model.tar.gz' archive = load_archive(archive_path) predictor = Predictor.from_archive(archive, 'wikitables-parser') result = predictor.predict_json(inputs) action_sequence = result.get("best_action_sequence") if action_sequence: # We don't currently disallow endless loops in the decoder, and an untrained seq2seq # model will easily get itself into a loop. An endless loop isn't a finished logical # form, so decoding doesn't return any finished states, which means no actions. So, # sadly, we don't have a great test here. This is just testing that the predictor # runs, basically. assert len(action_sequence) > 1 assert all([isinstance(action, str) for action in action_sequence]) logical_form = result.get("logical_form") assert logical_form is not None
def test_batch_prediction(self): inputs = [{ "sentence": "What a great test sentence." }, { "sentence": "Here's another good, interesting one." }] archive = load_archive(self.FIXTURES_ROOT / 'constituency_parser' / 'serialization' / 'model.tar.gz') predictor = Predictor.from_archive(archive, 'constituency-parser') results = predictor.predict_batch_json(inputs) result = results[0] assert len(result["spans"] ) == 21 # number of possible substrings of the sentence. assert len(result["class_probabilities"]) == 21 assert result["tokens"] == [ "What", "a", "great", "test", "sentence", "." ] assert isinstance(result["trees"], str) for class_distribution in result["class_probabilities"]: self.assertAlmostEqual(sum(class_distribution), 1.0, places=4) result = results[1] assert len(result["spans"] ) == 36 # number of possible substrings of the sentence. assert len(result["class_probabilities"]) == 36 assert result["tokens"] == [ "Here", "'s", "another", "good", ",", "interesting", "one", "." ] assert isinstance(result["trees"], str) for class_distribution in result["class_probabilities"]: self.assertAlmostEqual(sum(class_distribution), 1.0, places=4)
def __init__(self, vocab: Vocabulary, question_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, entity_encoder: Seq2VecEncoder, attention: Attention, decoder_beam_size: int, decoder_num_finished_states: int, max_decoding_steps: int, mixture_feedforward: FeedForward = None, normalize_beam_score_by_length: bool = False, checklist_cost_weight: float = 0.6, use_neighbor_similarity_for_linking: bool = False, dropout: float = 0.0, num_linking_features: int = 10, rule_namespace: str = 'rule_labels', tables_directory: str = '/wikitables/', mml_model_file: str = None) -> None: use_similarity = use_neighbor_similarity_for_linking super().__init__(vocab=vocab, question_embedder=question_embedder, action_embedding_dim=action_embedding_dim, encoder=encoder, entity_encoder=entity_encoder, max_decoding_steps=max_decoding_steps, use_neighbor_similarity_for_linking=use_similarity, dropout=dropout, num_linking_features=num_linking_features, rule_namespace=rule_namespace, tables_directory=tables_directory) # Not sure why mypy needs a type annotation for this! self._decoder_trainer: ExpectedRiskMinimization = \ ExpectedRiskMinimization(beam_size=decoder_beam_size, normalize_by_length=normalize_beam_score_by_length, max_decoding_steps=self._max_decoding_steps, max_num_finished_states=decoder_num_finished_states) unlinked_terminals_global_indices = [] global_vocab = self.vocab.get_token_to_index_vocabulary(rule_namespace) for production, index in global_vocab.items(): right_side = production.split(" -> ")[1] if right_side in types.COMMON_NAME_MAPPING: # This is a terminal production. unlinked_terminals_global_indices.append(index) self._num_unlinked_terminals = len(unlinked_terminals_global_indices) self._decoder_step = WikiTablesDecoderStep( encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, input_attention=attention, num_start_types=self._num_start_types, num_entity_types=self._num_entity_types, mixture_feedforward=mixture_feedforward, dropout=dropout, unlinked_terminal_indices=unlinked_terminals_global_indices) self._checklist_cost_weight = checklist_cost_weight self._agenda_coverage = Average() # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've # copied a trained ERM model from a different machine and the original MML model that was # used to initialize it does not exist on the current machine. This may not be the best # solution for the problem. if mml_model_file is not None: if os.path.isfile(mml_model_file): archive = load_archive(mml_model_file) self._initialize_weights_from_archive(archive) else: # A model file is passed, but it does not exist. This is expected to happen when # you're using a trained ERM model to decode. But it may also happen if the path to # the file is really just incorrect. So throwing a warning. logger.warning( "MML model file for initializing weights is passed, but does not exist." " This is fine if you're just decoding.")
def demo_model(archive_file: str, predictor_name: str) -> Predictor: archive = load_archive(archive_file) return Predictor.from_archive(archive, predictor_name)