def build_model(self, args): from fairseq import models model = models.build_model(args, self) if not self.uniform_prior and not hasattr(model, 'gating_network'): if self.args.mean_pool_gating_network: if getattr(args, 'mean_pool_gating_network_encoder_dim', None): encoder_dim = args.mean_pool_gating_network_encoder_dim elif getattr(args, 'encoder_embed_dim', None): # assume that encoder_embed_dim is the encoder's output dimension encoder_dim = args.encoder_embed_dim else: raise ValueError( 'Must specify --mean-pool-gating-network-encoder-dim') if getattr(args, 'mean_pool_gating_network_dropout', None): dropout = args.mean_pool_gating_network_dropout elif getattr(args, 'dropout', None): dropout = args.dropout else: raise ValueError( 'Must specify --mean-pool-gating-network-dropout') model.gating_network = modules.MeanPoolGatingNetwork( encoder_dim, args.num_experts, dropout, ) else: raise ValueError( 'translation_moe task with learned prior requires the model to ' 'have a gating network; try using --mean-pool-gating-network' ) return model
def build_model(self, args): from fairseq import models model = models.build_model(args, self) xml_estimator = None estimator = None if not self.uniform_prior and not hasattr(model, 'gating_network'): if self.args.mean_pool_gating_network: if getattr(args, 'mean_pool_gating_network_encoder_dim', None): encoder_dim = args.mean_pool_gating_network_encoder_dim elif getattr(args, 'encoder_embed_dim', None): # assume that encoder_embed_dim is the encoder's output dimension encoder_dim = args.encoder_embed_dim else: raise ValueError('Must specify --mean-pool-gating-network-encoder-dim') if getattr(args, 'mean_pool_gating_network_dropout', None): dropout = args.mean_pool_gating_network_dropout elif getattr(args, 'dropout', None): dropout = args.dropout else: raise ValueError('Must specify --mean-pool-gating-network-dropout') model.gating_network = modules.MeanPoolGatingNetwork( encoder_dim, args.num_experts, dropout, ) else: raise ValueError( 'translation_moe task with learned prior requires the model to ' 'have a gating network; try using --mean-pool-gating-network' ) if self.share_xml_dict: estimator = Estimator(self.estimator_hidden_dim, args.estimator_xml_dim + args.estimator_transformer_dim, dropout, topk_time_step=self.topk_time_step) elif self.estimator_xml_only: estimator = Estimator(self.estimator_hidden_dim, args.estimator_xml_dim, dropout, topk_time_step=self.topk_time_step) elif args.estimator_transformer_dim != 0: if self.share_estimator: estimator = Estimator(self.estimator_hidden_dim, args.estimator_transformer_dim, dropout, share_estimator=True, topk_time_step=self.topk_time_step) else: estimator = Estimator(self.estimator_hidden_dim, args.estimator_transformer_dim, dropout, topk_time_step=self.topk_time_step) if args.estimator_xml_dim != 0: xml_estimator = Estimator(self.estimator_hidden_dim, args.estimator_xml_dim, dropout, topk_time_step=self.topk_time_step) else: raise ValueError( 'translation_moe task with learned prior requires the model to ' 'have a gating network; try using --mean-pool-gating-network' ) return model, estimator, xml_estimator