예제 #1
0
    def test_batch_norm_class(self):
        # This tests the model and trainer set up
        train_config_text_proto = """
        optimizer {
          gradient_descent {
            learning_rate {
              constant_learning_rate {
                learning_rate: 1.0
              }
            }
          }
        }
        max_iterations: 5
        """
        model_config_text_proto = """
            path_drop_probabilities: [1.0, 1.0]
        """
        train_config = train_pb2.TrainConfig()
        text_format.Merge(train_config_text_proto, train_config)

        model_config = model_pb2.ModelConfig()
        text_format.Merge(model_config_text_proto, model_config)
        train_config.overwrite_checkpoints = True
        test_root_dir = '/tmp/pplp_unit_test/'

        paths_config = model_config.paths_config
        paths_config.logdir = test_root_dir + 'logs/'
        paths_config.checkpoint_dir = test_root_dir

        classifier = FakeBatchNormClassifier(model_config)
        trainer.train(classifier, train_config)
예제 #2
0
def train(model_config, train_config, dataset_config):

    dataset = DatasetBuilder.build_panoptic_dataset(dataset_config,
                                                    use_defaults=False)

    train_val_test = 'train'
    model_name = model_config.model_name

    with tf.Graph().as_default():
        if model_name == 'rpn_model':
            model = RpnModel(model_config,
                             train_val_test=train_val_test,
                             dataset=dataset)
        elif model_name == 'avod_3d_model':
            model = AvodModel(model_config,
                              train_val_test=train_val_test,
                              dataset=dataset)
        else:
            raise ValueError('Invalid model_name')

        trainer.train(model, train_config)
예제 #3
0
    def test_load_model_weights(self):
        # Tests loading weights

        train_val_test = 'train'

        # Overwrite the training iterations
        self.train_config.max_iterations = 1
        self.train_config.overwrite_checkpoints = True

        with tf.Graph().as_default():
            model = RpnModel(self.model_config,
                             train_val_test=train_val_test,
                             dataset=self.dataset)
            trainer.train(model, self.train_config)

            paths_config = self.model_config.paths_config
            rpn_checkpoint_dir = paths_config.checkpoint_dir

            # load the weights back in
            init_op = tf.global_variables_initializer()

            saver = tf.train.Saver()
            with tf.Session() as sess:
                sess.run(init_op)

                trainer_utils.load_checkpoints(rpn_checkpoint_dir, saver)
                checkpoint_to_restore = saver.last_checkpoints[-1]
                trainer_utils.load_model_weights(sess, checkpoint_to_restore)

                rpn_vars = slim.get_model_variables()
                rpn_weights = sess.run(rpn_vars)
                self.assertGreater(len(rpn_weights),
                                   0,
                                   msg='Loaded RPN weights are empty')

        with tf.Graph().as_default():
            model = AvodModel(self.model_config,
                              train_val_test=train_val_test,
                              dataset=self.dataset)
            model.build()

            # load the weights back in
            init_op = tf.global_variables_initializer()

            saver = tf.train.Saver()
            with tf.Session() as sess:
                sess.run(init_op)

                trainer_utils.load_checkpoints(rpn_checkpoint_dir, saver)
                checkpoint_to_restore = saver.last_checkpoints[-1]
                trainer_utils.load_model_weights(sess, checkpoint_to_restore)

                avod_vars = slim.get_model_variables()
                avod_weights = sess.run(avod_vars)

                # AVOD weights should include both RPN + AVOD weights
                self.assertGreater(len(avod_weights),
                                   len(rpn_weights),
                                   msg='Expected more weights for AVOD')

                # grab weights corresponding to RPN by index
                # since the model variables are ordered
                rpn_len = len(rpn_weights)
                loaded_rpn_vars = avod_vars[0:rpn_len]
                rpn_weights_reload = sess.run(loaded_rpn_vars)

                # Make sure the reloaded weights match the originally
                # loaded weights
                for i in range(rpn_len):
                    np.testing.assert_array_equal(rpn_weights_reload[i],
                                                  rpn_weights[i])