def test_batch_norm_class(self): # This tests the model and trainer set up train_config_text_proto = """ optimizer { gradient_descent { learning_rate { constant_learning_rate { learning_rate: 1.0 } } } } max_iterations: 5 """ model_config_text_proto = """ path_drop_probabilities: [1.0, 1.0] """ train_config = train_pb2.TrainConfig() text_format.Merge(train_config_text_proto, train_config) model_config = model_pb2.ModelConfig() text_format.Merge(model_config_text_proto, model_config) train_config.overwrite_checkpoints = True test_root_dir = '/tmp/mlod_unit_test/' paths_config = model_config.paths_config paths_config.logdir = test_root_dir + 'logs/' paths_config.checkpoint_dir = test_root_dir classifier = FakeBatchNormClassifier(model_config) trainer.train(classifier, train_config)
def train(rpn_model_config, mlod_model_config, rpn_train_config, mlod_train_config, dataset_config): train_val_test = 'train' dataset = DatasetBuilder.build_kitti_dataset(dataset_config, use_defaults=False) paths_config = rpn_model_config.paths_config rpn_checkpoint_dir = paths_config.checkpoint_dir with tf.Graph().as_default(): model = RpnModel(rpn_model_config, train_val_test=train_val_test, dataset=dataset) trainer.train(model, rpn_train_config) # load the weights back in saver = tf.train.Saver() init_op = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init_op) trainer_utils.load_checkpoints(rpn_checkpoint_dir, saver) checkpoint_to_restore = saver.last_checkpoints[-1] trainer_utils.load_model_weights(model, sess, checkpoint_to_restore) # Merge RPN configs with MLOD - This will overwrite # the appropriate configs set for MLOD while keeping # the common configs the same. rpn_model_config.MergeFrom(mlod_model_config) rpn_train_config.MergeFrom(mlod_train_config) mlod_model_merged = deepcopy(rpn_model_config) mlod_train_merged = deepcopy(rpn_train_config) with tf.Graph().as_default(): model = MlodModel(mlod_model_merged, train_val_test=train_val_test, dataset=dataset) trainer.train(model, mlod_train_merged, stagewise_training=True, init_checkpoint_dir=rpn_checkpoint_dir)
def train(model_config, train_config, dataset_config): dataset = DatasetBuilder.build_kitti_dataset(dataset_config, use_defaults=False) train_val_test = 'train' model_name = model_config.model_name with tf.Graph().as_default(): if model_name == 'rpn_model': model = RpnModel(model_config, train_val_test=train_val_test, dataset=dataset) elif model_name == 'mlod_model': model = MlodModel(model_config, train_val_test=train_val_test, dataset=dataset) else: raise ValueError('Invalid model_name') trainer.train(model, train_config)
def train_and_eval(model_config, train_config, eval_config, dataset_config): # Dataset Configuration dataset_config_train = DatasetBuilder.copy_config(dataset_config) dataset_config_eval = DatasetBuilder.copy_config(dataset_config) dataset_train = DatasetBuilder.build_kitti_dataset(dataset_config_train, use_defaults=False) dataset_eval = DatasetBuilder.build_kitti_dataset(dataset_config_eval, use_defaults=False) model_name = model_config.model_name train_val_test = 'train' eval_mode = eval_config.eval_mode if eval_mode == 'train': raise ValueError('Evaluation mode can only be set to `val` or `test`.') # keep a copy as this will be overwritten inside # the training loop below max_train_iter = train_config.max_iterations checkpoint_interval = train_config.checkpoint_interval eval_interval = eval_config.eval_interval if eval_interval < checkpoint_interval or \ (eval_interval % checkpoint_interval) != 0: raise ValueError( 'Checkpoint interval (given {}) must be greater than and' 'divisible by the evaluation interval (given {}).'.format( eval_interval, checkpoint_interval)) # Use the evaluation losses file to continue from the latest # checkpoint already_evaluated_ckpts = evaluator.get_evaluated_ckpts( model_config, model_name) if len(already_evaluated_ckpts) != 0: current_train_iter = already_evaluated_ckpts[-1] else: current_train_iter = eval_interval # while training is not finished while current_train_iter <= max_train_iter: # Train with tf.Graph().as_default(): if model_name == 'mlod_model': model = MlodModel(model_config, train_val_test=train_val_test, dataset=dataset_train) elif model_name == 'rpn_model': model = RpnModel(model_config, train_val_test=train_val_test, dataset=dataset_train) else: raise ValueError('Invalid model name {}'.format(model_name)) # overwrite the training epochs train_config.max_iterations = current_train_iter print('\n*************** Training ****************\n') trainer.train(model, train_config) current_train_iter += eval_interval # Evaluate with tf.Graph().as_default(): if model_name == 'mlod_model': model = MlodModel(model_config, train_val_test=eval_mode, dataset=dataset_eval) elif model_name == 'rpn_model': model = RpnModel(model_config, train_val_test=eval_mode, dataset=dataset_eval) else: raise ValueError('Invalid model name {}'.format(model_name)) print('\n*************** Evaluating *****************\n') evaluator.run_latest_checkpoints(model, dataset_config_eval) print('\n************ Finished training and evaluating *************\n')
def test_load_model_weights(self): # Tests stagewise training i.e. loading RPN weights # onto MLOD train_val_test = 'train' # overwrite the training iterations self.train_config.max_iterations = 1 self.train_config.overwrite_checkpoints = True rpn_weights = [] rpn_weights_reload = [] with tf.Graph().as_default(): model = RpnModel(self.model_config, train_val_test=train_val_test, dataset=self.dataset) trainer.train(model, self.train_config) paths_config = self.model_config.paths_config rpn_checkpoint_dir = paths_config.checkpoint_dir # load the weights back in init_op = tf.global_variables_initializer() saver = tf.train.Saver() with tf.Session() as sess: sess.run(init_op) trainer_utils.load_checkpoints(rpn_checkpoint_dir, saver) checkpoint_to_restore = saver.last_checkpoints[-1] trainer_utils.load_model_weights(model, sess, checkpoint_to_restore) rpn_vars = slim.get_model_variables() rpn_weights = sess.run(rpn_vars) self.assertGreater(len(rpn_weights), 0, msg='Loaded RPN weights are empty') with tf.Graph().as_default(): model = MlodModel(self.model_config, train_val_test=train_val_test, dataset=self.dataset) model.build() # load the weights back in init_op = tf.global_variables_initializer() saver = tf.train.Saver() with tf.Session() as sess: sess.run(init_op) trainer_utils.load_checkpoints(rpn_checkpoint_dir, saver) checkpoint_to_restore = saver.last_checkpoints[-1] trainer_utils.load_model_weights(model, sess, checkpoint_to_restore) mlod_vars = slim.get_model_variables() mlod_weights = sess.run(mlod_vars) # MLOD weights should include both RPN + MLOD weights self.assertGreater(len(mlod_weights), len(rpn_weights), msg='Expected more weights for MLOD') # grab weights corresponding to RPN by index # since the model variables are ordered rpn_len = len(rpn_weights) loaded_rpn_vars = mlod_vars[0:rpn_len] rpn_weights_reload = sess.run(loaded_rpn_vars) # Make sure the reloaded weights match the originally # loaded weights for i in range(rpn_len): np.testing.assert_array_equal(rpn_weights_reload[i], rpn_weights[i])