示例#1
0
    def test_beam(self):
        src_voc = Vocabulary()
        trg_voc = Vocabulary()
        for tok in "</s> I am Philip You are a".split():
            src_voc[tok]
        for tok in "</s> 私 は フィリップ です 1 2 3".split():
            trg_voc[tok]
        model = EncDecNMT(Args("attn"), src_voc, trg_voc, optimizer=optimizers.SGD())

        model_out = "/tmp/model-nmt.temp"
        X, Y  = src_voc, trg_voc
        
        # Train with 1 example
        src = np.array([[X["I"], X["am"], X["Philip"]]], dtype=np.int32)
        trg = np.array([[Y["私"], Y["は"], Y["フィリップ"], Y["です"]]], dtype=np.int32)
        
        model.train(src, trg)
            
        # Save
        serializer = ModelSerializer(model_out)
        serializer.save(model)

        # Load
        model1 = EncDecNMT(InitArgs(model_out))
        k      = model.classify(src, beam=10)
示例#2
0
 def test_NMT_2_read_write(self):
     for model in ["encdec", "attn"]:
         src_voc = Vocabulary()
         trg_voc = Vocabulary()
         for tok in "</s> I am Philip".split():
             src_voc[tok]
         for tok in "</s> 私 は フィリップ です".split():
             trg_voc[tok]
         model = EncDecNMT(Args(model), src_voc, trg_voc, optimizer=optimizers.SGD())
 
         model_out = "/tmp/nmt/tmp"
         X, Y  = src_voc, trg_voc
         
         # Train with 1 example
         src = np.array([[X["I"], X["am"], X["Philip"]]], dtype=np.int32)
         trg = np.array([[Y["私"], Y["は"], Y["フィリップ"], Y["です"]]], dtype=np.int32)
         
         model.train(src, trg)
             
         # Save
         serializer = ModelSerializer(model_out)
         serializer.save(model)
 
         # Load
         model1 = EncDecNMT(InitArgs(model_out))
             
         # Check
         self.assertModelEqual(model._model, model1._model)
示例#3
0
文件: nmt.py 项目: philip30/chainn
parser.add_argument("--eos_disc", type=float, default=0.0, help="Give fraction positive discount to output longer sentence.")
args  = parser.parse_args()

""" Sanity Check """
if args.use_cpu:
    args.gpu = -1
if args.src and args.batch != 1 and args.beam > 1:
    raise ValueError("Batched decoding does not support beam search.")

""" Begin Testing """
ao_fp = UF.load_stream(args.align_out)
decoding_options = {"gen_limit": args.gen_limit, "eos_disc": args.eos_disc, "beam": args.beam}

# Loading model
UF.trace("Setting up classifier")
model    = EncDecNMT(args, use_gpu=args.gpu, collect_output=True)
SRC, TRG = model.get_vocabularies()

# Testing callbacks
def print_result(ctr, trg, TRG, src, SRC, fp=sys.stderr):
    for i, (sent, result) in enumerate(zip(src, trg.y)):
        print(ctr + i, file=fp)
        print("SRC:", SRC.str_rpr(sent), file=fp)
        print("TRG:", TRG.str_rpr(result), file=fp)
    fp.flush()

def onDecodingStart():
    UF.trace("Decoding started.")

def onBatchUpdate(ctr, src, trg):
    # Decoding