예제 #1
0
def define_flags_with_default(**kwargs):
    for key, val in kwargs.items():
        if isinstance(val, ConfigDict):
            config_flags.DEFINE_config_dict(key, val)
        elif isinstance(val, bool):
            # Note that True and False are instances of int.
            absl.flags.DEFINE_bool(key, val, 'automatically defined flag')
        elif isinstance(val, int):
            absl.flags.DEFINE_integer(key, val, 'automatically defined flag')
        elif isinstance(val, float):
            absl.flags.DEFINE_float(key, val, 'automatically defined flag')
        elif isinstance(val, str):
            absl.flags.DEFINE_string(key, val, 'automatically defined flag')
        else:
            raise ValueError('Incorrect value type')
    return kwargs
def _parse_flags(command,
                 default=None,
                 config=None,
                 lock_config=True,
                 required=False):
    """Parses arguments simulating sys.argv."""

    if config is not None and default is not None:
        raise ValueError('If config is supplied a default should not be.')

    # Storing copy of the old sys.argv.
    old_argv = list(sys.argv)

    # Overwriting sys.argv, as sys has a global state it gets propagated.
    # The module shlex is useful here because it splits the input similar to
    # sys.argv. For instance, string arguments are not split by space.
    sys.argv = shlex.split(command)

    # Actual parsing.
    values = flags.FlagValues()
    if config is None:
        config_flags.DEFINE_config_file('test_config',
                                        default=default,
                                        flag_values=values,
                                        lock_config=lock_config)
    else:
        config_flags.DEFINE_config_dict('test_config',
                                        config=config,
                                        flag_values=values,
                                        lock_config=lock_config)

    if required:
        flags.mark_flag_as_required('test_config', flag_values=values)
    values(sys.argv)

    # Going back to original values.
    sys.argv = old_argv

    return values
예제 #3
0
# If true, will precompute server-side embedding projections such as PCA.
config.warm_projections = config_dict.placeholder(bool)

# If true, will disable capabilities not allowed in demo mode, such as
# saving generated datapoints to disk.
config.demo_mode = config_dict.placeholder(bool)

# Which layout to use by default (can be changed via url); see layout.ts
config.default_layout = 'default'

# What url base to use when copying the LIT url (e.g., something other
# than just a local server address.
config.canonical_url = config_dict.placeholder(str)

# Custom page title for this server.
config.page_title = config_dict.placeholder(str)

# Whether the LIT instance is a development demo.
config.development_demo = False

config.client_root = os.path.join(
    pathlib.Path(__file__).parent.absolute(), 'client', 'build', 'default')

config_flags.DEFINE_config_dict('lit', config)
# LINT.ThenChange(server_flags.py)


def get_flags():
    return config
예제 #4
0
flags.DEFINE_integer('batch_size', 1024, 'Batch size.')
flags.DEFINE_integer('max_dev_size', 10 * 1024, 'Maximum dev dataset size.')

flags.DEFINE_float('stage1_reg', 1e-5,
                   'ridge regularizer for stage 1 regression')
flags.DEFINE_float('stage2_reg', 1e-5,
                   'ridge regularizer for stage 2 regression')
flags.DEFINE_integer('n_component', 512, 'Number of random Fourier features.')
flags.DEFINE_float('gamma', None, 'Gamma in Gaussian kernel.')

flags.DEFINE_integer('evaluate_init_samples', 100,
                     'Number of initial samples for evaluation.')

flags.DEFINE_integer('max_steps', 1, 'Max number of steps.')

config_flags.DEFINE_config_dict('problem_config', utils.get_problem_config(),
                                'ConfigDict instance for problem config.')
FLAGS = flags.FLAGS


def main(_):
    problem_config = FLAGS.problem_config

    # Load the offline dataset and environment.
    dataset, dev_dataset, environment = utils.load_data_and_env(
        task_name=problem_config['task_name'],
        noise_level=problem_config['noise_level'],
        near_policy_dataset=problem_config['near_policy_dataset'],
        dataset_path=FLAGS.dataset_path,
        batch_size=FLAGS.batch_size,
        max_dev_size=FLAGS.max_dev_size,
        shuffle=False,
예제 #5
0
from model_dispatcher import load_model
from utils import set_seed

# Configure experiment runner
FLAGS = flags.FLAGS
flags.DEFINE_bool('debug', False, "Show debugging information.")
flags.DEFINE_bool('log', False, "Log this experiment to wandb.")

# Configure experiment tracking
config_wandb = ml_collections.ConfigDict()
config_wandb.project = "hparam-src"
config_wandb.job_type = placeholder(str)
config_wandb.notes = placeholder(str)
config_flags.DEFINE_config_dict(
    'wandb',
    config_wandb,
    "Configuration for W&B experiment tracking.",
)


def main(_):
    if FLAGS.log:
        wandb.init(config=FLAGS, **FLAGS.wandb)

    # Pipeline
    ## Setup
    set_seed()

    ## Data
    X_train, X_test, y_train, y_test = load_train_test_splits()
r"""Example of basic DEFINE_flag_dict usage.

To run this example:
python define_config_dict_basic.py -- --my_config_dict.field1=8 \
  --my_config_dict.nested.field=2.1 --my_config_dict.tuple='(1, 2, (1, 2))'
"""

from absl import app
from absl import flags

import ml_collections
from ml_collections.config_flags import config_flags

config = ml_collections.ConfigDict()
config.field1 = 1
config.field2 = 'tom'
config.nested = ml_collections.ConfigDict()
config.nested.field = 2.23
config.tuple = (1, 2, 3)

FLAGS = flags.FLAGS
config_flags.DEFINE_config_dict('my_config_dict', config)


def main(_):
    print(FLAGS.my_config_dict)


if __name__ == '__main__':
    app.run(main)
예제 #7
0
import pandas as pd

import ml_collections
from absl import flags
from ml_collections.config_flags import config_flags
from ml_collections.config_dict import placeholder

from sklearn.model_selection import train_test_split
from definitions import DATA_DIR


# Configure data
FLAGS = flags.FLAGS
config_data = ml_collections.ConfigDict()
config_data.path = DATA_DIR / "0_raw" / "concrete.csv"
config_data.frac = 1.0
config_data.test_size = 0.2
config_flags.DEFINE_config_dict("data", config_data)


def load_data():
    X = pd.read_csv(FLAGS.data.path)
    X = X.sample(frac=FLAGS.data.frac).reset_index(drop=True)
    y = X.pop("CompressiveStrength")
    return X, y


def load_train_test_splits():
    X, y = load_data()
    return train_test_split(X, y, test_size=FLAGS.data.test_size)