def __init__(self, args):
        if args.label_type == 'one_hot':
            args.output_dim = args.n_classes
        elif args.label_type == 'five_hot':
            args.output_dim = 25

        eprint("Creating Placeholders")
        if args.dataset_type == 'omniglot':
            self.x_image = tf.placeholder(dtype=tf.float32,
                                          shape=[args.batch_size, args.seq_length, args.image_width * args.image_height])
        elif args.dataset_type == 'kinetics_dynamic':
            self.x_image = tf.placeholder(dtype=tf.float32,
                                          shape=[args.batch_size, args.seq_length, args.image_width, args.image_height, 3])
        elif args.dataset_type == 'kinetics_video':
            #  self.x_image = tf.placeholder(dtype=tf.float32,
                                          #  shape=[args.batch_size, args.seq_length, args.sample_nframes, args.image_width, args.image_height, 3])
            self.x_image = tf.placeholder(dtype=tf.float32,
                                          shape=[args.batch_size, args.seq_length, args.sample_nframes, args.image_width, args.image_height, 3])
        elif args.dataset_type == 'kinetics_single_frame':
            self.x_image = tf.placeholder(dtype=tf.float32,
                                          shape=[args.batch_size, args.seq_length, args.image_width, args.image_height, 3])


        self.is_training = tf.placeholder(tf.bool)

        self.x_label = tf.placeholder(dtype=tf.float32,
                                      shape=[args.batch_size, args.seq_length, args.output_dim])
        self.y = tf.placeholder(dtype=tf.float32,
                                shape=[args.batch_size, args.seq_length, args.output_dim])

        eprint("Creating Cells")
        if args.model == 'LSTM':
            def rnn_cell(rnn_size):
                return tf.nn.rnn_cell.BasicLSTMCell(rnn_size)
            cell = tf.nn.rnn_cell.MultiRNNCell([rnn_cell(args.rnn_size) for _ in range(args.rnn_num_layers)])
        elif args.model == 'NTM':
            import ntm.ntm_cell as ntm_cell
            cell = ntm_cell.NTMCell(args.rnn_size, args.memory_size, args.memory_vector_dim,
                                    read_head_num=args.read_head_num,
                                    write_head_num=args.write_head_num,
                                    addressing_mode='content_and_location',
                                    output_dim=args.output_dim)
        elif args.model == 'MANN':
            import ntm.mann_cell as mann_cell
            cell = mann_cell.MANNCell(args.rnn_size, args.memory_size, args.memory_vector_dim, is_training=self.is_training,
                                    head_num=args.read_head_num, args=args)
        elif args.model == 'MANN2':
            import ntm.mann_cell_2 as mann_cell
            cell = mann_cell.MANNCell(args.rnn_size, args.memory_size, args.memory_vector_dim,
                                    head_num=args.read_head_num)

        eprint("Looping through seq")
        # Zero out the memory state
        state = cell.zero_state(args.batch_size, tf.float32)
        self.state_list = [state]   # For debugging. Keep track of previous states
        self.o = []
        for t in range(args.seq_length):
            # So this is going and calling __call__
            # it is passing both the image and x_label
            #  output, state = cell(tf.concat([self.x_image[:, t, :], self.x_label[:, t, :]], axis=1), state)
            #  output, state = cell(self.x_image[:, t, :], state)
            # Q: the labels are passed in here as part of the input....
            eprint("Seq: [{}] Passing data into the cell".format(t))
            if args.dataset_type == 'omniglot':
                output, state = cell(self.x_image[:, t, :], self.x_label[:, t, :], state)
            elif args.dataset_type == 'kinetics_video':
                output, state = cell(self.x_image[:, t, :, :, :, :], self.x_label[:, t, :], state)
            else:
                # in the case of kinetics dynamic.. aka default
                output, state = cell(self.x_image[:, t, :, :, :], self.x_label[:, t, :], state)

            # output, state = cell(self.y[:, t, :], state)
            # go from the memory stored dimensionality to the number of classes / predictions
            with tf.variable_scope("o2o", reuse=(t > 0)):
                o2o_w = tf.get_variable('o2o_w', [output.get_shape()[1], args.output_dim],
                                        initializer=tf.random_uniform_initializer(minval=-0.1, maxval=0.1))
                                        # initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1))
                o2o_b = tf.get_variable('o2o_b', [args.output_dim],
                                        initializer=tf.random_uniform_initializer(minval=-0.1, maxval=0.1))
                                        # initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1))
                # matmul + bias
                output = tf.nn.xw_plus_b(output, o2o_w, o2o_b)
            if args.label_type == 'one_hot':
                # Made a change here for making it dim = -1
                output = tf.nn.softmax(output, dim=-1)
            elif args.label_type == 'five_hot':
                output = tf.stack([tf.nn.softmax(o) for o in tf.split(output, 5, axis=1)], axis=1)
            self.o.append(output)
            self.state_list.append(state)
        self.o = tf.stack(self.o, axis=1)
        self.state_list.append(state)

        eprint("Defining Loss")
        eps = 1e-8
        if args.label_type == 'one_hot':
            self.learning_loss = -tf.reduce_mean(  # cross entropy function
                tf.reduce_sum(self.y * tf.log(self.o + eps), axis=[1, 2])
            )
        elif args.label_type == 'five_hot':
            self.learning_loss = -tf.reduce_mean(  # cross entropy function
                tf.reduce_sum(tf.stack(tf.split(self.y, 5, axis=2), axis=2) * tf.log(self.o + eps), axis=[1, 2, 3])
            )
        self.o = tf.reshape(self.o, shape=[args.batch_size, args.seq_length, -1])
        self.learning_loss_summary = tf.summary.scalar('learning_loss', self.learning_loss)

        eprint( "Total number of variables used ", np.sum([v.get_shape().num_elements() for v in tf.trainable_variables()]) )
        eprint("Defining optimizer")
        with tf.variable_scope('optimizer'):
            if args.optimizer == 'adam':
                self.optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
            elif args.optimizer == 'rms':
                self.optimizer = tf.train.RMSPropOptimizer(learning_rate=args.learning_rate)
            else:
                self.optimizer = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate)
            # self.optimizer = tf.train.RMSPropOptimizer(
            #     learning_rate=args.learning_rate, momentum=0.9, decay=0.95
            # )
            # gvs = self.optimizer.compute_gradients(self.learning_loss)
            # capped_gvs = [(tf.clip_by_value(grad, -10., 10.), var) for grad, var in gvs]
            # self.train_op = self.optimizer.apply_gradients(gvs)
            self.train_op = self.optimizer.minimize(self.learning_loss)

        eprint("Finished Definining Model")
Exemplo n.º 2
0
    def __init__(self, args):
        if args.label_type == 'one_hot':
            args.output_dim = args.n_classes
        elif args.label_type == 'five_hot':
            args.output_dim = 25

        self.x_image = tf.placeholder(dtype=tf.float32, shape=[args.batch_size,
                                                               args.seq_length,
                                                               args.image_width * args.image_height])
        self.x_label = tf.placeholder(dtype=tf.float32, shape=[args.batch_size,
                                                               args.seq_length,
                                                               args.output_dim])
        self.y = tf.placeholder(dtype=tf.float32, shape=[args.batch_size,
                                                         args.seq_length,
                                                         args.output_dim])

        if args.model == 'LSTM':
            def rnn_cell(rnn_size):
                return tf.contrib.rnn.BasicLSTMCell(rnn_size)

            cell = tf.contrib.rnn.MultiRNNCell([rnn_cell(args.rnn_size) for _ in range(args.rnn_num_layers)])
        elif args.model == 'NTM':
            import ntm.ntm_cell as ntm_cell
            cell = ntm_cell.NTMCell(args.rnn_size, args.memory_size, args.memory_vector_dim,
                                    read_head_num=args.read_head_num,
                                    write_head_num=args.write_head_num,
                                    addressing_mode='content_and_location',
                                    output_dim=args.output_dim)
        elif args.model == 'MANN':
            import ntm.mann_cell as mann_cell
            cell = mann_cell.MANNCell(args.rnn_size, args.memory_size, args.memory_vector_dim,
                                      head_num=args.read_head_num)
        elif args.model == 'MANN2':
            import ntm.mann_cell_2 as mann_cell
            cell = mann_cell.MANNCell(args.rnn_size, args.memory_size, args.memory_vector_dim,
                                      head_num=args.read_head_num)

        state = cell.zero_state(args.batch_size, tf.float32)
        self.state_list = [state]  # For debugging
        self.o = []
        for t in range(args.seq_length):
            output, state = cell(tf.concat([self.x_image[:, t, :], self.x_label[:, t, :]], axis=1), state)
            # output, state = cell(self.y[:, t, :], state)
            with tf.variable_scope("o2o", reuse=(t > 0)):
                o2o_w = tf.get_variable('o2o_w', [output.get_shape()[1], args.output_dim],
                                        initializer=tf.random_uniform_initializer(minval=-0.1, maxval=0.1))
                # initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1))
                o2o_b = tf.get_variable('o2o_b', [args.output_dim],
                                        initializer=tf.random_uniform_initializer(minval=-0.1, maxval=0.1))
                # initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1))
                output = tf.nn.xw_plus_b(output, o2o_w, o2o_b)
            if args.label_type == 'one_hot':
                output = tf.nn.softmax(output, dim=1)
            elif args.label_type == 'five_hot':
                output = tf.stack([tf.nn.softmax(o) for o in tf.split(output, 5, axis=1)], axis=1)
            self.o.append(output)
            self.state_list.append(state)
        self.o = tf.stack(self.o, axis=1)
        self.state_list.append(state)

        eps = 1e-8
        if args.label_type == 'one_hot':
            self.loss = -tf.reduce_mean(tf.reduce_sum(self.y * tf.log(self.o + eps), axis=[1, 2]))
        elif args.label_type == 'five_hot':
            self.loss = -tf.reduce_mean(
                tf.reduce_sum(tf.stack(tf.split(self.y, 5, axis=2), axis=2) * tf.log(self.o + eps)
                              , axis=[1, 2, 3]))
        self.o = tf.reshape(self.o, shape=[args.batch_size, args.seq_length, -1])
        self.loss_summary = tf.summary.scalar('Loss', self.loss)

        with tf.variable_scope('optimizer'):
            self.optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
            self.train_op = self.optimizer.minimize(self.loss)
Exemplo n.º 3
0
    def __init__(self, args):
        if args.label_type == 'one_hot':
            args.output_dim = args.n_classes
        elif args.label_type == 'five_hot':
            args.output_dim = 25

        self.x_image = tf.placeholder(dtype=tf.float32,
                                      shape=[
                                          args.batch_size, args.seq_length,
                                          args.image_width * args.image_height
                                      ])
        self.x_label = tf.placeholder(
            dtype=tf.float32,
            shape=[args.batch_size, args.seq_length, args.output_dim])
        self.y = tf.placeholder(
            dtype=tf.float32,
            shape=[args.batch_size, args.seq_length, args.output_dim])

        if args.model == 'LSTM':

            def rnn_cell(rnn_size):
                return tf.nn.rnn_cell.BasicLSTMCell(rnn_size)

            cell = tf.nn.rnn_cell.MultiRNNCell(
                [rnn_cell(args.rnn_size) for _ in range(args.rnn_num_layers)])
        elif args.model == 'NTM':
            import ntm.ntm_cell as ntm_cell
            cell = ntm_cell.NTMCell(args.rnn_size,
                                    args.memory_size,
                                    args.memory_vector_dim,
                                    read_head_num=args.read_head_num,
                                    write_head_num=args.write_head_num,
                                    addressing_mode='content_and_location',
                                    output_dim=args.output_dim)
        elif args.model == 'MANN':
            import ntm.mann_cell as mann_cell
            cell = mann_cell.MANNCell(args.rnn_size,
                                      args.memory_size,
                                      args.memory_vector_dim,
                                      head_num=args.read_head_num)
        elif args.model == 'MANN2':
            import ntm.mann_cell_2 as mann_cell
            cell = mann_cell.MANNCell(args.rnn_size,
                                      args.memory_size,
                                      args.memory_vector_dim,
                                      head_num=args.read_head_num)
        elif args.model == 'ACT':
            from tf_rnn_adaptive.act_wrapper import ACTWrapper

            def rnn_cell(rnn_size):
                return tf.nn.rnn_cell.BasicLSTMCell(rnn_size)

            # inner_cell = tf.nn.rnn_cell.MultiRNNCell([rnn_cell(args.rnn_size) for _ in range(args.rnn_num_layers)])
            inner_cell = rnn_cell(args.rnn_size)
            cell = ACTWrapper(inner_cell, ponder_limit=10)
        elif args.model == 'DNC':
            from dnc import dnc
            access_config = {
                "memory_size": args.memory_size,
                "word_size": args.memory_vector_dim,
                "num_reads": args.read_head_num,
                "num_writes": args.write_head_num,
            }
            controller_config = {
                "hidden_size": args.rnn_size,
            }
            clip_value = args.clip_value
            cell = dnc.DNC(access_config, controller_config, args.output_dim,
                           clip_value)
        elif args.model == 'ACT-DNC':
            from tf_rnn_adaptive.act_wrapper import ACTWrapper
            from dnc import dnc
            access_config = {
                "memory_size": args.memory_size,
                "word_size": args.memory_vector_dim,
                "num_reads": args.read_head_num,
                "num_writes": args.write_head_num,
            }
            controller_config = {
                "hidden_size": args.rnn_size,
            }
            clip_value = args.clip_value

            dnc_core = dnc.DNC(access_config, controller_config,
                               args.output_dim, clip_value)
            cell = ACTWrapper(dnc_core, ponder_limit=10)
        elif args.model == 'MY-ACT':
            from my_act.act_wrapper import ACTWrapper
            from dnc import dnc

            # main dnc
            with tf.variable_scope('act_wrapper'):
                access_config = {
                    "memory_size": args.memory_size,
                    "word_size": args.memory_vector_dim,
                    "num_reads": args.read_head_num,
                    "num_writes": args.write_head_num,
                }
                controller_config = {
                    "hidden_size": args.rnn_size,
                }
                clip_value = args.clip_value
                main_dnc = dnc.DNC(access_config,
                                   controller_config,
                                   args.output_dim,
                                   clip_value,
                                   name='main_dnc')

            # auxiliary dnc
            # use memory_vector_dim as aux's output size
            # so that info does not lose when communicating with main dnc
            with tf.variable_scope('act_wrapper'):
                aux_access_config = {
                    "memory_size": 8,
                    "word_size": args.memory_vector_dim,
                    "num_reads": 1,
                    "num_writes": 1
                }
                aux_controller_config = {
                    "hidden_size": args.rnn_size,
                }
                aux_dnc = dnc.DNC(aux_access_config,
                                  aux_controller_config,
                                  args.memory_vector_dim,
                                  clip_value,
                                  name='aux_dnc')

            cell = ACTWrapper(main_dnc,
                              aux_dnc,
                              ponder_limit=10,
                              divergence_type=args.divergence_loss)
        else:
            raise Exception('Unknown model: `{}`'.format(args.model))

        if args.model == 'ACT-DNC':
            state = dnc_core.initial_state(args.batch_size)
        elif args.model == 'DNC':
            state = cell.initial_state(args.batch_size)
        elif args.model == 'MY-ACT':
            state = main_dnc.initial_state(args.batch_size)
        else:
            state = cell.zero_state(args.batch_size, tf.float32)
        self.state_list = [state]  # For debugging
        self.o = []
        for t in range(args.seq_length):
            output, state = cell(
                tf.concat([self.x_image[:, t, :], self.x_label[:, t, :]],
                          axis=1), state)
            # output, state = cell(self.y[:, t, :], state)
            with tf.variable_scope("o2o", reuse=(t > 0)):
                o2o_w = tf.get_variable(
                    'o2o_w', [output.get_shape()[1], args.output_dim],
                    initializer=tf.random_uniform_initializer(minval=-0.1,
                                                              maxval=0.1))
                # initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1))
                o2o_b = tf.get_variable(
                    'o2o_b', [args.output_dim],
                    initializer=tf.random_uniform_initializer(minval=-0.1,
                                                              maxval=0.1))
                # initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1))
                output = tf.nn.xw_plus_b(output, o2o_w, o2o_b)
            if args.label_type == 'one_hot':
                output = tf.nn.softmax(output, dim=1)
            elif args.label_type == 'five_hot':
                output = tf.stack(
                    [tf.nn.softmax(o) for o in tf.split(output, 5, axis=1)],
                    axis=1)
            self.o.append(output)
            self.state_list.append(state)
        self.o = tf.stack(self.o, axis=1)
        self.state_list.append(state)

        eps = 1e-8
        if args.label_type == 'one_hot':
            self.learning_loss = -tf.reduce_mean(  # cross entropy function
                tf.reduce_sum(self.y * tf.log(self.o + eps), axis=[1, 2]))
        elif args.label_type == 'five_hot':
            self.learning_loss = -tf.reduce_mean(  # cross entropy function
                tf.reduce_sum(tf.stack(tf.split(self.y, 5, axis=2), axis=2) *
                              tf.log(self.o + eps),
                              axis=[1, 2, 3]))
        self.o = tf.reshape(self.o,
                            shape=[args.batch_size, args.seq_length, -1])
        self.learning_loss_summary = tf.summary.scalar('learning_loss',
                                                       self.learning_loss)
        """ ponder loss """
        if 'ACT' in args.model:
            time_penalty = 0.001
            self._ponder_loss = time_penalty * cell.get_ponder_cost(
                args.seq_length)
            self.learning_loss += self._ponder_loss
            self.ponder_steps = cell.get_ponder_steps(args.seq_length)
            self.mean_ponder_steps = tf.reduce_mean(self.ponder_steps)
        else:
            self.mean_ponder_steps = tf.constant(0)
        """ memory divergence loss """
        if args.model == 'MY-ACT':
            if args.divergence_loss is not None:
                divergence_penalty = 0.001
                self.learning_loss += cell.get_memory_divergence_loss(
                ) * divergence_penalty

        with tf.variable_scope('optimizer'):
            self.optimizer = tf.train.AdamOptimizer(
                learning_rate=args.learning_rate)
            # self.optimizer = tf.train.RMSPropOptimizer(
            #     learning_rate=args.learning_rate, momentum=0.9, decay=0.95
            # )
            # gvs = self.optimizer.compute_gradients(self.learning_loss)
            # capped_gvs = [(tf.clip_by_value(grad, -10., 10.), var) for grad, var in gvs]
            # self.train_op = self.optimizer.apply_gradients(gvs)
            self.train_op = self.optimizer.minimize(self.learning_loss)
Exemplo n.º 4
0
    def __init__(self, args):

        self.x_data = tf.placeholder(dtype=tf.float32,
                                     shape=[args.batch_size, args.seq_length],
                                     name="x_squences")
        self.x_label = tf.placeholder(dtype=tf.float32,
                                      shape=[args.batch_size, args.output_dim],
                                      name="x_label")
        self.y = tf.placeholder(dtype=tf.float32,
                                shape=[args.batch_size, args.output_dim],
                                name="y")

        if args.model == 'LSTM':

            def rnn_cell(rnn_size):
                return tf.nn.rnn_cell.BasicLSTMCell(rnn_size)

            cell = tf.nn.rnn_cell.MultiRNNCell(
                [rnn_cell(args.rnn_size) for _ in range(args.rnn_num_layers)])
        elif args.model == 'NTM':
            import ntm.ntm_cell as ntm_cell
            cell = ntm_cell.NTMCell(args.rnn_size,
                                    args.memory_size,
                                    args.memory_vector_dim,
                                    read_head_num=args.read_head_num,
                                    write_head_num=args.write_head_num,
                                    addressing_mode='content_and_location',
                                    output_dim=args.output_dim)
        elif args.model == 'MANN':
            import ntm.mann_cell as mann_cell
            cell = mann_cell.MANNCell(rnn_size=args.rnn_size,
                                      memory_size=args.memory_size,
                                      memory_vector_dim=args.memory_vector_dim,
                                      head_num=args.read_head_num,
                                      rnn_layers=args.rnn_num_layers)

        state = cell.zero_state(args.batch_size, tf.float32)
        self.state_list = [state]  # For debugging
        b = tf.concat([self.x_data[:, :], self.x_label[:, :]], axis=1)
        output, state = cell(b, state)
        with tf.variable_scope("o2o", reuse=tf.AUTO_REUSE):
            o2o_w = tf.get_variable(
                'o2o_w', [output.get_shape()[1], args.output_dim],
                initializer=tf.random_uniform_initializer(minval=-0.1,
                                                          maxval=0.1))
            o2o_b = tf.get_variable('o2o_b', [args.output_dim],
                                    initializer=tf.random_uniform_initializer(
                                        minval=-0.1, maxval=0.1))
            output = tf.nn.xw_plus_b(output, o2o_w, o2o_b)

        output = tf.nn.softmax(output)
        self.state_list.append(state)
        self.o = tf.squeeze(output, name="output")

        eps = 1e-8
        self.learning_loss = -tf.reduce_mean(  # cross entropy function
            tf.reduce_sum(self.y * tf.log(self.o + eps), axis=[1]))
        self.accuracy, self.acc_op = tf.metrics.accuracy(
            labels=tf.argmax(self.y, 1),
            predictions=tf.argmax(self.o, 1),
            name="accuracy")
        self.recall, self.rec_op = tf.metrics.recall_at_k(labels=tf.cast(
            self.y, tf.int64),
                                                          predictions=self.o,
                                                          k=100)
        self.precision, self.pre_op = tf.metrics.precision(
            labels=tf.argmax(self.y, 1),
            predictions=tf.argmax(self.o, 1),
            name="precision")

        tf.summary.scalar('learning_loss', self.learning_loss)
        tf.summary.scalar('Accuracy', self.accuracy)
        tf.summary.scalar('Recall_k', self.recall)
        tf.summary.scalar('precision', self.precision)
        self.merged_summary_op = tf.summary.merge_all()

        with tf.variable_scope('optimizer'):
            self.optimizer = tf.train.AdamOptimizer(
                learning_rate=args.learning_rate)
            # self.optimizer = tf.train.RMSPropOptimizer(
            #     learning_rate=args.learning_rate, momentum=0.9, decay=0.95
            # )
            # gvs = self.optimizer.compute_gradients(self.learning_loss)
            # capped_gvs = [(tf.clip_by_value(grad, -10., 10.), var) for grad, var in gvs]
            # self.train_op = self.optimizer.apply_gradients(gvs)
            self.train_op = self.optimizer.minimize(self.learning_loss,
                                                    name="train_op")
Exemplo n.º 5
0
    def __init__(self, args):
        if args.label_type == 'one_hot':
            args.output_dim = args.n_classes
        elif args.label_type == 'five_hot':
            args.output_dim = 25

        self.x_image = tf.placeholder(dtype=tf.float32,
                                      shape=[
                                          args.batch_size, args.seq_length,
                                          args.image_width * args.image_height
                                      ])
        self.x_label = tf.placeholder(
            dtype=tf.float32,
            shape=[args.batch_size, args.seq_length, args.output_dim])
        self.y = tf.placeholder(
            dtype=tf.float32,
            shape=[args.batch_size, args.seq_length, args.output_dim])

        if args.model == 'LSTM':

            def rnn_cell(rnn_size):
                return tf.nn.rnn_cell.BasicLSTMCell(rnn_size)

            cell = tf.nn.rnn_cell.MultiRNNCell(
                [rnn_cell(args.rnn_size) for _ in range(args.rnn_num_layers)])
        elif args.model == 'NTM':
            import ntm.ntm_cell as ntm_cell
            cell = ntm_cell.NTMCell(args.rnn_size,
                                    args.memory_size,
                                    args.memory_vector_dim,
                                    read_head_num=args.read_head_num,
                                    write_head_num=args.write_head_num,
                                    addressing_mode='content_and_location',
                                    output_dim=args.output_dim)
        elif args.model == 'MANN':
            import ntm.mann_cell as mann_cell
            cell = mann_cell.MANNCell(args.rnn_size,
                                      args.memory_size,
                                      args.memory_vector_dim,
                                      head_num=args.read_head_num)
        elif args.model == 'MANN2':
            import ntm.mann_cell_2 as mann_cell
            cell = mann_cell.MANNCell(args.rnn_size,
                                      args.memory_size,
                                      args.memory_vector_dim,
                                      head_num=args.read_head_num)

        # Zero out the memory state
        state = cell.zero_state(args.batch_size, tf.float32)
        self.state_list = [state
                           ]  # For debugging. Keep track of previous states
        self.o = []
        for t in range(args.seq_length):
            # So this is going and calling __call__
            # it is passing both the image and x_label
            output, state = cell(
                tf.concat([self.x_image[:, t, :], self.x_label[:, t, :]],
                          axis=1), state)
            # output, state = cell(self.y[:, t, :], state)
            # go from the memory stored dimensionality to the number of classes / predictions
            with tf.variable_scope("o2o", reuse=(t > 0)):
                o2o_w = tf.get_variable(
                    'o2o_w', [output.get_shape()[1], args.output_dim],
                    initializer=tf.random_uniform_initializer(minval=-0.1,
                                                              maxval=0.1))
                # initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1))
                o2o_b = tf.get_variable(
                    'o2o_b', [args.output_dim],
                    initializer=tf.random_uniform_initializer(minval=-0.1,
                                                              maxval=0.1))
                # initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1))
                # matmul + bias
                output = tf.nn.xw_plus_b(output, o2o_w, o2o_b)
            if args.label_type == 'one_hot':
                output = tf.nn.softmax(output, dim=1)
            elif args.label_type == 'five_hot':
                output = tf.stack(
                    [tf.nn.softmax(o) for o in tf.split(output, 5, axis=1)],
                    axis=1)
            self.o.append(output)
            self.state_list.append(state)
        self.o = tf.stack(self.o, axis=1)
        self.state_list.append(state)

        eps = 1e-8
        if args.label_type == 'one_hot':
            self.learning_loss = -tf.reduce_mean(  # cross entropy function
                tf.reduce_sum(self.y * tf.log(self.o + eps), axis=[1, 2]))
        elif args.label_type == 'five_hot':
            self.learning_loss = -tf.reduce_mean(  # cross entropy function
                tf.reduce_sum(tf.stack(tf.split(self.y, 5, axis=2), axis=2) *
                              tf.log(self.o + eps),
                              axis=[1, 2, 3]))
        self.o = tf.reshape(self.o,
                            shape=[args.batch_size, args.seq_length, -1])
        self.learning_loss_summary = tf.summary.scalar('learning_loss',
                                                       self.learning_loss)

        with tf.variable_scope('optimizer'):
            self.optimizer = tf.train.AdamOptimizer(
                learning_rate=args.learning_rate)
            # self.optimizer = tf.train.RMSPropOptimizer(
            #     learning_rate=args.learning_rate, momentum=0.9, decay=0.95
            # )
            # gvs = self.optimizer.compute_gradients(self.learning_loss)
            # capped_gvs = [(tf.clip_by_value(grad, -10., 10.), var) for grad, var in gvs]
            # self.train_op = self.optimizer.apply_gradients(gvs)
            self.train_op = self.optimizer.minimize(self.learning_loss)
Exemplo n.º 6
0
    def __init__(self, args):
        if args.label_type == 'one_hot':
            args.output_dim = args.n_classes
        elif args.label_type == 'five_hot':
            args.output_dim = 25

        self.x_image = tf.placeholder(dtype=tf.float32,
                                      shape=[
                                          args.batch_size, args.seq_length,
                                          args.image_width * args.image_height
                                      ])
        self.x_label = tf.placeholder(
            dtype=tf.float32,
            shape=[args.batch_size, args.seq_length, args.output_dim])
        self.y = tf.placeholder(
            dtype=tf.float32,
            shape=[args.batch_size, args.seq_length, args.output_dim])

        if args.model == 'LSTM':

            def rnn_cell(rnn_size):
                return tf.nn.rnn_cell.BasicLSTMCell(rnn_size)

            cell = tf.nn.rnn_cell.MultiRNNCell(
                [rnn_cell(args.rnn_size) for _ in range(args.rnn_num_layers)])
        elif args.model == 'NTM':
            import ntm.ntm_cell as ntm_cell
            cell = ntm_cell.NTMCell(args.rnn_size,
                                    args.memory_size,
                                    args.memory_vector_dim,
                                    read_head_num=args.read_head_num,
                                    write_head_num=args.write_head_num,
                                    addressing_mode='content_and_location',
                                    output_dim=args.output_dim)
        elif args.model == 'MANN':  #this is the standard MANN cell
            import ntm.mann_cell as mann_cell  #here no seperate write heads
            cell = mann_cell.MANNCell(
                args.rnn_size,
                args.memory_size,
                args.memory_vector_dim,  #initalizing the memory cell.
                head_num=args.read_head_num)
        elif args.model == 'MANN2':
            import ntm.mann_cell_2 as mann_cell
            cell = mann_cell.MANNCell(args.rnn_size,
                                      args.memory_size,
                                      args.memory_vector_dim,
                                      head_num=args.read_head_num)

        state = cell.zero_state(
            args.batch_size,
            tf.float32)  #Get the zero state or initialize the state
        self.state_list = [state]  # For debugging keep the zero state
        self.o = []  #for the output
        for t in range(args.seq_length):  #till the end of the sequence
            #here the x label should be time shifted, one step shifted
            output, state = cell(
                tf.concat([self.x_image[:, t, :], self.x_label[:, t, :]],
                          axis=1), state
            )  #call function in the MANN cell. get the output and the state dictionar
            # output, state = cell(self.y[:, t, :], state)
            with tf.variable_scope("o2o", reuse=(
                    t > 0
            )):  #sending the output via a fully connected to get real output
                o2o_w = tf.get_variable(
                    'o2o_w', [output.get_shape()[1], args.output_dim],
                    initializer=tf.random_uniform_initializer(minval=-0.1,
                                                              maxval=0.1))
                # initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1))
                o2o_b = tf.get_variable(
                    'o2o_b', [args.output_dim],
                    initializer=tf.random_uniform_initializer(minval=-0.1,
                                                              maxval=0.1))
                # initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1))
                output = tf.nn.xw_plus_b(
                    output, o2o_w, o2o_b
                )  #output is the cell output for each time step them we use fully connectd layers
            if args.label_type == 'one_hot':
                output = tf.nn.softmax(output, dim=1)  #softmax
            elif args.label_type == 'five_hot':
                output = tf.stack(
                    [tf.nn.softmax(o) for o in tf.split(output, 5, axis=1)],
                    axis=1)
            self.o.append(output)  #stacking the outputs
            self.state_list.append(
                state)  #keeping the states in the state list
        self.o = tf.stack(self.o, axis=1)
        self.state_list.append(state)  #get the final state and stack it

        eps = 1e-8
        if args.label_type == 'one_hot':
            self.learning_loss = -tf.reduce_mean(  # cross entropy function
                tf.reduce_sum(self.y * tf.log(self.o + eps), axis=[1, 2]))
        elif args.label_type == 'five_hot':
            self.learning_loss = -tf.reduce_mean(  # cross entropy function
                tf.reduce_sum(tf.stack(tf.split(self.y, 5, axis=2), axis=2) *
                              tf.log(self.o + eps),
                              axis=[1, 2, 3]))
        self.o = tf.reshape(self.o,
                            shape=[args.batch_size, args.seq_length, -1])
        self.learning_loss_summary = tf.summary.scalar('learning_loss',
                                                       self.learning_loss)

        with tf.variable_scope('optimizer'):
            self.optimizer = tf.train.AdamOptimizer(
                learning_rate=args.learning_rate)  #optimizing the loss
            # self.optimizer = tf.train.RMSPropOptimizer(
            #     learning_rate=args.learning_rate, momentum=0.9, decay=0.95
            # )
            # gvs = self.optimizer.compute_gradients(self.learning_loss)
            # capped_gvs = [(tf.clip_by_value(grad, -10., 10.), var) for grad, var in gvs]
            # self.train_op = self.optimizer.apply_gradients(gvs)
            self.train_op = self.optimizer.minimize(
                self.learning_loss)  #optimizing the loss