def find_the_best_branch(self, dataloader): with set_mode(self.student, training=False), \ set_mode(self.teachers, training=False), \ torch.no_grad(): n_blocks = len(self.student.student_decoders) branch_loss = { task: [0 for _ in range(n_blocks)] for task in self.tasks } for batch in dataloader: batch = move_to_device(batch, self.device) data = batch if isinstance(batch, torch.Tensor) else batch[0] for candidate_branch in range(n_blocks): self.student.set_branch( [candidate_branch for _ in range(len(self.teachers))]) s_out_list = self.student(data) t_out_list = [teacher(data) for teacher in self.teachers] for task, s_out, t_out in zip(self.tasks, s_out_list, t_out_list): task_loss = task.get_loss(s_out, t_out) branch_loss[task][candidate_branch] += sum( task_loss.values()) best_brach = [] for task in self.tasks: best_brach.append(int(np.argmin(branch_loss[task]))) self.student.set_branch(best_brach) return best_brach
def step_fn(self, engine, batch): start_time = time.perf_counter() batch = move_to_device(batch, self._device) data = batch[0] s_out = self.student( data ) with torch.no_grad(): t_out = [ teacher( data ) for teacher in self.teachers ] loss_amal = 0 loss_recons = 0 for amal_block, hooks, C in self._amal_blocks: features = [ h.feat_out for h in hooks ] fs, fts = features[0], features[1:] rep, _fs, _fts = amal_block( fs, fts ) loss_amal += F.mse_loss( _fs, rep.detach() ) loss_recons += sum( [ F.mse_loss( _ft, ft ) for (_ft, ft) in zip( _fts, fts ) ] ) loss_kd = tasks.loss.kldiv( s_out, torch.cat( t_out, dim=1 ) ) #loss_kd = F.mse_loss( s_out, torch.cat( t_out, dim=1 ) ) loss_dict = { "loss_kd": self._weights[0] * loss_kd, "loss_amal": self._weights[1] * loss_amal, "loss_recons": self._weights[2] * loss_recons } loss = sum(loss_dict.values()) self.optimizer.zero_grad() self._amal_optimimizer.zero_grad() loss.backward() self.optimizer.step() self._amal_optimimizer.step() self._amal_scheduler.step() step_time = time.perf_counter() - start_time metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } metrics.update({ 'total_loss': loss.item(), 'step_time': step_time, 'lr': float( self.optimizer.param_groups[0]['lr'] ) }) return metrics
def __call__(self, trainer): if trainer.tb_writer is None: trainer.logger.warning("summary writer was not found in trainer") return device = trainer.device model = self._model() with torch.no_grad(), set_mode(model, training=False): for img_id, idx in enumerate(self.idx_list): batch = move_to_device(self._dataset[idx], device) batch = [d.unsqueeze(0) for d in batch] inputs, targets, preds = self._prepare_fn(model, batch) if self._normalizer is not None: inputs = self._normalizer(inputs) inputs = inputs.detach().cpu().numpy() preds = preds.detach().cpu().numpy() targets = targets.detach().cpu().numpy() if self._decode_fn: # to RGB 0~1 NCHW preds = self._decode_fn(preds) targets = self._decode_fn(targets) inputs = inputs[0] preds = preds[0] targets = targets[0] trainer.tb_writer.add_images("%s-%d" % (self._tag, img_id), np.stack([inputs, targets, preds], axis=0), global_step=trainer.state.iter)
def step_fn(self, engine, batch): student = self.student teacher = self.teacher start_time = time.perf_counter() batch = move_to_device(batch, self.device) inputs, targets = batch outputs = student(inputs) with torch.no_grad(): soft_targets = teacher(inputs) loss_dict = { "loss_kld": self.alpha * kldiv(outputs, soft_targets, T=self.T), "loss_ce": self.beta * F.cross_entropy(outputs, targets), "loss_additional": self.gamma * self.additional_kd_loss(engine, batch) } loss = sum(loss_dict.values()) self.optimizer.zero_grad() loss.backward() self.optimizer.step() step_time = time.perf_counter() - start_time metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } metrics.update({ 'total_loss': loss.item(), 'step_time': step_time, 'lr': float(self.optimizer.param_groups[0]['lr']) }) return metrics
def step_fn(self, engine, batch): start_time = time.perf_counter() batch = move_to_device(batch, self._device) data = batch[0] s_out = self._student(data) with torch.no_grad(): t_out = [teacher(data) for teacher in self._teachers] loss_amal = 0 loss_recons = 0 for amal_block, hooks, C in self._amal_blocks: features = [ h.feat_in if self._on_layer_input else h.feat_out for h in hooks ] fs, fts = features[0], features[1:] (hs, hts), (_fts, fts) = amal_block(fs, fts) _loss_amal, _loss_recons = self._cfl_criterion(hs, hts, _fts, fts) loss_amal += _loss_amal loss_recons += _loss_recons loss_kd = tasks.loss.kldiv(s_out, torch.cat(t_out, dim=1)) loss_dict = { 'loss_kd': self._weights[0] * loss_kd, 'loss_amal': self._weights[1] * loss_amal, 'loss_recons': self._weights[2] * loss_recons } loss = sum(loss_dict.values()) self._optimizer.zero_grad() self._amal_optimimizer.zero_grad() loss.backward() self._optimizer.step() self._amal_optimimizer.step() self._amal_scheduler.step() step_time = time.perf_counter() - start_time metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } metrics.update({ 'total_loss': loss.item(), 'step_time': step_time, 'lr': float(self._optimizer.param_groups[0]['lr']) }) return metrics
def step_fn(self, engine, batch): start_time = time.perf_counter() batch = move_to_device(batch, self._device) data = batch[0] #data = batch if isinstance(batch, torch.Tensor) else batch[0] data, None n_blocks = len(self.student.student_decoders) if not self.is_finetuning: rand_branch_indices = [ random.randint(0, n_blocks - 1) for _ in range(len(self.teachers)) ] self.student.set_branch(rand_branch_indices) s_out_list = self.student(data) with torch.no_grad(): t_out_list = [teacher(data) for teacher in self.teachers] loss_dict = {} for task, s_out, t_out in zip(self.tasks, s_out_list, t_out_list): task_loss = task.get_loss(s_out, t_out) loss_dict.update(task_loss) loss = sum(loss_dict.values()) self.optimizer.zero_grad() loss.backward() self.optimizer.step() step_time = time.perf_counter() - start_time metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } metrics.update({ 'total_loss': loss.item(), 'step_time': step_time, 'lr': float(self.optimizer.param_groups[0]['lr']), 'branch': self.student.branch_indices.cpu().numpy().tolist() }) return metrics
def step_fn(self, engine, batch): model = self.model start_time = time.perf_counter() batch = move_to_device(batch, self.device) inputs, targets = split_batch(batch) outputs = model(inputs) loss_dict = self.task.get_loss(outputs, targets) # get loss loss = sum( loss_dict.values() ) self.optimizer.zero_grad() loss.backward() self.optimizer.step() step_time = time.perf_counter() - start_time metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } metrics.update({ 'total_loss': loss.item(), 'step_time': step_time, 'lr': float( self.optimizer.param_groups[0]['lr'] ) }) return metrics
def step_fn(self, engine, batch): batch = move_to_device(batch, self.device) self.eval_fn(engine, batch)