Exemplo n.º 1
0
    def on_batch_end(self, state):
        # @TODO: check this one
        if not state.is_train:
            return

        self.accumulation_counter += 1

        if not self.fp16:
            # for _, value in state._optimizer.items():
            #     value.zero_grad()

            # @TODO: check this one
            if len(state._optimizer) > 0:
                for key, value in state.loss.items():
                    value.backward()

                if (self.accumulation_counter + 1) \
                        % self.accumulation_steps == 0:
                    self.grad_step(state._optimizer)
                    state.model.zero_grad()
                    self.accumulation_counter = 0
        else:
            state.model.zero_grad()
            if len(state._optimizer) > 0:
                assert len(state._optimizer) == 1, \
                    "fp16 mode works only with one optimizer for now"

                for key, value in state.loss.items():
                    scaled_loss = self.fp16_grad_scale * value.float()
                    scaled_loss.backward()

                master_params = list(
                    state._optimizer["main"].param_groups[0]["params"])
                model_params = list(
                    filter(lambda p: p.requires_grad,
                           state.model.parameters()))

                copy_grads(source=model_params, target=master_params)

                for param in master_params:
                    param.grad.data.mul_(1. / self.fp16_grad_scale)

                self.grad_step(state._optimizer)

                copy_params(source=master_params, target=model_params)
                torch.cuda.synchronize()
Exemplo n.º 2
0
    def on_batch_end(self, state):
        if not state.is_train:
            return

        self.accumulation_counter += 1
        if not self.fp16:
            model = state.model
            optimizer = state.get_key(key="optimizer",
                                      inner_key=self.optimizer_key)
            loss = state.get_key(key="loss", inner_key=self.loss_key)
            loss.backward()

            if (self.accumulation_counter + 1) % self.accumulation_steps == 0:
                self.grad_step(optimizer=optimizer,
                               optimizer_wd=self.optimizer_wd,
                               grad_clip_fn=self.grad_clip_fn)
                model.zero_grad()
                self.accumulation_counter = 0
        else:
            model = state.model
            model.zero_grad()
            optimizer = state.get_key(key="optimizer",
                                      inner_key=self.optimizer_key)
            loss = state.get_key(key="loss", inner_key=self.optimizer_key)
            scaled_loss = self.fp16_grad_scale * loss.float()
            scaled_loss.backward()

            master_params = list(optimizer.param_groups[0]["params"])
            model_params = list(
                filter(lambda p: p.requires_grad, model.parameters()))
            copy_grads(source=model_params, target=master_params)
            for param in master_params:
                param.grad.data.mul_(1. / self.fp16_grad_scale)
            self.grad_step(optimizer=optimizer,
                           optimizer_wd=self.optimizer_wd,
                           grad_clip_fn=self.grad_clip_fn)
            copy_params(source=master_params, target=model_params)
            torch.cuda.synchronize()