def train(): print("local_rank:", args.local_rank) cudnn.benchmark = True if args.deterministic: cudnn.benchmark = False cudnn.deterministic = True torch.manual_seed(args.local_rank) torch.set_printoptions(precision=10) torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group( backend='nccl', init_method='env://', ) torch.manual_seed(0) if not args.eval_net: train_ds = dataset_desc.Dataset('train') train_sampler = torch.utils.data.distributed.DistributedSampler(train_ds) train_loader = torch.utils.data.DataLoader( train_ds, batch_size=config.mini_batch_size, shuffle=False, drop_last=True, num_workers=4, sampler=train_sampler, pin_memory=True ) val_ds = dataset_desc.Dataset('test') val_sampler = torch.utils.data.distributed.DistributedSampler(val_ds) val_loader = torch.utils.data.DataLoader( val_ds, batch_size=config.val_mini_batch_size, shuffle=False, drop_last=False, num_workers=4, sampler=val_sampler ) else: test_ds = dataset_desc.Dataset('test') test_loader = torch.utils.data.DataLoader( test_ds, batch_size=config.test_mini_batch_size, shuffle=False, num_workers=20 ) rndla_cfg = ConfigRandLA model = FFB6D( n_classes=config.n_objects, n_pts=config.n_sample_points, rndla_cfg=rndla_cfg, n_kps=config.n_keypoints ) model = convert_syncbn_model(model) device = torch.device('cuda:{}'.format(args.local_rank)) print('local_rank:', args.local_rank) model.to(device) optimizer = optim.Adam( model.parameters(), lr=args.lr, weight_decay=args.weight_decay ) opt_level = args.opt_level model, optimizer = amp.initialize( model, optimizer, opt_level=opt_level, ) # default value it = -1 # for the initialize value of `LambdaLR` and `BNMomentumScheduler` best_loss = 1e10 start_epoch = 1 # load status from checkpoint if args.checkpoint is not None: checkpoint_status = load_checkpoint( model, optimizer, filename=args.checkpoint[:-8] ) if checkpoint_status is not None: it, start_epoch, best_loss = checkpoint_status if args.eval_net: assert checkpoint_status is not None, "Failed loadding model." if not args.eval_net: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True ) clr_div = 6 lr_scheduler = CyclicLR( optimizer, base_lr=1e-5, max_lr=1e-3, cycle_momentum=False, step_size_up=config.n_total_epoch * train_ds.minibatch_per_epoch // clr_div // args.gpus, step_size_down=config.n_total_epoch * train_ds.minibatch_per_epoch // clr_div // args.gpus, mode='triangular' ) else: lr_scheduler = None bnm_lmbd = lambda it: max( args.bn_momentum * args.bn_decay ** (int(it * config.mini_batch_size / args.decay_step)), bnm_clip, ) bnm_scheduler = pt_utils.BNMomentumScheduler( model, bn_lambda=bnm_lmbd, last_epoch=it ) it = max(it, 0) # for the initialize value of `trainer.train` if args.eval_net: model_fn = model_fn_decorator( FocalLoss(gamma=2), OFLoss(), args.test, ) else: model_fn = model_fn_decorator( FocalLoss(gamma=2).to(device), OFLoss().to(device), args.test, ) checkpoint_fd = config.log_model_dir trainer = Trainer( model, model_fn, optimizer, checkpoint_name=os.path.join(checkpoint_fd, "FFB6D"), best_name=os.path.join(checkpoint_fd, "FFB6D_best"), lr_scheduler=lr_scheduler, bnm_scheduler=bnm_scheduler, ) if args.eval_net: start = time.time() val_loss, res = trainer.eval_epoch( test_loader, is_test=True, test_pose=args.test_pose ) end = time.time() print("\nUse time: ", end - start, 's') else: trainer.train( it, start_epoch, config.n_total_epoch, train_loader, None, val_loader, best_loss=best_loss, tot_iter=config.n_total_epoch * train_ds.minibatch_per_epoch // args.gpus, clr_div=clr_div ) if start_epoch == config.n_total_epoch: _ = trainer.eval_epoch(val_loader)
def main(args): def log_string(str): logger.info(str) print(str) '''CREATE DIR''' timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')) exp_dir = Path('./log/') exp_dir.mkdir(exist_ok=True) exp_dir = exp_dir.joinpath('dsae') exp_dir.mkdir(exist_ok=True) if args.log_dir is None: exp_dir = exp_dir.joinpath(timestr) else: exp_dir = exp_dir.joinpath(args.log_dir) exp_dir.mkdir(exist_ok=True) checkpoints_dir = exp_dir.joinpath('checkpoints/') checkpoints_dir.mkdir(exist_ok=True) log_dir = exp_dir.joinpath('logs/') log_dir.mkdir(exist_ok=True) '''LOG''' args = parse_args() logger = logging.getLogger("Model") logger.setLevel(logging.INFO) formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s') file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model)) file_handler.setLevel(logging.INFO) file_handler.setFormatter(formatter) logger.addHandler(file_handler) log_string('PARAMETER ...') log_string(args) '''DATA LOADING''' log_string('Load dataset ...') # Construct the dataset train_dataset, train_config = construct_dataset(is_train=True) # Random split train_set_size = int(len(train_dataset) * 0.8) valid_set_size = len(train_dataset) - train_set_size train_dataset, valid_dataset = torch.utils.data.random_split( train_dataset, [train_set_size, valid_set_size]) # And the dataloader trainDataLoader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) validDataLoader = DataLoader(dataset=valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) '''MODEL LOADING''' out_channel = args.out_channel model = importlib.import_module(args.model) shutil.copy('./models/%s.py' % args.model, str(exp_dir)) shutil.copy('models/pointnet2_utils.py', str(exp_dir)) shutil.copy('./train_dsae.py', str(exp_dir)) network = model.get_model() criterion_rmse = RMSELoss() criterion_cos = torch.nn.CosineSimilarity(dim=1) criterion_bce = torch.nn.BCELoss() criterion_kptof = OFLoss() network.apply(inplace_relu) if not args.use_cpu: network = network.cuda() criterion_rmse = criterion_rmse.cuda() criterion_cos = criterion_cos.cuda() criterion_bce = criterion_bce.cuda() criterion_kptof = criterion_kptof.cuda() try: checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth') start_epoch = checkpoint['epoch'] network.load_state_dict(checkpoint['model_state_dict']) log_string('Use pretrain model') except: log_string('No existing model, starting training from scratch...') start_epoch = 0 if args.optimizer == 'Adam': optimizer = torch.optim.Adam(network.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.decay_rate) else: optimizer = torch.optim.SGD(network.parameters(), lr=0.01, momentum=0.9) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) global_epoch = 0 global_step = 0 best_rot_error = 99.9 best_xyz_error = 99.9 best_recon_error = 99.9 '''TRANING''' logger.info('Start training...') for epoch in range(start_epoch, args.epoch): log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch)) train_xyz_error = [] train_rot_error = [] train_recon_error = [] network = network.train() scheduler.step() for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9): optimizer.zero_grad() rgbd = data[parameter.rgbd_image_key] rgb_pair = data[parameter.rgb_pair_key] depth_pair = data[parameter.depth_pair_key] delta_rot = data[parameter.delta_rot_key] delta_xyz = data[parameter.delta_xyz_key] if not args.use_cpu: delta_rot = delta_rot.cuda() delta_xyz = delta_xyz.cuda() rgbd = rgbd.cuda() rgb_pair = rgb_pair.cuda() depth_pair = depth_pair.cuda() rgb = rgbd[:, :3, :, :] depth = rgbd[:, 3:, :] # rgb (B,3,H,W) rgb_pair (B,6,H,W) # depth (B,1,H,W) depth_pair (B,2,H,W) delta_xyz_pred, delta_rot_pred, depth_pred = network(rgb_pair) # loss computation loss_t = (1 - criterion_cos(delta_xyz_pred, delta_xyz) ).mean() + criterion_rmse(delta_xyz_pred, delta_xyz) loss_r = criterion_rmse(delta_rot_pred, delta_rot) loss_recon = criterion_rmse(depth_pred, depth_pair) loss = loss_t + loss_r + loss_recon loss.backward() optimizer.step() global_step += 1 train_xyz_error.append(loss_t.item()) train_rot_error.append(loss_r.item()) train_recon_error.append(loss_recon.item()) train_xyz_error = sum(train_xyz_error) / len(train_xyz_error) train_rot_error = sum(train_rot_error) / len(train_rot_error) train_recon_error = sum(train_recon_error) / len(train_recon_error) log_string('Train Translation Error: %f' % train_xyz_error) log_string('Train Rotation Error: %f' % train_rot_error) log_string('Train Reconstruction Error: %f' % train_recon_error) with torch.no_grad(): xyz_error, rot_error, recon_error = test( network.eval(), validDataLoader, out_channel, criterion_rmse, criterion_cos, criterion_bce, criterion_kptof) log_string( 'Test Translation Error: %f, Rotation Error: %f, Reconstruction Error: %f' % (xyz_error, rot_error, recon_error)) log_string( 'Best Translation Error: %f, Rotation Error: %f, Reconstruction Error: %f' % (best_xyz_error, best_rot_error, best_recon_error)) if (xyz_error + rot_error + recon_error) < ( best_xyz_error + best_rot_error + best_recon_error): best_xyz_error = xyz_error best_rot_error = rot_error best_recon_error = recon_error best_epoch = epoch + 1 logger.info('Save model...') savepath = str(checkpoints_dir) + '/best_model_e_' + str( best_epoch) + '.pth' log_string('Saving at %s' % savepath) state = { 'epoch': best_epoch, 'xyz_error': xyz_error, 'rot_error': rot_error, 'recon_error': recon_error, 'model_state_dict': network.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), } torch.save(state, savepath) global_epoch += 1 logger.info('End of training...')
def main(args): def log_string(str): logger.info(str) print(str) '''CREATE DIR''' timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')) exp_dir = Path('./log/') exp_dir.mkdir(exist_ok=True) exp_dir = exp_dir.joinpath('kpts') exp_dir.mkdir(exist_ok=True) if args.log_dir is None: exp_dir = exp_dir.joinpath(timestr) else: exp_dir = exp_dir.joinpath(args.log_dir) exp_dir.mkdir(exist_ok=True) checkpoints_dir = exp_dir.joinpath('checkpoints/') checkpoints_dir.mkdir(exist_ok=True) log_dir = exp_dir.joinpath('logs/') log_dir.mkdir(exist_ok=True) '''LOG''' args = parse_args() logger = logging.getLogger("Model") logger.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model)) file_handler.setLevel(logging.INFO) file_handler.setFormatter(formatter) logger.addHandler(file_handler) log_string('PARAMETER ...') log_string(args) '''DATA LOADING''' log_string('Load dataset ...') # Construct the dataset train_dataset, train_config = construct_dataset(is_train=True) # Random split train_set_size = int(len(train_dataset) * 0.8) valid_set_size = len(train_dataset) - train_set_size train_dataset, valid_dataset = torch.utils.data.random_split(train_dataset, [train_set_size, valid_set_size]) # And the dataloader trainDataLoader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) validDataLoader = DataLoader(dataset=valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) '''MODEL LOADING''' out_channel = args.out_channel model = importlib.import_module(args.model) shutil.copy('./models/%s.py' % args.model, str(exp_dir)) shutil.copy('models/pointnet2_utils.py', str(exp_dir)) shutil.copy('./train_pointnet2_kpts.py', str(exp_dir)) network = model.get_model() criterion_rmse = RMSELoss() criterion_cos = torch.nn.CosineSimilarity(dim=1) criterion_bce = torch.nn.BCELoss() criterion_kptof = OFLoss() network.apply(inplace_relu) if not args.use_cpu: network = network.cuda() criterion_rmse = criterion_rmse.cuda() criterion_cos = criterion_cos.cuda() criterion_bce = criterion_bce.cuda() criterion_kptof = criterion_kptof.cuda() try: checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth') start_epoch = checkpoint['epoch'] network.load_state_dict(checkpoint['model_state_dict']) log_string('Use pretrain model') except: log_string('No existing model, starting training from scratch...') start_epoch = 0 if args.optimizer == 'Adam': optimizer = torch.optim.Adam( network.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.decay_rate ) else: optimizer = torch.optim.SGD(network.parameters(), lr=0.01, momentum=0.9) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7) global_epoch = 0 global_step = 0 best_rot_error = 99.9 best_xyz_error = 99.9 best_heatmap_error = 99.9 best_step_size_error = 99.9 best_kptof_error = 99.9 best_mask_error = 99.9 '''TRANING''' logger.info('Start training...') for epoch in range(start_epoch, args.epoch): log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch)) train_rot_error = [] train_xyz_error = [] train_heatmap_error = [] train_step_size_error = [] train_kptof_error = [] train_mask_error = [] network = network.train() scheduler.step() for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9): optimizer.zero_grad() points = data[parameter.pcd_key].numpy() # Because we need to predcit 3d keypoint, we don't do augmentation here. #points = provider.normalize_data(points) #points = provider.random_point_dropout(points) #points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3]) #points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3]) points = torch.Tensor(points) points = points.transpose(2, 1) heatmap_target = data[parameter.heatmap_key] #segmentation_target = data[parameter.segmentation_key] delta_rot = data[parameter.delta_rot_key] delta_xyz = data[parameter.delta_xyz_key] #unit_delta_xyz = data[parameter.unit_delta_xyz_key] #step_size = data[parameter.step_size_key] kpt_of_gt = data[parameter.kpt_of_key] pcd_centroid = data[parameter.pcd_centroid_key] pcd_mean = data[parameter.pcd_mean_key] gripper_pose = data[parameter.gripper_pose_key] if not args.use_cpu: points = points.cuda() kpt_of_gt = kpt_of_gt.cuda() delta_rot = delta_rot.cuda() delta_xyz = delta_xyz.cuda() pcd_centroid = pcd_centroid.cuda() pcd_mean = pcd_mean.cuda() gripper_pose = gripper_pose.cuda() #segmentation_target = segmentation_target.cuda() heatmap_target = heatmap_target.cuda() ''' delta_rot = delta_rot.cuda() delta_xyz = delta_xyz.cuda() heatmap_target = heatmap_target.cuda() unit_delta_xyz = unit_delta_xyz.cuda() step_size = step_size.cuda() ''' kpt_of_pred, trans_of_pred, rot_of_pred, mean_kpt_pred, rot_mat_pred, confidence = network(points) gripper_pos = gripper_pose[:, :3, 3] gripper_rot = gripper_pose[:, :3, :3] #points = points.transpose(2, 1) #kpt_pred = points - kpt_of_pred #mean_kpt_pred = torch.mean(kpt_pred, dim=1) real_kpt_pred = (mean_kpt_pred * pcd_mean) + pcd_centroid real_kpt_pred = real_kpt_pred / 1000 #unit: mm to m real_trans_of_pred = (trans_of_pred * pcd_mean) / 1000 #unit: mm to m delta_trans_pred = real_kpt_pred - gripper_pos + real_trans_of_pred delta_rot_pred = torch.bmm(torch.transpose(gripper_rot, 1, 2), rot_mat_pred) delta_rot_pred = torch.bmm(delta_rot_pred, rot_of_pred) ''' heatmap_pred, action_pred, step_size_pred = network(points) # action control delta_rot_pred_6d = action_pred[:, 0:6] delta_rot_pred = compute_rotation_matrix_from_ortho6d(delta_rot_pred_6d, args.use_cpu) # batch*3*3 delta_xyz_pred = action_pred[:, 6:9].view(-1,3) # batch*3 ''' # loss computation ''' loss_heatmap = criterion_rmse(heatmap_pred, heatmap_target) loss_r = criterion_rmse(delta_rot_pred, delta_rot) #loss_t = (1-criterion_cos(delta_xyz_pred, delta_xyz)).mean() + criterion_rmse(delta_xyz_pred, delta_xyz) loss_t = (1-criterion_cos(delta_xyz_pred, unit_delta_xyz)).mean() loss_step_size = criterion_bce(step_size_pred, step_size) loss = loss_r + loss_t + loss_heatmap + loss_step_size ''' loss_kptof = criterion_kptof(kpt_of_pred, kpt_of_gt).sum() loss_t = (1-criterion_cos(delta_trans_pred, delta_xyz)).mean() + criterion_rmse(delta_trans_pred, delta_xyz) loss_r = criterion_rmse(delta_rot_pred, delta_rot) loss_mask = criterion_rmse(confidence, heatmap_target) #loss = loss_kptof + loss_t + loss_mask + loss_r loss = loss_kptof + loss_mask loss.backward() optimizer.step() global_step += 1 ''' train_xyz_error.append(loss_t.item()) train_heatmap_error.append(loss_heatmap.item()) train_step_size_error.append(loss_step_size.item()) ''' train_kptof_error.append(loss_kptof.item()) train_xyz_error.append(loss_t.item()) train_rot_error.append(loss_r.item()) train_mask_error.append(loss_mask.item()) ''' train_xyz_error = sum(train_xyz_error) / len(train_xyz_error) train_heatmap_error = sum(train_heatmap_error) / len(train_heatmap_error) train_step_size_error = sum(train_step_size_error) / len(train_step_size_error) ''' train_kptof_error = sum(train_kptof_error) / len(train_kptof_error) train_xyz_error = sum(train_xyz_error) / len(train_xyz_error) train_rot_error = sum(train_rot_error) / len(train_rot_error) train_mask_error = sum(train_mask_error) / len(train_mask_error) ''' log_string('Train Translation Error: %f' % train_xyz_error) log_string('Train Heatmap Error: %f' % train_xyz_error) log_string('Train Step size Error: %f' % train_step_size_error) ''' log_string('Train Rotation Error: %f' % train_rot_error) log_string('Train Keypoint Offset Error: %f' % train_kptof_error) log_string('Train Translation Error: %f' % train_xyz_error) log_string('Train Mask Error: %f' % train_mask_error) with torch.no_grad(): #rot_error, xyz_error, heatmap_error, step_size_error = test(network.eval(), validDataLoader, out_channel, criterion_rmse, criterion_cos, criterion_bce) kptof_error, xyz_error, rot_error, mask_error = test(network.eval(), validDataLoader, out_channel, criterion_rmse, criterion_cos, criterion_bce, criterion_kptof) #log_string('Test Rotation Error: %f, Translation Error: %f, Heatmap Error: %f, Step size Error: %f' % (rot_error, xyz_error, heatmap_error, step_size_error)) #log_string('Best Rotation Error: %f, Translation Error: %f, Heatmap Error: %f, Step size Error: %f' % (best_rot_error, best_xyz_error, best_heatmap_error, best_step_size_error)) log_string('Test Keypoint offset Error: %f, Translation Error: %f, Rotation Error: %f, Mask Error: %f' % (kptof_error, xyz_error, rot_error, mask_error)) log_string('Best Keypoint offset Error: %f, Translation Error: %f, Rotation Error: %f, Mask Error: %f' % (best_kptof_error, best_xyz_error, best_rot_error, best_mask_error)) ''' if (rot_error + xyz_error + heatmap_error + step_size_error) < (best_rot_error + best_xyz_error + best_heatmap_error + best_step_size_error): best_rot_error = rot_error best_xyz_error = xyz_error best_heatmap_error = heatmap_error best_step_size_error = step_size_error best_epoch = epoch + 1 logger.info('Save model...') savepath = str(checkpoints_dir) + '/best_model.pth' log_string('Saving at %s' % savepath) state = { 'epoch': best_epoch, 'rot_error': rot_error, 'xyz_error': xyz_error, 'heatmap_error': heatmap_error, 'step_size_error': step_size_error, 'model_state_dict': network.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), } torch.save(state, savepath) global_epoch += 1 ''' if (kptof_error + xyz_error + rot_error + mask_error) < (best_kptof_error + best_xyz_error + best_rot_error+ best_mask_error): best_kptof_error = kptof_error best_xyz_error = xyz_error best_rot_error = rot_error best_mask_error = mask_error best_epoch = epoch + 1 logger.info('Save model...') savepath = str(checkpoints_dir) + '/best_model.pth' log_string('Saving at %s' % savepath) state = { 'epoch': best_epoch, 'kptof_error': kptof_error, 'xyz_error': xyz_error, 'rot_error': rot_error, 'mask_error': mask_error, 'model_state_dict': network.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), } torch.save(state, savepath) global_epoch += 1 logger.info('End of training...')