def setUp(self): layer_dim = 512 events.clear() ParamManager.init_param_col() src_vocab = Vocab(vocab_file="examples/data/head.ja.vocab") trg_vocab = Vocab(vocab_file="examples/data/head.en.vocab") self.model = DefaultTranslator( src_reader=PlainTextReader(vocab=src_vocab), trg_reader=PlainTextReader(vocab=trg_vocab), src_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100), encoder=BiLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim), attender=MlpAttender(input_dim=layer_dim, state_dim=layer_dim, hidden_dim=layer_dim), decoder=AutoRegressiveDecoder( input_dim=layer_dim, embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100), rnn=UniLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim, decoder_input_dim=layer_dim, yaml_path="model.decoder.rnn"), transform=NonLinear(input_dim=layer_dim * 2, output_dim=layer_dim), scorer=Softmax(input_dim=layer_dim, vocab_size=100), bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)), ) event_trigger.set_train(False) self.src_data = list( self.model.src_reader.read_sents("examples/data/head.ja"))
def setUp(self): events.clear() ParamManager.init_param_col() # Load a pre-trained model load_experiment = LoadSerialized(filename=f"test/data/tiny_jaen.model", overwrite=[ { "path": "train", "val": None }, { "path": "status", "val": None }, ]) EXP_DIR = '.' EXP = "decode" uninitialized_experiment = YamlPreloader.preload_obj(load_experiment, exp_dir=EXP_DIR, exp_name=EXP) loaded_experiment = initialize_if_needed(uninitialized_experiment) ParamManager.populate() # Pull out the parts we need from the experiment self.model = loaded_experiment.model src_vocab = self.model.src_reader.vocab trg_vocab = self.model.trg_reader.vocab event_trigger.set_train(False) self.src_data = list( self.model.src_reader.read_sents("test/data/head.ja")) self.trg_data = list( self.model.trg_reader.read_sents("test/data/head.en"))
def setUp(self): events.clear() self.hyp = ["the taro met the hanako".split()] self.ref = ["taro met hanako".split()] vocab = Vocab(i2w=["the", "taro", "met", "hanako"]) self.hyp_id = list(map(vocab.convert, self.hyp[0])) self.ref_id = list(map(vocab.convert, self.ref[0]))
def setUp(self): events.clear() ParamManager.init_param_col() src_vocab = Vocab(vocab_file="test/data/head.ja.vocab") trg_vocab = Vocab(vocab_file="test/data/head.en.vocab") self.src_reader = PlainTextReader(vocab=src_vocab) self.trg_reader = PlainTextReader(vocab=trg_vocab) self.src_data = list(self.src_reader.read_sents("test/data/head.ja")) self.trg_data = list(self.trg_reader.read_sents("test/data/head.en"))
def setUp(self): events.clear() xnmt.resolved_serialize_params = {} yaml.add_representer(DummyArgClass, xnmt.init_representer) yaml.add_representer(DummyArgClass2, xnmt.init_representer) self.out_dir = os.path.join("test", "tmp") utils.make_parent_dir(os.path.join(self.out_dir, "asdf")) self.model_file = os.path.join(self.out_dir, "saved.mod") param_collections.ParamManager.init_param_col() param_collections.ParamManager.param_col.model_file = self.model_file
def setUp(self): events.clear()
def setUp(self): events.clear() self.input_reader = PlainTextReader(vocab=Vocab(vocab_file="examples/data/head.ja.vocab")) list(self.input_reader.read_sents('examples/data/head.ja')) ParamManager.init_param_col()
def setUp(self): events.clear() ParamManager.init_param_col()