def forward(self, original_batch_input, resize_batch_input,
                org_ball_pos_xy, global_ball_pos_xy, event_class, target_seg):
        log_vars_idx = 0
        pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy = self.model(
            original_batch_input, resize_batch_input, org_ball_pos_xy)
        # Create target for events spotting and ball position (local and global)
        batch_size = pred_ball_global.size(0)
        target_ball_global = torch.zeros_like(pred_ball_global)
        for sample_idx in range(batch_size):
            target_ball_global[sample_idx] = create_target_ball(
                global_ball_pos_xy[sample_idx],
                sigma=self.sigma,
                w=self.w,
                h=self.h,
                thresh_mask=self.thresh_ball_pos_mask,
                device=self.device)
        global_ball_loss = self.ball_loss_criterion(pred_ball_global,
                                                    target_ball_global)
        total_loss = global_ball_loss / (torch.exp(
            2 * self.log_vars[log_vars_idx])) + self.log_vars[log_vars_idx]

        if pred_ball_local is not None:
            log_vars_idx += 1
            target_ball_local = torch.zeros_like(pred_ball_local)
            for sample_idx in range(batch_size):
                target_ball_local[sample_idx] = create_target_ball(
                    local_ball_pos_xy[sample_idx],
                    sigma=self.sigma,
                    w=self.w,
                    h=self.h,
                    thresh_mask=self.thresh_ball_pos_mask,
                    device=self.device)
            local_ball_loss = self.ball_loss_criterion(pred_ball_local,
                                                       target_ball_local)
            total_loss += local_ball_loss / (torch.exp(
                2 * self.log_vars[log_vars_idx])) + self.log_vars[log_vars_idx]

        if pred_events is not None:
            log_vars_idx += 1
            target_events = torch.zeros((batch_size, 2), device=self.device)
            for sample_idx in range(batch_size):
                target_events[sample_idx] = create_target_events(
                    event_class[sample_idx], device=self.device)
            event_loss = self.event_loss_criterion(pred_events, target_events)
            total_loss += event_loss / (2 * torch.exp(
                self.log_vars[log_vars_idx])) + self.log_vars[log_vars_idx]

        if pred_seg is not None:
            log_vars_idx += 1
            seg_loss = self.seg_loss_criterion(pred_seg, target_seg)
            total_loss += seg_loss / (2 * torch.exp(
                self.log_vars[log_vars_idx])) + self.log_vars[log_vars_idx]

        # Final weights: [math.exp(log_var) ** 0.5 for log_var in log_vars]

        return pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy, total_loss, self.log_vars.data.tolist(
        )
    def forward(self, original_batch_input, resize_batch_input,
                org_ball_pos_xy, global_ball_pos_xy, event_class, target_seg):
        pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy = self.model(
            original_batch_input, resize_batch_input, org_ball_pos_xy)
        # Create target for events spotting and ball position (local and global)
        batch_size = pred_ball_global.size(0)
        target_ball_global = torch.zeros_like(pred_ball_global)
        task_idx = 0
        for sample_idx in range(batch_size):
            target_ball_global[sample_idx] = create_target_ball(
                global_ball_pos_xy[sample_idx],
                sigma=self.sigma,
                w=self.w,
                h=self.h,
                thresh_mask=self.thresh_ball_pos_mask,
                device=self.device)
        global_ball_loss = self.ball_loss_criterion(pred_ball_global,
                                                    target_ball_global)
        total_loss = global_ball_loss * self.tasks_loss_weight[task_idx]

        if pred_ball_local is not None:
            task_idx += 1
            target_ball_local = torch.zeros_like(pred_ball_local)
            for sample_idx in range(batch_size):
                target_ball_local[sample_idx] = create_target_ball(
                    local_ball_pos_xy[sample_idx],
                    sigma=self.sigma,
                    w=self.w,
                    h=self.h,
                    thresh_mask=self.thresh_ball_pos_mask,
                    device=self.device)
            local_ball_loss = self.ball_loss_criterion(pred_ball_local,
                                                       target_ball_local)
            total_loss += local_ball_loss * self.tasks_loss_weight[task_idx]

        if pred_events is not None:
            task_idx += 1
            target_events = torch.zeros((batch_size, 2), device=self.device)
            for sample_idx in range(batch_size):
                target_events[sample_idx] = create_target_events(
                    event_class[sample_idx], device=self.device)
            event_loss = self.event_loss_criterion(pred_events, target_events)
            total_loss += event_loss * self.tasks_loss_weight[task_idx]

        if pred_seg is not None:
            task_idx += 1
            seg_loss = self.seg_loss_criterion(pred_seg, target_seg)
            total_loss += seg_loss * self.tasks_loss_weight[task_idx]

        return pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy, total_loss, None