Ejemplo n.º 1
0
  def test_enum_vals(self, ks, vs):
    """Setup ensures that the values are unique."""
    m = dict(zip(ks, vs))
    enum = Enum('TestEnum', m)

    # enum_vals returns the values from the enum.
    self.assertListEqual(list(m.values()), u.enum_vals(enum))
Ejemplo n.º 2
0
def test_enum_vals(ks, vs):
    """Setup ensures that the values are unique."""
    m = dict(zip(ks, vs))
    enum = Enum('TestEnum', m)

    # enum_vals returns the values from the enum.
    assert list(m.values()) == u.enum_vals(enum)
Ejemplo n.º 3
0
def region_arg(parser):
    regions = u.enum_vals(ct.valid_regions())
    parser.add_argument(
        "--region",
        type=ct.parse_region,
        help="Region to use for Cloud job submission and image persistence. " +
        "Must be one of {}. ".format(regions) +
        "(Defaults to $REGION or '{}'.)".format(conf.DEFAULT_REGION.value))
Ejemplo n.º 4
0
def machine_type_arg(parser):
    machine_types = u.enum_vals(ct.MachineType)
    cpu_default = conf.DEFAULT_MACHINE_TYPE[conf.JobMode.CPU].value
    gpu_default = conf.DEFAULT_MACHINE_TYPE[conf.JobMode.GPU].value

    parser.add_argument("--machine_type",
                        type=ct.parse_machine_type,
                        help="Cloud machine type to request. Must be one of " +
                        "{}. Defaults to '{}' in GPU mode, or '{}' ".format(
                            machine_types, gpu_default, cpu_default) +
                        "if --nogpu is passed.")
Ejemplo n.º 5
0
def parse_region(s: str) -> Region:
    """Attempts to parse the string into a valid region; raises a sensible argparse
  error if that's not possible.

  """
    try:
        return u.any_of(s, Region)
    except ValueError:
        valid_values = u.enum_vals(valid_regions())
        raise argparse.ArgumentTypeError("'{}' isn't a valid region. \
Must be one of {}.".format(s, valid_values))
Ejemplo n.º 6
0
def parse_machine_type(s: str) -> MachineType:
    """Attempts to parse the string into a valid machine type; raises a sensible
  argparse error if that's not possible.

  """
    try:
        return MachineType(s)
    except ValueError:
        valid_values = u.enum_vals(MachineType)
        raise argparse.ArgumentTypeError("'{}' isn't a valid machine type. \
Must be one of {}.".format(s, valid_values))
Ejemplo n.º 7
0
def _validate_machine_type(gpu_spec: Optional[ct.GPUSpec],
                           machine_type: Optional[ct.MachineType]):
    """If both args are provided,makes sure that Cloud supports this particular
  combination of GPU count, type and machine type.

  """
    if gpu_spec is not None and machine_type is not None:
        if not gpu_spec.valid_machine_type(machine_type):
            # Show a list of the allowed types, sorted so that at least the machine
            # prefixes stick together.
            allowed = u.enum_vals(gpu_spec.allowed_machine_types())
            allowed.sort()
            u.err(f"\n'{machine_type.value}' isn't a valid machine type " +
                  f"for {gpu_spec.count} {gpu_spec.gpu.name} GPUs.\n\n")
            u.err(ct.with_advice_suffix("gpu", f"Try one of these: {allowed}"))
            u.err("\n")
            sys.exit(1)
Ejemplo n.º 8
0
def _validate_accelerator_region(spec: Optional[Union[ct.GPUSpec, ct.TPUSpec]],
                                 region: ct.Region):
  """Check that the supplied region is valid for the accelerator specification,
  if supplied.

  """
  if spec is not None:
    accel = spec.accelerator_type

    if not spec.valid_region(region):
      # Show a list of the allowed types, sorted so that at least the machine
      # prefixes stick together.
      allowed = u.enum_vals(spec.allowed_regions())
      allowed.sort()
      u.err("\n'{}' isn't a valid region ".format(region.value) +
            "for {}s of type {}.\n\n".format(accel, spec.name))
      u.err("Try one of these: {}\n\n".format(allowed))
      u.err("See this page for more info about regional " +
            "support for {}s: https://cloud.google.com/ml-engine/docs/regions\n"
            .format(accel))
      sys.exit(1)