def build_model(self): cfg = self.cfg img_channels = cfg.DATASET.N_CHANNELS if 'grayscale' in cfg.INPUT.TRANSFORMS: img_channels = 1 print("Found grayscale! Set img_channels to 1") backbone_in_channels = img_channels * cfg.DATASET.NUM_STACK print(f'Building F with {backbone_in_channels} in channels') self.F = SimpleNet(cfg, cfg.MODEL, 0, in_channels=backbone_in_channels) self.F.to(self.device) print('# params: {:,}'.format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model('F', self.F, self.optim_F, self.sched_F) fdim = self.F.fdim print('Building E') self.E = Experts(self.dm.num_source_domains, fdim, self.num_classes, regressive=self.is_regressive) self.E.to(self.device) print('# params: {:,}'.format(count_num_param(self.E))) self.optim_E = build_optimizer(self.E, cfg.OPTIM) self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM) self.register_model('E', self.E, self.optim_E, self.sched_E) print('Building G') self.G = Gate(fdim, self.dm.num_source_domains) self.G.to(self.device) print('# params: {:,}'.format(count_num_param(self.G))) self.optim_G = build_optimizer(self.G, cfg.OPTIM) self.sched_G = build_lr_scheduler(self.optim_G, cfg.OPTIM) self.register_model('G', self.G, self.optim_G, self.sched_G)
def build_model(self): cfg = self.cfg print('Building F') self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes) self.F.to(self.device) print('# params: {:,}'.format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model('F', self.F, self.optim_F, self.sched_F) print('Building D') self.D = SimpleNet(cfg, cfg.MODEL, self.dm.num_source_domains) self.D.to(self.device) print('# params: {:,}'.format(count_num_param(self.D))) self.optim_D = build_optimizer(self.D, cfg.OPTIM) self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM) self.register_model('D', self.D, self.optim_D, self.sched_D) print('Building G') self.G = build_model(cfg.TRAINER.DDAIG.G_ARCH, verbose=cfg.VERBOSE) self.G.to(self.device) print('# params: {:,}'.format(count_num_param(self.G))) self.optim_G = build_optimizer(self.G, cfg.OPTIM) self.sched_G = build_lr_scheduler(self.optim_G, cfg.OPTIM) self.register_model('G', self.G, self.optim_G, self.sched_G)
def build_model(self): cfg = self.cfg print('Building F') self.F = SimpleNet(cfg, cfg.MODEL, 0) self.F.to(self.device) print('# params: {:,}'.format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model('F', self.F, self.optim_F, self.sched_F) fdim = self.F.fdim print('Building C1') self.C1 = nn.Linear(fdim, self.num_classes) self.C1.to(self.device) print('# params: {:,}'.format(count_num_param(self.C1))) self.optim_C1 = build_optimizer(self.C1, cfg.OPTIM) self.sched_C1 = build_lr_scheduler(self.optim_C1, cfg.OPTIM) self.register_model('C1', self.C1, self.optim_C1, self.sched_C1) print('Building C2') self.C2 = nn.Linear(fdim, self.num_classes) self.C2.to(self.device) print('# params: {:,}'.format(count_num_param(self.C2))) self.optim_C2 = build_optimizer(self.C2, cfg.OPTIM) self.sched_C2 = build_lr_scheduler(self.optim_C2, cfg.OPTIM) self.register_model('C2', self.C2, self.optim_C2, self.sched_C2)
def build_model(self): cfg = self.cfg print('Building F') self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes) self.F.to(self.device) print('# params: {:,}'.format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model('F', self.F, self.optim_F, self.sched_F) print('Building D') self.D = SimpleNet(cfg, cfg.MODEL, self.dm.num_source_domains) self.D.to(self.device) print('# params: {:,}'.format(count_num_param(self.D))) self.optim_D = build_optimizer(self.D, cfg.OPTIM) self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM) self.register_model('D', self.D, self.optim_D, self.sched_D)
def build_model(self): cfg = self.cfg print('Building F') self.F = SimpleNet(cfg, cfg.MODEL, 0) self.F.to(self.device) print('# params: {:,}'.format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model('F', self.F, self.optim_F, self.sched_F) fdim = self.F.fdim print('Building E') self.E = Experts(self.dm.num_source_domains, fdim, self.num_classes) self.E.to(self.device) print('# params: {:,}'.format(count_num_param(self.E))) self.optim_E = build_optimizer(self.E, cfg.OPTIM) self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM) self.register_model('E', self.E, self.optim_E, self.sched_E)
def build_model(self): cfg = self.cfg print('Building F') self.F = SimpleNet(cfg, cfg.MODEL, 0) self.F.to(self.device) print('# params: {:,}'.format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model('F', self.F, self.optim_F, self.sched_F) print('Building C') self.C = Prototypes(self.F.fdim, self.num_classes) self.C.to(self.device) print('# params: {:,}'.format(count_num_param(self.C))) self.optim_C = build_optimizer(self.C, cfg.OPTIM) self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM) self.register_model('C', self.C, self.optim_C, self.sched_C) self.revgrad = ReverseGrad()
def build_model(self): cfg = self.cfg print('Building multiple source models') self.models = nn.ModuleList([ SimpleNet(cfg, cfg.MODEL, self.num_classes, cfg.MODEL.CLASSIFIER.TYPE) for _ in range(self.dm.num_source_domains) ]) self.models.to(self.device) print('# params: {:,}'.format(count_num_param(self.models))) self.register_model('models', self.models)
def build_model(self): cfg = self.cfg print('Building F') self.F = SimpleNet(cfg, cfg.MODEL, 0) self.F.to(self.device) print('# params: {:,}'.format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model('F', self.F, self.optim_F, self.sched_F) fdim = self.F.fdim print('Building C') self.C = nn.ModuleList([ PairClassifiers(fdim, self.num_classes) for _ in range(self.dm.num_source_domains) ]) self.C.to(self.device) print('# params: {:,}'.format(count_num_param(self.C))) self.optim_C = build_optimizer(self.C, cfg.OPTIM) self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM) self.register_model('C', self.C, self.optim_C, self.sched_C)
def build_critic(self): cfg = self.cfg print('Building critic network') fdim = self.model.fdim critic_body = build_head('mlp', verbose=cfg.VERBOSE, in_features=fdim, hidden_layers=[fdim, fdim // 2], activation='leaky_relu') self.critic = nn.Sequential(critic_body, nn.Linear(fdim // 2, 1)) print('# params: {:,}'.format(count_num_param(self.critic))) self.critic.to(self.device) self.optim_c = build_optimizer(self.critic, cfg.OPTIM) self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM) self.register_model('critic', self.critic, self.optim_c, self.sched_c)
def build_model(self): """Build and register model. The default builds a classification model along with its optimizer and scheduler. Custom trainers can re-implement this method if necessary. """ cfg = self.cfg print('Building model') self.model = SimpleNet(cfg, cfg.MODEL, self.num_classes) if cfg.MODEL.INIT_WEIGHTS: load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS) self.model.to(self.device) print('# params: {:,}'.format(count_num_param(self.model))) self.optim = build_optimizer(self.model, cfg.OPTIM) self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) self.register_model('model', self.model, self.optim, self.sched)