Beispiel #1
0
 def testBuildEmptyOptimizer(self):
   optimizer_text_proto = """
   """
   optimizer_proto = optimizer_pb2.Optimizer()
   text_format.Merge(optimizer_text_proto, optimizer_proto)
   with self.assertRaises(ValueError):
     optimizer_builder.build(optimizer_proto)
 def testBuildEmptyOptimizer(self):
     optimizer_text_proto = """
 """
     global_summaries = set([])
     optimizer_proto = optimizer_pb2.Optimizer()
     text_format.Merge(optimizer_text_proto, optimizer_proto)
     with self.assertRaises(ValueError):
         optimizer_builder.build(optimizer_proto, global_summaries)
def test_train_cfg(cfg_file):
    with open(cfg_file) as f:
        cfg = json.load(f)

    cfg = build_config(cfg)
    print(cfg)
    train_cfg = cfg['training']
    dataloader_cfg = train_cfg['dataloader']
    model_cfg = train_cfg['model']
    optimizer_cfg = train_cfg['optimizer']
    loss_cfg = train_cfg['losses']
    scheduler_cfg = train_cfg['scheduler']
    device = torch.device('cuda')
    dataloader = dataloader_builder(dataloader_cfg)
    dataset = dataloader.dataset
    model = model_builder.build(model_cfg, dataset.info)

    optimizer = optimizer_builder.build(optimizer_cfg, model.parameters())

    #optimizer = torch.optim.SGD(model.parameters(), lr=eps0, momentum=0.9, weight_decay=5e-4)
    lr_scheduler = scheduler_builder.build(scheduler_cfg, optimizer)
    loss = loss_builder.build(loss_cfg)
    file_logger = log.get_file_logger()
    model = torch.nn.DataParallel(model)
    # new experiment
    model = model.train()
    trained_models = []
    while lr_scheduler.run:
        lr_scheduler.step()
        for batch_id, (data, split_info) in enumerate(dataloader):
            #print(data)
            optimizer.zero_grad()
            data['imgs'] = data['img'].to(device)
            print("imgs", data['imgs'])
            imgs = Variable(data['img'], requires_grad=True)
            endpoints = model(imgs, model.module.endpoints)
            # threoretically losses could also be caluclated distributed.
            losses = loss(endpoints, data, split_info)
            print("losses", losses)
            print(torch.mean(losses))
            loss_mean = torch.mean(losses)
            loss_mean.backward()
            optimizer.step()
        break
    path = file_logger.save_checkpoint(model, optimizer,
                                       lr_scheduler.last_epoch)
    if path:
        trained_models.append(path)
    file_logger.close()
Beispiel #4
0
 def testBuildAdamOptimizer(self):
   optimizer_text_proto = """
     adam_optimizer: {
       learning_rate: {
         constant_learning_rate {
           learning_rate: 0.002
         }
       }
     }
     use_moving_average: false
   """
   optimizer_proto = optimizer_pb2.Optimizer()
   text_format.Merge(optimizer_text_proto, optimizer_proto)
   optimizer, _ = optimizer_builder.build(optimizer_proto)
   self.assertTrue(isinstance(optimizer, tf.train.AdamOptimizer))
Beispiel #5
0
 def testBuildMomentumOptimizer(self):
   optimizer_text_proto = """
     momentum_optimizer: {
       learning_rate: {
         constant_learning_rate {
           learning_rate: 0.001
         }
       }
       momentum_optimizer_value: 0.99
     }
     use_moving_average: false
   """
   optimizer_proto = optimizer_pb2.Optimizer()
   text_format.Merge(optimizer_text_proto, optimizer_proto)
   optimizer, _ = optimizer_builder.build(optimizer_proto)
   self.assertTrue(isinstance(optimizer, tf.train.MomentumOptimizer))
 def testBuildMovingAverageOptimizer(self):
     optimizer_text_proto = """
   adam_optimizer: {
     learning_rate: {
       constant_learning_rate {
         learning_rate: 0.002
       }
     }
   }
   use_moving_average: True
 """
     global_summaries = set([])
     optimizer_proto = optimizer_pb2.Optimizer()
     text_format.Merge(optimizer_text_proto, optimizer_proto)
     optimizer = optimizer_builder.build(optimizer_proto, global_summaries)
     self.assertTrue(
         isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer))
Beispiel #7
0
 def testBuildMovingAverageOptimizerWithNonDefaultDecay(self):
   optimizer_text_proto = """
     adam_optimizer: {
       learning_rate: {
         constant_learning_rate {
           learning_rate: 0.002
         }
       }
     }
     use_moving_average: True
     moving_average_decay: 0.2
   """
   optimizer_proto = optimizer_pb2.Optimizer()
   text_format.Merge(optimizer_text_proto, optimizer_proto)
   optimizer, _ = optimizer_builder.build(optimizer_proto)
   self.assertTrue(
       isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer))
   # TODO: Find a way to not depend on the private members.
   self.assertAlmostEqual(optimizer._ema._decay, 0.2)
Beispiel #8
0
 def testBuildRMSPropOptimizer(self):
   optimizer_text_proto = """
     rms_prop_optimizer: {
       learning_rate: {
         exponential_decay_learning_rate {
           initial_learning_rate: 0.004
           decay_steps: 800720
           decay_factor: 0.95
         }
       }
       momentum_optimizer_value: 0.9
       decay: 0.9
       epsilon: 1.0
     }
     use_moving_average: false
   """
   optimizer_proto = optimizer_pb2.Optimizer()
   text_format.Merge(optimizer_text_proto, optimizer_proto)
   optimizer, _ = optimizer_builder.build(optimizer_proto)
   self.assertTrue(isinstance(optimizer, tf.train.RMSPropOptimizer))
Beispiel #9
0
def train(create_tensor_dict_fn_list, create_model_fn, train_config, master,
          task, num_clones, worker_replicas, clone_on_cpu, ps_tasks,
          worker_job_name, is_chief, train_dir):
    """Training function for models.
  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
  """
    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options
    ]

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = tf.train.create_global_step()

        with tf.device(deploy_config.inputs_device()), \
             tf.name_scope('Input'):
            input_queue_list = []
            for i, create_tensor_dict_fn in enumerate(
                    create_tensor_dict_fn_list):
                input_queue_list.append(
                    _create_input_queue(
                        train_config.batch_size[i] // num_clones,
                        create_tensor_dict_fn,
                        train_config.batch_queue_capacity,
                        train_config.num_batch_queue_threads,
                        train_config.prefetch_queue_capacity,
                        data_augmentation_options))

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        model_fn = functools.partial(_create_losses,
                                     create_model_fn=create_model_fn)
        clones = model_deploy.create_clones(deploy_config, model_fn,
                                            [input_queue_list])
        first_clone_scope = clones[0].scope

        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        with tf.device(deploy_config.optimizer_device()), \
             tf.name_scope('Optimizer'):
            training_optimizer = optimizer_builder.build(
                train_config.optimizer, global_summaries)

        sync_optimizer = None
        if train_config.sync_replicas:
            training_optimizer = tf.train.SyncReplicasOptimizer(
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=train_config.worker_replicas)
            sync_optimizer = training_optimizer

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        '''
    if train_config.fine_tune_checkpoint:
      var_map = detection_model.restore_map(
        from_detection_checkpoint=train_config.from_detection_checkpoint
      )
      available_var_map = variables_helper.get_variables_available_in_checkpoint(
        var_map,
        train_config.fine_tune_checkpoint
      )
      init_saver = tf.train.Saver(available_var_map)
      def initializer_fn(sess):
        init_saver.restore(sess, train_config.fine_tune_checkpoint)
      init_fn = initializer_fn
     '''
        if train_config.fine_tune_checkpoint:
            all_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            restore_vars = [
                var for var in all_vars
                if (var.name.split('/')[0] == 'FeatureExtractor'
                    and var.name.split('/')[1] == 'Convnet')
            ]
            pre_train_saver = tf.train.Saver(restore_vars)

            def load_pretrain(scaffold, sess):
                pre_train_saver.restore(sess,
                                        train_config.fine_tune_checkpoint)
        else:
            load_pretrain = None

        with tf.device(deploy_config.optimizer_device()), \
             tf.variable_scope('OptimizeClones'):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, training_optimizer, regularization_losses=None)
            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = [r'.*bias(?:es)?', r'.*beta']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally freeze some layers by setting their gradients to be zero.
            if train_config.freeze_variables:
                grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                    grads_and_vars, train_config.freeze_variables)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = tf.contrib.training.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            update_ops.append(grad_updates)

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add summaries.
        for (grad, var) in grads_and_vars:
            var_name = var.op.name
            grad_name = 'grad/' + var_name
            global_summaries.add(tf.summary.histogram(grad_name, grad))
            global_summaries.add(tf.summary.histogram(var_name, var))
        # for model_var in tf.contrib.framework.get_model_variables():
        #   global_summaries.add(tf.summary.histogram(model_var.op.name, model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar(loss_tensor.op.name, loss_tensor))
        global_summaries.add(
            tf.summary.scalar('TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        scaffold = tf.train.Scaffold(init_fn=load_pretrain,
                                     summary_op=summary_op,
                                     saver=saver)
        stop_hook = tf.train.StopAtStepHook(
            num_steps=(train_config.num_steps
                       if train_config.num_steps else None), )
        profile_hook = profile_session_run_hooks.ProfileAtStepHook(
            at_step=200, checkpoint_dir=train_dir)
        tf.contrib.training.train(
            train_tensor,
            train_dir,
            master=master,
            is_chief=is_chief,
            scaffold=scaffold,
            hooks=[stop_hook, profile_hook],
            chief_only_hooks=None,
            save_checkpoint_secs=train_config.save_checkpoint_secs,
            save_summaries_steps=train_config.save_summaries_steps,
            config=session_config)
Beispiel #10
0
def train(create_tensor_dict_fn,
          create_model_fn,
          train_config,
          master,
          task,
          num_clones,
          worker_replicas,
          clone_on_cpu,
          ps_tasks,
          worker_job_name,
          is_chief,
          train_dir,
          graph_hook_fn=None):
    """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
    graph_hook_fn: Optional function that is called after the inference graph is
      built (before optimization). This is helpful to perform additional changes
      to the training graph such as adding FakeQuant ops. The function should
      modify the default graph.

  Raises:
    ValueError: If both num_clones > 1 and train_config.sync_replicas is true.
  """

    detection_model = create_model_fn()
    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options
    ]

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        if num_clones != 1 and train_config.sync_replicas:
            raise ValueError('In Synchronous SGD mode num_clones must ',
                             'be 1. Found num_clones: {}'.format(num_clones))
        batch_size = train_config.batch_size // num_clones
        if train_config.sync_replicas:
            batch_size //= train_config.replicas_to_aggregate

        with tf.device(deploy_config.inputs_device()):
            input_queue = create_input_queue(
                batch_size, create_tensor_dict_fn,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity,
                data_augmentation_options)

        # Gather initial summaries.
        # TODO(rathodv): See if summaries can be added/extracted from global tf
        # collections so that they don't have to be passed around.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        model_fn = functools.partial(_create_losses,
                                     create_model_fn=create_model_fn,
                                     train_config=train_config)
        clones = model_deploy.create_clones(deploy_config, model_fn,
                                            [input_queue])
        first_clone_scope = clones[0].scope

        if graph_hook_fn:
            with tf.device(deploy_config.variables_device()):
                graph_hook_fn()

        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        with tf.device(deploy_config.optimizer_device()):
            training_optimizer, optimizer_summary_vars = optimizer_builder.build(
                train_config.optimizer)
            for var in optimizer_summary_vars:
                tf.summary.scalar(var.op.name, var, family='LearningRate')

        sync_optimizer = None
        if train_config.sync_replicas:
            training_optimizer = tf.train.SyncReplicasOptimizer(
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=worker_replicas)
            sync_optimizer = training_optimizer

        with tf.device(deploy_config.optimizer_device()):
            regularization_losses = (
                None if train_config.add_regularization_loss else [])
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones,
                training_optimizer,
                regularization_losses=regularization_losses)
            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally freeze some layers by setting their gradients to be zero.
            if train_config.freeze_variables:
                grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                    grads_and_vars, train_config.freeze_variables)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops, name='update_barrier')
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add summaries.
        for model_var in slim.get_model_variables():
            global_summaries.add(
                tf.summary.histogram('ModelVars/' + model_var.op.name,
                                     model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar('Losses/' + loss_tensor.op.name,
                                  loss_tensor))
        global_summaries.add(
            tf.summary.scalar('Losses/TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        if train_config.fine_tune_checkpoint:
            if not train_config.fine_tune_checkpoint_type:
                # train_config.from_detection_checkpoint field is deprecated. For
                # backward compatibility, fine_tune_checkpoint_type is set based on
                # from_detection_checkpoint.
                if train_config.from_detection_checkpoint:
                    train_config.fine_tune_checkpoint_type = 'detection'
                else:
                    train_config.fine_tune_checkpoint_type = 'classification'
            var_map = detection_model.restore_map(
                fine_tune_checkpoint_type=train_config.
                fine_tune_checkpoint_type,
                load_all_detection_checkpoint_vars=(
                    train_config.load_all_detection_checkpoint_vars))
            available_var_map = (
                variables_helper.get_variables_available_in_checkpoint(
                    var_map,
                    train_config.fine_tune_checkpoint,
                    include_global_step=False))
            init_saver = tf.train.Saver(available_var_map)

            def initializer_fn(sess):
                init_saver.restore(sess, train_config.fine_tune_checkpoint)

            init_fn = initializer_fn

        slim.learning.train(
            train_tensor,
            logdir=train_dir,
            master=master,
            is_chief=is_chief,
            session_config=session_config,
            startup_delay_steps=train_config.startup_delay_steps,
            init_fn=init_fn,
            summary_op=summary_op,
            number_of_steps=(train_config.num_steps
                             if train_config.num_steps else None),
            save_summaries_secs=120,
            sync_optimizer=sync_optimizer,
            saver=saver)
def run_train(dataloader_cfg, model_cfg, scheduler_cfg, optimizer_cfg,
              loss_cfg, validation_cfg, checkpoint_frequency,
              restore_checkpoint, max_epochs, _run):

    # Lets cuDNN benchmark conv implementations and choose the fastest.
    # Only good if sizes stay the same within the main loop!
    torch.backends.cudnn.benchmark = True
    exit_handler = ExitHandler()

    device = _run.config['device']
    device_id = _run.config['device_id']

    # during training just one dataloader
    dataloader = dataloader_builder.build(dataloader_cfg)[0]

    epoch = 0
    if restore_checkpoint is not None:
        model_cfg, optimizer_cfg, epoch = utils.restore_checkpoint(
            restore_checkpoint, model_cfg, optimizer_cfg)

    def overwrite(to_overwrite, dic):
        to_overwrite.update(dic)
        return to_overwrite

    # some models depend on dataset, for example num_joints
    model_cfg = overwrite(dataloader.dataset.info, model_cfg)
    model = model_builder.build(model_cfg)

    loss_cfg['model'] = model
    loss = loss_builder.build(loss_cfg)
    loss = loss.to(device)

    parameters = list(model.parameters()) + list(loss.parameters())
    optimizer = optimizer_builder.build(optimizer_cfg, parameters)

    lr_scheduler = scheduler_builder.build(scheduler_cfg, optimizer, epoch)

    if validation_cfg is None:
        validation_dataloaders = None
    else:
        validation_dataloaders = dataloader_builder.build(validation_cfg)
        keep = False

    file_logger = log.get_file_logger()
    logger = log.get_logger()

    model = torch.nn.DataParallel(model, device_ids=device_id)
    model.cuda()

    model = model.train()
    trained_models = []

    exit_handler.register(file_logger.save_checkpoint, model, optimizer,
                          "atexit", model_cfg)

    start_training_time = time.time()
    end = time.time()
    while epoch < max_epochs:
        epoch += 1
        lr_scheduler.step()
        logger.info("Starting Epoch %d/%d", epoch, max_epochs)
        len_batch = len(dataloader)
        acc_time = 0
        for batch_id, data in enumerate(dataloader):
            optimizer.zero_grad()
            endpoints = model(data, model.module.endpoints)
            logger.debug("datasets %s", list(data['split_info'].keys()))

            data.update(endpoints)
            # threoretically losses could also be caluclated distributed.
            losses = loss(endpoints, data)
            loss_mean = torch.mean(losses)
            loss_mean.backward()
            optimizer.step()

            acc_time += time.time() - end
            end = time.time()

            report_after_batch(_run=_run,
                               logger=logger,
                               batch_id=batch_id,
                               batch_len=len_batch,
                               acc_time=acc_time,
                               loss_mean=loss_mean,
                               max_mem=torch.cuda.max_memory_allocated())

        if epoch % checkpoint_frequency == 0:
            path = file_logger.save_checkpoint(model, optimizer, epoch,
                                               model_cfg)
            trained_models.append(path)

        report_after_epoch(_run=_run, epoch=epoch, max_epoch=max_epochs)

        if validation_dataloaders is not None and \
                epoch % checkpoint_frequency == 0:
            model.eval()

            # Lets cuDNN benchmark conv implementations and choose the fastest.
            # Only good if sizes stay the same within the main loop!
            # not the case for segmentation
            torch.backends.cudnn.benchmark = False
            score = evaluate(validation_dataloaders, model, epoch, keep=keep)
            logger.info(score)
            log_score(score, _run, prefix="val_", step=epoch)
            torch.backends.cudnn.benchmark = True
            model.train()

    report_after_training(_run=_run,
                          max_epoch=max_epochs,
                          total_time=time.time() - start_training_time)
    path = file_logger.save_checkpoint(model, optimizer, epoch, model_cfg)
    if path:
        trained_models.append(path)
    file_logger.close()
    # TODO get best performing val model
    evaluate_last = _run.config['training'].get('evaluate_last', 1)
    if len(trained_models) < evaluate_last:
        logger.info("Only saved %d models (evaluate_last=%d)",
                    len(trained_models), evaluate_last)
    return trained_models[-evaluate_last:]
Beispiel #12
0
def _general_model_fn(features, pipeline_config, result_folder, dataset_info,
                      feature_extractor, mode, num_gpu,
                      visualization_file_names, eval_dir):
    num_classes = pipeline_config.dataset.num_classes
    add_background_class = pipeline_config.train_config.loss.name == 'softmax'
    if add_background_class:
        assert (num_classes == 1)
        num_classes += 1

    image_batch = features[standard_fields.InputDataFields.image_decoded]

    if mode == tf.estimator.ModeKeys.PREDICT:
        annotation_mask_batch = None
    else:
        annotation_mask_batch = features[
            standard_fields.InputDataFields.annotation_mask]

    if mode == tf.estimator.ModeKeys.TRAIN:
        # Augment images
        image_batch, annotation_mask_batch = preprocessor.apply_data_augmentation(
            pipeline_config.train_config.data_augmentation_options,
            images=image_batch,
            gt_masks=annotation_mask_batch,
            batch_size=pipeline_config.train_config.batch_size)

    # General preprocessing
    image_batch_preprocessed = preprocessor.preprocess(
        image_batch,
        pipeline_config.dataset.val_range,
        scale_input=pipeline_config.dataset.scale_input)

    network_output = feature_extractor.build_network(
        image_batch_preprocessed,
        is_training=mode == tf.estimator.ModeKeys.TRAIN,
        num_classes=num_classes,
        use_batch_norm=pipeline_config.model.use_batch_norm,
        bn_momentum=pipeline_config.model.batch_norm_momentum,
        bn_epsilon=pipeline_config.model.batch_norm_epsilon,
        activation_fn=activation_fn_builder.build(pipeline_config.model))

    if mode == tf.estimator.ModeKeys.TRAIN:
        # Record model variable summaries
        for var in tf.trainable_variables():
            tf.summary.histogram('ModelVars/' + var.op.name, var)

    network_output_shape = network_output.get_shape().as_list()
    if mode != tf.estimator.ModeKeys.PREDICT:
        if (network_output_shape[1:3] !=
                annotation_mask_batch.get_shape().as_list()[1:3]):
            annotation_mask_batch = image_utils.central_crop(
                annotation_mask_batch,
                desired_size=network_output.get_shape().as_list()[1:3])

        annotation_mask_batch = tf.cast(tf.clip_by_value(
            annotation_mask_batch, 0, 1),
                                        dtype=tf.int64)

        assert (len(annotation_mask_batch.get_shape()) == 4)
        assert (annotation_mask_batch.get_shape().as_list()[:3] ==
                network_output.get_shape().as_list()[:3])

    # We should not apply the loss to evaluation. This would just cause
    # our loss to be minimum for f2 score, but we also get the same
    # optimum if we just optimzie for f1 score
    if (pipeline_config.train_config.loss.use_weighted
            and mode == tf.estimator.ModeKeys.TRAIN):
        patient_ratio = dataset_info[
            standard_fields.PickledDatasetInfo.patient_ratio]
        cancer_pixels = tf.reduce_sum(tf.to_float(annotation_mask_batch))
        healthy_pixels = tf.to_float(
            tf.size(annotation_mask_batch)) - cancer_pixels

        batch_pixel_ratio = tf.div(healthy_pixels, cancer_pixels + 1.0)

        loss_weight = (
            ((batch_pixel_ratio * patient_ratio) +
             pipeline_config.train_config.loss.weight_constant_add) *
            pipeline_config.train_config.loss.weight_constant_multiply)
    else:
        loss_weight = tf.constant(1.0)

    if mode == tf.estimator.ModeKeys.PREDICT:
        loss = None
    else:
        loss = _loss(tf.reshape(annotation_mask_batch, [-1]),
                     tf.reshape(network_output, [-1, num_classes]),
                     loss_name=pipeline_config.train_config.loss.name,
                     pos_weight=loss_weight)
        loss = tf.identity(loss, name='ModelLoss')
        tf.summary.scalar(loss.op.name, loss, family='Loss')

        total_loss = tf.identity(loss, name='TotalLoss')

        if mode == tf.estimator.ModeKeys.TRAIN:
            if pipeline_config.train_config.add_regularization_loss:
                regularization_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                if regularization_losses:
                    regularization_loss = tf.add_n(regularization_losses,
                                                   name='RegularizationLoss')
                    total_loss = tf.add_n([loss, regularization_loss],
                                          name='TotalLoss')
                    tf.summary.scalar(regularization_loss.op.name,
                                      regularization_loss,
                                      family='Loss')

        tf.summary.scalar(total_loss.op.name, total_loss, family='Loss')
        total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.')

    scaffold = None
    update_ops = []
    if mode == tf.estimator.ModeKeys.TRAIN:
        if pipeline_config.train_config.optimizer.use_moving_average:
            # EMA's are currently not supported with tf's DistributionStrategy.
            # Reenable once they fixed the bugs
            logging.warn(
                'EMA is currently not supported with tf DistributionStrategy.')
            exit(1)
            pipeline_config.train_config.optimizer.use_moving_average = False
            # The swapping saver will swap the trained variables with their moving
            # averages before saving, thus removing the need to care for moving
            # averages during evaluation
            # scaffold = tf.train.Scaffold(saver=optimizer.swapping_saver())

        optimizer, optimizer_summary_vars = optimizer_builder.build(
            pipeline_config.train_config.optimizer)
        for var in optimizer_summary_vars:
            tf.summary.scalar(var.op.name, var, family='LearningRate')

        grads_and_vars = optimizer.compute_gradients(total_loss)

        update_ops.append(
            optimizer.apply_gradients(grads_and_vars,
                                      global_step=tf.train.get_global_step()))

    graph_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    update_ops.append(graph_update_ops)
    update_op = tf.group(*update_ops, name='update_barrier')
    with tf.control_dependencies([update_op]):
        if mode == tf.estimator.ModeKeys.PREDICT:
            train_op = None
        else:
            train_op = tf.identity(total_loss)

    if mode == tf.estimator.ModeKeys.TRAIN:
        logging.info("Total number of trainable parameters: {}".format(
            np.sum([
                np.prod(v.get_shape().as_list())
                for v in tf.trainable_variables()
            ])))

        # Training Hooks are not working with MirroredStrategy. Fixed in 1.13
        #print_hook = session_hooks.PrintHook(
        #  file_name=features[standard_fields.InputDataFields.image_file],
        #  batch_pixel_ratio=batch_pixel_ratio)
        return tf.estimator.EstimatorSpec(mode,
                                          loss=total_loss,
                                          train_op=train_op,
                                          scaffold=scaffold)
    elif mode == tf.estimator.ModeKeys.EVAL:
        if pipeline_config.train_config.loss.name == 'sigmoid':
            scaled_network_output = tf.nn.sigmoid(network_output)[:, :, :, 0]
        elif pipeline_config.train_config.loss.name == 'softmax':
            assert (network_output.get_shape().as_list()[-1] == 2)
            scaled_network_output = tf.nn.softmax(network_output)[:, :, :, 1]

            # Metrics
        metric_dict, statistics_dict = metric_utils.get_metrics(
            scaled_network_output,
            annotation_mask_batch,
            tp_thresholds=np.array(pipeline_config.metrics_tp_thresholds,
                                   dtype=np.float32),
            parallel_iterations=min(pipeline_config.eval_config.batch_size,
                                    util_ops.get_cpu_count()))

        vis_hook = session_hooks.VisualizationHook(
            result_folder=result_folder,
            visualization_file_names=visualization_file_names,
            file_name=features[standard_fields.InputDataFields.image_file],
            image_decoded=image_batch,
            annotation_decoded=features[
                standard_fields.InputDataFields.annotation_decoded],
            predicted_mask=scaled_network_output,
            eval_dir=eval_dir)
        patient_metric_hook = session_hooks.PatientMetricHook(
            statistics_dict=statistics_dict,
            patient_id=features[standard_fields.InputDataFields.patient_id],
            result_folder=result_folder,
            tp_thresholds=pipeline_config.metrics_tp_thresholds,
            eval_dir=eval_dir)

        return tf.estimator.EstimatorSpec(
            mode,
            loss=total_loss,
            train_op=train_op,
            evaluation_hooks=[vis_hook, patient_metric_hook],
            eval_metric_ops=metric_dict)
    elif mode == tf.estimator.ModeKeys.PREDICT:
        if pipeline_config.train_config.loss.name == 'sigmoid':
            scaled_network_output = tf.nn.sigmoid(network_output)[:, :, :, 0]
        elif pipeline_config.train_config.loss.name == 'softmax':
            assert (network_output.get_shape().as_list()[-1] == 2)
            scaled_network_output = tf.nn.softmax(network_output)[:, :, :, 1]

        vis_hook = session_hooks.VisualizationHook(
            result_folder=result_folder,
            visualization_file_names=None,
            file_name=features[standard_fields.InputDataFields.image_file],
            image_decoded=image_batch,
            annotation_decoded=None,
            predicted_mask=scaled_network_output,
            eval_dir=eval_dir)

        predicted_mask = tf.stack([
            scaled_network_output * 255,
            tf.zeros_like(scaled_network_output),
            tf.zeros_like(scaled_network_output)
        ],
                                  axis=3)

        predicted_mask_overlay = tf.clip_by_value(
            features[standard_fields.InputDataFields.image_decoded] * 0.5 +
            predicted_mask, 0, 255)

        return tf.estimator.EstimatorSpec(
            mode,
            prediction_hooks=[vis_hook],
            predictions={
                'image_file':
                features[standard_fields.InputDataFields.image_file],
                'prediction': predicted_mask_overlay
            })
    else:
        assert (False)
Beispiel #13
0
    def model_fn(features, labels, mode, params=None):
        params = params or {}
        total_loss, train_op, predictions, export_outputs = None, None, None, None
        is_training = mode == tf.estimator.ModeKeys.TRAIN
        model = init_model_fn(is_training=is_training, add_summaries=True)
        scaffold = None

        preprocessed_images = features[fields.InputDataFields.image]
        tf.logging.info('msg:{}'.format(preprocessed_images))
        predictions = model.predict(preprocessed_images)

        if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
            total_loss = model.loss(predictions,
                                    labels[fields.InputDataFields.label])
            global_step = tf.train.get_or_create_global_step()
            training_optimizer, optimizer_summayr_vars = optimizer_builder.build(
            )

        if mode == tf.estimator.ModeKeys.TRAIN:

            for var in optimizer_summayr_vars:
                tf.summary.scalar(var.op.name, var)

            train_op = tf.contrib.layers.optimize_loss(
                loss=total_loss,
                global_step=global_step,
                # learning_rate=None,
                learning_rate=0.001,
                # clip_gradients=clip_gradients_value,
                # optimizer=training_optimizer,
                optimizer='Adam',
                # update_ops=model.updates(),
                # variables=trainable_variables,
                # summaries=summaries,
                name='')

        def postprocess_wrapper(predictions):
            return model.format_label(predictions)

        if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT):
            predictions = postprocess_wrapper(predictions)

        eval_metric_ops = None
        if mode == tf.estimator.ModeKeys.EVAL:
            eval_metric_ops = eval_utils.get_eval_metric_ops_for_evaluatiors(
                eval_config, model.format_label(predictions),
                model.format_label(labels[fields.InputDataFields.label]))

            # for var in optimizer_summayr_vars:
            #     eval_metric_ops[var.op.name] = (var, tf.no_op())

        if scaffold is None:
            # keep_checkpoint_every_n_hours = (
            #     train_config.keep_checkpoint_every_n_hours)
            saver = tf.train.Saver(
                sharded=True,
                # keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
                save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            scaffold = tf.train.Scaffold(saver=saver)

        return tf.estimator.EstimatorSpec(mode=mode,
                                          predictions=predictions,
                                          loss=total_loss,
                                          train_op=train_op,
                                          eval_metric_ops=eval_metric_ops,
                                          export_outputs=export_outputs,
                                          scaffold=scaffold)
  def model_fn(features, labels, mode, params=None):
    """Constructs the object detection model.

    Args:
      features: Dictionary of feature tensors, returned from `input_fn`.
      labels: Dictionary of groundtruth tensors if mode is TRAIN or EVAL,
        otherwise None.
      mode: Mode key from tf.estimator.ModeKeys.
      params: Parameter dictionary passed from the estimator.

    Returns:
      An `EstimatorSpec` that encapsulates the model and its serving
        configurations.
    """
    params = params or {}
    total_loss, train_op, detections, export_outputs = None, None, None, None
    is_training = mode == tf.estimator.ModeKeys.TRAIN
    detection_model = detection_model_fn(is_training=is_training,
                                         add_summaries=(not use_tpu))
    scaffold_fn = None

    if mode == tf.estimator.ModeKeys.TRAIN:
      labels = unstack_batch(
          labels,
          unpad_groundtruth_tensors=train_config.unpad_groundtruth_tensors)
    elif mode == tf.estimator.ModeKeys.EVAL:
      labels = unstack_batch(labels, unpad_groundtruth_tensors=False)

    if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
      gt_boxes_list = labels[fields.InputDataFields.groundtruth_boxes]
      gt_classes_list = labels[fields.InputDataFields.groundtruth_classes]
      gt_masks_list = None
      if fields.InputDataFields.groundtruth_instance_masks in labels:
        gt_masks_list = labels[
            fields.InputDataFields.groundtruth_instance_masks]
      gt_keypoints_list = None
      if fields.InputDataFields.groundtruth_keypoints in labels:
        gt_keypoints_list = labels[fields.InputDataFields.groundtruth_keypoints]
      detection_model.provide_groundtruth(
          groundtruth_boxes_list=gt_boxes_list,
          groundtruth_classes_list=gt_classes_list,
          groundtruth_masks_list=gt_masks_list,
          groundtruth_keypoints_list=gt_keypoints_list)

    preprocessed_images = features[fields.InputDataFields.image]
    prediction_dict = detection_model.predict(
        preprocessed_images, features[fields.InputDataFields.true_image_shape])
    detections = detection_model.postprocess(
        prediction_dict, features[fields.InputDataFields.true_image_shape])

    if mode == tf.estimator.ModeKeys.TRAIN:
      if train_config.fine_tune_checkpoint and hparams.load_pretrained:
        asg_map = detection_model.restore_map(
            from_detection_checkpoint=train_config.from_detection_checkpoint,
            load_all_detection_checkpoint_vars=(
                train_config.load_all_detection_checkpoint_vars))
        available_var_map = (
            variables_helper.get_variables_available_in_checkpoint(
                asg_map, train_config.fine_tune_checkpoint,
                include_global_step=False))
        if use_tpu:
          def tpu_scaffold():
            tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint,
                                          available_var_map)
            return tf.train.Scaffold()
          scaffold_fn = tpu_scaffold
        else:
          tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint,
                                        available_var_map)

    if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
      losses_dict = detection_model.loss(
          prediction_dict, features[fields.InputDataFields.true_image_shape])
      losses = [loss_tensor for loss_tensor in losses_dict.itervalues()]
      total_loss = tf.add_n(losses, name='total_loss')

    if mode == tf.estimator.ModeKeys.TRAIN:
      global_step = tf.train.get_or_create_global_step()
      training_optimizer, optimizer_summary_vars = optimizer_builder.build(
          train_config.optimizer)

      if use_tpu:
        training_optimizer = tpu_optimizer.CrossShardOptimizer(
            training_optimizer)

      # Optionally freeze some layers by setting their gradients to be zero.
      trainable_variables = None
      if train_config.freeze_variables:
        trainable_variables = tf.contrib.framework.filter_variables(
            tf.trainable_variables(),
            exclude_patterns=train_config.freeze_variables)

      clip_gradients_value = None
      if train_config.gradient_clipping_by_norm > 0:
        clip_gradients_value = train_config.gradient_clipping_by_norm

      if not use_tpu:
        for var in optimizer_summary_vars:
          tf.summary.scalar(var.op.name, var)
      summaries = [] if use_tpu else None
      train_op = tf.contrib.layers.optimize_loss(
          loss=total_loss,
          global_step=global_step,
          learning_rate=None,
          clip_gradients=clip_gradients_value,
          optimizer=training_optimizer,
          variables=trainable_variables,
          summaries=summaries,
          name='')  # Preventing scope prefix on all variables.

    if mode == tf.estimator.ModeKeys.PREDICT:
      export_outputs = {
          tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
              tf.estimator.export.PredictOutput(detections)
      }

    eval_metric_ops = None
    if mode == tf.estimator.ModeKeys.EVAL:
      # Detection summaries during eval.
      class_agnostic = (fields.DetectionResultFields.detection_classes
                        not in detections)
      groundtruth = _get_groundtruth_data(detection_model, class_agnostic)
      eval_dict = eval_util.result_dict_for_single_example(
          tf.expand_dims(features[fields.InputDataFields.original_image][0], 0),
          features[inputs.HASH_KEY][0],
          detections,
          groundtruth,
          class_agnostic=class_agnostic,
          scale_to_absolute=False)

      if class_agnostic:
        category_index = label_map_util.create_class_agnostic_category_index()
      else:
        category_index = label_map_util.create_category_index_from_labelmap(
            eval_input_config.label_map_path)
      detection_and_groundtruth = vis_utils.draw_side_by_side_evaluation_image(
          eval_dict, category_index, max_boxes_to_draw=20, min_score_thresh=0.2)
      if not use_tpu:
        tf.summary.image('Detections_Left_Groundtruth_Right',
                         detection_and_groundtruth)

      # Eval metrics on a single image.
      detection_fields = fields.DetectionResultFields()
      input_data_fields = fields.InputDataFields()
      coco_evaluator = coco_evaluation.CocoDetectionEvaluator(
          category_index.values())
      eval_metric_ops = coco_evaluator.get_estimator_eval_metric_ops(
          image_id=eval_dict[input_data_fields.key],
          groundtruth_boxes=eval_dict[input_data_fields.groundtruth_boxes],
          groundtruth_classes=eval_dict[input_data_fields.groundtruth_classes],
          detection_boxes=eval_dict[detection_fields.detection_boxes],
          detection_scores=eval_dict[detection_fields.detection_scores],
          detection_classes=eval_dict[detection_fields.detection_classes])

    if use_tpu:
      return tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode,
          scaffold_fn=scaffold_fn,
          predictions=detections,
          loss=total_loss,
          train_op=train_op,
          eval_metrics=eval_metric_ops,
          export_outputs=export_outputs)
    else:
      return tf.estimator.EstimatorSpec(
          mode=mode,
          predictions=detections,
          loss=total_loss,
          train_op=train_op,
          eval_metric_ops=eval_metric_ops,
          export_outputs=export_outputs)
Beispiel #15
0
    def model_fn(features, labels, mode, params=None):
        """Constructs the object detection model.

    Args:
      features: Dictionary of feature tensors, returned from `input_fn`.
      labels: Dictionary of groundtruth tensors if mode is TRAIN or EVAL,
        otherwise None.
      mode: Mode key from tf.estimator.ModeKeys.
      params: Parameter dictionary passed from the estimator.

    Returns:
      An `EstimatorSpec` that encapsulates the model and its serving
        configurations.
    """
        params = params or {}
        total_loss, train_op, detections, export_outputs = None, None, None, None
        is_training = mode == tf.estimator.ModeKeys.TRAIN

        # Make sure to set the Keras learning phase. True during training,
        # False for inference.
        tf.keras.backend.set_learning_phase(is_training)
        detection_model = detection_model_fn(is_training=is_training,
                                             add_summaries=(not use_tpu))
        scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:
            labels = unstack_batch(labels,
                                   unpad_groundtruth_tensors=train_config.
                                   unpad_groundtruth_tensors)
        elif mode == tf.estimator.ModeKeys.EVAL:
            # For evaling on train data, it is necessary to check whether groundtruth
            # must be unpadded.
            boxes_shape = (labels[fields.InputDataFields.groundtruth_boxes].
                           get_shape().as_list())
            unpad_groundtruth_tensors = boxes_shape[
                1] is not None and not use_tpu
            labels = unstack_batch(
                labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)

        if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
            gt_boxes_list = labels[fields.InputDataFields.groundtruth_boxes]
            gt_classes_list = labels[
                fields.InputDataFields.groundtruth_classes]
            gt_masks_list = None
            if fields.InputDataFields.groundtruth_instance_masks in labels:
                gt_masks_list = labels[
                    fields.InputDataFields.groundtruth_instance_masks]
            gt_keypoints_list = None
            if fields.InputDataFields.groundtruth_keypoints in labels:
                gt_keypoints_list = labels[
                    fields.InputDataFields.groundtruth_keypoints]
            gt_weights_list = None
            if fields.InputDataFields.groundtruth_weights in labels:
                gt_weights_list = labels[
                    fields.InputDataFields.groundtruth_weights]
            gt_confidences_list = None
            if fields.InputDataFields.groundtruth_confidences in labels:
                gt_confidences_list = labels[
                    fields.InputDataFields.groundtruth_confidences]
            gt_is_crowd_list = None
            if fields.InputDataFields.groundtruth_is_crowd in labels:
                gt_is_crowd_list = labels[
                    fields.InputDataFields.groundtruth_is_crowd]
            detection_model.provide_groundtruth(
                groundtruth_boxes_list=gt_boxes_list,
                groundtruth_classes_list=gt_classes_list,
                groundtruth_confidences_list=gt_confidences_list,
                groundtruth_masks_list=gt_masks_list,
                groundtruth_keypoints_list=gt_keypoints_list,
                groundtruth_weights_list=gt_weights_list,
                groundtruth_is_crowd_list=gt_is_crowd_list)

        preprocessed_images = features[fields.InputDataFields.image]
        if use_tpu and train_config.use_bfloat16:
            with tf.contrib.tpu.bfloat16_scope():
                prediction_dict = detection_model.predict(
                    preprocessed_images,
                    features[fields.InputDataFields.true_image_shape])
                for k, v in prediction_dict.items():
                    if v.dtype == tf.bfloat16:
                        prediction_dict[k] = tf.cast(v, tf.float32)
        else:
            prediction_dict = detection_model.predict(
                preprocessed_images,
                features[fields.InputDataFields.true_image_shape])
        if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT):
            detections = detection_model.postprocess(
                prediction_dict,
                features[fields.InputDataFields.true_image_shape])

        if mode == tf.estimator.ModeKeys.TRAIN:
            if train_config.fine_tune_checkpoint and hparams.load_pretrained:
                if not train_config.fine_tune_checkpoint_type:
                    # train_config.from_detection_checkpoint field is deprecated. For
                    # backward compatibility, set train_config.fine_tune_checkpoint_type
                    # based on train_config.from_detection_checkpoint.
                    if train_config.from_detection_checkpoint:
                        train_config.fine_tune_checkpoint_type = 'detection'
                    else:
                        train_config.fine_tune_checkpoint_type = 'classification'
                asg_map = detection_model.restore_map(
                    fine_tune_checkpoint_type=train_config.
                    fine_tune_checkpoint_type,
                    load_all_detection_checkpoint_vars=(
                        train_config.load_all_detection_checkpoint_vars))
                available_var_map = (
                    variables_helper.get_variables_available_in_checkpoint(
                        asg_map,
                        train_config.fine_tune_checkpoint,
                        include_global_step=False))
                if use_tpu:

                    def tpu_scaffold():
                        tf.train.init_from_checkpoint(
                            train_config.fine_tune_checkpoint,
                            available_var_map)
                        return tf.train.Scaffold()

                    scaffold_fn = tpu_scaffold
                else:
                    tf.train.init_from_checkpoint(
                        train_config.fine_tune_checkpoint, available_var_map)

        if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
            losses_dict = detection_model.loss(
                prediction_dict,
                features[fields.InputDataFields.true_image_shape])
            losses = [loss_tensor for loss_tensor in losses_dict.values()]
            if train_config.add_regularization_loss:
                regularization_losses = detection_model.regularization_losses()
                if regularization_losses:
                    regularization_loss = tf.add_n(regularization_losses,
                                                   name='regularization_loss')
                    losses.append(regularization_loss)
                    losses_dict[
                        'Loss/regularization_loss'] = regularization_loss
            total_loss = tf.add_n(losses, name='total_loss')
            losses_dict['Loss/total_loss'] = total_loss

            if 'graph_rewriter_config' in configs:
                graph_rewriter_fn = graph_rewriter_builder.build(
                    configs['graph_rewriter_config'], is_training=is_training)
                graph_rewriter_fn()

            # TODO(rathodv): Stop creating optimizer summary vars in EVAL mode once we
            # can write learning rate summaries on TPU without host calls.
            global_step = tf.train.get_or_create_global_step()
            training_optimizer, optimizer_summary_vars = optimizer_builder.build(
                train_config.optimizer)

        if mode == tf.estimator.ModeKeys.TRAIN:
            if use_tpu:
                training_optimizer = tf.contrib.tpu.CrossShardOptimizer(
                    training_optimizer)

            # Optionally freeze some layers by setting their gradients to be zero.
            trainable_variables = None
            include_variables = (train_config.update_trainable_variables
                                 if train_config.update_trainable_variables
                                 else None)
            exclude_variables = (train_config.freeze_variables
                                 if train_config.freeze_variables else None)
            trainable_variables = tf.contrib.framework.filter_variables(
                tf.trainable_variables(),
                include_patterns=include_variables,
                exclude_patterns=exclude_variables)

            clip_gradients_value = None
            if train_config.gradient_clipping_by_norm > 0:
                clip_gradients_value = train_config.gradient_clipping_by_norm

            if not use_tpu:
                for var in optimizer_summary_vars:
                    tf.summary.scalar(var.op.name, var)
            summaries = [] if use_tpu else None
            if train_config.summarize_gradients:
                summaries = [
                    'gradients', 'gradient_norm', 'global_gradient_norm'
                ]
            train_op = tf.contrib.layers.optimize_loss(
                loss=total_loss,
                global_step=global_step,
                learning_rate=None,
                clip_gradients=clip_gradients_value,
                optimizer=training_optimizer,
                update_ops=detection_model.updates(),
                variables=trainable_variables,
                summaries=summaries,
                name='')  # Preventing scope prefix on all variables.

        if mode == tf.estimator.ModeKeys.PREDICT:
            exported_output = exporter_lib.add_output_tensor_nodes(detections)
            export_outputs = {
                tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
                tf.estimator.export.PredictOutput(exported_output)
            }

        eval_metric_ops = None
        scaffold = None
        if mode == tf.estimator.ModeKeys.EVAL:
            class_agnostic = (fields.DetectionResultFields.detection_classes
                              not in detections)
            groundtruth = _prepare_groundtruth_for_eval(
                detection_model, class_agnostic,
                eval_input_config.max_number_of_boxes)
            use_original_images = fields.InputDataFields.original_image in features
            if use_original_images:
                eval_images = features[fields.InputDataFields.original_image]
                true_image_shapes = tf.slice(
                    features[fields.InputDataFields.true_image_shape], [0, 0],
                    [-1, 3])
                original_image_spatial_shapes = features[
                    fields.InputDataFields.original_image_spatial_shape]
            else:
                eval_images = features[fields.InputDataFields.image]
                true_image_shapes = None
                original_image_spatial_shapes = None

            eval_dict = eval_util.result_dict_for_batched_example(
                eval_images,
                features[inputs.HASH_KEY],
                detections,
                groundtruth,
                class_agnostic=class_agnostic,
                scale_to_absolute=True,
                original_image_spatial_shapes=original_image_spatial_shapes,
                true_image_shapes=true_image_shapes)

            if class_agnostic:
                category_index = label_map_util.create_class_agnostic_category_index(
                )
            else:
                category_index = label_map_util.create_category_index_from_labelmap(
                    eval_input_config.label_map_path)
            vis_metric_ops = None
            if not use_tpu and use_original_images:
                eval_metric_op_vis = vis_utils.VisualizeSingleFrameDetections(
                    category_index,
                    max_examples_to_draw=eval_config.num_visualizations,
                    max_boxes_to_draw=eval_config.max_num_boxes_to_visualize,
                    min_score_thresh=eval_config.min_score_threshold,
                    use_normalized_coordinates=False)
                vis_metric_ops = eval_metric_op_vis.get_estimator_eval_metric_ops(
                    eval_dict)

            # Eval metrics on a single example.
            eval_metric_ops = eval_util.get_eval_metric_ops_for_evaluators(
                eval_config, list(category_index.values()), eval_dict)
            for loss_key, loss_tensor in iter(losses_dict.items()):
                eval_metric_ops[loss_key] = tf.metrics.mean(loss_tensor)
            for var in optimizer_summary_vars:
                eval_metric_ops[var.op.name] = (var, tf.no_op())
            if vis_metric_ops is not None:
                eval_metric_ops.update(vis_metric_ops)
            eval_metric_ops = {str(k): v for k, v in eval_metric_ops.items()}

            if eval_config.use_moving_averages:
                variable_averages = tf.train.ExponentialMovingAverage(0.0)
                variables_to_restore = variable_averages.variables_to_restore()
                keep_checkpoint_every_n_hours = (
                    train_config.keep_checkpoint_every_n_hours)
                saver = tf.train.Saver(
                    variables_to_restore,
                    keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours
                )
                scaffold = tf.train.Scaffold(saver=saver)

        # EVAL executes on CPU, so use regular non-TPU EstimatorSpec.
        if use_tpu and mode != tf.estimator.ModeKeys.EVAL:
            return tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                scaffold_fn=scaffold_fn,
                predictions=detections,
                loss=total_loss,
                train_op=train_op,
                eval_metrics=eval_metric_ops,
                export_outputs=export_outputs)
        else:
            if scaffold is None:
                keep_checkpoint_every_n_hours = (
                    train_config.keep_checkpoint_every_n_hours)
                saver = tf.train.Saver(
                    sharded=True,
                    keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
                    save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                scaffold = tf.train.Scaffold(saver=saver)
            return tf.estimator.EstimatorSpec(mode=mode,
                                              predictions=detections,
                                              loss=total_loss,
                                              train_op=train_op,
                                              eval_metric_ops=eval_metric_ops,
                                              export_outputs=export_outputs,
                                              scaffold=scaffold)
def train_segmentation_model(create_model_fn,
                             create_input_fn,
                             train_config,
                             master,
                             task,
                             is_chief,
                             startup_delay_steps,
                             train_dir,
                             num_clones,
                             num_worker_replicas,
                             num_ps_tasks,
                             clone_on_cpu,
                             replica_id,
                             num_replicas,
                             max_checkpoints_to_keep,
                             save_interval_secs,
                             image_summaries,
                             log_memory=False,
                             gradient_checkpoints=None,
                             sync_bn_accross_gpu=False):
    """Create an instance of the FastSegmentationModel"""
    _, segmentation_model = create_model_fn()
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=num_clones,
        clone_on_cpu=clone_on_cpu,
        replica_id=task,
        num_replicas=num_worker_replicas,
        num_ps_tasks=num_ps_tasks)
    startup_delay_steps = task * startup_delay_steps

    per_clone_batch_size = train_config.batch_size // num_clones

    preprocess_fn = None
    if train_config.preprocessor_step:
        preprocess_fn = functools.partial(
            preprocessor_builder.build,
            preprocessor_config_list=train_config.preprocessor_step)

    with tf.Graph().as_default():
        # CPU of common ps server
        with tf.device(deploy_config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

        with tf.device(deploy_config.inputs_device()): # CPU of each worker
            input_queue = create_training_input(
                create_input_fn,
                preprocess_fn,
                per_clone_batch_size,
                batch_queue_capacity=train_config.batch_queue_capacity,
                batch_queue_threads=train_config.num_batch_queue_threads,
                prefetch_queue_capacity=train_config.prefetch_queue_capacity)

        # Create the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            # Note: it is assumed that any loss created by `model_fn`
            # is collected at the tf.GraphKeys.LOSSES collection.
            model_fn = functools.partial(create_training_model_losses,
                                    create_model_fn=create_model_fn,
                                    train_config=train_config,
                                    train_dir=train_dir,
                                    gradient_checkpoints=gradient_checkpoints)
            clones = model_deploy.create_clones(deploy_config,
                model_fn, [input_queue])
            first_clone_scope = deploy_config.clone_scope(0)

            if sync_bn_accross_gpu:
                # Attempt to sync BN updates across all GPU's in a tower.
                # Caution since this is very slow. Might not be needed
                update_ops = []
                for idx in range(num_clones):
                    nth_clone_sope = deploy_config.clone_scope(0)
                    update_ops.extend(tf.get_collection(
                        tf.GraphKeys.UPDATE_OPS, nth_clone_sope))
            else:
                # Gather updates from first GPU only
                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                               first_clone_scope)

        # Init variable to collect summeries
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('Losses/%s' % loss.op.name, loss))

        with tf.device(deploy_config.optimizer_device()): # CPU of each worker
            (training_optimizer,
              optimizer_summary_vars) = optimizer_builder.build(
                train_config.optimizer)
            for var in optimizer_summary_vars:
                summaries.add(
                    tf.summary.scalar(var.op.name, var, family='LearningRate'))

        # Add summaries for model variables.
        for model_var in slim.get_model_variables():
            summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        # Fine tune from classification or segmentation checkpoints
        trainable_vars = tf.get_collection(
                              tf.GraphKeys.TRAINABLE_VARIABLES)
        if train_config.fine_tune_checkpoint:
            if not train_config.fine_tune_checkpoint_type:
                raise ValueError('Must specify `fine_tune_checkpoint_type`.')

            tf.logging.info('Initializing %s model from checkpoint %s',
                train_config.fine_tune_checkpoint_type,
                train_config.fine_tune_checkpoint)

            variables_to_restore = segmentation_model.restore_map(
              fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type)

            init_fn = slim.assign_from_checkpoint_fn(
                        train_config.fine_tune_checkpoint,
                        variables_to_restore,
                        ignore_missing_vars=True)

            if train_config.freeze_fine_tune_backbone:
                tf.logging.info('Freezing %s scope from checkpoint.')
                non_frozen_vars = []
                for var in trainable_vars:
                    if not var.op.name.startswith(
                      segmentation_model.shared_feature_extractor_scope):
                        non_frozen_vars.append(var)
                        tf.logging.info('Training variable: %s', var.op.name)
                trainable_vars = non_frozen_vars
        else:
            tf.logging.info('Not initializing the model from a checkpoint. '
                            'Initializing from scratch!')

        # TODO(@oandrien): we might want to add gradient multiplier here
        # for the last layer if we have trouble with training
        # CPU of common ps server
        with tf.device(deploy_config.optimizer_device()):
            reg_losses = (None if train_config.add_regularization_loss
                               else [])
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, training_optimizer,
                regularization_losses=reg_losses,
                var_list=trainable_vars)
            total_loss = tf.check_numerics(total_loss,
                                          'LossTensor is inf or nan.')
            summaries.add(
                tf.summary.scalar('Losses/TotalLoss', total_loss))

            grad_updates = training_optimizer.apply_gradients(grads_and_vars,
                                                    global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops, name='update_barrier')
            with tf.control_dependencies([update_op]):
                train_op = tf.identity(total_loss, name='train_op')

        # TODO: this ideally should not be hardcoded like this.
        #   should have a way to access the prediction and GT tensor
        if image_summaries:
            graph = tf.get_default_graph()
            pixel_scaling = max(1, 255 // 19)
            summ_first_clone_scope = (first_clone_scope + '/'
                if first_clone_scope else '')
            main_labels = graph.get_tensor_by_name(
                '%sSegmentationLoss/Reshape:0'% summ_first_clone_scope)
            main_preds = graph.get_tensor_by_name(
                '%sSegmentationLoss/Reshape_1:0'% summ_first_clone_scope)
            main_preds = tf.cast(main_preds * pixel_scaling, tf.uint8)
            summaries.add(
              tf.summary.image('VerifyTrainImages/Predictions', main_preds))
            main_labels = tf.cast(main_labels * pixel_scaling, tf.uint8)
            summaries.add(
              tf.summary.image('VerifyTrainImages/Groundtruths', main_labels))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones()
        # or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        session_config = tf.ConfigProto(
            allow_soft_placement=True, log_device_placement=True)

        # Save checkpoints regularly.
        saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)

        # HACK to see memory usage.
        # TODO: Clean up, pretty messy.
        def train_step_mem(sess, train_op, global_step, train_step_kwargs):
            start_time = time.time()
            run_metadata = tf.RunMetadata()
            options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            total_loss, np_global_step = sess.run([train_op, global_step],
                                        options=options,
                                        run_metadata=run_metadata)
            time_elapsed = time.time() - start_time

            if 'should_log' in train_step_kwargs:
                if sess.run(train_step_kwargs['should_log']):
                    tf.logging.info(
                        'global step %d: loss = %.4f (%.3f sec/step)',
                        np_global_step, total_loss, time_elapsed)

            if log_memory:
                mem_use = mem_util.peak_memory(run_metadata)['/gpu:0']/1e6
                tf.logging.info('Memory used: %.2f MB',(mem_use))

            if 'should_stop' in train_step_kwargs:
                should_stop = sess.run(train_step_kwargs['should_stop'])
            else:
                should_stop = False

            return total_loss, should_stop

        # Main training loop
        slim.learning.train(
            train_op,
            train_step_fn=train_step_mem,
            logdir=train_dir,
            master=master,
            is_chief=is_chief,
            session_config=session_config,
            number_of_steps=train_config.num_steps,
            startup_delay_steps=startup_delay_steps,
            init_fn=init_fn,
            summary_op=summary_op,
            save_summaries_secs=120,
            save_interval_secs=save_interval_secs,
            saver=saver)
Beispiel #17
0
def train(create_tensor_dict_fn, create_model_fn, train_config, input_config, master, task,
          num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name,
          is_chief, train_dir, save_interval_secs=3600, log_every_n_steps=1000):
    """Training function for detection models.

    Args:
      create_tensor_dict_fn: a function to create a tensor input dictionary.
      create_model_fn: a function that creates a DetectionModel and generates
                       losses.
      train_config: a train_pb2.TrainConfig protobuf.
      input_config: a input_reader.InputReader protobuf.
      master: BNS name of the TensorFlow master to use.
      task: The task id of this training instance.
      num_clones: The number of clones to run per machine.
      worker_replicas: The number of work replicas to train with.
      clone_on_cpu: True if clones should be forced to run on CPU.
      ps_tasks: Number of parameter server tasks.
      worker_job_name: Name of the worker job.
      is_chief: Whether this replica is the chief replica.
      train_dir: Directory to write checkpoints and training summaries to.
      save_interval_secs: Interval in seconds to save a check point file.
      log_every_n_steps: The frequency, in terms of global steps, that the loss and global step are logged
    """

    detection_model = create_model_fn()

    preprocess_input_options = [
        preprocessor_input_builder.build(step)
        for step in input_config.preprocess_input_options]

    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options]

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        with tf.device(deploy_config.inputs_device()):
            input_queue = _create_input_queue(train_config.batch_size // num_clones,
                                              create_tensor_dict_fn,
                                              train_config.batch_queue_capacity,
                                              train_config.num_batch_queue_threads,
                                              train_config.prefetch_queue_capacity,
                                              data_augmentation_options,
                                              preprocess_input_options)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        if detection_model.is_rbbox:
            model_fn = functools.partial(_create_losses_rbbox,
                                         create_model_fn=create_model_fn)
        else:
            model_fn = functools.partial(_create_losses,
                                         create_model_fn=create_model_fn)
        clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue])
        first_clone_scope = clones[0].scope

        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

        with tf.device(deploy_config.optimizer_device()):
            training_optimizer = optimizer_builder.build(train_config.optimizer,
                                                         global_summaries)

        sync_optimizer = None
        if train_config.sync_replicas:
            training_optimizer = tf.SyncReplicasOptimizer(
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=train_config.worker_replicas)
            sync_optimizer = training_optimizer

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        if train_config.fine_tune_checkpoint:
            var_map = detection_model.restore_map(
                from_detection_checkpoint=train_config.from_detection_checkpoint)
            available_var_map = (variables_helper.
                get_variables_available_in_checkpoint(
                var_map, train_config.fine_tune_checkpoint))
            init_saver = tf.train.Saver(available_var_map)

            def initializer_fn(sess):
                init_saver.restore(sess, train_config.fine_tune_checkpoint)

            init_fn = initializer_fn

        with tf.device(deploy_config.optimizer_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, training_optimizer, regularization_losses=None)
            total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally freeze some layers by setting their gradients to be zero.
            if train_config.freeze_variables:
                grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                    grads_and_vars, train_config.freeze_variables)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(grads_and_vars,
                                                              global_step=global_step)
            update_ops.append(grad_updates)

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add summaries.
        for model_var in slim.get_model_variables():
            global_summaries.add(tf.summary.histogram(model_var.op.name, model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(tf.summary.scalar(loss_tensor.op.name, loss_tensor))
        global_summaries.add(
            tf.summary.scalar('TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                           first_clone_scope))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            max_to_keep=None,
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        slim.learning.train(
            train_tensor,
            logdir=train_dir,
            log_every_n_steps=log_every_n_steps,
            master=master,
            is_chief=is_chief,
            session_config=session_config,
            startup_delay_steps=train_config.startup_delay_steps,
            init_fn=init_fn,
            summary_op=summary_op,
            number_of_steps=(train_config.num_steps if train_config.num_steps else None),
            save_summaries_secs=240,
            save_interval_secs=save_interval_secs,
            sync_optimizer=sync_optimizer,
            saver=saver)
def train_segmentation_model(create_model_fn,
                             create_input_fn,
                             train_config,
                             model_config,
                             master,
                             task,
                             is_chief,
                             startup_delay_steps,
                             train_dir,
                             num_clones,
                             num_worker_replicas,
                             num_ps_tasks,
                             clone_on_cpu,
                             replica_id,
                             num_replicas,
                             max_checkpoints_to_keep,
                             save_interval_secs,
                             image_summaries,
                             log_memory=False,
                             gradient_checkpoints=None,
                             sync_bn_accross_gpu=False):
    """Create an instance of the SegmentationModel"""
    _, segmentation_model = create_model_fn()
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=num_clones,
        clone_on_cpu=clone_on_cpu,
        replica_id=task,
        num_replicas=num_worker_replicas,
        num_ps_tasks=num_ps_tasks)
    startup_delay_steps = task * startup_delay_steps

    per_clone_batch_size = train_config.batch_size  #// num_clones

    preprocess_fn = None
    if train_config.preprocessor_step:
        preprocess_fn = functools.partial(
            preprocessor_builder.build,
            preprocessor_config_list=train_config.preprocessor_step)

    with tf.Graph().as_default():
        # CPU of common ps server
        with tf.device(deploy_config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

        with tf.device(deploy_config.inputs_device()):  # CPU of each worker
            dataset = create_training_input(create_input_fn, preprocess_fn,
                                            per_clone_batch_size, num_clones)
            dataset = dataset.apply(tf.data.experimental.ignore_errors())
            data_iterator = dataset.make_one_shot_iterator()
        # Create the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            # Note: it is assumed that any loss created by `model_fn`
            # is collected at the tf.GraphKeys.LOSSES collection.
            model_fn = functools.partial(
                create_training_model_losses,
                create_model_fn=create_model_fn,
                train_config=train_config,
                train_dir=train_dir,
                gradient_checkpoints=gradient_checkpoints)
            clones = model_deploy.create_clones(deploy_config, model_fn,
                                                [data_iterator.get_next])
            first_clone_scope = deploy_config.clone_scope(0)

            if sync_bn_accross_gpu:
                # Attempt to sync BN updates across all GPU's in a tower.
                # Caution since this is very slow. Might not be needed
                update_ops = []
                for idx in range(num_clones):
                    nth_clone_sope = deploy_config.clone_scope(idx)
                    update_ops.extend(
                        tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                          nth_clone_sope))
            else:
                # Gather updates from first GPU only
                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                               first_clone_scope)

        # Init variable to collect summeries
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('Losses/%s' % loss.op.name, loss))

        with tf.device(deploy_config.optimizer_device()):  # CPU of each worker
            (training_optimizer,
             optimizer_summary_vars) = optimizer_builder.build(
                 train_config.optimizer, num_clones)
            for var in optimizer_summary_vars:
                summaries.add(
                    tf.summary.scalar(var.op.name, var, family='LearningRate'))

        # Add summaries for model variables.
        # for model_var in slim.get_model_variables():
        #     summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        # Fine tune from classification or segmentation checkpoints
        trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        if train_config.fine_tune_checkpoint:
            if not train_config.fine_tune_checkpoint_type:
                raise ValueError('Must specify `fine_tune_checkpoint_type`.')

            tf.logging.info('Initializing %s model from checkpoint %s',
                            train_config.fine_tune_checkpoint_type,
                            train_config.fine_tune_checkpoint)

            variables_to_restore = segmentation_model.restore_map(
                train_config.fine_tune_checkpoint,
                fine_tune_checkpoint_type=train_config.
                fine_tune_checkpoint_type)

            writer = tf.summary.FileWriter(train_dir)
            writer.close()

            init_fn = slim.assign_from_checkpoint_fn(
                train_config.fine_tune_checkpoint,
                variables_to_restore,
                ignore_missing_vars=True)

            if train_config.freeze_fine_tune_backbone:
                tf.logging.info('Freezing %s scope from checkpoint.')
                non_frozen_vars = []
                for var in trainable_vars:
                    if not var.op.name.startswith(
                            segmentation_model.shared_feature_extractor_scope):
                        non_frozen_vars.append(var)
                        tf.logging.info('Training variable: %s', var.op.name)
                trainable_vars = non_frozen_vars
        else:
            init_fn = None
            tf.logging.info('Not initializing the model from a checkpoint. '
                            'Initializing from scratch!')

        # TODO(@oandrien): we might want to add gradient multiplier here
        # for the last layer if we have trouble with training
        # CPU of common ps server
        with tf.device(deploy_config.optimizer_device()):
            reg_losses = (None if train_config.add_regularization_loss else [])
            if model_config.pspnet.train_reduce and reg_losses is None:
                regularization_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                reg_losses = [
                    r for r in regularization_losses if "dim_reduce" in r.name
                ]

            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones,
                training_optimizer,
                regularization_losses=reg_losses,
                var_list=trainable_vars)
            # total_loss = tf.check_numerics(total_loss,
            #                               'total_loss is inf or nan.')
            summaries.add(tf.summary.scalar('Losses/TotalLoss', total_loss))
            # with tf.variable_scope("grad_clip"):
            #     grads_and_vars = [(tf.clip_by_norm(grad, 1.), var) for grad, var in grads_and_vars]
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops, name='update_barrier')
            with tf.control_dependencies([update_op]):
                train_op = tf.identity(total_loss, name='train_op')

        # TODO: this ideally should not be hardcoded like this.
        #   should have a way to access the prediction and GT tensor
        if image_summaries:
            graph = tf.get_default_graph()
            pixel_scaling = max(1, 255 // 19)
            summ_first_clone_scope = (first_clone_scope +
                                      '/' if first_clone_scope else '')
            input_img = graph.get_tensor_by_name('%sInputs:0' %
                                                 summ_first_clone_scope)
            main_labels = graph.get_tensor_by_name(
                '%sSegmentationLoss/ScaledLabels:0' % summ_first_clone_scope)
            main_preds = graph.get_tensor_by_name(
                '%sSegmentationLoss/ScaledPreds:0' % summ_first_clone_scope)
            summaries.add(
                tf.summary.image('VerifyTrainImages/Inputs', input_img))
            logits = main_preds
            main_preds = tf.cast(
                tf.expand_dims(tf.argmax(main_preds, -1), -1) * pixel_scaling,
                tf.uint8)
            summaries.add(
                tf.summary.image('VerifyTrainImages/Predictions', main_preds))
            main_labels = tf.cast(main_labels * pixel_scaling, tf.uint8)
            summaries.add(
                tf.summary.image('VerifyTrainImages/Groundtruths',
                                 main_labels))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones()
        # or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)
        session_config.gpu_options.allow_growth = True

        #load_vars = [v for v in tf.global_variables() if "Dont_Load" not in v.op.name]

        # Save checkpoints regularly.
        saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)

        # gpu_mems = []
        # for i in range(num_clones):
        #     with tf.device("/gpu:"+str(i)):
        #         gpu_mems.append(tf.cast(BytesInUse(), tf.float32)/float(1024*1024))

        # HACK to see memory usage.
        # TODO: Clean up, pretty messy.
        # import pdb; pdb.set_trace()
        def train_step_mem(sess, train_op, global_step, train_step_kwargs):
            start_time = time.time()
            if log_memory:
                run_metadata = tf.RunMetadata()
                options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            else:
                run_metadata = None
                options = None
            total_loss, np_global_step, cur_gvs, dbg = sess.run(
                [train_op, global_step, grads_and_vars, dist_builder.DEBUG],
                options=options,
                run_metadata=run_metadata)
            time_elapsed = time.time() - start_time
            # graph = tf.get_default_graph()
            # main_labels = graph.get_tensor_by_name('SegmentationLoss/ScaledLabels:0')
            # label_out = sess.run(main_labels)
            # if len(np.unique(label_out)) != 1:
            #     print(label_out)
            #     import pdb; pdb.set_trace()
            #     print(label_out)
            if 'should_log' in train_step_kwargs:
                if sess.run(train_step_kwargs['should_log']):
                    tf.logging.info(
                        'global step %d: loss = %.4f (%.3f sec/step)',
                        np_global_step, total_loss, time_elapsed)

            if log_memory:
                peaks = mem_util.peak_memory(run_metadata)
                for mem_use in peaks:
                    # mem_use = mem_util.peak_memory(run_metadata)['/gpu:0']/1e6
                    if "/gpu" in mem_use:
                        tf.logging.info('Memory used (%s): %.2f MB', mem_use,
                                        peaks[mem_use] / 1e6)
                # for m in mem:
                #     tf.logging.info('Memory used: %.2f MB',(m))

            if 'should_stop' in train_step_kwargs:
                should_stop = sess.run(train_step_kwargs['should_stop'])
            else:
                should_stop = False

            return total_loss, should_stop

        # Main training loop
        slim.learning.train(train_op,
                            train_step_fn=train_step_mem,
                            logdir=train_dir,
                            master=master,
                            is_chief=is_chief,
                            session_config=session_config,
                            number_of_steps=train_config.num_steps,
                            startup_delay_steps=startup_delay_steps,
                            init_fn=init_fn,
                            summary_op=summary_op,
                            save_summaries_secs=60,
                            save_interval_secs=save_interval_secs,
                            saver=saver)