Ejemplo n.º 1
0
def Labels2PrecisionRecall(labels, cols):
    epsilon = 1e-30
    y_pred, y_true = labels
    num_classes = len(cols)
    y_pred = to_onehot(y_pred, num_classes)
    y_true = to_onehot(y_true, num_classes)
    tp = (y_pred * y_true).sum(0)
    pred = y_pred.sum(0)
    true = y_true.sum(0)
    precision = tp / (pred + epsilon)
    recall = tp / (true + epsilon)
    return PrecisionRecallTable(precision, recall, cols)
Ejemplo n.º 2
0
    def update(self, output: Sequence[torch.Tensor]) -> None:
        self._check_shape(output)
        self._check_type(output)
        y_pred, y = output[0].detach(), output[1].detach()

        if self._type == "binary":
            y_pred = y_pred.view(-1)
            y = y.view(-1)
        elif self._type == "multiclass":
            num_classes = y_pred.size(1)
            if y.max() + 1 > num_classes:
                raise ValueError(
                    f"y_pred contains less classes than y. Number of predicted classes is {num_classes}"
                    f" and element in y has invalid class = {y.max().item() + 1}."
                )
            y = to_onehot(y.view(-1), num_classes=num_classes)
            indices = torch.argmax(y_pred, dim=1).view(-1)
            y_pred = to_onehot(indices, num_classes=num_classes)
        elif self._type == "multilabel":
            # if y, y_pred shape is (N, C, ...) -> (C, N x ...)
            num_classes = y_pred.size(1)
            y_pred = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1)
            y = torch.transpose(y, 1, 0).reshape(num_classes, -1)

        # Convert from int cuda/cpu to double on self._device
        y_pred = y_pred.to(dtype=torch.float64, device=self._device)
        y = y.to(dtype=torch.float64, device=self._device)
        correct = y * y_pred
        actual_positives = y.sum(dim=0)

        if correct.sum() == 0:
            true_positives = torch.zeros_like(actual_positives)
        else:
            true_positives = correct.sum(dim=0)

        if self._type == "multilabel":
            if not self._average:
                self._true_positives = torch.cat(
                    [self._true_positives, true_positives],
                    dim=0)  # type: torch.Tensor
                self._positives = torch.cat(
                    [self._positives, actual_positives],
                    dim=0)  # type: torch.Tensor
            else:
                self._true_positives += torch.sum(
                    true_positives / (actual_positives + self.eps))
                self._positives += len(actual_positives)
        else:
            self._true_positives += true_positives
            self._positives += actual_positives

        self._updated = True
Ejemplo n.º 3
0
    def update(self, output):
        y_pred, y = output
        self._check_shape(output)
        self._check_type((y_pred, y))

        if self._type == "binary":
            y_pred = y_pred.view(-1)
            y = y.view(-1)
        elif self._type == "multiclass":
            num_classes = y_pred.size(1)
            if y.max() + 1 > num_classes:
                raise ValueError(
                    "y_pred contains less classes than y. Number of predicted classes is {}"
                    " and element in y has invalid class = {}.".format(
                        num_classes,
                        y.max().item() + 1))
            y = to_onehot(y.view(-1), num_classes=num_classes)
            indices = torch.argmax(y_pred, dim=1).view(-1)
            y_pred = to_onehot(indices, num_classes=num_classes)
        elif self._type == "multilabel":
            # if y, y_pred shape is (N, C, ...) -> (C, N x ...)
            num_classes = y_pred.size(1)
            y_pred = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1)
            y = torch.transpose(y, 1, 0).reshape(num_classes, -1)

        y = y.type_as(y_pred)
        correct = y * y_pred
        actual_positives = y.sum(dim=0).type(
            torch.DoubleTensor)  # Convert from int cuda/cpu to double cpu

        if correct.sum() == 0:
            true_positives = torch.zeros_like(actual_positives)
        else:
            true_positives = correct.sum(dim=0)

        # Convert from int cuda/cpu to double cpu
        # We need double precision for the division true_positives / actual_positives
        true_positives = true_positives.type(torch.DoubleTensor)

        if self._type == "multilabel":
            if not self._average:
                self._true_positives = torch.cat(
                    [self._true_positives, true_positives], dim=0)
                self._positives = torch.cat(
                    [self._positives, actual_positives], dim=0)
            else:
                self._true_positives += torch.sum(
                    true_positives / (actual_positives + self.eps))
                self._positives += len(actual_positives)
        else:
            self._true_positives += true_positives
            self._positives += actual_positives
Ejemplo n.º 4
0
    def _test_NC():
        num_classes = 4
        cm = ConfusionMatrix(num_classes=num_classes)
        y_pred = torch.rand(10, num_classes)
        y_labels = torch.randint(0, num_classes, size=(10,)).type(torch.LongTensor)
        y = to_onehot(y_labels, num_classes=num_classes)
        cm.update((y_pred, y))
        np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
        np_y = y_labels.numpy().ravel()
        assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

        num_classes = 10
        cm = ConfusionMatrix(num_classes=num_classes)
        y_pred = torch.rand(4, num_classes)
        y_labels = torch.randint(0, num_classes, size=(4, )).type(torch.LongTensor)
        y = to_onehot(y_labels, num_classes=num_classes)
        cm.update((y_pred, y))
        np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
        np_y = y_labels.numpy().ravel()
        assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

        # 2-classes
        num_classes = 2
        cm = ConfusionMatrix(num_classes=num_classes)
        y_pred = torch.rand(4, num_classes)
        y_labels = torch.randint(0, num_classes, size=(4,)).type(torch.LongTensor)
        y = to_onehot(y_labels, num_classes=num_classes)
        cm.update((y_pred, y))
        np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
        np_y = y_labels.numpy().ravel()
        assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())

        # Batched Updates
        num_classes = 5
        cm = ConfusionMatrix(num_classes=num_classes)

        y_pred = torch.rand(100, num_classes)
        y_labels = torch.randint(0, num_classes, size=(100,)).type(torch.LongTensor)
        y = to_onehot(y_labels, num_classes=num_classes)

        batch_size = 16
        n_iters = y.shape[0] // batch_size + 1

        for i in range(n_iters):
            idx = i * batch_size
            cm.update((y_pred[idx: idx + batch_size], y[idx: idx + batch_size]))

        np_y = y_labels.numpy().ravel()
        np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
        assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())
Ejemplo n.º 5
0
    def update(self, output):
        y_pred, y = self._check_shape(output)

        _, indices = torch.max(y_pred, dim=1)
        y_pred_ohe = to_onehot(indices.reshape(-1), self.num_classes)
        y_ohe = to_onehot(y.reshape(-1), self.num_classes)

        y_ohe_t = y_ohe.transpose(0, 1).float()
        y_pred_ohe = y_pred_ohe.float()

        if self.confusion_matrix.type() != y_ohe_t.type():
            self.confusion_matrix = self.confusion_matrix.type_as(y_ohe_t)

        self.confusion_matrix += (y_ohe_t @ y_pred_ohe).float()
        self._num_examples += y_pred.shape[0]
Ejemplo n.º 6
0
 def forward(self, input, target):
   target_onehot = to_onehot(target, num_classes=input.shape[1]).to(device=device)
   mse = (input - target_onehot) ** 2
   if self.class_weights is not None:
     weights = self.class_weights[target] * input.shape[1]
     return (mse.sum(1) * weights).sum()
   else:
     return mse.sum()
Ejemplo n.º 7
0
    def transform_fn(output):
        _, y_pred, y_true = output
        
        # print("ORIGINAL: ", torch.round(y_pred[:, label_index]).long())
        y_pred = to_onehot(torch.round(y_pred[:, label_index]).long(), num_classes)
        y_true = y_true[:, label_index]
        
        # print("TO: ", y_pred)

        return y_pred, y_true
Ejemplo n.º 8
0
def test_to_onehot():
    indices = torch.tensor([0, 1, 2, 3], dtype=torch.long)
    actual = to_onehot(indices, 4)
    expected = torch.eye(4, dtype=torch.uint8)
    assert actual.equal(expected)

    y = torch.randint(0, 21, size=(1000,))
    y_ohe = to_onehot(y, num_classes=21)
    y2 = torch.argmax(y_ohe, dim=1)
    assert y.equal(y2)

    y = torch.randint(0, 21, size=(4, 250, 255))
    y_ohe = to_onehot(y, num_classes=21)
    y2 = torch.argmax(y_ohe, dim=1)
    assert y.equal(y2)

    y = torch.randint(0, 21, size=(4, 150, 155, 4, 6))
    y_ohe = to_onehot(y, num_classes=21)
    y2 = torch.argmax(y_ohe, dim=1)
    assert y.equal(y2)

    # Test with `TorchScript`

    x = torch.tensor([0, 1, 2, 3])

    # Test the raw `to_onehot` function
    scripted_to_onehot = torch.jit.script(to_onehot)
    assert scripted_to_onehot(x, 4).allclose(to_onehot(x, 4))

    # Test inside `torch.nn.Module`
    class SLP(torch.nn.Module):
        def __init__(self):
            super(SLP, self).__init__()
            self.linear = torch.nn.Linear(4, 1)

        def forward(self, x):
            x = to_onehot(x, 4)
            return self.linear(x.to(torch.float))

    eager_model = SLP()
    scripted_model = torch.jit.script(eager_model)

    assert eager_model(x).allclose(scripted_model(x))
Ejemplo n.º 9
0
def test_to_onehot():
    indices = torch.tensor([0, 1, 2, 3], dtype=torch.long)
    actual = to_onehot(indices, 4)
    expected = torch.eye(4, dtype=torch.uint8)
    assert actual.equal(expected)

    y = torch.randint(0, 21, size=(1000, ))
    y_ohe = to_onehot(y, num_classes=21)
    y2 = torch.argmax(y_ohe, dim=1)
    assert y.equal(y2)

    y = torch.randint(0, 21, size=(4, 250, 255))
    y_ohe = to_onehot(y, num_classes=21)
    y2 = torch.argmax(y_ohe, dim=1)
    assert y.equal(y2)

    y = torch.randint(0, 21, size=(4, 150, 155, 4, 6))
    y_ohe = to_onehot(y, num_classes=21)
    y2 = torch.argmax(y_ohe, dim=1)
    assert y.equal(y2)
Ejemplo n.º 10
0
    def forward(self, input, target):
        y = to_onehot(target, input.size(-1))
        logit = F.softmax(input, dim=-1)
        logit = logit.clamp(self.eps, 1. - self.eps)

        loss = -1 * y * torch.log(logit)  # cross entropy
        loss = loss * (1 - logit)**self.gamma  # focal loss
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()
Ejemplo n.º 11
0
def output_transform_seg(process_output):
    """
    Output transform for segmentation metrics.
    """

    y_pred = process_output[0]['out'].argmax(dim=1)  # (B, W, H)
    y = process_output[1]  # (B, W, H)
    y_pred_ = y_pred.view(-1)  # B, (W*H)
    y_ = y.view(-1)
    y_pred_one_hot = to_onehot(y_pred_, num_classes=NUM_CLASSES)
    return dict(y_pred=y_pred_one_hot, y=y_)  # output format is according to `DiceCoefficient` docs
Ejemplo n.º 12
0
    def update(self, output):
        y_pred, y = self._check_shape(output)
        self._check_type((y_pred, y))

        if self._type == "binary":
            y_pred = y_pred.view(-1)
            y = y.view(-1)
        elif self._type == "multiclass":
            num_classes = y_pred.size(1)
            y = to_onehot(y.view(-1), num_classes=num_classes)
            indices = torch.max(y_pred, dim=1)[1].view(-1)
            y_pred = to_onehot(indices, num_classes=num_classes)
        elif self._type == "multilabel":
            # if y, y_pred shape is (N, C, ...) -> (C, N x ...)
            num_classes = y_pred.size(1)
            y_pred = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1)
            y = torch.transpose(y, 1, 0).reshape(num_classes, -1)

        y = y.type_as(y_pred)
        correct = y * y_pred
        actual_positives = y.sum(dim=0).type(
            torch.DoubleTensor)  # Convert from int cuda/cpu to double cpu

        if correct.sum() == 0:
            true_positives = torch.zeros_like(actual_positives)
        else:
            true_positives = correct.sum(dim=0)

        # Convert from int cuda/cpu to double cpu
        # We need double precision for the division true_positives / actual_positives
        true_positives = true_positives.type(torch.DoubleTensor)

        if self._type == "multilabel":
            self._true_positives = torch.cat(
                [self._true_positives, true_positives], dim=0)
            self._positives = torch.cat([self._positives, actual_positives],
                                        dim=0)
        else:
            self._true_positives += true_positives
            self._positives += actual_positives
 def forward(self, input, target):
     softmax = torch.exp(input) / torch.exp(input).sum(1)[:, None]
     onehot_labels = to_onehot(target, input.shape[1])
     soft_labels = torch.zeros_like(onehot_labels)
     soft_labels = torch.where(
         onehot_labels.cpu() == 1, torch.tensor([0.9]),
         torch.tensor([0.1 / (input.shape[1] - 1)])).to(device=device)
     if self.class_weights is not None:
         # print(soft_labels.shape, softmax.shape)
         loss = -torch.sum(
             torch.log(softmax) * soft_labels * self.class_weights *
             input.shape[1])
     else:
         loss = -torch.sum(torch.log(softmax) * soft_labels)
     return loss
Ejemplo n.º 14
0
    def train(loop: Loop):
        for _ in loop.iterate_epochs(NUM_EPOCHS):
            for x, y in loop.iterate_dataloader(train_loader, mode="train"):
                y_pred_logits = model(x)

                loss: torch.Tensor = criterion(y_pred_logits, y)
                loop.backward(loss)
                # Makes optimizer step and also
                # zeroes grad after (default)
                loop.optimizer_step(optim, zero_grad=True)

                # Here we call scheduler.step() every iteration
                # because we have one-cycle scheduler
                # we also can call it after all dataloader loop
                # if it's som usual scheduler
                scheduler.step()

                # Log learning rate. All metrics are written to tensorboard
                # with specified names
                # If iteration='auto' (default) its determined based on where the call is
                # performed. Here it will be batches
                loop.metrics.log("lr",
                                 scheduler.get_last_lr()[0],
                                 iteration="auto")

            # Loop disables gradients and calls Module.eval() inside loop
            # for all attached modules when mode="valid" (default)
            for x, y in loop.iterate_dataloader(valid_loader, mode="valid"):
                y_pred_logits: torch.Tensor = model(x)

                y_pred = to_onehot(y_pred_logits.argmax(dim=-1),
                                   num_classes=10)

                precision.update((y_pred, y))
                recall.update((y_pred, y))
                accuracy.update((y_pred, y))

            # This metrics will be epoch metrics because they are called outside
            # dataloader loop
            # Here we logging metric without resetting it
            loop.metrics.log("valid/precision", precision.compute().mean())
            loop.metrics.log("valid/recall", recall.compute().mean())

            # .log() method above accepts values (tensors, floats, np.array's)
            # .consume() accepts Metric object. It resets it after logging
            loop.metrics.consume("valid/f1", f1)
            loop.metrics.consume("valid/accuracy", accuracy)
Ejemplo n.º 15
0
    def __call__(self, output):
        if isinstance(output, tuple):
            y_pred, y = output
        elif isinstance(output, dict):
            y_pred = output["y_pred"]
            y = output["y"]
        else:
            raise ValueError

        if self._num_classes:
            y_pred = y_pred.clamp(min=0, max=self._num_classes - 1).long()
            y = y.clamp(min=0, max=self._num_classes - 1).long()
            y_pred = to_onehot(y_pred, self._num_classes)
        else:
            y_pred = y_pred.long()
            y = y.long()
        return y_pred, y
Ejemplo n.º 16
0
    def stats_collect_function(engine, batch):

        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)

        y_ohe = to_onehot(y.reshape(-1), config.num_classes)
        
        class_distrib = y_ohe.mean(dim=0).cpu()
        class_presence = (class_distrib > 1e-3).cpu().float()
        num_classes = (class_distrib > 1e-3).sum().item() 

        engine.state.class_presence += class_presence
        engine.state.class_presence -= (1 - class_presence)

        return {
            "class_distrib": class_distrib,
            "class_presence": engine.state.class_presence,
            "num_classes": num_classes
        }
Ejemplo n.º 17
0
    def forward(self, y_pred, y):
        y_pred = torch.softmax(y_pred, dim=1)

        b, c = y_pred.shape[0], y_pred.shape[1]
        if y_pred.ndim != y.ndim:
            input_shape = y_pred.shape
            input_shape = (input_shape[0], input_shape[2], input_shape[3])
            if input_shape == y.shape:
                y = to_onehot(y, num_classes=c).to(y_pred)
            else:
                raise ValueError("Shapes mismatch: {} vs {}".format(
                    y_pred.shape, y.shape))

        y_pred = y_pred.reshape(b, c, -1)
        y = y.reshape(b, c, -1)

        intersection = y_pred * y
        union = y_pred + y - intersection + 1e-10

        intersection = torch.sum(intersection, dim=-1)
        union = torch.sum(union, dim=-1)

        if self.ignore_index is not None:
            indices = list(range(c))
            indices.remove(self.ignore_index)
            intersection = intersection[:, indices]
            union = union[:, indices]

        if self.reduction == "mean":
            intersection = torch.mean(intersection)
            union = torch.mean(union)
        elif self.reduction == "sum":
            intersection = torch.sum(intersection)
            union = torch.sum(union)

        return 1.0 - intersection / union
Ejemplo n.º 18
0
 def forward(self, x):
     x = to_onehot(x, 4)
     return self.linear(x.to(torch.float))
Ejemplo n.º 19
0
def train():
    config_file = "configs/train_daily_dialog_full_pipeline_config.json"
    config = Config.from_json_file(config_file)

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if config.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", config.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(config))

    # Initialize distributed training if needed
    config.distributed = (config.local_rank != -1)
    if config.distributed:
        torch.cuda.set_device(config.local_rank)
        config.device = torch.device("cuda", config.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    model_checkpoint = "/home/rohola/codes/transfer-learning-conv-ai/logs/emotion_detection_log/"
    tokenizer_class = OpenAIGPTTokenizer
    tokenizer = tokenizer_class.from_pretrained(model_checkpoint)
    model_class = OpenAIGPTForEmotionDetection
    emotion_detection_model = model_class.from_pretrained(model_checkpoint)
    tokenizer.set_special_tokens(SPECIAL_TOKENS)
    emotion_detection_model.set_num_special_tokens(len(SPECIAL_TOKENS))
    emotion_detection_model.to(config.device)

    logger.info(
        "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning"
    )
    tokenizer_class = GPT2Tokenizer if "gpt2" in config.model_checkpoint else OpenAIGPTTokenizer
    tokenizer = tokenizer_class.from_pretrained(config.model_checkpoint)
    model_class = OpenAIGPTDoubleHeadLMEmotionRecognitionModel
    emotion_recognition_model = model_class.from_pretrained(
        config.model_checkpoint)
    tokenizer.set_special_tokens(SPECIAL_TOKENS)
    emotion_recognition_model.set_num_special_tokens(len(SPECIAL_TOKENS))
    emotion_recognition_model.to(config.device)
    optimizer = OpenAIAdam(emotion_recognition_model.parameters(),
                           lr=config.lr)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if config.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        emotion_recognition_model, optimizer = amp.initialize(
            emotion_recognition_model, optimizer, opt_level=config.fp16)
    if config.distributed:
        emotion_recognition_model = DistributedDataParallel(
            emotion_recognition_model,
            device_ids=[config.local_rank],
            output_device=config.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        config, tokenizer)

    emotion_detection_model.eval()
    n_emotions = 0
    num_correct = 0
    all_predicted_positives = 0
    all_true_positives = 0
    all_actual_positives = 0
    confusion_matrix = torch.zeros(6, 6, dtype=torch.float).cuda()
    num_all = len(val_loader)
    for batch in val_loader:
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(config.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids = batch
            model_outputs = emotion_detection_model(
                input_ids, mc_token_ids, token_type_ids=token_type_ids)
            lm_logits, mc_logits = model_outputs[0], model_outputs[
                1]  # So we can also use GPT2 outputs
            indices = torch.argmax(mc_logits, dim=1)
            if indices.item() != 0:  #have emotion
                recognition_output = emotion_recognition_model(
                    input_ids,
                    mc_token_ids,
                    token_type_ids=token_type_ids,
                    token_emotion_ids=token_emotion_ids)

                if mc_labels.item() != 0:
                    mc_labels = mc_labels - 1
                else:
                    continue
                    #mc_labels = torch.randint(0, 6, size=(1,)).cuda()

                mc_recognition_logit = recognition_output[1]
                indices = torch.argmax(mc_recognition_logit, dim=1)
                correct = torch.eq(indices, mc_labels).view(-1)
                num_correct += torch.sum(correct).item()
                n_emotions += 1

                #precision
                num_classes = mc_recognition_logit.size(1)
                print(mc_labels)
                mc_labels = to_onehot(mc_labels.view(-1),
                                      num_classes=num_classes)
                indices = torch.argmax(mc_recognition_logit, dim=1).view(-1)
                mc_recognition_logit = to_onehot(indices,
                                                 num_classes=num_classes)
                mc_labels = mc_labels.type_as(mc_recognition_logit)
                correct = mc_labels * mc_recognition_logit
                all_positives = mc_recognition_logit.sum(dim=0).type(
                    torch.DoubleTensor
                )  # Convert from int cuda/cpu to double cpu

                if correct.sum() == 0:
                    true_positives = torch.zeros_like(all_positives)
                else:
                    true_positives = correct.sum(dim=0)

                true_positives = true_positives.type(torch.DoubleTensor)
                all_predicted_positives += all_positives
                all_true_positives += true_positives

                #recall
                actual_positives = mc_labels.sum(dim=0).type(
                    torch.DoubleTensor)
                all_actual_positives += actual_positives

                #confusion matrix
                mc_labels_t = mc_labels.transpose(0, 1).float()
                mc_recognition_logit = mc_recognition_logit.float()
                confusion_matrix += torch.matmul(mc_labels_t,
                                                 mc_recognition_logit).float()

    print(num_correct / n_emotions)  # accuracy for all classes of emotion
    print(n_emotions / num_all)

    print(all_true_positives / all_predicted_positives)
    print(all_true_positives / all_actual_positives)

    print(confusion_matrix)
Ejemplo n.º 20
0
def test_to_onehot():
    indices = torch.LongTensor([0, 1, 2, 3])
    actual = to_onehot(indices, 4)
    expected = torch.eye(4)
    assert actual.equal(expected)
Ejemplo n.º 21
0
    def forward(self, inputs, labels=None):
        """
        :param inputs: [bsz, max_seq_leng]
        :param labels: [bsz, num_class]
        :return:
        """
        inputs = inputs.t()
        mask = (inputs > 0).float()
        inputs_len = (inputs > 0).int().sum(dim=0)

        hidden = self.encoder(inputs, mask, inputs_len)

        pool_values = []
        for pool in self.summary_type:
            if pool == 'max':
                val = max_pooling(hidden, mask)
                pool_values.append(val)
            elif pool == 'mean':
                val = mean_pooling(hidden, inputs_len, mask)
                pool_values.append(val)
            elif pool == 'first':
                seq_len, bsz, dim = hidden.size()
                val = hidden[0, :, :].view(bsz, -1).contiguous()
                pool_values.append(val)
            elif pool == 'last':
                seq_len, bsz, dim = hidden.size()
                val = hidden[-1, :, :].view(bsz, -1).contiguous()
                pool_values.append(val)
            elif pool == 'struct_att':
                val, att = self.strut_att(hidden, mask)
                bsz, head_num, dim = val.size()
                val = val.contiguous().view(bsz, -1)
                pool_values.append(val)
            elif pool == 'none':
                pool_values.append(hidden)

        if len(self.summary_type) == 1:
            hidden = pool_values[0]
        else:
            hidden = torch.cat(pool_values, dim=-1).contiguous()

        # [bsz, hid_dim]
        bsz, hid_dim = hidden.size()
        # logits = self.cls(self.dropout(hidden))
        hidden = self.normalize(hidden)
        logits = self.cls(hidden)

        if self.training:
            # Mixup
            indices = torch.randperm(bsz, device=logits.device)
            shuf_labels = torch.index_select(labels, 0, indices)
            shuf_hidden = torch.index_select(hidden, 0, indices)

            if self.mixup_type == 'mixup':
                lam = self.beta_dist.sample(sample_shape=(bsz, 1))
                lam = lam.to(inputs.device)
                lam_x, lam_y = lam, lam

            elif self.mixup_type == 'prior_mix':
                lam_x = self.beta_dist.sample(sample_shape=(bsz,))
                lam_x = lam_x.to(inputs.device)
                lam_y = self.prior_mixup(labels, shuf_labels)
                lam_y = 2. * lam_x * lam_y / (lam_x + lam_y)

            else:
                raise Exception('Unsupported mixup type %s' % self.mixup_type)

            mix_hidden = lam_x * hidden + (1 - lam_x) * shuf_hidden

            if not self.multi_label:
                onehot_label = to_onehot(labels, self.num_class)
                onehot_shuf_label = to_onehot(shuf_labels, self.num_class)
            else:
                onehot_label = labels
                onehot_shuf_label = shuf_labels

            lam_y = lam_y.unsqueeze(-1)
            mix_labels = lam_y * onehot_label + (1 - lam_y) * onehot_shuf_label

            mix_logits = self.cls(mix_hidden)

            return logits, mix_logits, mix_labels

        return logits, hidden
Ejemplo n.º 22
0
 def forward(self, input, target):
   target_onehot = to_onehot(target, num_classes=input.shape[1]).to(device=device)
   return nn.functional.mse_loss(input, target_onehot)
def train():
    config_file = "configs/train_daily_dialog_emotion_detection_config.json"
    config = Config.from_json_file(config_file)

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if config.local_rank in [-1, 0] else logging.WARN)
    logger.warning("Running process %d", config.local_rank)
    logger.info("Arguments: %s", pformat(config))

    # Initialize distributed training if needed
    config.distributed = (config.local_rank != -1)
    if config.distributed:
        torch.cuda.set_device(config.local_rank)
        config.device = torch.device("cuda", config.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info(
        "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning"
    )
    tokenizer_class = GPT2Tokenizer if "gpt2" in config.model_checkpoint else OpenAIGPTTokenizer
    tokenizer = tokenizer_class.from_pretrained(config.model_checkpoint)
    model_class = OpenAIGPTForEmotionDetection
    model = model_class.from_pretrained(config.model_checkpoint)
    tokenizer.set_special_tokens(SPECIAL_TOKENS)
    model.set_num_special_tokens(len(SPECIAL_TOKENS))
    model.to(config.device)
    optimizer = OpenAIAdam(model.parameters(), lr=config.lr)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if config.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=config.fp16)
    if config.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[config.local_rank],
                                        output_device=config.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        config, tokenizer)

    model.eval()
    n_emotions = 0
    num_correct = 0
    positives = 0
    all_true_positives = 0
    num_all = len(val_loader)
    for batch in val_loader:
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(config.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
            # logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            model_outputs = model(input_ids,
                                  mc_token_ids,
                                  token_type_ids=token_type_ids)
            lm_logits, mc_logits = model_outputs[0], model_outputs[
                1]  # So we can also use GPT2 outputs
            indices = torch.argmax(mc_logits, dim=1)

            correct = torch.eq(indices, mc_labels).view(-1)
            num_correct += torch.sum(correct).item()

            num_classes = mc_logits.size(1)
            mc_labels = to_onehot(mc_labels.view(-1), num_classes=num_classes)
            indices = torch.argmax(mc_logits, dim=1).view(-1)
            mc_logits = to_onehot(indices, num_classes=num_classes)
            mc_labels = mc_labels.type_as(mc_logits)
            correct = mc_labels * mc_logits
            all_positives = mc_logits.sum(dim=0).type(
                torch.DoubleTensor)  # Convert from int cuda/cpu to double cpu

            if correct.sum() == 0:
                true_positives = torch.zeros_like(all_positives)
            else:
                true_positives = correct.sum(dim=0)

            true_positives = true_positives.type(torch.DoubleTensor)
            positives += all_positives
            all_true_positives += true_positives

    print(num_correct / num_all)
    print(all_true_positives / positives)
    print(n_emotions)