def build_hparams(cell_name='amoeba_net_d'): """Build tf.Hparams for training Amoeba Net. Args: cell_name: Which of the cells in model_specs.py to use to build the amoebanet neural network; the cell names defined in that module correspond to architectures discovered by an evolutionary search described in https://arxiv.org/abs/1802.01548. Returns: A set of tf.HParams suitable for Amoeba Net training. """ hparams = imagenet_hparams() operations, hiddenstate_indices, used_hiddenstates = ( model_specs.get_normal_cell(cell_name)) hparams.add_hparam('normal_cell_operations', operations) hparams.add_hparam('normal_cell_hiddenstate_indices', hiddenstate_indices) hparams.add_hparam('normal_cell_used_hiddenstates', used_hiddenstates) operations, hiddenstate_indices, used_hiddenstates = ( model_specs.get_reduction_cell(cell_name)) hparams.add_hparam('reduction_cell_operations', operations) hparams.add_hparam('reduction_cell_hiddenstate_indices', hiddenstate_indices) hparams.add_hparam('reduction_cell_used_hiddenstates', used_hiddenstates) hparams.set_hparam('data_format', 'NHWC') return hparams
def build_hparams(): """Build tf.Hparams for training Amoeba Net.""" hparams = model_lib.imagenet_hparams() hparams.add_hparam('reduction_size', FLAGS.reduction_size) operations, hiddenstate_indices, used_hiddenstates = ( model_specs.get_normal_cell(FLAGS.cell_name)) hparams.add_hparam('normal_cell_operations', operations) hparams.add_hparam('normal_cell_hiddenstate_indices', hiddenstate_indices) hparams.add_hparam('normal_cell_used_hiddenstates', used_hiddenstates) operations, hiddenstate_indices, used_hiddenstates = ( model_specs.get_reduction_cell(FLAGS.cell_name)) hparams.add_hparam('reduction_cell_operations', operations) hparams.add_hparam('reduction_cell_hiddenstate_indices', hiddenstate_indices) hparams.add_hparam('reduction_cell_used_hiddenstates', used_hiddenstates) hparams.add_hparam('stem_reduction_size', FLAGS.stem_reduction_size) hparams.set_hparam('data_format', 'NHWC') override_with_flags(hparams) return hparams