def testBuildEmptyOptimizer(self): optimizer_text_proto = """ """ optimizer_proto = optimizer_pb2.Optimizer() text_format.Merge(optimizer_text_proto, optimizer_proto) with self.assertRaises(ValueError): optimizer_builder.build(optimizer_proto)
def testBuildEmptyOptimizer(self): optimizer_text_proto = """ """ global_summaries = set([]) optimizer_proto = optimizer_pb2.Optimizer() text_format.Merge(optimizer_text_proto, optimizer_proto) with self.assertRaises(ValueError): optimizer_builder.build(optimizer_proto, global_summaries)
def test_train_cfg(cfg_file): with open(cfg_file) as f: cfg = json.load(f) cfg = build_config(cfg) print(cfg) train_cfg = cfg['training'] dataloader_cfg = train_cfg['dataloader'] model_cfg = train_cfg['model'] optimizer_cfg = train_cfg['optimizer'] loss_cfg = train_cfg['losses'] scheduler_cfg = train_cfg['scheduler'] device = torch.device('cuda') dataloader = dataloader_builder(dataloader_cfg) dataset = dataloader.dataset model = model_builder.build(model_cfg, dataset.info) optimizer = optimizer_builder.build(optimizer_cfg, model.parameters()) #optimizer = torch.optim.SGD(model.parameters(), lr=eps0, momentum=0.9, weight_decay=5e-4) lr_scheduler = scheduler_builder.build(scheduler_cfg, optimizer) loss = loss_builder.build(loss_cfg) file_logger = log.get_file_logger() model = torch.nn.DataParallel(model) # new experiment model = model.train() trained_models = [] while lr_scheduler.run: lr_scheduler.step() for batch_id, (data, split_info) in enumerate(dataloader): #print(data) optimizer.zero_grad() data['imgs'] = data['img'].to(device) print("imgs", data['imgs']) imgs = Variable(data['img'], requires_grad=True) endpoints = model(imgs, model.module.endpoints) # threoretically losses could also be caluclated distributed. losses = loss(endpoints, data, split_info) print("losses", losses) print(torch.mean(losses)) loss_mean = torch.mean(losses) loss_mean.backward() optimizer.step() break path = file_logger.save_checkpoint(model, optimizer, lr_scheduler.last_epoch) if path: trained_models.append(path) file_logger.close()
def testBuildAdamOptimizer(self): optimizer_text_proto = """ adam_optimizer: { learning_rate: { constant_learning_rate { learning_rate: 0.002 } } } use_moving_average: false """ optimizer_proto = optimizer_pb2.Optimizer() text_format.Merge(optimizer_text_proto, optimizer_proto) optimizer, _ = optimizer_builder.build(optimizer_proto) self.assertTrue(isinstance(optimizer, tf.train.AdamOptimizer))
def testBuildMomentumOptimizer(self): optimizer_text_proto = """ momentum_optimizer: { learning_rate: { constant_learning_rate { learning_rate: 0.001 } } momentum_optimizer_value: 0.99 } use_moving_average: false """ optimizer_proto = optimizer_pb2.Optimizer() text_format.Merge(optimizer_text_proto, optimizer_proto) optimizer, _ = optimizer_builder.build(optimizer_proto) self.assertTrue(isinstance(optimizer, tf.train.MomentumOptimizer))
def testBuildMovingAverageOptimizer(self): optimizer_text_proto = """ adam_optimizer: { learning_rate: { constant_learning_rate { learning_rate: 0.002 } } } use_moving_average: True """ global_summaries = set([]) optimizer_proto = optimizer_pb2.Optimizer() text_format.Merge(optimizer_text_proto, optimizer_proto) optimizer = optimizer_builder.build(optimizer_proto, global_summaries) self.assertTrue( isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer))
def testBuildMovingAverageOptimizerWithNonDefaultDecay(self): optimizer_text_proto = """ adam_optimizer: { learning_rate: { constant_learning_rate { learning_rate: 0.002 } } } use_moving_average: True moving_average_decay: 0.2 """ optimizer_proto = optimizer_pb2.Optimizer() text_format.Merge(optimizer_text_proto, optimizer_proto) optimizer, _ = optimizer_builder.build(optimizer_proto) self.assertTrue( isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer)) # TODO: Find a way to not depend on the private members. self.assertAlmostEqual(optimizer._ema._decay, 0.2)
def testBuildRMSPropOptimizer(self): optimizer_text_proto = """ rms_prop_optimizer: { learning_rate: { exponential_decay_learning_rate { initial_learning_rate: 0.004 decay_steps: 800720 decay_factor: 0.95 } } momentum_optimizer_value: 0.9 decay: 0.9 epsilon: 1.0 } use_moving_average: false """ optimizer_proto = optimizer_pb2.Optimizer() text_format.Merge(optimizer_text_proto, optimizer_proto) optimizer, _ = optimizer_builder.build(optimizer_proto) self.assertTrue(isinstance(optimizer, tf.train.RMSPropOptimizer))
def train(create_tensor_dict_fn_list, create_model_fn, train_config, master, task, num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name, is_chief, train_dir): """Training function for models. Args: create_tensor_dict_fn: a function to create a tensor input dictionary. create_model_fn: a function that creates a DetectionModel and generates losses. train_config: a train_pb2.TrainConfig protobuf. master: BNS name of the TensorFlow master to use. task: The task id of this training instance. num_clones: The number of clones to run per machine. worker_replicas: The number of work replicas to train with. clone_on_cpu: True if clones should be forced to run on CPU. ps_tasks: Number of parameter server tasks. worker_job_name: Name of the worker job. is_chief: Whether this replica is the chief replica. train_dir: Directory to write checkpoints and training summaries to. """ data_augmentation_options = [ preprocessor_builder.build(step) for step in train_config.data_augmentation_options ] with tf.Graph().as_default(): # Build a configuration specifying multi-GPU and multi-replicas. deploy_config = model_deploy.DeploymentConfig( num_clones=num_clones, clone_on_cpu=clone_on_cpu, replica_id=task, num_replicas=worker_replicas, num_ps_tasks=ps_tasks, worker_job_name=worker_job_name) # Place the global step on the device storing the variables. with tf.device(deploy_config.variables_device()): global_step = tf.train.create_global_step() with tf.device(deploy_config.inputs_device()), \ tf.name_scope('Input'): input_queue_list = [] for i, create_tensor_dict_fn in enumerate( create_tensor_dict_fn_list): input_queue_list.append( _create_input_queue( train_config.batch_size[i] // num_clones, create_tensor_dict_fn, train_config.batch_queue_capacity, train_config.num_batch_queue_threads, train_config.prefetch_queue_capacity, data_augmentation_options)) # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) global_summaries = set([]) model_fn = functools.partial(_create_losses, create_model_fn=create_model_fn) clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue_list]) first_clone_scope = clones[0].scope # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by model_fn. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) with tf.device(deploy_config.optimizer_device()), \ tf.name_scope('Optimizer'): training_optimizer = optimizer_builder.build( train_config.optimizer, global_summaries) sync_optimizer = None if train_config.sync_replicas: training_optimizer = tf.train.SyncReplicasOptimizer( training_optimizer, replicas_to_aggregate=train_config.replicas_to_aggregate, total_num_replicas=train_config.worker_replicas) sync_optimizer = training_optimizer # Create ops required to initialize the model from a given checkpoint. init_fn = None ''' if train_config.fine_tune_checkpoint: var_map = detection_model.restore_map( from_detection_checkpoint=train_config.from_detection_checkpoint ) available_var_map = variables_helper.get_variables_available_in_checkpoint( var_map, train_config.fine_tune_checkpoint ) init_saver = tf.train.Saver(available_var_map) def initializer_fn(sess): init_saver.restore(sess, train_config.fine_tune_checkpoint) init_fn = initializer_fn ''' if train_config.fine_tune_checkpoint: all_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) restore_vars = [ var for var in all_vars if (var.name.split('/')[0] == 'FeatureExtractor' and var.name.split('/')[1] == 'Convnet') ] pre_train_saver = tf.train.Saver(restore_vars) def load_pretrain(scaffold, sess): pre_train_saver.restore(sess, train_config.fine_tune_checkpoint) else: load_pretrain = None with tf.device(deploy_config.optimizer_device()), \ tf.variable_scope('OptimizeClones'): total_loss, grads_and_vars = model_deploy.optimize_clones( clones, training_optimizer, regularization_losses=None) total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.') # Optionally multiply bias gradients by train_config.bias_grad_multiplier. if train_config.bias_grad_multiplier: biases_regex_list = [r'.*bias(?:es)?', r'.*beta'] grads_and_vars = variables_helper.multiply_gradients_matching_regex( grads_and_vars, biases_regex_list, multiplier=train_config.bias_grad_multiplier) # Optionally freeze some layers by setting their gradients to be zero. if train_config.freeze_variables: grads_and_vars = variables_helper.freeze_gradients_matching_regex( grads_and_vars, train_config.freeze_variables) # Optionally clip gradients if train_config.gradient_clipping_by_norm > 0: with tf.name_scope('clip_grads'): grads_and_vars = tf.contrib.training.clip_gradient_norms( grads_and_vars, train_config.gradient_clipping_by_norm) # Create gradient updates. grad_updates = training_optimizer.apply_gradients( grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add summaries. for (grad, var) in grads_and_vars: var_name = var.op.name grad_name = 'grad/' + var_name global_summaries.add(tf.summary.histogram(grad_name, grad)) global_summaries.add(tf.summary.histogram(var_name, var)) # for model_var in tf.contrib.framework.get_model_variables(): # global_summaries.add(tf.summary.histogram(model_var.op.name, model_var)) for loss_tensor in tf.losses.get_losses(): global_summaries.add( tf.summary.scalar(loss_tensor.op.name, loss_tensor)) global_summaries.add( tf.summary.scalar('TotalLoss', tf.losses.get_total_loss())) # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) summaries |= global_summaries # Merge all summaries together. summary_op = tf.summary.merge(list(summaries), name='summary_op') # Soft placement allows placing on CPU ops without GPU implementation. session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) # Save checkpoints regularly. keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours saver = tf.train.Saver( keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours) scaffold = tf.train.Scaffold(init_fn=load_pretrain, summary_op=summary_op, saver=saver) stop_hook = tf.train.StopAtStepHook( num_steps=(train_config.num_steps if train_config.num_steps else None), ) profile_hook = profile_session_run_hooks.ProfileAtStepHook( at_step=200, checkpoint_dir=train_dir) tf.contrib.training.train( train_tensor, train_dir, master=master, is_chief=is_chief, scaffold=scaffold, hooks=[stop_hook, profile_hook], chief_only_hooks=None, save_checkpoint_secs=train_config.save_checkpoint_secs, save_summaries_steps=train_config.save_summaries_steps, config=session_config)
def train(create_tensor_dict_fn, create_model_fn, train_config, master, task, num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name, is_chief, train_dir, graph_hook_fn=None): """Training function for detection models. Args: create_tensor_dict_fn: a function to create a tensor input dictionary. create_model_fn: a function that creates a DetectionModel and generates losses. train_config: a train_pb2.TrainConfig protobuf. master: BNS name of the TensorFlow master to use. task: The task id of this training instance. num_clones: The number of clones to run per machine. worker_replicas: The number of work replicas to train with. clone_on_cpu: True if clones should be forced to run on CPU. ps_tasks: Number of parameter server tasks. worker_job_name: Name of the worker job. is_chief: Whether this replica is the chief replica. train_dir: Directory to write checkpoints and training summaries to. graph_hook_fn: Optional function that is called after the inference graph is built (before optimization). This is helpful to perform additional changes to the training graph such as adding FakeQuant ops. The function should modify the default graph. Raises: ValueError: If both num_clones > 1 and train_config.sync_replicas is true. """ detection_model = create_model_fn() data_augmentation_options = [ preprocessor_builder.build(step) for step in train_config.data_augmentation_options ] with tf.Graph().as_default(): # Build a configuration specifying multi-GPU and multi-replicas. deploy_config = model_deploy.DeploymentConfig( num_clones=num_clones, clone_on_cpu=clone_on_cpu, replica_id=task, num_replicas=worker_replicas, num_ps_tasks=ps_tasks, worker_job_name=worker_job_name) # Place the global step on the device storing the variables. with tf.device(deploy_config.variables_device()): global_step = slim.create_global_step() if num_clones != 1 and train_config.sync_replicas: raise ValueError('In Synchronous SGD mode num_clones must ', 'be 1. Found num_clones: {}'.format(num_clones)) batch_size = train_config.batch_size // num_clones if train_config.sync_replicas: batch_size //= train_config.replicas_to_aggregate with tf.device(deploy_config.inputs_device()): input_queue = create_input_queue( batch_size, create_tensor_dict_fn, train_config.batch_queue_capacity, train_config.num_batch_queue_threads, train_config.prefetch_queue_capacity, data_augmentation_options) # Gather initial summaries. # TODO(rathodv): See if summaries can be added/extracted from global tf # collections so that they don't have to be passed around. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) global_summaries = set([]) model_fn = functools.partial(_create_losses, create_model_fn=create_model_fn, train_config=train_config) clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue]) first_clone_scope = clones[0].scope if graph_hook_fn: with tf.device(deploy_config.variables_device()): graph_hook_fn() # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by model_fn. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) with tf.device(deploy_config.optimizer_device()): training_optimizer, optimizer_summary_vars = optimizer_builder.build( train_config.optimizer) for var in optimizer_summary_vars: tf.summary.scalar(var.op.name, var, family='LearningRate') sync_optimizer = None if train_config.sync_replicas: training_optimizer = tf.train.SyncReplicasOptimizer( training_optimizer, replicas_to_aggregate=train_config.replicas_to_aggregate, total_num_replicas=worker_replicas) sync_optimizer = training_optimizer with tf.device(deploy_config.optimizer_device()): regularization_losses = ( None if train_config.add_regularization_loss else []) total_loss, grads_and_vars = model_deploy.optimize_clones( clones, training_optimizer, regularization_losses=regularization_losses) total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.') # Optionally multiply bias gradients by train_config.bias_grad_multiplier. if train_config.bias_grad_multiplier: biases_regex_list = ['.*/biases'] grads_and_vars = variables_helper.multiply_gradients_matching_regex( grads_and_vars, biases_regex_list, multiplier=train_config.bias_grad_multiplier) # Optionally freeze some layers by setting their gradients to be zero. if train_config.freeze_variables: grads_and_vars = variables_helper.freeze_gradients_matching_regex( grads_and_vars, train_config.freeze_variables) # Optionally clip gradients if train_config.gradient_clipping_by_norm > 0: with tf.name_scope('clip_grads'): grads_and_vars = slim.learning.clip_gradient_norms( grads_and_vars, train_config.gradient_clipping_by_norm) # Create gradient updates. grad_updates = training_optimizer.apply_gradients( grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops, name='update_barrier') with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add summaries. for model_var in slim.get_model_variables(): global_summaries.add( tf.summary.histogram('ModelVars/' + model_var.op.name, model_var)) for loss_tensor in tf.losses.get_losses(): global_summaries.add( tf.summary.scalar('Losses/' + loss_tensor.op.name, loss_tensor)) global_summaries.add( tf.summary.scalar('Losses/TotalLoss', tf.losses.get_total_loss())) # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) summaries |= global_summaries # Merge all summaries together. summary_op = tf.summary.merge(list(summaries), name='summary_op') # Soft placement allows placing on CPU ops without GPU implementation. session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) # Save checkpoints regularly. keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours saver = tf.train.Saver( keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours) # Create ops required to initialize the model from a given checkpoint. init_fn = None if train_config.fine_tune_checkpoint: if not train_config.fine_tune_checkpoint_type: # train_config.from_detection_checkpoint field is deprecated. For # backward compatibility, fine_tune_checkpoint_type is set based on # from_detection_checkpoint. if train_config.from_detection_checkpoint: train_config.fine_tune_checkpoint_type = 'detection' else: train_config.fine_tune_checkpoint_type = 'classification' var_map = detection_model.restore_map( fine_tune_checkpoint_type=train_config. fine_tune_checkpoint_type, load_all_detection_checkpoint_vars=( train_config.load_all_detection_checkpoint_vars)) available_var_map = ( variables_helper.get_variables_available_in_checkpoint( var_map, train_config.fine_tune_checkpoint, include_global_step=False)) init_saver = tf.train.Saver(available_var_map) def initializer_fn(sess): init_saver.restore(sess, train_config.fine_tune_checkpoint) init_fn = initializer_fn slim.learning.train( train_tensor, logdir=train_dir, master=master, is_chief=is_chief, session_config=session_config, startup_delay_steps=train_config.startup_delay_steps, init_fn=init_fn, summary_op=summary_op, number_of_steps=(train_config.num_steps if train_config.num_steps else None), save_summaries_secs=120, sync_optimizer=sync_optimizer, saver=saver)
def run_train(dataloader_cfg, model_cfg, scheduler_cfg, optimizer_cfg, loss_cfg, validation_cfg, checkpoint_frequency, restore_checkpoint, max_epochs, _run): # Lets cuDNN benchmark conv implementations and choose the fastest. # Only good if sizes stay the same within the main loop! torch.backends.cudnn.benchmark = True exit_handler = ExitHandler() device = _run.config['device'] device_id = _run.config['device_id'] # during training just one dataloader dataloader = dataloader_builder.build(dataloader_cfg)[0] epoch = 0 if restore_checkpoint is not None: model_cfg, optimizer_cfg, epoch = utils.restore_checkpoint( restore_checkpoint, model_cfg, optimizer_cfg) def overwrite(to_overwrite, dic): to_overwrite.update(dic) return to_overwrite # some models depend on dataset, for example num_joints model_cfg = overwrite(dataloader.dataset.info, model_cfg) model = model_builder.build(model_cfg) loss_cfg['model'] = model loss = loss_builder.build(loss_cfg) loss = loss.to(device) parameters = list(model.parameters()) + list(loss.parameters()) optimizer = optimizer_builder.build(optimizer_cfg, parameters) lr_scheduler = scheduler_builder.build(scheduler_cfg, optimizer, epoch) if validation_cfg is None: validation_dataloaders = None else: validation_dataloaders = dataloader_builder.build(validation_cfg) keep = False file_logger = log.get_file_logger() logger = log.get_logger() model = torch.nn.DataParallel(model, device_ids=device_id) model.cuda() model = model.train() trained_models = [] exit_handler.register(file_logger.save_checkpoint, model, optimizer, "atexit", model_cfg) start_training_time = time.time() end = time.time() while epoch < max_epochs: epoch += 1 lr_scheduler.step() logger.info("Starting Epoch %d/%d", epoch, max_epochs) len_batch = len(dataloader) acc_time = 0 for batch_id, data in enumerate(dataloader): optimizer.zero_grad() endpoints = model(data, model.module.endpoints) logger.debug("datasets %s", list(data['split_info'].keys())) data.update(endpoints) # threoretically losses could also be caluclated distributed. losses = loss(endpoints, data) loss_mean = torch.mean(losses) loss_mean.backward() optimizer.step() acc_time += time.time() - end end = time.time() report_after_batch(_run=_run, logger=logger, batch_id=batch_id, batch_len=len_batch, acc_time=acc_time, loss_mean=loss_mean, max_mem=torch.cuda.max_memory_allocated()) if epoch % checkpoint_frequency == 0: path = file_logger.save_checkpoint(model, optimizer, epoch, model_cfg) trained_models.append(path) report_after_epoch(_run=_run, epoch=epoch, max_epoch=max_epochs) if validation_dataloaders is not None and \ epoch % checkpoint_frequency == 0: model.eval() # Lets cuDNN benchmark conv implementations and choose the fastest. # Only good if sizes stay the same within the main loop! # not the case for segmentation torch.backends.cudnn.benchmark = False score = evaluate(validation_dataloaders, model, epoch, keep=keep) logger.info(score) log_score(score, _run, prefix="val_", step=epoch) torch.backends.cudnn.benchmark = True model.train() report_after_training(_run=_run, max_epoch=max_epochs, total_time=time.time() - start_training_time) path = file_logger.save_checkpoint(model, optimizer, epoch, model_cfg) if path: trained_models.append(path) file_logger.close() # TODO get best performing val model evaluate_last = _run.config['training'].get('evaluate_last', 1) if len(trained_models) < evaluate_last: logger.info("Only saved %d models (evaluate_last=%d)", len(trained_models), evaluate_last) return trained_models[-evaluate_last:]
def _general_model_fn(features, pipeline_config, result_folder, dataset_info, feature_extractor, mode, num_gpu, visualization_file_names, eval_dir): num_classes = pipeline_config.dataset.num_classes add_background_class = pipeline_config.train_config.loss.name == 'softmax' if add_background_class: assert (num_classes == 1) num_classes += 1 image_batch = features[standard_fields.InputDataFields.image_decoded] if mode == tf.estimator.ModeKeys.PREDICT: annotation_mask_batch = None else: annotation_mask_batch = features[ standard_fields.InputDataFields.annotation_mask] if mode == tf.estimator.ModeKeys.TRAIN: # Augment images image_batch, annotation_mask_batch = preprocessor.apply_data_augmentation( pipeline_config.train_config.data_augmentation_options, images=image_batch, gt_masks=annotation_mask_batch, batch_size=pipeline_config.train_config.batch_size) # General preprocessing image_batch_preprocessed = preprocessor.preprocess( image_batch, pipeline_config.dataset.val_range, scale_input=pipeline_config.dataset.scale_input) network_output = feature_extractor.build_network( image_batch_preprocessed, is_training=mode == tf.estimator.ModeKeys.TRAIN, num_classes=num_classes, use_batch_norm=pipeline_config.model.use_batch_norm, bn_momentum=pipeline_config.model.batch_norm_momentum, bn_epsilon=pipeline_config.model.batch_norm_epsilon, activation_fn=activation_fn_builder.build(pipeline_config.model)) if mode == tf.estimator.ModeKeys.TRAIN: # Record model variable summaries for var in tf.trainable_variables(): tf.summary.histogram('ModelVars/' + var.op.name, var) network_output_shape = network_output.get_shape().as_list() if mode != tf.estimator.ModeKeys.PREDICT: if (network_output_shape[1:3] != annotation_mask_batch.get_shape().as_list()[1:3]): annotation_mask_batch = image_utils.central_crop( annotation_mask_batch, desired_size=network_output.get_shape().as_list()[1:3]) annotation_mask_batch = tf.cast(tf.clip_by_value( annotation_mask_batch, 0, 1), dtype=tf.int64) assert (len(annotation_mask_batch.get_shape()) == 4) assert (annotation_mask_batch.get_shape().as_list()[:3] == network_output.get_shape().as_list()[:3]) # We should not apply the loss to evaluation. This would just cause # our loss to be minimum for f2 score, but we also get the same # optimum if we just optimzie for f1 score if (pipeline_config.train_config.loss.use_weighted and mode == tf.estimator.ModeKeys.TRAIN): patient_ratio = dataset_info[ standard_fields.PickledDatasetInfo.patient_ratio] cancer_pixels = tf.reduce_sum(tf.to_float(annotation_mask_batch)) healthy_pixels = tf.to_float( tf.size(annotation_mask_batch)) - cancer_pixels batch_pixel_ratio = tf.div(healthy_pixels, cancer_pixels + 1.0) loss_weight = ( ((batch_pixel_ratio * patient_ratio) + pipeline_config.train_config.loss.weight_constant_add) * pipeline_config.train_config.loss.weight_constant_multiply) else: loss_weight = tf.constant(1.0) if mode == tf.estimator.ModeKeys.PREDICT: loss = None else: loss = _loss(tf.reshape(annotation_mask_batch, [-1]), tf.reshape(network_output, [-1, num_classes]), loss_name=pipeline_config.train_config.loss.name, pos_weight=loss_weight) loss = tf.identity(loss, name='ModelLoss') tf.summary.scalar(loss.op.name, loss, family='Loss') total_loss = tf.identity(loss, name='TotalLoss') if mode == tf.estimator.ModeKeys.TRAIN: if pipeline_config.train_config.add_regularization_loss: regularization_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) if regularization_losses: regularization_loss = tf.add_n(regularization_losses, name='RegularizationLoss') total_loss = tf.add_n([loss, regularization_loss], name='TotalLoss') tf.summary.scalar(regularization_loss.op.name, regularization_loss, family='Loss') tf.summary.scalar(total_loss.op.name, total_loss, family='Loss') total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.') scaffold = None update_ops = [] if mode == tf.estimator.ModeKeys.TRAIN: if pipeline_config.train_config.optimizer.use_moving_average: # EMA's are currently not supported with tf's DistributionStrategy. # Reenable once they fixed the bugs logging.warn( 'EMA is currently not supported with tf DistributionStrategy.') exit(1) pipeline_config.train_config.optimizer.use_moving_average = False # The swapping saver will swap the trained variables with their moving # averages before saving, thus removing the need to care for moving # averages during evaluation # scaffold = tf.train.Scaffold(saver=optimizer.swapping_saver()) optimizer, optimizer_summary_vars = optimizer_builder.build( pipeline_config.train_config.optimizer) for var in optimizer_summary_vars: tf.summary.scalar(var.op.name, var, family='LearningRate') grads_and_vars = optimizer.compute_gradients(total_loss) update_ops.append( optimizer.apply_gradients(grads_and_vars, global_step=tf.train.get_global_step())) graph_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) update_ops.append(graph_update_ops) update_op = tf.group(*update_ops, name='update_barrier') with tf.control_dependencies([update_op]): if mode == tf.estimator.ModeKeys.PREDICT: train_op = None else: train_op = tf.identity(total_loss) if mode == tf.estimator.ModeKeys.TRAIN: logging.info("Total number of trainable parameters: {}".format( np.sum([ np.prod(v.get_shape().as_list()) for v in tf.trainable_variables() ]))) # Training Hooks are not working with MirroredStrategy. Fixed in 1.13 #print_hook = session_hooks.PrintHook( # file_name=features[standard_fields.InputDataFields.image_file], # batch_pixel_ratio=batch_pixel_ratio) return tf.estimator.EstimatorSpec(mode, loss=total_loss, train_op=train_op, scaffold=scaffold) elif mode == tf.estimator.ModeKeys.EVAL: if pipeline_config.train_config.loss.name == 'sigmoid': scaled_network_output = tf.nn.sigmoid(network_output)[:, :, :, 0] elif pipeline_config.train_config.loss.name == 'softmax': assert (network_output.get_shape().as_list()[-1] == 2) scaled_network_output = tf.nn.softmax(network_output)[:, :, :, 1] # Metrics metric_dict, statistics_dict = metric_utils.get_metrics( scaled_network_output, annotation_mask_batch, tp_thresholds=np.array(pipeline_config.metrics_tp_thresholds, dtype=np.float32), parallel_iterations=min(pipeline_config.eval_config.batch_size, util_ops.get_cpu_count())) vis_hook = session_hooks.VisualizationHook( result_folder=result_folder, visualization_file_names=visualization_file_names, file_name=features[standard_fields.InputDataFields.image_file], image_decoded=image_batch, annotation_decoded=features[ standard_fields.InputDataFields.annotation_decoded], predicted_mask=scaled_network_output, eval_dir=eval_dir) patient_metric_hook = session_hooks.PatientMetricHook( statistics_dict=statistics_dict, patient_id=features[standard_fields.InputDataFields.patient_id], result_folder=result_folder, tp_thresholds=pipeline_config.metrics_tp_thresholds, eval_dir=eval_dir) return tf.estimator.EstimatorSpec( mode, loss=total_loss, train_op=train_op, evaluation_hooks=[vis_hook, patient_metric_hook], eval_metric_ops=metric_dict) elif mode == tf.estimator.ModeKeys.PREDICT: if pipeline_config.train_config.loss.name == 'sigmoid': scaled_network_output = tf.nn.sigmoid(network_output)[:, :, :, 0] elif pipeline_config.train_config.loss.name == 'softmax': assert (network_output.get_shape().as_list()[-1] == 2) scaled_network_output = tf.nn.softmax(network_output)[:, :, :, 1] vis_hook = session_hooks.VisualizationHook( result_folder=result_folder, visualization_file_names=None, file_name=features[standard_fields.InputDataFields.image_file], image_decoded=image_batch, annotation_decoded=None, predicted_mask=scaled_network_output, eval_dir=eval_dir) predicted_mask = tf.stack([ scaled_network_output * 255, tf.zeros_like(scaled_network_output), tf.zeros_like(scaled_network_output) ], axis=3) predicted_mask_overlay = tf.clip_by_value( features[standard_fields.InputDataFields.image_decoded] * 0.5 + predicted_mask, 0, 255) return tf.estimator.EstimatorSpec( mode, prediction_hooks=[vis_hook], predictions={ 'image_file': features[standard_fields.InputDataFields.image_file], 'prediction': predicted_mask_overlay }) else: assert (False)
def model_fn(features, labels, mode, params=None): params = params or {} total_loss, train_op, predictions, export_outputs = None, None, None, None is_training = mode == tf.estimator.ModeKeys.TRAIN model = init_model_fn(is_training=is_training, add_summaries=True) scaffold = None preprocessed_images = features[fields.InputDataFields.image] tf.logging.info('msg:{}'.format(preprocessed_images)) predictions = model.predict(preprocessed_images) if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL): total_loss = model.loss(predictions, labels[fields.InputDataFields.label]) global_step = tf.train.get_or_create_global_step() training_optimizer, optimizer_summayr_vars = optimizer_builder.build( ) if mode == tf.estimator.ModeKeys.TRAIN: for var in optimizer_summayr_vars: tf.summary.scalar(var.op.name, var) train_op = tf.contrib.layers.optimize_loss( loss=total_loss, global_step=global_step, # learning_rate=None, learning_rate=0.001, # clip_gradients=clip_gradients_value, # optimizer=training_optimizer, optimizer='Adam', # update_ops=model.updates(), # variables=trainable_variables, # summaries=summaries, name='') def postprocess_wrapper(predictions): return model.format_label(predictions) if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT): predictions = postprocess_wrapper(predictions) eval_metric_ops = None if mode == tf.estimator.ModeKeys.EVAL: eval_metric_ops = eval_utils.get_eval_metric_ops_for_evaluatiors( eval_config, model.format_label(predictions), model.format_label(labels[fields.InputDataFields.label])) # for var in optimizer_summayr_vars: # eval_metric_ops[var.op.name] = (var, tf.no_op()) if scaffold is None: # keep_checkpoint_every_n_hours = ( # train_config.keep_checkpoint_every_n_hours) saver = tf.train.Saver( sharded=True, # keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) scaffold = tf.train.Scaffold(saver=saver) return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, loss=total_loss, train_op=train_op, eval_metric_ops=eval_metric_ops, export_outputs=export_outputs, scaffold=scaffold)
def model_fn(features, labels, mode, params=None): """Constructs the object detection model. Args: features: Dictionary of feature tensors, returned from `input_fn`. labels: Dictionary of groundtruth tensors if mode is TRAIN or EVAL, otherwise None. mode: Mode key from tf.estimator.ModeKeys. params: Parameter dictionary passed from the estimator. Returns: An `EstimatorSpec` that encapsulates the model and its serving configurations. """ params = params or {} total_loss, train_op, detections, export_outputs = None, None, None, None is_training = mode == tf.estimator.ModeKeys.TRAIN detection_model = detection_model_fn(is_training=is_training, add_summaries=(not use_tpu)) scaffold_fn = None if mode == tf.estimator.ModeKeys.TRAIN: labels = unstack_batch( labels, unpad_groundtruth_tensors=train_config.unpad_groundtruth_tensors) elif mode == tf.estimator.ModeKeys.EVAL: labels = unstack_batch(labels, unpad_groundtruth_tensors=False) if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL): gt_boxes_list = labels[fields.InputDataFields.groundtruth_boxes] gt_classes_list = labels[fields.InputDataFields.groundtruth_classes] gt_masks_list = None if fields.InputDataFields.groundtruth_instance_masks in labels: gt_masks_list = labels[ fields.InputDataFields.groundtruth_instance_masks] gt_keypoints_list = None if fields.InputDataFields.groundtruth_keypoints in labels: gt_keypoints_list = labels[fields.InputDataFields.groundtruth_keypoints] detection_model.provide_groundtruth( groundtruth_boxes_list=gt_boxes_list, groundtruth_classes_list=gt_classes_list, groundtruth_masks_list=gt_masks_list, groundtruth_keypoints_list=gt_keypoints_list) preprocessed_images = features[fields.InputDataFields.image] prediction_dict = detection_model.predict( preprocessed_images, features[fields.InputDataFields.true_image_shape]) detections = detection_model.postprocess( prediction_dict, features[fields.InputDataFields.true_image_shape]) if mode == tf.estimator.ModeKeys.TRAIN: if train_config.fine_tune_checkpoint and hparams.load_pretrained: asg_map = detection_model.restore_map( from_detection_checkpoint=train_config.from_detection_checkpoint, load_all_detection_checkpoint_vars=( train_config.load_all_detection_checkpoint_vars)) available_var_map = ( variables_helper.get_variables_available_in_checkpoint( asg_map, train_config.fine_tune_checkpoint, include_global_step=False)) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint, available_var_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint, available_var_map) if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL): losses_dict = detection_model.loss( prediction_dict, features[fields.InputDataFields.true_image_shape]) losses = [loss_tensor for loss_tensor in losses_dict.itervalues()] total_loss = tf.add_n(losses, name='total_loss') if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_or_create_global_step() training_optimizer, optimizer_summary_vars = optimizer_builder.build( train_config.optimizer) if use_tpu: training_optimizer = tpu_optimizer.CrossShardOptimizer( training_optimizer) # Optionally freeze some layers by setting their gradients to be zero. trainable_variables = None if train_config.freeze_variables: trainable_variables = tf.contrib.framework.filter_variables( tf.trainable_variables(), exclude_patterns=train_config.freeze_variables) clip_gradients_value = None if train_config.gradient_clipping_by_norm > 0: clip_gradients_value = train_config.gradient_clipping_by_norm if not use_tpu: for var in optimizer_summary_vars: tf.summary.scalar(var.op.name, var) summaries = [] if use_tpu else None train_op = tf.contrib.layers.optimize_loss( loss=total_loss, global_step=global_step, learning_rate=None, clip_gradients=clip_gradients_value, optimizer=training_optimizer, variables=trainable_variables, summaries=summaries, name='') # Preventing scope prefix on all variables. if mode == tf.estimator.ModeKeys.PREDICT: export_outputs = { tf.saved_model.signature_constants.PREDICT_METHOD_NAME: tf.estimator.export.PredictOutput(detections) } eval_metric_ops = None if mode == tf.estimator.ModeKeys.EVAL: # Detection summaries during eval. class_agnostic = (fields.DetectionResultFields.detection_classes not in detections) groundtruth = _get_groundtruth_data(detection_model, class_agnostic) eval_dict = eval_util.result_dict_for_single_example( tf.expand_dims(features[fields.InputDataFields.original_image][0], 0), features[inputs.HASH_KEY][0], detections, groundtruth, class_agnostic=class_agnostic, scale_to_absolute=False) if class_agnostic: category_index = label_map_util.create_class_agnostic_category_index() else: category_index = label_map_util.create_category_index_from_labelmap( eval_input_config.label_map_path) detection_and_groundtruth = vis_utils.draw_side_by_side_evaluation_image( eval_dict, category_index, max_boxes_to_draw=20, min_score_thresh=0.2) if not use_tpu: tf.summary.image('Detections_Left_Groundtruth_Right', detection_and_groundtruth) # Eval metrics on a single image. detection_fields = fields.DetectionResultFields() input_data_fields = fields.InputDataFields() coco_evaluator = coco_evaluation.CocoDetectionEvaluator( category_index.values()) eval_metric_ops = coco_evaluator.get_estimator_eval_metric_ops( image_id=eval_dict[input_data_fields.key], groundtruth_boxes=eval_dict[input_data_fields.groundtruth_boxes], groundtruth_classes=eval_dict[input_data_fields.groundtruth_classes], detection_boxes=eval_dict[detection_fields.detection_boxes], detection_scores=eval_dict[detection_fields.detection_scores], detection_classes=eval_dict[detection_fields.detection_classes]) if use_tpu: return tf.contrib.tpu.TPUEstimatorSpec( mode=mode, scaffold_fn=scaffold_fn, predictions=detections, loss=total_loss, train_op=train_op, eval_metrics=eval_metric_ops, export_outputs=export_outputs) else: return tf.estimator.EstimatorSpec( mode=mode, predictions=detections, loss=total_loss, train_op=train_op, eval_metric_ops=eval_metric_ops, export_outputs=export_outputs)
def model_fn(features, labels, mode, params=None): """Constructs the object detection model. Args: features: Dictionary of feature tensors, returned from `input_fn`. labels: Dictionary of groundtruth tensors if mode is TRAIN or EVAL, otherwise None. mode: Mode key from tf.estimator.ModeKeys. params: Parameter dictionary passed from the estimator. Returns: An `EstimatorSpec` that encapsulates the model and its serving configurations. """ params = params or {} total_loss, train_op, detections, export_outputs = None, None, None, None is_training = mode == tf.estimator.ModeKeys.TRAIN # Make sure to set the Keras learning phase. True during training, # False for inference. tf.keras.backend.set_learning_phase(is_training) detection_model = detection_model_fn(is_training=is_training, add_summaries=(not use_tpu)) scaffold_fn = None if mode == tf.estimator.ModeKeys.TRAIN: labels = unstack_batch(labels, unpad_groundtruth_tensors=train_config. unpad_groundtruth_tensors) elif mode == tf.estimator.ModeKeys.EVAL: # For evaling on train data, it is necessary to check whether groundtruth # must be unpadded. boxes_shape = (labels[fields.InputDataFields.groundtruth_boxes]. get_shape().as_list()) unpad_groundtruth_tensors = boxes_shape[ 1] is not None and not use_tpu labels = unstack_batch( labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors) if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL): gt_boxes_list = labels[fields.InputDataFields.groundtruth_boxes] gt_classes_list = labels[ fields.InputDataFields.groundtruth_classes] gt_masks_list = None if fields.InputDataFields.groundtruth_instance_masks in labels: gt_masks_list = labels[ fields.InputDataFields.groundtruth_instance_masks] gt_keypoints_list = None if fields.InputDataFields.groundtruth_keypoints in labels: gt_keypoints_list = labels[ fields.InputDataFields.groundtruth_keypoints] gt_weights_list = None if fields.InputDataFields.groundtruth_weights in labels: gt_weights_list = labels[ fields.InputDataFields.groundtruth_weights] gt_confidences_list = None if fields.InputDataFields.groundtruth_confidences in labels: gt_confidences_list = labels[ fields.InputDataFields.groundtruth_confidences] gt_is_crowd_list = None if fields.InputDataFields.groundtruth_is_crowd in labels: gt_is_crowd_list = labels[ fields.InputDataFields.groundtruth_is_crowd] detection_model.provide_groundtruth( groundtruth_boxes_list=gt_boxes_list, groundtruth_classes_list=gt_classes_list, groundtruth_confidences_list=gt_confidences_list, groundtruth_masks_list=gt_masks_list, groundtruth_keypoints_list=gt_keypoints_list, groundtruth_weights_list=gt_weights_list, groundtruth_is_crowd_list=gt_is_crowd_list) preprocessed_images = features[fields.InputDataFields.image] if use_tpu and train_config.use_bfloat16: with tf.contrib.tpu.bfloat16_scope(): prediction_dict = detection_model.predict( preprocessed_images, features[fields.InputDataFields.true_image_shape]) for k, v in prediction_dict.items(): if v.dtype == tf.bfloat16: prediction_dict[k] = tf.cast(v, tf.float32) else: prediction_dict = detection_model.predict( preprocessed_images, features[fields.InputDataFields.true_image_shape]) if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT): detections = detection_model.postprocess( prediction_dict, features[fields.InputDataFields.true_image_shape]) if mode == tf.estimator.ModeKeys.TRAIN: if train_config.fine_tune_checkpoint and hparams.load_pretrained: if not train_config.fine_tune_checkpoint_type: # train_config.from_detection_checkpoint field is deprecated. For # backward compatibility, set train_config.fine_tune_checkpoint_type # based on train_config.from_detection_checkpoint. if train_config.from_detection_checkpoint: train_config.fine_tune_checkpoint_type = 'detection' else: train_config.fine_tune_checkpoint_type = 'classification' asg_map = detection_model.restore_map( fine_tune_checkpoint_type=train_config. fine_tune_checkpoint_type, load_all_detection_checkpoint_vars=( train_config.load_all_detection_checkpoint_vars)) available_var_map = ( variables_helper.get_variables_available_in_checkpoint( asg_map, train_config.fine_tune_checkpoint, include_global_step=False)) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint( train_config.fine_tune_checkpoint, available_var_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint( train_config.fine_tune_checkpoint, available_var_map) if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL): losses_dict = detection_model.loss( prediction_dict, features[fields.InputDataFields.true_image_shape]) losses = [loss_tensor for loss_tensor in losses_dict.values()] if train_config.add_regularization_loss: regularization_losses = detection_model.regularization_losses() if regularization_losses: regularization_loss = tf.add_n(regularization_losses, name='regularization_loss') losses.append(regularization_loss) losses_dict[ 'Loss/regularization_loss'] = regularization_loss total_loss = tf.add_n(losses, name='total_loss') losses_dict['Loss/total_loss'] = total_loss if 'graph_rewriter_config' in configs: graph_rewriter_fn = graph_rewriter_builder.build( configs['graph_rewriter_config'], is_training=is_training) graph_rewriter_fn() # TODO(rathodv): Stop creating optimizer summary vars in EVAL mode once we # can write learning rate summaries on TPU without host calls. global_step = tf.train.get_or_create_global_step() training_optimizer, optimizer_summary_vars = optimizer_builder.build( train_config.optimizer) if mode == tf.estimator.ModeKeys.TRAIN: if use_tpu: training_optimizer = tf.contrib.tpu.CrossShardOptimizer( training_optimizer) # Optionally freeze some layers by setting their gradients to be zero. trainable_variables = None include_variables = (train_config.update_trainable_variables if train_config.update_trainable_variables else None) exclude_variables = (train_config.freeze_variables if train_config.freeze_variables else None) trainable_variables = tf.contrib.framework.filter_variables( tf.trainable_variables(), include_patterns=include_variables, exclude_patterns=exclude_variables) clip_gradients_value = None if train_config.gradient_clipping_by_norm > 0: clip_gradients_value = train_config.gradient_clipping_by_norm if not use_tpu: for var in optimizer_summary_vars: tf.summary.scalar(var.op.name, var) summaries = [] if use_tpu else None if train_config.summarize_gradients: summaries = [ 'gradients', 'gradient_norm', 'global_gradient_norm' ] train_op = tf.contrib.layers.optimize_loss( loss=total_loss, global_step=global_step, learning_rate=None, clip_gradients=clip_gradients_value, optimizer=training_optimizer, update_ops=detection_model.updates(), variables=trainable_variables, summaries=summaries, name='') # Preventing scope prefix on all variables. if mode == tf.estimator.ModeKeys.PREDICT: exported_output = exporter_lib.add_output_tensor_nodes(detections) export_outputs = { tf.saved_model.signature_constants.PREDICT_METHOD_NAME: tf.estimator.export.PredictOutput(exported_output) } eval_metric_ops = None scaffold = None if mode == tf.estimator.ModeKeys.EVAL: class_agnostic = (fields.DetectionResultFields.detection_classes not in detections) groundtruth = _prepare_groundtruth_for_eval( detection_model, class_agnostic, eval_input_config.max_number_of_boxes) use_original_images = fields.InputDataFields.original_image in features if use_original_images: eval_images = features[fields.InputDataFields.original_image] true_image_shapes = tf.slice( features[fields.InputDataFields.true_image_shape], [0, 0], [-1, 3]) original_image_spatial_shapes = features[ fields.InputDataFields.original_image_spatial_shape] else: eval_images = features[fields.InputDataFields.image] true_image_shapes = None original_image_spatial_shapes = None eval_dict = eval_util.result_dict_for_batched_example( eval_images, features[inputs.HASH_KEY], detections, groundtruth, class_agnostic=class_agnostic, scale_to_absolute=True, original_image_spatial_shapes=original_image_spatial_shapes, true_image_shapes=true_image_shapes) if class_agnostic: category_index = label_map_util.create_class_agnostic_category_index( ) else: category_index = label_map_util.create_category_index_from_labelmap( eval_input_config.label_map_path) vis_metric_ops = None if not use_tpu and use_original_images: eval_metric_op_vis = vis_utils.VisualizeSingleFrameDetections( category_index, max_examples_to_draw=eval_config.num_visualizations, max_boxes_to_draw=eval_config.max_num_boxes_to_visualize, min_score_thresh=eval_config.min_score_threshold, use_normalized_coordinates=False) vis_metric_ops = eval_metric_op_vis.get_estimator_eval_metric_ops( eval_dict) # Eval metrics on a single example. eval_metric_ops = eval_util.get_eval_metric_ops_for_evaluators( eval_config, list(category_index.values()), eval_dict) for loss_key, loss_tensor in iter(losses_dict.items()): eval_metric_ops[loss_key] = tf.metrics.mean(loss_tensor) for var in optimizer_summary_vars: eval_metric_ops[var.op.name] = (var, tf.no_op()) if vis_metric_ops is not None: eval_metric_ops.update(vis_metric_ops) eval_metric_ops = {str(k): v for k, v in eval_metric_ops.items()} if eval_config.use_moving_averages: variable_averages = tf.train.ExponentialMovingAverage(0.0) variables_to_restore = variable_averages.variables_to_restore() keep_checkpoint_every_n_hours = ( train_config.keep_checkpoint_every_n_hours) saver = tf.train.Saver( variables_to_restore, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours ) scaffold = tf.train.Scaffold(saver=saver) # EVAL executes on CPU, so use regular non-TPU EstimatorSpec. if use_tpu and mode != tf.estimator.ModeKeys.EVAL: return tf.contrib.tpu.TPUEstimatorSpec( mode=mode, scaffold_fn=scaffold_fn, predictions=detections, loss=total_loss, train_op=train_op, eval_metrics=eval_metric_ops, export_outputs=export_outputs) else: if scaffold is None: keep_checkpoint_every_n_hours = ( train_config.keep_checkpoint_every_n_hours) saver = tf.train.Saver( sharded=True, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) scaffold = tf.train.Scaffold(saver=saver) return tf.estimator.EstimatorSpec(mode=mode, predictions=detections, loss=total_loss, train_op=train_op, eval_metric_ops=eval_metric_ops, export_outputs=export_outputs, scaffold=scaffold)
def train_segmentation_model(create_model_fn, create_input_fn, train_config, master, task, is_chief, startup_delay_steps, train_dir, num_clones, num_worker_replicas, num_ps_tasks, clone_on_cpu, replica_id, num_replicas, max_checkpoints_to_keep, save_interval_secs, image_summaries, log_memory=False, gradient_checkpoints=None, sync_bn_accross_gpu=False): """Create an instance of the FastSegmentationModel""" _, segmentation_model = create_model_fn() deploy_config = model_deploy.DeploymentConfig( num_clones=num_clones, clone_on_cpu=clone_on_cpu, replica_id=task, num_replicas=num_worker_replicas, num_ps_tasks=num_ps_tasks) startup_delay_steps = task * startup_delay_steps per_clone_batch_size = train_config.batch_size // num_clones preprocess_fn = None if train_config.preprocessor_step: preprocess_fn = functools.partial( preprocessor_builder.build, preprocessor_config_list=train_config.preprocessor_step) with tf.Graph().as_default(): # CPU of common ps server with tf.device(deploy_config.variables_device()): global_step = tf.train.get_or_create_global_step() with tf.device(deploy_config.inputs_device()): # CPU of each worker input_queue = create_training_input( create_input_fn, preprocess_fn, per_clone_batch_size, batch_queue_capacity=train_config.batch_queue_capacity, batch_queue_threads=train_config.num_batch_queue_threads, prefetch_queue_capacity=train_config.prefetch_queue_capacity) # Create the global step on the device storing the variables. with tf.device(deploy_config.variables_device()): # Note: it is assumed that any loss created by `model_fn` # is collected at the tf.GraphKeys.LOSSES collection. model_fn = functools.partial(create_training_model_losses, create_model_fn=create_model_fn, train_config=train_config, train_dir=train_dir, gradient_checkpoints=gradient_checkpoints) clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue]) first_clone_scope = deploy_config.clone_scope(0) if sync_bn_accross_gpu: # Attempt to sync BN updates across all GPU's in a tower. # Caution since this is very slow. Might not be needed update_ops = [] for idx in range(num_clones): nth_clone_sope = deploy_config.clone_scope(0) update_ops.extend(tf.get_collection( tf.GraphKeys.UPDATE_OPS, nth_clone_sope)) else: # Gather updates from first GPU only update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) # Init variable to collect summeries summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('Losses/%s' % loss.op.name, loss)) with tf.device(deploy_config.optimizer_device()): # CPU of each worker (training_optimizer, optimizer_summary_vars) = optimizer_builder.build( train_config.optimizer) for var in optimizer_summary_vars: summaries.add( tf.summary.scalar(var.op.name, var, family='LearningRate')) # Add summaries for model variables. for model_var in slim.get_model_variables(): summaries.add(tf.summary.histogram(model_var.op.name, model_var)) # Fine tune from classification or segmentation checkpoints trainable_vars = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES) if train_config.fine_tune_checkpoint: if not train_config.fine_tune_checkpoint_type: raise ValueError('Must specify `fine_tune_checkpoint_type`.') tf.logging.info('Initializing %s model from checkpoint %s', train_config.fine_tune_checkpoint_type, train_config.fine_tune_checkpoint) variables_to_restore = segmentation_model.restore_map( fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type) init_fn = slim.assign_from_checkpoint_fn( train_config.fine_tune_checkpoint, variables_to_restore, ignore_missing_vars=True) if train_config.freeze_fine_tune_backbone: tf.logging.info('Freezing %s scope from checkpoint.') non_frozen_vars = [] for var in trainable_vars: if not var.op.name.startswith( segmentation_model.shared_feature_extractor_scope): non_frozen_vars.append(var) tf.logging.info('Training variable: %s', var.op.name) trainable_vars = non_frozen_vars else: tf.logging.info('Not initializing the model from a checkpoint. ' 'Initializing from scratch!') # TODO(@oandrien): we might want to add gradient multiplier here # for the last layer if we have trouble with training # CPU of common ps server with tf.device(deploy_config.optimizer_device()): reg_losses = (None if train_config.add_regularization_loss else []) total_loss, grads_and_vars = model_deploy.optimize_clones( clones, training_optimizer, regularization_losses=reg_losses, var_list=trainable_vars) total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.') summaries.add( tf.summary.scalar('Losses/TotalLoss', total_loss)) grad_updates = training_optimizer.apply_gradients(grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops, name='update_barrier') with tf.control_dependencies([update_op]): train_op = tf.identity(total_loss, name='train_op') # TODO: this ideally should not be hardcoded like this. # should have a way to access the prediction and GT tensor if image_summaries: graph = tf.get_default_graph() pixel_scaling = max(1, 255 // 19) summ_first_clone_scope = (first_clone_scope + '/' if first_clone_scope else '') main_labels = graph.get_tensor_by_name( '%sSegmentationLoss/Reshape:0'% summ_first_clone_scope) main_preds = graph.get_tensor_by_name( '%sSegmentationLoss/Reshape_1:0'% summ_first_clone_scope) main_preds = tf.cast(main_preds * pixel_scaling, tf.uint8) summaries.add( tf.summary.image('VerifyTrainImages/Predictions', main_preds)) main_labels = tf.cast(main_labels * pixel_scaling, tf.uint8) summaries.add( tf.summary.image('VerifyTrainImages/Groundtruths', main_labels)) # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() # or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries)) session_config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=True) # Save checkpoints regularly. saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep) # HACK to see memory usage. # TODO: Clean up, pretty messy. def train_step_mem(sess, train_op, global_step, train_step_kwargs): start_time = time.time() run_metadata = tf.RunMetadata() options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) total_loss, np_global_step = sess.run([train_op, global_step], options=options, run_metadata=run_metadata) time_elapsed = time.time() - start_time if 'should_log' in train_step_kwargs: if sess.run(train_step_kwargs['should_log']): tf.logging.info( 'global step %d: loss = %.4f (%.3f sec/step)', np_global_step, total_loss, time_elapsed) if log_memory: mem_use = mem_util.peak_memory(run_metadata)['/gpu:0']/1e6 tf.logging.info('Memory used: %.2f MB',(mem_use)) if 'should_stop' in train_step_kwargs: should_stop = sess.run(train_step_kwargs['should_stop']) else: should_stop = False return total_loss, should_stop # Main training loop slim.learning.train( train_op, train_step_fn=train_step_mem, logdir=train_dir, master=master, is_chief=is_chief, session_config=session_config, number_of_steps=train_config.num_steps, startup_delay_steps=startup_delay_steps, init_fn=init_fn, summary_op=summary_op, save_summaries_secs=120, save_interval_secs=save_interval_secs, saver=saver)
def train(create_tensor_dict_fn, create_model_fn, train_config, input_config, master, task, num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name, is_chief, train_dir, save_interval_secs=3600, log_every_n_steps=1000): """Training function for detection models. Args: create_tensor_dict_fn: a function to create a tensor input dictionary. create_model_fn: a function that creates a DetectionModel and generates losses. train_config: a train_pb2.TrainConfig protobuf. input_config: a input_reader.InputReader protobuf. master: BNS name of the TensorFlow master to use. task: The task id of this training instance. num_clones: The number of clones to run per machine. worker_replicas: The number of work replicas to train with. clone_on_cpu: True if clones should be forced to run on CPU. ps_tasks: Number of parameter server tasks. worker_job_name: Name of the worker job. is_chief: Whether this replica is the chief replica. train_dir: Directory to write checkpoints and training summaries to. save_interval_secs: Interval in seconds to save a check point file. log_every_n_steps: The frequency, in terms of global steps, that the loss and global step are logged """ detection_model = create_model_fn() preprocess_input_options = [ preprocessor_input_builder.build(step) for step in input_config.preprocess_input_options] data_augmentation_options = [ preprocessor_builder.build(step) for step in train_config.data_augmentation_options] with tf.Graph().as_default(): # Build a configuration specifying multi-GPU and multi-replicas. deploy_config = model_deploy.DeploymentConfig( num_clones=num_clones, clone_on_cpu=clone_on_cpu, replica_id=task, num_replicas=worker_replicas, num_ps_tasks=ps_tasks, worker_job_name=worker_job_name) # Place the global step on the device storing the variables. with tf.device(deploy_config.variables_device()): global_step = slim.create_global_step() with tf.device(deploy_config.inputs_device()): input_queue = _create_input_queue(train_config.batch_size // num_clones, create_tensor_dict_fn, train_config.batch_queue_capacity, train_config.num_batch_queue_threads, train_config.prefetch_queue_capacity, data_augmentation_options, preprocess_input_options) # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) global_summaries = set([]) if detection_model.is_rbbox: model_fn = functools.partial(_create_losses_rbbox, create_model_fn=create_model_fn) else: model_fn = functools.partial(_create_losses, create_model_fn=create_model_fn) clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue]) first_clone_scope = clones[0].scope # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by model_fn. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) with tf.device(deploy_config.optimizer_device()): training_optimizer = optimizer_builder.build(train_config.optimizer, global_summaries) sync_optimizer = None if train_config.sync_replicas: training_optimizer = tf.SyncReplicasOptimizer( training_optimizer, replicas_to_aggregate=train_config.replicas_to_aggregate, total_num_replicas=train_config.worker_replicas) sync_optimizer = training_optimizer # Create ops required to initialize the model from a given checkpoint. init_fn = None if train_config.fine_tune_checkpoint: var_map = detection_model.restore_map( from_detection_checkpoint=train_config.from_detection_checkpoint) available_var_map = (variables_helper. get_variables_available_in_checkpoint( var_map, train_config.fine_tune_checkpoint)) init_saver = tf.train.Saver(available_var_map) def initializer_fn(sess): init_saver.restore(sess, train_config.fine_tune_checkpoint) init_fn = initializer_fn with tf.device(deploy_config.optimizer_device()): total_loss, grads_and_vars = model_deploy.optimize_clones( clones, training_optimizer, regularization_losses=None) total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.') # Optionally multiply bias gradients by train_config.bias_grad_multiplier. if train_config.bias_grad_multiplier: biases_regex_list = ['.*/biases'] grads_and_vars = variables_helper.multiply_gradients_matching_regex( grads_and_vars, biases_regex_list, multiplier=train_config.bias_grad_multiplier) # Optionally freeze some layers by setting their gradients to be zero. if train_config.freeze_variables: grads_and_vars = variables_helper.freeze_gradients_matching_regex( grads_and_vars, train_config.freeze_variables) # Optionally clip gradients if train_config.gradient_clipping_by_norm > 0: with tf.name_scope('clip_grads'): grads_and_vars = slim.learning.clip_gradient_norms( grads_and_vars, train_config.gradient_clipping_by_norm) # Create gradient updates. grad_updates = training_optimizer.apply_gradients(grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add summaries. for model_var in slim.get_model_variables(): global_summaries.add(tf.summary.histogram(model_var.op.name, model_var)) for loss_tensor in tf.losses.get_losses(): global_summaries.add(tf.summary.scalar(loss_tensor.op.name, loss_tensor)) global_summaries.add( tf.summary.scalar('TotalLoss', tf.losses.get_total_loss())) # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) summaries |= global_summaries # Merge all summaries together. summary_op = tf.summary.merge(list(summaries), name='summary_op') # Soft placement allows placing on CPU ops without GPU implementation. session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) # Save checkpoints regularly. keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours saver = tf.train.Saver( max_to_keep=None, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours) slim.learning.train( train_tensor, logdir=train_dir, log_every_n_steps=log_every_n_steps, master=master, is_chief=is_chief, session_config=session_config, startup_delay_steps=train_config.startup_delay_steps, init_fn=init_fn, summary_op=summary_op, number_of_steps=(train_config.num_steps if train_config.num_steps else None), save_summaries_secs=240, save_interval_secs=save_interval_secs, sync_optimizer=sync_optimizer, saver=saver)
def train_segmentation_model(create_model_fn, create_input_fn, train_config, model_config, master, task, is_chief, startup_delay_steps, train_dir, num_clones, num_worker_replicas, num_ps_tasks, clone_on_cpu, replica_id, num_replicas, max_checkpoints_to_keep, save_interval_secs, image_summaries, log_memory=False, gradient_checkpoints=None, sync_bn_accross_gpu=False): """Create an instance of the SegmentationModel""" _, segmentation_model = create_model_fn() deploy_config = model_deploy.DeploymentConfig( num_clones=num_clones, clone_on_cpu=clone_on_cpu, replica_id=task, num_replicas=num_worker_replicas, num_ps_tasks=num_ps_tasks) startup_delay_steps = task * startup_delay_steps per_clone_batch_size = train_config.batch_size #// num_clones preprocess_fn = None if train_config.preprocessor_step: preprocess_fn = functools.partial( preprocessor_builder.build, preprocessor_config_list=train_config.preprocessor_step) with tf.Graph().as_default(): # CPU of common ps server with tf.device(deploy_config.variables_device()): global_step = tf.train.get_or_create_global_step() with tf.device(deploy_config.inputs_device()): # CPU of each worker dataset = create_training_input(create_input_fn, preprocess_fn, per_clone_batch_size, num_clones) dataset = dataset.apply(tf.data.experimental.ignore_errors()) data_iterator = dataset.make_one_shot_iterator() # Create the global step on the device storing the variables. with tf.device(deploy_config.variables_device()): # Note: it is assumed that any loss created by `model_fn` # is collected at the tf.GraphKeys.LOSSES collection. model_fn = functools.partial( create_training_model_losses, create_model_fn=create_model_fn, train_config=train_config, train_dir=train_dir, gradient_checkpoints=gradient_checkpoints) clones = model_deploy.create_clones(deploy_config, model_fn, [data_iterator.get_next]) first_clone_scope = deploy_config.clone_scope(0) if sync_bn_accross_gpu: # Attempt to sync BN updates across all GPU's in a tower. # Caution since this is very slow. Might not be needed update_ops = [] for idx in range(num_clones): nth_clone_sope = deploy_config.clone_scope(idx) update_ops.extend( tf.get_collection(tf.GraphKeys.UPDATE_OPS, nth_clone_sope)) else: # Gather updates from first GPU only update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) # Init variable to collect summeries summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('Losses/%s' % loss.op.name, loss)) with tf.device(deploy_config.optimizer_device()): # CPU of each worker (training_optimizer, optimizer_summary_vars) = optimizer_builder.build( train_config.optimizer, num_clones) for var in optimizer_summary_vars: summaries.add( tf.summary.scalar(var.op.name, var, family='LearningRate')) # Add summaries for model variables. # for model_var in slim.get_model_variables(): # summaries.add(tf.summary.histogram(model_var.op.name, model_var)) # Fine tune from classification or segmentation checkpoints trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if train_config.fine_tune_checkpoint: if not train_config.fine_tune_checkpoint_type: raise ValueError('Must specify `fine_tune_checkpoint_type`.') tf.logging.info('Initializing %s model from checkpoint %s', train_config.fine_tune_checkpoint_type, train_config.fine_tune_checkpoint) variables_to_restore = segmentation_model.restore_map( train_config.fine_tune_checkpoint, fine_tune_checkpoint_type=train_config. fine_tune_checkpoint_type) writer = tf.summary.FileWriter(train_dir) writer.close() init_fn = slim.assign_from_checkpoint_fn( train_config.fine_tune_checkpoint, variables_to_restore, ignore_missing_vars=True) if train_config.freeze_fine_tune_backbone: tf.logging.info('Freezing %s scope from checkpoint.') non_frozen_vars = [] for var in trainable_vars: if not var.op.name.startswith( segmentation_model.shared_feature_extractor_scope): non_frozen_vars.append(var) tf.logging.info('Training variable: %s', var.op.name) trainable_vars = non_frozen_vars else: init_fn = None tf.logging.info('Not initializing the model from a checkpoint. ' 'Initializing from scratch!') # TODO(@oandrien): we might want to add gradient multiplier here # for the last layer if we have trouble with training # CPU of common ps server with tf.device(deploy_config.optimizer_device()): reg_losses = (None if train_config.add_regularization_loss else []) if model_config.pspnet.train_reduce and reg_losses is None: regularization_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) reg_losses = [ r for r in regularization_losses if "dim_reduce" in r.name ] total_loss, grads_and_vars = model_deploy.optimize_clones( clones, training_optimizer, regularization_losses=reg_losses, var_list=trainable_vars) # total_loss = tf.check_numerics(total_loss, # 'total_loss is inf or nan.') summaries.add(tf.summary.scalar('Losses/TotalLoss', total_loss)) # with tf.variable_scope("grad_clip"): # grads_and_vars = [(tf.clip_by_norm(grad, 1.), var) for grad, var in grads_and_vars] grad_updates = training_optimizer.apply_gradients( grads_and_vars, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops, name='update_barrier') with tf.control_dependencies([update_op]): train_op = tf.identity(total_loss, name='train_op') # TODO: this ideally should not be hardcoded like this. # should have a way to access the prediction and GT tensor if image_summaries: graph = tf.get_default_graph() pixel_scaling = max(1, 255 // 19) summ_first_clone_scope = (first_clone_scope + '/' if first_clone_scope else '') input_img = graph.get_tensor_by_name('%sInputs:0' % summ_first_clone_scope) main_labels = graph.get_tensor_by_name( '%sSegmentationLoss/ScaledLabels:0' % summ_first_clone_scope) main_preds = graph.get_tensor_by_name( '%sSegmentationLoss/ScaledPreds:0' % summ_first_clone_scope) summaries.add( tf.summary.image('VerifyTrainImages/Inputs', input_img)) logits = main_preds main_preds = tf.cast( tf.expand_dims(tf.argmax(main_preds, -1), -1) * pixel_scaling, tf.uint8) summaries.add( tf.summary.image('VerifyTrainImages/Predictions', main_preds)) main_labels = tf.cast(main_labels * pixel_scaling, tf.uint8) summaries.add( tf.summary.image('VerifyTrainImages/Groundtruths', main_labels)) # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() # or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries)) session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) session_config.gpu_options.allow_growth = True #load_vars = [v for v in tf.global_variables() if "Dont_Load" not in v.op.name] # Save checkpoints regularly. saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep) # gpu_mems = [] # for i in range(num_clones): # with tf.device("/gpu:"+str(i)): # gpu_mems.append(tf.cast(BytesInUse(), tf.float32)/float(1024*1024)) # HACK to see memory usage. # TODO: Clean up, pretty messy. # import pdb; pdb.set_trace() def train_step_mem(sess, train_op, global_step, train_step_kwargs): start_time = time.time() if log_memory: run_metadata = tf.RunMetadata() options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) else: run_metadata = None options = None total_loss, np_global_step, cur_gvs, dbg = sess.run( [train_op, global_step, grads_and_vars, dist_builder.DEBUG], options=options, run_metadata=run_metadata) time_elapsed = time.time() - start_time # graph = tf.get_default_graph() # main_labels = graph.get_tensor_by_name('SegmentationLoss/ScaledLabels:0') # label_out = sess.run(main_labels) # if len(np.unique(label_out)) != 1: # print(label_out) # import pdb; pdb.set_trace() # print(label_out) if 'should_log' in train_step_kwargs: if sess.run(train_step_kwargs['should_log']): tf.logging.info( 'global step %d: loss = %.4f (%.3f sec/step)', np_global_step, total_loss, time_elapsed) if log_memory: peaks = mem_util.peak_memory(run_metadata) for mem_use in peaks: # mem_use = mem_util.peak_memory(run_metadata)['/gpu:0']/1e6 if "/gpu" in mem_use: tf.logging.info('Memory used (%s): %.2f MB', mem_use, peaks[mem_use] / 1e6) # for m in mem: # tf.logging.info('Memory used: %.2f MB',(m)) if 'should_stop' in train_step_kwargs: should_stop = sess.run(train_step_kwargs['should_stop']) else: should_stop = False return total_loss, should_stop # Main training loop slim.learning.train(train_op, train_step_fn=train_step_mem, logdir=train_dir, master=master, is_chief=is_chief, session_config=session_config, number_of_steps=train_config.num_steps, startup_delay_steps=startup_delay_steps, init_fn=init_fn, summary_op=summary_op, save_summaries_secs=60, save_interval_secs=save_interval_secs, saver=saver)