def test(model, loader, num_class=40): mean_correct = [] class_acc = np.zeros((num_class,3)) for j, data in tqdm(enumerate(loader), total=len(loader)): points, target = data trot = None if args.rot == 'z': trot = RotateAxisAngle(angle=torch.rand(points.shape[0])*360, axis="Z", degrees=True) elif args.rot == 'so3': trot = Rotate(R=random_rotations(points.shape[0])) if trot is not None: points = trot.transform_points(points) target = target[:, 0] points = points.transpose(2, 1) points, target = points.cuda(), target.cuda() classifier = model.eval() pred, _ = classifier(points) pred_choice = pred.data.max(1)[1] for cat in np.unique(target.cpu()): classacc = pred_choice[target==cat].eq(target[target==cat].long().data).cpu().sum() class_acc[cat,0]+= classacc.item()/float(points[target==cat].size()[0]) class_acc[cat,1]+=1 correct = pred_choice.eq(target.long().data).cpu().sum() mean_correct.append(correct.item()/float(points.size()[0])) class_acc[:,2] = class_acc[:,0]/ class_acc[:,1] class_acc = np.mean(class_acc[:,2]) instance_acc = np.mean(mean_correct) return instance_acc, class_acc
def test_specular_point_lights(self): """ Replace directional lights with point lights and check the output is the same. Test an additional case where the angle between the light reflection direction and the view direction is 30 degrees. """ color = torch.tensor([1, 0, 1], dtype=torch.float32) location = torch.tensor([-1, 1, 0], dtype=torch.float32) camera_position = torch.tensor( [+1 / np.sqrt(2), 1 / np.sqrt(2), 0], dtype=torch.float32 ) points = torch.tensor([0, 0, 0], dtype=torch.float32) normals = torch.tensor([0, 1, 0], dtype=torch.float32) expected_output = torch.tensor([1.0, 0.0, 1.0], dtype=torch.float32) expected_output = expected_output.view(-1, 1, 3) lights = PointLights( specular_color=color[None, :], location=location[None, :] ) output_light = lights.specular( points=points[None, None, :], normals=normals[None, None, :], camera_position=camera_position[None, :], shininess=torch.tensor(10), ) self.assertTrue(torch.allclose(output_light, expected_output)) # Change camera position to be behind the point camera_position = torch.tensor( [+1 / np.sqrt(2), -1 / np.sqrt(2), 0], dtype=torch.float32 ) expected_output = torch.zeros_like(expected_output) output_light = lights.specular( points=points[None, None, :], normals=normals[None, None, :], camera_position=camera_position[None, :], shininess=torch.tensor(10), ) self.assertTrue(torch.allclose(output_light, expected_output)) # Change camera direction to be 30 degrees from the reflection direction camera_position = torch.tensor( [+1 / np.sqrt(2), 1 / np.sqrt(2), 0], dtype=torch.float32 ) rotate_30 = RotateAxisAngle(-30, axis="z") camera_position = rotate_30.transform_points(camera_position[None, :]) expected_output = torch.tensor( [np.cos(30.0 * np.pi / 180), 0.0, np.cos(30.0 * np.pi / 180)], dtype=torch.float32, ) expected_output = expected_output.view(-1, 1, 3) output_light = lights.specular( points=points[None, None, :], normals=normals[None, None, :], camera_position=camera_position[None, :], shininess=torch.tensor(10), ) self.assertTrue(torch.allclose(output_light, expected_output ** 10))
def rotate_mesh_around_axis( mesh, rot: list, renderer, dist: float = 3.5, save: str = "", device: str = "" ): if not device: device = torch.cuda.current_device() rot_x = RotateAxisAngle(rot[0], "X", device=device) rot_y = RotateAxisAngle(rot[1], "Y", device=device) rot_z = RotateAxisAngle(rot[2], "Z", device=device) rot = Transform3d(device=device).stack(rot_x, rot_y, rot_z) verts, faces = mesh.get_mesh_verts_faces(0) verts = rot_x.transform_points(verts) verts = rot_y.transform_points(verts) verts = rot_z.transform_points(verts) mesh = mesh.update_padded(verts.unsqueeze(0)) dist = dist elev = 0 azim = 0 R, T = look_at_view_transform(dist=dist, elev=elev, azim=azim, device=device) image_ref = renderer(meshes_world=mesh, R=R, T=T, device=device) image_ref = image_ref.cpu().numpy()[..., :3] plt.imshow(image_ref.squeeze()) plt.show() if save: verts, faces = mesh.get_mesh_verts_faces(0) save_obj(save, verts, faces) return mesh
def main(args): def log_string(str): logger.info(str) print(str) '''HYPER PARAMETER''' os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu '''CREATE DIR''' timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')) experiment_dir = Path('./log/') experiment_dir.mkdir(parents=True, exist_ok=True) experiment_dir = experiment_dir.joinpath('partseg') experiment_dir.mkdir(parents=True, exist_ok=True) if args.log_dir is None: experiment_dir = experiment_dir.joinpath(timestr) else: experiment_dir = experiment_dir.joinpath(args.log_dir) experiment_dir.mkdir(parents=True, exist_ok=True) checkpoints_dir = experiment_dir.joinpath('checkpoints/') checkpoints_dir.mkdir(parents=True, exist_ok=True) log_dir = experiment_dir.joinpath('logs/') log_dir.mkdir(parents=True, exist_ok=True) '''LOG''' args = parse_args() logger = logging.getLogger("Model") logger.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model)) file_handler.setLevel(logging.INFO) file_handler.setFormatter(formatter) logger.addHandler(file_handler) log_string('PARAMETER ...') log_string(args) root = 'data/shapenetcore_partanno_segmentation_benchmark_v0_normal/' TRAIN_DATASET = PartNormalDataset(root = root, npoints=args.npoint, split='trainval', normal_channel=args.normal) trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size,shuffle=True, num_workers=4) TEST_DATASET = PartNormalDataset(root = root, npoints=args.npoint, split='test', normal_channel=args.normal) testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size,shuffle=False, num_workers=4) log_string("The number of training data is: %d" % len(TRAIN_DATASET)) log_string("The number of test data is: %d" % len(TEST_DATASET)) num_classes = 16 num_part = 50 '''MODEL LOADING''' MODEL = importlib.import_module(args.model) classifier = MODEL.get_model(args, num_part, normal_channel=args.normal).cuda() criterion = MODEL.get_loss().cuda() def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv2d') != -1: torch.nn.init.xavier_normal_(m.weight.data) torch.nn.init.constant_(m.bias.data, 0.0) elif classname.find('Linear') != -1: torch.nn.init.xavier_normal_(m.weight.data) torch.nn.init.constant_(m.bias.data, 0.0) try: checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth') start_epoch = checkpoint['epoch'] classifier.load_state_dict(checkpoint['model_state_dict']) log_string('Use pretrain model') except: log_string('No existing model, starting training from scratch...') start_epoch = 0 #classifier = classifier.apply(weights_init) if args.optimizer == 'Adam': optimizer = torch.optim.Adam( classifier.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.decay_rate ) else: optimizer = torch.optim.SGD(classifier.parameters(), lr=args.learning_rate, momentum=0.9) def bn_momentum_adjust(m, momentum): if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d): m.momentum = momentum LEARNING_RATE_CLIP = 1e-5 MOMENTUM_ORIGINAL = 0.1 MOMENTUM_DECCAY = 0.5 MOMENTUM_DECCAY_STEP = args.step_size best_acc = 0 global_epoch = 0 best_class_avg_iou = 0 best_inctance_avg_iou = 0 for epoch in range(start_epoch,args.epoch): log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch)) '''Adjust learning rate and BN momentum''' lr = max(args.learning_rate * (args.lr_decay ** (epoch // args.step_size)), LEARNING_RATE_CLIP) log_string('Learning rate:%f' % lr) for param_group in optimizer.param_groups: param_group['lr'] = lr mean_correct = [] momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP)) if momentum < 0.01: momentum = 0.01 print('BN momentum updated to: %f' % momentum) classifier = classifier.apply(lambda x: bn_momentum_adjust(x,momentum)) '''learning one epoch''' for i, data in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader), smoothing=0.9): points, label, target = data trot = None if args.rot == 'z': trot = RotateAxisAngle(angle=torch.rand(points.shape[0])*360, axis="Z", degrees=True) elif args.rot == 'so3': trot = Rotate(R=random_rotations(points.shape[0])) if trot is not None: points = trot.transform_points(points) points = points.data.numpy() points[:,:, 0:3] = provider.random_scale_point_cloud(points[:,:, 0:3]) points[:,:, 0:3] = provider.shift_point_cloud(points[:,:, 0:3]) points = torch.Tensor(points) points, label, target = points.float().cuda(),label.long().cuda(), target.long().cuda() points = points.transpose(2, 1) optimizer.zero_grad() classifier = classifier.train() seg_pred, trans_feat = classifier(points, to_categorical(label, num_classes)) seg_pred = seg_pred.contiguous().view(-1, num_part) target = target.view(-1, 1)[:, 0] pred_choice = seg_pred.data.max(1)[1] correct = pred_choice.eq(target.data).cpu().sum() mean_correct.append(correct.item() / (args.batch_size * args.npoint)) loss = criterion(seg_pred, target, trans_feat) loss.backward() optimizer.step() train_instance_acc = np.mean(mean_correct) log_string('Train accuracy is: %.5f' % train_instance_acc) with torch.no_grad(): test_metrics = {} total_correct = 0 total_seen = 0 total_seen_class = [0 for _ in range(num_part)] total_correct_class = [0 for _ in range(num_part)] shape_ious = {cat: [] for cat in seg_classes.keys()} seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table} for cat in seg_classes.keys(): for label in seg_classes[cat]: seg_label_to_cat[label] = cat for batch_id, (points, label, target) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9): cur_batch_size, NUM_POINT, _ = points.size() trot = None if args.rot == 'z': trot = RotateAxisAngle(angle=torch.rand(points.shape[0])*360, axis="Z", degrees=True) elif args.rot == 'so3': trot = Rotate(R=random_rotations(points.shape[0])) if trot is not None: points = trot.transform_points(points) points, label, target = points.float().cuda(), label.long().cuda(), target.long().cuda() points = points.transpose(2, 1) classifier = classifier.eval() seg_pred, _ = classifier(points, to_categorical(label, num_classes)) cur_pred_val = seg_pred.cpu().data.numpy() cur_pred_val_logits = cur_pred_val cur_pred_val = np.zeros((cur_batch_size, NUM_POINT)).astype(np.int32) target = target.cpu().data.numpy() for i in range(cur_batch_size): cat = seg_label_to_cat[target[i, 0]] logits = cur_pred_val_logits[i, :, :] cur_pred_val[i, :] = np.argmax(logits[:, seg_classes[cat]], 1) + seg_classes[cat][0] correct = np.sum(cur_pred_val == target) total_correct += correct total_seen += (cur_batch_size * NUM_POINT) for l in range(num_part): total_seen_class[l] += np.sum(target == l) total_correct_class[l] += (np.sum((cur_pred_val == l) & (target == l))) for i in range(cur_batch_size): segp = cur_pred_val[i, :] segl = target[i, :] cat = seg_label_to_cat[segl[0]] part_ious = [0.0 for _ in range(len(seg_classes[cat]))] for l in seg_classes[cat]: if (np.sum(segl == l) == 0) and ( np.sum(segp == l) == 0): # part is not present, no prediction as well part_ious[l - seg_classes[cat][0]] = 1.0 else: part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float( np.sum((segl == l) | (segp == l))) shape_ious[cat].append(np.mean(part_ious)) all_shape_ious = [] for cat in shape_ious.keys(): for iou in shape_ious[cat]: all_shape_ious.append(iou) shape_ious[cat] = np.mean(shape_ious[cat]) mean_shape_ious = np.mean(list(shape_ious.values())) test_metrics['accuracy'] = total_correct / float(total_seen) test_metrics['class_avg_accuracy'] = np.mean( np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float)) for cat in sorted(shape_ious.keys()): log_string('eval mIoU of %s %f' % (cat + ' ' * (14 - len(cat)), shape_ious[cat])) test_metrics['class_avg_iou'] = mean_shape_ious test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious) log_string('Epoch %d test Accuracy: %f Class avg mIOU: %f Inctance avg mIOU: %f' % ( epoch+1, test_metrics['accuracy'],test_metrics['class_avg_iou'],test_metrics['inctance_avg_iou'])) if (test_metrics['inctance_avg_iou'] >= best_inctance_avg_iou): logger.info('Save model...') savepath = str(checkpoints_dir) + '/best_model.pth' log_string('Saving at %s'% savepath) state = { 'epoch': epoch, 'train_acc': train_instance_acc, 'test_acc': test_metrics['accuracy'], 'class_avg_iou': test_metrics['class_avg_iou'], 'inctance_avg_iou': test_metrics['inctance_avg_iou'], 'model_state_dict': classifier.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), } torch.save(state, savepath) log_string('Saving model....') if test_metrics['accuracy'] > best_acc: best_acc = test_metrics['accuracy'] if test_metrics['class_avg_iou'] > best_class_avg_iou: best_class_avg_iou = test_metrics['class_avg_iou'] if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou: best_inctance_avg_iou = test_metrics['inctance_avg_iou'] log_string('Best accuracy is: %.5f'%best_acc) log_string('Best class avg mIOU is: %.5f'%best_class_avg_iou) log_string('Best inctance avg mIOU is: %.5f'%best_inctance_avg_iou) global_epoch+=1
def main(args): def log_string(str): logger.info(str) print(str) '''HYPER PARAMETER''' os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu experiment_dir = 'log/partseg/' + args.log_dir '''LOG''' args = parse_args() logger = logging.getLogger("Model") logger.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') file_handler = logging.FileHandler('%s/eval.txt' % experiment_dir) file_handler.setLevel(logging.INFO) file_handler.setFormatter(formatter) logger.addHandler(file_handler) log_string('PARAMETER ...') log_string(args) root = 'data/shapenetcore_partanno_segmentation_benchmark_v0_normal/' TEST_DATASET = PartNormalDataset(root = root, npoints=args.num_point, split='test', normal_channel=args.normal) testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size,shuffle=False, num_workers=4) log_string("The number of test data is: %d" % len(TEST_DATASET)) num_classes = 16 num_part = 50 '''MODEL LOADING''' MODEL = importlib.import_module(args.model) classifier = MODEL.get_model(args, num_part, normal_channel=args.normal).cuda() checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth') classifier.load_state_dict(checkpoint['model_state_dict']) with torch.no_grad(): test_metrics = {} total_correct = 0 total_seen = 0 total_seen_class = [0 for _ in range(num_part)] total_correct_class = [0 for _ in range(num_part)] shape_ious = {cat: [] for cat in seg_classes.keys()} seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table} for cat in seg_classes.keys(): for label in seg_classes[cat]: seg_label_to_cat[label] = cat for batch_id, (points, label, target) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9): batchsize, num_point, _ = points.size() cur_batch_size, NUM_POINT, _ = points.size() if args.rot == 'z': trot = RotateAxisAngle(angle=torch.rand(points.shape[0])*360, axis="Z", degrees=True) elif args.rot == 'so3': trot = Rotate(R=random_rotations(points.shape[0])) points = trot.transform_points(points) points, label, target = points.float().cuda(), label.long().cuda(), target.long().cuda() points = points.transpose(2, 1) classifier = classifier.eval() vote_pool = torch.zeros(target.size()[0], target.size()[1], num_part).cuda() for _ in range(args.num_votes): seg_pred, _ = classifier(points, to_categorical(label, num_classes)) vote_pool += seg_pred seg_pred = vote_pool / args.num_votes cur_pred_val = seg_pred.cpu().data.numpy() cur_pred_val_logits = cur_pred_val cur_pred_val = np.zeros((cur_batch_size, NUM_POINT)).astype(np.int32) target = target.cpu().data.numpy() for i in range(cur_batch_size): cat = seg_label_to_cat[target[i, 0]] logits = cur_pred_val_logits[i, :, :] cur_pred_val[i, :] = np.argmax(logits[:, seg_classes[cat]], 1) + seg_classes[cat][0] correct = np.sum(cur_pred_val == target) total_correct += correct total_seen += (cur_batch_size * NUM_POINT) for l in range(num_part): total_seen_class[l] += np.sum(target == l) total_correct_class[l] += (np.sum((cur_pred_val == l) & (target == l))) for i in range(cur_batch_size): segp = cur_pred_val[i, :] segl = target[i, :] cat = seg_label_to_cat[segl[0]] part_ious = [0.0 for _ in range(len(seg_classes[cat]))] for l in seg_classes[cat]: if (np.sum(segl == l) == 0) and ( np.sum(segp == l) == 0): # part is not present, no prediction as well part_ious[l - seg_classes[cat][0]] = 1.0 else: part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float( np.sum((segl == l) | (segp == l))) shape_ious[cat].append(np.mean(part_ious)) all_shape_ious = [] for cat in shape_ious.keys(): for iou in shape_ious[cat]: all_shape_ious.append(iou) shape_ious[cat] = np.mean(shape_ious[cat]) mean_shape_ious = np.mean(list(shape_ious.values())) test_metrics['accuracy'] = total_correct / float(total_seen) test_metrics['class_avg_accuracy'] = np.mean( np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float)) for cat in sorted(shape_ious.keys()): log_string('eval mIoU of %s %f' % (cat + ' ' * (14 - len(cat)), shape_ious[cat])) test_metrics['class_avg_iou'] = mean_shape_ious test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious) log_string('Accuracy is: %.5f'%test_metrics['accuracy']) log_string('Class avg accuracy is: %.5f'%test_metrics['class_avg_accuracy']) log_string('Class avg mIOU is: %.5f'%test_metrics['class_avg_iou']) log_string('Inctance avg mIOU is: %.5f'%test_metrics['inctance_avg_iou'])
def main(args): def log_string(str): logger.info(str) print(str) '''HYPER PARAMETER''' os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu '''CREATE DIR''' timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')) experiment_dir = Path('./log/') experiment_dir.mkdir(parents=True, exist_ok=True) experiment_dir = experiment_dir.joinpath('cls') experiment_dir.mkdir(parents=True, exist_ok=True) if args.log_dir is None: experiment_dir = experiment_dir.joinpath(timestr) else: experiment_dir = experiment_dir.joinpath(args.log_dir) experiment_dir.mkdir(parents=True, exist_ok=True) checkpoints_dir = experiment_dir.joinpath('checkpoints/') checkpoints_dir.mkdir(parents=True, exist_ok=True) log_dir = experiment_dir.joinpath('logs/') log_dir.mkdir(parents=True, exist_ok=True) '''LOG''' args = parse_args() logger = logging.getLogger("Model") logger.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model)) file_handler.setLevel(logging.INFO) file_handler.setFormatter(formatter) logger.addHandler(file_handler) log_string('PARAMETER ...') log_string(args) '''DATA LOADING''' log_string('Load dataset ...') DATA_PATH = 'data/modelnet40_normal_resampled/' TRAIN_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='train', normal_channel=args.normal) TEST_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='test', normal_channel=args.normal) trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=4) testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=4) '''MODEL LOADING''' num_class = 40 MODEL = importlib.import_module(args.model) classifier = MODEL.get_model(args, num_class, normal_channel=args.normal).cuda() criterion = MODEL.get_loss().cuda() try: checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth') start_epoch = checkpoint['epoch'] classifier.load_state_dict(checkpoint['model_state_dict']) log_string('Use pretrain model') except: log_string('No existing model, starting training from scratch...') start_epoch = 0 if args.optimizer == 'Adam': optimizer = torch.optim.Adam( classifier.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.decay_rate ) else: optimizer = torch.optim.SGD( classifier.parameters(), lr=args.learning_rate*100, momentum=0.9, weight_decay=args.decay_rate ) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7) global_epoch = 0 global_step = 0 best_instance_acc = 0.0 best_class_acc = 0.0 mean_correct = [] '''TRANING''' logger.info('Start training...') for epoch in range(start_epoch,args.epoch): log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch)) scheduler.step() for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9): points, target = data trot = None if args.rot == 'z': trot = RotateAxisAngle(angle=torch.rand(points.shape[0])*360, axis="Z", degrees=True) elif args.rot == 'so3': trot = Rotate(R=random_rotations(points.shape[0])) if trot is not None: points = trot.transform_points(points) points = points.data.numpy() points = provider.random_point_dropout(points) points[:,:, 0:3] = provider.random_scale_point_cloud(points[:,:, 0:3]) points[:,:, 0:3] = provider.shift_point_cloud(points[:,:, 0:3]) points = torch.Tensor(points) target = target[:, 0] points = points.transpose(2, 1) points, target = points.cuda(), target.cuda() optimizer.zero_grad() classifier = classifier.train() pred, trans_feat = classifier(points) loss = criterion(pred, target.long(), trans_feat) pred_choice = pred.data.max(1)[1] correct = pred_choice.eq(target.long().data).cpu().sum() mean_correct.append(correct.item() / float(points.size()[0])) loss.backward() optimizer.step() global_step += 1 train_instance_acc = np.mean(mean_correct) log_string('Train Instance Accuracy: %f' % train_instance_acc) with torch.no_grad(): instance_acc, class_acc = test(classifier.eval(), testDataLoader) if (instance_acc >= best_instance_acc): best_instance_acc = instance_acc best_epoch = epoch + 1 if (class_acc >= best_class_acc): best_class_acc = class_acc log_string('Test Instance Accuracy: %f, Class Accuracy: %f'% (instance_acc, class_acc)) log_string('Best Instance Accuracy: %f, Class Accuracy: %f'% (best_instance_acc, best_class_acc)) if (instance_acc >= best_instance_acc): logger.info('Save model...') savepath = str(checkpoints_dir) + '/best_model.pth' log_string('Saving at %s'% savepath) state = { 'epoch': best_epoch, 'instance_acc': instance_acc, 'class_acc': class_acc, 'model_state_dict': classifier.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), } torch.save(state, savepath) global_epoch += 1 logger.info('End of training...')