def test_beam(self): src_voc = Vocabulary() trg_voc = Vocabulary() for tok in "</s> I am Philip You are a".split(): src_voc[tok] for tok in "</s> 私 は フィリップ です 1 2 3".split(): trg_voc[tok] model = EncDecNMT(Args("attn"), src_voc, trg_voc, optimizer=optimizers.SGD()) model_out = "/tmp/model-nmt.temp" X, Y = src_voc, trg_voc # Train with 1 example src = np.array([[X["I"], X["am"], X["Philip"]]], dtype=np.int32) trg = np.array([[Y["私"], Y["は"], Y["フィリップ"], Y["です"]]], dtype=np.int32) model.train(src, trg) # Save serializer = ModelSerializer(model_out) serializer.save(model) # Load model1 = EncDecNMT(InitArgs(model_out)) k = model.classify(src, beam=10)
def test_NMT_2_read_write(self): for model in ["encdec", "attn"]: src_voc = Vocabulary() trg_voc = Vocabulary() for tok in "</s> I am Philip".split(): src_voc[tok] for tok in "</s> 私 は フィリップ です".split(): trg_voc[tok] model = EncDecNMT(Args(model), src_voc, trg_voc, optimizer=optimizers.SGD()) model_out = "/tmp/nmt/tmp" X, Y = src_voc, trg_voc # Train with 1 example src = np.array([[X["I"], X["am"], X["Philip"]]], dtype=np.int32) trg = np.array([[Y["私"], Y["は"], Y["フィリップ"], Y["です"]]], dtype=np.int32) model.train(src, trg) # Save serializer = ModelSerializer(model_out) serializer.save(model) # Load model1 = EncDecNMT(InitArgs(model_out)) # Check self.assertModelEqual(model._model, model1._model)
parser.add_argument("--eos_disc", type=float, default=0.0, help="Give fraction positive discount to output longer sentence.") args = parser.parse_args() """ Sanity Check """ if args.use_cpu: args.gpu = -1 if args.src and args.batch != 1 and args.beam > 1: raise ValueError("Batched decoding does not support beam search.") """ Begin Testing """ ao_fp = UF.load_stream(args.align_out) decoding_options = {"gen_limit": args.gen_limit, "eos_disc": args.eos_disc, "beam": args.beam} # Loading model UF.trace("Setting up classifier") model = EncDecNMT(args, use_gpu=args.gpu, collect_output=True) SRC, TRG = model.get_vocabularies() # Testing callbacks def print_result(ctr, trg, TRG, src, SRC, fp=sys.stderr): for i, (sent, result) in enumerate(zip(src, trg.y)): print(ctr + i, file=fp) print("SRC:", SRC.str_rpr(sent), file=fp) print("TRG:", TRG.str_rpr(result), file=fp) fp.flush() def onDecodingStart(): UF.trace("Decoding started.") def onBatchUpdate(ctr, src, trg): # Decoding