batch_size = 64 chunk_num = 10 #train_iteration = 10 train_iteration = 12 display_fre = 50 half = 4 # data augmentation # save the models model_dir = "models_train" if not os.path.exists(model_dir): os.makedirs(model_dir) ## ====================================== # with data augmentation train_dataset = TorchDataSet(train_list, batch_size, chunk_num, dimension) # without data augmentation dev_dataset = TorchDataSet(dev_list, batch_size, chunk_num, dimension) logging.info('finish reading all train data') # 优化器,SGD更新梯度 train_module = LanNet(input_dim=dimension, hidden_dim=128, bn_dim=30, output_dim=language_nums) logging.info(train_module) optimizer = torch.optim.SGD(train_module.parameters(), lr=learning_rate, momentum=0.9) # initialize the model
f = open("/result/result.txt", "w") #f.write("posterior: changsha, hebei, nanchang, shanghai, kejia, minnan\n") fangyan = np.array( ["minnan", "nanchang", "kejia", "changsha", "shanghai", "hebei"]) sentences = [] with open("./label_dev_list_fb.txt", "r") as s: for line in s.readlines(): sentences.append(line.strip().split("/")[-1].split()[0].replace( "fb", "pcm")) sentences = np.array(sentences) #print len(sentences) ## ====================================== dev_dataset = TorchDataSet(dev_list, batch_size, chunk_num, dimension) logging.info('finish reading all train data') train_module = LanNet(input_dim=dimension, hidden_dim=128, bn_dim=30, output_dim=language_nums) logging.info(train_module) train_module.load_state_dict( torch.load('/inference/models/model9.model', map_location=lambda storage, loc: storage)) train_module.eval() epoch_tic = time.time() dev_loss = 0. dev_acc = 0.