class ModelConfig(base_config.Config): """General common configuration for all models.""" model_name: str = '' backbone_config: BackboneConfig = BackboneConfig() head_config: head_cfg.HeadStack = head_cfg.HeadStack() loss_config: loss_cfg.LossStack = loss_cfg.LossStack()
class TxFACModel(ModelConfig): """Configs for Tx + MLP-FAC.""" model_name: str = 'tx_mlp_fac' backbone_config: BackboneConfig = TxBackboneConfig() head_config: head_cfg.HeadStack = head_cfg.HeadStack( bridge=(head_cfg.FACBridge(), )) loss_config: loss_cfg.LossStack = loss_cfg.LossStack( bridge=(loss_cfg.AsymmetricNCE(), ))
class UnifiedTxFACModel(ModelConfig): """Configs for Unified VATT Tx + MLP-FAC.""" model_name: str = 'uvatt_mlp_fac' backbone_config: UTBackboneConfig = UTBackboneConfig() head_config: head_cfg.HeadStack = head_cfg.HeadStack( bridge=(head_cfg.FACBridge(), )) loss_config: loss_cfg.LossStack = loss_cfg.LossStack( bridge=(loss_cfg.AsymmetricNCE(), ))
class MMVFACModel(ModelConfig): """Configs for MMV + MLP-FAC baseline.""" model_name: str = 'mmv_fac' backbone_config: BackboneConfig = CNNBackboneConfig() head_config: head_cfg.HeadStack = head_cfg.HeadStack( bridge=(head_cfg.FACBridge(), )) loss_config: loss_cfg.LossStack = loss_cfg.LossStack( bridge=(loss_cfg.AsymmetricNCE(), ))