예제 #1
0
config.strong_supervision = False

config.train_mode = False

print('Testing DMN ' + dmn_type + ' on babi task', config.babi_id)

# create model
with tf.variable_scope('DMN') as scope:
    if dmn_type == "original":
        from dmn_original import DMN
        model = DMN(config)
    elif dmn_type == "plus":
        from dmn_plus import DMN_PLUS
        model = DMN_PLUS(config)

print('==> initializing variables')
init = tf.global_variables_initializer()
saver = tf.train.Saver()

with tf.Session() as session:
    session.run(init)

    print('==> restoring weights')
    saver.restore(session,
                  'weights/task' + str(model.config.babi_id) + '.weights')
    print('==> running DMN')
    test_loss, test_accuracy = model.run_epoch(session, model.test)
    print('')
    print('Test accuracy:', test_accuracy)
        if args.restore:
            print '==> restoring weights'
            saver.restore(
                session,
                'weights/task' + str(model.config.babi_id) + '.weights')

        print '==> starting training'
        for epoch in xrange(config.max_epochs):
            print 'Epoch {}'.format(epoch)
            start = time.time()

            train_loss, train_accuracy = model.run_epoch(
                session,
                model.train,
                epoch,
                train_writer,
                train_op=model.train_step,
                train=True)
            valid_loss, valid_accuracy = model.run_epoch(session, model.valid)
            print 'Training loss: {}'.format(train_loss)
            print 'Validation loss: {}'.format(valid_loss)
            print 'Training accuracy: {}'.format(train_accuracy)
            print 'Vaildation accuracy: {}'.format(valid_accuracy)

            if valid_loss < best_val_loss:
                best_val_loss = valid_loss
                best_val_epoch = epoch
                if best_val_loss < best_overall_val_loss:
                    print 'Saving weights'
                    best_overall_val_loss = best_val_loss
config.strong_supervision = False

config.train_mode = False

print( 'Testing DMN ' + dmn_type + ' on babi task', config.babi_id)

# create model
with tf.variable_scope('DMN') as scope:
    if dmn_type == "original":
        from dmn_original import DMN
        model = DMN(config)
    elif dmn_type == "plus":
        from dmn_plus import DMN_PLUS
        model = DMN_PLUS(config)

print('==> initializing variables')
init = tf.global_variables_initializer()
saver = tf.train.Saver()

with tf.Session() as session:
    session.run(init)

    print('==> restoring weights')
    saver.restore(session, 'weights/task' + str(model.config.babi_id) + '.weights')

    print('==> running DMN')
    test_loss, test_accuracy = model.run_epoch(session, model.test)

    print('')
    print('Test accuracy:', test_accuracy)
        best_val_epoch = 0
        prev_epoch_loss = float('inf')
        best_val_loss = float('inf')
        best_val_accuracy = 0.0

        if args.restore:
            print('==> restoring weights')
            saver.restore(session, 'weights/task' + str(model.config.babi_id) + '.weights')

        print('==> starting training')
        for epoch in range(config.max_epochs):
            print('Epoch {}'.format(epoch))
            start = time.time()

            train_loss, train_accuracy = model.run_epoch(
              session, model.train, epoch, train_writer,
              train_op=model.train_step, train=True)
            valid_loss, valid_accuracy = model.run_epoch(session, model.valid)
            print('Training loss: {}'.format(train_loss))
            print('Validation loss: {}'.format(valid_loss))
            print('Training accuracy: {}'.format(train_accuracy))
            print('Vaildation accuracy: {}'.format(valid_accuracy))

            if valid_loss < best_val_loss:
                best_val_loss = valid_loss
                best_val_epoch = epoch
                if best_val_loss < best_overall_val_loss:
                    print('Saving weights')
                    best_overall_val_loss = best_val_loss
                    best_val_accuracy = valid_accuracy
                    saver.save(session, 'weights/task' + str(model.config.babi_id) + '.weights')
예제 #5
0
# asd

# create model
with tf.variable_scope('DMN') as scope:
    if dmn_type == "original":
        from dmn_original import DMN
        model = DMN(config)
    elif dmn_type == "plus":
        from dmn_self_plus import DMN_PLUS
        model = DMN_PLUS(config)

print('==> initializing variables')
init = tf.global_variables_initializer()
saver = tf.train.Saver()

with tf.Session() as session:
    session.run(init)
    print('==> restoring weights')
    saver.restore(session, 'weights/task' + str(model.config.babi_id) + '.weights')

    print('==> running DMN')
    # test_loss, test_accuracy = model.run_epoch(session, model.test)
    # print(model.test)
    # qp, ip, ql, il, im, a = data
    # questions, inputs, q_lens, input_lens, input_masks, answers

    asd
    answer = model.run_epoch(session, model.test)
    print(answer)
    # print('Test accuracy:', test_accuracy)