def main(args): def log_string(str): logger.info(str) print(str) '''CREATE DIR''' timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')) exp_dir = Path('./log/') exp_dir.mkdir(exist_ok=True) exp_dir = exp_dir.joinpath('reg_seg_heatmap_v3') exp_dir.mkdir(exist_ok=True) if args.log_dir is None: exp_dir = exp_dir.joinpath(timestr) else: exp_dir = exp_dir.joinpath(args.log_dir) exp_dir.mkdir(exist_ok=True) checkpoints_dir = exp_dir.joinpath('checkpoints/') checkpoints_dir.mkdir(exist_ok=True) log_dir = exp_dir.joinpath('logs/') log_dir.mkdir(exist_ok=True) '''LOG''' args = parse_args() logger = logging.getLogger("Model") logger.setLevel(logging.INFO) formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s') file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model)) file_handler.setLevel(logging.INFO) file_handler.setFormatter(formatter) logger.addHandler(file_handler) log_string('PARAMETER ...') log_string(args) '''DATA LOADING''' log_string('Load dataset ...') # Construct the dataset train_dataset, train_config = construct_dataset(is_train=True) # Random split train_set_size = int(len(train_dataset) * 0.8) valid_set_size = len(train_dataset) - train_set_size train_dataset, valid_dataset = torch.utils.data.random_split( train_dataset, [train_set_size, valid_set_size]) # And the dataloader trainDataLoader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) validDataLoader = DataLoader(dataset=valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) '''MODEL LOADING''' out_channel = args.out_channel model = importlib.import_module(args.model) shutil.copy('./models/%s.py' % args.model, str(exp_dir)) shutil.copy('models/pointnet2_utils.py', str(exp_dir)) shutil.copy('./train_pointnet2_reg_seg_heatmap_stepsize.py', str(exp_dir)) #network = model.get_model(out_channel, normal_channel=args.use_normals) network = model.get_model(out_channel) criterion_rmse = RMSELoss() criterion_cos = torch.nn.CosineSimilarity(dim=1) criterion_bce = torch.nn.BCELoss() network.apply(inplace_relu) if not args.use_cpu: network = network.cuda() criterion_rmse = criterion_rmse.cuda() criterion_cos = criterion_cos.cuda() try: checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth') start_epoch = checkpoint['epoch'] network.load_state_dict(checkpoint['model_state_dict']) log_string('Use pretrain model') except: log_string('No existing model, starting training from scratch...') start_epoch = 0 if args.optimizer == 'Adam': optimizer = torch.optim.Adam(network.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.decay_rate) else: optimizer = torch.optim.SGD(network.parameters(), lr=0.01, momentum=0.9) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7) global_epoch = 0 global_step = 0 best_rot_error = 99.9 best_xyz_error = 99.9 best_heatmap_error = 99.9 best_step_size_error = 99.9 '''TRANING''' logger.info('Start training...') for epoch in range(start_epoch, args.epoch): log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch)) train_rot_error = [] train_xyz_error = [] train_heatmap_error = [] train_step_size_error = [] network = network.train() scheduler.step() for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9): optimizer.zero_grad() points = data[parameter.pcd_key].numpy() points = provider.normalize_data(points) points = provider.random_point_dropout(points) points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3]) points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3]) points = torch.Tensor(points) points = points.transpose(2, 1) heatmap_target = data[parameter.heatmap_key] segmentation_target = data[parameter.segmentation_key] #print('heatmap size', heatmap_target.size()) #print('segmentation', segmentation_target.size()) delta_rot = data[parameter.delta_rot_key] delta_xyz = data[parameter.delta_xyz_key] unit_delta_xyz = data[parameter.unit_delta_xyz_key] step_size = data[parameter.step_size_key] if not args.use_cpu: points = points.cuda() delta_rot = delta_rot.cuda() delta_xyz = delta_xyz.cuda() heatmap_target = heatmap_target.cuda() unit_delta_xyz = unit_delta_xyz.cuda() step_size = step_size.cuda() heatmap_pred, action_pred, step_size_pred = network(points) # action control delta_rot_pred_6d = action_pred[:, 0:6] delta_rot_pred = compute_rotation_matrix_from_ortho6d( delta_rot_pred_6d, args.use_cpu) # batch*3*3 delta_xyz_pred = action_pred[:, 6:9].view(-1, 3) # batch*3 # loss computation loss_heatmap = criterion_rmse(heatmap_pred, heatmap_target) loss_r = criterion_rmse(delta_rot_pred, delta_rot) #loss_t = (1-criterion_cos(delta_xyz_pred, delta_xyz)).mean() + criterion_rmse(delta_xyz_pred, delta_xyz) loss_t = (1 - criterion_cos(delta_xyz_pred, unit_delta_xyz)).mean() loss_step_size = criterion_bce(step_size_pred, step_size) loss = loss_r + loss_t + loss_heatmap + loss_step_size loss.backward() optimizer.step() global_step += 1 train_rot_error.append(loss_r.item()) train_xyz_error.append(loss_t.item()) train_heatmap_error.append(loss_heatmap.item()) train_step_size_error.append(loss_step_size.item()) train_rot_error = sum(train_rot_error) / len(train_rot_error) train_xyz_error = sum(train_xyz_error) / len(train_xyz_error) train_heatmap_error = sum(train_heatmap_error) / len( train_heatmap_error) train_step_size_error = sum(train_step_size_error) / len( train_step_size_error) log_string('Train Rotation Error: %f' % train_rot_error) log_string('Train Translation Error: %f' % train_xyz_error) log_string('Train Heatmap Error: %f' % train_xyz_error) log_string('Train Step size Error: %f' % train_step_size_error) with torch.no_grad(): rot_error, xyz_error, heatmap_error, step_size_error = test( network.eval(), validDataLoader, out_channel, criterion_rmse, criterion_cos, criterion_bce) log_string( 'Test Rotation Error: %f, Translation Error: %f, Heatmap Error: %f, Step size Error: %f' % (rot_error, xyz_error, heatmap_error, step_size_error)) log_string( 'Best Rotation Error: %f, Translation Error: %f, Heatmap Error: %f, Step size Error: %f' % (best_rot_error, best_xyz_error, best_heatmap_error, best_step_size_error)) if (rot_error + xyz_error + heatmap_error + step_size_error) < ( best_rot_error + best_xyz_error + best_heatmap_error + best_step_size_error): best_rot_error = rot_error best_xyz_error = xyz_error best_heatmap_error = heatmap_error best_step_size_error = step_size_error best_epoch = epoch + 1 logger.info('Save model...') savepath = str(checkpoints_dir) + '/best_model.pth' log_string('Saving at %s' % savepath) state = { 'epoch': best_epoch, 'rot_error': rot_error, 'xyz_error': xyz_error, 'heatmap_error': heatmap_error, 'step_size_error': step_size_error, 'model_state_dict': network.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), } torch.save(state, savepath) global_epoch += 1 logger.info('End of training...')
def main(args): def log_string(str): logger.info(str) print(str) '''CREATE DIR''' timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')) exp_dir = Path('./log/') exp_dir.mkdir(exist_ok=True) exp_dir = exp_dir.joinpath('dsae') exp_dir.mkdir(exist_ok=True) if args.log_dir is None: exp_dir = exp_dir.joinpath(timestr) else: exp_dir = exp_dir.joinpath(args.log_dir) exp_dir.mkdir(exist_ok=True) checkpoints_dir = exp_dir.joinpath('checkpoints/') checkpoints_dir.mkdir(exist_ok=True) log_dir = exp_dir.joinpath('logs/') log_dir.mkdir(exist_ok=True) '''LOG''' args = parse_args() logger = logging.getLogger("Model") logger.setLevel(logging.INFO) formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s') file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model)) file_handler.setLevel(logging.INFO) file_handler.setFormatter(formatter) logger.addHandler(file_handler) log_string('PARAMETER ...') log_string(args) '''DATA LOADING''' log_string('Load dataset ...') # Construct the dataset train_dataset, train_config = construct_dataset(is_train=True) # Random split train_set_size = int(len(train_dataset) * 0.8) valid_set_size = len(train_dataset) - train_set_size train_dataset, valid_dataset = torch.utils.data.random_split( train_dataset, [train_set_size, valid_set_size]) # And the dataloader trainDataLoader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) validDataLoader = DataLoader(dataset=valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) '''MODEL LOADING''' out_channel = args.out_channel model = importlib.import_module(args.model) shutil.copy('./models/%s.py' % args.model, str(exp_dir)) shutil.copy('models/pointnet2_utils.py', str(exp_dir)) shutil.copy('./train_dsae.py', str(exp_dir)) network = model.get_model() criterion_rmse = RMSELoss() criterion_cos = torch.nn.CosineSimilarity(dim=1) criterion_bce = torch.nn.BCELoss() criterion_kptof = OFLoss() network.apply(inplace_relu) if not args.use_cpu: network = network.cuda() criterion_rmse = criterion_rmse.cuda() criterion_cos = criterion_cos.cuda() criterion_bce = criterion_bce.cuda() criterion_kptof = criterion_kptof.cuda() try: checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth') start_epoch = checkpoint['epoch'] network.load_state_dict(checkpoint['model_state_dict']) log_string('Use pretrain model') except: log_string('No existing model, starting training from scratch...') start_epoch = 0 if args.optimizer == 'Adam': optimizer = torch.optim.Adam(network.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.decay_rate) else: optimizer = torch.optim.SGD(network.parameters(), lr=0.01, momentum=0.9) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) global_epoch = 0 global_step = 0 best_rot_error = 99.9 best_xyz_error = 99.9 best_recon_error = 99.9 '''TRANING''' logger.info('Start training...') for epoch in range(start_epoch, args.epoch): log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch)) train_xyz_error = [] train_rot_error = [] train_recon_error = [] network = network.train() scheduler.step() for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9): optimizer.zero_grad() rgbd = data[parameter.rgbd_image_key] rgb_pair = data[parameter.rgb_pair_key] depth_pair = data[parameter.depth_pair_key] delta_rot = data[parameter.delta_rot_key] delta_xyz = data[parameter.delta_xyz_key] if not args.use_cpu: delta_rot = delta_rot.cuda() delta_xyz = delta_xyz.cuda() rgbd = rgbd.cuda() rgb_pair = rgb_pair.cuda() depth_pair = depth_pair.cuda() rgb = rgbd[:, :3, :, :] depth = rgbd[:, 3:, :] # rgb (B,3,H,W) rgb_pair (B,6,H,W) # depth (B,1,H,W) depth_pair (B,2,H,W) delta_xyz_pred, delta_rot_pred, depth_pred = network(rgb_pair) # loss computation loss_t = (1 - criterion_cos(delta_xyz_pred, delta_xyz) ).mean() + criterion_rmse(delta_xyz_pred, delta_xyz) loss_r = criterion_rmse(delta_rot_pred, delta_rot) loss_recon = criterion_rmse(depth_pred, depth_pair) loss = loss_t + loss_r + loss_recon loss.backward() optimizer.step() global_step += 1 train_xyz_error.append(loss_t.item()) train_rot_error.append(loss_r.item()) train_recon_error.append(loss_recon.item()) train_xyz_error = sum(train_xyz_error) / len(train_xyz_error) train_rot_error = sum(train_rot_error) / len(train_rot_error) train_recon_error = sum(train_recon_error) / len(train_recon_error) log_string('Train Translation Error: %f' % train_xyz_error) log_string('Train Rotation Error: %f' % train_rot_error) log_string('Train Reconstruction Error: %f' % train_recon_error) with torch.no_grad(): xyz_error, rot_error, recon_error = test( network.eval(), validDataLoader, out_channel, criterion_rmse, criterion_cos, criterion_bce, criterion_kptof) log_string( 'Test Translation Error: %f, Rotation Error: %f, Reconstruction Error: %f' % (xyz_error, rot_error, recon_error)) log_string( 'Best Translation Error: %f, Rotation Error: %f, Reconstruction Error: %f' % (best_xyz_error, best_rot_error, best_recon_error)) if (xyz_error + rot_error + recon_error) < ( best_xyz_error + best_rot_error + best_recon_error): best_xyz_error = xyz_error best_rot_error = rot_error best_recon_error = recon_error best_epoch = epoch + 1 logger.info('Save model...') savepath = str(checkpoints_dir) + '/best_model_e_' + str( best_epoch) + '.pth' log_string('Saving at %s' % savepath) state = { 'epoch': best_epoch, 'xyz_error': xyz_error, 'rot_error': rot_error, 'recon_error': recon_error, 'model_state_dict': network.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), } torch.save(state, savepath) global_epoch += 1 logger.info('End of training...')