コード例 #1
0
ファイル: predict.py プロジェクト: amoliu/pytorch_fnet
def save_predictions_csv(
    path_csv: str,
    pred_records: List[Dict],
    dataset: Any,
) -> None:
    """Saves csv with metadata of predictions.

    Parameters
    ----------
    path_csv
        CSV save path.
    pred_records
        List of metadata for each prediction.
    dataset
        Dataset from where signal-target pairs were retrieved.

    """
    df = pd.DataFrame(pred_records).set_index('index')
    if isinstance(dataset, FnetDataset):
        # For FnetDataset, add additional metadata
        df = (df.rename_axis(dataset.df.index.name).join(dataset.df,
                                                         lsuffix='_pre'))
    if os.path.exists(path_csv):
        df_old = pd.read_csv(path_csv)
        col_index = df_old.columns[0]  # Assumes first col is index col
        df_old = df_old.set_index(col_index)
        df = df.combine_first(df_old)
    df = df.sort_index(axis=1)
    dirname = os.path.dirname(path_csv)
    if not os.path.exists(dirname):
        os.makedirs(dirname)
        print('Created:', dirname)
    retry_if_oserror(df.to_csv)(path_csv)
    print('Saved:', path_csv)
コード例 #2
0
def save_csv(path_csv, df: pd.DataFrame) -> None:
    """Saves dataframe as csv and merges with existing csv if necessary."""
    if os.path.exists(path_csv):
        df_old = pd.read_csv(path_csv)
        col_index = df_old.columns[0]  # Assumes first col is index col
        df_old = df_old.set_index(col_index)
        df = df.combine_first(df_old)
    df = df.sort_index(axis=1)
    retry_if_oserror(df.to_csv)(path_csv)
    print('Saved:', path_csv)
コード例 #3
0
    def save(self, path_save: str):
        """Saves model to disk.

        Parameters
        ----------
        path_save
            Filename to which model is saved.

        """
        assert not os.path.isdir(path_save)
        curr_gpu_ids = self.gpu_ids
        self.to_gpu(-1)
        retry_if_oserror(torch.save)(self.get_state(), path_save)
        self.to_gpu(curr_gpu_ids)
コード例 #4
0
    def save(self, path_save: str):
        """Saves model to disk.

        Parameters
        ----------
        path_save
            Filename to which model is saved.

        """
        dirname = os.path.dirname(path_save)
        if not os.path.exists(dirname):
            os.makedirs(dirname)
            logger.info(f"Created: {dirname}")
        curr_gpu_ids = self.gpu_ids
        self.to_gpu(-1)
        retry_if_oserror(torch.save)(self.get_state(), path_save)
        self.to_gpu(curr_gpu_ids)