Esempio n. 1
0
def calculate_loss(seg, recon, depth, normals):
  ((seg_pred, seg_target),
   (recon_pred, recon_target),
   (depth_pred, depth_target),
   (normals_pred, normals_target)) = (seg, recon, depth, normals)

  seg_bce_loss = torch.nn.functional.binary_cross_entropy_with_logits(seg_pred, seg_target)
  ae_bce_loss = torch.nn.functional.binary_cross_entropy_with_logits(recon_pred, recon_target)
  normals_bce_loss = torch.nn.functional.binary_cross_entropy_with_logits(normals_pred, normals_target)
  depth_bce_loss = torch.nn.functional.binary_cross_entropy_with_logits(depth_pred, depth_target)

  pred_soft = torch.sigmoid(seg_pred)
  dice = dice_loss(pred_soft, seg_target, epsilon=1)
  jaccard = jaccard_loss(pred_soft, seg_target, epsilon=1)

  terms = NOD.dict_of(dice,
                      jaccard,
                      ae_bce_loss,
                      seg_bce_loss,
                      depth_bce_loss,
                      normals_bce_loss
                      )

  term_weight = 1 / len(terms)
  weighted_terms = [term.mean() * term_weight for term in terms.as_list()]

  loss = sum(weighted_terms)

  return NOD.dict_of(loss, terms)
Esempio n. 2
0
def test_model(model,
               data_iterator,
               latest_model_path,
               num_columns=2,
               device='cpu'):
    model = model.eval().to(device)

    inputs, labels = next(data_iterator)

    inputs = inputs.to(device)
    labels = labels.to(device)
    with torch.no_grad():
        pred = model(inputs)

    y_pred = pred.data.to(device).numpy()
    y_pred_max = np.argmax(y_pred, axis=-1)
    accuracy_w = accuracy_score(labels, y_pred_max)
    precision_a, recall_a, fscore_a, support_a = precision_recall_fscore_support(
        labels, y_pred_max)
    precision_w, recall_w, fscore_w, support_w = precision_recall_fscore_support(
        labels, y_pred_max, average='weighted')

    _, predicted = torch.max(pred, 1)

    truth_labels = labels.data.to(device).numpy()

    input_images_rgb = [a_retransform(x) for x in inputs.to(device)]

    cell_width = (800 / num_columns) - 6 - 6 * 2

    plt.plot(np.random.random((3, 3)))

    alphabet = string.ascii_lowercase
    class_names = np.array([*alphabet])

    samples = len(y_pred)
    predictions = [[None for _ in range(num_columns)]
                   for _ in range(samples // num_columns)]
    for i, a, b, c in zip(range(samples), input_images_rgb, y_pred_max,
                          truth_labels):
        plt.imshow(a)
        if b == c:
            outcome = 'tp'
        else:
            outcome = 'fn'

        gd = ReportEntry(name=i,
                         figure=plt_html(format='jpg',
                                         size=[cell_width, cell_width]),
                         prediction=class_names[b],
                         truth=class_names[c],
                         outcome=outcome)

        predictions[i // num_columns][i % num_columns] = gd

    plot_confusion_matrix(y_pred_max, truth_labels, class_names)

    title = 'Classification Report'
    model_name = latest_model_path
    confusion_matrix = plt_html(format='png', size=[800, 800])

    accuracy = generate_math_html('\dfrac{tp+tn}{N}'), None, accuracy_w
    precision = generate_math_html(
        '\dfrac{tp}{tp+fp}'), precision_a, precision_w
    recall = generate_math_html('\dfrac{tp}{tp+fn}'), recall_a, recall_w
    f1_score = generate_math_html(
        '2*\dfrac{precision*recall}{precision+recall}'), fscore_a, fscore_w
    support = generate_math_html('N_{class_truth}'), support_a, support_w
    metrics = NOD.dict_of(accuracy, precision, f1_score, recall,
                          support).as_flat_tuples()

    bundle = NOD.dict_of(title, model_name, confusion_matrix, metrics,
                         predictions)

    file_name = title.lower().replace(" ", "_")

    generate_html(file_name, **bundle)
    generate_pdf(file_name)