def iteration(self, data: Tuple[Tensor, Tensor]) -> None:
        input, target = data
        if self.is_train:
            if self.is_supervised:
                output = self.model['student'](input)
                loss = self.loss_f(output, target)
            else:
                teacher_output = self.model['teacher'](input).detach()
                input, teacher_target = mixup(input, teacher_output,
                                              self.beta.sample())
                output = self.model['student'](input)
                loss = self.consistency_weight * nn.MSELoss()(output,
                                                              teacher_target)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            _update_teacher(self.model['student'], self.model['teacher'],
                            self.alpha, self.epoch)
        else:
            output = self.model['student'](input)
            loss = self.loss_f(output, target)

        if self._is_debug and torch.isnan(loss):
            self.logger.warning("loss is NaN")
        # print(output.shape, target.shape)
        self.reporter.add('accuracy', accuracy(output, target))
        self.reporter.add('loss', loss.detach_())
        if self._report_topk is not None:
            for top_k in self._report_topk:
                self.reporter.add(f'accuracy@{top_k}',
                                  accuracy(output, target, top_k))
Beispiel #2
0
    def iteration(self, data: Tuple[Tensor, Tensor]) -> None:
        input = data[0]
        target = data[1]
        if self.is_train:
            unlabel_data = data[2]
            augment_unlabel_list = data[3:-1]
            num_list = len(augment_unlabel_list)
            mix_input, mix_unlabel_list, mix_target, mix_fake_target_list = self.generate_mixmatch(
                input, target, augment_unlabel_list)
            output = nn.LogSoftmax(dim=1)(self.model(mix_input))
            loss = -(output * mix_target).sum(dim=1).mean()
            for mix_unlabel, mix_fake_target in zip(mix_unlabel_list,
                                                    mix_fake_target_list):
                fake_output = self.model(mix_unlabel)
                loss += self.consistency_weight * \
                    nn.MSELoss()(fake_output, mix_fake_target) / num_list
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
        else:
            output = self.model(input)
            loss = self.loss_f(output, target)

        if self._is_debug and torch.isnan(loss):
            self.logger.warning("loss is NaN")
        self.reporter.add('accuracy', accuracy(output, target))
        self.reporter.add('loss', loss.detach_())
        if self._report_topk is not None:
            for top_k in self._report_topk:
                self.reporter.add(f'accuracy@{top_k}',
                                  accuracy(output, target, top_k))
 def iteration(self, data: Tuple[Tensor, Tensor]) -> None:
     input, target = data
     if self.is_train:
         unlabel = input[target == -100]
         output = self.model['student'](input)
         loss = self.loss_f(output, target)
         teacher_output = self.model['teacher'](unlabel).detach()
         unlabel, teacher_target = mixup(unlabel, teacher_output,
                                         self.beta.sample())
         student_output = self.model['student'](unlabel)
         loss += self.consistency_weight() * nn.MSELoss()(student_output,
                                                          teacher_target)
         self.optimizer.zero_grad()
         loss.backward()
         self.optimizer.step()
         _update_teacher(self.model['student'], self.model['teacher'],
                         self.alpha())
         self.consistency_weight.step()
         self.alpha.step()
     else:
         output = self.model['student'](input)
         loss = self.loss_f(output, target)
         self.reporter.add('accuracy', accuracy(output, target))
     self.reporter.add('loss', loss.detach_())
     if self._report_topk is not None:
         for top_k in self._report_topk:
             self.reporter.add(f'accuracy@{top_k}',
                               accuracy(output, target, top_k))
Beispiel #4
0
    def iteration(self, data: Tuple[Tensor, Tensor]) -> None:
        input, target = data
        _disable_noise(self.mbn_list)
        output = self.model['backbone'](input)
        if self.is_train:
            _enable_noise(self.mbn_list)
            output = self.model['backbone'](input)
            self.model['decoder'](self.mbn_list[-1].noise_h)
            loss = _get_loss_d(self.model['decoder'], self.lam_list)
            loss += self.loss_f(output, target)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
        else:
            loss = self.loss_f(output, target)
            self.reporter.add('accuracy', accuracy(output, target))

        if self._is_debug and torch.isnan(loss):
            self.logger.warning("loss is NaN")

        self.reporter.add('loss', loss.detach_())
        if self._report_topk is not None:
            for top_k in self._report_topk:
                self.reporter.add(
                    f'accuracy@{top_k}', accuracy(output, target, top_k))
Beispiel #5
0
    def iteration(self, data: Tuple[torch.Tensor, torch.Tensor]) -> None:
        input, target = data
        with torch.cuda.amp.autocast(self._use_amp):
            output = self.model(input)
            loss = self.loss_f(output, target)

        if self.is_train:
            self.optimizer.zero_grad(set_to_none=True)
            if self._use_amp:
                self.scaler.scale(loss).backward()
            else:
                loss.backward()
            if self.cfg.grad_clip > 0:
                if self._use_amp:
                    self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.cfg.grad_clip)
            if self._use_amp:
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                self.optimizer.step()

        if self._is_debug and torch.isnan(loss):
            self.logger.warning("loss is NaN")

        self.reporter.add('accuracy', accuracy(output, target))
        self.reporter.add('loss', loss.detach_())
        if self._report_topk is not None:
            for top_k in self._report_topk:
                self.reporter.add(f'accuracy@{top_k}',
                                  accuracy(output, target, top_k))
    def iteration(self, data: Tuple[Tensor, Tensor]) -> None:
        if len(data) == 3:
            stduent_input, teacher_input, target = data
        else:
            input, target = data

        if self.is_train:
            stduent_output = nn.LogSoftmax(dim=1)(
                self.model['student'](stduent_input))
            teacher_output = nn.Softmax(dim=1)(
                self.model['teacher'](teacher_input))
            loss = self.consistency_weight * \
                nn.KLDivLoss(reduction='batchmean')(
                    stduent_output, teacher_output.detach())
            if self.is_supervised:
                loss += nn.NLLLoss()(stduent_output, target)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            _update_teacher(self.model['student'], self.model['teacher'],
                            self.alpha, self.epoch)
        else:
            teacher_output = self.model['teacher'](input)
            loss = self.loss_f(teacher_output, target)

        if self._is_debug and torch.isnan(loss):
            self.logger.warning("loss is NaN")
        self.reporter.add('accuracy', accuracy(teacher_output, target))
        self.reporter.add('loss', loss.detach_())
        if self._report_topk is not None:
            for top_k in self._report_topk:
                self.reporter.add(f'accuracy@{top_k}',
                                  accuracy(output, target, top_k))
Beispiel #7
0
 def test_iteration(self, data: Tuple[Tensor, Tensor]):
     input, target = data
     z = self.model['generator_z'](input)
     output, _ = self.model['discriminator_x_z'](
         self.model['discriminator_x'](input),
         self.model['discriminator_z'](z))
     loss = self.loss_f(output, target)
     if self._is_debug and torch.isnan(loss):
         self.logger.warning("loss is NaN")
     self.reporter.add('accuracy', accuracy(output, target))
     self.reporter.add('loss', loss.detach_())
     if self._report_topk is not None:
         for top_k in self._report_topk:
             self.reporter.add(f'accuracy@{top_k}',
                               accuracy(output, target, top_k))
Beispiel #8
0
 def supervised_iteration(self, data: Tuple[Tensor, Tensor]):
     input, target = data
     # Accrding to paper, only discriminator is helpful to supervise task
     z = self.model['generator_z'](input).detach()
     output, _ = self.model['discriminator_x_z'](
         self.model['discriminator_x'](input),
         self.model['discriminator_z'](z))
     loss = self.loss_f(output, target)
     self.optimizerD.zero_grad()
     loss.backward()
     self.optimizerD.step()
     if self._is_debug and torch.isnan(loss):
         self.logger.warning("loss is NaN")
     self.reporter.add('accuracy', accuracy(output, target))
     self.reporter.add('loss', loss.detach_())
     if self._report_topk is not None:
         for top_k in self._report_topk:
             self.reporter.add(f'accuracy@{top_k}',
                               accuracy(output, target, top_k))
 def iteration(self, data: Tuple[Tensor, Tensor]) -> None:
     if self._is_train:
         if self.dataset_type == 'mix':
             loss = self._mix_iteration(data)
         else:
             loss = self._split_iteration(data)
         self.optimizer.zero_grad()
         loss.backward()
         self.optimizer.step()
         _update_teacher(self.model['student'], self.model['teacher'],
                         self.alpha())
         self.reporter.add('loss', loss.detach_())
         self.consistency_weight.step()
         self.alpha.step()
     else:
         input, target = data
         teacher_output = self.model['teacher'](input)
         loss = self.loss_f(teacher_output, target)
         self.reporter.add('accuracy', accuracy(teacher_output, target))
         self.reporter.add('loss', loss.detach_())
         if self._report_topk is not None:
             for top_k in self._report_topk:
                 self.reporter.add(f'accuracy@{top_k}',
                                   accuracy(output, target, top_k))
Beispiel #10
0
    def iteration(self,
                  data):
        input, target = data
        _target = target
        target = F.one_hot(target, self.num_classes)
        if self.is_train and self.cfg.mixup:
            if self.cfg.input_only:
                input = partial_mixup(input, np.random.beta(self.cfg.alpha + 1, self.cfg.alpha),
                                      torch.randperm(input.size(0), device=input.device, dtype=torch.long))
            else:
                input, target = mixup(input, target, np.random.beta(self.cfg.alpha, self.cfg.alpha))

        output = self.model(input)
        loss = self.loss_f(output, target)

        self.reporter.add('loss', loss.detach())
        self.reporter.add('accuracy', accuracy(output, _target))

        if self.is_train:
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()