Esempio n. 1
0
    def compute_loss(self, batch, split_name="V"):
        if len(batch) == 3:
            raise NotImplementedError("DANN does not support semi-supervised setting.")
        (x_s, y_s), (x_tu, y_tu) = batch
        batch_size = len(y_s)

        _, y_hat, d_hat = self.forward(x_s)
        _, y_t_hat, d_t_hat = self.forward(x_tu)

        loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s)
        _, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu)

        loss_dmn_src, dok_src = losses.cross_entropy_logits(
            d_hat, torch.zeros(batch_size)
        )
        loss_dmn_tgt, dok_tgt = losses.cross_entropy_logits(
            d_t_hat, torch.ones(len(d_t_hat))
        )
        adv_loss = loss_dmn_src + loss_dmn_tgt
        task_loss = loss_cls

        log_metrics = {
            f"{split_name}_source_acc": ok_src,
            f"{split_name}_target_acc": ok_tgt,
            f"{split_name}_domain_acc": torch.cat((dok_src, dok_tgt)),
            f"{split_name}_source_domain_acc": dok_src,
            f"{split_name}_target_domain_acc": dok_tgt,
        }
        return task_loss, adv_loss, log_metrics
Esempio n. 2
0
    def compute_loss(self, batch, split_name="V"):
        if len(batch) == 3:
            raise NotImplementedError("WDGRL does not support semi-supervised setting.")
        (x_s, y_s), (x_tu, y_tu) = batch
        batch_size = len(y_s)

        _, y_hat, d_hat = self.forward(x_s)
        _, y_t_hat, d_t_hat = self.forward(x_tu)

        loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s)
        _, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu)

        _, dok_src = losses.cross_entropy_logits(d_hat, torch.zeros(batch_size))
        _, dok_tgt = losses.cross_entropy_logits(d_t_hat, torch.ones(len(d_t_hat)))

        wasserstein_distance = d_hat.mean() - (1 + self._beta_ratio) * d_t_hat.mean()
        adv_loss = wasserstein_distance
        task_loss = loss_cls

        log_metrics = {
            f"{split_name}_source_acc": ok_src,
            f"{split_name}_target_acc": ok_tgt,
            f"{split_name}_domain_acc": torch.cat((dok_src, dok_tgt)),
            f"{split_name}_source_domain_acc": dok_src,
            f"{split_name}_target_domain_acc": dok_tgt,
            f"{split_name}_wasserstein_dist": wasserstein_distance,
        }
        return task_loss, adv_loss, log_metrics
Esempio n. 3
0
    def compute_loss(self, batch, split_name="V"):
        assert len(batch) == 3
        (x_s, y_s), (x_tl, y_tl), (x_tu, y_tu) = batch
        batch_size = len(y_s)

        _, y_hat, d_hat = self.forward(x_s)
        _, y_tl_hat, d_tl_hat = self.forward(x_tl)
        _, y_tu_hat, d_tu_hat = self.forward(x_tu)
        d_target_pred = torch.cat((d_tl_hat, d_tu_hat))

        loss_cls_s, ok_src = losses.cross_entropy_logits(y_hat, y_s)
        loss_cls_tl, ok_tl = losses.cross_entropy_logits(y_tl_hat, y_tl)
        _, ok_tu = losses.cross_entropy_logits(y_tu_hat, y_tu)
        ok_tgt = torch.cat((ok_tl, ok_tu))

        if self.current_epoch < self._init_epochs:
            # init phase doesn't use few-shot learning
            # ad-hoc decision but makes models more comparable between each other
            task_loss = loss_cls_s
        else:
            task_loss = (batch_size * loss_cls_s + len(y_tl) * loss_cls_tl) / (
                batch_size + len(y_tl)
            )

        loss_dmn_src, dok_src = losses.cross_entropy_logits(
            d_hat, torch.zeros(batch_size)
        )
        loss_dmn_tgt, dok_tgt = losses.cross_entropy_logits(
            d_target_pred, torch.ones(len(d_target_pred))
        )

        if self._method is Method.MME:
            # only keep accuracy, overwrite "domain" loss
            loss_dmn_src = 0
            loss_dmn_tgt = losses.entropy_logits_loss(y_tu_hat)

        adv_loss = loss_dmn_src + loss_dmn_tgt

        log_metrics = {
            f"{split_name}_source_acc": ok_src,
            f"{split_name}_target_acc": ok_tgt,
            f"{split_name}_domain_acc": torch.cat((dok_src, dok_tgt)),
            f"{split_name}_source_domain_acc": dok_src,
            f"{split_name}_target_domain_acc": dok_tgt,
        }
        return task_loss, adv_loss, log_metrics
Esempio n. 4
0
    def compute_loss(self, batch, split_name="V"):
        if len(batch) == 3:
            raise NotImplementedError("MMD does not support semi-supervised setting.")
        (x_s, y_s), (x_tu, y_tu) = batch

        phi_s, y_hat = self.forward(x_s)
        phi_t, y_t_hat = self.forward(x_tu)

        loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s)
        _, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu)

        mmd = self._compute_mmd(phi_s, phi_t, y_hat, y_t_hat)
        task_loss = loss_cls

        log_metrics = {
            f"{split_name}_source_acc": ok_src,
            f"{split_name}_target_acc": ok_tgt,
            f"{split_name}_mmd": mmd,
        }
        return task_loss, mmd, log_metrics