Exemple #1
0
    def build_model(self):
        cfg = self.cfg
        img_channels = cfg.DATASET.N_CHANNELS
        if 'grayscale' in cfg.INPUT.TRANSFORMS:
            img_channels = 1
            print("Found grayscale! Set img_channels to 1")
        backbone_in_channels = img_channels * cfg.DATASET.NUM_STACK
        print(f'Building F with {backbone_in_channels} in channels')
        self.F = SimpleNet(cfg, cfg.MODEL, 0, in_channels=backbone_in_channels)
        self.F.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model('F', self.F, self.optim_F, self.sched_F)
        fdim = self.F.fdim

        print('Building E')
        self.E = Experts(self.dm.num_source_domains,
                         fdim,
                         self.num_classes,
                         regressive=self.is_regressive)
        self.E.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.E)))
        self.optim_E = build_optimizer(self.E, cfg.OPTIM)
        self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)
        self.register_model('E', self.E, self.optim_E, self.sched_E)

        print('Building G')
        self.G = Gate(fdim, self.dm.num_source_domains)
        self.G.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.G)))
        self.optim_G = build_optimizer(self.G, cfg.OPTIM)
        self.sched_G = build_lr_scheduler(self.optim_G, cfg.OPTIM)
        self.register_model('G', self.G, self.optim_G, self.sched_G)
Exemple #2
0
    def build_model(self):
        cfg = self.cfg

        print('Building F')
        self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes)
        self.F.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model('F', self.F, self.optim_F, self.sched_F)

        print('Building D')
        self.D = SimpleNet(cfg, cfg.MODEL, self.dm.num_source_domains)
        self.D.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.D)))
        self.optim_D = build_optimizer(self.D, cfg.OPTIM)
        self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM)
        self.register_model('D', self.D, self.optim_D, self.sched_D)

        print('Building G')
        self.G = build_model(cfg.TRAINER.DDAIG.G_ARCH, verbose=cfg.VERBOSE)
        self.G.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.G)))
        self.optim_G = build_optimizer(self.G, cfg.OPTIM)
        self.sched_G = build_lr_scheduler(self.optim_G, cfg.OPTIM)
        self.register_model('G', self.G, self.optim_G, self.sched_G)
Exemple #3
0
    def build_model(self):
        cfg = self.cfg

        print('Building F')
        self.F = SimpleNet(cfg, cfg.MODEL, 0)
        self.F.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model('F', self.F, self.optim_F, self.sched_F)
        fdim = self.F.fdim

        print('Building C1')
        self.C1 = nn.Linear(fdim, self.num_classes)
        self.C1.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.C1)))
        self.optim_C1 = build_optimizer(self.C1, cfg.OPTIM)
        self.sched_C1 = build_lr_scheduler(self.optim_C1, cfg.OPTIM)
        self.register_model('C1', self.C1, self.optim_C1, self.sched_C1)

        print('Building C2')
        self.C2 = nn.Linear(fdim, self.num_classes)
        self.C2.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.C2)))
        self.optim_C2 = build_optimizer(self.C2, cfg.OPTIM)
        self.sched_C2 = build_lr_scheduler(self.optim_C2, cfg.OPTIM)
        self.register_model('C2', self.C2, self.optim_C2, self.sched_C2)
Exemple #4
0
    def build_model(self):
        cfg = self.cfg

        print('Building F')
        self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes)
        self.F.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model('F', self.F, self.optim_F, self.sched_F)

        print('Building D')
        self.D = SimpleNet(cfg, cfg.MODEL, self.dm.num_source_domains)
        self.D.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.D)))
        self.optim_D = build_optimizer(self.D, cfg.OPTIM)
        self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM)
        self.register_model('D', self.D, self.optim_D, self.sched_D)
Exemple #5
0
    def build_model(self):
        cfg = self.cfg

        print('Building F')
        self.F = SimpleNet(cfg, cfg.MODEL, 0)
        self.F.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model('F', self.F, self.optim_F, self.sched_F)
        fdim = self.F.fdim

        print('Building E')
        self.E = Experts(self.dm.num_source_domains, fdim, self.num_classes)
        self.E.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.E)))
        self.optim_E = build_optimizer(self.E, cfg.OPTIM)
        self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)
        self.register_model('E', self.E, self.optim_E, self.sched_E)
Exemple #6
0
    def build_model(self):
        cfg = self.cfg

        print('Building F')
        self.F = SimpleNet(cfg, cfg.MODEL, 0)
        self.F.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model('F', self.F, self.optim_F, self.sched_F)

        print('Building C')
        self.C = Prototypes(self.F.fdim, self.num_classes)
        self.C.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.C)))
        self.optim_C = build_optimizer(self.C, cfg.OPTIM)
        self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)
        self.register_model('C', self.C, self.optim_C, self.sched_C)

        self.revgrad = ReverseGrad()
Exemple #7
0
    def build_model(self):
        cfg = self.cfg

        print('Building multiple source models')
        self.models = nn.ModuleList([
            SimpleNet(cfg, cfg.MODEL, self.num_classes,
                      cfg.MODEL.CLASSIFIER.TYPE)
            for _ in range(self.dm.num_source_domains)
        ])
        self.models.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.models)))
        self.register_model('models', self.models)
Exemple #8
0
    def build_model(self):
        cfg = self.cfg

        print('Building F')
        self.F = SimpleNet(cfg, cfg.MODEL, 0)
        self.F.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model('F', self.F, self.optim_F, self.sched_F)
        fdim = self.F.fdim

        print('Building C')
        self.C = nn.ModuleList([
            PairClassifiers(fdim, self.num_classes)
            for _ in range(self.dm.num_source_domains)
        ])
        self.C.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.C)))
        self.optim_C = build_optimizer(self.C, cfg.OPTIM)
        self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)
        self.register_model('C', self.C, self.optim_C, self.sched_C)
Exemple #9
0
    def build_critic(self):
        cfg = self.cfg

        print('Building critic network')
        fdim = self.model.fdim
        critic_body = build_head('mlp',
                                 verbose=cfg.VERBOSE,
                                 in_features=fdim,
                                 hidden_layers=[fdim, fdim // 2],
                                 activation='leaky_relu')
        self.critic = nn.Sequential(critic_body, nn.Linear(fdim // 2, 1))
        print('# params: {:,}'.format(count_num_param(self.critic)))
        self.critic.to(self.device)
        self.optim_c = build_optimizer(self.critic, cfg.OPTIM)
        self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM)
        self.register_model('critic', self.critic, self.optim_c, self.sched_c)
Exemple #10
0
    def build_model(self):
        """Build and register model.

        The default builds a classification model along with its
        optimizer and scheduler.

        Custom trainers can re-implement this method if necessary.
        """
        cfg = self.cfg

        print('Building model')
        self.model = SimpleNet(cfg, cfg.MODEL, self.num_classes)
        if cfg.MODEL.INIT_WEIGHTS:
            load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)
        self.model.to(self.device)
        print('# params: {:,}'.format(count_num_param(self.model)))
        self.optim = build_optimizer(self.model, cfg.OPTIM)
        self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
        self.register_model('model', self.model, self.optim, self.sched)