def __init__( self, config, data_config: WMTQEDataset.Config = None, module_dict: Dict[str, Any] = None, ): super().__init__(config, data_config=data_config) if module_dict: # Load modules and weights self._load_dict(module_dict) elif self.config.load_encoder: self._load_encoder(self.config.load_encoder) else: # Initialize data processing self.data_encoders = WMTQEDataEncoder( config=self.config.data_processing, field_encoders=BertEncoder.input_data_encoders( self.config.model.encoder ), ) # Add possibly missing fields, like outputs if self.config.load_vocabs: self.data_encoders.load_vocabularies(self.config.load_vocabs) if self.train_dataset: self.data_encoders.fit_vocabularies(self.train_dataset) # Input to features if not self.encoder: self.encoder = BertEncoder( vocabs=self.data_encoders.vocabularies, config=self.config.model.encoder ) # Features to output if not self.decoder: self.decoder = LinearDecoder( inputs_dims=self.encoder.size(), config=self.config.model.decoder ) # Output layers if not self.outputs: self.outputs = QEOutputs( inputs_dims=self.decoder.size(), vocabs=self.data_encoders.vocabularies, config=self.config.model.outputs, ) if not self.tlm_outputs: self.tlm_outputs = TLMOutputs( inputs_dims=self.encoder.size(), vocabs=self.data_encoders.vocabularies, config=self.config.model.tlm_outputs, )
class ModelConfig(BaseConfig): encoder: PredictorEncoder.Config = PredictorEncoder.Config() tlm_outputs: TLMOutputs.Config = TLMOutputs.Config()
class ModelConfig(BaseConfig): encoder: BertEncoder.Config = BertEncoder.Config() decoder: LinearDecoder.Config = LinearDecoder.Config() outputs: QEOutputs.Config = QEOutputs.Config() tlm_outputs: TLMOutputs.Config = TLMOutputs.Config()
class ModelConfig(BaseConfig): encoder: PredictorEncoder.Config = PredictorEncoder.Config() decoder: EstimatorDecoder.Config = EstimatorDecoder.Config() outputs: QEOutputs.Config = QEOutputs.Config() tlm_outputs: TLMOutputs.Config = TLMOutputs.Config()