Example #1
0
 def rnn(self, inputs):
     access_config = {
         "memory_size": self._memory_size,
         "word_size": self._word_size,
         "num_reads": self._num_reads,
         "num_writes": self._num_writes,
     }
     controller_config = {
         # "num_layers": self._num_layers,
         "hidden_size": self._layer_width,
         "initializers": {
             "w_gates": tf.variance_scaling_initializer(),
             "b_gates": tf.constant_initializer(0.1),
             "w_f_diag": tf.variance_scaling_initializer(),
             "w_i_diag": tf.variance_scaling_initializer(),
             "w_o_diag": tf.variance_scaling_initializer()
         },
         "use_peepholes": True
     }
     with tf.variable_scope("RNN"):
         dnc_core = dnc.DNC(access_config, controller_config,
                            self._layer_width, self._clip_value)
         initial_state = dnc_core.initial_state(tf.shape(inputs)[0])
         output_sequence, _ = tf.nn.dynamic_rnn(
             cell=dnc_core,
             inputs=inputs,
             parallel_iterations=self._parallel_iterations,
             swap_memory=True,
             initial_state=initial_state,
             sequence_length=self.seqlen)
         output = self.last_relevant(output_sequence, self.seqlen)
         return output
Example #2
0
def train():
    """Trains the DNC and periodically reports the loss."""

    dataset = get_sample(PARAMS.batch_size, FLAGS.data_dir)
    output_size = get_word_space_size(FLAGS.data_dir)

    # wrap DNC recurrent cell to form complete model
    x = tf.keras.Input(shape=(None, output_size, ))
    dnc_cell = dnc.DNC(
        output_size,
        controller_units=PARAMS.units,
        memory_size=PARAMS.memory_size,
        word_size=PARAMS.word_size,
        num_read_heads=PARAMS.num_read_heads
    )
    dnc_initial_state = dnc_cell.get_initial_state(batch_size=PARAMS.batch_size)
    rnn = tf.keras.layers.RNN(dnc_cell, return_sequences=True)
    y = rnn(x, initial_state=dnc_initial_state)
    model = tf.keras.models.Model(x, y)


    learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
        PARAMS.learning_rate, 10000, 0.99, staircase=True
    )
    optimizer = tf.keras.optimizers.RMSprop(epsilon=PARAMS.optimizer_epsilon,
                                            momentum=0.9, learning_rate=learning_rate,
                                            clipnorm=PARAMS.max_grad_norm)
    os.makedirs(os.path.join(FLAGS.checkpoint_dir, 'model'), exist_ok=True)
    os.makedirs(os.path.join(FLAGS.checkpoint_dir, 'summaries'), exist_ok=True)

    step = 0
    logging.info("Starting training...")
    for step in range(FLAGS.num_training_steps):
        x, y, _, y_mask = next(dataset)
        with tf.GradientTape() as tape:
            logits = model(x)
            loss = tf.reduce_mean(
                y_mask *
                tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=y)
            )
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        step +=1 
        if step % FLAGS.report_interval == 0:
            logger.info("Loss at step {:d}: {:.6f}".format(step, loss))
        if step % FLAGS.checkpoint_interval == 0:
            model.save(os.path.join(FLAGS.checkpoint_dir, 'model', 'step{}.h5'.format(step)))

    model.save(os.path.join(FLAGS.checkpoint_dir, 'model', 'final.h5'.format(step)))
Example #3
0
def get_dnc(output_size=3):

    access_config = {
        "memory_size": FLAGS.memory_size,
        "word_size": FLAGS.word_size,
        "num_reads": FLAGS.num_read_heads,
        "num_writes": FLAGS.num_write_heads,
    }
    controller_config = {
        "hidden_size": FLAGS.hidden_size,
    }
    clip_value = FLAGS.clip_value
    dnc_core = dnc.DNC(access_config, controller_config, output_size,
                       clip_value)
    initial_state = dnc_core.initial_state(FLAGS.batch_size)

    return dnc_core, initial_state
Example #4
0
def run_model(input_sequence, output_size):
    """Runs model on input sequence."""

    memory_config = {
        "memory_size": FLAGS.memory_size,
        "word_size": FLAGS.word_size,
        "num_read_heads": FLAGS.num_read_heads,
    }

    dnc_core = dnc.DNC(output_size, controller_units=FLAGS.units, **memory_config)
    initial_state = dnc_core.get_initial_state(batch_size=FLAGS.batch_size)
    output_sequence, _ = tf.nn.dynamic_rnn(
        cell=dnc_core,
        inputs=input_sequence,
        time_major=True,
        initial_state=initial_state)

    return output_sequence
Example #5
0
 def rnn(self, inputs):
     access_config = {
         "memory_size": self._memory_size,
         "word_size": self._word_size,
         "num_reads": self._num_reads,
         "num_writes": self._num_writes,
     }
     controller_config = {
         # "num_layers": self._num_layers,
         "hidden_size": self._layer_width,
         "initializers": {
             "w_gates": tf.variance_scaling_initializer(),
             "b_gates": tf.constant_initializer(0.1),
             "w_f_diag": tf.variance_scaling_initializer(),
             "w_i_diag": tf.variance_scaling_initializer(),
             "w_o_diag": tf.variance_scaling_initializer()
         },
         "use_peepholes": True
     }
     with tf.variable_scope("RNN"):
         dnc_core = dnc.DNC(access_config, controller_config,
                            self._layer_width, self._clip_value)
         initial_state = dnc_core.initial_state(tf.shape(inputs)[0])
         # transpose to time major: [time, batch, feature]
         # tm_inputs = tf.transpose(inputs, perm=[1, 0, 2])
         output_sequence, _ = tf.nn.dynamic_rnn(
             cell=dnc_core,
             inputs=inputs,
             initial_state=initial_state,
             # parallel_iterations=256,
             # dtype=tf.float32,  # If there is no initial_state, you must give a dtype
             # time_major=True,
             # swap_memory=True,
             sequence_length=self.seqlen)
         # layer = tf.concat(layer, 1)
         # restore to batch major: [batch, time, feature]
         # output_sequence = tf.transpose(output_sequence, perm=[1, 0, 2])
         output = self.last_relevant(output_sequence, self.seqlen)
         return output
Example #6
0
def run_model(input_sequence, output_size):
  """Runs model on input sequence."""

  access_config = {
      "memory_size": FLAGS.memory_size,
      "word_size": FLAGS.word_size,
      "num_reads": FLAGS.num_read_heads,
      "num_writes": FLAGS.num_write_heads,
  }
  controller_config = {
      "hidden_size": FLAGS.hidden_size,
  }
  clip_value = FLAGS.clip_value

  dnc_core = dnc.DNC(access_config, controller_config, output_size, clip_value)
  initial_state = dnc_core.initial_state(FLAGS.batch_size)
  output_sequence, _ = tf.nn.dynamic_rnn(
      cell=dnc_core,
      inputs=input_sequence,
      time_major=True,
      initial_state=initial_state)

  return output_sequence
Example #7
0
def run_model(input_sequence, output_size):
    """Runs model on input sequence."""

    memory_config = {
        "words_num": FLAGS.memory_size,
        "word_size": FLAGS.word_size,
        "read_heads_num": FLAGS.num_read_heads,
    }
    controller_config = {
        "hidden_size": FLAGS.hidden_size,
    }

    dnc_core = dnc.DNC(controller_config,
                       memory_config,
                       output_size,
                       classic_dnc_output=False)
    initial_state = dnc_core.initial_state(FLAGS.batch_size)
    output_sequence, _ = tf.nn.dynamic_rnn(cell=dnc_core,
                                           inputs=input_sequence,
                                           time_major=True,
                                           initial_state=initial_state)

    return output_sequence
Example #8
0
    def __init__(self, args):
        if args.label_type == 'one_hot':
            args.output_dim = args.n_classes
        elif args.label_type == 'five_hot':
            args.output_dim = 25

        self.x_image = tf.placeholder(dtype=tf.float32,
                                      shape=[
                                          args.batch_size, args.seq_length,
                                          args.image_width * args.image_height
                                      ])
        self.x_label = tf.placeholder(
            dtype=tf.float32,
            shape=[args.batch_size, args.seq_length, args.output_dim])
        self.y = tf.placeholder(
            dtype=tf.float32,
            shape=[args.batch_size, args.seq_length, args.output_dim])

        if args.model == 'LSTM':

            def rnn_cell(rnn_size):
                return tf.nn.rnn_cell.BasicLSTMCell(rnn_size)

            cell = tf.nn.rnn_cell.MultiRNNCell(
                [rnn_cell(args.rnn_size) for _ in range(args.rnn_num_layers)])
        elif args.model == 'NTM':
            import ntm.ntm_cell as ntm_cell
            cell = ntm_cell.NTMCell(args.rnn_size,
                                    args.memory_size,
                                    args.memory_vector_dim,
                                    read_head_num=args.read_head_num,
                                    write_head_num=args.write_head_num,
                                    addressing_mode='content_and_location',
                                    output_dim=args.output_dim)
        elif args.model == 'MANN':
            import ntm.mann_cell as mann_cell
            cell = mann_cell.MANNCell(args.rnn_size,
                                      args.memory_size,
                                      args.memory_vector_dim,
                                      head_num=args.read_head_num)
        elif args.model == 'MANN2':
            import ntm.mann_cell_2 as mann_cell
            cell = mann_cell.MANNCell(args.rnn_size,
                                      args.memory_size,
                                      args.memory_vector_dim,
                                      head_num=args.read_head_num)
        elif args.model == 'ACT':
            from tf_rnn_adaptive.act_wrapper import ACTWrapper

            def rnn_cell(rnn_size):
                return tf.nn.rnn_cell.BasicLSTMCell(rnn_size)

            # inner_cell = tf.nn.rnn_cell.MultiRNNCell([rnn_cell(args.rnn_size) for _ in range(args.rnn_num_layers)])
            inner_cell = rnn_cell(args.rnn_size)
            cell = ACTWrapper(inner_cell, ponder_limit=10)
        elif args.model == 'DNC':
            from dnc import dnc
            access_config = {
                "memory_size": args.memory_size,
                "word_size": args.memory_vector_dim,
                "num_reads": args.read_head_num,
                "num_writes": args.write_head_num,
            }
            controller_config = {
                "hidden_size": args.rnn_size,
            }
            clip_value = args.clip_value
            cell = dnc.DNC(access_config, controller_config, args.output_dim,
                           clip_value)
        elif args.model == 'ACT-DNC':
            from tf_rnn_adaptive.act_wrapper import ACTWrapper
            from dnc import dnc
            access_config = {
                "memory_size": args.memory_size,
                "word_size": args.memory_vector_dim,
                "num_reads": args.read_head_num,
                "num_writes": args.write_head_num,
            }
            controller_config = {
                "hidden_size": args.rnn_size,
            }
            clip_value = args.clip_value

            dnc_core = dnc.DNC(access_config, controller_config,
                               args.output_dim, clip_value)
            cell = ACTWrapper(dnc_core, ponder_limit=10)
        elif args.model == 'MY-ACT':
            from my_act.act_wrapper import ACTWrapper
            from dnc import dnc

            # main dnc
            with tf.variable_scope('act_wrapper'):
                access_config = {
                    "memory_size": args.memory_size,
                    "word_size": args.memory_vector_dim,
                    "num_reads": args.read_head_num,
                    "num_writes": args.write_head_num,
                }
                controller_config = {
                    "hidden_size": args.rnn_size,
                }
                clip_value = args.clip_value
                main_dnc = dnc.DNC(access_config,
                                   controller_config,
                                   args.output_dim,
                                   clip_value,
                                   name='main_dnc')

            # auxiliary dnc
            # use memory_vector_dim as aux's output size
            # so that info does not lose when communicating with main dnc
            with tf.variable_scope('act_wrapper'):
                aux_access_config = {
                    "memory_size": 8,
                    "word_size": args.memory_vector_dim,
                    "num_reads": 1,
                    "num_writes": 1
                }
                aux_controller_config = {
                    "hidden_size": args.rnn_size,
                }
                aux_dnc = dnc.DNC(aux_access_config,
                                  aux_controller_config,
                                  args.memory_vector_dim,
                                  clip_value,
                                  name='aux_dnc')

            cell = ACTWrapper(main_dnc,
                              aux_dnc,
                              ponder_limit=10,
                              divergence_type=args.divergence_loss)
        else:
            raise Exception('Unknown model: `{}`'.format(args.model))

        if args.model == 'ACT-DNC':
            state = dnc_core.initial_state(args.batch_size)
        elif args.model == 'DNC':
            state = cell.initial_state(args.batch_size)
        elif args.model == 'MY-ACT':
            state = main_dnc.initial_state(args.batch_size)
        else:
            state = cell.zero_state(args.batch_size, tf.float32)
        self.state_list = [state]  # For debugging
        self.o = []
        for t in range(args.seq_length):
            output, state = cell(
                tf.concat([self.x_image[:, t, :], self.x_label[:, t, :]],
                          axis=1), state)
            # output, state = cell(self.y[:, t, :], state)
            with tf.variable_scope("o2o", reuse=(t > 0)):
                o2o_w = tf.get_variable(
                    'o2o_w', [output.get_shape()[1], args.output_dim],
                    initializer=tf.random_uniform_initializer(minval=-0.1,
                                                              maxval=0.1))
                # initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1))
                o2o_b = tf.get_variable(
                    'o2o_b', [args.output_dim],
                    initializer=tf.random_uniform_initializer(minval=-0.1,
                                                              maxval=0.1))
                # initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1))
                output = tf.nn.xw_plus_b(output, o2o_w, o2o_b)
            if args.label_type == 'one_hot':
                output = tf.nn.softmax(output, dim=1)
            elif args.label_type == 'five_hot':
                output = tf.stack(
                    [tf.nn.softmax(o) for o in tf.split(output, 5, axis=1)],
                    axis=1)
            self.o.append(output)
            self.state_list.append(state)
        self.o = tf.stack(self.o, axis=1)
        self.state_list.append(state)

        eps = 1e-8
        if args.label_type == 'one_hot':
            self.learning_loss = -tf.reduce_mean(  # cross entropy function
                tf.reduce_sum(self.y * tf.log(self.o + eps), axis=[1, 2]))
        elif args.label_type == 'five_hot':
            self.learning_loss = -tf.reduce_mean(  # cross entropy function
                tf.reduce_sum(tf.stack(tf.split(self.y, 5, axis=2), axis=2) *
                              tf.log(self.o + eps),
                              axis=[1, 2, 3]))
        self.o = tf.reshape(self.o,
                            shape=[args.batch_size, args.seq_length, -1])
        self.learning_loss_summary = tf.summary.scalar('learning_loss',
                                                       self.learning_loss)
        """ ponder loss """
        if 'ACT' in args.model:
            time_penalty = 0.001
            self._ponder_loss = time_penalty * cell.get_ponder_cost(
                args.seq_length)
            self.learning_loss += self._ponder_loss
            self.ponder_steps = cell.get_ponder_steps(args.seq_length)
            self.mean_ponder_steps = tf.reduce_mean(self.ponder_steps)
        else:
            self.mean_ponder_steps = tf.constant(0)
        """ memory divergence loss """
        if args.model == 'MY-ACT':
            if args.divergence_loss is not None:
                divergence_penalty = 0.001
                self.learning_loss += cell.get_memory_divergence_loss(
                ) * divergence_penalty

        with tf.variable_scope('optimizer'):
            self.optimizer = tf.train.AdamOptimizer(
                learning_rate=args.learning_rate)
            # self.optimizer = tf.train.RMSPropOptimizer(
            #     learning_rate=args.learning_rate, momentum=0.9, decay=0.95
            # )
            # gvs = self.optimizer.compute_gradients(self.learning_loss)
            # capped_gvs = [(tf.clip_by_value(grad, -10., 10.), var) for grad, var in gvs]
            # self.train_op = self.optimizer.apply_gradients(gvs)
            self.train_op = self.optimizer.minimize(self.learning_loss)
Example #9
0
        memory_config = {
            "words_num": 256,
            "word_size": 64,
            "read_heads_num": 4,
        }
        controller_config = {
            "hidden_size": 256,
        }

        output_size = len(lexicon_dictionary)
        input_data = tf.placeholder(tf.float32, [None, 1, output_size])
        target_output = tf.placeholder(tf.float32, [None, 1, output_size])
        target_mask = tf.placeholder(tf.float32, [None, 1, 1])

        dnc_core = dnc.DNC(controller_config,
                           memory_config,
                           output_size,
                           classic_dnc_output=False)
        initial_state = dnc_core.initial_state(1)
        output_logits, _ = tf.nn.dynamic_rnn(cell=dnc_core,
                                             inputs=input_data,
                                             time_major=True,
                                             initial_state=initial_state)
        softmaxed = tf.nn.softmax(output_logits)

        saver = tf.train.Saver()
        if not os.path.exists(FLAGS.checkpoint_file + ".index"):
            raise RuntimeError("wrong input file")
        saver.restore(session, FLAGS.checkpoint_file)

        tasks_results = {}
        tasks_names = {}
Example #10
0
        memory_config = {
            "memory_size": 256,
            "word_size": 64,
            "num_read_heads": 4,
        }
        controller_config = {
            "units": 256,
        }

        output_size = len(lexicon_dictionary)
        input_data = tf.placeholder(tf.float32, [None, 1, output_size])
        target_output = tf.placeholder(tf.float32, [None, 1, output_size])
        target_mask = tf.placeholder(tf.float32, [None, 1, 1])

        dnc_core = dnc.DNC(output_size, controller_units=256, **memory_config)
        initial_state = dnc_core.get_initial_state(batch_size=1)
        output_logits, _ = tf.nn.dynamic_rnn(
            cell=dnc_core,
            inputs=input_data,
            time_major=True,
            initial_state=initial_state)
        softmaxed = tf.nn.softmax(output_logits)

        saver = tf.train.Saver()
        if not os.path.exists(FLAGS.checkpoint_file + ".index"):
            raise RuntimeError("wrong input file")
        saver.restore(session, FLAGS.checkpoint_file)

        tasks_results = {}
        tasks_names = {}