コード例 #1
0
    def test_custom_tensorboard(self):
        log_dir = '/tmp/tf3d/callback_util_test'
        if tf.io.gfile.exists(log_dir):
            tf.io.gfile.rmtree(log_dir)

        callback = callback_utils.CustomTensorBoard(
            log_dir=log_dir,
            metric_classes=None,
            batch_update_freq=1,
            num_qualitative_examples=10,
            split='val')
        model = tf.keras.Model()
        model.compile(optimizer=tf.keras.optimizers.SGD(
            learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
                0.01, 1, 0.1)))
        model.loss_names_to_losses = {'total_loss': 5.}
        inputs = {
            standard_fields.InputDataFields.point_positions:
            tf.constant([[[3., 2., 1.], [2., 3., 1.]]]),
            standard_fields.InputDataFields.num_valid_points:
            tf.constant([1]),
            standard_fields.InputDataFields.object_class_points:
            tf.constant([[[0], [1]]]),
            # gt objects
            standard_fields.InputDataFields.objects_length:
            tf.constant([[[3.]]]),
            standard_fields.InputDataFields.objects_height:
            tf.constant([[[1.]]]),
            standard_fields.InputDataFields.objects_width:
            tf.constant([[[2.]]]),
            standard_fields.InputDataFields.objects_center:
            tf.constant([[[0., 0., 0.]]]),
            standard_fields.InputDataFields.objects_rotation_matrix:
            tf.eye(3, 3)[tf.newaxis, tf.newaxis, Ellipsis],
            standard_fields.InputDataFields.objects_class:
            tf.constant([[[1]]]),
            standard_fields.InputDataFields.camera_image_name:
            tf.convert_to_tensor([['image1', 'image2']], dtype=tf.string)
        }
        outputs = {
            standard_fields.DetectionResultFields.object_semantic_points:
            tf.constant([[[3., 2.], [2., 3.]]]),
            # predicted objects
            standard_fields.DetectionResultFields.objects_length:
            tf.constant([[3.]]),
            standard_fields.DetectionResultFields.objects_height:
            tf.constant([[1.]]),
            standard_fields.DetectionResultFields.objects_width:
            tf.constant([[2.]]),
            standard_fields.DetectionResultFields.objects_center:
            tf.constant([[0., 0., 0.]]),
            standard_fields.DetectionResultFields.objects_rotation_matrix:
            tf.expand_dims(tf.eye(3, 3), axis=0),
            standard_fields.DetectionResultFields.objects_class:
            tf.constant([[1]]),
        }

        callback.set_model(model)
        callback.on_train_begin()
        callback.on_epoch_begin(epoch=1, logs=None)
        callback.on_train_batch_begin(batch=1, logs=None)
        callback.on_train_batch_end(batch=1, logs=None)
        callback.on_epoch_end(epoch=1, logs=None)
        callback.on_train_end()
        self.assertNotEmpty(
            (tf.io.gfile.glob(os.path.join(log_dir, 'train/events*'))))

        callback.on_predict_begin()
        callback.on_predict_batch_begin(batch=1, logs=None)
        callback.on_predict_batch_end(batch=1,
                                      logs={
                                          'outputs': outputs,
                                          'inputs': inputs
                                      })
        callback.on_predict_end()
        self.assertEmpty(
            (tf.io.gfile.glob(os.path.join(log_dir, 'eval_val/events*'))))
        self.assertNotEmpty(
            (tf.io.gfile.glob(os.path.join(log_dir, 'eval_val_mesh/events*'))))
コード例 #2
0
ファイル: eval.py プロジェクト: ziyouzizai111/google-research
def evaluation(model_class=None,
               input_fn=None,
               num_quantitative_examples=1000,
               num_qualitative_examples=50):
    """A function that build the model and eval quali."""

    tensorboard_callback = callback_utils.CustomTensorBoard(
        log_dir=FLAGS.eval_dir,
        batch_update_freq=1,
        split=FLAGS.split,
        num_qualitative_examples=num_qualitative_examples,
        num_steps_per_epoch=FLAGS.num_steps_per_epoch)
    model = model_class()
    checkpoint = tf.train.Checkpoint(model=model,
                                     ckpt_saved_epoch=tf.Variable(
                                         initial_value=-1, dtype=tf.int64))
    val_inputs = input_fn(is_training=False, batch_size=1)
    num_evauated_epoch = -1

    while True:
        ckpt_path = tf.train.latest_checkpoint(FLAGS.ckpt_dir)
        if ckpt_path:
            ckpt_num_of_epoch = int(ckpt_path.split('/')[-1].split('-')[-1])
            if num_evauated_epoch == ckpt_num_of_epoch:
                logging.info(
                    'Found old epoch %d ckpt, skip and will check later.',
                    num_evauated_epoch)
                time.sleep(30)
                continue
            try:
                logging.info('Restoring new checkpoint[epoch:%d] at %s',
                             ckpt_num_of_epoch, ckpt_path)
                checkpoint.restore(ckpt_path)
            except tf.errors.NotFoundError:
                logging.info(
                    'Restoring from checkpoint has failed. Maybe file missing.'
                    'Try again now.')
                time.sleep(3)
                continue
        else:
            logging.info(
                'No checkpoint found at %s, will check again 10 s later..',
                FLAGS.ckpt_dir)
            time.sleep(10)
            continue

        tensorboard_callback.set_epoch_number(ckpt_num_of_epoch)
        logging.info('Start qualitative eval for %d steps...',
                     num_quantitative_examples)
        try:
            # TODO(huangrui): there is still possibility of crash due to
            # not found ckpt files.
            model._predict_counter.assign(0)  # pylint: disable=protected-access
            tensorboard_callback.set_model(model)
            tensorboard_callback.on_predict_begin()
            for i, inputs in enumerate(
                    val_inputs.take(num_quantitative_examples), start=1):
                tensorboard_callback.on_predict_batch_begin(batch=i)
                outputs = model(inputs, training=False)
                model._predict_counter.assign_add(1)  # pylint: disable=protected-access
                tensorboard_callback.on_predict_batch_end(batch=i,
                                                          logs={
                                                              'outputs':
                                                              outputs,
                                                              'inputs': inputs
                                                          })
                if i % FLAGS.num_steps_per_log == 0:
                    logging.info('eval progress %d / %d...', i,
                                 num_quantitative_examples)
            tensorboard_callback.on_predict_end()

            num_evauated_epoch = ckpt_num_of_epoch
            logging.info('Finished eval for epoch %d, sleeping for :%d s...',
                         num_evauated_epoch, 100)
            time.sleep(100)
        except tf.errors.NotFoundError:
            logging.info(
                'Restoring from checkpoint has failed. Maybe file missing.'
                'Try again now.')
            continue