def train(corpus, bigrams_dims, unigrams_dims, lstm_units, hidden_units, epochs, batch_size, train_data_file, dev_data_file, model_save_file, droprate, unk_params, alpha, beta): start_time = time.time() fm = corpus bigrams_size = corpus.total_bigrams() unigrams_size = corpus.total_unigrams() network = Network( bigrams_size=bigrams_size, unigrams_size=unigrams_size, bigrams_dims=bigrams_dims, unigrams_dims=unigrams_dims, lstm_units=lstm_units, hidden_units=hidden_units, label_size=fm.total_labels(), span_nums=fm.total_span_nums(), droprate=droprate, ) network.init_params() print('Hidden units : {} ,per LSTM units : {}'.format( hidden_units, lstm_units, )) print('Embeddings: bigrams = {}, unigrams = {}'.format( (bigrams_size, bigrams_dims), (unigrams_size, unigrams_dims))) print('Dropout rate : {}'.format(droprate)) print('Parameters initialized in [-0.01,0.01]') print('Random UNKing parameter z = {}'.format(unk_params)) training_data = corpus.gold_data_from_file(train_data_file) num_batched = -(-len(training_data) // batch_size) print('Loaded {} training sentences ({} batches of size {})!'.format( len(training_data), num_batched, batch_size, )) parse_every = -(-num_batched // 4) dev_sentences = SegSentence.load_sentence_file(dev_data_file) print('Loaded {} validation sentences!'.format(len(dev_sentences))) best_acc = FScore() for epoch in xrange(1, epochs + 1): print('............ epoch {} ............'.format(epoch)) total_cost = 0.0 total_states = 0 training_acc = FScore() np.random.shuffle(training_data) for b in xrange(num_batched): batch = training_data[(b * batch_size):(b + 1) * batch_size] explore = [ Segmenter.exploration(example, fm, network, alpha=alpha, beta=beta) for example in batch ] for (_, acc) in explore: training_acc += acc batch = [example for (example, _) in explore] dynet.renew_cg() network.prep_params() errors = [] for example in batch: ## random UNKing ## for (i, uni) in enumerate(example['unigrams']): if uni <= 2: continue u_freq = fm.unigrams_freq_list[uni] drop_prob = unk_params / (unk_params + u_freq) r = np.random.random() if r < drop_prob: example['unigrams'][i] = 0 for (i, bi) in enumerate(example['fwd_bigrams']): if bi <= 2: continue b_freq = fm.bigrams_freq_list[bi] drop_prob = unk_params / (unk_params + b_freq) r = np.random.random() if r < drop_prob: example['fwd_bigrams'][i] = 0 fwd, back = network.evaluate_recurrent( example['fwd_bigrams'], example['unigrams'], ) for (left, right), correct in example['label_data'].items(): # correct = example['label_data'][(left,right)] scores = network.evaluate_labels( fwd, back, left, right) probs = dynet.softmax(scores) loss = -dynet.log(dynet.pick(probs, correct)) errors.append(loss) total_states += len(example['label_data']) batch_error = dynet.esum(errors) total_cost += batch_error.scalar_value() batch_error.backward() network.trainer.update() mean_cost = total_cost / total_states print( '\rBatch {} Mean Cost {:.4f} [Train: {}]'.format( b, mean_cost, training_acc, ), end='', ) sys.stdout.flush() if ((b + 1) % parse_every) == 0 or b == (num_batched - 1): dev_acc = Segmenter.evaluate_corpus( dev_sentences, fm, network, ) print(' [Val: {}]'.format(dev_acc)) if dev_acc.fscore() > best_acc.fscore(): best_acc = dev_acc network.save(model_save_file) print(' [saved model : {}]'.format(model_save_file)) current_time = time.time() runmins = (current_time - start_time) / 60 print(' Elapsed time: {:.2f}m'.format(runmins)) return network