Beispiel #1
0
    def test_batch_norm_class(self):
        # This tests the model and trainer set up
        train_config_text_proto = """
        optimizer {
          gradient_descent {
            learning_rate {
              constant_learning_rate {
                learning_rate: 1.0
              }
            }
          }
        }
        max_iterations: 5
        """
        model_config_text_proto = """
            path_drop_probabilities: [1.0, 1.0]
        """
        train_config = train_pb2.TrainConfig()
        text_format.Merge(train_config_text_proto, train_config)

        model_config = model_pb2.ModelConfig()
        text_format.Merge(model_config_text_proto, model_config)
        train_config.overwrite_checkpoints = True
        test_root_dir = '/tmp/mlod_unit_test/'

        paths_config = model_config.paths_config
        paths_config.logdir = test_root_dir + 'logs/'
        paths_config.checkpoint_dir = test_root_dir

        classifier = FakeBatchNormClassifier(model_config)
        trainer.train(classifier, train_config)
Beispiel #2
0
def train(rpn_model_config, mlod_model_config, rpn_train_config,
          mlod_train_config, dataset_config):

    train_val_test = 'train'
    dataset = DatasetBuilder.build_kitti_dataset(dataset_config,
                                                 use_defaults=False)

    paths_config = rpn_model_config.paths_config
    rpn_checkpoint_dir = paths_config.checkpoint_dir

    with tf.Graph().as_default():
        model = RpnModel(rpn_model_config,
                         train_val_test=train_val_test,
                         dataset=dataset)
        trainer.train(model, rpn_train_config)

        # load the weights back in
        saver = tf.train.Saver()
        init_op = tf.global_variables_initializer()
        with tf.Session() as sess:
            sess.run(init_op)
            trainer_utils.load_checkpoints(rpn_checkpoint_dir, saver)
            checkpoint_to_restore = saver.last_checkpoints[-1]
            trainer_utils.load_model_weights(model, sess,
                                             checkpoint_to_restore)

    # Merge RPN configs with MLOD - This will overwrite
    # the appropriate configs set for MLOD while keeping
    # the common configs the same.
    rpn_model_config.MergeFrom(mlod_model_config)
    rpn_train_config.MergeFrom(mlod_train_config)
    mlod_model_merged = deepcopy(rpn_model_config)
    mlod_train_merged = deepcopy(rpn_train_config)

    with tf.Graph().as_default():
        model = MlodModel(mlod_model_merged,
                          train_val_test=train_val_test,
                          dataset=dataset)
        trainer.train(model,
                      mlod_train_merged,
                      stagewise_training=True,
                      init_checkpoint_dir=rpn_checkpoint_dir)
Beispiel #3
0
def train(model_config, train_config, dataset_config):

    dataset = DatasetBuilder.build_kitti_dataset(dataset_config,
                                                 use_defaults=False)

    train_val_test = 'train'
    model_name = model_config.model_name

    with tf.Graph().as_default():
        if model_name == 'rpn_model':
            model = RpnModel(model_config,
                             train_val_test=train_val_test,
                             dataset=dataset)
        elif model_name == 'mlod_model':
            model = MlodModel(model_config,
                              train_val_test=train_val_test,
                              dataset=dataset)
        else:
            raise ValueError('Invalid model_name')

        trainer.train(model, train_config)
Beispiel #4
0
def train_and_eval(model_config, train_config, eval_config, dataset_config):

    # Dataset Configuration
    dataset_config_train = DatasetBuilder.copy_config(dataset_config)
    dataset_config_eval = DatasetBuilder.copy_config(dataset_config)

    dataset_train = DatasetBuilder.build_kitti_dataset(dataset_config_train,
                                                       use_defaults=False)
    dataset_eval = DatasetBuilder.build_kitti_dataset(dataset_config_eval,
                                                      use_defaults=False)

    model_name = model_config.model_name
    train_val_test = 'train'
    eval_mode = eval_config.eval_mode
    if eval_mode == 'train':
        raise ValueError('Evaluation mode can only be set to `val` or `test`.')

    # keep a copy as this will be overwritten inside
    # the training loop below
    max_train_iter = train_config.max_iterations
    checkpoint_interval = train_config.checkpoint_interval
    eval_interval = eval_config.eval_interval

    if eval_interval < checkpoint_interval or \
            (eval_interval % checkpoint_interval) != 0:
        raise ValueError(
            'Checkpoint interval (given {}) must be greater than and'
            'divisible by the evaluation interval (given {}).'.format(
                eval_interval, checkpoint_interval))

    # Use the evaluation losses file to continue from the latest
    # checkpoint
    already_evaluated_ckpts = evaluator.get_evaluated_ckpts(
        model_config, model_name)

    if len(already_evaluated_ckpts) != 0:
        current_train_iter = already_evaluated_ckpts[-1]
    else:
        current_train_iter = eval_interval

    # while training is not finished
    while current_train_iter <= max_train_iter:
        # Train
        with tf.Graph().as_default():
            if model_name == 'mlod_model':
                model = MlodModel(model_config,
                                  train_val_test=train_val_test,
                                  dataset=dataset_train)
            elif model_name == 'rpn_model':
                model = RpnModel(model_config,
                                 train_val_test=train_val_test,
                                 dataset=dataset_train)
            else:
                raise ValueError('Invalid model name {}'.format(model_name))

            # overwrite the training epochs
            train_config.max_iterations = current_train_iter
            print('\n*************** Training ****************\n')
            trainer.train(model, train_config)
        current_train_iter += eval_interval

        # Evaluate
        with tf.Graph().as_default():
            if model_name == 'mlod_model':
                model = MlodModel(model_config,
                                  train_val_test=eval_mode,
                                  dataset=dataset_eval)
            elif model_name == 'rpn_model':
                model = RpnModel(model_config,
                                 train_val_test=eval_mode,
                                 dataset=dataset_eval)
            else:
                raise ValueError('Invalid model name {}'.format(model_name))

            print('\n*************** Evaluating *****************\n')
            evaluator.run_latest_checkpoints(model, dataset_config_eval)
    print('\n************ Finished training and evaluating *************\n')
Beispiel #5
0
    def test_load_model_weights(self):
        # Tests stagewise training i.e. loading RPN weights
        # onto MLOD

        train_val_test = 'train'
        # overwrite the training iterations
        self.train_config.max_iterations = 1
        self.train_config.overwrite_checkpoints = True

        rpn_weights = []
        rpn_weights_reload = []

        with tf.Graph().as_default():
            model = RpnModel(self.model_config,
                             train_val_test=train_val_test,
                             dataset=self.dataset)
            trainer.train(model, self.train_config)

            paths_config = self.model_config.paths_config
            rpn_checkpoint_dir = paths_config.checkpoint_dir

            # load the weights back in
            init_op = tf.global_variables_initializer()

            saver = tf.train.Saver()
            with tf.Session() as sess:
                sess.run(init_op)

                trainer_utils.load_checkpoints(rpn_checkpoint_dir, saver)
                checkpoint_to_restore = saver.last_checkpoints[-1]

                trainer_utils.load_model_weights(model, sess,
                                                 checkpoint_to_restore)

                rpn_vars = slim.get_model_variables()
                rpn_weights = sess.run(rpn_vars)
                self.assertGreater(len(rpn_weights),
                                   0,
                                   msg='Loaded RPN weights are empty')

        with tf.Graph().as_default():
            model = MlodModel(self.model_config,
                              train_val_test=train_val_test,
                              dataset=self.dataset)
            model.build()

            # load the weights back in
            init_op = tf.global_variables_initializer()

            saver = tf.train.Saver()
            with tf.Session() as sess:
                sess.run(init_op)

                trainer_utils.load_checkpoints(rpn_checkpoint_dir, saver)
                checkpoint_to_restore = saver.last_checkpoints[-1]
                trainer_utils.load_model_weights(model, sess,
                                                 checkpoint_to_restore)

                mlod_vars = slim.get_model_variables()
                mlod_weights = sess.run(mlod_vars)
                # MLOD weights should include both RPN + MLOD weights
                self.assertGreater(len(mlod_weights),
                                   len(rpn_weights),
                                   msg='Expected more weights for MLOD')

                # grab weights corresponding to RPN by index
                # since the model variables are ordered
                rpn_len = len(rpn_weights)
                loaded_rpn_vars = mlod_vars[0:rpn_len]
                rpn_weights_reload = sess.run(loaded_rpn_vars)

                # Make sure the reloaded weights match the originally
                # loaded weights
                for i in range(rpn_len):
                    np.testing.assert_array_equal(rpn_weights_reload[i],
                                                  rpn_weights[i])