コード例 #1
0
def build_dataset(params):
    """
    Builds (or loads) a Dataset instance.
    :param params: Parameters specifying Dataset options
    :return: Dataset object
    """

    if params['REBUILD_DATASET']:  # We build a new dataset instance
        if params['VERBOSE'] > 0:
            silence = False
            logging.info(
                'Building ' + params['DATASET_NAME'] + '_' + params['SRC_LAN'] + params['TRG_LAN'] + ' dataset')
        else:
            silence = True

        # If we are using Character NMT ONLY at encoder level. We'll need different tokenization functions
        if params['MAX_INPUT_WORD_LEN'] > 0:
            conditional_tok = 'tokenize_none'
        else:
            conditional_tok = params['TOKENIZATION_METHOD']

        base_path = params['DATA_ROOT_PATH']
        name = params['TASK_NAME'] + '_' + params['SRC_LAN'] + params['TRG_LAN']
        ds = Dataset(name, base_path, silence=silence)

        # OUTPUT DATA
        # Let's load the train, val and test splits of the target language sentences (outputs)
        #    the files include a sentence per line.
        ds.setOutput(base_path + '/' + params['TEXT_FILES']['train'] + params['TRG_LAN'],
                     'train',
                     type='dense_text' if 'sparse' in params['LOSS'] else 'text',
                     id=params['OUTPUTS_IDS_DATASET'][0],
                     tokenization=conditional_tok,
                     build_vocabulary=True,
                     pad_on_batch=params.get('PAD_ON_BATCH', True),
                     sample_weights=params.get('SAMPLE_WEIGHTS', True),
                     fill=params.get('FILL', 'end'),
                     max_text_len=params.get('MAX_OUTPUT_TEXT_LEN', 70),
                     max_words=params.get('OUTPUT_VOCABULARY_SIZE', 0),
                     min_occ=params.get('MIN_OCCURRENCES_OUTPUT_VOCAB', 0),
                     bpe_codes=params.get('BPE_CODES_PATH', None))
        if params.get('ALIGN_FROM_RAW', True) and not params.get('HOMOGENEOUS_BATCHES', False):
            ds.setRawOutput(base_path + '/' + params['TEXT_FILES']['train'] + params['TRG_LAN'],
                            'train',
                            type='file-name',
                            id='raw_' + params['OUTPUTS_IDS_DATASET'][0])

        for split in ['val', 'test']:
            if params['TEXT_FILES'].get(split) is not None:
                ds.setOutput(base_path + '/' + params['TEXT_FILES'][split] + params['TRG_LAN'],
                             split,
                             type='dense_text' if 'sparse' in params['LOSS'] else 'text',
                             id=params['OUTPUTS_IDS_DATASET'][0],
                             pad_on_batch=params.get('PAD_ON_BATCH', True),
                             fill=params.get('FILL_TARGET', 'end'),
                             fill_char=params.get('FILL_TARGET', 'end'),
                             sample_weights=params.get('SAMPLE_WEIGHTS', True),
                             max_text_len=params.get('MAX_OUTPUT_TEXT_LEN', 70),
                             max_words=params.get('OUTPUT_VOCABULARY_SIZE', 0),
                             bpe_codes=params.get('BPE_CODES_PATH', None))
                if params.get('ALIGN_FROM_RAW', True) and not params.get('HOMOGENEOUS_BATCHES', False):
                    ds.setRawOutput(base_path + '/' + params['TEXT_FILES'][split] + params['TRG_LAN'],
                                    split,
                                    type='file-name',
                                    id='raw_' + params['OUTPUTS_IDS_DATASET'][0])

        # INPUT DATA
        # We must ensure that the 'train' split is the first (for building the vocabulary)
        for split in ['train', 'val', 'test']:
            if params['TEXT_FILES'].get(split) is not None:
                if split == 'train':
                    build_vocabulary = True
                else:
                    build_vocabulary = False
                ds.setInput(base_path + '/' + params['TEXT_FILES'][split] + params['SRC_LAN'],
                            split,
                            type='text',
                            id=params['INPUTS_IDS_DATASET'][0],
                            pad_on_batch=params.get('PAD_ON_BATCH', True),
                            tokenization=params.get('TOKENIZATION_METHOD', 'tokenize_none'),
                            build_vocabulary=build_vocabulary,
                            fill=params['FILL'],
                            fill_char=params.get('FILL_CHAR', 'end'),
                            max_text_len=params['MAX_INPUT_TEXT_LEN'],
                            max_word_len=params['MAX_INPUT_WORD_LEN'],
                            char_bpe=params['CHAR_BPE'],
                            max_words=params['INPUT_VOCABULARY_SIZE'],
                            min_occ=params['MIN_OCCURRENCES_INPUT_VOCAB'],
                            bpe_codes=params.get('BPE_CODES_PATH', None))

                if len(params['INPUTS_IDS_DATASET']) > 1:
                    if 'train' in split:
                        ds.setInput(base_path + '/' + params['TEXT_FILES'][split] + params['TRG_LAN'],
                                    split,
                                    type='text',
                                    id=params['INPUTS_IDS_DATASET'][1],
                                    required=False,
                                    tokenization=conditional_tok,
                                    pad_on_batch=params['PAD_ON_BATCH'],
                                    build_vocabulary=params['OUTPUTS_IDS_DATASET'][0],
                                    offset=1,
                                    fill=params.get('FILL', 'end'),
                                    max_text_len=params['MAX_OUTPUT_TEXT_LEN'],
                                    max_word_len=0,
                                    char_bpe=params['CHAR_BPE'],
                                    max_words=params['OUTPUT_VOCABULARY_SIZE'],
                                    bpe_codes=params.get('BPE_CODES_PATH', None))
                        if params.get('TIE_EMBEDDINGS', False):
                            ds.merge_vocabularies([params['INPUTS_IDS_DATASET'][1], params['INPUTS_IDS_DATASET'][0]])
                    else:
                        ds.setInput(None,
                                    split,
                                    type='ghost',
                                    id=params['INPUTS_IDS_DATASET'][-1],
                                    required=False)
                if params.get('ALIGN_FROM_RAW', True) and not params.get('HOMOGENEOUS_BATCHES', False):
                    ds.setRawInput(base_path + '/' + params['TEXT_FILES'][split] + params['SRC_LAN'],
                                   split,
                                   type='file-name',
                                   id='raw_' + params['INPUTS_IDS_DATASET'][0])
        if params.get('POS_UNK', False):
            if params.get('HEURISTIC', 0) > 0:
                ds.loadMapping(params['MAPPING'])

        # If we had multiple references per sentence
        keep_n_captions(ds, repeat=1, n=1, set_names=params['EVAL_ON_SETS'])

        # We have finished loading the dataset, now we can store it for using it in the future
        saveDataset(ds, params['DATASET_STORE_PATH'])

    else:
        # We can easily recover it with a single line
        ds = loadDataset(params['DATASET_STORE_PATH'] + '/Dataset_' + params['DATASET_NAME'] + '_' + params['SRC_LAN'] + params['TRG_LAN'] + '.pkl')

        # If we had multiple references per sentence
        keep_n_captions(ds, repeat=1, n=1, set_names=params['EVAL_ON_SETS'])

    return ds
コード例 #2
0
def build_dataset(params):
    """
    Builds (or loads) a Dataset instance.
    :param params: Parameters specifying Dataset options
    :return: Dataset object
    """

    if params['REBUILD_DATASET']:  # We build a new dataset instance
        if params['VERBOSE'] > 0:
            silence = False
            logger.info('Building ' + params['DATASET_NAME'] + '_' +
                        params['SRC_LAN'] + params['TRG_LAN'] + ' dataset')
        else:
            silence = True

        base_path = params['DATA_ROOT_PATH']
        name = params['DATASET_NAME'] + '_' + params['SRC_LAN'] + params[
            'TRG_LAN']
        ds = Dataset(name, base_path, silence=silence)

        # OUTPUT DATA
        # Load the train, val and test splits of the target language sentences (outputs). The files include a sentence per line.
        ds.setOutput(
            os.path.join(base_path,
                         params['TEXT_FILES']['train'] + params['TRG_LAN']),
            'train',
            type=params.get(
                'OUTPUTS_TYPES_DATASET',
                ['dense-text'] if 'sparse' in params['LOSS'] else ['text'])[0],
            id=params['OUTPUTS_IDS_DATASET'][0],
            tokenization=params.get('TOKENIZATION_METHOD', 'tokenize_none'),
            build_vocabulary=True,
            pad_on_batch=params.get('PAD_ON_BATCH', True),
            sample_weights=params.get('SAMPLE_WEIGHTS', True),
            fill=params.get('FILL', 'end'),
            max_text_len=params.get('MAX_OUTPUT_TEXT_LEN', 70),
            max_words=params.get('OUTPUT_VOCABULARY_SIZE', 0),
            min_occ=params.get('MIN_OCCURRENCES_OUTPUT_VOCAB', 0),
            bpe_codes=params.get('BPE_CODES_PATH', None),
            label_smoothing=params.get('LABEL_SMOOTHING', 0.))

        for split in ['val', 'test']:
            if params['TEXT_FILES'].get(split) is not None:
                ds.setOutput(
                    os.path.join(
                        base_path,
                        params['TEXT_FILES'][split] + params['TRG_LAN']),
                    split,
                    type=
                    'text',  # The type of the references should be always 'text'
                    id=params['OUTPUTS_IDS_DATASET'][0],
                    pad_on_batch=params.get('PAD_ON_BATCH', True),
                    tokenization=params.get('TOKENIZATION_METHOD',
                                            'tokenize_none'),
                    sample_weights=params.get('SAMPLE_WEIGHTS', True),
                    max_text_len=params.get('MAX_OUTPUT_TEXT_LEN', 70),
                    max_words=params.get('OUTPUT_VOCABULARY_SIZE', 0),
                    bpe_codes=params.get('BPE_CODES_PATH', None),
                    label_smoothing=0.)

        # INPUT DATA
        # We must ensure that the 'train' split is the first (for building the vocabulary)
        for split in params['TEXT_FILES']:
            build_vocabulary = split == 'train'
            ds.setInput(os.path.join(
                base_path, params['TEXT_FILES'][split] + params['SRC_LAN']),
                        split,
                        type=params.get('INPUTS_TYPES_DATASET',
                                        ['text', 'text'])[0],
                        id=params['INPUTS_IDS_DATASET'][0],
                        pad_on_batch=params.get('PAD_ON_BATCH', True),
                        tokenization=params.get('TOKENIZATION_METHOD',
                                                'tokenize_none'),
                        build_vocabulary=build_vocabulary,
                        fill=params.get('FILL', 'end'),
                        max_text_len=params.get('MAX_INPUT_TEXT_LEN', 70),
                        max_words=params.get('INPUT_VOCABULARY_SIZE', 0),
                        min_occ=params.get('MIN_OCCURRENCES_INPUT_VOCAB', 0),
                        bpe_codes=params.get('BPE_CODES_PATH', None))

            if len(params['INPUTS_IDS_DATASET']) > 1:
                if 'train' in split:
                    ds.setInput(
                        os.path.join(
                            base_path,
                            params['TEXT_FILES'][split] + params['TRG_LAN']),
                        split,
                        type=params.get('INPUTS_TYPES_DATASET',
                                        ['text', 'text'])[1],
                        id=params['INPUTS_IDS_DATASET'][1],
                        required=False,
                        tokenization=params.get('TOKENIZATION_METHOD',
                                                'tokenize_none'),
                        pad_on_batch=params.get('PAD_ON_BATCH', True),
                        build_vocabulary=params['OUTPUTS_IDS_DATASET'][0],
                        offset=1,
                        fill=params.get('FILL', 'end'),
                        max_text_len=params.get('MAX_OUTPUT_TEXT_LEN', 70),
                        max_words=params.get('OUTPUT_VOCABULARY_SIZE', 0),
                        bpe_codes=params.get('BPE_CODES_PATH', None))
                    if params.get('TIE_EMBEDDINGS', False):
                        ds.merge_vocabularies([
                            params['INPUTS_IDS_DATASET'][1],
                            params['INPUTS_IDS_DATASET'][0]
                        ])
                else:
                    ds.setInput(None,
                                split,
                                type='ghost',
                                id=params['INPUTS_IDS_DATASET'][-1],
                                required=False)
            if params.get('ALIGN_FROM_RAW', True) and not params.get(
                    'HOMOGENEOUS_BATCHES', False):
                ds.setRawInput(os.path.join(
                    base_path,
                    params['TEXT_FILES'][split] + params['SRC_LAN']),
                               split,
                               type='file-name',
                               id='raw_' + params['INPUTS_IDS_DATASET'][0])
        if params.get('POS_UNK', False):
            if params.get('HEURISTIC', 0) > 0:
                ds.loadMapping(params['MAPPING'])
        # Prepare references
        prepare_references(ds, repeat=1, n=1, set_names=params['EVAL_ON_SETS'])

        # We have finished loading the dataset, now we can store it for using it in the future
        saveDataset(ds, params['DATASET_STORE_PATH'])

    else:
        # We can easily recover it with a single line
        ds = loadDataset(
            os.path.join(
                params['DATASET_STORE_PATH'],
                'Dataset_' + params['DATASET_NAME'] + '_' + params['SRC_LAN'] +
                params['TRG_LAN'] + '.pkl'))

        # Prepare references
        prepare_references(ds, repeat=1, n=1, set_names=params['EVAL_ON_SETS'])

    return ds