示例#1
0
    def read(self, prev_interface_tuple, mem):
        """
        returns the data read from memory.

        :param prev_interface_tuple: Tuple [previous read, previous write, prev usage, prev links]
        :param mem: the memory [batch_size, content_size, memory_size]
        :return: the read data [batch_size, content_size]

        """
        (wt, _, _, _) = prev_interface_tuple

        memory = Memory(mem)
        read_data = memory.attention_read(wt)
        # flatten the data_gen in the last 2 dimensions
        sz = read_data.size()[:-2]
        return read_data.view(*sz, self.read_size)
示例#2
0
    def edit_memory(self, interface_tuple, update_data, mem):
        """
        Edits the external memory and then returns it.

        :param update_data: the parameters from the controllers [dictionary]
        :param prev_interface_tuple: Tuple [previous read, previous write, prev usage, prev links]
        :param mem: the memory [batch_size, content_size, memory_size]
        :return: edited memory [batch_size, content_size, memory_size]

        """

        (_, write_attention, _, _) = interface_tuple

        # Write to memory
        write_gate = update_data['write_gate']
        add = update_data['write_vectors']
        erase = update_data['erase_vectors']

        if self.use_extra_write_gate:
            add = add * write_gate
            erase = erase * write_gate

        memory = Memory(mem)

        memory.erase_weighted(erase, write_attention)
        memory.add_weighted(add, write_attention)

        mem = memory.content

        return mem
示例#3
0
    def update_write(self, update_data, prev_interface_tuple, mem):
        """
        Updates the write attention switching between the NTM and DNC
        mechanisms.

        :param update_data: the parameters from the controllers [dictionary]
        :param prev_interface_tuple: Tuple [previous read, previous write, prev usage, prev links]
        :param prev_memory_BxMxA: the memory of the previous step (class)
        :return: The new interface tuple with an updated usage and write attention

        """

        (prev_read_attention, prev_write_attention, prev_usage,
         prev_links) = prev_interface_tuple

        # Obtain update parameters

        key = update_data['write_content_keys']
        strength = update_data['write_content_strengths']
        gate = update_data['allocation_gate']

        # retrieve memory Class
        memory = Memory(mem)

        free_gate = update_data['free_gate']
        usage = self.mem_usage.calculate_usage(prev_write_attention, free_gate,
                                               prev_read_attention, prev_usage)

        # update the attention using either the NTM write mechanism (True) or
        # the DNC (False)
        if self.use_ntm_write:
            # Parameters for shift addressing
            shift = update_data['shifts']
            sharp = update_data['sharpening']

            write_attention = self.update_weight(prev_write_attention, memory,
                                                 strength, gate, key, shift,
                                                 sharp)
        else:
            write_gate = update_data['write_gate']
            allocation_gate = gate
            write_attention = self.update_write_weight(usage, memory,
                                                       allocation_gate,
                                                       write_gate, key,
                                                       strength)

        interface_state_tuple = InterfaceStateTuple(prev_read_attention,
                                                    write_attention, usage,
                                                    prev_links)
        return interface_state_tuple
示例#4
0
    def update_read(self, update_data, prev_interface_tuple, mem):
        """
        Updates the read attention switching between the NTM and DNC
        mechanisms.

        :param update_data: the parameters from the controllers [dictionary]
        :param prev_interface_tuple: Tuple [previous read, previous write, prev usage, prev links[
        :param prev_memory_BxMxA: the memory of the previous step (class)
        :return: The new interface tuple with an updated usage and write attention

        """
        (prev_read_attention, prev_write_attention, prev_usage,
         prev_links) = prev_interface_tuple

        # Parameters for the content addressing
        key = update_data['read_content_keys']
        strength = update_data['read_content_strengths']

        # retrieve memory Class
        memory = Memory(mem)

        # update the attention using either the NTM read mechanism (True) or
        # the DNC (False)
        if self.use_ntm_read:
            # Parameters for shift addressing
            shift = update_data['shifts_read']
            sharp = update_data['sharpening_read']
            gate = update_data['read_mode_shift']

            read_attention = self.update_weight(prev_read_attention, memory,
                                                strength, gate, key, shift,
                                                sharp)
            links = prev_links
        else:

            read_mode = update_data['read_mode']
            links = self.temporal_linkage.calc_temporal_links(
                prev_write_attention, prev_links)
            read_attention = self.update_read_weight(links, memory,
                                                     prev_read_attention,
                                                     read_mode, key, strength)

        interface_state_tuple = InterfaceStateTuple(read_attention,
                                                    prev_write_attention,
                                                    prev_usage, links)
        return interface_state_tuple