def make_bitransformer_ll(input_vocab_size=gin.REQUIRED,
                          output_vocab_size=gin.REQUIRED,
                          layout=None,
                          mesh_shape=None,
                          encoder_name="encoder",
                          decoder_name="decoder",
                          cut_cross_attention=False):
    """Gin-configurable bitransformer constructor.
  In your config file you need to set the encoder and decoder layers like this:
  encoder/make_layer_stack.layers = [
    @transformer_layers.SelfAttention,
    @transformer_layers.DenseReluDense,
  ]
  decoder/make_layer_stack.layers = [
    @transformer_layers.SelfAttention,
    @transformer_layers.EncDecAttention,
    @transformer_layers.DenseReluDense,
  ]
  Args:
    input_vocab_size: a integer
    output_vocab_size: an integer
    layout: optional - an input to mtf.convert_to_layout_rules
      Some layers (e.g. MoE layers) cheat by looking at layout and mesh_shape
    mesh_shape: optional - an input to mtf.convert_to_shape
      Some layers (e.g. MoE layers) cheat by looking at layout and mesh_shape
    encoder_name: optional - a string giving the Unitransformer encoder name.
    decoder_name: optional - a string giving the Unitransformer decoder name.
  Returns:
    a Bitransformer
  """
    with gin.config_scope("encoder"):
        encoder = Unitransformer_ll(layer_stack=make_layer_stack(),
                                    input_vocab_size=input_vocab_size,
                                    output_vocab_size=None,
                                    autoregressive=False,
                                    name=encoder_name,
                                    layout=layout,
                                    mesh_shape=mesh_shape)
    with gin.config_scope("decoder"):
        if cut_cross_attention:
            layer_stack = make_layer_stack(layers=[
                mtf.transformer.transformer_layers.SelfAttention,
                [
                    mtf.transformer.transformer_layers.DenseReluDense,
                    "layer_002"
                ]
            ])
        else:
            layer_stack = make_layer_stack()

        decoder = Unitransformer_ll(layer_stack=layer_stack,
                                    input_vocab_size=output_vocab_size,
                                    output_vocab_size=output_vocab_size,
                                    autoregressive=True,
                                    name=decoder_name,
                                    layout=layout,
                                    mesh_shape=mesh_shape)
    return Bitransformer_ll(encoder,
                            decoder,
                            cut_cross_attention=cut_cross_attention)
コード例 #2
0
    def __init__(
        self,
        env_class,
        agent_class,
        network_fn,
        config,
        scope,
        init_hooks,
        compress_episodes,
    ):

        gin.parse_config(config, skip_unknown=True)

        for hook in init_hooks:
            hook()

        import tensorflow as tf
        tf.config.threading.set_inter_op_parallelism_threads(1)
        tf.config.threading.set_intra_op_parallelism_threads(1)

        with gin.config_scope(scope):
            self.env = env_class()
            self.agent = agent_class()
            self._request_handler = core.RequestHandler(network_fn)

        self._compress_episodes = compress_episodes
コード例 #3
0
ファイル: simple_trainer.py プロジェクト: zzszmyf/trax
    def train_model(self):
        """Train the model.

    Returns:
      whether the training was skipped due to a restart.
    """
        logging.info('SimPLe epoch [% 6d]: training model.',
                     self._simple_epoch)
        start_time = time.time()

        (train_stream, eval_stream) = self._make_input_streams()
        # Ignore n_devices for now.
        inputs = trax_inputs.Inputs(train_stream=(lambda _: train_stream),
                                    eval_stream=(lambda _: eval_stream))

        if self._simple_epoch == 0:
            train_steps = self._n_model_initial_train_steps
        else:
            train_steps = self._n_model_train_steps_per_epoch
        self._model_train_step += train_steps
        with gin.config_scope('world_model'):
            state = trainer_lib.train(
                model=self._sim_env.model,
                inputs=inputs,
                steps=self._model_train_step,
                output_dir=self._model_dir,
                has_weights=True,
            )

        logging.vlog(1, 'Training model took %0.2f sec.',
                     time.time() - start_time)
        return state.step > self._model_train_step
コード例 #4
0
ファイル: simple_trainer.py プロジェクト: zhaoqiuye/trax
    def train_model(self):
        """Train the model.

    Returns:
      whether the training was skipped due to a restart.
    """
        logging.info('SimPLe epoch [% 6d]: training model.',
                     self._simple_epoch)
        start_time = time.time()

        (train_stream, eval_stream) = self._make_input_streams()
        # Ignore n_devices for now.
        inputs = trax_inputs.Inputs(train_stream=(lambda _: train_stream),
                                    eval_stream=(lambda _: eval_stream))
        (obs, act, _, _) = next(train_stream)
        # TODO(pkozakowski): Refactor Inputs so this can be inferred correctly.
        inputs._input_shape = (tuple(obs.shape)[1:], tuple(act.shape)[1:])  # pylint: disable=protected-access
        inputs._input_dtype = (obs.dtype, act.dtype)  # pylint: disable=protected-access

        if self._simple_epoch == 0:
            train_steps = self._n_model_initial_train_steps
        else:
            train_steps = self._n_model_train_steps_per_epoch
        self._model_train_step += train_steps
        with gin.config_scope('world_model'):
            state = trainer_lib.train(
                model=self._sim_env.model,
                inputs=inputs,
                steps=self._model_train_step,
                output_dir=self._model_dir,
            )

        logging.vlog(1, 'Training model took %0.2f sec.',
                     time.time() - start_time)
        return state.step > self._model_train_step
コード例 #5
0
def build_primitive_clustering(scopes):
    separators = []
    for scope in scopes:
        with gin.config_scope(scope):
            separator = build_primitive_separator()
            separators.append(separator)
    return separators
コード例 #6
0
    def __init__(self,
                 dataset_name,
                 train_iters=-1,
                 valid_iters=-1,
                 test_iters=-1):

        self.loaders = {}
        self.count = 0
        self.iters = {
            'train': train_iters,
            'valid': valid_iters,
            'test': test_iters
        }

        self.datasets = {}
        preloaded_data = None

        splits = [s for s in self.iters if self.iters[s] != 0]

        for split in splits:
            d = datasets.__dict__[dataset_name](split,
                                                preloaded_data=preloaded_data)
            self.datasets[split] = d

            if preloaded_data is None and not d.remote_loading:
                preloaded_data = d.data

            with gin.config_scope(split):
                self.loaders[split] = initialize_dataloader(d)
            if self.iters[split] == -1:
                self.iters[split] = len(self.loaders[split])

        self.data_iterator = {s: iter(self.loaders[s]) for s in self.loaders}
        self.cached_samples = {}
コード例 #7
0
def make_bitransformer(
    input_vocab_size=gin.REQUIRED,
    output_vocab_size=gin.REQUIRED,
    layout=None,
    mesh_shape=None):
  """Gin-configurable bitransformer constructor.

  In your config file you need to set the encoder and decoder layers like this:
  encoder/make_layer_stack.layers = [
    @transformer_layers.SelfAttention,
    @transformer_layers.DenseReluDense,
  ]
  decoder/make_layer_stack.layers = [
    @transformer_layers.SelfAttention,
    @transformer_layers.EncDecAttention,
    @transformer_layers.DenseReluDense,
  ]

  Args:
    input_vocab_size: a integer
    output_vocab_size: an integer
    layout: optional - an input to mtf.convert_to_layout_rules
      Some layers (e.g. MoE layers) cheat by looking at layout and mesh_shape
    mesh_shape: optional - an input to mtf.convert_to_shape
      Some layers (e.g. MoE layers) cheat by looking at layout and mesh_shape
  Returns:
    a Bitransformer
  """
  with gin.config_scope("encoder"):
    encoder = Unitransformer(
        layer_stack=make_layer_stack(),
        input_vocab_size=input_vocab_size,
        output_vocab_size=None,
        autoregressive=False,
        name="encoder",
        layout=layout,
        mesh_shape=mesh_shape)
  with gin.config_scope("decoder"):
    decoder = Unitransformer(
        layer_stack=make_layer_stack(),
        input_vocab_size=output_vocab_size,
        output_vocab_size=output_vocab_size,
        autoregressive=True,
        name="decoder",
        layout=layout,
        mesh_shape=mesh_shape)
  return Bitransformer(encoder, decoder)
コード例 #8
0
ファイル: evaluate.py プロジェクト: nussl/models
def evaluate(output_folder, separation_algorithm, eval_class, 
             block_on_gpu, num_workers, seed):
    nussl.utils.seed(seed)
    logging.info(gin.operative_config_str())
    
    with gin.config_scope('test'):
        test_dataset = build_dataset()
    
    results_folder = os.path.join(output_folder, 'results')
    os.makedirs(results_folder, exist_ok=True)
    set_model_to_none = False

    if block_on_gpu:
        # make an instance that'll be used on GPU
        # has an empty audio signal for now
        gpu_algorithm = separation_algorithm(
            nussl.AudioSignal(), device='cuda')
        set_model_to_none = True

    def forward_on_gpu(audio_signal):
        # set the audio signal of the object to this item's mix
        gpu_algorithm.audio_signal = audio_signal
        if hasattr(gpu_algorithm, 'forward'):
            gpu_output = gpu_algorithm.forward()
        elif hasattr(gpu_algorithm, 'extract_features'):
            gpu_output = gpu_algorithm.extract_features()
        return gpu_output

    pbar = tqdm.tqdm(total=len(test_dataset))

    def separate_and_evaluate(item, gpu_output):
        if set_model_to_none:
            separator = separation_algorithm(item['mix'], model_path=None)
        else:
            separator = separation_algorithm(item['mix'])
        estimates = separator(gpu_output)
        source_names = sorted(list(item['sources'].keys()))
        sources = [item['sources'][k] for k in source_names]
        
        # other arguments come from gin config
        evaluator = eval_class(sources, estimates)
        scores = evaluator.evaluate()
        output_path = os.path.join(
            results_folder, f"{item['mix'].file_name}.json")
        with open(output_path, 'w') as f:
            json.dump(scores, f, indent=2)
        pbar.update(1)
    
    pool = ThreadPoolExecutor(max_workers=num_workers)
    
    for i in range(len(test_dataset)):
        item = test_dataset[i]
        gpu_output = forward_on_gpu(item['mix'])
        if i == 0:
            separate_and_evaluate(item, gpu_output)
            continue
        pool.submit(separate_and_evaluate, item, gpu_output)
    
    pool.shutdown(wait=True)
コード例 #9
0
 def _setup_for_cache(scope):
     with gin.config_scope(scope):
         _dataset = helpers.build_dataset()
         _dataset.cache_populated = False
         gin.bind_parameter(
             f'{scope}/build_dataset.dataset_class', 
             _dataset
         )
コード例 #10
0
def get_base_dataset(num_examples=100, **kwargs):
    # return tf.data.Dataset.from_tensor_slices(
    #   get_base_data(num_exmaples=num_examples, **kwargs))
    from deep_cloud.problems.modelnet import ModelnetProblem
    from deep_cloud.problems.builders import pointnet_builder
    problem = ModelnetProblem(builder=pointnet_builder(2), positions_only=False)
    with gin.config_scope('train'):
        dataset = problem.get_base_dataset('validation').map(augment_cloud)
    return dataset
コード例 #11
0
def separate_and_evaluate(separator, i, gpu_output, evaluator, results_folder,
                          save_audio_path):
    with gin.config_scope('segment_and_separate'):
        test_dataset = build_dataset()
    item = test_dataset[i]
    file_name = item['mix'].file_name
    output_path = os.path.join(results_folder, f"{file_name}.json")
    if os.path.exists(output_path):
        return f"{file_name} exists!"
    if item['mix'].loudness() < -40:
        return f"{file_name} too quiet! Skipping."

    estimates = _separate(separator, item, gpu_output)
    extra_data = {}
    if hasattr(separator, 'confidence'):
        source_labels = evaluator.source_labels
        _confidence = {}
        confidence_approaches = [
            k for k in dir(nussl.ml.confidence) if 'confidence' in k
        ]
        confidence_approaches = [
            'silhouette_confidence', 'posterior_confidence'
        ]
        for k in confidence_approaches:
            _confidence[k] = [float(separator.confidence(k, threshold=99))]
        for k in source_labels:
            extra_data[k] = _confidence

    if 'sources' in item:
        source_names = sorted(list(item['sources'].keys()))
        sources = [item['sources'][k] for k in source_names]
        evaluator.estimated_sources_list = estimates
        evaluator.true_sources_list = sources
    else:
        evaluator = None

    extra_data['metadata'] = {
        'original_path':
        item['mix'].path_to_input_file,
        'separated_path': [
            os.path.join(save_audio_path, f's{i}', file_name)
            for i in range(len(estimates))
        ]
    }

    _evaluate(evaluator, file_name, results_folder, save_audio_path,
              extra_data)
    for i, e in enumerate(estimates):
        audio_dir = os.path.join(save_audio_path, f's{i}')
        audio_path = os.path.join(audio_dir, f'{file_name}')
        os.makedirs(audio_dir, exist_ok=True)
        e.write_audio_to_file(audio_path)

    return f"Done with {file_name}"
コード例 #12
0
 def test_runtime_no_error(self, num_ways, num_support, num_query, kwargs):
     """Testing run-time errors thrown when arguments are not set correctly."""
     # The following scope removes the gin-config set.
     with gin.config_scope('none'):
         # No error thrown
         _ = sampling.EpisodeDescriptionSampler(self.dataset_spec,
                                                self.split,
                                                num_ways=num_ways,
                                                num_support=num_support,
                                                num_query=num_query,
                                                **kwargs)
コード例 #13
0
 def build_models(self, model_dict, summary=True):
     model_dictionary = self._build_model_dict(model_dict)
     compiled_models = {}
     for model_name, model_arch in model_dictionary.items():
         with gin.config_scope(model_name):
             _, in_shape, out_shape = data_shape()
             cnn_model = comp_model(model_name=model_name,
                                    model_arch=model_arch,
                                    input_shape=in_shape,
                                    classes=out_shape)
             compiled_models[model_name] = cnn_model
             if summary:
                 compiled_models[model_name].summary()
     return compiled_models
コード例 #14
0
def load_transformer_model(ckpt_dir, model_cls, domain=None):
    """Loads a model from directory."""

    if domain is None:
        domain = data.protein_domain

    config_path = os.path.join(ckpt_dir, 'config.gin')
    with gin.config_scope('load_model'):
        with tf.io.gfile.GFile(config_path) as f:
            gin.parse_config(f, skip_unknown=True)
        model = model_cls(domain=domain)
        model.load_checkpoint(ckpt_dir)

    return model
コード例 #15
0
ファイル: train.py プロジェクト: nussl/cookiecutter
def cache(num_cache_workers, batch_size):
    num_cache_workers = min(num_cache_workers, multiprocessing.cpu_count())
    for scope in ['train', 'val']:
        with gin.config_scope(scope):
            dataset = build_dataset()
            dataset.cache_populated = False
            cache_dataloader = torch.utils.data.DataLoader(
                dataset, num_workers=num_cache_workers, batch_size=batch_size)
            nussl.ml.train.cache_dataset(cache_dataloader)

    alert = "Make sure to change cache_populated = True in your gin config!"
    border = ''.join(['=' for _ in alert])

    logging.info(f'\n\n{border}\n' f'{alert}\n' f'{border}\n')
コード例 #16
0
def train(game,
          identity,
          opponent_mix_str,
          epoch,
          writer,
          save_path: str = None,
          scope: str = None):
    """ Train a best response policy.

    :param game:
    :param identity:
    :param opponent_mix_str:
    :param epoch:
    """
    env = game.env
    env.reset_everything()
    env.set_training_flag(identity)

    if identity:  # Training the attacker.
        if len(opponent_mix_str) != len(game.def_str):
            raise ValueError("The length must match while training.")
        env.defender.set_mix_strategy(opponent_mix_str)
        env.defender.set_str_set(game.def_str)
        if save_path is None:
            save_path = osp.join(settings.get_attacker_strategy_dir(),
                                 "att_str_epoch" + str(epoch) + ".pkl")

    else:  # Training the defender.
        if len(opponent_mix_str) != len(game.att_str):
            raise ValueError("The length must match while training.")
        env.attacker.set_mix_strategy(opponent_mix_str)
        env.attacker.set_str_set(game.att_str)
        if save_path is None:
            save_path = osp.join(settings.get_defender_strategy_dir(),
                                 "def_str_epoch" + str(epoch) + ".pkl")

    name = "attacker" if identity else "defender"
    scope = name if scope is None else scope
    with gin.config_scope(scope):
        learner = learner_factory()
        policy, best_deviation, _, report = learner.learn_multi_nets(
            env, epoch=epoch, writer=writer, game=game)

    # add online policy, without loading policies every time.
    game.total_strategies[identity].append(policy)

    torch.save(policy, save_path, pickle_module=dill)
    # fp.save_pkl(replay_buffer, save_path[:-4]+".replay_buffer.pkl")
    return best_deviation, report
コード例 #17
0
def separate_and_evaluate(separator, i, gpu_output, evaluator, results_folder,
                          debug):
    with gin.config_scope('test'):
        test_dataset = build_dataset()
    item = test_dataset[i]
    file_name = item['mix'].file_name
    output_path = os.path.join(results_folder, f"{file_name}.json")
    if os.path.exists(output_path):
        return f"{file_name} exists!"

    estimates = _separate(separator, item, gpu_output)
    source_names = sorted(list(item['sources'].keys()))
    sources = [item['sources'][k] for k in source_names]
    evaluator.estimated_sources_list = estimates
    evaluator.true_sources_list = sources
    return _evaluate(evaluator, file_name, results_folder, debug)
コード例 #18
0
  def test_load_model(self):
    with gin.config_scope('test'):
      for k, v in lm_cfg.items():
        gin.bind_parameter('FlaxLM.%s' % k, v)

      lm = models.FlaxLM(domain=_test_domain(), random_seed=1)

      save_dir = self._tmpdir / 'save_ckpt'
      lm.save_checkpoint(save_dir)
      config_str = gin.operative_config_str()
      with tf.gfile.GFile(str(save_dir / 'config.gin'), 'w') as f:
        f.write(config_str)

      loaded_model = models.load_model(save_dir, model_cls=models.FlaxLM)
      self.assertAllEqual(
          lm.optimizer.target.params['embed']['embedding'],
          loaded_model.optimizer.target.params['embed']['embedding'])
コード例 #19
0
    def train_models(self,
                     train_list,
                     compiled_models,
                     model_type='bin_classifier',
                     save_figs=False,
                     print_class_rep=True):
        score_dict = {}
        for model_name in tqdm(train_list):
            history_dict = {}
            with gin.config_scope(model_name):
                image_size, _, _ = data_shape()
                #print(f'\nModel name: {model_name} \nImage Size: {image_size}')
                train_gen = directory_flow(dir=self.train_dir,
                                           shuffle=True,
                                           image_size=image_size)
                test_gen = directory_flow(dir=self.test_dir,
                                          shuffle=False,
                                          image_size=image_size)
                save_name = str(model_name) + ('.h5')
                history = fit_generator(model_name=model_name,
                                        model=compiled_models[model_name],
                                        gen=train_gen,
                                        validation_data=test_gen)
                save_model(model=compiled_models[model_name],
                           model_name=save_name)
                history_dict[model_name] = history.history
                test_gen.reset()
                preds = compiled_models[model_name].predict_generator(
                    test_gen,
                    verbose=1,
                    steps=math.ceil(
                        len(test_gen.classes) / test_gen.batch_size))
                score_dict[model_name] = self.score_models(
                    preds,
                    model_name,
                    history=history_dict[model_name],
                    save_figs=save_figs,
                    model_type=model_type,
                    print_class_rep=print_class_rep,
                    test_gen=test_gen)

        model_table = pd.DataFrame(
            score_dict).transpose().reset_index().rename(
                mapper={'index': 'Model_Name'}, axis=1)
        return compiled_models, model_table
コード例 #20
0
def main(_):
    device = 'cuda' if FLAGS.gpu else 'cpu'
    keynames = get_all_keynames_from_dir(FLAGS.input_path)

    logging.info('Retrieving experiment data')
    gin.parse_config_file(
        '/home/marcospiau/final_project_ia376j/src/models/gin/defaults.gin',
        skip_unknown=True)
    project = neptune.init(FLAGS.neptune_project)
    experiment = project.get_experiments(FLAGS.neptune_experiment_id)[0]
    experiment_channels = experiment.get_channels()
    with tempfile.TemporaryDirectory() as tmp_folder:
        experiment.download_artifact('gin_operative_config.gin', tmp_folder)
        gin.parse_config_file(os.path.join(tmp_folder,
                                           'gin_operative_config.gin'),
                              skip_unknown=True)

    logging.info('Loading model checkpoint')
    model = T5OCRBaseline.load_from_checkpoint(
        experiment_channels['best_model_path'].y)
    model.eval()
    model.freeze()
    model.to('cuda')

    logging.info('Preparing datasets and dataloaders')
    with gin.config_scope('sroie_t5_baseline'):
        task_functions_maps = get_tasks_functions_maps()
    datasets = get_datasets_dict_from_task_functions_map(
        keynames=keynames, tasks_functions_maps=task_functions_maps)
    loader_kwargs = {
        'num_workers': mp.cpu_count(),
        'shuffle': False,
        'pin_memory': True
    }
    dataloaders = get_dataloaders_dict_from_datasets_dict(
        datasets_dict=datasets, dataloader_kwargs=loader_kwargs)

    logging.info('Making predictions')
    preds = predict_all_fields(model, dataloaders, keynames, device=device)
    logging.info('Saving predictions')
    os.makedirs(FLAGS.output_path, exist_ok=True)
    save_predictions_in_dir(preds, FLAGS.output_path)
コード例 #21
0
    def train_model(self):
        """Train the model.

    Returns:
      whether the training was skipped due to a restart.
    """
        logging.info('SimPLe epoch [% 6d]: training model.',
                     self._simple_epoch)
        start_time = time.time()

        (train_stream, eval_stream) = self._make_input_streams()
        # Ignore n_devices for now.
        inputs = lambda _: trax_inputs.Inputs(  # pylint: disable=g-long-lambda
            train_stream=(lambda: train_stream),
            train_eval_stream=(lambda: train_stream),
            eval_stream=(lambda: eval_stream),
            input_shape=self._sim_env.model_input_shape,
            input_dtype=self._sim_env.model_input_dtype,
            # TODO(lukaszkaiser): correct those, they may differ from inputs.
            target_shape=self._sim_env.model_input_shape,
            target_dtype=self._sim_env.model_input_dtype)

        if self._simple_epoch == 0:
            train_steps = self._n_model_initial_train_steps
        else:
            train_steps = self._n_model_train_steps_per_epoch
        self._model_train_step += train_steps
        with gin.config_scope('world_model'):
            state = trainer_lib.train(
                model=self._sim_env.model,
                inputs=inputs,
                train_steps=self._model_train_step,
                output_dir=self._model_dir,
                has_weights=True,
            )

        logging.vlog(1, 'Training model took %0.2f sec.',
                     time.time() - start_time)
        return state.step > self._model_train_step
コード例 #22
0
def main(_):
    if FLAGS.module_import:
        for module in FLAGS.module_import:
            importlib.import_module(module)

    # Add search path for gin files stored in package.
    gin.add_config_file_search_path(
        pkg_resources.resource_filename(__name__, "gin"))

    models_dir_name = FLAGS.model_dir_name
    if FLAGS.model_dir_counter >= 0:
        models_dir_name += "_%s" % str(FLAGS.model_dir_counter)
    models_dir = os.path.join(FLAGS.base_dir, models_dir_name)

    model_dir = os.path.join(models_dir, FLAGS.model_size)
    try:
        tf.io.gfile.makedirs(model_dir)
        suffix = 0
        command_filename = os.path.join(model_dir, "command")
        while tf.io.gfile.exists(command_filename):
            suffix += 1
            command_filename = os.path.join(model_dir,
                                            "command.{}".format(suffix))
        with tf.io.gfile.GFile(command_filename, "w") as f:
            f.write(" ".join(sys.argv))
    except tf.errors.PermissionDeniedError:
        logging.info(
            "No write access to model directory. Skipping command logging.")

    utils.parse_gin_defaults_and_flags()

    # Load and print a few examples.
    st_task = TaskRegistry_ll.get("processed_cctk")
    sequence_length = {"inputs": 64, "targets": 64}
    sequence_length[
        "attribute"] = 64  # Or "attribute": 1 but packing not efficient...
    sequence_length["codeprefixedtargets"] = 64
    sequence_length["controlcode"] = 64

    with gin.config_scope('caet5'):
        ds = st_task.get_dataset(split="validation",
                                 sequence_length=sequence_length)

    print("A few preprocessed validation examples...")
    for ex in tfds.as_numpy(ds.take(5)):
        print(ex)
    """
    print("unitests")

    mixture_or_task_name = "processed_cctk"
    from caet5.models.mesh_transformer import mesh_train_dataset_fn_ll
    from caet5.data.utils import get_mixture_or_task_ll, MixtureRegistry_ll

    from mesh_tensorflow_caet5.dataset import pack_or_pad_ll

    mixture_or_task = get_mixture_or_task_ll("mixture_processed_cctk")

    with gin.config_scope('caet5'):
        dsbis = mixture_or_task.get_dataset(split="train", sequence_length=sequence_length)

    
    #ds2 = pack_or_pad_ll(dsbis, sequence_length, pack=False,
    #                     feature_keys=tuple(mixture_or_task.output_features), ensure_eos=True)
    

    def filter_attribute_1_fn(x):
        return tf.equal(x["attribute"][0], 1)

    def filter_attribute_2_fn(x):
        return tf.equal(x["attribute"][0], 2)

    ds_attribute_1 = dsbis.filter(filter_attribute_1_fn)
    ds_attribute_2 = dsbis.filter(filter_attribute_2_fn)

    ds2_attribute_1 = pack_or_pad_ll(
        ds_attribute_1, sequence_length, pack=False,
        feature_keys=tuple(mixture_or_task.output_features),
        ensure_eos=True)  # (not straightforward) Adapt packing so that pack=True
    ds2_attribute_2 = pack_or_pad_ll(
        ds_attribute_2, sequence_length, pack=False,
        feature_keys=tuple(mixture_or_task.output_features),
        ensure_eos=True)  # (not straightforward) Adapt packing so that pack=True

    ds3_attribute_1 = ds2_attribute_1
    ds3_attribute_2 = ds2_attribute_2

    def f1():
        return ds3_attribute_1

    def f2():
        return ds3_attribute_2

    def interleave_map_fn(x):
        return tf.cond(tf.equal(x, 0), f1, f2)

    ds3 = tf.data.Dataset.range(2).interleave(
        interleave_map_fn, cycle_length=2,
        block_length=4,
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

    print("A few preprocessed validation examples...")
    for ex in tfds.as_numpy(ds3.take(80)):
        print(ex)
    """

    if FLAGS.use_model_api:
        # Modifying original T5 in CAE-T5
        transformer.make_bitransformer = make_bitransformer_ll
        utils.tpu_estimator_model_fn = tpu_estimator_model_fn_ll

        model_parallelism, train_batch_size, keep_checkpoint_max = {
            "small": (1, 256, 16),
            "base": (2, 128, 8),
            "large": (8, 64, 4),
            "3B": (8, 16, 1),
            "11B": (8, 16, 1)
        }[FLAGS.model_size]

        model = MtfModel_ll(
            tpu_job_name=FLAGS.tpu_job_name,
            tpu=FLAGS.tpu,
            gcp_project=FLAGS.gcp_project,
            tpu_zone=FLAGS.tpu_zone,
            model_dir=model_dir,
            model_parallelism=model_parallelism,
            batch_size=train_batch_size,
            learning_rate_schedule=0.003,
            save_checkpoints_steps=2000,
            keep_checkpoint_max=keep_checkpoint_max,  # if ON_CLOUD else None,
            iterations_per_loop=100,
            model_type="bitransformer",
            unsupervised_attribute_transfer_metrics=True)

        if FLAGS.checkpoint_mode != "specific" and FLAGS.checkpoint_steps:
            raise ValueError(
                "checkpoint_mode is set to %s and checkpoint_steps is "
                "also set. To use a particular checkpoint, please set "
                "checkpoint_mode to 'specific'. For other modes, please "
                "ensure that checkpoint_steps is not set." %
                FLAGS.checkpoint_mode)

        if FLAGS.checkpoint_mode == "latest":
            checkpoint_steps = -1
        elif FLAGS.checkpoint_mode == "all":
            checkpoint_steps = "all"
        else:
            checkpoint_steps = [int(c) for c in FLAGS.checkpoint_steps]

        if FLAGS.mode == "finetune":
            pretrained_dir = os.path.join(FLAGS.base_pretrained_model_dir,
                                          FLAGS.model_size)

            model.finetune(mixture_or_task_name=FLAGS.mixture_or_task,
                           pretrained_model_dir=pretrained_dir,
                           finetune_steps=FLAGS.train_steps)

        elif FLAGS.mode == "eval":
            model.batch_size = train_batch_size * 4
            model.eval(mixture_or_task_name=FLAGS.mixture_or_task,
                       checkpoint_steps=checkpoint_steps,
                       summary_dir=FLAGS.eval_summary_dir,
                       split=FLAGS.eval_split)

            # print_random_predictions("yelp", sequence_length, model_dir, n=10)

        elif FLAGS.mode == "predict":
            if FLAGS.predict_batch_size > 0:
                model.batch_size = FLAGS.predict_batch_size
            model.predict(checkpoint_steps=checkpoint_steps,
                          input_file=FLAGS.input_file,
                          output_file=FLAGS.output_file,
                          temperature=0)
        else:
            raise ValueError("--mode flag must be set when using Model API.")

    else:
        raise NotImplementedError()
コード例 #23
0
def make_scaper_datasets(scopes=['train', 'val']):
    for scope in scopes:
        with gin.config_scope(scope):
            mix_with_scaper()
コード例 #24
0
def main(_):
    # https://github.com/google-research/text-to-text-transfer-transformer/blob/c0ea75dbe9e35a629ae2e3c964ef32adc0e997f3/t5/models/mesh_transformer_main.py#L149
    # Add search path for gin files stored in package.
    gin.add_config_file_search_path(
        pkg_resources.resource_filename(__name__, "gin"))
    gin.parse_config_files_and_bindings(FLAGS.gin_file,
                                        FLAGS.gin_param,
                                        finalize_config=True)
    pl.seed_everything(1234)
    with gin.config_scope('sroie_t5_baseline'):
        task_functions_maps = get_tasks_functions_maps()

    # Datasets
    with gin.config_scope('train_sroie'):
        train_keynames = get_all_keynames_from_dir()

    with gin.config_scope('validation_sroie'):
        val_keynames = get_all_keynames_from_dir()

    train_datasets = get_datasets_dict_from_task_functions_map(
        keynames=train_keynames, tasks_functions_maps=task_functions_maps)
    val_datasets = get_datasets_dict_from_task_functions_map(
        keynames=val_keynames, tasks_functions_maps=task_functions_maps)

    with gin.config_scope('task_train'):
        task_train = operative_macro()

    # Initializing model
    model = T5OCRBaseline()

    # Trainer
    if FLAGS.debug:
        logger = False
        trainer_callbacks = []
    else:
        logger = NeptuneLogger(
            close_after_fit=False,
            api_key=os.environ["NEPTUNE_API_TOKEN"],
            # project_name is set via gin file
            # params=None,
            tags=[model.t5_model_prefix, task_train, 't5_ocr_baseline'])
        with gin.config_scope('sroie_t5_baseline'):
            checkpoint_callback = config_model_checkpoint(
                monitor=None if FLAGS.best_model_run_mode else "val_f1",
                dirpath=("/home/marcospiau/final_project_ia376j/checkpoints/"
                         f"{logger.project_name.replace('/', '_')}/"
                         "t5_ocr_baseline/"),
                prefix=(
                    f"experiment_id={logger.experiment.id}-task={task_train}-"
                    "t5_model_prefix="
                    f"{model.t5_model_prefix.replace('-', '_')}"),
                filename=("{step}-{epoch}-{val_precision:.6f}-{val_recall:.6f}"
                          "-{val_f1:.6f}-{val_exact_match:.6f}"),
                mode="max",
                save_top_k=None if FLAGS.best_model_run_mode else 1,
                verbose=True)
        early_stop_callback = config_early_stopping_callback()
        trainer_callbacks = [checkpoint_callback, early_stop_callback]

    trainer = Trainer(
        checkpoint_callback=not (FLAGS.debug),
        log_gpu_memory=True,
        # profiler=FLAGS.debug,
        logger=logger,
        callbacks=trainer_callbacks,
        progress_bar_refresh_rate=1,
        log_every_n_steps=1)
    # Dataloaders
    train_loader_kwargs = {
        'num_workers': mp.cpu_count(),
        'shuffle': True if (trainer.overfit_batches == 0) else False,
        'pin_memory': True
    }

    if trainer.overfit_batches != 0:
        with gin.unlock_config():
            gin.bind_parameter(
                'get_dataloaders_dict_from_datasets_dict.batch_size', 1)

    eval_loader_kwargs = {**train_loader_kwargs, **{'shuffle': False}}

    train_dataloaders = get_dataloaders_dict_from_datasets_dict(
        datasets_dict=train_datasets, dataloader_kwargs=train_loader_kwargs)
    val_dataloaders = get_dataloaders_dict_from_datasets_dict(
        datasets_dict=val_datasets, dataloader_kwargs=eval_loader_kwargs)

    # Logging important artifacts and params
    if logger:
        to_upload = {
            'gin_operative_config.gin': gin.operative_config_str(),
            'gin_complete_config.gin': gin.config_str(),
            'abseil_flags.txt': FLAGS.flags_into_string()
        }
        for destination, content in to_upload.items():
            buffer = StringIO(initial_value=content)
            buffer.seek(0)
            logger.log_artifact(buffer, destination=destination)
        params_to_log = dict()
        params_to_log['str_replace_newlines'] = gin.query_parameter(
            'sroie_t5_baseline/get_default_preprocessing_functions.'
            'str_replace_newlines')
        params_to_log['task_train'] = task_train
        params_to_log['patience'] = early_stop_callback.patience
        params_to_log['max_epochs'] = trainer.max_epochs
        params_to_log['min_epochs'] = trainer.min_epochs
        params_to_log[
            'accumulate_grad_batches'] = trainer.accumulate_grad_batches
        params_to_log['batch_size'] = train_dataloaders[task_train].batch_size

        for k, v in params_to_log.items():
            logger.experiment.set_property(k, v)

    trainer.fit(model,
                train_dataloader=train_dataloaders[task_train],
                val_dataloaders=val_dataloaders[task_train])

    # Logging best metrics and saving best checkpoint on Neptune experiment
    if logger:
        trainer.logger.experiment.log_text(
            log_name='best_model_path',
            x=trainer.checkpoint_callback.best_model_path)
        if not (FLAGS.best_model_run_mode):
            trainer.logger.experiment.log_metric(
                'best_model_val_f1',
                trainer.checkpoint_callback.best_model_score.item())
        if FLAGS.upload_best_checkpoint:
            trainer.logger.experiment.log_artifact(
                trainer.checkpoint_callback.best_model_path)

        trainer.logger.experiment.stop()
コード例 #25
0
        def logits_and_loss(mtf_features):
            """Compute logits and loss.
            Args:
              mtf_features: a dictionary
            Returns:
              logits: a mtf.Tensor
              loss: a mtf.Tensor
            """
            if model_type == "lm":  # TOTRY Adapt that to our case
                if "inputs" in mtf_features:
                    mtf_features = _dynamic_text2self(mtf_features)
                _, _, length_dim = mtf_features["targets"].shape
                inputs = mtf.shift(mtf_features["targets"],
                                   offset=1,
                                   dim=length_dim,
                                   wrap=False)
            else:
                inputs = mtf_features["inputs"]

            if attribute_embedding:
                attributes = mtf_features["attribute"]
            else:
                attributes = None

            if control_codes:
                codeprefixedtargets = mtf_features["codeprefixedtargets"]
            else:
                codeprefixedtargets = None

            if isinstance(transformer_model, transformer.Unitransformer):
                position_kwargs = dict(
                    sequence_id=mtf_features.get("targets_segmentation", None),
                    position=mtf_features.get("targets_position", None),
                )
            elif isinstance(transformer_model, transformer.Bitransformer
                            ) or model_type == "bi_student_teacher":
                if control_codes:
                    position_kwargs = dict(
                        encoder_sequence_id=mtf_features.get(
                            "inputs_segmentation", None),
                        decoder_sequence_id=mtf_features.get(
                            "codeprefixedtargets_segmentation", None),
                        decoder_subsequence_id=mtf_features.get(
                            "codeprefixedtargets_subsegmentation", None),
                        encoder_position=mtf_features.get(
                            "inputs_position", None),
                        decoder_position=mtf_features.get(
                            "codeprefixedtargets_position", None),
                    )
                else:
                    position_kwargs = dict(
                        encoder_sequence_id=mtf_features.get(
                            "inputs_segmentation", None),
                        decoder_sequence_id=mtf_features.get(
                            "targets_segmentation", None),
                        decoder_subsequence_id=mtf_features.get(
                            "targets_subsegmentation", None),
                        encoder_position=mtf_features.get(
                            "inputs_position", None),
                        decoder_position=mtf_features.get(
                            "targets_position", None),
                    )
            else:
                raise ValueError("unrecognized class")

            if isinstance(transformer_model, Bitransformer_ll):
                if cycle_consistency_loss:
                    logits_ae, l_ae = transformer_model.call_simple(
                        inputs=inputs,
                        targets=mtf_features["targets"],
                        compute_loss=True,
                        attributes=attributes,
                        codeprefixedtargets=codeprefixedtargets,
                        mode=mode,
                        variable_dtype=get_variable_dtype(),
                        **position_kwargs)

                    if has_partial_sequences:
                        controlcodes = mtf_features["controlcode"]
                    else:
                        controlcodes = None

                    with gin.config_scope('training'):
                        mtf_samples = transformer_model.decode(
                            inputs,
                            attributes=attributes,
                            controlcodes=controlcodes,
                            has_partial_sequences=has_partial_sequences,
                            remove_partial_sequences=remove_partial_sequences,
                            variable_dtype=get_variable_dtype())
                        # mtf_samples = mtf.anonymize(mtf_samples)
                    outputs = mtf_samples

                    logits_cycle, l_cycle = transformer_model.call_simple(
                        inputs=outputs,
                        targets=mtf_features["targets"],
                        compute_loss=True,
                        attributes=attributes,
                        codeprefixedtargets=codeprefixedtargets,
                        mode=mode,
                        variable_dtype=get_variable_dtype(),
                        **position_kwargs)

                    loss_ae_cycle = lambda_ae * l_ae + lambda_cycle * l_cycle
                    return logits_cycle, loss_ae_cycle
                else:
                    return transformer_model.call_simple(
                        inputs=inputs,
                        targets=mtf_features["targets"],
                        compute_loss=True,
                        attributes=attributes,
                        codeprefixedtargets=codeprefixedtargets,
                        mode=mode,
                        variable_dtype=get_variable_dtype(),
                        **position_kwargs)
            else:
                return transformer_model.call_simple(
                    inputs=inputs,
                    targets=mtf_features["targets"],
                    compute_loss=True,
                    mode=mode,
                    variable_dtype=get_variable_dtype(),
                    num_microbatches=num_microbatches,
                    **position_kwargs)
コード例 #26
0
def segment_and_separate(output_folder,
                         separation_algorithm,
                         eval_class,
                         block_on_gpu,
                         num_workers,
                         seed,
                         save_audio_path,
                         use_threadpool=False,
                         num_sources=None):
    nussl.utils.seed(seed)
    logging.info(gin.operative_config_str())

    with gin.config_scope('segment_and_separate'):
        test_dataset = build_dataset()

    results_folder = os.path.join(output_folder, 'results')
    os.makedirs(results_folder, exist_ok=True)
    set_model_to_none = False

    if block_on_gpu:
        # make an instance that'll be used on GPU
        # has an empty audio signal for now
        gpu_algorithm = separation_algorithm(nussl.AudioSignal(),
                                             device='cuda')
        set_model_to_none = True

    def forward_on_gpu(audio_signal):
        # set the audio signal of the object to this item's mix
        if block_on_gpu:
            gpu_algorithm.audio_signal = audio_signal
            if hasattr(gpu_algorithm, 'forward'):
                gpu_output = gpu_algorithm.forward()
            elif hasattr(gpu_algorithm, 'extract_features'):
                gpu_output = gpu_algorithm.extract_features()
            model_output = {
                k: v.cpu()
                for k, v in gpu_algorithm.model_output.items()
            }
            return gpu_output, model_output
        else:
            return None

    pbar = tqdm.tqdm(total=len(test_dataset))

    PoolExecutor = (ThreadPoolExecutor
                    if use_threadpool else ProcessPoolExecutor)

    with PoolExecutor(max_workers=num_workers) as pool:

        def update(future):
            desc = future.result()
            pbar.update()
            pbar.set_description(desc)

        indices = list(range(len(test_dataset)))
        shuffle(indices)
        for i in indices:
            item = test_dataset[i]

            file_name = item['mix'].file_name
            output_path = os.path.join(results_folder, f"{file_name}.json")
            if os.path.exists(output_path):
                pbar.set_description(f"{file_name} exists!")
                pbar.update()
                continue

            pbar.set_description(f"Starting {item['mix'].file_name}")
            gpu_output = forward_on_gpu(item['mix'])
            kwargs = {'model_path': None} if set_model_to_none else {}

            empty_signal = nussl.AudioSignal(audio_data_array=np.random.rand(
                1, 100),
                                             sample_rate=100)
            separator = separation_algorithm(empty_signal, **kwargs)
            if set_model_to_none:
                separator.model = None
            if 'sources' in item:
                num_sources = len(item['sources'])

            dummy_signal_list = [
                nussl.AudioSignal(audio_data_array=np.random.rand(1, 100),
                                  sample_rate=100) for _ in range(num_sources)
            ]
            evaluator = eval_class(dummy_signal_list, dummy_signal_list)
            args = (separator, i, gpu_output, evaluator, results_folder,
                    save_audio_path)

            if num_workers == 1:
                desc = separate_and_evaluate(*args)
                pbar.update()
                pbar.set_description(desc)
            else:
                future = pool.submit(separate_and_evaluate, *args)
                future.add_done_callback(update)
コード例 #27
0
def train_eval(
    # Params for experiment identification
    root_dir: str,
    training_id: str,
    model_id: str,
    env_id: str,
    # Params for training process
    num_iterations: int,
    target_update_period: int,
    # Params for eval
    eval_interval: int,
    num_eval_episodes: int,
    # Param for checkpoints
    checkpoint_interval: int,
    replay_size: int,
):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    experiment_dir = os.path.join(root_dir,
                                  "-".join([model_id, env_id, training_id]))
    model_dir = os.path.join(experiment_dir, 'model')
    plots_dir = os.path.join(experiment_dir, 'plots')
    model_path = os.path.join(model_dir, 'dqn_model.h5')
    image_path = os.path.join(plots_dir, 'reward_plot.png')
    csv_path = os.path.join(plots_dir, 'reward_data.csv')

    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
        os.makedirs(plots_dir)

    tf.profiler.experimental.server.start(6008)

    # create the enviroment
    env = utility.create_environment()
    epoch_counter = tf.Variable(0,
                                trainable=False,
                                name='Epoch',
                                dtype=tf.int64)
    # Epsilon implementing decaying behaviour for the two agents
    decaying_epsilon = utility.decaying_epsilon(step=epoch_counter)
    # create an agent and a network
    tf_agent = agent.DQNAgent(epsilon=decaying_epsilon,
                              obs_spec=env.observation_spec())
    # replay buffer
    replay_memory = agent.ReplayMemory(maxlen=replay_size)
    reward_tracker = agent.RewardTracker()

    start_time = datetime.now()

    # Initial collection
    with gin.config_scope('initial_step'):
        collection_step(env=env,
                        tf_agent=tf_agent,
                        replay_memory=replay_memory,
                        reward_tracker=reward_tracker)

    for _ in range(num_iterations):
        epoch_counter.assign_add(1)
        tf.summary.scalar(name='Epsilon',
                          data=decaying_epsilon(),
                          step=epoch_counter)

        episode_reward = collection_step(env=env,
                                         tf_agent=tf_agent,
                                         replay_memory=replay_memory,
                                         reward_tracker=reward_tracker)

        training_step(tf_agent=tf_agent, replay_memory=replay_memory)

        avg_reward = reward_tracker.mean()
        tf.summary.scalar(name='Average Reward',
                          data=avg_reward,
                          step=epoch_counter)

        print(
            "\rTime: {}, Episode: {}, Reward: {}, Avg Reward {}, eps: {:.3f}".
            format(datetime.now() - start_time, epoch_counter.numpy(),
                   np.round(episode_reward, 2), np.round(avg_reward, 2),
                   decaying_epsilon().numpy()),
            end="")

        # Copy weights from main model to target model
        if epoch_counter.numpy() % target_update_period == 0:
            tf_agent.update_target_model()

        # Checkpointing
        if epoch_counter.numpy() % checkpoint_interval == 0:
            tf_agent.save_model(model_path)
            plot_learning_curve(reward_tracker=reward_tracker,
                                image_path=image_path,
                                csv_path=csv_path)

        # Evaluation Run
        if epoch_counter.numpy() % eval_interval == 0:
            pass
コード例 #28
0
ファイル: neutra.py プロジェクト: guangyusong/google-research
    def __init__(self,
                 train_batch_size=4096,
                 test_chain_batch_size=4096,
                 bijector="iaf",
                 log_dir="/tmp/neutra",
                 base_learning_rate=1e-3,
                 q_base_scale=1.,
                 learning_rate_schedule=[[6000, 1e-1]]):
        target, target_spec = GetTargetSpec()
        self.target = target
        self.target_spec = target_spec
        with gin.config_scope("train"):
            train_target, train_target_spec = GetTargetSpec()
            self.train_target = train_target
            self.train_target_spec = train_target_spec

        if bijector == "rnvp":
            bijector_fn = tf.make_template("bijector",
                                           MakeRNVPBijectorFn,
                                           num_dims=self.target_spec.num_dims)
        elif bijector == "iaf":
            bijector_fn = tf.make_template("bijector",
                                           MakeIAFBijectorFn,
                                           num_dims=self.target_spec.num_dims)
        elif bijector == "affine":
            bijector_fn = tf.make_template("bijector",
                                           MakeAffineBijectorFn,
                                           num_dims=self.target_spec.num_dims)
        else:
            bijector_fn = lambda *args, **kwargs: tfb.Identity()

        self.train_bijector = bijector_fn(train=True)
        self.bijector = bijector_fn(train=False)
        if train_target_spec.bijector is not None:
            print("Using train target bijector")
            self.train_bijector = tfb.Chain(
                [train_target_spec.bijector, self.train_bijector])
        if target_spec.bijector is not None:
            print("Using target bijector")
            self.bijector = tfb.Chain([target_spec.bijector, self.bijector])

        q_base = tfd.Independent(
            tfd.Normal(loc=tf.zeros(self.target_spec.num_dims),
                       scale=q_base_scale *
                       tf.ones(self.target_spec.num_dims)), 1)
        self.q_x_train = tfd.TransformedDistribution(q_base,
                                                     self.train_bijector)
        self.q_x = tfd.TransformedDistribution(q_base, self.bijector)

        # Params
        self.train_batch_size = int(train_batch_size)
        self.test_chain_batch_size = tf.placeholder_with_default(
            test_chain_batch_size, [], "test_chain_batch_size")
        self.test_batch_size = tf.placeholder_with_default(
            16384 * 8, [], "test_batch_size")
        self.test_num_steps = tf.placeholder_with_default(
            1000, [], "test_num_steps")
        self.test_num_leapfrog_steps = tf.placeholder_with_default(
            tf.to_int32(2), [], "test_num_leapfrog_steps")
        self.test_step_size = tf.placeholder_with_default(
            0.1, [], "test_step_size")

        # Test
        self.neutra_outputs = MakeNeuTra(
            target=self.target,
            q=self.q_x,
            batch_size=self.test_chain_batch_size,
            num_steps=self.test_num_steps,
            num_leapfrog_steps=self.test_num_leapfrog_steps,
            step_size=self.test_step_size,
        )
        self.z_chain = tf.reshape(
            self.bijector.inverse(
                tf.reshape(self.neutra_outputs.x_chain,
                           [-1, self.target_spec.num_dims])),
            tf.shape(self.neutra_outputs.x_chain))
        self.target_samples = self.target.sample(self.test_batch_size)
        self.target_z = self.bijector.inverse(self.target_samples)
        self.q_samples = self.q_x.sample(self.test_batch_size)

        self.target_cov = utils.Covariance(self.target_samples)
        self.target_eigvals, self.target_eigvecs = tf.linalg.eigh(
            self.target_cov)

        self.cached_target_eigvals = tf.get_local_variable(
            "cached_target_eigvals",
            self.target_eigvals.shape,
            initializer=tf.zeros_initializer())
        self.cached_target_eigvecs = tf.get_local_variable(
            "cached_target_eigvecs",
            self.target_eigvecs.shape,
            initializer=tf.zeros_initializer())
        self.cached_target_stats_update_op = [
            self.cached_target_eigvals.assign(self.target_eigvals),
            self.cached_target_eigvecs.assign(self.target_eigvecs),
            tf.print("Assigning target stats")
        ]

        def variance(x):
            x -= tf.reduce_mean(x, 0, keep_dims=True)
            x = tf.square(x)
            return x

        def rotated_variance(x):
            x2 = tf.reshape(x, [-1, self.target_spec.num_dims])
            x2 -= tf.reduce_mean(x2, 0, keep_dims=True)
            x2 = tf.matmul(x2, self.cached_target_eigvecs)
            x2 = tf.square(x2)
            return tf.reshape(x2, tf.shape(x))

        functions = [
            ("mean", tf.identity),
            #        ("var", variance),
            ("square", tf.square),
            #        ("rot_square", rot_square),
            #        ("rot_var", rotated_variance),
        ]

        self.cached_target_mean = {}
        self.cached_target_mean_update_op = [
            tf.print("Assigning target means.")
        ]
        self.neutra_stats = {}
        self.q_stats = {}

        for name, f in functions:
            target_mean = tf.reduce_mean(f(self.target_samples), 0)
            cached_target_mean = tf.get_local_variable(name + "_cached_mean",
                                                       target_mean.shape)
            if self.target_spec.stats is not None:
                self.cached_target_mean_update_op.append(
                    cached_target_mean.assign(self.target_spec.stats[name]))
            else:
                self.cached_target_mean_update_op.append(
                    cached_target_mean.assign(target_mean))

            self.cached_target_mean[name] = cached_target_mean
            self.q_stats[name] = ComputeQStats(f(self.q_samples),
                                               cached_target_mean)
            self.neutra_stats[name] = ComputeChainStats(
                f(self.neutra_outputs.x_chain), cached_target_mean,
                self.test_num_leapfrog_steps)

        # Training
        self.train_q_samples = self.q_x_train.sample(self.train_batch_size)
        self.train_log_q_x = self.q_x_train.log_prob(self.train_q_samples)
        self.kl_q_p = tf.reduce_mean(
            self.train_log_q_x - self.target.log_prob(self.train_q_samples))

        loss = self.kl_q_p
        reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        if reg_losses:
            tf.logging.info("Regularizing.")
            loss += tf.add_n(reg_losses)
        self.loss = tf.check_numerics(loss, "Loss has NaNs")

        self.global_step = tf.train.get_or_create_global_step()
        steps, factors = list(zip(*learning_rate_schedule))
        learning_rate = base_learning_rate * tf.train.piecewise_constant(
            self.global_step, steps, [1.0] + list(factors))

        opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
        self.train_op = opt.minimize(self.loss, global_step=self.global_step)

        tf.summary.scalar("kl_q_p", self.kl_q_p)
        tf.summary.scalar("loss", self.loss)

        self.init = [
            tf.global_variables_initializer(),
            tf.local_variables_initializer(),
            tf.print("Initializing variables")
        ]

        self.saver = tf.train.Saver()
        self.log_dir = log_dir
コード例 #29
0
def evaluate_qmix(opponents: typing.List, mixture: typing.List):
    """ . """
    assert len(opponents) == len(mixture)
    name = "player1"
    env = GridWorldSoccer()

    # -------------------------------------------------------------------------
    # Train best-response to each pure-strategy opponent.
    logger.info("Training best-response against each pure-strategy.")
    best_responses = []
    replay_buffers = []
    best_response_paths = []
    for opponent_i, opponent in enumerate(opponents):
        logger.info(f"  - Training against opponent {opponent_i}")
        br_path = osp.join(settings.get_run_dir(),
                           f"v{opponent_i}.best_response.pkl")
        best_response_paths += [br_path]
        with gin.config_scope("pure"):
            response, replay_buffer = _train(
                br_path, opponent,
                SummaryWriter(logdir=osp.join(settings.get_run_dir(),
                                              f"br_vs_{opponent_i}")))
        best_responses += [response]
        replay_buffers += [replay_buffer]

    # -------------------------------------------------------------------------
    # Simulate the performance of QMixture.
    logger.info("Simulating the performance of the QMixture.")
    qmix = QMixture(mixture=mixture, q_funcs=best_responses)

    # Save policy, for future evaluation.
    qmix_path = osp.join(settings.get_run_dir(), "qmix.pkl")
    torch.save(qmix, qmix_path, pickle_module=dill)

    qmix_rewards = []
    mixed_reward = 0.0
    reward_std = 0.0
    for opponent_i, opponent in enumerate(opponents):
        rewards, _ = simulate_profile(env=env,
                                      nn_att=qmix,
                                      nn_def=opponent,
                                      n_episodes=250,
                                      save_dir=None,
                                      summary_writer=None,
                                      raw_rewards=True)

        logger.info(
            f"  - Opponent {opponent_i} vs. QMix: {np.mean(rewards)}, {np.std(rewards)}"
        )
        qmix_rewards += [rewards]
        mixed_reward += mixture[opponent_i] * np.mean(rewards)
        reward_std += mixture[opponent_i]**2 * np.std(rewards)**2
    reward_std = np.sqrt(reward_std)
    logger.info(
        f"Expected reward against mixture opponent: {mixed_reward}, {reward_std}"
    )
    dill.dump(
        mixed_reward,
        open(osp.join(settings.get_run_dir(), "qmix.simulated_reward.pkl"),
             "wb"))

    # -------------------------------------------------------------------------
    # Simulate the performance of QMixture with state frequencies.
    """
    logger.info("Simulating the performance of the QMixture with State-Frequency weighting.")
    qmix_statefreq = QMixtureStateFreq(mixture=mixture, q_funcs=best_responses, replay_buffers=replay_buffers)

    # Save policy, for future evaluation.
    qmix_statefreq_path = osp.join(settings.get_run_dir(), "qmix_statefreq.pkl")
    torch.save(qmix_statefreq, qmix_statefreq_path, pickle_module=dill)

    qmix_statefreq_rewards = []
    mixed_statefreq_reward = 0.0
    for opponent_i, opponent in enumerate(opponents):
        rewards, _ = simulate_profile(
            env=env,
            nn_att=qmix_statefreq,
            nn_def=opponent,
            n_episodes=250,
            save_dir=None,
            summary_writer=SummaryWriter(logdir=osp.join(settings.get_run_dir(), f"simulate_statefreq_vs_{opponent_i}")),
            raw_rewards=True)

        logger.info(f"  - Opponent {opponent_i}: {np.mean(rewards)}, {np.std(rewards)}")
        with open(osp.join(settings.get_run_dir(), f"qmix_statefreq.rewards_v{opponent_i}.pkl"), "wb") as outfile:
            dill.dump(rewards, outfile)
        qmix_statefreq_rewards += [rewards]
        mixed_statefreq_reward += mixture[opponent_i] * np.mean(rewards)
    logger.info(f"Expected reward against mixture opponent: {mixed_statefreq_reward}")
    dill.dump(mixed_reward, open(osp.join(settings.get_run_dir(), "qmix_statefreq.simulated_reward.pkl"), "wb"))
    """
    # -------------------------------------------------------------------------
    # Train best-response to opponent mixture.
    logger.info("Training a best-response against the mixture opponent.")
    mixture_br_path = osp.join(settings.get_run_dir(),
                               "mixture.best_response.pkl")
    opponent_agent = Agent(mixture=mixture, policies=opponents)

    with gin.config_scope("mix"):
        mixture_br, _ = _train(
            mixture_br_path, opponent_agent,
            SummaryWriter(
                logdir=osp.join(settings.get_run_dir(), "br_vs_mixture")))

    # -------------------------------------------------------------------------
    # Evaluate the mixture policy against the individual opponent strategies.
    logger.info(
        "Evaluating the best-response trained against mixture opponents on pure-strategy opponents."
    )

    mix_br_reward = 0.0
    reward_std = 0.0
    for opponent_i, opponent in enumerate(opponents):
        rewards, _ = simulate_profile(env=env,
                                      nn_att=mixture_br,
                                      nn_def=opponent,
                                      n_episodes=250,
                                      save_dir=None,
                                      summary_writer=None,
                                      raw_rewards=True)

        logger.info(
            f"  - Opponent {opponent_i} vs. MixtureBR: {np.mean(rewards)}, {np.std(rewards)}"
        )
        mix_br_reward += mixture[opponent_i] * np.mean(rewards)
        reward_std += mixture[opponent_i]**2 * np.std(rewards)**2
    reward_std = np.sqrt(reward_std)
    logger.info(
        f"Expected reward for mixture best-response: {mix_br_reward}, {reward_std}"
    )

    # -------------------------------------------------------------------------
    # Evaluate pure-strategy-best-response policies against all opponents (all pure strategy + mixture).
    logger.info(
        "Evaluating pure-strategy-best-response against all opponent policies."
    )

    response_rewards = {}
    response_std = {}
    for opponent_i, opponent in enumerate(opponents):
        for response_i, best_response in enumerate(best_responses):
            rewards, _ = simulate_profile(env=env,
                                          nn_att=best_response,
                                          nn_def=opponent,
                                          n_episodes=250,
                                          save_dir=None,
                                          summary_writer=None,
                                          raw_rewards=True)

            logger.info(
                f"  - Opponent {opponent_i} vs. Best-Response {response_i}: {np.mean(rewards)}, {np.std(rewards)}"
            )
            if response_i not in response_rewards:
                response_rewards[response_i] = 0.0
                response_std[response_i] = 0.0
            response_rewards[response_i] += mixture[opponent_i] * np.mean(
                rewards)
            response_std[response_i] += mixture[opponent_i]**2 * np.std(
                rewards)**2

    for key, value in response_rewards.items():
        logger.info(
            f"Expected reward of response {key} against mixture: {value}, {np.sqrt(response_std[key])}"
        )
    logger.info("Finished.")
コード例 #30
0
 def GinWrapper(*args, **kwargs):
     with gin.config_scope(saved_scopes):
         return function(*args, **kwargs)