示例#1
0
def train_lstm():
    batch_size = 100
    num_layers = 3
    num_directions = 2
    embedding_size = 100
    hidden_size = 64
    learning_rate = 0.0001
    num_epochs = 5

    data_helper = DataHelper()
    train_text, train_labels, ver_text, ver_labels, test_text, test_labels = data_helper.get_data_and_labels()
    word_set = data_helper.get_word_set()
    vocab = data_helper.get_word_dict()
    words_length = len(word_set) + 2

    lstm = LSTM(words_length, embedding_size, hidden_size, num_layers, num_directions, batch_size)
    X = [[vocab[word] for word in sentence.split(' ')] for sentence in train_text]
    X_lengths = [len(sentence) for sentence in X]
    pad_token = vocab['<PAD>']
    longest_sent = max(X_lengths)
    b_size = len(X)
    padded_X = np.ones((b_size, longest_sent)) * pad_token
    for i, x_len in enumerate(X_lengths):
        sequence = X[i]
        padded_X[i, 0:x_len] = sequence[:x_len]

    x = Variable(torch.tensor(padded_X)).long()
    y = Variable(torch.tensor(list(int(i) for i in train_labels)))
    dataset = Data.TensorDataset(x, y)
    loader = Data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2
    )

    loss_func = nn.CrossEntropyLoss()
    optimizer = optim.Adam(lstm.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        for step, (batch_x, batch_y) in enumerate(loader):
            output = lstm(batch_x)
            temp = torch.argmax(output, dim=1)
            correct = 0
            for i in range(batch_size):
                if batch_y[i] == temp[i]:
                    correct += 1

            loss = loss_func(output, batch_y)
            print('epoch: {0}, step: {1}, loss: {2}, train acc: {3}'.format(epoch, step, loss, correct / batch_size))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        ver_lstm(lstm, ver_text, ver_labels, vocab, batch_size)
    test_lstm(lstm, test_text, test_labels, vocab, batch_size)
示例#2
0
print("Tokenized data")

model = LSTM(a_vocab_size=len(a_to_index),
             b_vocab_size=len(b_to_index),
             padding_index=0,
             lstms_in_out=((5, 5), (5, 5)),
             linear_layers=(10, 5),
             out_size=1,
             hidden_activation=nn.ReLU,
             final_activation=None)
print("Model loaded.")
learningRate = 0.01
epochs = 50
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learningRate)
batch_size = 100
print("Starting training...")
stats = StatsManager("exp1.0000")

for epoch in range(epochs):
    random.shuffle(data)
    for batch in range(int(len(data) / batch_size) - 1):
        print(".", end='')
        # Converting inputs and labels to Variable
        #print([row[0] for row in data[batch*batch_size:(batch+1)*batch_size]])
        a_normalized, a_len = normalize([
            row[0] for row in data[batch * batch_size:(batch + 1) * batch_size]
        ])
        b_normalized, b_len = normalize([
            row[1] for row in data[batch * batch_size:(batch + 1) * batch_size]
示例#3
0
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True  # CUDA only
    )

    # Network Definition + Optimizer + Scheduler
    model = LSTM(hidden_size=n_hidden1,
                 hidden_size2=n_hidden2,
                 num_securities=n_stocks,
                 dropout=0.2,
                 n_layers=2,
                 T=T)
    if use_cuda:
        model.cuda()
    optimizer = optim.RMSprop(model.parameters(),
                              lr=learning_rate,
                              weight_decay=0.0)  # n
    scheduler_model = lr_scheduler.StepLR(optimizer, step_size=1, gamma=1.0)

    # loss function
    criterion = nn.MSELoss(size_average=True).cuda()
    # Store successive losses
    losses = []
    it = 0
    for i in range(max_epochs):
        loss_ = 0.
        # Store current predictions
        predicted = []
        gt = []
        # Go through training data set
if os.path.isfile(checkpoint_file):
    print("Loading checkpoint...")
    lstm.load_state_dict(torch.load(checkpoint_file))

if use_cuda:
    lstm.cuda()

lstm.hidden = lstm.init_hidden(1)

# predictions = predict_batches(x_val, lstm, use_cuda=use_cuda)
# plt.plot(predictions.numpy().flatten())
# plt.plot(y_val.numpy().flatten())
# plt.show()

optimizer = optim.Adam(lstm.parameters(), lr=lr)

best_val_loss = 1000
for epoch in range(n_epochs):
    n_batches = x_train.shape[0]
    for i in range(n_batches):
        lstm.hidden = None
        input_batches = x_train[i]
        target_batches = y_train[i]
        train_loss = train(input_batches, target_batches, lstm, optimizer,
                           use_cuda)

    epoch_train_loss = evaluate(x_train, y_train, lstm)
    epoch_val_loss = evaluate(x_val, y_val, lstm)

    print("epoch %i/%i" % (epoch + 1, n_epochs))
示例#5
0
class Model(object):
    def __init__(self, args, device, rel2id, word_emb=None):
        lr = args.lr
        lr_decay = args.lr_decay
        self.cpu = torch.device('cpu')
        self.device = device
        self.args = args
        self.max_grad_norm = args.max_grad_norm
        if args.model == 'pa_lstm':
            self.model = PositionAwareLSTM(args, rel2id, word_emb)
        elif args.model == 'bgru':
            self.model = BGRU(args, rel2id, word_emb)
        elif args.model == 'cnn':
            self.model = CNN(args, rel2id, word_emb)
        elif args.model == 'pcnn':
            self.model = PCNN(args, rel2id, word_emb)
        elif args.model == 'lstm':
            self.model = LSTM(args, rel2id, word_emb)
        else:
            raise ValueError
        self.model.to(device)
        self.criterion = nn.CrossEntropyLoss()
        self.parameters = [
            p for p in self.model.parameters() if p.requires_grad
        ]
        # self.parameters = self.model.parameters()
        self.optimizer = torch.optim.SGD(self.parameters, lr)

    def update(self, batch):
        inputs = [p.to(self.device) for p in batch[:-1]]
        labels = batch[-1].to(self.device)
        self.model.train()
        logits = self.model(inputs)
        loss = self.criterion(logits, labels)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.parameters, self.max_grad_norm)
        self.optimizer.step()
        return loss.item()

    def predict(self, batch):
        inputs = [p.to(self.device) for p in batch[:-1]]
        labels = batch[-1].to(self.device)
        logits = self.model(inputs)
        loss = self.criterion(logits, labels)
        pred = torch.argmax(logits, dim=1).to(self.cpu)
        # corrects = torch.eq(pred, labels)
        # acc_cnt = torch.sum(corrects, dim=-1)
        return pred, batch[-1], loss.item()

    def eval(self, dset, vocab=None, output_false_file=None):
        rel_labels = [''] * len(dset.rel2id)
        for label, id in dset.rel2id.items():
            rel_labels[id] = label
        self.model.eval()
        pred = []
        labels = []
        loss = 0.0
        for idx, batch in enumerate(tqdm(dset.batched_data)):
            pred_b, labels_b, loss_b = self.predict(batch)
            pred += pred_b.tolist()
            labels += labels_b.tolist()
            loss += loss_b
            if output_false_file is not None and vocab is not None:
                all_words, pos, ner, subj_pos, obj_pos, labels_ = batch
                all_words = all_words.tolist()
                labels_ = labels_.tolist()
                for i, word_ids in enumerate(all_words):
                    if labels[i] != pred[i]:
                        length = 0
                        for wid in word_ids:
                            if wid != utils.PAD_ID:
                                length += 1
                        words = [vocab[wid] for wid in word_ids[:length]]
                        sentence = ' '.join(words)

                        subj_words = []
                        for sidx in range(length):
                            if subj_pos[i][sidx] == 0:
                                subj_words.append(words[sidx])
                        subj = '_'.join(subj_words)

                        obj_words = []
                        for oidx in range(length):
                            if obj_pos[i][oidx] == 0:
                                obj_words.append(words[oidx])
                        obj = '_'.join(obj_words)

                        output_false_file.write(
                            '%s\t%s\t%s\t%s\t%s\n' %
                            (sentence, subj, obj, rel_labels[pred[i]],
                             rel_labels[labels[i]]))

        loss /= len(dset.batched_data)
        return loss, utils.eval(pred, labels)

    def save(self, filename, epoch):
        params = {
            'model': self.model.state_dict(),
            'config': self.args,
            'epoch': epoch
        }
        try:
            torch.save(params, filename)
            print("model saved to {}".format(filename))
        except BaseException:
            print("[Warning: Saving failed... continuing anyway.]")

    def load(self, filename):
        params = torch.load(filename, map_location=self.device.type)
        self.model.load_state_dict(params['model'])
示例#6
0
def setup(config):
    if config.task.name == 'copy':
        task = CopyTask(
            batch_size=config.task.batch_size,
            min_len=config.task.min_len,
            max_len=config.task.max_len,
            bit_width=config.task.bit_width,
            seed=config.task.seed,
        )
    elif config.task.name == 'repeat':
        task = RepeatCopyTask(
            batch_size=config.task.batch_size,
            bit_width=config.task.bit_width,
            min_len=config.task.min_len,
            max_len=config.task.max_len,
            min_rep=config.task.min_rep,
            max_rep=config.task.max_rep,
            norm_max=config.task.norm_max,
            seed=config.task.seed,
        )
    elif config.task.name == 'recall':
        task = AssociativeRecallTask(
            batch_size=config.task.batch_size,
            bit_width=config.task.bit_width,
            item_len=config.task.item_len,
            min_cnt=config.task.min_cnt,
            max_cnt=config.task.max_cnt,
            seed=config.task.seed,
        )
    else:
        logging.info('Unknown task')
        exit(0)

    torch.manual_seed(config.model.seed)
    if config.model.name == 'lstm':
        model = LSTM(
            n_inputs=task.full_input_width,
            n_outputs=task.full_output_width,
            n_hidden=config.model.n_hidden,
            n_layers=config.model.n_layers,
        )
    elif config.model.name == 'ntm':
        model = NTM(
            input_size=task.full_input_width,
            output_size=task.full_output_width,
            mem_word_length=config.model.mem_word_length,
            mem_cells_count=config.model.mem_cells_count,
            n_writes=config.model.n_writes,
            n_reads=config.model.n_reads,
            controller_n_hidden=config.model.controller_n_hidden,
            controller_n_layers=config.model.controller_n_layers,
            clip_value=config.model.clip_value,
        )
    elif config.model.name == 'dnc':
        model = DNC(
            input_size=task.full_input_width,
            output_size=task.full_output_width,
            cell_width=config.model.cell_width,
            n_cells=config.model.n_cells,
            n_reads=config.model.n_reads,
            controller_n_hidden=config.model.controller_n_hidden,
            controller_n_layers=config.model.controller_n_layers,
            clip_value=config.model.clip_value,
            masking=config.model.masking,
            mask_min=config.model.mask_min,
            dealloc=config.model.dealloc,
            diff_alloc=config.model.diff_alloc,
            links=config.model.links,
            links_sharpening=config.model.links_sharpening,
            normalization=config.model.normalization,
            dropout=config.model.dropout,
        )
    else:
        logging.info('Unknown model')
        exit(0)

    if config.gpu and torch.cuda.is_available():
        model = model.cuda()

    # Setup optimizer
    if config.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=config.learning_rate,
                                    momentum=config.momentum)
    if config.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(
            model.parameters(),
            lr=config.learning_rate,
            momentum=config.momentum,
        )
    if config.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=config.learning_rate)

    step = 0
    if config.load:
        logging.info('Restoring model from checkpoint')
        model, optimizer, task, step = utils.load_checkpoint(
            model,
            optimizer,
            task,
            config.load,
        )

    return model, optimizer, task, step
示例#7
0
class Model(object):
    def __init__(self, args, device, rel2id, word_emb=None):
        lr = args.lr
        lr_decay = args.lr_decay
        self.cpu = torch.device('cpu')
        self.device = device
        self.args = args
        self.rel2id = rel2id
        self.max_grad_norm = args.max_grad_norm
        if args.model == 'pa_lstm':
            self.model = PositionAwareRNN(args, rel2id, word_emb)
        elif args.model == 'bgru':
            self.model = BGRU(args, rel2id, word_emb)
        elif args.model == 'cnn':
            self.model = CNN(args, rel2id, word_emb)
        elif args.model == 'pcnn':
            self.model = PCNN(args, rel2id, word_emb)
        elif args.model == 'lstm':
            self.model = LSTM(args, rel2id, word_emb)
        else:
            raise ValueError
        self.model.to(device)
        self.criterion = nn.CrossEntropyLoss()
        if args.fix_bias:
            self.model.flinear.bias.requires_grad = False
        self.parameters = [
            p for p in self.model.parameters() if p.requires_grad
        ]
        # self.parameters = self.model.parameters()
        self.optimizer = torch.optim.SGD(self.parameters, lr)
        self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                        'min',
                                                        patience=3,
                                                        factor=lr_decay)

    def update_lr(self, valid_loss):
        self.scheduler.step(valid_loss)

    def update(self, batch, penalty=False, weight=1.0):
        inputs = [p.to(self.device) for p in batch[:5]]
        labels = batch[5].to(self.device)
        self.model.train()
        logits = self.model(inputs)
        loss = self.criterion(logits, labels)
        # batch_ent = utils.calcEntropy(logits)
        # ent = torch.sum(batch_ent) / len(batch_ent)
        # if penalty:
        # 	loss = loss - ent*weight
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.parameters, self.max_grad_norm)
        self.optimizer.step()
        return loss.item()

    def get_bias(self):
        return self.model.flinear.bias.data

    def set_bias(self, bias):
        self.model.flinear.bias.data = bias

    def predict(self, batch):
        inputs = [p.to(self.device) for p in batch[:5]]
        labels = batch[5].to(self.cpu)
        orig_idx = batch[6]
        logits = self.model(inputs).to(self.cpu)
        loss = self.criterion(logits, labels)
        pred = torch.argmax(logits, dim=1).to(self.cpu)
        # corrects = torch.eq(pred, labels)
        # acc_cnt = torch.sum(corrects, dim=-1)
        recover_idx = utils.recover_idx(orig_idx)
        logits = [logits[idx].tolist() for idx in recover_idx]
        pred = [pred[idx].item() for idx in recover_idx]
        labels = [labels[idx].item() for idx in recover_idx]
        return logits, pred, labels, loss.item()

    def eval(self,
             dset,
             vocab=None,
             output_false_file=None,
             output_label_file=None,
             weights=None):
        if weights is None:
            weights = [1.0] * len(dset.rel2id)

        rel_labels = [''] * len(dset.rel2id)
        for label, id in dset.rel2id.items():
            rel_labels[id] = label
        self.model.eval()
        pred = []
        labels = []
        loss = 0.0

        for idx, batch in enumerate(dset.batched_data):
            scores_b, pred_b, labels_b, loss_b = self.predict(batch)
            pred += pred_b
            labels += labels_b
            loss += loss_b

            if output_false_file is not None and vocab is not None:
                all_words, pos, ner, subj_pos, obj_pos, labels_, _ = batch
                all_words = all_words.tolist()
                output_false_file.write('\n')
                for i, word_ids in enumerate(all_words):
                    if labels[i] != pred[i]:
                        length = 0
                        for wid in word_ids:
                            if wid != utils.PAD_ID:
                                length += 1
                        words = [vocab[wid] for wid in word_ids[:length]]
                        sentence = ' '.join(words)

                        subj_words = []
                        for sidx in range(length):
                            if subj_pos[i][sidx] == 0:
                                subj_words.append(words[sidx])
                        subj = '_'.join(subj_words)

                        obj_words = []
                        for oidx in range(length):
                            if obj_pos[i][oidx] == 0:
                                obj_words.append(words[oidx])
                        obj = '_'.join(obj_words)

                        output_false_file.write(
                            '%s\t%s\t%s\t%s\t%s\n' %
                            (sentence, subj, obj, rel_labels[pred[i]],
                             rel_labels[labels[i]]))

        if output_label_file is not None and vocab is not None:
            output_label_file.write(json.dumps(pred) + '\n')
            output_label_file.write(json.dumps(labels) + '\n')

        loss /= len(dset.batched_data)
        return loss, utils.eval(pred, labels, weights)

    def TuneEntropyThres(self,
                         test_dset,
                         noneInd=utils.NO_RELATION,
                         ratio=0.2,
                         cvnum=100):
        '''
		Tune threshold on test set
		'''
        rel_labels = [''] * len(test_dset.rel2id)
        for label, id in test_dset.rel2id.items():
            rel_labels[id] = label
        self.model.eval()
        pred = []
        labels = []
        scores = []
        loss = 0.0
        for idx, batch in enumerate(test_dset.batched_data):
            scores_b, pred_b, labels_b, loss_b = self.predict(batch)
            pred += pred_b
            labels += labels_b
            scores += scores_b
            loss += loss_b
        loss /= len(test_dset.batched_data)

        # start tuning
        scores = torch.tensor(scores)
        f1score = 0.0
        recall = 0.0
        precision = 0.0

        pre_ind = utils.calcInd(scores)
        pre_entropy = utils.calcEntropy(scores)
        valSize = int(np.floor(ratio * len(pre_ind)))
        data = [[pre_ind[ind], pre_entropy[ind], labels[ind]]
                for ind in range(0, len(pre_ind))]

        for cvind in tqdm(range(cvnum)):
            random.shuffle(data)
            val = data[0:valSize]
            eva = data[valSize:]

            # find best threshold
            max_ent = max(val, key=lambda t: t[1])[1]
            min_ent = min(val, key=lambda t: t[1])[1]
            stepSize = (max_ent - min_ent) / 100
            thresholdList = [min_ent + ind * stepSize for ind in range(0, 100)]
            ofInterest = 0
            for ins in val:
                if ins[2] != noneInd:
                    ofInterest += 1
            bestThreshold = float('nan')
            bestF1 = float('-inf')
            for threshold in thresholdList:
                corrected = 0
                predicted = 0
                for ins in val:
                    if ins[1] < threshold and ins[0] != noneInd:
                        predicted += 1
                        if ins[0] == ins[2]:
                            corrected += 1
                curF1 = 2.0 * corrected / (ofInterest + predicted)
                if curF1 > bestF1:
                    bestF1 = curF1
                    bestThreshold = threshold
            ofInterest = 0
            corrected = 0
            predicted = 0
            for ins in eva:
                if ins[2] != noneInd:
                    ofInterest += 1
                if ins[1] < bestThreshold and ins[0] != noneInd:
                    predicted += 1
                    if ins[0] == ins[2]:
                        corrected += 1

            f1score += (2.0 * corrected / (ofInterest + predicted))
            recall += (1.0 * corrected / ofInterest)
            precision += (1.0 * corrected / (predicted + 0.00001))

        f1score /= cvnum
        recall /= cvnum
        precision /= cvnum

        return loss, f1score, recall, precision

    def TuneMaxThres(self,
                     test_dset,
                     noneInd=utils.NO_RELATION,
                     ratio=0.2,
                     cvnum=100):
        '''
		Tune threshold on test set
		'''
        rel_labels = [''] * len(test_dset.rel2id)
        for label, id in test_dset.rel2id.items():
            rel_labels[id] = label
        self.model.eval()
        pred = []
        labels = []
        scores = []
        loss = 0.0
        for idx, batch in enumerate(test_dset.batched_data):
            scores_b, pred_b, labels_b, loss_b = self.predict(batch)
            pred += pred_b
            labels += labels_b
            scores += scores_b
            loss += loss_b
        loss /= len(test_dset.batched_data)

        # start tuning
        scores = torch.tensor(scores)
        f1score = 0.0
        recall = 0.0
        precision = 0.0

        pre_prob, pre_ind = torch.max(scores, 1)
        valSize = int(np.floor(ratio * len(pre_ind)))
        data = [[pre_ind[ind], pre_prob[ind], labels[ind]]
                for ind in range(0, len(pre_ind))]
        for cvind in tqdm(range(cvnum)):
            random.shuffle(data)
            val = data[0:valSize]
            eva = data[valSize:]

            # find best threshold
            max_ent = max(val, key=lambda t: t[1])[1]
            min_ent = min(val, key=lambda t: t[1])[1]
            stepSize = (max_ent - min_ent) / 100
            thresholdList = [min_ent + ind * stepSize for ind in range(0, 100)]
            ofInterest = 0
            for ins in val:
                if ins[2] != noneInd:
                    ofInterest += 1
            bestThreshold = float('nan')
            bestF1 = float('-inf')
            for threshold in thresholdList:
                corrected = 0
                predicted = 0
                for ins in val:
                    if ins[1] > threshold and ins[0] != noneInd:
                        predicted += 1
                        if ins[0] == ins[2]:
                            corrected += 1
                curF1 = 2.0 * corrected / (ofInterest + predicted)
                if curF1 > bestF1:
                    bestF1 = curF1
                    bestThreshold = threshold

            ofInterest = 0
            corrected = 0
            predicted = 0
            for ins in eva:
                if ins[2] != noneInd:
                    ofInterest += 1
                if ins[1] > bestThreshold and ins[0] != noneInd:
                    predicted += 1
                    if ins[0] == ins[2]:
                        corrected += 1
            f1score += (2.0 * corrected / (ofInterest + predicted))
            recall += (1.0 * corrected / ofInterest)
            precision += (1.0 * corrected / (predicted + 0.00001))

        f1score /= cvnum
        recall /= cvnum
        precision /= cvnum

        return loss, f1score, recall, precision

    def save(self, filename, epoch):
        params = {
            'model': self.model.state_dict(),
            'config': self.args,
            'epoch': epoch
        }
        try:
            torch.save(params, filename)
            print("Epoch {}, model saved to {}".format(epoch, filename))
        except BaseException:
            print("[Warning: Saving failed... continuing anyway.]")
        # json.dump(vars(self.args), open('%s.json' % filename, 'w'))

    def count_parameters(self):
        return sum(p.numel() for p in self.model.parameters()
                   if p.requires_grad)

    def load(self, filename):
        params = torch.load(filename)
        if type(params).__name__ == 'dict' and 'model' in params:
            self.model.load_state_dict(params['model'])
        else:
            self.model.load_state_dict(params)
示例#8
0
def setup_model(config):
    # Load data
    if config.task.name == 'arithmetic':
        train_data = Arithmetic(
            batch_size=config.task.batch_size,
            min_len=config.task.min_len,
            max_len=config.task.max_len,
            task=config.task.task,
            seed=config.seed,
        )

        np.random.seed(config.seed)

        params = [20, 30, 40, 60]
        validation_data = []

        for length in params:
            example = train_data.gen_batch(batch_size=50,
                                           min_len=length,
                                           max_len=length,
                                           distribution=np.array([
                                               1,
                                           ]))
            validation_data.append((example, length))
        loss = Arithmetic.loss
    else:
        logging.info('Unknown task')
        exit(0)

    # Setup model
    torch.manual_seed(config.seed)
    if config.model.name == 'lstm':
        model = LSTM(
            n_inputs=train_data.symbols_amount,
            n_outputs=train_data.symbols_amount,
            n_hidden=config.model.n_hidden,
            n_layers=config.model.n_layers,
        )
    elif config.model.name == 'ntm':
        model = NTM(input_size=train_data.symbols_amount,
                    output_size=train_data.symbols_amount,
                    mem_word_length=config.model.mem_word_length,
                    mem_cells_count=config.model.mem_cells_count,
                    n_writes=config.model.n_writes,
                    n_reads=config.model.n_reads,
                    controller_n_hidden=config.model.controller_n_hidden,
                    controller_n_layers=config.model.controller_n_layers,
                    controller=config.model.controller,
                    layer_sizes=config.model.layer_sizes,
                    controller_output=config.model.controller_output,
                    clip_value=config.model.clip_value,
                    dropout=config.model.dropout)
    elif config.model.name == 'dnc':
        model = DNC(
            input_size=train_data.symbols_amount,
            output_size=train_data.symbols_amount,
            n_cells=config.model.n_cells,
            cell_width=config.model.cell_width,
            n_reads=config.model.n_reads,
            controller_n_hidden=config.model.controller_n_hidden,
            controller_n_layers=config.model.controller_n_layers,
            clip_value=config.model.clip_value,
        )
    else:
        logging.info('Unknown model')
        exit(0)

    if config.gpu and torch.cuda.is_available():
        model = model.cuda()

    logging.info('Loaded model')
    logging.info('Total number of parameters %d', model.calculate_num_params())

    # Setup optimizer
    if config.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=config.learning_rate,
                                    momentum=config.momentum)
    if config.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(
            model.parameters(),
            lr=config.learning_rate,
            momentum=config.momentum,
        )
    if config.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=config.learning_rate)

    if config.scheduler is not None:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=config.scheduler.factor,
            patience=config.scheduler.patience,
            verbose=config.scheduler.verbose,
            threshold=config.scheduler.threshold,
        )
        optimizer = (optimizer, scheduler)

    if config.load:
        model, optimizer, train_data, step = utils.load_checkpoint(
            model,
            optimizer,
            train_data,
            config.load,
        )

    return model, optimizer, loss, train_data, validation_data