def build_hparams(cell_name='amoeba_net_d'):
  """Build tf.Hparams for training Amoeba Net.

  Args:
    cell_name:         Which of the cells in model_specs.py to use to build the
                       amoebanet neural network; the cell names defined in that
                       module correspond to architectures discovered by an
                       evolutionary search described in
                       https://arxiv.org/abs/1802.01548.

  Returns:
    A set of tf.HParams suitable for Amoeba Net training.
  """
  hparams = imagenet_hparams()
  operations, hiddenstate_indices, used_hiddenstates = (
      model_specs.get_normal_cell(cell_name))
  hparams.add_hparam('normal_cell_operations', operations)
  hparams.add_hparam('normal_cell_hiddenstate_indices',
                     hiddenstate_indices)
  hparams.add_hparam('normal_cell_used_hiddenstates',
                     used_hiddenstates)
  operations, hiddenstate_indices, used_hiddenstates = (
      model_specs.get_reduction_cell(cell_name))
  hparams.add_hparam('reduction_cell_operations',
                     operations)
  hparams.add_hparam('reduction_cell_hiddenstate_indices',
                     hiddenstate_indices)
  hparams.add_hparam('reduction_cell_used_hiddenstates',
                     used_hiddenstates)
  hparams.set_hparam('data_format', 'NHWC')
  return hparams
示例#2
0
def build_hparams():
  """Build tf.Hparams for training Amoeba Net."""
  hparams = model_lib.imagenet_hparams()
  hparams.add_hparam('reduction_size', FLAGS.reduction_size)
  operations, hiddenstate_indices, used_hiddenstates = (
      model_specs.get_normal_cell(FLAGS.cell_name))
  hparams.add_hparam('normal_cell_operations', operations)
  hparams.add_hparam('normal_cell_hiddenstate_indices',
                     hiddenstate_indices)
  hparams.add_hparam('normal_cell_used_hiddenstates',
                     used_hiddenstates)
  operations, hiddenstate_indices, used_hiddenstates = (
      model_specs.get_reduction_cell(FLAGS.cell_name))
  hparams.add_hparam('reduction_cell_operations',
                     operations)
  hparams.add_hparam('reduction_cell_hiddenstate_indices',
                     hiddenstate_indices)
  hparams.add_hparam('reduction_cell_used_hiddenstates',
                     used_hiddenstates)
  hparams.add_hparam('stem_reduction_size', FLAGS.stem_reduction_size)

  hparams.set_hparam('data_format', 'NHWC')
  override_with_flags(hparams)

  return hparams