Esempio n. 1
0
def train():
    print("local_rank:", args.local_rank)
    cudnn.benchmark = True
    if args.deterministic:
        cudnn.benchmark = False
        cudnn.deterministic = True
        torch.manual_seed(args.local_rank)
        torch.set_printoptions(precision=10)
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(
        backend='nccl',
        init_method='env://',
    )
    torch.manual_seed(0)

    if not args.eval_net:
        train_ds = dataset_desc.Dataset('train')
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_ds)
        train_loader = torch.utils.data.DataLoader(
            train_ds, batch_size=config.mini_batch_size, shuffle=False,
            drop_last=True, num_workers=4, sampler=train_sampler, pin_memory=True
        )

        val_ds = dataset_desc.Dataset('test')
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_ds)
        val_loader = torch.utils.data.DataLoader(
            val_ds, batch_size=config.val_mini_batch_size, shuffle=False,
            drop_last=False, num_workers=4, sampler=val_sampler
        )
    else:
        test_ds = dataset_desc.Dataset('test')
        test_loader = torch.utils.data.DataLoader(
            test_ds, batch_size=config.test_mini_batch_size, shuffle=False,
            num_workers=20
        )

    rndla_cfg = ConfigRandLA
    model = FFB6D(
        n_classes=config.n_objects, n_pts=config.n_sample_points, rndla_cfg=rndla_cfg,
        n_kps=config.n_keypoints
    )
    model = convert_syncbn_model(model)
    device = torch.device('cuda:{}'.format(args.local_rank))
    print('local_rank:', args.local_rank)
    model.to(device)
    optimizer = optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
    )
    opt_level = args.opt_level
    model, optimizer = amp.initialize(
        model, optimizer, opt_level=opt_level,
    )

    # default value
    it = -1  # for the initialize value of `LambdaLR` and `BNMomentumScheduler`
    best_loss = 1e10
    start_epoch = 1

    # load status from checkpoint
    if args.checkpoint is not None:
        checkpoint_status = load_checkpoint(
            model, optimizer, filename=args.checkpoint[:-8]
        )
        if checkpoint_status is not None:
            it, start_epoch, best_loss = checkpoint_status
        if args.eval_net:
            assert checkpoint_status is not None, "Failed loadding model."

    if not args.eval_net:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank,
            find_unused_parameters=True
        )
        clr_div = 6
        lr_scheduler = CyclicLR(
            optimizer, base_lr=1e-5, max_lr=1e-3,
            cycle_momentum=False,
            step_size_up=config.n_total_epoch * train_ds.minibatch_per_epoch // clr_div // args.gpus,
            step_size_down=config.n_total_epoch * train_ds.minibatch_per_epoch // clr_div // args.gpus,
            mode='triangular'
        )
    else:
        lr_scheduler = None

    bnm_lmbd = lambda it: max(
        args.bn_momentum * args.bn_decay ** (int(it * config.mini_batch_size / args.decay_step)),
        bnm_clip,
    )
    bnm_scheduler = pt_utils.BNMomentumScheduler(
        model, bn_lambda=bnm_lmbd, last_epoch=it
    )

    it = max(it, 0)  # for the initialize value of `trainer.train`

    if args.eval_net:
        model_fn = model_fn_decorator(
            FocalLoss(gamma=2), OFLoss(),
            args.test,
        )
    else:
        model_fn = model_fn_decorator(
            FocalLoss(gamma=2).to(device), OFLoss().to(device),
            args.test,
        )

    checkpoint_fd = config.log_model_dir

    trainer = Trainer(
        model,
        model_fn,
        optimizer,
        checkpoint_name=os.path.join(checkpoint_fd, "FFB6D"),
        best_name=os.path.join(checkpoint_fd, "FFB6D_best"),
        lr_scheduler=lr_scheduler,
        bnm_scheduler=bnm_scheduler,
    )

    if args.eval_net:
        start = time.time()
        val_loss, res = trainer.eval_epoch(
            test_loader, is_test=True, test_pose=args.test_pose
        )
        end = time.time()
        print("\nUse time: ", end - start, 's')
    else:
        trainer.train(
            it, start_epoch, config.n_total_epoch, train_loader, None,
            val_loader, best_loss=best_loss,
            tot_iter=config.n_total_epoch * train_ds.minibatch_per_epoch // args.gpus,
            clr_div=clr_div
        )

        if start_epoch == config.n_total_epoch:
            _ = trainer.eval_epoch(val_loader)
def main(args):
    def log_string(str):
        logger.info(str)
        print(str)

    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    exp_dir = Path('./log/')
    exp_dir.mkdir(exist_ok=True)
    exp_dir = exp_dir.joinpath('dsae')
    exp_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        exp_dir = exp_dir.joinpath(timestr)
    else:
        exp_dir = exp_dir.joinpath(args.log_dir)
    exp_dir.mkdir(exist_ok=True)
    checkpoints_dir = exp_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = exp_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)
    '''DATA LOADING'''
    log_string('Load dataset ...')
    # Construct the dataset
    train_dataset, train_config = construct_dataset(is_train=True)
    # Random split
    train_set_size = int(len(train_dataset) * 0.8)
    valid_set_size = len(train_dataset) - train_set_size
    train_dataset, valid_dataset = torch.utils.data.random_split(
        train_dataset, [train_set_size, valid_set_size])
    # And the dataloader
    trainDataLoader = DataLoader(dataset=train_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=4)
    validDataLoader = DataLoader(dataset=valid_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=4)
    '''MODEL LOADING'''
    out_channel = args.out_channel
    model = importlib.import_module(args.model)
    shutil.copy('./models/%s.py' % args.model, str(exp_dir))
    shutil.copy('models/pointnet2_utils.py', str(exp_dir))
    shutil.copy('./train_dsae.py', str(exp_dir))

    network = model.get_model()
    criterion_rmse = RMSELoss()
    criterion_cos = torch.nn.CosineSimilarity(dim=1)
    criterion_bce = torch.nn.BCELoss()
    criterion_kptof = OFLoss()
    network.apply(inplace_relu)

    if not args.use_cpu:
        network = network.cuda()
        criterion_rmse = criterion_rmse.cuda()
        criterion_cos = criterion_cos.cuda()
        criterion_bce = criterion_bce.cuda()
        criterion_kptof = criterion_kptof.cuda()
    try:
        checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        network.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(network.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    else:
        optimizer = torch.optim.SGD(network.parameters(),
                                    lr=0.01,
                                    momentum=0.9)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=10,
                                                gamma=0.5)
    global_epoch = 0
    global_step = 0
    best_rot_error = 99.9
    best_xyz_error = 99.9
    best_recon_error = 99.9
    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        log_string('Epoch %d (%d/%s):' %
                   (global_epoch + 1, epoch + 1, args.epoch))
        train_xyz_error = []
        train_rot_error = []
        train_recon_error = []
        network = network.train()

        scheduler.step()
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            optimizer.zero_grad()
            rgbd = data[parameter.rgbd_image_key]
            rgb_pair = data[parameter.rgb_pair_key]
            depth_pair = data[parameter.depth_pair_key]
            delta_rot = data[parameter.delta_rot_key]
            delta_xyz = data[parameter.delta_xyz_key]

            if not args.use_cpu:
                delta_rot = delta_rot.cuda()
                delta_xyz = delta_xyz.cuda()
                rgbd = rgbd.cuda()
                rgb_pair = rgb_pair.cuda()
                depth_pair = depth_pair.cuda()
            rgb = rgbd[:, :3, :, :]
            depth = rgbd[:, 3:, :]
            # rgb (B,3,H,W) rgb_pair (B,6,H,W)
            # depth (B,1,H,W) depth_pair (B,2,H,W)
            delta_xyz_pred, delta_rot_pred, depth_pred = network(rgb_pair)
            # loss computation
            loss_t = (1 - criterion_cos(delta_xyz_pred, delta_xyz)
                      ).mean() + criterion_rmse(delta_xyz_pred, delta_xyz)
            loss_r = criterion_rmse(delta_rot_pred, delta_rot)
            loss_recon = criterion_rmse(depth_pred, depth_pair)
            loss = loss_t + loss_r + loss_recon
            loss.backward()
            optimizer.step()
            global_step += 1

            train_xyz_error.append(loss_t.item())
            train_rot_error.append(loss_r.item())
            train_recon_error.append(loss_recon.item())

        train_xyz_error = sum(train_xyz_error) / len(train_xyz_error)
        train_rot_error = sum(train_rot_error) / len(train_rot_error)
        train_recon_error = sum(train_recon_error) / len(train_recon_error)

        log_string('Train Translation Error: %f' % train_xyz_error)
        log_string('Train Rotation Error: %f' % train_rot_error)
        log_string('Train Reconstruction Error: %f' % train_recon_error)
        with torch.no_grad():
            xyz_error, rot_error, recon_error = test(
                network.eval(), validDataLoader, out_channel, criterion_rmse,
                criterion_cos, criterion_bce, criterion_kptof)

            log_string(
                'Test Translation Error: %f, Rotation Error: %f, Reconstruction Error: %f'
                % (xyz_error, rot_error, recon_error))
            log_string(
                'Best Translation Error: %f, Rotation Error: %f, Reconstruction Error: %f'
                % (best_xyz_error, best_rot_error, best_recon_error))

            if (xyz_error + rot_error + recon_error) < (
                    best_xyz_error + best_rot_error + best_recon_error):
                best_xyz_error = xyz_error
                best_rot_error = rot_error
                best_recon_error = recon_error
                best_epoch = epoch + 1
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model_e_' + str(
                    best_epoch) + '.pth'
                log_string('Saving at %s' % savepath)
                state = {
                    'epoch': best_epoch,
                    'xyz_error': xyz_error,
                    'rot_error': rot_error,
                    'recon_error': recon_error,
                    'model_state_dict': network.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1

    logger.info('End of training...')
Esempio n. 3
0
def main(args):
    def log_string(str):
        logger.info(str)
        print(str)


    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    exp_dir = Path('./log/')
    exp_dir.mkdir(exist_ok=True)
    exp_dir = exp_dir.joinpath('kpts')
    exp_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        exp_dir = exp_dir.joinpath(timestr)
    else:
        exp_dir = exp_dir.joinpath(args.log_dir)
    exp_dir.mkdir(exist_ok=True)
    checkpoints_dir = exp_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = exp_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)

    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)

    '''DATA LOADING'''
    log_string('Load dataset ...')
    # Construct the dataset
    train_dataset, train_config = construct_dataset(is_train=True)
    # Random split
    train_set_size = int(len(train_dataset) * 0.8)
    valid_set_size = len(train_dataset) - train_set_size
    train_dataset, valid_dataset = torch.utils.data.random_split(train_dataset, [train_set_size, valid_set_size])
    # And the dataloader
    trainDataLoader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
    validDataLoader = DataLoader(dataset=valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    
    '''MODEL LOADING'''
    out_channel = args.out_channel
    model = importlib.import_module(args.model)
    shutil.copy('./models/%s.py' % args.model, str(exp_dir))
    shutil.copy('models/pointnet2_utils.py', str(exp_dir))
    shutil.copy('./train_pointnet2_kpts.py', str(exp_dir))

    network = model.get_model()
    criterion_rmse = RMSELoss()
    criterion_cos = torch.nn.CosineSimilarity(dim=1)
    criterion_bce = torch.nn.BCELoss()
    criterion_kptof = OFLoss()
    network.apply(inplace_relu)

    if not args.use_cpu:
        network = network.cuda()
        criterion_rmse = criterion_rmse.cuda()
        criterion_cos = criterion_cos.cuda()
        criterion_bce = criterion_bce.cuda()
        criterion_kptof = criterion_kptof.cuda()
    try:
        checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        network.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(
            network.parameters(),
            lr=args.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=args.decay_rate
        )
    else:
        optimizer = torch.optim.SGD(network.parameters(), lr=0.01, momentum=0.9)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)
    global_epoch = 0
    global_step = 0
    best_rot_error = 99.9
    best_xyz_error = 99.9
    best_heatmap_error = 99.9
    best_step_size_error = 99.9
    best_kptof_error = 99.9
    best_mask_error = 99.9

    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
        train_rot_error = []
        train_xyz_error = []
        train_heatmap_error = []
        train_step_size_error = []
        train_kptof_error = []
        train_mask_error = []
        network = network.train()

        scheduler.step()
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9):
            optimizer.zero_grad()
            points =  data[parameter.pcd_key].numpy()
            # Because we need to predcit 3d keypoint, we don't do augmentation here.
            #points = provider.normalize_data(points)
            #points = provider.random_point_dropout(points)
            #points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])
            #points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
            points = torch.Tensor(points)
            points = points.transpose(2, 1)
            heatmap_target = data[parameter.heatmap_key]
            #segmentation_target = data[parameter.segmentation_key]
            delta_rot = data[parameter.delta_rot_key]
            delta_xyz = data[parameter.delta_xyz_key]
            #unit_delta_xyz = data[parameter.unit_delta_xyz_key]
            #step_size = data[parameter.step_size_key]
            kpt_of_gt = data[parameter.kpt_of_key]
            pcd_centroid = data[parameter.pcd_centroid_key]
            pcd_mean = data[parameter.pcd_mean_key]
            gripper_pose = data[parameter.gripper_pose_key]
 
            if not args.use_cpu:
                points = points.cuda()
                kpt_of_gt = kpt_of_gt.cuda()
                delta_rot = delta_rot.cuda()
                delta_xyz = delta_xyz.cuda()
                pcd_centroid = pcd_centroid.cuda()
                pcd_mean = pcd_mean.cuda()
                gripper_pose = gripper_pose.cuda()
                #segmentation_target = segmentation_target.cuda()
                heatmap_target = heatmap_target.cuda()
                '''
                delta_rot = delta_rot.cuda()
                delta_xyz = delta_xyz.cuda()
                heatmap_target = heatmap_target.cuda()
                unit_delta_xyz = unit_delta_xyz.cuda()
                step_size = step_size.cuda()
                '''
            kpt_of_pred, trans_of_pred, rot_of_pred, mean_kpt_pred, rot_mat_pred, confidence = network(points)
            gripper_pos = gripper_pose[:, :3, 3]
            gripper_rot = gripper_pose[:, :3, :3]
            #points = points.transpose(2, 1)
            #kpt_pred = points - kpt_of_pred
            #mean_kpt_pred = torch.mean(kpt_pred, dim=1)
            real_kpt_pred = (mean_kpt_pred * pcd_mean) + pcd_centroid
            real_kpt_pred = real_kpt_pred / 1000  #unit: mm to m
            real_trans_of_pred = (trans_of_pred * pcd_mean) / 1000 #unit: mm to m
            delta_trans_pred = real_kpt_pred - gripper_pos + real_trans_of_pred
            delta_rot_pred = torch.bmm(torch.transpose(gripper_rot, 1, 2), rot_mat_pred)
            delta_rot_pred = torch.bmm(delta_rot_pred, rot_of_pred)
            '''
            heatmap_pred, action_pred, step_size_pred = network(points)
            # action control
            delta_rot_pred_6d = action_pred[:, 0:6]
            delta_rot_pred = compute_rotation_matrix_from_ortho6d(delta_rot_pred_6d, args.use_cpu) # batch*3*3
            delta_xyz_pred = action_pred[:, 6:9].view(-1,3) # batch*3
            '''
            # loss computation
            '''
            loss_heatmap = criterion_rmse(heatmap_pred, heatmap_target)
            loss_r = criterion_rmse(delta_rot_pred, delta_rot)
            #loss_t = (1-criterion_cos(delta_xyz_pred, delta_xyz)).mean() + criterion_rmse(delta_xyz_pred, delta_xyz)
            loss_t = (1-criterion_cos(delta_xyz_pred, unit_delta_xyz)).mean()
            loss_step_size = criterion_bce(step_size_pred, step_size)
            loss = loss_r + loss_t + loss_heatmap + loss_step_size
            '''
            loss_kptof = criterion_kptof(kpt_of_pred, kpt_of_gt).sum()
            loss_t = (1-criterion_cos(delta_trans_pred, delta_xyz)).mean() + criterion_rmse(delta_trans_pred, delta_xyz)
            loss_r = criterion_rmse(delta_rot_pred, delta_rot)
            loss_mask = criterion_rmse(confidence, heatmap_target)
            #loss = loss_kptof + loss_t + loss_mask + loss_r
            loss = loss_kptof + loss_mask
            loss.backward()
            optimizer.step()
            global_step += 1
            
            '''
            
            train_xyz_error.append(loss_t.item())
            train_heatmap_error.append(loss_heatmap.item())
            train_step_size_error.append(loss_step_size.item())
            '''
            train_kptof_error.append(loss_kptof.item())
            train_xyz_error.append(loss_t.item())
            train_rot_error.append(loss_r.item())
            train_mask_error.append(loss_mask.item())
            
        '''
        
        train_xyz_error = sum(train_xyz_error) / len(train_xyz_error)
        train_heatmap_error = sum(train_heatmap_error) / len(train_heatmap_error)
        train_step_size_error = sum(train_step_size_error) / len(train_step_size_error)
        '''
        train_kptof_error = sum(train_kptof_error) / len(train_kptof_error)
        train_xyz_error = sum(train_xyz_error) / len(train_xyz_error)
        train_rot_error = sum(train_rot_error) / len(train_rot_error)
        train_mask_error = sum(train_mask_error) / len(train_mask_error)
        '''
        
        log_string('Train Translation Error: %f' % train_xyz_error)
        log_string('Train Heatmap Error: %f' % train_xyz_error)
        log_string('Train Step size Error: %f' % train_step_size_error)
        '''
        log_string('Train Rotation Error: %f' % train_rot_error)
        log_string('Train Keypoint Offset Error: %f' % train_kptof_error)
        log_string('Train Translation Error: %f' % train_xyz_error)
        log_string('Train Mask Error: %f' % train_mask_error)
        with torch.no_grad():
            #rot_error, xyz_error, heatmap_error, step_size_error = test(network.eval(), validDataLoader, out_channel, criterion_rmse, criterion_cos, criterion_bce)
            kptof_error, xyz_error, rot_error, mask_error = test(network.eval(), validDataLoader, out_channel, criterion_rmse, criterion_cos, criterion_bce, criterion_kptof)
            
            #log_string('Test Rotation Error: %f, Translation Error: %f, Heatmap Error: %f, Step size Error: %f' % (rot_error, xyz_error, heatmap_error, step_size_error))
            #log_string('Best Rotation Error: %f, Translation Error: %f, Heatmap Error: %f, Step size Error: %f' % (best_rot_error, best_xyz_error, best_heatmap_error, best_step_size_error))
            log_string('Test Keypoint offset Error: %f, Translation Error: %f, Rotation Error: %f, Mask Error: %f' % (kptof_error, xyz_error, rot_error, mask_error))
            log_string('Best Keypoint offset Error: %f, Translation Error: %f, Rotation Error: %f, Mask Error: %f' % (best_kptof_error, best_xyz_error, best_rot_error, best_mask_error))
            '''
            if (rot_error + xyz_error + heatmap_error + step_size_error) < (best_rot_error + best_xyz_error + best_heatmap_error + best_step_size_error):
                best_rot_error = rot_error
                best_xyz_error = xyz_error
                best_heatmap_error = heatmap_error
                best_step_size_error = step_size_error
                best_epoch = epoch + 1
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model.pth'
                log_string('Saving at %s' % savepath)
                state = {
                    'epoch': best_epoch,
                    'rot_error': rot_error,
                    'xyz_error': xyz_error,
                    'heatmap_error': heatmap_error,
                    'step_size_error': step_size_error,
                    'model_state_dict': network.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1
            '''
            if (kptof_error + xyz_error + rot_error + mask_error) < (best_kptof_error + best_xyz_error + best_rot_error+ best_mask_error):
                best_kptof_error = kptof_error
                best_xyz_error = xyz_error
                best_rot_error = rot_error
                best_mask_error = mask_error
                best_epoch = epoch + 1
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model.pth'
                log_string('Saving at %s' % savepath)
                state = {
                    'epoch': best_epoch,
                    'kptof_error': kptof_error,
                    'xyz_error': xyz_error,
                    'rot_error': rot_error,
                    'mask_error': mask_error,
                    'model_state_dict': network.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1

    logger.info('End of training...')