Example #1
0
def train(model: nn.Module,
          dataset: Dataset,
          validate_data: Dataset = None) -> None:
    loader = DataLoader(dataset,
                        batch_size=dataset.BATCH_SIZE,
                        shuffle=dataset.BATCH_SIZE)

    optimizer = getattr(torch.optim,
                        config.TRAIN.OPTIMIZER)(model.parameters(),
                                                **config.TRAIN.OPTIM_PARAMS)
    overall_iter = 0
    evaluation = ConfusionMatrix(dataset.get_num_class())

    model.train()
    for epoch in range(config.TRAIN.NUM_EPOCHS):
        total_loss = 0
        for batch_idx, samples in enumerate(loader):
            images, target = device([samples['image'], samples['mask']],
                                    gpu=config.USE_GPU)
            outputs = model(images)['out']
            output_mask = outputs.argmax(1)

            batch_loss = Loss.cross_entropy2D(outputs, target, False)
            total_loss += batch_loss.item()
            overall_loss = total_loss / ((batch_idx + 1))
            evaluation.update(output_mask, target)

            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()

            if batch_idx % config.PRINT_BATCH_FREQ == 0:
                metrics = evaluation()
                logger.info(f'Train Epoch: {epoch}, {batch_idx}')
                logger.info(
                    f'Batch loss: {batch_loss.item():.6f}, Overall loss: {overall_loss:.6f}'
                )
                for met in beautify(metrics[0]):
                    logger.info(f'{met}')
                logger.info(f'Classwise IoU')
                for met in beautify(metrics[1]):
                    logger.info(f'{met}')
                logger.info("\n")

            overall_iter += 1
            if config.SAVE_ITER_FREQ and overall_iter % config.SAVE_ITER_FREQ == 0:
                torch.save(
                    model.state_dict(),
                    os.path.join(config.LOG_PATH,
                                 config.NAME + f"-iter={overall_iter}"))
Example #2
0
                            sample_batched['gender']).cuda(), Variable(
                                sample_batched['edu']).cuda(), Variable(
                                    sample_batched['apoe']).cuda(), Variable(
                                        sample_batched['dx']).view(-1).cuda()
        # ===================forward====================
        outputs = classifier(images, lefts, rights)

        loss = criterion(outputs, labels)

        predictions = torch.argmax(outputs, dim=1)  # outputs are N * 2: 0 / 1
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
        metrics.update(labels.data.cpu().numpy(),
                       predictions.data.cpu().numpy())
    results_train = metrics.get_scores()
    torch.cuda.empty_cache()
    metrics.reset()
    # =================== evaluation ========================
    if epoch % 1 == 0:
        classifier.eval()
        with torch.no_grad():
            for i_batch, sample_batched in enumerate(tqdm(dataloader_val)):
                images, lefts, rights, ages, genders, edus, apoes, labels = Variable(
                    sample_batched['mri']
                ).cuda(), Variable(sample_batched['left']).cuda(), Variable(
                    sample_batched['right']).cuda(), Variable(
                        sample_batched['age']).cuda(), Variable(
                            sample_batched['gender']).cuda(), Variable(
                                sample_batched['edu']).cuda(), Variable(
def predict_worker(proc_id,
                   output_file,
                   classes,
                   model_params,
                   batch_size,
                   que,
                   lock,
                   status_que,
                   gpu_id=0,
                   evaluate=True,
                   framework='mxnet'):
    """ get data from batch loader and make predictions, predictions will be saved in output_file
        if evaluate, will evaluate recall, precision, f1_score and recall_top5 """

    logging.info('Predictor #{}: Loading model...'.format(proc_id))
    model = load_model(proc_id,
                       model_params,
                       batch_size,
                       classes,
                       gpu_id,
                       framework=framework)
    if model is None:
        status_que.put('Error')
        raise ValueError('No model created! Exit')
    logging.info('Predictor #{}: Model loaded'.format(proc_id))
    status_que.put('OK')

    if evaluate:
        from metrics import F1, ConfusionMatrix, MisClassified, RecallTopK
        evaluator = F1(len(classes))
        misclassified = MisClassified(len(classes))
        cm = ConfusionMatrix(classes)
        recall_topk = RecallTopK(len(classes), top_k=5)

    f = open(output_file, 'w')
    batch_idx = 0
    logging.info('Predictor #{} starts'.format(proc_id))
    start = time.time()
    while True:
        # get a batch from data loader via a queue
        lock.acquire()
        batch = que.get()
        lock.release()
        if batch == 'FINISH':
            logging.info(
                'Predictor #{} has received all batches, exit'.format(proc_id))
            break

        # predict
        im_names, batch, gt_list = batch
        logging.debug('Predictor #{}: predict'.format(proc_id))
        pred, prob = model.predict(batch)
        pred_labels, top_probs = model.get_label_prob(top_k=5)

        # write prediction to file
        for im_name, label, top_prob in zip(im_names, pred_labels, top_probs):
            if im_name is None:
                continue
            top_prob = [str(p) for p in top_prob]
            f.write('{} labels:{} prob:{}\n'.format(im_name, ','.join(label),
                                                    ','.join(top_prob)))

        # update metrics if evaluation mode is set
        if evaluate:
            assert gt_list is not None and gt_list != [] and gt_list[
                0] is not None
            top1_int = [p[0] for p in pred]
            assert len(top1_int) == len(gt_list), '{} != {}'.format(
                len(top1_int), len(gt_list))
            evaluator.update(top1_int, gt_list)
            misclassified.update(top1_int, gt_list, prob, im_names)
            cm.update(top1_int, gt_list)

            top5_int = [p[:5] for p in pred]
            assert len(top5_int) == len(gt_list), '{} != {}'.format(
                len(top5_int), len(gt_list))
            recall_topk.update(top5_int, gt_list)

        batch_idx += 1
        if batch_idx % 50 == 0 and batch_idx != 0:
            elapsed = time.time() - start
            logging.info(
                'Predictor #{}: Tested {} batches of {} images, elapsed {}s'.
                format(proc_id, batch_idx, batch_size, elapsed))

    # evaluation after prediction if set
    if evaluate:
        logging.info('Evaluating...')
        recall, precision, f1_score = evaluator.get()
        for rec, prec, f1, cls, in zip(recall, precision, f1_score, classes):
            print(
                'Class {:<20}: recall: {:<12}, precsion: {:<12}, f1 score: {:<12}'
                .format(cls, rec, prec, f1))
            f.write(
                'Class {:<20}: recall: {:<12}, precsion: {:<12}, f1 score: {:<12}\n'
                .format(cls, rec, prec, f1))
        topk_recall = recall_topk.get()
        for rec, cls in zip(topk_recall, classes):
            print('Class {:<20}: recall-top-5: {:<12}'.format(cls, rec))
            f.write('Class {:<20}: recall-top-5: {:<12}\n'.format(cls, rec))

        fp_images, fn_images = misclassified.get()
        g = open(output_file + '.fp', 'w')
        for cls, fp_cls in zip(classes, fp_images):
            for fp in fp_cls:
                g.write('{} pred:{} prob:{} gt:{} prob:{}\n'.format(
                    fp[0], cls, fp[2], classes[fp[1]], fp[3]))
        g.close()
        g = open(output_file + '.fn', 'w')
        for cls, fn_cls in zip(classes, fn_images):
            for fn in fn_cls:
                g.write('{} gt:{} prob:{} pred:{} prob:{}\n'.format(
                    fp[0], cls, fp[3], classes[fp[1]], fp[2]))
        g.close()

        cm.normalize()
        plt_name = output_file + '_cm.jpg'
        cm.draw(plt_name)
    f.close()