def calculate_loss( self, X_dict: Dict[str, Any], Y_dict: Dict[str, torch.Tensor] ) -> Tuple[Dict[str, torch.Tensor], Dict[str, float]]: """Calculate the loss for each task and the number of data points contributing. Parameters ---------- X_dict A dict of data fields Y_dict A dict from task names to label sets Returns ------- Dict[str, torch.Tensor], Dict[str, float] A dict of losses by task name and seen examples by task name """ loss_dict = dict() count_dict = dict() labels_to_tasks = self._get_labels_to_tasks(Y_dict.keys()) outputs = self.forward(X_dict, task_names=labels_to_tasks.values()) # Calculate loss for each task for label_name, task_name in labels_to_tasks.items(): Y = Y_dict[label_name] # Select the active samples if len(Y.size()) == 1: active = Y.detach() != -1 else: active = torch.any(Y.detach() != -1, dim=1) # Only calculate the loss when active example exists if active.any(): # Note: Use label_name as key, but task_name to access model attributes count_dict[label_name] = active.sum().item() # Extract the output of the last operation for this task inputs = outputs[self.op_sequences[task_name][-1].name] # Filter out any inactive examples if inputs is a Tensor if not active.all() and isinstance(inputs, torch.Tensor): inputs = inputs[active] Y = Y[active] loss_dict[label_name] = self.loss_funcs[task_name]( inputs, move_to_device(Y, self.config.device) ) return loss_dict, count_dict
def forward( # type: ignore self, X_dict: Dict[str, Any], task_names: Iterable[str]) -> OutputDict: """Do a forward pass through the network for all specified tasks. Parameters ---------- X_dict A dict of data fields task_names The names of the tasks to execute the forward pass for Returns ------- OutputDict A dict mapping each operation name to its corresponding output Raises ------ TypeError If an Operation input has an invalid type ValueError If a specified Operation failed to execute """ X_dict_moved = move_to_device(X_dict, self.config.device) outputs: OutputDict = {"_input_": X_dict_moved} # type: ignore # Call forward for each task, using cached result if available # Each op_sequence consists of one or more operations that are executed in order for task_name in task_names: op_sequence = self.op_sequences[task_name] for operation in op_sequence: if operation.name not in outputs: try: if operation.inputs: # Feed the inputs the module requested in the reqested order inputs = [] for op_input in operation.inputs: if isinstance(op_input, tuple): # The output of the indicated operation has a dict # of fields; extract the designated field by name op_name, field_key = op_input inputs.append(outputs[op_name][field_key]) else: # The output of the indicated operation has only # one field; use that as the input to the current op op_name = op_input inputs.append(outputs[op_name]) output = self.module_pool[ operation.module_name].forward(*inputs) else: # Feed the entire outputs dict for the module to pull from output = self.module_pool[ operation.module_name].forward(outputs) except Exception as e: raise ValueError( f"Unsuccessful operation {operation}: {repr(e)}.") outputs[operation.name] = output return outputs