def test_batch_norm_class(self): # This tests the model and trainer set up train_config_text_proto = """ optimizer { gradient_descent { learning_rate { constant_learning_rate { learning_rate: 1.0 } } } } max_iterations: 5 """ model_config_text_proto = """ path_drop_probabilities: [1.0, 1.0] """ train_config = train_pb2.TrainConfig() text_format.Merge(train_config_text_proto, train_config) model_config = model_pb2.ModelConfig() text_format.Merge(model_config_text_proto, model_config) train_config.overwrite_checkpoints = True test_root_dir = '/tmp/avod_unit_test/' paths_config = model_config.paths_config paths_config.logdir = test_root_dir + 'logs/' paths_config.checkpoint_dir = test_root_dir classifier = FakeBatchNormClassifier(model_config) trainer.train(classifier, train_config)
def train(model_config, train_config, dataset_config): dataset = DatasetBuilder.build_kitti_dataset(dataset_config, use_defaults=False) train_val_test = 'train' model_name = model_config.model_name with tf.Graph().as_default(): if model_name == 'rpn_model': model = RpnModel(model_config, train_val_test=train_val_test, dataset=dataset) elif model_name == 'bev_only_rpn_model': model = BevOnlyRpnModel(model_config, train_val_test=train_val_test, dataset=dataset) elif model_name == 'avod_model': model = AvodModel(model_config, train_val_test=train_val_test, dataset=dataset) elif model_name == 'bev_only_avod_model': model = BevOnlyAvodModel(model_config, train_val_test=train_val_test, dataset=dataset) else: raise ValueError('Invalid model_name') trainer.train(model, train_config)
def train(model_config, train_config, dataset_config): # 读取config文件里面的详细内容,包括: # model_config:模型参数 # train_config:训练参数 # dataset_config:数据集参数 dataset = DatasetBuilder.build_kitti_dataset(dataset_config, use_defaults=False) train_val_test = 'train' # 包括avod_model 和 rpn_model model_name = model_config.model_name with tf.Graph().as_default(): if model_name == 'rpn_model': model = RpnModel(model_config, train_val_test=train_val_test, dataset=dataset) elif model_name == 'avod_model': model = AvodModel(model_config, train_val_test=train_val_test, dataset=dataset) else: raise ValueError('Invalid model_name') trainer.train(model, train_config)
def train(model_config, train_config, dataset_config): dataset = DatasetBuilder.build_kitti_dataset(dataset_config, use_defaults=False) train_val_test = 'train' model_name = model_config.model_name with tf.Graph().as_default(): if model_name == 'rpn_model': model = RpnModel(model_config, train_val_test=train_val_test, dataset=dataset) elif model_name == 'avod_model': model = AvodModel(model_config, train_val_test=train_val_test, dataset=dataset) elif model_name == 'retinanet_model': model = RetinanetModel(model_config, train_val_test=train_val_test, dataset=dataset) else: raise ValueError('Invalid model_name') #import pdb #pdb.set_trace() trainer.train(model, train_config)
def train(model_config, train_config, dataset_config): dataset = DatasetBuilder.build_kitti_dataset(dataset_config, use_defaults=False) train_val_test = 'train' model_name = model_config.model_name with tf.Graph().as_default(): if model_name == 'rpn_model': model = RpnModel(model_config, train_val_test=train_val_test, dataset=dataset) trainer.train(model, train_config) elif model_name == 'avod_model': model = AvodModel(model_config, train_val_test=train_val_test, dataset=dataset) trainer.train(model, train_config) elif model_name == 'avod_moe_model': model = AvodMoeModel(model_config, train_val_test=train_val_test, dataset=dataset) trainer_moe.train(model, train_config) elif model_name == 'epbrm': model = epBRM(model_config, dataset=dataset) epbrm_trainer.train(model, train_config) elif model_name == 'avod_model_new_bev': model = AvodModelBEV(model_config, train_val_test=train_val_test, dataset=dataset) elif model_name == 'avod_model_double_fusion_new_bev': model = AvodModelDoubleFusionBEV(model_config, train_val_test=train_val_test, dataset=dataset) else: raise ValueError('Invalid model_name') if model_name == 'avod_model_new_bev' or model_name == 'avod_model_double_fusion_new_bev': trainer_new_bev.train(model, train_config) else: trainer.train(model, train_config)
def test_load_model_weights(self): # Tests loading weights train_val_test = 'train' # Overwrite the training iterations self.train_config.max_iterations = 1 self.train_config.overwrite_checkpoints = True with tf.Graph().as_default(): model = RpnModel(self.model_config, train_val_test=train_val_test, dataset=self.dataset) trainer.train(model, self.train_config) paths_config = self.model_config.paths_config rpn_checkpoint_dir = paths_config.checkpoint_dir # load the weights back in init_op = tf.global_variables_initializer() saver = tf.train.Saver() with tf.Session() as sess: sess.run(init_op) trainer_utils.load_checkpoints(rpn_checkpoint_dir, saver) checkpoint_to_restore = saver.last_checkpoints[-1] trainer_utils.load_model_weights(sess, checkpoint_to_restore) rpn_vars = slim.get_model_variables() rpn_weights = sess.run(rpn_vars) self.assertGreater(len(rpn_weights), 0, msg='Loaded RPN weights are empty') with tf.Graph().as_default(): model = AvodModel(self.model_config, train_val_test=train_val_test, dataset=self.dataset) model.build() # load the weights back in init_op = tf.global_variables_initializer() saver = tf.train.Saver() with tf.Session() as sess: sess.run(init_op) trainer_utils.load_checkpoints(rpn_checkpoint_dir, saver) checkpoint_to_restore = saver.last_checkpoints[-1] trainer_utils.load_model_weights(sess, checkpoint_to_restore) avod_vars = slim.get_model_variables() avod_weights = sess.run(avod_vars) # AVOD weights should include both RPN + AVOD weights self.assertGreater(len(avod_weights), len(rpn_weights), msg='Expected more weights for AVOD') # grab weights corresponding to RPN by index # since the model variables are ordered rpn_len = len(rpn_weights) loaded_rpn_vars = avod_vars[0:rpn_len] rpn_weights_reload = sess.run(loaded_rpn_vars) # Make sure the reloaded weights match the originally # loaded weights for i in range(rpn_len): np.testing.assert_array_equal(rpn_weights_reload[i], rpn_weights[i])