def _subnetwork_generator(checkpoint_dir): hparams = tf.contrib.training.HParams( clip_gradients=5., stem_multiplier=3.0, drop_path_keep_prob=0.6, num_cells=3, use_aux_head=False, aux_head_weight=0.4, label_smoothing=0.1, num_conv_filters=4, dense_dropout_keep_prob=1.0, filter_scaling_rate=2.0, complexity_decay_rate=0.9, num_reduction_layers=2, data_format="NHWC", skip_reduction_layer_input=0, initial_learning_rate=.01, use_bounded_activation=False, weight_decay=.0001, knowledge_distillation=improve_nas.KnowledgeDistillation.NONE, snapshot=False, learn_mixture_weights=False, mixture_weight_type=adanet.MixtureWeightType.SCALAR, model_version="cifar", total_training_steps=100) return improve_nas.Generator( [tf.feature_column.numeric_column(key="x", shape=[32, 32, 3])], seed=11, optimizer_fn=_optimizer, iteration_steps=3, checkpoint_dir=checkpoint_dir, hparams=hparams)
def estimator(self, data_provider, run_config, hparams, train_steps=None, seed=None): """Returns an AdaNet `Estimator` for train and evaluation. Args: data_provider: Data `Provider` for dataset to model. run_config: `RunConfig` object to configure the runtime settings. hparams: `HParams` instance defining custom hyperparameters. train_steps: number of train steps. seed: An integer seed if determinism is required. Returns: Returns an `Estimator`. """ max_iteration_steps = int(train_steps / hparams.boosting_iterations) optimizer_fn = optimizer.fn_with_name( hparams.optimizer, learning_rate_schedule=hparams.learning_rate_schedule, cosine_decay_steps=max_iteration_steps) hparams.add_hparam("total_training_steps", max_iteration_steps) if hparams.generator == GeneratorType.SIMPLE: subnetwork_generator = improve_nas.Generator( feature_columns=data_provider.get_feature_columns(), optimizer_fn=optimizer_fn, iteration_steps=max_iteration_steps, checkpoint_dir=run_config.model_dir, hparams=hparams, seed=seed) elif hparams.generator == GeneratorType.DYNAMIC: subnetwork_generator = improve_nas.DynamicGenerator( feature_columns=data_provider.get_feature_columns(), optimizer_fn=optimizer_fn, iteration_steps=max_iteration_steps, checkpoint_dir=run_config.model_dir, hparams=hparams, seed=seed) else: raise ValueError("Invalid generator: `%s`" % hparams.generator) evaluator = None if hparams.use_evaluator: evaluator = adanet.Evaluator(input_fn=data_provider.get_input_fn( partition="train", mode=tf.estimator.ModeKeys.EVAL, batch_size=hparams.evaluator_batch_size), steps=hparams.evaluator_steps) return adanet.Estimator( head=data_provider.get_head(), subnetwork_generator=subnetwork_generator, max_iteration_steps=max_iteration_steps, adanet_lambda=hparams.adanet_lambda, adanet_beta=hparams.adanet_beta, mixture_weight_type=hparams.mixture_weight_type, force_grow=hparams.force_grow, evaluator=evaluator, config=run_config, model_dir=run_config.model_dir)