def test_operative_config_str(self): with GinState(): g() op_config_string = gin.operative_config_str() self.assertTrue('f.x' not in gin.operative_config_str()) self.assertTrue('g.z' in gin.operative_config_str()) f() op_config_string = gin.operative_config_str() self.assertTrue('f.x' in op_config_string) self.assertTrue('g.z' not in op_config_string)
def gin_log_config(log_dir, slug): """ Attempt to write the opperative config to our log dir and Tensorboard... Also prints the config to tf.logging.INFO """ try: if log_dir is not None and tf.io.gfile.isdir(log_dir): path = os.path.join(log_dir, "%s.gin" % slug) with tf.io.gfile.GFile(path, "w") as fp: fp.write(gin.operative_config_str()) except: logging.info(gin.operative_config_str()) logging.warn("Failed to write log gin config")
def log_gin_config(output_dir, cometml_experiment=None, wandb_run=None): gin_config_str = gin.operative_config_str() print("Used config: " + "-" * 40) print(gin_config_str) print("-" * 52) with open(os.path.join(output_dir, "config.gin"), "w") as f: f.write(gin_config_str) # parse the gin config string to dictionary gin_config_str = "\n".join( [x for x in gin_config_str.split("\n") if not x.startswith("import")]) gin_config_dict = yaml.load( gin_config_str.replace("@", "").replace(" = %", ": ").replace(" = ", ": ")) write_json(gin_config_dict, os.path.join(output_dir, "config.gin.json"), sort_keys=True, indent=2) if cometml_experiment is not None: # Skip any rows starting with import cometml_experiment.log_multiple_params(gin_config_dict) if wandb_run is not None: # This allows to display the metric on the dashboard wandb_run.config.update( {k.replace(".", "/"): v for k, v in gin_config_dict.items()})
def main(experiment_name: str): print("Using remote MLFlow server") lib_mlflow.setup() print(f"Using experiment name {experiment_name}") lib_mlflow.try_until_success(mlflow.set_experiment, experiment_name) run = lib_mlflow.try_until_success(mlflow.start_run) run_id = run.info.run_id lib_mlflow.log_param("run_id", run.info.run_id) base_dir = Path(os.environ.get("SCRATCH", "./")) / experiment_name / run_id problem = lib_problem.Problem(base_dir=base_dir) results, models = problem.train() gin_config_string = gin.operative_config_str() lib_mlflow.log_params(lib_analysis._parse_gin_config(gin_config_string)) # Include also all the default parameters in the final gin config file. (base_dir / "full_config.gin").write_text(gin_config_string) print("Saving models") model_base_save_path = base_dir / "models" model_base_save_path.mkdir(parents=True, exist_ok=True) for im, m in enumerate(models): out = model_base_save_path / f"model-{im}" if out.exists(): shutil.rmtree(out) out.mkdir() tf.keras.models.save_model(m, str(out)) results["model_base_save_path"] = model_base_save_path print(results) lib_mlflow.shutdown() lib_mlflow.try_until_success(mlflow.end_run)
def main(): parse_args() train_ds, test_ds, num_classes = core_gin.get_datasets() inputs_spec = train_ds.element_spec[0] model = core_gin.get_classification_model(inputs_spec, num_classes) core_gin.fit(model, train_ds, test_ds) print(gin.operative_config_str())
def __init__(self, model, train_dataset, eval_dataset, num_epoch, batch_size, save_dir, snapshot=None): # create TensorFlow Dataset objects parser = functools.partial(data_parser, dim=model.dim) tr_data = tf.data.TFRecordDataset(train_dataset) tr_data = tr_data.map(parser).shuffle(10000).batch(batch_size) val_data = tf.data.TFRecordDataset(eval_dataset) val_data = val_data.map(parser).batch(batch_size) # create TensorFlow Iterator object iterator = tf.data.Iterator.from_structure(tr_data.output_types, tr_data.output_shapes) self.start, self.action, self.result = iterator.get_next() # create two initialization ops to switch between the datasets self.training_init_op = iterator.make_initializer(tr_data) self.validation_init_op = iterator.make_initializer(val_data) tf_config = tf.ConfigProto( inter_op_parallelism_threads=16, intra_op_parallelism_threads=16) tf_config.gpu_options.allow_growth=True self.sess = tf.Session(config=tf_config) self.num_epoch = num_epoch self.model = model self.model.build(input=self.start, action=self.action) self.model.setup_optimizer(0.001, self.result) self.global_step = 0 self.save_dir = save_dir self.train_writer = tf.summary.FileWriter(os.path.join(save_dir, 'tfboard'), self.sess.graph) self.sess.run(tf.global_variables_initializer()) if snapshot is not None: self.model.load(self.sess, snapshot) config_str = gin.operative_config_str() with open(os.path.join(save_dir, '0.gin'), 'w') as f: f.write(config_str)
def on_train_begin(self, logs=None): logger.info( "Saving meta data information from the beginning of training") assert os.system("cp {} {}".format( sys.argv[0], self.save_path)) == 0, "Failed to execute cp of source script" utc_date = datetime.datetime.utcnow().strftime("%Y_%m_%d") time_start = time.time() cmd = "python " + " ".join(sys.argv) self.meta = { "cmd": cmd, "save_path": self.save_path, "most_recent_train_start_date": utc_date, "execution_time": -time_start } json.dump(self.meta, open(os.path.join(self.save_path, "meta.json"), "w"), indent=4) # Copy gin configs used, for reference, to the save folder os.system("rm " + os.path.join(self.save_path, "*gin")) for gin_config in sys.argv[2].split(";"): os.system("cp {} {}/base_config.gin".format( gin_config, self.save_path)) with open(os.path.join(self.save_path, "config.gin"), "w") as f: f.write(gin.operative_config_str())
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') xm.setup_work_unit() if not gfile.Exists(FLAGS.workdir): gfile.MakeDirs(FLAGS.workdir) utils.dump_flags_to_file(os.path.join(FLAGS.workdir, 'flags.txt')) gin.bind_parameter('CuriosityEnvWrapper.scale_task_reward', 0.) gin.bind_parameter('CuriosityEnvWrapper.scale_surrogate_reward', 1.) gin.bind_parameter('AntWrapper.enable_die_condition', FLAGS.ant_env_enable_die_condition) gin.parse_config_files_and_bindings(None, FLAGS.gin_bindings) # Hardware crashes with: # Failed to open library! # dlopen: cannot load any more object with static TLS FLAGS.renderer = 'software' work_unit = None if FLAGS.xm_xid != -1: work_unit = xmanager_api.XManagerApi().get_current_work_unit() visualize_curiosity_reward(work_unit) with gfile.GFile(os.path.join(FLAGS.workdir, 'gin_config.txt'), 'w') as f: f.write(gin.operative_config_str())
def train(device, save, epochs, episodes, seed, task): Env = dict(keytask=KeyTask)[task] env = Env(seed=seed, max_steps=100) replay = gen.gen_env(env, n_episodes=episodes) tloader = gen.Transitions(replay) loader = data.DataLoader(tloader, batch_size=512, shuffle=True) learner = model.Learner(device, action_dim=len(env.action_space)) optimizer = optim.Adam(learner.parameters(), lr=0.001) L.info('Starting training for %s epochs.', epochs) L.info('Best model will be saved to %s.', save) with trange(1, epochs + 1) as t: best_loss = 1e9 for epoch in t: loss_ = 0 for step, batch in enumerate(tqdm(loader)): obs, action, reward, next_obs = batch = [ x.to(device) for x in batch ] optimizer.zero_grad() loss = learner.loss(*batch) loss.backward() loss_ += loss.item() optimizer.step() avg_loss = loss_ / len(loader.dataset) if avg_loss < best_loss: best_loss = avg_loss torch.save( dict(model=learner.state_dict(), config=gin.operative_config_str()), save) t.set_postfix(best=best_loss, avg=avg_loss) L.info('Done training. Best loss: %s', best_loss)
def j2j_train(model_name, dataset_name, data_dir=None, output_dir=None, config_file=None, config=None): """Main function to train the given model on the given dataset. Args: model_name: The name of the model to train. dataset_name: The name of the dataset to train on. data_dir: Directory where the data is located. output_dir: Directory where to put the logs and checkpoints. config_file: the gin configuration file to use. config: string (in gin format) to override gin parameters. """ gin.bind_parameter("train_fn.dataset", dataset_name) if FLAGS.model: config = [] if config is None else config config += ["train_fn.model=@models." + model_name] gin.parse_config_files_and_bindings(config_file, config) if output_dir: if not tf.gfile.Exists(output_dir): tf.gfile.MkDir(output_dir) config_path = os.path.join(output_dir, "gin.config") # TODO(lukaszkaiser): why is the file empty if there's no provided config? with tf.gfile.Open(config_path, "w") as f: f.write(gin.operative_config_str()) j2j.train_fn(data_dir, output_dir=output_dir)
def get_gin_confg_strs(): """ Obtain both the operative and inoperative config strs from gin. The operative configuration consists of all parameter values used by configurable functions that are actually called during execution of the current program, and inoperative configuration consists of all parameter configured but not used by configurable functions. See `gin.operative_config_str()` and `gin_utils.inoperative_config_str` for more detail on how the config is generated. Returns: md_operative_config_str (str): a markdown-formatted operative str md_inoperative_config_str (str): a markdown-formatted inoperative str """ operative_config_str = gin.operative_config_str() md_operative_config_str = _markdownify_gin_config_str( operative_config_str, 'All parameter values used by configurable functions that are actually called' ) md_inoperative_config_str = gin_utils.inoperative_config_str() if md_inoperative_config_str: md_inoperative_config_str = _markdownify_gin_config_str( md_inoperative_config_str, "All parameter values configured but not used by program. The configured " "functions are either not called or called with explicit parameter values " "overriding the config.") return md_operative_config_str, md_inoperative_config_str
def log_gin_config(output_dir, cometml_experiment=None, wandb_run=None, prefix=''): """Save the config.gin file containing the whole config, convert it to a dictionary and upload it to cometml and wandb. """ gin_config_str = gin.operative_config_str() print("Used config: " + "-" * 40) print(gin_config_str) print("-" * 52) with open(os.path.join(output_dir, f"{prefix}config.gin"), "w") as f: f.write(gin_config_str) gin_config_dict = gin2dict(gin_config_str) write_json(gin_config_dict, os.path.join(output_dir, f"{prefix}config.gin.json"), sort_keys=True, indent=2) if cometml_experiment is not None: # Skip any rows starting with import cometml_experiment.log_parameters(gin_config_dict) if wandb_run is not None: # This allows to display the metric on the dashboard wandb_run.config.update( {k.replace(".", "/"): v for k, v in gin_config_dict.items()})
def main(args): # Define visible GPU devices print('>> Found {} GPUs to use.'.format(torch.cuda.device_count())) if args.gpu_id: print('>> Using GPU {}'.format(args.gpu_id)) os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id device = torch.device('cuda' if torch.cuda.is_available() else "cpu") # Load gin config based on dataset name gin.parse_config_file('configs/runs/run_{}_on_{}.gin'.format( args.mode, args.dataset)) # For CMU, pick a slice if args.dataset == 'cmu': bind_cmu_parameters(args.cmu_slice, args.mode) # Create dataset loader dataset = get_dataset_loader() # Load retrieval model and initialize pose predictor net = network.ImageRetrievalModel(device=device) print(gin.operative_config_str()) # Check if image retrieval rankings exist, if not compute them ranks = rank_images.fetch_or_compute_ranks(dataset, net) # Predict query images poses pose_predictor = get_pose_predictor(dataset=dataset, network=net, ranks=ranks, log_images=args.log_images) pose_predictor.save(pose_predictor.run()) print("All ranks computed.")
def dump(self): """Save `gin.operative_config_str()` to file.""" exists = os.path.exists(self._path) if exists: raise IOError(f"Operative config already exists at {self._path}") with open(self._path, "w") as fp: fp.write(gin.operative_config_str())
def train_agent(self, agent: PTUtilAgent, **kwargs) -> None: """Handles all the necessary subroutines for training a PTUtilAgent. Args: agent: The agent model to train. """ self.model = agent self.model.to(self.device) self.model.set_trainer(self) # Checkpoint not loaded, start from 0 and dump config params if not self.checkpoint_loaded: self.episode_start = 1 self.step_global = 0 self.logging_buffer.append( LoggingItem("INFO", gin.operative_config_str())) self.episode_losses = [] self.episode_rewards = [] self.checkpoint_loaded = False self._begin_fit(**kwargs) for ep in range(self.episode_start, self.episode_end + 1): self.step_ep = ep self._begin_episode() while not self.env.is_done(): self.step_global += 1 self.step_episode += 1 with torch.cuda.amp.autocast(self.use_amp): loss, reward = self._run_episode_step() self.episode_loss += loss self.episode_reward += reward self._after_episode() if self.early_stoppage: break self._after_fit()
def evaluate(output_folder, separation_algorithm, eval_class, block_on_gpu, num_workers, seed): nussl.utils.seed(seed) logging.info(gin.operative_config_str()) with gin.config_scope('test'): test_dataset = build_dataset() results_folder = os.path.join(output_folder, 'results') os.makedirs(results_folder, exist_ok=True) set_model_to_none = False if block_on_gpu: # make an instance that'll be used on GPU # has an empty audio signal for now gpu_algorithm = separation_algorithm( nussl.AudioSignal(), device='cuda') set_model_to_none = True def forward_on_gpu(audio_signal): # set the audio signal of the object to this item's mix gpu_algorithm.audio_signal = audio_signal if hasattr(gpu_algorithm, 'forward'): gpu_output = gpu_algorithm.forward() elif hasattr(gpu_algorithm, 'extract_features'): gpu_output = gpu_algorithm.extract_features() return gpu_output pbar = tqdm.tqdm(total=len(test_dataset)) def separate_and_evaluate(item, gpu_output): if set_model_to_none: separator = separation_algorithm(item['mix'], model_path=None) else: separator = separation_algorithm(item['mix']) estimates = separator(gpu_output) source_names = sorted(list(item['sources'].keys())) sources = [item['sources'][k] for k in source_names] # other arguments come from gin config evaluator = eval_class(sources, estimates) scores = evaluator.evaluate() output_path = os.path.join( results_folder, f"{item['mix'].file_name}.json") with open(output_path, 'w') as f: json.dump(scores, f, indent=2) pbar.update(1) pool = ThreadPoolExecutor(max_workers=num_workers) for i in range(len(test_dataset)): item = test_dataset[i] gpu_output = forward_on_gpu(item['mix']) if i == 0: separate_and_evaluate(item, gpu_output) continue pool.submit(separate_and_evaluate, item, gpu_output) pool.shutdown(wait=True)
def _save_gin(self): config_path = os.path.join(self._output_dir, 'config.gin') config_str = gin.operative_config_str() with open(config_path, 'w') as f: f.write(config_str) for (name, value) in gin_utils.extract_bindings(config_str): metric_logging.log_property(name, value)
def _write_gin_configs(output_file): """Writes current gin configs to `output_file`.""" config_str = gin.operative_config_str() logging.info('=' * 80) logging.info('Gin configs\n%s', config_str) logging.info('=' * 80) with tf.gfile.GFile(output_file, 'w') as f: f.write(config_str)
def init_saves(self): if not os.path.exists(self.save_location_dir): os.mkdir(self.save_location_dir) with open(os.path.join(self.save_location_dir,'config.gin'), 'w') as conf: conf.write(gin.operative_config_str()) self.output_log = os.path.join(self.save_location_dir,'output_log.txt') self.save_path = os.path.join(self.save_location_dir, 'best_model.pt') self.summary_writer = SummaryWriter(os.path.join(self.save_location_dir, 'logs'), 300)
def main(_): logging.info('parsing config files: %s', FLAGS.gin_file) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param, skip_unknown=True) run_episodes() if FLAGS.debug: print(gin.operative_config_str())
def _save_gin_config(output_dir: str): # Replace \n with \n\n to improve logging. gin_config = gin.operative_config_str().replace('\n', '\n\n') logging.info('gin config: %s', gin_config) tf.summary.text('gin_config', gin_config, 0) with tf.io.gfile.GFile(os.path.join(output_dir, 'gin_config.txt'), 'w') as f: f.write(gin_config)
def on_train_begin(self, logs=None): super(GinConfigWriter, self).on_train_begin(logs) path = os.path.join(self._log_dir, 'operative-config.gin') operative_config = gin.operative_config_str() with tf.io.gfile.GFile(path, 'w') as fp: fp.write(operative_config) logging.info('Starting training with operative config:\n{}'.format( operative_config))
def _save_gin(output_dir, sw=None): config_path = os.path.join(output_dir, "config.gin") config_str = gin.operative_config_str() with gfile.GFile(config_path, "w") as f: f.write(config_str) if sw: sw.text("gin_config", jaxboard.markdownify_operative_config_str(config_str))
def save_gin_config(filename_suffix: str, model_dir: str): """Serializes and saves the experiment config.""" gin_save_path = os.path.join( model_dir, 'operative_config.{}.gin'.format(filename_suffix)) logging.info('Saving gin configurations to %s', gin_save_path) tf.io.gfile.makedirs(model_dir) with tf.io.gfile.GFile(gin_save_path, 'w') as f: f.write(gin.operative_config_str())
def dump(root, name): gin_path = os.path.join(root, f'{name}.gin') if not os.path.isfile(gin_path): with open(gin_path, 'w') as f: f.write(gin.operative_config_str()) arg_path = os.path.join(root, f'{name}.args') if not os.path.isfile(arg_path): with open(arg_path, 'w') as f: print(' '.join(sys.argv[1:]), file=f)
def after_create_session(self, session=None, coord=None): """Logs Gin's operative config.""" if self._only_once and self._written_at_least_once: return logging.info('Gin operative configuration:') for gin_config_line in gin.operative_config_str().splitlines(): logging.info(gin_config_line) self._written_at_least_once = True
def record_operative_gin_configurations(operative_config_dir): """Record operative Gin configurations in the given directory.""" gin_log_file = operative_config_path(operative_config_dir) # If it exists already, rename it instead of overwriting it. # This just saves the previous one, not all the ones before. if tf.io.gfile.exists(gin_log_file): tf.io.gfile.rename(gin_log_file, gin_log_file + '.old', overwrite=True) with tf.io.gfile.GFile(gin_log_file, 'w') as f: f.write(gin.operative_config_str())
def save_gin(self): config_path = os.path.join(self._output_dir, 'config.gin') config_str = gin.operative_config_str() with tf.io.gfile.GFile(config_path, 'w') as f: f.write(config_str) sw = self._train_sw if sw: sw.text('gin_config', jaxboard.markdownify_operative_config_str(config_str))
def _save_operative_config(self): # calling these setters ensures the operative config is complete self.generators self.model if not tf.io.gfile.isdir(self.logdir): tf.io.gfile.makedirs(self.logdir) path = operative_config_path(self.logdir, self._iteration) with open(path, 'w') as fp: fp.write(gin.operative_config_str())
def save_gin(self): assert self._output_dir is not None config_path = os.path.join(self._output_dir, 'config.gin') config_str = gin.operative_config_str() with tf.io.gfile.GFile(config_path, 'w') as f: f.write(config_str) if self._sw: self._sw.text('gin_config', jaxboard.markdownify_operative_config_str(config_str))
def save_gin_config(self): config_str = gin.operative_config_str() if 'AdvantageActorCriticAgent.batch_sz' not in config_str: # gin ignores batch size since it's passed manually from args # as a hacky workaround - insert it manually as the first param batch_sz = gin.query_parameter('AdvantageActorCriticAgent.batch_sz') config_lines = config_str.split('\n') first_ac_line = 0 for first_ac_line in range(0, len(config_lines)): if 'AdvantageActorCriticAgent.' in config_lines[first_ac_line]: break config_lines.insert(first_ac_line, 'AdvantageActorCriticAgent.batch_sz = ' + str(batch_sz)) config_str = '\n'.join(config_lines) with open(self.config_path, 'w') as cfg_file: cfg_file.write(config_str)