def compute_loss(self, batch, split_name="V"): if len(batch) == 3: raise NotImplementedError("WDGRL does not support semi-supervised setting.") (x_s, y_s), (x_tu, y_tu) = batch batch_size = len(y_s) _, y_hat, d_hat = self.forward(x_s) _, y_t_hat, d_t_hat = self.forward(x_tu) loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s) _, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu) _, dok_src = losses.cross_entropy_logits(d_hat, torch.zeros(batch_size)) _, dok_tgt = losses.cross_entropy_logits(d_t_hat, torch.ones(len(d_t_hat))) wasserstein_distance = d_hat.mean() - (1 + self._beta_ratio) * d_t_hat.mean() adv_loss = wasserstein_distance task_loss = loss_cls log_metrics = { f"{split_name}_source_acc": ok_src, f"{split_name}_target_acc": ok_tgt, f"{split_name}_domain_acc": torch.cat((dok_src, dok_tgt)), f"{split_name}_source_domain_acc": dok_src, f"{split_name}_target_domain_acc": dok_tgt, f"{split_name}_wasserstein_dist": wasserstein_distance, } return task_loss, adv_loss, log_metrics
def compute_loss(self, batch, split_name="V"): if len(batch) == 3: raise NotImplementedError("DANN does not support semi-supervised setting.") (x_s, y_s), (x_tu, y_tu) = batch batch_size = len(y_s) _, y_hat, d_hat = self.forward(x_s) _, y_t_hat, d_t_hat = self.forward(x_tu) loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s) _, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu) loss_dmn_src, dok_src = losses.cross_entropy_logits( d_hat, torch.zeros(batch_size) ) loss_dmn_tgt, dok_tgt = losses.cross_entropy_logits( d_t_hat, torch.ones(len(d_t_hat)) ) adv_loss = loss_dmn_src + loss_dmn_tgt task_loss = loss_cls log_metrics = { f"{split_name}_source_acc": ok_src, f"{split_name}_target_acc": ok_tgt, f"{split_name}_domain_acc": torch.cat((dok_src, dok_tgt)), f"{split_name}_source_domain_acc": dok_src, f"{split_name}_target_domain_acc": dok_tgt, } return task_loss, adv_loss, log_metrics
def compute_loss(self, batch, split_name="V"): assert len(batch) == 3 (x_s, y_s), (x_tl, y_tl), (x_tu, y_tu) = batch batch_size = len(y_s) _, y_hat, d_hat = self.forward(x_s) _, y_tl_hat, d_tl_hat = self.forward(x_tl) _, y_tu_hat, d_tu_hat = self.forward(x_tu) d_target_pred = torch.cat((d_tl_hat, d_tu_hat)) loss_cls_s, ok_src = losses.cross_entropy_logits(y_hat, y_s) loss_cls_tl, ok_tl = losses.cross_entropy_logits(y_tl_hat, y_tl) _, ok_tu = losses.cross_entropy_logits(y_tu_hat, y_tu) ok_tgt = torch.cat((ok_tl, ok_tu)) if self.current_epoch < self._init_epochs: # init phase doesn't use few-shot learning # ad-hoc decision but makes models more comparable between each other task_loss = loss_cls_s else: task_loss = (batch_size * loss_cls_s + len(y_tl) * loss_cls_tl) / ( batch_size + len(y_tl) ) loss_dmn_src, dok_src = losses.cross_entropy_logits( d_hat, torch.zeros(batch_size) ) loss_dmn_tgt, dok_tgt = losses.cross_entropy_logits( d_target_pred, torch.ones(len(d_target_pred)) ) if self._method is Method.MME: # only keep accuracy, overwrite "domain" loss loss_dmn_src = 0 loss_dmn_tgt = losses.entropy_logits_loss(y_tu_hat) adv_loss = loss_dmn_src + loss_dmn_tgt log_metrics = { f"{split_name}_source_acc": ok_src, f"{split_name}_target_acc": ok_tgt, f"{split_name}_domain_acc": torch.cat((dok_src, dok_tgt)), f"{split_name}_source_domain_acc": dok_src, f"{split_name}_target_domain_acc": dok_tgt, } return task_loss, adv_loss, log_metrics
def compute_loss(self, batch, split_name="valid"): x, y, domain_labels = batch phi_x = self.forward(x) tgt_idx = torch.where(domain_labels == self.target_label)[0] n_src = len(self.src_domains) domain_dist = 0 loss_cls = 0 ok_src = [] for src_domain in self.src_domains: src_domain_idx = torch.where( domain_labels == self.domain_to_idx[src_domain])[0] phi_src = self.domain_net[src_domain].forward( phi_x[src_domain_idx]) phi_tgt = self.domain_net[src_domain].forward(phi_x[tgt_idx]) kernels = losses.gaussian_kernel( phi_src, phi_tgt, kernel_mul=self._kernel_mul, kernel_num=self._kernel_num, ) domain_dist += losses.compute_mmd_loss(kernels, len(phi_src)) y_src_hat = self.classifiers[src_domain](phi_src) loss_cls_, ok_src_ = losses.cross_entropy_logits( y_src_hat, y[src_domain_idx]) loss_cls += loss_cls_ ok_src.append(ok_src_) domain_dist += self.cls_discrepancy(phi_x[tgt_idx]) loss_cls = loss_cls / n_src ok_src = torch.cat(ok_src) y_tgt_hat = self._get_avg_cls_output(phi_x[tgt_idx]) _, ok_tgt = losses.cross_entropy_logits(y_tgt_hat, y[tgt_idx]) task_loss = loss_cls log_metrics = { f"{split_name}_source_acc": ok_src, f"{split_name}_target_acc": ok_tgt, f"{split_name}_domain_acc": domain_dist, } return task_loss, domain_dist, log_metrics
def compute_loss(self, batch, split_name="valid"): x, y, domain_labels = batch phi_x = self.forward(x) loss_dist = self._compute_domain_dist(phi_x, domain_labels) src_idx = torch.where(domain_labels != self.target_label)[0] tgt_idx = torch.where(domain_labels == self.target_label)[0] cls_output = self.classifier(phi_x) loss_cls, ok_src = losses.cross_entropy_logits(cls_output[src_idx], y[src_idx]) _, ok_tgt = losses.cross_entropy_logits(cls_output[tgt_idx], y[tgt_idx]) task_loss = loss_cls log_metrics = { f"{split_name}_source_acc": ok_src, f"{split_name}_target_acc": ok_tgt, f"{split_name}_domain_acc": loss_dist, } return task_loss, loss_dist, log_metrics
def compute_loss(self, batch, split_name="V"): if len(batch) == 3: raise NotImplementedError("MMD does not support semi-supervised setting.") (x_s, y_s), (x_tu, y_tu) = batch phi_s, y_hat = self.forward(x_s) phi_t, y_t_hat = self.forward(x_tu) loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s) _, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu) mmd = self._compute_mmd(phi_s, phi_t, y_hat, y_t_hat) task_loss = loss_cls log_metrics = { f"{split_name}_source_acc": ok_src, f"{split_name}_target_acc": ok_tgt, f"{split_name}_mmd": mmd, } return task_loss, mmd, log_metrics
def compute_loss(self, batch, split_name="valid"): # _s refers to source, _tu refers to unlabeled target if self.image_modality == "joint" and len(batch) == 4: (x_s_rgb, y_s), (x_s_flow, y_s_flow), (x_tu_rgb, y_tu), (x_tu_flow, y_tu_flow) = batch [phi_s_rgb, phi_s_flow], y_hat = self.forward({ "rgb": x_s_rgb, "flow": x_s_flow }) [phi_t_rgb, phi_t_flow], y_t_hat = self.forward({ "rgb": x_tu_rgb, "flow": x_tu_flow }) mmd_rgb = self._compute_mmd(phi_s_rgb, phi_t_rgb, y_hat, y_t_hat) mmd_flow = self._compute_mmd(phi_s_flow, phi_t_flow, y_hat, y_t_hat) mmd = mmd_rgb + mmd_flow elif self.image_modality in ["rgb", "flow"] and len(batch) == 2: (x_s, y_s), (x_tu, y_tu) = batch phi_s, y_hat = self.forward(x_s) phi_t, y_t_hat = self.forward(x_tu) mmd = self._compute_mmd(phi_s, phi_t, y_hat, y_t_hat) else: raise NotImplementedError( "Batch len is {}. Check the Dataloader.".format(len(batch))) # Uncomment when checking whether rgb & flow labels are equal. # print('rgb_s:{}, flow_s:{}, rgb_f:{}, flow_f:{}'.format(y_s, y_s_flow, y_tu, y_tu_flow)) # print('equal: {}/{}'.format(torch.all(torch.eq(y_s, y_s_flow)), torch.all(torch.eq(y_tu, y_tu_flow)))) # ok is abbreviation for (all) correct loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s) _, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu) task_loss = loss_cls log_metrics = { f"{split_name}_source_acc": ok_src, f"{split_name}_target_acc": ok_tgt, f"{split_name}_domain_acc": mmd, } return task_loss, mmd, log_metrics
def compute_loss(self, batch, split_name="valid"): if len(batch) == 3: raise NotImplementedError( "CDAN does not support semi-supervised setting.") (x_s, y_s), (x_tu, y_tu) = batch batch_size = len(y_s) _, y_hat, d_hat = self.forward(x_s) _, y_t_hat, d_t_hat = self.forward(x_tu) loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s) _, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu) if self.entropy: e_s = self._compute_entropy_weights(y_hat) e_t = self._compute_entropy_weights(y_t_hat) source_weight = e_s / torch.sum(e_s) target_weight = e_t / torch.sum(e_t) else: source_weight = None target_weight = None loss_dmn_src, dok_src = losses.cross_entropy_logits( d_hat, torch.zeros(batch_size), source_weight) loss_dmn_tgt, dok_tgt = losses.cross_entropy_logits( d_t_hat, torch.ones(len(d_t_hat)), target_weight) adv_loss = loss_dmn_src + loss_dmn_tgt task_loss = loss_cls log_metrics = { f"{split_name}_source_acc": ok_src, f"{split_name}_target_acc": ok_tgt, f"{split_name}_domain_acc": torch.cat((dok_src, dok_tgt)), f"{split_name}_source_domain_acc": dok_src, f"{split_name}_target_domain_acc": dok_tgt, } return task_loss, adv_loss, log_metrics
def _compute_cls_loss(self, x, y, domain_labels: torch.Tensor): if len(y) == 0: return 0.0, 0.0 else: cls_loss = 0.0 ok_src = [] n_src = 0 for domain_ in self.domain_to_idx.keys(): if domain_ == self.target_domain: continue domain_idx = torch.where( domain_labels == self.domain_to_idx[domain_])[0] cls_output = self.classifiers[domain_](x[domain_idx]) loss_cls_, ok_src_ = losses.cross_entropy_logits( cls_output, y[domain_idx]) cls_loss += loss_cls_ ok_src.append(ok_src_) n_src += 1 cls_loss = cls_loss / n_src ok_src = torch.cat(ok_src) return cls_loss, ok_src
def compute_loss(self, batch, split_name="valid"): x, y, domain_labels = batch phi_x = self.forward(x) moment_loss = self._compute_domain_dist(phi_x, domain_labels) src_idx = torch.where(domain_labels != self.target_label)[0] tgt_idx = torch.where(domain_labels == self.target_label)[0] cls_loss, ok_src = self._compute_cls_loss(phi_x[src_idx], y[src_idx], domain_labels[src_idx]) if len(tgt_idx) > 0: y_tgt_hat = _average_cls_output(phi_x[tgt_idx], self.classifiers) _, ok_tgt = losses.cross_entropy_logits(y_tgt_hat, y[tgt_idx]) else: ok_tgt = 0.0 task_loss = cls_loss log_metrics = { f"{split_name}_source_acc": ok_src, f"{split_name}_target_acc": ok_tgt, f"{split_name}_domain_acc": moment_loss, } return task_loss, moment_loss, log_metrics
def compute_loss(self, batch, split_name="valid"): # _s refers to source, _tu refers to unlabeled target x_s_rgb = x_tu_rgb = x_s_flow = x_tu_flow = None if self.rgb: if self.flow: # For joint input (x_s_rgb, y_s), (x_s_flow, y_s_flow), (x_tu_rgb, y_tu), (x_tu_flow, y_tu_flow) = batch else: # For rgb input (x_s_rgb, y_s), (x_tu_rgb, y_tu) = batch else: # For flow input (x_s_flow, y_s), (x_tu_flow, y_tu) = batch _, y_hat, [d_hat_rgb, d_hat_flow] = self.forward({ "rgb": x_s_rgb, "flow": x_s_flow }) _, y_t_hat, [d_t_hat_rgb, d_t_hat_flow] = self.forward({ "rgb": x_tu_rgb, "flow": x_tu_flow }) batch_size = len(y_s) # ok is abbreviation for (all) correct, dok refers to domain correct if self.rgb: _, dok_src_rgb = losses.cross_entropy_logits( d_hat_rgb, torch.zeros(batch_size)) _, dok_tgt_rgb = losses.cross_entropy_logits( d_t_hat_rgb, torch.ones(batch_size)) if self.flow: _, dok_src_flow = losses.cross_entropy_logits( d_hat_flow, torch.zeros(batch_size)) _, dok_tgt_flow = losses.cross_entropy_logits( d_t_hat_flow, torch.ones(batch_size)) if self.rgb and self.flow: # For joint input dok = torch.cat( (dok_src_rgb, dok_src_flow, dok_tgt_rgb, dok_tgt_flow)) dok_src = torch.cat((dok_src_rgb, dok_src_flow)) dok_tgt = torch.cat((dok_tgt_rgb, dok_tgt_flow)) wasserstein_distance_rgb = d_hat_rgb.mean() - ( 1 + self._beta_ratio) * d_t_hat_rgb.mean() wasserstein_distance_flow = d_hat_flow.mean() - ( 1 + self._beta_ratio) * d_t_hat_flow.mean() wasserstein_distance = (wasserstein_distance_rgb + wasserstein_distance_flow) / 2 else: if self.rgb: # For rgb input d_hat = d_hat_rgb d_t_hat = d_t_hat_rgb dok_src = dok_src_rgb dok_tgt = dok_tgt_rgb else: # For flow input d_hat = d_hat_flow d_t_hat = d_t_hat_flow dok_src = dok_src_flow dok_tgt = dok_tgt_flow wasserstein_distance = d_hat.mean() - ( 1 + self._beta_ratio) * d_t_hat.mean() dok = torch.cat((dok_src, dok_tgt)) loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s) _, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu) adv_loss = wasserstein_distance task_loss = loss_cls log_metrics = { f"{split_name}_source_acc": ok_src, f"{split_name}_target_acc": ok_tgt, f"{split_name}_domain_acc": dok, f"{split_name}_source_domain_acc": dok_src, f"{split_name}_target_domain_acc": dok_tgt, f"{split_name}_wasserstein_dist": wasserstein_distance, } return task_loss, adv_loss, log_metrics
def compute_loss(self, batch, split_name="valid"): # _s refers to source, _tu refers to unlabeled target x_s_rgb = x_tu_rgb = x_s_flow = x_tu_flow = None if self.rgb: if self.flow: # For joint input (x_s_rgb, y_s), (x_s_flow, y_s_flow), (x_tu_rgb, y_tu), (x_tu_flow, y_tu_flow) = batch else: # For rgb input (x_s_rgb, y_s), (x_tu_rgb, y_tu) = batch else: # For flow input (x_s_flow, y_s), (x_tu_flow, y_tu) = batch _, y_hat, [d_hat_rgb, d_hat_flow] = self.forward({ "rgb": x_s_rgb, "flow": x_s_flow }) _, y_t_hat, [d_t_hat_rgb, d_t_hat_flow] = self.forward({ "rgb": x_tu_rgb, "flow": x_tu_flow }) batch_size = len(y_s) if self.entropy: e_s = self._compute_entropy_weights(y_hat) e_t = self._compute_entropy_weights(y_t_hat) source_weight = e_s / torch.sum(e_s) target_weight = e_t / torch.sum(e_t) else: source_weight = None target_weight = None if self.rgb: loss_dmn_src_rgb, dok_src_rgb = losses.cross_entropy_logits( d_hat_rgb, torch.zeros(batch_size), source_weight) loss_dmn_tgt_rgb, dok_tgt_rgb = losses.cross_entropy_logits( d_t_hat_rgb, torch.ones(batch_size), target_weight) if self.flow: loss_dmn_src_flow, dok_src_flow = losses.cross_entropy_logits( d_hat_flow, torch.zeros(batch_size), source_weight) loss_dmn_tgt_flow, dok_tgt_flow = losses.cross_entropy_logits( d_t_hat_flow, torch.ones(batch_size), target_weight) # ok is abbreviation for (all) correct, dok refers to domain correct if self.rgb and self.flow: # For joint input loss_dmn_src = loss_dmn_src_rgb + loss_dmn_src_flow loss_dmn_tgt = loss_dmn_tgt_rgb + loss_dmn_tgt_flow dok = torch.cat( (dok_src_rgb, dok_src_flow, dok_tgt_rgb, dok_tgt_flow)) dok_src = torch.cat((dok_src_rgb, dok_src_flow)) dok_tgt = torch.cat((dok_tgt_rgb, dok_tgt_flow)) else: if self.rgb: # For rgb input d_hat = d_hat_rgb d_t_hat = d_t_hat_rgb else: # For flow input d_hat = d_hat_flow d_t_hat = d_t_hat_flow loss_dmn_src, dok_src = losses.cross_entropy_logits( d_hat, torch.zeros(batch_size)) loss_dmn_tgt, dok_tgt = losses.cross_entropy_logits( d_t_hat, torch.ones(batch_size)) dok = torch.cat((dok_src, dok_tgt)) loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s) _, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu) adv_loss = loss_dmn_src + loss_dmn_tgt # adv_loss = src + tgt task_loss = loss_cls log_metrics = { f"{split_name}_source_acc": ok_src, f"{split_name}_target_acc": ok_tgt, f"{split_name}_domain_acc": dok, f"{split_name}_source_domain_acc": dok_src, f"{split_name}_target_domain_acc": dok_tgt, } return task_loss, adv_loss, log_metrics