def __init__(self, access_config, controller_config, output_size, clip_value=None, name='dnc'): """Initializes the DNC core. Args: access_config: dictionary of access module configurations. controller_config: dictionary of controller (LSTM) module configurations. output_size: output dimension size of core. clip_value: clips controller and core output values to between `[-clip_value, clip_value]` if specified. name: module name (default 'dnc'). Raises: TypeError: if direct_input_size is not None for any access module other than KeyValueMemory. """ super(DNC, self).__init__(name=name) with self._enter_variable_scope(): self._controller = snt.LSTM(**controller_config) self._access = access.MemoryAccess(**access_config) self._access_output_size = np.prod(self._access.output_size.as_list()) self._output_size = output_size self._clip_value = clip_value or 0 self._output_size = tf.TensorShape([output_size]) self._state_size = DNCState( access_output=self._access_output_size, access_state=self._access.state_size, controller_state=self._controller.state_size)
def __init__(self, access_config, controller_config, output_size, clip_value=None, dropout=0.0, mode=None, batch_size=None, name='dnc'): """Initializes the DNC core. Args: access_config: dictionary of access module configurations. controller_config: dictionary of controller (LSTM) module configurations. output_size: output dimension size of core. clip_value: clips controller and core output values to between `[-clip_value, clip_value]` if specified. name: module name (default 'dnc'). Raises: TypeError: if direct_input_size is not None for any access module other than KeyValueMemory. """ super(DNC, self).__init__(name=name) self.dropout = dropout if mode == tf.contrib.learn.ModeKeys.TRAIN else 0.0 self.access_config = access_config with self._enter_variable_scope(): def single_cell(num_units): cell = tf.contrib.rnn.BasicLSTMCell(num_units, forget_bias=1.0) if self.dropout > 0.0: cell = tf.contrib.rnn.DropoutWrapper( cell=cell, input_keep_prob=(1.0 - self.dropout)) return cell self._controller = tf.contrib.rnn.MultiRNNCell([ single_cell(controller_config['num_units']) for _ in range(controller_config['num_layers']) ]) self._access = access.MemoryAccess(**access_config) self._access_output_size = np.prod(self._access.output_size.as_list()) self._output_size = output_size self._clip_value = clip_value or 0 self._output_size = tf.TensorShape([output_size]) # self._state_size = DNCState( # access_output=self._access_output_size, # access_state=self._access.state_size, # controller_state=self._controller.state_size) self.batch_size = batch_size self._state_size = DNCState( access_output=tf.TensorShape((self.access_config['word_size'])), access_state=access.AccessState( memory=tf.TensorShape((self.access_config['memory_size'] * self.access_config['word_size'])), read_weights=tf.TensorShape( (self.access_config['memory_size'])), write_weights=tf.TensorShape( (self.access_config['memory_size'])), linkage=TemporalLinkageState( link=tf.TensorShape((self.access_config['memory_size'] * self.access_config['memory_size'])), precedence_weights=tf.TensorShape( (self.access_config['memory_size']))), usage=self._access.state_size.usage), controller_state=self._controller.state_size)
def setUp(self): self.module = access.MemoryAccess(MEMORY_SIZE, WORD_SIZE, NUM_READS, NUM_WRITES) self.initial_state = self.module.initial_state(BATCH_SIZE)