def read(self, prev_interface_tuple, mem): """ returns the data read from memory. :param prev_interface_tuple: Tuple [previous read, previous write, prev usage, prev links] :param mem: the memory [batch_size, content_size, memory_size] :return: the read data [batch_size, content_size] """ (wt, _, _, _) = prev_interface_tuple memory = Memory(mem) read_data = memory.attention_read(wt) # flatten the data_gen in the last 2 dimensions sz = read_data.size()[:-2] return read_data.view(*sz, self.read_size)
def edit_memory(self, interface_tuple, update_data, mem): """ Edits the external memory and then returns it. :param update_data: the parameters from the controllers [dictionary] :param prev_interface_tuple: Tuple [previous read, previous write, prev usage, prev links] :param mem: the memory [batch_size, content_size, memory_size] :return: edited memory [batch_size, content_size, memory_size] """ (_, write_attention, _, _) = interface_tuple # Write to memory write_gate = update_data['write_gate'] add = update_data['write_vectors'] erase = update_data['erase_vectors'] if self.use_extra_write_gate: add = add * write_gate erase = erase * write_gate memory = Memory(mem) memory.erase_weighted(erase, write_attention) memory.add_weighted(add, write_attention) mem = memory.content return mem
def update_write(self, update_data, prev_interface_tuple, mem): """ Updates the write attention switching between the NTM and DNC mechanisms. :param update_data: the parameters from the controllers [dictionary] :param prev_interface_tuple: Tuple [previous read, previous write, prev usage, prev links] :param prev_memory_BxMxA: the memory of the previous step (class) :return: The new interface tuple with an updated usage and write attention """ (prev_read_attention, prev_write_attention, prev_usage, prev_links) = prev_interface_tuple # Obtain update parameters key = update_data['write_content_keys'] strength = update_data['write_content_strengths'] gate = update_data['allocation_gate'] # retrieve memory Class memory = Memory(mem) free_gate = update_data['free_gate'] usage = self.mem_usage.calculate_usage(prev_write_attention, free_gate, prev_read_attention, prev_usage) # update the attention using either the NTM write mechanism (True) or # the DNC (False) if self.use_ntm_write: # Parameters for shift addressing shift = update_data['shifts'] sharp = update_data['sharpening'] write_attention = self.update_weight(prev_write_attention, memory, strength, gate, key, shift, sharp) else: write_gate = update_data['write_gate'] allocation_gate = gate write_attention = self.update_write_weight(usage, memory, allocation_gate, write_gate, key, strength) interface_state_tuple = InterfaceStateTuple(prev_read_attention, write_attention, usage, prev_links) return interface_state_tuple
def update_read(self, update_data, prev_interface_tuple, mem): """ Updates the read attention switching between the NTM and DNC mechanisms. :param update_data: the parameters from the controllers [dictionary] :param prev_interface_tuple: Tuple [previous read, previous write, prev usage, prev links[ :param prev_memory_BxMxA: the memory of the previous step (class) :return: The new interface tuple with an updated usage and write attention """ (prev_read_attention, prev_write_attention, prev_usage, prev_links) = prev_interface_tuple # Parameters for the content addressing key = update_data['read_content_keys'] strength = update_data['read_content_strengths'] # retrieve memory Class memory = Memory(mem) # update the attention using either the NTM read mechanism (True) or # the DNC (False) if self.use_ntm_read: # Parameters for shift addressing shift = update_data['shifts_read'] sharp = update_data['sharpening_read'] gate = update_data['read_mode_shift'] read_attention = self.update_weight(prev_read_attention, memory, strength, gate, key, shift, sharp) links = prev_links else: read_mode = update_data['read_mode'] links = self.temporal_linkage.calc_temporal_links( prev_write_attention, prev_links) read_attention = self.update_read_weight(links, memory, prev_read_attention, read_mode, key, strength) interface_state_tuple = InterfaceStateTuple(read_attention, prev_write_attention, prev_usage, links) return interface_state_tuple