class OLTR_For_Textcnn(nn.Module):
    def __init__(self,
                 pretrained_model_path,
                 vocabulary_size,
                 filter_sizes,
                 filter_num,
                 data=None,
                 train=False,
                 cuda=1):
        super(OLTR_For_Textcnn, self).__init__()
        self.device = torch.device(
            'cuda:%d' % cuda if torch.cuda.is_available() else 'cpu')
        self.textcnn = TextCNN(vocabulary_size=vocabulary_size,
                               class_num=183,
                               filter_num=filter_num,
                               filter_sizes=filter_sizes,
                               embedding_dim=128)
        checkpoint = torch.load(pretrained_model_path,
                                map_location=self.device)
        self.textcnn.load_state_dict(checkpoint)
        self.textcnn = self.textcnn.to(self.device)
        # fix all param in textcnn when training OLTR
        for param_name, param in self.textcnn.named_parameters():
            param.requires_grad = False
        self.textcnn.eval()
        self.classes_num = 183
        self.feature_dim = len(filter_sizes.split(",")) * filter_num

        self.classifier = OLTR_classifier(self.feature_dim, self.classes_num)

        self.centroids = nn.Parameter(
            torch.randn(self.classes_num, self.feature_dim))
        if train and data is not None:
            print("update centroid with data")
            self.centroids.data = self.centroids_cal(data)
        elif train and data is None:
            raise ValueError("Train mode should update centroid with data")
        else:
            print("Test mode should load pretrained centroid")

    def forward(self, x, *args):
        feature = self.textcnn.extract_feature(x)
        logits, _ = self.classifier(feature, self.centroids)
        return logits, feature

    def class_count(self, data):
        labels = np.array([int(ex.label) for ex in data.dataset])
        class_data_num = []
        for l in range(self.classes_num):
            class_data_num.append(len(labels[labels == l]))
            if class_data_num[-1] == 0:
                class_data_num[-1] = 1
        return class_data_num

    def centroids_cal(self, data):

        centroids = torch.zeros(self.classes_num,
                                self.feature_dim).to(self.device)

        print('Calculating centroids.')

        # for model in self.networks.values():
        #     model.eval()
        self.textcnn.eval()

        # Calculate initial centroids only on training data.
        with torch.set_grad_enabled(False):

            for batch in data:
                inputs, labels = batch.text, batch.label
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                # Calculate Features of each training data
                features = self.textcnn.extract_feature(inputs)
                # Add all calculated features to center tensor
                for i in range(len(labels)):
                    label = labels[i]
                    centroids[label] += features[i]

        # Average summed features with class count
        centroids /= torch.Tensor(
            self.class_count(data)).float().unsqueeze(1).to(self.device)

        return centroids
Пример #2
0
    samples.append([])

for batch in train_iter:
    for i,label in enumerate(batch.label.numpy().tolist()):
        samples[label].append(batch.text)


max_points=100
select_classes=["1,3,4,57,79,136,0,9,35,51","74,88,159,129,121,123,148,165,169,178","1,57,64,59,75","55,68,60,100,106"]

for use_classes in select_classes:
    classes=[int(x) for x in use_classes.split(",")]
    use_label=[]
    use_feature=[]
    for label in classes:
        use_points = min(max_points,len(samples[label]))
        use_label=use_label+[label]*use_points
        for i in range(use_points):
            feature = textcnn.extract_feature(samples[label][i].to(device))
            feature = feature.squeeze(0).detach().cpu().numpy().tolist()
            use_feature.append(feature)

    use_feature = np.array(use_feature)
    use_label = np.array(use_label)
    feature_tsne = TSNE(n_components=2, random_state=33).fit_transform(use_feature)
    plt.figure(figsize=(5, 5))
    plt.scatter(feature_tsne[:, 0], feature_tsne[:, 1], c=use_label, label="t-SNE")
    plt.legend()
    plt.savefig(pic_dir+"class%s.png"%use_classes)