Exemplo n.º 1
0
    def compute_loss(self, batch, split_name="V"):
        if len(batch) == 3:
            raise NotImplementedError("WDGRL does not support semi-supervised setting.")
        (x_s, y_s), (x_tu, y_tu) = batch
        batch_size = len(y_s)

        _, y_hat, d_hat = self.forward(x_s)
        _, y_t_hat, d_t_hat = self.forward(x_tu)

        loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s)
        _, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu)

        _, dok_src = losses.cross_entropy_logits(d_hat, torch.zeros(batch_size))
        _, dok_tgt = losses.cross_entropy_logits(d_t_hat, torch.ones(len(d_t_hat)))

        wasserstein_distance = d_hat.mean() - (1 + self._beta_ratio) * d_t_hat.mean()
        adv_loss = wasserstein_distance
        task_loss = loss_cls

        log_metrics = {
            f"{split_name}_source_acc": ok_src,
            f"{split_name}_target_acc": ok_tgt,
            f"{split_name}_domain_acc": torch.cat((dok_src, dok_tgt)),
            f"{split_name}_source_domain_acc": dok_src,
            f"{split_name}_target_domain_acc": dok_tgt,
            f"{split_name}_wasserstein_dist": wasserstein_distance,
        }
        return task_loss, adv_loss, log_metrics
Exemplo n.º 2
0
    def compute_loss(self, batch, split_name="V"):
        if len(batch) == 3:
            raise NotImplementedError("DANN does not support semi-supervised setting.")
        (x_s, y_s), (x_tu, y_tu) = batch
        batch_size = len(y_s)

        _, y_hat, d_hat = self.forward(x_s)
        _, y_t_hat, d_t_hat = self.forward(x_tu)

        loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s)
        _, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu)

        loss_dmn_src, dok_src = losses.cross_entropy_logits(
            d_hat, torch.zeros(batch_size)
        )
        loss_dmn_tgt, dok_tgt = losses.cross_entropy_logits(
            d_t_hat, torch.ones(len(d_t_hat))
        )
        adv_loss = loss_dmn_src + loss_dmn_tgt
        task_loss = loss_cls

        log_metrics = {
            f"{split_name}_source_acc": ok_src,
            f"{split_name}_target_acc": ok_tgt,
            f"{split_name}_domain_acc": torch.cat((dok_src, dok_tgt)),
            f"{split_name}_source_domain_acc": dok_src,
            f"{split_name}_target_domain_acc": dok_tgt,
        }
        return task_loss, adv_loss, log_metrics
Exemplo n.º 3
0
    def compute_loss(self, batch, split_name="V"):
        assert len(batch) == 3
        (x_s, y_s), (x_tl, y_tl), (x_tu, y_tu) = batch
        batch_size = len(y_s)

        _, y_hat, d_hat = self.forward(x_s)
        _, y_tl_hat, d_tl_hat = self.forward(x_tl)
        _, y_tu_hat, d_tu_hat = self.forward(x_tu)
        d_target_pred = torch.cat((d_tl_hat, d_tu_hat))

        loss_cls_s, ok_src = losses.cross_entropy_logits(y_hat, y_s)
        loss_cls_tl, ok_tl = losses.cross_entropy_logits(y_tl_hat, y_tl)
        _, ok_tu = losses.cross_entropy_logits(y_tu_hat, y_tu)
        ok_tgt = torch.cat((ok_tl, ok_tu))

        if self.current_epoch < self._init_epochs:
            # init phase doesn't use few-shot learning
            # ad-hoc decision but makes models more comparable between each other
            task_loss = loss_cls_s
        else:
            task_loss = (batch_size * loss_cls_s + len(y_tl) * loss_cls_tl) / (
                batch_size + len(y_tl)
            )

        loss_dmn_src, dok_src = losses.cross_entropy_logits(
            d_hat, torch.zeros(batch_size)
        )
        loss_dmn_tgt, dok_tgt = losses.cross_entropy_logits(
            d_target_pred, torch.ones(len(d_target_pred))
        )

        if self._method is Method.MME:
            # only keep accuracy, overwrite "domain" loss
            loss_dmn_src = 0
            loss_dmn_tgt = losses.entropy_logits_loss(y_tu_hat)

        adv_loss = loss_dmn_src + loss_dmn_tgt

        log_metrics = {
            f"{split_name}_source_acc": ok_src,
            f"{split_name}_target_acc": ok_tgt,
            f"{split_name}_domain_acc": torch.cat((dok_src, dok_tgt)),
            f"{split_name}_source_domain_acc": dok_src,
            f"{split_name}_target_domain_acc": dok_tgt,
        }
        return task_loss, adv_loss, log_metrics
Exemplo n.º 4
0
    def compute_loss(self, batch, split_name="valid"):
        x, y, domain_labels = batch
        phi_x = self.forward(x)
        tgt_idx = torch.where(domain_labels == self.target_label)[0]
        n_src = len(self.src_domains)
        domain_dist = 0
        loss_cls = 0
        ok_src = []
        for src_domain in self.src_domains:
            src_domain_idx = torch.where(
                domain_labels == self.domain_to_idx[src_domain])[0]
            phi_src = self.domain_net[src_domain].forward(
                phi_x[src_domain_idx])
            phi_tgt = self.domain_net[src_domain].forward(phi_x[tgt_idx])
            kernels = losses.gaussian_kernel(
                phi_src,
                phi_tgt,
                kernel_mul=self._kernel_mul,
                kernel_num=self._kernel_num,
            )
            domain_dist += losses.compute_mmd_loss(kernels, len(phi_src))
            y_src_hat = self.classifiers[src_domain](phi_src)
            loss_cls_, ok_src_ = losses.cross_entropy_logits(
                y_src_hat, y[src_domain_idx])
            loss_cls += loss_cls_
            ok_src.append(ok_src_)

        domain_dist += self.cls_discrepancy(phi_x[tgt_idx])
        loss_cls = loss_cls / n_src
        ok_src = torch.cat(ok_src)

        y_tgt_hat = self._get_avg_cls_output(phi_x[tgt_idx])
        _, ok_tgt = losses.cross_entropy_logits(y_tgt_hat, y[tgt_idx])

        task_loss = loss_cls
        log_metrics = {
            f"{split_name}_source_acc": ok_src,
            f"{split_name}_target_acc": ok_tgt,
            f"{split_name}_domain_acc": domain_dist,
        }

        return task_loss, domain_dist, log_metrics
Exemplo n.º 5
0
    def compute_loss(self, batch, split_name="valid"):
        x, y, domain_labels = batch
        phi_x = self.forward(x)
        loss_dist = self._compute_domain_dist(phi_x, domain_labels)
        src_idx = torch.where(domain_labels != self.target_label)[0]
        tgt_idx = torch.where(domain_labels == self.target_label)[0]
        cls_output = self.classifier(phi_x)
        loss_cls, ok_src = losses.cross_entropy_logits(cls_output[src_idx],
                                                       y[src_idx])
        _, ok_tgt = losses.cross_entropy_logits(cls_output[tgt_idx],
                                                y[tgt_idx])

        task_loss = loss_cls
        log_metrics = {
            f"{split_name}_source_acc": ok_src,
            f"{split_name}_target_acc": ok_tgt,
            f"{split_name}_domain_acc": loss_dist,
        }

        return task_loss, loss_dist, log_metrics
Exemplo n.º 6
0
    def compute_loss(self, batch, split_name="V"):
        if len(batch) == 3:
            raise NotImplementedError("MMD does not support semi-supervised setting.")
        (x_s, y_s), (x_tu, y_tu) = batch

        phi_s, y_hat = self.forward(x_s)
        phi_t, y_t_hat = self.forward(x_tu)

        loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s)
        _, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu)

        mmd = self._compute_mmd(phi_s, phi_t, y_hat, y_t_hat)
        task_loss = loss_cls

        log_metrics = {
            f"{split_name}_source_acc": ok_src,
            f"{split_name}_target_acc": ok_tgt,
            f"{split_name}_mmd": mmd,
        }
        return task_loss, mmd, log_metrics
Exemplo n.º 7
0
    def compute_loss(self, batch, split_name="valid"):
        # _s refers to source, _tu refers to unlabeled target
        if self.image_modality == "joint" and len(batch) == 4:
            (x_s_rgb, y_s), (x_s_flow, y_s_flow), (x_tu_rgb,
                                                   y_tu), (x_tu_flow,
                                                           y_tu_flow) = batch
            [phi_s_rgb, phi_s_flow], y_hat = self.forward({
                "rgb": x_s_rgb,
                "flow": x_s_flow
            })
            [phi_t_rgb, phi_t_flow], y_t_hat = self.forward({
                "rgb": x_tu_rgb,
                "flow": x_tu_flow
            })
            mmd_rgb = self._compute_mmd(phi_s_rgb, phi_t_rgb, y_hat, y_t_hat)
            mmd_flow = self._compute_mmd(phi_s_flow, phi_t_flow, y_hat,
                                         y_t_hat)
            mmd = mmd_rgb + mmd_flow
        elif self.image_modality in ["rgb", "flow"] and len(batch) == 2:
            (x_s, y_s), (x_tu, y_tu) = batch
            phi_s, y_hat = self.forward(x_s)
            phi_t, y_t_hat = self.forward(x_tu)
            mmd = self._compute_mmd(phi_s, phi_t, y_hat, y_t_hat)
        else:
            raise NotImplementedError(
                "Batch len is {}. Check the Dataloader.".format(len(batch)))

        # Uncomment when checking whether rgb & flow labels are equal.
        # print('rgb_s:{}, flow_s:{}, rgb_f:{}, flow_f:{}'.format(y_s, y_s_flow, y_tu, y_tu_flow))
        # print('equal: {}/{}'.format(torch.all(torch.eq(y_s, y_s_flow)), torch.all(torch.eq(y_tu, y_tu_flow))))

        # ok is abbreviation for (all) correct
        loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s)
        _, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu)
        task_loss = loss_cls
        log_metrics = {
            f"{split_name}_source_acc": ok_src,
            f"{split_name}_target_acc": ok_tgt,
            f"{split_name}_domain_acc": mmd,
        }
        return task_loss, mmd, log_metrics
Exemplo n.º 8
0
    def compute_loss(self, batch, split_name="valid"):
        if len(batch) == 3:
            raise NotImplementedError(
                "CDAN does not support semi-supervised setting.")
        (x_s, y_s), (x_tu, y_tu) = batch
        batch_size = len(y_s)

        _, y_hat, d_hat = self.forward(x_s)
        _, y_t_hat, d_t_hat = self.forward(x_tu)

        loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s)
        _, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu)

        if self.entropy:
            e_s = self._compute_entropy_weights(y_hat)
            e_t = self._compute_entropy_weights(y_t_hat)
            source_weight = e_s / torch.sum(e_s)
            target_weight = e_t / torch.sum(e_t)
        else:
            source_weight = None
            target_weight = None

        loss_dmn_src, dok_src = losses.cross_entropy_logits(
            d_hat, torch.zeros(batch_size), source_weight)
        loss_dmn_tgt, dok_tgt = losses.cross_entropy_logits(
            d_t_hat, torch.ones(len(d_t_hat)), target_weight)

        adv_loss = loss_dmn_src + loss_dmn_tgt
        task_loss = loss_cls

        log_metrics = {
            f"{split_name}_source_acc": ok_src,
            f"{split_name}_target_acc": ok_tgt,
            f"{split_name}_domain_acc": torch.cat((dok_src, dok_tgt)),
            f"{split_name}_source_domain_acc": dok_src,
            f"{split_name}_target_domain_acc": dok_tgt,
        }
        return task_loss, adv_loss, log_metrics
Exemplo n.º 9
0
 def _compute_cls_loss(self, x, y, domain_labels: torch.Tensor):
     if len(y) == 0:
         return 0.0, 0.0
     else:
         cls_loss = 0.0
         ok_src = []
         n_src = 0
         for domain_ in self.domain_to_idx.keys():
             if domain_ == self.target_domain:
                 continue
             domain_idx = torch.where(
                 domain_labels == self.domain_to_idx[domain_])[0]
             cls_output = self.classifiers[domain_](x[domain_idx])
             loss_cls_, ok_src_ = losses.cross_entropy_logits(
                 cls_output, y[domain_idx])
             cls_loss += loss_cls_
             ok_src.append(ok_src_)
             n_src += 1
         cls_loss = cls_loss / n_src
         ok_src = torch.cat(ok_src)
         return cls_loss, ok_src
Exemplo n.º 10
0
    def compute_loss(self, batch, split_name="valid"):
        x, y, domain_labels = batch
        phi_x = self.forward(x)
        moment_loss = self._compute_domain_dist(phi_x, domain_labels)
        src_idx = torch.where(domain_labels != self.target_label)[0]
        tgt_idx = torch.where(domain_labels == self.target_label)[0]
        cls_loss, ok_src = self._compute_cls_loss(phi_x[src_idx], y[src_idx],
                                                  domain_labels[src_idx])
        if len(tgt_idx) > 0:
            y_tgt_hat = _average_cls_output(phi_x[tgt_idx], self.classifiers)
            _, ok_tgt = losses.cross_entropy_logits(y_tgt_hat, y[tgt_idx])
        else:
            ok_tgt = 0.0

        task_loss = cls_loss
        log_metrics = {
            f"{split_name}_source_acc": ok_src,
            f"{split_name}_target_acc": ok_tgt,
            f"{split_name}_domain_acc": moment_loss,
        }

        return task_loss, moment_loss, log_metrics
Exemplo n.º 11
0
    def compute_loss(self, batch, split_name="valid"):
        # _s refers to source, _tu refers to unlabeled target
        x_s_rgb = x_tu_rgb = x_s_flow = x_tu_flow = None
        if self.rgb:
            if self.flow:  # For joint input
                (x_s_rgb, y_s), (x_s_flow,
                                 y_s_flow), (x_tu_rgb,
                                             y_tu), (x_tu_flow,
                                                     y_tu_flow) = batch
            else:  # For rgb input
                (x_s_rgb, y_s), (x_tu_rgb, y_tu) = batch
        else:  # For flow input
            (x_s_flow, y_s), (x_tu_flow, y_tu) = batch

        _, y_hat, [d_hat_rgb, d_hat_flow] = self.forward({
            "rgb": x_s_rgb,
            "flow": x_s_flow
        })
        _, y_t_hat, [d_t_hat_rgb, d_t_hat_flow] = self.forward({
            "rgb":
            x_tu_rgb,
            "flow":
            x_tu_flow
        })
        batch_size = len(y_s)

        # ok is abbreviation for (all) correct, dok refers to domain correct
        if self.rgb:
            _, dok_src_rgb = losses.cross_entropy_logits(
                d_hat_rgb, torch.zeros(batch_size))
            _, dok_tgt_rgb = losses.cross_entropy_logits(
                d_t_hat_rgb, torch.ones(batch_size))
        if self.flow:
            _, dok_src_flow = losses.cross_entropy_logits(
                d_hat_flow, torch.zeros(batch_size))
            _, dok_tgt_flow = losses.cross_entropy_logits(
                d_t_hat_flow, torch.ones(batch_size))

        if self.rgb and self.flow:  # For joint input
            dok = torch.cat(
                (dok_src_rgb, dok_src_flow, dok_tgt_rgb, dok_tgt_flow))
            dok_src = torch.cat((dok_src_rgb, dok_src_flow))
            dok_tgt = torch.cat((dok_tgt_rgb, dok_tgt_flow))
            wasserstein_distance_rgb = d_hat_rgb.mean() - (
                1 + self._beta_ratio) * d_t_hat_rgb.mean()
            wasserstein_distance_flow = d_hat_flow.mean() - (
                1 + self._beta_ratio) * d_t_hat_flow.mean()
            wasserstein_distance = (wasserstein_distance_rgb +
                                    wasserstein_distance_flow) / 2
        else:
            if self.rgb:  # For rgb input
                d_hat = d_hat_rgb
                d_t_hat = d_t_hat_rgb
                dok_src = dok_src_rgb
                dok_tgt = dok_tgt_rgb
            else:  # For flow input
                d_hat = d_hat_flow
                d_t_hat = d_t_hat_flow
                dok_src = dok_src_flow
                dok_tgt = dok_tgt_flow

            wasserstein_distance = d_hat.mean() - (
                1 + self._beta_ratio) * d_t_hat.mean()
            dok = torch.cat((dok_src, dok_tgt))

        loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s)
        _, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu)
        adv_loss = wasserstein_distance
        task_loss = loss_cls

        log_metrics = {
            f"{split_name}_source_acc": ok_src,
            f"{split_name}_target_acc": ok_tgt,
            f"{split_name}_domain_acc": dok,
            f"{split_name}_source_domain_acc": dok_src,
            f"{split_name}_target_domain_acc": dok_tgt,
            f"{split_name}_wasserstein_dist": wasserstein_distance,
        }
        return task_loss, adv_loss, log_metrics
Exemplo n.º 12
0
    def compute_loss(self, batch, split_name="valid"):
        # _s refers to source, _tu refers to unlabeled target
        x_s_rgb = x_tu_rgb = x_s_flow = x_tu_flow = None
        if self.rgb:
            if self.flow:  # For joint input
                (x_s_rgb, y_s), (x_s_flow,
                                 y_s_flow), (x_tu_rgb,
                                             y_tu), (x_tu_flow,
                                                     y_tu_flow) = batch
            else:  # For rgb input
                (x_s_rgb, y_s), (x_tu_rgb, y_tu) = batch
        else:  # For flow input
            (x_s_flow, y_s), (x_tu_flow, y_tu) = batch

        _, y_hat, [d_hat_rgb, d_hat_flow] = self.forward({
            "rgb": x_s_rgb,
            "flow": x_s_flow
        })
        _, y_t_hat, [d_t_hat_rgb, d_t_hat_flow] = self.forward({
            "rgb":
            x_tu_rgb,
            "flow":
            x_tu_flow
        })
        batch_size = len(y_s)

        if self.entropy:
            e_s = self._compute_entropy_weights(y_hat)
            e_t = self._compute_entropy_weights(y_t_hat)
            source_weight = e_s / torch.sum(e_s)
            target_weight = e_t / torch.sum(e_t)
        else:
            source_weight = None
            target_weight = None

        if self.rgb:
            loss_dmn_src_rgb, dok_src_rgb = losses.cross_entropy_logits(
                d_hat_rgb, torch.zeros(batch_size), source_weight)
            loss_dmn_tgt_rgb, dok_tgt_rgb = losses.cross_entropy_logits(
                d_t_hat_rgb, torch.ones(batch_size), target_weight)

        if self.flow:
            loss_dmn_src_flow, dok_src_flow = losses.cross_entropy_logits(
                d_hat_flow, torch.zeros(batch_size), source_weight)
            loss_dmn_tgt_flow, dok_tgt_flow = losses.cross_entropy_logits(
                d_t_hat_flow, torch.ones(batch_size), target_weight)

        # ok is abbreviation for (all) correct, dok refers to domain correct
        if self.rgb and self.flow:  # For joint input
            loss_dmn_src = loss_dmn_src_rgb + loss_dmn_src_flow
            loss_dmn_tgt = loss_dmn_tgt_rgb + loss_dmn_tgt_flow
            dok = torch.cat(
                (dok_src_rgb, dok_src_flow, dok_tgt_rgb, dok_tgt_flow))
            dok_src = torch.cat((dok_src_rgb, dok_src_flow))
            dok_tgt = torch.cat((dok_tgt_rgb, dok_tgt_flow))
        else:
            if self.rgb:  # For rgb input
                d_hat = d_hat_rgb
                d_t_hat = d_t_hat_rgb
            else:  # For flow input
                d_hat = d_hat_flow
                d_t_hat = d_t_hat_flow

            loss_dmn_src, dok_src = losses.cross_entropy_logits(
                d_hat, torch.zeros(batch_size))
            loss_dmn_tgt, dok_tgt = losses.cross_entropy_logits(
                d_t_hat, torch.ones(batch_size))
            dok = torch.cat((dok_src, dok_tgt))

        loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s)
        _, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu)
        adv_loss = loss_dmn_src + loss_dmn_tgt  # adv_loss = src + tgt
        task_loss = loss_cls

        log_metrics = {
            f"{split_name}_source_acc": ok_src,
            f"{split_name}_target_acc": ok_tgt,
            f"{split_name}_domain_acc": dok,
            f"{split_name}_source_domain_acc": dok_src,
            f"{split_name}_target_domain_acc": dok_tgt,
        }

        return task_loss, adv_loss, log_metrics