Пример #1
0
def main():
    parser = argparse.ArgumentParser(description='Reinforce')
    parser.add_argument('--data',
                        type=str,
                        default=config.data_dir,
                        help='location of the data corpus')
    parser.add_argument('--unk_threshold',
                        type=int,
                        default=config.unk_threshold,
                        help='minimum word frequency to be in dictionary')
    parser.add_argument('--alice_model_file',
                        type=str,
                        help='Alice model file')
    parser.add_argument('--bob_model_file', type=str, help='Bob model file')
    parser.add_argument('--output_model_file',
                        type=str,
                        help='output model file')
    parser.add_argument('--context_file', type=str, help='context file')
    parser.add_argument('--temperature',
                        type=float,
                        default=config.rl_temperature,
                        help='temperature')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=config.cuda,
                        help='use CUDA')
    parser.add_argument('--verbose',
                        action='store_true',
                        default=config.verbose,
                        help='print out converations')
    parser.add_argument('--seed',
                        type=int,
                        default=config.seed,
                        help='random seed')
    parser.add_argument(
        '--score_threshold',
        type=int,
        default=config.rl_score_threshold,
        help='successful dialog should have more than score_threshold in score'
    )
    parser.add_argument('--log_file',
                        type=str,
                        default='',
                        help='log successful dialogs to file for training')
    parser.add_argument('--smart_bob',
                        action='store_true',
                        default=False,
                        help='make Bob smart again')
    parser.add_argument('--gamma',
                        type=float,
                        default=config.rl_gamma,
                        help='discount factor')
    parser.add_argument('--eps',
                        type=float,
                        default=config.rl_eps,
                        help='eps greedy')
    parser.add_argument('--nesterov',
                        action='store_true',
                        default=config.nesterov,
                        help='enable nesterov momentum')
    parser.add_argument('--momentum',
                        type=float,
                        default=config.rl_momentum,
                        help='momentum for sgd')
    parser.add_argument('--lr',
                        type=float,
                        default=config.rl_lr,
                        help='learning rate')
    parser.add_argument('--clip',
                        type=float,
                        default=config.rl_clip,
                        help='gradient clip')
    parser.add_argument('--rl_lr',
                        type=float,
                        default=config.rl_reinforcement_lr,
                        help='RL learning rate')
    parser.add_argument('--rl_clip',
                        type=float,
                        default=config.rl_reinforcement_clip,
                        help='RL gradient clip')
    parser.add_argument('--ref_text',
                        type=str,
                        help='file with the reference text')
    parser.add_argument('--bsz',
                        type=int,
                        default=config.rl_bsz,
                        help='batch size')
    parser.add_argument('--sv_train_freq',
                        type=int,
                        default=config.rl_sv_train_freq,
                        help='supervision train frequency')
    parser.add_argument('--nepoch',
                        type=int,
                        default=config.rl_nepoch,
                        help='number of epochs')
    parser.add_argument('--visual',
                        action='store_true',
                        default=config.plot_graphs,
                        help='plot graphs')
    parser.add_argument('--domain',
                        type=str,
                        default=config.domain,
                        help='domain for the dialogue')
    parser.add_argument('--reward',
                        type=str,
                        choices=['margin', 'fair', 'length'],
                        default='margin',
                        help='reward function')
    args = parser.parse_args()

    device_id = utils.use_cuda(args.cuda)
    logging.info("Starting training using pytorch version:%s" %
                 (str(torch.__version__)))
    logging.info("CUDA is %s" % ("enabled. Using device_id:"+str(device_id) + " version:" \
        +str(torch.version.cuda) + " on gpu:" + torch.cuda.get_device_name(0) if args.cuda else "disabled"))

    alice_model = utils.load_model(args.alice_model_file)
    # we don't want to use Dropout during RL
    alice_model.eval()
    # Alice is a RL based agent, meaning that she will be learning while selfplaying
    logging.info("Creating RlAgent from alice_model: %s" %
                 (args.alice_model_file))
    alice = RlAgent(alice_model, args, name='Alice')

    # we keep Bob frozen, i.e. we don't update his parameters
    logging.info("Creating Bob's (--smart_bob) LstmRolloutAgent" if args.smart_bob \
        else "Creating Bob's (not --smart_bob) LstmAgent" )
    bob_ty = LstmRolloutAgent if args.smart_bob else LstmAgent
    bob_model = utils.load_model(args.bob_model_file)
    bob_model.eval()
    bob = bob_ty(bob_model, args, name='Bob')

    logging.info("Initializing communication dialogue between Alice and Bob")
    dialog = Dialog([alice, bob], args)
    logger = DialogLogger(verbose=args.verbose, log_file=args.log_file)
    ctx_gen = ContextGenerator(args.context_file)

    logging.info(
        "Building word corpus, requiring minimum word frequency of %d for dictionary"
        % (args.unk_threshold))
    corpus = data.WordCorpus(args.data, freq_cutoff=args.unk_threshold)
    engine = Engine(alice_model, args, device_id, verbose=False)

    logging.info("Starting Reinforcement Learning")
    reinforce = Reinforce(dialog, ctx_gen, args, engine, corpus, logger)
    reinforce.run()

    logging.info("Saving updated Alice model to %s" % (args.output_model_file))
    utils.save_model(alice.model, args.output_model_file)
Пример #2
0
def main():
    parser = argparse.ArgumentParser(description='training script')
    parser.add_argument('--data',
                        type=str,
                        default='data/negotiate',
                        help='location of the data corpus')
    parser.add_argument('--nembed_word',
                        type=int,
                        default=256,
                        help='size of word embeddings')
    parser.add_argument('--nembed_ctx',
                        type=int,
                        default=64,
                        help='size of context embeddings')
    parser.add_argument(
        '--nhid_lang',
        type=int,
        default=256,
        help='size of the hidden state for the language module')
    parser.add_argument('--nhid_ctx',
                        type=int,
                        default=64,
                        help='size of the hidden state for the context module')
    parser.add_argument(
        '--nhid_strat',
        type=int,
        default=64,
        help='size of the hidden state for the strategy module')
    parser.add_argument(
        '--nhid_attn',
        type=int,
        default=64,
        help='size of the hidden state for the attention module')
    parser.add_argument(
        '--nhid_sel',
        type=int,
        default=64,
        help='size of the hidden state for the selection module')
    parser.add_argument('--lr',
                        type=float,
                        default=20.0,
                        help='initial learning rate')
    parser.add_argument('--min_lr',
                        type=float,
                        default=1e-5,
                        help='min threshold for learning rate annealing')
    parser.add_argument('--decay_rate',
                        type=float,
                        default=9.0,
                        help='decrease learning rate by this factor')
    parser.add_argument('--decay_every',
                        type=int,
                        default=1,
                        help='decrease learning rate after decay_every epochs')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.0,
                        help='momentum for sgd')
    parser.add_argument('--nesterov',
                        action='store_true',
                        default=False,
                        help='enable nesterov momentum')
    parser.add_argument('--clip',
                        type=float,
                        default=0.2,
                        help='gradient clipping')
    parser.add_argument('--dropout',
                        type=float,
                        default=0.5,
                        help='dropout rate in embedding layer')
    parser.add_argument('--init_range',
                        type=float,
                        default=0.1,
                        help='initialization range')
    parser.add_argument('--max_epoch',
                        type=int,
                        default=30,
                        help='max number of epochs')
    parser.add_argument('--bsz', type=int, default=25, help='batch size')
    parser.add_argument('--unk_threshold',
                        type=int,
                        default=20,
                        help='minimum word frequency to be in dictionary')
    parser.add_argument('--temperature',
                        type=float,
                        default=0.1,
                        help='temperature')
    parser.add_argument('--sel_weight',
                        type=float,
                        default=1.0,
                        help='selection weight')
    parser.add_argument('--seed', type=int, default=1, help='random seed')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=False,
                        help='use CUDA')
    parser.add_argument('--model_file',
                        type=str,
                        default='',
                        help='path to save the final model')
    parser.add_argument('--visual',
                        action='store_true',
                        default=False,
                        help='plot graphs')
    parser.add_argument('--domain',
                        type=str,
                        default='object_division',
                        help='domain for the dialogue')
    parser.add_argument('--rnn_ctx_encoder',
                        action='store_true',
                        default=False,
                        help='wheather to use RNN for encoding the context')
    args = parser.parse_args()

    device_id = utils.use_cuda(args.cuda)
    utils.set_seed(args.seed)

    corpus = data.WordCorpus(args.data,
                             freq_cutoff=args.unk_threshold,
                             verbose=True)
    model = DialogModel(corpus.word_dict, corpus.item_dict,
                        corpus.context_dict, corpus.output_length, args,
                        device_id)
    if device_id is not None:
        model.cuda(device_id)
    engine = Engine(model, args, device_id, verbose=True)
    train_loss, valid_loss, select_loss = engine.train(corpus)
    print('final selectppl %.3f' % np.exp(select_loss))

    utils.save_model(engine.get_model(), args.model_file)
Пример #3
0
def main():
    parser = argparse.ArgumentParser(description='Reinforce')
    parser.add_argument('--data',
                        type=str,
                        default='./data/negotiate',
                        help='location of the data corpus')
    parser.add_argument('--unk_threshold',
                        type=int,
                        default=20,
                        help='minimum word frequency to be in dictionary')
    parser.add_argument('--alice_model_file',
                        type=str,
                        help='Alice model file')
    parser.add_argument('--bob_model_file', type=str, help='Bob model file')
    parser.add_argument('--output_model_file',
                        type=str,
                        help='output model file')
    parser.add_argument('--context_file', type=str, help='context file')
    parser.add_argument('--temperature',
                        type=float,
                        default=1.0,
                        help='temperature')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=False,
                        help='use CUDA')
    parser.add_argument('--verbose',
                        action='store_true',
                        default=False,
                        help='print out converations')
    parser.add_argument('--seed', type=int, default=1, help='random seed')
    parser.add_argument(
        '--score_threshold',
        type=int,
        default=6,
        help='successful dialog should have more than score_threshold in score'
    )
    parser.add_argument('--log_file',
                        type=str,
                        default='',
                        help='log successful dialogs to file for training')
    parser.add_argument('--smart_bob',
                        action='store_true',
                        default=False,
                        help='make Bob smart again')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.99,
                        help='discount factor')
    parser.add_argument('--eps', type=float, default=0.5, help='eps greedy')
    parser.add_argument('--nesterov',
                        action='store_true',
                        default=False,
                        help='enable nesterov momentum')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.0,
                        help='momentum for sgd')
    parser.add_argument('--lr', type=float, default=0.1, help='learning rate')
    parser.add_argument('--clip',
                        type=float,
                        default=0.1,
                        help='gradient clip')
    parser.add_argument('--rl_lr',
                        type=float,
                        default=0.2,
                        help='RL learning rate')
    parser.add_argument('--rl_clip',
                        type=float,
                        default=0.1,
                        help='RL gradient clip')
    parser.add_argument('--ref_text',
                        type=str,
                        help='file with the reference text')
    parser.add_argument('--bsz', type=int, default=8, help='batch size')
    parser.add_argument('--sv_train_freq',
                        type=int,
                        default=-1,
                        help='supervision train frequency')
    parser.add_argument('--nepoch',
                        type=int,
                        default=4,
                        help='number of epochs')
    parser.add_argument('--visual',
                        action='store_true',
                        default=False,
                        help='plot graphs')
    parser.add_argument('--domain',
                        type=str,
                        default='object_division',
                        help='domain for the dialogue')
    args = parser.parse_args()

    device_id = utils.use_cuda(args.cuda)

    alice_model = utils.load_model(args.alice_model_file)
    # we don't want to use Dropout during RL
    alice_model.eval()
    # Alice is a RL based agent, meaning that she will be learning while selfplaying
    alice = RlAgent(alice_model, args, name='Alice')

    # we keep Bob frozen, i.e. we don't update his parameters
    bob_ty = LstmRolloutAgent if args.smart_bob else LstmAgent
    bob_model = utils.load_model(args.bob_model_file)
    bob_model.eval()
    bob = bob_ty(bob_model, args, name='Bob')

    dialog = Dialog([alice, bob], args)
    logger = DialogLogger(verbose=args.verbose, log_file=args.log_file)
    ctx_gen = ContextGenerator(args.context_file)

    corpus = data.WordCorpus(args.data, freq_cutoff=args.unk_threshold)
    engine = Engine(alice_model, args, device_id, verbose=False)

    reinforce = Reinforce(dialog, ctx_gen, args, engine, corpus, logger)
    reinforce.run()

    utils.save_model(alice.model, args.output_model_file)
Пример #4
0
import tensorflow as tf
import numpy as np
import data
import domain
import sys
import random

from rl_agent import *
from sv_agent import *
from helpers import *
from utils import ContextGenerator

# load data same way as original repository
corpus = data.WordCorpus('end-to-end-negotiator/src/data/negotiate',
                         freq_cutoff=20,
                         verbose=True)
traindata = corpus.train_dataset(16)
trainset, _ = traindata

# initialize models
rl_model = RLModel(corpus, 'RL')
sv_model = SVModel(corpus, 'SV')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    # load pretrained SuperVised model
    var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "SV")
    saver = tf.train.Saver(var_list=var_list)
    saver.restore(sess, 'model/sv-named99/sv-named-99')
Пример #5
0
names = ["Nearest Neighbors", "Linear SVM", "RBF SVM",
         "Decision Tree", "Random Forest", "Neural Net", "AdaBoost",
         "Naive Bayes", "QDA"]

classifiers = [
    KNeighborsClassifier(3),
    SVC(kernel="linear", C=0.025),
    SVC(gamma=2, C=1),
    DecisionTreeClassifier(criterion='entropy'),
    RandomForestClassifier(criterion='entropy', max_depth=5, n_estimators=10),
    MLPClassifier(alpha=1),
    AdaBoostClassifier(),
    GaussianNB(),
    QuadraticDiscriminantAnalysis()]

word_corpus = data.WordCorpus('data/negotiate/')
train_dataset = word_corpus.train
val_dataset = word_corpus.valid
test_dataset = word_corpus.test

total_word_num = len(word_corpus.word_dict)

X_train, y_train = parse_dataset(train_dataset, total_word_num)
X_valid, y_valid = parse_dataset(val_dataset, total_word_num)
X_test, y_test = parse_dataset(test_dataset, total_word_num)


# iterate over classifiers
for name, clf in zip(names, classifiers):
    clf.fit(X_train, y_train)
    score = clf.score(X_test, y_test)
Пример #6
0
def main():
    parser = argparse.ArgumentParser(description='training script')
    parser.add_argument('--data',
                        type=str,
                        default=config.data_dir,
                        help='location of the data corpus')
    parser.add_argument('--nembed_word',
                        type=int,
                        default=config.nembed_word,
                        help='size of word embeddings')
    parser.add_argument('--nembed_ctx',
                        type=int,
                        default=config.nembed_ctx,
                        help='size of context embeddings')
    parser.add_argument(
        '--nhid_lang',
        type=int,
        default=config.nhid_lang,
        help='size of the hidden state for the language module')
    parser.add_argument('--nhid_ctx',
                        type=int,
                        default=config.nhid_ctx,
                        help='size of the hidden state for the context module')
    parser.add_argument(
        '--nhid_strat',
        type=int,
        default=config.nhid_strat,
        help='size of the hidden state for the strategy module')
    parser.add_argument(
        '--nhid_attn',
        type=int,
        default=config.nhid_attn,
        help='size of the hidden state for the attention module')
    parser.add_argument(
        '--nhid_sel',
        type=int,
        default=config.nhid_sel,
        help='size of the hidden state for the selection module')
    parser.add_argument('--lr',
                        type=float,
                        default=config.lr,
                        help='initial learning rate')
    parser.add_argument('--min_lr',
                        type=float,
                        default=config.min_lr,
                        help='min threshold for learning rate annealing')
    parser.add_argument('--decay_rate',
                        type=float,
                        default=config.decay_rate,
                        help='decrease learning rate by this factor')
    parser.add_argument('--decay_every',
                        type=int,
                        default=config.decay_every,
                        help='decrease learning rate after decay_every epochs')
    parser.add_argument('--momentum',
                        type=float,
                        default=config.momentum,
                        help='momentum for sgd')
    parser.add_argument('--nesterov',
                        action='store_true',
                        default=config.nesterov,
                        help='enable nesterov momentum')
    parser.add_argument('--clip',
                        type=float,
                        default=config.clip,
                        help='gradient clipping')
    parser.add_argument('--dropout',
                        type=float,
                        default=config.dropout,
                        help='dropout rate in embedding layer')
    parser.add_argument('--init_range',
                        type=float,
                        default=config.init_range,
                        help='initialization range')
    parser.add_argument('--max_epoch',
                        type=int,
                        default=config.max_epoch,
                        help='max number of epochs')
    parser.add_argument('--bsz',
                        type=int,
                        default=config.bsz,
                        help='batch size')
    parser.add_argument('--unk_threshold',
                        type=int,
                        default=config.unk_threshold,
                        help='minimum word frequency to be in dictionary')
    parser.add_argument('--temperature',
                        type=float,
                        default=config.temperature,
                        help='temperature')
    parser.add_argument('--sel_weight',
                        type=float,
                        default=config.sel_weight,
                        help='selection weight')
    parser.add_argument('--seed',
                        type=int,
                        default=config.seed,
                        help='random seed')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=config.cuda,
                        help='use CUDA')
    parser.add_argument('--model_file',
                        type=str,
                        default='',
                        help='path to save the final model')
    parser.add_argument('--visual',
                        action='store_true',
                        default=config.plot_graphs,
                        help='plot graphs')
    parser.add_argument('--domain',
                        type=str,
                        default=config.domain,
                        help='domain for the dialogue')
    parser.add_argument('--rnn_ctx_encoder',
                        action='store_true',
                        default=config.rnn_ctx_encoder,
                        help='whether to use RNN for encoding the context')
    args = parser.parse_args()

    device_id = utils.use_cuda(args.cuda)
    logging.info("Starting training using pytorch version:%s" %
                 (str(torch.__version__)))
    logging.info("CUDA is %s" % ("enabled. Using device_id:"+str(device_id) + " version:" \
        +str(torch.version.cuda) + " on gpu:" + torch.cuda.get_device_name(0) if args.cuda else "disabled"))
    utils.set_seed(args.seed)

    logging.info(
        "Building word corpus, requiring minimum word frequency of %d for dictionary"
        % (args.unk_threshold))
    corpus = data.WordCorpus(args.data,
                             freq_cutoff=args.unk_threshold,
                             verbose=True)

    logging.info("Building RNN-based dialogue model from word corpus")
    model = DialogModel(corpus.word_dict, corpus.item_dict,
                        corpus.context_dict, corpus.output_length, args,
                        device_id)
    if device_id is not None:
        model.cuda(device_id)

    engine = Engine(model, args, device_id, verbose=True)
    logging.info("Training model")
    train_loss, valid_loss, select_loss = engine.train(corpus)
    logging.info('final select_ppl %.3f' % np.exp(select_loss))

    # utils.save_model(engine.get_model(), args.model_file)
    torch.save(engine.get_model().state_dict(), args.model_file)
Пример #7
0
def main():
    parser = argparse.ArgumentParser(description='training script')
    parser.add_argument('--data',
                        type=str,
                        default='data/onecommon',
                        help='location of the data corpus')
    parser.add_argument('--nembed_word',
                        type=int,
                        default=128,
                        help='size of word embeddings')
    parser.add_argument('--nembed_ctx',
                        type=int,
                        default=128,
                        help='size of context embeddings')
    parser.add_argument(
        '--nhid_lang',
        type=int,
        default=128,
        help='size of the hidden state for the language module')
    parser.add_argument(
        '--nhid_sel',
        type=int,
        default=128,
        help='size of the hidden state for the selection module')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='initial learning rate')
    parser.add_argument('--clip',
                        type=float,
                        default=0.1,
                        help='gradient clipping')
    parser.add_argument('--dropout',
                        type=float,
                        default=0.5,
                        help='dropout rate in embedding layer')
    parser.add_argument('--init_range',
                        type=float,
                        default=0.01,
                        help='initialization range')
    parser.add_argument('--max_epoch',
                        type=int,
                        default=40,
                        help='max number of epochs')
    parser.add_argument('--bsz', type=int, default=16, help='batch size')
    parser.add_argument('--unk_threshold',
                        type=int,
                        default=10,
                        help='minimum word frequency to be in dictionary')
    parser.add_argument('--seed', type=int, default=None, help='random seed')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=False,
                        help='use CUDA')
    parser.add_argument('--model_file',
                        type=str,
                        default='tmp.th',
                        help='path to save the final model')
    parser.add_argument('--domain',
                        type=str,
                        default='one_common',
                        help='domain for the dialogue')
    parser.add_argument(
        '--rel_ctx_encoder',
        action='store_true',
        default=False,
        help='wheather to use relational module for encoding the context')
    parser.add_argument('--rel_hidden',
                        type=int,
                        default=128,
                        help='size of relation module embeddings')
    parser.add_argument('--context_only',
                        action='store_true',
                        default=False,
                        help='train without dialogue embeddings')
    parser.add_argument('--test_corpus',
                        choices=['full', 'uncorrelated', 'success_only'],
                        default='full',
                        help='type of test corpus to use')
    parser.add_argument('--test_only',
                        action='store_true',
                        default=False,
                        help='use pretrained model for testing')
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)

    if args.seed is None:
        print("running experiments with 10 different seeds")
        args.seed = list(range(10))
    else:
        args.seed = [args.seed]
    best_valid_loss = 1e8
    best_model = None

    train_accuracies = np.array([])
    valid_accuracies = np.array([])
    test_accuracies = np.array([])
    test_correct = defaultdict(list)

    for seed in args.seed:
        utils.set_seed(seed)

        # consider double count
        freq_cutoff = args.unk_threshold * 2
        corpus = data.WordCorpus(args.data,
                                 freq_cutoff=freq_cutoff,
                                 verbose=True)
        if args.test_corpus == 'full':
            test_corpus = corpus
        elif args.test_corpus == 'uncorrelated':
            test_corpus = data.WordCorpus(args.data,
                                          train='train_uncorrelated.txt',
                                          valid='valid_uncorrelated.txt',
                                          test='test_uncorrelated.txt',
                                          freq_cutoff=freq_cutoff,
                                          verbose=True,
                                          word_dict=corpus.word_dict)
        elif args.test_corpus == 'success_only':
            test_corpus = data.WordCorpus(args.data,
                                          train='train_success_only.txt',
                                          valid='valid_success_only.txt',
                                          test='test_success_only.txt',
                                          freq_cutoff=freq_cutoff,
                                          verbose=True,
                                          word_dict=corpus.word_dict)

        if args.test_only:
            model = utils.load_model(args.model_file)
        else:
            model = SelectModel(corpus.word_dict, corpus.output_length, args,
                                device)
            if torch.cuda.is_available():
                model = model.to(device)
            engine_select = SelectEngine(model, args, device, verbose=True)
            train_loss, best_valid_loss, best_model_state = engine_select.train(
                corpus)
            print('best valid loss %.3f' % np.exp(best_valid_loss))
            model.load_state_dict(best_model_state)

        # Test Target Selection
        model.eval()

        sel_crit = nn.CrossEntropyLoss()

        trainset, trainset_stats = test_corpus.train_dataset(args.bsz,
                                                             device=device)
        train_loss, train_accuracy, _ = get_result(model, trainset, sel_crit)

        validset, validset_stats = test_corpus.valid_dataset(args.bsz,
                                                             device=device)
        valid_loss, valid_accuracy, _ = get_result(model, validset, sel_crit)

        testset, testset_stats = test_corpus.test_dataset(args.bsz,
                                                          device=device)
        test_loss, test_accuracy, correct_idxs = get_result(
            model, testset, sel_crit)

        if best_model is None or valid_loss < best_valid_loss:
            best_model = model
            best_valid_loss = valid_loss
            utils.save_model(best_model, args.model_file)

        print('trainloss %.5f' % (train_loss))
        print('trainaccuracy {:.5f}'.format(train_accuracy))
        print('validloss %.5f' % (valid_loss))
        print('validaccuracy {:.5f}'.format(valid_accuracy))
        print('testloss %.5f' % (test_loss))
        print('testaccuracy {:.5f}'.format(test_accuracy))

        train_accuracies = np.append(train_accuracies, train_accuracy)
        valid_accuracies = np.append(valid_accuracies, valid_accuracy)
        test_accuracies = np.append(test_accuracies, test_accuracy)

    # print final results
    output = '{:.2f} \\pm {:.1f}'.format(
        np.mean(train_accuracies) * 100,
        np.std(train_accuracies) * 100)
    output += ' & {:.2f} \\pm {:.1f}'.format(
        np.mean(valid_accuracies) * 100,
        np.std(valid_accuracies) * 100)
    output += ' & {:.2f} \\pm {:.1f}'.format(
        np.mean(test_accuracies) * 100,
        np.std(test_accuracies) * 100)
    print(output)