示例#1
0
    def build_model(self):
        cfg = self.cfg

        print('Building F')
        self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes)
        self.F.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model('F', self.F, self.optim_F, self.sched_F)

        print('Building D')
        self.D = SimpleNet(cfg, cfg.MODEL, self.dm.num_source_domains)
        self.D.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.D)))
        self.optim_D = build_optimizer(self.D, cfg.OPTIM)
        self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM)
        self.register_model('D', self.D, self.optim_D, self.sched_D)

        print('Building G')
        self.G = build_model(cfg.TRAINER.DDAIG.G_ARCH, verbose=cfg.VERBOSE)
        self.G.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.G)))
        self.optim_G = build_optimizer(self.G, cfg.OPTIM)
        self.sched_G = build_lr_scheduler(self.optim_G, cfg.OPTIM)
        self.register_model('G', self.G, self.optim_G, self.sched_G)
示例#2
0
    def build_model(self):
        cfg = self.cfg
        img_channels = cfg.DATASET.N_CHANNELS
        if 'grayscale' in cfg.INPUT.TRANSFORMS:
            img_channels = 1
            print("Found grayscale! Set img_channels to 1")
        backbone_in_channels = img_channels * cfg.DATASET.NUM_STACK
        print(f'Building F with {backbone_in_channels} in channels')
        self.F = SimpleNet(cfg, cfg.MODEL, 0, in_channels=backbone_in_channels)
        self.F.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model('F', self.F, self.optim_F, self.sched_F)
        fdim = self.F.fdim

        print('Building E')
        self.E = Experts(self.dm.num_source_domains,
                         fdim,
                         self.num_classes,
                         regressive=self.is_regressive)
        self.E.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.E)))
        self.optim_E = build_optimizer(self.E, cfg.OPTIM)
        self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)
        self.register_model('E', self.E, self.optim_E, self.sched_E)

        print('Building G')
        self.G = Gate(fdim, self.dm.num_source_domains)
        self.G.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.G)))
        self.optim_G = build_optimizer(self.G, cfg.OPTIM)
        self.sched_G = build_lr_scheduler(self.optim_G, cfg.OPTIM)
        self.register_model('G', self.G, self.optim_G, self.sched_G)
示例#3
0
    def build_model(self):
        cfg = self.cfg

        print('Building F')
        self.F = SimpleNet(cfg, cfg.MODEL, 0)
        self.F.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model('F', self.F, self.optim_F, self.sched_F)
        fdim = self.F.fdim

        print('Building C1')
        self.C1 = nn.Linear(fdim, self.num_classes)
        self.C1.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.C1)))
        self.optim_C1 = build_optimizer(self.C1, cfg.OPTIM)
        self.sched_C1 = build_lr_scheduler(self.optim_C1, cfg.OPTIM)
        self.register_model('C1', self.C1, self.optim_C1, self.sched_C1)

        print('Building C2')
        self.C2 = nn.Linear(fdim, self.num_classes)
        self.C2.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.C2)))
        self.optim_C2 = build_optimizer(self.C2, cfg.OPTIM)
        self.sched_C2 = build_lr_scheduler(self.optim_C2, cfg.OPTIM)
        self.register_model('C2', self.C2, self.optim_C2, self.sched_C2)
示例#4
0
class MME(TrainerXU):
    """Minimax Entropy.

    https://arxiv.org/abs/1904.06487.
    """
    def __init__(self, cfg):
        super().__init__(cfg)
        self.lmda = cfg.TRAINER.MME.LMDA

    def build_model(self):
        cfg = self.cfg

        print('Building F')
        self.F = SimpleNet(cfg, cfg.MODEL, 0)
        self.F.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model('F', self.F, self.optim_F, self.sched_F)

        print('Building C')
        self.C = Prototypes(self.F.fdim, self.num_classes)
        self.C.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.C)))
        self.optim_C = build_optimizer(self.C, cfg.OPTIM)
        self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)
        self.register_model('C', self.C, self.optim_C, self.sched_C)

        self.revgrad = ReverseGrad()

    def forward_backward(self, batch_x, batch_u):
        input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)

        feat_x = self.F(input_x)
        logit_x = self.C(feat_x)
        loss_x = F.cross_entropy(logit_x, label_x)
        self.model_backward_and_update(loss_x)

        feat_u = self.F(input_u)
        feat_u = self.revgrad(feat_u)
        logit_u = self.C(feat_u)
        prob_u = F.softmax(logit_u, 1)
        loss_u = -(-prob_u * torch.log(prob_u + 1e-5)).sum(1).mean()
        self.model_backward_and_update(loss_u * self.lmda)

        output_dict = {
            'loss_x': loss_x.item(),
            'acc_x': compute_accuracy(logit_x.detach(), label_x)[0].item(),
            'loss_u': loss_u.item(),
            'lr': self.optim_F.param_groups[0]['lr']
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return output_dict

    def model_inference(self, input):
        return self.C(self.F(input))
示例#5
0
    def build_model(self):
        cfg = self.cfg

        print('Building F')
        self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes)
        self.F.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model('F', self.F, self.optim_F, self.sched_F)

        print('Building D')
        self.D = SimpleNet(cfg, cfg.MODEL, self.dm.num_source_domains)
        self.D.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.D)))
        self.optim_D = build_optimizer(self.D, cfg.OPTIM)
        self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM)
        self.register_model('D', self.D, self.optim_D, self.sched_D)
示例#6
0
    def build_model(self):
        cfg = self.cfg

        print('Building F')
        self.F = SimpleNet(cfg, cfg.MODEL, 0)
        self.F.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model('F', self.F, self.optim_F, self.sched_F)
        fdim = self.F.fdim

        print('Building E')
        self.E = Experts(self.dm.num_source_domains, fdim, self.num_classes)
        self.E.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.E)))
        self.optim_E = build_optimizer(self.E, cfg.OPTIM)
        self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)
        self.register_model('E', self.E, self.optim_E, self.sched_E)
示例#7
0
    def build_model(self):
        cfg = self.cfg

        print('Building F')
        self.F = SimpleNet(cfg, cfg.MODEL, 0)
        self.F.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model('F', self.F, self.optim_F, self.sched_F)

        print('Building C')
        self.C = Prototypes(self.F.fdim, self.num_classes)
        self.C.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.C)))
        self.optim_C = build_optimizer(self.C, cfg.OPTIM)
        self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)
        self.register_model('C', self.C, self.optim_C, self.sched_C)

        self.revgrad = ReverseGrad()
示例#8
0
    def build_model(self):
        cfg = self.cfg

        print('Building multiple source models')
        self.models = nn.ModuleList([
            SimpleNet(cfg, cfg.MODEL, self.num_classes,
                      cfg.MODEL.CLASSIFIER.TYPE)
            for _ in range(self.dm.num_source_domains)
        ])
        self.models.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.models)))
        self.register_model('models', self.models)
示例#9
0
    def build_model(self):
        cfg = self.cfg

        print('Building F')
        self.F = SimpleNet(cfg, cfg.MODEL, 0)
        self.F.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model('F', self.F, self.optim_F, self.sched_F)
        fdim = self.F.fdim

        print('Building C')
        self.C = nn.ModuleList([
            PairClassifiers(fdim, self.num_classes)
            for _ in range(self.dm.num_source_domains)
        ])
        self.C.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.C)))
        self.optim_C = build_optimizer(self.C, cfg.OPTIM)
        self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)
        self.register_model('C', self.C, self.optim_C, self.sched_C)
示例#10
0
class M3SDA(TrainerXU):
    """Moment Matching for Multi-Source Domain Adaptation.

    https://arxiv.org/abs/1812.01754.
    """
    def __init__(self, cfg):
        super().__init__(cfg)
        n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
        batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
        if n_domain <= 0:
            n_domain = self.dm.num_source_domains
        self.split_batch = batch_size // n_domain
        self.n_domain = n_domain

        self.n_step_F = cfg.TRAINER.M3SDA.N_STEP_F
        self.lmda = cfg.TRAINER.M3SDA.LMDA

    def check_cfg(self, cfg):
        assert cfg.DATALOADER.TRAIN_X.SAMPLER == 'RandomDomainSampler'
        assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X

    def build_model(self):
        cfg = self.cfg

        print('Building F')
        self.F = SimpleNet(cfg, cfg.MODEL, 0)
        self.F.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model('F', self.F, self.optim_F, self.sched_F)
        fdim = self.F.fdim

        print('Building C')
        self.C = nn.ModuleList([
            PairClassifiers(fdim, self.num_classes)
            for _ in range(self.dm.num_source_domains)
        ])
        self.C.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.C)))
        self.optim_C = build_optimizer(self.C, cfg.OPTIM)
        self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)
        self.register_model('C', self.C, self.optim_C, self.sched_C)

    def forward_backward(self, batch_x, batch_u):
        parsed = self.parse_batch_train(batch_x, batch_u)
        input_x, label_x, domain_x, input_u = parsed

        input_x = torch.split(input_x, self.split_batch, 0)
        label_x = torch.split(label_x, self.split_batch, 0)
        domain_x = torch.split(domain_x, self.split_batch, 0)
        domain_x = [d[0].item() for d in domain_x]

        # Step A
        loss_x = 0
        feat_x = []

        for x, y, d in zip(input_x, label_x, domain_x):
            f = self.F(x)
            z1, z2 = self.C[d](f)
            loss_x += F.cross_entropy(z1, y) + F.cross_entropy(z2, y)

            feat_x.append(f)

        loss_x /= self.n_domain

        feat_u = self.F(input_u)
        loss_msda = self.moment_distance(feat_x, feat_u)

        loss_step_A = loss_x + loss_msda * self.lmda
        self.model_backward_and_update(loss_step_A)

        # Step B
        with torch.no_grad():
            feat_u = self.F(input_u)

        loss_x, loss_dis = 0, 0

        for x, y, d in zip(input_x, label_x, domain_x):
            with torch.no_grad():
                f = self.F(x)
            z1, z2 = self.C[d](f)
            loss_x += F.cross_entropy(z1, y) + F.cross_entropy(z2, y)

            z1, z2 = self.C[d](feat_u)
            p1 = F.softmax(z1, 1)
            p2 = F.softmax(z2, 1)
            loss_dis += self.discrepancy(p1, p2)

        loss_x /= self.n_domain
        loss_dis /= self.n_domain

        loss_step_B = loss_x - loss_dis
        self.model_backward_and_update(loss_step_B, 'C')

        # Step C
        for _ in range(self.n_step_F):
            feat_u = self.F(input_u)

            loss_dis = 0

            for d in domain_x:
                z1, z2 = self.C[d](feat_u)
                p1 = F.softmax(z1, 1)
                p2 = F.softmax(z2, 1)
                loss_dis += self.discrepancy(p1, p2)

            loss_dis /= self.n_domain
            loss_step_C = loss_dis

            self.model_backward_and_update(loss_step_C, 'F')

        loss_summary = {
            'loss_step_A': loss_step_A.item(),
            'loss_step_B': loss_step_B.item(),
            'loss_step_C': loss_step_C.item()
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary

    def moment_distance(self, x, u):
        # x (list): a list of feature matrix.
        # u (torch.Tensor): feature matrix.
        x_mean = [xi.mean(0) for xi in x]
        u_mean = u.mean(0)
        dist1 = self.pairwise_distance(x_mean, u_mean)

        x_var = [xi.var(0) for xi in x]
        u_var = u.var(0)
        dist2 = self.pairwise_distance(x_var, u_var)

        return (dist1 + dist2) / 2

    def pairwise_distance(self, x, u):
        # x (list): a list of feature vector.
        # u (torch.Tensor): feature vector.
        dist = 0
        count = 0

        for xi in x:
            dist += self.euclidean(xi, u)
            count += 1

        for i in range(len(x) - 1):
            for j in range(i + 1, len(x)):
                dist += self.euclidean(x[i], x[j])
                count += 1

        return dist / count

    def euclidean(self, input1, input2):
        return ((input1 - input2)**2).sum().sqrt()

    def discrepancy(self, y1, y2):
        return (y1 - y2).abs().mean()

    def parse_batch_train(self, batch_x, batch_u):
        input_x = batch_x['img']
        label_x = batch_x['label']
        domain_x = batch_x['domain']
        input_u = batch_u['img']

        input_x = input_x.to(self.device)
        label_x = label_x.to(self.device)
        input_u = input_u.to(self.device)

        return input_x, label_x, domain_x, input_u

    def model_inference(self, input):
        f = self.F(input)
        p = 0
        for C_i in self.C:
            z = C_i(f)
            p += F.softmax(z, 1)
        p = p / len(self.C)
        return p
示例#11
0
class MCD(TrainerXU):
    """Maximum Classifier Discrepancy.

    https://arxiv.org/abs/1712.02560.
    """

    def __init__(self, cfg):
        super().__init__(cfg)
        self.n_step_F = cfg.TRAINER.MCD.N_STEP_F

    def build_model(self):
        cfg = self.cfg

        print('Building F')
        self.F = SimpleNet(cfg, cfg.MODEL, 0)
        self.F.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model('F', self.F, self.optim_F, self.sched_F)
        fdim = self.F.fdim

        print('Building C1')
        self.C1 = nn.Linear(fdim, self.num_classes)
        self.C1.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.C1)))
        self.optim_C1 = build_optimizer(self.C1, cfg.OPTIM)
        self.sched_C1 = build_lr_scheduler(self.optim_C1, cfg.OPTIM)
        self.register_model('C1', self.C1, self.optim_C1, self.sched_C1)

        print('Building C2')
        self.C2 = nn.Linear(fdim, self.num_classes)
        self.C2.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.C2)))
        self.optim_C2 = build_optimizer(self.C2, cfg.OPTIM)
        self.sched_C2 = build_lr_scheduler(self.optim_C2, cfg.OPTIM)
        self.register_model('C2', self.C2, self.optim_C2, self.sched_C2)

    def forward_backward(self, batch_x, batch_u):
        parsed = self.parse_batch_train(batch_x, batch_u)
        input_x, label_x, input_u = parsed

        # Step A
        feat_x = self.F(input_x)
        logit_x1 = self.C1(feat_x)
        logit_x2 = self.C2(feat_x)
        loss_x1 = F.cross_entropy(logit_x1, label_x)
        loss_x2 = F.cross_entropy(logit_x2, label_x)
        loss_step_A = loss_x1 + loss_x2
        self.model_backward_and_update(loss_step_A)

        # Step B
        with torch.no_grad():
            feat_x = self.F(input_x)
        logit_x1 = self.C1(feat_x)
        logit_x2 = self.C2(feat_x)
        loss_x1 = F.cross_entropy(logit_x1, label_x)
        loss_x2 = F.cross_entropy(logit_x2, label_x)
        loss_x = loss_x1 + loss_x2

        with torch.no_grad():
            feat_u = self.F(input_u)
        pred_u1 = F.softmax(self.C1(feat_u), 1)
        pred_u2 = F.softmax(self.C2(feat_u), 1)
        loss_dis = self.discrepancy(pred_u1, pred_u2)

        loss_step_B = loss_x - loss_dis
        self.model_backward_and_update(loss_step_B, ['C1', 'C2'])

        # Step C
        for _ in range(self.n_step_F):
            feat_u = self.F(input_u)
            pred_u1 = F.softmax(self.C1(feat_u), 1)
            pred_u2 = F.softmax(self.C2(feat_u), 1)
            loss_step_C = self.discrepancy(pred_u1, pred_u2)
            self.model_backward_and_update(loss_step_C, 'F')

        loss_summary = {
            'loss_step_A': loss_step_A.item(),
            'loss_step_B': loss_step_B.item(),
            'loss_step_C': loss_step_C.item()
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary

    def discrepancy(self, y1, y2):
        return (y1 - y2).abs().mean()

    def model_inference(self, input):
        feat = self.F(input)
        return self.C1(feat)
示例#12
0
class CrossGrad(TrainerX):
    """Cross-gradient training.

    https://arxiv.org/abs/1804.10745.
    """
    def __init__(self, cfg):
        super().__init__(cfg)
        self.eps_f = cfg.TRAINER.CG.EPS_F
        self.eps_d = cfg.TRAINER.CG.EPS_D
        self.alpha_f = cfg.TRAINER.CG.ALPHA_F
        self.alpha_d = cfg.TRAINER.CG.ALPHA_D

    def build_model(self):
        cfg = self.cfg

        print('Building F')
        self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes)
        self.F.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model('F', self.F, self.optim_F, self.sched_F)

        print('Building D')
        self.D = SimpleNet(cfg, cfg.MODEL, self.dm.num_source_domains)
        self.D.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.D)))
        self.optim_D = build_optimizer(self.D, cfg.OPTIM)
        self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM)
        self.register_model('D', self.D, self.optim_D, self.sched_D)

    def forward_backward(self, batch):
        input, label, domain = self.parse_batch_train(batch)

        input.requires_grad = True

        # Compute domain perturbation
        loss_d = F.cross_entropy(self.D(input), domain)
        loss_d.backward()
        grad_d = torch.clamp(input.grad.data, min=-0.1, max=0.1)
        input_d = input.data + self.eps_f * grad_d

        # Compute label perturbation
        input.grad.data.zero_()
        loss_f = F.cross_entropy(self.F(input), label)
        loss_f.backward()
        grad_f = torch.clamp(input.grad.data, min=-0.1, max=0.1)
        input_f = input.data + self.eps_d * grad_f

        input = input.detach()

        # Update label net
        loss_f1 = F.cross_entropy(self.F(input), label)
        loss_f2 = F.cross_entropy(self.F(input_d), label)
        loss_f = (1 - self.alpha_f) * loss_f1 + self.alpha_f * loss_f2
        self.model_backward_and_update(loss_f, 'F')

        # Update domain net
        loss_d1 = F.cross_entropy(self.D(input), domain)
        loss_d2 = F.cross_entropy(self.D(input_f), domain)
        loss_d = (1 - self.alpha_d) * loss_d1 + self.alpha_d * loss_d2
        self.model_backward_and_update(loss_d, 'D')

        loss_summary = {'loss_f': loss_f.item(), 'loss_d': loss_d.item()}

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary

    def model_inference(self, input):
        return self.F(input)
示例#13
0
class DAELDG(TrainerX):
    """Domain Adaptive Ensemble Learning.

    DG version: only use labeled source data.

    https://arxiv.org/abs/2003.07325.
    """
    def __init__(self, cfg):
        super().__init__(cfg)
        n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
        batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
        if n_domain <= 0:
            n_domain = self.dm.num_source_domains
        self.split_batch = batch_size // n_domain
        self.n_domain = n_domain

        self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE

    def check_cfg(self, cfg):
        assert cfg.DATALOADER.TRAIN_X.SAMPLER == 'RandomDomainSampler'
        assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0

    def build_data_loader(self):
        cfg = self.cfg
        tfm_train = build_transform(cfg, is_train=True)
        custom_tfm_train = [tfm_train]
        choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS
        tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
        custom_tfm_train += [tfm_train_strong]
        self.dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
        self.train_loader_x = self.dm.train_loader_x
        self.train_loader_u = self.dm.train_loader_u
        self.val_loader = self.dm.val_loader
        self.test_loader = self.dm.test_loader
        self.num_classes = self.dm.num_classes

    def build_model(self):
        cfg = self.cfg

        print('Building F')
        self.F = SimpleNet(cfg, cfg.MODEL, 0)
        self.F.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model('F', self.F, self.optim_F, self.sched_F)
        fdim = self.F.fdim

        print('Building E')
        self.E = Experts(self.dm.num_source_domains, fdim, self.num_classes)
        self.E.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.E)))
        self.optim_E = build_optimizer(self.E, cfg.OPTIM)
        self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)
        self.register_model('E', self.E, self.optim_E, self.sched_E)

    def forward_backward(self, batch):
        parsed_data = self.parse_batch_train(batch)
        input, input2, label, domain = parsed_data

        input = torch.split(input, self.split_batch, 0)
        input2 = torch.split(input2, self.split_batch, 0)
        label = torch.split(label, self.split_batch, 0)
        domain = torch.split(domain, self.split_batch, 0)
        domain = [d[0].item() for d in domain]

        loss = 0
        loss_cr = 0
        acc = 0

        feat = [self.F(x) for x in input]
        feat2 = [self.F(x) for x in input2]

        for feat_i, feat2_i, label_i, i in zip(feat, feat2, label, domain):
            cr_s = [j for j in domain if j != i]

            # Learning expert
            pred_i = self.E(i, feat_i)
            loss += (-label_i * torch.log(pred_i + 1e-5)).sum(1).mean()
            expert_label_i = pred_i.detach()
            acc += compute_accuracy(pred_i.detach(),
                                    label_i.max(1)[1])[0].item()

            # Consistency regularization
            cr_pred = []
            for j in cr_s:
                pred_j = self.E(j, feat2_i)
                pred_j = pred_j.unsqueeze(1)
                cr_pred.append(pred_j)
            cr_pred = torch.cat(cr_pred, 1)
            cr_pred = cr_pred.mean(1)
            loss_cr += ((cr_pred - expert_label_i)**2).sum(1).mean()

        loss /= self.n_domain
        loss_cr /= self.n_domain
        acc /= self.n_domain

        loss = 0
        loss += loss
        loss += loss_cr
        self.model_backward_and_update(loss)

        loss_summary = {
            'loss': loss.item(),
            'acc': acc,
            'loss_cr': loss_cr.item()
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary

    def parse_batch_train(self, batch):
        input = batch['img']
        input2 = batch['img2']
        label = batch['label']
        domain = batch['domain']

        label = create_onehot(label, self.num_classes)

        input = input.to(self.device)
        input2 = input2.to(self.device)
        label = label.to(self.device)

        return input, input2, label, domain

    def model_inference(self, input):
        f = self.F(input)
        p = []
        for k in range(self.dm.num_source_domains):
            p_k = self.E(k, f)
            p_k = p_k.unsqueeze(1)
            p.append(p_k)
        p = torch.cat(p, 1)
        p = p.mean(1)
        return p
示例#14
0
class AltDAELGated(TrainerXU):
    """Domain Adaptive Ensemble Learning.

    https://arxiv.org/abs/2003.07325.
    """
    def __init__(self, cfg):
        self.is_regressive = cfg.TRAINER.DAEL.TASK.lower() == "regression"
        super().__init__(cfg)
        n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
        batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
        if n_domain <= 0:
            n_domain = self.dm.num_source_domains
        self.split_batch = batch_size // n_domain
        self.n_domain = n_domain

        self.weight_u = cfg.TRAINER.DAEL.WEIGHT_U
        self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE

    def check_cfg(self, cfg):
        assert cfg.DATALOADER.TRAIN_X.SAMPLER == 'RandomDomainSampler'
        assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X
        assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0

    def build_data_loader(self):
        cfg = self.cfg
        tfm_train = build_transform(cfg, is_train=True)
        custom_tfm_train = [tfm_train]
        choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS
        tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
        custom_tfm_train += [tfm_train_strong]
        self.dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
        self.train_loader_x = self.dm.train_loader_x
        self.train_loader_u = self.dm.train_loader_u
        self.val_loader = self.dm.val_loader
        self.test_loader = self.dm.test_loader
        self.num_classes = self.dm.num_classes

    def build_model(self):
        cfg = self.cfg
        img_channels = cfg.DATASET.N_CHANNELS
        if 'grayscale' in cfg.INPUT.TRANSFORMS:
            img_channels = 1
            print("Found grayscale! Set img_channels to 1")
        backbone_in_channels = img_channels * cfg.DATASET.NUM_STACK
        print(f'Building F with {backbone_in_channels} in channels')
        self.F = SimpleNet(cfg, cfg.MODEL, 0, in_channels=backbone_in_channels)
        self.F.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model('F', self.F, self.optim_F, self.sched_F)
        fdim = self.F.fdim

        print('Building E')
        self.E = Experts(self.dm.num_source_domains,
                         fdim,
                         self.num_classes,
                         regressive=self.is_regressive)
        self.E.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.E)))
        self.optim_E = build_optimizer(self.E, cfg.OPTIM)
        self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)
        self.register_model('E', self.E, self.optim_E, self.sched_E)

        print('Building G')
        self.G = Gate(fdim, self.dm.num_source_domains)
        self.G.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.G)))
        self.optim_G = build_optimizer(self.G, cfg.OPTIM)
        self.sched_G = build_lr_scheduler(self.optim_G, cfg.OPTIM)
        self.register_model('G', self.G, self.optim_G, self.sched_G)

    def d_closest(self, d_filter):
        n_dom = d_filter.shape[1]
        closest = d_filter.max(1)[1]
        n_closest = torch.zeros(n_dom)
        for dom in range(n_dom):
            times_closest = torch.Tensor([
                1 for i in range(len(closest)) if closest[i] == dom
            ]).sum().item()
            n_closest[dom] = (times_closest / len(d_filter))
        return n_closest

    def forward_backward(self, batch_x, batch_u):
        # Load data
        parsed_data = self.parse_batch_train(batch_x, batch_u)
        input_x, input_x2, label_x, domain_x, input_u, input_u2 = parsed_data
        input_x = torch.split(input_x, self.split_batch, 0)
        input_x2 = torch.split(input_x2, self.split_batch, 0)
        label_x = torch.split(label_x, self.split_batch, 0)
        domain_x = torch.split(domain_x, self.split_batch, 0)
        domain_x = [d[0].item() for d in domain_x]

        # x = data with small augmentations. x2 = data with large augmentations
        # They both correspond to the same datapoints. Same scheme for u and u2.

        # Generate pseudo label
        with torch.no_grad():
            # Unsupervised predictions
            feat_u = self.F(input_u)
            pred_u = []
            for k in range(self.dm.num_source_domains):
                pred_uk = self.E(k, feat_u)
                pred_uk = pred_uk.unsqueeze(1)
                pred_u.append(pred_uk)
            pred_u = torch.cat(pred_u, 1)  # (B, K, C)
            # Pseudolabel = weighted predictions
            u_filter = self.G(feat_u)  # (B, K)
            label_u_mask = (u_filter.max(1)[0] >= self.conf_thre
                            )  # (B). 1 if >=1 expert > thre, 0 otherwise
            new_u_filter = torch.zeros(*u_filter.shape).to(self.device)
            for i, row in enumerate(u_filter):
                j_max = row.max(0)[1]
                new_u_filter[i, j_max] = 1
            u_filter = new_u_filter
            d_closest = self.d_closest(u_filter).max(0)[1]
            u_filter = u_filter.unsqueeze(2).expand(*pred_u.shape)
            pred_fu = (pred_u * u_filter).sum(
                1)  # Zero out all non chosen experts
            pseudo_label_u = pred_fu.max(1)[1]  # (B)
            pseudo_label_u = create_onehot(pseudo_label_u,
                                           self.num_classes).to(self.device)
        # Init losses
        loss_x = 0
        loss_cr = 0
        acc_x = 0
        loss_filter = 0
        acc_filter = 0

        # Supervised and unsupervised features
        feat_x = [self.F(x) for x in input_x]
        feat_x2 = [self.F(x) for x in input_x2]
        feat_u2 = self.F(input_u2)

        for feat_xi, feat_x2i, label_xi, i in zip(feat_x, feat_x2, label_x,
                                                  domain_x):
            cr_s = [j for j in domain_x if j != i]

            # Learning expert
            pred_xi = self.E(i, feat_xi)
            expert_label_xi = pred_xi.detach()
            if self.is_regressive:
                loss_x += ((pred_xi - label_xi)**2).sum(1).mean()
            else:
                loss_x += (-label_xi * torch.log(pred_xi + 1e-5)).sum(1).mean()
                acc_x += compute_accuracy(pred_xi.detach(),
                                          label_xi.max(1)[1])[0].item()

            x_filter = self.G(feat_xi)
            # Filter must be 1 for expert, 0 otherwise
            filter_label = torch.Tensor([0 for _ in range(len(domain_x))
                                         ]).to(self.device)
            filter_label[i] = 1
            filter_label = filter_label.unsqueeze(0).expand(*x_filter.shape)
            loss_filter += (-filter_label *
                            torch.log(x_filter + 1e-5)).sum(1).mean()
            acc_filter += compute_accuracy(x_filter.detach(),
                                           filter_label.max(1)[1])[0].item()

            # Consistency regularization - Mean must follow the leading expert
            cr_pred = []
            for j in cr_s:
                pred_j = self.E(j, feat_x2i)
                pred_j = pred_j.unsqueeze(1)
                cr_pred.append(pred_j)
            cr_pred = torch.cat(cr_pred, 1).mean(1)
            loss_cr += ((cr_pred - expert_label_xi)**2).sum(1).mean()

        loss_x /= self.n_domain
        loss_cr /= self.n_domain
        if not self.is_regressive:
            acc_x /= self.n_domain
        loss_filter /= self.n_domain
        acc_filter /= self.n_domain

        # Unsupervised loss
        pred_u = []
        for k in range(self.dm.num_source_domains):
            pred_uk = self.E(k, feat_u2)
            pred_uk = pred_uk.unsqueeze(1)
            pred_u.append(pred_uk)
        pred_u = torch.cat(pred_u, 1).to(self.device)
        pred_u = pred_u.mean(1)
        if self.is_regressive:
            l_u = (-pseudo_label_u * torch.log(pred_u + 1e-5)).sum(1)
        else:
            l_u = ((pseudo_label_u - pred_u)**2).sum(1).mean()
        loss_u = (l_u * label_u_mask).mean()

        loss = 0
        loss += loss_x
        loss += loss_cr
        loss += loss_filter
        loss += loss_u * self.weight_u
        self.model_backward_and_update(loss)

        loss_summary = {
            'loss_x': loss_x.item(),
            'loss_filter': loss_filter.item(),
            'acc_filter': acc_filter,
            'loss_cr': loss_cr.item(),
            'loss_u': loss_u.item(),
            #'d_closest': d_closest.max(0)[1]
            'd_closest': d_closest.item()
        }
        if not self.is_regressive:
            loss_summary['acc_x'] = acc_x

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary

    def parse_batch_train(self, batch_x, batch_u):
        input_x = batch_x['img']
        input_x2 = batch_x['img2']
        label_x = batch_x['label']
        domain_x = batch_x['domain']
        input_u = batch_u['img']
        input_u2 = batch_u['img2']
        if self.is_regressive:
            label_x = torch.cat([torch.unsqueeze(x, 1) for x in label_x],
                                1)  #Stack list of tensors
        else:
            label_x = create_onehot(label_x, self.num_classes)

        input_x = input_x.to(self.device)
        input_x2 = input_x2.to(self.device)
        label_x = label_x.to(self.device)
        input_u = input_u.to(self.device)
        input_u2 = input_u2.to(self.device)

        return input_x, input_x2, label_x, domain_x, input_u, input_u2

    def parse_batch_test(self, batch):
        if self.is_regressive:
            input = batch['img']
            label = batch['label']
            label = torch.cat([torch.unsqueeze(x, 1) for x in label],
                              1)  #Stack list of tensors
            input = input.to(self.device)
            label = label.to(self.device)
        else:
            input, label = super().parse_batch_test(batch)

        return input, label

    def model_inference(self, input):
        f = self.F(input)
        g = self.G(f).unsqueeze(2)
        p = []
        for k in range(self.dm.num_source_domains):
            p_k = self.E(k, f)
            p_k = p_k.unsqueeze(1)
            p.append(p_k)
        new_g = torch.zeros(*g.shape).to(self.device)
        for i, row in enumerate(g):
            j_max = row.max(0)[1]
            new_g[i, j_max] = 1
        p = torch.cat(p, 1)
        p = (p * new_g).sum(1)
        return p, g

    @torch.no_grad()
    def test(self):
        """A generic testing pipeline."""
        self.set_model_mode('eval')
        self.evaluator.reset()

        split = self.cfg.TEST.SPLIT
        print('Do evaluation on {} set'.format(split))
        data_loader = self.val_loader if split == 'val' else self.test_loader
        assert data_loader is not None

        all_d_filter = []
        for batch_idx, batch in enumerate(data_loader):
            input, label = self.parse_batch_test(batch)
            output, d_filter = self.model_inference(input)
            all_d_filter.append(d_filter)
            self.evaluator.process(output, label)

        results = self.evaluator.evaluate()
        all_d_filter = torch.cat(all_d_filter,
                                 0).mean(0).cpu().detach().numpy()
        print(f"* Average gate: {list(all_d_filter)}")
        for k, v in results.items():
            tag = '{}/{}'.format(split, k)
            self.write_scalar(tag, v, self.epoch)
示例#15
0
class DDAIG(TrainerX):
    """Deep Domain-Adversarial Image Generation.

    https://arxiv.org/abs/2003.06054.
    """

    def __init__(self, cfg):
        super().__init__(cfg)
        self.lmda = cfg.TRAINER.DDAIG.LMDA
        self.clamp = cfg.TRAINER.DDAIG.CLAMP
        self.clamp_min = cfg.TRAINER.DDAIG.CLAMP_MIN
        self.clamp_max = cfg.TRAINER.DDAIG.CLAMP_MAX
        self.warmup = cfg.TRAINER.DDAIG.WARMUP
        self.alpha = cfg.TRAINER.DDAIG.ALPHA

    def build_model(self):
        cfg = self.cfg

        print('Building F')
        self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes)
        self.F.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model('F', self.F, self.optim_F, self.sched_F)

        print('Building D')
        self.D = SimpleNet(cfg, cfg.MODEL, self.dm.num_source_domains)
        self.D.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.D)))
        self.optim_D = build_optimizer(self.D, cfg.OPTIM)
        self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM)
        self.register_model('D', self.D, self.optim_D, self.sched_D)

        print('Building G')
        self.G = build_model(cfg.TRAINER.DDAIG.G_ARCH, verbose=cfg.VERBOSE)
        self.G.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.G)))
        self.optim_G = build_optimizer(self.G, cfg.OPTIM)
        self.sched_G = build_lr_scheduler(self.optim_G, cfg.OPTIM)
        self.register_model('G', self.G, self.optim_G, self.sched_G)

    def forward_backward(self, batch):
        input, label, domain = self.parse_batch_train(batch)

        #############
        # Update G
        #############
        input_p = self.G(input, lmda=self.lmda)
        if self.clamp:
            input_p = torch.clamp(
                input_p, min=self.clamp_min, max=self.clamp_max
            )
        loss_g = 0
        # Minimize label loss
        loss_g += F.cross_entropy(self.F(input_p), label)
        # Maximize domain loss
        loss_g -= F.cross_entropy(self.D(input_p), domain)
        self.model_backward_and_update(loss_g, 'G')

        # Perturb data with new G
        with torch.no_grad():
            input_p = self.G(input, lmda=self.lmda)
            if self.clamp:
                input_p = torch.clamp(
                    input_p, min=self.clamp_min, max=self.clamp_max
                )

        #############
        # Update F
        #############
        loss_f = F.cross_entropy(self.F(input), label)
        if (self.epoch + 1) > self.warmup:
            loss_fp = F.cross_entropy(self.F(input_p), label)
            loss_f = (1. - self.alpha) * loss_f + self.alpha * loss_fp
        self.model_backward_and_update(loss_f, 'F')

        #############
        # Update D
        #############
        loss_d = F.cross_entropy(self.D(input), domain)
        self.model_backward_and_update(loss_d, 'D')

        output_dict = {
            'loss_g': loss_g.item(),
            'loss_f': loss_f.item(),
            'loss_d': loss_d.item(),
            'lr': self.optim_F.param_groups[0]['lr']
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return output_dict

    def model_inference(self, input):
        return self.F(input)
示例#16
0
class DAELReg(TrainerXU):
    """Domain Adaptive Ensemble Learning.

    https://arxiv.org/abs/2003.07325.
    """
    def __init__(self, cfg):
        super().__init__(cfg)
        self.is_regressive = cfg.TRAINER.DAEL.TASK.lower() == "regression"
        n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
        batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
        if n_domain <= 0:
            n_domain = self.dm.num_source_domains
        self.split_batch = batch_size // n_domain
        self.n_domain = n_domain

        self.weight_u = cfg.TRAINER.DAEL.WEIGHT_U
        self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE
#         self.e_crit = e_crit
#         self.cr_crit= cr_crit

    def check_cfg(self, cfg):
        assert cfg.DATALOADER.TRAIN_X.SAMPLER == 'RandomDomainSampler'
        assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X
        assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0

    def build_data_loader(self):
        cfg = self.cfg
        tfm_train = build_transform(cfg, is_train=True)
        custom_tfm_train = [tfm_train]
        choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS
        tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
        custom_tfm_train += [tfm_train_strong]
        self.dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
        self.train_loader_x = self.dm.train_loader_x
        self.train_loader_u = self.dm.train_loader_u
        self.val_loader = self.dm.val_loader
        self.test_loader = self.dm.test_loader
        self.num_classes = self.dm.num_classes

    def build_model(self):
        cfg = self.cfg
        img_channels = cfg.DATASET.N_CHANNELS
        if 'grayscale' in cfg.INPUT.TRANSFORMS:
            img_channels = 1
            print("Found grayscale! Set img_channels to 1")

        backbone_in_channels = img_channels * cfg.DATASET.NUM_STACK
        print(f'Building F with {backbone_in_channels} in channels')
        self.F = SimpleNet(cfg, cfg.MODEL, 0, in_channels=backbone_in_channels)
        self.F.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model('F', self.F, self.optim_F, self.sched_F)
        fdim = self.F.fdim

        print('Building E')
        self.E = Experts(self.dm.num_source_domains, fdim, self.num_classes)
        self.E.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.E)))
        self.optim_E = build_optimizer(self.E, cfg.OPTIM)
        self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)
        self.register_model('E', self.E, self.optim_E, self.sched_E)

    def forward_backward(self, batch_x, batch_u):
        parsed_data = self.parse_batch_train(batch_x, batch_u)
        input_x, input_x2, label_x, domain_x, input_u, input_u2 = parsed_data

        input_x = torch.split(input_x, self.split_batch, 0)
        input_x2 = torch.split(input_x2, self.split_batch, 0)
        label_x = torch.split(label_x, self.split_batch, 0)
        domain_x = torch.split(domain_x, self.split_batch, 0)
        domain_x = [d[0].item() for d in domain_x]

        # Generate pseudo label
        with torch.no_grad():
            feat_u = self.F(input_u)
            pred_u = []
            for k in range(self.dm.num_source_domains):
                pred_uk = self.E(k, feat_u)
                pred_uk = pred_uk.unsqueeze(1)
                pred_u.append(pred_uk)
            pred_u = torch.cat(pred_u, 1)  # (B, K, C)
            # Get the median prediction for each action
            pred_u = pred_u.median(1).values
            #Note that there is no leading expert, we just take the median of all predictions

        loss_x = 0
        loss_cr = 0
        if not self.is_regressive:
            acc_x = 0

        feat_x = [self.F(x) for x in input_x]
        feat_x2 = [self.F(x) for x in input_x2]
        #         feat_u2 = self.F(input_u2)

        for feat_xi, feat_x2i, label_xi, i in zip(feat_x, feat_x2, label_x,
                                                  domain_x):
            cr_s = [j for j in domain_x if j != i]

            # Learning expert
            pred_xi = self.E(i, feat_xi)
            if self.is_regressive:
                loss_x += ((pred_xi - label_xi)**2).sum(1).mean()
            else:
                loss_x += (-label_xi * torch.log(pred_xi + 1e-5)).sum(1).mean()
                acc_x += compute_accuracy(pred_xi.detach(),
                                          label_xi.max(1)[1])[0].item()
            expert_label_xi = pred_xi.detach()

            # Consistency regularization
            cr_pred = []
            for j in cr_s:
                pred_j = self.E(j, feat_x2i)
                pred_j = pred_j.unsqueeze(1)
                cr_pred.append(pred_j)
            cr_pred = torch.cat(cr_pred, 1)
            cr_pred = cr_pred.mean(1)
            loss_cr += ((cr_pred - expert_label_xi)**2).sum(1).mean()

        loss_x /= self.n_domain
        loss_cr /= self.n_domain
        if not self.is_regressive:
            acc_x /= self.n_domain

        # Unsupervised loss -> None yet
        # Pending: provide a means of establishing a lead expert so that loss can be calculated

        loss = 0
        loss += loss_x
        loss += loss_cr
        #         loss += loss_u * self.weight_u
        self.model_backward_and_update(loss)

        loss_summary = {
            'loss_x': loss_x.item(),
            'loss_cr': loss_cr.item(),
            #             'loss_u': loss_u.item()
        }
        if not self.is_regressive:
            loss_summary['acc_x'] = acc_x
        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary

    def parse_batch_train(self, batch_x, batch_u):
        input_x = batch_x['img']
        input_x2 = batch_x['img2']
        label_x = batch_x['label']
        domain_x = batch_x['domain']
        input_u = batch_u['img']
        input_u2 = batch_u['img2']

        if self.is_regressive:
            label_x = torch.cat([torch.unsqueeze(x, 1) for x in label_x],
                                1)  #Stack list of tensors
        else:
            label_x = create_onehot(label_x, self.num_classes)

        input_x = input_x.to(self.device)
        input_x2 = input_x2.to(self.device)
        label_x = label_x.to(self.device)
        input_u = input_u.to(self.device)
        input_u2 = input_u2.to(self.device)

        return input_x, input_x2, label_x, domain_x, input_u, input_u2

    def parse_batch_test(self, batch):
        if self.is_regressive:
            input = batch['img']
            label = batch['label']
            label = torch.cat([torch.unsqueeze(x, 1) for x in label],
                              1)  #Stack list of tensors
            input = input.to(self.device)
            label = label.to(self.device)
        else:
            input, label = super().parse_batch_test(batch)

        return input, label

    def model_inference(self, input):
        f = self.F(input)
        p = []
        for k in range(self.dm.num_source_domains):
            p_k = self.E(k, f)
            p_k = p_k.unsqueeze(1)
            p.append(p_k)
        p = torch.cat(p, 1)
        p = p.mean(1)
        return p
示例#17
0
class DAEL(TrainerXU):
    """Domain Adaptive Ensemble Learning.

    https://arxiv.org/abs/2003.07325.
    """
    def __init__(self, cfg):
        super().__init__(cfg)
        n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
        batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
        if n_domain <= 0:
            n_domain = self.dm.num_source_domains
        self.split_batch = batch_size // n_domain
        self.n_domain = n_domain

        self.weight_u = cfg.TRAINER.DAEL.WEIGHT_U
        self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE

    def check_cfg(self, cfg):
        assert cfg.DATALOADER.TRAIN_X.SAMPLER == 'RandomDomainSampler'
        assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X
        assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0

    def build_data_loader(self):
        cfg = self.cfg
        tfm_train = build_transform(cfg, is_train=True)
        custom_tfm_train = [tfm_train]
        choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS
        tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
        custom_tfm_train += [tfm_train_strong]
        self.dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
        self.train_loader_x = self.dm.train_loader_x
        self.train_loader_u = self.dm.train_loader_u
        self.val_loader = self.dm.val_loader
        self.test_loader = self.dm.test_loader
        self.num_classes = self.dm.num_classes

    def build_model(self):
        cfg = self.cfg

        print('Building F')
        self.F = SimpleNet(cfg, cfg.MODEL, 0)
        self.F.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model('F', self.F, self.optim_F, self.sched_F)
        fdim = self.F.fdim

        print('Building E')
        self.E = Experts(self.dm.num_source_domains, fdim, self.num_classes)
        self.E.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.E)))
        self.optim_E = build_optimizer(self.E, cfg.OPTIM)
        self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)
        self.register_model('E', self.E, self.optim_E, self.sched_E)

    def forward_backward(self, batch_x, batch_u):
        parsed_data = self.parse_batch_train(batch_x, batch_u)
        input_x, input_x2, label_x, domain_x, input_u, input_u2 = parsed_data

        input_x = torch.split(input_x, self.split_batch, 0)
        input_x2 = torch.split(input_x2, self.split_batch, 0)
        label_x = torch.split(label_x, self.split_batch, 0)
        domain_x = torch.split(domain_x, self.split_batch, 0)
        domain_x = [d[0].item() for d in domain_x]
        # Generate pseudo label
        with torch.no_grad():
            feat_u = self.F(input_u)
            pred_u = []
            for k in range(self.dm.num_source_domains):
                pred_uk = self.E(k, feat_u)
                pred_uk = pred_uk.unsqueeze(1)
                pred_u.append(pred_uk)
            pred_u = torch.cat(pred_u, 1)  # (B, K, C)
            # Get the highest probability and index (label) for each expert
            experts_max_p, experts_max_idx = pred_u.max(2)  # (B, K)
            # Get the most confident expert
            max_expert_p, max_expert_idx = experts_max_p.max(1)  # (B)
            pseudo_label_u = []
            for i, experts_label in zip(max_expert_idx, experts_max_idx):
                pseudo_label_u.append(experts_label[i])
            pseudo_label_u = torch.stack(pseudo_label_u, 0)
            pseudo_label_u = create_onehot(pseudo_label_u, self.num_classes)
            pseudo_label_u = pseudo_label_u.to(self.device)
            label_u_mask = (max_expert_p >= self.conf_thre).float()

        loss_x = 0
        loss_cr = 0
        acc_x = 0

        feat_x = [self.F(x) for x in input_x]
        feat_x2 = [self.F(x) for x in input_x2]
        feat_u2 = self.F(input_u2)

        for feat_xi, feat_x2i, label_xi, i in zip(feat_x, feat_x2, label_x,
                                                  domain_x):
            cr_s = [j for j in domain_x if j != i]

            # Learning expert
            pred_xi = self.E(i, feat_xi)
            loss_x += (-label_xi * torch.log(pred_xi + 1e-5)).sum(1).mean()
            expert_label_xi = pred_xi.detach()
            acc_x += compute_accuracy(pred_xi.detach(),
                                      label_xi.max(1)[1])[0].item()

            # Consistency regularization
            cr_pred = []
            for j in cr_s:
                pred_j = self.E(j, feat_x2i)
                pred_j = pred_j.unsqueeze(1)
                cr_pred.append(pred_j)
            cr_pred = torch.cat(cr_pred, 1)
            cr_pred = cr_pred.mean(1)
            loss_cr += ((cr_pred - expert_label_xi)**2).sum(1).mean()

        loss_x /= self.n_domain
        loss_cr /= self.n_domain
        acc_x /= self.n_domain

        # Unsupervised loss
        pred_u = []
        for k in range(self.dm.num_source_domains):
            pred_uk = self.E(k, feat_u2)
            pred_uk = pred_uk.unsqueeze(1)
            pred_u.append(pred_uk)
        pred_u = torch.cat(pred_u, 1)
        pred_u = pred_u.mean(1)
        l_u = (-pseudo_label_u * torch.log(pred_u + 1e-5)).sum(1)
        loss_u = (l_u * label_u_mask).mean()

        loss = 0
        loss += loss_x
        loss += loss_cr
        loss += loss_u * self.weight_u
        self.model_backward_and_update(loss)

        loss_summary = {
            'loss_x': loss_x.item(),
            'acc_x': acc_x,
            'loss_cr': loss_cr.item(),
            'loss_u': loss_u.item()
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary

    def parse_batch_train(self, batch_x, batch_u):
        input_x = batch_x['img']
        input_x2 = batch_x['img2']
        label_x = batch_x['label']
        domain_x = batch_x['domain']
        input_u = batch_u['img']
        input_u2 = batch_u['img2']

        label_x = create_onehot(label_x, self.num_classes)

        input_x = input_x.to(self.device)
        input_x2 = input_x2.to(self.device)
        label_x = label_x.to(self.device)
        input_u = input_u.to(self.device)
        input_u2 = input_u2.to(self.device)

        return input_x, input_x2, label_x, domain_x, input_u, input_u2

    def model_inference(self, input):
        f = self.F(input)
        p = []
        for k in range(self.dm.num_source_domains):
            p_k = self.E(k, f)
            p_k = p_k.unsqueeze(1)
            p.append(p_k)
        p = torch.cat(p, 1)
        p = p.mean(1)
        return p