예제 #1
0
    def _create(self, data_path, P):
        def get_max_column_length(fname):
            with open(fname) as fin:
                max_col = 0
                for l in fin:
                    max_col = max(max_col, len(l))
            return max_col

        uid_path, iid_path, main_path = P['uid_path'], P['iid_path'], P[
            'main_path']
        if uid_path:
            with open(uid_path) as fin:
                num_users = len([1 for _ in fin])
        else:
            with open(main_path) as fin:
                num_users = len([1 for _ in fin])

        uid_max_col = len(str(num_users)) + 1
        if uid_path:
            uid_max_col = get_max_column_length(uid_path) + 1

        vali_n = self.opt.data.validation.get('n', 0)
        num_nnz, vali_limit, itemids = 0, 0, set()
        self.logger.info(f'gathering itemids from {main_path}...')
        if self.opt.data.validation.name not in ["oldest", "newest"]:
            vali_n = 0
        with open(main_path) as fin:
            for line in log.iter_pbar(log_level=log.DEBUG, iterable=fin):
                data = line.strip().split()
                if not iid_path:
                    itemids |= set(data)

                data_size = len(data)
                _vali_size = min(vali_n, len(data) - 1)
                vali_limit += _vali_size
                if self.opt.data.internal_data_type == 'stream':
                    num_nnz += (data_size - _vali_size)
                elif self.opt.data.internal_data_type == 'matrix':
                    num_nnz += len(set(data[:(data_size - _vali_size)]))
        if iid_path:
            with open(iid_path) as fin:
                itemids = {iid.strip(): idx + 1 for idx, iid in enumerate(fin)}
        else:  # in case of item information is not given
            itemids = {i: idx + 1 for idx, i in enumerate(itemids)}
        iid_max_col = max(len(k) + 1 for k in itemids.keys())
        num_items = len(itemids)

        self.logger.info('Found %d unique itemids' % len(itemids))

        try:
            db = self._create_database(data_path,
                                       num_users=num_users,
                                       num_items=num_items,
                                       num_nnz=num_nnz,
                                       uid_max_col=uid_max_col,
                                       iid_max_col=iid_max_col,
                                       num_validation_samples=vali_limit)
            idmap = db['idmap']
            # if not given, assume id as is
            if uid_path:
                with open(uid_path) as fin:
                    idmap['rows'][:] = np.loadtxt(fin, dtype=f'S{uid_max_col}')
            else:
                idmap['rows'][:] = np.array(
                    [str(i) for i in range(1, num_users + 1)],
                    dtype=f'S{uid_max_col}')
            if iid_path:
                with open(iid_path) as fin:
                    idmap['cols'][:] = np.loadtxt(fin, dtype=f'S{iid_max_col}')
            else:
                cols = sorted(itemids.items(), key=lambda x: x[1])
                cols = [k for k, _ in cols]
                idmap['cols'][:] = np.array(cols, dtype=f'S{iid_max_col}')
        except Exception as e:
            self.logger.error('Cannot create db: %s' % (str(e)))
            self.logger.error(traceback.format_exc())
            raise
        return db, itemids
예제 #2
0
    def _create_working_data(self,
                             db,
                             stream_main_path,
                             itemids,
                             with_sppmi=False,
                             windows=5):
        vali_method = None if 'vali' not in db else db['vali'].attrs['method']
        vali_indexes, vali_n = set(), 0
        if vali_method == 'sample':
            vali_indexes = set(db['vali']['indexes'])
        elif vali_method in ['newest']:
            vali_n = db['vali'].attrs['n']
        vali_lines = []
        users = db['idmap']['rows'][:]

        with warnings.catch_warnings():
            warnings.simplefilter("ignore", ResourceWarning)
            if with_sppmi:
                w_sppmi = open(
                    aux.get_temporary_file(root=self.opt.data.tmp_dir), "w")
            file_path = aux.get_temporary_file(root=self.opt.data.tmp_dir)
            with open(stream_main_path) as fin,\
                open(file_path, 'w') as w:
                total_index = 0
                internal_data_type = self.opt.data.internal_data_type
                for line_idx, data in log.iter_pbar(log_level=log.DEBUG,
                                                    iterable=enumerate(fin)):
                    data = data.strip().split()
                    total_data_size = len(data)
                    user = line_idx + 1
                    vali_data, train_data = [], []
                    if vali_method in ['newest']:
                        vali_data_size = min(vali_n, len(data) - 1)
                        train_data_size = len(data) - vali_data_size
                        vali = data[train_data_size:]
                        data = data[:train_data_size]
                        for col, val in Counter(vali).items():
                            col = itemids[col]
                            vali_data.append(col)
                    if internal_data_type == 'stream':
                        for idx, col in enumerate(data):
                            col = itemids[col]
                            if (idx + total_index) in vali_indexes:
                                vali_data.append(col)
                            else:
                                train_data.append(col)
                    elif internal_data_type == 'matrix':
                        for idx, col in enumerate(data):
                            col = itemids[col]
                            if (idx + total_index) in vali_indexes:
                                vali_data.append(col)
                            else:
                                train_data.append(col)
                    total_index += len(data)
                    if internal_data_type == 'stream':
                        for col in train_data:
                            w.write(f'{user} {col} 1\n')
                        for col in vali_data:
                            vali_lines.append(f'{user} {col} {val}')
                    else:
                        for col, val in Counter(train_data).items():
                            w.write(f'{user} {col} {val}\n')
                        for col, val in Counter(vali_data).items():
                            vali_lines.append(f'{user} {col} {val}')
                    if with_sppmi:
                        sz = len(train_data)
                        for i in range(sz):
                            beg, end = i + 1, i + windows + 1
                            for j in range(beg, end):
                                if j >= sz:
                                    break
                                _w, _c = train_data[i], train_data[j]
                                w_sppmi.write(f'{_w} {_c}\n')
                                w_sppmi.write(f'{_c} {_w}\n')
                if with_sppmi:
                    w_sppmi.close()
                    return w.name, vali_lines, w_sppmi.name
                return w.name, vali_lines, None