示例#1
0
    def load_data(self):

        train_filepath = valid_path_append(self.path, 'train.csv')
        test_filepath = valid_path_append(self.path, 'test.csv')

        train_texts, train_labels = self.load_split(train_filepath)
        test_texts, test_labels = self.load_split(test_filepath)

        train_dict = self.make_dict(train_texts, train_labels)
        test_dict = self.make_dict(test_texts, test_labels)
        data_dict = {'train': train_dict, 'test': test_dict}

        self.data_dict = data_dict
        return data_dict
示例#2
0
    def parse_text_window(self):
        wiki_window_file = valid_path_append(
            self.path, 'movieqa/lower_wiki-w=0-d=3-m-4.txt')
        # as it stands the wiki_window_file is created ahead of time...
        # this bit looks for it and completed the pre-processing
        # it can take a couple of hours for this process to complete
        if os.path.exists(wiki_window_file):
            print('Wiki Window Previously created')
        else:
            print('Pre-processing text through wiki windows')
            print('Note that this could take a number of hours')
            subprocess.call(['python', 'wikiwindows.py', self.path])

        knowledge_dict = defaultdict(list)
        with open(wiki_window_file) as read:
            movie_ent = None
            for l in read:
                l = l.strip()
                if len(l) > 0:
                    nid, line = l.split(' ', 1)

                    if nid == '1':
                        movie_ent = line.split(' ')[0]

                    sentence, center = line.split('\t')
                    # Add the movie sentences as well as the center encoding
                    knowledge_dict[movie_ent].append((sentence, center))
                    knowledge_dict[center].append((sentence, movie_ent))

        return knowledge_dict
示例#3
0
文件: ptb.py 项目: rsumner31/ngraph
    def load_data(self):
        self.data_dict = {}
        self.vocab = None
        for phase in ['train', 'test', 'valid']:
            filename, filesize = self.filemap[phase]['filename'], self.filemap[phase]['size']
            workdir, filepath = valid_path_append(self.path, '', filename)
            if not os.path.exists(filepath):
                fetch_file(self.url, filename, filepath, filesize)

            tokens = open(filepath).read()  # add tokenization here if necessary

            if self.use_words:
                tokens = tokens.strip().split()

            self.vocab = sorted(set(tokens)) if self.vocab is None else self.vocab

            # vocab dicts
            self.token_to_index = dict((t, i) for i, t in enumerate(self.vocab))
            self.index_to_token = dict((i, t) for i, t in enumerate(self.vocab))

            # map tokens to indices
            X = np.asarray([self.token_to_index[t] for t in tokens], dtype=np.uint32)
            if self.shift_target:
                y = np.concatenate((X[1:], X[:1]))
            else:
                y = X.copy()

            self.data_dict[phase] = {'inp_txt': X, 'tgt_txt': y}

        return self.data_dict
示例#4
0
    def load_data(self):
        """
        Fetch the MNIST dataset and load it into memory.

        Arguments:
            path (str, optional): Local directory in which to cache the raw
                                  dataset.  Defaults to current directory.

        Returns:
            tuple: Both training and test sets are returned.
        """
        workdir, filepath = valid_path_append(self.path, '', self.filename)
        if not os.path.exists(filepath):
            fetch_file(self.url, self.filename, filepath, self.size)

        with gzip.open(filepath, 'rb') as f:
            self.train_set, self.valid_set = pickle_load(f)

        self.train_set = {'image': {'data': self.train_set[0].reshape(60000, 28, 28),
                                    'axes': ('batch', 'height', 'width')},
                          'label': {'data': self.train_set[1],
                                    'axes': ('batch',)}}
        self.valid_set = {'image': {'data': self.valid_set[0].reshape(10000, 28, 28),
                                    'axes': ('batch', 'height', 'width')},
                          'label': {'data': self.valid_set[1],
                                    'axes': ('batch',)}}

        return self.train_set, self.valid_set
示例#5
0
    def parse_text_window(self):
        wiki_window_file = valid_path_append(self.path, 'movieqa/lower_wiki-w=0-d=3-m-4.txt')
        # as it stands the wiki_window_file is created ahead of time...
        # this bit looks for it and completed the pre-processing
        # it can take a couple of hours for this process to complete
        if os.path.exists(wiki_window_file):
            print('Wiki Window Previously created')
        else:
            print('Pre-processing text through wiki windows')
            print('Note that this could take a number of hours')
            subprocess.call(['python', 'wikiwindows.py', self.path])

        knowledge_dict = defaultdict(list)
        with open(wiki_window_file) as read:
            movie_ent = None
            for l in read:
                l = l.strip()
                if len(l) > 0:
                    nid, line = l.split(' ', 1)

                    if nid == '1':
                        movie_ent = line.split(' ')[0]

                    sentence, center = line.split('\t')
                    # Add the movie sentences as well as the center encoding
                    knowledge_dict[movie_ent].append((sentence, center))
                    knowledge_dict[center].append((sentence, movie_ent))

        return knowledge_dict
示例#6
0
文件: lsun.py 项目: rsumner31/ngraph
 def download_lsun(self, category, dset, tag='latest', overwrite=False):
     """
     Download LSUN data and unpack
     Arguments:
         category (str): LSUN category (valid selections: lsun_categories)
         dset (str): dataset, "train", "val", or "test"
         tag (str, optional): version tag, defaults to most recent
         overwrite (bool): whether to overwrite existing data
     """
     dfile = 'test_lmdb' if dset == 'test' else '{0}_{1}_lmdb'.format(
         category, dset)
     self.filepath = filepath = valid_path_append(self.path, dfile)
     if not os.path.exists(filepath) or overwrite:
         filepath += '.zip'
         if not os.path.exists(filepath):
             url = LSUN.url + \
                 'download.cgi?tag={0}&category={1}&set={2}'.format(tag, category, dset)
             print('Data download might take a long time.')
             print('Downloading {0} {1} set...'.format(category, dset))
             subprocess.call(['curl', url, '-o', filepath])
             # TODO
             # should change to fetch_file,
             # but currently did not get the correct "Content-length" or total_size
             # fetch_file(url, 'bedroom_train_lmdb.zip', filepath)
         print('Extracting {0} {1} set...'.format(category, dset))
         zf = zipfile.ZipFile(filepath, 'r')
         zf.extractall(self.path)
         zf.close()
         print('Deleting {} ...'.format(filepath))
         os.remove(filepath)
     else:
         pass  # data already downloaded
     print("LSUN {0} {1} dataset downloaded and unpacked.".format(
         category, dset))
示例#7
0
文件: tsp.py 项目: rsumner31/ngraph
    def load_data(self):
        self.data_dict = {}
        for phase in ['train', 'test']:
            filename = self.filemap[phase]['filename']
            workdir, filepath = valid_path_append(self.path, '', filename)
            if not os.path.exists(filepath):
                for file_name, file_id in GOOGLE_DRIVE_IDS.items():
                    destination = './' + file_name
                    print('\nDownloading and unzipping traveling salesman data {} released '
                          'with Pointer Networks paper\n'.format(file_name))
                    self.download_file_from_google_drive(file_id, destination)
                    with zipfile.ZipFile(destination, 'r') as z:
                        z.extractall('./')

            cities = int(re.search(r'\d+', filename).group())
            print('Loading and preprocessing tsp{} {} data...'.format(cities, phase))
            with open(filepath, 'r') as f:
                X, y, y_teacher = [], [], []
                for i, line in tqdm(enumerate(f)):
                    inputs, outputs = line.split('output')
                    X.append(np.array([float(j) for j in inputs.split()]).reshape([-1, 2]))
                    y.append(np.array([int(j) - 1 for j in outputs.split()])[:-1])  # delete last
                    # teacher forcing array as decoder's input while training
                    y_teacher.append([X[i][j - 1] for j in y[i]])

            X = np.array(X)
            y = np.array(y)
            y_teacher = np.array(y_teacher)
            self.data_dict[phase] = {'inp_txt': X, 'tgt_txt': y, 'teacher_tgt': y_teacher}

        return self.data_dict
示例#8
0
文件: imdb.py 项目: ugiwgh/ngraph
    def load_data(self, test_split=0.2):
        self.data_dict = {}
        self.vocab = None
        workdir, filepath = valid_path_append(self.path, '', self.filename)
        if not os.path.exists(filepath):
            fetch_file(self.url, self.filename, filepath, self.filesize)

        with open(filepath, 'rb') as f:
            X, y = pickle_load(f)

        X = preprocess_text(X, self.vocab_size)
        X = pad_sentences(X,
                          pad_idx=self.pad_idx,
                          pad_to_len=self.sentence_length,
                          pad_from='left')

        if self.shuffle:
            indices = np.arange(len(y))
            np.random.shuffle(indices)
            X = X[indices]
            y = np.asarray(y)[indices]

        # split the data
        X_train = X[:int(len(X) * (1 - test_split))]
        y_train = y[:int(len(X) * (1 - test_split))]

        X_test = X[int(len(X) * (1 - test_split)):]
        y_test = y[int(len(X) * (1 - test_split)):]

        y_train = np.array(y_train)
        y_test = np.array(y_test)

        self.nclass = 1 + max(np.max(y_train), np.max(y_test))

        self.data_dict['train'] = {
            'review': {
                'data': X_train,
                'axes': ('batch', 'REC')
            },
            'label': {
                'data': y_train,
                'axes': ('batch', )
            }
        }
        self.data_dict['valid'] = {
            'review': {
                'data': X_test,
                'axes': ('batch', 'REC')
            },
            'label': {
                'data': y_test,
                'axes': ('batch', )
            }
        }
        return self.data_dict
示例#9
0
    def parse_kb(self, reverse_dictionary):
        # Check data repo to see if this exists, if not then run code and save
        workdir, kb_file_path = valid_path_append(self.path, '', 'movieqa/kb_dictionary.pkl')
        if os.path.exists(kb_file_path) and not self.reparse:
            print('Loading files from path')
            print(kb_file_path)
            knowledge_dict = pickle.load(open(kb_file_path, "rb"))
            return knowledge_dict

        action_words = ['directed_by', 'written_by', 'starred_actors', 'release_year',
                        'in_language', 'has_tags', 'has_plot', 'has_imdb_votes', 'has_imdb_rating']
        rev_actions_pre = 'REV_'

        with open(self.kb_file) as file:
            babi_data = file.read()

        lines = self.data_to_list(babi_data)

        knowledge_dict = defaultdict(list)
        fact = None
        for line in lines:
            if len(line) > 1:
                nid, line = line.lower().split(' ', 1)

                if int(nid) == 1:
                    fact = None

                for a in action_words:
                    # find the action word and split on that, ignoring the has_plot info for now
                    if len(line.split(a)) > 1 and a != 'has_plot':
                        # there can be more than one entity on the left hand side
                        # (particually with starred_actors)
                        entities = line.split(a)
                        # Let's use the info that we know the fact for related stories
                        if not fact:
                            fact = ex_entity_names(entities[0], reverse_dictionary, self.re_list)
                        subject_entities = [ex_entity_names(e_0, reverse_dictionary, self.re_list)
                                            for e_0 in entities[1].split(', ')]

                        # create a hash for the knowledge base where key is the subject
                        # add the fact and reverse fact
                        for subject in subject_entities:
                            knowledge_dict[fact].append((fact + ' ' + a, subject))
                            # Also add reverse here
                            knowledge_dict[subject].append((subject + ' ' +
                                                            rev_actions_pre+a, fact))

        kb_out_path = ensure_dirs_exist(os.path.join(workdir, kb_file_path))

        print('Writing to ', kb_out_path)
        with open(kb_out_path, 'wb') as f:
            pickle.dump(knowledge_dict, f)

        return knowledge_dict
    def load_data(self):
        self.data_dict = {}
        workdir, filepath = valid_path_append(self.path, '', self.filename)
        if not os.path.exists(filepath):
            fetch_file(self.url, self.filename, filepath)

        tokens = open(filepath).read()
        train_samples = int(self.train_split * len(tokens))
        train = tokens[:train_samples]
        test = tokens[train_samples:]

        return train, test
示例#11
0
    def load_data(self):
        """
        Fetch the CIFAR-10 dataset and load it into memory.

        Arguments:
            path (str, optional): Local directory in which to cache the raw
                                  dataset.  Defaults to current directory.
            normalize (bool, optional): Whether to scale values between 0 and 1.
                                        Defaults to True.

        Returns:
            tuple: Both training and test sets are returned.
        """
        workdir, filepath = valid_path_append(self.path, '', self.filename)
        if not os.path.exists(filepath):
            fetch_file(self.url, self.filename, filepath, self.size)

        batchdir = os.path.join(workdir, 'cifar-10-batches-py')
        if not os.path.exists(os.path.join(batchdir, 'data_batch_1')):
            assert os.path.exists(filepath), "Must have cifar-10-python.tar.gz"
            with tarfile.open(filepath, 'r:gz') as f:
                f.extractall(workdir)

        train_batches = [os.path.join(batchdir, 'data_batch_' + str(i)) for i in range(1, 6)]
        Xlist, ylist = [], []
        for batch in train_batches:
            with open(batch, 'rb') as f:
                d = pickle_load(f)
                Xlist.append(d['data'])
                ylist.append(d['labels'])

        X_train = np.vstack(Xlist).reshape(-1, 3, 32, 32)
        y_train = np.vstack(ylist).ravel()

        with open(os.path.join(batchdir, 'test_batch'), 'rb') as f:
            d = pickle_load(f)
            X_test, y_test = d['data'], d['labels']
            X_test = X_test.reshape(-1, 3, 32, 32)

        self.train_set = {'image': {'data': X_train,
                                    'axes': ('batch', 'channel', 'height', 'width')},
                          'label': {'data': y_train,
                                    'axes': ('batch',)}}
        self.valid_set = {'image': {'data': X_test,
                                    'axes': ('batch', 'channel', 'height', 'width')},
                          'label': {'data': np.array(y_test),
                                    'axes': ('batch',)}}

        return self.train_set, self.valid_set
示例#12
0
    def load_data(self, path=".", subset='wiki_entities'):
        """
        Fetch the Facebook WikiMovies dataset and load it to memory.

        Arguments:
            path (str, optional): Local directory in which to cache the raw
                                  dataset.  Defaults to current directory.

        Returns:
            tuple: knowledge base, entity list, training and test files are returned
        """
        self.data_dict = {}
        self.vocab = None
        workdir, filepath = valid_path_append(path, '', self.filename)
        babi_dir_name = self.filename.split('.')[0]

        if subset == 'wiki-entities':
            subset_folder = 'wiki_entities'
        else:
            subset_folder = subset

        file_base = babi_dir_name + '/questions/' + subset_folder + '/' + subset + '_qa_{}.txt'
        train_file = os.path.join(workdir, file_base.format('train'))
        test_file = os.path.join(workdir, file_base.format('test'))

        entity_file_path = babi_dir_name + '/knowledge_source/entities.txt'
        entity_file = os.path.join(workdir, entity_file_path)

        # Check for the existence of the entity file
        # If it isn't there then we know we need to fetch everything
        if not os.path.exists(entity_file):
            if license_prompt('WikiMovies',
                              'https://research.fb.com/downloads/babi/',
                              self.path) is False:
                sys.exit(0)

            fetch_file(self.url, self.filename, filepath, self.size)

        knowledge_file_path = babi_dir_name + '/knowledge_source/' + subset_folder + '/' \
            + subset_folder + '_kb.txt'
        kb_file = os.path.join(workdir, knowledge_file_path)

        return entity_file, kb_file, train_file, test_file
示例#13
0
    def load_data(self, path=".", subset='wiki_entities'):
        """
        Fetch the Facebook WikiMovies dataset and load it to memory.

        Arguments:
            path (str, optional): Local directory in which to cache the raw
                                  dataset.  Defaults to current directory.

        Returns:
            tuple: knowledge base, entity list, training and test files are returned
        """
        self.data_dict = {}
        self.vocab = None
        workdir, filepath = valid_path_append(path, '', self.filename)
        babi_dir_name = self.filename.split('.')[0]

        if subset == 'wiki-entities':
            subset_folder = 'wiki_entities'
        else:
            subset_folder = subset

        file_base = babi_dir_name + '/questions/' + subset_folder + '/' + subset + '_qa_{}.txt'
        train_file = os.path.join(workdir, file_base.format('train'))
        test_file = os.path.join(workdir, file_base.format('test'))

        entity_file_path = babi_dir_name + '/knowledge_source/entities.txt'
        entity_file = os.path.join(workdir, entity_file_path)

        # Check for the existence of the entity file
        # If it isn't there then we know we need to fetch everything
        if not os.path.exists(entity_file):
            if license_prompt('WikiMovies',
                              'https://research.fb.com/downloads/babi/',
                              self.path) is False:
                sys.exit(0)

            fetch_file(self.url, self.filename, filepath, self.size)

        knowledge_file_path = babi_dir_name + '/knowledge_source/' + subset_folder + '/' \
            + subset_folder + '_kb.txt'
        kb_file = os.path.join(workdir, knowledge_file_path)

        return entity_file, kb_file, train_file, test_file
示例#14
0
    def load_data(self, data_directory=None, manifest_file=None):
        """
        Create a manifest file for the requested dataset. First downloads the
        dataset and extracts it, if necessary.

        Arguments:
            data_directory (str): Path to data directory. Defaults to <path>/<version>
            manifest_file (str): Path to manifest file. Defaults to <data_directory>/manifest.tsv

        Returns:
            Path to manifest file
        """

        if manifest_file is None:
            if self.manifest_file is not None:
                manifest_file = self.manifest_file
            else:
                manifest_file = os.path.join(self.path, "manifest.tsv")

        if os.path.exists(manifest_file):
            return manifest_file

        # Download the file
        workdir, filepath = valid_path_append(self.path, '', self.source_file)
        if not os.path.exists(filepath):
            fetch_file(self.url, self.source_file, filepath)

        # Untar the file
        if data_directory is None:
            data_directory = os.path.join(self.path, self.version)
        if not os.path.exists(data_directory):
            print("Extracting tar file to {}".format(data_directory))
            with contextlib.closing(tarfile.open(filepath)) as tf:
                tf.extractall(data_directory)

        # Ingest the file
        ingest_librispeech(data_directory, manifest_file)

        return manifest_file
示例#15
0
    def load_data(self):
        """
        Fetch and extract the Facebook bAbI-dialog dataset if not already downloaded.

        Returns:
            tuple: training and test filenames are returned
        """
        if self.task < 5:
            self.candidate_answer_filename = 'dialog-babi-candidates.txt'
            self.kb_filename = 'dialog-babi-kb-all.txt'
            self.cands_mat_filename = 'babi-cands-with-matchtype_{}.npy'
            self.vocab_filename = 'dialog-babi-vocab-task{}'.format(self.task + 1) +\
                                  '_matchtype{}.pkl'.format(self.use_match_type)
        else:
            self.candidate_answer_filename = 'dialog-babi-task6-dstc2-candidates.txt'
            self.kb_filename = 'dialog-babi-task6-dstc2-kb.txt'
            self.cands_mat_filename = 'dstc2-cands-with-matchtype_{}.npy'
            self.vocab_filename = 'dstc2-vocab-task{}_matchtype{}.pkl'.format(
                self.task + 1, self.use_match_type)

        self.vectorized_filename = 'vectorized_task{}.pkl'.format(self.task +
                                                                  1)

        self.data_dict = {}
        self.vocab = None
        self.workdir, filepath = valid_path_append(self.path, '',
                                                   self.filename)
        if not os.path.exists(filepath):
            if license_prompt('bAbI-dialog',
                              'https://research.fb.com/downloads/babi/',
                              self.path) is False:
                sys.exit(0)

            download_unlicensed_file(self.url, self.filename, filepath,
                                     self.size)

        self.babi_dir_name = self.filename.split('.')[0]

        self.candidate_answer_filename = self.babi_dir_name + \
            '/' + self.candidate_answer_filename
        self.kb_filename = self.babi_dir_name + '/' + self.kb_filename
        self.cands_mat_filename = os.path.join(
            self.workdir, self.babi_dir_name + '/' + self.cands_mat_filename)
        self.vocab_filename = self.babi_dir_name + '/' + self.vocab_filename
        self.vectorized_filename = self.babi_dir_name + '/' + self.vectorized_filename

        task_name = self.babi_dir_name + '/' + self.tasks[self.task] + '{}.txt'

        train_file = os.path.join(self.workdir, task_name.format('trn'))
        dev_file = os.path.join(self.workdir, task_name.format('dev'))
        test_file_postfix = 'tst-OOV' if self.oov else 'tst'
        test_file = os.path.join(self.workdir,
                                 task_name.format(test_file_postfix))

        cand_file = os.path.join(self.workdir, self.candidate_answer_filename)
        kb_file = os.path.join(self.workdir, self.kb_filename)
        vocab_file = os.path.join(self.workdir, self.vocab_filename)
        vectorized_file = os.path.join(self.workdir, self.vectorized_filename)

        if (os.path.exists(train_file) is False
                or os.path.exists(dev_file) is False
                or os.path.exists(test_file) is False
                or os.path.exists(cand_file) is False):
            with tarfile.open(filepath, 'r:gz') as f:
                f.extractall(self.workdir)

        return train_file, dev_file, test_file, cand_file, kb_file, vocab_file, vectorized_file
示例#16
0
    def __init__(self, path='.', subset='wiki-entities', reparse=False,
                 mem_source='kb'):

        self.url = 'http://www.thespermwhale.com/jaseweston/babi'
        self.size = 11745123
        self.filename = 'movieqa.tar.gz'
        self.path = path
        self.reparse = reparse
        data_sub_path = 'movieqa/parsed_data_{}'.format(mem_source)
        data_dict_out_path = valid_path_append(self.path, data_sub_path + '_full_parse.pkl')
        data_dict_train_out_path = valid_path_append(self.path, data_sub_path + '_train.pkl')
        data_dict_test_out_path = valid_path_append(self.path, data_sub_path + '_test.pkl')
        infer_elems_out_path = valid_path_append(self.path, data_sub_path + '_infer_elems.pkl')

        # First try reading from the prevsiously parsed data
        if os.path.exists(data_dict_out_path) and not self.reparse:
            print('Extracting pre-parsed data from ', data_dict_out_path)
            self.data_dict = pickle.load(open(data_dict_out_path, "rb"))
            self.story_length = self.data_dict['info']['story_length']
            self.memory_size = self.data_dict['info']['memory_size']
            self.vocab_size = self.data_dict['info']['vocab_size']

            inference_elems = pickle.load(open(infer_elems_out_path, 'rb'))
            self.full_rev_entity_dict = inference_elems['rev_entity_dict']
            self.full_entity_dict = inference_elems['entity_dict']
            self.knowledge_dict = inference_elems['knowledge_dict']
            self.re_list = inference_elems['regex_list']
            self.word_to_index = inference_elems['word_index']
            self.index_to_word = inference_elems['index_to_word']

        else:
            if not os.path.exists(data_dict_train_out_path) or self.reparse:
                print('Preparing WikiMovies dataset or extracting from %s' % path)
                self.entity_file, self.kb_file, self.train_file, self.test_file = \
                    self.load_data(path, subset=subset)
                print('Creating Entity Dictionary')
                self.full_entity_dict, self.full_rev_entity_dict, self.re_list = \
                    self.create_entity_dict()
                print('Creating knowledge base information')
                if mem_source == 'text':
                    print('Creating knowledge base information from text')
                    self.knowledge_dict = self.parse_text_window()
                else:
                    print('Creating knowledge base information from kb')
                    self.knowledge_dict = self.parse_kb(self.full_rev_entity_dict)

                print('Parsing files')
                self.train_parsed = WIKIMOVIES.parse_wikimovies(self.train_file,
                                                                self.full_rev_entity_dict,
                                                                self.knowledge_dict, self.re_list)
                self.test_parsed = WIKIMOVIES.parse_wikimovies(self.test_file,
                                                               self.full_rev_entity_dict,
                                                               self.knowledge_dict, self.re_list)

                print('Writing to ', data_dict_train_out_path)
                with open(data_dict_train_out_path, 'wb') as f:
                    pickle.dump(self.train_parsed, f)

                with open(data_dict_test_out_path, 'wb') as f:
                    pickle.dump(self.test_parsed, f)

                # Save items needed for inference
                save_elems = {'rev_entity_dict': self.full_rev_entity_dict,
                              'entity_dict': self.full_entity_dict,
                              'knowledge_dict': self.knowledge_dict,
                              'regex_list': self.re_list
                              }
                with open(infer_elems_out_path, 'wb') as f:
                    pickle.dump(save_elems, f)

            else:
                self.data_dict = {}
                self.test_parsed = pickle.load(open(data_dict_test_out_path, 'rb'))
                self.train_parsed = pickle.load(open(data_dict_train_out_path, 'rb'))

            print('Computing Stats')
            self.compute_statistics(self.train_parsed, self.test_parsed)

            print('Vectorizing')
            self.test = self.vectorize_stories(self.test_parsed)
            print('done test')
            self.train = self.vectorize_stories(self.train_parsed)
            print('done train')

            self.story_length = self.story_maxlen
            self.memory_size = self.max_storylen
            self.query_length = self.query_maxlen

            self.data_dict['train'] = {'keys': {'data': self.train[0],
                                                'axes': ('batch', 'memory_axis', 'sentence_axis')},
                                       'values': {'data': self.train[1],
                                                  'axes': ('batch', 'memory_axis', 1)},
                                       'query': {'data': self.train[2],
                                                 'axes': ('batch', 'sentence_axis')},
                                       'answer': {'data': self.train[3],
                                                  'axes': ('batch', 'vocab_axis')}
                                       }

            self.data_dict['test'] = {'keys': {'data': self.test[0],
                                               'axes': ('batch', 'memory_axis', 'sentence_axis')},
                                      'values': {'data': self.test[1],
                                                 'axes': ('batch', 'memory_axis', 1)},
                                      'query': {'data': self.test[2],
                                                'axes': ('batch', 'sentence_axis')},
                                      'answer': {'data': self.test[3],
                                                 'axes': ('batch', 'vocab_axis')}
                                      }

            self.data_dict['info'] = {'story_length': self.story_length,
                                      'memory_size': self.memory_size,
                                      'vocab_size': self.vocab_size
                                      }

            print('Writing to ', data_dict_out_path)
            with open(data_dict_out_path, 'wb') as f:
                pickle.dump(self.data_dict, f)

            # Make sure the index_to_word has an entity UNK for the zero ID
            # This is used during inference
            self.index_to_word[0] = 'UNKNOWN'

            # Also save out the inference elements with the word_to_index
            inference_elems = pickle.load(open(infer_elems_out_path, 'rb'))
            save_elems = {'entity_dict': inference_elems['entity_dict'],
                          'rev_entity_dict': inference_elems['rev_entity_dict'],
                          'knowledge_dict': inference_elems['knowledge_dict'],
                          'regex_list': inference_elems['regex_list'],
                          'word_index': self.word_to_index,
                          'index_to_word': self.index_to_word,
                          }
            with open(infer_elems_out_path, 'wb') as f:
                pickle.dump(save_elems, f)
示例#17
0
    def __init__(self,
                 path='.',
                 subset='wiki-entities',
                 reparse=False,
                 mem_source='kb'):

        self.url = 'http://www.thespermwhale.com/jaseweston/babi'
        self.size = 11745123
        self.filename = 'movieqa.tar.gz'
        self.path = path
        self.reparse = reparse
        data_sub_path = 'movieqa/parsed_data_{}'.format(mem_source)
        data_dict_out_path = valid_path_append(
            self.path, data_sub_path + '_full_parse.pkl')
        data_dict_train_out_path = valid_path_append(
            self.path, data_sub_path + '_train.pkl')
        data_dict_test_out_path = valid_path_append(
            self.path, data_sub_path + '_test.pkl')
        infer_elems_out_path = valid_path_append(
            self.path, data_sub_path + '_infer_elems.pkl')

        # First try reading from the prevsiously parsed data
        if os.path.exists(data_dict_out_path) and not self.reparse:
            print('Extracting pre-parsed data from ', data_dict_out_path)
            self.data_dict = pickle.load(open(data_dict_out_path, "rb"))
            self.story_length = self.data_dict['info']['story_length']
            self.memory_size = self.data_dict['info']['memory_size']
            self.vocab_size = self.data_dict['info']['vocab_size']

            inference_elems = pickle.load(open(infer_elems_out_path, 'rb'))
            self.full_rev_entity_dict = inference_elems['rev_entity_dict']
            self.full_entity_dict = inference_elems['entity_dict']
            self.knowledge_dict = inference_elems['knowledge_dict']
            self.re_list = inference_elems['regex_list']
            self.word_to_index = inference_elems['word_index']
            self.index_to_word = inference_elems['index_to_word']

        else:
            if not os.path.exists(data_dict_train_out_path) or self.reparse:
                print('Preparing WikiMovies dataset or extracting from %s' %
                      path)
                self.entity_file, self.kb_file, self.train_file, self.test_file = \
                    self.load_data(path, subset=subset)
                print('Creating Entity Dictionary')
                self.full_entity_dict, self.full_rev_entity_dict, self.re_list = \
                    self.create_entity_dict()
                print('Creating knowledge base information')
                if mem_source == 'text':
                    print('Creating knowledge base information from text')
                    self.knowledge_dict = self.parse_text_window()
                else:
                    print('Creating knowledge base information from kb')
                    self.knowledge_dict = self.parse_kb(
                        self.full_rev_entity_dict)

                print('Parsing files')
                self.train_parsed = WIKIMOVIES.parse_wikimovies(
                    self.train_file, self.full_rev_entity_dict,
                    self.knowledge_dict, self.re_list)
                self.test_parsed = WIKIMOVIES.parse_wikimovies(
                    self.test_file, self.full_rev_entity_dict,
                    self.knowledge_dict, self.re_list)

                print('Writing to ', data_dict_train_out_path)
                with open(data_dict_train_out_path, 'wb') as f:
                    pickle.dump(self.train_parsed, f)

                with open(data_dict_test_out_path, 'wb') as f:
                    pickle.dump(self.test_parsed, f)

                # Save items needed for inference
                save_elems = {
                    'rev_entity_dict': self.full_rev_entity_dict,
                    'entity_dict': self.full_entity_dict,
                    'knowledge_dict': self.knowledge_dict,
                    'regex_list': self.re_list
                }
                with open(infer_elems_out_path, 'wb') as f:
                    pickle.dump(save_elems, f)

            else:
                self.data_dict = {}
                self.test_parsed = pickle.load(
                    open(data_dict_test_out_path, 'rb'))
                self.train_parsed = pickle.load(
                    open(data_dict_train_out_path, 'rb'))

            print('Computing Stats')
            self.compute_statistics(self.train_parsed, self.test_parsed)

            print('Vectorizing')
            self.test = self.vectorize_stories(self.test_parsed)
            print('done test')
            self.train = self.vectorize_stories(self.train_parsed)
            print('done train')

            self.story_length = self.story_maxlen
            self.memory_size = self.max_storylen
            self.query_length = self.query_maxlen

            self.data_dict['train'] = {
                'keys': {
                    'data': self.train[0],
                    'axes': ('batch', 'memory_axis', 'sentence_axis')
                },
                'values': {
                    'data': self.train[1],
                    'axes': ('batch', 'memory_axis', 1)
                },
                'query': {
                    'data': self.train[2],
                    'axes': ('batch', 'sentence_axis')
                },
                'answer': {
                    'data': self.train[3],
                    'axes': ('batch', 'vocab_axis')
                }
            }

            self.data_dict['test'] = {
                'keys': {
                    'data': self.test[0],
                    'axes': ('batch', 'memory_axis', 'sentence_axis')
                },
                'values': {
                    'data': self.test[1],
                    'axes': ('batch', 'memory_axis', 1)
                },
                'query': {
                    'data': self.test[2],
                    'axes': ('batch', 'sentence_axis')
                },
                'answer': {
                    'data': self.test[3],
                    'axes': ('batch', 'vocab_axis')
                }
            }

            self.data_dict['info'] = {
                'story_length': self.story_length,
                'memory_size': self.memory_size,
                'vocab_size': self.vocab_size
            }

            print('Writing to ', data_dict_out_path)
            with open(data_dict_out_path, 'wb') as f:
                pickle.dump(self.data_dict, f)

            # Make sure the index_to_word has an entity UNK for the zero ID
            # This is used during inference
            self.index_to_word[0] = 'UNKNOWN'

            # Also save out the inference elements with the word_to_index
            inference_elems = pickle.load(open(infer_elems_out_path, 'rb'))
            save_elems = {
                'entity_dict': inference_elems['entity_dict'],
                'rev_entity_dict': inference_elems['rev_entity_dict'],
                'knowledge_dict': inference_elems['knowledge_dict'],
                'regex_list': inference_elems['regex_list'],
                'word_index': self.word_to_index,
                'index_to_word': self.index_to_word,
            }
            with open(infer_elems_out_path, 'wb') as f:
                pickle.dump(save_elems, f)
示例#18
0
    def load_data(self):
        """
        Fetch and extract the Facebook bAbI-dialog dataset if not already downloaded.

        Returns:
            tuple: training and test filenames are returned
        """
        if self.task < 5:
            self.candidate_answer_filename = 'dialog-babi-candidates.txt'
            self.kb_filename = 'dialog-babi-kb-all.txt'
            self.cands_mat_filename = 'babi-cands-with-matchtype_{}.npy'
            self.vocab_filename = 'dialog-babi-vocab-task{}.pkl'.format(self.task + 1)
        else:
            self.candidate_answer_filename = 'dialog-babi-task6-dstc2-candidates.txt'
            self.kb_filename = 'dialog-babi-task6-dstc2-kb.txt'
            self.cands_mat_filename = 'dstc2-cands-with-matchtype_{}.npy'
            self.vocab_filename = 'dstc2-vocab-task{}.pkl'.format(self.task + 1)

        self.vectorized_filename = 'vectorized_task{}.pkl'.format(self.task + 1)

        self.data_dict = {}
        self.vocab = None
        self.workdir, filepath = valid_path_append(
            self.path, '', self.filename)
        if not os.path.exists(filepath):
            if license_prompt('bAbI-dialog',
                              'https://research.fb.com/downloads/babi/',
                              self.path) is False:
                sys.exit(0)

            fetch_file(self.url, self.filename, filepath, self.size)

        self.babi_dir_name = self.filename.split('.')[0]

        self.candidate_answer_filename = self.babi_dir_name + \
            '/' + self.candidate_answer_filename
        self.kb_filename = self.babi_dir_name + '/' + self.kb_filename
        self.cands_mat_filename = os.path.join(
            self.workdir, self.babi_dir_name + '/' + self.cands_mat_filename)
        self.vocab_filename = self.babi_dir_name + '/' + self.vocab_filename
        self.vectorized_filename = self.babi_dir_name + '/' + self.vectorized_filename

        task_name = self.babi_dir_name + '/' + self.tasks[self.task] + '{}.txt'

        train_file = os.path.join(self.workdir, task_name.format('trn'))
        dev_file = os.path.join(self.workdir, task_name.format('dev'))
        test_file_postfix = 'tst-OOV' if self.oov else 'tst'
        test_file = os.path.join(
            self.workdir,
            task_name.format(test_file_postfix))

        cand_file = os.path.join(self.workdir, self.candidate_answer_filename)
        kb_file = os.path.join(self.workdir, self.kb_filename)
        vocab_file = os.path.join(self.workdir, self.vocab_filename)
        vectorized_file = os.path.join(self.workdir, self.vectorized_filename)

        if (os.path.exists(train_file) is False
            or os.path.exists(dev_file) is False
            or os.path.exists(test_file) is False
                or os.path.exists(cand_file) is False):
            with tarfile.open(filepath, 'r:gz') as f:
                f.extractall(self.workdir)

        return train_file, dev_file, test_file, cand_file, kb_file, vocab_file, vectorized_file
示例#19
0
    def parse_kb(self, reverse_dictionary):
        # Check data repo to see if this exists, if not then run code and save
        workdir, kb_file_path = valid_path_append(self.path, '',
                                                  'movieqa/kb_dictionary.pkl')
        if os.path.exists(kb_file_path) and not self.reparse:
            print('Loading files from path')
            print(kb_file_path)
            knowledge_dict = pickle.load(open(kb_file_path, "rb"))
            return knowledge_dict

        action_words = [
            'directed_by', 'written_by', 'starred_actors', 'release_year',
            'in_language', 'has_tags', 'has_plot', 'has_imdb_votes',
            'has_imdb_rating'
        ]
        rev_actions_pre = 'REV_'

        with open(self.kb_file) as file:
            babi_data = file.read()

        lines = self.data_to_list(babi_data)

        knowledge_dict = defaultdict(list)
        fact = None
        for line in lines:
            if len(line) > 1:
                nid, line = line.lower().split(' ', 1)

                if int(nid) == 1:
                    fact = None

                for a in action_words:
                    # find the action word and split on that, ignoring the has_plot info for now
                    if len(line.split(a)) > 1 and a != 'has_plot':
                        # there can be more than one entity on the left hand side
                        # (particually with starred_actors)
                        entities = line.split(a)
                        # Let's use the info that we know the fact for related stories
                        if not fact:
                            fact = ex_entity_names(entities[0],
                                                   reverse_dictionary,
                                                   self.re_list)
                        subject_entities = [
                            ex_entity_names(e_0, reverse_dictionary,
                                            self.re_list)
                            for e_0 in entities[1].split(', ')
                        ]

                        # create a hash for the knowledge base where key is the subject
                        # add the fact and reverse fact
                        for subject in subject_entities:
                            knowledge_dict[fact].append(
                                (fact + ' ' + a, subject))
                            # Also add reverse here
                            knowledge_dict[subject].append(
                                (subject + ' ' + rev_actions_pre + a, fact))

        kb_out_path = ensure_dirs_exist(os.path.join(workdir, kb_file_path))

        print('Writing to ', kb_out_path)
        with open(kb_out_path, 'wb') as f:
            pickle.dump(knowledge_dict, f)

        return knowledge_dict