Exemplo n.º 1
0
    def intp_tensors(self, preact_only=False, graph=None):
        """Return the required interpretation tensors (scalars)

        Note: Since we are predicting a track,
            we should return a single scalar here
        """
        if graph is None:
            graph = tf.get_default_graph()

        preact = graph.get_tensor_by_name(self.pre_act)
        postact = graph.get_tensor_by_name(self.post_act)

        # Contruct the profile summary ops
        preact_tensors = self.profile_contrib(preact)
        postact_tensors = dict_prefix_key(self.profile_contrib(postact),
                                          'output_')

        if self.activation is None:
            # the post-activation doesn't
            # have any specific meaning when
            # we don't use any activation function
            return preact_tensors

        if preact_only:
            return preact_tensors
        else:
            return {**preact_tensors, **postact_tensors}
Exemplo n.º 2
0
    def evaluate(self, metric, batch_size=256, num_workers=8, eval_train=False, eval_skip=[], save=True, **kwargs):
        """Evaluate the model on the validation set
        Args:
          metric: a function accepting (y_true, y_pred) and returning the evaluation metric(s)
          batch_size:
          num_workers:
          eval_train: if True, also compute the evaluation metrics on the training set
          save: save the json file to the output directory
        """
        if len(kwargs) > 0:
            logger.warn(f"Extra kwargs were provided to trainer.evaluate(): {kwargs}")
        # Save the complete model -> HACK
        self.seq_model.save(os.path.join(self.output_dir, 'seq_model.pkl'))

        # contruct a list of dataset to evaluate
        if eval_train:
            eval_datasets = [('train', self.train_dataset)] + self.valid_dataset
        else:
            eval_datasets = self.valid_dataset

        # skip some datasets for evaluation
        try:
            if len(eval_skip) > 0:
                logger.info(f"Using eval_skip: {eval_skip}")
                eval_datasets = [(k, v) for k, v in eval_datasets if k not in eval_skip]
        except Exception:
            logger.warn(f"eval datasets don't contain tuples. Unable to skip them using {eval_skip}")

        metric_res = OrderedDict()
        for d in eval_datasets:
            if len(d) == 2:
                dataset_name, dataset = d
                eval_metric = None  # Ignore the provided metric
            elif len(d) == 3:
                # specialized evaluation metric was passed
                dataset_name, dataset, eval_metric = d
            else:
                raise ValueError("Valid dataset needs to be a list of tuples of 2 or 3 elements"
                                 "(name, dataset) or (name, dataset, metric)")
            logger.info(f"Evaluating dataset: {dataset_name}")
            metric_res[dataset_name] = self.seq_model.evaluate(dataset,
                                                               eval_metric=eval_metric,
                                                               num_workers=num_workers,
                                                               batch_size=batch_size)
        if save:
            write_json(metric_res, self.evaluation_path, indent=2)
            logger.info("Saved metrics to {}".format(self.evaluation_path))

        if self.cometml_experiment is not None:
            self.cometml_experiment.log_metrics(flatten(metric_res, separator='/'), prefix="eval/")

        if self.wandb_run is not None:
            self.wandb_run.summary.update(flatten(dict_prefix_key(metric_res, prefix="eval/"), separator='/'))

        return metric_res
Exemplo n.º 3
0
def bpnet_train(dataspec,
                output_dir,
                premade='bpnet9',
                config=None,
                override='',
                gpu=0,
                memfrac_gpu=0.45,
                num_workers=8,
                vmtouch=False,
                in_memory=False,
                wandb_project="",
                cometml_project="",
                run_id=None,
                note_params="",
                overwrite=False):
    """Train a model using gin-config

    Output files:
      train.log - log file
      model.h5 - Keras model HDF5 file
      seqmodel.pkl - Serialized SeqModel. This is the main trained model.
      eval-report.ipynb/.html - evaluation report containing training loss curves and some example model predictions.
        You can specify your own ipynb using `--override='report_template.name="my-template.ipynb"'`.
      model.gin -> copied from the input
      dataspec.yaml -> copied from the input
    """
    cometml_experiment, wandb_run, output_dir = start_experiment(
        output_dir=output_dir,
        cometml_project=cometml_project,
        wandb_project=wandb_project,
        run_id=run_id,
        note_params=note_params,
        overwrite=overwrite)
    # remember the executed command
    write_json(
        {
            "dataspec": dataspec,
            "output_dir": output_dir,
            "premade": premade,
            "config": config,
            "override": override,
            "gpu": gpu,
            "memfrac_gpu": memfrac_gpu,
            "num_workers": num_workers,
            "vmtouch": vmtouch,
            "in_memory": in_memory,
            "wandb_project": wandb_project,
            "cometml_project": cometml_project,
            "run_id": run_id,
            "note_params": note_params,
            "overwrite": overwrite
        },
        os.path.join(output_dir, 'bpnet-train.kwargs.json'),
        indent=2)

    # copy dataspec.yml and input config file over
    if config is not None:
        shutil.copyfile(config, os.path.join(output_dir, 'input-config.gin'))

    # parse and validate the dataspec
    ds = DataSpec.load(dataspec)
    related_dump_yaml(ds.abspath(), os.path.join(output_dir, 'dataspec.yml'))
    if vmtouch:
        if shutil.which('vmtouch') is None:
            logger.warn(
                "vmtouch is currently not installed. "
                "--vmtouch disabled. Please install vmtouch to enable it")
        else:
            # use vmtouch to load all file to memory
            ds.touch_all_files()

    # --------------------------------------------
    # Parse the config file
    # import gin.tf
    if gpu is not None:
        logger.info(f"Using gpu: {gpu}, memory fraction: {memfrac_gpu}")
        create_tf_session(gpu, per_process_gpu_memory_fraction=memfrac_gpu)

    gin_files = _get_gin_files(premade, config)

    # infer differnet hyper-parameters from the dataspec file
    if len(ds.bias_specs) > 0:
        use_bias = True
        if len(ds.bias_specs) > 1:
            # TODO - allow multiple bias track
            # - split the heads separately
            raise ValueError("Only a single bias track is currently supported")

        bias = [v for k, v in ds.bias_specs.items()][0]
        n_bias_tracks = len(bias.tracks)
    else:
        use_bias = False
        n_bias_tracks = 0
    tasks = list(ds.task_specs)
    # TODO - handle multiple track widths?
    tracks_per_task = [len(v.tracks) for k, v in ds.task_specs.items()][0]
    # figure out the right hyper-parameters
    dataspec_bindings = [
        f'dataspec="{dataspec}"', f'use_bias={use_bias}',
        f'n_bias_tracks={n_bias_tracks}', f'tracks_per_task={tracks_per_task}',
        f'tasks={tasks}'
    ]

    gin.parse_config_files_and_bindings(
        gin_files,
        bindings=dataspec_bindings + override.split(";"),
        # NOTE: custom files were inserted right after
        # ther user's config file and before the `override`
        # parameters specified at the command-line
        # This allows the user to disable the bias correction
        # despite being specified in the config file
        skip_unknown=False)

    # --------------------------------------------
    # Remember the parsed configs

    # comet - log environment
    if cometml_experiment is not None:
        # log other parameters
        cometml_experiment.log_parameters(dict(premade=premade,
                                               config=config,
                                               override=override,
                                               gin_files=gin_files,
                                               gpu=gpu),
                                          prefix='cli/')

    # wandb - log environment
    if wandb_run is not None:

        # store general configs
        wandb_run.config.update(
            dict_prefix_key(dict(premade=premade,
                                 config=config,
                                 override=override,
                                 gin_files=gin_files,
                                 gpu=gpu),
                            prefix='cli/'))

    return train(
        output_dir=output_dir,
        cometml_experiment=cometml_experiment,
        wandb_run=wandb_run,
        num_workers=num_workers,
        in_memory=in_memory,
        # to execute the sub-notebook
        memfrac_gpu=memfrac_gpu,
        gpu=gpu)
Exemplo n.º 4
0
def start_experiment(output_dir,
                     cometml_project="",
                     wandb_project="",
                     run_id=None,
                     note_params="",
                     extra_kwargs=None,
                     overwrite=False):
    """Start a model training experiment. This will create a new output directory
    and setup the experiment management handles
    """
    sys.path.append(os.getcwd())
    if cometml_project:
        logger.info("Using comet.ml")
        if Experiment is None:
            raise ImportError("Comet.ml could not be imported")
        workspace, project_name = cometml_project.split("/")
        cometml_experiment = Experiment(project_name=project_name,
                                        workspace=workspace)
        # TODO - get the experiment id
        # specify output_dir to that directory
    else:
        cometml_experiment = None

    if wandb_project:
        assert "/" in wandb_project
        entity, project = wandb_project.split("/")
        if wandb is None:
            logger.warn("wandb not installed. Not using it")
            wandb_run = None
        else:
            logger.info("Using wandb. Running wandb.init()")
            wandb._set_stage_dir("./")  # Don't prepend wandb to output file
            if run_id is not None:
                wandb.init(project=project,
                           dir=output_dir,
                           entity=entity,
                           reinit=True,
                           resume=run_id)
            else:
                # automatically set the output
                wandb.init(project=project,
                           entity=entity,
                           reinit=True,
                           dir=output_dir)
            wandb_run = wandb.run
            if wandb_run is None:
                logger.warn("Wandb run is None")
            print(wandb_run)
    else:
        wandb_run = None

    # update the output directory
    if run_id is None:
        if wandb_run is not None:
            run_id = os.path.basename(wandb_run.dir)
        elif cometml_experiment is not None:
            run_id = cometml_experiment.id
        else:
            # random run_id
            run_id = str(uuid4())
    output_dir = os.path.join(output_dir, run_id)

    if wandb_run is not None:
        # make sure the output directory is the same
        # wandb_run._dir = os.path.normpath(output_dir)  # This doesn't work
        # assert os.path.normpath(wandb_run.dir) == os.path.normpath(output_dir)
        # TODO - fix this assertion-> the output directories should be the same
        # in order for snakemake to work correctly
        pass
    # -----------------------------

    if os.path.exists(os.path.join(output_dir, 'config.gin')):
        if overwrite:
            logger.info(
                f"config.gin already exists in the output "
                "directory {output_dir}. Removing the whole directory.")
            shutil.rmtree(output_dir)
        else:
            raise ValueError(f"Output directory {output_dir} shouldn't exist!")
    os.makedirs(output_dir,
                exist_ok=True)  # make the output directory. It shouldn't exist

    # add logging to the file
    add_file_logging(output_dir, logger)

    # write note_params.json
    if note_params:
        logger.info(f"note_params: {note_params}")
        note_params_dict = kv_string2dict(note_params)
    else:
        note_params_dict = dict()
    write_json(note_params_dict,
               os.path.join(output_dir, "note_params.json"),
               sort_keys=True,
               indent=2)

    if cometml_experiment is not None:
        cometml_experiment.log_parameters(note_params_dict)
        cometml_experiment.log_parameters(dict(output_dir=output_dir),
                                          prefix='cli/')

        exp_url = f"https://www.comet.ml/{cometml_experiment.workspace}/{cometml_experiment.project_name}/{cometml_experiment.id}"
        logger.info("Comet.ml url: " + exp_url)
        # write the information about comet.ml experiment
        write_json(
            {
                "url": exp_url,
                "key": cometml_experiment.id,
                "project": cometml_experiment.project_name,
                "workspace": cometml_experiment.workspace
            },
            os.path.join(output_dir, "cometml.json"),
            sort_keys=True,
            indent=2)

    if wandb_run is not None:
        wandb_run.config.update(note_params_dict)
        write_json(
            {
                "url": wandb_run.get_url(),
                "key": wandb_run.id,
                "project": wandb_run.project,
                "path": wandb_run.path,
                "group": wandb_run.group
            },
            os.path.join(output_dir, "wandb.json"),
            sort_keys=True,
            indent=2)
        wandb_run.config.update(
            dict_prefix_key(dict(output_dir=output_dir), prefix='cli/'))

    return cometml_experiment, wandb_run, output_dir
Exemplo n.º 5
0
    def train(self,
              batch_size=256,
              epochs=100,
              early_stop_patience=4,
              num_workers=8,
              train_epoch_frac=1.0,
              valid_epoch_frac=1.0,
              train_samples_per_epoch=None,
              validation_samples=None,
              train_batch_sampler=None,
              tensorboard=True):
        """Train the model
        Args:
          batch_size:
          epochs:
          patience: early stopping patience
          num_workers: how many workers to use in parallel
          train_epoch_frac: if smaller than 1, then make the epoch shorter
          valid_epoch_frac: same as train_epoch_frac for the validation dataset
          train_batch_sampler: batch Sampler for training. Useful for say Stratified sampling
          tensorboard: if True, tensorboard output will be added
        """

        if train_batch_sampler is not None:
            train_it = self.train_dataset.batch_train_iter(shuffle=False,
                                                           batch_size=1,
                                                           drop_last=None,
                                                           batch_sampler=train_batch_sampler,
                                                           num_workers=num_workers)
        else:
            train_it = self.train_dataset.batch_train_iter(batch_size=batch_size,
                                                           shuffle=True,
                                                           num_workers=num_workers)
        next(train_it)
        valid_dataset = self.valid_dataset[0][1]  # take the first one
        valid_it = valid_dataset.batch_train_iter(batch_size=batch_size,
                                                  shuffle=True,
                                                  num_workers=num_workers)
        next(valid_it)

        if tensorboard:
            tb = [TensorBoard(log_dir=self.output_dir)]
        else:
            tb = []

        if self.wandb_run is not None:
            from wandb.keras import WandbCallback
            wcp = [WandbCallback(save_model=False)]  # we save the model using ModelCheckpoint
        else:
            wcp = []

        # train the model
        if len(valid_dataset) == 0:
            raise ValueError("len(self.valid_dataset) == 0")

        if train_samples_per_epoch is None:
            train_steps_per_epoch = max(int(len(self.train_dataset) / batch_size * train_epoch_frac), 1)
        else:
            train_steps_per_epoch = max(int(train_samples_per_epoch / batch_size), 1)

        if validation_samples is None:
            # parametrize with valid_epoch_frac
            validation_steps = max(int(len(valid_dataset) / batch_size * valid_epoch_frac), 1)
        else:
            validation_steps = max(int(validation_samples / batch_size), 1)

        self.model.fit_generator(
            train_it,
            epochs=epochs,
            steps_per_epoch=train_steps_per_epoch,
            validation_data=valid_it,
            validation_steps=validation_steps,
            callbacks=[
                EarlyStopping(
                    patience=early_stop_patience,
                    restore_best_weights=True
                ),
                CSVLogger(self.history_path)
            ] + tb + wcp
        )
        self.model.save(self.ckp_file)

        # log metrics from the best epoch
        try:
            dfh = pd.read_csv(self.history_path)
            m = dict(dfh.iloc[dfh.val_loss.idxmin()])
            if self.cometml_experiment is not None:
                self.cometml_experiment.log_metrics(m, prefix="best-epoch/")
            if self.wandb_run is not None:
                self.wandb_run.summary.update(flatten(dict_prefix_key(m, prefix="best-epoch/"), separator='/'))
        except FileNotFoundError as e:
            logger.warning(e)