def __init__(self, config, output_size, clip_value=None, name='dnc'): """Initializes the DNC core. Args: access_config: dictionary of access module configurations. controller_config: dictionary of controller (LSTM) module configurations. output_size: output dimension size of core. clip_value: clips controller and core output values to between `[-clip_value, clip_value]` if specified. name: module name (default 'dnc'). Raises: TypeError: if direct_input_size is not None for any access module other than KeyValueMemory. """ super(DNC, self).__init__(name=name) self._config = config self._access = MemoryAccess(memory_size=self._config.memory_size, word_size=self._config.word_size,\ num_reads=self._config.num_reads, num_writes=self._config.num_writes) self._access_output_size = np.prod(self._access.output_size.as_list()) self._output_size = output_size self._clip_value = clip_value or 0 self._output_size = tf.TensorShape([output_size]) self._state_size = DNCState( access_output=self._access_output_size, access_state=self._access.state_size)
def zero_state(config, batch_size, dtype=np.float32): return DNCState( access_output=np.zeros([batch_size, config.num_reads, config.word_size], dtype=dtype), access_state=MemoryAccess.zero_state(config, batch_size, dtype))
def state_placeholder(config, dtype=tf.float32): return DNCState( access_output=tf.placeholder(dtype, shape=(None, config.num_reads, config.word_size)), access_state=MemoryAccess.state_placeholder(config, dtype), controller_state=tf.nn.rnn_cell.LSTMStateTuple(tf.placeholder(dtype, shape=(None, config.controller_h_size)), \ tf.placeholder(dtype, shape=(None, config.controller_h_size))))
def state_placeholder(config, dtype=tf.float32): return DNCState( access_output=tf.placeholder(dtype, shape=(None, config.num_reads, config.word_size)), access_state=MemoryAccess.state_placeholder(config, dtype))
def zero_state(config, batch_size, dtype=np.float32): return DNCState( access_output=np.zeros([batch_size, config.num_reads, config.word_size], dtype=dtype), access_state=MemoryAccess.zero_state(config, batch_size, dtype), controller_state=tf.nn.rnn_cell.LSTMStateTuple(np.zeros([batch_size, config.controller_h_size], dtype=dtype), \ np.zeros([batch_size, config.controller_h_size], dtype=dtype)))