示例#1
0
def create_samplers(config, dataset, built_datasets):
    """Create sampler for training, validation and testing.

    Args:
        config (Config): An instance object of Config, used to record parameter information.
        dataset (Dataset): An instance object of Dataset, which contains all interaction records.
        built_datasets (list of Dataset): A list of split Dataset, which contains dataset for
            training, validation and testing.

    Returns:
        tuple:
            - train_sampler (AbstractSampler): The sampler for training.
            - valid_sampler (AbstractSampler): The sampler for validation.
            - test_sampler (AbstractSampler): The sampler for testing.
    """
    phases = ['train', 'valid', 'test']
    train_neg_sample_args = config['train_neg_sample_args']
    eval_neg_sample_args = config['eval_neg_sample_args']
    sampler = None
    train_sampler, valid_sampler, test_sampler = None, None, None

    if train_neg_sample_args['strategy'] != 'none':
        if not config['repeatable']:
            sampler = Sampler(phases, built_datasets,
                              train_neg_sample_args['distribution'])
        else:
            sampler = RepeatableSampler(phases, dataset,
                                        train_neg_sample_args['distribution'])
        train_sampler = sampler.set_phase('train')

    if eval_neg_sample_args['strategy'] != 'none':
        if sampler is None:
            if not config['repeatable']:
                sampler = Sampler(phases, built_datasets,
                                  eval_neg_sample_args['distribution'])
            else:
                sampler = RepeatableSampler(
                    phases, dataset, eval_neg_sample_args['distribution'])
        else:
            sampler.set_distribution(eval_neg_sample_args['distribution'])
        valid_sampler = sampler.set_phase('valid')
        test_sampler = sampler.set_phase('test')

    return train_sampler, valid_sampler, test_sampler
示例#2
0
def data_preparation(config, dataset, save=False):
    """Split the dataset by :attr:`config['eval_setting']` and call :func:`dataloader_construct` to create
    corresponding dataloader.

    Args:
        config (Config): An instance object of Config, used to record parameter information.
        dataset (Dataset): An instance object of Dataset, which contains all interaction records.
        save (bool, optional): If ``True``, it will call :func:`save_datasets` to save split dataset.
            Defaults to ``False``.

    Returns:
        tuple:
            - train_data (AbstractDataLoader): The dataloader for training.
            - valid_data (AbstractDataLoader): The dataloader for validation.
            - test_data (AbstractDataLoader): The dataloader for testing.
    """
    model_type = config['MODEL_TYPE']

    es_str = [_.strip() for _ in config['eval_setting'].split(',')]
    es = EvalSetting(config)

    kwargs = {}
    if 'RS' in es_str[0]:
        kwargs['ratios'] = config['split_ratio']
        if kwargs['ratios'] is None:
            raise ValueError('`ratios` should be set if `RS` is set')
    if 'LS' in es_str[0]:
        kwargs['leave_one_num'] = config['leave_one_num']
        if kwargs['leave_one_num'] is None:
            raise ValueError('`leave_one_num` should be set if `LS` is set')
    kwargs['group_by_user'] = config['group_by_user']
    getattr(es, es_str[0])(**kwargs)

    if es.split_args['strategy'] != 'loo' and model_type == ModelType.SEQUENTIAL:
        raise ValueError('Sequential models require "loo" split strategy.')

    builded_datasets = dataset.build(es)
    train_dataset, valid_dataset, test_dataset = builded_datasets
    phases = ['train', 'valid', 'test']

    if save:
        save_datasets(config['checkpoint_dir'], name=phases, dataset=builded_datasets)

    kwargs = {}
    if config['training_neg_sample_num']:
        es.neg_sample_by(config['training_neg_sample_num'])
        if model_type != ModelType.SEQUENTIAL:
            sampler = Sampler(phases, builded_datasets, es.neg_sample_args['distribution'])
        else:
            sampler = RepeatableSampler(phases, dataset, es.neg_sample_args['distribution'])
        kwargs['sampler'] = sampler.set_phase('train')
        kwargs['neg_sample_args'] = copy.deepcopy(es.neg_sample_args)
        if model_type == ModelType.KNOWLEDGE:
            kg_sampler = KGSampler(dataset, es.neg_sample_args['distribution'])
            kwargs['kg_sampler'] = kg_sampler
    train_data = dataloader_construct(
        name='train',
        config=config,
        eval_setting=es,
        dataset=train_dataset,
        dl_format=config['MODEL_INPUT_TYPE'],
        batch_size=config['train_batch_size'],
        shuffle=True,
        **kwargs
    )

    kwargs = {}
    if len(es_str) > 1 and getattr(es, es_str[1], None):
        getattr(es, es_str[1])()
        if 'sampler' not in locals():
            sampler = Sampler(phases, builded_datasets, es.neg_sample_args['distribution'])
        kwargs['sampler'] = [sampler.set_phase('valid'), sampler.set_phase('test')]
        kwargs['neg_sample_args'] = copy.deepcopy(es.neg_sample_args)
    valid_data, test_data = dataloader_construct(
        name='evaluation',
        config=config,
        eval_setting=es,
        dataset=[valid_dataset, test_dataset],
        batch_size=config['eval_batch_size'],
        **kwargs
    )

    return train_data, valid_data, test_data
示例#3
0
def data_preparation(config, dataset, save=False):
    """Split the dataset by :attr:`config['eval_setting']` and call :func:`dataloader_construct` to create
    corresponding dataloader.

    Args:
        config (Config): An instance object of Config, used to record parameter information.
        dataset (Dataset): An instance object of Dataset, which contains all interaction records.
        save (bool, optional): If ``True``, it will call :func:`save_datasets` to save split dataset.
            Defaults to ``False``.

    Returns:
        tuple:
            - train_data (AbstractDataLoader): The dataloader for training.
            - valid_data (AbstractDataLoader): The dataloader for validation.
            - test_data (AbstractDataLoader): The dataloader for testing.
    """
    model_type = config['MODEL_TYPE']

    es_str = [_.strip() for _ in config['eval_setting'].split(',')]
    es = EvalSetting(config)
    es.set_ordering_and_splitting(es_str[0])

    built_datasets = dataset.build(es)
    train_dataset, valid_dataset, test_dataset = built_datasets
    phases = ['train', 'valid', 'test']
    sampler = None

    if save:
        save_datasets(config['checkpoint_dir'],
                      name=phases,
                      dataset=built_datasets)

    kwargs = {}
    if config['training_neg_sample_num']:
        if dataset.label_field in dataset.inter_feat:
            raise ValueError(
                f'`training_neg_sample_num` should be 0 '
                f'if inter_feat have label_field [{dataset.label_field}].')
        train_distribution = config[
            'training_neg_sample_distribution'] or 'uniform'
        es.neg_sample_by(by=config['training_neg_sample_num'],
                         distribution=train_distribution)
        if model_type != ModelType.SEQUENTIAL:
            sampler = Sampler(phases, built_datasets,
                              es.neg_sample_args['distribution'])
        else:
            sampler = RepeatableSampler(phases, dataset,
                                        es.neg_sample_args['distribution'])
        kwargs['sampler'] = sampler.set_phase('train')
        kwargs['neg_sample_args'] = copy.deepcopy(es.neg_sample_args)
        if model_type == ModelType.KNOWLEDGE:
            kg_sampler = KGSampler(dataset, es.neg_sample_args['distribution'])
            kwargs['kg_sampler'] = kg_sampler
    train_data = dataloader_construct(name='train',
                                      config=config,
                                      eval_setting=es,
                                      dataset=train_dataset,
                                      dl_format=config['MODEL_INPUT_TYPE'],
                                      batch_size=config['train_batch_size'],
                                      shuffle=True,
                                      **kwargs)

    kwargs = {}
    if len(es_str) > 1 and getattr(es, es_str[1], None):
        if dataset.label_field in dataset.inter_feat:
            raise ValueError(
                f'It can not validate with `{es_str[1]}` '
                f'when inter_feat have label_field [{dataset.label_field}].')
        getattr(es, es_str[1])()
        if sampler is None:
            if model_type != ModelType.SEQUENTIAL:
                sampler = Sampler(phases, built_datasets,
                                  es.neg_sample_args['distribution'])
            else:
                sampler = RepeatableSampler(phases, dataset,
                                            es.neg_sample_args['distribution'])
        sampler.set_distribution(es.neg_sample_args['distribution'])
        kwargs['sampler'] = [
            sampler.set_phase('valid'),
            sampler.set_phase('test')
        ]
        kwargs['neg_sample_args'] = copy.deepcopy(es.neg_sample_args)
    valid_data, test_data = dataloader_construct(
        name='evaluation',
        config=config,
        eval_setting=es,
        dataset=[valid_dataset, test_dataset],
        batch_size=config['eval_batch_size'],
        **kwargs)

    return train_data, valid_data, test_data
示例#4
0
文件: utils.py 项目: zrymsm/RecBole
def data_preparation(config, dataset, save=False):
    """Split the dataset by :attr:`config['eval_setting']` and call :func:`dataloader_construct` to create
    corresponding dataloader.

    Args:
        config (Config): An instance object of Config, used to record parameter information.
        dataset (Dataset): An instance object of Dataset, which contains all interaction records.
        save (bool, optional): If ``True``, it will call :func:`save_datasets` to save split dataset.
            Defaults to ``False``.

    Returns:
        tuple:
            - train_data (AbstractDataLoader): The dataloader for training.
            - valid_data (AbstractDataLoader): The dataloader for validation.
            - test_data (AbstractDataLoader): The dataloader for testing.
    """
    model_type = config['MODEL_TYPE']

    es = EvalSetting(config)

    built_datasets = dataset.build(es)
    train_dataset, valid_dataset, test_dataset = built_datasets
    phases = ['train', 'valid', 'test']
    sampler = None
    logger = getLogger()
    train_neg_sample_args = config['train_neg_sample_args']
    eval_neg_sample_args = es.neg_sample_args

    # Training
    train_kwargs = {
        'config': config,
        'dataset': train_dataset,
        'batch_size': config['train_batch_size'],
        'dl_format': config['MODEL_INPUT_TYPE'],
        'shuffle': True,
    }
    if train_neg_sample_args['strategy'] != 'none':
        if dataset.label_field in dataset.inter_feat:
            raise ValueError(
                f'`training_neg_sample_num` should be 0 '
                f'if inter_feat have label_field [{dataset.label_field}].')
        if model_type != ModelType.SEQUENTIAL:
            sampler = Sampler(phases, built_datasets,
                              train_neg_sample_args['distribution'])
        else:
            sampler = RepeatableSampler(phases, dataset,
                                        train_neg_sample_args['distribution'])
        train_kwargs['sampler'] = sampler.set_phase('train')
        train_kwargs['neg_sample_args'] = train_neg_sample_args
        if model_type == ModelType.KNOWLEDGE:
            kg_sampler = KGSampler(dataset,
                                   train_neg_sample_args['distribution'])
            train_kwargs['kg_sampler'] = kg_sampler

    dataloader = get_data_loader('train', config, train_neg_sample_args)
    logger.info(
        set_color('Build', 'pink') +
        set_color(f' [{dataloader.__name__}]', 'yellow') + ' for ' +
        set_color('[train]', 'yellow') + ' with format ' +
        set_color(f'[{train_kwargs["dl_format"]}]', 'yellow'))
    if train_neg_sample_args['strategy'] != 'none':
        logger.info(
            set_color('[train]', 'pink') +
            set_color(' Negative Sampling', 'blue') +
            f': {train_neg_sample_args}')
    else:
        logger.info(
            set_color('[train]', 'pink') +
            set_color(' No Negative Sampling', 'yellow'))
    logger.info(
        set_color('[train]', 'pink') + set_color(' batch_size', 'cyan') +
        ' = ' + set_color(f'[{train_kwargs["batch_size"]}]', 'yellow') + ', ' +
        set_color('shuffle', 'cyan') + ' = ' +
        set_color(f'[{train_kwargs["shuffle"]}]\n', 'yellow'))
    train_data = dataloader(**train_kwargs)

    # Evaluation
    eval_kwargs = {
        'config': config,
        'batch_size': config['eval_batch_size'],
        'dl_format': InputType.POINTWISE,
        'shuffle': False,
    }
    valid_kwargs = {'dataset': valid_dataset}
    test_kwargs = {'dataset': test_dataset}
    if eval_neg_sample_args['strategy'] != 'none':
        if dataset.label_field in dataset.inter_feat:
            raise ValueError(
                f'It can not validate with `{es.es_str[1]}` '
                f'when inter_feat have label_field [{dataset.label_field}].')
        if sampler is None:
            if model_type != ModelType.SEQUENTIAL:
                sampler = Sampler(phases, built_datasets,
                                  eval_neg_sample_args['distribution'])
            else:
                sampler = RepeatableSampler(
                    phases, dataset, eval_neg_sample_args['distribution'])
        else:
            sampler.set_distribution(eval_neg_sample_args['distribution'])
        eval_kwargs['neg_sample_args'] = eval_neg_sample_args
        valid_kwargs['sampler'] = sampler.set_phase('valid')
        test_kwargs['sampler'] = sampler.set_phase('test')
    valid_kwargs.update(eval_kwargs)
    test_kwargs.update(eval_kwargs)

    dataloader = get_data_loader('evaluation', config, eval_neg_sample_args)
    logger.info(
        set_color('Build', 'pink') +
        set_color(f' [{dataloader.__name__}]', 'yellow') + ' for ' +
        set_color('[evaluation]', 'yellow') + ' with format ' +
        set_color(f'[{eval_kwargs["dl_format"]}]', 'yellow'))
    logger.info(es)
    logger.info(
        set_color('[evaluation]', 'pink') + set_color(' batch_size', 'cyan') +
        ' = ' + set_color(f'[{eval_kwargs["batch_size"]}]', 'yellow') + ', ' +
        set_color('shuffle', 'cyan') + ' = ' +
        set_color(f'[{eval_kwargs["shuffle"]}]\n', 'yellow'))

    valid_data = dataloader(**valid_kwargs)
    test_data = dataloader(**test_kwargs)

    if save:
        save_split_dataloaders(config,
                               dataloaders=(train_data, valid_data, test_data))

    return train_data, valid_data, test_data