def train(lr): # Load data print 'loading dataset...' train_data = TextIterator(train_datafile, n_batch=n_batch, maxlen=maxlen) valid_data = TextIterator(valid_datafile, n_batch=n_batch, maxlen=maxlen) test_data = TextIterator(test_datafile, n_batch=n_batch, maxlen=maxlen) print 'building model...' model = RNNLM(n_input, n_hidden, vocabulary_size, rnn_cell, optimizer, p, bptt) print 'training start...' start = time.time() idx = 0 error = [] n_words = 0 for epoch in xrange(NEPOCH): in_start = time.time() for x, x_mask, y, y_mask in train_data: idx += 1 beg_time = time.time() #print x.shape #print y.shape cost, batch_nll = model.train(x, x_mask, y, y_mask, lr) error.append(batch_nll) n_words += np.sum(y_mask) if np.isnan(cost) or np.isinf(cost): print 'NaN Or Inf detected!' return -1 if idx % disp_freq == 0: error = np.asarray(error).flatten() logger.info('epoch: %d idx: %d cost: %f ppl: %f' % (epoch, idx, np.sum(error) / n_words, np.exp(np.sum(error) / n_words))) error = [] n_words = 0 if idx % save_freq == 0: filename = './model/param_{}_bptt{}_{:.2f}.pkl'.format( rnn_cell, bptt, (time.time() - start)) logger.info('dumping...' + filename) save_model(filename, model) if idx % valid_freq == 0: logger.info('validing...') valid_cost = evaluate_ppl(valid_data, model) logger.info('validation cost: %f perplexity: %f' % (valid_cost, np.exp(valid_cost))) if idx % test_freq == 0: logger.info('testing...') test_cost = evaluate_ppl(test_data, model) logger.info('test cost: %f perplexity: %f' % (test_cost, np.exp(test_cost))) print "Finished. Time = " + str(time.time() - start)
def test(): test_data = TextIterator(test_datafile, n_batch=n_batch) valid_data = TextIterator(valid_datafile, n_batch=n_batch) model = RNNLM(n_input, n_hidden, vocabulary_size, rnn_cell, optimizer, p) if os.path.isfile(args.model_dir): print 'loading pretrained model:', args.model_dir model = load_model(args.model_dir, model) else: print args.model_dir, 'not found' mean_cost = evaluate(valid_data, model) print 'valid cost:', mean_cost, 'perplexity:', np.exp( mean_cost) #,"word_error_rate:",mean_wer mean_cost = evaluate(test_data, model) print 'test cost:', mean_cost, 'perplexity:', np.exp(mean_cost)
def test(): test_data = TextIterator(test_datafile, n_batch=n_batch) valid_data = TextIterator(valid_datafile, n_batch=n_batch) model = RNNLM(n_input, n_hidden, vocabulary_size, rnn_cell, optimizer, p) if os.path.isfile(args.model_dir): print 'loading pretrained model:', args.model_dir model = load_model(args.model_dir, model) else: print args.model_dir, 'not found' valid_cost, wer = evaluate(valid_data, model) logger.info('validation cost: %f perplexity: %f,word_error_rate:%f' % (valid_cost, np.exp(valid_cost), wer)) test_cost, wer = evaluate(test_data, model) logger.info('test cost: %f perplexity: %f,word_error_rate:%f' % (test_cost, np.exp(test_cost), wer))
def train(lr): print 'loading dataset...' train_data = TextIterator(train_datafile, n_batch=n_batch, maxlen=maxlen) valid_data = TextIterator(valid_datafile, n_batch=n_batch, maxlen=maxlen) test_data = TextIterator(test_datafile, n_batch=n_batch, maxlen=maxlen) print 'building model...' model = RNNLM(n_input, n_hidden, vocabulary_size, rnn_cell, optimizer, p, bptt) if os.path.isfile(model_dir): print 'loading checkpoint parameters....', model_dir model = load_model(model_dir, model) if goto_line > 0: train_data.goto_line(goto_line) print 'goto line:', goto_line print 'training start...' start = time.time() idx = goto_line for epoch in xrange(NEPOCH): error = 0 for x, x_mask, y, y_mask in train_data: idx += 1 cost = model.train(x, x_mask, y, y_mask, lr) error += cost if np.isnan(cost) or np.isinf(cost): print 'NaN Or Inf detected!' return -1 if idx % disp_freq == 0: logger.info('epoch: %d idx: %d cost: %f ppl: %f' % (epoch, idx, error / disp_freq, np.exp(error / (1.0 * disp_freq)))) error = 0 if idx % save_freq == 0: logger.info('dumping...') save_model( './model/parameters_%.2f.pkl' % (time.time() - start), model) if idx % valid_freq == 0: logger.info('validing...') valid_cost, wer = evaluate(valid_data, model) logger.info( 'validation cost: %f perplexity: %f,word_error_rate:%f' % (valid_cost, np.exp(valid_cost), wer)) if idx % test_freq == 0: logger.info('testing...') test_cost, wer = evaluate(test_data, model) logger.info('test cost: %f perplexity: %f,word_error_rate:%f' % (test_cost, np.exp(test_cost), wer)) print "Finished. Time = " + str(time.time() - start)