コード例 #1
0
    def build_model(self, args):
        from fairseq import models
        model = models.build_model(args, self)
        if not self.uniform_prior and not hasattr(model, 'gating_network'):
            if self.args.mean_pool_gating_network:
                if getattr(args, 'mean_pool_gating_network_encoder_dim', None):
                    encoder_dim = args.mean_pool_gating_network_encoder_dim
                elif getattr(args, 'encoder_embed_dim', None):
                    # assume that encoder_embed_dim is the encoder's output dimension
                    encoder_dim = args.encoder_embed_dim
                else:
                    raise ValueError(
                        'Must specify --mean-pool-gating-network-encoder-dim')

                if getattr(args, 'mean_pool_gating_network_dropout', None):
                    dropout = args.mean_pool_gating_network_dropout
                elif getattr(args, 'dropout', None):
                    dropout = args.dropout
                else:
                    raise ValueError(
                        'Must specify --mean-pool-gating-network-dropout')

                model.gating_network = modules.MeanPoolGatingNetwork(
                    encoder_dim,
                    args.num_experts,
                    dropout,
                )
            else:
                raise ValueError(
                    'translation_moe task with learned prior requires the model to '
                    'have a gating network; try using --mean-pool-gating-network'
                )
        return model
コード例 #2
0
    def build_model(self, args):
        from fairseq import models
        model = models.build_model(args, self)
        xml_estimator = None
        estimator = None

        if not self.uniform_prior and not hasattr(model, 'gating_network'):
            if self.args.mean_pool_gating_network:
                if getattr(args, 'mean_pool_gating_network_encoder_dim', None):
                    encoder_dim = args.mean_pool_gating_network_encoder_dim
                elif getattr(args, 'encoder_embed_dim', None):
                    # assume that encoder_embed_dim is the encoder's output dimension
                    encoder_dim = args.encoder_embed_dim
                else:
                    raise ValueError('Must specify --mean-pool-gating-network-encoder-dim')

                if getattr(args, 'mean_pool_gating_network_dropout', None):
                    dropout = args.mean_pool_gating_network_dropout
                elif getattr(args, 'dropout', None):
                    dropout = args.dropout
                else:
                    raise ValueError('Must specify --mean-pool-gating-network-dropout')

                model.gating_network = modules.MeanPoolGatingNetwork(
                    encoder_dim, args.num_experts, dropout,
                )
            else:
                raise ValueError(
                    'translation_moe task with learned prior requires the model to '
                    'have a gating network; try using --mean-pool-gating-network'
                )
        if self.share_xml_dict:
            estimator = Estimator(self.estimator_hidden_dim, args.estimator_xml_dim + args.estimator_transformer_dim,
                                  dropout, topk_time_step=self.topk_time_step)
        elif self.estimator_xml_only:
            estimator = Estimator(self.estimator_hidden_dim, args.estimator_xml_dim, dropout,
                                  topk_time_step=self.topk_time_step)
        elif args.estimator_transformer_dim != 0:
            if self.share_estimator:
                estimator = Estimator(self.estimator_hidden_dim, args.estimator_transformer_dim, dropout,
                                      share_estimator=True, topk_time_step=self.topk_time_step)
            else:
                estimator = Estimator(self.estimator_hidden_dim, args.estimator_transformer_dim, dropout,
                                      topk_time_step=self.topk_time_step)
            if args.estimator_xml_dim != 0:
                xml_estimator = Estimator(self.estimator_hidden_dim, args.estimator_xml_dim, dropout,
                                          topk_time_step=self.topk_time_step)
        else:
            raise ValueError(
                'translation_moe task with learned prior requires the model to '
                'have a gating network; try using --mean-pool-gating-network'
            )

        return model, estimator, xml_estimator