Exemple #1
0
    def test_batch_norm_class(self):
        # This tests the model and trainer set up
        train_config_text_proto = """
        optimizer {
          gradient_descent {
            learning_rate {
              constant_learning_rate {
                learning_rate: 1.0
              }
            }
          }
        }
        max_iterations: 5
        """
        model_config_text_proto = """
            path_drop_probabilities: [1.0, 1.0]
        """
        train_config = train_pb2.TrainConfig()
        text_format.Merge(train_config_text_proto, train_config)

        model_config = model_pb2.ModelConfig()
        text_format.Merge(model_config_text_proto, model_config)
        train_config.overwrite_checkpoints = True
        test_root_dir = '/tmp/avod_unit_test/'

        paths_config = model_config.paths_config
        paths_config.logdir = test_root_dir + 'logs/'
        paths_config.checkpoint_dir = test_root_dir

        classifier = FakeBatchNormClassifier(model_config)
        trainer.train(classifier, train_config)
Exemple #2
0
def train(model_config, train_config, dataset_config):

    dataset = DatasetBuilder.build_kitti_dataset(dataset_config,
                                                 use_defaults=False)

    train_val_test = 'train'
    model_name = model_config.model_name

    with tf.Graph().as_default():
        if model_name == 'rpn_model':
            model = RpnModel(model_config,
                             train_val_test=train_val_test,
                             dataset=dataset)
        elif model_name == 'bev_only_rpn_model':
            model = BevOnlyRpnModel(model_config,
                                    train_val_test=train_val_test,
                                    dataset=dataset)
        elif model_name == 'avod_model':
            model = AvodModel(model_config,
                              train_val_test=train_val_test,
                              dataset=dataset)
        elif model_name == 'bev_only_avod_model':
            model = BevOnlyAvodModel(model_config,
                                     train_val_test=train_val_test,
                                     dataset=dataset)
        else:
            raise ValueError('Invalid model_name')

        trainer.train(model, train_config)
Exemple #3
0
def train(model_config, train_config, dataset_config):
    # 读取config文件里面的详细内容,包括:
    # model_config:模型参数
    # train_config:训练参数
    # dataset_config:数据集参数
    dataset = DatasetBuilder.build_kitti_dataset(dataset_config,
                                                 use_defaults=False)

    train_val_test = 'train'
    # 包括avod_model 和 rpn_model
    model_name = model_config.model_name

    with tf.Graph().as_default():
        if model_name == 'rpn_model':
            model = RpnModel(model_config,
                             train_val_test=train_val_test,
                             dataset=dataset)
        elif model_name == 'avod_model':
            model = AvodModel(model_config,
                              train_val_test=train_val_test,
                              dataset=dataset)
        else:
            raise ValueError('Invalid model_name')

        trainer.train(model, train_config)
Exemple #4
0
def train(model_config, train_config, dataset_config):

    dataset = DatasetBuilder.build_kitti_dataset(dataset_config,
                                                 use_defaults=False)

    train_val_test = 'train'
    model_name = model_config.model_name

    with tf.Graph().as_default():
        if model_name == 'rpn_model':
            model = RpnModel(model_config,
                             train_val_test=train_val_test,
                             dataset=dataset)
        elif model_name == 'avod_model':
            model = AvodModel(model_config,
                              train_val_test=train_val_test,
                              dataset=dataset)
        elif model_name == 'retinanet_model':
            model = RetinanetModel(model_config,
                                   train_val_test=train_val_test,
                                   dataset=dataset)
        else:
            raise ValueError('Invalid model_name')

        #import pdb
        #pdb.set_trace()

        trainer.train(model, train_config)
Exemple #5
0
def train(model_config, train_config, dataset_config):

    dataset = DatasetBuilder.build_kitti_dataset(dataset_config,
                                                 use_defaults=False)

    train_val_test = 'train'
    model_name = model_config.model_name

    with tf.Graph().as_default():
        if model_name == 'rpn_model':
            model = RpnModel(model_config,
                             train_val_test=train_val_test,
                             dataset=dataset)
            trainer.train(model, train_config)
        elif model_name == 'avod_model':
            model = AvodModel(model_config,
                              train_val_test=train_val_test,
                              dataset=dataset)
            trainer.train(model, train_config)
        elif model_name == 'avod_moe_model':
            model = AvodMoeModel(model_config,
                                 train_val_test=train_val_test,
                                 dataset=dataset)
            trainer_moe.train(model, train_config)
        elif model_name == 'epbrm':
            model = epBRM(model_config, dataset=dataset)
            epbrm_trainer.train(model, train_config)

        elif model_name == 'avod_model_new_bev':
            model = AvodModelBEV(model_config,
                                 train_val_test=train_val_test,
                                 dataset=dataset)
        elif model_name == 'avod_model_double_fusion_new_bev':
            model = AvodModelDoubleFusionBEV(model_config,
                                             train_val_test=train_val_test,
                                             dataset=dataset)
        else:
            raise ValueError('Invalid model_name')

        if model_name == 'avod_model_new_bev' or model_name == 'avod_model_double_fusion_new_bev':
            trainer_new_bev.train(model, train_config)
        else:
            trainer.train(model, train_config)
Exemple #6
0
    def test_load_model_weights(self):
        # Tests loading weights

        train_val_test = 'train'

        # Overwrite the training iterations
        self.train_config.max_iterations = 1
        self.train_config.overwrite_checkpoints = True

        with tf.Graph().as_default():
            model = RpnModel(self.model_config,
                             train_val_test=train_val_test,
                             dataset=self.dataset)
            trainer.train(model, self.train_config)

            paths_config = self.model_config.paths_config
            rpn_checkpoint_dir = paths_config.checkpoint_dir

            # load the weights back in
            init_op = tf.global_variables_initializer()

            saver = tf.train.Saver()
            with tf.Session() as sess:
                sess.run(init_op)

                trainer_utils.load_checkpoints(rpn_checkpoint_dir, saver)
                checkpoint_to_restore = saver.last_checkpoints[-1]
                trainer_utils.load_model_weights(sess, checkpoint_to_restore)

                rpn_vars = slim.get_model_variables()
                rpn_weights = sess.run(rpn_vars)
                self.assertGreater(len(rpn_weights), 0,
                                   msg='Loaded RPN weights are empty')

        with tf.Graph().as_default():
            model = AvodModel(self.model_config,
                              train_val_test=train_val_test,
                              dataset=self.dataset)
            model.build()

            # load the weights back in
            init_op = tf.global_variables_initializer()

            saver = tf.train.Saver()
            with tf.Session() as sess:
                sess.run(init_op)

                trainer_utils.load_checkpoints(rpn_checkpoint_dir, saver)
                checkpoint_to_restore = saver.last_checkpoints[-1]
                trainer_utils.load_model_weights(sess, checkpoint_to_restore)

                avod_vars = slim.get_model_variables()
                avod_weights = sess.run(avod_vars)

                # AVOD weights should include both RPN + AVOD weights
                self.assertGreater(len(avod_weights),
                                   len(rpn_weights),
                                   msg='Expected more weights for AVOD')

                # grab weights corresponding to RPN by index
                # since the model variables are ordered
                rpn_len = len(rpn_weights)
                loaded_rpn_vars = avod_vars[0:rpn_len]
                rpn_weights_reload = sess.run(loaded_rpn_vars)

                # Make sure the reloaded weights match the originally
                # loaded weights
                for i in range(rpn_len):
                    np.testing.assert_array_equal(rpn_weights_reload[i],
                                                  rpn_weights[i])