Esempio n. 1
0
 def predict_level(self, level, test_x, k, labels_num):
     data_cnf, model_cnf = self.data_cnf, self.model_cnf
     model = self.models.get(level, None)
     if level == 0:
         logger.info(F'Predicting Level-{level}, Top: {k}')
         if model is None:
             model = Model(AttentionRNN, labels_num=labels_num, model_path=F'{self.model_path}-Level-{level}',
                           emb_init=self.emb_init, **data_cnf['model'], **model_cnf['model'])
         test_loader = DataLoader(MultiLabelDataset(test_x), model_cnf['predict']['batch_size'],
                                  num_workers=4)
         return model.predict(test_loader, k=k)
     else:
         if level == self.level - 1:
             groups = np.load(F'{self.groups_path}-Level-{level-1}.npy')
         else:
             groups = self.get_inter_groups(labels_num)
         group_scores, group_labels = self.predict_level(level - 1, test_x, self.top, len(groups))
         torch.cuda.empty_cache()
         logger.info(F'Predicting Level-{level}, Top: {k}')
         if model is None:
             model = XMLModel(network=FastAttentionRNN, labels_num=labels_num,
                              model_path=F'{self.model_path}-Level-{level}',
                              emb_init=self.emb_init, **data_cnf['model'], **model_cnf['model'])
         test_loader = DataLoader(XMLDataset(test_x, labels_num=labels_num,
                                             groups=groups, group_labels=group_labels, group_scores=group_scores),
                                  model_cnf['predict']['batch_size'], num_workers=4)
         return model.predict(test_loader, k=k)
Esempio n. 2
0
    def train_level(self, level, train_x, train_y, valid_x, valid_y):
        model_cnf, data_cnf = self.model_cnf, self.data_cnf
        if level == 0:
            while not os.path.exists(F'{self.groups_path}-Level-{level}.npy'):
                time.sleep(30)
            groups = np.load(F'{self.groups_path}-Level-{level}.npy')
            train_y, valid_y = self.get_mapping_y(groups, self.labels_num,
                                                  train_y, valid_y)
            labels_num = len(groups)
            train_loader = DataLoader(MultiLabelDataset(train_x, train_y),
                                      model_cnf['train'][level]['batch_size'],
                                      num_workers=4,
                                      shuffle=True)
            valid_loader = DataLoader(MultiLabelDataset(valid_x,
                                                        valid_y,
                                                        training=False),
                                      model_cnf['valid']['batch_size'],
                                      num_workers=4)
            model = Model(AttentionRNN,
                          labels_num=labels_num,
                          model_path=F'{self.model_path}-Level-{level}',
                          emb_init=self.emb_init,
                          **data_cnf['model'],
                          **model_cnf['model'])
            if not os.path.exists(model.model_path):
                logger.info(
                    F'Training Level-{level}, Number of Labels: {labels_num}')
                model.train(train_loader, valid_loader,
                            **model_cnf['train'][level])
                model.optimizer = None
                logger.info(F'Finish Training Level-{level}')

            self.models[level] = model

            logger.info(F'Generating Candidates for Level-{level+1}, '
                        F'Number of Labels: {labels_num}, Top: {self.top}')
            train_loader = DataLoader(MultiLabelDataset(train_x),
                                      model_cnf['valid']['batch_size'],
                                      num_workers=4)
            return train_y, model.predict(
                train_loader, k=self.top), model.predict(valid_loader,
                                                         k=self.top)
        else:
            train_group_y, train_group, valid_group = self.train_level(
                level - 1, train_x, train_y, valid_x, valid_y)
            torch.cuda.empty_cache()

            logger.info('Getting Candidates')
            _, group_labels = train_group
            group_candidates = np.empty((len(train_x), self.top), dtype=np.int)
            for i, labels in tqdm(enumerate(group_labels),
                                  leave=False,
                                  desc='Parents'):
                ys, ye = train_group_y.indptr[i], train_group_y.indptr[i + 1]
                positive = set(train_group_y.indices[ys:ye])
                if self.top >= len(positive):
                    candidates = positive
                    for la in labels:
                        if len(candidates) == self.top:
                            break
                        if la not in candidates:
                            candidates.add(la)
                else:
                    candidates = set()
                    for la in labels:
                        if la in positive:
                            candidates.add(la)
                        if len(candidates) == self.top:
                            break
                    if len(candidates) < self.top:
                        candidates = (list(candidates) +
                                      list(positive - candidates))[:self.top]
                group_candidates[i] = np.asarray(list(candidates))

            if level < self.level - 1:
                while not os.path.exists(
                        F'{self.groups_path}-Level-{level}.npy'):
                    time.sleep(30)
                groups = np.load(F'{self.groups_path}-Level-{level}.npy')
                train_y, valid_y = self.get_mapping_y(groups, self.labels_num,
                                                      train_y, valid_y)
                labels_num, last_groups = len(groups), self.get_inter_groups(
                    len(groups))
            else:
                groups, labels_num = None, train_y.shape[1]
                last_groups = np.load(
                    F'{self.groups_path}-Level-{level-1}.npy')

            train_loader = DataLoader(XMLDataset(
                train_x,
                train_y,
                labels_num=labels_num,
                groups=last_groups,
                group_labels=group_candidates),
                                      model_cnf['train'][level]['batch_size'],
                                      num_workers=4,
                                      shuffle=True)
            group_scores, group_labels = valid_group
            valid_loader = DataLoader(XMLDataset(valid_x,
                                                 valid_y,
                                                 training=False,
                                                 labels_num=labels_num,
                                                 groups=last_groups,
                                                 group_labels=group_labels,
                                                 group_scores=group_scores),
                                      model_cnf['valid']['batch_size'],
                                      num_workers=4)
            model = XMLModel(network=FastAttentionRNN,
                             labels_num=labels_num,
                             emb_init=self.emb_init,
                             model_path=F'{self.model_path}-Level-{level}',
                             **data_cnf['model'],
                             **model_cnf['model'])
            if not os.path.exists(model.model_path):
                logger.info(
                    F'Loading parameters of Level-{level} from Level-{level-1}'
                )
                last_model = self.get_last_models(level - 1)
                model.network.module.emb.load_state_dict(
                    last_model.module.emb.state_dict())
                model.network.module.lstm.load_state_dict(
                    last_model.module.lstm.state_dict())
                model.network.module.linear.load_state_dict(
                    last_model.module.linear.state_dict())
                logger.info(
                    F'Training Level-{level}, '
                    F'Number of Labels: {labels_num}, '
                    F'Candidates Number: {train_loader.dataset.candidates_num}'
                )
                model.train(train_loader, valid_loader,
                            **model_cnf['train'][level])
                model.optimizer = model.state = None
                logger.info(F'Finish Training Level-{level}')
            self.models[level] = model
            if level == self.level - 1:
                return
            logger.info(F'Generating Candidates for Level-{level+1}, '
                        F'Number of Labels: {labels_num}, Top: {self.top}')
            group_scores, group_labels = train_group
            train_loader = DataLoader(XMLDataset(train_x,
                                                 labels_num=labels_num,
                                                 groups=last_groups,
                                                 group_labels=group_labels,
                                                 group_scores=group_scores),
                                      model_cnf['valid']['batch_size'],
                                      num_workers=4)
            return train_y, model.predict(
                train_loader, k=self.top), model.predict(valid_loader,
                                                         k=self.top)