def main(unused_argv): params = resnet_params.from_file(FLAGS.param_file) params = resnet_params.override(params, FLAGS.param_overrides) resnet_params.log_hparams_to_model_dir(params, FLAGS.model_dir) tf.logging.info('Model params: {}'.format(params)) if params['use_async_checkpointing']: save_checkpoints_steps = None else: save_checkpoints_steps = FLAGS.save_checkpoints_steps # TO BE MODIFIED config = tf.contrib.tpu.RunConfig( cluster='', model_dir=FLAGS.model_dir, save_checkpoints_steps=save_checkpoints_steps, keep_checkpoint_max=50, log_step_count_steps=FLAGS.log_step_count_steps, session_config=tf.ConfigProto(graph_options=tf.GraphOptions( rewrite_options=rewriter_config_pb2.RewriterConfig( disable_meta_optimizer=True)))) if FLAGS.warm_start_ckpt_path: var_names = [] checkpoint_path = FLAGS.warm_start_ckpt_path reader = tf.train.NewCheckpointReader(checkpoint_path) for key in reader.get_variable_to_shape_map(): extra_str = '' keep_str = 'Momentum|global_step|finetune_global_step' if not re.findall('({}{})'.format(keep_str, extra_str), key): var_names.append(key) tf.logging.info('Warm-starting tensors: %s', sorted(var_names)) vars_to_warm_start = var_names warm_start_settings = tf.estimator.WarmStartSettings( ckpt_to_initialize_from=checkpoint_path, vars_to_warm_start=vars_to_warm_start) else: warm_start_settings = None # TO BE MODIFIED resnet_classifier = tf.estimator.Estimator( model_fn=train_model_fn, config=config, params=params, train_batch_size=FLAGS.train_batch_size, eval_batch_size=1024, warm_start_from=warm_start_settings) use_bfloat16 = params['precision'] == 'bfloat16' num_classes = get_src_num_classes() def make_input_dataset(params): """return input dataset.""" def _merge_datasets(train_batch, finetune_batch, target_batch): """merge different splits.""" train_features, train_labels = train_batch finetune_features, finetune_labels = finetune_batch target_features, target_labels = target_batch features = { 'src': train_features, 'finetune': finetune_features, 'target': target_features } labels = { 'src': train_labels, 'finetune': finetune_labels, 'target': target_labels } return (features, labels) # TO BE MODIFIED data_dir = '' num_parallel_calls = 8 src_train = data_input.ImageNetInput( dataset_name=FLAGS.source_dataset, is_training=True, data_dir=data_dir, transpose_input=params['transpose_input'], cache=False, image_size=params['image_size'], num_parallel_calls=num_parallel_calls, use_bfloat16=use_bfloat16, num_classes=num_classes) finetune_dataset = data_input.ImageNetInput( dataset_name=FLAGS.target_dataset, task_id=1, is_training=True, data_dir=data_dir, dataset_split='l2l_train', transpose_input=params['transpose_input'], cache=False, image_size=params['image_size'], num_parallel_calls=num_parallel_calls, use_bfloat16=use_bfloat16) target_dataset = data_input.ImageNetInput( dataset_name=FLAGS.target_dataset, task_id=2, is_training=True, data_dir=data_dir, dataset_split='l2l_valid', transpose_input=params['transpose_input'], cache=False, image_size=params['image_size'], num_parallel_calls=num_parallel_calls, use_bfloat16=use_bfloat16) train_params = dict(params) train_params['batch_size'] = int( round(params['batch_size'] * FLAGS.train_batch_size_multiplier)) train_data = src_train.input_fn(train_params) target_train_params = dict(params) target_train_params['batch_size'] = int( round(params['batch_size'] * FLAGS.target_train_batch_multiplier)) finetune_data = finetune_dataset.input_fn(target_train_params) target_params = dict(params) target_params['batch_size'] = int( round(params['batch_size'] * FLAGS.target_batch_multiplier)) target_data = target_dataset.input_fn(target_params) dataset = tf.data.Dataset.zip((train_data, finetune_data, target_data)) dataset = dataset.map(_merge_datasets) dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) return dataset if FLAGS.mode == 'train': max_train_steps = FLAGS.train_steps resnet_classifier.train(make_input_dataset, max_steps=max_train_steps)
def main(unused_argv): params = resnet_params.from_file(FLAGS.param_file) params = resnet_params.override(params, FLAGS.param_overrides) params['batch_size'] = FLAGS.target_batch_size resnet_params.log_hparams_to_model_dir(params, FLAGS.model_dir) print('Model params: {}'.format(params)) if params['use_async_checkpointing']: save_checkpoints_steps = None else: save_checkpoints_steps = FLAGS.pre_train_steps + FLAGS.finetune_steps + FLAGS.ctrl_steps save_checkpoints_steps = max(1000, params['iterations_per_loop']) run_config_args = { 'model_dir': FLAGS.model_dir, 'save_checkpoints_steps': save_checkpoints_steps, 'log_step_count_steps': FLAGS.log_step_count_steps, 'keep_checkpoint_max': 100, } run_config_args['master'] = FLAGS.master config = tf.contrib.learn.RunConfig(**run_config_args) resnet_classifier = tf.estimator.Estimator(get_model_fn(config), config=config) use_bfloat16 = params['precision'] == 'bfloat16' def _merge_datasets(train_batch): feature, label = train_batch features = { 'feature': feature, } labels = { 'label': label, } return (features, labels) def make_input_dataset(params): """Returns input dataset.""" finetune_dataset = data_input.ImageNetInput( dataset_name=FLAGS.target_dataset, num_classes=data_input.num_classes_map[FLAGS.target_dataset], task_id=1, is_training=True, data_dir=FLAGS.data_dir, dataset_split=FLAGS.dataset_split, transpose_input=params['transpose_input'], cache=False, image_size=params['image_size'], num_parallel_calls=params['num_parallel_calls'], use_bfloat16=use_bfloat16) finetune_data = finetune_dataset.input_fn(params) dataset = tf.data.Dataset.zip((finetune_data, )) dataset = dataset.map(_merge_datasets) dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) return dataset # pylint: disable=protected-access current_step = estimator._load_global_step_from_checkpoint_dir( FLAGS.model_dir) train_steps = FLAGS.train_steps while current_step < train_steps: next_checkpoint = train_steps resnet_classifier.train(input_fn=make_input_dataset, max_steps=next_checkpoint) current_step = next_checkpoint