import shutil import tempfile from absl import flags import apache_beam as beam from apache_beam.transforms import combiners from moonlight.pipeline import pipeline_flags from moonlight.training.clustering import staffline_patches_dofn import tensorflow as tf from tensorflow.contrib.learn.python.learn import learn_runner from tensorflow.python.lib.io import file_io from tensorflow.python.lib.io import tf_record FLAGS = flags.FLAGS flags.DEFINE_multi_string('music_pattern', [], 'Pattern for the input music score PNGs.') flags.DEFINE_string('output_path', None, 'Path to the output TFRecords.') flags.DEFINE_integer('patch_height', 18, 'The normalized height of a staffline.') flags.DEFINE_integer('patch_width', 15, 'The width of a horizontal patch of a staffline.') flags.DEFINE_integer('num_stafflines', 19, 'The number of stafflines to extract.') flags.DEFINE_integer('num_pages', 0, 'Subsample the pages to run on.') flags.DEFINE_integer('num_outputs', 0, 'Number of output patches.') flags.DEFINE_integer('max_patches_per_page', 10, 'Sample patches per page if above this amount.') flags.DEFINE_integer('timeout_ms', 600000, 'Timeout for processing a page.') flags.DEFINE_integer('kmeans_num_clusters', 1000, 'Number of k-means clusters.') flags.DEFINE_integer('kmeans_batch_size', 10000,
flags.DEFINE_integer('decode_top_k', 8, 'Maximum number of tokens to consider for begin/end.') flags.DEFINE_integer('decode_max_size', 16, 'Maximum number of sentence pieces in an answer.') flags.DEFINE_float('dropout_rate', 0.1, 'Dropout rate for hidden layers.') flags.DEFINE_float('attention_dropout_rate', 0.3, 'Dropout rate for attention layers.') flags.DEFINE_float('label_smoothing', 1e-1, 'Degree of label smoothing.') flags.DEFINE_multi_string( 'gin_bindings', [], 'Gin bindings to override the values set in the config files') FLAGS = flags.FLAGS @contextlib.contextmanager def worker_context(): if FLAGS.master: with tf.device('/job:worker') as d: yield d else: yield def read_sentencepiece_model(path):
# from mesh_tensorflow.transformer import utils import gin import t5 _DEFAULT_MODULE_IMPORTS = [] FLAGS = flags.FLAGS flags.DEFINE_string("task", None, "A registered Task.") flags.DEFINE_integer("max_examples", -1, "maximum number of examples. -1 for no limit") flags.DEFINE_string("format_string", "{inputs}\t{targets}", "format for printing examples") flags.DEFINE_multi_string( "module_import", _DEFAULT_MODULE_IMPORTS, "Modules to import. Use this when your Task or is defined outside " "of the T5 codebase so that it is registered.") flags.DEFINE_string("split", "train", "which split of the dataset, e.g. train or validation") flags.DEFINE_bool("detokenize", False, "If True, then decode ids to strings.") @gin.configurable def sequence_length(value=512): """Sequence length used when tokenizing. Args: value: an integer or dictionary Returns: a dictionary
from official.common import registry_imports # pylint: disable=unused-import from official.core import exp_factory from official.modeling import hyperparams from official.vision.beta.serving import export_saved_model_lib FLAGS = flags.FLAGS flags.DEFINE_string('experiment', None, 'experiment type, e.g. retinanet_resnetfpn_coco') flags.DEFINE_string('export_dir', None, 'The export directory.') flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path.') flags.DEFINE_multi_string( 'config_file', default=None, help='YAML/JSON files which specifies overrides. The override order ' 'follows the order of args. Note that each file ' 'can be used as an override template to override the default parameters ' 'specified in Python. If the same parameter is specified in both ' '`--config_file` and `--params_override`, `config_file` will be used ' 'first, followed by params_override.') flags.DEFINE_string( 'params_override', '', 'The JSON/YAML file or string which specifies the parameter to be overriden' ' on top of `config_file` template.') flags.DEFINE_integer('batch_size', None, 'The batch size.') flags.DEFINE_string( 'input_type', 'image_tensor', 'One of `image_tensor`, `image_bytes`, `tf_example` and `tflite`.') flags.DEFINE_string( 'input_image_size', '224,224', 'The comma-separated string of two integers representing the height,width '
(va verificato però fino a che punto) se updatiamo le priority sul RB mentre stiamo iterando sul dataset la cosa non diventa problematica perchè a ogni iterazione viene chiamata self.get_next() che esegue con le nuove (e corrette) priority. N.B. probabilmente il metodo prefetch() applicato al dataset gioca un ruolo perchè le cose già fetchate a occhio non aggiornano la priorità, ma di queste sottigliezze forse possiamo fregarcene e contare che non cambino molto 2) Implementare il cambiamento non sarebbe particolarmente fastidioso per il driver visto che potremmo usare il DynamicStepDriver che come output ti dà lo stato delle cose all'ultimo step che poi puoi passare al DynamicStepDriver stesso alla prossima iterazione perchè riprenda da lì (e questo dovrebbe mantenere consistency per quanto riguarda metriche importanti come AverageReturn) """ flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'), 'Root directory for writing logs/summaries/checkpoints.') flags.DEFINE_multi_string( 'gin_files', [], 'List of paths to gin configuration files (e.g.' '"configs/hanabi_rainbow.gin").') flags.DEFINE_multi_string( 'gin_bindings', [], 'Gin bindings to override the values set in the config files ' '(e.g. "train_eval.num_iterations=100").') FLAGS = flags.FLAGS #TODO Very much unfinished function. it should run an episode stopping step by step # and printing everything we might want to see. def run_verbose_mode(agent_1, agent_2): env = rl_env.make('Hanabi-Full-CardKnowledge', num_players=2) tf_env = tf_py_environment.TFPyEnvironment(env)
from typing import Any, Callable, Dict, List, Optional, Text EXPORTER_FN = Callable[[ model_interface.ModelInterface, abstract_export_generator. AbstractExportGenerator ], List[tf.estimator.Exporter]] FLAGS = flags.FLAGS try: flags.DEFINE_list( 'gin_configs', None, 'A comma-separated list of paths to Gin ' 'configuration files.') flags.DEFINE_multi_string( 'gin_bindings', [], 'A newline separated list of Gin parameter bindings.') except flags.DuplicateFlagError: pass gin_configurable_eval_spec = gin.external_configurable( tf.estimator.EvalSpec, name='tf.estimator.EvalSpec') def print_spec(tensor_spec): """Iterate over a spec and print its values in sorted order. Args: tensor_spec: A dict, (named)tuple, list or a hierarchy thereof filled by TensorSpecs(subclasses) or Tensors. """
flags.DEFINE_integer('pre_train_steps', 100, help=('pretrain steps')) flags.DEFINE_integer('finetune_steps', 100, help=('pretrain steps')) flags.DEFINE_integer('ctrl_steps', 100, help=('pretrain steps')) flags.DEFINE_string( 'param_file', None, help=( 'Base set of model parameters to use with this model. To see ' 'documentation on the parameters, see the docstring in resnet_params.' )) flags.DEFINE_multi_string( 'param_overrides', None, help=( 'Model parameter overrides for this model. For example, if ' 'experimenting with larger numbers of train_steps, a possible value ' 'is --param_overrides=train_steps=28152. If you have a collection of ' 'parameters that make sense to use together repeatedly, consider ' 'extending resnet_params.param_sets_table.')) flags.DEFINE_string( 'data_dir', '', help=('The directory where the ImageNet input data is stored. Please see' ' the README.md for the expected data format.')) flags.DEFINE_string( 'model_dir', None, help=('The directory where the model and training/evaluation summaries are' ' stored.')) flags.DEFINE_string('mode',
from tapas.scripts import calc_metrics_utils from tapas.scripts import prediction_utils as pred_utils from tapas.utils import experiment_utils # pylint: disable=unused-import import tensorflow.compat.v1 as tf FLAGS = flags.FLAGS flags.DEFINE_string("data_format", "tfrecord", "The input data format.") flags.DEFINE_string( "compression_type", "GZIP", "Compression to use when reading tfrecords. '' for no compression.", ) flags.DEFINE_multi_string( "input_file_train", None, "Input TF example files (can be a glob or comma separated).") flags.DEFINE_multi_string( "input_file_eval", None, "Input TF example files (can be a glob or comma separated).") flags.DEFINE_multi_string( "input_file_predict", None, "Input TF example files (can be a glob or comma separated).") flags.DEFINE_string( "prediction_output_dir", None, "If not none or empty writes predictions to this directory. Otherwise " "writes predictions to model_dir.")
flags.DEFINE_bool( 'deploy', True, 'Depoly the test environment. ' 'Set to false to skip the deployment phase and go straight to tests') flags.DEFINE_bool( 'skip_before_all', False, 'True to skip @before_all methods. ' 'Like --nodeploy, this is used to skip set up steps. ' 'Useful when developing new tests.') flags.DEFINE_bool( 'no_external_access', False, 'True to skip creating RDP/SSH firewall ' 'rules during deployment. Should be used in automated test runs.') flags.DEFINE_bool('cleanup', False, 'Clean up the host environment after the test') flags.DEFINE_string('error_logs_dir', None, 'Where to collect extra logs on test failures') flags.DEFINE_multi_string('test_arg', None, 'Flags passed to tests') def ConfigureLogging(): # Filter out logs from low level loggers errorOnlyLoggers = [ 'googleapiclient.discovery_cache', 'google.auth', 'google_auth_httplib2' ] for logger in errorOnlyLoggers: logging.getLogger(logger).setLevel(logging.ERROR) message = 'We recommend that most server applications use service accounts.' warnings.filterwarnings('ignore', '.*%s' % message) logging.error("%s: Logging level error is visible." % __file__) logging.warning("%s: Logging level warning is visible." % __file__) logging.info("%s: Logging level info is visible." % __file__)
import gin.tf import time import experiment_prototype from absl import app from absl import flags import argparse flags.DEFINE_string('env', 'cartpole', 'Env name') flags.DEFINE_string('base_dir', './results', 'Base dir') flags.DEFINE_string('agent_id', 'mmd', 'Agent id') flags.DEFINE_string('agent_name', None, 'Agent name') flags.DEFINE_integer('run_id', 0, 'run id') flags.DEFINE_multi_string('gin_files', './configs/mmd_atari.gin', 'List of paths to gin configuration files') flags.DEFINE_multi_string('gin_bindings', [], 'Gin bindings to override the values in the config files') flags.DEFINE_bool('debug', 0, 'Debug') FLAGS = flags.FLAGS def main(unused_argv): agent_name = FLAGS.agent_id if FLAGS.agent_name is None else FLAGS.agent_name if FLAGS.debug: base_dir = os.path.join('./results/tmp',str(time.time())) else: base_dir = os.path.join(FLAGS.base_dir, FLAGS.env, agent_name, 'run_%d'%(FLAGS.run_id)) tf.logging.set_verbosity(tf.logging.INFO) experiment_prototype.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
from __future__ import absolute_import from __future__ import division from __future__ import print_function import json import os from absl import app from absl import flags import gin.tf from bisimulation_aaai2020.grid_world import grid_world flags.DEFINE_string('grid_file', None, 'Path to file defining grid world MDP.') flags.DEFINE_string('base_dir', None, 'Base directory to store stats.') flags.DEFINE_multi_string('gin_files', [], 'List of paths to gin configuration files.') flags.DEFINE_multi_string( 'gin_bindings', [], 'Gin bindings to override the values set in the config files.') flags.DEFINE_bool('exact_metric', True, 'Whether to compute the metric using the exact method.') flags.DEFINE_bool('sampled_metric', True, 'Whether to compute the metric using sampling.') flags.DEFINE_bool('learn_metric', True, 'Whether to compute the metric using learning.') flags.DEFINE_bool('sample_distance_pairs', True, 'Whether to aggregate states (needs a learned metric.') flags.DEFINE_integer('num_samples_per_cell', 100, 'Number of samples per cell when aggregating.') flags.DEFINE_bool('verbose', False, 'Whether to print verbose messages.')
# See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Flags for trainer.py and rl_trainer.py. We keep these flags in sync across the trainer and the rl_trainer binaries. """ from absl import flags from absl import logging # Common flags. flags.DEFINE_string('output_dir', None, 'Path to the directory to save logs and checkpoints.') flags.DEFINE_multi_string('config_file', None, 'Configuration file with parameters (.gin).') flags.DEFINE_multi_string('config', None, 'Configuration parameters (gin string).') # TPU Flags flags.DEFINE_bool('use_tpu', False, "Whether we're running on TPU.") flags.DEFINE_string( 'jax_xla_backend', 'xla', 'Either "xla" for the XLA service directly, or "tpu_driver"' 'for a TPU Driver backend.') flags.DEFINE_string( 'jax_backend_target', 'local', 'Either "local" or "rpc:address" to connect to a ' 'remote service target.') # trainer.py flags.
from absl import flags ''' flags.DEFINE_multi_string("gin_file", "dataset.gin", "Path to a Gin file.") flags.DEFINE_multi_string("gin_file", "gs://t5-data/pretrained_models/small/operative_config.gin", "Path to a Gin file.") flags.DEFINE_multi_string("gin_param", "utils.tpu_mesh_shape.model_parallelism = 1", "Gin parameter binding.") flags.DEFINE_multi_string("gin_param", "utils.tpu_mesh_shape.tpu_topology = '2x2'", "Gin parameter binding.") flags.DEFINE_multi_string("gin_param", "MIXTURE_NAME = 'glue_mrpc_v002'", "Gin parameter binding.") flags.DEFINE_list("gin_location_prefix", [], "Gin file search path.") ''' flags.DEFINE_string( "tpu_job_name", None, "Name of TPU worker binary. Only necessary if job name is changed from " "default tpu_worker.") flags.DEFINE_string("model_dir", "transformer_standalone", "Estimator model_dir") flags.DEFINE_string( "tpu", None, "The Cloud TPU to use for training. This should be either the name " "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url." ) flags.DEFINE_string( "gcp_project", None, "Project name for the Cloud TPU-enabled project. If not specified, we " "will attempt to automatically detect the GCE project from metadata.") flags.DEFINE_string( "tpu_zone", None, "GCE zone where the Cloud TPU is located in. If not specified, we "
FLAGS = flags.FLAGS flags.DEFINE_string('clust_spec_csv_path', '../sandbox/clust_spec_example.csv', 'cluster specification csv file path.') flags.DEFINE_string('remote_working_folder', '/root/code', 'unified remote working folder across all machines.') flags.DEFINE_boolean('force_overwrite_remote', True, 'Should force overwriting remote working folder?') flags.DEFINE_string( 'remote_worker_pre_cmd', '', "pre command for each REMOTE worker. " "e.g., activate python virutal env") flags.DEFINE_multi_string( "prepare_python_package_path", [], "pip installable package path in comma separated format: " "package_local_fullpath,remote_fullpath " "It can occur multiple times. E.g., \n" "-p /root/Arena,/root/Arena\n" "-p /home/me/pysc2,/home/work/pysc2\n" "...", short_name='p') flags.DEFINE_integer('n_process', 1, 'number of parallel processes for connections.') def _parse_packages(args): def _known_package_ext_whitelist(package_basename): if package_basename.lower() == 'pysc2': # Tencent PySC2 Extension return ['.py', '.serialized', '.SC2Map', '.md'] if package_basename.lower() == 'arena': return ['.py', '.md', '.wad', '.cfg'] return ['.py', '.md', '.csv'] # default to these typical files
'For more information see train_util.get_strategy().') flags.DEFINE_boolean( 'allow_memory_growth', False, 'Whether to grow the GPU memory usage as is needed by the ' 'process. Prevents crashes on GPUs with smaller memory.') flags.DEFINE_boolean( 'hypertune', False, 'Enable metric reporting for hyperparameter tuning, such ' 'as on Google Cloud AI-Platform.') flags.DEFINE_float( 'early_stop_loss_value', None, 'Stops training early when the `total_loss` reaches below ' 'this value during training.') # Gin config flags. flags.DEFINE_multi_string('gin_search_path', [], 'Additional gin file search paths.') flags.DEFINE_multi_string( 'gin_file', [], 'List of paths to the config files. If file ' 'in gstorage bucket specify whole gstorage path: ' 'gs://bucket-name/dir/in/bucket/file.gin.') flags.DEFINE_multi_string('gin_param', [], 'Newline separated list of Gin parameter bindings.') # Evaluation/sampling specific flags. flags.DEFINE_boolean('run_once', False, 'Whether evaluation will run once.') flags.DEFINE_integer('initial_delay_secs', None, 'Time to wait before evaluation starts') GIN_PATH = pkg_resources.resource_filename(__name__, 'gin')
flags.DEFINE_enum( 'inference_model', 'streaming_f0_pw', [ 'autoencoder', 'streaming_f0_pw', 'vst_extract_features', 'vst_predict_controls', 'vst_synthesize', ], 'Specify the ddsp.training.inference model to use for ' 'converting a checkpoint to a SavedModel. Names are ' 'snake_case versions of class names.') # Optional flags. flags.DEFINE_multi_string('gin_param', [], 'Gin parameters for custom inference model kwargs.') flags.DEFINE_boolean('debug', False, 'DEBUG: Do not save the model') # Conversion formats. flags.DEFINE_boolean('tfjs', True, 'Convert SavedModel to TFJS for deploying on the web.') flags.DEFINE_boolean('tflite', True, 'Convert SavedModel to TFLite for embedded C++ apps.') flags.DEFINE_string('metadata_file', None, 'Optional metadata file to pack into TFLite model.') FLAGS = flags.FLAGS def get_inference_model(ckpt): """Restore model from checkpoint using global FLAGS.
def test_write_help_in_xmlformat(self): fv = flags.FlagValues() # Since these flags are defined by the top module, they are all key. flags.DEFINE_integer('index', 17, 'An integer flag', flag_values=fv) flags.DEFINE_integer('nb_iters', 17, 'An integer flag', lower_bound=5, upper_bound=27, flag_values=fv) flags.DEFINE_string('file_path', '/path/to/my/dir', 'A test string flag.', flag_values=fv) flags.DEFINE_boolean('use_gpu', False, 'Use gpu for performance.', flag_values=fv) flags.DEFINE_enum('cc_version', 'stable', ['stable', 'experimental'], 'Compiler version to use.', flag_values=fv) flags.DEFINE_list('files', 'a.cc,a.h,archive/old.zip', 'Files to process.', flag_values=fv) flags.DEFINE_list('allow_users', ['alice', 'bob'], 'Users with access.', flag_values=fv) flags.DEFINE_spaceseplist('dirs', 'src libs bins', 'Directories to create.', flag_values=fv) flags.DEFINE_multi_string('to_delete', ['a.cc', 'b.h'], 'Files to delete', flag_values=fv) flags.DEFINE_multi_integer('cols', [5, 7, 23], 'Columns to select', flag_values=fv) flags.DEFINE_multi_enum('flavours', ['APPLE', 'BANANA'], ['APPLE', 'BANANA', 'CHERRY'], 'Compilation flavour.', flag_values=fv) # Define a few flags in a different module. module_bar.define_flags(flag_values=fv) # And declare only a few of them to be key. This way, we have # different kinds of flags, defined in different modules, and not # all of them are key flags. flags.declare_key_flag('tmod_bar_z', flag_values=fv) flags.declare_key_flag('tmod_bar_u', flag_values=fv) # Generate flag help in XML format in the StringIO sio. sio = io.StringIO() if six.PY3 else io.BytesIO() fv.write_help_in_xml_format(sio) # Check that we got the expected result. expected_output_template = EXPECTED_HELP_XML_START main_module_name = sys.argv[0] module_bar_name = module_bar.__name__ if main_module_name < module_bar_name: expected_output_template += EXPECTED_HELP_XML_FOR_FLAGS_FROM_MAIN_MODULE expected_output_template += EXPECTED_HELP_XML_FOR_FLAGS_FROM_MODULE_BAR else: expected_output_template += EXPECTED_HELP_XML_FOR_FLAGS_FROM_MODULE_BAR expected_output_template += EXPECTED_HELP_XML_FOR_FLAGS_FROM_MAIN_MODULE expected_output_template += EXPECTED_HELP_XML_END # XML representation of the whitespace list separators. whitespace_separators = _list_separators_in_xmlformat( string.whitespace, indent=' ') expected_output = (expected_output_template % { 'basename_of_argv0': os.path.basename(sys.argv[0]), 'usage_doc': sys.modules['__main__'].__doc__, 'main_module_name': main_module_name, 'module_bar_name': module_bar_name, 'whitespace_separators': whitespace_separators }) actual_output = sio.getvalue() self.assertMultiLineEqual(expected_output, actual_output) # Also check that our result is valid XML. minidom.parseString # throws an xml.parsers.expat.ExpatError in case of an error. xml.dom.minidom.parseString(actual_output)
# See the License for the specific language governing permissions and # limitations under the License. """Utilities for asserting against golden files.""" import contextlib import difflib import io import os.path import re import sys import traceback from typing import Dict, Optional from absl import flags flags.DEFINE_multi_string('golden', [], 'List of golden files available.') flags.DEFINE_bool('update_goldens', False, 'Set true to update golden files.') flags.DEFINE_bool('verbose', False, 'Set true to show golden diff output.') FLAGS = flags.FLAGS class MismatchedGoldenError(RuntimeError): pass _filename_to_golden_map: Optional[Dict[str, str]] = None def _filename_to_golden_path(filename: str) -> str: """Retrieve the `--golden` path flag for a golden file named `filename`."""
############################################################################ # Finally, write commands, script, and results to disk ############################################################################ # for reproducibility, copy command and script contents to results if results_dir not in ('.', ): cmd = 'python ' + ' '.join(sys.argv) with open(os.path.join(results_dir, 'command.sh'), 'w') as f: f.write(cmd) this_script = open(__file__, 'r').readlines() with open(os.path.join(results_dir, __file__), 'w') as f: f.write(''.join(this_script)) results_filename = os.path.join(results_dir, 'results.p') with open(results_filename, 'wb') as f: _ = pickle.dump(results, f) # Finally, write gin config to disk with open(os.path.join(results_dir, 'config.gin'), 'w') as f: f.write(gin.operative_config_str()) if __name__ == "__main__": FLAGS = flags.FLAGS flags.DEFINE_string('gin_file', './config/multi_step_simulation.gin', 'Config file path.') flags.DEFINE_multi_string( 'gin_param', None, 'Newline separated list of Gin parameter bindings.') app.run(main)
from tf_agents.metrics import tf_metrics from tf_agents.metrics import batched_py_metric from tf_agents.networks import actor_distribution_network from tf_agents.networks import normal_projection_network from tf_agents.networks.utils import mlp_layers from tf_agents.policies import greedy_policy from tf_agents.policies import py_tf_policy from tf_agents.policies import random_tf_policy from tf_agents.replay_buffers import tf_uniform_replay_buffer from tf_agents.utils import common from tf_agents.utils import episode_utils from IPython import embed flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'), 'Root directory for writing logs/summaries/checkpoints.') flags.DEFINE_multi_string('gin_file', None, 'Path to the gin config files.') flags.DEFINE_multi_string('gin_param', None, 'Gin binding to pass through.') flags.DEFINE_integer('num_iterations', 1000000, 'Total number train/eval iterations to perform.') flags.DEFINE_integer( 'initial_collect_steps', 1000, 'Number of steps to collect at the beginning of training using random policy' ) flags.DEFINE_integer( 'collect_steps_per_iteration', 1, 'Number of steps to collect and be added to the replay buffer after every training iteration' ) flags.DEFINE_integer('num_parallel_environments', 1, 'Number of environments to run in parallel') flags.DEFINE_integer('num_parallel_environments_eval', 1,
from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl import app from absl import flags from dopamine.discrete_domains import run_experiment import tensorflow.compat.v1 as tf flags.DEFINE_string('base_dir', None, 'Base directory to host all required sub-directories.') flags.DEFINE_integer('level', 1, 'Game Level to Start') flags.DEFINE_multi_string( 'gin_files', [], 'List of paths to gin configuration files (e.g.' '"dopamine/agents/dqn/dqn.gin").') flags.DEFINE_multi_string( 'gin_bindings', [], 'Gin bindings to override the values set in the config files ' '(e.g. "DQNAgent.epsilon_train=0.1",' ' "create_environment.game_name="Pong"").') FLAGS = flags.FLAGS def main(unused_argv): """Main method. Args: unused_argv: Arguments (unused).
import gin import gym import gym_minigrid # pylint: disable=unused-import from gym_minigrid.wrappers import RGBImgObsWrapper import matplotlib.pylab as plt import tensorflow as tf from minigrid_basics.custom_wrappers import tabular_wrapper # pylint: disable=unused-import from minigrid_basics.envs import mon_minigrid FLAGS = flags.FLAGS flags.DEFINE_string('file_path', '/tmp/rw_four_directions', 'Path in which we will save the observations.') flags.DEFINE_multi_string( 'gin_bindings', [], 'Gin bindings to override default parameter values ' '(e.g. "MonMiniGridEnv.stochasticity=0.1").') def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') gin.parse_config_files_and_bindings( [os.path.join(mon_minigrid.GIN_FILES_PREFIX, 'classic_fourrooms.gin')], bindings=FLAGS.gin_bindings, skip_unknown=False) env_id = mon_minigrid.register_environment() env = gym.make(env_id) env = RGBImgObsWrapper(env) # Get pixel observations # Get tabular observation and drop the 'mission' field:
import numpy as np import relabelling_replay_buffer import tensorflow as tf from tf_agents.agents.ddpg import critic_network from tf_agents.agents.sac import sac_agent from tf_agents.eval import metric_utils from tf_agents.metrics import tf_metrics from tf_agents.networks import actor_distribution_network from tf_agents.policies import greedy_policy from tf_agents.policies import random_tf_policy from tf_agents.utils import common import utils flags.DEFINE_string("root_dir", None, "Root directory for writing logs/summaries/checkpoints.") flags.DEFINE_multi_string("gin_file", None, "Path to the trainer config files.") flags.DEFINE_multi_string("gin_bindings", None, "Gin binding to pass through.") FLAGS = flags.FLAGS @gin.configurable def train_eval( root_dir, env_name="HalfCheetah-v2", num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect
from absl import flags from absl import logging import gin from tensor2tensor.jax import j2j import tensorflow as tf FLAGS = flags.FLAGS flags.DEFINE_string("dataset", None, "Which dataset to use.") flags.DEFINE_string("model", None, "Which model to train.") flags.DEFINE_string("data_dir", None, "Path to the directory with data.") flags.DEFINE_string("output_dir", None, "Path to the directory to save logs and checkpoints.") flags.DEFINE_multi_string("config_file", None, "Configuration file with parameters (.gin).") flags.DEFINE_multi_string("config", None, "Configuration parameters (gin string).") # For iterators over datasets so we can do "for example in dataset". tf.enable_v2_behavior() def j2j_train(model_name, dataset_name, data_dir=None, output_dir=None, config_file=None, config=None): """Main function to train the given model on the given dataset.
def define_flags(): """Defines flags.""" flags.DEFINE_string('experiment', default=None, help='The experiment type registered.') flags.DEFINE_enum('mode', default=None, enum_values=[ 'train', 'eval', 'train_and_eval', 'continuous_eval', 'continuous_train_and_eval' ], help='Mode to run: `train`, `eval`, `train_and_eval`, ' '`continuous_eval`, and `continuous_train_and_eval`.') flags.DEFINE_string( 'model_dir', default=None, help='The directory where the model and training/evaluation summaries' 'are stored.') flags.DEFINE_multi_string( 'config_file', default=None, help='YAML/JSON files which specifies overrides. The override order ' 'follows the order of args. Note that each file ' 'can be used as an override template to override the default parameters ' 'specified in Python. If the same parameter is specified in both ' '`--config_file` and `--params_override`, `config_file` will be used ' 'first, followed by params_override.') flags.DEFINE_string( 'params_override', default=None, help='a YAML/JSON string or a YAML file which specifies additional ' 'overrides over the default parameters and those specified in ' '`--config_file`. Note that this is supposed to be used only to override ' 'the model parameters, but not the parameters like TPU specific flags. ' 'One canonical use case of `--config_file` and `--params_override` is ' 'users first define a template config file using `--config_file`, then ' 'use `--params_override` to adjust the minimal set of tuning parameters, ' 'for example setting up different `train_batch_size`. The final override ' 'order of parameters: default_model_params --> params from config_file ' '--> params in params_override. See also the help message of ' '`--config_file`.') flags.DEFINE_multi_string('gin_file', default=None, help='List of paths to the config files.') flags.DEFINE_multi_string( 'gin_params', default=None, help='Newline separated list of Gin parameter bindings.') flags.DEFINE_string( 'tpu', default=None, help='The Cloud TPU to use for training. This should be either the name ' 'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 ' 'url.') flags.DEFINE_string('tf_data_service', default=None, help='The tf.data service address')
from __future__ import absolute_import from __future__ import division from __future__ import print_function import pickle from absl import app from absl import flags from tleague.model_pools.model_pool_apis import ModelPoolAPIs from tleague.model_pools.model import Model FLAGS = flags.FLAGS flags.DEFINE_string("model_pool_addrs", "localhost:10003:10004", "Model Pool address.") flags.DEFINE_multi_string("model_path", [], "model file path") flags.DEFINE_multi_string("model_key", [], "model_keys") def main(_): model_pool_apis = ModelPoolAPIs(FLAGS.model_pool_addrs.split(',')) keys = model_pool_apis.pull_keys() for key, model_path in zip(FLAGS.model_key, FLAGS.model_path): if key in keys: m = model_pool_apis.pull_model(key) with open(model_path, 'rb') as f: model = pickle.load(f) if isinstance(model, Model): model = model.model model_pool_apis.push_model(model, m.hyperparam, m.key, m.createtime, m.freezetime, m.updatetime)
# coding=utf-8 # Copyright 2022 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Common flags.""" from absl import flags flags.DEFINE_integer('trial_id', 0, 'The trial ID from 0 to num_trials-1.') flags.DEFINE_integer('max_episode_len', 1000, 'Number of steps in an episode.') flags.DEFINE_string('env_name', 'cartpole-swingup', 'Name of the environment.') flags.DEFINE_string('root_dir', None, 'Path to output trajectories from data collection') flags.DEFINE_multi_string('gin_files', None, 'Paths to the gin-config files.') flags.DEFINE_multi_string('gin_bindings', None, 'Gin binding parameters.') flags.DEFINE_integer('seed', None, 'Random Seed for model_dir/data_dir')
common_tpu_flags.define_common_tpu_flags() common_hparams_flags.define_common_hparams_flags() FLAGS = flags.FLAGS FAKE_DATA_DIR = 'gs://cloud-tpu-test-datasets/fake_imagenet' flags.DEFINE_string( 'hparams_file', default=None, help=('Set of model parameters to override the default mparams.')) flags.DEFINE_multi_string( 'hparams', default=None, help=('This is used to override only the model hyperparameters. It should ' 'not be used to override the other parameters like the tpu specific ' 'flags etc. For example, if experimenting with larger numbers of ' 'train_steps, a possible value is ' '--hparams=train_steps=28152.')) flags.DEFINE_string( 'default_hparams_file', default=os.path.join(os.path.dirname(__file__), './configs/default.yaml'), help=( 'Default set of model parameters to use with this model. Look the at ' 'configs/default.yaml for this.')) flags.DEFINE_integer( 'resnet_depth', default=None, help=('Depth of ResNet model to use. Must be one of {18, 34, 50, 101, 152,'
from six.moves import range import tensorflow as tf from tf_agents.agents.ddpg import critic_network from tf_agents.agents.sac import tanh_normal_projection_network from tf_agents.drivers import dynamic_step_driver from tf_agents.eval import metric_utils from tf_agents.metrics import tf_metrics from tf_agents.networks import actor_distribution_network from tf_agents.policies import greedy_policy from tf_agents.policies import random_tf_policy from tf_agents.replay_buffers import tf_uniform_replay_buffer from tf_agents.utils import common flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'), 'Root directory for writing logs/summaries/checkpoints.') flags.DEFINE_multi_string('gin_file', None, 'Path to the trainer config files.') flags.DEFINE_multi_string('gin_bindings', None, 'Gin binding to pass through.') FLAGS = flags.FLAGS @gin.configurable def bce_loss(y_true, y_pred, label_smoothing=0): loss_fn = tf.keras.losses.BinaryCrossentropy( label_smoothing=label_smoothing, reduction=tf.keras.losses.Reduction.NONE) return loss_fn(y_true[:, None], y_pred[:, None]) @gin.configurable class ClassifierCriticNetwork(critic_network.CriticNetwork): """Creates a critic network."""
params = utils.get_gin_params_as_dict(gin.config._CONFIG) neptune.init(project_qualified_name="melindafkiss/sandbox") exp = neptune.create_experiment(params=params, name="exp") #ONLY WORKS FOR ONE GIN-CONFIG FILE with open(FLAGS.gin_file[0]) as ginf: param = ginf.readline() while param: param = param.replace('.','-').replace('=','-').replace(' ','').replace('\'','').replace('\n','').replace('@','') #neptune.append_tag(param) param = ginf.readline() #for tag in opts['tags'].split(','): # neptune.append_tag(tag) else: neptune.init('shared/onboarding', api_token='ANONYMOUS', backend=neptune.OfflineBackend()) er = ExperimentRunner(prefix=exp.id) er.train() params = utils.get_gin_params_as_dict(gin.config._OPERATIVE_CONFIG) for k, v in params.items(): neptune.set_property(k, v) neptune.stop() print('fin') if __name__ == '__main__': flags.DEFINE_multi_string('gin_file', None, 'List of paths to the config files.') flags.DEFINE_multi_string('gin_param', None, 'Newline separated list of Gin parameter bindings.') FLAGS = flags.FLAGS app.run(main)