def find_the_best_branch(self, dataloader): with set_mode(self.student, training=False), \ set_mode(self.teachers, training=False), \ torch.no_grad(): n_blocks = len(self.student.student_decoders) branch_loss = { task: [0 for _ in range(n_blocks)] for task in self.tasks } for batch in dataloader: batch = move_to_device(batch, self.device) data = batch if isinstance(batch, torch.Tensor) else batch[0] for candidate_branch in range(n_blocks): self.student.set_branch( [candidate_branch for _ in range(len(self.teachers))]) s_out_list = self.student(data) t_out_list = [teacher(data) for teacher in self.teachers] for task, s_out, t_out in zip(self.tasks, s_out_list, t_out_list): task_loss = task.get_loss(s_out, t_out) branch_loss[task][candidate_branch] += sum( task_loss.values()) best_brach = [] for task in self.tasks: best_brach.append(int(np.argmin(branch_loss[task]))) self.student.set_branch(best_brach) return best_brach
def run(self, max_iter, start_iter=0, epoch_length=None): block_params = [] for block, _, _ in self._amal_blocks: block_params.extend(list(block.parameters())) if isinstance(self._optimizer, torch.optim.SGD): self._amal_optimimizer = torch.optim.SGD( block_params, lr=self._optimizer.param_groups[0]['lr'], momentum=0.9, weight_decay=1e-4) else: self._amal_optimimizer = torch.optim.Adam( block_params, lr=self._optimizer.param_groups[0]['lr'], weight_decay=1e-4) self._amal_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self._amal_optimimizer, T_max=max_iter) with set_mode(self._student, training=True), \ set_mode(self._teachers, training=False): super(CommonFeatureAmalgamator, self).run(self.step_fn, self._dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length)
def run(self, max_iter, start_iter=0, epoch_length=None, stage_callback=None): # Branching with set_mode(self.student, training=True), \ set_mode(self.teachers, training=False): super(TaskBranchingAmalgamator, self).run(self.step_fn, self._dataloader, start_iter=start_iter, max_iter=max_iter // 2, epoch_length=epoch_length) branch = self.find_the_best_branch(self._dataloader) self.logger.info("[Task Branching] the best branch indices: %s" % (branch)) if stage_callback is not None: stage_callback() # Finetuning self.is_finetuning = True with set_mode(self.student, training=True), \ set_mode(self.teachers, training=False): super(TaskBranchingAmalgamator, self).run(self.step_fn, self._dataloader, start_iter=max_iter - max_iter // 2, max_iter=max_iter, epoch_length=epoch_length)
def run( self, max_iter, start_iter=0, epoch_length=None): self.student.to(self.device) self.teacher.to(self.device) with set_mode(self.student, training=True), \ set_mode(self.teacher, training=False): super( BasicTrainer, self ).run( self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length)
def run(self, max_iter, start_iter=0, epoch_length=None): with set_mode(self.student, training=True), \ set_mode(self.teacher, training=False): super(KDDistiller, self).run(self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length)
def __call__(self, trainer): if trainer.tb_writer is None: trainer.logger.warning("summary writer was not found in trainer") return device = trainer.device model = self._model() with torch.no_grad(), set_mode(model, training=False): for img_id, idx in enumerate(self.idx_list): batch = move_to_device(self._dataset[idx], device) batch = [d.unsqueeze(0) for d in batch] inputs, targets, preds = self._prepare_fn(model, batch) if self._normalizer is not None: inputs = self._normalizer(inputs) inputs = inputs.detach().cpu().numpy() preds = preds.detach().cpu().numpy() targets = targets.detach().cpu().numpy() if self._decode_fn: # to RGB 0~1 NCHW preds = self._decode_fn(preds) targets = self._decode_fn(targets) inputs = inputs[0] preds = preds[0] targets = targets[0] trainer.tb_writer.add_images("%s-%d" % (self._tag, img_id), np.stack([inputs, targets, preds], axis=0), global_step=trainer.state.iter)
def train(self, start_iter, max_iter, optim_s, optim_g, device=None): if device is None: device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.device = device self.optim_s, self.optim_g = optim_s, optim_g self.model.to(self.device) self.teacher.to(self.device) self.generator.to(self.device) self.train_loader = [ 0, ] with set_mode(self.student, training=True), \ set_mode(self.teacher, training=False), \ set_mode(self.generator, training=True): super(ZSKTDistiller, self).train(start_iter, max_iter)
def eval(self, model, device=None): device = device if device is not None else \ torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) self._model = weakref.ref(model) # use weakref here self.device = device self.metric.reset() model.to(device) if self.progress: self.porgress_callback.callback.reset() with torch.no_grad(), set_mode(model, training=False): super(BasicEvaluator, self).run(self.step_fn, self.dataloader, max_iter=len(self.dataloader)) return self.metric.get_results()
def eval(self, model, device=None): self.teacher.to(device) with set_mode(self.teacher, training=False): return super(TeacherEvaluator, self).eval(model, device=device)
def run( self, max_iter, start_iter=0, epoch_length=None): self.model.to(self.device) with set_mode(self.model, training=True): super( BasicTrainer, self ).run( self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length)