def wrapper_ngram(data=TREC, resplit=True, validate_ratio=0.2): train_x, train_y, validate_x, validate_y, test_x, test_y, \ W, mask = prepare_datasets(data, resplit=resplit, validation_ratio=validate_ratio) # get input shape input_shape = (train_x[0].shape[0], W.shape[1]) print "input data shape", input_shape n_out = len(np.unique(test_y)) shuffle_indices = np.random.permutation(train_x.shape[0]) datasets = (train_x[shuffle_indices], train_y[shuffle_indices], validate_x, validate_y, test_x, test_y) test_accuracy = train_ngram_net( U=W, datasets=datasets, n_epochs=10, ngrams=(3, 2), ngram_out=(150, 50), non_static=False, input_shape=input_shape, concat_out=True, n_kernels=(8, 16), use_bias=False, lr_rate=0.02, dropout=True, dropout_rate=0.5, n_hidden=600, n_out=n_out, ngram_activation=leaky_relu, activation=leaky_relu, batch_size=50, l2_ratio=1e-5, update_rule='adagrad', skip_gram=False, ) return test_accuracy
def error_analysis(data=SST_SENT_POL): train_x, train_y, validate_x, validate_y, test_x, test_y, \ W, mask = prepare_datasets(data, resplit=False, validation_ratio=0.0) # get input shape input_shape = (train_x[0].shape[0], W.shape[1]) print "input data shape", input_shape n_out = len(np.unique(test_y)) shuffle_indices = np.random.permutation(train_x.shape[0]) datasets = (train_x[shuffle_indices], train_y[shuffle_indices], validate_x, validate_y, test_x, test_y) best_prediction = train_ngram_net( U=W, datasets=datasets, n_epochs=10, ngrams=(1, 2), ngram_out=(300, 250), non_static=False, input_shape=input_shape, concat_out=False, n_kernels=(4, 4), use_bias=False, lr_rate=0.02, dropout=True, dropout_rate=0.2, n_hidden=250, n_out=n_out, ngram_activation=leaky_relu, activation=leaky_relu, batch_size=50, l2_ratio=1e-5, update_rule='adagrad', skip_gram=False, predict=True ) raw_datasets = load_raw_datasets(datasets=data) _, _, validate_raw, _, _, _ = raw_datasets from collections import Counter errors = [] for i in xrange(len(best_prediction)): if best_prediction[i] != validate_y[i]: errors.append("%d & %d" % (validate_y[i], best_prediction[i])) print validate_y[i], best_prediction[i], " ".join(validate_raw[i]) errors = Counter(errors) print errors.most_common(10)