Esempio n. 1
0
    def train(self,
              batch_size=256,
              epochs=100,
              early_stop_patience=4,
              num_workers=8,
              train_epoch_frac=1.0,
              valid_epoch_frac=1.0,
              train_samples_per_epoch=None,
              validation_samples=None,
              train_batch_sampler=None,
              tensorboard=True):
        """Train the model
        Args:
          batch_size:
          epochs:
          patience: early stopping patience
          num_workers: how many workers to use in parallel
          train_epoch_frac: if smaller than 1, then make the epoch shorter
          valid_epoch_frac: same as train_epoch_frac for the validation dataset
          train_batch_sampler: batch Sampler for training. Useful for say Stratified sampling
          tensorboard: if True, tensorboard output will be added
        """

        if train_batch_sampler is not None:
            train_it = self.train_dataset.batch_train_iter(
                shuffle=False,
                batch_size=1,
                drop_last=None,
                batch_sampler=train_batch_sampler,
                num_workers=num_workers)
        else:
            train_it = self.train_dataset.batch_train_iter(
                batch_size=batch_size, shuffle=True, num_workers=num_workers)
        next(train_it)
        valid_dataset = self.valid_dataset[0][1]  # take the first one
        valid_it = valid_dataset.batch_train_iter(batch_size=batch_size,
                                                  shuffle=True,
                                                  num_workers=num_workers)
        next(valid_it)

        if tensorboard:
            tb = [TensorBoard(log_dir=self.output_dir)]
        else:
            tb = []

        if self.wandb_run is not None:
            from wandb.keras import WandbCallback
            wcp = [WandbCallback(save_model=False)
                   ]  # we save the model using ModelCheckpoint
        else:
            wcp = []

        # train the model
        if len(valid_dataset) == 0:
            raise ValueError("len(self.valid_dataset) == 0")

        if train_samples_per_epoch is None:
            train_steps_per_epoch = max(
                int(len(self.train_dataset) / batch_size * train_epoch_frac),
                1)
        else:
            train_steps_per_epoch = max(
                int(train_samples_per_epoch / batch_size), 1)

        if validation_samples is None:
            # parametrize with valid_epoch_frac
            validation_steps = max(
                int(len(valid_dataset) / batch_size * valid_epoch_frac), 1)
        else:
            validation_steps = max(int(validation_samples / batch_size), 1)

        self.model.fit_generator(
            train_it,
            epochs=epochs,
            steps_per_epoch=train_steps_per_epoch,
            validation_data=valid_it,
            validation_steps=validation_steps,
            callbacks=[
                EarlyStopping(patience=early_stop_patience,
                              restore_best_weights=True),
                CSVLogger(self.history_path)
            ] + tb + wcp
            # ModelCheckpoint(self.ckp_file, save_best_only=True)]
        )
        self.model.save(self.ckp_file)
        # self.model = load_model(self.ckp_file)  # not necessary, EarlyStopping is already restoring the best weights

        # log metrics from the best epoch
        try:
            dfh = pd.read_csv(self.history_path)
            m = dict(dfh.iloc[dfh.val_loss.idxmin()])
            if self.cometml_experiment is not None:
                self.cometml_experiment.log_multiple_metrics(
                    m, prefix="best-epoch/")
            if self.wandb_run is not None:
                self.wandb_run.summary.update(
                    flatten(prefix_dict(m, prefix="best-epoch/"),
                            separator='/'))
        except FileNotFoundError as e:
            logger.warning(e)
Esempio n. 2
0
def gin_train(gin_files,
              output_dir,
              gin_bindings='',
              gpu=0,
              memfrac=0.45,
              framework='tf',
              cometml_project="",
              wandb_project="",
              remote_dir="",
              run_id=None,
              note_params="",
              force_overwrite=False):
    """Train a model using gin-config

    Args:
      gin_file: comma separated list of gin files
      output_dir: where to store the results. Note: a subdirectory `run_id`
        will be created in `output_dir`.
      gin_bindings: comma separated list of additional gin-bindings to use
      gpu: which gpu to use. Example: gpu=1
      memfrac: what fraction of the GPU's memory to use
      framework: which framework to use. Available: tf
      cometml_project: comet_ml project name. Example: Avsecz/basepair.
        If not specified, cometml will not get used
      wandb_project: wandb `<entity>/<project>` name. Example: Avsecz/test.
        If not specified, wandb will not be used
      remote_dir: additional path to the remote directory. Can be an s3 path.
        Example: `s3://mybucket/model1/exp1`
      run_id: manual run id. If not specified, it will be either randomly
        generated or re-used from wandb or comet.ml.
      note_params: take note of additional key=value pairs.
        Example: --note-params note='my custom note',feature_set=this
      force_overwrite: if True, the output directory will be overwritten
    """

    sys.path.append(os.getcwd())
    if cometml_project:
        logger.info("Using comet.ml")
        workspace, project_name = cometml_project.split("/")
        cometml_experiment = Experiment(project_name=project_name,
                                        workspace=workspace)
        # TODO - get the experiment id
        # specify output_dir to that directory
    else:
        cometml_experiment = None

    if wandb_project:
        assert "/" in wandb_project
        entity, project = wandb_project.split("/")
        if wandb is None:
            logger.warn("wandb not installed. Not using it")
            wandb_run = None
        else:
            wandb._set_stage_dir("./")  # Don't prepend wandb to output file
            if run_id is not None:
                wandb.init(project=project,
                           dir=output_dir,
                           entity=entity,
                           resume=run_id)
            else:
                # automatically set the output
                wandb.init(project=project, entity=entity, dir=output_dir)
            wandb_run = wandb.run
            logger.info("Using wandb")
            print(wandb_run)
    else:
        wandb_run = None

    # update the output directory
    if run_id is None:
        if wandb_run is not None:
            run_id = os.path.basename(wandb_run.dir)
        elif cometml_experiment is not None:
            run_id = cometml_experiment.id
        else:
            # random run_id
            run_id = str(uuid4())
    output_dir = os.path.join(output_dir, run_id)
    if remote_dir:
        remote_dir = os.path.join(remote_dir, run_id)
    if wandb_run is not None:
        # make sure the output directory is the same
        # wandb_run._dir = os.path.normpath(output_dir)  # This doesn't work
        # assert os.path.normpath(wandb_run.dir) == os.path.normpath(output_dir)
        # TODO - fix this assertion-> the output directories should be the same
        # in order for snakemake to work correctly
        pass
    # -----------------------------

    if os.path.exists(os.path.join(output_dir, 'config.gin')):
        if force_overwrite:
            logger.info(
                f"config.gin already exists in the output "
                "directory {output_dir}. Removing the whole directory.")
            import shutil
            shutil.rmtree(output_dir)
        else:
            raise ValueError(f"Output directory {output_dir} shouldn't exist!")
    os.makedirs(output_dir,
                exist_ok=True)  # make the output directory. It shouldn't exist

    # add logging to the file
    add_file_logging(output_dir, logger)

    if framework == 'tf':
        import gin.tf
        if gpu is not None:
            logger.info(f"Using gpu: {gpu}, memory fraction: {memfrac}")
            create_tf_session(gpu, per_process_gpu_memory_fraction=memfrac)

    gin.parse_config_files_and_bindings(gin_files.split(","),
                                        bindings=gin_bindings.split(";"),
                                        skip_unknown=False)

    # write note_params.json
    if note_params:
        logger.info(f"note_params: {note_params}")
        note_params_dict = kv_string2dict(note_params)
    else:
        note_params_dict = dict()
    write_json(note_params_dict,
               os.path.join(output_dir, "note_params.json"),
               sort_keys=True,
               indent=2)

    # comet - log environment
    if cometml_experiment is not None:
        # log other parameters
        cometml_experiment.log_multiple_params(dict(gin_files=gin_files,
                                                    gin_bindings=gin_bindings,
                                                    output_dir=output_dir,
                                                    gpu=gpu),
                                               prefix='cli/')
        cometml_experiment.log_multiple_params(note_params_dict)

        exp_url = f"https://www.comet.ml/{cometml_experiment.workspace}/{cometml_experiment.project_name}/{cometml_experiment.id}"
        logger.info("Comet.ml url: " + exp_url)
        # write the information about comet.ml experiment
        write_json(
            {
                "url": exp_url,
                "key": cometml_experiment.id,
                "project": cometml_experiment.project_name,
                "workspace": cometml_experiment.workspace
            },
            os.path.join(output_dir, "cometml.json"),
            sort_keys=True,
            indent=2)

    # wandb - log environment
    if wandb_run is not None:
        write_json(
            {
                "url": wandb_run.get_url(),
                "key": wandb_run.id,
                "project": wandb_run.project,
                "path": wandb_run.path,
                "group": wandb_run.group
            },
            os.path.join(output_dir, "wandb.json"),
            sort_keys=True,
            indent=2)
        # store general configs
        wandb_run.config.update(
            prefix_dict(dict(gin_files=gin_files,
                             gin_bindings=gin_bindings,
                             output_dir=output_dir,
                             gpu=gpu),
                        prefix='cli/'))
        wandb_run.config.update(note_params_dict)

    if remote_dir:
        import time
        logger.info("Test file upload to: {}".format(remote_dir))
        upload_dir(output_dir, remote_dir)
    return train(output_dir=output_dir,
                 remote_dir=remote_dir,
                 cometml_experiment=cometml_experiment,
                 wandb_run=wandb_run)
Esempio n. 3
0
    def evaluate(self,
                 metric,
                 batch_size=256,
                 num_workers=8,
                 eval_train=False,
                 eval_skip=(),
                 save=True,
                 **kwargs):
        """Evaluate the model on the validation set
        Args:
          metrics: a list or a dictionary of metrics
          batch_size:
          num_workers:
          eval_train: if True, also compute the evaluation metrics on the training set
          save: save the json file to the output directory
        """
        if len(kwargs) > 0:
            logger.warning(
                f"Extra kwargs were provided to trainer.evaluate: {kwargs}")
        # contruct a list of dataset to evaluate
        if eval_train:
            eval_datasets = [('train', self.train_dataset)
                             ] + self.valid_dataset
        else:
            eval_datasets = self.valid_dataset

        try:
            if len(eval_skip) > 0:
                eval_datasets = [(k, v) for k, v in eval_datasets
                                 if k not in eval_skip]
        except:
            logger.warning(
                f"eval datasets don't contain tuples. Unable to skip them using {eval_skip}"
            )

        metric_res = OrderedDict()
        for d in eval_datasets:
            if len(d) == 2:
                dataset_name, dataset = d
                eval_metric = metric  # use the default eval metric
            elif len(d) == 3:
                # specialized evaluation metric was passed
                dataset_name, dataset, eval_metric = d
            else:
                # TODO - this should be made more explicit with classes
                raise ValueError(
                    "Valid dataset needs to be a list of tuples of 2 or 3 elements"
                    "(name, dataset) or (name, dataset, metric)")
            logger.info(f"Evaluating dataset: {dataset_name}")
            lpreds = []
            llabels = []
            from copy import deepcopy
            for inputs, targets in tqdm(
                    dataset.batch_train_iter(cycle=False,
                                             num_workers=num_workers,
                                             batch_size=batch_size),
                    total=len(dataset) // batch_size):
                lpreds.append(self.model.predict_on_batch(inputs))
                llabels.append(deepcopy(targets))
                del inputs
                del targets
            preds = numpy_collate_concat(lpreds)
            labels = numpy_collate_concat(llabels)
            del lpreds
            del llabels
            metric_res[dataset_name] = eval_metric(labels, preds)

        if save:
            write_json(metric_res, self.evaluation_path, indent=2)
            logger.info("Saved metrics to {}".format(self.evaluation_path))

        if self.cometml_experiment is not None:
            self.cometml_experiment.log_multiple_metrics(flatten(
                metric_res, separator='/'),
                                                         prefix="eval/")

        if self.wandb_run is not None:
            self.wandb_run.summary.update(
                flatten(prefix_dict(metric_res, prefix="eval/"),
                        separator='/'))
        metric_res = {**self.metrics, **metric_res}
        return metric_res
Esempio n. 4
0
    def evaluate(self,
                 metric,
                 batch_size=256,
                 num_workers=8,
                 eval_train=False,
                 eval_skip=[],
                 save=True,
                 **kwargs):
        """Evaluate the model on the validation set
        Args:
          metric: a function accepting (y_true, y_pred) and returning the evaluation metric(s)
          batch_size:
          num_workers:
          eval_train: if True, also compute the evaluation metrics on the training set
          save: save the json file to the output directory
        """
        if len(kwargs) > 0:
            logger.warn(
                f"Extra kwargs were provided to trainer.evaluate(): {kwargs}")
        # Save the complete model -> HACK
        self.seq_model.save(os.path.join(self.output_dir, 'seq_model.pkl'))

        # contruct a list of dataset to evaluate
        if eval_train:
            eval_datasets = [('train', self.train_dataset)
                             ] + self.valid_dataset
        else:
            eval_datasets = self.valid_dataset

        # skip some datasets for evaluation
        try:
            if len(eval_skip) > 0:
                logger.info(f"Using eval_skip: {eval_skip}")
                eval_datasets = [(k, v) for k, v in eval_datasets
                                 if k not in eval_skip]
        except:
            logger.warn(
                f"eval datasets don't contain tuples. Unable to skip them using {eval_skip}"
            )

        metric_res = OrderedDict()
        for d in eval_datasets:
            if len(d) == 2:
                dataset_name, dataset = d
                eval_metric = None  # Ignore the provided metric
            elif len(d) == 3:
                # specialized evaluation metric was passed
                dataset_name, dataset, eval_metric = d
            else:
                raise ValueError(
                    "Valid dataset needs to be a list of tuples of 2 or 3 elements"
                    "(name, dataset) or (name, dataset, metric)")
            logger.info(f"Evaluating dataset: {dataset_name}")
            metric_res[dataset_name] = self.seq_model.evaluate(
                dataset,
                eval_metric=eval_metric,
                num_workers=num_workers,
                batch_size=batch_size)
        if save:
            write_json(metric_res, self.evaluation_path, indent=2)
            logger.info("Saved metrics to {}".format(self.evaluation_path))

        if self.cometml_experiment is not None:
            self.cometml_experiment.log_multiple_metrics(flatten(
                metric_res, separator='/'),
                                                         prefix="eval/")

        if self.wandb_run is not None:
            self.wandb_run.summary.update(
                flatten(prefix_dict(metric_res, prefix="eval/"),
                        separator='/'))

        return metric_res