Пример #1
0
def run_tasks(exp_dir: str, log_path: str, args):
    Utils.cmlog = functools.partial(Utils.safe_log, path='{}/{}/cm.log'.format(exp_dir, EXP_LOGS_DIR))
    Utils.write_json(vars(args), f'{exp_dir}/{EXP_LOGS_DIR}/cmdline-args.json')

    model_class = modelnamedict[args.model]
    dataset_names = [args.dataset]
    mstd_group_by: Optional[List[str]] = ['height', 'sigquant', 'softquant', 'black_box', 'gamma', 'use_last_model']

    for dataset_name in dataset_names:
        normalize_x_kwargs = None
        normalize_y_kwargs = None

        if args.datanorm:
            normalize_x_kwargs = {'category': 'zscore'}

            if DataLoader.stats[dataset_name]['n_classes'] <= 0:
                normalize_y_kwargs = {'category': 'minmax', 'interval': (0, 1)}

        get_dataset_kwargs = {
            'dataset_name': dataset_name,
            'normalize_x_kwargs': normalize_x_kwargs,
            'normalize_y_kwargs': normalize_y_kwargs,
            'shuffle_seed': 1
        }

        train, validn, test = DataLoader.get_dataset(**get_dataset_kwargs)
        setattr(args, 'train_size', train._x.shape[0])
        if JOIN_TRAIN_VAL:
            train = Utils.join_datasets(train, validn)

        assert len(cast(np.ndarray, train['y']).shape) == 1
        assert len(cast(np.ndarray, validn['y']).shape) == 1
        assert len(cast(np.ndarray, test['y']).shape) == 1

        if LOG_BASE_PERF:
            logger = Logger(log_path)
            logger.log('\n{}\n'.format(DataLoader.get_base_perf(train, validn, test)))
            logger.close()

        if SAVE_PLOTS:
            train.visualize(save_path='{}/train-data.png'.format(exp_dir))
            test.visualize(save_path='{}/test-data.png'.format(exp_dir))
        mss = HpUtils.get_model_search_set(model_class, dataset_name, args)
        if mss.model_class in [DGTPredictor, LinPredictor]:
            train, validn, test = train.to_tensor(), validn.to_tensor(), test.to_tensor()

        get_best_model(
            log_path, mss, train, validn, test, get_dataset_kwargs,
            exp_dir=exp_dir, devices_info=args.DEVICES_INFO, show_hpsearch_stats=HPSEARCH_STATS
        )

        if args.compute_mstd:
            if args.model == 'SkCART' or args.model == 'SkLin':
                nseeds = NSEEDS
            else:
                nseeds = NSEEDS
            for i in os.listdir('{}/{}'.format(exp_dir, EXP_LOGS_DIR)):
                if '{}-search-summary'.format(dataset_name) in i:
                    MstdUtils.summary_file('{}/{}/{}'.format(exp_dir, EXP_LOGS_DIR, i), args, group_by=mstd_group_by, metric='validn_acc', nseeds=nseeds, nshuffleseeds=args.ndshuffles)
Пример #2
0
    def plot_sat_info(src_file: str, dst_file: str, start_point: int = 0):
        sat_info = Utils.read_json(src_file)
        int_nodes = 2**sat_info['height'] - 1
        node_safe = True
        if int_nodes > SAT_INFO_NODE_MAX:
            int_nodes = 1
            node_safe = False

        fig, ax = plt.subplots(nrows=int_nodes,
                               ncols=1,
                               figsize=(15, 6 * int_nodes),
                               squeeze=False)
        msize = 6
        sat_info_np = {k: np.array(v) for k, v in sat_info.items()}

        ax[0][0].set_title(
            'Fraction of predicate layer\'s outputs (given all training data) whose absolute difference from 0.5 that fall in various intervals. \nInterval closer to 0.5 implies activation value closer to 0/1 implies smaller gradient'
        )
        for i in range(int_nodes):
            ax[i][0].set_xlabel('Epochs')
            ax[i][0].set_ylabel('Fraction of pred outputs at {}'.format(
                'all internal nodes combined' if not node_safe else 'node {}'.
                format(i)))
            for j in sat_info_np.keys():
                if j != 'epoch' and j != 'height':
                    ax[i][0].plot(sat_info_np['epoch'][start_point:],
                                  sat_info_np[j][start_point:, i],
                                  label=j,
                                  marker='.',
                                  markersize=10)
            ax[i][0].grid()
            ax[i][0].legend()

        plt.savefig(dst_file)
        plt.close(fig)
Пример #3
0
Файл: xhp.py Проект: ajay0/dtnet
    def __init__(self, model_class: Type[LearnablePredictor], hps: List[Dict[str, Any]], expand: bool=True, use_lforb: bool=False):
        self.hps: List[Dict[str, Any]]
        if expand:
            self.hps = Utils.get_list_dict_combinations(hps)
        else:
            self.hps = hps

        self.model_class = model_class
        self.search_size = len(self.hps)
        self.use_lforb = use_lforb # use ulm=True run to compute stats for ulm=False
Пример #4
0
    def log_stats(summary_file: str,
                  logs_dir: str,
                  config_idx: Optional[int] = None):
        log_path = '{}/{}'.format(logs_dir, MAIN_PROGRESS_FILE)
        logger = Logger(log_path)
        logger.log('config_idx={}\n'.format(config_idx))

        df = pd.read_csv(summary_file)

        summary_stats = {'config_idx': config_idx, 'metrics': {}}

        for idx, metric in enumerate(MstdUtils.metrics):
            if metric in df.columns:
                mean = df[metric].mean()
                median = df[metric].median()
                std = df[metric].std(ddof=0)
                mn = df[metric].min()
                mx = df[metric].max()
                nancount = int(df[metric].isna().sum())
                vals = df[metric].tolist()

                summary_stats['metrics'][metric] = {
                    'mean': mean,
                    'median': median,
                    'std': std,
                    'min': mn,
                    'max': mx,
                    'nancount': nancount,
                    'vals': vals
                }
                logger.log(
                    '{:<14} - mean:{:.5f}, median:{:.5f}, std:{:.5f}, min:{:.5f}, max:{:.5f}, nancount:{}\n'
                    .format(metric, mean, median, std, mn, mx, nancount),
                    stdout=True)
                if idx % 3 == 2:
                    logger.log('\n')

        logger.close()
        Utils.write_json(summary_stats,
                         '{}/{}'.format(logs_dir, MSTD_SUMMARY_STATS_FILE))
        return summary_stats
Пример #5
0
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 weight: Optional[torch.Tensor] = None,
                 bias: Optional[torch.Tensor] = None,
                 same: bool = False):
        super().__init__()

        self._l = nn.Linear(in_features, out_features, bias=False)

        if weight is not None:
            self._l.weight = nn.Parameter(weight, requires_grad=False)

        if bias is None:
            self._bias = Utils.get_initialized_bias(
                in_features, 1 if same else out_features)
        else:
            self._bias = nn.Parameter(bias, requires_grad=False)
Пример #6
0
    def get_base_perf(train: Dataset[np.ndarray], validn: Dataset[np.ndarray],
                      test: Dataset[np.ndarray]) -> str:
        assert cast(np.ndarray, train['x']).dtype in [np.float32]
        assert cast(np.ndarray, train['y']).dtype in [np.int64, np.float32]
        assert cast(np.ndarray,
                    validn['x']).dtype == cast(np.ndarray, train['x']).dtype
        assert cast(np.ndarray,
                    validn['y']).dtype == cast(np.ndarray, train['y']).dtype
        assert cast(np.ndarray,
                    test['x']).dtype == cast(np.ndarray, train['x']).dtype
        assert cast(np.ndarray,
                    test['y']).dtype == cast(np.ndarray, train['y']).dtype

        sep = '------------------\n'
        ret = ''
        ret += sep
        ret += ' Base Perf\n'
        ret += sep

        is_classification = cast(np.ndarray, train['y']).dtype == np.int64
        n_features, n_classes = DataLoader.stats[train.name][
            'n_features'], DataLoader.stats[train.name]['n_classes']

        if is_classification:
            mode = stats.mode(train['y'])[0][0]
            model = FixedConstantPredictor(n_features, n_classes, mode)
        else:
            mean = cast(np.ndarray, train['y']).mean()
            model = FixedConstantPredictor(n_features, n_classes, mean)
        model.acc_func, model.acc_func_type = Utils.get_acc_def(
            is_classification)

        ret += ' train_acc={:.5f}\n'.format(model.acc(train))
        ret += ' validn_acc={:.5f}\n'.format(model.acc(validn))
        ret += ' test_acc={:.5f}\n'.format(model.acc(test))

        ret += sep
        return ret
Пример #7
0
def setup_config_dir(model: LearnablePredictor, dataset_name: str, dseed: int,
                     exp_dir: str, config_idx: int) -> Tuple[str, str, str]:

    config = '{}{}{}'.format(type(model).__name__, DESC_SEP, dataset_name)
    for k, v in model.get_hyperparams().items():
        config += '{}{}={}'.format(
            DESC_SEP, k,
            Utils.shorten_mid(str(v), begin_keep=1000, end_keep=1000) if
            (isinstance(v, list) or isinstance(v, np.ndarray)) else v)

    # TODO: config will be used to create directory and it may have illegal characters
    config_dir = f'{exp_dir}/{CONFIGS_ROOT_DIR}/{config_idx}{DESC_SEP}{dataset_name}-{dseed}{DESC_SEP}{type(model).__name__}'
    logs_dir = f'{config_dir}/{CONFIG_LOGS_DIR}'
    plots_dir = f'{config_dir}/{CONFIG_PLOTS_DIR}'
    os.mkdir(config_dir)
    os.mkdir(logs_dir)
    os.mkdir(plots_dir)
    model.log_f = Logger('{}/progress.log'.format(logs_dir)).log

    model.logs_dir = logs_dir
    model.plots_dir = plots_dir

    img_save_prefix = plots_dir
    return config, config_dir, img_save_prefix
Пример #8
0
    def config(config_dir: str,
               args,
               exp_root=None,
               desc: str = '',
               verify_row: Optional[pd.Series] = None,
               nseeds: int = NSEEDS,
               group_by_metric=None,
               nshuffleseeds: int = 1):
        time_str = str(time.time()).split(".")[0]

        # Load datasets and other data needed
        if config_dir[-1] in ['/', '\\']:
            config_dir = config_dir[:-1]

        with open('{}/{}.pkl'.format(config_dir, CONFIG_REPRO_DATA_FILE),
                  'rb') as f:
            repro_data = pickle.load(f)

        model_class = repro_data['model_class']
        hp = repro_data['hp']
        get_dataset_kwargs = repro_data['get_dataset_kwargs']

        # Compute mstd stats
        hps = {}
        for k, v in hp.items():
            hps[k] = [v]
        hps['seed'] = list(range(1, nseeds + 1))

        if verify_row is not None:
            assert repro_data['seed'] == verify_row['seed']
        if repro_data['seed'] not in hps['seed']:
            hps['seed'][0] = repro_data['seed']

        mss = ModelSearchSet(model_class, [hps])

        if exp_root is None:
            exp_root = f'{Utils.get_exp_dir(config_dir)}/{MSTD_DIR}/{desc}'
        try:
            os.makedirs(exp_root)
        except FileExistsError:
            pass

        desc = f'{desc.strip()}{DESC_SEP}{time_str}'
        shuffle_seeds = list(range(1, nshuffleseeds + 1))
        if get_dataset_kwargs['shuffle_seed'] not in shuffle_seeds:
            shuffle_seeds[0] = get_dataset_kwargs['shuffle_seed']
        orig_shuffle_seed = get_dataset_kwargs['shuffle_seed']

        local_mstd_summary_file = f'{exp_root}/local-meanstd-run-summary.csv'
        concat_local_mstd_summary_file = f'{exp_root}/../concat-local-meanstd-run-summary.csv'
        mstd_summary_file = f'{exp_root}/../meanstd-run-summary.csv'

        gethp_output = model_class(**hp).get_hyperparams()
        mstd_dict_aux = {}
        for shuffle_seed in shuffle_seeds:

            get_dataset_kwargs['shuffle_seed'] = shuffle_seed
            train_data, validn_data, test_data = DataLoader.get_dataset(
                **get_dataset_kwargs)
            if JOIN_TRAIN_VAL:
                train_data = Utils.join_datasets(train_data, validn_data)
            if model_class in [DGTPredictor, LinPredictor]:
                train_data, validn_data, test_data = train_data.to_tensor(
                ), validn_data.to_tensor(), test_data.to_tensor()

            # curr_desc = f'{get_dataset_kwargs["dataset_name"]}-{shuffle_seed}{DESC_SEP}{desc.split(DESC_SEP, maxsplit=1)[1]}'
            curr_desc = f'shuffle_seed-{shuffle_seed}'
            exp_dir = f'{exp_root}/exp{DESC_SEP}{curr_desc}{DESC_SEP}{time_str}'

            os.mkdir(exp_dir)
            os.mkdir('{}/{}'.format(exp_dir, CONFIGS_ROOT_DIR))
            logs_dir = '{}/{}'.format(exp_dir, EXP_LOGS_DIR)
            os.mkdir(logs_dir)
            log_path = '{}/{}'.format(logs_dir, MAIN_PROGRESS_FILE)

            get_best_model(log_path,
                           mss,
                           train_data,
                           validn_data,
                           test_data,
                           get_dataset_kwargs,
                           exp_dir=exp_dir,
                           devices_info=args.DEVICES_INFO,
                           show_hpsearch_stats=False)

            summary_file = '{}/{}/{}-search-summary{}{}.csv'.format(
                exp_dir, EXP_LOGS_DIR, get_dataset_kwargs['dataset_name'],
                DESC_SEP, time_str)

            # Verify
            if (shuffle_seed == orig_shuffle_seed) and (verify_row
                                                        is not None):
                df = pd.read_csv(summary_file)
                new = df.loc[df['seed'] == repro_data['seed']]
                assert len(new) == 1
                new = new.iloc[0]
                orig = verify_row
                assert set(orig.index) == set(new.index)

                sel = set([
                    'train_acc', 'validn_acc', 'test_acc', 'train_auc',
                    'validn_auc', 'test_auc', 'dt_train_acc', 'dt_validn_acc',
                    'dt_test_acc', 'dt_train_auc', 'dt_validn_auc',
                    'dt_test_auc', 'cdt_train_acc', 'cdt_validn_acc',
                    'cdt_test_acc'
                ]).intersection(set(orig.index)).intersection(set(new.index))
                diff = new.loc[sel] - orig.loc[sel]

                s = ''
                s += '\n{}\n'.format(Utils.get_padded_text('Verifying'))
                s += '\n--------\nOriginal\n--------\n{}'.format(orig)
                s += '\n---\nNew\n---\n{}'.format(new)
                s += '\n---------------\nDiff (new-orig)\n---------------\n{}'.format(
                    diff)
                s += '\n{}\n'.format(Utils.get_padded_text(''))

                logger = Logger(log_path)
                logger.log(s, stdout=False)
                logger.close()

            summary_stats = MstdUtils.log_stats(summary_file, logs_dir,
                                                verify_row['config_idx'])

            # add a line to local-meanstd-run-summary.csv
            local_mstd_dict = {'shuffle_seed': shuffle_seed}
            for metric in MstdUtils.metrics:
                if metric in summary_stats['metrics']:
                    local_mstd_dict[f'{metric}_mean'] = summary_stats[
                        'metrics'][metric]['mean']
                    local_mstd_dict[f'{metric}_median'] = summary_stats[
                        'metrics'][metric]['median']
                    local_mstd_dict[f'{metric}_std'] = summary_stats[
                        'metrics'][metric]['std']

                    if metric not in mstd_dict_aux:
                        mstd_dict_aux[metric] = []
                    mstd_dict_aux[metric].extend(
                        summary_stats['metrics'][metric]['vals'])

            local_mstd_dict.update(gethp_output)
            local_mstd_dict['group_by_metric'] = group_by_metric

            Utils.append_linedict_to_csv(local_mstd_dict,
                                         local_mstd_summary_file)
            Utils.append_linedict_to_csv(local_mstd_dict,
                                         concat_local_mstd_summary_file)

        # add a line to meanstd-run-summary.csv
        mstd_dict = {}
        for metric in mstd_dict_aux:
            mstd_dict[f'{metric}_mean'] = pd.Series(
                mstd_dict_aux[metric]).mean()
            mstd_dict[f'{metric}_std'] = pd.Series(
                mstd_dict_aux[metric]).std(ddof=0)
        mstd_dict.update(gethp_output)
        mstd_dict['group_by_metric'] = group_by_metric
        for metric in mstd_dict_aux:
            mstd_dict[f'{metric}_nancount'] = int(
                pd.Series(mstd_dict_aux[metric]).isna().sum())
        Utils.append_linedict_to_csv(mstd_dict, mstd_summary_file)
Пример #9
0
def process_model(model: LearnablePredictor, train_data: Dataset,
                  validn_data: Dataset, test_data: Dataset, config_dir: str,
                  config_idx: int, seed: int, compute_auc: bool) -> Dict:
    if SAVE_PLOTS:
        save_path = '{}/plots/model-decisions-post-train.png'.format(
            config_dir)
        model.visualize_decisions(
            test_data['x'] if POST_TRAIN_PLOT_TEST else train_data['x'],
            save_path=save_path)

    train_acc = model.acc(train_data, denormalize=DENORMALIZE_RESULTS)
    validn_acc = model.acc(validn_data, denormalize=DENORMALIZE_RESULTS)
    test_acc = model.acc(test_data, denormalize=DENORMALIZE_RESULTS)

    if compute_auc:
        train_auc = model.auc(train_data)
        validn_auc = model.auc(validn_data)
        test_auc = model.auc(test_data)

    # Get dt_acc
    if isinstance(model, DTExtractablePredictor):
        if not model._is_pure_dt:
            dt_c = model.extract_dt_predictor()

            dt_c.acc_func = model.acc_func
            dt_c.acc_func_type = model.acc_func_type

            dt_train_acc = dt_c.acc(train_data,
                                    denormalize=DENORMALIZE_RESULTS)
            dt_validn_acc = dt_c.acc(validn_data,
                                     denormalize=DENORMALIZE_RESULTS)
            dt_test_acc = dt_c.acc(test_data, denormalize=DENORMALIZE_RESULTS)

            if compute_auc:
                dt_train_auc = dt_c.auc(train_data)
                dt_validn_auc = dt_c.auc(validn_data)
                dt_test_auc = dt_c.auc(test_data)

        if (SAVE_PLOTS or SAVE_TREE) and (model._is_pure_dt):
            dt_c = model.extract_dt_predictor()
        if SAVE_PLOTS:
            dt_c.visualize_decisions(
                cast(
                    torch.Tensor, test_data['x']
                    if POST_TRAIN_PLOT_TEST else train_data['x']),
                save_path='{}/plots/model-tree-decisions-post-train.png'.
                format(config_dir))
        if SAVE_TREE:
            dt_c.visualize_tree(save_path='{}/{}/model-tree.svg'.format(
                config_dir, CONFIG_PLOTS_DIR),
                                data=train_data)

    if isinstance(model, DTExtractablePredictor) and isinstance(
            model, DGTPredictor):
        if not model._is_pure_dt:
            cdt_c = model.extract_cdt_predictor(train_data)
            cdt_c.acc_func = model.acc_func
            cdt_c.acc_func_type = model.acc_func_type

            cdt_train_acc = cdt_c.acc(train_data,
                                      denormalize=DENORMALIZE_RESULTS)
            cdt_validn_acc = cdt_c.acc(validn_data,
                                       denormalize=DENORMALIZE_RESULTS)
            cdt_test_acc = cdt_c.acc(test_data,
                                     denormalize=DENORMALIZE_RESULTS)

        if SAVE_CDT_TREE:
            if model._is_pure_dt:
                cdt_c = model.extract_cdt_predictor(train_data)
            cdt_c.visualize_tree(save_path='{}/{}/model-ctree.svg'.format(
                config_dir, CONFIG_PLOTS_DIR),
                                 data=train_data)

    if isinstance(model, SkCARTPredictor) and SAVE_TREE:
        fig, ax = plt.subplots(dpi=400)
        plot_tree(model._model, ax=ax, precision=5)
        plt.savefig('{}/{}/model-tree.png'.format(config_dir,
                                                  CONFIG_PLOTS_DIR))
        plt.close(fig)

    # Collect stats
    stats: Dict[str, Any] = OrderedDict()
    stats['train_acc'] = train_acc
    stats['validn_acc'] = validn_acc
    stats['test_acc'] = test_acc
    if compute_auc:
        stats['train_auc'] = train_auc
        stats['validn_auc'] = validn_auc
        stats['test_auc'] = test_auc
    if isinstance(model, DTExtractablePredictor):
        if not model._is_pure_dt:
            stats['dt_train_acc'] = dt_train_acc
            stats['dt_validn_acc'] = dt_validn_acc
            stats['dt_test_acc'] = dt_test_acc
            if compute_auc:
                stats['dt_train_auc'] = dt_train_auc
                stats['dt_validn_auc'] = dt_validn_auc
                stats['dt_test_auc'] = dt_test_auc
    if isinstance(model, DTExtractablePredictor) and isinstance(
            model, DGTPredictor):
        if not model._is_pure_dt:
            stats['cdt_train_acc'] = cdt_train_acc
            stats['cdt_validn_acc'] = cdt_validn_acc
            stats['cdt_test_acc'] = cdt_test_acc
    stats['config_idx'] = config_idx
    stats['model_name'] = type(model).__name__
    stats['dataset'] = train_data.name
    stats['seed'] = seed
    stats['shuffle_seed'] = train_data.shuffle_seed

    stats.update(model.get_hyperparams())
    for key, val in stats.items():
        if isinstance(val, list) or isinstance(val, np.ndarray):
            stats[key] = Utils.shorten_mid(str(val),
                                           begin_keep=1000,
                                           end_keep=1000)
    """
    if isinstance(model, DGTPredictor):
        sat_info = model.get_sat_info(train_data)
        stats['sat_info'] = sat_info
        stats['ignore_config'] = 0
    """

    stats['config_dir'] = os.path.abspath(config_dir)

    return stats
Пример #10
0
def get_best_model_aux(
        proc_num: int,
        l:
    'multiprocessing.synchronize.Lock',  # lock for coordinating MAIN_PROGRESS_FILE logging
        cmlock:
    'multiprocessing.synchronize.Lock',  # common lock for coordinating all other logging (typically for small messages)
        started_hps: Value,
        started_hps_lk: 'multiprocessing.synchronize.Lock',
        finished_hps: Value,
        total_hps: int,
        start_time: float,
        model_class: Type[LearnablePredictor],
        proc_hps: List[List[Dict[str, Any]]],
        train_data: Dataset,
        validn_data: Dataset,
        test_data: Dataset,
        get_dataset_kwargs: Dict[str, Any],
        exp_dir: str,
        proc_device_id: List[int],
        log_path: str,
        model_search_summary_path: str,
        use_lforb: bool):
    # Prep process-specific data
    hps = proc_hps[proc_num]
    device_ids = None if proc_device_id[proc_num] == -1 else [
        proc_device_id[proc_num]
    ]

    assert train_data.n_labels == 1
    compute_auc = Utils.is_binary_labels(train_data.to_ndarray()['y'])

    Utils.cmlog = functools.partial(Utils.safe_log,
                                    path='{}/{}/cm.log'.format(
                                        exp_dir, EXP_LOGS_DIR),
                                    l=cmlock)

    logger = Logger(log_path)
    for i, hp in enumerate(hps):
        with started_hps_lk:
            config_idx = started_hps.value
            started_hps.value += 1

        hp = hp.copy()
        seed = hp['seed']
        # print('seed: {}, type(seed): {}'.format(seed, type(seed)), flush=True)
        np.random.seed(seed)
        torch.manual_seed(seed)
        hp.pop('seed')

        # Prep
        try:
            model = model_class(**hp, device_ids=device_ids)  # type: ignore
        except TypeError:
            model = model_class(**hp)  # type: ignore
        model._use_lforb = use_lforb
        config, config_dir, img_save_prefix = setup_config_dir(
            model, train_data.name, train_data.shuffle_seed, exp_dir,
            config_idx)

        # Save data for easy reproducibility (currently used in mstd computation)
        repro_data = {
            'get_dataset_kwargs': get_dataset_kwargs,
            'model_class': model_class,
            'seed': seed,
            'hp': hp
        }
        with open('{}/{}.pkl'.format(config_dir, CONFIG_REPRO_DATA_FILE),
                  'wb') as f:
            pickle.dump(repro_data, f, protocol=pickle.HIGHEST_PROTOCOL)

        # Set what model.acc() should do
        acc_func, acc_func_type = Utils.get_acc_def(
            DataLoader.is_classification(get_dataset_kwargs['dataset_name']),
            hp.get('criterion', None))
        model.acc_func = acc_func
        model.acc_func_type = acc_func_type

        model.train(train_data, validn_data, test_data)
        stats = process_model(model, train_data, validn_data, test_data,
                              config_dir, config_idx, seed, compute_auc)

        # Log
        l.acquire()

        finished_hps.value += 1
        logger.log('\n>> [{:.2f}% ({}/{})]:\nRan: (cidx={}): {}\n'.format(
            finished_hps.value * 100 / total_hps, finished_hps.value,
            total_hps, config_idx, config))
        logger.log('Config dir: {}\n'.format(os.path.abspath(config_dir)))

        logger.log('\ntrain_acc={:.5f}%\n'.format(stats['train_acc']))
        logger.log('validn_acc={:.5f}%\n'.format(stats['validn_acc']))
        logger.log('test_acc={:.5f}%\n'.format(stats['test_acc']))
        if compute_auc:
            logger.log('train_auc={:.5f}\n'.format(stats['train_auc']))
            logger.log('validn_auc={:.5f}\n'.format(stats['validn_auc']))
            logger.log('test_auc={:.5f}\n'.format(stats['test_auc']))

        if isinstance(model, DTExtractablePredictor):
            if not model._is_pure_dt:
                logger.log('dt_train_acc={:.5f}%\n'.format(
                    stats['dt_train_acc']))
                logger.log('dt_validn_acc={:.5f}%\n'.format(
                    stats['dt_validn_acc']))
                logger.log('dt_test_acc={:.5f}%\n'.format(
                    stats['dt_test_acc']))
                if compute_auc:
                    logger.log('dt_train_auc={:.5f}\n'.format(
                        stats['dt_train_auc']))
                    logger.log('dt_validn_auc={:.5f}\n'.format(
                        stats['dt_validn_auc']))
                    logger.log('dt_test_auc={:.5f}\n'.format(
                        stats['dt_test_auc']))

        if isinstance(model, DTExtractablePredictor) and isinstance(
                model, DGTPredictor):
            if not model._is_pure_dt:
                logger.log('cdt_train_acc={:.5f}%\n'.format(
                    stats['cdt_train_acc']))
                logger.log('cdt_validn_acc={:.5f}%\n'.format(
                    stats['cdt_validn_acc']))
                logger.log('cdt_test_acc={:.5f}%\n'.format(
                    stats['cdt_test_acc']))

        total_time = time.time() - start_time
        per_hp_time = total_time / finished_hps.value
        rem_time = per_hp_time * (total_hps - finished_hps.value)
        logger.log('Time: per hp={}, total={}, rem={}\n'.format(
            td(seconds=per_hp_time), td(seconds=total_time),
            td(seconds=rem_time)))

        Utils.append_linedict_to_csv(stats, model_search_summary_path)
        l.release()

        if use_lforb:
            assert hp['use_last_model']

            old_logs_dir = model.logs_dir
            config_idx += total_hps
            _, config_dir, _ = setup_config_dir(model, train_data.name,
                                                train_data.shuffle_seed,
                                                exp_dir, config_idx)
            model.load_best_model(old_logs_dir)

            hp['use_last_model'] = False
            repro_data = {
                'get_dataset_kwargs': get_dataset_kwargs,
                'model_class': model_class,
                'seed': seed,
                'hp': hp
            }
            with open('{}/{}.pkl'.format(config_dir, CONFIG_REPRO_DATA_FILE),
                      'wb') as f:
                pickle.dump(repro_data, f, protocol=pickle.HIGHEST_PROTOCOL)

            stats = process_model(model, train_data, validn_data, test_data,
                                  config_dir, config_idx, seed, compute_auc)

            l.acquire()
            Utils.append_linedict_to_csv(stats, model_search_summary_path)
            l.release()

    logger.close()
Пример #11
0
class DataLoader():
    stats: Dict[str,
                Dict[str,
                     int]] = Utils.read_json(f'{DATASETS_ROOT}/stats.json')

    @staticmethod
    def get_dataset(dataset_name: str, **kwargs) -> TVT:
        assert 'shuffle_seed' in kwargs

        try:
            train, validn, test = getattr(
                DataLoader,
                'prep_{}_all'.format(dataset_name.replace('-', '_')))(**kwargs)
        except AttributeError:
            train, validn, test = DataLoader._prep_generic_all(
                dataset_name=dataset_name, **kwargs)

        train.shuffle_seed = kwargs['shuffle_seed']
        validn.shuffle_seed = kwargs['shuffle_seed']
        test.shuffle_seed = kwargs['shuffle_seed']
        return train, validn, test

    @staticmethod
    def _prep_generic(dataset_name: str,
                      category: Optional[str]) -> Dataset[np.ndarray]:
        if category in ['train', 'val', 'test']:
            x = np.load(
                f'{DATASETS_ROOT}/{dataset_name}/{dataset_name}-{category}-x.npy'
            ).astype(np.float32)
            y = np.load(
                f'{DATASETS_ROOT}/{dataset_name}/{dataset_name}-{category}-y.npy'
            ).astype(np.int64 if DataLoader.
                     is_classification(dataset_name) else np.float32)

            return Dataset(x,
                           y,
                           name=dataset_name,
                           copy=False,
                           autoshrink_y=True)

        else:
            raise ValueError('category must be in ["train", "test", "val"]')

    @staticmethod
    def _prep_generic_all(dataset_name: str,
                          normalize_x_kwargs: Optional[Dict[str, Any]] = None,
                          normalize_y_kwargs: Optional[Dict[str, Any]] = None,
                          shuffle_seed: int = SEED_DEF) -> TVT:

        train = DataLoader._prep_generic(dataset_name, category='train')
        validn = DataLoader._prep_generic(dataset_name, category='val')
        test = DataLoader._prep_generic(dataset_name, category='test')
        return DataLoader.normalize_all_datasets(train, validn, test,
                                                 normalize_x_kwargs,
                                                 normalize_y_kwargs)

    @staticmethod
    def prep_pdbbind_all(normalize_x_kwargs: Optional[Dict[str, Any]] = None,
                         normalize_y_kwargs: Optional[Dict[str, Any]] = None,
                         shuffle_seed: int = SEED_DEF) -> TVT:
        url = r'https://raw.githubusercontent.com/guanghelee/iclr20-lcn/master/data/PDBbind.pkl.gz'
        path = f'{DATASETS_ROOT}/pdbbind/PDBbind.pkl.gz'
        DataLoader.download(url, path)

        with gzip.open(path, 'rb') as f:
            xtrain, ytrain, xval, yval, xtest, ytest = pickle.load(f)

        train = Dataset(xtrain.astype(np.float32),
                        ytrain.astype(np.float32),
                        name='pdbbind',
                        copy=False,
                        autoshrink_y=True)
        validn = Dataset(xval.astype(np.float32),
                         yval.astype(np.float32),
                         name='pdbbind',
                         copy=False,
                         autoshrink_y=True)
        test = Dataset(xtest.astype(np.float32),
                       ytest.astype(np.float32),
                       name='pdbbind',
                       copy=False,
                       autoshrink_y=True)

        return DataLoader.normalize_all_datasets(train, validn, test,
                                                 normalize_x_kwargs,
                                                 normalize_y_kwargs)

    @staticmethod
    def prep_abalone(category, seed) -> Dataset[np.ndarray]:
        npz = np.load(
            open('{}/abalone/abalone{}.npz'.format(DATASETS_ROOT, seed - 1),
                 'rb'))
        x, y = npz[f'X_{category}'].astype(
            np.float32), npz[f'y_{category}'].astype(np.float32)
        return Dataset(x, y, name='abalone', copy=False, autoshrink_y=True)

    @staticmethod
    def prep_abalone_all(normalize_x_kwargs: Optional[Dict[str, Any]] = None,
                         normalize_y_kwargs: Optional[Dict[str, Any]] = None,
                         shuffle_seed: int = SEED_DEF) -> TVT:
        from xconstants import TAO

        train, validn = cast(
            TV,
            DataLoader.prep_abalone('train', shuffle_seed).shuffle(
                seed=shuffle_seed).split(0.8 if not TAO else 0.9))
        test = DataLoader.prep_abalone('test',
                                       shuffle_seed).shuffle(seed=shuffle_seed)

        return DataLoader.normalize_all_datasets(train, validn, test,
                                                 normalize_x_kwargs,
                                                 normalize_y_kwargs)

    @staticmethod
    def prep_cpuactiv(category, seed) -> Dataset[np.ndarray]:
        with open('{}/cpuactiv/cpu_act{}.npz'.format(DATASETS_ROOT, seed - 1),
                  'rb') as f:
            npz = np.load(f)
            x, y = npz[f'X_{category}'].astype(
                np.float32), npz[f'y_{category}'].astype(np.float32)
        return Dataset(x, y, name='cpuactiv', copy=False, autoshrink_y=True)

    @staticmethod
    def prep_cpuactiv_all(normalize_x_kwargs: Optional[Dict[str, Any]] = None,
                          normalize_y_kwargs: Optional[Dict[str, Any]] = None,
                          shuffle_seed: int = SEED_DEF) -> TVT:
        from xconstants import TAO

        train, validn = cast(
            TV,
            DataLoader.prep_cpuactiv('train', shuffle_seed).shuffle(
                seed=shuffle_seed).split(0.8 if not TAO else 0.9))
        test = DataLoader.prep_cpuactiv(
            'test', shuffle_seed).shuffle(seed=shuffle_seed)

        return DataLoader.normalize_all_datasets(train, validn, test,
                                                 normalize_x_kwargs,
                                                 normalize_y_kwargs)

    @staticmethod
    def prep_ailerons(category: str) -> Dataset[np.ndarray]:
        df = pd.read_csv(f'{DATASETS_ROOT}/ailerons/ailerons.{category}',
                         sep=',',
                         header=None)

        x = np.array(df.iloc[:, :-1], dtype=np.float32)
        y = np.array(df.iloc[:, -1], dtype=np.float32) * 1e4
        return Dataset(x, y, name='ailerons', copy=False, autoshrink_y=True)

    @staticmethod
    def prep_ailerons_all(normalize_x_kwargs: Optional[Dict[str, Any]] = None,
                          normalize_y_kwargs: Optional[Dict[str, Any]] = None,
                          shuffle_seed: int = SEED_DEF) -> TVT:
        url = r'https://www.dcc.fc.up.pt/~ltorgo/Regression/ailerons.tgz'
        DataLoader.download(url, f'{DATASETS_ROOT}/ailerons')

        from xconstants import TAO

        train, validn = cast(
            TV,
            DataLoader.prep_ailerons('data').shuffle(
                seed=shuffle_seed).split(0.8 if not TAO else 0.9))
        test = DataLoader.prep_ailerons('test').shuffle(seed=shuffle_seed)
        return DataLoader.normalize_all_datasets(train, validn, test,
                                                 normalize_x_kwargs,
                                                 normalize_y_kwargs)

    @staticmethod
    def prep_ctslice(category, seed) -> Dataset[np.ndarray]:
        raise NotImplementedError
        with open('{}/ctslice/ctslice{}.npz'.format(DATASETS_ROOT, seed - 1),
                  'rb') as f:
            npz = np.load(f)
            x, y = npz[f'X_{category}'].astype(
                np.float32), npz[f'y_{category}'].astype(np.float32)
        return Dataset(x, y, name='ctslice', copy=False, autoshrink_y=True)

    @staticmethod
    def prep_ctslice_all(normalize_x_kwargs: Optional[Dict[str, Any]] = None,
                         normalize_y_kwargs: Optional[Dict[str, Any]] = None,
                         shuffle_seed: int = SEED_DEF) -> TVT:
        from xconstants import TAO

        train, validn = cast(
            TV,
            DataLoader.prep_ctslice('train', shuffle_seed).shuffle(
                seed=shuffle_seed).split(0.8 if not TAO else 0.9))
        test = DataLoader.prep_ctslice('test',
                                       shuffle_seed).shuffle(seed=shuffle_seed)

        return DataLoader.normalize_all_datasets(train, validn, test,
                                                 normalize_x_kwargs,
                                                 normalize_y_kwargs)

    @staticmethod
    def prep_year(category: Optional[str]) -> Dataset[np.ndarray]:
        if category in ['train', 'test']:
            x = np.load(f'{DATASETS_ROOT}/year/year-{category}-x.npy').astype(
                np.float32)
            y = np.load(f'{DATASETS_ROOT}/year/year-{category}-y.npy').astype(
                np.float32)

            return Dataset(x, y, name='year', copy=False, autoshrink_y=True)
        else:
            raise ValueError('category must be in ["train", "test"]')

    @staticmethod
    def prep_year_all(normalize_x_kwargs: Optional[Dict[str, Any]] = None,
                      normalize_y_kwargs: Optional[Dict[str, Any]] = None,
                      shuffle_seed: int = SEED_DEF) -> TVT:
        url = r'https://archive.ics.uci.edu/ml/machine-learning-databases/00203/YearPredictionMSD.txt.zip'
        path = f'{DATASETS_ROOT}/year'
        if DataLoader.download(url, path):
            df = pd.read_csv(f'{path}/YearPredictionMSD.txt')
            split_at = 463715
            train, test = df.iloc[:split_at], df.iloc[split_at:]
            train.iloc[:,
                       0].to_numpy(np.int64).save(f'{path}/year-train-y.npy')
            train.iloc[:, 1:].to_numpy(
                np.float32).save(f'{path}/year-train-x.npy')
            test.iloc[:, 0].to_numpy(np.int64).save(f'{path}/year-test-y.npy')
            test.iloc[:,
                      1:].to_numpy(np.float32).save(f'{path}/year-test-x.npy')

        from xconstants import TAO

        train, validn = cast(
            TV,
            DataLoader.prep_year(category='train').shuffle(
                seed=shuffle_seed).split(0.8 if not TAO else 0.9))
        test = DataLoader.prep_year(category='test')
        return DataLoader.normalize_all_datasets(train, validn, test,
                                                 normalize_x_kwargs,
                                                 normalize_y_kwargs)

    @staticmethod
    def prep_mic(category: Optional[str]) -> Dataset[np.ndarray]:
        if category in ['train', 'test', 'val']:
            x = np.load('{}/mic/mic-{}-x.npy'.format(
                DATASETS_ROOT, category)).astype(np.float32)
            y = np.load('{}/mic/mic-{}-y.npy'.format(
                DATASETS_ROOT, category)).astype(np.float32)

            return Dataset(x, y, name='mic', copy=False, autoshrink_y=True)
        else:
            raise ValueError('category must be in ["train", "test", "val"]')

    @staticmethod
    def prep_mic_all(normalize_x_kwargs: Optional[Dict[str, Any]] = None,
                     normalize_y_kwargs: Optional[Dict[str, Any]] = None,
                     shuffle_seed: int = SEED_DEF) -> TVT:

        # other papers dont use given val, instead split train, so we do the same
        train, validn = cast(
            TV,
            DataLoader.prep_mic(category='train').shuffle(
                seed=shuffle_seed).split(0.8))
        test = cast(Dataset[np.ndarray], DataLoader.prep_mic(category='test'))
        return DataLoader.normalize_all_datasets(train, validn, test,
                                                 normalize_x_kwargs,
                                                 normalize_y_kwargs)

    @staticmethod
    def prep_yah(category: Optional[str]) -> Dataset[np.ndarray]:
        if category in ['train', 'test', 'val']:
            x = np.load('{}/yah/yah-{}-x.npy'.format(
                DATASETS_ROOT, category)).astype(np.float32)
            y = np.load('{}/yah/yah-{}-y.npy'.format(
                DATASETS_ROOT, category)).astype(np.float32)

            return Dataset(x, y, name='yah', copy=False, autoshrink_y=True)
        else:
            raise ValueError('category must be in ["train", "test", "val"]')

    @staticmethod
    def prep_yah_all(normalize_x_kwargs: Optional[Dict[str, Any]] = None,
                     normalize_y_kwargs: Optional[Dict[str, Any]] = None,
                     shuffle_seed: int = SEED_DEF) -> TVT:

        train = cast(Dataset[np.ndarray],
                     DataLoader.prep_yah(category='train'))
        validn = cast(Dataset[np.ndarray], DataLoader.prep_yah(category='val'))
        test = cast(Dataset[np.ndarray], DataLoader.prep_yah(category='test'))
        return DataLoader.normalize_all_datasets(train, validn, test,
                                                 normalize_x_kwargs,
                                                 normalize_y_kwargs)

    @staticmethod
    def prep_connect_4(shuffle_seed: int) -> Dataset[np.ndarray]:
        url = r'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/connect-4'
        path = f'{DATASETS_ROOT}/connect-4/connect-4.txt'
        DataLoader.download(url, path)

        return DataLoader.read_libsvm_format(
            path,
            n_features=DataLoader.stats['connect-4']['n_features'],
            n_classes=DataLoader.stats['connect-4']['n_classes'],
            name='connect-4',
            shuffle_seed=shuffle_seed)

    @staticmethod
    def prep_connect_4_all(normalize_x_kwargs: Optional[Dict[str, Any]] = None,
                           normalize_y_kwargs: Optional[Dict[str, Any]] = None,
                           shuffle_seed: int = SEED_DEF) -> TVT:

        train, validn, test = cast(
            TVT,
            DataLoader.prep_connect_4(shuffle_seed).split(0.64, 0.16))
        return DataLoader.normalize_all_datasets(train, validn, test,
                                                 normalize_x_kwargs,
                                                 normalize_y_kwargs)

    @staticmethod
    def prep_mnist(shuffle_seed: int,
                   category: Optional[str] = None) -> Dataset[np.ndarray]:
        if category is None:
            url = r'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/mnist.bz2'
            path = f'{DATASETS_ROOT}/mnist/mnist.txt'
            DataLoader.download(url, path)

            return DataLoader.read_libsvm_format(
                path,
                n_features=DataLoader.stats['mnist']['n_features'],
                n_classes=DataLoader.stats['mnist']['n_classes'],
                name='mnist',
                shuffle_seed=shuffle_seed)

        elif category == 'test':
            url = r'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/mnist.t.bz2'
            path = f'{DATASETS_ROOT}/mnist/mnist-test.txt'
            DataLoader.download(url, path)

            return DataLoader.read_libsvm_format(
                path,
                n_features=DataLoader.stats['mnist']['n_features'],
                n_classes=DataLoader.stats['mnist']['n_classes'],
                name='mnist',
                shuffle_seed=1)

        else:
            raise ValueError('category must be in [None, "test"]')

    @staticmethod
    def prep_mnist_all(normalize_x_kwargs: Optional[Dict[str, Any]] = None,
                       normalize_y_kwargs: Optional[Dict[str, Any]] = None,
                       shuffle_seed: int = SEED_DEF) -> TVT:

        train, validn = cast(
            Tuple[Dataset[np.ndarray], Dataset[np.ndarray]],
            DataLoader.prep_mnist(shuffle_seed, category=None).split(0.8))
        test = DataLoader.prep_mnist(shuffle_seed, category='test')
        return DataLoader.normalize_all_datasets(train, validn, test,
                                                 normalize_x_kwargs,
                                                 normalize_y_kwargs)

    @staticmethod
    def prep_protein(category: Optional[str] = None) -> Dataset[np.ndarray]:
        if category is None:
            url = r'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/protein.bz2'
            path = f'{DATASETS_ROOT}/protein/protein.txt'
            DataLoader.download(url, path)

            return DataLoader.read_libsvm_format(
                path,
                n_features=DataLoader.stats['protein']['n_features'],
                n_classes=DataLoader.stats['protein']['n_classes'],
                name='protein')

        elif category == 'train':
            url = r'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/protein.tr.bz2'
            path = f'{DATASETS_ROOT}/protein/protein-train.txt'
            DataLoader.download(url, path)

            return DataLoader.read_libsvm_format(
                path,
                n_features=DataLoader.stats['protein']['n_features'],
                n_classes=DataLoader.stats['protein']['n_classes'],
                name='protein')

        elif category == 'test':
            url = r'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/protein.t.bz2'
            path = f'{DATASETS_ROOT}/protein/protein-test.txt'
            DataLoader.download(url, path)

            return DataLoader.read_libsvm_format(
                path,
                n_features=DataLoader.stats['protein']['n_features'],
                n_classes=DataLoader.stats['protein']['n_classes'],
                name='protein')

        elif category == 'val':
            url = r'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/protein.val.bz2'
            path = f'{DATASETS_ROOT}/protein/protein-val.txt'
            DataLoader.download(url, path)

            return DataLoader.read_libsvm_format(
                path,
                n_features=DataLoader.stats['protein']['n_features'],
                n_classes=DataLoader.stats['protein']['n_classes'],
                name='protein')

        else:
            raise ValueError(
                'category must be in [None, "train", "test", "val"]')

    @staticmethod
    def prep_protein_all(normalize_x_kwargs: Optional[Dict[str, Any]] = None,
                         normalize_y_kwargs: Optional[Dict[str, Any]] = None,
                         shuffle_seed: int = SEED_DEF) -> TVT:

        train = DataLoader.prep_protein(category='train')
        validn = DataLoader.prep_protein(category='val')
        test = DataLoader.prep_protein(category='test')
        return DataLoader.normalize_all_datasets(train, validn, test,
                                                 normalize_x_kwargs,
                                                 normalize_y_kwargs)

    @staticmethod
    def prep_sensit_combined(
            shuffle_seed: int,
            category: Optional[str] = None) -> Dataset[np.ndarray]:
        if category is None:
            url = r'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/vehicle/combined.bz2'
            path = f'{DATASETS_ROOT}/sensit-combined/sensit-combined.txt'
            DataLoader.download(url, path)

            return DataLoader.read_libsvm_format(
                path,
                n_features=DataLoader.stats['sensit-combined']['n_features'],
                n_classes=DataLoader.stats['sensit-combined']['n_classes'],
                name='sensit-combined',
                shuffle_seed=shuffle_seed)

        elif category == 'test':
            url = r'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/vehicle/combined.t.bz2'
            path = f'{DATASETS_ROOT}/sensit-combined/sensit-combined-test.txt'
            DataLoader.download(url, path)

            return DataLoader.read_libsvm_format(
                path,
                n_features=DataLoader.stats['sensit-combined']['n_features'],
                n_classes=DataLoader.stats['sensit-combined']['n_classes'],
                name='sensit-combined',
                shuffle_seed=1)

        else:
            raise ValueError('category must be in [None, "test"]')

    @staticmethod
    def prep_sensit_combined_all(
            normalize_x_kwargs: Optional[Dict[str, Any]] = None,
            normalize_y_kwargs: Optional[Dict[str, Any]] = None,
            shuffle_seed: int = SEED_DEF) -> TVT:

        train, validn = cast(
            Tuple[Dataset[np.ndarray], Dataset[np.ndarray]],
            DataLoader.prep_sensit_combined(shuffle_seed,
                                            category=None).split(0.8))
        test = DataLoader.prep_sensit_combined(shuffle_seed, category='test')
        return DataLoader.normalize_all_datasets(train, validn, test,
                                                 normalize_x_kwargs,
                                                 normalize_y_kwargs)

    @staticmethod
    def prep_letter(category: Optional[str] = None) -> Dataset[np.ndarray]:
        if category is None:
            url = r'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/letter.scale'
            path = f'{DATASETS_ROOT}/letter/letter.txt'
            DataLoader.download(url, path)

            return DataLoader.read_libsvm_format(
                path,
                n_features=DataLoader.stats['letter']['n_features'],
                n_classes=DataLoader.stats['letter']['n_classes'],
                name='letter')

        elif category == 'train':
            url = r'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/letter.scale.tr'
            path = f'{DATASETS_ROOT}/letter/letter-train.txt'
            DataLoader.download(url, path)

            return DataLoader.read_libsvm_format(
                path,
                n_features=DataLoader.stats['letter']['n_features'],
                n_classes=DataLoader.stats['letter']['n_classes'],
                name='letter')

        elif category == 'test':
            url = r'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/letter.scale.t'
            path = f'{DATASETS_ROOT}/letter/letter-test.txt'
            DataLoader.download(url, path)

            return DataLoader.read_libsvm_format(
                path,
                n_features=DataLoader.stats['letter']['n_features'],
                n_classes=DataLoader.stats['letter']['n_classes'],
                name='letter')

        elif category == 'val':
            url = r'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/letter.scale.val'
            path = f'{DATASETS_ROOT}/letter/letter-val.txt'
            DataLoader.download(url, path)

            return DataLoader.read_libsvm_format(
                path,
                n_features=DataLoader.stats['letter']['n_features'],
                n_classes=DataLoader.stats['letter']['n_classes'],
                name='letter')

        else:
            raise ValueError(
                'category must be in [None, "train", "test", "val"]')

    @staticmethod
    def prep_letter_all(normalize_x_kwargs: Optional[Dict[str, Any]] = None,
                        normalize_y_kwargs: Optional[Dict[str, Any]] = None,
                        shuffle_seed: int = SEED_DEF) -> TVT:
        train = DataLoader.prep_letter(category='train')
        validn = DataLoader.prep_letter(category='val')
        test = DataLoader.prep_letter(category='test')

        return DataLoader.normalize_all_datasets(train, validn, test,
                                                 normalize_x_kwargs,
                                                 normalize_y_kwargs)

    @staticmethod
    def prep_pendigits(shuffle_seed: int,
                       category: Optional[str] = None) -> Dataset[np.ndarray]:
        if category is None:
            url = r'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/pendigits'
            path = f'{DATASETS_ROOT}/pendigits/pendigits.txt'
            DataLoader.download(url, path)

            return DataLoader.read_libsvm_format(
                path,
                n_features=DataLoader.stats['pendigits']['n_features'],
                n_classes=DataLoader.stats['pendigits']['n_classes'],
                name='pendigits',
                shuffle_seed=shuffle_seed)

        elif category == 'test':
            url = r'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/pendigits.t'
            path = f'{DATASETS_ROOT}/pendigits/pendigits-test.txt'
            DataLoader.download(url, path)

            return DataLoader.read_libsvm_format(
                path,
                n_features=DataLoader.stats['pendigits']['n_features'],
                n_classes=DataLoader.stats['pendigits']['n_classes'],
                name='pendigits',
                shuffle_seed=1)

        else:
            raise ValueError('category must be in [None, "test"]')

    @staticmethod
    def prep_pendigits_all(normalize_x_kwargs: Optional[Dict[str, Any]] = None,
                           normalize_y_kwargs: Optional[Dict[str, Any]] = None,
                           shuffle_seed: int = SEED_DEF) -> TVT:

        train, validn = cast(
            Tuple[Dataset[np.ndarray], Dataset[np.ndarray]],
            DataLoader.prep_pendigits(shuffle_seed, category=None).split(0.8))
        test = DataLoader.prep_pendigits(shuffle_seed, category='test')
        return DataLoader.normalize_all_datasets(train, validn, test,
                                                 normalize_x_kwargs,
                                                 normalize_y_kwargs)

    @staticmethod
    def prep_satimage(category: Optional[str] = None) -> Dataset[np.ndarray]:
        if category is None:
            url = r'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/satimage.scale'
            path = f'{DATASETS_ROOT}/satimage/satimage.txt'
            DataLoader.download(url, path)

            return DataLoader.read_libsvm_format(
                path,
                n_features=DataLoader.stats['satimage']['n_features'],
                n_classes=DataLoader.stats['satimage']['n_classes'],
                name='satimage')

        elif category == 'train':
            url = r'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/satimage.scale.tr'
            path = f'{DATASETS_ROOT}/satimage/satimage-train.txt'
            DataLoader.download(url, path)

            return DataLoader.read_libsvm_format(
                path,
                n_features=DataLoader.stats['satimage']['n_features'],
                n_classes=DataLoader.stats['satimage']['n_classes'],
                name='satimage')

        elif category == 'test':
            url = r'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/satimage.scale.t'
            path = f'{DATASETS_ROOT}/satimage/satimage-test.txt'
            DataLoader.download(url, path)

            return DataLoader.read_libsvm_format(
                path,
                n_features=DataLoader.stats['satimage']['n_features'],
                n_classes=DataLoader.stats['satimage']['n_classes'],
                name='satimage')

        elif category == 'val':
            url = r'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/satimage.scale.val'
            path = f'{DATASETS_ROOT}/satimage/satimage-val.txt'
            DataLoader.download(url, path)

            return DataLoader.read_libsvm_format(
                path,
                n_features=DataLoader.stats['satimage']['n_features'],
                n_classes=DataLoader.stats['satimage']['n_classes'],
                name='satimage')

        else:
            raise ValueError(
                'category must be in [None, "train", "test", "val"]')

    @staticmethod
    def prep_satimage_all(normalize_x_kwargs: Optional[Dict[str, Any]] = None,
                          normalize_y_kwargs: Optional[Dict[str, Any]] = None,
                          shuffle_seed: int = SEED_DEF) -> TVT:

        train = DataLoader.prep_satimage(category='train')
        validn = DataLoader.prep_satimage(category='val')
        test = DataLoader.prep_satimage(category='test')
        return DataLoader.normalize_all_datasets(train, validn, test,
                                                 normalize_x_kwargs,
                                                 normalize_y_kwargs)

    @staticmethod
    def prep_segment(shuffle_seed: int) -> Dataset[np.ndarray]:
        url = r'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/segment.scale'
        path = f'{DATASETS_ROOT}/segment/segment.txt'
        DataLoader.download(url, path)

        return DataLoader.read_libsvm_format(
            path,
            n_features=DataLoader.stats['segment']['n_features'],
            n_classes=DataLoader.stats['segment']['n_classes'],
            name='segment',
            shuffle_seed=shuffle_seed)

    @staticmethod
    def prep_segment_all(normalize_x_kwargs: Optional[Dict[str, Any]] = None,
                         normalize_y_kwargs: Optional[Dict[str, Any]] = None,
                         shuffle_seed: int = SEED_DEF) -> TVT:

        train, validn, test = cast(
            TVT,
            DataLoader.prep_segment(shuffle_seed).split(0.64, 0.16))
        return DataLoader.normalize_all_datasets(train, validn, test,
                                                 normalize_x_kwargs,
                                                 normalize_y_kwargs)

    @staticmethod
    def prep_bace(category: Optional[str]) -> Dataset[np.ndarray]:
        if category in ['train', 'val', 'test']:
            if category == 'val':
                category = 'valid'

            url = rf'https://raw.githubusercontent.com/guanghelee/iclr20-lcn/master/data/bace_split/{category}.fgp2048.csv'
            path = f'{DATASETS_ROOT}/bace/{category}.fgp2048.csv'
            DataLoader.download(url, path)

            df = pd.read_csv(path)
            x = np.array([[int(j) for j in i]
                          for i in df['mol'].values]).astype(np.float32)
            y = df['Class'].values.astype(np.int64)

            return Dataset(x, y, name='bace', copy=False, autoshrink_y=True)
        else:
            raise ValueError('category must be in ["train", "test", "val"]')

    @staticmethod
    def prep_bace_all(normalize_x_kwargs: Optional[Dict[str, Any]] = None,
                      normalize_y_kwargs: Optional[Dict[str, Any]] = None,
                      shuffle_seed: int = SEED_DEF) -> TVT:

        train = DataLoader.prep_bace(category='train')
        validn = DataLoader.prep_bace(category='val')
        test = DataLoader.prep_bace(category='test')
        return DataLoader.normalize_all_datasets(train, validn, test,
                                                 normalize_x_kwargs,
                                                 normalize_y_kwargs)

    @staticmethod
    def prep_hiv(category: Optional[str]) -> Dataset[np.ndarray]:
        if category in ['train', 'val', 'test']:
            if category == 'val':
                category = 'valid'

            url = rf'https://raw.githubusercontent.com/guanghelee/iclr20-lcn/master/data/HIV_split/{category}.fgp2048.csv'
            path = f'{DATASETS_ROOT}/hiv/{category}.fgp2048.csv'
            DataLoader.download(url, path)

            df = pd.read_csv(path)
            x = np.array([[int(j) for j in i]
                          for i in df['smiles'].values]).astype(np.float32)
            y = df['HIV_active'].values.astype(np.int64)

            return Dataset(x, y, name='hiv', copy=False, autoshrink_y=True)
        else:
            raise ValueError('category must be in ["train", "test", "val"]')

    @staticmethod
    def prep_hiv_all(normalize_x_kwargs: Optional[Dict[str, Any]] = None,
                     normalize_y_kwargs: Optional[Dict[str, Any]] = None,
                     shuffle_seed: int = SEED_DEF) -> TVT:

        train = DataLoader.prep_hiv(category='train')
        validn = DataLoader.prep_hiv(category='val')
        test = DataLoader.prep_hiv(category='test')
        return DataLoader.normalize_all_datasets(train, validn, test,
                                                 normalize_x_kwargs,
                                                 normalize_y_kwargs)

    """
    Notes:
        Converts feature names to int (and throws if such conversion is not posible). Assumes feature names are 1-indexed and converts them to 0-indexed.
        Converts feature values to float (and throws if such conversion is not possible).
        Assumes only 1 label exists and n_classes is the number of classes for that label. Incase of classification, converts labels to 0-indexed if it is not.
    """

    @staticmethod
    def read_libsvm_format(
            file_path: str,
            n_features: int,
            n_classes: int,
            name: str = '',
            shuffle_seed: Optional[int] = 1) -> Dataset[np.ndarray]:
        is_classification = (n_classes > 0)

        with open(file_path, 'r') as f:
            content = f.read()
        assert ':  ' not in content, 'Error while reading: {}'.format(
            file_path)

        content = content.replace(': ', ':')
        content = content.strip()
        lines = content.split('\n')
        lines = [line.strip() for line in lines]

        x = np.zeros((len(lines), n_features), dtype=np.float32)
        y = np.zeros((len(lines), ),
                     dtype=np.int64 if is_classification else np.float32)

        for line_idx, line in enumerate(lines):
            for unit_idx, unit in enumerate(line.split()):
                if unit_idx == 0:
                    assert ':' not in unit
                    if is_classification:
                        y[line_idx] = int(unit.strip())
                    else:
                        y[line_idx] = float(unit.strip())
                else:
                    feat, val = unit.strip().split(':')
                    feat: int = int(feat)
                    val: float = float(val)
                    x[line_idx][feat - 1] = val

        if is_classification:
            # To get classes in [0..n_classes-1]
            y = y - np.min(y)

        return Dataset(x, y, name=name, copy=False).shuffle(seed=shuffle_seed)

    @staticmethod
    def normalize_all_datasets(
            train: Dataset[np.ndarray],
            validn: Dataset[np.ndarray],
            test: Dataset[np.ndarray],
            normalize_x_kwargs: Optional[Dict[str, Any]] = None,
            normalize_y_kwargs: Optional[Dict[str, Any]] = None) -> TVT:

        if normalize_x_kwargs is not None:
            train = train.normalize_x(**normalize_x_kwargs)
            validn = validn.normalize_x(category='mirror',
                                        mirror_params=train.mirror_x_params)
            test = test.normalize_x(category='mirror',
                                    mirror_params=train.mirror_x_params)

        if normalize_y_kwargs is not None:
            train = train.normalize_y(**normalize_y_kwargs)
            validn = validn.normalize_y(category='mirror',
                                        mirror_params=train.mirror_y_params)
            test = test.normalize_y(category='mirror',
                                    mirror_params=train.mirror_y_params)

        return train, validn, test

    @staticmethod
    def get_base_perf(train: Dataset[np.ndarray], validn: Dataset[np.ndarray],
                      test: Dataset[np.ndarray]) -> str:
        assert cast(np.ndarray, train['x']).dtype in [np.float32]
        assert cast(np.ndarray, train['y']).dtype in [np.int64, np.float32]
        assert cast(np.ndarray,
                    validn['x']).dtype == cast(np.ndarray, train['x']).dtype
        assert cast(np.ndarray,
                    validn['y']).dtype == cast(np.ndarray, train['y']).dtype
        assert cast(np.ndarray,
                    test['x']).dtype == cast(np.ndarray, train['x']).dtype
        assert cast(np.ndarray,
                    test['y']).dtype == cast(np.ndarray, train['y']).dtype

        sep = '------------------\n'
        ret = ''
        ret += sep
        ret += ' Base Perf\n'
        ret += sep

        is_classification = cast(np.ndarray, train['y']).dtype == np.int64
        n_features, n_classes = DataLoader.stats[train.name][
            'n_features'], DataLoader.stats[train.name]['n_classes']

        if is_classification:
            mode = stats.mode(train['y'])[0][0]
            model = FixedConstantPredictor(n_features, n_classes, mode)
        else:
            mean = cast(np.ndarray, train['y']).mean()
            model = FixedConstantPredictor(n_features, n_classes, mean)
        model.acc_func, model.acc_func_type = Utils.get_acc_def(
            is_classification)

        ret += ' train_acc={:.5f}\n'.format(model.acc(train))
        ret += ' validn_acc={:.5f}\n'.format(model.acc(validn))
        ret += ' test_acc={:.5f}\n'.format(model.acc(test))

        ret += sep
        return ret

    @staticmethod
    def is_classification(dataset: str) -> bool:
        n_classes = DataLoader.stats[dataset]['n_classes']
        return n_classes > 0

    @staticmethod
    def download(url: str, dst: str):
        if os.path.exists(dst):
            print(f'Using dataset from: {dst}')
            return False

        os.makedirs(os.path.dirname(dst), exist_ok=True)

        print(f'Downloading dataset from: {url}...', end='')
        r = requests.get(url)
        print('Done')

        print(f'Writing to: {dst}...', end='')
        if url.endswith('tgz'):
            os.mkdir(dst)
            tarobj = tarfile.open(fileobj=io.BytesIO(r.content))
            for i in tarobj.getnames():
                with open(f'{dst}/{os.path.basename(i)}', 'w') as f:
                    f.write(tarobj.extractfile(i).read().decode('utf-8'))
            tarobj.close()

        elif url.endswith('zip'):
            os.mkdir(dst)
            with zipfile.ZipFile() as zipobj:
                for i in zipobj.namelist():
                    with zipobj.open(i) as zipfile:
                        with open(f'{dst}/{i}', 'w') as f:
                            f.write(zipfile.read())

        else:
            if url.endswith('bz2'):
                towrite = bz2.decompress(r.content).decode('utf-8')
            else:
                towrite = r.text

            with open(dst, 'w') as f:
                f.write(towrite)
        print('Done')

        return True