예제 #1
0
def omniglot():

    sess = tf.InteractiveSession()

    input_ph = tf.placeholder(dtype=tf.float32, shape=(16,50,400))   #(batch_size, time, input_dim)
    target_ph = tf.placeholder(dtype=tf.int32, shape=(16,50))     #(batch_size, time)(label_indices)

    ##Global variables for Omniglot Problem
    nb_reads = 4
    controller_size = 200
    memory_shape = (128,40)
    nb_class = 5
    input_size = 20*20
    batch_size = 16
    nb_samples_per_class = 10

    #Load Data
    generator = OmniglotGenerator(data_folder='./data/omniglot', batch_size=batch_size, nb_samples=nb_class, nb_samples_per_class=nb_samples_per_class, max_rotation=0., max_shift=0., max_iter=None)
    output_var, output_var_flatten, params = memory_augmented_neural_network(input_ph, target_ph, batch_size=batch_size, nb_class=nb_class, memory_shape=memory_shape, controller_size=controller_size, input_size=input_size, nb_reads=nb_reads)

    print('Compiling the Model')
    

    with tf.variable_scope("Weights", reuse=True):
        W_key = tf.get_variable('W_key', shape=(nb_reads, controller_size, memory_shape[1]))
        b_key = tf.get_variable('b_key', shape=(nb_reads, memory_shape[1]))
        W_add = tf.get_variable('W_add', shape=(nb_reads, controller_size, memory_shape[1]))
        b_add = tf.get_variable('b_add', shape=(nb_reads, memory_shape[1]))
        W_sigma = tf.get_variable('W_sigma', shape=(nb_reads, controller_size, 1))
        b_sigma = tf.get_variable('b_sigma', shape=(nb_reads, 1))
        W_xh = tf.get_variable('W_xh', shape=(input_size + nb_class, 4 * controller_size))
        b_h = tf.get_variable('b_xh', shape=(4 * controller_size))
        W_o = tf.get_variable('W_o', shape=(controller_size + nb_reads * memory_shape[1], nb_class))
        b_o = tf.get_variable('b_o', shape=(nb_class))
        W_rh = tf.get_variable('W_rh', shape=(nb_reads * memory_shape[1], 4 * controller_size))
        W_hh = tf.get_variable('W_hh', shape=(controller_size, 4 * controller_size))
        gamma = tf.get_variable('gamma', shape=[1], initializer=tf.constant_initializer(0.95))

    params = [W_key, b_key, W_add, b_add, W_sigma, b_sigma, W_xh, W_rh, W_hh, b_h, W_o, b_o]
    
    #output_var = tf.cast(output_var, tf.int32)
    target_ph_oh = tf.one_hot(target_ph, depth=generator.nb_samples)
    print('Output, Target shapes: ',output_var.get_shape().as_list(), target_ph_oh.get_shape().as_list())
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(output_var, target_ph_oh), name="cost")
    opt = tf.train.AdamOptimizer(learning_rate=1e-3)
    train_step = opt.minimize(cost, var_list=params)

    #train_step = tf.train.AdamOptimizer(1e-3).minimize(cost)
    accuracies = accuracy_instance(tf.argmax(output_var, axis=2), target_ph, batch_size=generator.batch_size)
    sum_out = tf.reduce_sum(tf.reshape(tf.one_hot(tf.argmax(output_var, axis=2), depth=generator.nb_samples), (-1, generator.nb_samples)), axis=0)

    print('Done')

    tf.summary.scalar('cost', cost)
    for i in range(generator.nb_samples_per_class):
    	tf.summary.scalar('accuracy-'+str(i), accuracies[i])
    
    merged = tf.summary.merge_all()
    #writer = tf.summary.FileWriter('/tmp/tensorflow', graph=tf.get_default_graph())
    train_writer = tf.summary.FileWriter('/tmp/tensorflow/', sess.graph)

    t0 = time.time()
    all_scores, scores, accs = [],[],np.zeros(generator.nb_samples_per_class)


    sess.run(tf.global_variables_initializer())

    print('Training the model')



    try:
        for i, (batch_input, batch_output) in generator:
            feed_dict = {
                input_ph: batch_input,
                target_ph: batch_output
            }
            #print batch_input.shape, batch_output.shape
            train_step.run(feed_dict)
            score = cost.eval(feed_dict)
            acc = accuracies.eval(feed_dict)
            temp = sum_out.eval(feed_dict)
            summary = merged.eval(feed_dict)
            train_writer.add_summary(summary, i)
            print(i, ' ',temp)
            all_scores.append(score)
            scores.append(score)
            accs += acc
            if i>0 and not (i%100):
                print((accs / 100.0))
                print(('Episode %05d: %.6f' % (i, np.mean(score))))
                scores, accs = [], np.zeros(generator.nb_samples_per_class)


    except KeyboardInterrupt:
        print(time.time() - t0)
        pass
예제 #2
0
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

from MANN.Utils.Metrics import accuracy_instance
import tensorflow as tf
import numpy as np
import copy

x = [0, 0, 0, 0, 0] * 10
y = [0, 1, 2, 3, 4] * 10
np.random.shuffle(y)
x = np.append([x], [x], axis=0)
y = np.append([y], [y], axis=0)

p = tf.constant(x)
t = tf.constant(y)

sess = tf.InteractiveSession()

zz = accuracy_instance(p, t, batch_size=2)

sess.run(zz)

print(p[0].eval())
print(t[0].eval())

print(zz.eval())

print(tf.equal(p, t).eval())
예제 #3
0
def train(num_trials, report_interval):
    """Trains the DNC and periodically reports the loss."""

    input_var = tf.placeholder(shape=(None, FLAGS.batch_size,
                                      FLAGS.input_size),
                               dtype=tf.float32)
    target_var = tf.placeholder(shape=(None, FLAGS.batch_size), dtype=tf.int32)
    target_var_mod = tf.one_hot(indices=target_var, depth=FLAGS.nb_class)
    zero_first = tf.zeros(shape=[1, FLAGS.batch_size, FLAGS.nb_class],
                          dtype=tf.float32)
    target_shift = tf.concat(values=[zero_first, target_var_mod[:-1, :, :]],
                             axis=0)
    input_var_mod = tf.concat(values=[input_var, target_shift], axis=2)

    output_logits, _ = run_model(input_var_mod, FLAGS.nb_class)
    #output_var = tf.round(tf.sigmoid(output_logits))
    #print(state)

    generator = OmniglotGenerator(
        data_folder='./data/omniglot',
        batch_size=FLAGS.batch_size,
        nb_samples=FLAGS.nb_class,
        nb_samples_per_class=FLAGS.nb_samples_per_class,
        max_rotation=FLAGS.rot_max,
        max_shift=FLAGS.shift_max,
        max_iter=FLAGS.num_episodes)

    target_var_mod = tf.one_hot(indices=target_var, depth=generator.nb_samples)
    train_loss = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(logits=output_logits,
                                                labels=target_var_mod))

    # Set up optimizer with global norm clipping.
    trainable_variables = tf.trainable_variables()
    grads, _ = tf.clip_by_global_norm(
        tf.gradients(train_loss, trainable_variables), FLAGS.max_grad_norm)

    global_step = tf.get_variable(
        name="global_step",
        shape=[],
        dtype=tf.int64,
        initializer=tf.zeros_initializer(),
        trainable=False,
        collections=[tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.GLOBAL_STEP])

    optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate,
                                       epsilon=FLAGS.optimizer_epsilon)
    train_step = optimizer.apply_gradients(zip(grads, trainable_variables),
                                           global_step=global_step)
    print(tf.argmax(output_logits, axis=2))
    print(target_var)
    accuracies = accuracy_instance(
        tf.argmax(output_logits, axis=2),
        target_var,
        nb_classes=FLAGS.nb_class,
        nb_samples_per_class=FLAGS.nb_samples_per_class,
        batch_size=generator.batch_size)
    sum_out = tf.reduce_sum(tf.reshape(
        tf.one_hot(tf.argmax(output_logits, axis=2),
                   depth=generator.nb_samples), (-1, generator.nb_samples)),
                            axis=0)
    #show_memory = state.access_state.memory

    saver = tf.train.Saver()

    if FLAGS.checkpoint_interval > 0:
        hooks = [
            tf.train.CheckpointSaverHook(checkpoint_dir=FLAGS.checkpoint_dir,
                                         save_steps=FLAGS.checkpoint_interval,
                                         saver=saver)
        ]
    else:
        hooks = []

    losses = []
    iters = []

    # Train.
    with tf.train.SingularMonitoredSession(
            hooks=hooks, checkpoint_dir=FLAGS.checkpoint_dir) as sess:

        print('Omniglot Task - DNC - concatenated_input - ', FLAGS.hidden_size,
              'hidden - ', FLAGS.learning_rate, 'lr')

        start_iteration = sess.run(global_step)
        accs = np.zeros(generator.nb_samples_per_class)
        losses = []

        for i, (batch_input, batch_output) in generator:

            #print(batch_input[0,0,:])
            feed_dict = {input_var: batch_input, target_var: batch_output}

            _, loss, acc, lab_distr = sess.run(
                [train_step, train_loss, accuracies, sum_out],
                feed_dict=feed_dict)

            accs += acc[0:FLAGS.nb_samples_per_class]
            losses.append(loss)
            if (i + 1) % report_interval == 0:
                print('\nEpisode %05d: %.6f' % (i + 1, np.mean(losses)))
                print('Labels Distribution: ', lab_distr)
                print('Accuracies:')
                print(accs / report_interval)
                losses, accs = [], np.zeros(generator.nb_samples_per_class)
예제 #4
0
def omniglot():

    sess = tf.InteractiveSession()

    input_ph = tf.placeholder(dtype=tf.float32,
                              shape=(16, 50,
                                     400))  #(batch_size, time, input_dim)
    target_ph = tf.placeholder(dtype=tf.int32,
                               shape=(16,
                                      50))  #(batch_size, time)(label_indices)

    #Load Data
    generator = OmniglotGenerator(data_folder='./data/omniglot',
                                  batch_size=16,
                                  nb_samples=5,
                                  nb_samples_per_class=10,
                                  max_rotation=0.,
                                  max_shift=0.,
                                  max_iter=1000)
    output_var, output_var_flatten, params = memory_augmented_neural_network(
        input_ph,
        target_ph,
        batch_size=generator.batch_size,
        nb_class=generator.nb_samples,
        memory_shape=(128, 40),
        controller_size=200,
        input_size=20 * 20,
        nb_reads=4)

    print 'Compiling the Model'

    output_var = tf.cast(output_var, tf.int32)
    target_ph_flatten = tf.one_hot(tf.reshape(target_ph, shape=[-1, 1]),
                                   depth=generator.nb_samples)
    #print '*******************------------>',target_ph.get_shape().as_list(),tf.argmax(output_var, axis=2).get_shape().as_list(), output_var.dtype
    cost = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(output_var_flatten,
                                                target_ph_flatten))
    train_step = tf.train.AdamOptimizer(1e-3).minimize(cost)
    accuracies = accuracy_instance(tf.argmax(output_var, axis=2),
                                   target_ph,
                                   batch_size=generator.batch_size)

    print 'Done'

    print 'Training the model'

    t0 = time.time()
    all_scores, scores, accs = [], [], np.zeros(generator.nb_samples_per_class)

    sess.run(tf.global_variables_initializer())

    try:
        for i, (batch_input, batch_output) in generator:
            feed_dict = {input_ph: batch_input, target_ph: batch_output}
            train_step.run(feed_dict)
            score = cost.eval(feed_dict)
            acc = accuracies.eval(feed_dict)
            all_scores.append(score)
            scores.append(score)
            accuracies += acc

            if i > 0 and not (i % 20):
                print('Episode %05d: %.6f' % (i, np.mean(score)))
                print(accs / 100.)
                scores, accs = [], np.zeros(generator.nb_samples_per_class)

    except KeyboardInterrupt:
        print time.time() - t0
        pass
예제 #5
0
def omniglot():
    sess = tf.InteractiveSession()

    ##Global variables for Omniglot Problem
    nb_reads = 4  # todo what does nb mean here? number, then what does read mean here?
    # finally get it, 4 means from a 128*40 external memory, you will read 4 * 40
    controller_size = 200  # todo what does the size of controller mean, the length of k?
    memory_shape = (128, 40)
    nb_class = 5  # each task will deal with 5 classes of samples
    input_size = 20 * 20
    batch_size = 16
    nb_samples_per_class = 10  # each class will have 10 samples

    input_ph = tf.placeholder(dtype=tf.float32,
                              shape=(batch_size, 50,
                                     400))  # (batch_size, time, input_dim)
    # todo why 400, does he flatten the input image? or he use some neural network to embed each image as a 400 vector?
    target_ph = tf.placeholder(dtype=tf.int32,
                               shape=(batch_size,
                                      50))  # (batch_size, time)(label_indices)
    # todo what does time mean here?

    # Load Data
    generator = OmniglotGenerator(
        data_folder='./data/omniglot/images_background',
        batch_size=batch_size,  # 16
        nb_classes=nb_class,  # 5 classes
        nb_samples_per_class=nb_samples_per_class,  # 10 samples per class
        max_rotation=0.,
        max_shift=0.,
        max_iter=None)

    output_var, output_var_flatten, params = memory_augmented_neural_network(
        input_ph,
        target_ph,
        batch_size=batch_size,
        nb_class=nb_class,
        memory_shape=memory_shape,
        # 128 by 40, so 128 rows?
        controller_size=controller_size,
        input_size=input_size,  # 400 here
        nb_reads=nb_reads)  # 4 here, what does reads mean?

    print('Compiling the Model')

    with tf.variable_scope("Weights", reuse=True):
        W_key = tf.get_variable('W_key',
                                shape=(nb_reads, controller_size,
                                       memory_shape[1]))
        b_key = tf.get_variable('b_key', shape=(nb_reads, memory_shape[1]))
        W_add = tf.get_variable('W_add',
                                shape=(nb_reads, controller_size,
                                       memory_shape[1]))
        b_add = tf.get_variable('b_add', shape=(nb_reads, memory_shape[1]))
        W_sigma = tf.get_variable('W_sigma',
                                  shape=(nb_reads, controller_size, 1))
        b_sigma = tf.get_variable('b_sigma', shape=(nb_reads, 1))
        W_xh = tf.get_variable('W_xh',
                               shape=(input_size + nb_class,
                                      4 * controller_size))
        b_h = tf.get_variable('b_xh', shape=(4 * controller_size))
        W_o = tf.get_variable('W_o',
                              shape=(controller_size +
                                     nb_reads * memory_shape[1], nb_class))
        b_o = tf.get_variable('b_o', shape=(nb_class))
        W_rh = tf.get_variable('W_rh',
                               shape=(nb_reads * memory_shape[1],
                                      4 * controller_size))
        W_hh = tf.get_variable('W_hh',
                               shape=(controller_size, 4 * controller_size))
        gamma = tf.get_variable('gamma',
                                shape=[1],
                                initializer=tf.constant_initializer(0.95))

    params = [
        W_key, b_key, W_add, b_add, W_sigma, b_sigma, W_xh, W_rh, W_hh, b_h,
        W_o, b_o
    ]

    # output_var = tf.cast(output_var, tf.int32)
    target_ph_oh = tf.one_hot(target_ph, depth=generator.nb_classes)
    print('Output, Target shapes: ',
          output_var.get_shape().as_list(),
          target_ph_oh.get_shape().as_list())
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
        logits=output_var, labels=target_ph_oh),
                          name="cost")
    opt = tf.train.AdamOptimizer(learning_rate=1e-3)
    train_step = opt.minimize(cost, var_list=params)

    # train_step = tf.train.AdamOptimizer(1e-3).minimize(cost)
    accuracies = accuracy_instance(tf.argmax(output_var, axis=2),
                                   target_ph,
                                   batch_size=generator.batch_size)
    sum_out = tf.reduce_sum(tf.reshape(
        tf.one_hot(tf.argmax(output_var, axis=2), depth=generator.nb_classes),
        (-1, generator.nb_classes)),
                            axis=0)

    print('Done')

    tf.summary.scalar('cost', cost)
    for i in range(generator.nb_samples_per_class):
        tf.summary.scalar('accuracy-' + str(i), accuracies[i])

    merged = tf.summary.merge_all()
    # writer = tf.summary.FileWriter('/tmp/tensorflow', graph=tf.get_default_graph())
    train_writer = tf.summary.FileWriter('./logs/', sess.graph)

    t0 = time.time()
    all_scores, scores, accs = [], [], np.zeros(generator.nb_samples_per_class)

    sess.run(tf.global_variables_initializer())

    print('Training the model')

    try:
        for i, (batch_input, batch_output) in tqdm(generator):
            feed_dict = {input_ph: batch_input, target_ph: batch_output}
            # print batch_input.shape, batch_output.shape
            train_step.run(feed_dict)
            score = cost.eval(feed_dict)
            acc = accuracies.eval(feed_dict)
            temp = sum_out.eval(feed_dict)
            summary = merged.eval(feed_dict)
            train_writer.add_summary(summary, i)
            # print(i, ' ', temp)
            all_scores.append(score)
            scores.append(score)
            accs += acc
            if i > 0 and not (i % 10):
                print(accs / 100.0)
                print('Episode %05d: %.6f' % (i, np.mean(score)))
                scores, accs = [], np.zeros(generator.nb_samples_per_class)

    except KeyboardInterrupt:
        print(time.time() - t0)
        pass
예제 #6
0
def omniglot():

    sess = tf.InteractiveSession()

    input_ph = tf.placeholder(dtype=tf.float32, shape=(16,50,400))   #(batch_size, time, input_dim)
    target_ph = tf.placeholder(dtype=tf.int32, shape=(16,50))     #(batch_size, time)(label_indices)

    ##Global variables for Omniglot Problem
    nb_reads = 4
    controller_size = 200
    memory_shape = (128,40)
    nb_class = 5
    input_size = 20*20
    batch_size = 16
    nb_samples_per_class = 10

    #Load Data
    generator = OmniglotGenerator(data_folder='./data/omniglot', batch_size=batch_size, nb_samples=nb_class, nb_samples_per_class=nb_samples_per_class, max_rotation=0., max_shift=0., max_iter=None)
    output_var, output_var_flatten, params = memory_augmented_neural_network(input_ph, target_ph, batch_size=batch_size, nb_class=nb_class, memory_shape=memory_shape, controller_size=controller_size, input_size=input_size, nb_reads=nb_reads)

    print 'Compiling the Model'
    

    with tf.variable_scope("Weights", reuse=True):
        W_key = tf.get_variable('W_key', shape=(nb_reads, controller_size, memory_shape[1]))
        b_key = tf.get_variable('b_key', shape=(nb_reads, memory_shape[1]))
        W_add = tf.get_variable('W_add', shape=(nb_reads, controller_size, memory_shape[1]))
        b_add = tf.get_variable('b_add', shape=(nb_reads, memory_shape[1]))
        W_sigma = tf.get_variable('W_sigma', shape=(nb_reads, controller_size, 1))
        b_sigma = tf.get_variable('b_sigma', shape=(nb_reads, 1))
        W_xh = tf.get_variable('W_xh', shape=(input_size + nb_class, 4 * controller_size))
        b_h = tf.get_variable('b_xh', shape=(4 * controller_size))
        W_o = tf.get_variable('W_o', shape=(controller_size + nb_reads * memory_shape[1], nb_class))
        b_o = tf.get_variable('b_o', shape=(nb_class))
        W_rh = tf.get_variable('W_rh', shape=(nb_reads * memory_shape[1], 4 * controller_size))
        W_hh = tf.get_variable('W_hh', shape=(controller_size, 4 * controller_size))
        gamma = tf.get_variable('gamma', shape=[1], initializer=tf.constant_initializer(0.95))

    params = [W_key, b_key, W_add, b_add, W_sigma, b_sigma, W_xh, W_rh, W_hh, b_h, W_o, b_o]
    
    #output_var = tf.cast(output_var, tf.int32)
    target_ph_oh = tf.one_hot(target_ph, depth=generator.nb_samples)
    print 'Output, Target shapes: ',output_var.get_shape().as_list(), target_ph_oh.get_shape().as_list()
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=output_var, labels=target_ph_oh), name="cost")
    opt = tf.train.AdamOptimizer(learning_rate=1e-3)
    train_step = opt.minimize(cost, var_list=params)

    #train_step = tf.train.AdamOptimizer(1e-3).minimize(cost)
    accuracies = accuracy_instance(tf.argmax(output_var, axis=2), target_ph, batch_size=generator.batch_size)
    sum_out = tf.reduce_sum(tf.reshape(tf.one_hot(tf.argmax(output_var, axis=2), depth=generator.nb_samples), (-1, generator.nb_samples)), axis=0)

    print 'Done'

    tf.summary.scalar('cost', cost)
    for i in range(generator.nb_samples_per_class):
    	tf.summary.scalar('accuracy-'+str(i), accuracies[i])
    
    merged = tf.summary.merge_all()
    #writer = tf.summary.FileWriter('/tmp/tensorflow', graph=tf.get_default_graph())
    train_writer = tf.summary.FileWriter('/tmp/tensorflow/', sess.graph)

    t0 = time.time()
    all_scores, scores, accs = [],[],np.zeros(generator.nb_samples_per_class)


    sess.run(tf.global_variables_initializer())

    print 'Training the model'



    try:
        for i, (batch_input, batch_output) in generator:
            feed_dict = {
                input_ph: batch_input,
                target_ph: batch_output
            }
            #print batch_input.shape, batch_output.shape
            train_step.run(feed_dict)
            score = cost.eval(feed_dict)
            acc = accuracies.eval(feed_dict)
            temp = sum_out.eval(feed_dict)
            summary = merged.eval(feed_dict)
            train_writer.add_summary(summary, i)
            print i, ' ',temp
            all_scores.append(score)
            scores.append(score)
            accs += acc
            if i>0 and not (i%100):
                print(accs / 100.0)
                print('Episode %05d: %.6f' % (i, np.mean(score)))
                scores, accs = [], np.zeros(generator.nb_samples_per_class)


    except KeyboardInterrupt:
        print time.time() - t0
        pass
예제 #7
0
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='3'


from MANN.Utils.Metrics import accuracy_instance
import tensorflow as tf
import numpy as np
import copy

x = [0,0,0,0,0]*10
y = [0,1,2,3,4]*10
np.random.shuffle(y)
x = np.append([x],[x],axis=0)
y = np.append([y], [y], axis=0)

p = tf.constant(x)
t = tf.constant(y)

sess = tf.InteractiveSession()

zz = accuracy_instance(p, t, batch_size=2)

sess.run(zz)

print p[0].eval()
print t[0].eval()

print zz.eval()

print tf.equal(p,t).eval()
def cifar10(load=False, class_to_train=0):

    if (class_to_train == 0):
        load = False

    tf.reset_default_graph()
    sess = tf.InteractiveSession()

    ##Global variables for cifar10 Problem
    nb_reads = 4
    controller_size = 200
    memory_shape = (128, 40)
    nb_class = 1
    input_size = 32 * 32 * 3
    batch_size = 16
    which_class = class_to_train
    class_to_plot = 0
    nb_samples_per_class = 10

    input_ph = tf.placeholder(
        dtype=tf.float32,
        shape=(batch_size, nb_class * nb_samples_per_class,
               input_size))  #(batch_size, time, input_dim)
    target_ph = tf.placeholder(
        dtype=tf.int32,
        shape=(batch_size, nb_class *
               nb_samples_per_class))  #(batch_size, time)(label_indices)

    #Load Data
    generator = CifarGenerator(data_folder='./data/cifar-10',
                               batch_size=batch_size,
                               nb_classes=nb_class,
                               _class=which_class,
                               nb_samples_per_class=nb_samples_per_class,
                               max_iter=500)
    output_var, output_var_flatten, params = memory_augmented_neural_network(
        input_ph,
        target_ph,
        batch_size=batch_size,
        nb_class=nb_class,
        memory_shape=memory_shape,
        controller_size=controller_size,
        input_size=input_size,
        nb_reads=nb_reads)

    print('Compiling the Model')

    with tf.variable_scope("Weights", reuse=True):
        W_key = tf.get_variable('W_key',
                                shape=(nb_reads, controller_size,
                                       memory_shape[1]))
        b_key = tf.get_variable('b_key', shape=(nb_reads, memory_shape[1]))
        W_add = tf.get_variable('W_add',
                                shape=(nb_reads, controller_size,
                                       memory_shape[1]))
        b_add = tf.get_variable('b_add', shape=(nb_reads, memory_shape[1]))
        W_sigma = tf.get_variable('W_sigma',
                                  shape=(nb_reads, controller_size, 1))
        b_sigma = tf.get_variable('b_sigma', shape=(nb_reads, 1))
        W_xh = tf.get_variable('W_xh',
                               shape=(input_size + nb_class,
                                      4 * controller_size))
        b_h = tf.get_variable('b_xh', shape=(4 * controller_size))
        W_o = tf.get_variable('W_o',
                              shape=(controller_size +
                                     nb_reads * memory_shape[1], nb_class))
        b_o = tf.get_variable('b_o', shape=(nb_class))
        W_rh = tf.get_variable('W_rh',
                               shape=(nb_reads * memory_shape[1],
                                      4 * controller_size))
        W_hh = tf.get_variable('W_hh',
                               shape=(controller_size, 4 * controller_size))
        gamma = tf.get_variable('gamma',
                                shape=[1],
                                initializer=tf.constant_initializer(0.95))

    params = [
        W_key, b_key, W_add, b_add, W_sigma, b_sigma, W_xh, W_rh, W_hh, b_h,
        W_o, b_o
    ]

    #output_var = tf.cast(output_var, tf.int32)
    target_ph_oh = tf.one_hot(target_ph, depth=generator.nb_classes)
    print('Output, Target shapes: ',
          output_var.get_shape().as_list(),
          target_ph_oh.get_shape().as_list())
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
        logits=output_var, labels=target_ph_oh),
                          name="cost")
    opt = tf.train.AdamOptimizer(learning_rate=1e-3)
    train_step = opt.minimize(cost, var_list=params)

    #train_step = tf.train.AdamOptimizer(1e-3).minimize(cost)
    accuracies = accuracy_instance(tf.argmax(output_var, axis=2),
                                   target_ph,
                                   batch_size=generator.batch_size)
    sum_out = tf.reduce_sum(tf.reshape(
        tf.one_hot(tf.argmax(output_var, axis=2), depth=generator.nb_classes),
        (-1, generator.nb_classes)),
                            axis=0)

    print('Done')

    tf.summary.scalar('cost', cost)
    for i in range(generator.nb_samples_per_class):
        tf.summary.scalar('accuracy-' + str(i), accuracies[i])

    merged = tf.summary.merge_all()
    #writer = tf.summary.FileWriter('/tmp/tensorflow', graph=tf.get_default_graph())
    train_writer = tf.summary.FileWriter('./tmp/tensorflow/', sess.graph)

    t0 = time.time()
    all_scores, scores, accs = [], [], np.zeros(generator.nb_samples_per_class)

    saver = tf.train.Saver()

    if (load):
        ckpt = tf.train.get_checkpoint_state('./saved/')
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            print("No Checkpoint found, setting load to false")
            load = False

    if (not load):
        sess.run(tf.global_variables_initializer())

    print('Training the model')

    try:
        for i, (batch_input, batch_output) in generator:

            if (batch_input.shape[0] == batch_size):
                break

            feed_dict = {input_ph: batch_input, target_ph: batch_output}

            #print(batch_input.shape, batch_output.shape)
            train_step.run(feed_dict)
            score = cost.eval(feed_dict)
            temp = sum_out.eval(feed_dict)
            summary = merged.eval(feed_dict)
            train_writer.add_summary(summary, i)
            print(i, ' ', temp)
            all_scores.append(score)
            scores.append(score)

            test_gen = CifarGenerator(
                data_folder='./data/cifar-10',
                batch_size=batch_size,
                nb_classes=nb_class,
                _class=class_to_plot,
                nb_samples_per_class=nb_samples_per_class,
                max_iter=100)

            for j, (test_input, test_output) in test_gen:
                test_dict = {input_ph: test_input, target_ph: test_output}
                acc = accuracies.eval(test_dict)
                accs += acc

            accs /= 100.0

            if i > 0 and not (i % 100):
                print("Test Accuracy (class 0) : {}".format(accs / 100.0))
                print('Episode %05d: %.6f' % (i, np.mean(score)))
                scores, accs = [], np.zeros(generator.nb_samples_per_class)
                saver.save(sess, './saved/model.ckpt', global_step=i + 1)

    except KeyboardInterrupt:
        print(time.time() - t0)
        pass