コード例 #1
0
    def testOverrideValues(self):
        config_flags.DEFINE_config_file('config')
        with self.assertRaisesWithLiteralMatch(
                config_flags.UnparsedFlagError,
                'The flag has not been parsed yet'):
            flags.FLAGS['config'].override_values  # pylint: disable=pointless-statement

        original_float = -1.0
        original_dictfloat = -2.0
        config = ml_collections.ConfigDict({
            'integer': -1,
            'float': original_float,
            'dict': {
                'float': original_dictfloat
            }
        })
        integer_override = 0
        dictfloat_override = 1.1
        values = _parse_flags(
            './program --test_config={} --test_config.integer={} '
            '--test_config.dict.float={}'.format(_TEST_CONFIG_FILE,
                                                 integer_override,
                                                 dictfloat_override))

        config.update_from_flattened_dict(
            values['test_config'].override_values)
        self.assertEqual(config['integer'], integer_override)
        self.assertEqual(config['float'], original_float)
        self.assertEqual(config['dict']['float'], dictfloat_override)
コード例 #2
0
    def testIsConfigFile(self):
        config_flags.DEFINE_config_file('is_a_config_flag')
        flags.DEFINE_integer('not_a_config_flag', -1, '')

        self.assertTrue(
            config_flags.is_config_flag(flags.FLAGS['is_a_config_flag']))
        self.assertFalse(
            config_flags.is_config_flag(flags.FLAGS['not_a_config_flag']))
コード例 #3
0
def _parse_flags(command,
                 default=None,
                 config=None,
                 lock_config=True,
                 required=False):
    """Parses arguments simulating sys.argv."""

    if config is not None and default is not None:
        raise ValueError('If config is supplied a default should not be.')

    # Storing copy of the old sys.argv.
    old_argv = list(sys.argv)

    # Overwriting sys.argv, as sys has a global state it gets propagated.
    # The module shlex is useful here because it splits the input similar to
    # sys.argv. For instance, string arguments are not split by space.
    sys.argv = shlex.split(command)

    # Actual parsing.
    values = flags.FlagValues()
    if config is None:
        config_flags.DEFINE_config_file('test_config',
                                        default=default,
                                        flag_values=values,
                                        lock_config=lock_config)
    else:
        config_flags.DEFINE_config_dict('test_config',
                                        config=config,
                                        flag_values=values,
                                        lock_config=lock_config)

    if required:
        flags.mark_flag_as_required('test_config', flag_values=values)
    values(sys.argv)

    # Going back to original values.
    sys.argv = old_argv

    return values
コード例 #4
0
from absl import logging  # pylint: disable=unused-import

from ml_collections.config_flags import config_flags

from ipagnn.lib import path_utils
from ipagnn.lib import setup
from ipagnn.workflows import analysis_workflows

DEFAULT_DATA_DIR = os.path.expanduser(os.path.join('~', 'tensorflow_datasets'))
DEFAULT_CONFIG = 'ipagnn/config/config.py'

flags.DEFINE_string('data_dir', DEFAULT_DATA_DIR, 'Where to place the data.')
flags.DEFINE_string('run_dir', '/tmp/learned_interpreters/default/',
                    'The directory to use for this run of the experiment.')
config_flags.DEFINE_config_file(name='config',
                                default=DEFAULT_CONFIG,
                                help_string='config file')
FLAGS = flags.FLAGS


def main(argv):
    del argv  # Unused.

    data_dir = FLAGS.data_dir
    xm_parameters = {}
    run_dir = path_utils.expand_run_dir(FLAGS.run_dir, xm_parameters)
    config = FLAGS.config
    override_values = FLAGS['config'].override_values

    run_configuration = setup.configure(data_dir, run_dir, config,
                                        override_values, xm_parameters)
コード例 #5
0
# Lint as: python3
"""Example of ConfigDict usage.

This example includes loading a ConfigDict in FLAGS, locking it, type
safety, iteration over fields, checking for a particular field, unpacking with
`**`, and loading dictionary from string representation.
"""

from absl import app
from absl import flags
from ml_collections.config_flags import config_flags
import yaml

FLAGS = flags.FLAGS

config_flags.DEFINE_config_file(
    'my_config', default='ml_collections/config_dict/examples/config.py')


def dummy_function(string, **unused_kwargs):
    return 'Hello {}'.format(string)


def print_section(name):
    print()
    print()
    print('-' * len(name))
    print(name.upper())
    print('-' * len(name))
    print()

コード例 #6
0
ファイル: run_classifier.py プロジェクト: skye/flax_bert
from tensorflow.io import gfile

import datasets
import transformers
from transformers import BertTokenizerFast

import ml_collections
from ml_collections.config_flags import config_flags

FLAGS = flags.FLAGS

flags.DEFINE_string(
    'output_dir', None,
    'The output directory where the model checkpoints will be written.')

config_flags.DEFINE_config_file('config', None, 'Hyperparameter configuration')


def get_config():
    config = FLAGS.config
    hf_config = transformers.AutoConfig.from_pretrained(config.init_checkpoint)
    assert hf_config.model_type == 'bert', 'Only BERT is supported.'
    model_config = ml_collections.ConfigDict({
        'vocab_size':
        hf_config.vocab_size,
        'hidden_size':
        hf_config.hidden_size,
        'num_hidden_layers':
        hf_config.num_hidden_layers,
        'num_attention_heads':
        hf_config.num_attention_heads,
コード例 #7
0
import tensorflow as tf
from tensorflow.io import gfile
import uncertainty_baselines as ub
import checkpoint_utils  # local file import from baselines.jft
import data_uncertainty_utils  # local file import from baselines.jft
import input_utils  # local file import from baselines.jft
import preprocess_utils  # local file import from baselines.jft
import train_utils  # local file import from baselines.jft
import ood_utils  # local file import from experimental.near_ood.vit

# TODO(dusenberrymw): Open-source remaining imports.
fewshot = None

config_flags.DEFINE_config_file('config',
                                None,
                                'Training configuration.',
                                lock_config=True)
flags.DEFINE_string('output_dir', default=None, help='Work unit directory.')
flags.DEFINE_integer('num_cores',
                     default=None,
                     help='Unused. How many devices being used.')
flags.DEFINE_boolean('use_gpu',
                     default=None,
                     help='Unused. Whether or not running on GPU.')
flags.DEFINE_string('tpu', None,
                    'Unused. Name of the TPU. Only used if use_gpu is False.')

FLAGS = flags.FLAGS


def main(config, output_dir):
コード例 #8
0
from ml_collections.config_flags import config_flags
from tensorflow.io import gfile

import data
import modeling
import training

FLAGS = flags.FLAGS

flags.DEFINE_string(
    "output_dir",
    None,
    "The output directory where the model checkpoints will be written.",
)

config_flags.DEFINE_config_file("config", None, "Hyperparameter configuration")


def get_config():
    config = FLAGS.config
    hf_config = transformers.AutoConfig.from_pretrained(config.init_checkpoint)
    assert hf_config.model_type == "bert", "Only BERT is supported."
    model_config = ml_collections.ConfigDict({
        "vocab_size":
        hf_config.vocab_size,
        "hidden_size":
        hf_config.hidden_size,
        "num_hidden_layers":
        hf_config.num_hidden_layers,
        "num_attention_heads":
        hf_config.num_attention_heads,
コード例 #9
0
from abstract_nas.train import train
from abstract_nas.zoo import utils as zoo_utils

_MAX_NUM_TRIALS = flags.DEFINE_integer("max_num_trials", 50,
                                       "The total number of trials to run.")
_RESULTS_DIR = flags.DEFINE_string(
    "results_dir", "", "The directory in which to write the results.")
_WRITE_CHECKPOINTS = flags.DEFINE_bool(
    "checkpoints", False, "Whether to checkpoint the evolved models.")
_STUDY_NAME = flags.DEFINE_string(
    "study_name", "abstract_nas_demo",
    "The name of this study, used for writing results.")

FLAGS = flags.FLAGS

config_flags.DEFINE_config_file("config")


def save_results(population):
  """Saves results as csv."""
  results_dir = _RESULTS_DIR.value
  study_name = _STUDY_NAME.value
  if not results_dir:
    logging.warn("No results_dir defined, so skipping saving results.")
    return

  results_filename = f"{results_dir}/{study_name}.csv"
  if os.path.exists(results_filename):
    logging.warn("Results file %s already exists, so skipping saving results.",
                 results_filename)
    return
コード例 #10
0
from absl import flags
import gym
import matplotlib.pyplot as plt
from ml_collections.config_flags import config_flags
import torch
from wrappers import wrapper_from_config
import xmagical
from xmagical.utils import KeyboardEnvInteractor

FLAGS = flags.FLAGS

flags.DEFINE_string("embodiment", None, "The agent embodiment.")

config_flags.DEFINE_config_file(
    "config",
    "configs/rl/default.py",
    "File path to the training hyperparameter configuration.",
    lock_config=True,
)

flags.mark_flag_as_required("embodiment")


def make_env():
    xmagical.register_envs()
    embodiment_name = FLAGS.embodiment.capitalize()
    env = gym.make(f"SweepToTop-{embodiment_name}-State-Allo-TestLayout-v0")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    env = wrapper_from_config(FLAGS.config, env, device)
    return env

コード例 #11
0
from cascaded_networks.datasets.dataset_handler import DataHandler
from cascaded_networks.models import densenet
from cascaded_networks.models import resnet
from cascaded_networks.modules import eval_handler
from cascaded_networks.modules import losses
from cascaded_networks.modules import train_handler
from cascaded_networks.modules import utils

# Setup Flags
FLAGS = flags.FLAGS
flags.DEFINE_string('gcs_path', None, 'gcs_path dir')
flags.DEFINE_bool('hyper_param_sweep', None, 'conducting hyperparam sweep')
flags.DEFINE_integer('n_gpus', None, 'Number of GPUs')

config_flags.DEFINE_config_file(
    name='config',
    default=None,
    help_string='Path to the Training configuration.')


def main(_):
  config = FLAGS.config

  if config.debug:
    config.epochs = 5

  # Make reproducible
  utils.make_reproducible(config.random_seed)

  # Parse GCS bucket path
  gcs_subpath = config.local_output_dir
コード例 #12
0
import os
import sys
import wandb
import pytorch_lightning as pl

from datasets import load_dataset
from utils.wandb import get_experiments, load_model
from attack import trainer

# args
from absl import app
from absl import flags
from ml_collections.config_flags import config_flags

FLAGS = flags.FLAGS
config_flags.DEFINE_config_file("config", default="config.py:attack")


def model_name(args):
    n = args.prior + '_' + args.arc_type + '_' + str(args.z_dim)
    return n


def cli_main(_):
    pl.seed_everything(1234)

    if "absl.logging" in sys.modules:
        import absl.logging

        absl.logging.set_verbosity("info")
        absl.logging.set_stderrthreshold("info")
コード例 #13
0
    'preds_col', None,
    'Name to use for the DeepNull prediction column. If unspecified, will be '
    'the target column name with "_deepnull" suffix added.')
_NUM_FOLDS = flags.DEFINE_integer(
    'num_folds', 5, 'The number of cross-validation folds to use.')
_SEED = flags.DEFINE_integer('seed', None, 'Random seed to use.')
_LOGDIR = flags.DEFINE_string(
    'logdir', '/tmp', 'Directory in which to write temporary outputs.')
_VERBOSE = flags.DEFINE_boolean(
    'verbose', False, 'If True, prints verbose model training output.')

# N.B.: ml_collections v0.1.0 does not return from this function. When an update
# is released on PyPI, use the same structure as the above flags.
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file(
    'model_config', None,
    'Specifies the model config file to use. If unspecified, defaults to the '
    'MLP-based TF model used for all main results of the DeepNull paper.')


def main(argv: Sequence[str]) -> None:
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.random.set_seed(_SEED.value)
    logging.info('Loading data from %s', _INPUT_TSV.value)
    input_df, binary_col_map = data.load_plink_or_bolt_file(
        path_or_buf=_INPUT_TSV.value, missing_value=_MISSING_VALUE.value)

    if FLAGS.model_config is None:
        full_config = config.get_config('deepnull')
    else:
コード例 #14
0
from absl import app
from absl import logging
from ml_collections.config_flags import config_flags

from stable_transfer.classification import accuracy
from stable_transfer.classification import gbc
from stable_transfer.classification import hscore
from stable_transfer.classification import leep
from stable_transfer.classification import logme
from stable_transfer.classification import nleep
from stable_transfer.classification import transfer_experiment

_CONFIG_DIR = './'

_CONFIG = config_flags.DEFINE_config_file(
    'my_config',
    f'{_CONFIG_DIR}/stable_transfer/classification/config_transfer_experiment.py',
)


def run_experiment(experiment):
    """Run the experiment defined in the config file."""
    if experiment.config.experiment.metric == 'accuracy':
        return accuracy.get_test_accuracy(experiment)
    if experiment.config.experiment.metric == 'leep':
        return leep.get_train_leep(experiment)
    if experiment.config.experiment.metric == 'logme':
        return logme.get_train_logme(experiment)
    if experiment.config.experiment.metric == 'hscore':
        return hscore.get_train_hscore(experiment)
    if experiment.config.experiment.metric == 'nleep':
        return nleep.get_train_nleep(experiment)
コード例 #15
0
 def testModuleName(self):
     config_flags.DEFINE_config_file('flag')
     argv_0 = './program'
     _parse_flags(argv_0)
     self.assertIn(flags.FLAGS['flag'],
                   flags.FLAGS.flags_by_module_dict()[argv_0])
コード例 #16
0
ファイル: run_experiment.py プロジェクト: mtyhon/ckconv
import torch

# project
from path_handler import model_path
from model import get_model
import dataset
import trainer
import tester

# args
from absl import app
from absl import flags
from ml_collections.config_flags import config_flags

FLAGS = flags.FLAGS
config_flags.DEFINE_config_file("config", default="config.py")


def main(_):

    if "absl.logging" in sys.modules:
        import absl.logging

        absl.logging.set_verbosity("info")
        absl.logging.set_stderrthreshold("info")

    config = FLAGS.config
    print(config)

    # Set the seed
    torch.manual_seed(config.seed)
コード例 #17
0
from importlib import import_module

from absl import flags
from ml_collections.config_flags import config_flags


# Configure model
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file('model')


def load_model():
    module_path, class_name = FLAGS.model.constructor.rsplit('.', 1)
    module = import_module(module_path)
    learner = getattr(module, class_name)
    return learner(**FLAGS.model.hparams)
コード例 #18
0
r"""Example of basic DEFINE_flag_dict usage.

To run this example with basic config file:
python define_config_dict_basic.py -- \
  --my_config=ml_collections/config_flags/examples/config.py
  \
  --my_config.field1=8 --my_config.nested.field=2.1 \
  --my_config.tuple='(1, 2, (1, 2))'

To run this example with parameterised config file:
python define_config_dict_basic.py -- \
  --my_config=ml_collections/config_flags/examples/parameterised_config.py:linear
  \
  --my_config.model_config.output_size=256'
"""
# pylint: enable=line-too-long

from absl import app

from ml_collections.config_flags import config_flags

_CONFIG = config_flags.DEFINE_config_file('my_config')


def main(_):
    print(_CONFIG.value)


if __name__ == '__main__':
    app.run(main)
コード例 #19
0
To run this example with basic config file:
python define_config_dict_basic.py -- \
  --my_config=ml_collections/config_flags/examples/config.py
  \
  --my_config.field1=8 --my_config.nested.field=2.1 \
  --my_config.tuple='(1, 2, (1, 2))'

To run this example with parameterised config file:
python define_config_dict_basic.py -- \
  --my_config=ml_collections/config_flags/examples/parameterised_config.py:linear
  \
  --my_config.model_config.output_size=256'
"""
# pylint: enable=line-too-long

from absl import app
from absl import flags

from ml_collections.config_flags import config_flags

FLAGS = flags.FLAGS
config_flags.DEFINE_config_file('my_config')


def main(_):
    print(FLAGS.my_config)


if __name__ == '__main__':
    app.run(main)
コード例 #20
0
# pylint: disable=line-too-long

flags.DEFINE_string(
    'query', None, 'natural language description of the desired object.'
    'can also be specified in the config file')
flags.DEFINE_integer('seed', 0,
                     'random seed. change to get a different generation')
flags.DEFINE_string('executable_name', 'train',
                    'executable name. [train|eval]')
flags.DEFINE_string('experiment_dir', 'results', 'experiment output directory')
flags.DEFINE_string('work_unit_dir', None,
                    'work unit output directory within experiment_dir')
flags.DEFINE_string('config_json', None,
                    'hyperparameter file to read in .json')
flags.DEFINE_string('extra_args_json_str', None, 'extra args to pass in')
config_flags.DEFINE_config_file('config', lock_config=False)
FLAGS = flags.FLAGS
# pylint: enable=line-too-long


def main(executable_dict, argv):
    del argv

    work_unit = platform.work_unit()
    tf.enable_v2_behavior()
    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')

    host_id = jax.host_id()
    n_host = jax.host_count()
コード例 #21
0
ファイル: main.py プロジェクト: ziyouzizai111/google-research
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training and evaluation for NCSNv3."""

from . import ncsn_lib
from absl import app
from absl import flags
from ml_collections.config_flags import config_flags
import tensorflow as tf

FLAGS = flags.FLAGS

config_flags.DEFINE_config_file("config",
                                None,
                                "Training configuration.",
                                lock_config=True)
flags.DEFINE_string("workdir", None, "Work unit directory.")
flags.DEFINE_string("mode", "train", "Running mode: train or eval")
flags.DEFINE_string("eval_folder", "eval",
                    "The folder name for storing evaluation results")
flags.mark_flags_as_required(["workdir", "config"])


def main(argv):
    del argv
    tf.config.experimental.set_visible_devices([], "GPU")
    if FLAGS.mode == "train":
        ncsn_lib.train(FLAGS.config, FLAGS.workdir)
    elif FLAGS.mode == "eval":
        ncsn_lib.evaluate(FLAGS.config, FLAGS.workdir, FLAGS.eval_folder)
コード例 #22
0
from jax.flatten_util import ravel_pytree
from ml_collections.config_flags import config_flags
import numpy as np
import scipy.linalg

import adversarial
import data as data_loader
import model
from norm import norm_f
from norm import norm_type_dual
import optim
import summary as summary_tools

FLAGS = flags.FLAGS

config_flags.DEFINE_config_file('config', None, 'Config file name.')


def evaluate_risks(data, predict_f, loss_f, model_param):
  """Returns the risk of a model for various loss functions.

  Args:
    data: An array of data samples for approximating the risk.
    predict_f: Function that predicts labels given input.
    loss_f: Function that outputs model's specific loss function.
    model_param: Model parameters.

  Returns:
    Dictionary of risks for following loss functions:
        (model's loss, 0/1, adversarial risk wrt a single norm-ball).
  """
コード例 #23
0
import jax.numpy as jnp
from ml_collections.config_flags import config_flags
import numpy as np
import optax
import pandas as pd
from sklearn import metrics
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_probability.substrates import jax as tfp
from counterfactual_fairness import adult
from counterfactual_fairness import causal_network
from counterfactual_fairness import utils
from counterfactual_fairness import variational

FLAGS = flags.FLAGS
config_flags.DEFINE_config_file('config', 'adult_pscf_config.py',
                                'Training configuration.')

LOG_EVERY = 100

# These are all aliases to callables which will return instances of
# particular distribution modules, or a Node itself. This is used to make
# subsequent code more legible.
Node = causal_network.Node
Gaussian = causal_network.Gaussian
MLPMultinomial = causal_network.MLPMultinomial


def build_input(train_data: pd.DataFrame,
                batch_size: int,
                training_steps: int,
                shuffle_size: int = 10000):