def _compute_region_loss(self, task, attn_scores): """ """ # one attention head if len(attn_scores.shape) == 4: attn_scores = attn_scores.unsqueeze(1) batch_size, num_heads, L, H, W = attn_scores.shape exam_dims = (L, H, W) region_config = self.task_to_region_config[task] region_loss_fn = region_config['loss_fn'] region_target = self._load_region(exam_dims, **region_config['region']) region_target = region_target.unsqueeze(0).expand((batch_size, ) + exam_dims).double() region_target = place_on_gpu(region_target, attn_scores.device) region_loss = 0.0 for head in range(num_heads): head_scores = attn_scores[:, head, :, :] if region_config['loss_class'] == 'KLDivLoss': head_scores_flat = head_scores.view(batch_size, L * H * W) log_head_scores_flat = nn.functional.log_softmax( head_scores_flat, dim=-1) region_preds = log_head_scores_flat.view(batch_size, L, H, W) else: region_preds = head_scores region_loss += region_loss_fn(region_preds.double(), region_target) return place_on_gpu(region_loss.float(), attn_scores.device)
def generate_gradients(self, inputs, targets, target_task=None, token_idx=None, device=0): """ Generates gradients through the scan for the first element in the batch specified by inputs and targets. If batch size is greater than 1, only gradients for the first example will be computed. Supports multi-task output via the target_task argument. Supports tasks with multiple outputs (as in masked language modeling or natural language generation). @inputs (dict or torch.Tensor) the 3D input scan. if dict, should contain key "scan". @targets (dict, torch.Tensor) target tensor or dict @target_task (None or str) required if targets is dict @token_idx (None or int) optional: specific @grad (torch.Tensor, torch.Tensor) the gradient through the scan on the cpu @scan (torch.Tensor) the scan itself through the scan on the cpu """ inputs = place_on_gpu(inputs, device=device) self.model = self.model.to(device=device) # require gradient on scan so we can backpropagate through the pixels scan = inputs["scan"] if isinstance(inputs, dict) else inputs scan.requires_grad = True # forward pass output = self.model(inputs, targets) self.model.zero_grad() # backward pass targets = targets[target_task] if isinstance(targets, dict) else targets target_class = targets[ 0, token_idx] if token_idx is not None else targets[0] # if not an index but a softmax (probabilistic case) if len(target_class.shape) > 0: target_class = target_class.argmax() output = output[target_task] if target_task is not None else output if type(output) == dict: output = output['out'] output = output[0, token_idx] if token_idx is not None else output[0] output[target_class].backward() scan.requires_grad = False grad, scan = place_on_cpu([scan.grad, scan]) # empty relu outputs so we don't leak CPU memory self.forward_relu_outputs = [] return grad, scan
def forward(self, task, targets): """ """ if not self.task_to_strategy_fn[task]: batch_size, _ = targets[task].shape relevance = torch.ones(batch_size, dtype=torch.uint8, device=targets[task].device) # bit-ops else: relevance = self.task_to_strategy_fn[task](targets) relevance = relevance.to( targets[task].dtype) * self.task_to_weight[task] if targets[task].is_cuda: relevance = place_on_gpu(relevance, targets[task].device) return relevance
def compute_mt_attention(model, inputs, targets, task=None, device=0): """ """ inputs = place_on_gpu(inputs, device=device) model = model.to(device=device) model.eval() for task_head in model.decoder.task_heads.values(): task_head.region_aware = True model_output = model(inputs, targets) #attention_module.region_aware = region_aware if task: return model_output[task]['attn_scores'] else: return { task: output["attn_scores"] for task, output in model_output.items() }
def compute_attention(model, attention_module, inputs, targets, device=0): """ """ inputs = place_on_gpu(inputs, device=device) model = model.to(device=device) model.eval() # require gradient on scan so we can backpropagate through the pixels scan = inputs["scan"] if isinstance(inputs, dict) else inputs scan.requires_grad = True attention_module.keep_attention = True # forward pass model_output = model(inputs, targets) attention_module.keep_attention = False attention_probs = attention_module.attention_probs[-1] attention_module.keep_attention = False attention_module.attention_probs = [] return attention_probs
def score(self, dataloader, metric_configs=[], log_predictions=False): """ """ logging.info("Validation") self.eval() # move to cuda if self.cuda: self._to_gpu() metrics = Metrics(metric_configs) avg_loss = 0 with tqdm(total=len(dataloader)) as t, torch.no_grad(): for i, (inputs, targets, info) in enumerate(dataloader): # move to GPU if available if self.cuda: inputs, targets = place_on_gpu([inputs, targets], self.device) # forward pass predictions = self.predict(inputs) if log_predictions: self._log_predictions( inputs=inputs, targets=targets, predictions=predictions, info=info, ) labels = self._get_labels(targets) metrics.add(predictions, labels, info) # compute average loss and update the progress bar t.update() metrics.compute() return metrics
def predict_many(self, dataloader): """ """ logging.info("Prediction") self.eval() # move to cuda if self.cuda: self._to_gpu() with tqdm(total=len(dataloader)) as t, torch.no_grad(): for i, (inputs, labels, info) in enumerate(dataloader): # move to GPU if available if self.cuda: inputs = place_on_gpu(inputs, self.device) # forward pass predictions = self.predict(inputs) # compute average loss and update the progress bar t.update() yield inputs, labels, predictions, info
def _train_epoch( self, dataloader, metric_configs=[], summary_period=1, writer=None, log_predictions=True, ): """ Train the model for one epoch Args: train_data (DataLoader) """ logging.info("Training") self.train() metrics = Metrics(metric_configs) avg_loss = 0 with tqdm(total=len(dataloader)) as t: for i, (inputs, targets, info) in enumerate(dataloader): if self.cuda: inputs, targets = place_on_gpu([inputs, targets], self.device) # forward pass outputs = self.forward(inputs, targets) # loss for dynamic dataloader if hasattr(dataloader, "get_loss_weights"): loss = self.loss(outputs, targets, dataloader.get_loss_weights(targets)) else: loss = self.loss(outputs, targets) # backward pass self.optimizer.zero_grad() loss.backward() self.optimizer.step() loss = loss.cpu().detach().numpy() # compute metrics periodically: if i % summary_period == 0: predictions = self.predict(inputs) if log_predictions: self._log_predictions(inputs, targets, predictions, info) labels = self._get_labels(targets) metrics.add(predictions, labels, info, {"loss": loss}) del predictions # update dynamic dataloader if hasattr(dataloader, "update_batch"): dataloader.update_batch([i["idx"] for i in info]) # compute average loss and update progress bar avg_loss = ((avg_loss * i) + loss) / (i + 1) if writer is not None: writer.add_scalar(tag="loss", scalar_value=loss) t.set_postfix(loss="{:05.3f}".format(float(avg_loss))) t.update() del loss, outputs, inputs, targets, labels metrics.compute() return metrics