def build_model(self): cfg = self.cfg print('Building F') self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes) self.F.to(self.device) print('# params: {:,}'.format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model('F', self.F, self.optim_F, self.sched_F) print('Building D') self.D = SimpleNet(cfg, cfg.MODEL, self.dm.num_source_domains) self.D.to(self.device) print('# params: {:,}'.format(count_num_param(self.D))) self.optim_D = build_optimizer(self.D, cfg.OPTIM) self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM) self.register_model('D', self.D, self.optim_D, self.sched_D) print('Building G') self.G = build_model(cfg.TRAINER.DDAIG.G_ARCH, verbose=cfg.VERBOSE) self.G.to(self.device) print('# params: {:,}'.format(count_num_param(self.G))) self.optim_G = build_optimizer(self.G, cfg.OPTIM) self.sched_G = build_lr_scheduler(self.optim_G, cfg.OPTIM) self.register_model('G', self.G, self.optim_G, self.sched_G)
def build_model(self): cfg = self.cfg img_channels = cfg.DATASET.N_CHANNELS if 'grayscale' in cfg.INPUT.TRANSFORMS: img_channels = 1 print("Found grayscale! Set img_channels to 1") backbone_in_channels = img_channels * cfg.DATASET.NUM_STACK print(f'Building F with {backbone_in_channels} in channels') self.F = SimpleNet(cfg, cfg.MODEL, 0, in_channels=backbone_in_channels) self.F.to(self.device) print('# params: {:,}'.format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model('F', self.F, self.optim_F, self.sched_F) fdim = self.F.fdim print('Building E') self.E = Experts(self.dm.num_source_domains, fdim, self.num_classes, regressive=self.is_regressive) self.E.to(self.device) print('# params: {:,}'.format(count_num_param(self.E))) self.optim_E = build_optimizer(self.E, cfg.OPTIM) self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM) self.register_model('E', self.E, self.optim_E, self.sched_E) print('Building G') self.G = Gate(fdim, self.dm.num_source_domains) self.G.to(self.device) print('# params: {:,}'.format(count_num_param(self.G))) self.optim_G = build_optimizer(self.G, cfg.OPTIM) self.sched_G = build_lr_scheduler(self.optim_G, cfg.OPTIM) self.register_model('G', self.G, self.optim_G, self.sched_G)
def build_model(self): cfg = self.cfg print('Building F') self.F = SimpleNet(cfg, cfg.MODEL, 0) self.F.to(self.device) print('# params: {:,}'.format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model('F', self.F, self.optim_F, self.sched_F) fdim = self.F.fdim print('Building C1') self.C1 = nn.Linear(fdim, self.num_classes) self.C1.to(self.device) print('# params: {:,}'.format(count_num_param(self.C1))) self.optim_C1 = build_optimizer(self.C1, cfg.OPTIM) self.sched_C1 = build_lr_scheduler(self.optim_C1, cfg.OPTIM) self.register_model('C1', self.C1, self.optim_C1, self.sched_C1) print('Building C2') self.C2 = nn.Linear(fdim, self.num_classes) self.C2.to(self.device) print('# params: {:,}'.format(count_num_param(self.C2))) self.optim_C2 = build_optimizer(self.C2, cfg.OPTIM) self.sched_C2 = build_lr_scheduler(self.optim_C2, cfg.OPTIM) self.register_model('C2', self.C2, self.optim_C2, self.sched_C2)
class MME(TrainerXU): """Minimax Entropy. https://arxiv.org/abs/1904.06487. """ def __init__(self, cfg): super().__init__(cfg) self.lmda = cfg.TRAINER.MME.LMDA def build_model(self): cfg = self.cfg print('Building F') self.F = SimpleNet(cfg, cfg.MODEL, 0) self.F.to(self.device) print('# params: {:,}'.format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model('F', self.F, self.optim_F, self.sched_F) print('Building C') self.C = Prototypes(self.F.fdim, self.num_classes) self.C.to(self.device) print('# params: {:,}'.format(count_num_param(self.C))) self.optim_C = build_optimizer(self.C, cfg.OPTIM) self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM) self.register_model('C', self.C, self.optim_C, self.sched_C) self.revgrad = ReverseGrad() def forward_backward(self, batch_x, batch_u): input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u) feat_x = self.F(input_x) logit_x = self.C(feat_x) loss_x = F.cross_entropy(logit_x, label_x) self.model_backward_and_update(loss_x) feat_u = self.F(input_u) feat_u = self.revgrad(feat_u) logit_u = self.C(feat_u) prob_u = F.softmax(logit_u, 1) loss_u = -(-prob_u * torch.log(prob_u + 1e-5)).sum(1).mean() self.model_backward_and_update(loss_u * self.lmda) output_dict = { 'loss_x': loss_x.item(), 'acc_x': compute_accuracy(logit_x.detach(), label_x)[0].item(), 'loss_u': loss_u.item(), 'lr': self.optim_F.param_groups[0]['lr'] } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return output_dict def model_inference(self, input): return self.C(self.F(input))
def build_model(self): cfg = self.cfg print('Building F') self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes) self.F.to(self.device) print('# params: {:,}'.format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model('F', self.F, self.optim_F, self.sched_F) print('Building D') self.D = SimpleNet(cfg, cfg.MODEL, self.dm.num_source_domains) self.D.to(self.device) print('# params: {:,}'.format(count_num_param(self.D))) self.optim_D = build_optimizer(self.D, cfg.OPTIM) self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM) self.register_model('D', self.D, self.optim_D, self.sched_D)
def build_model(self): cfg = self.cfg print('Building F') self.F = SimpleNet(cfg, cfg.MODEL, 0) self.F.to(self.device) print('# params: {:,}'.format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model('F', self.F, self.optim_F, self.sched_F) fdim = self.F.fdim print('Building E') self.E = Experts(self.dm.num_source_domains, fdim, self.num_classes) self.E.to(self.device) print('# params: {:,}'.format(count_num_param(self.E))) self.optim_E = build_optimizer(self.E, cfg.OPTIM) self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM) self.register_model('E', self.E, self.optim_E, self.sched_E)
def build_model(self): cfg = self.cfg print('Building F') self.F = SimpleNet(cfg, cfg.MODEL, 0) self.F.to(self.device) print('# params: {:,}'.format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model('F', self.F, self.optim_F, self.sched_F) print('Building C') self.C = Prototypes(self.F.fdim, self.num_classes) self.C.to(self.device) print('# params: {:,}'.format(count_num_param(self.C))) self.optim_C = build_optimizer(self.C, cfg.OPTIM) self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM) self.register_model('C', self.C, self.optim_C, self.sched_C) self.revgrad = ReverseGrad()
def build_model(self): cfg = self.cfg print('Building multiple source models') self.models = nn.ModuleList([ SimpleNet(cfg, cfg.MODEL, self.num_classes, cfg.MODEL.CLASSIFIER.TYPE) for _ in range(self.dm.num_source_domains) ]) self.models.to(self.device) print('# params: {:,}'.format(count_num_param(self.models))) self.register_model('models', self.models)
def build_model(self): cfg = self.cfg print('Building F') self.F = SimpleNet(cfg, cfg.MODEL, 0) self.F.to(self.device) print('# params: {:,}'.format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model('F', self.F, self.optim_F, self.sched_F) fdim = self.F.fdim print('Building C') self.C = nn.ModuleList([ PairClassifiers(fdim, self.num_classes) for _ in range(self.dm.num_source_domains) ]) self.C.to(self.device) print('# params: {:,}'.format(count_num_param(self.C))) self.optim_C = build_optimizer(self.C, cfg.OPTIM) self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM) self.register_model('C', self.C, self.optim_C, self.sched_C)
class M3SDA(TrainerXU): """Moment Matching for Multi-Source Domain Adaptation. https://arxiv.org/abs/1812.01754. """ def __init__(self, cfg): super().__init__(cfg) n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE if n_domain <= 0: n_domain = self.dm.num_source_domains self.split_batch = batch_size // n_domain self.n_domain = n_domain self.n_step_F = cfg.TRAINER.M3SDA.N_STEP_F self.lmda = cfg.TRAINER.M3SDA.LMDA def check_cfg(self, cfg): assert cfg.DATALOADER.TRAIN_X.SAMPLER == 'RandomDomainSampler' assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X def build_model(self): cfg = self.cfg print('Building F') self.F = SimpleNet(cfg, cfg.MODEL, 0) self.F.to(self.device) print('# params: {:,}'.format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model('F', self.F, self.optim_F, self.sched_F) fdim = self.F.fdim print('Building C') self.C = nn.ModuleList([ PairClassifiers(fdim, self.num_classes) for _ in range(self.dm.num_source_domains) ]) self.C.to(self.device) print('# params: {:,}'.format(count_num_param(self.C))) self.optim_C = build_optimizer(self.C, cfg.OPTIM) self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM) self.register_model('C', self.C, self.optim_C, self.sched_C) def forward_backward(self, batch_x, batch_u): parsed = self.parse_batch_train(batch_x, batch_u) input_x, label_x, domain_x, input_u = parsed input_x = torch.split(input_x, self.split_batch, 0) label_x = torch.split(label_x, self.split_batch, 0) domain_x = torch.split(domain_x, self.split_batch, 0) domain_x = [d[0].item() for d in domain_x] # Step A loss_x = 0 feat_x = [] for x, y, d in zip(input_x, label_x, domain_x): f = self.F(x) z1, z2 = self.C[d](f) loss_x += F.cross_entropy(z1, y) + F.cross_entropy(z2, y) feat_x.append(f) loss_x /= self.n_domain feat_u = self.F(input_u) loss_msda = self.moment_distance(feat_x, feat_u) loss_step_A = loss_x + loss_msda * self.lmda self.model_backward_and_update(loss_step_A) # Step B with torch.no_grad(): feat_u = self.F(input_u) loss_x, loss_dis = 0, 0 for x, y, d in zip(input_x, label_x, domain_x): with torch.no_grad(): f = self.F(x) z1, z2 = self.C[d](f) loss_x += F.cross_entropy(z1, y) + F.cross_entropy(z2, y) z1, z2 = self.C[d](feat_u) p1 = F.softmax(z1, 1) p2 = F.softmax(z2, 1) loss_dis += self.discrepancy(p1, p2) loss_x /= self.n_domain loss_dis /= self.n_domain loss_step_B = loss_x - loss_dis self.model_backward_and_update(loss_step_B, 'C') # Step C for _ in range(self.n_step_F): feat_u = self.F(input_u) loss_dis = 0 for d in domain_x: z1, z2 = self.C[d](feat_u) p1 = F.softmax(z1, 1) p2 = F.softmax(z2, 1) loss_dis += self.discrepancy(p1, p2) loss_dis /= self.n_domain loss_step_C = loss_dis self.model_backward_and_update(loss_step_C, 'F') loss_summary = { 'loss_step_A': loss_step_A.item(), 'loss_step_B': loss_step_B.item(), 'loss_step_C': loss_step_C.item() } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def moment_distance(self, x, u): # x (list): a list of feature matrix. # u (torch.Tensor): feature matrix. x_mean = [xi.mean(0) for xi in x] u_mean = u.mean(0) dist1 = self.pairwise_distance(x_mean, u_mean) x_var = [xi.var(0) for xi in x] u_var = u.var(0) dist2 = self.pairwise_distance(x_var, u_var) return (dist1 + dist2) / 2 def pairwise_distance(self, x, u): # x (list): a list of feature vector. # u (torch.Tensor): feature vector. dist = 0 count = 0 for xi in x: dist += self.euclidean(xi, u) count += 1 for i in range(len(x) - 1): for j in range(i + 1, len(x)): dist += self.euclidean(x[i], x[j]) count += 1 return dist / count def euclidean(self, input1, input2): return ((input1 - input2)**2).sum().sqrt() def discrepancy(self, y1, y2): return (y1 - y2).abs().mean() def parse_batch_train(self, batch_x, batch_u): input_x = batch_x['img'] label_x = batch_x['label'] domain_x = batch_x['domain'] input_u = batch_u['img'] input_x = input_x.to(self.device) label_x = label_x.to(self.device) input_u = input_u.to(self.device) return input_x, label_x, domain_x, input_u def model_inference(self, input): f = self.F(input) p = 0 for C_i in self.C: z = C_i(f) p += F.softmax(z, 1) p = p / len(self.C) return p
class MCD(TrainerXU): """Maximum Classifier Discrepancy. https://arxiv.org/abs/1712.02560. """ def __init__(self, cfg): super().__init__(cfg) self.n_step_F = cfg.TRAINER.MCD.N_STEP_F def build_model(self): cfg = self.cfg print('Building F') self.F = SimpleNet(cfg, cfg.MODEL, 0) self.F.to(self.device) print('# params: {:,}'.format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model('F', self.F, self.optim_F, self.sched_F) fdim = self.F.fdim print('Building C1') self.C1 = nn.Linear(fdim, self.num_classes) self.C1.to(self.device) print('# params: {:,}'.format(count_num_param(self.C1))) self.optim_C1 = build_optimizer(self.C1, cfg.OPTIM) self.sched_C1 = build_lr_scheduler(self.optim_C1, cfg.OPTIM) self.register_model('C1', self.C1, self.optim_C1, self.sched_C1) print('Building C2') self.C2 = nn.Linear(fdim, self.num_classes) self.C2.to(self.device) print('# params: {:,}'.format(count_num_param(self.C2))) self.optim_C2 = build_optimizer(self.C2, cfg.OPTIM) self.sched_C2 = build_lr_scheduler(self.optim_C2, cfg.OPTIM) self.register_model('C2', self.C2, self.optim_C2, self.sched_C2) def forward_backward(self, batch_x, batch_u): parsed = self.parse_batch_train(batch_x, batch_u) input_x, label_x, input_u = parsed # Step A feat_x = self.F(input_x) logit_x1 = self.C1(feat_x) logit_x2 = self.C2(feat_x) loss_x1 = F.cross_entropy(logit_x1, label_x) loss_x2 = F.cross_entropy(logit_x2, label_x) loss_step_A = loss_x1 + loss_x2 self.model_backward_and_update(loss_step_A) # Step B with torch.no_grad(): feat_x = self.F(input_x) logit_x1 = self.C1(feat_x) logit_x2 = self.C2(feat_x) loss_x1 = F.cross_entropy(logit_x1, label_x) loss_x2 = F.cross_entropy(logit_x2, label_x) loss_x = loss_x1 + loss_x2 with torch.no_grad(): feat_u = self.F(input_u) pred_u1 = F.softmax(self.C1(feat_u), 1) pred_u2 = F.softmax(self.C2(feat_u), 1) loss_dis = self.discrepancy(pred_u1, pred_u2) loss_step_B = loss_x - loss_dis self.model_backward_and_update(loss_step_B, ['C1', 'C2']) # Step C for _ in range(self.n_step_F): feat_u = self.F(input_u) pred_u1 = F.softmax(self.C1(feat_u), 1) pred_u2 = F.softmax(self.C2(feat_u), 1) loss_step_C = self.discrepancy(pred_u1, pred_u2) self.model_backward_and_update(loss_step_C, 'F') loss_summary = { 'loss_step_A': loss_step_A.item(), 'loss_step_B': loss_step_B.item(), 'loss_step_C': loss_step_C.item() } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def discrepancy(self, y1, y2): return (y1 - y2).abs().mean() def model_inference(self, input): feat = self.F(input) return self.C1(feat)
class CrossGrad(TrainerX): """Cross-gradient training. https://arxiv.org/abs/1804.10745. """ def __init__(self, cfg): super().__init__(cfg) self.eps_f = cfg.TRAINER.CG.EPS_F self.eps_d = cfg.TRAINER.CG.EPS_D self.alpha_f = cfg.TRAINER.CG.ALPHA_F self.alpha_d = cfg.TRAINER.CG.ALPHA_D def build_model(self): cfg = self.cfg print('Building F') self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes) self.F.to(self.device) print('# params: {:,}'.format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model('F', self.F, self.optim_F, self.sched_F) print('Building D') self.D = SimpleNet(cfg, cfg.MODEL, self.dm.num_source_domains) self.D.to(self.device) print('# params: {:,}'.format(count_num_param(self.D))) self.optim_D = build_optimizer(self.D, cfg.OPTIM) self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM) self.register_model('D', self.D, self.optim_D, self.sched_D) def forward_backward(self, batch): input, label, domain = self.parse_batch_train(batch) input.requires_grad = True # Compute domain perturbation loss_d = F.cross_entropy(self.D(input), domain) loss_d.backward() grad_d = torch.clamp(input.grad.data, min=-0.1, max=0.1) input_d = input.data + self.eps_f * grad_d # Compute label perturbation input.grad.data.zero_() loss_f = F.cross_entropy(self.F(input), label) loss_f.backward() grad_f = torch.clamp(input.grad.data, min=-0.1, max=0.1) input_f = input.data + self.eps_d * grad_f input = input.detach() # Update label net loss_f1 = F.cross_entropy(self.F(input), label) loss_f2 = F.cross_entropy(self.F(input_d), label) loss_f = (1 - self.alpha_f) * loss_f1 + self.alpha_f * loss_f2 self.model_backward_and_update(loss_f, 'F') # Update domain net loss_d1 = F.cross_entropy(self.D(input), domain) loss_d2 = F.cross_entropy(self.D(input_f), domain) loss_d = (1 - self.alpha_d) * loss_d1 + self.alpha_d * loss_d2 self.model_backward_and_update(loss_d, 'D') loss_summary = {'loss_f': loss_f.item(), 'loss_d': loss_d.item()} if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def model_inference(self, input): return self.F(input)
class DAELDG(TrainerX): """Domain Adaptive Ensemble Learning. DG version: only use labeled source data. https://arxiv.org/abs/2003.07325. """ def __init__(self, cfg): super().__init__(cfg) n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE if n_domain <= 0: n_domain = self.dm.num_source_domains self.split_batch = batch_size // n_domain self.n_domain = n_domain self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE def check_cfg(self, cfg): assert cfg.DATALOADER.TRAIN_X.SAMPLER == 'RandomDomainSampler' assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0 def build_data_loader(self): cfg = self.cfg tfm_train = build_transform(cfg, is_train=True) custom_tfm_train = [tfm_train] choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS tfm_train_strong = build_transform(cfg, is_train=True, choices=choices) custom_tfm_train += [tfm_train_strong] self.dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train) self.train_loader_x = self.dm.train_loader_x self.train_loader_u = self.dm.train_loader_u self.val_loader = self.dm.val_loader self.test_loader = self.dm.test_loader self.num_classes = self.dm.num_classes def build_model(self): cfg = self.cfg print('Building F') self.F = SimpleNet(cfg, cfg.MODEL, 0) self.F.to(self.device) print('# params: {:,}'.format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model('F', self.F, self.optim_F, self.sched_F) fdim = self.F.fdim print('Building E') self.E = Experts(self.dm.num_source_domains, fdim, self.num_classes) self.E.to(self.device) print('# params: {:,}'.format(count_num_param(self.E))) self.optim_E = build_optimizer(self.E, cfg.OPTIM) self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM) self.register_model('E', self.E, self.optim_E, self.sched_E) def forward_backward(self, batch): parsed_data = self.parse_batch_train(batch) input, input2, label, domain = parsed_data input = torch.split(input, self.split_batch, 0) input2 = torch.split(input2, self.split_batch, 0) label = torch.split(label, self.split_batch, 0) domain = torch.split(domain, self.split_batch, 0) domain = [d[0].item() for d in domain] loss = 0 loss_cr = 0 acc = 0 feat = [self.F(x) for x in input] feat2 = [self.F(x) for x in input2] for feat_i, feat2_i, label_i, i in zip(feat, feat2, label, domain): cr_s = [j for j in domain if j != i] # Learning expert pred_i = self.E(i, feat_i) loss += (-label_i * torch.log(pred_i + 1e-5)).sum(1).mean() expert_label_i = pred_i.detach() acc += compute_accuracy(pred_i.detach(), label_i.max(1)[1])[0].item() # Consistency regularization cr_pred = [] for j in cr_s: pred_j = self.E(j, feat2_i) pred_j = pred_j.unsqueeze(1) cr_pred.append(pred_j) cr_pred = torch.cat(cr_pred, 1) cr_pred = cr_pred.mean(1) loss_cr += ((cr_pred - expert_label_i)**2).sum(1).mean() loss /= self.n_domain loss_cr /= self.n_domain acc /= self.n_domain loss = 0 loss += loss loss += loss_cr self.model_backward_and_update(loss) loss_summary = { 'loss': loss.item(), 'acc': acc, 'loss_cr': loss_cr.item() } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def parse_batch_train(self, batch): input = batch['img'] input2 = batch['img2'] label = batch['label'] domain = batch['domain'] label = create_onehot(label, self.num_classes) input = input.to(self.device) input2 = input2.to(self.device) label = label.to(self.device) return input, input2, label, domain def model_inference(self, input): f = self.F(input) p = [] for k in range(self.dm.num_source_domains): p_k = self.E(k, f) p_k = p_k.unsqueeze(1) p.append(p_k) p = torch.cat(p, 1) p = p.mean(1) return p
class AltDAELGated(TrainerXU): """Domain Adaptive Ensemble Learning. https://arxiv.org/abs/2003.07325. """ def __init__(self, cfg): self.is_regressive = cfg.TRAINER.DAEL.TASK.lower() == "regression" super().__init__(cfg) n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE if n_domain <= 0: n_domain = self.dm.num_source_domains self.split_batch = batch_size // n_domain self.n_domain = n_domain self.weight_u = cfg.TRAINER.DAEL.WEIGHT_U self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE def check_cfg(self, cfg): assert cfg.DATALOADER.TRAIN_X.SAMPLER == 'RandomDomainSampler' assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0 def build_data_loader(self): cfg = self.cfg tfm_train = build_transform(cfg, is_train=True) custom_tfm_train = [tfm_train] choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS tfm_train_strong = build_transform(cfg, is_train=True, choices=choices) custom_tfm_train += [tfm_train_strong] self.dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train) self.train_loader_x = self.dm.train_loader_x self.train_loader_u = self.dm.train_loader_u self.val_loader = self.dm.val_loader self.test_loader = self.dm.test_loader self.num_classes = self.dm.num_classes def build_model(self): cfg = self.cfg img_channels = cfg.DATASET.N_CHANNELS if 'grayscale' in cfg.INPUT.TRANSFORMS: img_channels = 1 print("Found grayscale! Set img_channels to 1") backbone_in_channels = img_channels * cfg.DATASET.NUM_STACK print(f'Building F with {backbone_in_channels} in channels') self.F = SimpleNet(cfg, cfg.MODEL, 0, in_channels=backbone_in_channels) self.F.to(self.device) print('# params: {:,}'.format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model('F', self.F, self.optim_F, self.sched_F) fdim = self.F.fdim print('Building E') self.E = Experts(self.dm.num_source_domains, fdim, self.num_classes, regressive=self.is_regressive) self.E.to(self.device) print('# params: {:,}'.format(count_num_param(self.E))) self.optim_E = build_optimizer(self.E, cfg.OPTIM) self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM) self.register_model('E', self.E, self.optim_E, self.sched_E) print('Building G') self.G = Gate(fdim, self.dm.num_source_domains) self.G.to(self.device) print('# params: {:,}'.format(count_num_param(self.G))) self.optim_G = build_optimizer(self.G, cfg.OPTIM) self.sched_G = build_lr_scheduler(self.optim_G, cfg.OPTIM) self.register_model('G', self.G, self.optim_G, self.sched_G) def d_closest(self, d_filter): n_dom = d_filter.shape[1] closest = d_filter.max(1)[1] n_closest = torch.zeros(n_dom) for dom in range(n_dom): times_closest = torch.Tensor([ 1 for i in range(len(closest)) if closest[i] == dom ]).sum().item() n_closest[dom] = (times_closest / len(d_filter)) return n_closest def forward_backward(self, batch_x, batch_u): # Load data parsed_data = self.parse_batch_train(batch_x, batch_u) input_x, input_x2, label_x, domain_x, input_u, input_u2 = parsed_data input_x = torch.split(input_x, self.split_batch, 0) input_x2 = torch.split(input_x2, self.split_batch, 0) label_x = torch.split(label_x, self.split_batch, 0) domain_x = torch.split(domain_x, self.split_batch, 0) domain_x = [d[0].item() for d in domain_x] # x = data with small augmentations. x2 = data with large augmentations # They both correspond to the same datapoints. Same scheme for u and u2. # Generate pseudo label with torch.no_grad(): # Unsupervised predictions feat_u = self.F(input_u) pred_u = [] for k in range(self.dm.num_source_domains): pred_uk = self.E(k, feat_u) pred_uk = pred_uk.unsqueeze(1) pred_u.append(pred_uk) pred_u = torch.cat(pred_u, 1) # (B, K, C) # Pseudolabel = weighted predictions u_filter = self.G(feat_u) # (B, K) label_u_mask = (u_filter.max(1)[0] >= self.conf_thre ) # (B). 1 if >=1 expert > thre, 0 otherwise new_u_filter = torch.zeros(*u_filter.shape).to(self.device) for i, row in enumerate(u_filter): j_max = row.max(0)[1] new_u_filter[i, j_max] = 1 u_filter = new_u_filter d_closest = self.d_closest(u_filter).max(0)[1] u_filter = u_filter.unsqueeze(2).expand(*pred_u.shape) pred_fu = (pred_u * u_filter).sum( 1) # Zero out all non chosen experts pseudo_label_u = pred_fu.max(1)[1] # (B) pseudo_label_u = create_onehot(pseudo_label_u, self.num_classes).to(self.device) # Init losses loss_x = 0 loss_cr = 0 acc_x = 0 loss_filter = 0 acc_filter = 0 # Supervised and unsupervised features feat_x = [self.F(x) for x in input_x] feat_x2 = [self.F(x) for x in input_x2] feat_u2 = self.F(input_u2) for feat_xi, feat_x2i, label_xi, i in zip(feat_x, feat_x2, label_x, domain_x): cr_s = [j for j in domain_x if j != i] # Learning expert pred_xi = self.E(i, feat_xi) expert_label_xi = pred_xi.detach() if self.is_regressive: loss_x += ((pred_xi - label_xi)**2).sum(1).mean() else: loss_x += (-label_xi * torch.log(pred_xi + 1e-5)).sum(1).mean() acc_x += compute_accuracy(pred_xi.detach(), label_xi.max(1)[1])[0].item() x_filter = self.G(feat_xi) # Filter must be 1 for expert, 0 otherwise filter_label = torch.Tensor([0 for _ in range(len(domain_x)) ]).to(self.device) filter_label[i] = 1 filter_label = filter_label.unsqueeze(0).expand(*x_filter.shape) loss_filter += (-filter_label * torch.log(x_filter + 1e-5)).sum(1).mean() acc_filter += compute_accuracy(x_filter.detach(), filter_label.max(1)[1])[0].item() # Consistency regularization - Mean must follow the leading expert cr_pred = [] for j in cr_s: pred_j = self.E(j, feat_x2i) pred_j = pred_j.unsqueeze(1) cr_pred.append(pred_j) cr_pred = torch.cat(cr_pred, 1).mean(1) loss_cr += ((cr_pred - expert_label_xi)**2).sum(1).mean() loss_x /= self.n_domain loss_cr /= self.n_domain if not self.is_regressive: acc_x /= self.n_domain loss_filter /= self.n_domain acc_filter /= self.n_domain # Unsupervised loss pred_u = [] for k in range(self.dm.num_source_domains): pred_uk = self.E(k, feat_u2) pred_uk = pred_uk.unsqueeze(1) pred_u.append(pred_uk) pred_u = torch.cat(pred_u, 1).to(self.device) pred_u = pred_u.mean(1) if self.is_regressive: l_u = (-pseudo_label_u * torch.log(pred_u + 1e-5)).sum(1) else: l_u = ((pseudo_label_u - pred_u)**2).sum(1).mean() loss_u = (l_u * label_u_mask).mean() loss = 0 loss += loss_x loss += loss_cr loss += loss_filter loss += loss_u * self.weight_u self.model_backward_and_update(loss) loss_summary = { 'loss_x': loss_x.item(), 'loss_filter': loss_filter.item(), 'acc_filter': acc_filter, 'loss_cr': loss_cr.item(), 'loss_u': loss_u.item(), #'d_closest': d_closest.max(0)[1] 'd_closest': d_closest.item() } if not self.is_regressive: loss_summary['acc_x'] = acc_x if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def parse_batch_train(self, batch_x, batch_u): input_x = batch_x['img'] input_x2 = batch_x['img2'] label_x = batch_x['label'] domain_x = batch_x['domain'] input_u = batch_u['img'] input_u2 = batch_u['img2'] if self.is_regressive: label_x = torch.cat([torch.unsqueeze(x, 1) for x in label_x], 1) #Stack list of tensors else: label_x = create_onehot(label_x, self.num_classes) input_x = input_x.to(self.device) input_x2 = input_x2.to(self.device) label_x = label_x.to(self.device) input_u = input_u.to(self.device) input_u2 = input_u2.to(self.device) return input_x, input_x2, label_x, domain_x, input_u, input_u2 def parse_batch_test(self, batch): if self.is_regressive: input = batch['img'] label = batch['label'] label = torch.cat([torch.unsqueeze(x, 1) for x in label], 1) #Stack list of tensors input = input.to(self.device) label = label.to(self.device) else: input, label = super().parse_batch_test(batch) return input, label def model_inference(self, input): f = self.F(input) g = self.G(f).unsqueeze(2) p = [] for k in range(self.dm.num_source_domains): p_k = self.E(k, f) p_k = p_k.unsqueeze(1) p.append(p_k) new_g = torch.zeros(*g.shape).to(self.device) for i, row in enumerate(g): j_max = row.max(0)[1] new_g[i, j_max] = 1 p = torch.cat(p, 1) p = (p * new_g).sum(1) return p, g @torch.no_grad() def test(self): """A generic testing pipeline.""" self.set_model_mode('eval') self.evaluator.reset() split = self.cfg.TEST.SPLIT print('Do evaluation on {} set'.format(split)) data_loader = self.val_loader if split == 'val' else self.test_loader assert data_loader is not None all_d_filter = [] for batch_idx, batch in enumerate(data_loader): input, label = self.parse_batch_test(batch) output, d_filter = self.model_inference(input) all_d_filter.append(d_filter) self.evaluator.process(output, label) results = self.evaluator.evaluate() all_d_filter = torch.cat(all_d_filter, 0).mean(0).cpu().detach().numpy() print(f"* Average gate: {list(all_d_filter)}") for k, v in results.items(): tag = '{}/{}'.format(split, k) self.write_scalar(tag, v, self.epoch)
class DDAIG(TrainerX): """Deep Domain-Adversarial Image Generation. https://arxiv.org/abs/2003.06054. """ def __init__(self, cfg): super().__init__(cfg) self.lmda = cfg.TRAINER.DDAIG.LMDA self.clamp = cfg.TRAINER.DDAIG.CLAMP self.clamp_min = cfg.TRAINER.DDAIG.CLAMP_MIN self.clamp_max = cfg.TRAINER.DDAIG.CLAMP_MAX self.warmup = cfg.TRAINER.DDAIG.WARMUP self.alpha = cfg.TRAINER.DDAIG.ALPHA def build_model(self): cfg = self.cfg print('Building F') self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes) self.F.to(self.device) print('# params: {:,}'.format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model('F', self.F, self.optim_F, self.sched_F) print('Building D') self.D = SimpleNet(cfg, cfg.MODEL, self.dm.num_source_domains) self.D.to(self.device) print('# params: {:,}'.format(count_num_param(self.D))) self.optim_D = build_optimizer(self.D, cfg.OPTIM) self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM) self.register_model('D', self.D, self.optim_D, self.sched_D) print('Building G') self.G = build_model(cfg.TRAINER.DDAIG.G_ARCH, verbose=cfg.VERBOSE) self.G.to(self.device) print('# params: {:,}'.format(count_num_param(self.G))) self.optim_G = build_optimizer(self.G, cfg.OPTIM) self.sched_G = build_lr_scheduler(self.optim_G, cfg.OPTIM) self.register_model('G', self.G, self.optim_G, self.sched_G) def forward_backward(self, batch): input, label, domain = self.parse_batch_train(batch) ############# # Update G ############# input_p = self.G(input, lmda=self.lmda) if self.clamp: input_p = torch.clamp( input_p, min=self.clamp_min, max=self.clamp_max ) loss_g = 0 # Minimize label loss loss_g += F.cross_entropy(self.F(input_p), label) # Maximize domain loss loss_g -= F.cross_entropy(self.D(input_p), domain) self.model_backward_and_update(loss_g, 'G') # Perturb data with new G with torch.no_grad(): input_p = self.G(input, lmda=self.lmda) if self.clamp: input_p = torch.clamp( input_p, min=self.clamp_min, max=self.clamp_max ) ############# # Update F ############# loss_f = F.cross_entropy(self.F(input), label) if (self.epoch + 1) > self.warmup: loss_fp = F.cross_entropy(self.F(input_p), label) loss_f = (1. - self.alpha) * loss_f + self.alpha * loss_fp self.model_backward_and_update(loss_f, 'F') ############# # Update D ############# loss_d = F.cross_entropy(self.D(input), domain) self.model_backward_and_update(loss_d, 'D') output_dict = { 'loss_g': loss_g.item(), 'loss_f': loss_f.item(), 'loss_d': loss_d.item(), 'lr': self.optim_F.param_groups[0]['lr'] } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return output_dict def model_inference(self, input): return self.F(input)
class DAELReg(TrainerXU): """Domain Adaptive Ensemble Learning. https://arxiv.org/abs/2003.07325. """ def __init__(self, cfg): super().__init__(cfg) self.is_regressive = cfg.TRAINER.DAEL.TASK.lower() == "regression" n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE if n_domain <= 0: n_domain = self.dm.num_source_domains self.split_batch = batch_size // n_domain self.n_domain = n_domain self.weight_u = cfg.TRAINER.DAEL.WEIGHT_U self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE # self.e_crit = e_crit # self.cr_crit= cr_crit def check_cfg(self, cfg): assert cfg.DATALOADER.TRAIN_X.SAMPLER == 'RandomDomainSampler' assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0 def build_data_loader(self): cfg = self.cfg tfm_train = build_transform(cfg, is_train=True) custom_tfm_train = [tfm_train] choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS tfm_train_strong = build_transform(cfg, is_train=True, choices=choices) custom_tfm_train += [tfm_train_strong] self.dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train) self.train_loader_x = self.dm.train_loader_x self.train_loader_u = self.dm.train_loader_u self.val_loader = self.dm.val_loader self.test_loader = self.dm.test_loader self.num_classes = self.dm.num_classes def build_model(self): cfg = self.cfg img_channels = cfg.DATASET.N_CHANNELS if 'grayscale' in cfg.INPUT.TRANSFORMS: img_channels = 1 print("Found grayscale! Set img_channels to 1") backbone_in_channels = img_channels * cfg.DATASET.NUM_STACK print(f'Building F with {backbone_in_channels} in channels') self.F = SimpleNet(cfg, cfg.MODEL, 0, in_channels=backbone_in_channels) self.F.to(self.device) print('# params: {:,}'.format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model('F', self.F, self.optim_F, self.sched_F) fdim = self.F.fdim print('Building E') self.E = Experts(self.dm.num_source_domains, fdim, self.num_classes) self.E.to(self.device) print('# params: {:,}'.format(count_num_param(self.E))) self.optim_E = build_optimizer(self.E, cfg.OPTIM) self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM) self.register_model('E', self.E, self.optim_E, self.sched_E) def forward_backward(self, batch_x, batch_u): parsed_data = self.parse_batch_train(batch_x, batch_u) input_x, input_x2, label_x, domain_x, input_u, input_u2 = parsed_data input_x = torch.split(input_x, self.split_batch, 0) input_x2 = torch.split(input_x2, self.split_batch, 0) label_x = torch.split(label_x, self.split_batch, 0) domain_x = torch.split(domain_x, self.split_batch, 0) domain_x = [d[0].item() for d in domain_x] # Generate pseudo label with torch.no_grad(): feat_u = self.F(input_u) pred_u = [] for k in range(self.dm.num_source_domains): pred_uk = self.E(k, feat_u) pred_uk = pred_uk.unsqueeze(1) pred_u.append(pred_uk) pred_u = torch.cat(pred_u, 1) # (B, K, C) # Get the median prediction for each action pred_u = pred_u.median(1).values #Note that there is no leading expert, we just take the median of all predictions loss_x = 0 loss_cr = 0 if not self.is_regressive: acc_x = 0 feat_x = [self.F(x) for x in input_x] feat_x2 = [self.F(x) for x in input_x2] # feat_u2 = self.F(input_u2) for feat_xi, feat_x2i, label_xi, i in zip(feat_x, feat_x2, label_x, domain_x): cr_s = [j for j in domain_x if j != i] # Learning expert pred_xi = self.E(i, feat_xi) if self.is_regressive: loss_x += ((pred_xi - label_xi)**2).sum(1).mean() else: loss_x += (-label_xi * torch.log(pred_xi + 1e-5)).sum(1).mean() acc_x += compute_accuracy(pred_xi.detach(), label_xi.max(1)[1])[0].item() expert_label_xi = pred_xi.detach() # Consistency regularization cr_pred = [] for j in cr_s: pred_j = self.E(j, feat_x2i) pred_j = pred_j.unsqueeze(1) cr_pred.append(pred_j) cr_pred = torch.cat(cr_pred, 1) cr_pred = cr_pred.mean(1) loss_cr += ((cr_pred - expert_label_xi)**2).sum(1).mean() loss_x /= self.n_domain loss_cr /= self.n_domain if not self.is_regressive: acc_x /= self.n_domain # Unsupervised loss -> None yet # Pending: provide a means of establishing a lead expert so that loss can be calculated loss = 0 loss += loss_x loss += loss_cr # loss += loss_u * self.weight_u self.model_backward_and_update(loss) loss_summary = { 'loss_x': loss_x.item(), 'loss_cr': loss_cr.item(), # 'loss_u': loss_u.item() } if not self.is_regressive: loss_summary['acc_x'] = acc_x if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def parse_batch_train(self, batch_x, batch_u): input_x = batch_x['img'] input_x2 = batch_x['img2'] label_x = batch_x['label'] domain_x = batch_x['domain'] input_u = batch_u['img'] input_u2 = batch_u['img2'] if self.is_regressive: label_x = torch.cat([torch.unsqueeze(x, 1) for x in label_x], 1) #Stack list of tensors else: label_x = create_onehot(label_x, self.num_classes) input_x = input_x.to(self.device) input_x2 = input_x2.to(self.device) label_x = label_x.to(self.device) input_u = input_u.to(self.device) input_u2 = input_u2.to(self.device) return input_x, input_x2, label_x, domain_x, input_u, input_u2 def parse_batch_test(self, batch): if self.is_regressive: input = batch['img'] label = batch['label'] label = torch.cat([torch.unsqueeze(x, 1) for x in label], 1) #Stack list of tensors input = input.to(self.device) label = label.to(self.device) else: input, label = super().parse_batch_test(batch) return input, label def model_inference(self, input): f = self.F(input) p = [] for k in range(self.dm.num_source_domains): p_k = self.E(k, f) p_k = p_k.unsqueeze(1) p.append(p_k) p = torch.cat(p, 1) p = p.mean(1) return p
class DAEL(TrainerXU): """Domain Adaptive Ensemble Learning. https://arxiv.org/abs/2003.07325. """ def __init__(self, cfg): super().__init__(cfg) n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE if n_domain <= 0: n_domain = self.dm.num_source_domains self.split_batch = batch_size // n_domain self.n_domain = n_domain self.weight_u = cfg.TRAINER.DAEL.WEIGHT_U self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE def check_cfg(self, cfg): assert cfg.DATALOADER.TRAIN_X.SAMPLER == 'RandomDomainSampler' assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0 def build_data_loader(self): cfg = self.cfg tfm_train = build_transform(cfg, is_train=True) custom_tfm_train = [tfm_train] choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS tfm_train_strong = build_transform(cfg, is_train=True, choices=choices) custom_tfm_train += [tfm_train_strong] self.dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train) self.train_loader_x = self.dm.train_loader_x self.train_loader_u = self.dm.train_loader_u self.val_loader = self.dm.val_loader self.test_loader = self.dm.test_loader self.num_classes = self.dm.num_classes def build_model(self): cfg = self.cfg print('Building F') self.F = SimpleNet(cfg, cfg.MODEL, 0) self.F.to(self.device) print('# params: {:,}'.format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model('F', self.F, self.optim_F, self.sched_F) fdim = self.F.fdim print('Building E') self.E = Experts(self.dm.num_source_domains, fdim, self.num_classes) self.E.to(self.device) print('# params: {:,}'.format(count_num_param(self.E))) self.optim_E = build_optimizer(self.E, cfg.OPTIM) self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM) self.register_model('E', self.E, self.optim_E, self.sched_E) def forward_backward(self, batch_x, batch_u): parsed_data = self.parse_batch_train(batch_x, batch_u) input_x, input_x2, label_x, domain_x, input_u, input_u2 = parsed_data input_x = torch.split(input_x, self.split_batch, 0) input_x2 = torch.split(input_x2, self.split_batch, 0) label_x = torch.split(label_x, self.split_batch, 0) domain_x = torch.split(domain_x, self.split_batch, 0) domain_x = [d[0].item() for d in domain_x] # Generate pseudo label with torch.no_grad(): feat_u = self.F(input_u) pred_u = [] for k in range(self.dm.num_source_domains): pred_uk = self.E(k, feat_u) pred_uk = pred_uk.unsqueeze(1) pred_u.append(pred_uk) pred_u = torch.cat(pred_u, 1) # (B, K, C) # Get the highest probability and index (label) for each expert experts_max_p, experts_max_idx = pred_u.max(2) # (B, K) # Get the most confident expert max_expert_p, max_expert_idx = experts_max_p.max(1) # (B) pseudo_label_u = [] for i, experts_label in zip(max_expert_idx, experts_max_idx): pseudo_label_u.append(experts_label[i]) pseudo_label_u = torch.stack(pseudo_label_u, 0) pseudo_label_u = create_onehot(pseudo_label_u, self.num_classes) pseudo_label_u = pseudo_label_u.to(self.device) label_u_mask = (max_expert_p >= self.conf_thre).float() loss_x = 0 loss_cr = 0 acc_x = 0 feat_x = [self.F(x) for x in input_x] feat_x2 = [self.F(x) for x in input_x2] feat_u2 = self.F(input_u2) for feat_xi, feat_x2i, label_xi, i in zip(feat_x, feat_x2, label_x, domain_x): cr_s = [j for j in domain_x if j != i] # Learning expert pred_xi = self.E(i, feat_xi) loss_x += (-label_xi * torch.log(pred_xi + 1e-5)).sum(1).mean() expert_label_xi = pred_xi.detach() acc_x += compute_accuracy(pred_xi.detach(), label_xi.max(1)[1])[0].item() # Consistency regularization cr_pred = [] for j in cr_s: pred_j = self.E(j, feat_x2i) pred_j = pred_j.unsqueeze(1) cr_pred.append(pred_j) cr_pred = torch.cat(cr_pred, 1) cr_pred = cr_pred.mean(1) loss_cr += ((cr_pred - expert_label_xi)**2).sum(1).mean() loss_x /= self.n_domain loss_cr /= self.n_domain acc_x /= self.n_domain # Unsupervised loss pred_u = [] for k in range(self.dm.num_source_domains): pred_uk = self.E(k, feat_u2) pred_uk = pred_uk.unsqueeze(1) pred_u.append(pred_uk) pred_u = torch.cat(pred_u, 1) pred_u = pred_u.mean(1) l_u = (-pseudo_label_u * torch.log(pred_u + 1e-5)).sum(1) loss_u = (l_u * label_u_mask).mean() loss = 0 loss += loss_x loss += loss_cr loss += loss_u * self.weight_u self.model_backward_and_update(loss) loss_summary = { 'loss_x': loss_x.item(), 'acc_x': acc_x, 'loss_cr': loss_cr.item(), 'loss_u': loss_u.item() } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def parse_batch_train(self, batch_x, batch_u): input_x = batch_x['img'] input_x2 = batch_x['img2'] label_x = batch_x['label'] domain_x = batch_x['domain'] input_u = batch_u['img'] input_u2 = batch_u['img2'] label_x = create_onehot(label_x, self.num_classes) input_x = input_x.to(self.device) input_x2 = input_x2.to(self.device) label_x = label_x.to(self.device) input_u = input_u.to(self.device) input_u2 = input_u2.to(self.device) return input_x, input_x2, label_x, domain_x, input_u, input_u2 def model_inference(self, input): f = self.F(input) p = [] for k in range(self.dm.num_source_domains): p_k = self.E(k, f) p_k = p_k.unsqueeze(1) p.append(p_k) p = torch.cat(p, 1) p = p.mean(1) return p