コード例 #1
0
ファイル: app.py プロジェクト: laranea/ddsp-streamlit-ui
def load_model(instrument_model, audio_length):
    # Build checkpoint path
    # Assumes only one checkpoint in the folder, 'model.ckpt-[iter]`.
    model_dir = os.path.join(CKPT_DIR,
                             "solo_%s_ckpt" % instrument_model.lower())
    ckpt_files = [
        f for f in tf.gfile.ListDirectory(model_dir) if "model.ckpt" in f
    ]
    ckpt_name = ".".join(ckpt_files[0].split(".")[:2])
    ckpt = os.path.join(model_dir, ckpt_name)

    # Parse gin config
    with gin.unlock_config():
        gin_file = os.path.join(model_dir, "operative_config-0.gin")
        gin.parse_config_file(gin_file, skip_unknown=True)

    # Ensure dimensions sampling rates are equal
    time_steps_train = gin.query_parameter("DefaultPreprocessor.time_steps")
    n_samples_train = gin.query_parameter("Additive.n_samples")
    hop_size = int(n_samples_train / time_steps_train)

    time_steps = int(audio_length / hop_size)
    n_samples = time_steps * hop_size

    gin_params = [
        "Additive.n_samples = {}".format(n_samples),
        "FilteredNoise.n_samples = {}".format(n_samples),
        "DefaultPreprocessor.time_steps = {}".format(time_steps),
    ]

    with gin.unlock_config():
        gin.parse_config(gin_params)

    return ckpt, time_steps, n_samples
コード例 #2
0
ファイル: inference.py プロジェクト: Startup-Data/SatLunNeh
  def configure_gin(self, ckpt):
    """Parse the model operative config with special streaming parameters."""
    parse_operative_config(ckpt)

    # Set streaming specific params.
    time_steps = gin.query_parameter('F0PowerPreprocessor.time_steps')
    n_samples = gin.query_parameter('Harmonic.n_samples')
    samples_per_frame = int(n_samples / time_steps)
    config = [
        'F0PowerPreprocessor.time_steps = 1',
        f'Harmonic.n_samples = {samples_per_frame}',
        f'FilteredNoise.n_samples = {samples_per_frame}',
    ]

    # Remove reverb processor.
    processor_group_string = """ProcessorGroup.dag = [
    (@synths.Harmonic(),
      ['amps', 'harmonic_distribution', 'f0_hz']),
    (@synths.FilteredNoise(),
      ['noise_magnitudes']),
    (@processors.Add(),
      ['filtered_noise/signal', 'harmonic/signal']),
    ]"""
    config.append(processor_group_string)

    with gin.unlock_config():
      gin.parse_config(config)
コード例 #3
0
ファイル: play.py プロジェクト: rystrauss/interact
def play(agent_dir, num_episodes, max_episode_steps, save_videos):
    agent = get_agent(gin.query_parameter("train.agent"))(make_env_fn(
        gin.query_parameter("train.env_id"),
        episode_time_limit=max_episode_steps))
    agent.pretrain_setup(gin.query_parameter("train.total_timesteps"))

    ckpt_path = tf.train.latest_checkpoint(
        os.path.join(agent_dir, "best-weights"))
    checkpoint = tf.train.Checkpoint(agent)
    checkpoint.restore(
        ckpt_path).assert_existing_objects_matched().expect_partial()

    env = agent.make_env()

    if save_videos:
        env = Monitor(
            env,
            os.path.join(agent_dir, "monitor"),
            video_callable=lambda _: True,
            force=True,
        )

    try:
        episodes = 0
        obs = env.reset()
        while episodes < num_episodes:
            action = agent.act(np.expand_dims(obs, 0),
                               deterministic=True).numpy()
            obs, _, done, _ = env.step(action[0])
            env.render()
            if done:
                obs = env.reset()
                episodes += 1
    except KeyboardInterrupt:
        env.close()
コード例 #4
0
ファイル: inference.py プロジェクト: zeeps31/ddsp
 def parse_gin_config(self, ckpt):
     """Parse the model operative config with special streaming parameters."""
     with gin.unlock_config():
         ckpt_dir = os.path.dirname(ckpt)
         operative_config = train_util.get_latest_operative_config(ckpt_dir)
         print(f'Parsing from operative_config {operative_config}')
         gin.parse_config_file(operative_config, skip_unknown=True)
         # Set streaming specific params.
         # Remove reverb processor.
         pg_string = """ProcessorGroup.dag = [
   (@synths.Harmonic(),
     ['amps', 'harmonic_distribution', 'f0_hz']),
   (@synths.FilteredNoise(),
     ['noise_magnitudes']),
   (@processors.Add(),
     ['filtered_noise/signal', 'harmonic/signal']),
   ]"""
         time_steps = gin.query_parameter('F0PowerPreprocessor.time_steps')
         n_samples = gin.query_parameter('Harmonic.n_samples')
         samples_per_frame = int(n_samples / time_steps)
         gin.parse_config([
             'F0PowerPreprocessor.time_steps=1',
             f'Harmonic.n_samples={samples_per_frame}',
             f'FilteredNoise.n_samples={samples_per_frame}',
             pg_string,
         ])
コード例 #5
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments for train.')

    if not FLAGS.gin_config:
        # Run the experiments locally.
        gin.parse_config_file(FLAGS.gin_config_file)
    else:
        # Run the experiments on a server.
        gin.parse_config_files_and_bindings(FLAGS.gin_config,
                                            FLAGS.gin_bindings)

    # create the `checkpoing` and `summary` directory used during training
    # to save/restore the model and write TF summaries.
    checkpoint_dir_str = gin.query_parameter('checkpoint_dir/macro.value')
    summary_dir_str = gin.query_parameter('summary_dir/macro.value')
    mcts_checkpoint_dir_str = os.path.join(checkpoint_dir_str, 'mcts_data')
    app_directories = [
        checkpoint_dir_str, summary_dir_str, mcts_checkpoint_dir_str
    ]

    for d in app_directories:
        directory_handling.ensure_dir_exists(d)

    ppo_train()
コード例 #6
0
    def configure_gin(self, ckpt):
        """Parse the model operative config with special streaming parameters."""
        parse_operative_config(ckpt)

        # Set streaming specific params.
        preprocessor_ref = gin.query_parameter('Autoencoder.preprocessor')
        preprocessor_str = preprocessor_ref.scoped_selector
        time_steps = gin.query_parameter(f'{preprocessor_str}.time_steps')
        n_samples = gin.query_parameter('Harmonic.n_samples')
        if not isinstance(n_samples, int):
            n_samples = gin.query_parameter('%n_samples')
        samples_per_frame = int(n_samples / time_steps)

        config = [
            'Autoencoder.preprocessor = @F0PowerPreprocessor()',
            'F0PowerPreprocessor.time_steps = 1',
            f'Harmonic.n_samples = {samples_per_frame}',
            f'FilteredNoise.n_samples = {samples_per_frame}',
        ]

        # Remove reverb and crop processors.
        processor_group_string = """ProcessorGroup.dag = [
    (@synths.Harmonic(),
      ['amps', 'harmonic_distribution', 'f0_hz']),
    (@synths.FilteredNoise(),
      ['noise_magnitudes']),
    (@processors.Add(),
      ['filtered_noise/signal', 'harmonic/signal']),
    ]"""
        config.append(processor_group_string)

        with gin.unlock_config():
            gin.parse_config(config)
コード例 #7
0
    def test_dry_run(self, config):
        """Dry-runs all gin configs."""
        gin.clear_config(clear_constants=True)
        gin.parse_config_file(config)

        def run_config():
            try:
                rl_trainer.train_rl(
                    output_dir=self.create_tempdir().full_path,
                    # Don't run any actual training, just initialize all classes.
                    n_epochs=0,
                    train_batch_size=1,
                    eval_batch_size=1,
                )
            except Exception as e:
                raise AssertionError('Error in gin config {}.'.format(
                    os.path.basename(config))) from e

        # Some tests, ex: DM suite can't be run in OSS - so skip them.
        should_skip = False
        try:
            should_skip = should_skip or gin.query_parameter('RLTask.dm_suite')
        except ValueError as e:
            pass
        try:
            env_name = gin.query_parameter('RLTask.env')
            should_skip = (should_skip or env_name.startswith('DM-')
                           or env_name.startswith('LunarLander'))
        except ValueError as e:
            pass

        if should_skip:
            pass
        else:
            run_config()
コード例 #8
0
def main(argv):
    args = flags.FLAGS
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    if args.test:
        args.envs = 4
        args.batch_sz = 4
        args.log_freq = 10
        args.restore = True

    expt = rvr.utils.Experiment(args.results_dir, args.env, args.agent,
                                args.experiment, args.restore)

    base_path = os.path.dirname(os.path.abspath(__file__))
    gin_files = gin_configs.get(args.env, [])
    gin_files = [base_path + '/configs/' + fl for fl in gin_files]
    if args.restore:
        gin_files += [expt.config_path]
    gin_files += args.gin_files

    if not args.gpu:
        args.gin_bindings.append(
            "build_cnn_nature.data_format = 'channels_last'")
        args.gin_bindings.append(
            "build_fully_conv.data_format = 'channels_last'")

    gin.parse_config_files_and_bindings(gin_files, args.gin_bindings)

    # TODO: do this the other way around - put these as gin bindings
    if not args.traj_len:
        args.traj_len = int(
            gin.query_parameter('AdvantageActorCriticAgent.traj_len'))

    if not args.batch_sz:
        args.batch_sz = int(
            gin.query_parameter('AdvantageActorCriticAgent.batch_sz'))

    env_cls = rvr.envs.GymEnv if '-v' in args.env else rvr.envs.SC2Env
    env = env_cls(args.env, args.render, max_ep_len=args.max_ep_len)

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    sess_mgr = rvr.utils.tensorflow.SessionManager(
        sess, expt.path, args.ckpt_freq, training_enabled=not args.test)

    agent = agent_cls[args.agent](env.obs_spec(),
                                  env.act_spec(),
                                  sess_mgr=sess_mgr,
                                  n_envs=args.envs,
                                  traj_len=args.traj_len,
                                  batch_sz=args.batch_sz)
    agent.logger = rvr.utils.StreamLogger(args.envs, args.log_freq,
                                          args.eps_avg, sess_mgr,
                                          expt.log_path)

    if sess_mgr.training_enabled:
        expt.save_gin_config()
        expt.save_model_summary(agent.model)

    agent.run(env, args.updates * args.traj_len * args.batch_sz // args.envs)
コード例 #9
0
def setup_logger():
    # import os
    # Set run specific envirorment configurations
    timestamp = time.strftime("run_%Y_%m_%d_%H_%M_%S") + "_{machine}".format(
        machine=socket.gethostname())

    gin.bind_parameter(
        'multi_tasking_train.model_storage_directory',
        os.path.join(
            gin.query_parameter('multi_tasking_train.model_storage_directory'),
            timestamp))

    os.makedirs(
        gin.query_parameter('multi_tasking_train.model_storage_directory'),
        exist_ok=True)

    log.handlers.clear()
    formatter = logging.Formatter('%(message)s')
    fh = logging.FileHandler(
        os.path.join(
            gin.query_parameter('multi_tasking_train.model_storage_directory'),
            "log.txt"))
    fh.setLevel(logging.INFO)
    fh.setFormatter(formatter)
    log.addHandler(fh)

    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    ch.setFormatter(formatter)
    log.setLevel(logging.INFO)
    log.addHandler(ch)

    # Set global GPU state
    if torch.cuda.is_available() and gin.query_parameter(
            'multi_tasking_train.device') == 'cuda':
        log.info("Using CUDA device:{0}".format(torch.cuda.current_device()))
    else:
        if gin.query_parameter('multi_tasking_train.device') == 'cpu':
            log.info("Utilizing CPU")
        else:
            raise Exception(
                f"Unrecognized device: {gin.query_parameter('multi_tasking_train.device')}"
            )

    # ML-Flow
    mlflow.set_tracking_uri(
        f"{gin.query_parameter('multi_tasking_train.ml_flow_directory')}")
    mlflow.set_experiment(
        f"/{gin.query_parameter('multi_tasking_train.experiment_name')}")

    mlflow.start_run()
    gin_parameters = gin.config._CONFIG.get(list(gin.config._CONFIG.keys())[0])
    mlflow.log_params(gin_parameters)

    # all_params = {x[1].split('.')[-1]: gin.config._CONFIG.get(x) for x in list(gin.config._CONFIG.keys())}
    all_params = gin.config_str()
    with open('config_log.txt', 'w') as f:
        f.write(all_params)
    mlflow.log_artifact("config_log.txt")
    mlflow.log_artifact(__file__)
コード例 #10
0
def finetune_gin_bindings(config):
  gin_bindings = []
  env_str = config.env_str or gin.query_parameter('%ENV_STR')
  if 'Minitaur' in env_str:
    if config.friction:
      gin_bindings.append('minitaur.MinitaurGoalVelocityEnv.friction = {}'.format(config.friction))
    if config.goal_vel:
      gin_bindings.append("GOAL_VELOCITY = {}".format(config.goal_vel))
  elif 'Cube' in env_str:
    if config.finetune:
      gin_bindings.append("cube_env.SafemrlCubeEnv.goal_task = ('more_left', 'more_right', 'more_up', 'more_down')")
  elif 'DrunkSpider' in env_str:
    if config.action_noise:
      gin_bindings.append('point_mass.PointMassEnv.action_noise = {}'.format(config.action_noise))
    if config.action_scale:
      gin_bindings.append('point_mass.PointMassEnv.action_scale = {}'.format(config.action_scale))
    if config.finetune:
      gin_bindings.append("point_mass.env_load_fn.goal = (6, 3)")
  # set NUM_STEPS to be previous value + FINETUNE_STEPS
  ft_steps = gin.query_parameter("%FINETUNE_STEPS")
  if ft_steps is None:
    ft_steps = config.finetune_steps
  num_steps = gin.query_parameter("%NUM_STEPS") + ft_steps
  gin_bindings.append('NUM_STEPS = {}'.format(num_steps))
  return gin_bindings
コード例 #11
0
def get_parallelized_combinations(varying_type: str):

    variables = []
    lst = [
        gin.query_parameter(v) for v in gin.query_parameter("%VARYING_PARS")
    ]
    if varying_type == "combination":
        for xs in itertools.product(*lst):
            variables.append(xs)
    elif varying_type == "ordered_combination":
        for xs in zip(*lst):
            variables.append(xs)
    elif varying_type == "random_search":
        for xs in itertools.product(*lst):
            variables.append(xs)
        variables = [
            variables[i]
            for i in np.random.randint(0,
                                       len(variables) -
                                       1, gin.query_parameter("%NUM_CORES"))
        ]
    elif varying_type == "chunk":
        for xs in itertools.product(*lst):
            variables.append(xs)

    else:
        print("Choose proper way to combine varying parameters")
        sys.exit()

    num_cores = len(variables)

    return variables, num_cores
コード例 #12
0
def load(env_name):
    """Creates the training and evaluation environment.

  This method automatically detects whether we are using a subset of the
  observation for the goal and modifies the observation space to include the
  full state + partial goal.

  Args:
    env_name: (str) Name of the environment.
  Returns:
    tf_env, eval_tf_env, obs_dim: The training and evaluation environments.
  """
    if env_name == 'sawyer_reach':
        tf_env = load_sawyer_reach()
        eval_tf_env = load_sawyer_reach()
    elif env_name == 'sawyer_push':
        tf_env = load_sawyer_push()
        eval_tf_env = load_sawyer_push()
        eval_tf_env.envs[0]._env.gym.MODE = 'eval'  # pylint: disable=protected-access
    elif env_name == 'sawyer_drawer':
        tf_env = load_sawyer_drawer()
        eval_tf_env = load_sawyer_drawer()
    elif env_name == 'sawyer_window':
        tf_env = load_sawyer_window()
        eval_tf_env = load_sawyer_window()
    elif env_name == 'sawyer_faucet':
        tf_env = load_sawyer_faucet()
        eval_tf_env = load_sawyer_faucet()
    else:
        raise NotImplementedError('Unsupported environment: %s' % env_name)
    assert len(tf_env.envs) == 1
    assert len(eval_tf_env.envs) == 1

    # By default, the environment observation contains the current state and goal
    # state. By setting the obs_to_goal parameters, the use can specify that the
    # agent should only look at certain subsets of the goal state. The following
    # code modifies the environment observation to include the full state but only
    # the user-specified dimensions of the goal state.
    obs_dim = tf_env.observation_spec().shape[0] // 2
    try:
        start_index = gin.query_parameter('obs_to_goal.start_index')
    except ValueError:
        start_index = 0
    try:
        end_index = gin.query_parameter('obs_to_goal.end_index')
    except ValueError:
        end_index = None
    if end_index is None:
        end_index = obs_dim

    indices = np.concatenate([
        np.arange(obs_dim),
        np.arange(obs_dim + start_index, obs_dim + end_index)
    ])
    tf_env = tf_py_environment.TFPyEnvironment(
        wrappers.ObservationFilterWrapper(tf_env.envs[0], indices))
    eval_tf_env = tf_py_environment.TFPyEnvironment(
        wrappers.ObservationFilterWrapper(eval_tf_env.envs[0], indices))
    return (tf_env, eval_tf_env, obs_dim)
コード例 #13
0
ファイル: env_utils.py プロジェクト: Danmou/MerCur-Re
def construct_envs(
    config: Config, training: bool
) -> VectorEnv:
    r"""Create VectorEnv object with specified config and env class type.
    To allow better performance, dataset are split into small ones for
    each individual env, grouped by scenes.

    Args:
        config: configs that contain num_processes as well as information
        necessary to create individual environments.
        env_class: class type of the envs to be created.

    Returns:
        VectorEnv object created according to specification.
    """
    num_processes = config.NUM_PROCESSES
    dataset = make_dataset(config.TASK_CONFIG.DATASET.TYPE)
    scenes = dataset.get_scenes_to_load(config.TASK_CONFIG.DATASET)

    if len(scenes) > 0:
        random.shuffle(scenes)

        assert len(scenes) >= num_processes, (
            "reduce the number of processes as there "
            "aren't enough number of scenes"
        )

    scene_splits = [[] for _ in range(num_processes)]
    for idx, scene in enumerate(scenes):
        scene_splits[idx % len(scene_splits)].append(scene)

    assert sum(map(len, scene_splits)) == len(scenes)

    task = 'habitat_train_task' if training else 'habitat_eval_task'
    max_duration = gin.query_parameter(f'{task}.max_length')
    wrappers = [w.scoped_configurable_fn() for w in gin.query_parameter(f'{task}.wrappers')]
    kwargs = get_config(training=training, max_steps=max_duration*3)
    kwargs['max_duration'] = max_duration
    kwargs['action_repeat'] = 1
    kwargs['wrappers'] = [(wrapper, kwarg_fn(kwargs)) for wrapper, kwarg_fn in wrappers]
    env_kwargs = []
    for scenes in scene_splits:
        kw = kwargs.copy()
        config = kw['config'].clone()
        if len(scenes) > 0:
            config.defrost()
            config.DATASET.CONTENT_SCENES = scenes
            config.freeze()
        kw['config'] = config
        env_kwargs.append(kw)

    envs = habitat.VectorEnv(
        make_env_fn=make_env_fn,
        env_fn_args=tuple(
            # tuple(zip(configs, env_classes, range(num_processes)))
            tuple(zip(env_kwargs, range(num_processes)))
        ),
    )
    return envs
コード例 #14
0
    def load_model(self, model):
        #model = 'Flute2' #@param ['Violin', 'Flute', 'Flute2', 'Trumpet', 'Tenor_Saxophone','Upload your own (checkpoint folder as .zip)']
        MODEL = model
        self.model_name = model
        if model in ('Violin', 'Flute', 'Flute2', 'Trumpet',
                     'Tenor_Saxophone'):
            # Pretrained models.
            PRETRAINED_DIR = 'pretrained'
            model_dir = PRETRAINED_DIR
            gin_file = os.path.join(PRETRAINED_DIR, 'operative_config-0.gin')

        # Parse gin config,
        with gin.unlock_config():
            gin.parse_config_file(gin_file, skip_unknown=True)

        # Assumes only one checkpoint in the folder, 'ckpt-[iter]`.
        ckpt_files = [f for f in tf.io.gfile.listdir(model_dir) if 'ckpt' in f]
        ckpt_name = ckpt_files[0].split('.')[0]
        ckpt = os.path.join(model_dir, ckpt_name)

        # Ensure dimensions and sampling rates are equal
        time_steps_train = gin.query_parameter(
            'DefaultPreprocessor.time_steps')
        n_samples_train = gin.query_parameter('Additive.n_samples')
        hop_size = int(n_samples_train / time_steps_train)
        print(self.audio.shape[1])
        time_steps = int(self.audio.shape[1] / hop_size)
        print(time_steps)
        n_samples = time_steps * hop_size
        print(n_samples)
        gin_params = [
            'Additive.n_samples = {}'.format(n_samples),
            'FilteredNoise.n_samples = {}'.format(n_samples),
            'DefaultPreprocessor.time_steps = {}'.format(time_steps),
        ]

        with gin.unlock_config():
            gin.parse_config(gin_params)

        # Trim all input vectors to correct lengths
        for key in ['f0_hz', 'f0_confidence', 'loudness_db']:
            print(type(self.audio_features[key]))
            self.audio_features[key] = self.audio_features[key][:time_steps]

        print(self.audio_features['audio'].shape)
        print(n_samples)
        self.audio_features['audio'] = self.audio_features[
            'audio'][:, :n_samples]

        # Set up the model just to predict audio given new conditioning
        self.model = ddsp.training.models.Autoencoder()
        self.model.restore(ckpt)

        # Build model by running a batch through it.
        start_time = time.time()
        _ = self.model(self.audio_features, training=False)
        print('Restoring model took %.1f seconds' % (time.time() - start_time))
コード例 #15
0
    def set_up_training(self):

        self.logging.debug("Simulating Data")

        self.data_handler = DataHandler(N_train=self.len_series, rng=self.rng)
        if self.experiment_type == "GP":
            self.data_handler.generate_returns()
        else:
            self.data_handler.generate_returns()
            # TODO check if these method really fit and change the parameters in the gin file
            self.data_handler.estimate_parameters()

        self.logging.debug("Instantiating action space")
        if self.MV_res:
            self.action_space = ResActionSpace()
        else:
            action_range, ret_quantile, holding_quantile = get_action_boundaries(
                N_train=self.N_train,
                f_speed=self.data_handler.f_speed,
                returns=self.data_handler.returns,
                factors=self.data_handler.factors,
            )

            gin.query_parameter("%ACTION_RANGE")[0] = action_range
            self.action_space = ActionSpace()

        self.logging.debug("Instantiating market environment")
        self.env = self.env_cls(
            N_train=self.N_train,
            f_speed=self.data_handler.f_speed,
            returns=self.data_handler.returns,
            factors=self.data_handler.factors,
        )

        self.logging.debug("Instantiating DQN model")
        input_shape = self.env.get_state_dim()

        step_size = (self.len_series / gin.query_parameter("PPO.batch_size")
                     ) * gin.query_parameter("%EPOCHS")
        gin.bind_parameter("PPO.step_size", step_size)
        self.train_agent = PPO(input_shape=input_shape,
                               action_space=self.action_space,
                               rng=self.rng)

        self.train_agent.model.to(self.device)

        self.logging.debug("Instantiating Out of sample tester")
        self.oos_test = Out_sample_vs_gp(
            savedpath=self.savedpath,
            tag="PPO",
            experiment_type=self.experiment_type,
            env_cls=self.env_cls,
            MV_res=self.MV_res,
        )

        self.oos_test.init_series_to_fill(iterations=self.col_names_oos)
コード例 #16
0
def main_runner(configs_path: str, algo: str):
    """Main function to run both a single experiment or a
    set of parallelized experiment

    Parameters
    ----------
    configs_path: str
        Path where the config files are stored

    algo: str
        Acronym of the algorithm to run. Read the comments in the gin config to see
        the available algorithms

    experiment: str
        Name of the type of synthetic experiment to perform. Read the comments in the gin config to see
        the available algorithms

    parallel: bool
        Choose to parallelize or not the selected experiments
    """

    # get runner to do the experiments
    if algo == "DQN":
        func = DQN_runner

    elif algo == "PPO":
        func = PPO_runner

    # launch runner (either parallelized or not)
    if gin.query_parameter("%VARYING_PARS") is not None:
        # get varying parameters, combinations and cores
        varying_type = gin.query_parameter("%VARYING_TYPE")
        varying_par_to_change = gin.query_parameter("%VARYING_PARS")
        combinations, num_cores = get_parallelized_combinations(varying_type)

        # choose way to parallelize
        if varying_type == "random_search":
            Parallel(n_jobs=num_cores)(delayed(parallel_exps)(
                var_par, varying_par_to_change, gin_path, func=func)
                                       for var_par in combinations)
            time.sleep(5)
            os.execv(sys.executable, ["python"] + sys.argv)
        elif varying_type == "chunk":
            num_cores = gin.query_parameter("%NUM_CORES")
            for chunk_var in chunks(combinations, num_cores):
                Parallel(n_jobs=num_cores)(delayed(parallel_exps)(
                    var_par, varying_par_to_change, gin_path, func=func)
                                           for var_par in chunk_var)
                time.sleep(5)
        else:
            print("Choose proper way to parallelize.")
            sys.exit()
    else:

        model_runner = func()
        model_runner.run()
コード例 #17
0
def main():
    AUDIO_PATH = sys.argv[1]
    MODEL_DIR = sys.argv[2]
    RESULT_PATH = sys.argv[3]
    print(sys.argv)

    audio = audio_bytes_to_np(open(AUDIO_PATH, "rb").read(),
                              sample_rate=DEFAULT_SAMPLE_RATE,
                              normalize_db=None)
    audio = audio[np.newaxis, :]
    audio_features = ddsp.training.eval_util.compute_audio_features(audio)
    audio_features['loudness_db'] = audio_features['loudness_db'].astype(
        np.float32)

    # Parse the gin config.
    gin_file = os.path.join(MODEL_DIR, 'operative_config-0.gin')
    gin.parse_config_file(gin_file)

    # Ensure dimensions and sampling rates are equal
    time_steps_train = gin.query_parameter('DefaultPreprocessor.time_steps')
    n_samples_train = gin.query_parameter('Additive.n_samples')
    hop_size = int(n_samples_train / time_steps_train)

    time_steps = int(audio.shape[1] / hop_size)
    n_samples = time_steps * hop_size

    print("===Trained model===")
    print("Time Steps", time_steps_train)
    print("Samples", n_samples_train)
    print("Hop Size", hop_size)
    print("\n===Resynthesis===")
    print("Time Steps", time_steps)
    print("Samples", n_samples)
    print('')

    gin_params = [
        'Additive.n_samples = {}'.format(n_samples),
        'FilteredNoise.n_samples = {}'.format(n_samples),
        'DefaultPreprocessor.time_steps = {}'.format(time_steps),
    ]

    with gin.unlock_config():
        gin.parse_config(gin_params)

    # Trim all input vectors to correct lengths
    for key in ['f0_hz', 'f0_confidence', 'loudness_db']:
        audio_features[key] = audio_features[key][:time_steps]
        audio_features['audio'] = audio_features['audio'][:, :n_samples]

    # Load model
    model = ddsp.training.models.Autoencoder()
    model.restore(MODEL_DIR)

    # Resynthesize audio.
    audio_gen = model(audio_features, training=False)
    outputToWav(audio_gen, RESULT_PATH)
コード例 #18
0
ファイル: state_test.py プロジェクト: jackd/more-keras
 def test_gin_state(self):
     gin.bind_parameter('f.x', 'global')
     self.assertEqual(gin.query_parameter('f.x'), 'global')
     with GinState() as temp_state:
         self.assertEqual(f()[0], 'default')
         gin.bind_parameter('f.x', 'temp')
         self.assertEqual(gin.query_parameter('f.x'), 'temp')
     self.assertEqual(gin.query_parameter('f.x'), 'global')
     with temp_state:
         self.assertEqual(gin.query_parameter('f.x'), 'temp')
コード例 #19
0
def _default_output_dir():
  """Default output directory."""
  dir_name = "{model_name}_{dataset_name}_{timestamp}".format(
      model_name=gin.query_parameter("train.model").configurable.name,
      dataset_name=gin.query_parameter("inputs.dataset_name"),
      timestamp=datetime.datetime.now().strftime("%Y%m%d_%H%M"),
  )
  dir_path = os.path.join("~", "trax", dir_name)
  print()
  trax.log("No --output_dir specified")
  return dir_path
コード例 #20
0
def run(**kwargs) -> None:
    """Runs a T5 model for training, finetuning, evaluation etc."""
    tf.disable_v2_behavior()

    if gin.query_parameter("utils.run.mode") == "eval":
        # Increase the recursion limit, see: https://github.com/pltrdy/rouge/issues/19
        length = gin.query_parameter("utils.run.sequence_length").get(
            "inputs", 512)
        batch_size = 1024  # TODO: do not hardcode batch_size for recursionlimit calc
        sys.setrecursionlimit(batch_size * length + 10)

    utils.run(**kwargs)
コード例 #21
0
 def test_acc(_):
     """
     report testing accurarcy
     :param _:
     :return:
     """
     acc, cm, _, = test_classification(
         model,
         gin.query_parameter('triggered_earthquake_dataset.testing_quakes'),
         device,
         gin.query_parameter('triggered_earthquake_dataset.data_dir'))
     writer.add_scalar('Accurarcy/test', acc, trainer.state.epoch)
     print('Testing Accurarcy: {:.2f}'.format(acc))
     print(cm)
コード例 #22
0
ファイル: trainer.py プロジェクト: koz4k2/trax
def _default_output_dir():
    """Default output directory."""
    try:
        dataset_name = gin.query_parameter('inputs.dataset_name')
    except ValueError:
        dataset_name = 'random'
    dir_name = '{model_name}_{dataset_name}_{timestamp}'.format(
        model_name=gin.query_parameter('train.model').configurable.name,
        dataset_name=dataset_name,
        timestamp=datetime.datetime.now().strftime('%Y%m%d_%H%M'),
    )
    dir_path = os.path.join('~', 'trax', dir_name)
    print()
    trainer_lib.log('No --output_dir specified')
    return dir_path
コード例 #23
0
ファイル: run.py プロジェクト: stjordanis/reaver-pysc2
def main(argv):
    args = flags.FLAGS
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    if args.test:
        args.envs = 4
        args.batch_sz = 4
        args.log_freq = 10
        args.restore = True

    expt = rvr.utils.Experiment(args.results_dir, args.env, args.agent, args.experiment, args.restore)

    base_path = os.path.dirname(os.path.abspath(__file__))
    gin_files = gin_configs.get(args.env, [])
    gin_files = [base_path + '/configs/' + fl for fl in gin_files]
    if args.restore:
        gin_files += [expt.config_path]
    gin_files += args.gin_files

    if not args.gpu:
        args.gin_bindings.append("build_cnn_nature.data_format = 'channels_last'")
        args.gin_bindings.append("build_fully_conv.data_format = 'channels_last'")

    gin.parse_config_files_and_bindings(gin_files, args.gin_bindings)

    # TODO: do this the other way around - put these as gin bindings
    if not args.traj_len:
        args.traj_len = int(gin.query_parameter('AdvantageActorCriticAgent.traj_len'))

    if not args.batch_sz:
        args.batch_sz = int(gin.query_parameter('AdvantageActorCriticAgent.batch_sz'))

    env_cls = rvr.envs.GymEnv if '-v' in args.env else rvr.envs.SC2Env
    env = env_cls(args.env, args.render, max_ep_len=args.max_ep_len)

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    sess_mgr = rvr.utils.tensorflow.SessionManager(sess, expt.path, args.ckpt_freq, training_enabled=not args.test)

    agent = agent_cls[args.agent](env.obs_spec(), env.act_spec(), sess_mgr=sess_mgr,
                                  n_envs=args.envs, traj_len=args.traj_len, batch_sz=args.batch_sz)
    agent.logger = rvr.utils.StreamLogger(args.envs, args.log_freq, args.eps_avg, sess_mgr, expt.log_path)

    if sess_mgr.training_enabled:
        expt.save_gin_config()
        expt.save_model_summary(agent.model)

    agent.run(env, args.updates * args.traj_len * args.batch_sz // args.envs)
コード例 #24
0
def fetch_grid_results_dict_from_neptune(results_dict: GridResultDict,
                                         study_name: str) -> None:
    import neptune
    user_name = gin.query_parameter('neptune.user_name')
    project_name = gin.query_parameter('neptune.project_name')
    project = neptune.init(f'{user_name}/{project_name}')

    df = download_dataframe_from_neptune(project, study_name)

    hps_names = list(get_hps_dict().keys())
    for idx, row in df.iterrows():
        hps = dict(row[hps_names])
        seed = int(row['data.split_seed'])
        results = download_results_from_neptune(project, row['id'],
                                                'results.json')
        results_dict[frozenset(hps.items())][seed] = results
コード例 #25
0
def get_params_dict() -> Dict[str, Any]:
    params_dict: dict = gin.query_parameter('optuna.params')
    return {
        k: v.__deepcopy__(None)
        if isinstance(v, gin.config.ConfigurableReference) else v
        for k, v in params_dict.items()
    }
コード例 #26
0
    def save_metadata(_):
        '''
        save a metadata file, used for inference
        :param _:
        :return:
        '''
        transformer = triggered_earthquake_transform(random_trim_offset=False)
        transformer_path = os.path.join(model_dir, 'transformer.p')
        pickle.dump(transformer, open(transformer_path, 'wb'))

        metadata = {
            'name':
            run_name,
            'classes':
            gin.query_parameter('triggered_earthquake_dataset.labels'),
            'model_state_path':
            save_handler.last_checkpoint,
            'classifier_path':
            os.path.join(model_dir, '{}_classifier.p'.format(prefix)),
            'embedding_size':
            embedding_size,
            'num_layers':
            num_layers,
            'transformer':
            transformer_path
        }

        with open(os.path.join(model_dir, 'metadata.json'), 'w') as f:
            json.dump(metadata, f)
コード例 #27
0
    def report_embeddings(_):
        """
        write embeddings to tensorboard
        :param _:
        :return:
        """
        train_loader = DataLoader(ds_train, batch_size=1)
        test_loader = DataLoader(ds_test, batch_size=1)

        text_labels = gin.query_parameter(
            'triggered_earthquake_dataset.labels')
        train_embeddings, train_labels = get_embeddings(model,
                                                        train_loader,
                                                        device=device)
        train_labels = [
            text_labels[np.argmax(l)] for l in train_labels.squeeze(1)
        ]
        writer.add_embedding(train_embeddings.squeeze(1),
                             metadata=train_labels,
                             global_step=trainer.state.epoch,
                             tag='train_embeddings')

        test_embeddings, test_labels = get_embeddings(model,
                                                      test_loader,
                                                      device=device)
        test_labels = [
            text_labels[np.argmax(l)] for l in test_labels.squeeze(1)
        ]
        writer.add_embedding(test_embeddings.squeeze(1),
                             metadata=test_labels,
                             global_step=trainer.state.epoch,
                             tag='test_embeddings')
コード例 #28
0
def main(_):
    seed = common.set_random_seed(FLAGS.random_seed)
    gin_file = common.get_gin_file()
    gin.parse_config_files_and_bindings(gin_file, FLAGS.gin_param)
    algorithm_ctor = gin.query_parameter(
        'TrainerConfig.algorithm_ctor').scoped_configurable_fn
    env = create_environment(nonparallel=True, seed=seed)
    env.reset()
    common.set_global_env(env)
    config = policy_trainer.TrainerConfig(root_dir="")
    data_transformer = create_data_transformer(config.data_transformer_ctor,
                                               env.observation_spec())
    config.data_transformer = data_transformer
    observation_spec = data_transformer.transformed_observation_spec
    common.set_transformed_observation_spec(observation_spec)
    algorithm = algorithm_ctor(
        observation_spec=observation_spec,
        action_spec=env.action_spec(),
        config=config)
    try:
        policy_trainer.play(
            FLAGS.root_dir,
            env,
            algorithm,
            checkpoint_step=FLAGS.checkpoint_step or "latest",
            epsilon_greedy=FLAGS.epsilon_greedy,
            num_episodes=FLAGS.num_episodes,
            max_episode_length=FLAGS.max_episode_length,
            sleep_time_per_step=FLAGS.sleep_time_per_step,
            record_file=FLAGS.record_file,
            ignored_parameter_prefixes=FLAGS.ignored_parameter_prefixes.split(
                ",") if FLAGS.ignored_parameter_prefixes else [])
    finally:
        env.close()
コード例 #29
0
def log_vote(instance, log_file=None, guess=False):
    if log_file is None:
        log_file = 'votes-' + gin.query_parameter('Codenamer.name') + ('-guess' if guess else '') + '.json'
    with portalocker.Lock(log_file) as f:
        f.write(instance.to_json() + "\n")
        f.flush()
        os.fsync(f.fileno())
コード例 #30
0
def get_vocabulary(mixture_or_task_name=None):
  """Return vocabulary from the mixture or task."""
  if not mixture_or_task_name:
    # Attempt to extract the mixture/task name from the gin config.
    try:
      mixture_or_task_name = gin.query_parameter("%MIXTURE_NAME")
    except ValueError:
      logging.warning("Could not extract mixture/task name from gin config.")
  if mixture_or_task_name:
    provider = t5.data.get_mixture_or_task(mixture_or_task_name)
    features = provider.output_features
    if "inputs" in features and "targets" in features:
      return (features["inputs"].vocabulary, features["targets"].vocabulary)
    else:
      feature_values = list(features.values())
      vocabulary = feature_values[0].vocabulary
      for feature in feature_values[1:]:
        if feature.vocabulary != vocabulary:
          logging.warning("No feature_name was provided to get_vocabulary, but "
                          "output_features have different vocabularies.")
          vocabulary = None
          break
      if vocabulary:
        return vocabulary
  logging.warning("Using default vocabulary.")
  return t5.data.get_default_vocabulary()
コード例 #31
0
    def update_communicator(self):
        """Set current parameters and update existing ones."""
        if self.communicator is None:
            return

        config_backup = deepcopy(self._config)

        try:
            gin_queries = ray.get(
                self.communicator.get_clean_gin_queries.remote())

            for q in gin_queries:
                val = gin.query_parameter(q)
                self.communicator.add_msg.remote(f"Gin value {q}: {val}")

            updates = ray.get(self.communicator.get_clear_updates.remote())
            for k, v in updates:
                self.communicator.add_msg.remote(
                    f"Updating parameter {k}: {v}")
                logging.info(f"Updating parameter {k}: {v}")
                param_update_from_flat(self._config, k, v)

            self.communicator.set_current_parameters.remote(
                param_flatten_dict_keys(self._config))
            self.set_gin_variables()
        except Exception as e:
            self.communicator.add_msg.remote("Error: " + str(e))
            logging.warning(f"Remote parameter update failed {e}")
            self._config = config_backup
コード例 #32
0
ファイル: experiment.py プロジェクト: stjordanis/reaver-pysc2
    def save_gin_config(self):
        config_str = gin.operative_config_str()

        if 'AdvantageActorCriticAgent.batch_sz' not in config_str:
            # gin ignores batch size since it's passed manually from args
            # as a hacky workaround - insert it manually as the first param
            batch_sz = gin.query_parameter('AdvantageActorCriticAgent.batch_sz')
            config_lines = config_str.split('\n')
            first_ac_line = 0
            for first_ac_line in range(0, len(config_lines)):
                if 'AdvantageActorCriticAgent.' in config_lines[first_ac_line]:
                    break
            config_lines.insert(first_ac_line, 'AdvantageActorCriticAgent.batch_sz = ' + str(batch_sz))
            config_str = '\n'.join(config_lines)

        with open(self.config_path, 'w') as cfg_file:
            cfg_file.write(config_str)