示例#1
0
def get_config(method):
    """Default config."""
    batch_size = config_dict.FieldReference(64)
    alpha = config_dict.FieldReference(0.001)
    num_tasks = config_dict.FieldReference(50)

    env_config = config_dict.ConfigDict({
        'name':
        'pw',
        'task':
        'random_policy',
        'pw':
        config_dict.ConfigDict({
            'arena': 'sutton',
            'num_bins': 20,
            'rollout_length': 50,
            'gamma': 0.9,
            'samples_for_ground_truth': 100_000
        }),
        'gym':
        config_dict.ConfigDict({
            'id': None,
        }),
        'random':
        config_dict.ConfigDict({
            'num_states': 2048,
            'num_actions': 5
        })
    })
def main(_):
    placeholder = config_dict.FieldReference(0)
    cfg = config_dict.ConfigDict()
    cfg.placeholder = placeholder
    cfg.optional = config_dict.FieldReference(0, field_type=int)
    cfg.nested = config_dict.ConfigDict()
    cfg.nested.placeholder = placeholder

    try:
        cfg.optional = 'tom'  # Raises Type error as this field is an integer.
    except TypeError as e:
        print(e)

    cfg.optional = 1555  # Works fine.
    cfg.placeholder = 1  # Changes the value of both placeholder and
    # nested.placeholder fields.

    # Note that the indirection provided by FieldReferences will be lost if
    # accessed through a ConfigDict:
    placeholder = config_dict.FieldReference(0)
    cfg.field1 = placeholder
    cfg.field2 = placeholder  # This field will be tied to cfg.field1.
    cfg.field3 = cfg.field1  # This will just be an int field initialized to 0.

    print(cfg)
示例#3
0
def get_config():
    config = config_dict.ConfigDict()
    config.field1 = 1
    config.field2 = 'tom'
    config.nested = config_dict.ConfigDict()
    config.nested.field = 2.23
    config.tuple = (1, 2, 3)
    return config
示例#4
0
def get_config():
    """Returns a ConfigDict. Used for tests."""
    cfg = config_dict.ConfigDict()
    cfg.integer = 1
    cfg.reference = config_dict.FieldReference(1)
    cfg.list = [1, 2, 3]
    cfg.nested_list = [[1, 2, 3]]
    cfg.nested_configdict = config_dict.ConfigDict()
    cfg.nested_configdict.integer = 1
    cfg.unusable_config = UnusableConfig()

    return cfg
示例#5
0
def get_base_config():
  """Returns base config object for an experiment."""
  config = config_dict.ConfigDict()
  config.experiment_kwargs = config_dict.ConfigDict()

  config.training_steps = 10000  # Number of training steps.

  config.interval_type = "secs"
  config.save_checkpoint_interval = 300
  config.log_tensors_interval = 60
  config.log_train_data_interval = 120.0  # None to turn off

  # Overrides of `interval_type` for specific periodic operations. If `None`,
  # we use the value of `interval_type`.
  config.logging_interval_type = None
  config.checkpoint_interval_type = None

  # If set to True we checkpoint on all hosts, which may be useful
  # for model parallelism. Otherwise we checkpoint on host 0.
  config.train_checkpoint_all_hosts = False

  # If True, asynchronously logs training data from every training step.
  config.log_all_train_data = False

  # If true, run evaluate() on the experiment once before you load a checkpoint.
  # This is useful for getting initial values of metrics at random weights, or
  # when debugging locally if you do not have any train job running.
  config.eval_initial_weights = False

  # When True, the eval job immediately loads a checkpoint runs evaluate()
  # once, then terminates.
  config.one_off_evaluate = False

  # Number of checkpoints to keep by default
  config.max_checkpoints_to_keep = 5

  # Settings for the RNGs used during training and evaluation.
  config.random_seed = 42
  config.random_mode_train = "unique_host_unique_device"
  config.random_mode_eval = "same_host_same_device"

  # The metric (returned by the step function) used as a fitness score.
  # It saves a separate series of checkpoints corresponding to
  # those which produce a better fitness score than previously seen.
  # By default it is assumed that higher is better, but this behaviour can be
  # changed to lower is better, i.e. behaving as a loss score, by setting
  # `best_model_eval_metric_higher_is_better = False`.
  # If `best_model_eval_metric` is empty (the default), best checkpointing is
  # disabled.
  config.best_model_eval_metric = ""
  config.best_model_eval_metric_higher_is_better = True

  return config
示例#6
0
def main(_):
    print_section('Attribute Types.')
    cfg = config_dict.ConfigDict()
    cfg.int = 1
    cfg.list = [1, 2, 3]
    cfg.tuple = (1, 2, 3)
    cfg.set = {1, 2, 3}
    cfg.frozenset = frozenset({1, 2, 3})
    cfg.dict = {
        'nested_int': 4,
        'nested_list': [4, 5, 6],
        'nested_tuple': ([4], 5, 6),
    }

    print('Types of cfg fields:')
    print('list: ', type(cfg.list))  # List
    print('set: ', type(cfg.set))  # Set
    print('nested_list: ', type(cfg.dict.nested_list))  # List
    print('nested_tuple[0]: ', type(cfg.dict.nested_tuple[0]))  # List

    frozen_cfg = config_dict.FrozenConfigDict(cfg)
    print('\nTypes of FrozenConfigDict(cfg) fields:')
    print('list: ', type(frozen_cfg.list))  # Tuple
    print('set: ', type(frozen_cfg.set))  # Frozenset
    print('nested_list: ', type(frozen_cfg.dict.nested_list))  # Tuple
    print('nested_tuple[0]: ', type(frozen_cfg.dict.nested_tuple[0]))  # Tuple

    cfg_from_frozen = config_dict.ConfigDict(frozen_cfg)
    print('\nTypes of ConfigDict(FrozenConfigDict(cfg)) fields:')
    print('list: ', type(cfg_from_frozen.list))  # List
    print('set: ', type(cfg_from_frozen.set))  # Set
    print('nested_list: ', type(cfg_from_frozen.dict.nested_list))  # List
    print('nested_tuple[0]: ',
          type(cfg_from_frozen.dict.nested_tuple[0]))  # List

    print(
        '\nCan use FrozenConfigDict.as_configdict() to convert to ConfigDict:')
    print(cfg_from_frozen == frozen_cfg.as_configdict())  # True

    print_section('Immutability.')
    try:
        frozen_cfg.new_field = 1  # Raises AttributeError because of immutability.
    except AttributeError as e:
        print(e)

    print_section('"==" and eq_as_configdict().')
    # FrozenConfigDict.__eq__() is not type-invariant with respect to ConfigDict
    print(frozen_cfg == cfg)  # False
    # FrozenConfigDict.eq_as_configdict() is type-invariant with respect to
    # ConfigDict
    print(frozen_cfg.eq_as_configdict(cfg))  # True
    # .eq_as_congfigdict() is also a method of ConfigDict
    print(cfg.eq_as_configdict(frozen_cfg))  # True
def get_config(config_string):
  """A config which takes an extra string argument."""
  possible_configs = {
      'type_a': config_dict.ConfigDict({
          'thing_a': 23,
          'thing_b': 42,
      }),
      'type_b': config_dict.ConfigDict({
          'thing_a': 19,
          'thing_c': 65,
      }),
  }
  return possible_configs[config_string]
示例#8
0
def get_dataset_config():
    """Gets the config for dataset."""
    config = config_dict.ConfigDict()
    # The path to the specification of grid evaluator.
    # If not specified, normal evaluator will be used.
    config.grid_evaluator_spec = ''
    # The directory of saved mgcdb84 dataset.
    config.dataset_directory = ''
    # The data types of MGCDB84 dataset to use. If specified, the training and
    # validation set will be obtained by partition data with specified type.
    # If not specified, training and validation set will be those of MCGDB84.
    config.mgcdb84_types = ''
    # The fraction of training, validation and set sets.
    # Only used if mgcdb84_types is not None. Comma separated string of 3 floats.
    config.train_validation_test_split = '0.6,0.2,0.2'
    # The targets used for training. Defaults to mgcdb84_ref, which uses target
    # values from reference values given by MCGDB84. targets can also be set to
    # the exchange-correlation energies of a certain functional, which can be
    # specified by an existing functional name in xc_functionals or the path to
    # a json file specifying the functional form and parameters.
    config.targets = 'mgcdb84_ref'
    # The number of targets used for training. Default to 0 (use all targets).
    config.num_targets = 0
    # If True, only spin unpolarized molecules are used.
    config.spin_singlet = False
    # The evaluation mode for training, validation and test sets. Possible values
    # are jit, onp and jnp. Comma separated string.
    config.eval_modes = 'jit,onp,onp'
    return config
    def testOverrideValues(self):
        config_flags.DEFINE_config_file('config')
        with self.assertRaisesWithLiteralMatch(
                config_flags.UnparsedFlagError,
                'The flag has not been parsed yet'):
            flags.FLAGS['config'].override_values  # pylint: disable=pointless-statement

        original_float = -1.0
        original_dictfloat = -2.0
        config = config_dict.ConfigDict({
            'integer': -1,
            'float': original_float,
            'dict': {
                'float': original_dictfloat
            }
        })
        integer_override = 0
        dictfloat_override = 1.1
        values = _parse_flags(
            './program --test_config={} --test_config.integer={} '
            '--test_config.dict.float={}'.format(_TEST_CONFIG_FILE,
                                                 integer_override,
                                                 dictfloat_override))

        config.update_from_flattened_dict(
            config_flags.get_override_values(values['test_config']))
        self.assertEqual(config['integer'], integer_override)
        self.assertEqual(config['float'], original_float)
        self.assertEqual(config['dict']['float'], dictfloat_override)
示例#10
0
def distort_image_with_autoaugment(image, augmentation_name):
    """Applies the AutoAugment policy to `image`.

  AutoAugment is from the paper: https://arxiv.org/abs/1805.09501.

  Args:
    image: `Tensor` of shape [height, width, 3] representing an image.
    augmentation_name: The name of the AutoAugment policy to use. The available
      options are `v0` and `test`. `v0` is the policy used for
      all of the results in the paper and was found to achieve the best results
      on the COCO dataset. `v1`, `v2` and `v3` are additional good policies
      found on the COCO dataset that have slight variation in what operations
      were used during the search procedure along with how many operations are
      applied in parallel to a single image (2 vs 3).

  Returns:
    A tuple containing the augmented versions of `image`.
  """
    available_policies = {'v0': policy_v0, 'test': policy_vtest}
    if augmentation_name not in available_policies:
        raise ValueError(
            'Invalid augmentation_name: {}'.format(augmentation_name))

    policy = available_policies[augmentation_name]()
    # Hparams that will be used for AutoAugment.
    augmentation_hparams = config_dict.ConfigDict(
        dict(cutout_const=100, translate_const=250))

    return build_and_apply_nas_policy(policy, image, augmentation_hparams)
def get_config(launch_on_gcp):
  """Returns the configuration for this experiment."""
  config = config_dict.ConfigDict()
  config.user = getpass.getuser()
  config.priority = 'prod'
  config.platform = 'tpu-v3'
  config.tpu_topology = '2x2'
  config.experiment_name = (
      os.path.splitext(os.path.basename(__file__))[0] + '_' +
      datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
  output_dir = 'gs://launcher-beta-test-bucket/diabetic_retinopathy_detection/{}'.format(
      config.experiment_name)
  data_dir = 'gs://ub-data/retinopathy'
  config.args = {
      'train_epochs': 90,
      'train_batch_size': 64,
      'eval_batch_size': 64,
      'output_dir': output_dir,
      # Checkpoint every eval to get the best checkpoints via early stopping.
      'checkpoint_interval': 1,
      'lr_schedule': 'linear',
      # Best hyperparameters.
      'base_learning_rate': 0.31448,
      'one_minus_momentum': 0.0052243,
      'l2': 0.0000051293,
      'use_validation': False,
      'data_dir': data_dir,
  }
  return config
def get_config():
    """Returns the configuration for this experiment."""
    config = config_dict.ConfigDict()
    config.user = getpass.getuser()
    config.priority = 'prod'
    config.platform = 'tpu-v2'
    config.tpu_topology = '2x2'

    config.experiment_name = (
        os.path.splitext(os.path.basename(__file__))[0] + '_' +
        datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
    output_dir = 'gs://drd-radial-severity-finetune/indomain/{}'.format(
        config.experiment_name)
    config.args = {
        'batch_size': 16,
        'num_mc_samples_train': 1,
        'num_mc_samples_eval': 5,
        'train_epochs': 90,
        'num_cores': 8,
        'class_reweight_mode': 'minibatch',
        'dr_decision_threshold': 'moderate',
        'distribution_shift': 'severity',
        'checkpoint_interval': 1,
        'output_dir': output_dir,
        'data_dir': 'gs://ub-data/retinopathy',

        # Config
        'l2': 0.00084192,
        'one_minus_momentum': 0.027963,
        'stddev_stddev_init': 0.037535,
        'stddev_mean_init': 0.012607,
        'base_learning_rate': 0.20617
    }
    return config
示例#13
0
def get_config():
  """Returns the configuration for this experiment."""
  config = config_dict.ConfigDict()
  config.user = getpass.getuser()
  config.priority = 'prod'
  config.platform = 'tpu-v2'
  config.vm = 'n1-standard-64'
  config.tpu_topology = '2x2'
  config.experiment_name = (
      os.path.splitext(os.path.basename(__file__))[0] + '_' +
      datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
  output_dir = (
      'gs://launcher-beta-test-bucket/diabetic_retinopathy_detection/{}'.format(
          config.experiment_name))
  data_dir = 'gs://ub-data/retinopathy'
  config.args = {
      # 'train_epochs': 90,
      # 'use_gpu': False,  # Use TPU.
      # 'train_batch_size': 64,
      # 'eval_batch_size': 64,
      'output_dir': output_dir,
      # 'checkpoint_interval': -1,
      # 'lr_schedule': 'linear',
      'data_dir': data_dir,
  }
  return config
示例#14
0
def get_config():
    config = base_config.get_base_config()

    config.random_seed = 0
    images_per_epoch = 1281167
    train_batch_size = 2048
    num_epochs = 300
    steps_per_epoch = images_per_epoch / train_batch_size
    config.training_steps = ((images_per_epoch * num_epochs) //
                             train_batch_size)
    config.experiment_kwargs = config_dict.ConfigDict(
        dict(config=dict(
            lr=1e-3,
            num_epochs=num_epochs,
            image_size=224,
            num_classes=1000,
            which_dataset='imagenet',
            loss='softmax_cross_entropy',
            transpose=True,
            dtype=jnp.bfloat16,
            lr_schedule=dict(name='cosine_decay_schedule',
                             kwargs=dict(init_value=1e-3,
                                         decay_steps=config.training_steps)),
            optimizer_weights=dict(
                name='adamw', kwargs=dict(b1=0.9, b2=0.999, weight_decay=0.05)),
            optimizer_biases=dict(name='adam', kwargs=dict(b1=0.9, b2=0.999)),
            model=dict(name='BoTNet',
                       config_kwargs=dict(stage_sizes=[3, 4, 6, 6],
                                          dtype=jnp.bfloat16)),
            augment_name='cutmix_mixup_randaugment_405')))
def get_config():
    """Returns the configuration for this experiment."""
    config = config_dict.ConfigDict()
    config.user = getpass.getuser()
    config.priority = 'prod'
    config.platform = 'gpu'
    config.gpu_type = 't4'
    config.num_gpus = 1
    config.experiment_name = (
        os.path.splitext(os.path.basename(__file__))[0] + '_' +
        datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
    output_dir = 'gs://launcher-beta-test-bucket/{}'.format(
        config.experiment_name)
    config.args = {
        'train_epochs': 90,
        'per_core_batch_size': 64,
        'checkpoint_interval': -1,
        'data_dir': output_dir,
        'output_dir': output_dir,
        'download_data': True,
        'train_proportion': 0.9,
        'eval_on_ood': True,
        'ood_dataset': 'cifar100,svhn_cropped',
        # If drop_remainder=false, it will cause the issue of
        # `TPU has inputs with dynamic shapes` for sngp.py
        # To make the evaluation comparable, we set true for deterministic.py too.
        'drop_remainder_for_eval': True,
    }
    return config
示例#16
0
def lazy_configdict_advanced():
    """Advanced lazy computation with ConfigDict."""
    # FieldReferences can be used with ConfigDict as well
    config = config_dict.ConfigDict()
    config.float_field = 12.6
    config.integer_field = 123
    config.list_field = [0, 1, 2]

    config.float_multiply_field = config.get_ref('float_field') * 3
    print(config.float_multiply_field)  # Prints 37.8

    config.float_field = 10.0
    print(config.float_multiply_field)  # Prints 30.0

    config.longer_list_field = config.get_ref('list_field') + [3, 4, 5]
    print(config.longer_list_field)  # Prints [0, 1, 2, 3, 4, 5]

    config.list_field = [-1]
    print(config.longer_list_field)  # Prints [-1, 3, 4, 5]

    # Both operands can be references
    config.ref_subtraction = (config.get_ref('float_field') -
                              config.get_ref('integer_field'))
    print(config.ref_subtraction)  # Prints -113.0

    config.integer_field = 10
    print(config.ref_subtraction)  # Prints 0.0
示例#17
0
def lazy_configdict():
    """Example usage of lazy computation with ConfigDict."""
    config = config_dict.ConfigDict()
    config.reference_field = config_dict.FieldReference(1)
    config.integer_field = 2
    config.float_field = 2.5

    # No lazy evaluatuations because we didn't use get_ref()
    config.no_lazy = config.integer_field * config.float_field

    # This will lazily evaluate ONLY config.integer_field
    config.lazy_integer = config.get_ref('integer_field') * config.float_field

    # This will lazily evaluate ONLY config.float_field
    config.lazy_float = config.integer_field * config.get_ref('float_field')

    # This will lazily evaluate BOTH config.integer_field and config.float_Field
    config.lazy_both = (config.get_ref('integer_field') *
                        config.get_ref('float_field'))

    config.integer_field = 3
    print(
        config.no_lazy)  # Prints 5.0 - It uses integer_field's original value

    print(config.lazy_integer)  # Prints 7.5

    config.float_field = 3.5
    print(config.lazy_float)  # Prints 7.0
    print(config.lazy_both)  # Prints 10.5
示例#18
0
def get_config(launch_on_gcp):
    """Returns the configuration for this experiment."""
    config = config_dict.ConfigDict()
    config.user = getpass.getuser()
    config.priority = 'prod'
    config.platform = 'tpu-v3'
    config.tpu_topology = '2x2'
    config.experiment_name = (
        os.path.splitext(os.path.basename(__file__))[0] + '_' +
        datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
    output_dir = 'gs://launcher-beta-test-bucket/diabetic_retinopathy_detection/{}'.format(
        config.experiment_name)
    data_dir = 'gs://ub-data/retinopathy'
    config.args = {
        'train_epochs': 90,
        'use_gpu': False,  # Use TPU.
        'batch_size': 32,
        'output_dir': output_dir,
        # Checkpoint every eval to get the best checkpoints via early stopping.
        'checkpoint_interval': 1,
        # Second-best hparams.
        'base_learning_rate': 0.84557,
        'one_minus_momentum': 0.023980,
        'l2': 0.00019403,
        'stddev_mean_init': 0.000018096,
        'stddev_stddev_init': 0.067054,
        'use_validation': False,
        'data_dir': data_dir,
    }
    return config
def get_config():
    """Returns the configuration for this experiment."""
    config = config_dict.ConfigDict()
    config.user = getpass.getuser()
    config.priority = 'prod'
    config.platform = 'tpu-v2'
    config.tpu_topology = '2x2'

    config.experiment_name = (
        os.path.splitext(os.path.basename(__file__))[0] + '_' +
        datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
    output_dir = 'gs://drd-radial-aptos-finetune/indomain/{}'.format(
        config.experiment_name)
    config.args = {
        'batch_size': 16,
        'num_mc_samples_train': 1,
        'num_mc_samples_eval': 5,
        'train_epochs': 90,
        'num_cores': 8,
        'class_reweight_mode': 'minibatch',
        'dr_decision_threshold': 'moderate',
        'distribution_shift': 'aptos',
        'checkpoint_interval': 1,
        'output_dir': output_dir,
        'data_dir': 'gs://ub-data/retinopathy',
        'l2': 0.0005243849811857283,
        'one_minus_momentum': 0.016699763426232056,
        'stddev_stddev_init': 0.06282496582469976,
        'stddev_mean_init': 0.00014497837766733678,
        'base_learning_rate': 0.18958209702776632
    }
    return config
示例#20
0
def get_config(launch_on_gcp):
  """Returns the configuration for this experiment."""
  del launch_on_gcp
  config = config_dict.ConfigDict()
  config.user = getpass.getuser()
  config.priority = 'prod'
  config.platform = 'gpu'
  config.gpu_type = 't4'
  config.num_gpus = 1
  config.experiment_name = (
      os.path.splitext(os.path.basename(__file__))[0] + '_' +
      datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
  output_dir = 'gs://launcher-beta-test-bucket/{}'.format(
      config.experiment_name)
  data_dir = 'gs://ub-data/retinopathy'
  config.args = {
      'train_epochs': 90,
      'train_batch_size': 64,
      'eval_batch_size': 64,
      'checkpoint_interval': -1,
      'lr_schedule': 'step',
      'output_dir': output_dir,
      'data_dir': data_dir,
  }
  return config
示例#21
0
def get_config():
    """Returns the configuration for this experiment."""
    config = config_dict.ConfigDict()
    config.user = getpass.getuser()
    config.priority = 'prod'
    config.platform = 'gpu'
    config.gpu_type = 't4'
    config.num_gpus = 1
    config.experiment_name = (
        os.path.splitext(os.path.basename(__file__))[0] + '_' +
        datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
    output_dir = 'gs://launcher-beta-test-bucket/{}'.format(
        config.experiment_name)
    config.args = {
        'base_learning_rate': 0.08,
        'l2': 3e-4,
        'gp_mean_field_factor': 20,  # 25,
        'train_epochs': 250,
        'per_core_batch_size': 64,
        'data_dir': output_dir,
        'output_dir': output_dir,
        'download_data': True,
        'train_proportion': 0.9,
        'eval_on_ood': True,
        'ood_dataset': 'cifar100,svhn_cropped',
        # If drop_remainder=false, it will cause the issue of
        # `TPU has inputs with dynamic shapes`
        'drop_remainder_for_eval': True,
        'corruptions_interval': 10,
        'use_mc_dropout': True,
        'num_dropout_samples': 10,
    }
    return config
示例#22
0
def get_config():
    """Returns the configuration for this experiment."""
    config = config_dict.ConfigDict()
    config.user = getpass.getuser()
    config.priority = 'prod'
    config.platform = 'tpu-v2'
    config.tpu_topology = '2x2'

    config.experiment_name = (
        os.path.splitext(os.path.basename(__file__))[0] + '_' +
        datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
    output_dir = 'gs://drd-rank1-severity-results/{}'.format(
        config.experiment_name)
    config.args = {
        'per_core_batch_size': 16,
        'num_mc_samples_train': 1,
        'num_mc_samples_eval': 5,
        'train_epochs': 90,
        'num_cores': 8,
        'class_reweight_mode': 'minibatch',
        'dr_decision_threshold': 'moderate',
        'distribution_shift': 'severity',
        'checkpoint_interval': 1,
        'output_dir': output_dir,
        'data_dir': 'gs://ub-data/retinopathy',
    }
    return config
def get_config():
    """Returns the configuration for this experiment."""
    config = config_dict.ConfigDict()
    config.user = getpass.getuser()
    config.priority = 'prod'
    config.platform = 'tpu-v2'
    config.tpu_topology = '2x2'
    config.experiment_name = (
        os.path.splitext(os.path.basename(__file__))[0] + '_' +
        datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
    output_dir = 'gs://drd-dropout-aptos-finetune/joint/{}'.format(
        config.experiment_name)
    config.args = {
        'per_core_batch_size': 16,
        'num_dropout_samples_eval': 5,
        'train_epochs': 90,
        'num_cores': 8,
        'class_reweight_mode': 'minibatch',
        'dr_decision_threshold': 'moderate',
        'distribution_shift': 'aptos',
        'checkpoint_interval': 1,
        'output_dir': output_dir,
        'data_dir': 'gs://ub-data/retinopathy',

        # Hypers
        'base_learning_rate': 0.0028274,
        'one_minus_momentum': 0.024251,
        'l2': 0.000041296,
        'dropout_rate': 0.067338
    }
    return config
def get_config():
    """Returns the configuration for this experiment."""
    config = config_dict.ConfigDict()
    config.user = getpass.getuser()
    config.priority = 'prod'
    config.platform = 'tpu-v2'
    config.tpu_topology = '2x2'
    config.experiment_name = (
        os.path.splitext(os.path.basename(__file__))[0] + '_' +
        datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
    output_dir = 'gs://drd-deterministic-aptos-finetune/indomain/{}'.format(
        config.experiment_name)
    config.args = {
        'per_core_batch_size': 32,
        'train_epochs': 90,
        'num_cores': 8,
        'class_reweight_mode': 'minibatch',
        'dr_decision_threshold': 'moderate',
        'distribution_shift': 'aptos',
        'checkpoint_interval': 1,
        'lr_schedule': 'linear',
        'output_dir': output_dir,
        'data_dir': 'gs://ub-data/retinopathy',

        # Hypers
        'final_decay_factor': 0.010000,
        'one_minus_momentum': 0.0098467,
        'l2': 0.00010674,
        'base_learning_rate': 0.023072
    }
    return config
def get_config():
    """Returns the configuration for this experiment."""
    config = config_dict.ConfigDict()
    config.user = getpass.getuser()
    config.priority = 'prod'
    config.platform = 'gpu'
    config.gpu_type = 't4'
    config.num_gpus = 1
    config.experiment_name = (
        os.path.splitext(os.path.basename(__file__))[0] + '_' +
        datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
    output_dir = 'gs://launcher-beta-test-bucket/{}'.format(
        config.experiment_name)
    config.args = {
        'train_epochs': 250,
        'per_core_batch_size': 64,
        'checkpoint_interval': -1,
        'data_dir': output_dir,
        'output_dir': output_dir,
        'download_data': True,
        'train_proportion': 0.9,
        'eval_on_ood': True,
        'ood_dataset': 'cifar100,svhn_cropped',
        'drop_remainder_for_eval': True,
    }
    return config
 def testFieldReferenceResolved(self):
     """Tests that FieldReferences are resolved."""
     cfg = config_dict.ConfigDict({'fr': config_dict.FieldReference(1)})
     frozen_cfg = config_dict.FrozenConfigDict(cfg)
     self.assertNotIsInstance(frozen_cfg._fields['fr'],
                              config_dict.FieldReference)
     hash(
         frozen_cfg)  # with FieldReference resolved, frozen_cfg is hashable
示例#27
0
 def get_state(self, ckpt_series):
     if ckpt_series not in GLOBAL_CHECKPOINT_DICT:
         active = threading.local()
         new_series = CheckpointNT(active, [])
         GLOBAL_CHECKPOINT_DICT[ckpt_series] = new_series
     if not hasattr(GLOBAL_CHECKPOINT_DICT[ckpt_series].active, "state"):
         GLOBAL_CHECKPOINT_DICT[ckpt_series].active.state = (
             config_dict.ConfigDict())
     return GLOBAL_CHECKPOINT_DICT[ckpt_series].active.state
示例#28
0
def get_config():
    """Gets the config."""
    config = config_dict.ConfigDict()
    config.xc = get_xc_config()
    config.re = get_re_config()
    config.infra = get_infra_config()
    config.dataset = get_dataset_config()
    config.opt = get_opt_config()
    return config
示例#29
0
文件: utils.py 项目: deepmind/jaxline
 def get_experiment_state(self, ckpt_series: str):
     """Returns the experiment state for a given checkpoint series."""
     if ckpt_series not in GLOBAL_CHECKPOINT_DICT:
         active = threading.local()
         new_series = CheckpointNT(active, [])
         GLOBAL_CHECKPOINT_DICT[ckpt_series] = new_series
     if not hasattr(GLOBAL_CHECKPOINT_DICT[ckpt_series].active, "state"):
         GLOBAL_CHECKPOINT_DICT[ckpt_series].active.state = (
             config_dict.ConfigDict())
     return GLOBAL_CHECKPOINT_DICT[ckpt_series].active.state
    def testInitConfigDict(self):
        """Tests that ConfigDict initialization handles FrozenConfigDict.

    Initializing a ConfigDict on a dictionary with FrozenConfigDict values
    should unfreeze these values.
    """
        dict_without_fcd_node = _test_dict_deepcopy()
        dict_without_fcd_node.pop('ref')
        dict_with_fcd_node = copy.deepcopy(dict_without_fcd_node)
        dict_with_fcd_node['dict'] = config_dict.FrozenConfigDict(
            dict_with_fcd_node['dict'])
        cd_without_fcd_node = config_dict.ConfigDict(dict_without_fcd_node)
        cd_with_fcd_node = config_dict.ConfigDict(dict_with_fcd_node)
        fcd_without_fcd_node = config_dict.FrozenConfigDict(
            dict_without_fcd_node)
        fcd_with_fcd_node = config_dict.FrozenConfigDict(dict_with_fcd_node)

        self.assertEqual(cd_without_fcd_node, cd_with_fcd_node)
        self.assertEqual(fcd_without_fcd_node, fcd_with_fcd_node)