def metaMain(modelFile=None,wids_de=None,wids_en=None):

    if modelFile==None:
        wids_de,wids_en,modelFile=main()
    
    model=dy.Model()
    encoder_params={}
    decoder_params={}
    (encoder,revcoder,decoder,encoder_params["lookup"],decoder_params["lookup"],decoder_params["R"],decoder_params["bias"])=model.load(modelFile)

    print "Reversing dictionaries"
    #Reverse dictionaries
    reverse_wids_en=reverseDictionary(wids_en)
    reverse_wids_de=reverseDictionary(wids_de)

    print "Reading Test Data"
    test_sentences_en=readData.read_corpus(wids_en,mode="test",update_dict=False,min_frequency=MIN_EN_FREQUENCY,language="en")
    test_sentences_de=readData.read_corpus(wids_de,mode="test",update_dict=False,min_frequency=MIN_DE_FREQUENCY,language="de")
    test_sentences=zip(test_sentences_de,test_sentences_en)

    print "Reading blind German"
    blind_sentences_de=readData.read_corpus(wids_de,mode="blind",update_dict=False,min_frequency=MIN_EN_FREQUENCY,language="de")

    testPerplexity=computePerplexity(model,encoder,revcoder,decoder,encoder_params,decoder_params,test_sentences)
    print "Test perplexity,",testPerplexity
    
    outFileName=modelFile+"_testOutput"+"_"+str(datetime.datetime.now())
    blindFileName=modelFile+"_blindOutput"+"_"+str(datetime.datetime.now())
    refFileName=modelFile+"_testRef"+"_"+str(datetime.datetime.now())
    outFile=open(outFileName,"w")
    refFile=open(refFileName,"w")
    blindFile=open(blindFileName,"w")
    bleuOutputFile=modelFile+"_BLEU"

    print "Decoding Test Sentences"
    for test_sentence in test_sentences:
        sentence_en_hat,interpreted_test_sentence_en_hat,loss=greedyDecode(model,encoder,revcoder,decoder,encoder_params,decoder_params,test_sentence[0],reverse_wids_en)
        #print interpreted_test_sentence_en_hat
        outFile.write(interpreted_test_sentence_en_hat+"\n")
        interpreted_test_sentence_en=" ".join([reverse_wids_en[x] for x in test_sentence[1]])
        refFile.write(interpreted_test_sentence_en+"\n")

    outFile.close()
    refFile.close()

    print "wrote test data"

    print "Decoding Blind Sentences"
    for blind_sentence_de in blind_sentences_de:
        sentence_en_hat,interpreted_test_sentence_en_hat,loss=greedyDecode(model,encoder,revcoder,decoder,encoder_params,decoder_params,blind_sentence_de,reverse_wids_en)
        blindFile.write(interpreted_test_sentence_en_hat+"\n")
 
    blindFile.close()
    print "Finished Decoding Blind Sentences"
    #print "Computing perplexity"
    #import shlex
    #import subprocess
    #subprocess.call(["perl","multi-bleu.perl","-lc",refFileName,"<",outFileName],stdout=stdout)
    print "Over"
    return modelFile
Exemple #2
0
outFile1Name = str(sys.argv[2])
outFile2Name = str(sys.argv[3])

wids_de = defaultdict(int)
wids_en = defaultdict(int)

model = dy.Model()
encoder_params = {}
decoder_params = {}
(encoder, revcoder, decoder, encoder_params["lookup"],
 decoder_params["lookup"], decoder_params["R"],
 decoder_params["bias"]) = model.load(modelFile)

train_sentences_en = readData.read_corpus(wids_en,
                                          mode="train",
                                          update_dict=True,
                                          min_frequency=bedb.MIN_EN_FREQUENCY,
                                          language="en")
train_sentences_de = readData.read_corpus(wids_de,
                                          mode="train",
                                          update_dict=True,
                                          min_frequency=bedb.MIN_DE_FREQUENCY,
                                          language="de")

reverse_wids_de = bedb.reverseDictionary(wids_de)
reverse_wids_en = bedb.reverseDictionary(wids_en)

train_sentences_en = None
train_sentences_de = None

test_sentences_de = readData.read_corpus(wids_de,
Exemple #3
0
def main():
    torch.manual_seed(1)
    random.seed(7867567)

    modelName=sys.argv[1]

    wids_src=defaultdict(lambda: len(wids_src))
    wids_tgt=defaultdict(lambda: len(wids_tgt))

    
    train_src=readData.read_corpus(wids_src,mode="train",update_dict=True,min_frequency=cnfg.min_src_frequency,language=srcLang)
    train_tgt=readData.read_corpus(wids_tgt,mode="train",update_dict=True,min_frequency=cnfg.min_tgt_frequency,language=tgtLang)

    valid_src=readData.read_corpus(wids_src,mode="valid",update_dict=False,min_frequency=cnfg.min_src_frequency,language=srcLang)
    valid_tgt=readData.read_corpus(wids_tgt,mode="valid",update_dict=False,min_frequency=cnfg.min_tgt_frequency,language=tgtLang)

    test_src=readData.read_corpus(wids_src,mode="test",update_dict=False,min_frequency=cnfg.min_src_frequency,language=srcLang)
    test_tgt=readData.read_corpus(wids_tgt,mode="test",update_dict=False,min_frequency=cnfg.min_tgt_frequency,language=tgtLang)



    train_src,train_tgt=train_src[:cnfg.max_train_sentences],train_tgt[:cnfg.max_train_sentences]
    print "src vocab size:",len(wids_src)
    print "tgt vocab size:",len(wids_tgt)
    print "training size:",len(train_src)
    print "valid size:",len(valid_src)

    train=zip(train_src,train_tgt) #zip(train_src,train_tgt)
    valid=zip(valid_src,valid_tgt) #zip(train_src,train_tgt)
    

    #train.sort(key=lambda x:-len(x[1]))
    #valid.sort(key=lambda x:-len(x[1]))

    train.sort(key=lambda x:len(x[0]))
    valid.sort(key=lambda x:len(x[0]))


    train_src,train_tgt=[x[0] for x in train],[x[1] for x in train]
    valid_src,valid_tgt=[x[0] for x in valid],[x[1] for x in valid]
    

    #NUM_TOKENS=sum([len(x) for x in train_tgt])

    train_src_batches,train_src_masks=torch_utils.splitBatches(train=train_src,batch_size=cnfg.batch_size,padSymbol=cnfg.garbage,method="pre")
    train_tgt_batches,train_tgt_masks=torch_utils.splitBatches(train=train_tgt,batch_size=cnfg.batch_size,padSymbol=cnfg.garbage,method="post")
    valid_src_batches,valid_src_masks=torch_utils.splitBatches(train=valid_src,batch_size=cnfg.batch_size,padSymbol=cnfg.garbage,method="pre")
    valid_tgt_batches,valid_tgt_masks=torch_utils.splitBatches(train=valid_tgt,batch_size=cnfg.batch_size,padSymbol=cnfg.garbage,method="post")
    test_src_batches,test_src_masks=torch_utils.splitBatches(train=test_src,batch_size=1,padSymbol=cnfg.garbage,method="pre")
    test_tgt_batches,test_tgt_masks=torch_utils.splitBatches(train=test_tgt,batch_size=1,padSymbol=cnfg.garbage,method="post")


    #Dump useless references
    train=None
    valid=None
    #Sanity check
    assert (len(train_tgt_batches)==len(train_src_batches))
    assert (len(valid_tgt_batches)==len(valid_src_batches))
    assert (len(test_tgt_batches)==len(test_src_batches))

    print "Training Batches:",len(train_tgt_batches)
    print "Validation Batches:",len(valid_tgt_batches)
    print "Test Points:",len(test_src_batches)

    if cnfg.cudnnBenchmark:
        torch.backends.cudnn.benchmark=True
    #Declare model object
    print "Declaring Model, Loss, Optimizer"
    model=SeqToSeqAttn(cnfg,wids_src=wids_src,wids_tgt=wids_tgt)
    loss_function=nn.NLLLoss()
    if torch.cuda.is_available():
        model.cuda()
        loss_function=loss_function.cuda()
    optimizer=None
    if cnfg.optimizer_type=="SGD":
        optimizer=optim.SGD(model.parameters(),lr=0.05)
    elif cnfg.optimizer_type=="ADAM":
        optimizer=optim.Adam(model.parameters())

    if cnfg.mode=="trial":
        print "Running Sample Batch" 
        print "Source Batch Shape:",train_src_batches[30].shape
        print "Source Mask Shape:",train_src_masks[30].shape
        print "Target Batch Shape:",train_tgt_batches[30].shape
        print "Target Mask Shape:",train_tgt_masks[30].shape
        sample_src_batch=train_src_batches[30]
        sample_tgt_batch=train_tgt_batches[30]
        sample_mask=train_tgt_masks[30]
        sample_src_mask=train_src_masks[30]
        print datetime.datetime.now() 
        model.zero_grad()
        loss=model(sample_src_batch,sample_tgt_batch,sample_src_mask,sample_mask)
        print loss
        loss.backward()
        optimizer.step()
        print datetime.datetime.now()
        #print torch.backends.cudnn.benchmark
        #print torch.backends.cudnn.enabled
        print "Done Running Sample Batch"

    train_batches=zip(train_src_batches,train_tgt_batches,train_src_masks,train_tgt_masks)
    valid_batches=zip(valid_src_batches,valid_tgt_batches,valid_src_masks,valid_tgt_masks)

    train_src_batches,train_tgt_batches,train_src_masks,train_tgt_masks=None,None,None,None
    valid_src_batches,valid_tgt_batches,valid_src_masks,valid_tgt_masks=None,None,None,None
    if cnfg.mode=="train":
        print "Start Time:",datetime.datetime.now()     
        for epochId in range(cnfg.NUM_EPOCHS):
            random.shuffle(train_batches)
            for batchId,batch in enumerate(train_batches):
                src_batch,tgt_batch,src_mask,tgt_mask=batch[0],batch[1],batch[2],batch[3]
                batchLength=src_batch.shape[1]
                batchSize=src_batch.shape[0]
                #print "Batch Length:",batchLength
                if batchLength<cnfg.MAX_SEQ_LEN and batchSize>1:
                    model.zero_grad()
                    loss=model(src_batch,tgt_batch,src_mask,tgt_mask)
                    if cnfg.mem_optimize:
                        del src_batch,tgt_batch,src_mask,tgt_mask
                    loss.backward()
                    if cnfg.mem_optimize:
                        del loss
                    optimizer.step()               
                if batchId%cnfg.PRINT_STEP==0:
                    print "Batch No:",batchId," Time:",datetime.datetime.now()

            totalValidationLoss=0.0
            NUM_TOKENS=0.0
            for batchId,batch in enumerate(valid_batches):
                src_batch,tgt_batch,src_mask,tgt_mask=batch[0],batch[1],batch[2],batch[3]
                model.zero_grad()
                loss=model(src_batch,tgt_batch,src_mask,tgt_mask,inference=True)
                if cnfg.normalizeLoss:
                    totalValidationLoss+=(loss.data.cpu().numpy())*np.sum(tgt_mask)
                else:
                    totalValidationLoss+=(loss.data.cpu().numpy())
                NUM_TOKENS+=np.sum(tgt_mask)
                if cnfg.mem_optimize:
                    del src_batch,tgt_batch,src_mask,tgt_mask,loss
            
            model.save_checkpoint(modelName+"_"+str(epochId),optimizer)

            perplexity=math.exp(totalValidationLoss/NUM_TOKENS)
            print "Epoch:",epochId," Total Validation Loss:",totalValidationLoss," Perplexity:",perplexity
        print "End Time:",datetime.datetime.now()

    elif cnfg.mode=="inference":
        model.load_from_checkpoint(modelName)
        #print " ".join([model.reverse_wids_src[x] for x in test_src_batches[1][0]])
        #print " ".join([model.reverse_wids_tgt[x] for x in test_tgt_batches[1][0]])
        #model(test_src_batches[1],test_tgt_batches[1],test_tgt_masks[1])
        model.decodeAll(test_src_batches,modelName,method="greedy",evalMethod="BLEU",suffix="test")
def main():
    # Read in data
    wids_en=defaultdict(lambda: len(wids_en))
    wids_de=defaultdict(lambda: len(wids_de))

    train_sentences_en=readData.read_corpus(wids_en,mode="train",update_dict=True,min_frequency=MIN_EN_FREQUENCY,language="en")
    train_sentences_de=readData.read_corpus(wids_de,mode="train",update_dict=True,min_frequency=MIN_DE_FREQUENCY,language="de")

    enDictionaryFile="Models/"+"en-dict_"+str(MIN_EN_FREQUENCY)+".txt" 
    deDictionaryFile="Models/"+"de-dict_"+str(MIN_DE_FREQUENCY)+".txt"

    dicFile=open(enDictionaryFile,"w")
    print len(wids_en)
    for key in wids_en:
        dicFile.write(key+" "+str(wids_en[key])+"\n")
    dicFile.close()
    print "Writing EN"

    dicFile=open(deDictionaryFile,"w")
    print len(wids_de)
    for key in wids_en:
        dicFile.write(key+" "+str(wids_de[key])+"\n")
    dicFile.close()
    print "Writing DE"

    reverse_wids_en=reverseDictionary(wids_en)
    reverse_wids_de=reverseDictionary(wids_de)

    valid_sentences_en=readData.read_corpus(wids_en,mode="valid",update_dict=False,min_frequency=MIN_EN_FREQUENCY,language="en")
    valid_sentences_de=readData.read_corpus(wids_de,mode="valid",update_dict=False,min_frequency=MIN_DE_FREQUENCY,language="de")

    train_sentences=zip(train_sentences_de,train_sentences_en)
    valid_sentences=zip(valid_sentences_de,valid_sentences_en)

    for train_sentence in train_sentences[:10]:
        print "German:",[reverse_wids_de[x] for x in train_sentence[0]]
        print "English:",[reverse_wids_en[x] for x in train_sentence[1]]


    train_sentences=train_sentences[:MAX_TRAIN_SENTENCES]
    valid_sentences=valid_sentences

    print "Number of Training Sentences:",len(train_sentences)
    print "Number of Validation Sentences:",len(valid_sentences)


    VOCAB_SIZE_EN=len(wids_en)
    VOCAB_SIZE_DE=len(wids_de)

    random.shuffle(train_sentences)
    random.shuffle(valid_sentences)

    #Prepare batches
    lengthMap={}
    for x in train_sentences:
        if len(x[0]) not in lengthMap:
            lengthMap[len(x[0])]=[]
        lengthMap[len(x[0])].append(x)

    print "Number of Different Lengths:",len(lengthMap)

    train_batches=[]

    for megaBatch in lengthMap.values():
        index=0
        while index<len(megaBatch):
            if index%BATCH_SIZE==0:
                batch=megaBatch[index:min(index+BATCH_SIZE,len(megaBatch))]
                train_batches.append(batch)
                index+=BATCH_SIZE

    print [len(batch) for batch in train_batches]
    print sum([len(batch) for batch in train_batches])

    #Free some memory.Dump useless references
    train_sentences=None
    train_sentences_en=None
    train_sentences_de=None

    #Specify model
    model=dy.Model()

    encoder=dy.LSTMBuilder(LAYER_DEPTH,EMB_SIZE,HIDDEN_SIZE,model)
    revcoder=dy.LSTMBuilder(LAYER_DEPTH,EMB_SIZE,HIDDEN_SIZE,model)
    decoder=dy.LSTMBuilder(LAYER_DEPTH,EMB_SIZE+HIDDEN_SIZE,HIDDEN_SIZE,model)

    encoder_params={}
    encoder_params["lookup"]=model.add_lookup_parameters((VOCAB_SIZE_DE,EMB_SIZE))

    decoder_params={}
    decoder_params["lookup"]=model.add_lookup_parameters((VOCAB_SIZE_EN,EMB_SIZE))
    decoder_params["R"]=model.add_parameters((VOCAB_SIZE_EN,HIDDEN_SIZE))
    decoder_params["bias"]=model.add_parameters((VOCAB_SIZE_EN))

    trainer=dy.AdamTrainer(model)

    totalSentences=0
    sentencesCovered=totalSentences/3200

    startTime=datetime.datetime.now()
    print "Start Time",startTime
    for epochId in xrange(NUM_EPOCHS):    
        random.shuffle(train_batches)
        for batchId,batch in enumerate(train_batches):
            if len(batch)>1:
                totalSentences+=len(batch)
                if totalSentences/3200>sentencesCovered:
                    sentencesCovered=totalSentences/3200
                    print "Sentences covered:",totalSentences,"Current Time",datetime.datetime.now()
                sentence_de=[sentence[0] for sentence in batch]
                sentence_en=[sentence[1] for sentence in batch]
                loss,words=do_one_batch(model,encoder,revcoder,decoder,encoder_params,decoder_params,sentence_de,sentence_en)
                loss.value()
                loss.backward()
                trainer.update()
            else:
                totalSentences+=1
                #print "Sentences covered:",totalSentences
                sentence=batch[0]
                sentence_de=sentence[0]
                sentence_en=sentence[1]
                loss,words=do_one_example(model,encoder,revcoder,decoder,encoder_params,decoder_params,sentence_de,sentence_en)
                loss.value()
                loss.backward()
                trainer.update()
            #if totalSentences%1000<20:
            #    print "Total Sentences Covered:",totalSentences

        
        perplexity=0.0
        totalLoss=0.0
        totalWords=0.0
        for valid_sentence in valid_sentences:
            valid_sentence_de=valid_sentence[0]
            valid_sentence_en=valid_sentence[1]
            validLoss,words=do_one_example(model,encoder,revcoder,decoder,encoder_params,decoder_params,valid_sentence_de,valid_sentence_en)
            totalLoss+=float(validLoss.value())
            totalWords+=words
        print totalLoss
        print totalWords
        perplexity=math.exp(totalLoss/totalWords)
        print "Validation perplexity after epoch:",epochId,"Perplexity:",perplexity,"Time:",datetime.datetime.now()             
        
        trainer.update_epoch(1.0)
        
        #Save Model
        modelFile="Models/"+"barebones_enc_dec_batched"+"_"+str(datetime.datetime.now())+"_"+str(EMB_SIZE)+"_"+str(LAYER_DEPTH)+"_"+str(HIDDEN_SIZE)+"_"+str(MIN_EN_FREQUENCY)+"_"+str(MIN_DE_FREQUENCY)
        model.save(modelFile,[encoder,revcoder,decoder,encoder_params["lookup"],decoder_params["lookup"],decoder_params["R"],decoder_params["bias"]])

    return wids_de,wids_en,modelFile
    print "Decoding Blind Sentences"
    for blind_sentence_de in blind_sentences_de:
        sentence_en_hat,interpreted_test_sentence_en_hat,loss=greedyDecode(model,encoder,revcoder,decoder,encoder_params,decoder_params,blind_sentence_de,reverse_wids_en)
        blindFile.write(interpreted_test_sentence_en_hat+"\n")
 
    blindFile.close()
    print "Finished Decoding Blind Sentences"
    #print "Computing perplexity"
    #import shlex
    #import subprocess
    #subprocess.call(["perl","multi-bleu.perl","-lc",refFileName,"<",outFileName],stdout=stdout)
    print "Over"
    return modelFile

if __name__=="__main__":
    modelFile=None

    if modelFile!=None:
        wids_en=defaultdict(lambda: len(wids_en))
        wids_de=defaultdict(lambda: len(wids_en))

        train_sentences_en=readData.read_corpus(wids_en,mode="train",update_dict=True,min_frequency=MIN_EN_FREQUENCY,language="en")
        train_sentences_de=readData.read_corpus(wids_de,mode="train",update_dict=True,min_frequency=MIN_DE_FREQUENCY,language="de")
        train_sentences_en=None
        train_sentences_de=None
        metaMain(modelFile=modelFile,wids_de=wids_de,wids_en=wids_en)

    else:
        modelFile=metaMain()
        print "Model File Name:",modelFile