예제 #1
0
def create_ntm(config, sess, **ntm_args):
    if config.rand_hyper:
        hyper_params = {}
        if config.is_test:
            hyper_params = load_hyperparamters(config)
        else:
            hyper_params = generate_hyperparams(config)
        print(" [*] Hyperparameters: {}".format(hyper_params))
        cell = NTMCell(input_dim=config.input_dim,
                       output_dim=config.output_dim,
                       controller_layer_size=hyper_params["c_layer"],
                       controller_dim=hyper_params["c_dim"],
                       mem_size=hyper_params["mem_size"],
                       write_head_size=config.write_head_size,
                       read_head_size=config.read_head_size,
                       is_LSTM_mode=config.is_LSTM_mode)
        scope = ntm_args.pop('scope', 'NTM-%s' % config.task)

        # Description + query + plan + answer
        min_length = (config.min_size -
                      1) + 1 + config.plan_length + (config.min_size - 1)
        max_length = int(((config.max_size * (config.max_size - 1) / 2) + 1 +
                          config.plan_length + (config.max_size - 1)))
        ntm = NTM(cell,
                  sess,
                  min_length,
                  max_length,
                  config.min_size,
                  config.max_size,
                  scope=scope,
                  **ntm_args,
                  lr=hyper_params["lr"],
                  momentum=hyper_params["momentum"],
                  decay=hyper_params["decay"],
                  beta=hyper_params["l2"])

    else:
        cell = NTMCell(input_dim=config.input_dim,
                       output_dim=config.output_dim,
                       controller_layer_size=config.controller_layer_size,
                       controller_dim=config.controller_dim,
                       write_head_size=config.write_head_size,
                       read_head_size=config.read_head_size,
                       is_LSTM_mode=config.is_LSTM_mode)
        scope = ntm_args.pop('scope', 'NTM-%s' % config.task)

        # Description + query + plan + answer
        min_length = (config.min_size -
                      1) + 1 + config.plan_length + (config.min_size - 1)
        max_length = int(((config.max_size * (config.max_size - 1) / 2) + 1 +
                          config.plan_length + (config.max_size - 1)))
        ntm = NTM(cell,
                  sess,
                  min_length,
                  max_length,
                  config.min_size,
                  config.max_size,
                  scope=scope,
                  **ntm_args)
    return cell, ntm
예제 #2
0
def inference(images_t, last_labels_t):
    _, time_steps, height, width = images_t.get_shape().as_list()
    _, _, num_labels = last_labels_t.get_shape().as_list()

    with tf.variable_scope("rnn"):
        images_t = tf.reshape(images_t, (-1, time_steps, height * width))
        rnn_inputs_t = tf.concat(2, (images_t, last_labels_t))
        if CELL_TYPE == 'lstm':
            rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(LSTM_STATE_SIZE)
        elif CELL_TYPE == 'ntm':
            print 'ntm'
            rnn_cell = NTMCell(memory_slots=128,
                               memory_width=40,
                               controller_size=LSTM_STATE_SIZE)
        rnn_output_t, rnn_final_state_t = tf.nn.dynamic_rnn(rnn_cell,
                                                            rnn_inputs_t,
                                                            time_major=False,
                                                            dtype=tf.float32,
                                                            swap_memory=False)

    with tf.variable_scope("fcout"):
        rnn_output_size = rnn_output_t.get_shape().as_list()[-1]
        W_t = tf.get_variable(
            "W", (rnn_output_size, num_labels),
            initializer=tf.random_normal_initializer(stddev=0.1))
        b_t = tf.get_variable("b", (num_labels),
                              initializer=tf.constant_initializer(0.0))
        logits_t = tf.matmul(tf.reshape(rnn_output_t,
                                        (-1, rnn_output_size)), W_t) + b_t
        logits_t = tf.reshape(logits_t, (-1, time_steps, num_labels))

    return logits_t
def inference(images_t, last_labels_t):
    a, time_steps, width = images_t.get_shape().as_list()
    b, c, num_labels = last_labels_t.get_shape().as_list()
    with tf.variable_scope("rnn"):
        images_t = tf.reshape(images_t, (-1, time_steps, width))
        rnn_inputs_t = tf.concat((images_t, last_labels_t), 2)
        #keep_prob=tf.placeholder(tf.float32)
        #rnn_inputs_t = tf.nn.dropout(rnn_inputs, keep_prob)
        if CELL_TYPE == 'lstm':
            rnn_cell = tf.contrib.rnn.LSTMCell(LSTM_STATE_SIZE,
                                               activation=tf.nn.tanh)
        elif CELL_TYPE == 'ntm':
            print 'ntm'
            rnn_cell = NTMCell(memory_slots=128,
                               memory_width=40,
                               controller_size=LSTM_STATE_SIZE)
        rnn_output_t, rnn_final_state_t = tf.nn.dynamic_rnn(rnn_cell,
                                                            rnn_inputs_t,
                                                            time_major=False,
                                                            dtype=tf.float32,
                                                            swap_memory=False)
        #dynami-rnn is to automatically unroll lstm
    rnn_output_size = rnn_output_t.get_shape().as_list()[-1]
    W_t = tf.get_variable("W", (rnn_output_size, num_labels),
                          initializer=tf.random_normal_initializer(stddev=0.1))
    b_t = tf.get_variable("b", (num_labels),
                          initializer=tf.constant_initializer(0.0))
    logits_t = tf.matmul(tf.reshape(rnn_output_t,
                                    (-1, rnn_output_size)), W_t) + b_t
    logits_t = tf.reshape(logits_t, (-1, time_steps, num_labels))

    return logits_t
  def testOps(self):
    '''
    Verify that each of the operations (convolution, gating, etc.) are
    correct.
    Only compare the output from a single batch element and single time
    slice.
    '''

    mem_size = (self.N, self.M)
    initial_memory = self.initial_state[0:-2]
    np_initial_read_address = self.initial_state[-2]
    np_initial_write_address = self.initial_state[-1]
    tf_mem_prev = tf.stack(initial_memory, axis=1)
    np_mem_prev = np.stack(initial_memory, axis=1)
    # Only want the first batch element and first time slice from the
    # controller output to produce the read and write head values from a
    # single timestep.
    np_read_head, np_write_head = head_pieces(self.controller_output[0, 0, :],
                                              mem_size, self.S)

    np_read_ops_out = generate_address(np_read_head,
                                       np_initial_read_address[0, :],
                                       np_mem_prev[0, :, :],
                                       self.N,
                                       self.S)
    np_write_ops_out = generate_address(np_write_head[0:-2],
                                        np_initial_write_address[0, :],
                                        np_mem_prev[0, :, :],
                                        self.N,
                                        self.S)

    with self.test_session() as session:
      # The TF head pieces method takes in a single time slice from an entire
      # batch of controller data and spits out the read/write head values for
      # all batch items at that time slice.
      tf_write_head, tf_read_head = \
        NTMCell.head_pieces(self.controller_output[:, 0, :], mem_size, self.S)
      tf_read_ops_out = address_regression(tf_read_head,
                                           self.initial_state[-2],
                                           tf_mem_prev,
                                           self.N,
                                           self.S)
      tf_write_ops_out = address_regression(tf_write_head[0:-2],
                                            self.initial_state[-1],
                                            tf_mem_prev,
                                            self.N,
                                            self.S)

      tf_write_ops_out = session.run(tf_write_ops_out)
      tf_read_ops_out = session.run(tf_read_ops_out)

      self.assertEqual(len(tf_read_ops_out), len(np_read_ops_out))
      self.assertEqual(len(tf_write_ops_out), len(np_write_ops_out))

      for i in range(len(np_read_ops_out)):
        self.assertArrayNear(tf_read_ops_out[i][0], np_read_ops_out[i],
                             err=1e-8)
        self.assertArrayNear(tf_write_ops_out[i][0], np_write_ops_out[i],
                             err=1e-8)
예제 #5
0
def predict_train(config, sess):
    """Train an NTM for the copy task given a TensorFlow session, which is a
    connection to the C++ backend"""

    if not os.path.isdir(config.checkpoint_dir):
        raise Exception(" [!] Directory %s not found" % config.checkpoint_dir)

    # delimiter flag-like vector inputs indicating the start and end
    # you can see these in the figure examples in the README
    # this is kind of defined redundantly
    start_symbol = np.zeros([config.input_dim], dtype=np.float32)
    start_symbol[0] = 1
    end_symbol = np.zeros([config.input_dim], dtype=np.float32)
    end_symbol[1] = 1

    # initialise the neural turing machine and the neural-net controller thing
    cell = NTMCell(input_dim=config.input_dim,
                   output_dim=config.output_dim,
                   controller_layer_size=config.controller_layer_size,
                   write_head_size=config.write_head_size,
                   read_head_size=config.read_head_size)
    ntm = NTM(cell, sess, config.min_length, config.max_length*3)

    print(" [*] Initialize all variables")
    tf.initialize_all_variables().run()
    print(" [*] Initialization finished")

    start_time = time.time()
    for idx in xrange(config.epoch):
        # generate a sequence of random length
        seq_length = randint(config.min_length, config.max_length) * 4
        inc_seq, comp_seq = generate_predict_sequence(seq_length, config.input_dim - 2)

        # this somehow associates the desired inputs and outputs with the NTM
        feed_dict = {input_:vec for vec, input_ in zip(inc_seq, ntm.inputs)}
        feed_dict.update(
            {true_output:vec for vec, true_output in zip(comp_seq, ntm.true_outputs)}
        )
        feed_dict.update({
            ntm.start_symbol: start_symbol,
            ntm.end_symbol: end_symbol
        })

        # this runs the session and returns the current training loss and step
        # I'm kind of surprised it returns the step, but whatevs
        _, cost, step = sess.run([ntm.optims[seq_length],
                                  ntm.get_loss(seq_length),
                                  ntm.global_step], feed_dict=feed_dict)

        # how does one use these checkpoints?
        if idx % 100 == 0:
            ntm.save(config.checkpoint_dir, 'copy', step)

        if idx % print_interval == 0:
            print("[%5d] %2d: %.2f (%.1fs)" \
                % (idx, seq_length, cost, time.time() - start_time))

    print("Training Copy task finished")
    return cell, ntm
예제 #6
0
def create_ntm(FLAGS, sess, **ntm_args):
    cell = NTMCell(
        input_dim=FLAGS.input_dim,
        output_dim=FLAGS.output_dim,
        controller_layer_size=FLAGS.controller_layer_size,
        write_head_size=FLAGS.write_head_size,
        read_head_size=FLAGS.read_head_size)
    ntm = NTM(
        cell, sess, FLAGS.min_length, FLAGS.max_length,
        test_max_length=FLAGS.test_max_length, scope='NTM-%s' % FLAGS.task, **ntm_args)
    return cell, ntm
  def __init__(self, mem_size, session, num_heads=1, shift_range=3,
               name="NTM"):
    '''
    Just sets up an NTM without the controller. So all this will do is
    apply the NTM operations to some set of fake input.
    '''

    self.mem_size = mem_size
    self.shift_range = shift_range
    self.sess = session
    self.num_heads = num_heads

    (_, num_bits) = self.mem_size
    dt = tf.float32

    head_size = 4*num_bits + 2*self.shift_range + 6

    with tf.variable_scope(name):

      self.ntm_cell = NTMCell(mem_size=self.mem_size,
                              num_shifts=self.shift_range)

      # [batch_size, sequence_length, 4*M + 2*S + 6]
      self.feed_controller_input = \
        tf.placeholder(dtype=dt,
                       shape=(None, None, head_size))

      # ([batch_size, ntm_cell.state_size[0]], ...)
      self.feed_initial_state = \
        tuple([tf.placeholder(dtype=dt, shape=(None, s))
               for s in self.ntm_cell.state_size])

      self.ntm_reads, self.ntm_last_state = \
        tf.nn.dynamic_rnn(cell=self.ntm_cell,
                          initial_state=self.feed_initial_state,
                          inputs=self.feed_controller_input, dtype=dt)

      self.write_head, self.read_head = \
        self.ntm_cell.head_pieces(self.feed_controller_input,
                                  mem_size=self.mem_size,
                                  num_shifts=self.shift_range, axis=2)
예제 #8
0
def copy_train(config):
    sess = config.sess

    if not os.path.isdir(config.checkpoint_dir):
        raise Exception(" [!] Directory %s not found" % config.checkpoint_dir)

    # delimiter flag for start and end
    start_symbol = np.zeros([config.input_dim], dtype=np.float32)
    start_symbol[0] = 1
    end_symbol = np.zeros([config.input_dim], dtype=np.float32)
    end_symbol[1] = 1

    cell = NTMCell(input_dim=config.input_dim,
                   output_dim=config.output_dim,
                   controller_layer_size=config.controller_layer_size,
                   write_head_size=config.write_head_size,
                   read_head_size=config.read_head_size)
    ntm = NTM(cell, sess, config.min_length, config.max_length)

    print(" [*] Initialize all variables")
    tf.initialize_all_variables().run()
    print(" [*] Initialization finished")

    start_time = time.time()
    for idx in xrange(config.epoch):
        seq_length = randint(config.min_length, config.max_length)
        seq = generate_copy_sequence(seq_length, config.input_dim - 2)

        feed_dict = {input_: vec for vec, input_ in zip(seq, ntm.inputs)}
        feed_dict.update({
            true_output: vec
            for vec, true_output in zip(seq, ntm.true_outputs)
        })
        feed_dict.update({
            ntm.start_symbol: start_symbol,
            ntm.end_symbol: end_symbol
        })

        _, cost, step = sess.run([
            ntm.optims[seq_length],
            ntm.get_loss(seq_length), ntm.global_step
        ],
                                 feed_dict=feed_dict)

        if idx % 100 == 0:
            ntm.save(config.checkpoint_dir, 'copy', step)

        if idx % print_interval == 0:
            print("[%5d] %2d: %.2f (%.1fs)" \
                % (idx, seq_length, cost, time.time() - start_time))

    print("Training Copy task finished")
    return cell, ntm
예제 #9
0
    def _initialize(self, observation_t):
        image_t, last_label_t, _ = observation_t
        self.batch_size_t = tf.unpack(tf.shape(image_t))[0]
        _, self.image_height, self.image_width = image_t.get_shape().as_list()
        _, self.num_actions = last_label_t.get_shape().as_list()
        self.num_actions += 1  # for "pay for label"

        with tf.variable_scope("rnn"):
            if CELL_TYPE == 'lstm':
                self.rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(LSTM_STATE_SIZE)
            elif CELL_TYPE == 'ntm':
                print 'ntm'
                self.rnn_cell = NTMCell(memory_slots=128,
                                        memory_width=40,
                                        controller_size=LSTM_STATE_SIZE)
            #self.rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(LSTM_STATE_SIZE, state_is_tuple=True)
            self.rnn_state_t = self.rnn_cell.zero_state(
                self.batch_size_t, tf.float32)

        self.q_t = self._Q(observation_t)
        self.a_t = None
        self.initialized = True
예제 #10
0
def create_ntm(config, sess, **ntm_args):
    cell = NTMCell(
        input_dim=config.input_dim,
        output_dim=config.output_dim,
        controller_layer_size=config.controller_layer_size,
        controller_dim=config.controller_dim,
        write_head_size=config.write_head_size,
        read_head_size=config.read_head_size)
    scope = ntm_args.pop('scope', 'NTM-%s' % config.task)
    ntm = NTM(
        cell, sess, config.min_length, config.max_length,
        test_max_length=config.test_max_length, scope=scope, **ntm_args)
    return cell, ntm
  def setUp(self):
    '''
    Define the parameters that will be used to create the NP forward pass,
    then perform the NP forward pass.
    '''

    # Parameter definitions
    min_addresses = 5
    max_addresses = 10
    min_bits_per_address = 6
    max_bits_per_address = 12
    max_batch_size = 32
    min_batch_size = 10

    self.N = np.random.randint(low=min_addresses, high=max_addresses + 1)
    self.M = np.random.randint(low=min_bits_per_address,
                               high=max_bits_per_address + 1)
    #self.N, self.M = (10, 9)
    self.mem_size = (self.N, self.M)

    min_shifts = 3
    max_shifts = self.N - 1

    self.S = np.random.randint(low=min_shifts, high=max_shifts + 1)
    self.shift_range = self.S
    self.batch_size = np.random.randint(low=min_batch_size,
                                        high=max_batch_size)
    self.sequence_length = np.random.randint(low=3, high=max_addresses)
    #self.S, self.batch_size, self.sequence_length = (3, 12, 15)

    self.initial_state = NTMCell(self.mem_size,
                                 self.shift_range).bias_state(self.batch_size)

    self.controller_output = 10*np.random.rand(self.batch_size,
                                               self.sequence_length,
                                               4*self.M + 2*self.S + 6) - 5

    # Get the reference NP output for a single sequence (only one of the
    # batch items gets processed to completion).
    seq_initial_state = tuple([x[0, :] for x in self.initial_state])

    self.np_read_addresses, self.np_write_addresses, self.np_reads = \
      numpy_forward_pass(self.N,
                         self.M,
                         self.S,
                         seq_initial_state,
                         self.controller_output[0, :, :])
  def testHeadPieces(self):
    '''
    Show that the values extracted from the controller (key, gate, shift, etc.)
    are correct.
    '''

    mem_size = (self.N, self.M)
    np_read_head, np_write_head = head_pieces(self.controller_output,
                                              mem_size,
                                              self.S)

    with self.test_session() as session:
      tf_write_head, tf_read_head = NTMCell.head_pieces(self.controller_output,
                                                        mem_size,
                                                        self.S,
                                                        axis=2)
      tf_write_head, tf_read_head = session.run([tf_write_head, tf_read_head])

      # Make sure we got the same number of items from the read and write
      # heads.
      self.assertEqual(len(tf_write_head), len(np_write_head))
      self.assertEqual(len(tf_read_head), len(np_read_head))

      # Verify that the NP and TF read heads have approximately the same
      # values.
      for i in range(len(np_read_head)):
        for j in range(np_read_head[i].shape[0]):
          for k in range(np_read_head[i].shape[1]):
            self.assertArrayNear(np_read_head[i][j, k, :],
                                 tf_read_head[i][j, k, :],
                                 err=1e-8)

      # Verify that the NP and TF write heads have approximately the same
      # values.
      for i in range(len(np_write_head)):
        for j in range(np_write_head[i].shape[0]):
          for k in range(np_write_head[i].shape[1]):
            self.assertArrayNear(np_write_head[i][j, k, :],
                                 tf_write_head[i][j, k, :],
                                 err=1e-8)
class NTM(object):
    def __init__(self,
                 mem_size,
                 input_size,
                 output_size,
                 session,
                 num_heads=1,
                 shift_range=3,
                 name="NTM"):

        self.num_heads = 1
        self.sess = session
        self.S = shift_range
        self.N, self.M = mem_size
        self.in_size = input_size
        self.out_size = output_size

        num_lstm_units = 100
        self.dt = tf.float32
        self.pi = 64

        pi = self.pi
        dt = self.dt
        N = self.N
        M = self.M
        S = self.S
        num_heads = self.num_heads

        with tf.variable_scope(name):
            self.feed_in = tf.placeholder(dtype=dt,
                                          shape=(None, None, input_size))

            self.feed_out = tf.placeholder(dtype=dt,
                                           shape=(None, None, output_size))

            self.feed_learning_rate = tf.placeholder(dtype=dt, shape=())

            batch_size = tf.shape(self.feed_in)[0]
            seq_length = tf.shape(self.feed_in)[1]

            head_raw = self.controller(self.feed_in, batch_size, seq_length)

            self.ntm_cell = NTMCell(mem_size=(N, M), shift_range=S)

            self.write_head, self.read_head = NTMCell.head_pieces(
                head_raw, mem_size=(N, M), shift_range=S, axis=2, style='dict')

            self.ntm_init_state = tuple(
                [tf.placeholder(dtype=dt, shape=(None, s)) \
                for s in self.ntm_cell.state_size])

            self.ntm_reads, self.ntm_last_state = tf.nn.dynamic_rnn(
                cell=self.ntm_cell,
                initial_state=self.ntm_init_state,
                inputs=head_raw,
                dtype=dt,
                parallel_iterations=pi)

            # Started conversion to the multi-head output here, still have
            # lots to do.
            self.w_read = self.ntm_last_state[N:N + num_heads]
            self.w_write = self.ntm_last_state[N + num_heads:N + 2 * num_heads]

            ntm_reads_flat = [tf.reshape(r, [-1, M]) for r in self.ntm_reads]

            L = tf.Variable(tf.random_normal([M, output_size]))
            b_L = tf.Variable(tf.random_normal([
                output_size,
            ]))

            logits_flat = tf.matmul(ntm_reads_flat, L) + b_L
            targets_flat = tf.reshape(self.feed_out, [-1, output_size])

            self.error = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(labels=targets_flat,
                                                        logits=logits_flat))

            self.predictions = tf.sigmoid(
                tf.reshape(logits_flat, [batch_size, seq_length, output_size]))

            optimizer = tf.train.RMSPropOptimizer(
                learning_rate=self.feed_learning_rate, momentum=0.9)

            grads_and_vars = optimizer.compute_gradients(self.error)
            capped_grads = [(tf.clip_by_value(grad, -10., 10.), var) \
                for grad, var in grads_and_vars]

            self.train_op = optimizer.apply_gradients(capped_grads)

    def controller(self, inputs, batch_size, seq_length, num_units=100):
        N = self.N
        M = self.M
        S = self.S
        pi = self.pi
        dt = self.dt
        num_heads = self.num_heads

        self.lstm_cell = tf.contrib.rnn.BasicLSTMCell(num_units=num_units,
                                                      forget_bias=1.0)

        self.lstm_init_state = tuple([
            tf.placeholder(dtype=dt, shape=(None, s))
            for s in self.lstm_cell.state_size
        ])

        lstm_init_state = tf.contrib.rnn.LSTMStateTuple(
            self.lstm_init_state[0], self.lstm_init_state[1])

        lstm_out_raw, self.lstm_last_state = tf.nn.dynamic_rnn(
            cell=self.lstm_cell,
            initial_state=lstm_init_state,
            inputs=inputs,
            dtype=dt,
            parallel_iterations=pi)

        lstm_out = tf.tanh(lstm_out_raw)
        lstm_out_flat = tf.reshape(lstm_out, [-1, num_units])

        head_nodes = 4 * M + 2 * S + 6

        head_W = tf.Variable(tf.random_normal(
            [num_units, num_heads * head_nodes]),
                             name='head_W')
        head_b_W = tf.Variable(tf.random_normal([
            num_heads * head_nodes,
        ]),
                               name='head_b_W')

        head_raw_flat = tf.matmul(lstm_out_flat, head_W) + head_b_W
        head_raw = tf.reshape(head_raw_flat,
                              [batch_size, seq_length, head_nodes])

        return head_raw

    def train_batch(self, batch_x, batch_y, learning_rate=1e-4):

        lr = learning_rate
        batch_size = batch_x.shape[0]
        ntm_init_state = self.ntm_cell.bias_state(batch_size)
        lstm_init_state = tuple(
            [np.zeros((batch_size, s)) for s in self.lstm_cell.state_size])

        fetches = [self.error, self.train_op]
        feeds = {
            self.feed_in: batch_x,
            self.feed_out: batch_y,
            self.feed_learning_rate: lr
        }

        for i in range(len(ntm_init_state)):
            feeds[self.ntm_init_state[i]] = ntm_init_state[i]

        for i in range(len(lstm_init_state)):
            feeds[self.lstm_init_state[i]] = lstm_init_state[i]

        error, _ = self.sess.run(fetches, feeds)

        return error

    def run_once(self, test_x):
        batch_size = test_x.shape[0]
        num_seq = test_x.shape[1]
        sequences = np.split(test_x, num_seq, axis=1)
        ntm_init_state = self.ntm_cell.bias_state(batch_size)
        lstm_init_state = tuple(
            [np.zeros((batch_size, s)) for s in self.lstm_cell.state_size])

        outputs = []
        w_read = []
        w_write = []
        g_read = []
        g_write = []
        s_read = []
        s_write = []

        for seq in sequences:
            fetches = [
                self.predictions, self.ntm_last_state, self.lstm_last_state,
                self.read_head, self.write_head
            ]
            feeds = {self.feed_in: seq}

            for i in range(len(ntm_init_state)):
                feeds[self.ntm_init_state[i]] = ntm_init_state[i]

            for i in range(len(lstm_init_state)):
                feeds[self.lstm_init_state[i]] = lstm_init_state[i]

            output, ntm_init_state, lstm_init_state, \
                read_head, write_head = self.sess.run(fetches, feeds)

            outputs.append(output[0].copy())
            w_read.append(ntm_init_state[-2][0].copy())
            w_write.append(ntm_init_state[-1][0].copy())
            g_read.append(read_head['g'][0, 0, :].copy())
            g_write.append(write_head['g'][0, 0, :].copy())
            s_read.append(read_head['shift'][0, 0, :].copy())
            s_write.append(write_head['shift'][0, 0, :].copy())

        output_b = np.squeeze(np.array(outputs))
        w_read_b = np.array(w_read)
        w_write_b = np.array(w_write)
        g_read_b = np.array(g_read)
        g_write_b = np.array(g_write)
        s_read_b = np.array(s_read)
        s_write_b = np.array(s_write)

        #print(output_b.shape)

        return output_b, w_read_b, w_write_b, g_read_b, \
            g_write_b, s_read_b, s_write_b
예제 #14
0
class Agent:
    def __init__(self, epsilon_t):
        self.epsilon_t = epsilon_t
        self.initialized = False
        self.num_actions = None
        self.batch_size_t = None
        self.a_t = None
        self.q_t = None
        self.image_height = self.image_width = None
        self.rnn_cell = self.rnn_state_t = None

    def choose_action(self, observation_t):
        if not self.initialized:
            self._initialize(observation_t)
        _, _, oracle_label_t = observation_t

        with tf.variable_scope("action_selection"):
            a_max_t = tf.to_int32(tf.argmax(self.q_t, 1))
            a_const0_t = tf.zeros_like(a_max_t)
            a_const1_t = tf.ones_like(a_max_t)
            a_rand_t = tf.random_uniform([self.batch_size_t],
                                         maxval=self.num_actions,
                                         dtype=tf.int32)
            #use_max_t = tf.to_int32(tf.greater(tf.random_uniform([self.batch_size_t]), tf.ones([self.batch_size_t])*self.epsilon_t))
            #self.a_t = tf.one_hot(use_max_t*a_max_t + (1-use_max_t)*a_rand_t, self.num_actions)
            #self.a_t = tf.one_hot(use_max_t*a_max_t + (1-use_max_t)*a_const0_t, self.num_actions)
            #self.a_t = tf.one_hot(a_rand_t, self.num_actions)
            #self.a_t = tf.one_hot(use_max_t*a_const0_t + (1-use_max_t)*a_const1_t, self.num_actions)
            #self.a_t = tf.one_hot(a_const0_t, self.num_actions)
            #self.a_t = tf.one_hot(a_const1_t, self.num_actions)
            #self.a_t = tf.one_hot(a_rand_t, self.num_actions)
            #self.a_t = tf.one_hot(a_max_t, self.num_actions)

            a_true_t = tf.to_int32(tf.argmax(oracle_label_t, 1))
            a_wrong_t = tf.to_int32(
                tf.squeeze(tf.multinomial(tf.log(1 - oracle_label_t), 1), [1]))
            a_question_t = tf.to_int32(
                tf.ones_like(a_max_t) * (self.num_actions - 1))

            #a_max_t = a_true_t
            #a_true_t = a_wrong_t = a_question_t
            #a_wrong_t = a_question_t = a_true_t
            #a_true_t = a_question_t = a_wrong_t

            a_type_t = tf.to_int32(
                tf.one_hot(
                    tf.squeeze(
                        tf.multinomial([[
                            tf.log(1 - self.epsilon_t),
                            tf.log(self.epsilon_t / 3.0),
                            tf.log(self.epsilon_t / 3.0),
                            tf.log(self.epsilon_t / 3.0)
                        ]], self.batch_size_t), [0]), 4))  #self.num_actions))
            self.a_t = tf.one_hot(
                tf.reduce_sum(
                    a_type_t * tf.pack(
                        [a_max_t, a_true_t, a_wrong_t, a_question_t], axis=1),
                    1), self.num_actions)

            #self.a_t = tf.one_hot(tf.argmax(self.q_t, 1), self.num_actions)

        return self.a_t

    def learn(self, reward_t, observation_t):
        q_new_t = tf.nn.softmax(self._Q(observation_t))
        qa_t = tf.reduce_sum(self.a_t * self.q_t,
                             1)  # extract q for the action we already took
        #regret_t = tf.square(qa_t - reward_t) # bandit
        regret_t = tf.square(
            qa_t - (reward_t + 0.5 * tf.reduce_max(q_new_t, 1)))  # q-learning
        # mnist_..._rl_002:0.0 discount factor (bandit)
        # mnist_..._rl_003:0.0 discount factor (bandit)
        # mnist_..._rl_004:0.5 discount factor
        # mnist_..._rl_005:0.5 discount factor
        # mnist_..._rl_006:1.0 discount factor
        # mnist_..._rl_007:0.8 discount factor
        # mnist_..._rl_008:0.2 discount factor
        # omniglot_..._rl_001:0.0 discount factor (bandit)
        # omniglot_..._rl_002:0.5 discount factor
        self.q_t = q_new_t
        return regret_t

    def _initialize(self, observation_t):
        image_t, last_label_t, _ = observation_t
        self.batch_size_t = tf.unpack(tf.shape(image_t))[0]
        _, self.image_height, self.image_width = image_t.get_shape().as_list()
        _, self.num_actions = last_label_t.get_shape().as_list()
        self.num_actions += 1  # for "pay for label"

        with tf.variable_scope("rnn"):
            if CELL_TYPE == 'lstm':
                self.rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(LSTM_STATE_SIZE)
            elif CELL_TYPE == 'ntm':
                print 'ntm'
                self.rnn_cell = NTMCell(memory_slots=128,
                                        memory_width=40,
                                        controller_size=LSTM_STATE_SIZE)
            #self.rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(LSTM_STATE_SIZE, state_is_tuple=True)
            self.rnn_state_t = self.rnn_cell.zero_state(
                self.batch_size_t, tf.float32)

        self.q_t = self._Q(observation_t)
        self.a_t = None
        self.initialized = True

    def _Q(self, observation_t):
        image_t, last_label_t, _ = observation_t

        #with tf.variable_scope("Q") as scope:
        scope = tf.get_variable_scope()
        if self.initialized:
            scope.reuse_variables()

        with tf.variable_scope("rnn/RNN"):
            image_t = tf.reshape(
                image_t,
                (self.batch_size_t, self.image_height * self.image_width))
            rnn_input_t = tf.concat(1, (image_t, last_label_t))
            rnn_output_t, self.rnn_state_t = self.rnn_cell(
                rnn_input_t, self.rnn_state_t)

        with tf.variable_scope("fcout"):
            rnn_output_size = rnn_output_t.get_shape().as_list()[-1]
            W_t = tf.get_variable(
                "W", (rnn_output_size, self.num_actions),
                initializer=tf.random_normal_initializer(stddev=0.1))
            b_t = tf.get_variable("b", (self.num_actions),
                                  initializer=tf.constant_initializer(0.0))
            #q_t = tf.matmul(tf.reshape(rnn_output_t, (-1, LSTM_STATE_SIZE)), W_t)+b_t
            q_t = tf.matmul(rnn_output_t, W_t) + b_t

        return q_t
class NTMRegression(object):
  '''
  A class that makes regression testing on the NTMCell easier
  '''
  def __init__(self, mem_size, session, num_heads=1, shift_range=3,
               name="NTM"):
    '''
    Just sets up an NTM without the controller. So all this will do is
    apply the NTM operations to some set of fake input.
    '''

    self.mem_size = mem_size
    self.shift_range = shift_range
    self.sess = session
    self.num_heads = num_heads

    (_, num_bits) = self.mem_size
    dt = tf.float32

    head_size = 4*num_bits + 2*self.shift_range + 6

    with tf.variable_scope(name):

      self.ntm_cell = NTMCell(mem_size=self.mem_size,
                              num_shifts=self.shift_range)

      # [batch_size, sequence_length, 4*M + 2*S + 6]
      self.feed_controller_input = \
        tf.placeholder(dtype=dt,
                       shape=(None, None, head_size))

      # ([batch_size, ntm_cell.state_size[0]], ...)
      self.feed_initial_state = \
        tuple([tf.placeholder(dtype=dt, shape=(None, s))
               for s in self.ntm_cell.state_size])

      self.ntm_reads, self.ntm_last_state = \
        tf.nn.dynamic_rnn(cell=self.ntm_cell,
                          initial_state=self.feed_initial_state,
                          inputs=self.feed_controller_input, dtype=dt)

      self.write_head, self.read_head = \
        self.ntm_cell.head_pieces(self.feed_controller_input,
                                  mem_size=self.mem_size,
                                  num_shifts=self.shift_range, axis=2)


  def run(self, controller_input, initial_state):
    '''
    Takes some controller input and initial state and spits out read/write
    addresses and values that are read from a memory matrix.
    '''

    (_, seq_length, _) = controller_input.shape
    sequences = np.split(controller_input, seq_length, axis=1)
    init_state = initial_state

    read_addresses = []
    write_addresses = []
    sequence_reads = []

    for seq in sequences:
      fetches = [self.ntm_reads, self.ntm_last_state]
      feeds = {self.feed_controller_input: seq}

      for i in range(len(init_state)):
        feeds[self.feed_initial_state[i]] = init_state[i]

      reads, last_state = self.sess.run(fetches, feeds)

      sequence_reads.append(reads)
      read_addresses.append(last_state[-2])
      write_addresses.append(last_state[-1])

      init_state = last_state

    read_addresses = \
      np.transpose(np.squeeze(np.array(read_addresses)), [1, 0, 2])
    write_addresses = \
      np.transpose(np.squeeze(np.array(write_addresses)), [1, 0, 2])
    sequence_reads = \
      np.transpose(np.squeeze(np.array(sequence_reads)), [1, 0, 2])

    return read_addresses, write_addresses, sequence_reads
예제 #16
0
    def __init__(self, mem_size, input_size, output_size, session,
                 num_heads=1, shift_range=3, name="NTM"):
        '''
        Builds the computation graph for the Neural Turing Machine.
        The tasks from the original paper call for the NTM to take in a
        sequence of arrays, and produce some output.
        Let B = batch size, T = sequence length, and L = array length, then
        a single input sequence is a matrix of size [TxL]. A batch of these
        input sequences has size [BxTxL].

        Arguments:
          mem_size - Tuple of integers corresponding to the number of storage
            locations and the dimension of each storage location (in the paper
            the memory matrix is NxM, mem_size refers to (N, M)).
          input_size - Integer number of elements in a single input vector
            (the value L).
          output_size - Integer number of elements in a single output vector.
          session - The TensorFlow session object that refers to the current
            computation graph.
          num_heads - The integer number of write heads the NTM uses (future
            feature).
          shift_range - The integer number of shift values that the read/write 
            heads can perform, which corresponds to the direction and magnitude
            of the allowable shifts.
            Shift ranges and corresponding available shift
            directions/magnitudes:
              3 => [-1, 0, 1]
              4 => [-2, -1, 0, 1] 
              5 => [-2, -1, 0, 1, 2]
          name - A string name for the variable scope, for troubleshooting.
        '''

        self.num_heads = 1
        self.sess = session
        self.S = shift_range
        self.N, self.M = mem_size
        self.in_size = input_size
        self.out_size = output_size

        num_lstm_units = 100
        self.dt=tf.float32

        dt = self.dt
        N = self.N
        M = self.M
        S = self.S
        num_heads = self.num_heads

        with tf.variable_scope(name):
            self.feed_in = tf.placeholder(dtype=dt,
                shape=(None, None, input_size))

            self.feed_out = tf.placeholder(dtype=dt,
                shape=(None, None, output_size))

            self.feed_learning_rate = tf.placeholder(dtype=dt, 
                shape=())

            batch_size = tf.shape(self.feed_in)[0]
            seq_length = tf.shape(self.feed_in)[1]

            head_raw = self.controller(self.feed_in, batch_size, seq_length)

            self.ntm_cell = NTMCell(mem_size=(N, M), num_shifts=S)

            write_head, read_head = NTMCell.head_pieces(
                head_raw, mem_size=(N, M), num_shifts=S, axis=2)

            self.write_head, self.read_head = \
                head_pieces_tuple_to_dict(write_head, read_head)

            self.ntm_init_state = tuple(
                [tf.placeholder(dtype=dt, shape=(None, s)) \
                for s in self.ntm_cell.state_size])

            self.ntm_reads, self.ntm_last_state = tf.nn.dynamic_rnn(
                cell=self.ntm_cell, initial_state=self.ntm_init_state,
                inputs=head_raw, dtype=dt)

            self.w_read = self.ntm_last_state[-2]
            self.w_write = self.ntm_last_state[-1]

            ntm_reads_flat = tf.reshape(self.ntm_reads, [-1, M])

            L = tf.Variable(tf.random_normal([M, output_size]))
            b_L = tf.Variable(tf.random_normal([output_size,]))

            logits_flat = tf.matmul(ntm_reads_flat, L) + b_L
            targets_flat = tf.reshape(self.feed_out, [-1, output_size])

            self.error = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=targets_flat, logits=logits_flat))

            self.predictions = tf.sigmoid(
                tf.reshape(logits_flat, [batch_size, seq_length, output_size]))

            optimizer = tf.train.RMSPropOptimizer(
                learning_rate=self.feed_learning_rate, momentum=0.9)

            grads_and_vars = optimizer.compute_gradients(self.error)
            capped_grads = [(tf.clip_by_value(grad, -10., 10.), var) \
                for grad, var in grads_and_vars]

            self.train_op = optimizer.apply_gradients(capped_grads)
예제 #17
0
config = {
    'epoch': 100000,
    'input_dim': 7,
    'output_dim': 7,
    'length': 5,
    'controller_layer_size': 1,
    'write_head_size': 1,
    'read_head_size': 1,
    'checkpoint_dir': 'checkpoint'
}

if __name__ == "__main__":
    with tf.device('/cpu:0'), tf.Session() as sess:
        cell = NTMCell(input_dim=config['input_dim'],
                       output_dim=config['output_dim'],
                       controller_layer_size=config['controller_layer_size'],
                       write_head_size=config['write_head_size'],
                       read_head_size=config['read_head_size'],
                       controller_dim=32)
        ntm = NTM(cell, sess, config['length'] * 2 + 2)

        if not os.path.isdir(config['checkpoint_dir'] + '/copy_' +
                             str(config['length'] * 2 + 2)):
            print(" [*] Initialize all variables")
            tf.global_variables_initializer().run()
            print(" [*] Initialization finished")
        else:
            ntm.load(config['checkpoint_dir'], 'copy')

        start_time = time.time()
        print('')
        for idx in range(config['epoch']):
예제 #18
0
class NTM(object):
    '''
    Performs several operations relevant to the NTM:
      - builds the computation graph
      - trains the model
      - tests the model
    '''
    
    def __init__(self, mem_size, input_size, output_size, session,
                 num_heads=1, shift_range=3, name="NTM"):
        '''
        Builds the computation graph for the Neural Turing Machine.
        The tasks from the original paper call for the NTM to take in a
        sequence of arrays, and produce some output.
        Let B = batch size, T = sequence length, and L = array length, then
        a single input sequence is a matrix of size [TxL]. A batch of these
        input sequences has size [BxTxL].

        Arguments:
          mem_size - Tuple of integers corresponding to the number of storage
            locations and the dimension of each storage location (in the paper
            the memory matrix is NxM, mem_size refers to (N, M)).
          input_size - Integer number of elements in a single input vector
            (the value L).
          output_size - Integer number of elements in a single output vector.
          session - The TensorFlow session object that refers to the current
            computation graph.
          num_heads - The integer number of write heads the NTM uses (future
            feature).
          shift_range - The integer number of shift values that the read/write 
            heads can perform, which corresponds to the direction and magnitude
            of the allowable shifts.
            Shift ranges and corresponding available shift
            directions/magnitudes:
              3 => [-1, 0, 1]
              4 => [-2, -1, 0, 1] 
              5 => [-2, -1, 0, 1, 2]
          name - A string name for the variable scope, for troubleshooting.
        '''

        self.num_heads = 1
        self.sess = session
        self.S = shift_range
        self.N, self.M = mem_size
        self.in_size = input_size
        self.out_size = output_size

        num_lstm_units = 100
        self.dt=tf.float32

        dt = self.dt
        N = self.N
        M = self.M
        S = self.S
        num_heads = self.num_heads

        with tf.variable_scope(name):
            self.feed_in = tf.placeholder(dtype=dt,
                shape=(None, None, input_size))

            self.feed_out = tf.placeholder(dtype=dt,
                shape=(None, None, output_size))

            self.feed_learning_rate = tf.placeholder(dtype=dt, 
                shape=())

            batch_size = tf.shape(self.feed_in)[0]
            seq_length = tf.shape(self.feed_in)[1]

            head_raw = self.controller(self.feed_in, batch_size, seq_length)

            self.ntm_cell = NTMCell(mem_size=(N, M), num_shifts=S)

            write_head, read_head = NTMCell.head_pieces(
                head_raw, mem_size=(N, M), num_shifts=S, axis=2)

            self.write_head, self.read_head = \
                head_pieces_tuple_to_dict(write_head, read_head)

            self.ntm_init_state = tuple(
                [tf.placeholder(dtype=dt, shape=(None, s)) \
                for s in self.ntm_cell.state_size])

            self.ntm_reads, self.ntm_last_state = tf.nn.dynamic_rnn(
                cell=self.ntm_cell, initial_state=self.ntm_init_state,
                inputs=head_raw, dtype=dt)

            self.w_read = self.ntm_last_state[-2]
            self.w_write = self.ntm_last_state[-1]

            ntm_reads_flat = tf.reshape(self.ntm_reads, [-1, M])

            L = tf.Variable(tf.random_normal([M, output_size]))
            b_L = tf.Variable(tf.random_normal([output_size,]))

            logits_flat = tf.matmul(ntm_reads_flat, L) + b_L
            targets_flat = tf.reshape(self.feed_out, [-1, output_size])

            self.error = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=targets_flat, logits=logits_flat))

            self.predictions = tf.sigmoid(
                tf.reshape(logits_flat, [batch_size, seq_length, output_size]))

            optimizer = tf.train.RMSPropOptimizer(
                learning_rate=self.feed_learning_rate, momentum=0.9)

            grads_and_vars = optimizer.compute_gradients(self.error)
            capped_grads = [(tf.clip_by_value(grad, -10., 10.), var) \
                for grad, var in grads_and_vars]

            self.train_op = optimizer.apply_gradients(capped_grads)

    def controller(self, inputs, batch_size, seq_length, num_units=100):
        '''
        Builds a single-layer LSTM controller that manipulates values in the 
        memory matrix and helps produce output. This method should only be 
        utilized by the class.

        Arguments:
          inputs - TF tensor containing data that is passed to the controller.
          batch_size - The number of sequences in a given training batch.
          seq_length - The length of the sequence being passed to the 
            controller.
          num_units - The number of units inside of the LSTM controller.
        '''

        N = self.N
        M = self.M
        S = self.S
        dt = self.dt
        num_heads = self.num_heads

        self.lstm_cell = tf.contrib.rnn.BasicLSTMCell(
            num_units=num_units, forget_bias=1.0)

        self.lstm_init_state = tuple(
            [tf.placeholder(dtype=dt, shape=(None, s))
            for s in self.lstm_cell.state_size])

        lstm_init_state = tf.contrib.rnn.LSTMStateTuple(
            self.lstm_init_state[0], self.lstm_init_state[1])

        lstm_out_raw, self.lstm_last_state = tf.nn.dynamic_rnn(
            cell=self.lstm_cell, initial_state=lstm_init_state,
            inputs=inputs, dtype=dt)

        lstm_out = tf.tanh(lstm_out_raw)
        lstm_out_flat = tf.reshape(lstm_out, [-1, num_units])

        # The number of nodes on the controller's output is determined by
        #   1. the number of allowable shifts
        #   2. the width of the columns in the memory matrix
        head_nodes = 4*M+2*S+6

        head_W = tf.Variable(
            tf.random_normal([num_units, num_heads*head_nodes]), name='head_W')
        head_b_W = tf.Variable(
            tf.random_normal([num_heads*head_nodes,]), name='head_b_W')

        head_raw_flat = tf.matmul(lstm_out_flat, head_W) + head_b_W
        head_raw = tf.reshape(head_raw_flat, [batch_size, seq_length, head_nodes])

        return head_raw

    def train_batch(self, batch_x, batch_y, learning_rate=1e-4):
        '''
        Trains the model on a batch of inputs and their corresponding outputs.
        Returns the error that was obtained by training the NTM on the input
        sequence that is provided as an argument.

        Arguments:
          batch_x - The batch of input training sequences [BxTxL1]. Note that
            the first two dimensions (batch size and sequence length) of both
            batches MUST be the same. Numpy array.
          batch_y - The batch of output training sequences [BxTxL2]. The 
            output sequences are the desired outputs after the NTM has been
            presented with the training input, batch_x. Numpy array.

        Outputs:
          error - The amount of error (float)produced from this particular 
            training sequence. The error operation is defined in the
            constructor.
        '''

        lr = learning_rate
        batch_size = batch_x.shape[0]
        ntm_init_state = self.ntm_cell.bias_state(batch_size)
        lstm_init_state = tuple([np.zeros((batch_size, s)) \
            for s in self.lstm_cell.state_size])

        fetches = [self.error, self.train_op]
        feeds = {
            self.feed_in:batch_x,
            self.feed_out:batch_y,
            self.feed_learning_rate:lr
        }

        for i in range(len(ntm_init_state)):
            feeds[self.ntm_init_state[i]] = ntm_init_state[i]

        for i in range(len(lstm_init_state)):
            feeds[self.lstm_init_state[i]] = lstm_init_state[i]

        error, _ = self.sess.run(fetches, feeds)

        return error

    def run_once(self, test_x):
        '''
        Passes a single input sequence to the NTM, and produces an output
        according to what it's learned. Returns a tuple of items of interest
        for troubleshooting purposes (the read/write vectors and output).

        Arguments:
          test_x - A batch of input sequences [BxTxL1] that the NTM will use to
            produce a batch of output sequences [BxTxL2]. Numpy array.

        Outputs:
          output_b - A numpy array representing the output of the NTM after
            being presented with the input batch [BxTxL2].
          w_read_b - A numpy array of "read" locations that the NTM used.
            From the paper, write locations are normalized vectors that allow
            the NTM to focus on rows of the memory matrix.
          w_write_b - A numpy array of "write" locations that the NTM used.
          g_read_b - A numpy array of scalar values indicating whether the NTM
            used the previous read location or associative recall to determine
            the read location at each timestep.
          g_write_b - A numpy array of scalar values indicating whether the NTM
            used the previous write location or associative recall to determine
            the write location at each timestep.
          s_read_b - A numpy array of vectors describing the magnitude and
            direction of the shifting operation that was applied to the read
            head.
          s_write_b - A numpy array of vectors describing the magnitude and
            direction of the shifting operation that was applied to the write
            head.
        '''

        batch_size = test_x.shape[0]
        num_seq = test_x.shape[1]
        sequences = np.split(test_x, num_seq, axis=1)
        ntm_init_state = self.ntm_cell.bias_state(batch_size)
        lstm_init_state = tuple(
            [np.zeros((batch_size, s)) for s in self.lstm_cell.state_size])

        outputs = []
        w_read = []
        w_write = []
        g_read = []
        g_write = []
        s_read = []
        s_write = []

        for seq in sequences:
            fetches = [self.predictions, self.ntm_last_state,
                self.lstm_last_state, self.read_head, self.write_head]
            feeds = {self.feed_in: seq}

            for i in range(len(ntm_init_state)):
                feeds[self.ntm_init_state[i]] = ntm_init_state[i]

            for i in range(len(lstm_init_state)):
                feeds[self.lstm_init_state[i]] = lstm_init_state[i]

            output, ntm_init_state, lstm_init_state, \
                read_head, write_head = self.sess.run(fetches, feeds)

            outputs.append(output[0].copy())
            w_read.append(ntm_init_state[-2][0].copy())
            w_write.append(ntm_init_state[-1][0].copy())
            g_read.append(read_head['g'][0,0,:].copy())
            g_write.append(write_head['g'][0,0,:].copy())
            s_read.append(read_head['shift'][0,0,:].copy())
            s_write.append(write_head['shift'][0,0,:].copy())

        output_b = np.squeeze(np.array(outputs))
        w_read_b = np.array(w_read)
        w_write_b = np.array(w_write)
        g_read_b = np.array(g_read)
        g_write_b = np.array(g_write)
        s_read_b = np.array(s_read)
        s_write_b = np.array(s_write)

        #print(output_b.shape)

        return output_b, w_read_b, w_write_b, g_read_b, \
            g_write_b, s_read_b, s_write_b
    def __init__(self,
                 mem_size,
                 input_size,
                 output_size,
                 session,
                 num_heads=1,
                 shift_range=3,
                 name="NTM"):

        self.num_heads = 1
        self.sess = session
        self.S = shift_range
        self.N, self.M = mem_size
        self.in_size = input_size
        self.out_size = output_size

        num_lstm_units = 100
        self.dt = tf.float32
        self.pi = 64

        pi = self.pi
        dt = self.dt
        N = self.N
        M = self.M
        S = self.S
        num_heads = self.num_heads

        with tf.variable_scope(name):
            self.feed_in = tf.placeholder(dtype=dt,
                                          shape=(None, None, input_size))

            self.feed_out = tf.placeholder(dtype=dt,
                                           shape=(None, None, output_size))

            self.feed_learning_rate = tf.placeholder(dtype=dt, shape=())

            batch_size = tf.shape(self.feed_in)[0]
            seq_length = tf.shape(self.feed_in)[1]

            head_raw = self.controller(self.feed_in, batch_size, seq_length)

            self.ntm_cell = NTMCell(mem_size=(N, M), shift_range=S)

            self.write_head, self.read_head = NTMCell.head_pieces(
                head_raw, mem_size=(N, M), shift_range=S, axis=2, style='dict')

            self.ntm_init_state = tuple(
                [tf.placeholder(dtype=dt, shape=(None, s)) \
                for s in self.ntm_cell.state_size])

            self.ntm_reads, self.ntm_last_state = tf.nn.dynamic_rnn(
                cell=self.ntm_cell,
                initial_state=self.ntm_init_state,
                inputs=head_raw,
                dtype=dt,
                parallel_iterations=pi)

            # Started conversion to the multi-head output here, still have
            # lots to do.
            self.w_read = self.ntm_last_state[N:N + num_heads]
            self.w_write = self.ntm_last_state[N + num_heads:N + 2 * num_heads]

            ntm_reads_flat = [tf.reshape(r, [-1, M]) for r in self.ntm_reads]

            L = tf.Variable(tf.random_normal([M, output_size]))
            b_L = tf.Variable(tf.random_normal([
                output_size,
            ]))

            logits_flat = tf.matmul(ntm_reads_flat, L) + b_L
            targets_flat = tf.reshape(self.feed_out, [-1, output_size])

            self.error = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(labels=targets_flat,
                                                        logits=logits_flat))

            self.predictions = tf.sigmoid(
                tf.reshape(logits_flat, [batch_size, seq_length, output_size]))

            optimizer = tf.train.RMSPropOptimizer(
                learning_rate=self.feed_learning_rate, momentum=0.9)

            grads_and_vars = optimizer.compute_gradients(self.error)
            capped_grads = [(tf.clip_by_value(grad, -10., 10.), var) \
                for grad, var in grads_and_vars]

            self.train_op = optimizer.apply_gradients(capped_grads)