def forward(self, sample_batch, targets, supp_dict):
        teacher_outputs, extracted_teacher_io_dict =\
            self.get_teacher_output(sample_batch, targets, supp_dict=supp_dict)
        student_outputs = self.student_forward_proc(self.student_model,
                                                    sample_batch, targets,
                                                    supp_dict)
        extracted_student_io_dict = extract_io_dict(self.student_io_dict,
                                                    self.device)
        if isinstance(self.student_model, SpecialModule):
            self.student_model.post_forward(extracted_student_io_dict)

        org_loss_dict = self.extract_org_loss(
            self.org_criterion,
            student_outputs,
            teacher_outputs,
            targets,
            uses_teacher_output=self.uses_teacher_output,
            supp_dict=supp_dict)
        update_io_dict(extracted_student_io_dict,
                       extract_io_dict(self.student_io_dict, self.device))
        output_dict = {
            'teacher': extracted_teacher_io_dict,
            'student': extracted_student_io_dict
        }
        total_loss = self.criterion(output_dict, org_loss_dict, targets)
        return total_loss
Beispiel #2
0
    def get_teacher_output(self, sample_batch, targets, supp_dict):
        cached_data = supp_dict.get('cached_data', None)
        cache_file_paths = supp_dict.get('cache_file_path', None)
        teacher_outputs = None
        cached_extracted_teacher_output_dict = None
        # Use cached data if available
        if cached_data is not None and isinstance(cached_data, dict):
            device = sample_batch.device
            teacher_outputs = cached_data['teacher_outputs']
            cached_extracted_teacher_output_dict = cached_data['extracted_outputs']
            if device.type != 'cpu':
                teacher_outputs = change_device(teacher_outputs, device)
                cached_extracted_teacher_output_dict = change_device(cached_extracted_teacher_output_dict, device)
            if not self.teacher_updatable:
                return teacher_outputs, cached_extracted_teacher_output_dict

        if teacher_outputs is None:
            if self.teacher_updatable:
                teacher_outputs = self.teacher_forward_proc(self.teacher_model, sample_batch, targets, supp_dict)
            else:
                with torch.no_grad():
                    teacher_outputs = self.teacher_forward_proc(self.teacher_model, sample_batch, targets, supp_dict)

        if cached_extracted_teacher_output_dict is not None:
            if isinstance(self.teacher_model, SpecialModule) or \
                    (check_if_wrapped(self.teacher_model) and isinstance(self.teacher_model.module, SpecialModule)):
                self.teacher_io_dict.update(cached_extracted_teacher_output_dict)
                if isinstance(self.teacher_model, SpecialModule):
                    self.teacher_model.post_forward(self.teacher_io_dict)

            extracted_teacher_io_dict = extract_io_dict(self.teacher_io_dict, self.device)
            return teacher_outputs, extracted_teacher_io_dict

        # Deep copy of teacher info dict if teacher special module contains trainable module(s)
        teacher_io_dict4cache = copy.deepcopy(self.teacher_io_dict) \
            if self.teacher_updatable and isinstance(cache_file_paths, (list, tuple)) is not None else None
        extracted_teacher_io_dict = extract_io_dict(self.teacher_io_dict, self.device)
        if isinstance(self.teacher_model, SpecialModule):
            self.teacher_model.post_forward(extracted_teacher_io_dict)

        update_io_dict(extracted_teacher_io_dict, extract_io_dict(self.teacher_io_dict, self.device))
        # Write cache files if output file paths (cache_file_paths) are given
        if isinstance(cache_file_paths, (list, tuple)):
            if teacher_io_dict4cache is None:
                teacher_io_dict4cache = extracted_teacher_io_dict

            cpu_device = torch.device('cpu')
            for i, (teacher_output, cache_file_path) in enumerate(zip(teacher_outputs.cpu().numpy(), cache_file_paths)):
                sub_dict = extract_sub_model_output_dict(teacher_io_dict4cache, i)
                sub_dict = tensor2numpy2tensor(sub_dict, cpu_device)
                cache_dict = {'teacher_outputs': torch.Tensor(teacher_output), 'extracted_outputs': sub_dict}
                make_parent_dirs(cache_file_path)
                torch.save(cache_dict, cache_file_path)
        return teacher_outputs, extracted_teacher_io_dict
    def forward(self, sample_batch, targets, supp_dict):
        model_outputs = self.model_forward_proc(self.model, sample_batch, targets, supp_dict)
        extracted_model_io_dict = extract_io_dict(self.model_io_dict, self.device)
        if isinstance(self.model, SpecialModule):
            self.model.post_forward(extracted_model_io_dict)

        teacher_outputs = None
        org_loss_dict = self.extract_org_loss(self.org_criterion, model_outputs, teacher_outputs, targets,
                                              uses_teacher_output=False, supp_dict=supp_dict)
        update_io_dict(extracted_model_io_dict, extract_io_dict(self.model_io_dict, self.device))
        output_dict = {'student': extracted_model_io_dict, 'teacher': dict()}
        total_loss = self.criterion(output_dict, org_loss_dict, targets)
        return total_loss