예제 #1
0
    def build_graph(self):
        """
        builds the computational graph that performs a step-by-step evaluation
        of the input data batches
        """

        self.unpacked_input_data = utility.unpack_into_tensorarray(
            self.input_data, 1, self.sequence_length)

        outputs = tf.TensorArray(tf.float32, self.sequence_length)
        free_gates = tf.TensorArray(tf.float32, self.sequence_length)
        allocation_gates = tf.TensorArray(tf.float32, self.sequence_length)
        write_gates = tf.TensorArray(tf.float32, self.sequence_length)
        read_weightings = tf.TensorArray(tf.float32, self.sequence_length)
        write_weightings = tf.TensorArray(tf.float32, self.sequence_length)
        usage_vectors = tf.TensorArray(tf.float32, self.sequence_length)

        controller_state = self.controller.get_state(
        ) if self.controller.has_recurrent_nn else (tf.zeros(1), tf.zeros(1))
        memory_state = self.memory.init_memory()
        if not isinstance(controller_state, LSTMStateTuple):
            controller_state = LSTMStateTuple(controller_state[0],
                                              controller_state[1])
        final_results = None

        with tf.variable_scope("sequence_loop") as scope:
            time = tf.constant(0, dtype=tf.int32)

            final_results = tf.while_loop(
                cond=lambda time, *_: time < self.sequence_length,
                body=self._loop_body,
                loop_vars=(time, memory_state, outputs, free_gates,
                           allocation_gates, write_gates, read_weightings,
                           write_weightings, usage_vectors, controller_state),
                parallel_iterations=32,
                swap_memory=True)

        dependencies = []
        if self.controller.has_recurrent_nn:
            dependencies.append(self.controller.update_state(final_results[9]))

        with tf.control_dependencies(dependencies):
            self.packed_output = utility.pack_into_tensor(final_results[2],
                                                          axis=1)
            self.packed_memory_view = {
                'free_gates':
                utility.pack_into_tensor(final_results[3], axis=1),
                'allocation_gates':
                utility.pack_into_tensor(final_results[4], axis=1),
                'write_gates':
                utility.pack_into_tensor(final_results[5], axis=1),
                'read_weightings':
                utility.pack_into_tensor(final_results[6], axis=1),
                'write_weightings':
                utility.pack_into_tensor(final_results[7], axis=1),
                'usage_vectors':
                utility.pack_into_tensor(final_results[8], axis=1)
            }
예제 #2
0
    def _loop_body(self, time, memory_state, outputs, free_gates,
                   allocation_gates, write_gates, read_weightings,
                   write_weightings, usage_vectors, controller_state):
        """
        the body of the DNC sequence processing loop

        Parameters:
        ----------
        time: Tensor
        outputs: TensorArray
        memory_state: Tuple
        free_gates: TensorArray
        allocation_gates: TensorArray
        write_gates: TensorArray
        read_weightings: TensorArray,
        write_weightings: TensorArray,
        usage_vectors: TensorArray,
        controller_state: Tuple

        Returns: Tuple containing all updated arguments
        """

        step_input = self.unpacked_input_data.read(time)

        output_list = self._step_op(step_input, memory_state, controller_state)

        # update memory parameters

        new_controller_state = tf.zeros(1)
        new_memory_state = tuple(output_list[0:7])

        new_controller_state = LSTMStateTuple(output_list[11], output_list[12])

        outputs = outputs.write(time, output_list[7])

        # collecting memory view for the current step
        free_gates = free_gates.write(time, output_list[8])
        allocation_gates = allocation_gates.write(time, output_list[9])
        write_gates = write_gates.write(time, output_list[10])
        read_weightings = read_weightings.write(time, output_list[5])
        write_weightings = write_weightings.write(time, output_list[4])
        usage_vectors = usage_vectors.write(time, output_list[1])

        return (time + 1, new_memory_state, outputs, free_gates,
                allocation_gates, write_gates, read_weightings,
                write_weightings, usage_vectors, new_controller_state)
예제 #3
0
    def __call__(self, inputs, state, scope=None):
        output, new_state = self._cell(inputs, state, scope)

        def train():
            cell_update = nn_ops.dropout(
                state[0], self._cell_out_prob,
                seed=self._seed) + nn_ops.dropout(
                    new_state[0], 1 - self._cell_out_prob, seed=self._seed)
            state_update = nn_ops.dropout(
                state[1], self._state_out_prob,
                seed=self._seed) + nn_ops.dropout(
                    new_state[1], 1 - self._state_out_prob, seed=self._seed)
            return cell_update, state_update

        def test():
            cell_update = state[0] * self._cell_out_prob + new_state[0] * (
                1 - self._cell_out_prob)
            state_update = state[1] * self._state_out_prob + new_state[1] * (
                1 - self._state_out_prob)
            return cell_update, state_update

        cell_update, state_update = tf.cond(self._training, train, test)
        new_state_update = LSTMStateTuple(cell_update, state_update)
        return output, new_state_update
예제 #4
0
 def get_state(self):
     return LSTMStateTuple(tf.zeros(1), tf.zeros(1))
 def get_state(self):
     return LSTMStateTuple(self.output, self.state)
예제 #6
0
    def __call__(self, inputs, state, scope=None):
        num_proj = self._num_units if self._num_proj is None else self._num_proj

        if self._state_is_tuple:
            (c_prev, m_prev) = state
        else:
            c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
            m_prev = array_ops.slice(state, [0, self._num_units],
                                     [-1, num_proj])

        dtype = inputs.dtype
        input_size = inputs.get_shape().with_rank(2)[1]
        if input_size.value is None:
            raise ValueError(
                "Could not infer input size from inputs.get_shape()[-1]")
        with vs.variable_scope(scope or "lstm_cell",
                               initializer=self._initializer) as unit_scope:
            #concat_w = _get_concat_variable("W", [input_size.value + num_proj, 4 * self._num_units], dtype, self._num_unit_shards)
            W_xh = vs.get_variable(
                "W_xh",
                shape=[input_size.value, 4 * self._num_units],
                dtype=dtype)
            W_hh = vs.get_variable(
                "W_hh",
                shape=[self._num_units, 4 * self._num_units],
                dtype=dtype)
            xh = math_ops.matmul(inputs, W_xh)
            hh = math_ops.matmul(m_prev, W_hh)
            bn_xh = my_batch_norm(xh, self._training, recurrent=True)
            lstm_matrix = bn_xh + hh
            i, j, f, o = array_ops.split(lstm_matrix, 4, 1)
            # Diagonal connections
            if self._use_peepholes:
                with vs.variable_scope(unit_scope) as projection_scope:
                    if self._num_unit_shards is not None:
                        projection_scope.set_partitioner(None)
                w_f_diag = vs.get_variable("w_f_diag",
                                           shape=[self._num_units],
                                           dtype=dtype)
                w_i_diag = vs.get_variable("w_i_diag",
                                           shape=[self._num_units],
                                           dtype=dtype)
                w_o_diag = vs.get_variable("w_o_diag",
                                           shape=[self._num_units],
                                           dtype=dtype)

            if self._use_peepholes:
                c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) *
                     c_prev +
                     sigmoid(i + w_i_diag * c_prev) * self._activation(j))
            else:
                c = (sigmoid(f + self._forget_bias) * c_prev +
                     sigmoid(i) * self._activation(j))

            if self._cell_clip is not None:
                c = clip_ops.clip_by_value(c, -self._cell_clip,
                                           self._cell_clip)

            if self._use_peepholes:
                m = sigmoid(o + w_o_diag * c) * self._activation(c)
            else:
                m = sigmoid(o) * self._activation(c)

            if self._num_proj is not None:
                with vs.variable_scope("projection") as proj_scope:
                    if self._num_proj_shards is not None:
                        proj_scope.set_partitioner(
                            partitioned_variables.fixed_size_partitioner(
                                self._num_proj_shards))
                    m = _linear(m, self._num_proj, bias=False, scope=scope)
                if self._proj_clip is not None:
                    m = clip_ops.clip_by_value(m, -self._proj_clip,
                                               self._proj_clip)
        new_state = (LSTMStateTuple(c, m) if self._state_is_tuple else
                     array_ops.concat_v2([c, m], 1))
        return m, new_state