Esempio n. 1
0
 def on_stage_start(self, state: RunnerState):
     optimizer = state.get_key(
         key="optimizer", inner_key=self.optimizer_key
     )
     assert optimizer is not None
     lr = optimizer.defaults["lr"]
     momentum = get_optimizer_momentum(optimizer)
     state.set_key(lr, "lr", inner_key=self.optimizer_key)
     state.set_key(momentum, "momentum", inner_key=self.optimizer_key)
Esempio n. 2
0
    def on_epoch_start(self, state: RunnerState):
        if self.checkpoint_resume and state.stage_epoch == 0:
            state.epoch += 1

        state.stage_epoch = state.stage_epoch + self.checkpoint_stage_epoch
        state.checkpoint_data = {'stage_epoch': state.stage_epoch}
        if self.master:
            if state.stage_epoch == 0:
                self.step.start(1, name=state.stage)

            self.step.start(2,
                            name=f'epoch {state.stage_epoch}',
                            index=state.stage_epoch)
Esempio n. 3
0
    def on_batch_end(self, state: RunnerState):
        if not self.is_needed:
            return

        targets = state.input[self.target_key]
        loss_mask = (targets == self.unsupervised_label)  # Mask indicating unlabeled samples

        loss_mask = loss_mask.float()  # Mask indicating unlabeled samples
        aug_regression = state.output[self.output_key]

        with torch.no_grad():
            orig_image: torch.Tensor = state.input[self.input_key].detach()

            # Compute target probability distribution
            training = state.model.training
            state.model.eval()
            ori_regression_tgt = state.model(orig_image)[self.output_key]
            state.model.train(training)

        aug_loss = F.mse_loss(aug_regression, ori_regression_tgt, reduction='none')

        loss = aug_loss * loss_mask
        loss = loss.sum() / loss_mask.sum().clamp_min(1)

        state.metrics.add_batch_value(metrics_dict={
            self.prefix: loss.item(),
        })

        self._add_loss_to_state(state, loss)
Esempio n. 4
0
    def on_batch_start(self, state: RunnerState):
        if not self.is_needed:
            return

        is_batch_needed = np.random.random() < self.p

        if is_batch_needed:
            lam = np.random.beta(self.alpha, self.alpha)
        else:
            lam = 1

        index = torch.randperm(state.input[self.fields[0]].shape[0]).to(state.device)
        state.input["mixup_index"] = index
        state.input["mixup_lambda"] = lam

        for f in self.fields:
            a = lam * state.input[f]
            b = (1 - lam) * state.input[f][index]
            state.input[f] = a + b
Esempio n. 5
0
    def on_batch_start(self, state: RunnerState):
        if not self.is_needed:
            return

        targets = state.input[self.target_key]

        for label_index in torch.arange(5):
            mask = targets == label_index
            lam = np.random.beta(self.alpha, self.alpha)

            index = torch.randperm(mask.shape[0])

            for f in self.fields:
                state.input[f][mask] = lam * state.input[f][mask] + \
                                       (1 - lam) * state.input[f][mask][index]
Esempio n. 6
0
    def on_batch_start(self, state: RunnerState):
        if not self.is_needed:
            return

        if self.alpha > 0:
            self.lam = np.random.beta(self.alpha, self.alpha)
        else:
            self.lam = 1

        if self.lam < 0.3 or self.lam > 0.7:
            # Do not apply mixup on small lambdas
            return

        self.index = torch.randperm(state.input[self.fields[0]].shape[0])
        self.index.to(state.device)

        for f in self.fields:
            state.input[f] = self.lam * state.input[f] + \
                             (1 - self.lam) * state.input[f][self.index]
Esempio n. 7
0
    def on_batch_end(self, state: RunnerState):
        if not self.is_needed:
            return

        targets = state.input[self.target_key]
        loss_mask = (targets == self.unsupervised_label)  # Mask indicating unlabeled samples

        loss_mask = loss_mask.float()  # Mask indicating unlabeled samples
        aug_logits = state.output[self.output_key]

        with torch.no_grad():
            orig_image: torch.Tensor = state.input[self.input_key].detach()

            # Compute target probability distribution
            training = state.model.training
            state.model.eval()
            ori_logits_tgt = state.model(orig_image)[self.output_key]
            state.model.train(training)

            if self.softmax_temperature is not None:
                # Softmax temperature controlling. See Chapter 3.2
                ori_logits_tgt = ori_logits_tgt / self.softmax_temperature

        aug_loss = _kl_divergence_with_logits(p_logits=ori_logits_tgt,
                                              q_logits=aug_logits)

        if self.confidence_masking_threshold is not None:
            # Confidence-based masking. See Chapter 3.2
            ori_prob = F.log_softmax(ori_logits_tgt, dim=1).exp()
            max_prob, max_idxs = torch.max(ori_prob, dim=1)
            max_prob_mask = (max_prob > self.confidence_masking_threshold).float()
            loss_mask = loss_mask * max_prob_mask

        loss = aug_loss * loss_mask
        loss = loss.sum() / loss_mask.sum().clamp_min(1)

        state.metrics.add_batch_value(metrics_dict={
            self.prefix: loss.item(),
        })

        self._add_loss_to_state(state, loss)
Esempio n. 8
0
 def on_stage_start(self, state: RunnerState):
     state.loggers = {
         'console': VerboseLogger(),
         'raise': RaiseExceptionLogger()
     }
Esempio n. 9
0
 def on_stage_start(self, state: RunnerState):
     state.loggers = [VerboseLogger(), RaiseExceptionLogger()]
Esempio n. 10
0
 def on_stage_start(self, state: RunnerState):
     state.loggers = {'console': VerboseLogger()}