示例#1
0
    def __init__(self,
                 access_config,
                 controller_config,
                 output_size,
                 name='components'):
        """Initializes the DNC core.

    Args:
      access_config: dictionary of access module configurations.
      controller_config: dictionary of controller (LSTM) module configurations.
      output_size: output dimension size of core.
      clip_value: clips controller and core output values to between
          `[-clip_value, clip_value]` if specified.
      name: module name (default 'dnc').

    Raises:
      TypeError: if direct_input_size is not None for any access module other
        than KeyValueMemory.
    """
        super(Components, self).__init__(name=name)

        with self._enter_variable_scope():
            self.controller = snt.DeepRNN(
                [snt.LSTM(**controller_config),
                 snt.LSTM(**controller_config)])
            #self.controller = snt.LSTM(**controller_config)
            self.access = access.MemoryAccess(**access_config)

            self.output_linear = snt.Linear(output_size=output_size,
                                            use_bias=False)

            if FLAGS.is_input_embedder:
                self.input_embedder = snt.Sequential(
                    [snt.Linear(output_size=64, use_bias=True), tf.nn.tanh])
示例#2
0
文件: dnc.py 项目: qoffee/dnc-1
  def __init__(self,
               access_config,
               controller_config,
               output_size,
               clip_value=None,
               name='dnc'):
    """Initializes the DNC core.

    Args:
      access_config: dictionary of access module configurations.
      controller_config: dictionary of controller (LSTM) module configurations.
      output_size: output dimension size of core.
      clip_value: clips controller and core output values to between
          `[-clip_value, clip_value]` if specified.
      name: module name (default 'dnc').

    Raises:
      TypeError: if direct_input_size is not None for any access module other
        than KeyValueMemory.
    """
    super(DNC, self).__init__(name=name)

    with self._enter_variable_scope():
      self._controller = snt.LSTM(**controller_config)
      self._access = access.MemoryAccess(**access_config)

    self._access_output_size = np.prod(self._access.output_size.as_list())
    self._output_size = output_size
    self._clip_value = clip_value or 0

    self._output_size = tf.TensorShape([output_size])
    self._state_size = DNCState(
        access_output=self._access_output_size,
        access_state=self._access.state_size,
        controller_state=self._controller.state_size)
示例#3
0
    def __init__(self,
                 access_config,
                 controller_config,
                 output_size,
                 name='components'):
        """Initializes the DNC core.

    Args:
      access_config: dictionary of access module configurations.
      controller_config: dictionary of controller (LSTM) module configurations.
      output_size: output dimension size of core.
      clip_value: clips controller and core output values to between
          `[-clip_value, clip_value]` if specified.
      name: module name (default 'dnc').

    Raises:
      TypeError: if direct_input_size is not None for any access module other
        than KeyValueMemory.
    """
        super(Components, self).__init__(name=name)

        with self._enter_variable_scope():
            self.controller = snt.DeepRNN([
                snt.LSTM(**controller_config) for _ in range(FLAGS.lstm_depth)
            ],
                                          skip_connections=True)
            #self.controller = snt.LSTM(**controller_config)
            self.access = access.MemoryAccess(**access_config)

            self.output_linear = snt.Linear(output_size=output_size,
                                            use_bias=False)

            if FLAGS.is_input_embedder:
                self.input_embedder = snt.Sequential(
                    [snt.Linear(output_size=64, use_bias=True), tf.nn.tanh])

            if FLAGS.is_variable_initial_states:

                def c_fn(x):
                    shape = x.get_shape().as_list()
                    y = tf.Variable(initial_value=tf.zeros(shape=[1] +
                                                           shape[1:]),
                                    dtype=tf.float32,
                                    trainable=True)
                    return y

                self.initial_controller_state = tf.contrib.framework.nest.map_structure(
                    c_fn, self.controller.initial_state(1, tf.float32))
示例#4
0
 def setUp(self):
     self.module = access.MemoryAccess(MEMORY_SIZE, WORD_SIZE, NUM_READS,
                                       NUM_WRITES)
     self.initial_state = self.module.initial_state(BATCH_SIZE)