def get_model_init_fn(train_logdir, tf_initial_checkpoint, initialize_last_layer, last_layers, ignore_missing_vars=True): """Gets the function initializing model variables from a checkpoint. Args: train_logdir: Log directory for training. tf_initial_checkpoint: TensorFlow checkpoint for initialization. initialize_last_layer: Initialize last layer or not. last_layers: Last layers of the model. ignore_missing_vars: Ignore missing variables in the checkpoint. Returns: Initialization function. """ if tf_initial_checkpoint is None: tf.logging.info('Not initializing the model from a checkpoint.') return None if tf.train.latest_checkpoint(train_logdir): tf.logging.info('Ignoring initialization; other checkpoint exists') return None tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint) # Variables that will not be restored. # MobilenetV3_Large #exclude_list = ['global_step'] # MobilenetV3_Small #exclude_list = ['global_step', 'image_pooling/', 'aspp0/', 'decoder/feature_projection0', 'MobilenetV3/expanded_conv/squeeze_excite/'] exclude_list = ['global_step', 'decoder/feature_projection0', 'decoder/decoder_conv0_depthwise', 'decoder/decoder_conv0_pointwise', 'decoder/decoder_conv0_pointwise/'] #exclude_list = ['global_step'] if not initialize_last_layer: exclude_list.extend(last_layers) variables_to_restore = contrib_framework.get_variables_to_restore( exclude=exclude_list) if variables_to_restore: init_op, init_feed_dict = contrib_framework.assign_from_checkpoint( tf_initial_checkpoint, variables_to_restore, ignore_missing_vars=ignore_missing_vars) global_step = tf.train.get_or_create_global_step() def restore_fn(sess): sess.run(init_op, init_feed_dict) sess.run([global_step]) return restore_fn return None
def get_model_init_fn(train_logdir, tf_initial_checkpoint, initialize_last_layer, last_layers, ignore_missing_vars=False): """Gets the function initializing model variables from a checkpoint. Args: train_logdir: Log directory for training. tf_initial_checkpoint: TensorFlow checkpoint for initialization. initialize_last_layer: Initialize last layer or not. last_layers: Last layers of the model. ignore_missing_vars: Ignore missing variables in the checkpoint. Returns: Initialization function. """ if tf_initial_checkpoint is None: tf.compat.v1.logging.info( 'Not initializing the model from a checkpoint.') return None if tf.train.latest_checkpoint(train_logdir): tf.compat.v1.logging.info( 'Ignoring initialization; other checkpoint exists') return None tf.compat.v1.logging.info('Initializing model from path: %s', tf_initial_checkpoint) # Variables that will not be restored. exclude_list = ['global_step', 'logits'] if not initialize_last_layer: exclude_list.extend(last_layers) variables_to_restore = contrib_framework.get_variables_to_restore( exclude=exclude_list) if variables_to_restore: init_op, init_feed_dict = contrib_framework.assign_from_checkpoint( tf_initial_checkpoint, variables_to_restore, ignore_missing_vars=ignore_missing_vars) global_step = tf.compat.v1.train.get_or_create_global_step() def restore_fn(sess): sess.run(init_op, init_feed_dict) sess.run([global_step]) return restore_fn return None
def train(train_dir, config, dataset_fn, checkpoints_to_keep=5, keep_checkpoint_every_n_hours=1, num_steps=None, master='', num_sync_workers=0, num_ps_tasks=0, task=0): """Train loop.""" tf.gfile.MakeDirs(train_dir) is_chief = (task == 0) if is_chief: _trial_summary(config.hparams, config.train_examples_path or config.tfds_name, train_dir) with tf.Graph().as_default(): with tf.device( tf.train.replica_device_setter(num_ps_tasks, merge_devices=True)): model = config.model model.build(config.hparams, config.data_converter.output_depth, encoder_train=config.encoder_train, decoder_train=config.decoder_train) optimizer = model.train(**_get_input_tensors(dataset_fn(), config)) restored_vars = _get_restore_vars(config.var_train_pattern) _set_trainable_vars(config.var_train_pattern) hooks = [] if num_sync_workers: optimizer = tf.train.SyncReplicasOptimizer( optimizer, num_sync_workers) hooks.append(optimizer.make_session_run_hook(is_chief)) grads, var_list = zip(*optimizer.compute_gradients(model.loss)) global_norm = tf.global_norm(grads) tf.summary.scalar('global_norm', global_norm) if config.hparams.clip_mode == 'value': g = config.hparams.grad_clip clipped_grads = [ tf.clip_by_value(grad, -g, g) for grad in grads ] elif config.hparams.clip_mode == 'global_norm': clipped_grads = tf.cond( global_norm < config.hparams.grad_norm_clip_to_zero, lambda: tf.clip_by_global_norm(grads, config.hparams.grad_clip, use_norm=global_norm)[0], lambda: [tf.zeros(tf.shape(g)) for g in grads]) else: raise ValueError('Unknown clip_mode: {}'.format( config.hparams.clip_mode)) train_op = optimizer.apply_gradients(zip(clipped_grads, var_list), global_step=model.global_step, name='train_step') logging_dict = { 'global_step': model.global_step, 'loss': model.loss } hooks.append( tf.train.LoggingTensorHook(logging_dict, every_n_iter=5)) if num_steps: hooks.append(tf.train.StopAtStepHook(last_step=num_steps)) variables_to_restore = contrib_framework.get_variables_to_restore( include=[v.name for v in restored_vars]) init_assign_op, init_feed_dict = contrib_framework.assign_from_checkpoint( config.pretrained_path, variables_to_restore) def InitAssignFn(scaffold, sess): sess.run(init_assign_op, init_feed_dict) scaffold = tf.train.Scaffold( init_fn=InitAssignFn, saver=tf.train.Saver( max_to_keep=checkpoints_to_keep, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, )) contrib_training.train(train_op=train_op, logdir=train_dir, scaffold=scaffold, hooks=hooks, save_checkpoint_secs=60, master=master, is_chief=is_chief)
def get_model_init_fn(train_logdir, tf_initial_checkpoint, initialize_first_layer, initialize_last_layer, last_layers=None, restore_adam=False, ignore_missing_vars=False): """Gets the function initializing model variables from a checkpoint. Args: train_logdir: Log directory for training. tf_initial_checkpoint: TensorFlow checkpoint for initialization. initialize_last_layer: Initialize first layer or not. initialize_last_layer: Initialize last layer or not. last_layers: Last layers of the model. restore_adam: Restore Adam optimization parameters or not. ignore_missing_vars: Ignore missing variables in the checkpoint. Returns: Initialization function. """ if tf_initial_checkpoint is None: tf.logging.info('Not initializing the model from a checkpoint.') return None if tf.train.latest_checkpoint(train_logdir): tf.logging.info('Ignoring initialization; other checkpoint exists') return None tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint) # Variables that will not be restored. exclude_list = ['global_step'] if not initialize_last_layer: exclude_list.extend(last_layers) if not initialize_first_layer: exclude_list.append('resnet_v1_50/conv1_1/weights:0') variables_to_restore = contrib_framework.get_variables_to_restore( exclude=exclude_list) # Restore without Adam parameters if not restore_adam: new_v = [] for v in variables_to_restore: if "Adam" not in v.name: new_v.append(v) variables_to_restore = new_v if variables_to_restore: init_op, init_feed_dict = contrib_framework.assign_from_checkpoint( tf_initial_checkpoint, variables_to_restore, ignore_missing_vars=ignore_missing_vars) global_step = tf.train.get_or_create_global_step() def restore_fn(scaffold, sess): sess.run(init_op, init_feed_dict) sess.run([global_step]) return restore_fn return None