Esempio n. 1
0
    def on_epoch_start(self, state: State):
        if self.checkpoint_resume and state.stage_epoch == 0:
            state.epoch += 1

        state.stage_epoch = state.stage_epoch + self.checkpoint_stage_epoch
        state.checkpoint_data = {'stage_epoch': state.stage_epoch}
        if self.master:
            self.step.start(1, name=state.stage)

            self.step.start(2,
                            name=f'epoch {state.stage_epoch}',
                            index=state.stage_epoch)
Esempio n. 2
0
    def on_batch_start(self, state: State):
        data = state.input['features']
        inp1 = state.input['h1_targets']
        inp2 = state.input['h2_targets']
        inp3 = state.input['h3_targets']

        if not state.loader_name.startswith("train"):
            return
        data, targets = mixup(data, inp1, inp2, inp3, self.alpha)

        state.input['features'] = data
        state.input['targets'] = targets
Esempio n. 3
0
    def on_batch_start(self, state: State):
        """
        Mixes data according to Cutmix algorithm.
        :param state: current state
        :return: void
        """
        if not self.is_needed:
            return

        if self.alpha > 0:
            self.lam = np.random.beta(self.alpha, self.alpha)
        else:
            self.lam = 1

        self.index = torch.randperm(state.batch_in[self.fields[0]].shape[0])
        self.index.to(state.device)

        bbx1, bby1, bbx2, bby2 = \
            self._rand_bbox(state.batch_in[self.fields[0]].shape, self.lam)

        for f in self.fields:
            state.batch_in[f][:, :, bbx1:bbx2, bby1:bby2] = \
                state.batch_in[f][self.index, :, bbx1:bbx2, bby1:bby2]

        self.lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) /
                        (state.batch_in[self.fields[0]].shape[-1] *
                         state.batch_in[self.fields[0]].shape[-2]))
Esempio n. 4
0
    def on_batch_start(self, state: State):
        step = state.step - self.loader_step_start
        if self.last_batch_logged and \
                step != state.batch_size * state.loader_len:
            if (now() - self.last_batch_logged).total_seconds() < 10:
                return

        task = self.get_parent_task()
        task.batch_index = int(step / state.batch_size)
        task.batch_total = state.loader_len
        if task.batch_index > task.batch_total:
            state.step = state.batch_size
            task.batch_index = 1
        task.loader_name = state.loader_name

        duration = int((now() - self.loader_started_time).total_seconds())
        task.epoch_duration = duration
        task.epoch_time_remaining = int(
            duration *
            (task.batch_total / task.batch_index)) - task.epoch_duration
        if state.loss is not None:
            # noinspection PyUnresolvedReferences
            task.loss = float(state.loss['loss'])

        self.task_provider.update()
        self.last_batch_logged = now()
Esempio n. 5
0
 def _compute_metric(self, state: State):
     criterion_kwargs = {
         self.real_data_criterion_key:
         state.batch_in[self.input_key],
         self.fake_data_criterion_key:
         state.batch_out[self.output_key],
         self.critic_criterion_key:
         state.model[self.critic_model_key],
         self.condition_args_criterion_key:
         [state.batch_in[key] for key in self.condition_keys]
     }
     criterion = state.get_attr("criterion", self.criterion_key)
     return criterion(**criterion_kwargs)
 def do_cutmix(self, state: State):
     bbx1, bby1, bbx2, bby2 = rand_bbox(state.input[self.fields[0]].shape, self.lam)
     for f in self.fields:
         state.input[f][:, :, bbx1:bbx2, bby1:bby2] = state.input[f][
             self.index, :, bbx1:bbx2, bby1:bby2
         ]
     self.lam = 1 - (
         (bbx2 - bbx1)
         * (bby2 - bby1)
         / (
             state.input[self.fields[0]].shape[-1]
             * state.input[self.fields[0]].shape[-2]
         )
     )
Esempio n. 7
0
    def on_batch_start(self, state: State):
        if not self.is_needed:
            return

        if self.alpha > 0:
            self.lam = np.random.beta(self.alpha, self.alpha)
        else:
            self.lam = 1

        self.index = torch.randperm(state.batch_in[self.fields[0]].shape[0])
        self.index.to(state.device)

        for f in self.fields:
            state.batch_in[f] = self.lam * state.batch_in[f] + \
                                (1 - self.lam) * state.batch_in[f][self.index]
Esempio n. 8
0
    def on_batch_start(self, state: State):
        if not self.is_needed:
            return

        if self.alpha > 0:
            self.lam = np.random.beta(self.alpha, self.alpha)
        else:
            self.lam = 1

        self.index = torch.randperm(state.input[self.fields[0]].shape[0])
        self.index.to(state.device)

        bbx1, bby1, bbx2, bby2 = rand_bbox(state.input[self.input_key].size(), self.lam)
        for f in self.fields:
            state.input[f][:, :, bbx1:bbx2, bby1:bby2] = state.input[f][self.index, :, bbx1:bbx2, bby1:bby2]
Esempio n. 9
0
    def on_epoch_end(self, state: State) -> None:
        if state.stage.startswith("infer"):
            return

        score = state.metric_manager.valid_values[self.metric]
        if self.best_score is None:
            self.best_score = score
        if self.is_better(score, self.best_score):
            self.num_bad_epochs = 0
            self.best_score = score
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs >= self.patience:
            print(f"Early stop at {state.stage_epoch} epoch")
            state.early_stop = True
Esempio n. 10
0
    def on_batch_end(self, state: State):
        """On batch end event"""
        super().on_batch_end(state)
        if not state.need_backward:
            return

        optimizer = state.get_key(key="optimizer",
                                  inner_key=self.optimizer_key)

        need_gradient_step = \
            self._accumulation_counter % self.accumulation_steps == 0

        if need_gradient_step:
            for group in optimizer.param_groups:
                for param in group["params"]:
                    param.data.clamp_(min=-self.weight_clamp_value,
                                      max=self.weight_clamp_value)
Esempio n. 11
0
    def on_batch_start(self, state: State):
        """Batch start hook.

        Args:
            state (State): current state
        """
        if not self.is_needed:
            return

        if self.alpha > 0:
            self.lam = np.random.beta(self.alpha, self.alpha)
        else:
            self.lam = 1

        self.index = torch.randperm(state.input[self.fields[0]].shape[0])
        self.index.to(state.device)

        for f in self.fields:
            state.input[f] = (self.lam * state.input[f] +
                              (1 - self.lam) * state.input[f][self.index])
Esempio n. 12
0
    def on_batch_end(self, state: State) -> None:
        """On batch end event.
        Args:
            state (State): current state
        """
        super().on_batch_end(state)
        if not state.is_train_loader:
            return

        optimizer = state.get_attr(key="optimizer",
                                   inner_key=self.optimizer_key)

        need_gradient_step = (self._accumulation_counter %
                              self.accumulation_steps == 0)

        if need_gradient_step:
            for group in optimizer.param_groups:
                for param in group["params"]:
                    param.data.clamp_(
                        min=-self.weight_clamp_value,
                        max=self.weight_clamp_value,
                    )
Esempio n. 13
0
 def on_batch_start(self, state: State):
     state.prev_batch_metrics = state.batch_metrics
     super().on_batch_start(state)
Esempio n. 14
0
 def on_stage_start(self, state: State):
     state.loggers = {'console': VerboseLogger()}
 def do_mixup(self, state: State):
     for f in self.fields:
         state.input[f] = (
             self.lam * state.input[f] + (1 - self.lam) * state.input[f][self.index]
         )