def test_single_image(model, img, label_list, uncertainty=False, device=None): if not device: device = get_device() num_classes = 10 trans = transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor()]) img_tensor = trans(img) img_tensor.unsqueeze_(0) img_variable = Variable(img_tensor) img_variable = img_variable.to(device) if uncertainty: output = model(img_variable) evidence = relu_evidence(output) alpha = evidence + 1 uncertainty = num_classes / torch.sum(alpha, dim=1, keepdim=True) _, preds = torch.max(output, 1) prob = alpha / torch.sum(alpha, dim=1, keepdim=True) output = output.flatten() prob = prob.flatten() preds = preds.flatten() label = list(label_list.keys())[list(label_list.values()).index(preds[0])] print("Predict:", label) print("Probs:", prob) print("Uncertainty:", uncertainty) else: output = model(img_variable) _, preds = torch.max(output, 1) prob = F.softmax(output, dim=1) output = output.flatten() prob = prob.flatten() preds = preds.flatten() label = list(label_list.keys())[list(label_list.values()).index(preds[0])] print("Predict:", label) print("Probs:", prob) labels = label_list.keys() fig = plt.figure(figsize=[6.2, 5]) fig, axs = plt.subplots(1, 2, gridspec_kw={"width_ratios": [1, 3]}) plt.title("Classified as: {}".format( label)) axs[0].imshow(img, cmap="gray") axs[0].axis("off") axs[1].bar(labels, prob.cpu().detach().numpy(), width=0.5) axs[1].set_xlim([0, 9]) axs[1].set_ylim([0, 1]) # axs[1].set_xticks(np.arange(10)) axs[1].set_xlabel("Classes") axs[1].set_ylabel("Classification Probability") fig.tight_layout() plt.savefig("./results/test_image.jpg")
def train_model(model, dataloaders, num_classes, criterion, optimizer, scheduler=None, num_epochs=25, device=None, uncertainty=False): since = time.time() if not device: device = get_device() best_model_wts = copy.deepcopy(model.state_dict()) best_acc = 0.0 losses = {"loss": [], "phase": [], "epoch": []} accuracy = {"accuracy": [], "phase": [], "epoch": []} evidences = {"evidence": [], "type": [], "epoch": []} for epoch in range(num_epochs): print("Epoch {}/{}".format(epoch, num_epochs - 1)) print("-" * 10) # Each epoch has a training and validation phase for phase in ["train", "val"]: if phase == "train": print("Training...") model.train() # Set model to training mode else: print("Validating...") model.eval() # Set model to evaluate mode running_loss = 0.0 running_corrects = 0.0 correct = 0 # Iterate over data. for i, (inputs, labels) in enumerate(dataloaders[phase]): inputs = inputs.to(device) labels = labels.to(device) # zero the parameter gradients optimizer.zero_grad() # forward # track history if only in train with torch.set_grad_enabled(phase == "train"): if uncertainty: y = one_hot_embedding(labels) y = y.to(device) outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion( outputs, y.float(), epoch, num_classes, 10, device) match = torch.reshape(torch.eq( preds, labels).float(), (-1, 1)) acc = torch.mean(match) evidence = relu_evidence(outputs) alpha = evidence + 1 u = num_classes / torch.sum(alpha, dim=1, keepdim=True) total_evidence = torch.sum(evidence, 1, keepdim=True) mean_evidence = torch.mean(total_evidence) mean_evidence_succ = torch.sum( torch.sum(evidence, 1, keepdim=True) * match) / torch.sum(match + 1e-20) mean_evidence_fail = torch.sum( torch.sum(evidence, 1, keepdim=True) * (1 - match)) / (torch.sum(torch.abs(1 - match)) + 1e-20) else: outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) if phase == "train": loss.backward() optimizer.step() # statistics running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) if scheduler is not None: if phase == "train": scheduler.step() epoch_loss = running_loss / len(dataloaders[phase].dataset) epoch_acc = running_corrects.double( ) / len(dataloaders[phase].dataset) losses["loss"].append(epoch_loss) losses["phase"].append(phase) losses["epoch"].append(epoch) accuracy["accuracy"].append(epoch_acc.item()) accuracy["epoch"].append(epoch) accuracy["phase"].append(phase) print("{} loss: {:.4f} acc: {:.4f}".format( phase.capitalize(), epoch_loss, epoch_acc)) # deep copy the model if phase == "val" and epoch_acc > best_acc: best_acc = epoch_acc best_model_wts = copy.deepcopy(model.state_dict()) print() time_elapsed = time.time() - since print("Training complete in {:.0f}m {:.0f}s".format( time_elapsed // 60, time_elapsed % 60)) print("Best val Acc: {:4f}".format(best_acc)) # load best model weights model.load_state_dict(best_model_wts) metrics = (losses, accuracy) return model, metrics
def rotating_image_classification(model, img, filename, label_list, uncertainty=False, threshold=0.5, device=None): print(label_list) if not device: device = get_device() num_classes = 10 Mdeg = 180 Ndeg = int(Mdeg / 10) + 1 ldeg = [] lp = [] lu = [] classifications = [] scores = np.zeros((1, num_classes)) rimgs = np.zeros((28, 28 * Ndeg)) for i, deg in enumerate(np.linspace(0, Mdeg, Ndeg)): nimg = rotate_img(img.numpy()[0], deg).reshape(28, 28) nimg = np.clip(a=nimg, a_min=0, a_max=1) rimgs[:, i*28:(i+1)*28] = nimg trans = transforms.ToTensor() img_tensor = trans(nimg) img_tensor.unsqueeze_(0) img_variable = Variable(img_tensor) img_variable = img_variable.to(device) if uncertainty: output = model(img_variable) evidence = relu_evidence(output) alpha = evidence + 1 uncertainty = num_classes / torch.sum(alpha, dim=1, keepdim=True) _, preds = torch.max(output, 1) prob = alpha / torch.sum(alpha, dim=1, keepdim=True) output = output.flatten() prob = prob.flatten() preds = preds.flatten() classifications.append(preds[0].item()) lu.append(uncertainty.mean()) else: output = model(img_variable) _, preds = torch.max(output, 1) prob = F.softmax(output, dim=1) output = output.flatten() prob = prob.flatten() preds = preds.flatten() classifications.append(preds[0].item()) scores += prob.detach().cpu().numpy() >= threshold ldeg.append(deg) lp.append(prob.tolist()) labels = np.arange(10)[scores[0].astype(bool)] lp = np.array(lp)[:, labels] c = ["black", "blue", "red", "brown", "purple", "cyan"] marker = ["s", "^", "o"]*2 labels = labels.tolist() fig = plt.figure(figsize=[6.2, 5]) fig, axs = plt.subplots(3, gridspec_kw={"height_ratios": [4, 1, 12]}) for i in range(len(labels)): axs[2].plot(ldeg, lp[:, i], marker=marker[i], c=c[i]) if uncertainty: labels += ["uncertainty"] axs[2].plot(ldeg, lu, marker="<", c="red") print(classifications) axs[0].set_title("Rotated Image Classifications") axs[0].imshow(1 - rimgs, cmap="gray") axs[0].axis("off") plt.pause(0.001) empty_lst = [] empty_lst.append(classifications) axs[1].table(cellText=empty_lst, bbox=[0, 1.2, 1, 1]) axs[1].axis("off") axs[2].legend(labels) axs[2].set_xlim([0, Mdeg]) axs[2].set_ylim([0, 1]) axs[2].set_xlabel("Rotation Degree") axs[2].set_ylabel("Classification Probability") plt.savefig(filename)
def zoom_image_classification(model, img, filename, uncertainty=False, threshold=0.2, device=None): if not device: device = get_device() num_classes = 10 zoom_images = 5 ldeg = [] lp = [] lu = [] classifications = [] scores = np.zeros((1, num_classes)) rimgs = np.zeros((28, 28 * zoom_images)) image_sizes = (np.random.dirichlet([25, 10, 10, 10, 10]) * 28).astype(int) ss = 0 x = int(np.random.rand() * 28) y = int(np.random.rand() * 28) for i, image_size in enumerate(image_sizes): ss += image_size nimg = reduce_img(img, ss, [x, y]) nimg = np.clip(a=nimg, a_min=0, a_max=1) rimgs[:, i * 28:(i + 1) * 28] = nimg trans = transforms.ToTensor() img_tensor = trans(nimg) img_tensor.unsqueeze_(0) img_variable = Variable(img_tensor) img_variable = img_variable.to(device) if uncertainty: output = model(img_variable) evidence = relu_evidence(output) alpha = evidence + 1 uncertainty = num_classes / torch.sum(alpha, dim=1, keepdim=True) _, preds = torch.max(output, 1) prob = alpha / torch.sum(alpha, dim=1, keepdim=True) output = output.flatten() prob = prob.flatten() preds = preds.flatten() classifications.append(preds[0].item()) lu.append(uncertainty.mean()) else: output = model(img_variable) _, preds = torch.max(output, 1) prob = F.softmax(output, dim=1) output = output.flatten() prob = prob.flatten() preds = preds.flatten() classifications.append(preds[0].item()) scores += prob.detach().cpu().numpy() >= threshold ldeg.append(ss) lp.append(prob.tolist()) labels = np.arange(10)[scores[0].astype(bool)] lp = np.array(lp)[:, labels] c = ["blue", "red", "brown", "purple", "cyan"] marker = ["s", "^", "o"] * 2 labels = labels.tolist() fig = plt.figure(figsize=[6, 5]) fig, axs = plt.subplots(3, gridspec_kw={"height_ratios": [5, 1, 12]}) for i in range(len(labels)): axs[2].plot(ldeg, lp[:, i], marker=marker[i], c=c[i]) if uncertainty: labels += ["uncertainty"] axs[2].plot(ldeg, lu, marker="<", c="black") #axs[0].set_title("Zoomed \"1\" Digit Classifications") axs[0].imshow(1 - rimgs, cmap="gray") axs[0].axis("off") #plt.pause(0.001) empty_lst = [] empty_lst.append(classifications) axs[1].table(cellText=empty_lst, bbox=[0, 1, 1, 1]) axs[1].axis("off") axs[2].legend(labels) axs[2].set_xlim([8, 28]) axs[2].set_ylim([0, 1]) axs[2].set_xlabel("Zoom pixels") axs[2].set_ylabel("Classification Probability") #fig.show() fig.savefig(filename)