def _compute_mmd(self, phi_s, phi_t, y_hat, y_t_hat): batch_size = int(phi_s.size()[0]) kernels = losses.gaussian_kernel( phi_s, phi_t, kernel_mul=self._kernel_mul, kernel_num=self._kernel_num, ) return losses.compute_mmd_loss(kernels, batch_size)
def _compute_mmd(self, phi_s, phi_t, y_hat, y_t_hat): softmax_layer = torch.nn.Softmax(dim=-1) source_list = [phi_s, softmax_layer(y_hat)] target_list = [phi_t, softmax_layer(y_t_hat)] batch_size = int(phi_s.size()[0]) joint_kernels = None for source, target, k_mul, k_num, sigma in zip( source_list, target_list, self._kernel_mul, self._kernel_num, [None, 1.68] ): kernels = losses.gaussian_kernel(source, target, kernel_mul=k_mul, kernel_num=k_num, fix_sigma=sigma) if joint_kernels is not None: joint_kernels = joint_kernels * kernels else: joint_kernels = kernels return losses.compute_mmd_loss(joint_kernels, batch_size)
def compute_loss(self, batch, split_name="valid"): x, y, domain_labels = batch phi_x = self.forward(x) tgt_idx = torch.where(domain_labels == self.target_label)[0] n_src = len(self.src_domains) domain_dist = 0 loss_cls = 0 ok_src = [] for src_domain in self.src_domains: src_domain_idx = torch.where( domain_labels == self.domain_to_idx[src_domain])[0] phi_src = self.domain_net[src_domain].forward( phi_x[src_domain_idx]) phi_tgt = self.domain_net[src_domain].forward(phi_x[tgt_idx]) kernels = losses.gaussian_kernel( phi_src, phi_tgt, kernel_mul=self._kernel_mul, kernel_num=self._kernel_num, ) domain_dist += losses.compute_mmd_loss(kernels, len(phi_src)) y_src_hat = self.classifiers[src_domain](phi_src) loss_cls_, ok_src_ = losses.cross_entropy_logits( y_src_hat, y[src_domain_idx]) loss_cls += loss_cls_ ok_src.append(ok_src_) domain_dist += self.cls_discrepancy(phi_x[tgt_idx]) loss_cls = loss_cls / n_src ok_src = torch.cat(ok_src) y_tgt_hat = self._get_avg_cls_output(phi_x[tgt_idx]) _, ok_tgt = losses.cross_entropy_logits(y_tgt_hat, y[tgt_idx]) task_loss = loss_cls log_metrics = { f"{split_name}_source_acc": ok_src, f"{split_name}_target_acc": ok_tgt, f"{split_name}_domain_acc": domain_dist, } return task_loss, domain_dist, log_metrics