예제 #1
0
def main(unused_argv):
    params = resnet_params.from_file(FLAGS.param_file)
    params = resnet_params.override(params, FLAGS.param_overrides)
    resnet_params.log_hparams_to_model_dir(params, FLAGS.model_dir)
    tf.logging.info('Model params: {}'.format(params))

    if params['use_async_checkpointing']:
        save_checkpoints_steps = None
    else:
        save_checkpoints_steps = FLAGS.save_checkpoints_steps

    # TO BE MODIFIED
    config = tf.contrib.tpu.RunConfig(
        cluster='',
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=save_checkpoints_steps,
        keep_checkpoint_max=50,
        log_step_count_steps=FLAGS.log_step_count_steps,
        session_config=tf.ConfigProto(graph_options=tf.GraphOptions(
            rewrite_options=rewriter_config_pb2.RewriterConfig(
                disable_meta_optimizer=True))))

    if FLAGS.warm_start_ckpt_path:
        var_names = []
        checkpoint_path = FLAGS.warm_start_ckpt_path
        reader = tf.train.NewCheckpointReader(checkpoint_path)
        for key in reader.get_variable_to_shape_map():
            extra_str = ''
            keep_str = 'Momentum|global_step|finetune_global_step'
            if not re.findall('({}{})'.format(keep_str, extra_str), key):
                var_names.append(key)

        tf.logging.info('Warm-starting tensors: %s', sorted(var_names))

        vars_to_warm_start = var_names
        warm_start_settings = tf.estimator.WarmStartSettings(
            ckpt_to_initialize_from=checkpoint_path,
            vars_to_warm_start=vars_to_warm_start)
    else:
        warm_start_settings = None

    # TO BE MODIFIED
    resnet_classifier = tf.estimator.Estimator(
        model_fn=train_model_fn,
        config=config,
        params=params,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=1024,
        warm_start_from=warm_start_settings)

    use_bfloat16 = params['precision'] == 'bfloat16'

    num_classes = get_src_num_classes()

    def make_input_dataset(params):
        """return input dataset."""
        def _merge_datasets(train_batch, finetune_batch, target_batch):
            """merge different splits."""
            train_features, train_labels = train_batch
            finetune_features, finetune_labels = finetune_batch
            target_features, target_labels = target_batch
            features = {
                'src': train_features,
                'finetune': finetune_features,
                'target': target_features
            }
            labels = {
                'src': train_labels,
                'finetune': finetune_labels,
                'target': target_labels
            }
            return (features, labels)

        # TO BE MODIFIED
        data_dir = ''

        num_parallel_calls = 8
        src_train = data_input.ImageNetInput(
            dataset_name=FLAGS.source_dataset,
            is_training=True,
            data_dir=data_dir,
            transpose_input=params['transpose_input'],
            cache=False,
            image_size=params['image_size'],
            num_parallel_calls=num_parallel_calls,
            use_bfloat16=use_bfloat16,
            num_classes=num_classes)
        finetune_dataset = data_input.ImageNetInput(
            dataset_name=FLAGS.target_dataset,
            task_id=1,
            is_training=True,
            data_dir=data_dir,
            dataset_split='l2l_train',
            transpose_input=params['transpose_input'],
            cache=False,
            image_size=params['image_size'],
            num_parallel_calls=num_parallel_calls,
            use_bfloat16=use_bfloat16)
        target_dataset = data_input.ImageNetInput(
            dataset_name=FLAGS.target_dataset,
            task_id=2,
            is_training=True,
            data_dir=data_dir,
            dataset_split='l2l_valid',
            transpose_input=params['transpose_input'],
            cache=False,
            image_size=params['image_size'],
            num_parallel_calls=num_parallel_calls,
            use_bfloat16=use_bfloat16)

        train_params = dict(params)
        train_params['batch_size'] = int(
            round(params['batch_size'] * FLAGS.train_batch_size_multiplier))

        train_data = src_train.input_fn(train_params)

        target_train_params = dict(params)
        target_train_params['batch_size'] = int(
            round(params['batch_size'] * FLAGS.target_train_batch_multiplier))
        finetune_data = finetune_dataset.input_fn(target_train_params)

        target_params = dict(params)
        target_params['batch_size'] = int(
            round(params['batch_size'] * FLAGS.target_batch_multiplier))

        target_data = target_dataset.input_fn(target_params)
        dataset = tf.data.Dataset.zip((train_data, finetune_data, target_data))
        dataset = dataset.map(_merge_datasets)
        dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
        return dataset

    if FLAGS.mode == 'train':
        max_train_steps = FLAGS.train_steps

        resnet_classifier.train(make_input_dataset, max_steps=max_train_steps)
예제 #2
0
def main(unused_argv):
    params = resnet_params.from_file(FLAGS.param_file)
    params = resnet_params.override(params, FLAGS.param_overrides)

    params['batch_size'] = FLAGS.target_batch_size

    resnet_params.log_hparams_to_model_dir(params, FLAGS.model_dir)
    print('Model params: {}'.format(params))

    if params['use_async_checkpointing']:
        save_checkpoints_steps = None
    else:
        save_checkpoints_steps = FLAGS.pre_train_steps + FLAGS.finetune_steps + FLAGS.ctrl_steps
        save_checkpoints_steps = max(1000, params['iterations_per_loop'])
    run_config_args = {
        'model_dir': FLAGS.model_dir,
        'save_checkpoints_steps': save_checkpoints_steps,
        'log_step_count_steps': FLAGS.log_step_count_steps,
        'keep_checkpoint_max': 100,
    }

    run_config_args['master'] = FLAGS.master
    config = tf.contrib.learn.RunConfig(**run_config_args)

    resnet_classifier = tf.estimator.Estimator(get_model_fn(config),
                                               config=config)

    use_bfloat16 = params['precision'] == 'bfloat16'

    def _merge_datasets(train_batch):
        feature, label = train_batch
        features = {
            'feature': feature,
        }
        labels = {
            'label': label,
        }
        return (features, labels)

    def make_input_dataset(params):
        """Returns input dataset."""
        finetune_dataset = data_input.ImageNetInput(
            dataset_name=FLAGS.target_dataset,
            num_classes=data_input.num_classes_map[FLAGS.target_dataset],
            task_id=1,
            is_training=True,
            data_dir=FLAGS.data_dir,
            dataset_split=FLAGS.dataset_split,
            transpose_input=params['transpose_input'],
            cache=False,
            image_size=params['image_size'],
            num_parallel_calls=params['num_parallel_calls'],
            use_bfloat16=use_bfloat16)
        finetune_data = finetune_dataset.input_fn(params)
        dataset = tf.data.Dataset.zip((finetune_data, ))
        dataset = dataset.map(_merge_datasets)
        dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
        return dataset

    # pylint: disable=protected-access
    current_step = estimator._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)

    train_steps = FLAGS.train_steps
    while current_step < train_steps:
        next_checkpoint = train_steps
        resnet_classifier.train(input_fn=make_input_dataset,
                                max_steps=next_checkpoint)
        current_step = next_checkpoint