Esempio n. 1
0
    def find_the_best_branch(self, dataloader):

        with set_mode(self.student, training=False), \
             set_mode(self.teachers, training=False), \
             torch.no_grad():
            n_blocks = len(self.student.student_decoders)
            branch_loss = {
                task: [0 for _ in range(n_blocks)]
                for task in self.tasks
            }
            for batch in dataloader:
                batch = move_to_device(batch, self.device)
                data = batch if isinstance(batch, torch.Tensor) else batch[0]
                for candidate_branch in range(n_blocks):
                    self.student.set_branch(
                        [candidate_branch for _ in range(len(self.teachers))])
                    s_out_list = self.student(data)
                    t_out_list = [teacher(data) for teacher in self.teachers]
                    for task, s_out, t_out in zip(self.tasks, s_out_list,
                                                  t_out_list):
                        task_loss = task.get_loss(s_out, t_out)
                        branch_loss[task][candidate_branch] += sum(
                            task_loss.values())
            best_brach = []
            for task in self.tasks:
                best_brach.append(int(np.argmin(branch_loss[task])))

            self.student.set_branch(best_brach)
            return best_brach
Esempio n. 2
0
    def run(self, max_iter, start_iter=0, epoch_length=None):
        block_params = []
        for block, _, _ in self._amal_blocks:
            block_params.extend(list(block.parameters()))
        if isinstance(self._optimizer, torch.optim.SGD):
            self._amal_optimimizer = torch.optim.SGD(
                block_params,
                lr=self._optimizer.param_groups[0]['lr'],
                momentum=0.9,
                weight_decay=1e-4)
        else:
            self._amal_optimimizer = torch.optim.Adam(
                block_params,
                lr=self._optimizer.param_groups[0]['lr'],
                weight_decay=1e-4)
        self._amal_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self._amal_optimimizer, T_max=max_iter)

        with set_mode(self._student, training=True), \
             set_mode(self._teachers, training=False):
            super(CommonFeatureAmalgamator,
                  self).run(self.step_fn,
                            self._dataloader,
                            start_iter=start_iter,
                            max_iter=max_iter,
                            epoch_length=epoch_length)
Esempio n. 3
0
    def run(self,
            max_iter,
            start_iter=0,
            epoch_length=None,
            stage_callback=None):
        # Branching
        with set_mode(self.student, training=True), \
             set_mode(self.teachers, training=False):
            super(TaskBranchingAmalgamator,
                  self).run(self.step_fn,
                            self._dataloader,
                            start_iter=start_iter,
                            max_iter=max_iter // 2,
                            epoch_length=epoch_length)
        branch = self.find_the_best_branch(self._dataloader)
        self.logger.info("[Task Branching] the best branch indices: %s" %
                         (branch))

        if stage_callback is not None:
            stage_callback()

        # Finetuning
        self.is_finetuning = True
        with set_mode(self.student, training=True), \
             set_mode(self.teachers, training=False):
            super(TaskBranchingAmalgamator,
                  self).run(self.step_fn,
                            self._dataloader,
                            start_iter=max_iter - max_iter // 2,
                            max_iter=max_iter,
                            epoch_length=epoch_length)
Esempio n. 4
0
    def run( self, max_iter, start_iter=0, epoch_length=None):
        self.student.to(self.device)
        self.teacher.to(self.device)

        with set_mode(self.student, training=True), \
             set_mode(self.teacher, training=False):
            super( BasicTrainer, self ).run(
                self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length)
Esempio n. 5
0
 def run(self, max_iter, start_iter=0, epoch_length=None):
     with set_mode(self.student, training=True), \
          set_mode(self.teacher, training=False):
         super(KDDistiller, self).run(self.step_fn,
                                      self.dataloader,
                                      start_iter=start_iter,
                                      max_iter=max_iter,
                                      epoch_length=epoch_length)
Esempio n. 6
0
 def __call__(self, trainer):
     if trainer.tb_writer is None:
         trainer.logger.warning("summary writer was not found in trainer")
         return
     device = trainer.device
     model = self._model()
     with torch.no_grad(), set_mode(model, training=False):
         for img_id, idx in enumerate(self.idx_list):
             batch = move_to_device(self._dataset[idx], device)
             batch = [d.unsqueeze(0) for d in batch]
             inputs, targets, preds = self._prepare_fn(model, batch)
             if self._normalizer is not None:
                 inputs = self._normalizer(inputs)
             inputs = inputs.detach().cpu().numpy()
             preds = preds.detach().cpu().numpy()
             targets = targets.detach().cpu().numpy()
             if self._decode_fn:  # to RGB 0~1 NCHW
                 preds = self._decode_fn(preds)
                 targets = self._decode_fn(targets)
             inputs = inputs[0]
             preds = preds[0]
             targets = targets[0]
             trainer.tb_writer.add_images("%s-%d" % (self._tag, img_id),
                                          np.stack([inputs, targets, preds],
                                                   axis=0),
                                          global_step=trainer.state.iter)
Esempio n. 7
0
    def train(self, start_iter, max_iter, optim_s, optim_g, device=None):
        if device is None:
            device = torch.device(
                'cuda' if torch.cuda.is_available() else 'cpu')
        self.device = device
        self.optim_s, self.optim_g = optim_s, optim_g

        self.model.to(self.device)
        self.teacher.to(self.device)
        self.generator.to(self.device)
        self.train_loader = [
            0,
        ]

        with set_mode(self.student, training=True), \
             set_mode(self.teacher, training=False), \
             set_mode(self.generator, training=True):
            super(ZSKTDistiller, self).train(start_iter, max_iter)
Esempio n. 8
0
 def eval(self, model, device=None):
     device = device if device is not None else \
         torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
     self._model = weakref.ref(model)  # use weakref here
     self.device = device
     self.metric.reset()
     model.to(device)
     if self.progress:
         self.porgress_callback.callback.reset()
     with torch.no_grad(), set_mode(model, training=False):
         super(BasicEvaluator, self).run(self.step_fn,
                                         self.dataloader,
                                         max_iter=len(self.dataloader))
     return self.metric.get_results()
Esempio n. 9
0
 def eval(self, model, device=None):
     self.teacher.to(device)
     with set_mode(self.teacher, training=False):
         return super(TeacherEvaluator, self).eval(model, device=device)
Esempio n. 10
0
 def run( self, max_iter, start_iter=0, epoch_length=None):
     self.model.to(self.device)
     with set_mode(self.model, training=True):
         super( BasicTrainer, self ).run( self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length)