Example #1
0
File: ntm.py Project: clemkoa/ntm
 def __init__(self,
              vector_length,
              hidden_size,
              memory_size,
              lstm_controller=True):
     super(NTM, self).__init__()
     self.controller = Controller(lstm_controller,
                                  vector_length + 1 + memory_size[1],
                                  hidden_size)
     self.memory = Memory(memory_size)
     self.read_head = ReadHead(self.memory, hidden_size)
     self.write_head = WriteHead(self.memory, hidden_size)
     self.fc = nn.Linear(hidden_size + memory_size[1], vector_length)
     nn.init.xavier_uniform_(self.fc.weight, gain=1)
     nn.init.normal_(self.fc.bias, std=0.01)
Example #2
0
def model(input_var, batch_size=1, size=1, num_units=100, memory_shape=(128, 20)):

    # Input Layer
    l_input = InputLayer((batch_size, None, size + 1), input_var=input_var)
    _, seqlen, _ = l_input.input_var.shape

    # Neural Turing Machine Layer
    memory = Memory(memory_shape, name='memory', memory_init=lasagne.init.Constant(1e-6), learn_init=False)
    controller = DenseController(l_input, memory_shape=memory_shape,
        num_units=num_units, num_reads=1,
        nonlinearity=lasagne.nonlinearities.rectify,
        name='controller')
    heads = [
        WriteHead(controller, num_shifts=3, memory_shape=memory_shape, name='write', learn_init=False,
            nonlinearity_key=lasagne.nonlinearities.rectify,
            nonlinearity_add=lasagne.nonlinearities.rectify),
        ReadHead(controller, num_shifts=3, memory_shape=memory_shape, name='read', learn_init=False,
            nonlinearity_key=lasagne.nonlinearities.rectify)
    ]
    l_ntm = NTMLayer(l_input, memory=memory, controller=controller, heads=heads)

    # Output Layer
    l_output_reshape = ReshapeLayer(l_ntm, (-1, num_units))
    l_output_dense = DenseLayer(l_output_reshape, num_units=size + 1, nonlinearity=lasagne.nonlinearities.sigmoid, \
        name='dense')
    l_output = ReshapeLayer(l_output_dense, (batch_size, seqlen, size + 1))

    return l_output, l_ntm
Example #3
0
File: ntm.py Project: clemkoa/ntm
class NTM(nn.Module):
    def __init__(self,
                 vector_length,
                 hidden_size,
                 memory_size,
                 lstm_controller=True):
        super(NTM, self).__init__()
        self.controller = Controller(lstm_controller,
                                     vector_length + 1 + memory_size[1],
                                     hidden_size)
        self.memory = Memory(memory_size)
        self.read_head = ReadHead(self.memory, hidden_size)
        self.write_head = WriteHead(self.memory, hidden_size)
        self.fc = nn.Linear(hidden_size + memory_size[1], vector_length)
        nn.init.xavier_uniform_(self.fc.weight, gain=1)
        nn.init.normal_(self.fc.bias, std=0.01)

    def get_initial_state(self, batch_size=1):
        self.memory.reset(batch_size)
        controller_state = self.controller.get_initial_state(batch_size)
        read = self.memory.get_initial_read(batch_size)
        read_head_state = self.read_head.get_initial_state(batch_size)
        write_head_state = self.write_head.get_initial_state(batch_size)
        return (read, read_head_state, write_head_state, controller_state)

    def forward(self, x, previous_state):
        previous_read, previous_read_head_state, previous_write_head_state, previous_controller_state = previous_state
        controller_input = torch.cat([x, previous_read], dim=1)
        controller_output, controller_state = self.controller(
            controller_input, previous_controller_state)
        # Read
        read_head_output, read_head_state = self.read_head(
            controller_output, previous_read_head_state)
        # Write
        write_head_state = self.write_head(controller_output,
                                           previous_write_head_state)
        fc_input = torch.cat((controller_output, read_head_output), dim=1)
        state = (read_head_output, read_head_state, write_head_state,
                 controller_state)
        return F.sigmoid(self.fc(fc_input)), state
Example #4
0
    def call(
        self, tm_input, states
    ):  # states = [tm_output, tm_state, wt_read_flat, wt_write_flat, mem_t_flat]
        _, tm_state, wt_read_flat, wt_write_flat, mem_t_flat = states  # ignore previous TM output
        wt_read = Reshape((self.n_read_heads, self.N))(wt_read_flat)
        wt_write = Reshape((self.n_write_heads, self.N))(wt_write_flat)
        mem_t = Reshape((self.N, self.M))(mem_t_flat)
        memory = Memory(mem_t)

        _, wt_write, memory = self.write_head.call(tm_input, tm_state,
                                                   [wt_write, memory])
        head_data, wt_read = self.read_head.call(tm_state, [wt_read, memory])
        tm_output, tm_state = self.controller.call(tm_input, tm_state,
                                                   head_data)

        states_flat = [
            Flatten()(tensor) for tensor in [wt_read, wt_write, memory.tensor]
        ]
        return tm_output, [tm_output, tm_state] + states_flat