示例#1
0
def _subnetwork_generator(checkpoint_dir):
    hparams = tf.contrib.training.HParams(
        clip_gradients=5.,
        stem_multiplier=3.0,
        drop_path_keep_prob=0.6,
        num_cells=3,
        use_aux_head=False,
        aux_head_weight=0.4,
        label_smoothing=0.1,
        num_conv_filters=4,
        dense_dropout_keep_prob=1.0,
        filter_scaling_rate=2.0,
        complexity_decay_rate=0.9,
        num_reduction_layers=2,
        data_format="NHWC",
        skip_reduction_layer_input=0,
        initial_learning_rate=.01,
        use_bounded_activation=False,
        weight_decay=.0001,
        knowledge_distillation=improve_nas.KnowledgeDistillation.NONE,
        snapshot=False,
        learn_mixture_weights=False,
        mixture_weight_type=adanet.MixtureWeightType.SCALAR,
        model_version="cifar",
        total_training_steps=100)
    return improve_nas.Generator(
        [tf.feature_column.numeric_column(key="x", shape=[32, 32, 3])],
        seed=11,
        optimizer_fn=_optimizer,
        iteration_steps=3,
        checkpoint_dir=checkpoint_dir,
        hparams=hparams)
示例#2
0
    def estimator(self,
                  data_provider,
                  run_config,
                  hparams,
                  train_steps=None,
                  seed=None):
        """Returns an AdaNet `Estimator` for train and evaluation.

    Args:
      data_provider: Data `Provider` for dataset to model.
      run_config: `RunConfig` object to configure the runtime settings.
      hparams: `HParams` instance defining custom hyperparameters.
      train_steps: number of train steps.
      seed: An integer seed if determinism is required.

    Returns:
      Returns an `Estimator`.
    """

        max_iteration_steps = int(train_steps / hparams.boosting_iterations)

        optimizer_fn = optimizer.fn_with_name(
            hparams.optimizer,
            learning_rate_schedule=hparams.learning_rate_schedule,
            cosine_decay_steps=max_iteration_steps)
        hparams.add_hparam("total_training_steps", max_iteration_steps)

        if hparams.generator == GeneratorType.SIMPLE:
            subnetwork_generator = improve_nas.Generator(
                feature_columns=data_provider.get_feature_columns(),
                optimizer_fn=optimizer_fn,
                iteration_steps=max_iteration_steps,
                checkpoint_dir=run_config.model_dir,
                hparams=hparams,
                seed=seed)
        elif hparams.generator == GeneratorType.DYNAMIC:
            subnetwork_generator = improve_nas.DynamicGenerator(
                feature_columns=data_provider.get_feature_columns(),
                optimizer_fn=optimizer_fn,
                iteration_steps=max_iteration_steps,
                checkpoint_dir=run_config.model_dir,
                hparams=hparams,
                seed=seed)
        else:
            raise ValueError("Invalid generator: `%s`" % hparams.generator)

        evaluator = None
        if hparams.use_evaluator:
            evaluator = adanet.Evaluator(input_fn=data_provider.get_input_fn(
                partition="train",
                mode=tf.estimator.ModeKeys.EVAL,
                batch_size=hparams.evaluator_batch_size),
                                         steps=hparams.evaluator_steps)

        return adanet.Estimator(
            head=data_provider.get_head(),
            subnetwork_generator=subnetwork_generator,
            max_iteration_steps=max_iteration_steps,
            adanet_lambda=hparams.adanet_lambda,
            adanet_beta=hparams.adanet_beta,
            mixture_weight_type=hparams.mixture_weight_type,
            force_grow=hparams.force_grow,
            evaluator=evaluator,
            config=run_config,
            model_dir=run_config.model_dir)