def test_batch_norm_class(self): # This tests the model and trainer set up train_config_text_proto = """ optimizer { gradient_descent { learning_rate { constant_learning_rate { learning_rate: 1.0 } } } } max_iterations: 5 """ model_config_text_proto = """ path_drop_probabilities: [1.0, 1.0] """ train_config = train_pb2.TrainConfig() text_format.Merge(train_config_text_proto, train_config) model_config = model_pb2.ModelConfig() text_format.Merge(model_config_text_proto, model_config) train_config.overwrite_checkpoints = True test_root_dir = '/tmp/pplp_unit_test/' paths_config = model_config.paths_config paths_config.logdir = test_root_dir + 'logs/' paths_config.checkpoint_dir = test_root_dir classifier = FakeBatchNormClassifier(model_config) trainer.train(classifier, train_config)
def train(model_config, train_config, dataset_config): dataset = DatasetBuilder.build_panoptic_dataset(dataset_config, use_defaults=False) train_val_test = 'train' model_name = model_config.model_name with tf.Graph().as_default(): if model_name == 'rpn_model': model = RpnModel(model_config, train_val_test=train_val_test, dataset=dataset) elif model_name == 'avod_3d_model': model = AvodModel(model_config, train_val_test=train_val_test, dataset=dataset) else: raise ValueError('Invalid model_name') trainer.train(model, train_config)
def test_load_model_weights(self): # Tests loading weights train_val_test = 'train' # Overwrite the training iterations self.train_config.max_iterations = 1 self.train_config.overwrite_checkpoints = True with tf.Graph().as_default(): model = RpnModel(self.model_config, train_val_test=train_val_test, dataset=self.dataset) trainer.train(model, self.train_config) paths_config = self.model_config.paths_config rpn_checkpoint_dir = paths_config.checkpoint_dir # load the weights back in init_op = tf.global_variables_initializer() saver = tf.train.Saver() with tf.Session() as sess: sess.run(init_op) trainer_utils.load_checkpoints(rpn_checkpoint_dir, saver) checkpoint_to_restore = saver.last_checkpoints[-1] trainer_utils.load_model_weights(sess, checkpoint_to_restore) rpn_vars = slim.get_model_variables() rpn_weights = sess.run(rpn_vars) self.assertGreater(len(rpn_weights), 0, msg='Loaded RPN weights are empty') with tf.Graph().as_default(): model = AvodModel(self.model_config, train_val_test=train_val_test, dataset=self.dataset) model.build() # load the weights back in init_op = tf.global_variables_initializer() saver = tf.train.Saver() with tf.Session() as sess: sess.run(init_op) trainer_utils.load_checkpoints(rpn_checkpoint_dir, saver) checkpoint_to_restore = saver.last_checkpoints[-1] trainer_utils.load_model_weights(sess, checkpoint_to_restore) avod_vars = slim.get_model_variables() avod_weights = sess.run(avod_vars) # AVOD weights should include both RPN + AVOD weights self.assertGreater(len(avod_weights), len(rpn_weights), msg='Expected more weights for AVOD') # grab weights corresponding to RPN by index # since the model variables are ordered rpn_len = len(rpn_weights) loaded_rpn_vars = avod_vars[0:rpn_len] rpn_weights_reload = sess.run(loaded_rpn_vars) # Make sure the reloaded weights match the originally # loaded weights for i in range(rpn_len): np.testing.assert_array_equal(rpn_weights_reload[i], rpn_weights[i])