def calculate_loss(seg, recon, depth, normals): ((seg_pred, seg_target), (recon_pred, recon_target), (depth_pred, depth_target), (normals_pred, normals_target)) = (seg, recon, depth, normals) seg_bce_loss = torch.nn.functional.binary_cross_entropy_with_logits(seg_pred, seg_target) ae_bce_loss = torch.nn.functional.binary_cross_entropy_with_logits(recon_pred, recon_target) normals_bce_loss = torch.nn.functional.binary_cross_entropy_with_logits(normals_pred, normals_target) depth_bce_loss = torch.nn.functional.binary_cross_entropy_with_logits(depth_pred, depth_target) pred_soft = torch.sigmoid(seg_pred) dice = dice_loss(pred_soft, seg_target, epsilon=1) jaccard = jaccard_loss(pred_soft, seg_target, epsilon=1) terms = NOD.dict_of(dice, jaccard, ae_bce_loss, seg_bce_loss, depth_bce_loss, normals_bce_loss ) term_weight = 1 / len(terms) weighted_terms = [term.mean() * term_weight for term in terms.as_list()] loss = sum(weighted_terms) return NOD.dict_of(loss, terms)
def test_model(model, data_iterator, latest_model_path, num_columns=2, device='cpu'): model = model.eval().to(device) inputs, labels = next(data_iterator) inputs = inputs.to(device) labels = labels.to(device) with torch.no_grad(): pred = model(inputs) y_pred = pred.data.to(device).numpy() y_pred_max = np.argmax(y_pred, axis=-1) accuracy_w = accuracy_score(labels, y_pred_max) precision_a, recall_a, fscore_a, support_a = precision_recall_fscore_support( labels, y_pred_max) precision_w, recall_w, fscore_w, support_w = precision_recall_fscore_support( labels, y_pred_max, average='weighted') _, predicted = torch.max(pred, 1) truth_labels = labels.data.to(device).numpy() input_images_rgb = [a_retransform(x) for x in inputs.to(device)] cell_width = (800 / num_columns) - 6 - 6 * 2 plt.plot(np.random.random((3, 3))) alphabet = string.ascii_lowercase class_names = np.array([*alphabet]) samples = len(y_pred) predictions = [[None for _ in range(num_columns)] for _ in range(samples // num_columns)] for i, a, b, c in zip(range(samples), input_images_rgb, y_pred_max, truth_labels): plt.imshow(a) if b == c: outcome = 'tp' else: outcome = 'fn' gd = ReportEntry(name=i, figure=plt_html(format='jpg', size=[cell_width, cell_width]), prediction=class_names[b], truth=class_names[c], outcome=outcome) predictions[i // num_columns][i % num_columns] = gd plot_confusion_matrix(y_pred_max, truth_labels, class_names) title = 'Classification Report' model_name = latest_model_path confusion_matrix = plt_html(format='png', size=[800, 800]) accuracy = generate_math_html('\dfrac{tp+tn}{N}'), None, accuracy_w precision = generate_math_html( '\dfrac{tp}{tp+fp}'), precision_a, precision_w recall = generate_math_html('\dfrac{tp}{tp+fn}'), recall_a, recall_w f1_score = generate_math_html( '2*\dfrac{precision*recall}{precision+recall}'), fscore_a, fscore_w support = generate_math_html('N_{class_truth}'), support_a, support_w metrics = NOD.dict_of(accuracy, precision, f1_score, recall, support).as_flat_tuples() bundle = NOD.dict_of(title, model_name, confusion_matrix, metrics, predictions) file_name = title.lower().replace(" ", "_") generate_html(file_name, **bundle) generate_pdf(file_name)