Exemplo n.º 1
0
    def _handle_batch(self, batch):
        # model train/valid step
        x, y = batch
        y_hat = self.model(x.view(x.size(0), -1))

        loss = F.cross_entropy(y_hat, y)
        accuracy01, accuracy03 = metrics.accuracy(y_hat, y, topk=(1, 3))
        self.state.batch_metrics.update({
            "loss": loss,
            "accuracy01": accuracy01,
            "accuracy03": accuracy03
        })

        if self.state.is_train_loader:
            loss.backward()
            self.state.optimizer.step()
            self.state.optimizer.zero_grad()
Exemplo n.º 2
0
    def _handle_batch(self, batch):
        x, y = batch
        # y_hat, attention = self.model(x)
        outputs = self.model(x)

        loss = F.cross_entropy(outputs['logits'], y)
        accuracy01, accuracy02 = metrics.accuracy(
            outputs['logits'], y, topk=(1, 2))
        self.batch_metrics = {
            "loss": loss,
            "accuracy01": accuracy01,
            "accuracy02": accuracy02,
        }

        if self.is_train_loader:
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
Exemplo n.º 3
0
    def _handle_batch(self, batch):
        x, y = batch
        x = x.view(x.size(0), -1)
        y_hat, x_ = self.model(x)
        loss_clf = F.cross_entropy(y_hat, y)
        loss_ae = F.mse_loss(x_, x)
        loss = loss_clf + loss_ae
        accuracy01, accuracy03, accuracy05 = metrics.accuracy(y_hat,
                                                              y,
                                                              topk=(1, 3, 5))

        self.batch_metrics = {
            "loss_clf": loss_clf,
            "loss_ae": loss_ae,
            "loss": loss,
            "accuracy01": accuracy01,
            "accuracy03": accuracy03,
            "accuracy05": accuracy05,
        }

        if self.is_train_loader:
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
Exemplo n.º 4
0
    def _handle_batch(self, batch):
        """
        Docs.
        """
        x, y = batch
        x = x.view(x.size(0), -1)
        y_hat, x_, z_logprob, loc, log_scale = self.model(x)

        loss_clf = F.cross_entropy(y_hat, y)
        loss_ae = F.mse_loss(x_, x)
        loss_kld = (
            -0.5
            * torch.mean(1 + log_scale - loc.pow(2) - log_scale.exp())
            * 0.1
        )
        loss_logprob = torch.mean(z_logprob) * 0.01
        loss = loss_clf + loss_ae + loss_kld + loss_logprob
        accuracy01, accuracy03, accuracy05 = metrics.accuracy(
            y_hat, y, topk=(1, 3, 5)
        )

        self.state.batch_metrics = {
            "loss_clf": loss_clf,
            "loss_ae": loss_ae,
            "loss_kld": loss_kld,
            "loss_logprob": loss_logprob,
            "loss": loss,
            "accuracy01": accuracy01,
            "accuracy03": accuracy03,
            "accuracy05": accuracy05,
        }

        if self.state.is_train_loader:
            loss.backward()
            self.state.optimizer.step()
            self.state.optimizer.zero_grad()