def Labels2PrecisionRecall(labels, cols): epsilon = 1e-30 y_pred, y_true = labels num_classes = len(cols) y_pred = to_onehot(y_pred, num_classes) y_true = to_onehot(y_true, num_classes) tp = (y_pred * y_true).sum(0) pred = y_pred.sum(0) true = y_true.sum(0) precision = tp / (pred + epsilon) recall = tp / (true + epsilon) return PrecisionRecallTable(precision, recall, cols)
def update(self, output: Sequence[torch.Tensor]) -> None: self._check_shape(output) self._check_type(output) y_pred, y = output[0].detach(), output[1].detach() if self._type == "binary": y_pred = y_pred.view(-1) y = y.view(-1) elif self._type == "multiclass": num_classes = y_pred.size(1) if y.max() + 1 > num_classes: raise ValueError( f"y_pred contains less classes than y. Number of predicted classes is {num_classes}" f" and element in y has invalid class = {y.max().item() + 1}." ) y = to_onehot(y.view(-1), num_classes=num_classes) indices = torch.argmax(y_pred, dim=1).view(-1) y_pred = to_onehot(indices, num_classes=num_classes) elif self._type == "multilabel": # if y, y_pred shape is (N, C, ...) -> (C, N x ...) num_classes = y_pred.size(1) y_pred = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1) y = torch.transpose(y, 1, 0).reshape(num_classes, -1) # Convert from int cuda/cpu to double on self._device y_pred = y_pred.to(dtype=torch.float64, device=self._device) y = y.to(dtype=torch.float64, device=self._device) correct = y * y_pred actual_positives = y.sum(dim=0) if correct.sum() == 0: true_positives = torch.zeros_like(actual_positives) else: true_positives = correct.sum(dim=0) if self._type == "multilabel": if not self._average: self._true_positives = torch.cat( [self._true_positives, true_positives], dim=0) # type: torch.Tensor self._positives = torch.cat( [self._positives, actual_positives], dim=0) # type: torch.Tensor else: self._true_positives += torch.sum( true_positives / (actual_positives + self.eps)) self._positives += len(actual_positives) else: self._true_positives += true_positives self._positives += actual_positives self._updated = True
def update(self, output): y_pred, y = output self._check_shape(output) self._check_type((y_pred, y)) if self._type == "binary": y_pred = y_pred.view(-1) y = y.view(-1) elif self._type == "multiclass": num_classes = y_pred.size(1) if y.max() + 1 > num_classes: raise ValueError( "y_pred contains less classes than y. Number of predicted classes is {}" " and element in y has invalid class = {}.".format( num_classes, y.max().item() + 1)) y = to_onehot(y.view(-1), num_classes=num_classes) indices = torch.argmax(y_pred, dim=1).view(-1) y_pred = to_onehot(indices, num_classes=num_classes) elif self._type == "multilabel": # if y, y_pred shape is (N, C, ...) -> (C, N x ...) num_classes = y_pred.size(1) y_pred = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1) y = torch.transpose(y, 1, 0).reshape(num_classes, -1) y = y.type_as(y_pred) correct = y * y_pred actual_positives = y.sum(dim=0).type( torch.DoubleTensor) # Convert from int cuda/cpu to double cpu if correct.sum() == 0: true_positives = torch.zeros_like(actual_positives) else: true_positives = correct.sum(dim=0) # Convert from int cuda/cpu to double cpu # We need double precision for the division true_positives / actual_positives true_positives = true_positives.type(torch.DoubleTensor) if self._type == "multilabel": if not self._average: self._true_positives = torch.cat( [self._true_positives, true_positives], dim=0) self._positives = torch.cat( [self._positives, actual_positives], dim=0) else: self._true_positives += torch.sum( true_positives / (actual_positives + self.eps)) self._positives += len(actual_positives) else: self._true_positives += true_positives self._positives += actual_positives
def _test_NC(): num_classes = 4 cm = ConfusionMatrix(num_classes=num_classes) y_pred = torch.rand(10, num_classes) y_labels = torch.randint(0, num_classes, size=(10,)).type(torch.LongTensor) y = to_onehot(y_labels, num_classes=num_classes) cm.update((y_pred, y)) np_y_pred = y_pred.numpy().argmax(axis=1).ravel() np_y = y_labels.numpy().ravel() assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) num_classes = 10 cm = ConfusionMatrix(num_classes=num_classes) y_pred = torch.rand(4, num_classes) y_labels = torch.randint(0, num_classes, size=(4, )).type(torch.LongTensor) y = to_onehot(y_labels, num_classes=num_classes) cm.update((y_pred, y)) np_y_pred = y_pred.numpy().argmax(axis=1).ravel() np_y = y_labels.numpy().ravel() assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) # 2-classes num_classes = 2 cm = ConfusionMatrix(num_classes=num_classes) y_pred = torch.rand(4, num_classes) y_labels = torch.randint(0, num_classes, size=(4,)).type(torch.LongTensor) y = to_onehot(y_labels, num_classes=num_classes) cm.update((y_pred, y)) np_y_pred = y_pred.numpy().argmax(axis=1).ravel() np_y = y_labels.numpy().ravel() assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy()) # Batched Updates num_classes = 5 cm = ConfusionMatrix(num_classes=num_classes) y_pred = torch.rand(100, num_classes) y_labels = torch.randint(0, num_classes, size=(100,)).type(torch.LongTensor) y = to_onehot(y_labels, num_classes=num_classes) batch_size = 16 n_iters = y.shape[0] // batch_size + 1 for i in range(n_iters): idx = i * batch_size cm.update((y_pred[idx: idx + batch_size], y[idx: idx + batch_size])) np_y = y_labels.numpy().ravel() np_y_pred = y_pred.numpy().argmax(axis=1).ravel() assert np.all(confusion_matrix(np_y, np_y_pred, labels=list(range(num_classes))) == cm.compute().numpy())
def update(self, output): y_pred, y = self._check_shape(output) _, indices = torch.max(y_pred, dim=1) y_pred_ohe = to_onehot(indices.reshape(-1), self.num_classes) y_ohe = to_onehot(y.reshape(-1), self.num_classes) y_ohe_t = y_ohe.transpose(0, 1).float() y_pred_ohe = y_pred_ohe.float() if self.confusion_matrix.type() != y_ohe_t.type(): self.confusion_matrix = self.confusion_matrix.type_as(y_ohe_t) self.confusion_matrix += (y_ohe_t @ y_pred_ohe).float() self._num_examples += y_pred.shape[0]
def forward(self, input, target): target_onehot = to_onehot(target, num_classes=input.shape[1]).to(device=device) mse = (input - target_onehot) ** 2 if self.class_weights is not None: weights = self.class_weights[target] * input.shape[1] return (mse.sum(1) * weights).sum() else: return mse.sum()
def transform_fn(output): _, y_pred, y_true = output # print("ORIGINAL: ", torch.round(y_pred[:, label_index]).long()) y_pred = to_onehot(torch.round(y_pred[:, label_index]).long(), num_classes) y_true = y_true[:, label_index] # print("TO: ", y_pred) return y_pred, y_true
def test_to_onehot(): indices = torch.tensor([0, 1, 2, 3], dtype=torch.long) actual = to_onehot(indices, 4) expected = torch.eye(4, dtype=torch.uint8) assert actual.equal(expected) y = torch.randint(0, 21, size=(1000,)) y_ohe = to_onehot(y, num_classes=21) y2 = torch.argmax(y_ohe, dim=1) assert y.equal(y2) y = torch.randint(0, 21, size=(4, 250, 255)) y_ohe = to_onehot(y, num_classes=21) y2 = torch.argmax(y_ohe, dim=1) assert y.equal(y2) y = torch.randint(0, 21, size=(4, 150, 155, 4, 6)) y_ohe = to_onehot(y, num_classes=21) y2 = torch.argmax(y_ohe, dim=1) assert y.equal(y2) # Test with `TorchScript` x = torch.tensor([0, 1, 2, 3]) # Test the raw `to_onehot` function scripted_to_onehot = torch.jit.script(to_onehot) assert scripted_to_onehot(x, 4).allclose(to_onehot(x, 4)) # Test inside `torch.nn.Module` class SLP(torch.nn.Module): def __init__(self): super(SLP, self).__init__() self.linear = torch.nn.Linear(4, 1) def forward(self, x): x = to_onehot(x, 4) return self.linear(x.to(torch.float)) eager_model = SLP() scripted_model = torch.jit.script(eager_model) assert eager_model(x).allclose(scripted_model(x))
def test_to_onehot(): indices = torch.tensor([0, 1, 2, 3], dtype=torch.long) actual = to_onehot(indices, 4) expected = torch.eye(4, dtype=torch.uint8) assert actual.equal(expected) y = torch.randint(0, 21, size=(1000, )) y_ohe = to_onehot(y, num_classes=21) y2 = torch.argmax(y_ohe, dim=1) assert y.equal(y2) y = torch.randint(0, 21, size=(4, 250, 255)) y_ohe = to_onehot(y, num_classes=21) y2 = torch.argmax(y_ohe, dim=1) assert y.equal(y2) y = torch.randint(0, 21, size=(4, 150, 155, 4, 6)) y_ohe = to_onehot(y, num_classes=21) y2 = torch.argmax(y_ohe, dim=1) assert y.equal(y2)
def forward(self, input, target): y = to_onehot(target, input.size(-1)) logit = F.softmax(input, dim=-1) logit = logit.clamp(self.eps, 1. - self.eps) loss = -1 * y * torch.log(logit) # cross entropy loss = loss * (1 - logit)**self.gamma # focal loss if self.size_average: return loss.mean() else: return loss.sum()
def output_transform_seg(process_output): """ Output transform for segmentation metrics. """ y_pred = process_output[0]['out'].argmax(dim=1) # (B, W, H) y = process_output[1] # (B, W, H) y_pred_ = y_pred.view(-1) # B, (W*H) y_ = y.view(-1) y_pred_one_hot = to_onehot(y_pred_, num_classes=NUM_CLASSES) return dict(y_pred=y_pred_one_hot, y=y_) # output format is according to `DiceCoefficient` docs
def update(self, output): y_pred, y = self._check_shape(output) self._check_type((y_pred, y)) if self._type == "binary": y_pred = y_pred.view(-1) y = y.view(-1) elif self._type == "multiclass": num_classes = y_pred.size(1) y = to_onehot(y.view(-1), num_classes=num_classes) indices = torch.max(y_pred, dim=1)[1].view(-1) y_pred = to_onehot(indices, num_classes=num_classes) elif self._type == "multilabel": # if y, y_pred shape is (N, C, ...) -> (C, N x ...) num_classes = y_pred.size(1) y_pred = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1) y = torch.transpose(y, 1, 0).reshape(num_classes, -1) y = y.type_as(y_pred) correct = y * y_pred actual_positives = y.sum(dim=0).type( torch.DoubleTensor) # Convert from int cuda/cpu to double cpu if correct.sum() == 0: true_positives = torch.zeros_like(actual_positives) else: true_positives = correct.sum(dim=0) # Convert from int cuda/cpu to double cpu # We need double precision for the division true_positives / actual_positives true_positives = true_positives.type(torch.DoubleTensor) if self._type == "multilabel": self._true_positives = torch.cat( [self._true_positives, true_positives], dim=0) self._positives = torch.cat([self._positives, actual_positives], dim=0) else: self._true_positives += true_positives self._positives += actual_positives
def forward(self, input, target): softmax = torch.exp(input) / torch.exp(input).sum(1)[:, None] onehot_labels = to_onehot(target, input.shape[1]) soft_labels = torch.zeros_like(onehot_labels) soft_labels = torch.where( onehot_labels.cpu() == 1, torch.tensor([0.9]), torch.tensor([0.1 / (input.shape[1] - 1)])).to(device=device) if self.class_weights is not None: # print(soft_labels.shape, softmax.shape) loss = -torch.sum( torch.log(softmax) * soft_labels * self.class_weights * input.shape[1]) else: loss = -torch.sum(torch.log(softmax) * soft_labels) return loss
def train(loop: Loop): for _ in loop.iterate_epochs(NUM_EPOCHS): for x, y in loop.iterate_dataloader(train_loader, mode="train"): y_pred_logits = model(x) loss: torch.Tensor = criterion(y_pred_logits, y) loop.backward(loss) # Makes optimizer step and also # zeroes grad after (default) loop.optimizer_step(optim, zero_grad=True) # Here we call scheduler.step() every iteration # because we have one-cycle scheduler # we also can call it after all dataloader loop # if it's som usual scheduler scheduler.step() # Log learning rate. All metrics are written to tensorboard # with specified names # If iteration='auto' (default) its determined based on where the call is # performed. Here it will be batches loop.metrics.log("lr", scheduler.get_last_lr()[0], iteration="auto") # Loop disables gradients and calls Module.eval() inside loop # for all attached modules when mode="valid" (default) for x, y in loop.iterate_dataloader(valid_loader, mode="valid"): y_pred_logits: torch.Tensor = model(x) y_pred = to_onehot(y_pred_logits.argmax(dim=-1), num_classes=10) precision.update((y_pred, y)) recall.update((y_pred, y)) accuracy.update((y_pred, y)) # This metrics will be epoch metrics because they are called outside # dataloader loop # Here we logging metric without resetting it loop.metrics.log("valid/precision", precision.compute().mean()) loop.metrics.log("valid/recall", recall.compute().mean()) # .log() method above accepts values (tensors, floats, np.array's) # .consume() accepts Metric object. It resets it after logging loop.metrics.consume("valid/f1", f1) loop.metrics.consume("valid/accuracy", accuracy)
def __call__(self, output): if isinstance(output, tuple): y_pred, y = output elif isinstance(output, dict): y_pred = output["y_pred"] y = output["y"] else: raise ValueError if self._num_classes: y_pred = y_pred.clamp(min=0, max=self._num_classes - 1).long() y = y.clamp(min=0, max=self._num_classes - 1).long() y_pred = to_onehot(y_pred, self._num_classes) else: y_pred = y_pred.long() y = y.long() return y_pred, y
def stats_collect_function(engine, batch): x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) y_ohe = to_onehot(y.reshape(-1), config.num_classes) class_distrib = y_ohe.mean(dim=0).cpu() class_presence = (class_distrib > 1e-3).cpu().float() num_classes = (class_distrib > 1e-3).sum().item() engine.state.class_presence += class_presence engine.state.class_presence -= (1 - class_presence) return { "class_distrib": class_distrib, "class_presence": engine.state.class_presence, "num_classes": num_classes }
def forward(self, y_pred, y): y_pred = torch.softmax(y_pred, dim=1) b, c = y_pred.shape[0], y_pred.shape[1] if y_pred.ndim != y.ndim: input_shape = y_pred.shape input_shape = (input_shape[0], input_shape[2], input_shape[3]) if input_shape == y.shape: y = to_onehot(y, num_classes=c).to(y_pred) else: raise ValueError("Shapes mismatch: {} vs {}".format( y_pred.shape, y.shape)) y_pred = y_pred.reshape(b, c, -1) y = y.reshape(b, c, -1) intersection = y_pred * y union = y_pred + y - intersection + 1e-10 intersection = torch.sum(intersection, dim=-1) union = torch.sum(union, dim=-1) if self.ignore_index is not None: indices = list(range(c)) indices.remove(self.ignore_index) intersection = intersection[:, indices] union = union[:, indices] if self.reduction == "mean": intersection = torch.mean(intersection) union = torch.mean(union) elif self.reduction == "sum": intersection = torch.sum(intersection) union = torch.sum(union) return 1.0 - intersection / union
def forward(self, x): x = to_onehot(x, 4) return self.linear(x.to(torch.float))
def train(): config_file = "configs/train_daily_dialog_full_pipeline_config.json" config = Config.from_json_file(config_file) # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes logging.basicConfig( level=logging.INFO if config.local_rank in [-1, 0] else logging.WARN) logger.warning( "Running process %d", config.local_rank ) # This is a logger.warning: it will be printed by all distributed processes logger.info("Arguments: %s", pformat(config)) # Initialize distributed training if needed config.distributed = (config.local_rank != -1) if config.distributed: torch.cuda.set_device(config.local_rank) config.device = torch.device("cuda", config.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') model_checkpoint = "/home/rohola/codes/transfer-learning-conv-ai/logs/emotion_detection_log/" tokenizer_class = OpenAIGPTTokenizer tokenizer = tokenizer_class.from_pretrained(model_checkpoint) model_class = OpenAIGPTForEmotionDetection emotion_detection_model = model_class.from_pretrained(model_checkpoint) tokenizer.set_special_tokens(SPECIAL_TOKENS) emotion_detection_model.set_num_special_tokens(len(SPECIAL_TOKENS)) emotion_detection_model.to(config.device) logger.info( "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning" ) tokenizer_class = GPT2Tokenizer if "gpt2" in config.model_checkpoint else OpenAIGPTTokenizer tokenizer = tokenizer_class.from_pretrained(config.model_checkpoint) model_class = OpenAIGPTDoubleHeadLMEmotionRecognitionModel emotion_recognition_model = model_class.from_pretrained( config.model_checkpoint) tokenizer.set_special_tokens(SPECIAL_TOKENS) emotion_recognition_model.set_num_special_tokens(len(SPECIAL_TOKENS)) emotion_recognition_model.to(config.device) optimizer = OpenAIAdam(emotion_recognition_model.parameters(), lr=config.lr) # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) if config.fp16: from apex import amp # Apex is only required if we use fp16 training emotion_recognition_model, optimizer = amp.initialize( emotion_recognition_model, optimizer, opt_level=config.fp16) if config.distributed: emotion_recognition_model = DistributedDataParallel( emotion_recognition_model, device_ids=[config.local_rank], output_device=config.local_rank) logger.info("Prepare datasets") train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders( config, tokenizer) emotion_detection_model.eval() n_emotions = 0 num_correct = 0 all_predicted_positives = 0 all_true_positives = 0 all_actual_positives = 0 confusion_matrix = torch.zeros(6, 6, dtype=torch.float).cuda() num_all = len(val_loader) for batch in val_loader: with torch.no_grad(): batch = tuple( input_tensor.to(config.device) for input_tensor in batch) input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids = batch model_outputs = emotion_detection_model( input_ids, mc_token_ids, token_type_ids=token_type_ids) lm_logits, mc_logits = model_outputs[0], model_outputs[ 1] # So we can also use GPT2 outputs indices = torch.argmax(mc_logits, dim=1) if indices.item() != 0: #have emotion recognition_output = emotion_recognition_model( input_ids, mc_token_ids, token_type_ids=token_type_ids, token_emotion_ids=token_emotion_ids) if mc_labels.item() != 0: mc_labels = mc_labels - 1 else: continue #mc_labels = torch.randint(0, 6, size=(1,)).cuda() mc_recognition_logit = recognition_output[1] indices = torch.argmax(mc_recognition_logit, dim=1) correct = torch.eq(indices, mc_labels).view(-1) num_correct += torch.sum(correct).item() n_emotions += 1 #precision num_classes = mc_recognition_logit.size(1) print(mc_labels) mc_labels = to_onehot(mc_labels.view(-1), num_classes=num_classes) indices = torch.argmax(mc_recognition_logit, dim=1).view(-1) mc_recognition_logit = to_onehot(indices, num_classes=num_classes) mc_labels = mc_labels.type_as(mc_recognition_logit) correct = mc_labels * mc_recognition_logit all_positives = mc_recognition_logit.sum(dim=0).type( torch.DoubleTensor ) # Convert from int cuda/cpu to double cpu if correct.sum() == 0: true_positives = torch.zeros_like(all_positives) else: true_positives = correct.sum(dim=0) true_positives = true_positives.type(torch.DoubleTensor) all_predicted_positives += all_positives all_true_positives += true_positives #recall actual_positives = mc_labels.sum(dim=0).type( torch.DoubleTensor) all_actual_positives += actual_positives #confusion matrix mc_labels_t = mc_labels.transpose(0, 1).float() mc_recognition_logit = mc_recognition_logit.float() confusion_matrix += torch.matmul(mc_labels_t, mc_recognition_logit).float() print(num_correct / n_emotions) # accuracy for all classes of emotion print(n_emotions / num_all) print(all_true_positives / all_predicted_positives) print(all_true_positives / all_actual_positives) print(confusion_matrix)
def test_to_onehot(): indices = torch.LongTensor([0, 1, 2, 3]) actual = to_onehot(indices, 4) expected = torch.eye(4) assert actual.equal(expected)
def forward(self, inputs, labels=None): """ :param inputs: [bsz, max_seq_leng] :param labels: [bsz, num_class] :return: """ inputs = inputs.t() mask = (inputs > 0).float() inputs_len = (inputs > 0).int().sum(dim=0) hidden = self.encoder(inputs, mask, inputs_len) pool_values = [] for pool in self.summary_type: if pool == 'max': val = max_pooling(hidden, mask) pool_values.append(val) elif pool == 'mean': val = mean_pooling(hidden, inputs_len, mask) pool_values.append(val) elif pool == 'first': seq_len, bsz, dim = hidden.size() val = hidden[0, :, :].view(bsz, -1).contiguous() pool_values.append(val) elif pool == 'last': seq_len, bsz, dim = hidden.size() val = hidden[-1, :, :].view(bsz, -1).contiguous() pool_values.append(val) elif pool == 'struct_att': val, att = self.strut_att(hidden, mask) bsz, head_num, dim = val.size() val = val.contiguous().view(bsz, -1) pool_values.append(val) elif pool == 'none': pool_values.append(hidden) if len(self.summary_type) == 1: hidden = pool_values[0] else: hidden = torch.cat(pool_values, dim=-1).contiguous() # [bsz, hid_dim] bsz, hid_dim = hidden.size() # logits = self.cls(self.dropout(hidden)) hidden = self.normalize(hidden) logits = self.cls(hidden) if self.training: # Mixup indices = torch.randperm(bsz, device=logits.device) shuf_labels = torch.index_select(labels, 0, indices) shuf_hidden = torch.index_select(hidden, 0, indices) if self.mixup_type == 'mixup': lam = self.beta_dist.sample(sample_shape=(bsz, 1)) lam = lam.to(inputs.device) lam_x, lam_y = lam, lam elif self.mixup_type == 'prior_mix': lam_x = self.beta_dist.sample(sample_shape=(bsz,)) lam_x = lam_x.to(inputs.device) lam_y = self.prior_mixup(labels, shuf_labels) lam_y = 2. * lam_x * lam_y / (lam_x + lam_y) else: raise Exception('Unsupported mixup type %s' % self.mixup_type) mix_hidden = lam_x * hidden + (1 - lam_x) * shuf_hidden if not self.multi_label: onehot_label = to_onehot(labels, self.num_class) onehot_shuf_label = to_onehot(shuf_labels, self.num_class) else: onehot_label = labels onehot_shuf_label = shuf_labels lam_y = lam_y.unsqueeze(-1) mix_labels = lam_y * onehot_label + (1 - lam_y) * onehot_shuf_label mix_logits = self.cls(mix_hidden) return logits, mix_logits, mix_labels return logits, hidden
def forward(self, input, target): target_onehot = to_onehot(target, num_classes=input.shape[1]).to(device=device) return nn.functional.mse_loss(input, target_onehot)
def train(): config_file = "configs/train_daily_dialog_emotion_detection_config.json" config = Config.from_json_file(config_file) # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes logging.basicConfig( level=logging.INFO if config.local_rank in [-1, 0] else logging.WARN) logger.warning("Running process %d", config.local_rank) logger.info("Arguments: %s", pformat(config)) # Initialize distributed training if needed config.distributed = (config.local_rank != -1) if config.distributed: torch.cuda.set_device(config.local_rank) config.device = torch.device("cuda", config.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') logger.info( "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning" ) tokenizer_class = GPT2Tokenizer if "gpt2" in config.model_checkpoint else OpenAIGPTTokenizer tokenizer = tokenizer_class.from_pretrained(config.model_checkpoint) model_class = OpenAIGPTForEmotionDetection model = model_class.from_pretrained(config.model_checkpoint) tokenizer.set_special_tokens(SPECIAL_TOKENS) model.set_num_special_tokens(len(SPECIAL_TOKENS)) model.to(config.device) optimizer = OpenAIAdam(model.parameters(), lr=config.lr) # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) if config.fp16: from apex import amp # Apex is only required if we use fp16 training model, optimizer = amp.initialize(model, optimizer, opt_level=config.fp16) if config.distributed: model = DistributedDataParallel(model, device_ids=[config.local_rank], output_device=config.local_rank) logger.info("Prepare datasets") train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders( config, tokenizer) model.eval() n_emotions = 0 num_correct = 0 positives = 0 all_true_positives = 0 num_all = len(val_loader) for batch in val_loader: with torch.no_grad(): batch = tuple( input_tensor.to(config.device) for input_tensor in batch) input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch # logger.info(tokenizer.decode(input_ids[0, -1, :].tolist())) model_outputs = model(input_ids, mc_token_ids, token_type_ids=token_type_ids) lm_logits, mc_logits = model_outputs[0], model_outputs[ 1] # So we can also use GPT2 outputs indices = torch.argmax(mc_logits, dim=1) correct = torch.eq(indices, mc_labels).view(-1) num_correct += torch.sum(correct).item() num_classes = mc_logits.size(1) mc_labels = to_onehot(mc_labels.view(-1), num_classes=num_classes) indices = torch.argmax(mc_logits, dim=1).view(-1) mc_logits = to_onehot(indices, num_classes=num_classes) mc_labels = mc_labels.type_as(mc_logits) correct = mc_labels * mc_logits all_positives = mc_logits.sum(dim=0).type( torch.DoubleTensor) # Convert from int cuda/cpu to double cpu if correct.sum() == 0: true_positives = torch.zeros_like(all_positives) else: true_positives = correct.sum(dim=0) true_positives = true_positives.type(torch.DoubleTensor) positives += all_positives all_true_positives += true_positives print(num_correct / num_all) print(all_true_positives / positives) print(n_emotions)