def forward(self, sample_batch, targets, supp_dict): teacher_outputs, extracted_teacher_io_dict =\ self.get_teacher_output(sample_batch, targets, supp_dict=supp_dict) student_outputs = self.student_forward_proc(self.student_model, sample_batch, targets, supp_dict) extracted_student_io_dict = extract_io_dict(self.student_io_dict, self.device) if isinstance(self.student_model, SpecialModule): self.student_model.post_forward(extracted_student_io_dict) org_loss_dict = self.extract_org_loss( self.org_criterion, student_outputs, teacher_outputs, targets, uses_teacher_output=self.uses_teacher_output, supp_dict=supp_dict) update_io_dict(extracted_student_io_dict, extract_io_dict(self.student_io_dict, self.device)) output_dict = { 'teacher': extracted_teacher_io_dict, 'student': extracted_student_io_dict } total_loss = self.criterion(output_dict, org_loss_dict, targets) return total_loss
def get_teacher_output(self, sample_batch, targets, supp_dict): cached_data = supp_dict.get('cached_data', None) cache_file_paths = supp_dict.get('cache_file_path', None) teacher_outputs = None cached_extracted_teacher_output_dict = None # Use cached data if available if cached_data is not None and isinstance(cached_data, dict): device = sample_batch.device teacher_outputs = cached_data['teacher_outputs'] cached_extracted_teacher_output_dict = cached_data['extracted_outputs'] if device.type != 'cpu': teacher_outputs = change_device(teacher_outputs, device) cached_extracted_teacher_output_dict = change_device(cached_extracted_teacher_output_dict, device) if not self.teacher_updatable: return teacher_outputs, cached_extracted_teacher_output_dict if teacher_outputs is None: if self.teacher_updatable: teacher_outputs = self.teacher_forward_proc(self.teacher_model, sample_batch, targets, supp_dict) else: with torch.no_grad(): teacher_outputs = self.teacher_forward_proc(self.teacher_model, sample_batch, targets, supp_dict) if cached_extracted_teacher_output_dict is not None: if isinstance(self.teacher_model, SpecialModule) or \ (check_if_wrapped(self.teacher_model) and isinstance(self.teacher_model.module, SpecialModule)): self.teacher_io_dict.update(cached_extracted_teacher_output_dict) if isinstance(self.teacher_model, SpecialModule): self.teacher_model.post_forward(self.teacher_io_dict) extracted_teacher_io_dict = extract_io_dict(self.teacher_io_dict, self.device) return teacher_outputs, extracted_teacher_io_dict # Deep copy of teacher info dict if teacher special module contains trainable module(s) teacher_io_dict4cache = copy.deepcopy(self.teacher_io_dict) \ if self.teacher_updatable and isinstance(cache_file_paths, (list, tuple)) is not None else None extracted_teacher_io_dict = extract_io_dict(self.teacher_io_dict, self.device) if isinstance(self.teacher_model, SpecialModule): self.teacher_model.post_forward(extracted_teacher_io_dict) update_io_dict(extracted_teacher_io_dict, extract_io_dict(self.teacher_io_dict, self.device)) # Write cache files if output file paths (cache_file_paths) are given if isinstance(cache_file_paths, (list, tuple)): if teacher_io_dict4cache is None: teacher_io_dict4cache = extracted_teacher_io_dict cpu_device = torch.device('cpu') for i, (teacher_output, cache_file_path) in enumerate(zip(teacher_outputs.cpu().numpy(), cache_file_paths)): sub_dict = extract_sub_model_output_dict(teacher_io_dict4cache, i) sub_dict = tensor2numpy2tensor(sub_dict, cpu_device) cache_dict = {'teacher_outputs': torch.Tensor(teacher_output), 'extracted_outputs': sub_dict} make_parent_dirs(cache_file_path) torch.save(cache_dict, cache_file_path) return teacher_outputs, extracted_teacher_io_dict
def forward(self, sample_batch, targets, supp_dict): model_outputs = self.model_forward_proc(self.model, sample_batch, targets, supp_dict) extracted_model_io_dict = extract_io_dict(self.model_io_dict, self.device) if isinstance(self.model, SpecialModule): self.model.post_forward(extracted_model_io_dict) teacher_outputs = None org_loss_dict = self.extract_org_loss(self.org_criterion, model_outputs, teacher_outputs, targets, uses_teacher_output=False, supp_dict=supp_dict) update_io_dict(extracted_model_io_dict, extract_io_dict(self.model_io_dict, self.device)) output_dict = {'student': extracted_model_io_dict, 'teacher': dict()} total_loss = self.criterion(output_dict, org_loss_dict, targets) return total_loss