def get_model_init_fn(train_logdir,
                      tf_initial_checkpoint,
                      initialize_last_layer,
                      last_layers,
                      ignore_missing_vars=True):
  """Gets the function initializing model variables from a checkpoint.

  Args:
    train_logdir: Log directory for training.
    tf_initial_checkpoint: TensorFlow checkpoint for initialization.
    initialize_last_layer: Initialize last layer or not.
    last_layers: Last layers of the model.
    ignore_missing_vars: Ignore missing variables in the checkpoint.

  Returns:
    Initialization function.
  """
  if tf_initial_checkpoint is None:
    tf.logging.info('Not initializing the model from a checkpoint.')
    return None

  if tf.train.latest_checkpoint(train_logdir):
    tf.logging.info('Ignoring initialization; other checkpoint exists')
    return None

  tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint)

  # Variables that will not be restored.

  # MobilenetV3_Large
  #exclude_list = ['global_step']

  # MobilenetV3_Small
  #exclude_list = ['global_step', 'image_pooling/', 'aspp0/', 'decoder/feature_projection0', 'MobilenetV3/expanded_conv/squeeze_excite/']
  exclude_list = ['global_step', 'decoder/feature_projection0', 'decoder/decoder_conv0_depthwise', 'decoder/decoder_conv0_pointwise', 'decoder/decoder_conv0_pointwise/']

  #exclude_list = ['global_step']
  if not initialize_last_layer:
    exclude_list.extend(last_layers)

  variables_to_restore = contrib_framework.get_variables_to_restore(
      exclude=exclude_list)

  if variables_to_restore:
    init_op, init_feed_dict = contrib_framework.assign_from_checkpoint(
        tf_initial_checkpoint,
        variables_to_restore,
        ignore_missing_vars=ignore_missing_vars)
    global_step = tf.train.get_or_create_global_step()

    def restore_fn(sess):
      sess.run(init_op, init_feed_dict)
      sess.run([global_step])

    return restore_fn

  return None
Example #2
0
def get_model_init_fn(train_logdir,
                      tf_initial_checkpoint,
                      initialize_last_layer,
                      last_layers,
                      ignore_missing_vars=False):
    """Gets the function initializing model variables from a checkpoint.

  Args:
    train_logdir: Log directory for training.
    tf_initial_checkpoint: TensorFlow checkpoint for initialization.
    initialize_last_layer: Initialize last layer or not.
    last_layers: Last layers of the model.
    ignore_missing_vars: Ignore missing variables in the checkpoint.

  Returns:
    Initialization function.
  """
    if tf_initial_checkpoint is None:
        tf.compat.v1.logging.info(
            'Not initializing the model from a checkpoint.')
        return None

    if tf.train.latest_checkpoint(train_logdir):
        tf.compat.v1.logging.info(
            'Ignoring initialization; other checkpoint exists')
        return None

    tf.compat.v1.logging.info('Initializing model from path: %s',
                              tf_initial_checkpoint)

    # Variables that will not be restored.
    exclude_list = ['global_step', 'logits']
    if not initialize_last_layer:
        exclude_list.extend(last_layers)

    variables_to_restore = contrib_framework.get_variables_to_restore(
        exclude=exclude_list)

    if variables_to_restore:
        init_op, init_feed_dict = contrib_framework.assign_from_checkpoint(
            tf_initial_checkpoint,
            variables_to_restore,
            ignore_missing_vars=ignore_missing_vars)
        global_step = tf.compat.v1.train.get_or_create_global_step()

        def restore_fn(sess):
            sess.run(init_op, init_feed_dict)
            sess.run([global_step])

        return restore_fn

    return None
Example #3
0
def train(train_dir,
          config,
          dataset_fn,
          checkpoints_to_keep=5,
          keep_checkpoint_every_n_hours=1,
          num_steps=None,
          master='',
          num_sync_workers=0,
          num_ps_tasks=0,
          task=0):
    """Train loop."""
    tf.gfile.MakeDirs(train_dir)
    is_chief = (task == 0)
    if is_chief:
        _trial_summary(config.hparams, config.train_examples_path
                       or config.tfds_name, train_dir)

    with tf.Graph().as_default():
        with tf.device(
                tf.train.replica_device_setter(num_ps_tasks,
                                               merge_devices=True)):
            model = config.model
            model.build(config.hparams,
                        config.data_converter.output_depth,
                        encoder_train=config.encoder_train,
                        decoder_train=config.decoder_train)
            optimizer = model.train(**_get_input_tensors(dataset_fn(), config))
            restored_vars = _get_restore_vars(config.var_train_pattern)
            _set_trainable_vars(config.var_train_pattern)

            hooks = []
            if num_sync_workers:
                optimizer = tf.train.SyncReplicasOptimizer(
                    optimizer, num_sync_workers)
                hooks.append(optimizer.make_session_run_hook(is_chief))

            grads, var_list = zip(*optimizer.compute_gradients(model.loss))
            global_norm = tf.global_norm(grads)
            tf.summary.scalar('global_norm', global_norm)

            if config.hparams.clip_mode == 'value':
                g = config.hparams.grad_clip
                clipped_grads = [
                    tf.clip_by_value(grad, -g, g) for grad in grads
                ]
            elif config.hparams.clip_mode == 'global_norm':
                clipped_grads = tf.cond(
                    global_norm < config.hparams.grad_norm_clip_to_zero,
                    lambda: tf.clip_by_global_norm(grads,
                                                   config.hparams.grad_clip,
                                                   use_norm=global_norm)[0],
                    lambda: [tf.zeros(tf.shape(g)) for g in grads])
            else:
                raise ValueError('Unknown clip_mode: {}'.format(
                    config.hparams.clip_mode))
            train_op = optimizer.apply_gradients(zip(clipped_grads, var_list),
                                                 global_step=model.global_step,
                                                 name='train_step')

            logging_dict = {
                'global_step': model.global_step,
                'loss': model.loss
            }

            hooks.append(
                tf.train.LoggingTensorHook(logging_dict, every_n_iter=5))
            if num_steps:
                hooks.append(tf.train.StopAtStepHook(last_step=num_steps))

            variables_to_restore = contrib_framework.get_variables_to_restore(
                include=[v.name for v in restored_vars])
            init_assign_op, init_feed_dict = contrib_framework.assign_from_checkpoint(
                config.pretrained_path, variables_to_restore)

            def InitAssignFn(scaffold, sess):
                sess.run(init_assign_op, init_feed_dict)

            scaffold = tf.train.Scaffold(
                init_fn=InitAssignFn,
                saver=tf.train.Saver(
                    max_to_keep=checkpoints_to_keep,
                    keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
                ))
            contrib_training.train(train_op=train_op,
                                   logdir=train_dir,
                                   scaffold=scaffold,
                                   hooks=hooks,
                                   save_checkpoint_secs=60,
                                   master=master,
                                   is_chief=is_chief)
Example #4
0
def get_model_init_fn(train_logdir,
                      tf_initial_checkpoint,
                      initialize_first_layer,
                      initialize_last_layer,
                      last_layers=None,
                      restore_adam=False,
                      ignore_missing_vars=False):
    """Gets the function initializing model variables from a checkpoint.
  Args:
    train_logdir: Log directory for training.
    tf_initial_checkpoint: TensorFlow checkpoint for initialization.
    initialize_last_layer: Initialize first layer or not.
    initialize_last_layer: Initialize last layer or not.
    last_layers: Last layers of the model.
    restore_adam: Restore Adam optimization parameters or not.
    ignore_missing_vars: Ignore missing variables in the checkpoint.
  Returns:
    Initialization function.
  """
    if tf_initial_checkpoint is None:
        tf.logging.info('Not initializing the model from a checkpoint.')
        return None

    if tf.train.latest_checkpoint(train_logdir):
        tf.logging.info('Ignoring initialization; other checkpoint exists')
        return None

    tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint)

    # Variables that will not be restored.
    exclude_list = ['global_step']
    if not initialize_last_layer:
        exclude_list.extend(last_layers)

    if not initialize_first_layer:
        exclude_list.append('resnet_v1_50/conv1_1/weights:0')

    variables_to_restore = contrib_framework.get_variables_to_restore(
        exclude=exclude_list)

    # Restore without Adam parameters
    if not restore_adam:
        new_v = []
        for v in variables_to_restore:
            if "Adam" not in v.name:
                new_v.append(v)
        variables_to_restore = new_v

    if variables_to_restore:
        init_op, init_feed_dict = contrib_framework.assign_from_checkpoint(
            tf_initial_checkpoint,
            variables_to_restore,
            ignore_missing_vars=ignore_missing_vars)

        global_step = tf.train.get_or_create_global_step()

        def restore_fn(scaffold, sess):
            sess.run(init_op, init_feed_dict)
            sess.run([global_step])

        return restore_fn

    return None