Ejemplo n.º 1
0
    def __init__(self,
                 src_vocab_size,
                 d_word_vec,
                 d_ntm_input,
                 d_controller,
                 d_sent_enc,
                 n_layers,
                 n_heads,
                 n_slots,
                 m_depth,
                 embedding_matrix=None):
        super().__init__()
        if embedding_matrix is None:
            self.src_word_emb = nn.Embedding(src_vocab_size,
                                             d_word_vec,
                                             padding_idx=Constants.PAD)
        else:
            self.src_word_emb = nn.Embedding.from_pretrained(
                torch.from_numpy(embedding_matrix).type(
                    torch.cuda.FloatTensor),
                freeze=True)

            print("set pretrained word embeddings, size {}".format(
                self.src_word_emb.weight.size()))

        self.linear = nn.Linear(d_word_vec, d_ntm_input)
        self.ntm = EncapsulatedNTM(d_ntm_input, d_sent_enc, d_controller,
                                   n_layers, n_heads, n_slots, m_depth)
Ejemplo n.º 2
0
 def default_net(self):
     # See dataloader documentation
     net = EncapsulatedNTM(self.params.sequence_width + 2, self.params.sequence_width + 1,
                           self.params.controller_size, self.params.controller_layers,
                           self.params.num_heads,
                           self.params.memory_n, self.params.memory_m,
                           self.cuda,
                           self.time)
     if self.cuda and torch.cuda.is_available():
       net.cuda()
     return net
Ejemplo n.º 3
0
 def default_net(self):
     # We have 1 additional input for the delimiter which is passed on a
     # separate "control" channel
     net = EncapsulatedNTM(self.params.sequence_width + 1,
                           self.params.sequence_width,
                           self.params.controller_size,
                           self.params.controller_layers,
                           self.params.num_heads, self.params.memory_n,
                           self.params.memory_m, self.cuda, self.time)
     if self.cuda and torch.cuda.is_available():
         net.cuda()
     return net
Ejemplo n.º 4
0
 def default_net(self):
     # See dataloader documentation
     net = EncapsulatedNTM(self.params.sequence_width + 2, self.params.sequence_width + 1,
                           self.params.controller_size, self.params.controller_layers,
                           self.params.num_heads,
                           self.params.memory_n, self.params.memory_m)
     return net
Ejemplo n.º 5
0
 def default_net(self):
     # We have 1 additional input for the delimiter which is passed on a
     # separate "control" channel
     net = EncapsulatedNTM(self.params.sequence_width, self.params.sequence_width,
                           self.params.controller_size, self.params.controller_layers,
                           self.params.num_heads,
                           self.params.memory_n, self.params.memory_m).to(params.device)
     return net
Ejemplo n.º 6
0
class NTMEncoder(nn.Module):
    def __init__(self,
                 src_vocab_size,
                 d_word_vec,
                 d_ntm_input,
                 d_controller,
                 d_sent_enc,
                 n_layers,
                 n_heads,
                 n_slots,
                 m_depth,
                 embedding_matrix=None):
        super().__init__()
        if embedding_matrix is None:
            self.src_word_emb = nn.Embedding(src_vocab_size,
                                             d_word_vec,
                                             padding_idx=Constants.PAD)
        else:
            self.src_word_emb = nn.Embedding.from_pretrained(
                torch.from_numpy(embedding_matrix).type(
                    torch.cuda.FloatTensor),
                freeze=True)

            print("set pretrained word embeddings, size {}".format(
                self.src_word_emb.weight.size()))

        self.linear = nn.Linear(d_word_vec, d_ntm_input)
        self.ntm = EncapsulatedNTM(d_ntm_input, d_sent_enc, d_controller,
                                   n_layers, n_heads, n_slots, m_depth)

    def forward(self, src_seq, src_pos):
        """
        :param src_seq: source token index list
        :param src_pos: not used.
        :return:
        """
        batch_size, seq_len = src_seq.shape
        self.ntm.init_sequence(batch_size)

        embs = self.src_word_emb(src_seq)
        ntm_input = self.linear(F.dropout(embs, p=0.3))

        for t in range(seq_len):
            output, state = self.ntm(ntm_input[:, t, :])

        return output,
Ejemplo n.º 7
0
def init_model(args):

    params = Parameters()
    params = update_model_params(params, args.param)

    LOGGER.info(params)

    model = EncapsulatedNTM(params.image_size + params.classes, params.classes,
                            params.controller_size, params.controller_layers,
                            params.num_read_heads, params.num_write_heads,
                            params.memory_n, params.memory_m)

    # Can be collected by checkpoint:
    losses, accuracies = [], []
    accuracy_dict = {1: [], 2: [], 5: [], 10: []}
    start_epoch = 0

    ### LOADING PREVIOUS NETWORK ###
    if args.load_checkpoint:
        if os.path.isfile(args.load_checkpoint):
            print("=> loading checkpoint '{}'".format(args.load_checkpoint))
            checkpoint = torch.load(args.load_checkpoint)
            start_epoch = checkpoint['epoch']
            episode = checkpoint['episode']
            accuracy_dict = checkpoint['accuracy']
            accuracies = checkpoint['accuracies']
            losses = checkpoint['losses']
            params = checkpoint['parameters']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.load_checkpoint, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(
                args.load_checkpoint))

    return model, params, losses, accuracy_dict, accuracies, start_epoch
Ejemplo n.º 8
0
train_loader = dataloader(TOTAL_BATCHES, BATCH_SIZE, BYTE_WIDTH,
                          SEQUENCE_MIN_LEN, SEQUENCE_MAX_LEN)

criterion = nn.BCELoss()

#%%
'''
Train EncapsulatedNTM
Use LSTM Controller
'''

model = EncapsulatedNTM(BYTE_WIDTH + 1,
                        BYTE_WIDTH,
                        controller_size,
                        controller_layers,
                        num_heads,
                        memory_n,
                        memory_m,
                        controller_type='lstm')

optimizer = optim.RMSprop(model.parameters(),
                          momentum=rmsprop_momentum,
                          alpha=rmsprop_alpha,
                          lr=rmsprop_lr)
print('Total params of Model EncapsulatedNTM with LSTM controller :',
      model.calculate_num_params())
list_loss, list_cost, list_seq_num = train_model(model, loss_function,
                                                 optimizer, train_loader)

saveCheckpoint(model, list_seq_num, list_loss, list_cost, path='ntm1')