예제 #1
0
 def test_operative_config_str(self):
     with GinState():
         g()
         op_config_string = gin.operative_config_str()
         self.assertTrue('f.x' not in gin.operative_config_str())
         self.assertTrue('g.z' in gin.operative_config_str())
     f()
     op_config_string = gin.operative_config_str()
     self.assertTrue('f.x' in op_config_string)
     self.assertTrue('g.z' not in op_config_string)
예제 #2
0
def gin_log_config(log_dir, slug):
    """
    Attempt to write the opperative config to our log dir and Tensorboard...
    Also prints the config to tf.logging.INFO
    """
    try:
        if log_dir is not None and tf.io.gfile.isdir(log_dir):
            path = os.path.join(log_dir, "%s.gin" % slug)
            with tf.io.gfile.GFile(path, "w") as fp:
                fp.write(gin.operative_config_str())
    except:
        logging.info(gin.operative_config_str())
        logging.warn("Failed to write log gin config")
예제 #3
0
def log_gin_config(output_dir, cometml_experiment=None, wandb_run=None):
    gin_config_str = gin.operative_config_str()

    print("Used config: " + "-" * 40)
    print(gin_config_str)
    print("-" * 52)
    with open(os.path.join(output_dir, "config.gin"), "w") as f:
        f.write(gin_config_str)
    # parse the gin config string to dictionary
    gin_config_str = "\n".join(
        [x for x in gin_config_str.split("\n") if not x.startswith("import")])
    gin_config_dict = yaml.load(
        gin_config_str.replace("@", "").replace(" = %",
                                                ": ").replace(" = ", ": "))
    write_json(gin_config_dict,
               os.path.join(output_dir, "config.gin.json"),
               sort_keys=True,
               indent=2)

    if cometml_experiment is not None:
        # Skip any rows starting with import
        cometml_experiment.log_multiple_params(gin_config_dict)

    if wandb_run is not None:
        # This allows to display the metric on the dashboard
        wandb_run.config.update(
            {k.replace(".", "/"): v
             for k, v in gin_config_dict.items()})
예제 #4
0
def main(experiment_name: str):
    print("Using remote MLFlow server")
    lib_mlflow.setup()

    print(f"Using experiment name {experiment_name}")
    lib_mlflow.try_until_success(mlflow.set_experiment, experiment_name)

    run = lib_mlflow.try_until_success(mlflow.start_run)

    run_id = run.info.run_id
    lib_mlflow.log_param("run_id", run.info.run_id)
    base_dir = Path(os.environ.get("SCRATCH", "./")) / experiment_name / run_id
    problem = lib_problem.Problem(base_dir=base_dir)
    results, models = problem.train()

    gin_config_string = gin.operative_config_str()
    lib_mlflow.log_params(lib_analysis._parse_gin_config(gin_config_string))
    # Include also all the default parameters in the final gin config file.
    (base_dir / "full_config.gin").write_text(gin_config_string)

    print("Saving models")
    model_base_save_path = base_dir / "models"
    model_base_save_path.mkdir(parents=True, exist_ok=True)
    for im, m in enumerate(models):
        out = model_base_save_path / f"model-{im}"
        if out.exists():
            shutil.rmtree(out)
        out.mkdir()
        tf.keras.models.save_model(m, str(out))
    results["model_base_save_path"] = model_base_save_path

    print(results)

    lib_mlflow.shutdown()
    lib_mlflow.try_until_success(mlflow.end_run)
예제 #5
0
def main():
    parse_args()
    train_ds, test_ds, num_classes = core_gin.get_datasets()
    inputs_spec = train_ds.element_spec[0]
    model = core_gin.get_classification_model(inputs_spec, num_classes)
    core_gin.fit(model, train_ds, test_ds)
    print(gin.operative_config_str())
예제 #6
0
    def __init__(self, model, train_dataset, eval_dataset, num_epoch, batch_size, save_dir, snapshot=None):

        # create TensorFlow Dataset objects
        parser = functools.partial(data_parser, dim=model.dim)
        tr_data = tf.data.TFRecordDataset(train_dataset)
        tr_data = tr_data.map(parser).shuffle(10000).batch(batch_size)
        val_data = tf.data.TFRecordDataset(eval_dataset)
        val_data = val_data.map(parser).batch(batch_size)
        # create TensorFlow Iterator object
        iterator = tf.data.Iterator.from_structure(tr_data.output_types,
                                                   tr_data.output_shapes)
        self.start, self.action, self.result = iterator.get_next()
        # create two initialization ops to switch between the datasets
        self.training_init_op = iterator.make_initializer(tr_data)
        self.validation_init_op = iterator.make_initializer(val_data)

        tf_config = tf.ConfigProto(
            inter_op_parallelism_threads=16,
            intra_op_parallelism_threads=16)
        tf_config.gpu_options.allow_growth=True
        self.sess = tf.Session(config=tf_config)
        self.num_epoch = num_epoch
        self.model = model
        self.model.build(input=self.start, action=self.action)
        self.model.setup_optimizer(0.001, self.result)
        self.global_step = 0
        self.save_dir = save_dir
        self.train_writer = tf.summary.FileWriter(os.path.join(save_dir, 'tfboard'), self.sess.graph)
        self.sess.run(tf.global_variables_initializer())
        if snapshot is not None:
            self.model.load(self.sess, snapshot)
        config_str = gin.operative_config_str()
        with open(os.path.join(save_dir, '0.gin'), 'w') as f:
            f.write(config_str)
예제 #7
0
파일: base.py 프로젝트: wDaniec/toolkit
    def on_train_begin(self, logs=None):
        logger.info(
            "Saving meta data information from the beginning of training")

        assert os.system("cp {} {}".format(
            sys.argv[0],
            self.save_path)) == 0, "Failed to execute cp of source script"

        utc_date = datetime.datetime.utcnow().strftime("%Y_%m_%d")

        time_start = time.time()
        cmd = "python " + " ".join(sys.argv)
        self.meta = {
            "cmd": cmd,
            "save_path": self.save_path,
            "most_recent_train_start_date": utc_date,
            "execution_time": -time_start
        }

        json.dump(self.meta,
                  open(os.path.join(self.save_path, "meta.json"), "w"),
                  indent=4)

        # Copy gin configs used, for reference, to the save folder
        os.system("rm " + os.path.join(self.save_path, "*gin"))
        for gin_config in sys.argv[2].split(";"):
            os.system("cp {} {}/base_config.gin".format(
                gin_config, self.save_path))

        with open(os.path.join(self.save_path, "config.gin"), "w") as f:
            f.write(gin.operative_config_str())
예제 #8
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')
  xm.setup_work_unit()
  if not gfile.Exists(FLAGS.workdir):
    gfile.MakeDirs(FLAGS.workdir)
  utils.dump_flags_to_file(os.path.join(FLAGS.workdir, 'flags.txt'))
  gin.bind_parameter('CuriosityEnvWrapper.scale_task_reward', 0.)
  gin.bind_parameter('CuriosityEnvWrapper.scale_surrogate_reward', 1.)
  gin.bind_parameter('AntWrapper.enable_die_condition',
                     FLAGS.ant_env_enable_die_condition)
  gin.parse_config_files_and_bindings(None,
                                      FLAGS.gin_bindings)
  # Hardware crashes with:
  # Failed to open library!
  # dlopen: cannot load any more object with static TLS
  FLAGS.renderer = 'software'

  work_unit = None
  if FLAGS.xm_xid != -1:
    work_unit = xmanager_api.XManagerApi().get_current_work_unit()

  visualize_curiosity_reward(work_unit)
  with gfile.GFile(os.path.join(FLAGS.workdir, 'gin_config.txt'), 'w') as f:
    f.write(gin.operative_config_str())
def train(device, save, epochs, episodes, seed, task):
    Env = dict(keytask=KeyTask)[task]
    env = Env(seed=seed, max_steps=100)

    replay = gen.gen_env(env, n_episodes=episodes)
    tloader = gen.Transitions(replay)
    loader = data.DataLoader(tloader, batch_size=512, shuffle=True)

    learner = model.Learner(device, action_dim=len(env.action_space))

    optimizer = optim.Adam(learner.parameters(), lr=0.001)
    L.info('Starting training for %s epochs.', epochs)
    L.info('Best model will be saved to %s.', save)
    with trange(1, epochs + 1) as t:
        best_loss = 1e9
        for epoch in t:
            loss_ = 0
            for step, batch in enumerate(tqdm(loader)):
                obs, action, reward, next_obs = batch = [
                    x.to(device) for x in batch
                ]
                optimizer.zero_grad()
                loss = learner.loss(*batch)
                loss.backward()
                loss_ += loss.item()
                optimizer.step()
            avg_loss = loss_ / len(loader.dataset)
            if avg_loss < best_loss:
                best_loss = avg_loss
                torch.save(
                    dict(model=learner.state_dict(),
                         config=gin.operative_config_str()), save)
            t.set_postfix(best=best_loss, avg=avg_loss)
    L.info('Done training. Best loss: %s', best_loss)
예제 #10
0
def j2j_train(model_name,
              dataset_name,
              data_dir=None,
              output_dir=None,
              config_file=None,
              config=None):
    """Main function to train the given model on the given dataset.

  Args:
    model_name: The name of the model to train.
    dataset_name: The name of the dataset to train on.
    data_dir: Directory where the data is located.
    output_dir: Directory where to put the logs and checkpoints.
    config_file: the gin configuration file to use.
    config: string (in gin format) to override gin parameters.
  """
    gin.bind_parameter("train_fn.dataset", dataset_name)
    if FLAGS.model:
        config = [] if config is None else config
        config += ["train_fn.model=@models." + model_name]
    gin.parse_config_files_and_bindings(config_file, config)
    if output_dir:
        if not tf.gfile.Exists(output_dir):
            tf.gfile.MkDir(output_dir)
        config_path = os.path.join(output_dir, "gin.config")
        # TODO(lukaszkaiser): why is the file empty if there's no provided config?
        with tf.gfile.Open(config_path, "w") as f:
            f.write(gin.operative_config_str())
    j2j.train_fn(data_dir, output_dir=output_dir)
예제 #11
0
def get_gin_confg_strs():
    """
    Obtain both the operative and inoperative config strs from gin.

    The operative configuration consists of all parameter values used by
    configurable functions that are actually called during execution of the
    current program, and inoperative configuration consists of all parameter
    configured but not used by configurable functions. See `gin.operative_config_str()`
    and `gin_utils.inoperative_config_str` for more detail on how the config is generated.

    Returns:
        md_operative_config_str (str): a markdown-formatted operative str
        md_inoperative_config_str (str): a markdown-formatted inoperative str
    """
    operative_config_str = gin.operative_config_str()
    md_operative_config_str = _markdownify_gin_config_str(
        operative_config_str,
        'All parameter values used by configurable functions that are actually called'
    )
    md_inoperative_config_str = gin_utils.inoperative_config_str()
    if md_inoperative_config_str:
        md_inoperative_config_str = _markdownify_gin_config_str(
            md_inoperative_config_str,
            "All parameter values configured but not used by program. The configured "
            "functions are either not called or called with explicit parameter values "
            "overriding the config.")
    return md_operative_config_str, md_inoperative_config_str
예제 #12
0
def log_gin_config(output_dir,
                   cometml_experiment=None,
                   wandb_run=None,
                   prefix=''):
    """Save the config.gin file containing the whole config, convert it
    to a dictionary and upload it to cometml and wandb.
    """
    gin_config_str = gin.operative_config_str()

    print("Used config: " + "-" * 40)
    print(gin_config_str)
    print("-" * 52)
    with open(os.path.join(output_dir, f"{prefix}config.gin"), "w") as f:
        f.write(gin_config_str)

    gin_config_dict = gin2dict(gin_config_str)
    write_json(gin_config_dict,
               os.path.join(output_dir, f"{prefix}config.gin.json"),
               sort_keys=True,
               indent=2)

    if cometml_experiment is not None:
        # Skip any rows starting with import
        cometml_experiment.log_parameters(gin_config_dict)

    if wandb_run is not None:
        # This allows to display the metric on the dashboard
        wandb_run.config.update(
            {k.replace(".", "/"): v
             for k, v in gin_config_dict.items()})
예제 #13
0
파일: run.py 프로젝트: lxxue/S2DHM
def main(args):
    # Define visible GPU devices
    print('>> Found {} GPUs to use.'.format(torch.cuda.device_count()))
    if args.gpu_id:
        print('>> Using GPU {}'.format(args.gpu_id))
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

    # Load gin config based on dataset name
    gin.parse_config_file('configs/runs/run_{}_on_{}.gin'.format(
        args.mode, args.dataset))

    # For CMU, pick a slice
    if args.dataset == 'cmu':
        bind_cmu_parameters(args.cmu_slice, args.mode)

    # Create dataset loader
    dataset = get_dataset_loader()

    # Load retrieval model and initialize pose predictor
    net = network.ImageRetrievalModel(device=device)
    print(gin.operative_config_str())

    # Check if image retrieval rankings exist, if not compute them
    ranks = rank_images.fetch_or_compute_ranks(dataset, net)

    # Predict query images poses
    pose_predictor = get_pose_predictor(dataset=dataset,
                                        network=net,
                                        ranks=ranks,
                                        log_images=args.log_images)
    pose_predictor.save(pose_predictor.run())
    print("All ranks computed.")
예제 #14
0
파일: callbacks.py 프로젝트: jackd/kblocks
 def dump(self):
     """Save `gin.operative_config_str()` to file."""
     exists = os.path.exists(self._path)
     if exists:
         raise IOError(f"Operative config already exists at {self._path}")
     with open(self._path, "w") as fp:
         fp.write(gin.operative_config_str())
예제 #15
0
    def train_agent(self, agent: PTUtilAgent, **kwargs) -> None:
        """Handles all the necessary subroutines for training a PTUtilAgent.

        Args:
            agent: The agent model to train.
        """
        self.model = agent
        self.model.to(self.device)
        self.model.set_trainer(self)

        # Checkpoint not loaded, start from 0 and dump config params
        if not self.checkpoint_loaded:
            self.episode_start = 1
            self.step_global = 0
            self.logging_buffer.append(
                LoggingItem("INFO", gin.operative_config_str()))
            self.episode_losses = []
            self.episode_rewards = []
        self.checkpoint_loaded = False

        self._begin_fit(**kwargs)
        for ep in range(self.episode_start, self.episode_end + 1):
            self.step_ep = ep
            self._begin_episode()
            while not self.env.is_done():
                self.step_global += 1
                self.step_episode += 1
                with torch.cuda.amp.autocast(self.use_amp):
                    loss, reward = self._run_episode_step()
                self.episode_loss += loss
                self.episode_reward += reward
            self._after_episode()
            if self.early_stoppage:
                break
        self._after_fit()
예제 #16
0
파일: evaluate.py 프로젝트: nussl/models
def evaluate(output_folder, separation_algorithm, eval_class, 
             block_on_gpu, num_workers, seed):
    nussl.utils.seed(seed)
    logging.info(gin.operative_config_str())
    
    with gin.config_scope('test'):
        test_dataset = build_dataset()
    
    results_folder = os.path.join(output_folder, 'results')
    os.makedirs(results_folder, exist_ok=True)
    set_model_to_none = False

    if block_on_gpu:
        # make an instance that'll be used on GPU
        # has an empty audio signal for now
        gpu_algorithm = separation_algorithm(
            nussl.AudioSignal(), device='cuda')
        set_model_to_none = True

    def forward_on_gpu(audio_signal):
        # set the audio signal of the object to this item's mix
        gpu_algorithm.audio_signal = audio_signal
        if hasattr(gpu_algorithm, 'forward'):
            gpu_output = gpu_algorithm.forward()
        elif hasattr(gpu_algorithm, 'extract_features'):
            gpu_output = gpu_algorithm.extract_features()
        return gpu_output

    pbar = tqdm.tqdm(total=len(test_dataset))

    def separate_and_evaluate(item, gpu_output):
        if set_model_to_none:
            separator = separation_algorithm(item['mix'], model_path=None)
        else:
            separator = separation_algorithm(item['mix'])
        estimates = separator(gpu_output)
        source_names = sorted(list(item['sources'].keys()))
        sources = [item['sources'][k] for k in source_names]
        
        # other arguments come from gin config
        evaluator = eval_class(sources, estimates)
        scores = evaluator.evaluate()
        output_path = os.path.join(
            results_folder, f"{item['mix'].file_name}.json")
        with open(output_path, 'w') as f:
            json.dump(scores, f, indent=2)
        pbar.update(1)
    
    pool = ThreadPoolExecutor(max_workers=num_workers)
    
    for i in range(len(test_dataset)):
        item = test_dataset[i]
        gpu_output = forward_on_gpu(item['mix'])
        if i == 0:
            separate_and_evaluate(item, gpu_output)
            continue
        pool.submit(separate_and_evaluate, item, gpu_output)
    
    pool.shutdown(wait=True)
예제 #17
0
    def _save_gin(self):
        config_path = os.path.join(self._output_dir, 'config.gin')
        config_str = gin.operative_config_str()
        with open(config_path, 'w') as f:
            f.write(config_str)

        for (name, value) in gin_utils.extract_bindings(config_str):
            metric_logging.log_property(name, value)
예제 #18
0
def _write_gin_configs(output_file):
    """Writes current gin configs to `output_file`."""
    config_str = gin.operative_config_str()
    logging.info('=' * 80)
    logging.info('Gin configs\n%s', config_str)
    logging.info('=' * 80)
    with tf.gfile.GFile(output_file, 'w') as f:
        f.write(config_str)
 def init_saves(self):
     if not os.path.exists(self.save_location_dir):
         os.mkdir(self.save_location_dir)
     with open(os.path.join(self.save_location_dir,'config.gin'), 'w') as conf:
         conf.write(gin.operative_config_str())
     self.output_log = os.path.join(self.save_location_dir,'output_log.txt')
     self.save_path = os.path.join(self.save_location_dir, 'best_model.pt')
     self.summary_writer =  SummaryWriter(os.path.join(self.save_location_dir, 'logs'), 300)
예제 #20
0
def main(_):
    logging.info('parsing config files: %s', FLAGS.gin_file)
    gin.parse_config_files_and_bindings(FLAGS.gin_file,
                                        FLAGS.gin_param,
                                        skip_unknown=True)
    run_episodes()
    if FLAGS.debug:
        print(gin.operative_config_str())
예제 #21
0
파일: trainer.py 프로젝트: jaigupta/lmml
def _save_gin_config(output_dir: str):
    # Replace \n with \n\n to improve logging.
    gin_config = gin.operative_config_str().replace('\n', '\n\n')
    logging.info('gin config: %s', gin_config)
    tf.summary.text('gin_config', gin_config, 0)
    with tf.io.gfile.GFile(os.path.join(output_dir, 'gin_config.txt'),
                           'w') as f:
        f.write(gin_config)
예제 #22
0
파일: callbacks.py 프로젝트: jackd/pointnet
 def on_train_begin(self, logs=None):
     super(GinConfigWriter, self).on_train_begin(logs)
     path = os.path.join(self._log_dir, 'operative-config.gin')
     operative_config = gin.operative_config_str()
     with tf.io.gfile.GFile(path, 'w') as fp:
         fp.write(operative_config)
     logging.info('Starting training with operative config:\n{}'.format(
         operative_config))
예제 #23
0
def _save_gin(output_dir, sw=None):
    config_path = os.path.join(output_dir, "config.gin")
    config_str = gin.operative_config_str()
    with gfile.GFile(config_path, "w") as f:
        f.write(config_str)
    if sw:
        sw.text("gin_config",
                jaxboard.markdownify_operative_config_str(config_str))
예제 #24
0
def save_gin_config(filename_suffix: str, model_dir: str):
  """Serializes and saves the experiment config."""
  gin_save_path = os.path.join(
      model_dir, 'operative_config.{}.gin'.format(filename_suffix))
  logging.info('Saving gin configurations to %s', gin_save_path)
  tf.io.gfile.makedirs(model_dir)
  with tf.io.gfile.GFile(gin_save_path, 'w') as f:
    f.write(gin.operative_config_str())
예제 #25
0
def dump(root, name):
    gin_path = os.path.join(root, f'{name}.gin')
    if not os.path.isfile(gin_path):
        with open(gin_path, 'w') as f:
            f.write(gin.operative_config_str())
    arg_path = os.path.join(root, f'{name}.args')
    if not os.path.isfile(arg_path):
        with open(arg_path, 'w') as f:
            print(' '.join(sys.argv[1:]), file=f)
예제 #26
0
    def after_create_session(self, session=None, coord=None):
        """Logs Gin's operative config."""
        if self._only_once and self._written_at_least_once:
            return

        logging.info('Gin operative configuration:')
        for gin_config_line in gin.operative_config_str().splitlines():
            logging.info(gin_config_line)
        self._written_at_least_once = True
예제 #27
0
def record_operative_gin_configurations(operative_config_dir):
    """Record operative Gin configurations in the given directory."""
    gin_log_file = operative_config_path(operative_config_dir)
    # If it exists already, rename it instead of overwriting it.
    # This just saves the previous one, not all the ones before.
    if tf.io.gfile.exists(gin_log_file):
        tf.io.gfile.rename(gin_log_file, gin_log_file + '.old', overwrite=True)
    with tf.io.gfile.GFile(gin_log_file, 'w') as f:
        f.write(gin.operative_config_str())
예제 #28
0
 def save_gin(self):
     config_path = os.path.join(self._output_dir, 'config.gin')
     config_str = gin.operative_config_str()
     with tf.io.gfile.GFile(config_path, 'w') as f:
         f.write(config_str)
     sw = self._train_sw
     if sw:
         sw.text('gin_config',
                 jaxboard.markdownify_operative_config_str(config_str))
예제 #29
0
 def _save_operative_config(self):
     # calling these setters ensures the operative config is complete
     self.generators
     self.model
     if not tf.io.gfile.isdir(self.logdir):
         tf.io.gfile.makedirs(self.logdir)
     path = operative_config_path(self.logdir, self._iteration)
     with open(path, 'w') as fp:
         fp.write(gin.operative_config_str())
예제 #30
0
 def save_gin(self):
   assert self._output_dir is not None
   config_path = os.path.join(self._output_dir, 'config.gin')
   config_str = gin.operative_config_str()
   with tf.io.gfile.GFile(config_path, 'w') as f:
     f.write(config_str)
   if self._sw:
     self._sw.text('gin_config',
                   jaxboard.markdownify_operative_config_str(config_str))
예제 #31
0
    def save_gin_config(self):
        config_str = gin.operative_config_str()

        if 'AdvantageActorCriticAgent.batch_sz' not in config_str:
            # gin ignores batch size since it's passed manually from args
            # as a hacky workaround - insert it manually as the first param
            batch_sz = gin.query_parameter('AdvantageActorCriticAgent.batch_sz')
            config_lines = config_str.split('\n')
            first_ac_line = 0
            for first_ac_line in range(0, len(config_lines)):
                if 'AdvantageActorCriticAgent.' in config_lines[first_ac_line]:
                    break
            config_lines.insert(first_ac_line, 'AdvantageActorCriticAgent.batch_sz = ' + str(batch_sz))
            config_str = '\n'.join(config_lines)

        with open(self.config_path, 'w') as cfg_file:
            cfg_file.write(config_str)