示例#1
0
def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
          num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name,
          is_chief, train_dir):
    """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
  """

    detection_model = create_model_fn()
    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options
    ]

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = tf.train.create_global_step()

        with tf.device(deploy_config.inputs_device()):
            input_queue = create_input_queue(
                train_config.batch_size // num_clones, create_tensor_dict_fn,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity,
                data_augmentation_options)

        # Gather initial summaries.
        # TODO(rathodv): See if summaries can be added/extracted from global tf
        # collections so that they don't have to be passed around.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        model_fn = functools.partial(_create_losses,
                                     create_model_fn=create_model_fn,
                                     train_config=train_config)
        clones = model_deploy.create_clones(deploy_config, model_fn,
                                            [input_queue])
        first_clone_scope = clones[0].scope

        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        with tf.device(deploy_config.optimizer_device()):
            training_optimizer = optimizer_builder.build(
                train_config.optimizer, global_summaries)

        sync_optimizer = None
        if train_config.sync_replicas:
            training_optimizer = tf.SyncReplicasOptimizer(
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=train_config.worker_replicas)
            sync_optimizer = training_optimizer

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        if train_config.fine_tune_checkpoint:
            var_map = detection_model.restore_map(
                from_detection_checkpoint=train_config.
                from_detection_checkpoint)
            available_var_map = (
                variables_helper.get_variables_available_in_checkpoint(
                    var_map, train_config.fine_tune_checkpoint))
            init_saver = tf.train.Saver(available_var_map)

            def initializer_fn(sess):
                init_saver.restore(sess, train_config.fine_tune_checkpoint)

            init_fn = initializer_fn

        with tf.device(deploy_config.optimizer_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, training_optimizer, regularization_losses=None)
            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally freeze some layers by setting their gradients to be zero.
            if train_config.freeze_variables:
                grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                    grads_and_vars, train_config.freeze_variables)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            update_ops.append(grad_updates)

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add summaries.
        for model_var in slim.get_model_variables():
            global_summaries.add(
                tf.summary.histogram(model_var.op.name, model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar(loss_tensor.op.name, loss_tensor))
        global_summaries.add(
            tf.summary.scalar('TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        session_config.gpu_options.allow_growth = True

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        slim.learning.train(
            train_tensor,
            logdir=train_dir,
            master=master,
            is_chief=is_chief,
            session_config=session_config,
            startup_delay_steps=train_config.startup_delay_steps,
            init_fn=init_fn,
            summary_op=summary_op,
            number_of_steps=(train_config.num_steps
                             if train_config.num_steps else None),
            save_summaries_secs=120,
            sync_optimizer=sync_optimizer,
            saver=saver)
示例#2
0
def general_train(make_loss, hparams, make_hooks=None):
    """Trains a general model with a loss.

  Args:
    make_loss: Function which creates loss (and possibly registers accuracy
      summaries and other features).
    hparams: Hyperparameters (see default_hparams() for details).
    make_hooks: Optional, function which creates additional hooks for training.

  Returns:
    Final loss.

  Raises:
    ValueError: If flags are missing or invalid.
  """
    train_dir = mode_dir('train')
    if not tf.gfile.Exists(train_dir):
        tf.gfile.MakeDirs(train_dir)
    if hparams.seed:
        tf.set_random_seed(hparams.seed)

    # Configure keras
    keras.backend.set_learning_phase(1)
    keras.backend.manual_variable_initialization(True)

    with tf.device(
            tf.train.replica_device_setter(FLAGS.ps_tasks,
                                           merge_devices=True)):
        # Set the caching device to prevent hangs during distributed training
        vs = tf.get_variable_scope()
        if vs.caching_device is None:
            vs.set_caching_device(lambda op: op.device)

        # Grab loss and global step
        total_loss = make_loss()
        global_step = slim.get_or_create_global_step()

        # Set up Polyak averaging if desired
        if hparams.use_averages:
            moving_average_variables = tf.trainable_variables()
            moving_average_variables.extend(slim.losses.get_losses())
            moving_average_variables.append(total_loss)
            variable_averages = tf.train.ExponentialMovingAverage(
                hparams.moving_average_decay, global_step)
            # For sync_replicas, averaging happens in the chief queue runner
            if not hparams.sync_replicas:
                tf.add_to_collection(
                    tf.GraphKeys.UPDATE_OPS,
                    variable_averages.apply(moving_average_variables))
        else:
            variable_averages = None
            moving_average_variables = None

        # Decay learning rate exponentially
        learning_rate = tf.train.exponential_decay(
            hparams.learning_rate,
            global_step,
            hparams.decay_steps,
            hparams.learning_rate_decay_factor,
            staircase=True)
        tf.contrib.deprecated.scalar_summary('learning rate', learning_rate)

        # Create optimizer
        if hparams.optimizer == 'adam':
            optimizer = tf.train.AdamOptimizer(learning_rate, epsilon=1e-3)
        elif hparams.optimizer == 'rmsprop':
            optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate,
                                                  decay=0.9,
                                                  momentum=0.9,
                                                  epsilon=1e-5)
        else:
            raise ValueError('Unknown optimizer %s' % hparams.optimizer)

        is_chief = FLAGS.task == 0
        chief_only_hooks = []

        hooks = [
            tf.train.LoggingTensorHook(
                {
                    'global_step': global_step,
                    'total_loss': total_loss
                },
                every_n_iter=FLAGS.log_every_n_iter),
            tf.train.NanTensorHook(total_loss),
            tf.train.StopAtStepHook(hparams.max_steps),
        ]

        if make_hooks is not None:
            hooks.extend(make_hooks())

        # If desired, optimize synchronously
        if hparams.sync_replicas:
            optimizer = tf.SyncReplicasOptimizer(
                optimizer=optimizer,
                replicas_to_aggregate=FLAGS.worker_replicas -
                hparams.backup_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables,
                replica_id=FLAGS.task,
                total_num_replicas=FLAGS.worker_replicas)
            sync_replicas_hook = optimizer.make_session_run_hook(is_chief)
            hooks.append(sync_replicas_hook)

        # Train
        train_tensor = slim.learning.create_train_op(
            total_loss,
            optimizer,
            clip_gradient_norm=hparams.gradient_clipping_norm)
        saver = tf.train.Saver(keep_checkpoint_every_n_hours=2)

        scaffold = tf.train.Scaffold(saver=saver)

        if FLAGS.save_summaries_secs > 0:
            save_summaries_secs = FLAGS.save_summaries_secs
            save_summaries_steps = None
        else:
            save_summaries_steps = FLAGS.save_summaries_steps
            save_summaries_secs = None
        with tf.train.MonitoredTrainingSession(
                master=FLAGS.master,
                is_chief=is_chief,
                hooks=hooks,
                chief_only_hooks=chief_only_hooks,
                checkpoint_dir=train_dir,
                scaffold=scaffold,
                save_checkpoint_secs=FLAGS.save_checkpoint_secs,
                save_summaries_secs=save_summaries_secs,
                save_summaries_steps=save_summaries_steps) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_tensor)
示例#3
0
    def resnet_model_fn(features, labels, mode, params):
        """Returns the model function."""
        global_step = tf.train.get_global_step()

        feature = features['feature']
        labels = labels['label']
        one_hot_labels = model_utils.get_label(labels,
                                               params,
                                               bird_num_classes,
                                               batch_size=params['batch_size'])

        def get_logits():
            """Return the logits."""
            end_points, aux_logits = None, None
            if FLAGS.model_type == 'resnet':
                avg_pool = model.resnet_v1_model(feature, labels, mode, params)
            else:
                assert False
            name = 'final_dense_dst'
            with tf.variable_scope('target_CLS'):
                logits = tf.layers.dense(
                    inputs=avg_pool,
                    units=bird_num_classes,
                    kernel_initializer=tf.random_normal_initializer(
                        stddev=.01),
                    name=name)
                if end_points is not None:
                    aux_pool = end_points['AuxLogits_Pool']
                    aux_logits = tf.layers.dense(
                        inputs=aux_pool,
                        units=bird_num_classes,
                        kernel_initializer=tf.random_normal_initializer(
                            stddev=.001),
                        name='Aux{}'.format(name))
            return logits, aux_logits, end_points

        logits, _, _ = get_logits()
        logits = tf.cast(logits, tf.float32)

        if FLAGS.model_type == 'resnet':
            dst_loss = tf.losses.softmax_cross_entropy(
                logits=logits,
                weights=1.,
                onehot_labels=one_hot_labels,
                label_smoothing=params['label_smoothing'])
            dst_l2_loss = FLAGS.weight_decay * tf.add_n([
                tf.nn.l2_loss(v) for v in tf.trainable_variables()
                if 'batch_normalization' not in v.name
            ])
            loss = dst_loss + dst_l2_loss

        train_op = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            cur_finetune_step = tf.train.get_global_step()
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                if FLAGS.model_type == 'resnet':
                    finetune_learning_rate = rampcosine()
                else:
                    finetune_learning_rate = rampcosine()
                if FLAGS.optimizer == 'momentum':
                    optimizer = tf.train.MomentumOptimizer(
                        learning_rate=finetune_learning_rate,
                        momentum=params['momentum'],
                        use_nesterov=True)
                elif FLAGS.optimizer == 'RMS':
                    optimizer = tf.train.RMSPropOptimizer(
                        finetune_learning_rate,
                        RMSPROP_DECAY,
                        momentum=RMSPROP_MOMENTUM,
                        epsilon=RMSPROP_EPSILON)
                elif FLAGS.optimizer == 'adam':
                    optimizer = tf.train.AdamOptimizer(finetune_learning_rate)

                optimizer = tf.SyncReplicasOptimizer(
                    optimizer,
                    replicas_to_aggregate=FLAGS.sync_replicas,
                    total_num_replicas=run_config.num_worker_replicas)
                train_op = tf.contrib.training.create_train_op(loss, optimizer)
                with tf.variable_scope('finetune'):
                    train_op = optimizer.minimize(loss, cur_finetune_step)
                if FLAGS.moving_average:
                    ema = tf.train.ExponentialMovingAverage(
                        decay=MOVING_AVERAGE_DECAY, num_updates=global_step)
                    variables_to_average = (tf.trainable_variables() +
                                            tf.moving_average_variables())
                    with tf.control_dependencies([train_op]):
                        with tf.name_scope('moving_average'):
                            train_op = ema.apply(variables_to_average)
        else:
            train_op = None

        batch_size = params['batch_size']  # pylint: disable=unused-variable
        eval_metrics = None
        if mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = model_utils.metric_fn(labels, logits)

        if mode == tf.estimator.ModeKeys.TRAIN:
            with tf.control_dependencies([train_op]):
                tf.summary.scalar('classifier/finetune_loss', loss)
                tf.summary.scalar('classifier/finetune_lr',
                                  finetune_learning_rate)
        else:
            train_op = None

        return tf.estimator.EstimatorSpec(
            mode=mode,
            loss=loss,
            train_op=train_op,
            eval_metric_ops=eval_metrics,
        )
示例#4
0
def main(unused_argv):
    # Create training directory if it doesn't already exist.
    if not tf.gfile.IsDirectory(FLAGS.train_dir):
        tf.logging.info("Creating training directory: %s", FLAGS.train_dir)
        tf.gfile.MakeDirs(FLAGS.train_dir)

    # Set up the model config.
    model_config = configuration.model_config(
        input_file_pattern=FLAGS.input_file_pattern)
    if FLAGS.model_config_overrides:
        model_config.parse_json(FLAGS.model_config_overrides)
    _log_config(model_config, "model_config")

    # Set up the training config.
    training_config = configuration.training_config()
    if FLAGS.training_config_overrides:
        training_config.parse_json(FLAGS.training_config_overrides)
    _log_config(training_config, "training_config")

    tf.logging.info("Building training graph.")
    g = tf.Graph()
    with g.as_default(), g.device(
            tf.train.replica_device_setter(FLAGS.ps_tasks)):
        # Build the model.
        model = skip_thoughts_model.SkipThoughtsModel(model_config,
                                                      mode="train")
        model.build()

        _log_variable_device_placement()

        hooks = [
            # Stop training if loss is NaN.
            tf.train.NanTensorHook(model.total_loss),
            # Log every training step.
            tf.train.LoggingTensorHook(
                {
                    "global_step": model.global_step,
                    "total_loss": model.total_loss
                },
                every_n_iter=1)
        ]

        # Set up the learning rate and optimizer.
        learning_rate = training.create_learning_rate(training_config,
                                                      model.global_step)
        optimizer = training.create_optimizer(training_config, learning_rate)

        # Set up distributed sync or async training.
        is_chief = (FLAGS.task == 0)
        if FLAGS.sync_replicas:
            optimizer = tf.SyncReplicasOptimizer(
                optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.total_num_replicas)
            hooks.append(optimizer.make_session_run_hook(is_chief))
        else:
            # Startup delay for non-chief asynchronous workers.
            if not is_chief and training_config.startup_delay_steps:
                hooks.append(
                    tf.train.GlobalStepWaiterHook(
                        training_config.startup_delay_steps))

        train_tensor = training.create_train_op(training_config, optimizer,
                                                model)
        keep_every_n = training_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            max_to_keep=training_config.max_checkpoints_to_keep,
            keep_checkpoint_every_n_hours=keep_every_n,
            save_relative_paths=True)
        scaffold = tf.train.Scaffold(saver=saver)

        # Possibly set a step limit.
        if training_config.number_of_steps:
            hooks.append(
                tf.train.StopAtStepHook(
                    last_step=training_config.number_of_steps))

        # Create the TensorFlow session.
        with tf.train.MonitoredTrainingSession(
                master=FLAGS.master,
                is_chief=is_chief,
                checkpoint_dir=FLAGS.train_dir,
                scaffold=scaffold,
                hooks=hooks,
                save_checkpoint_secs=training_config.save_model_secs,
                save_summaries_steps=None,
                save_summaries_secs=training_config.save_summaries_secs
        ) as sess:

            # Run training.
            while not sess.should_stop():
                sess.run(train_tensor)
示例#5
0
def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
          num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name,
          is_chief, train_dir):
    """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
  """

    detection_model = create_model_fn()  #Object for create the detection model
    data_augmentation_options = [  #for ssd it's ssd random crop 
        preprocessor_builder.build(
            step)  #random_horizontal_flip in the faster rcnn config file 
        for step in train_config.data_augmentation_options
    ]

    with tf.Graph().as_default(
    ):  #we need a default graph in order to create the model
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.    #global step is needed to keep the records
        with tf.device(deploy_config.variables_device()
                       ):  #suitable device for operation  +++On CPU I think
            global_step = slim.create_global_step(
            )  #created the global step tensor


#The following will create an input Que images ,boxes m targets
        with tf.device(deploy_config.inputs_device()
                       ):  #Device to use to build the inputs ++++on CPU ??
            input_queue = _create_input_queue(
                train_config.batch_size //
                num_clones,  #here batch size/number_clones 
                create_tensor_dict_fn,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity,
                data_augmentation_options)  #random_horizontal_flip

        # Gather initial summaries.
        summaries = set(tf.get_collection(
            tf.GraphKeys.SUMMARIES))  #vreate the summeries
        global_summaries = set([])
        #Creating the loss
        model_fn = functools.partial(
            _create_losses,  #This will create the losses , It need a object of our model as an argivement 
            create_model_fn=create_model_fn)
        clones = model_deploy.create_clones(
            deploy_config, model_fn,
            [input_queue
             ])  #creating the clones with respect to t he input model fn
        first_clone_scope = clones[0].scope

        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        with tf.device(deploy_config.optimizer_device()):  #This is important
            training_optimizer = optimizer_builder.build(
                train_config.optimizer,  #optimization 
                global_summaries
            )  #will select rms_prop , Adam Here derectly we get the optimizer

        sync_optimizer = None
        if train_config.sync_replicas:
            training_optimizer = tf.SyncReplicasOptimizer(  #This is more of synchronising the optimizer because there are repicas doing optimizing
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=train_config.worker_replicas)
            sync_optimizer = training_optimizer

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        if train_config.fine_tune_checkpoint:  #This is the checkpoint path file
            init_fn = detection_model.restore_fn(  #Re storing the weights from the feature extractors 
                train_config.fine_tune_checkpoint,
                from_detection_checkpoint=train_config.
                from_detection_checkpoint
            )  #This is more of the initializer which is re-stored from check points

        with tf.device(deploy_config.optimizer_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(  #This gives the total loss and also the grad and var pairs (Tuple) 
                clones,
                training_optimizer,
                regularization_losses=None)
            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:  #We have not initialized a bias gradient multiplier
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally freeze some layers by setting their gradients to be zero.
            if train_config.freeze_variables:  #Here we are not freezing any may be it's good to freeze the
                #This will be usefult to go through the variables
                print("Priting the grad_and_vars to check the tuples ")
                print(grad_and_vars)
                grads_and_vars = variables_helper.freeze_gradients_matching_regex(  #input to this also grads and vars which means 
                    grads_and_vars,
                    train_config.freeze_variables)  #This function will output
                #We are getiing gradients and of their varaibles exept the froxen list
            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars,  #updating the gradinets list 
                global_step=global_step)
            update_ops.append(grad_updates)  #Here the new updated variables

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add summaries.
        for model_var in slim.get_model_variables():
            global_summaries.add(
                tf.summary.histogram(model_var.op.name, model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar(loss_tensor.op.name, loss_tensor))
        global_summaries.add(
            tf.summary.scalar('TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(  #saving the checkpoints 
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        slim.learning.train(  #Training the network using a compact function 
            train_tensor,
            logdir=train_dir,
            master=master,
            is_chief=is_chief,
            session_config=session_config,
            startup_delay_steps=train_config.startup_delay_steps,
            init_fn=init_fn,
            summary_op=summary_op,
            number_of_steps=(train_config.num_steps
                             if train_config.num_steps else None),
            save_summaries_secs=120,
            sync_optimizer=sync_optimizer,
            saver=saver)
示例#6
0
def train(create_tensor_dict_fn,
          create_model_fn,
          train_config,
          master,
          task,
          num_clones,
          worker_replicas,
          clone_on_cpu,
          ps_tasks,
          worker_job_name,
          is_chief,
          train_dir,
          num_examples,
          total_configs,
          model_config,
          is_first_training=True):
    """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
    num_examples: The number of examples in dataset for training.
    total_configs: config list
  """

    detection_model = create_model_fn()
    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options
    ]

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            if is_first_training:
                global_step = slim.create_global_step()
            else:
                prev_global_step = int(
                    train_config.fine_tune_checkpoint.split('-')[-1])
                global_step = variable_scope.get_variable(
                    ops.GraphKeys.GLOBAL_STEP,
                    dtype=dtypes.int64,
                    initializer=tf.constant(prev_global_step,
                                            dtype=dtypes.int64),
                    trainable=False,
                    collections=[
                        ops.GraphKeys.GLOBAL_VARIABLES,
                        ops.GraphKeys.GLOBAL_STEP
                    ])

        with tf.device(deploy_config.inputs_device()):
            input_queue = _create_input_queue(
                train_config.batch_size // num_clones,
                create_tensor_dict_fn,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity,
                data_augmentation_options,
                ignore_options=train_config.ignore_options,
                mtl_window=model_config.mtl.window,
                mtl_edgemask=model_config.mtl.edgemask)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        kwargs = {}
        kwargs['mtl'] = model_config.mtl

        update_schedule = None
        model_fn = functools.partial(
            _create_losses,
            create_model_fn=create_model_fn,
            show_image_summary=train_config.show_image_summary,
            update_schedule=update_schedule,
            **kwargs)
        clones = model_deploy.create_clones(deploy_config, model_fn,
                                            [input_queue])
        first_clone_scope = clones[0].scope
        with tf.device(deploy_config.optimizer_device()):
            training_optimizer = optimizer_builder.build(
                train_config.optimizer, global_summaries)

        sync_optimizer = None
        if train_config.sync_replicas:
            # TODO: support syncrhonous update for manual loss update
            training_optimizer = tf.SyncReplicasOptimizer(
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=train_config.worker_replicas)
            sync_optimizer = training_optimizer

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        if train_config.fine_tune_checkpoint:
            var_map = detection_model.restore_map(
                from_detection_checkpoint=train_config.
                from_detection_checkpoint,
                restore_box_predictor=train_config.restore_box_predictor,
                restore_window=train_config.restore_window,
                restore_edgemask=train_config.restore_edgemask,
                restore_closeness=train_config.restore_closeness,
                restore_mtl_refine=train_config.restore_mtl_refine,
            )
            available_var_map = (
                variables_helper.get_variables_available_in_checkpoint(
                    var_map, train_config.fine_tune_checkpoint))
            init_saver = tf.train.Saver(available_var_map)

            mtl = model_config.mtl
            mtl_init_saver_list = []

            def _get_mtl_init_saver(scope_name):
                _var_map = detection_model._feature_extractor.mtl_restore_from_classification_checkpoint_fn(
                    scope_name)
                if train_config.from_detection_checkpoint:
                    _var_map_new = dict()
                    for name, val in _var_map.iteritems():
                        _var_map_new[detection_model.
                                     second_stage_feature_extractor_scope +
                                     '/' + name] = val
                    _var_map = _var_map_new
                _available_var_map = (
                    variables_helper.get_variables_available_in_checkpoint(
                        _var_map, train_config.fine_tune_checkpoint))
                if _available_var_map:
                    return tf.train.Saver(_available_var_map)
                else:
                    return None

            # if mtl.share_second_stage_init and mtl.shared_feature == 'proposal_feature_maps':
            if mtl.share_second_stage_init and train_config.from_detection_checkpoint == False:
                if mtl.window:
                    mtl_init_saver_list.append(
                        _get_mtl_init_saver(
                            detection_model.window_box_predictor_scope))
                if mtl.closeness:
                    mtl_init_saver_list.append(
                        _get_mtl_init_saver(
                            detection_model.closeness_box_predictor_scope))
                if mtl.edgemask:
                    mtl_init_saver_list.append(
                        _get_mtl_init_saver(
                            detection_model.edgemask_predictor_scope))

            def initializer_fn(sess):
                init_saver.restore(sess, train_config.fine_tune_checkpoint)
                for mtl_init_saver in mtl_init_saver_list:
                    if not mtl_init_saver == None:
                        mtl_init_saver.restore(
                            sess, train_config.fine_tune_checkpoint)

            init_fn = initializer_fn

        def _get_trainable_variables(except_scopes=None):
            trainable_variables = tf.trainable_variables()
            if except_scopes is None:
                return trainable_variables
            for var in tf.trainable_variables():
                if any([scope in var.name for scope in except_scopes]):
                    trainable_variables.remove(var)
            return trainable_variables

        def _get_update_ops(except_scopes=None):
            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)
            if except_scopes is None:
                return update_ops
            for var in tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                         first_clone_scope):
                if any([scope in var.name for scope in except_scopes]):
                    update_ops.remove(var)
            return update_ops

        with tf.device(deploy_config.optimizer_device()):

            def _single_update():
                kwargs = {}
                _training_optimizer = training_optimizer
                kwargs['var_list'] = None
                update_ops = _get_update_ops()
                total_loss, grads_and_vars = model_deploy.optimize_clones(
                    clones,
                    _training_optimizer,
                    regularization_losses=None,
                    **kwargs)

                # Optionaly multiply gradients by train_config.{grad_multiplier,
                # divide_grad_by_batch}.
                if train_config.grad_multiplier or train_config.divide_grad_by_batch:
                    base_multiplier = train_config.grad_multiplier \
                        if train_config.grad_multiplier else 1.0
                    batch_divider = float(train_config.batch_size) \
                        if train_config.divide_grad_by_batch else 1.0
                    total_multiplier = base_multiplier / batch_divider
                    grads_and_vars = variables_helper.multiply_gradients_by_scalar_multiplier(
                        grads_and_vars, multiplier=total_multiplier)

                # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
                if train_config.bias_grad_multiplier:
                    biases_regex_list = ['.*/biases']
                    grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                        grads_and_vars,
                        biases_regex_list,
                        multiplier=train_config.bias_grad_multiplier)

                # Optionally freeze some layers by setting their gradients to be zero.
                if train_config.freeze_variables:
                    grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                        grads_and_vars, train_config.freeze_variables)

                # Optionally clip gradients
                if train_config.gradient_clipping_by_norm > 0:
                    with tf.name_scope('clip_grads'):
                        grads_and_vars = slim.learning.clip_gradient_norms(
                            grads_and_vars,
                            train_config.gradient_clipping_by_norm)

                # Create gradient updates.
                grad_updates = _training_optimizer.apply_gradients(
                    grads_and_vars, global_step=global_step)
                # update_ops.append(grad_updates)
                total_update_ops = update_ops + [grad_updates]

                update_op = tf.group(*total_update_ops)
                with tf.control_dependencies([update_op]):
                    train_tensor = tf.identity(total_loss, name=('train_op'))
                return train_tensor

            train_tensor = _single_update()

        # Add summaries.
        def _get_total_loss_with_collection(collection,
                                            add_regularization_losses=True,
                                            name="total_loss"):
            losses = tf.losses.get_losses(loss_collection=collection)
            if add_regularization_losses:
                losses += tf.losses.get_regularization_losses()
            return math_ops.add_n(losses, name=name)

        for model_var in slim.get_model_variables():
            global_summaries.add(
                tf.summary.histogram(model_var.op.name, model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar(loss_tensor.op.name, loss_tensor))
        global_summaries.add(
            tf.summary.scalar('TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # not contained in global_summaries
        config_summary_list = select_config_summary_list(total_configs,
                                                         as_matrix=False)

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        custom_learning.train(
            train_tensor,
            logdir=train_dir,
            master=master,
            is_chief=is_chief,
            global_step=(None if is_first_training else global_step),
            session_config=session_config,
            startup_delay_steps=train_config.startup_delay_steps,
            init_fn=init_fn,
            summary_op=summary_op,
            number_of_steps=(train_config.num_steps
                             if train_config.num_steps else None),
            log_every_n_steps=(train_config.log_every_n_steps
                               if train_config.log_every_n_steps else None),
            save_summaries_secs=train_config.save_summaries_secs,
            save_interval_secs=train_config.save_interval_secs,
            sync_optimizer=sync_optimizer,
            saver=saver,
            batch_size=train_config.batch_size,
            num_examples=num_examples,
            config_summary_list=config_summary_list)