Example #1
0
class DataLoaderTrain(IterableDataset):
    '''
    DataLoader used for training with producer-consumer architecture.
    - Dynamic Batching
    - Generate batch for two-stage encoding
    '''
    def __init__(self,
                 args,
                 data_files,
                 news_idx_incache,
                 prefetch_step,
                 prefetch_step2,
                 end,
                 local_rank,
                 world_size,
                 news_features,
                 enable_prefetch=True,
                 enable_prefetch_stream=False,
                 global_step=0):
        '''
        Args:
            args: parameters
            data_files(shared list): the paths of train data, storaged in a shared list
            news_idx_incache(shared dict): {news_id:(index in cache, encoded step)}
            prefetch_step(shared list): sync the dataloaders
            prefetch_step2(shared list): avoid to skip the data of last step
            end(shared bool value): If it is True, stop all data processes
            local_rank(int): The rank of current process
            world_size(int): The number of processes
            news_features(dict):{news_id:(segments_ids, segments_mask, key_position, key_frequence, elements)}
        '''
        self.args = args
        self.beta_for_cache = args.beta_for_cache
        self.data_files = data_files
        self.news_idx_incache = news_idx_incache
        self.prefetch_step = prefetch_step
        self.prefetch_step2 = prefetch_step2
        self.end = end
        self.local_rank = local_rank
        self.world_size = world_size
        self.news_features = news_features
        self.enable_prefetch = enable_prefetch
        self.enable_prefetch_stream = enable_prefetch_stream
        self.global_step = global_step

    def __iter__(self):
        """Implement IterableDataset method to provide data iterator."""
        if self.enable_prefetch:
            self.start_async()
        else:
            self.outputs = self.dynamic_batch().__iter__()
        return self

    def start_async(self):
        logging.info('start async...')
        self.aval_count = 0
        self.end.value = False
        self.prefetch_step[self.local_rank] = 0
        self.outputs = Queue(10)
        self.pool = ThreadPoolExecutor(1)
        self.pool.submit(self._produce)

    def __next__(self):
        if self.enable_prefetch:
            if self.end.value and self.aval_count == 0:
                raise StopIteration
            next_batch = self.outputs.get()
            self.outputs.task_done()
            self.aval_count -= 1
            return next_batch
        else:
            next_data = self.outputs.__next__()
            self.sync_prefetch_step(self.prefetch_step)
            if self.end.value:
                raise StopIteration
            return next_data

    def _produce(self):
        try:
            for address_cache,update_cache,batch in self.dynamic_batch():
                self.sync_prefetch_step(self.prefetch_step)
                if self.end.value:
                    break
                self.sync_prefetch_step(self.prefetch_step2) # Avoid to discard the data of last batch

                self.outputs.put((address_cache, update_cache, batch))
                self.aval_count += 1
            self.pool.shutdown(wait=False)
            raise
        except:
            error_type, error_value, error_trace = sys.exc_info()
            traceback.print_tb(error_trace)
            logging.info(error_value)
            self.pool.shutdown(wait=False)
            raise

    def dynamic_batch(self):
        '''
        Each training case will be routed to a bucket based on its max sequence length.
        The buckets will be checked before each insert-in: that whether it is filled.
        Once a bucket is filled, the filled transactions will be generated as a mini-batch and appended to the mini-batch queue.
        '''
        # Buffer Buckets
        logging.info('init bucket')
        blocks = [[]for x in range(self.args.bucket_num)]   #hitory_id, neg_id
        block_encode_set = [set() for x in range(self.args.bucket_num)]
        block_cache_set = [set() for x in range(self.args.bucket_num)]
        block_max_length = [0 for x in range(self.args.bucket_num)]
        block_space = [(self.args.seg_length // self.args.bucket_num) * i for i in range(self.args.bucket_num)]

        self.use_cache = False

        if self.args.enable_gpu:
            torch.cuda.set_device(self.local_rank)

        self.sampler = StreamSamplerTrain(data_files=self.data_files)
        if self.enable_prefetch_stream:
            self.sampler_batch = self.sampler
        else:
            self.sampler_batch = self.sampler._generate_batch()

        for one_user in self.sampler_batch:
            news_set, history, negs = self._process(one_user)
            cache_set, encode_set = self.split_news_set(news_set, self.use_cache)
            max_len = 0
            if len(encode_set) > 0:
                max_len = max([len(self.news_features[nid][0][0]) if nid in self.news_features else 0 for nid in encode_set])

            for i in range(self.args.bucket_num-1,-1,-1):
                if max_len > block_space[i]:
                    if (max(block_max_length[i],max_len)+self.args.bus_num)*len(block_encode_set[i] | encode_set)*self.args.seg_num > self.args.batch_size:
                        if len(block_encode_set[i]) == 0:
                            break
                        address_cache,update_cache,batch = self.gen_batch_for_two_stage(block_encode_set[i],block_cache_set[i],blocks[i],block_max_length[i],self.global_step)
                        self.global_step += 1

                        yield address_cache,update_cache,batch

                        block_encode_set[i] = set();block_cache_set[i] = set();blocks[i]=[];block_max_length[i]=0
                        self.update_use_cache()

                    block_max_length[i] = max(block_max_length[i],max_len)
                    block_encode_set[i] = block_encode_set[i] | encode_set
                    block_cache_set[i] = block_cache_set[i] | cache_set
                    blocks[i].append((history,negs))
                    break
        self.end.value = True

    def sync_prefetch_step(self, prefetch_step):
        prefetch_step[self.local_rank] += 1
        while sum(prefetch_step) != prefetch_step[self.local_rank] * self.world_size:
            if self.end.value: break

    def drop_encoder_prob(self, step):
        return 1 - math.exp(-step*self.beta_for_cache)

    def update_use_cache(self):
        if random.random() < self.drop_encoder_prob(self.global_step):
            self.use_cache = True
        else:
            self.use_cache = False

    def split_news_set(self,news_set,use_cache):
        '''
        For each news article, the dataloader will check the cache in the first place:
        if there is a copy of news embedding in cache, it will outputs the index of news in the cache
        otherwise, it will outputs the features of this news as imputs to news encoder.
        '''
        if use_cache:
            cache_set = set()
            encode_set = set()
            for n in news_set:
                if n == 'MISS':
                    continue
                if self.global_step - self.news_idx_incache[n][1] <= self.args.max_step_in_cache:
                    cache_set.add(n)
                else:
                    encode_set.add(n)

            return cache_set,encode_set
        else:
            news_set.discard('MISS')
            return set(), news_set


    def gen_batch_for_two_stage(self,encode_set,cache_set,data,max_len,global_step):
        '''
        Once a mini-batch is presented, it will gather all of the news articles from different users.
        '''
        news_index = {'MISS': 0}
        idx = 1

        if len(cache_set)==0:
            address_cache = None
        else:
            address_cache = []
            for n in cache_set:
                address_cache.append(self.news_idx_incache[n][0])
                news_index[n] = idx
                idx += 1
            address_cache = np.array(address_cache)

        update_cache = []
        segments = []
        token_masks = []
        seg_masks = []
        key_position = []
        fre_ids = []
        elements = []
        for n in encode_set:
            news_index[n] = idx
            idx += 1

            tokens,s_mask,positions,fre_cnt, elem = self.news_features[n]
            for i in range(self.args.seg_num):
                text,mask = self.pad_to_fix_len(tokens[i],max_len,padding_front=False)
                segments.append(text)
                token_masks.append(mask)
                if self.args.content_refinement:
                    position,_ = self.pad_to_fix_len(positions[i],max_len,padding_front=False)
                    key_position.append(position)
                    fre,_ = self.pad_to_fix_len(fre_cnt[i],max_len,padding_front=False)
                    fre = [min(x,self.args.max_keyword_freq-1) for x in fre]
                    fre_ids.append(fre)

            seg_masks.append(s_mask[:self.args.seg_num])
            elements.append(elem)

            # update cache
            update_cache.append(self.news_idx_incache[n][0])
            self.news_idx_incache[n] = [self.news_idx_incache[n][0],global_step]

        batch_hist = []
        batch_negs = []
        batch_mask = []
        max_hist_len = max([len(x[0]) for x in data])
        for history,negs in data:
            history = self.trans_to_nindex(history,news_index)
            history,mask = self.pad_to_fix_len(history,max_hist_len)
            batch_hist.append(history)
            batch_mask.append(mask)

            temp_negs = [self.trans_to_nindex(n,news_index) for n in negs]
            temp_negs = self.pad_to_fix_len_neg(temp_negs,max_hist_len-1)
            batch_negs.append(temp_negs)

        if self.args.enable_gpu:
            segments = torch.LongTensor(segments).cuda()
            token_masks = torch.FloatTensor(token_masks).cuda()
            seg_masks = torch.FloatTensor(seg_masks).cuda()
            if self.args.content_refinement:
                key_position = torch.LongTensor(key_position).cuda()
                fre_ids = torch.LongTensor(fre_ids).cuda()
            else:
                key_position, fre_ids = None,None
            elements = torch.LongTensor(elements).cuda()
            batch_hist = torch.LongTensor(batch_hist).cuda()
            batch_negs = torch.LongTensor(batch_negs).cuda()
            batch_mask = torch.FloatTensor(batch_mask).cuda()
        else:
            segments = torch.LongTensor(segments)
            token_masks = torch.FloatTensor(token_masks)
            seg_masks = torch.FloatTensor(seg_masks)
            if self.args.content_refinement:
                key_position = torch.LongTensor(key_position)
                fre_ids = torch.LongTensor(fre_ids)
            else:
                key_position, fre_ids = None, None
            elements = torch.LongTensor(elements)
            batch_hist = torch.LongTensor(batch_hist)
            batch_negs = torch.LongTensor(batch_negs)
            batch_mask = torch.FloatTensor(batch_mask)

        return address_cache,np.array(update_cache),(segments,token_masks,seg_masks,key_position,fre_ids, elements,batch_hist,batch_mask,batch_negs)

    def trans_to_nindex(self, nids,news_index):
        return [news_index[i] for i in nids]

    def pad_to_fix_len(self, x, fix_length, padding_value=0, padding_front=True):
        if padding_front:
            pad_x = x[-fix_length:] + [padding_value] * (fix_length - len(x))
            mask = [1] * min(fix_length, len(x)) + [0] * (fix_length - len(x))
        else:
            pad_x = x[:fix_length] + [padding_value] * (fix_length - len(x))
            mask = [1] * min(fix_length, len(x)) + [0] * (fix_length-len(x))
        return pad_x,mask

    def pad_to_fix_len_neg(self, x, fix_length, padding_value=0,padding_front=True):
        if padding_front:
            pad_x = x[-fix_length:] + [[padding_value] * self.args.npratio] * (fix_length - len(x))
        else:
            pad_x = x[:fix_length] + [[padding_value] * self.args.npratio] * (fix_length - len(x))
        return pad_x

    def _process(self, line):
        clicked = []
        negnews = []
        u_set = []

        uid, sessions = line.strip().split('\t')
        for sess in sessions.split('|'):
            pos, neg = sess.split('&')
            pos = [p if p in self.news_features else 'MISS' for p in pos.split(';')]
            clicked.extend(pos)

            neg = neg.split(';')
            for p in pos:
                if len(neg) < self.args.npratio:
                    neg = neg*(int(self.args.npratio/len(neg))+1)
                sample_neg = [n if n in self.news_features else 'MISS' for n in random.sample(neg, self.args.npratio)]
                negnews.append(sample_neg)

        clicked = clicked[-self.args.user_log_length:]
        negnews = negnews[-(self.args.user_log_length-1):]

        for p in clicked:
            u_set.append(p)
        for ns in negnews:
            u_set.extend(ns)

        u_set = set(u_set)
        return u_set,clicked,negnews