Ejemplo n.º 1
0
    def evaluate(self, batch_size=256, shuffle=True, num_workers=8, save=True):
        """Evaluate the model on the validation set
        Args:
          metrics: a list or a dictionary of metrics
          batch_size:
          num_workers:
        """
        print("Started loading validation dataset")
        
        X_valid, y_valid = self.valid_dataset.load_all(batch_size=batch_size,
                                               num_workers=num_workers)
        """
        it = self.valid_dataset.batch_train_iter(batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
        X_valid, y_valid = next(it)
        """
        print("Finished loading validation dataset")
        metric_res = self.model.score(X_valid, y_valid)

        if save:
            write_json(metric_res, self.evaluation_path, indent=2)

        if self.cometml_experiment:
            self.cometml_experiment.log_multiple_metrics(flatten(metric_res), prefix="best/")

        return metric_res
Ejemplo n.º 2
0
def log_gin_config(output_dir, cometml_experiment=None, wandb_run=None):
    gin_config_str = gin.operative_config_str()

    print("Used config: " + "-" * 40)
    print(gin_config_str)
    print("-" * 52)
    with open(os.path.join(output_dir, "config.gin"), "w") as f:
        f.write(gin_config_str)
    # parse the gin config string to dictionary
    gin_config_str = "\n".join(
        [x for x in gin_config_str.split("\n") if not x.startswith("import")])
    gin_config_dict = yaml.load(
        gin_config_str.replace("@", "").replace(" = %",
                                                ": ").replace(" = ", ": "))
    write_json(gin_config_dict,
               os.path.join(output_dir, "config.gin.json"),
               sort_keys=True,
               indent=2)

    if cometml_experiment is not None:
        # Skip any rows starting with import
        cometml_experiment.log_multiple_params(gin_config_dict)

    if wandb_run is not None:
        # This allows to display the metric on the dashboard
        wandb_run.config.update(
            {k.replace(".", "/"): v
             for k, v in gin_config_dict.items()})
Ejemplo n.º 3
0
    def evaluate(self, metric, scaler_path=None, eval_type=np.float32, save=True):
        """Evaluate the model on the validation set
        Args:
          metrics: a list or a dictionary of metrics
          batch_size:
          num_workers:
        """
        print("Started loading validation dataset")
        
        X_valid, y_valid = self.valid_dataset.load_all()

        if scaler_path:
            scaler = load_pickle(scaler_path)
            print("Started scaling X.")
            X_infl = X_valid.astype(np.float32)
            X_infl = scaler.transform(X_infl)

            if eval_type is not np.float32:
                X_valid = X_infl.astype(np.float16)
                if isinstance(X_valid, csr_matrix):
                    X_valid.data = np.minimum(X_valid.data, 65500)
                else:
                    X_valid = np.minimum(X_valid, 65500)
                del X_infl
            print("Finished scaling X.")

        print("Finished loading validation dataset. Shape: ", X_valid.shape, "True values:", y_valid.sum()/y_valid.shape[0])
        
        y_pred = self.model.predict(X_valid)
        metric_res = metric(y_valid, y_pred)
        print("metric_res", metric_res, np.amax(X_valid))

        if save:
            write_json(metric_res, self.evaluation_path, indent=2)

        if self.cometml_experiment:
            self.cometml_experiment.log_multiple_metrics(flatten(metric_res), prefix="best/")

        return metric_res
Ejemplo n.º 4
0
    def evaluate(self, batch_size=256, shuffle=True, num_workers=8, save=True):
        """Evaluate the model on the validation set
        Args:
          metrics: a list or a dictionary of metrics
          batch_size:
          num_workers:
        """
        print("Started loading validation dataset")
        for d in self.valid_datasets:
            self.valid_datasets.append(d.batch_iter(batch_size))
        
        batch = None
        metric_res_b = []
        print("Loading and training")
        for batch_num in tqdm(enumerate(len(self.valid_datasets[0]))):
            for it in self.valid_datasets:
                # Connecting features from all kipoi datasets
                # we assume that the variants have been curated,
                # i.e. the same in the exact order.
                if batch is None:
                    batch = next(it)
                else:
                    batch = np.concatenate(batch, next(it), axis=1)
            X_batch, y_batch = batch[:, 1:], batch[:,0]
            metric_res_b.append(self.model.test_on_batch(X_batch,
                                y_batch,
                                sample_weight=sample_weight))
        
        metric_res = np.average(metric_res_b)

        if save:
            write_json(metric_res, self.evaluation_path, indent=2)

        if self.cometml_experiment:
            self.cometml_experiment.log_multiple_metrics(flatten(metric_res), prefix="best/")

        return metric_res
Ejemplo n.º 5
0
    def evaluate(self,
                 metric,
                 batch_size=256,
                 num_workers=8,
                 eval_train=False,
                 eval_skip=(),
                 save=True,
                 **kwargs):
        """Evaluate the model on the validation set
        Args:
          metrics: a list or a dictionary of metrics
          batch_size:
          num_workers:
          eval_train: if True, also compute the evaluation metrics on the training set
          save: save the json file to the output directory
        """
        if len(kwargs) > 0:
            logger.warning(
                f"Extra kwargs were provided to trainer.evaluate: {kwargs}")
        # contruct a list of dataset to evaluate
        if eval_train:
            eval_datasets = [('train', self.train_dataset)
                             ] + self.valid_dataset
        else:
            eval_datasets = self.valid_dataset

        try:
            if len(eval_skip) > 0:
                eval_datasets = [(k, v) for k, v in eval_datasets
                                 if k not in eval_skip]
        except:
            logger.warning(
                f"eval datasets don't contain tuples. Unable to skip them using {eval_skip}"
            )

        metric_res = OrderedDict()
        for d in eval_datasets:
            if len(d) == 2:
                dataset_name, dataset = d
                eval_metric = metric  # use the default eval metric
            elif len(d) == 3:
                # specialized evaluation metric was passed
                dataset_name, dataset, eval_metric = d
            else:
                # TODO - this should be made more explicit with classes
                raise ValueError(
                    "Valid dataset needs to be a list of tuples of 2 or 3 elements"
                    "(name, dataset) or (name, dataset, metric)")
            logger.info(f"Evaluating dataset: {dataset_name}")
            lpreds = []
            llabels = []
            from copy import deepcopy
            for inputs, targets in tqdm(
                    dataset.batch_train_iter(cycle=False,
                                             num_workers=num_workers,
                                             batch_size=batch_size),
                    total=len(dataset) // batch_size):
                lpreds.append(self.model.predict_on_batch(inputs))
                llabels.append(deepcopy(targets))
                del inputs
                del targets
            preds = numpy_collate_concat(lpreds)
            labels = numpy_collate_concat(llabels)
            del lpreds
            del llabels
            metric_res[dataset_name] = eval_metric(labels, preds)

        if save:
            write_json(metric_res, self.evaluation_path, indent=2)
            logger.info("Saved metrics to {}".format(self.evaluation_path))

        if self.cometml_experiment is not None:
            self.cometml_experiment.log_multiple_metrics(flatten(
                metric_res, separator='/'),
                                                         prefix="eval/")

        if self.wandb_run is not None:
            self.wandb_run.summary.update(
                flatten(prefix_dict(metric_res, prefix="eval/"),
                        separator='/'))
        metric_res = {**self.metrics, **metric_res}
        return metric_res
Ejemplo n.º 6
0
def gin_train(gin_files,
              output_dir,
              gin_bindings='',
              gpu=0,
              memfrac=0.45,
              framework='tf',
              cometml_project="",
              wandb_project="",
              remote_dir="",
              run_id=None,
              note_params="",
              force_overwrite=False):
    """Train a model using gin-config

    Args:
      gin_file: comma separated list of gin files
      output_dir: where to store the results. Note: a subdirectory `run_id`
        will be created in `output_dir`.
      gin_bindings: comma separated list of additional gin-bindings to use
      gpu: which gpu to use. Example: gpu=1
      memfrac: what fraction of the GPU's memory to use
      framework: which framework to use. Available: tf
      cometml_project: comet_ml project name. Example: Avsecz/basepair.
        If not specified, cometml will not get used
      wandb_project: wandb `<entity>/<project>` name. Example: Avsecz/test.
        If not specified, wandb will not be used
      remote_dir: additional path to the remote directory. Can be an s3 path.
        Example: `s3://mybucket/model1/exp1`
      run_id: manual run id. If not specified, it will be either randomly
        generated or re-used from wandb or comet.ml.
      note_params: take note of additional key=value pairs.
        Example: --note-params note='my custom note',feature_set=this
      force_overwrite: if True, the output directory will be overwritten
    """

    sys.path.append(os.getcwd())
    if cometml_project:
        logger.info("Using comet.ml")
        workspace, project_name = cometml_project.split("/")
        cometml_experiment = Experiment(project_name=project_name,
                                        workspace=workspace)
        # TODO - get the experiment id
        # specify output_dir to that directory
    else:
        cometml_experiment = None

    if wandb_project:
        assert "/" in wandb_project
        entity, project = wandb_project.split("/")
        if wandb is None:
            logger.warn("wandb not installed. Not using it")
            wandb_run = None
        else:
            wandb._set_stage_dir("./")  # Don't prepend wandb to output file
            if run_id is not None:
                wandb.init(project=project,
                           dir=output_dir,
                           entity=entity,
                           resume=run_id)
            else:
                # automatically set the output
                wandb.init(project=project, entity=entity, dir=output_dir)
            wandb_run = wandb.run
            logger.info("Using wandb")
            print(wandb_run)
    else:
        wandb_run = None

    # update the output directory
    if run_id is None:
        if wandb_run is not None:
            run_id = os.path.basename(wandb_run.dir)
        elif cometml_experiment is not None:
            run_id = cometml_experiment.id
        else:
            # random run_id
            run_id = str(uuid4())
    output_dir = os.path.join(output_dir, run_id)
    if remote_dir:
        remote_dir = os.path.join(remote_dir, run_id)
    if wandb_run is not None:
        # make sure the output directory is the same
        # wandb_run._dir = os.path.normpath(output_dir)  # This doesn't work
        # assert os.path.normpath(wandb_run.dir) == os.path.normpath(output_dir)
        # TODO - fix this assertion-> the output directories should be the same
        # in order for snakemake to work correctly
        pass
    # -----------------------------

    if os.path.exists(os.path.join(output_dir, 'config.gin')):
        if force_overwrite:
            logger.info(
                f"config.gin already exists in the output "
                "directory {output_dir}. Removing the whole directory.")
            import shutil
            shutil.rmtree(output_dir)
        else:
            raise ValueError(f"Output directory {output_dir} shouldn't exist!")
    os.makedirs(output_dir,
                exist_ok=True)  # make the output directory. It shouldn't exist

    # add logging to the file
    add_file_logging(output_dir, logger)

    if framework == 'tf':
        import gin.tf
        if gpu is not None:
            logger.info(f"Using gpu: {gpu}, memory fraction: {memfrac}")
            create_tf_session(gpu, per_process_gpu_memory_fraction=memfrac)

    gin.parse_config_files_and_bindings(gin_files.split(","),
                                        bindings=gin_bindings.split(";"),
                                        skip_unknown=False)

    # write note_params.json
    if note_params:
        logger.info(f"note_params: {note_params}")
        note_params_dict = kv_string2dict(note_params)
    else:
        note_params_dict = dict()
    write_json(note_params_dict,
               os.path.join(output_dir, "note_params.json"),
               sort_keys=True,
               indent=2)

    # comet - log environment
    if cometml_experiment is not None:
        # log other parameters
        cometml_experiment.log_multiple_params(dict(gin_files=gin_files,
                                                    gin_bindings=gin_bindings,
                                                    output_dir=output_dir,
                                                    gpu=gpu),
                                               prefix='cli/')
        cometml_experiment.log_multiple_params(note_params_dict)

        exp_url = f"https://www.comet.ml/{cometml_experiment.workspace}/{cometml_experiment.project_name}/{cometml_experiment.id}"
        logger.info("Comet.ml url: " + exp_url)
        # write the information about comet.ml experiment
        write_json(
            {
                "url": exp_url,
                "key": cometml_experiment.id,
                "project": cometml_experiment.project_name,
                "workspace": cometml_experiment.workspace
            },
            os.path.join(output_dir, "cometml.json"),
            sort_keys=True,
            indent=2)

    # wandb - log environment
    if wandb_run is not None:
        write_json(
            {
                "url": wandb_run.get_url(),
                "key": wandb_run.id,
                "project": wandb_run.project,
                "path": wandb_run.path,
                "group": wandb_run.group
            },
            os.path.join(output_dir, "wandb.json"),
            sort_keys=True,
            indent=2)
        # store general configs
        wandb_run.config.update(
            prefix_dict(dict(gin_files=gin_files,
                             gin_bindings=gin_bindings,
                             output_dir=output_dir,
                             gpu=gpu),
                        prefix='cli/'))
        wandb_run.config.update(note_params_dict)

    if remote_dir:
        import time
        logger.info("Test file upload to: {}".format(remote_dir))
        upload_dir(output_dir, remote_dir)
    return train(output_dir=output_dir,
                 remote_dir=remote_dir,
                 cometml_experiment=cometml_experiment,
                 wandb_run=wandb_run)
Ejemplo n.º 7
0
    def evaluate(self,
                 metric,
                 batch_size=256,
                 num_workers=8,
                 eval_train=False,
                 eval_skip=[],
                 save=True,
                 **kwargs):
        """Evaluate the model on the validation set
        Args:
          metric: a function accepting (y_true, y_pred) and returning the evaluation metric(s)
          batch_size:
          num_workers:
          eval_train: if True, also compute the evaluation metrics on the training set
          save: save the json file to the output directory
        """
        if len(kwargs) > 0:
            logger.warn(
                f"Extra kwargs were provided to trainer.evaluate(): {kwargs}")
        # Save the complete model -> HACK
        self.seq_model.save(os.path.join(self.output_dir, 'seq_model.pkl'))

        # contruct a list of dataset to evaluate
        if eval_train:
            eval_datasets = [('train', self.train_dataset)
                             ] + self.valid_dataset
        else:
            eval_datasets = self.valid_dataset

        # skip some datasets for evaluation
        try:
            if len(eval_skip) > 0:
                logger.info(f"Using eval_skip: {eval_skip}")
                eval_datasets = [(k, v) for k, v in eval_datasets
                                 if k not in eval_skip]
        except:
            logger.warn(
                f"eval datasets don't contain tuples. Unable to skip them using {eval_skip}"
            )

        metric_res = OrderedDict()
        for d in eval_datasets:
            if len(d) == 2:
                dataset_name, dataset = d
                eval_metric = None  # Ignore the provided metric
            elif len(d) == 3:
                # specialized evaluation metric was passed
                dataset_name, dataset, eval_metric = d
            else:
                raise ValueError(
                    "Valid dataset needs to be a list of tuples of 2 or 3 elements"
                    "(name, dataset) or (name, dataset, metric)")
            logger.info(f"Evaluating dataset: {dataset_name}")
            metric_res[dataset_name] = self.seq_model.evaluate(
                dataset,
                eval_metric=eval_metric,
                num_workers=num_workers,
                batch_size=batch_size)
        if save:
            write_json(metric_res, self.evaluation_path, indent=2)
            logger.info("Saved metrics to {}".format(self.evaluation_path))

        if self.cometml_experiment is not None:
            self.cometml_experiment.log_multiple_metrics(flatten(
                metric_res, separator='/'),
                                                         prefix="eval/")

        if self.wandb_run is not None:
            self.wandb_run.summary.update(
                flatten(prefix_dict(metric_res, prefix="eval/"),
                        separator='/'))

        return metric_res