示例#1
0
    def setUp(self):
        layer_dim = 512
        events.clear()
        ParamManager.init_param_col()
        src_vocab = Vocab(vocab_file="examples/data/head.ja.vocab")
        trg_vocab = Vocab(vocab_file="examples/data/head.en.vocab")
        self.model = DefaultTranslator(
            src_reader=PlainTextReader(vocab=src_vocab),
            trg_reader=PlainTextReader(vocab=trg_vocab),
            src_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100),
            encoder=BiLSTMSeqTransducer(input_dim=layer_dim,
                                        hidden_dim=layer_dim),
            attender=MlpAttender(input_dim=layer_dim,
                                 state_dim=layer_dim,
                                 hidden_dim=layer_dim),
            decoder=AutoRegressiveDecoder(
                input_dim=layer_dim,
                embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100),
                rnn=UniLSTMSeqTransducer(input_dim=layer_dim,
                                         hidden_dim=layer_dim,
                                         decoder_input_dim=layer_dim,
                                         yaml_path="model.decoder.rnn"),
                transform=NonLinear(input_dim=layer_dim * 2,
                                    output_dim=layer_dim),
                scorer=Softmax(input_dim=layer_dim, vocab_size=100),
                bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
        )
        event_trigger.set_train(False)

        self.src_data = list(
            self.model.src_reader.read_sents("examples/data/head.ja"))
示例#2
0
    def setUp(self):
        events.clear()
        ParamManager.init_param_col()

        # Load a pre-trained model
        load_experiment = LoadSerialized(filename=f"test/data/tiny_jaen.model",
                                         overwrite=[
                                             {
                                                 "path": "train",
                                                 "val": None
                                             },
                                             {
                                                 "path": "status",
                                                 "val": None
                                             },
                                         ])
        EXP_DIR = '.'
        EXP = "decode"
        uninitialized_experiment = YamlPreloader.preload_obj(load_experiment,
                                                             exp_dir=EXP_DIR,
                                                             exp_name=EXP)
        loaded_experiment = initialize_if_needed(uninitialized_experiment)
        ParamManager.populate()

        # Pull out the parts we need from the experiment
        self.model = loaded_experiment.model
        src_vocab = self.model.src_reader.vocab
        trg_vocab = self.model.trg_reader.vocab

        event_trigger.set_train(False)

        self.src_data = list(
            self.model.src_reader.read_sents("test/data/head.ja"))
        self.trg_data = list(
            self.model.trg_reader.read_sents("test/data/head.en"))
示例#3
0
    def setUp(self):
        events.clear()
        self.hyp = ["the taro met the hanako".split()]
        self.ref = ["taro met hanako".split()]

        vocab = Vocab(i2w=["the", "taro", "met", "hanako"])
        self.hyp_id = list(map(vocab.convert, self.hyp[0]))
        self.ref_id = list(map(vocab.convert, self.ref[0]))
示例#4
0
    def setUp(self):
        events.clear()
        ParamManager.init_param_col()

        src_vocab = Vocab(vocab_file="test/data/head.ja.vocab")
        trg_vocab = Vocab(vocab_file="test/data/head.en.vocab")
        self.src_reader = PlainTextReader(vocab=src_vocab)
        self.trg_reader = PlainTextReader(vocab=trg_vocab)
        self.src_data = list(self.src_reader.read_sents("test/data/head.ja"))
        self.trg_data = list(self.trg_reader.read_sents("test/data/head.en"))
示例#5
0
 def setUp(self):
     events.clear()
     xnmt.resolved_serialize_params = {}
     yaml.add_representer(DummyArgClass, xnmt.init_representer)
     yaml.add_representer(DummyArgClass2, xnmt.init_representer)
     self.out_dir = os.path.join("test", "tmp")
     utils.make_parent_dir(os.path.join(self.out_dir, "asdf"))
     self.model_file = os.path.join(self.out_dir, "saved.mod")
     param_collections.ParamManager.init_param_col()
     param_collections.ParamManager.param_col.model_file = self.model_file
示例#6
0
 def setUp(self):
   events.clear()
示例#7
0
 def setUp(self):
   events.clear()
   self.input_reader = PlainTextReader(vocab=Vocab(vocab_file="examples/data/head.ja.vocab"))
   list(self.input_reader.read_sents('examples/data/head.ja'))
   ParamManager.init_param_col()
示例#8
0
 def setUp(self):
   events.clear()
   ParamManager.init_param_col()