Beispiel #1
0
 def forward(self, outputs, targets):
     res = {}
     value = any_value(outputs)
     device_n_key = f'_device|overall|_n'
     if device_n_key in outputs:
         bs = int(outputs[device_n_key].sum().item())
     else:
         bs = len(value)
     overall_loss_items = torch.zeros(bs, dtype=value.dtype)
     for path, loss in self._all_losses.items():
         loss_items = loss(outputs, targets).loss_items
         res[path] = loss_items.squeeze(dim=-1)
         device_key = f'_device|indices|{path}|loss_per_sample'
         if device_key in self.ctx:
             indices = self.ctx[device_key].detach().cpu()
             valid_indices_mask = indices >= 0
             res[f'indices|{path}'] = indices[valid_indices_mask]
             res[path] = res[path][valid_indices_mask]
             overall_loss_items[indices[valid_indices_mask]] += res[path]
         else:
             indices = torch.arange(bs, device=value.device)
             precondition = outputs[f'precondition|{path}']
             axes = tuple(range(1, len(precondition.shape)))
             if len(axes) > 0:
                 precondition = precondition.sum(axis=axes) > 0
             if precondition.sum() == 0:
                 # Since the metric decorator will add a 0 term to the loss, we have
                 # to mark the indices as well. This is later checked in the if above.
                 res[f'indices|{path}'] = torch.ones(1, dtype=indices.dtype, device=indices.device) * -1
             else:
                 res[f'indices|{path}'] = indices[precondition]
                 overall_loss_items[precondition] += res[path].cpu()
     res['overall'] = overall_loss_items
     res['indices|overall'] = torch.arange(bs, device=value.device)
     return res
Beispiel #2
0
    def __call__(self, x):
        if not isinstance(x, Results):
            x = Results(x)

        if x.idx is None:
            n = len(any_value(x.module_output.logits))
            tree = Tree()
            batch_node = tree.create_node(tag='batch', identifier='batch')
            for i in range(n):
                inp_tree = Tree()
                inp_node = inp_tree.create_node(tag=f'inp {i}', identifier=f'inp_{i}.{self.prefix}')
                x_for_id = Results(x.module_output, i)
                out = TreeExplanation(inp_tree, inp_node, x_for_id, f'inp_{i}.')
                out = self.flow(self, x_for_id, out)
                tree.paste(batch_node.identifier, out.tree)

            if len(self.prefix) == 0:
                return tree
            return TreeExplanation(tree, start_node=batch_node, results=x, prefix=self.prefix)

        tree = Tree()
        start_node = tree.create_node(tag=self.task_name, identifier=f'inp_{x.idx}.{self.prefix}.{self.task_name}')
        out = TreeExplanation(tree, start_node=start_node, results=x, prefix=f'inp_{x.idx}.{self.prefix}')
        out = self.flow(self, x, out)

        no_new_nodes_added = (len(out.tree.nodes) == 1) and (out.start_node.identifier in out.tree)
        # if all the nodes of a nested flow are empty, remove the whole flow node.
        if no_new_nodes_added:
            out.tree = Tree()
        if len(self.prefix) == 0:
            return out.tree
        return out
Beispiel #3
0
    def forward(self, *args):
        """
        TaskFlowLoss can be invoked either by giving two arguments: (outputs, targets), or bby giving a single
        LossFlowData argument, which holds the outputs and the targets.
        :param args:
        :return:
        """
        if args_are_dicts(args):
            outputs, targets = args
        else:
            loss_flow_data = args[0]
            outputs = loss_flow_data.outputs
            targets = loss_flow_data.targets

        value = any_value(outputs)
        key = self.prefix + self.task_name
        result_from_device_reducing = self._device_reducing_cache(key, outputs)
        if result_from_device_reducing is not None:
            return result_from_device_reducing
        loss_items = torch.zeros(1, dtype=value.dtype, device=value.device)
        flow_result = self.flow(self, LossFlowData(outputs, targets), LossItems(loss_items))

        is_root = self.prefix == ''
        if not is_root:
            return LossItems(flow_result.loss_items)

        return flow_result.loss_items
Beispiel #4
0
def reduce_on_device(criterion, per_sample_criterion, leaf_criterions, metrics,
                     outputs, targets, r_device_metrics, r_leaf_losses,
                     r_per_sample_losses):
    loss = criterion(outputs, targets)
    any_tensor = any_value(targets)
    n = len(any_tensor)
    criterion_n = torch.tensor(n,
                               dtype=any_tensor.dtype,
                               device=any_tensor.device)
    reduced_with_grad = {
        f'_device|{criterion.prefix}{criterion.task_name}|loss': loss,
        f'_device|{criterion.prefix}{criterion.task_name}|_n': criterion_n,
        f'_device|overall|loss': loss,
        f'_device|overall|_n': criterion_n
    }
    reduced = {}

    with torch.no_grad():
        if r_device_metrics:
            compute_device_metrics(reduced, any_tensor, metrics, outputs,
                                   targets)
        if r_leaf_losses:
            compute_leaf_losses(leaf_criterions, outputs, reduced, targets)
        if r_per_sample_losses:
            compute_per_sample_losses(reduced, per_sample_criterion, outputs,
                                      targets, n)

    return reduced_with_grad, reduced
Beispiel #5
0
    def on_batch_end(self, state: State):
        if not isinstance(state.loaders[state.loader_name].sampler, SequentialSampler):
            return
        outputs = state.output['logits']
        targets = state.input['targets']
        overall_res = self.overall_loss(outputs, targets)
        start = self.loader_counts[state.loader_name]
        bs = len(any_value(outputs))

        for path, loss in overall_res.items():
            if path.startswith('indices'):
                continue
            self.interpretations[state.loader_name][path].append(loss.detach().cpu().numpy())
            ind_key = f'indices|{path}'
            indices = overall_res[ind_key] + start
            self.interpretations[state.loader_name][ind_key].append(indices.detach().cpu().numpy())
        self.loader_counts[state.loader_name] += bs
Beispiel #6
0
 def forward(self, outputs, targets):
     res = {}
     value = any_value(outputs)
     bs = len(value)
     overall_loss_items = torch.zeros(bs,
                                      device=value.device,
                                      dtype=value.dtype)
     for path, loss in self._all_losses.items():
         loss_items = loss(outputs, targets).loss_items
         res[path] = loss_items.squeeze(dim=-1)
         indices = torch.arange(bs, device=value.device)
         precondition = outputs[f'precondition|{path}']
         axes = tuple(range(1, len(precondition.shape)))
         if len(axes) > 0:
             precondition = precondition.sum(axis=axes) > 0
         res[f'indices|{path}'] = indices[precondition]
         overall_loss_items[precondition] += res[path]
     res['overall'] = overall_loss_items
     res['indices|overall'] = torch.arange(bs, device=value.device)
     return res
Beispiel #7
0
    def forward(self, *args):
        """
        TaskFlowLoss can be invoked either by giving two arguments: (outputs, targets), or bby giving a single
        LossFlowData argument, which holds the outputs and the targets.
        :param args:
        :return:
        """
        is_root = len(args) == 2
        if is_root:
            outputs, targets = args
        else:
            loss_flow_data = args[0]
            outputs = loss_flow_data.outputs
            targets = loss_flow_data.targets

        value = any_value(outputs)
        loss_items = torch.zeros(1, dtype=value.dtype, device=value.device)
        flow_result = self.flow(self, LossFlowData(outputs, targets),
                                LossItems(loss_items))

        if not is_root:
            return LossItems(flow_result.loss_items)

        return flow_result.loss_items