Exemplo n.º 1
0
    def evaluate_pose_arp_2d(self, config, all_poses_est, all_poses_gt,
                             output_dir, logger):
        '''
        evaluate average re-projection 2d error
        :param config:
        :param all_poses_est:
        :param all_poses_gt:
        :param output_dir:
        :param logger:
        :return:
        '''
        print_and_log('evaluating pose average re-projection 2d error', logger)
        num_iter = config.TEST.test_iter
        K = config.dataset.INTRINSIC_MATRIX

        count_all = np.zeros((self.num_classes, ), dtype=np.float32)
        count_correct = {
            k: np.zeros((self.num_classes, num_iter), dtype=np.float32)
            for k in ['2', '5', '10', '20']
        }

        threshold_2 = np.zeros((self.num_classes, num_iter), dtype=np.float32)
        threshold_5 = np.zeros((self.num_classes, num_iter), dtype=np.float32)
        threshold_10 = np.zeros((self.num_classes, num_iter), dtype=np.float32)
        threshold_20 = np.zeros((self.num_classes, num_iter), dtype=np.float32)
        dx = 0.1
        threshold_mean = np.tile(
            np.arange(0, 50, dx).astype(np.float32),
            (self.num_classes, num_iter,
             1))  # (num_class, num_iter, num_thresh)
        num_thresh = threshold_mean.shape[-1]
        count_correct['mean'] = np.zeros(
            (self.num_classes, num_iter, num_thresh), dtype=np.float32)

        for i in xrange(self.num_classes):
            threshold_2[i, :] = 2
            threshold_5[i, :] = 5
            threshold_10[i, :] = 10
            threshold_20[i, :] = 20

        num_valid_class = 0
        for cls_idx, cls_name in enumerate(self.classes):
            if not (all_poses_est[cls_idx][0] and all_poses_gt[cls_idx][0]):
                continue
            num_valid_class += 1
            for iter_i in range(num_iter):
                curr_poses_gt = all_poses_gt[cls_idx][0]
                num = len(curr_poses_gt)
                curr_poses_est = all_poses_est[cls_idx][iter_i]

                for j in xrange(num):
                    if iter_i == 0:
                        count_all[cls_idx] += 1
                    RT = curr_poses_est[j]  # est pose
                    pose_gt = curr_poses_gt[j]  # gt pose

                    error_rotation = re(RT[:3, :3], pose_gt[:3, :3])
                    if cls_name == 'eggbox' and error_rotation > 90:
                        RT_z = np.array([[-1, 0, 0, 0], [0, -1, 0, 0],
                                         [0, 0, 1, 0]])
                        RT_sym = se3_mul(RT, RT_z)
                        error = arp_2d(RT_sym[:3, :3], RT_sym[:, 3],
                                       pose_gt[:3, :3], pose_gt[:, 3],
                                       self._points[cls_name], K)
                    else:
                        error = arp_2d(RT[:3, :3], RT[:, 3], pose_gt[:3, :3],
                                       pose_gt[:,
                                               3], self._points[cls_name], K)

                    if error < threshold_2[cls_idx, iter_i]:
                        count_correct['2'][cls_idx, iter_i] += 1
                    if error < threshold_5[cls_idx, iter_i]:
                        count_correct['5'][cls_idx, iter_i] += 1
                    if error < threshold_10[cls_idx, iter_i]:
                        count_correct['10'][cls_idx, iter_i] += 1
                    if error < threshold_20[cls_idx, iter_i]:
                        count_correct['20'][cls_idx, iter_i] += 1
                    for thresh_i in xrange(num_thresh):
                        if error < threshold_mean[cls_idx, iter_i, thresh_i]:
                            count_correct['mean'][cls_idx, iter_i,
                                                  thresh_i] += 1
        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt

        # store plot data
        plot_data = {}
        sum_acc_mean = np.zeros(num_iter)
        sum_acc_02 = np.zeros(num_iter)
        sum_acc_05 = np.zeros(num_iter)
        sum_acc_10 = np.zeros(num_iter)
        sum_acc_20 = np.zeros(num_iter)
        for cls_idx, cls_name in enumerate(self.classes):
            if count_all[cls_idx] == 0:
                continue
            plot_data[cls_name] = []
            for iter_i in range(num_iter):
                print_and_log("** {}, iter {} **".format(cls_name, iter_i + 1),
                              logger)
                from scipy.integrate import simps
                area = simps(count_correct['mean'][cls_idx, iter_i] /
                             float(count_all[cls_idx]),
                             dx=dx) / (50.0)
                acc_mean = area * 100
                sum_acc_mean[iter_i] += acc_mean
                acc_02 = 100 * float(count_correct['2'][cls_idx,
                                                        iter_i]) / float(
                                                            count_all[cls_idx])
                sum_acc_02[iter_i] += acc_02
                acc_05 = 100 * float(count_correct['5'][cls_idx,
                                                        iter_i]) / float(
                                                            count_all[cls_idx])
                sum_acc_05[iter_i] += acc_05
                acc_10 = 100 * float(
                    count_correct['10'][cls_idx, iter_i]) / float(
                        count_all[cls_idx])
                sum_acc_10[iter_i] += acc_10
                acc_20 = 100 * float(
                    count_correct['20'][cls_idx, iter_i]) / float(
                        count_all[cls_idx])
                sum_acc_20[iter_i] += acc_20

                fig = plt.figure()
                x_s = np.arange(0, 50, dx).astype(np.float32)
                y_s = 100 * count_correct['mean'][cls_idx, iter_i] / float(
                    count_all[cls_idx])
                plot_data[cls_name].append((x_s, y_s))
                plt.plot(x_s, y_s, '-')
                plt.xlim(0, 50)
                plt.ylim(0, 100)
                plt.grid(True)
                plt.xlabel("px")
                plt.ylabel("correctly estimated poses in %")
                plt.savefig(os.path.join(
                    output_dir,
                    'arp_2d_{}_iter{}.png'.format(cls_name, iter_i + 1)),
                            dpi=fig.dpi)

                print_and_log(
                    'threshold=[0, 50], area: {:.2f}'.format(acc_mean), logger)
                print_and_log(
                    'threshold=2, correct poses: {}, all poses: {}, accuracy: {:.2f}'
                    .format(count_correct['2'][cls_idx, iter_i],
                            count_all[cls_idx], acc_02), logger)
                print_and_log(
                    'threshold=5, correct poses: {}, all poses: {}, accuracy: {:.2f}'
                    .format(count_correct['5'][cls_idx, iter_i],
                            count_all[cls_idx], acc_05), logger)
                print_and_log(
                    'threshold=10, correct poses: {}, all poses: {}, accuracy: {:.2f}'
                    .format(count_correct['10'][cls_idx, iter_i],
                            count_all[cls_idx], acc_10), logger)
                print_and_log(
                    'threshold=20, correct poses: {}, all poses: {}, accuracy: {:.2f}'
                    .format(count_correct['20'][cls_idx, iter_i],
                            count_all[cls_idx], acc_20), logger)
                print_and_log(" ", logger)

        with open(os.path.join(output_dir, 'arp_2d_xys.pkl'), 'wb') as f:
            cPickle.dump(plot_data, f, protocol=2)
        print_and_log("=" * 30, logger)

        print(' ')
        # overall performance of arp 2d
        for iter_i in range(num_iter):
            print_and_log(
                "---------- arp 2d performance over {} classes -----------".
                format(num_valid_class), logger)
            print_and_log("** iter {} **".format(iter_i + 1), logger)

            print_and_log(
                'threshold=[0, 50], area: {:.2f}'.format(
                    sum_acc_mean[iter_i] / num_valid_class), logger)
            print_and_log(
                'threshold=2, mean accuracy: {:.2f}'.format(
                    sum_acc_02[iter_i] / num_valid_class), logger)
            print_and_log(
                'threshold=5, mean accuracy: {:.2f}'.format(
                    sum_acc_05[iter_i] / num_valid_class), logger)
            print_and_log(
                'threshold=10, mean accuracy: {:.2f}'.format(
                    sum_acc_10[iter_i] / num_valid_class), logger)
            print_and_log(
                'threshold=20, mean accuracy: {:.2f}'.format(
                    sum_acc_20[iter_i] / num_valid_class), logger)
            print_and_log(" ", logger)

        print_and_log("=" * 30, logger)
Exemplo n.º 2
0
    def evaluate_pose_add(self, config, all_poses_est, all_poses_gt,
                          output_dir, logger):
        '''

        :param config:
        :param all_poses_est:
        :param all_poses_gt:
        :param output_dir:
        :param logger:
        :return:
        '''
        print_and_log('evaluating pose add', logger)
        eval_method = 'add'
        num_iter = config.TEST.test_iter

        count_all = np.zeros((self.num_classes, ), dtype=np.float32)
        count_correct = {
            k: np.zeros((self.num_classes, num_iter), dtype=np.float32)
            for k in ['0.02', '0.05', '0.10']
        }

        threshold_002 = np.zeros((self.num_classes, num_iter),
                                 dtype=np.float32)
        threshold_005 = np.zeros((self.num_classes, num_iter),
                                 dtype=np.float32)
        threshold_010 = np.zeros((self.num_classes, num_iter),
                                 dtype=np.float32)
        dx = 0.0001
        threshold_mean = np.tile(
            np.arange(0, 0.1, dx).astype(np.float32),
            (self.num_classes, num_iter,
             1))  # (num_class, num_iter, num_thresh)
        num_thresh = threshold_mean.shape[-1]
        count_correct['mean'] = np.zeros(
            (self.num_classes, num_iter, num_thresh), dtype=np.float32)

        for i, cls_name in enumerate(self.classes):
            threshold_002[i, :] = 0.02 * self._diameters[cls_name]
            threshold_005[i, :] = 0.05 * self._diameters[cls_name]
            threshold_010[i, :] = 0.10 * self._diameters[cls_name]
            threshold_mean[i, :, :] *= self._diameters[cls_name]

        num_valid_class = 0
        for cls_idx, cls_name in enumerate(self.classes):
            if not (all_poses_est[cls_idx][0] and all_poses_gt[cls_idx][0]):
                continue
            num_valid_class += 1
            for iter_i in range(num_iter):
                curr_poses_gt = all_poses_gt[cls_idx][0]
                num = len(curr_poses_gt)
                curr_poses_est = all_poses_est[cls_idx][iter_i]

                for j in xrange(num):
                    if iter_i == 0:
                        count_all[cls_idx] += 1
                    RT = curr_poses_est[j]  # est pose
                    pose_gt = curr_poses_gt[j]  # gt pose
                    if cls_name == 'eggbox' or cls_name == 'glue' or cls_name == 'bowl' or cls_name == 'cup':
                        eval_method = 'adi'
                        error = adi(RT[:3, :3], RT[:, 3], pose_gt[:3, :3],
                                    pose_gt[:, 3], self._points[cls_name])
                    else:
                        error = add(RT[:3, :3], RT[:, 3], pose_gt[:3, :3],
                                    pose_gt[:, 3], self._points[cls_name])

                    if error < threshold_002[cls_idx, iter_i]:
                        count_correct['0.02'][cls_idx, iter_i] += 1
                    if error < threshold_005[cls_idx, iter_i]:
                        count_correct['0.05'][cls_idx, iter_i] += 1
                    if error < threshold_010[cls_idx, iter_i]:
                        count_correct['0.10'][cls_idx, iter_i] += 1
                    for thresh_i in xrange(num_thresh):
                        if error < threshold_mean[cls_idx, iter_i, thresh_i]:
                            count_correct['mean'][cls_idx, iter_i,
                                                  thresh_i] += 1

        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt

        plot_data = {}

        sum_acc_mean = np.zeros(num_iter)
        sum_acc_002 = np.zeros(num_iter)
        sum_acc_005 = np.zeros(num_iter)
        sum_acc_010 = np.zeros(num_iter)
        for cls_idx, cls_name in enumerate(self.classes):
            if count_all[cls_idx] == 0:
                continue
            plot_data[cls_name] = []
            for iter_i in range(num_iter):
                print_and_log("** {}, iter {} **".format(cls_name, iter_i + 1),
                              logger)
                from scipy.integrate import simps
                area = simps(count_correct['mean'][cls_idx, iter_i] /
                             float(count_all[cls_idx]),
                             dx=dx) / 0.1
                acc_mean = area * 100
                sum_acc_mean[iter_i] += acc_mean
                acc_002 = 100 * float(
                    count_correct['0.02'][cls_idx, iter_i]) / float(
                        count_all[cls_idx])
                sum_acc_002[iter_i] += acc_002
                acc_005 = 100 * float(
                    count_correct['0.05'][cls_idx, iter_i]) / float(
                        count_all[cls_idx])
                sum_acc_005[iter_i] += acc_005
                acc_010 = 100 * float(
                    count_correct['0.10'][cls_idx, iter_i]) / float(
                        count_all[cls_idx])
                sum_acc_010[iter_i] += acc_010

                fig = plt.figure()
                x_s = np.arange(0, 0.1, dx).astype(np.float32)
                y_s = count_correct['mean'][cls_idx, iter_i] / float(
                    count_all[cls_idx])
                plot_data[cls_name].append((x_s, y_s))
                plt.plot(x_s, y_s, '-')
                plt.xlim(0, 0.1)
                plt.ylim(0, 1)
                plt.xlabel("Average distance threshold in meter (symmetry)")
                plt.ylabel("accuracy")
                plt.savefig(os.path.join(
                    output_dir,
                    'acc_thres_{}_iter{}.png'.format(cls_name, iter_i + 1)),
                            dpi=fig.dpi)

                print_and_log(
                    'threshold=[0.0, 0.10], area: {:.2f}'.format(acc_mean),
                    logger)
                print_and_log(
                    'threshold=0.02, correct poses: {}, all poses: {}, accuracy: {:.2f}'
                    .format(count_correct['0.02'][cls_idx, iter_i],
                            count_all[cls_idx], acc_002), logger)
                print_and_log(
                    'threshold=0.05, correct poses: {}, all poses: {}, accuracy: {:.2f}'
                    .format(count_correct['0.05'][cls_idx, iter_i],
                            count_all[cls_idx], acc_005), logger)
                print_and_log(
                    'threshold=0.10, correct poses: {}, all poses: {}, accuracy: {:.2f}'
                    .format(count_correct['0.10'][cls_idx, iter_i],
                            count_all[cls_idx], acc_010), logger)
                print_and_log(" ", logger)

        with open(os.path.join(output_dir, '{}_xys.pkl'.format(eval_method)),
                  'wb') as f:
            cPickle.dump(plot_data, f, protocol=2)

        print_and_log("=" * 30, logger)

        print(' ')
        # overall performance of add
        for iter_i in range(num_iter):
            print_and_log(
                "---------- add performance over {} classes -----------".
                format(num_valid_class), logger)
            print_and_log("** iter {} **".format(iter_i + 1), logger)
            print_and_log(
                'threshold=[0.0, 0.10], area: {:.2f}'.format(
                    sum_acc_mean[iter_i] / num_valid_class), logger)
            print_and_log(
                'threshold=0.02, mean accuracy: {:.2f}'.format(
                    sum_acc_002[iter_i] / num_valid_class), logger)
            print_and_log(
                'threshold=0.05, mean accuracy: {:.2f}'.format(
                    sum_acc_005[iter_i] / num_valid_class), logger)
            print_and_log(
                'threshold=0.10, mean accuracy: {:.2f}'.format(
                    sum_acc_010[iter_i] / num_valid_class), logger)
            print(' ')

        print_and_log("=" * 30, logger)
Exemplo n.º 3
0
    def evaluate_pose(self, config, all_poses_est, all_poses_gt, logger):
        # evaluate and display
        print_and_log('evaluating pose', logger)
        rot_thresh_list = np.arange(1, 11, 1)
        trans_thresh_list = np.arange(0.01, 0.11, 0.01)
        num_metric = len(rot_thresh_list)
        num_iter = config.TEST.test_iter
        rot_acc = np.zeros((self.num_classes, num_iter, num_metric))
        trans_acc = np.zeros((self.num_classes, num_iter, num_metric))
        space_acc = np.zeros((self.num_classes, num_iter, num_metric))

        num_valid_class = 0
        for cls_idx, cls_name in enumerate(self.classes):
            if not (all_poses_est[cls_idx][0] and all_poses_gt[cls_idx][0]):
                continue
            num_valid_class += 1
            for iter_i in range(num_iter):
                curr_poses_gt = all_poses_gt[cls_idx][0]
                num = len(curr_poses_gt)
                curr_poses_est = all_poses_est[cls_idx][iter_i]

                cur_rot_rst = np.zeros((num, 1))
                cur_trans_rst = np.zeros((num, 1))

                for j in range(num):
                    r_dist_est, t_dist_est = calc_rt_dist_m(
                        curr_poses_est[j], curr_poses_gt[j])
                    if cls_name == 'eggbox' and r_dist_est > 90:
                        RT_z = np.array([[-1, 0, 0, 0], [0, -1, 0, 0],
                                         [0, 0, 1, 0]])
                        curr_pose_est_sym = se3_mul(curr_poses_est[j], RT_z)
                        r_dist_est, t_dist_est = calc_rt_dist_m(
                            curr_pose_est_sym, curr_poses_gt[j])
                    cur_rot_rst[j, 0] = r_dist_est
                    cur_trans_rst[j, 0] = t_dist_est

                for thresh_idx in range(num_metric):
                    rot_acc[cls_idx, iter_i, thresh_idx] = np.mean(
                        cur_rot_rst < rot_thresh_list[thresh_idx])
                    trans_acc[cls_idx, iter_i, thresh_idx] = np.mean(
                        cur_trans_rst < trans_thresh_list[thresh_idx])
                    space_acc[cls_idx, iter_i, thresh_idx] = np.mean(
                        np.logical_and(
                            cur_rot_rst < rot_thresh_list[thresh_idx],
                            cur_trans_rst < trans_thresh_list[thresh_idx]))

            show_list = [1, 4, 9]
            print_and_log("------------ {} -----------".format(cls_name),
                          logger)
            print_and_log(
                "{:>24}: {:>7}, {:>7}, {:>7}".format(
                    "[rot_thresh, trans_thresh", "RotAcc", "TraAcc", "SpcAcc"),
                logger)
            for iter_i in range(num_iter):
                print_and_log("** iter {} **".format(iter_i + 1), logger)
                print_and_log(
                    "{:<16}{:>8}: {:>7.2f}, {:>7.2f}, {:>7.2f}".format(
                        'average_accuracy', '[{:>2}, {:>5.2f}]'.format(-1, -1),
                        np.mean(rot_acc[cls_idx, iter_i, :]) * 100,
                        np.mean(trans_acc[cls_idx, iter_i, :]) * 100,
                        np.mean(space_acc[cls_idx, iter_i, :]) * 100), logger)
                for i, show_idx in enumerate(show_list):
                    print_and_log(
                        "{:>16}{:>8}: {:>7.2f}, {:>7.2f}, {:>7.2f}".format(
                            'average_accuracy', '[{:>2}, {:>5.2f}]'.format(
                                rot_thresh_list[show_idx],
                                trans_thresh_list[show_idx]),
                            rot_acc[cls_idx, iter_i, show_idx] * 100,
                            trans_acc[cls_idx, iter_i, show_idx] * 100,
                            space_acc[cls_idx, iter_i, show_idx] * 100),
                        logger)
        print(' ')
        # overall performance
        for iter_i in range(num_iter):
            show_list = [1, 4, 9]
            print_and_log(
                "---------- performance over {} classes -----------".format(
                    num_valid_class), logger)
            print_and_log("** iter {} **".format(iter_i + 1), logger)
            print_and_log(
                "{:>24}: {:>7}, {:>7}, {:>7}".format(
                    "[rot_thresh, trans_thresh", "RotAcc", "TraAcc", "SpcAcc"),
                logger)
            print_and_log(
                "{:<16}{:>8}: {:>7.2f}, {:>7.2f}, {:>7.2f}".format(
                    'average_accuracy', '[{:>2}, {:>5.2f}]'.format(-1, -1),
                    np.sum(rot_acc[:, iter_i, :]) /
                    (num_valid_class * num_metric) * 100,
                    np.sum(trans_acc[:, iter_i, :]) /
                    (num_valid_class * num_metric) * 100,
                    np.sum(space_acc[:, iter_i, :]) /
                    (num_valid_class * num_metric) * 100), logger)
            for i, show_idx in enumerate(show_list):
                print_and_log(
                    "{:>16}{:>8}: {:>7.2f}, {:>7.2f}, {:>7.2f}".format(
                        'average_accuracy', '[{:>2}, {:>5.2f}]'.format(
                            rot_thresh_list[show_idx],
                            trans_thresh_list[show_idx]),
                        np.sum(rot_acc[:, iter_i, show_idx]) /
                        num_valid_class * 100,
                        np.sum(trans_acc[:, iter_i, show_idx]) /
                        num_valid_class * 100,
                        np.sum(space_acc[:, iter_i, show_idx]) /
                        num_valid_class * 100), logger)
            print(' ')
Exemplo n.º 4
0
    def fit(
        self,
        train_data,
        eval_data=None,
        eval_metric="acc",
        epoch_end_callback=None,
        batch_end_callback=None,
        kvstore="local",
        optimizer="sgd",
        optimizer_params=(("learning_rate", 0.01),),
        eval_end_callback=None,
        eval_batch_end_callback=None,
        initializer=Uniform(0.01),
        arg_params=None,
        aux_params=None,
        allow_missing=False,
        force_rebind=False,
        force_init=False,
        begin_epoch=0,
        num_epoch=None,
        validation_metric=None,
        monitor=None,
        prefix=None,
    ):
        """Train the module parameters.

        Parameters
        ----------
        train_data : DataIter
        eval_data : DataIter
            If not `None`, will be used as validation set and evaluate the performance
            after each epoch.
        eval_metric : str or EvalMetric
            Default `'acc'`. The performance measure used to display during training.
        epoch_end_callback : function or list of function
            Each callback will be called with the current `epoch`, `symbol`, `arg_params`
            and `aux_params`.
        batch_end_callback : function or list of function
            Each callback will be called with a `BatchEndParam`.
        kvstore : str or KVStore
            Default `'local'`.
        optimizer : str or Optimizer
            Default `'sgd'`
        optimizer_params : dict
            Default `(('learning_rate', 0.01),)`. The parameters for the optimizer constructor.
            The default value is not a `dict`, just to avoid pylint warning on dangerous
            default values.
        eval_end_callback : function or list of function
            These will be called at the end of each full evaluation, with the metrics over
            the entire evaluation set.
        eval_batch_end_callback : function or list of function
            These will be called at the end of each minibatch during evaluation
        initializer : Initializer
            Will be called to initialize the module parameters if not already initialized.
        arg_params : dict
            Default `None`, if not `None`, should be existing parameters from a trained
            model or loaded from a checkpoint (previously saved model). In this case,
            the value here will be used to initialize the module parameters, unless they
            are already initialized by the user via a call to `init_params` or `fit`.
            `arg_params` has higher priority to `initializer`.
        aux_params : dict
            Default `None`. Similar to `arg_params`, except for auxiliary states.
        allow_missing : bool
            Default `False`. Indicate whether we allow missing parameters when `arg_params`
            and `aux_params` are not `None`. If this is `True`, then the missing parameters
            will be initialized via the `initializer`.
        force_rebind : bool
            Default `False`. Whether to force rebinding the executors if already binded.
        force_init : bool
            Default `False`. Indicate whether we should force initialization even if the
            parameters are already initialized.
        begin_epoch : int
            Default `0`. Indicate the starting epoch. Usually, if we are resuming from a
            checkpoint saved at a previous training phase at epoch N, then we should specify
            this value as N+1.
        num_epoch : int
            Number of epochs to run training.

        Examples
        --------
        An example of using fit for training::
            >>> #Assume training dataIter and validation dataIter are ready
            >>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter,
                        optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
                        num_epoch=10)
        """
        assert num_epoch is not None, "please specify number of epochs"

        self.bind(
            data_shapes=train_data.provide_data,
            label_shapes=train_data.provide_label,
            for_training=True,
            force_rebind=force_rebind,
        )
        if monitor is not None:
            self.install_monitor(monitor)
        self.init_params(
            initializer=initializer,
            arg_params=arg_params,
            aux_params=aux_params,
            allow_missing=allow_missing,
            force_init=force_init,
        )
        self.init_optimizer(
            kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params
        )

        if validation_metric is None:
            validation_metric = eval_metric
        if not isinstance(eval_metric, metric.EvalMetric):
            eval_metric = metric.create(eval_metric)

        ################################################################################
        # training loop
        ################################################################################
        # epoch 0
        if epoch_end_callback is not None:
            arg_params, aux_params = self.get_params()
            self.set_params(arg_params, aux_params)
            for callback in _as_list(epoch_end_callback):
                callback(-1, self.symbol, arg_params, aux_params)

        from lib.pair_matching.batch_updater_py_multi import batchUpdaterPyMulti

        config = self.config
        if config.TRAIN.TENSORBOARD_LOG:
            from mxboard import SummaryWriter

            tf_log_dir = os.path.join(
                os.path.dirname(prefix),
                "logs/{}".format(time.strftime("%Y-%m-%d-%H-%M")),
            )
            summ_writer = SummaryWriter(logdir=tf_log_dir)

        interBatchUpdater = batchUpdaterPyMulti(config, 480, 640)
        last_lr = 0
        cur_step = 0
        for epoch in range(begin_epoch, num_epoch):
            tic = time.time()
            eval_metric.reset()
            for nbatch, data_batch in enumerate(train_data):
                if monitor is not None:
                    monitor.tic()
                # disp weights L2 norm
                cur_lr = self._curr_module._optimizer._get_lr(0)
                if nbatch % (4000 / train_data.batch_size) == 0:
                    all_params = self._curr_module.get_params()[0]
                    all_param_names = all_params.keys()
                    all_param_names = sorted(all_param_names)
                    print_and_log(prefix, self.logger)
                    weight_str = ""
                    for view_name in all_param_names:
                        weight_str += "{}: {} ".format(
                            view_name, nd.norm(all_params[view_name]).asnumpy()
                        )
                    print_and_log(weight_str, self.logger)
                    print_and_log(
                        "batch {}: lr: {}".format(nbatch, cur_lr), self.logger
                    )
                    if config.TRAIN.TENSORBOARD_LOG:
                        summ_writer.add_scalar(
                            tag="learning_rate", value=cur_lr, global_step=cur_step
                        )
                if cur_lr != last_lr:
                    print_and_log(
                        "batch {}: lr: {}".format(nbatch, cur_lr), self.logger
                    )
                    last_lr = cur_lr
                    if config.TRAIN.TENSORBOARD_LOG:
                        summ_writer.add_scalar(
                            tag="learning_rate", value=cur_lr, global_step=cur_step
                        )

                train_iter_size = config.network.TRAIN_ITER_SIZE
                for iter_idx in range(train_iter_size):
                    self.forward_backward(data_batch)
                    preds = self._curr_module.get_outputs(False)
                    self.update()
                    if iter_idx != train_iter_size - 1:
                        data_batch = interBatchUpdater.forward(
                            data_batch, preds, config
                        )
                cur_step += 1
                self.update_metric(eval_metric, data_batch.label)

                if monitor is not None:
                    monitor.toc_print()

                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(
                        epoch=epoch,
                        nbatch=nbatch,
                        eval_metric=eval_metric,
                        locals=locals(),
                    )
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)
                if config.TRAIN.TENSORBOARD_LOG:
                    for name, val in eval_metric.get_name_value():
                        summ_writer.add_scalar(
                            tag="BatchTrain-{}".format(name),
                            value=val,
                            global_step=cur_step,
                        )

            # one epoch of training is finished
            for name, val in eval_metric.get_name_value():
                self.logger.info("Epoch[%d] Train-%s=%f", epoch, name, val)
                if config.TRAIN.TENSORBOARD_LOG:
                    summ_writer.add_scalar(
                        tag="EpochTrain-{}".format(name), value=val, global_step=epoch
                    )

            toc = time.time()
            self.logger.info("Epoch[%d] Time cost=%.3f", epoch, (toc - tic))

            # sync aux params across devices
            arg_params, aux_params = self.get_params()
            self.set_params(arg_params, aux_params)

            if epoch_end_callback is not None:
                for callback in _as_list(epoch_end_callback):
                    callback(epoch, self.symbol, arg_params, aux_params)

            # ----------------------------------------
            # evaluation on validation set
            if eval_data:
                res = self.score(
                    eval_data,
                    validation_metric,
                    score_end_callback=eval_end_callback,
                    batch_end_callback=eval_batch_end_callback,
                    epoch=epoch,
                )
                # TODO: pull this into default
                for name, val in res:
                    self.logger.info("Epoch[%d] Validation-%s=%f", epoch, name, val)

            # end of 1 epoch, reset the data-iter for another epoch
            train_data.reset()
Exemplo n.º 5
0
def pred_eval(config,
              predictor,
              test_data,
              imdb_test,
              vis=False,
              ignore_cache=None,
              logger=None,
              pairdb=None):
    """
    wrapper for calculating offline validation for faster data analysis
    in this example, all threshold are set by hand
    :param predictor: Predictor
    :param test_data: data iterator, must be non-shuffle
    :param imdb_test: image database
    :param vis: controls visualization
    :param ignore_cache: ignore the saved cache file
    :param logger: the logger instance
    :return:
    """
    print(imdb_test.result_path)
    print('test iter size: ', config.TEST.test_iter)
    pose_err_file = os.path.join(
        imdb_test.result_path,
        imdb_test.name + '_pose_iter{}.pkl'.format(config.TEST.test_iter))
    if os.path.exists(pose_err_file) and not ignore_cache and not vis:
        with open(pose_err_file, 'rb') as fid:
            if six.PY3:
                [all_rot_err, all_trans_err, all_poses_est,
                 all_poses_gt] = cPickle.load(fid, encoding='latin1')
            else:
                [all_rot_err, all_trans_err, all_poses_est,
                 all_poses_gt] = cPickle.load(fid)
        imdb_test.evaluate_pose(config, all_poses_est, all_poses_gt, logger)
        pose_add_plots_dir = os.path.join(imdb_test.result_path, 'add_plots')
        mkdir_if_missing(pose_add_plots_dir)
        imdb_test.evaluate_pose_add(config,
                                    all_poses_est,
                                    all_poses_gt,
                                    output_dir=pose_add_plots_dir,
                                    logger=logger)
        pose_arp2d_plots_dir = os.path.join(imdb_test.result_path,
                                            'arp_2d_plots')
        mkdir_if_missing(pose_arp2d_plots_dir)
        imdb_test.evaluate_pose_arp_2d(config,
                                       all_poses_est,
                                       all_poses_gt,
                                       output_dir=pose_arp2d_plots_dir,
                                       logger=logger)
        return

    assert vis or not test_data.shuffle
    assert config.TEST.BATCH_PAIRS == 1
    if not isinstance(test_data, PrefetchingIter):
        test_data = PrefetchingIter(test_data)

    num_pairs = len(pairdb)
    height = 480
    width = 640

    data_time, net_time, post_time = 0.0, 0.0, 0.0

    sum_EPE_all = 0.0
    num_inst_all = 0.0
    sum_EPE_viz = 0.0
    num_inst_viz = 0.0
    sum_EPE_vizbg = 0.0
    num_inst_vizbg = 0.0
    sum_PoseErr = [
        np.zeros((len(imdb_test.classes) + 1, 2))
        for batch_idx in range(config.TEST.test_iter)
    ]

    all_rot_err = [[[] for j in range(config.TEST.test_iter)]
                   for batch_idx in range(len(imdb_test.classes))
                   ]  # num_cls x test_iter
    all_trans_err = [[[] for j in range(config.TEST.test_iter)]
                     for batch_idx in range(len(imdb_test.classes))]

    all_poses_est = [[[] for j in range(config.TEST.test_iter)]
                     for batch_idx in range(len(imdb_test.classes))]
    all_poses_gt = [[[] for j in range(config.TEST.test_iter)]
                    for batch_idx in range(len(imdb_test.classes))]

    num_inst = np.zeros(len(imdb_test.classes) + 1)

    K = config.dataset.INTRINSIC_MATRIX
    if (config.TEST.test_iter > 1 or config.TEST.VISUALIZE) and True:
        print(
            "************* start setup render_glumpy environment... ******************"
        )
        if config.dataset.dataset.startswith('ModelNet'):
            from lib.render_glumpy.render_py_light_modelnet_multi import Render_Py_Light_ModelNet_Multi
            modelnet_root = config.modelnet_root
            texture_path = os.path.join(modelnet_root, 'gray_texture.png')

            model_path_list = [
                os.path.join(config.dataset.model_dir,
                             '{}.obj'.format(model_name))
                for model_name in config.dataset.class_name
            ]
            render_machine = Render_Py_Light_ModelNet_Multi(
                model_path_list,
                texture_path,
                K,
                width,
                height,
                config.dataset.ZNEAR,
                config.dataset.ZFAR,
                brightness_ratios=[0.7])
        else:
            render_machine = Render_Py(config.dataset.model_dir,
                                       config.dataset.class_name, K, width,
                                       height, config.dataset.ZNEAR,
                                       config.dataset.ZFAR)

        def render(render_machine, pose, cls_idx, K=None):
            if config.dataset.dataset.startswith('ModelNet'):
                idx = 2
                # generate random light_position
                if idx % 6 == 0:
                    light_position = [1, 0, 1]
                elif idx % 6 == 1:
                    light_position = [1, 1, 1]
                elif idx % 6 == 2:
                    light_position = [0, 1, 1]
                elif idx % 6 == 3:
                    light_position = [-1, 1, 1]
                elif idx % 6 == 4:
                    light_position = [-1, 0, 1]
                elif idx % 6 == 5:
                    light_position = [0, 0, 1]
                else:
                    raise Exception("???")
                light_position = np.array(light_position) * 0.5
                # inverse yz
                light_position[0] += pose[0, 3]
                light_position[1] -= pose[1, 3]
                light_position[2] -= pose[2, 3]

                colors = np.array([1, 1, 1])  # white light
                intensity = np.random.uniform(0.9, 1.1, size=(3, ))
                colors_randk = 0
                light_intensity = colors[colors_randk] * intensity

                # randomly choose a render machine
                rm_randk = 0  # random.randint(0, len(brightness_ratios) - 1)
                rgb_gl, depth_gl = render_machine.render(cls_idx,
                                                         pose[:3, :3],
                                                         pose[:3, 3],
                                                         light_position,
                                                         light_intensity,
                                                         brightness_k=rm_randk,
                                                         r_type='mat')
                rgb_gl = rgb_gl.astype('uint8')
            else:
                rgb_gl, depth_gl = render_machine.render(cls_idx,
                                                         pose[:3, :3],
                                                         pose[:, 3],
                                                         r_type='mat',
                                                         K=K)
                rgb_gl = rgb_gl.astype('uint8')
            return rgb_gl, depth_gl

        print(
            "***************setup render_glumpy environment succeed ******************"
        )

    if config.TEST.PRECOMPUTED_ICP:
        print('precomputed_ICP')
        config.TEST.test_iter = 1
        all_rot_err = [[[] for j in range(1)]
                       for batch_idx in range(len(imdb_test.classes))]
        all_trans_err = [[[] for j in range(1)]
                         for batch_idx in range(len(imdb_test.classes))]

        all_poses_est = [[[] for j in range(1)]
                         for batch_idx in range(len(imdb_test.classes))]
        all_poses_gt = [[[] for j in range(1)]
                        for batch_idx in range(len(imdb_test.classes))]

        xy_trans_err = [[[] for j in range(1)]
                        for batch_idx in range(len(imdb_test.classes))]
        z_trans_err = [[[] for j in range(1)]
                       for batch_idx in range(len(imdb_test.classes))]
        for idx in range(len(pairdb)):
            pose_path = pairdb[idx]['depth_rendered'][:-10] + '-pose_icp.txt'
            pose_rendered_update = np.loadtxt(pose_path, skiprows=1)
            pose_real = pairdb[idx]['pose_observed']
            r_dist_est, t_dist_est = calc_rt_dist_m(pose_rendered_update,
                                                    pose_real)
            xy_dist = np.linalg.norm(pose_rendered_update[:2, -1] -
                                     pose_real[:2, -1])
            z_dist = np.linalg.norm(pose_rendered_update[-1, -1] -
                                    pose_real[-1, -1])
            print(
                "{}: r_dist_est: {}, t_dist_est: {}, xy_dist: {}, z_dist: {}".
                format(idx, r_dist_est, t_dist_est, xy_dist, z_dist))
            class_id = imdb_test.classes.index(pairdb[idx]['gt_class'])
            # store poses estimation and gt
            all_poses_est[class_id][0].append(pose_rendered_update)
            all_poses_gt[class_id][0].append(pairdb[idx]['pose_observed'])
            all_rot_err[class_id][0].append(r_dist_est)
            all_trans_err[class_id][0].append(t_dist_est)
            xy_trans_err[class_id][0].append(xy_dist)
            z_trans_err[class_id][0].append(z_dist)
        all_rot_err = np.array(all_rot_err)
        all_trans_err = np.array(all_trans_err)
        print("rot = {} +/- {}".format(np.mean(all_rot_err[class_id][0]),
                                       np.std(all_rot_err[class_id][0])))
        print("trans = {} +/- {}".format(np.mean(all_trans_err[class_id][0]),
                                         np.std(all_trans_err[class_id][0])))
        num_list = all_trans_err[class_id][0]
        print("xyz: {:.2f} +/- {:.2f}".format(
            np.mean(num_list) * 100,
            np.std(num_list) * 100))
        num_list = xy_trans_err[class_id][0]
        print("xy: {:.2f} +/- {:.2f}".format(
            np.mean(num_list) * 100,
            np.std(num_list) * 100))
        num_list = z_trans_err[class_id][0]
        print("z: {:.2f} +/- {:.2f}".format(
            np.mean(num_list) * 100,
            np.std(num_list) * 100))

        imdb_test.evaluate_pose(config, all_poses_est, all_poses_gt, logger)
        pose_add_plots_dir = os.path.join(imdb_test.result_path,
                                          'add_plots_precomputed_ICP')
        mkdir_if_missing(pose_add_plots_dir)
        imdb_test.evaluate_pose_add(config,
                                    all_poses_est,
                                    all_poses_gt,
                                    output_dir=pose_add_plots_dir,
                                    logger=logger)
        pose_arp2d_plots_dir = os.path.join(imdb_test.result_path,
                                            'arp_2d_plots_precomputed_ICP')
        mkdir_if_missing(pose_arp2d_plots_dir)
        imdb_test.evaluate_pose_arp_2d(config,
                                       all_poses_est,
                                       all_poses_gt,
                                       output_dir=pose_arp2d_plots_dir,
                                       logger=logger)
        return

    if config.TEST.BEFORE_ICP:
        print('before_ICP')
        config.TEST.test_iter = 1
        all_rot_err = [[[] for j in range(1)]
                       for batch_idx in range(len(imdb_test.classes))]
        all_trans_err = [[[] for j in range(1)]
                         for batch_idx in range(len(imdb_test.classes))]

        all_poses_est = [[[] for j in range(1)]
                         for batch_idx in range(len(imdb_test.classes))]
        all_poses_gt = [[[] for j in range(1)]
                        for batch_idx in range(len(imdb_test.classes))]

        xy_trans_err = [[[] for j in range(1)]
                        for batch_idx in range(len(imdb_test.classes))]
        z_trans_err = [[[] for j in range(1)]
                       for batch_idx in range(len(imdb_test.classes))]
        for idx in range(len(pairdb)):
            pose_path = pairdb[idx]['depth_rendered'][:-10] + '-pose.txt'
            pose_rendered_update = np.loadtxt(pose_path, skiprows=1)
            pose_real = pairdb[idx]['pose_observed']
            r_dist_est, t_dist_est = calc_rt_dist_m(pose_rendered_update,
                                                    pose_real)
            xy_dist = np.linalg.norm(pose_rendered_update[:2, -1] -
                                     pose_real[:2, -1])
            z_dist = np.linalg.norm(pose_rendered_update[-1, -1] -
                                    pose_real[-1, -1])
            class_id = imdb_test.classes.index(pairdb[idx]['gt_class'])
            # store poses estimation and gt
            all_poses_est[class_id][0].append(pose_rendered_update)
            all_poses_gt[class_id][0].append(pairdb[idx]['pose_observed'])
            all_rot_err[class_id][0].append(r_dist_est)
            all_trans_err[class_id][0].append(t_dist_est)
            xy_trans_err[class_id][0].append(xy_dist)
            z_trans_err[class_id][0].append(z_dist)

        all_trans_err = np.array(all_trans_err)
        imdb_test.evaluate_pose(config, all_poses_est, all_poses_gt, logger)
        pose_add_plots_dir = os.path.join(imdb_test.result_path,
                                          'add_plots_before_ICP')
        mkdir_if_missing(pose_add_plots_dir)
        imdb_test.evaluate_pose_add(config,
                                    all_poses_est,
                                    all_poses_gt,
                                    output_dir=pose_add_plots_dir,
                                    logger=logger)
        pose_arp2d_plots_dir = os.path.join(imdb_test.result_path,
                                            'arp_2d_plots_before_ICP')
        mkdir_if_missing(pose_arp2d_plots_dir)
        imdb_test.evaluate_pose_arp_2d(config,
                                       all_poses_est,
                                       all_poses_gt,
                                       output_dir=pose_arp2d_plots_dir,
                                       logger=logger)
        return

    # ------------------------------------------------------------------------------
    t_start = time.time()
    t = time.time()
    for idx, data_batch in enumerate(test_data):
        if np.sum(pairdb[idx]
                  ['pose_rendered']) == -12:  # NO POINT VALID IN INIT POSE
            print(idx)
            class_id = imdb_test.classes.index(pairdb[idx]['gt_class'])
            for pose_iter_idx in range(config.TEST.test_iter):
                all_poses_est[class_id][pose_iter_idx].append(
                    pairdb[idx]['pose_rendered'])
                all_poses_gt[class_id][pose_iter_idx].append(
                    pairdb[idx]['pose_observed'])

                r_dist = 1000
                t_dist = 1000
                all_rot_err[class_id][pose_iter_idx].append(r_dist)
                all_trans_err[class_id][pose_iter_idx].append(t_dist)
                sum_PoseErr[pose_iter_idx][class_id, :] += np.array(
                    [r_dist, t_dist])
                sum_PoseErr[pose_iter_idx][-1, :] += np.array([r_dist, t_dist])
                # post process
            if idx % 50 == 0:
                print_and_log(
                    'testing {}/{} data {:.4f}s net {:.4f}s calc_gt {:.4f}s'.
                    format((idx + 1), num_pairs,
                           data_time / (idx + 1) * test_data.batch_size,
                           net_time / (idx + 1) * test_data.batch_size,
                           post_time / (idx + 1) * test_data.batch_size),
                    logger)
            print("NO POINT_VALID IN rendered")
            continue
        data_time += time.time() - t

        t = time.time()

        pose_rendered = pairdb[idx]['pose_rendered']
        if np.sum(pose_rendered) == -12:
            print(idx)
            class_id = imdb_test.classes.index(pairdb[idx]['gt_class'])
            num_inst[class_id] += 1
            num_inst[-1] += 1
            for pose_iter_idx in range(config.TEST.test_iter):
                all_poses_est[class_id][pose_iter_idx].append(pose_rendered)
                all_poses_gt[class_id][pose_iter_idx].append(
                    pairdb[idx]['pose_observed'])

            # post process
            if idx % 50 == 0:
                print_and_log(
                    'testing {}/{} data {:.4f}s net {:.4f}s calc_gt {:.4f}s'.
                    format((idx + 1), num_pairs,
                           data_time / (idx + 1) * test_data.batch_size,
                           net_time / (idx + 1) * test_data.batch_size,
                           post_time / (idx + 1) * test_data.batch_size),
                    logger)

            t = time.time()
            continue

        output_all = predictor.predict(data_batch)
        net_time += time.time() - t

        t = time.time()
        rst_iter = []
        for output in output_all:
            cur_rst = {}
            cur_rst['se3'] = np.squeeze(
                output['se3_output'].asnumpy()).astype('float32')

            if not config.TEST.FAST_TEST and config.network.PRED_FLOW:
                cur_rst['flow'] = np.squeeze(
                    output['flow_est_crop_output'].asnumpy().transpose(
                        (2, 3, 1, 0))).astype('float16')
            else:
                cur_rst['flow'] = None
            if config.network.PRED_MASK and config.TEST.UPDATE_MASK not in [
                    'init', 'box_rendered'
            ]:
                mask_pred = np.squeeze(
                    output['mask_observed_pred_output'].asnumpy()).astype(
                        'float32')
                cur_rst['mask_pred'] = mask_pred

            rst_iter.append(cur_rst)

        post_time += time.time() - t
        sample_ratio = 1  # 0.01
        for batch_idx in range(0, test_data.batch_size):
            # if config.TEST.VISUALIZE and not (r_dist>15 and t_dist>0.05):
            #     continue # 3388, 5326
            # calculate the flow error --------------------------------------------
            t = time.time()
            if config.network.PRED_FLOW and not config.TEST.FAST_TEST:
                # evaluate optical flow
                flow_gt = par_generate_gt(config, pairdb[idx])
                if config.network.PRED_FLOW:
                    all_diff = calc_EPE_one_pair(rst_iter[batch_idx], flow_gt,
                                                 'flow')
                sum_EPE_all += all_diff['epe_all']
                num_inst_all += all_diff['num_all']
                sum_EPE_viz += all_diff['epe_viz']
                num_inst_viz += all_diff['num_viz']
                sum_EPE_vizbg += all_diff['epe_vizbg']
                num_inst_vizbg += all_diff['num_vizbg']

            # calculate the se3 error ---------------------------------------------
            # evaluate se3 estimation
            pose_rendered = pairdb[idx]['pose_rendered']
            class_id = imdb_test.classes.index(pairdb[idx]['gt_class'])
            num_inst[class_id] += 1
            num_inst[-1] += 1
            post_time += time.time() - t

            # iterative refine se3 estimation --------------------------------------------------
            for pose_iter_idx in range(config.TEST.test_iter):
                t = time.time()
                pose_rendered_update = RT_transform(pose_rendered,
                                                    rst_iter[0]['se3'][:-3],
                                                    rst_iter[0]['se3'][-3:],
                                                    config.dataset.trans_means,
                                                    config.dataset.trans_stds,
                                                    config.network.ROT_COORD)

                # calculate error
                r_dist, t_dist = calc_rt_dist_m(pose_rendered_update,
                                                pairdb[idx]['pose_observed'])

                # store poses estimation and gt
                all_poses_est[class_id][pose_iter_idx].append(
                    pose_rendered_update)
                all_poses_gt[class_id][pose_iter_idx].append(
                    pairdb[idx]['pose_observed'])

                all_rot_err[class_id][pose_iter_idx].append(r_dist)
                all_trans_err[class_id][pose_iter_idx].append(t_dist)
                sum_PoseErr[pose_iter_idx][class_id, :] += np.array(
                    [r_dist, t_dist])
                sum_PoseErr[pose_iter_idx][-1, :] += np.array([r_dist, t_dist])
                if config.TEST.VISUALIZE:
                    print("idx {}, iter {}: rError: {}, tError: {}".format(
                        idx + batch_idx, pose_iter_idx + 1, r_dist, t_dist))

                post_time += time.time() - t

                # # if more than one iteration
                if pose_iter_idx < (config.TEST.test_iter -
                                    1) or config.TEST.VISUALIZE:
                    t = time.time()
                    # get refined image
                    K_path = pairdb[idx]['image_observed'][:-10] + '-K.txt'
                    if os.path.exists(K_path):
                        K = np.loadtxt(K_path)
                    image_refined, depth_refined = render(
                        render_machine,
                        pose_rendered_update,
                        config.dataset.class_name.index(
                            pairdb[idx]['gt_class']),
                        K=K)
                    image_refined = image_refined[:, :, :3]

                    # update minibatch
                    update_package = [{
                        'image_rendered': image_refined,
                        'src_pose': pose_rendered_update
                    }]
                    if config.network.INPUT_DEPTH:
                        update_package[0]['depth_rendered'] = depth_refined
                    if config.network.INPUT_MASK:
                        mask_rendered_refined = np.zeros(depth_refined.shape)
                        mask_rendered_refined[depth_refined > 0.2] = 1
                        update_package[0][
                            'mask_rendered'] = mask_rendered_refined
                        if config.network.PRED_MASK:
                            # init, box_rendered, mask_rendered, box_real, mask_observed
                            if config.TEST.UPDATE_MASK == 'box_rendered':
                                input_names = [
                                    blob_name[0]
                                    for blob_name in data_batch.provide_data[0]
                                ]
                                update_package[0]['mask_observed'] = np.squeeze(
                                    data_batch.data[0][input_names.index(
                                        'mask_rendered')].asnumpy()[batch_idx])
                            elif config.TEST.UPDATE_MASK == 'init':
                                pass
                            else:
                                raise Exception(
                                    'Unknown UPDATE_MASK type: {}'.format(
                                        config.network.UPDATE_MASK))

                    pose_rendered = pose_rendered_update
                    data_batch = update_data_batch(config, data_batch,
                                                   update_package)

                    data_time += time.time() - t

                    # forward and get rst
                    if pose_iter_idx < config.TEST.test_iter - 1:
                        t = time.time()
                        output_all = predictor.predict(data_batch)
                        net_time += time.time() - t

                        t = time.time()
                        rst_iter = []
                        for output in output_all:
                            cur_rst = {}
                            if config.network.REGRESSOR_NUM == 1:
                                cur_rst['se3'] = np.squeeze(
                                    output['se3_output'].asnumpy()).astype(
                                        'float32')

                            if not config.TEST.FAST_TEST and config.network.PRED_FLOW:
                                cur_rst['flow'] = np.squeeze(
                                    output['flow_est_crop_output'].asnumpy().
                                    transpose((2, 3, 1, 0))).astype('float16')
                            else:
                                cur_rst['flow'] = None

                            if config.network.PRED_MASK and config.TEST.UPDATE_MASK not in [
                                    'init', 'box_rendered'
                            ]:
                                mask_pred = np.squeeze(
                                    output['mask_observed_pred_output'].
                                    asnumpy()).astype('float32')
                                cur_rst['mask_pred'] = mask_pred

                            rst_iter.append(cur_rst)
                            post_time += time.time() - t

        # post process
        if idx % 50 == 0:
            print_and_log(
                'testing {}/{} data {:.4f}s net {:.4f}s calc_gt {:.4f}s'.
                format((idx + 1), num_pairs,
                       data_time / (idx + 1) * test_data.batch_size,
                       net_time / (idx + 1) * test_data.batch_size,
                       post_time / (idx + 1) * test_data.batch_size), logger)

        t = time.time()

    all_rot_err = np.array(all_rot_err)
    all_trans_err = np.array(all_trans_err)

    # save inference results
    if not config.TEST.VISUALIZE:
        with open(pose_err_file, 'wb') as f:
            print("saving result cache to {}".format(pose_err_file), )
            cPickle.dump(
                [all_rot_err, all_trans_err, all_poses_est, all_poses_gt],
                f,
                protocol=2)
            print("done")

    if config.network.PRED_FLOW:
        print_and_log('evaluate flow:', logger)
        print_and_log(
            'EPE all: {}'.format(sum_EPE_all / max(num_inst_all, 1.0)), logger)
        print_and_log(
            'EPE ignore unvisible: {}'.format(
                sum_EPE_vizbg / max(num_inst_vizbg, 1.0)), logger)
        print_and_log(
            'EPE visible: {}'.format(sum_EPE_viz / max(num_inst_viz, 1.0)),
            logger)

    print_and_log('evaluate pose:', logger)
    imdb_test.evaluate_pose(config, all_poses_est, all_poses_gt, logger)
    # evaluate pose add
    pose_add_plots_dir = os.path.join(imdb_test.result_path, 'add_plots')
    mkdir_if_missing(pose_add_plots_dir)
    imdb_test.evaluate_pose_add(config,
                                all_poses_est,
                                all_poses_gt,
                                output_dir=pose_add_plots_dir,
                                logger=logger)
    pose_arp2d_plots_dir = os.path.join(imdb_test.result_path, 'arp_2d_plots')
    mkdir_if_missing(pose_arp2d_plots_dir)
    imdb_test.evaluate_pose_arp_2d(config,
                                   all_poses_est,
                                   all_poses_gt,
                                   output_dir=pose_arp2d_plots_dir,
                                   logger=logger)

    print_and_log('using {} seconds in total'.format(time.time() - t_start),
                  logger)
Exemplo n.º 6
0
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, lr,
              lr_step):
    new_args_name = args.cfg
    if args.vis:
        config.TRAIN.VISUALIZE = True
    logger, final_output_path = create_logger(config.output_path,
                                              new_args_name,
                                              config.dataset.image_set,
                                              args.temp)
    prefix = os.path.join(final_output_path, prefix)
    logger.info('called with args {}'.format(args))

    print(config.train_iter.SE3_PM_LOSS)
    if config.train_iter.SE3_PM_LOSS:
        print("SE3_PM_LOSS == True")
    else:
        print("SE3_PM_LOSS == False")

    if not config.network.STANDARD_FLOW_REP:
        print_and_log("[h, w] representation for flow is dep", logger)

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    datasets = [dset for dset in config.dataset.dataset.split('+')]
    print("config.dataset.class_name: {}".format(config.dataset.class_name))
    print("image_sets: {}".format(image_sets))
    if datasets[0].startswith('ModelNet'):
        pairdbs = [
            load_gt_pairdb(config,
                           datasets[i],
                           image_sets[i] + class_name.split('/')[-1],
                           config.dataset.root_path,
                           config.dataset.dataset_path,
                           class_name=class_name,
                           result_path=final_output_path)
            for class_name in config.dataset.class_name
            for i in range(len(image_sets))
        ]
    else:
        pairdbs = [
            load_gt_pairdb(config,
                           datasets[i],
                           image_sets[i] + class_name,
                           config.dataset.root_path,
                           config.dataset.dataset_path,
                           class_name=class_name,
                           result_path=final_output_path)
            for class_name in config.dataset.class_name
            for i in range(len(image_sets))
        ]
    pairdb = merge_pairdb(pairdbs)

    if not args.temp:
        src_file = os.path.join(curr_path, 'symbols', config.symbol + '.py')
        dst_file = os.path.join(
            final_output_path,
            '{}_{}.py'.format(config.symbol, time.strftime('%Y-%m-%d-%H-%M')))
        os.popen('cp {} {}'.format(src_file, dst_file))

    sym_instance = eval(config.symbol + '.' + config.symbol)()
    sym = sym_instance.get_symbol(config, is_train=True)

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_PAIRS * batch_size

    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # load training data
    train_data = TrainDataLoader(sym,
                                 pairdb,
                                 config,
                                 batch_size=input_batch_size,
                                 shuffle=config.TRAIN.SHUFFLE,
                                 ctx=ctx)

    train_data.get_batch_parallel()
    max_scale = [
        max([v[0] for v in config.SCALES]),
        max(v[1] for v in config.SCALES)
    ]
    max_data_shape = [('image_observed', (config.TRAIN.BATCH_PAIRS, 3,
                                          max_scale[0], max_scale[1])),
                      ('image_rendered', (config.TRAIN.BATCH_PAIRS, 3,
                                          max_scale[0], max_scale[1])),
                      ('depth_gt_observed', (config.TRAIN.BATCH_PAIRS, 1,
                                             max_scale[0], max_scale[1])),
                      ('src_pose', (config.TRAIN.BATCH_PAIRS, 3, 4)),
                      ('tgt_pose', (config.TRAIN.BATCH_PAIRS, 3, 4))]
    if config.network.INPUT_DEPTH:
        max_data_shape.append(('depth_observed', (config.TRAIN.BATCH_PAIRS, 1,
                                                  max_scale[0], max_scale[1])))
        max_data_shape.append(('depth_rendered', (config.TRAIN.BATCH_PAIRS, 1,
                                                  max_scale[0], max_scale[1])))
    if config.network.INPUT_MASK:
        max_data_shape.append(('mask_observed', (config.TRAIN.BATCH_PAIRS, 1,
                                                 max_scale[0], max_scale[1])))
        max_data_shape.append(('mask_rendered', (config.TRAIN.BATCH_PAIRS, 1,
                                                 max_scale[0], max_scale[1])))

    rot_param = 3 if config.network.ROT_TYPE == "EULER" else 4
    max_label_shape = [('rot', (config.TRAIN.BATCH_PAIRS, rot_param)),
                       ('trans', (config.TRAIN.BATCH_PAIRS, 3))]
    if config.network.PRED_FLOW:
        max_label_shape.append(('flow', (config.TRAIN.BATCH_PAIRS, 2,
                                         max_scale[0], max_scale[1])))
        max_label_shape.append(('flow_weights', (config.TRAIN.BATCH_PAIRS, 2,
                                                 max_scale[0], max_scale[1])))
    if config.train_iter.SE3_PM_LOSS:
        max_label_shape.append(
            ('point_cloud_model', (config.TRAIN.BATCH_PAIRS, 3,
                                   config.train_iter.NUM_3D_SAMPLE)))
        max_label_shape.append(
            ('point_cloud_weights', (config.TRAIN.BATCH_PAIRS, 3,
                                     config.train_iter.NUM_3D_SAMPLE)))
        max_label_shape.append(
            ('point_cloud_observed', (config.TRAIN.BATCH_PAIRS, 3,
                                      config.train_iter.NUM_3D_SAMPLE)))
    if config.network.PRED_MASK:
        max_label_shape.append(
            ('mask_gt_observed', (config.TRAIN.BATCH_PAIRS, 1, max_scale[0],
                                  max_scale[1])))

    # max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape, max_label_shape)
    print_and_log(
        'providing maximum shape, {}, {}'.format(max_data_shape,
                                                 max_label_shape), logger)

    # infer max shape
    '''
    max_label_shape = [('label', (config.TRAIN.BATCH_IMAGES, 1,
                                  max([v[0] for v in max_scale]),
                                  max([v[1] for v in max_scale])))]
    max_data_shape, max_label_shape = train_data.infer_shape(
        max_data_shape, max_label_shape)
    print('providing maximum shape', max_data_shape, max_label_shape)
    '''
    # infer shape
    data_shape_dict = dict(train_data.provide_data_single +
                           train_data.provide_label_single)
    print_and_log('\ndata_shape_dict: {}\n'.format(data_shape_dict), logger)
    sym_instance.infer_shape(data_shape_dict)

    print('************(wg): infering shape **************')
    internals = sym.get_internals()
    _, out_shapes, _ = internals.infer_shape(**data_shape_dict)
    print(sym.list_outputs())
    shape_dict = dict(zip(internals.list_outputs(), out_shapes))
    pprint.pprint(shape_dict)

    # load and initialize params
    if config.TRAIN.RESUME:
        print('continue training from ', begin_epoch)
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    elif pretrained == 'xavier':
        print('xavier')
        # arg_params = {}
        # aux_params = {}
        # sym_instance.init_weights(config, arg_params, aux_params)
    else:
        print(pretrained)
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        print('arg_params: ', arg_params.keys())
        print('aux_params: ', aux_params.keys())
        if not config.network.skip_initialize:
            sym_instance.init_weights(config, arg_params, aux_params)

    # check parameter shapes
    if pretrained != 'xavier':
        sym_instance.check_parameter_shapes(arg_params, aux_params,
                                            data_shape_dict)

    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in range(batch_size)],
        max_label_shapes=[max_label_shape for _ in range(batch_size)],
        fixed_param_prefix=fixed_param_prefix,
        config=config)

    # decide training params
    # metrics
    eval_metrics = mx.metric.CompositeEvalMetric()

    metric_list = []
    iter_idx = 0
    if config.network.PRED_FLOW:
        metric_list.append(metric.Flow_L2LossMetric(config, iter_idx))
        metric_list.append(metric.Flow_CurLossMetric(config, iter_idx))
    if config.train_iter.SE3_DIST_LOSS:
        metric_list.append(metric.Rot_L2LossMetric(config, iter_idx))
        metric_list.append(metric.Trans_L2LossMetric(config, iter_idx))
    if config.train_iter.SE3_PM_LOSS:
        metric_list.append(metric.PointMatchingLossMetric(config, iter_idx))
    if config.network.PRED_MASK:
        metric_list.append(metric.MaskLossMetric(config, iter_idx))

    # Visualize Training Batches
    if config.TRAIN.VISUALIZE:
        metric_list.append(metric.SimpleVisualize(config))
        # metric_list.append(metric.MaskVisualize(config, save_dir = final_output_path))
        metric_list.append(
            metric.MinibatchVisualize(config))  # flow visualization

    for child_metric in metric_list:
        eval_metrics.add(child_metric)

    # callback
    batch_end_callback = callback.Speedometer(train_data.batch_size,
                                              frequent=args.frequent)
    epoch_end_callback = mx.callback.module_checkpoint(
        mod, prefix, period=1, save_optimizer_states=True)
    # decide learning rate
    base_lr = lr
    lr_factor = 0.1
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(pairdb) / batch_size) for epoch in lr_epoch_diff
    ]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)

    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor,
                                              config.TRAIN.warmup,
                                              config.TRAIN.warmup_lr,
                                              config.TRAIN.warmup_step)

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    # train
    if config.TRAIN.optimizer == 'adam':
        optimizer_params = {'learning_rate': lr}
        if pretrained == 'xavier':
            init = mx.init.Mixed(['rot_weight|trans_weight', '.*'], [
                mx.init.Zero(),
                mx.init.Xavier(
                    rnd_type='gaussian', factor_type="in", magnitude=2)
            ])
            mod.fit(train_data,
                    eval_metric=eval_metrics,
                    epoch_end_callback=epoch_end_callback,
                    batch_end_callback=batch_end_callback,
                    kvstore=config.default.kvstore,
                    optimizer='adam',
                    optimizer_params=optimizer_params,
                    begin_epoch=begin_epoch,
                    num_epoch=end_epoch,
                    prefix=prefix,
                    initializer=init,
                    force_init=True)
        else:
            mod.fit(train_data,
                    eval_metric=eval_metrics,
                    epoch_end_callback=epoch_end_callback,
                    batch_end_callback=batch_end_callback,
                    kvstore=config.default.kvstore,
                    optimizer='adam',
                    arg_params=arg_params,
                    aux_params=aux_params,
                    begin_epoch=begin_epoch,
                    num_epoch=end_epoch,
                    prefix=prefix)
    elif config.TRAIN.optimizer == 'sgd':
        # optimizer
        optimizer_params = {
            'momentum': config.TRAIN.momentum,
            'wd': config.TRAIN.wd,
            'learning_rate': lr,
            'lr_scheduler': lr_scheduler,
            'rescale_grad': 1.0,
            'clip_gradient': None
        }
        if pretrained == 'xavier':
            init = mx.init.Mixed(['rot_weight|trans_weight', '.*'], [
                mx.init.Zero(),
                mx.init.Xavier(
                    rnd_type='gaussian', factor_type="in", magnitude=2)
            ])
            mod.fit(train_data,
                    eval_metric=eval_metrics,
                    epoch_end_callback=epoch_end_callback,
                    batch_end_callback=batch_end_callback,
                    kvstore=config.default.kvstore,
                    optimizer='sgd',
                    optimizer_params=optimizer_params,
                    begin_epoch=begin_epoch,
                    num_epoch=end_epoch,
                    prefix=prefix,
                    initializer=init,
                    force_init=True)
        else:
            mod.fit(train_data,
                    eval_metric=eval_metrics,
                    epoch_end_callback=epoch_end_callback,
                    batch_end_callback=batch_end_callback,
                    kvstore=config.default.kvstore,
                    optimizer='sgd',
                    optimizer_params=optimizer_params,
                    arg_params=arg_params,
                    aux_params=aux_params,
                    begin_epoch=begin_epoch,
                    num_epoch=end_epoch,
                    prefix=prefix)