def criterion(logits_list, labels_list, img_size, n_class): loss_list = [] out_1, out_2, out_3, out_4, out_5, out_6 = logits_list label_1, label_2, label_3, label_4, label_5, label_6 = labels_list out_1 = torch.log(softmax_2d(out_1) + EPS) out_2 = torch.log(softmax_2d(out_2) + EPS) out_3 = torch.log(softmax_2d(out_3) + EPS) out_4 = torch.log(softmax_2d(out_4) + EPS) out_5 = torch.log(softmax_2d(out_5) + EPS) out_6 = torch.log(softmax_2d(out_6) + EPS) loss_1 = lossfunc(out_1, label_1[0]) loss_2 = lossfunc(out_2, label_2[0]) loss_3 = lossfunc(out_3, label_3[0]) loss_4 = lossfunc(out_4, label_4[0]) loss_5 = lossfunc(out_5, label_5[0]) loss_6 = lossfunc(out_6, label_6[0]) l = loss_1 + loss_2 + loss_3 + loss_4 + loss_5 + loss_6 loss_list.append(l) ppi_1 = np.argmax(out_1.cpu().data.numpy(), 1).reshape( (img_size, img_size)) ppi_2 = np.argmax(out_2.cpu().data.numpy(), 1).reshape( (img_size / 2, img_size / 2)) ppi_3 = np.argmax(out_3.cpu().data.numpy(), 1).reshape( (img_size / 4, img_size / 4)) ppi_4 = np.argmax(out_4.cpu().data.numpy(), 1).reshape( (img_size / 8, img_size / 8)) ppi_5 = np.argmax(out_5.cpu().data.numpy(), 1).reshape( (img_size / 16, img_size / 16)) ppi_6 = np.argmax(out_6.cpu().data.numpy(), 1).reshape( (img_size / 32, img_size / 32)) confusion = np.zeros([n_class, n_class]) def compute_acc(confusion, tmp_gt, ppi): tmp_gt = tmp_gt.reshape([-1]) tmp_out = ppi.reshape([-1]) for idx in xrange(len(tmp_gt)): confusion[tmp_gt[idx], tmp_out[idx]] += 1 return confusion confusion = compute_acc(confusion, label_1[1], ppi_1) confusion = compute_acc(confusion, label_2[1], ppi_2) confusion = compute_acc(confusion, label_3[1], ppi_3) confusion = compute_acc(confusion, label_4[1], ppi_4) confusion = compute_acc(confusion, label_5[1], ppi_5) confusion = compute_acc(confusion, label_6[1], ppi_6) meanIU, pixelAccuracy, meanAccuracy, classAccuracy = calculate_Accuracy( confusion) return l, meanIU, pixelAccuracy, meanAccuracy, classAccuracy
np.random.shuffle(batch_idx) if epoch % 10 == 0 and epoch != 0: args.lr /= 10 args.pre_lr /= 10 optimizer = torch.optim.SGD([{ 'params': base_params }, { 'params': pre_params, 'lr': args.pre_lr }], lr=args.lr, momentum=0.9) if epoch != 0: torch.save(model.state_dict(), './models/%s.pth' % epoch) meanIU, pixelAccuracy, meanAccuracy, classAccuracy = calculate_Accuracy( confusion) # print '=========================== epoch : %s ==========================='%epoch # print('meanIOU: {:.2f} | pixelAccuracy: {:.2f} | meanAccuracy: {:.2f} | classAccuracy: {:.2f} |'.format(meanIU, # pixelAccuracy, # meanAccuracy, # classAccuracy)) # with open('./logs/log.txt', 'a+') as f: # log_str = "%d\t\t%.4f\t%.4f\t%.4f\t%.4f\t" % (epoch, meanIU, pixelAccuracy, meanAccuracy, classAccuracy) # f.writelines(str(log_str) + '\n') # # with open('./logs/confusion.pkl', 'a+') as f: # pkl.dump(confusion, f) # print 'save the results success' confusion = np.zeros([args.n_class, args.n_class])