Exemplo n.º 1
0
def fit(train_loader, val_loader, test_loader, model, loss_fn, optimizer, scheduler, n_epochs, cuda, log_interval, save_root, metrics=[],
        start_epoch=0):
    """
    Loaders, model, loss function and metrics should work together for a given task,
    i.e. The model should be able to process data output of loaders,
    loss function should process target output of loaders and outputs from the model
    Examples: Classification: batch loader, classification model, NLL loss, accuracy metric
    Siamese network: Siamese loader, siamese model, contrastive loss
    Online triplet learning: batch loader, embedding model, online triplet loss
    """
    # for epoch in range(0, start_epoch):
    #     scheduler.step()

    if not os.path.exists(save_root):
        os.makedirs(save_root)

    for epoch in range(start_epoch, n_epochs):
        # scheduler.step()

        # Train stage
        train_loss, metrics = train_epoch(train_loader, model, loss_fn, optimizer, cuda, log_interval, metrics)

        message = 'Epoch: {}/{}. Train set: Average loss: {:.4f}'.format(epoch + 1, n_epochs, train_loss)
        for metric in metrics:
            message += '\t{}: {}'.format(metric.name(), metric.value())
        # if (epoch+1) % 10 == 0:
        val_acc = val_epoch(val_loader, model, cuda)
        message += '\nEpoch: {}/{}. Validation set: Acc: {:.4f}'.format(epoch + 1, n_epochs,
                                                                                    val_acc)
        # write prediction for every epoch
        prediction = test_epoch(test_loader, model, cuda)
        write_preds(prediction, config.pred_file+str(epoch+1)+'.txt')
            # val_loss, metrics = val_epoch(val_loader, model, cuda)
            # val_loss /= len(val_loader)

            # message += '\nEpoch: {}/{}. Validation set: Average loss: {:.4f}'.format(epoch + 1, n_epochs,
            #                                                                         val_loss)
            # for metric in metrics:
            #     message += '\t{}: {}'.format(metric.name(), metric.value())
        
        if (epoch+1)%5 == 0:
            save_checkpoint(model.state_dict(), False, save_root, str(epoch))
        print(message)
    save_checkpoint(model.state_dict(), False, save_root, str(n_epochs))
    prediction = test_epoch(test_loader, model, cuda)
    write_preds(prediction, config.pred_file+str(epoch+1)+'.txt')
Exemplo n.º 2
0
def pred_test_with_best_model(clf,
                              X_train,
                              y_train,
                              X_dev,
                              y_dev,
                              X_test,
                              write_to_file=True):
    X = deepcopy(X_train)
    y = deepcopy(y_train)
    X.extend(X_dev)
    y.extend(y_dev)
    print(len(X), len(y))

    clf.fit(X, y)
    preds = clf.predict(X_test)

    if write_to_file:
        write_preds('../data/predictions_for_test.txt', preds)

    return preds
Exemplo n.º 3
0
    def val_epoch(self, epoch, model, dev_dataloader, args={}, fold='dev'):
        y, yhat, yhat_raw, hids, losses = [], [], [], [], []
        model.eval()
        with torch.no_grad():
            for step, sample in enumerate(dev_dataloader):
                hadms, docs, labels, ordered_labels, doc_masks, doc_lengths, desc_vectors, code_set = sample
                if torch.cuda.is_available():
                    docs = docs.cuda()
                    labels = labels.cuda()
                    doc_lengths = doc_lengths.cuda()
                    doc_masks = doc_masks.cuda()

                logits, _, _, _  = model(docs, doc_masks, doc_lengths, adj = self.adj, leaf_idxs = self.leaf_idxs, code_desc=self.code_desc, code_set=None)
                loss = self.model.get_multilabel_loss(labels, logits)
                output = F.sigmoid(logits)
                output = output.cpu().numpy() if torch.cuda.is_available() else output.numpy()
                losses.append(loss.cpu().item() if torch.cuda.is_available() else loss.item())
                targets = labels.cpu().numpy() if torch.cuda.is_available() else labels.numpy()

                y.append(targets)
                yhat.append(np.round(output))
                yhat_raw.append(output)
                hids.extend(hadms)

        y = np.concatenate(y, axis=0)
        yhat = np.concatenate(yhat, axis=0)
        yhat_raw = np.concatenate(yhat_raw, axis=0)
        
        dicts = dev_dataloader.dataset.dicts
        ind2c = dicts['ind2c']
        #write the predictions

        preds_file = utils.write_preds(yhat, args.save_dir, epoch, hids, fold, ind2c, yhat_raw)
        #get metrics
        k = 5 if args.Y == 50 else [8,15]
        metrics = all_metrics(yhat, y, k=k, yhat_raw=yhat_raw)
        metrics['loss_%s' % fold] = np.mean(losses)
        return metrics
Exemplo n.º 4
0
    def eval_epoch(self, final=False, save_predictions=False):
        """
        Evaluate the model on the test set.
        No backward computation is allowed.
        """
        t1 = time()
        output = {
            'tp': [],
            'fp': [],
            'fn': [],
            'tn': [],
            'loss': [],
            'preds': []
        }
        test_info = []

        self.model = self.model.eval()
        test_iter = self.iterator(self.data['test'],
                                  batch_size=self.params['batch'],
                                  shuffle_=False)
        for batch_idx, batch in enumerate(test_iter):
            batch = self.convert_batch(batch)

            with torch.no_grad():
                loss, stats, predictions, select = self.model(batch)

                output['loss'] += [loss.item()]
                output['tp'] += [stats['tp'].to('cpu').data.numpy()]
                output['fp'] += [stats['fp'].to('cpu').data.numpy()]
                output['fn'] += [stats['fn'].to('cpu').data.numpy()]
                output['tn'] += [stats['tn'].to('cpu').data.numpy()]
                output['preds'] += [predictions.to('cpu').data.numpy()]
                test_info += [
                    batch['info'][select[0].to('cpu').data.numpy(),
                                  select[1].to('cpu').data.numpy(),
                                  select[2].to('cpu').data.numpy()]
                ]
        t2 = time()

        # estimate performance
        if self.window:
            total_loss, scores = self.subdocs_performance(
                output['loss'], output['preds'], test_info)
        else:
            total_loss, scores = self.performance(output)

        if not final:
            self.test_res['loss'] += [total_loss]
            self.test_res['score'] += [scores[self.primary_metric]]
        print('            TEST  | LOSS = {:.05f}, '.format(total_loss),
              end="")
        print_results(scores, [], self.show_class, t2 - t1)
        print()

        if save_predictions:
            write_preds(output['preds'],
                        test_info,
                        self.preds_file,
                        map_=self.loader.index2rel)
            write_errors(output['preds'],
                         test_info,
                         self.preds_file,
                         map_=self.loader.index2rel)
Exemplo n.º 5
0
def main(args):
    logging.info('-' * 50)
    logging.info('Load data files..')

    if args.debug:
        logging.info('*' * 10 + ' Train')
        train_examples = utils.load_data(args.train_file, 5, relabeling=args.relabeling,
                                         remove_notfound=args.remove_notfound)
        logging.info('*' * 10 + ' Dev')
        dev_examples = utils.load_data(args.dev_file, 100, relabeling=args.relabeling,
                                       remove_notfound=False)
    #elif args.test_only:
    #    logging.info('*' * 10 + ' Train')
    #    #train_examples = utils.load_cnn_data(args.train_file, relabeling=args.relabeling)  # docs, qs, ans
    #    train_examples = utils.load_data(args.train_file, relabeling=args.relabeling, remove_notfound=args.remove_notfound)  # docs, qs, ans
    #    logging.info('*' * 10 + ' Dev')
    #    dev_examples = utils.load_data(args.dev_file, args.max_dev, relabeling=args.relabeling,
    #                                   remove_notfound=False)
    elif args.cnn_train:
        logging.info('*' * 10 + ' Train')
        train_examples = utils.load_cnn_data(args.train_file, relabeling=args.relabeling, has_ids=args.train_has_ids)  # docs, qs, ans
        logging.info('*' * 10 + ' Dev')
        dev_examples = utils.load_cnn_data(args.dev_file, args.max_dev, relabeling=args.relabeling, has_ids=args.dev_has_ids)
    else:
        logging.info('*' * 10 + ' Train')
        train_examples = utils.load_data(args.train_file, relabeling=args.relabeling,
                                         remove_notfound=args.remove_notfound)  # docs, qs, ans
        logging.info('*' * 10 + ' Dev')
        dev_examples = utils.load_data(args.dev_file, args.max_dev, relabeling=args.relabeling,
                                       remove_notfound=False)

    args.num_train = len(train_examples[0])
    args.num_dev = len(dev_examples[0])

    logging.info('-' * 50)
    logging.info('Build dictionary..')
    word_dict = utils.build_dict(train_examples[0] + train_examples[1],  # + dev_examples[0] + dev_examples[1],
                                 max_words=args.max_words)  # docs+qs
    entity_markers = list(set([w for w in word_dict.keys()
                              if w.startswith('@entity')] + train_examples[2]))
    entity_markers = ['<unk_entity>'] + entity_markers
    entity_dict = {w: index for (index, w) in enumerate(entity_markers)}
    inv_entity_dict = {index: w for w, index in entity_dict.items()}
    assert len(entity_dict) == len(inv_entity_dict)
    logging.info('Entity markers: %d' % len(entity_dict))
    args.num_labels = len(entity_dict)

    logging.info('-' * 50)
    # Load embedding file
    embeddings = utils.gen_embeddings(word_dict, args.embedding_size, args.embedding_file)
    (args.vocab_size, args.embedding_size) = embeddings.shape
    logging.info('Compile functions..')
    train_fn, test_fn, params = build_fn(args, embeddings)
    logging.info('Done.')

    logging.info('-' * 50)
    logging.info(args)

    logging.info('-' * 50)
    logging.info('Intial test..')
    dev_x1, dev_x2, dev_l, dev_y, dev_ids = utils.vectorize(dev_examples, word_dict, entity_dict,
                                                   remove_notfound=False,
                                                   relabeling=args.relabeling)
    if dev_ids is not None:
        assert len(dev_y) == len(dev_ids)
    assert len(dev_x1) == args.num_dev
    all_dev = gen_examples(dev_x1, dev_x2, dev_l, dev_y, args.batch_size)
    dev_acc, dev_preds = eval_acc(test_fn, all_dev)

    if dev_ids is not None:
        assert len(dev_ids) == len(dev_preds) == len(dev_y)
        dev_preds_data = to_output_preds(dev_ids, dev_preds, inv_entity_dict, args.relabeling)
    logging.info('Dev accuracy: %.2f %%' % dev_acc)
    best_acc = dev_acc

    if args.log_file is not None:
        assert args.log_file.endswith(".log")
        run_name = args.log_file[:args.log_file.find(".log")]
        if dev_ids is not None:
            preds_file_name = run_name + ".preds"
            utils.write_preds(dev_preds_data, preds_file_name)
            utils.external_eval(preds_file_name,
                                run_name + ".preds.scores",
                                eval_data="test" if "test" in os.path.basename(args.dev_file) else "dev")
    if args.test_only:
        return

    if args.log_file is not None:
        utils.save_params(run_name + ".model", params, epoch=0, n_updates=0)

    # Training
    logging.info('-' * 50)
    logging.info('Start training..')
    train_x1, train_x2, train_l, train_y, train_ids = utils.vectorize(train_examples, word_dict, entity_dict,
                                                           remove_notfound=args.remove_notfound,
                                                           relabeling=args.relabeling)
    assert len(train_x1) == args.num_train
    start_time = time.time()
    n_updates = 0
    train_accs = []
    dev_accs = []
    all_train = gen_examples(train_x1, train_x2, train_l, train_y, args.batch_size)
    improved = []
    for epoch in range(args.num_epoches):
        ep_acc_improved = False
        np.random.shuffle(all_train)
        for idx, (mb_x1, mb_mask1, mb_x2, mb_mask2, mb_l, mb_y) in enumerate(all_train):
            logging.info('#Examples = %d, max_len = %d' % (len(mb_x1), mb_x1.shape[1]))
            train_loss = train_fn(mb_x1, mb_mask1, mb_x2, mb_mask2, mb_l, mb_y)
            logging.info('Epoch = %d, iter = %d (max = %d), loss = %.2f, elapsed time = %.2f (s)' %
                         (epoch, idx, len(all_train), train_loss, time.time() - start_time))
            n_updates += 1

            if n_updates % args.eval_iter == 0:
                samples = sorted(np.random.choice(args.num_train, min(args.num_train, args.num_dev),
                                                  replace=False))
                sample_train = gen_examples([train_x1[k] for k in samples],
                                            [train_x2[k] for k in samples],
                                            train_l[samples],
                                            [train_y[k] for k in samples],
                                            args.batch_size)
                train_acc, train_preds = eval_acc(test_fn, sample_train)
                train_accs.append(train_acc)
                logging.info('Train accuracy: %.2f %%' % train_acc)
                dev_acc, dev_preds = eval_acc(test_fn, all_dev)
                dev_accs.append(dev_acc)
                logging.info('Dev accuracy: %.2f %%' % dev_acc)
                utils.update_plot(args.eval_iter, train_accs, dev_accs, file_name=args.log_file + ".html")
                if dev_acc > best_acc:
                    ep_acc_improved = True
                    best_acc = dev_acc
                    logging.info('Best dev accuracy: epoch = %d, n_udpates = %d, acc = %.2f %%'
                                 % (epoch, n_updates, dev_acc))
                    if args.log_file is not None:
                        utils.save_params(run_name + ".model", params, epoch=epoch, n_updates=n_updates)
                        if dev_ids is not None:
                            dev_preds_data = to_output_preds(dev_ids, dev_preds, inv_entity_dict, args.relabeling)
                            utils.write_preds(dev_preds_data, preds_file_name)
                            utils.external_eval(preds_file_name, run_name + ".preds.scores", eval_data="dev")
        improved.append(ep_acc_improved)
        # early stop
        if len(improved) > 25 and sum(improved[-3:]) == 0:
            break
Exemplo n.º 6
0
import numpy as np
from utils import write_preds

a = np.arange(10)
filepath = 'pred.txt'
write_preds(a, filepath)