Esempio n. 1
0
    def find_the_best_branch(self, dataloader):

        with set_mode(self.student, training=False), \
             set_mode(self.teachers, training=False), \
             torch.no_grad():
            n_blocks = len(self.student.student_decoders)
            branch_loss = {
                task: [0 for _ in range(n_blocks)]
                for task in self.tasks
            }
            for batch in dataloader:
                batch = move_to_device(batch, self.device)
                data = batch if isinstance(batch, torch.Tensor) else batch[0]
                for candidate_branch in range(n_blocks):
                    self.student.set_branch(
                        [candidate_branch for _ in range(len(self.teachers))])
                    s_out_list = self.student(data)
                    t_out_list = [teacher(data) for teacher in self.teachers]
                    for task, s_out, t_out in zip(self.tasks, s_out_list,
                                                  t_out_list):
                        task_loss = task.get_loss(s_out, t_out)
                        branch_loss[task][candidate_branch] += sum(
                            task_loss.values())
            best_brach = []
            for task in self.tasks:
                best_brach.append(int(np.argmin(branch_loss[task])))

            self.student.set_branch(best_brach)
            return best_brach
 def step_fn(self, engine, batch):
     start_time = time.perf_counter()
     batch = move_to_device(batch, self._device)
     data = batch[0]
     s_out = self.student( data )
     with torch.no_grad():
         t_out = [ teacher( data ) for teacher in self.teachers ]
     loss_amal = 0
     loss_recons = 0
     for amal_block, hooks, C in self._amal_blocks:
         features = [ h.feat_out for h in hooks ]
         fs, fts = features[0], features[1:]
         rep, _fs, _fts = amal_block( fs, fts )
         loss_amal += F.mse_loss( _fs, rep.detach() )
         loss_recons += sum( [ F.mse_loss( _ft, ft ) for (_ft, ft) in zip( _fts, fts ) ] )
     loss_kd = tasks.loss.kldiv( s_out, torch.cat( t_out, dim=1 ) )
     #loss_kd = F.mse_loss( s_out, torch.cat( t_out, dim=1 ) )
     loss_dict = { "loss_kd":      self._weights[0] * loss_kd,
                   "loss_amal":    self._weights[1] * loss_amal,
                   "loss_recons":  self._weights[2] * loss_recons }
     loss = sum(loss_dict.values())
     self.optimizer.zero_grad()
     self._amal_optimimizer.zero_grad()
     loss.backward()
     self.optimizer.step()
     self._amal_optimimizer.step()
     self._amal_scheduler.step()
     step_time = time.perf_counter() - start_time
     metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() }
     metrics.update({
         'total_loss': loss.item(),
         'step_time': step_time,
         'lr': float( self.optimizer.param_groups[0]['lr'] )
     })
     return metrics
Esempio n. 3
0
 def __call__(self, trainer):
     if trainer.tb_writer is None:
         trainer.logger.warning("summary writer was not found in trainer")
         return
     device = trainer.device
     model = self._model()
     with torch.no_grad(), set_mode(model, training=False):
         for img_id, idx in enumerate(self.idx_list):
             batch = move_to_device(self._dataset[idx], device)
             batch = [d.unsqueeze(0) for d in batch]
             inputs, targets, preds = self._prepare_fn(model, batch)
             if self._normalizer is not None:
                 inputs = self._normalizer(inputs)
             inputs = inputs.detach().cpu().numpy()
             preds = preds.detach().cpu().numpy()
             targets = targets.detach().cpu().numpy()
             if self._decode_fn:  # to RGB 0~1 NCHW
                 preds = self._decode_fn(preds)
                 targets = self._decode_fn(targets)
             inputs = inputs[0]
             preds = preds[0]
             targets = targets[0]
             trainer.tb_writer.add_images("%s-%d" % (self._tag, img_id),
                                          np.stack([inputs, targets, preds],
                                                   axis=0),
                                          global_step=trainer.state.iter)
Esempio n. 4
0
    def step_fn(self, engine, batch):
        student = self.student
        teacher = self.teacher
        start_time = time.perf_counter()
        batch = move_to_device(batch, self.device)
        inputs, targets = batch
        outputs = student(inputs)
        with torch.no_grad():
            soft_targets = teacher(inputs)

        loss_dict = {
            "loss_kld": self.alpha * kldiv(outputs, soft_targets, T=self.T),
            "loss_ce": self.beta * F.cross_entropy(outputs, targets),
            "loss_additional":
            self.gamma * self.additional_kd_loss(engine, batch)
        }

        loss = sum(loss_dict.values())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        step_time = time.perf_counter() - start_time
        metrics = {
            loss_name: loss_value.item()
            for (loss_name, loss_value) in loss_dict.items()
        }
        metrics.update({
            'total_loss': loss.item(),
            'step_time': step_time,
            'lr': float(self.optimizer.param_groups[0]['lr'])
        })
        return metrics
Esempio n. 5
0
    def step_fn(self, engine, batch):
        start_time = time.perf_counter()
        batch = move_to_device(batch, self._device)
        data = batch[0]
        s_out = self._student(data)
        with torch.no_grad():
            t_out = [teacher(data) for teacher in self._teachers]
        loss_amal = 0
        loss_recons = 0
        for amal_block, hooks, C in self._amal_blocks:
            features = [
                h.feat_in if self._on_layer_input else h.feat_out
                for h in hooks
            ]
            fs, fts = features[0], features[1:]
            (hs, hts), (_fts, fts) = amal_block(fs, fts)
            _loss_amal, _loss_recons = self._cfl_criterion(hs, hts, _fts, fts)
            loss_amal += _loss_amal
            loss_recons += _loss_recons
        loss_kd = tasks.loss.kldiv(s_out, torch.cat(t_out, dim=1))
        loss_dict = {
            'loss_kd': self._weights[0] * loss_kd,
            'loss_amal': self._weights[1] * loss_amal,
            'loss_recons': self._weights[2] * loss_recons
        }
        loss = sum(loss_dict.values())
        self._optimizer.zero_grad()
        self._amal_optimimizer.zero_grad()
        loss.backward()
        self._optimizer.step()
        self._amal_optimimizer.step()
        self._amal_scheduler.step()
        step_time = time.perf_counter() - start_time

        metrics = {
            loss_name: loss_value.item()
            for (loss_name, loss_value) in loss_dict.items()
        }
        metrics.update({
            'total_loss': loss.item(),
            'step_time': step_time,
            'lr': float(self._optimizer.param_groups[0]['lr'])
        })
        return metrics
Esempio n. 6
0
    def step_fn(self, engine, batch):
        start_time = time.perf_counter()
        batch = move_to_device(batch, self._device)
        data = batch[0]
        #data = batch if isinstance(batch, torch.Tensor) else batch[0]
        data, None
        n_blocks = len(self.student.student_decoders)
        if not self.is_finetuning:
            rand_branch_indices = [
                random.randint(0, n_blocks - 1)
                for _ in range(len(self.teachers))
            ]
            self.student.set_branch(rand_branch_indices)

        s_out_list = self.student(data)
        with torch.no_grad():
            t_out_list = [teacher(data) for teacher in self.teachers]

        loss_dict = {}
        for task, s_out, t_out in zip(self.tasks, s_out_list, t_out_list):
            task_loss = task.get_loss(s_out, t_out)
            loss_dict.update(task_loss)
        loss = sum(loss_dict.values())

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        step_time = time.perf_counter() - start_time
        metrics = {
            loss_name: loss_value.item()
            for (loss_name, loss_value) in loss_dict.items()
        }
        metrics.update({
            'total_loss':
            loss.item(),
            'step_time':
            step_time,
            'lr':
            float(self.optimizer.param_groups[0]['lr']),
            'branch':
            self.student.branch_indices.cpu().numpy().tolist()
        })
        return metrics
Esempio n. 7
0
 def step_fn(self, engine, batch):
     model = self.model
     start_time = time.perf_counter()
     batch = move_to_device(batch, self.device)
     inputs, targets = split_batch(batch)
     outputs = model(inputs)
     loss_dict = self.task.get_loss(outputs, targets) # get loss
     loss = sum( loss_dict.values() )
     self.optimizer.zero_grad()
     loss.backward()
     self.optimizer.step()
     step_time = time.perf_counter() - start_time
     metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() }
     metrics.update({
         'total_loss': loss.item(),
         'step_time': step_time,
         'lr': float( self.optimizer.param_groups[0]['lr'] )
     })
     return metrics
Esempio n. 8
0
 def step_fn(self, engine, batch):
     batch = move_to_device(batch, self.device)
     self.eval_fn(engine, batch)