def compute_loss(self, batch, split_name="V"): if len(batch) == 3: raise NotImplementedError("DANN does not support semi-supervised setting.") (x_s, y_s), (x_tu, y_tu) = batch batch_size = len(y_s) _, y_hat, d_hat = self.forward(x_s) _, y_t_hat, d_t_hat = self.forward(x_tu) loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s) _, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu) loss_dmn_src, dok_src = losses.cross_entropy_logits( d_hat, torch.zeros(batch_size) ) loss_dmn_tgt, dok_tgt = losses.cross_entropy_logits( d_t_hat, torch.ones(len(d_t_hat)) ) adv_loss = loss_dmn_src + loss_dmn_tgt task_loss = loss_cls log_metrics = { f"{split_name}_source_acc": ok_src, f"{split_name}_target_acc": ok_tgt, f"{split_name}_domain_acc": torch.cat((dok_src, dok_tgt)), f"{split_name}_source_domain_acc": dok_src, f"{split_name}_target_domain_acc": dok_tgt, } return task_loss, adv_loss, log_metrics
def compute_loss(self, batch, split_name="V"): if len(batch) == 3: raise NotImplementedError("WDGRL does not support semi-supervised setting.") (x_s, y_s), (x_tu, y_tu) = batch batch_size = len(y_s) _, y_hat, d_hat = self.forward(x_s) _, y_t_hat, d_t_hat = self.forward(x_tu) loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s) _, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu) _, dok_src = losses.cross_entropy_logits(d_hat, torch.zeros(batch_size)) _, dok_tgt = losses.cross_entropy_logits(d_t_hat, torch.ones(len(d_t_hat))) wasserstein_distance = d_hat.mean() - (1 + self._beta_ratio) * d_t_hat.mean() adv_loss = wasserstein_distance task_loss = loss_cls log_metrics = { f"{split_name}_source_acc": ok_src, f"{split_name}_target_acc": ok_tgt, f"{split_name}_domain_acc": torch.cat((dok_src, dok_tgt)), f"{split_name}_source_domain_acc": dok_src, f"{split_name}_target_domain_acc": dok_tgt, f"{split_name}_wasserstein_dist": wasserstein_distance, } return task_loss, adv_loss, log_metrics
def compute_loss(self, batch, split_name="V"): assert len(batch) == 3 (x_s, y_s), (x_tl, y_tl), (x_tu, y_tu) = batch batch_size = len(y_s) _, y_hat, d_hat = self.forward(x_s) _, y_tl_hat, d_tl_hat = self.forward(x_tl) _, y_tu_hat, d_tu_hat = self.forward(x_tu) d_target_pred = torch.cat((d_tl_hat, d_tu_hat)) loss_cls_s, ok_src = losses.cross_entropy_logits(y_hat, y_s) loss_cls_tl, ok_tl = losses.cross_entropy_logits(y_tl_hat, y_tl) _, ok_tu = losses.cross_entropy_logits(y_tu_hat, y_tu) ok_tgt = torch.cat((ok_tl, ok_tu)) if self.current_epoch < self._init_epochs: # init phase doesn't use few-shot learning # ad-hoc decision but makes models more comparable between each other task_loss = loss_cls_s else: task_loss = (batch_size * loss_cls_s + len(y_tl) * loss_cls_tl) / ( batch_size + len(y_tl) ) loss_dmn_src, dok_src = losses.cross_entropy_logits( d_hat, torch.zeros(batch_size) ) loss_dmn_tgt, dok_tgt = losses.cross_entropy_logits( d_target_pred, torch.ones(len(d_target_pred)) ) if self._method is Method.MME: # only keep accuracy, overwrite "domain" loss loss_dmn_src = 0 loss_dmn_tgt = losses.entropy_logits_loss(y_tu_hat) adv_loss = loss_dmn_src + loss_dmn_tgt log_metrics = { f"{split_name}_source_acc": ok_src, f"{split_name}_target_acc": ok_tgt, f"{split_name}_domain_acc": torch.cat((dok_src, dok_tgt)), f"{split_name}_source_domain_acc": dok_src, f"{split_name}_target_domain_acc": dok_tgt, } return task_loss, adv_loss, log_metrics
def compute_loss(self, batch, split_name="V"): if len(batch) == 3: raise NotImplementedError("MMD does not support semi-supervised setting.") (x_s, y_s), (x_tu, y_tu) = batch phi_s, y_hat = self.forward(x_s) phi_t, y_t_hat = self.forward(x_tu) loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s) _, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu) mmd = self._compute_mmd(phi_s, phi_t, y_hat, y_t_hat) task_loss = loss_cls log_metrics = { f"{split_name}_source_acc": ok_src, f"{split_name}_target_acc": ok_tgt, f"{split_name}_mmd": mmd, } return task_loss, mmd, log_metrics