Пример #1
0
class Wav2Vec2FTASRModel(BaseFairseqModel):
    def __init__(self, cfg: Wav2Vec2FTASRConfig):
        super().__init__()
        self.cfg = cfg

        feature_enc_layers = eval(cfg.conv_feature_layers)
        self.embed = feature_enc_layers[-1][0]

        self.feature_extractor = ConvFeatureExtractionModel(
            conv_layers=feature_enc_layers,
            dropout=0.0,
            mode=cfg.extractor_mode,
            conv_bias=cfg.conv_bias,
        )

        self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim)
                                  if self.embed != cfg.encoder_embed_dim
                                  and not cfg.quantize_input else None)

        self.mask_prob = cfg.mask_prob
        self.mask_selection = cfg.mask_selection
        self.mask_other = cfg.mask_other
        self.mask_length = cfg.mask_length
        self.no_mask_overlap = cfg.no_mask_overlap
        self.mask_min_space = cfg.mask_min_space

        self.mask_channel_prob = cfg.mask_channel_prob
        self.mask_channel_selection = cfg.mask_channel_selection
        self.mask_channel_other = cfg.mask_channel_other
        self.mask_channel_length = cfg.mask_channel_length
        self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
        self.mask_channel_min_space = cfg.mask_channel_min_space

        self.n_freeze = cfg.n_freeze_layers
        self.n_remove = cfg.n_remove_layers

        self.dropout_input = nn.Dropout(cfg.dropout_input)
        self.dropout_features = nn.Dropout(cfg.dropout_features)

        self.feature_grad_mult = cfg.feature_grad_mult

        self.quantizer = None
        self.input_quantizer = None

        self.n_negatives = cfg.num_negatives
        self.cross_sample_negatives = cfg.cross_sample_negatives
        self.codebook_negatives = cfg.codebook_negatives
        self.negatives_from_everywhere = cfg.negatives_from_everywhere

        self.logit_temp = cfg.logit_temp

        final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim

        if cfg.quantize_targets:
            vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else final_dim
            self.quantizer = GumbelVectorQuantizer(
                dim=self.embed,
                num_vars=cfg.latent_vars,
                temp=cfg.latent_temp,
                groups=cfg.latent_groups,
                combine_groups=False,
                vq_dim=vq_dim,
                time_first=True,
            )
            self.project_q = nn.Linear(vq_dim, final_dim)
        else:
            self.project_q = nn.Linear(self.embed, final_dim)

        if cfg.quantize_input:
            if cfg.same_quantizer and self.quantizer is not None:
                vq_dim = final_dim
                self.input_quantizer = self.quantizer
            else:
                vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else cfg.encoder_embed_dim
                self.input_quantizer = GumbelVectorQuantizer(
                    dim=self.embed,
                    num_vars=cfg.latent_vars,
                    temp=cfg.latent_temp,
                    groups=cfg.latent_groups,
                    combine_groups=False,
                    vq_dim=vq_dim,
                    time_first=True,
                )
            self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim)

        self.mask_emb = nn.Parameter(
            torch.FloatTensor(cfg.encoder_embed_dim).uniform_())

        self.encoder = TransformerEncoder(cfg)
        self.layer_norm = LayerNorm(self.embed)

        self.target_glu = None
        if cfg.target_glu:
            self.target_glu = nn.Sequential(
                nn.Linear(final_dim, final_dim * 2), nn.GLU())

        self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
        self.set_grad_to_false()

    def set_grad_to_false(self):
        if self.n_freeze > 0:
            for param in self.feature_extractor.parameters():
                param.requires_grad = False
            for param in self.post_extract_proj.parameters():
                param.requires_grad = False
            for param in self.layer_norm.parameters():
                param.requires_grad = False
        for param in self.quantizer.parameters():
            param.requires_grad = False

        for param in self.project_q.parameters():
            param.requires_grad = False

        for param in self.final_proj.parameters():
            param.requires_grad = False

    def upgrade_state_dict_named(self, state_dict, name):
        super().upgrade_state_dict_named(state_dict, name)
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        return state_dict

    @classmethod
    def build_model(cls, cfg: Wav2Vec2FTASRConfig, task=None):
        """Build a new model instance."""

        return cls(cfg)

    def apply_mask(self, x, padding_mask):
        B, T, C = x.shape
        if self.mask_prob > 0:
            mask_indices = compute_mask_indices(
                (B, T),
                padding_mask,
                self.mask_prob,
                self.mask_length,
                self.mask_selection,
                self.mask_other,
                min_masks=2,
                no_overlap=self.no_mask_overlap,
                min_space=self.mask_min_space,
            )
            mask_indices = torch.from_numpy(mask_indices).to(x.device)
            x[mask_indices] = self.mask_emb
        else:
            mask_indices = None

        if self.mask_channel_prob > 0:
            mask_channel_indices = compute_mask_indices(
                (B, C),
                None,
                self.mask_channel_prob,
                self.mask_channel_length,
                self.mask_channel_selection,
                self.mask_channel_other,
                no_overlap=self.no_mask_channel_overlap,
                min_space=self.mask_channel_min_space,
            )
            mask_channel_indices = (torch.from_numpy(mask_channel_indices).to(
                x.device).unsqueeze(1).expand(-1, T, -1))
            x[mask_channel_indices] = 0

        return x, mask_indices

    def sample_negatives(self, y, num):

        if self.n_negatives == 0 and self.cross_sample_negatives == 0:
            return y.new(0)

        bsz, tsz, fsz = y.shape
        y = y.view(-1, fsz)  # BTC => (BxT)C

        cross_high = tsz * bsz
        high = tsz
        with torch.no_grad():
            assert high > 1, f"{bsz,tsz,fsz}"

            if self.n_negatives > 0:
                tszs = (buffered_arange(num).unsqueeze(-1).expand(
                    -1, self.n_negatives).flatten())

                neg_idxs = torch.randint(low=0,
                                         high=high - 1,
                                         size=(bsz, self.n_negatives * num))
                neg_idxs[neg_idxs >= tszs] += 1

            if self.cross_sample_negatives > 0:
                tszs = (buffered_arange(num).unsqueeze(-1).expand(
                    -1, self.cross_sample_negatives).flatten())

                cross_neg_idxs = torch.randint(
                    low=0,
                    high=cross_high - 1,
                    size=(bsz, self.cross_sample_negatives * num),
                )
                cross_neg_idxs[cross_neg_idxs >= tszs] += 1

        if self.n_negatives > 0:
            for i in range(1, bsz):
                neg_idxs[i] += i * high
        else:
            neg_idxs = cross_neg_idxs

        if self.cross_sample_negatives > 0 and self.n_negatives > 0:
            neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1)

        negs = y[neg_idxs.view(-1)]
        negs = negs.view(bsz, num,
                         self.n_negatives + self.cross_sample_negatives,
                         fsz).permute(2, 0, 1, 3)  # to NxBxTxC
        return negs, neg_idxs

    def compute_preds(self, x, y, negatives):

        neg_is_pos = (y == negatives).all(-1)
        y = y.unsqueeze(0)
        targets = torch.cat([y, negatives], dim=0)

        logits = torch.cosine_similarity(x.float(), targets.float(),
                                         dim=-1).type_as(x)

        logits /= self.logit_temp

        if neg_is_pos.any():
            logits[1:][neg_is_pos] = float("-inf")

        return logits

    def forward(self,
                source,
                padding_mask=None,
                mask=True,
                features_only=False):
        if len(source.shape) <= 2:
            if self.feature_grad_mult > 0:
                features = self.feature_extractor(source)
                if self.feature_grad_mult != 1.0:
                    features = GradMultiply.apply(features,
                                                  self.feature_grad_mult)
            else:
                with torch.no_grad():
                    features = self.feature_extractor(source)
            features = features.transpose(1, 2)
        else:
            features = source

        features_pen = features.float().pow(2).mean()
        features = self.layer_norm(features)
        unmasked_features = features.clone()

        if padding_mask is not None:
            extra = padding_mask.size(1) % features.size(1)
            if extra > 0:
                padding_mask = padding_mask[:, :-extra]
            padding_mask = padding_mask.view(padding_mask.size(0),
                                             features.size(1), -1)
            padding_mask = padding_mask.all(-1)

        if self.post_extract_proj is not None:
            features = self.post_extract_proj(features)

        features = self.dropout_input(features)
        unmasked_features = self.dropout_features(unmasked_features)

        num_vars = None
        code_ppl = None
        prob_ppl = None
        curr_temp = None

        if self.input_quantizer:
            q = self.input_quantizer(features, produce_targets=False)
            features = q["x"]
            num_vars = q["num_vars"]
            code_ppl = q["code_perplexity"]
            prob_ppl = q["prob_perplexity"]
            curr_temp = q["temp"]
            features = self.project_inp(features)

        if mask:
            x, mask_indices = self.apply_mask(features, padding_mask)
            if mask_indices is not None:
                y = unmasked_features[mask_indices].view(
                    unmasked_features.size(0), -1, unmasked_features.size(-1))
            else:
                y = unmasked_features
        else:
            x = features
            y = unmasked_features
            mask_indices = None

        x = self.encoder(x, padding_mask=padding_mask)

        if features_only:
            return {"x": x, "padding_mask": padding_mask}

        if self.quantizer:
            q = self.quantizer(y, produce_targets=False)
            y = q["x"]
            num_vars = q["num_vars"]
            code_ppl = q["code_perplexity"]
            prob_ppl = q["prob_perplexity"]
            curr_temp = q["temp"]

            y = self.project_q(y)

            if self.negatives_from_everywhere:
                neg_cands, *_ = self.quantizer(unmasked_features,
                                               produce_targets=False)
                negs, _ = self.sample_negatives(neg_cands, y.size(1))
                negs = self.project_q(negs)

            else:
                negs, _ = self.sample_negatives(y, y.size(1))

            if self.codebook_negatives > 0:
                cb_negs = self.quantizer.sample_from_codebook(
                    y.size(0) * y.size(1), self.codebook_negatives)
                cb_negs = cb_negs.view(self.codebook_negatives, y.size(0),
                                       y.size(1), -1)  # order doesnt matter
                cb_negs = self.project_q(cb_negs)
                negs = torch.cat([negs, cb_negs], dim=0)
        else:
            y = self.project_q(y)

            if self.negatives_from_everywhere:
                negs, _ = self.sample_negatives(unmasked_features, y.size(1))
                negs = self.project_q(negs)
            else:
                negs, _ = self.sample_negatives(y, y.size(1))

        x = x[mask_indices].view(x.size(0), -1, x.size(-1))

        if self.target_glu:
            y = self.target_glu(y)
            negs = self.target_glu(negs)

        x = self.final_proj(x)
        x = self.compute_preds(x, y, negs)

        result = {
            "x": x,
            "padding_mask": padding_mask,
            "features_pen": features_pen
        }

        if prob_ppl is not None:
            result["prob_perplexity"] = prob_ppl
            result["code_perplexity"] = code_ppl
            result["num_vars"] = num_vars
            result["temp"] = curr_temp

        return result

    def quantize(self, x):
        assert self.quantizer is not None
        x = self.feature_extractor(x)
        x = x.transpose(1, 2)
        x = self.layer_norm(x)
        return self.quantizer.forward_idx(x)

    def extract_features(self, source, padding_mask, mask=False):
        res = self.forward(source, padding_mask, mask=mask, features_only=True)
        return res["x"], res["padding_mask"]

    def get_logits(self, net_output):
        logits = net_output["x"]
        logits = logits.transpose(0, 2)
        logits = logits.reshape(-1, logits.size(-1))
        return logits

    def get_targets(self, sample, net_output, expand_steps=True):
        x = net_output["x"]
        return x.new_zeros(x.size(1) * x.size(2), dtype=torch.long)

    def get_extra_losses(self, net_output):
        pen = []

        if "prob_perplexity" in net_output:
            pen.append(
                (net_output["num_vars"] - net_output["prob_perplexity"]) /
                net_output["num_vars"])

        if "features_pen" in net_output:
            pen.append(net_output["features_pen"])

        return pen

    def remove_pretraining_modules(self):
        self.quantizer = None
        self.project_q = None
        self.target_glu = None
        self.final_proj = None
Пример #2
0
class ABCModel(BaseFairseqModel):
    @staticmethod
    def add_args(parser):
        """Add model-specific arguments to the parser."""

        parser.add_argument(
            "--extractor-mode",
            choices=["default", "layer_norm"],
            help=
            "mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with --normalize)",
        )

        parser.add_argument(
            "--encoder-layers",
            type=int,
            metavar="L",
            help="num encoder layers in the transformer",
        )
        parser.add_argument(
            "--encoder-embed-dim",
            type=int,
            metavar="H",
            help="encoder embedding dimension",
        )
        parser.add_argument(
            "--encoder-ffn-embed-dim",
            type=int,
            metavar="F",
            help="encoder embedding dimension for FFN",
        )
        parser.add_argument(
            "--encoder-attention-heads",
            type=int,
            metavar="A",
            help="num encoder attention heads",
        )
        parser.add_argument(
            "--activation-fn",
            choices=utils.get_available_activation_fns(),
            help="activation function to use",
        )

        parser.add_argument(
            "--dropout",
            type=float,
            metavar="D",
            help="dropout probability for the transformer",
        )

        parser.add_argument(
            "--attention-dropout",
            type=float,
            metavar="D",
            help="dropout probability for attention weights",
        )

        parser.add_argument(
            "--activation-dropout",
            type=float,
            metavar="D",
            help="dropout probability after activation in FFN",
        )

        parser.add_argument(
            "--final-dim",
            type=int,
            metavar="D",
            help=
            "project final representations and targets to this many dimensions",
        )

        parser.add_argument(
            "--layer-norm-first",
            action="store_true",
            help="apply layernorm first in the transformer",
        )

        parser.add_argument(
            "--encoder-layerdrop",
            type=float,
            help="probability of dropping a tarnsformer layer",
        )

        parser.add_argument(
            "--conv-feature-layers",
            type=str,
            metavar="EXPR",
            help=
            "convolutional feature extraction layers [(dim, kernel_size, stride), ...]",
        )

        parser.add_argument("--logit-temp",
                            type=float,
                            help="temperature to divide logits by")

        parser.add_argument("--quantize-targets",
                            action="store_true",
                            help="use quantized targets")

        parser.add_argument("--quantize-input",
                            action="store_true",
                            help="use quantized inputs")

        parser.add_argument(
            "--same-quantizer",
            action="store_true",
            help="use same quantizer for inputs and targets",
        )

        parser.add_argument(
            "--feature-grad-mult",
            type=float,
            help="multiply feature extractor var grads by this",
        )

        parser.add_argument(
            "--latent-vars",
            type=int,
            metavar="N",
            help="number of latent variables V in each group of the codebook",
        )

        parser.add_argument(
            "--latent-groups",
            type=int,
            metavar="N",
            help="number of groups G of latent variables in the codebook",
        )

        parser.add_argument(
            "--latent-dim",
            type=int,
            metavar="N",
            help=
            "if set, uses this dimensionality for latent variables. otherwise uses final_dim / latent_groups",
        )

        parser.add_argument("--mask-length", type=int, help="mask length")

        parser.add_argument("--mask-prob",
                            type=float,
                            help="probability of replacing a token with mask")

        parser.add_argument(
            "--mask-selection",
            type=str,
            choices=["static", "uniform", "normal", "poisson"],
            help="how to choose masks",
        )

        parser.add_argument(
            "--mask-other",
            type=float,
            help=
            "secondary mask argument (used for more complex distributions), see help in compute_mask_indices",
        )

        parser.add_argument(
            "--no-mask-overlap",
            action="store_true",
            help="whether to allow masks to overlap",
        )

        parser.add_argument(
            "--mask-min-space",
            type=int,
            help="min space between spans (if no overlap is enabled)",
        )

        parser.add_argument(
            "--mask-channel-length",
            type=int,
            help="repeat the mask indices multiple times",
        )

        parser.add_argument(
            "--mask-channel-prob",
            type=float,
            help="probability of replacing a token with mask",
        )

        parser.add_argument(
            "--mask-channel-selection",
            type=str,
            choices=["static", "uniform", "normal", "poisson"],
            help="how to choose masks",
        )

        parser.add_argument(
            "--mask-channel-other",
            type=float,
            help=
            "secondary mask argument (used for more complex distributions), see help in compute_mask_indices",
        )

        parser.add_argument(
            "--no-mask-channel-overlap",
            action="store_true",
            help="whether to allow masks to overlap",
        )

        parser.add_argument(
            "--mask-channel-min-space",
            type=int,
            help="min space between spans (if no overlap is enabled)",
        )

        parser.add_argument(
            "--dropout-input",
            type=float,
            metavar="D",
            help="dropout to apply to the input (after feat extr)",
        )

        parser.add_argument(
            "--dropout-features",
            type=float,
            metavar="D",
            help="dropout to apply to the features (after feat extr)",
        )

        parser.add_argument("--num-negatives",
                            type=int,
                            metavar="N",
                            help="number of negative examples")

        parser.add_argument(
            "--negatives-from-everywhere",
            action="store_true",
            help="sample negatives from everywhere, not just masked states",
        )

        parser.add_argument(
            "--cross-sample-negatives",
            type=int,
            metavar="N",
            help="num of cross sampled negatives",
        )

        parser.add_argument(
            "--codebook-negatives",
            type=int,
            metavar="N",
            help="num of codebook sampled negatives",
        )

        #parser.add_argument(
        #    "--byol-all",
        #    action="store_true",
        #    help="apply byol to whole network"
        #)

        parser.add_argument("--base-decay",
                            type=float,
                            metavar="D",
                            help="base decay value of target network")

        parser.add_argument("--projection-dim",
                            type=int,
                            metavar="N",
                            help="byol projection network's output dimension")

        parser.add_argument("--prediction-dim",
                            type=int,
                            metavar="N",
                            help="byol prediction network's output dimension")

        parser.add_argument("--byol-hidden-dim",
                            type=int,
                            metavar="N",
                            help="hidden dimension of MLPs used for byol")

        parser.add_argument("--shared-quantizer",
                            action="store_true",
                            help="share quantizer with target network")

        parser.add_argument("--shared-emb",
                            action="store_true",
                            help="share mask embedding parameter")

        parser.add_argument("--mlp-prediction",
                            action="store_true",
                            help="use MLP as prediction network")

        parser.add_argument("--mlp-encoder",
                            action="store_true",
                            help="use MLP as encoder network")

        parser.add_argument(
            "--separate-mask-indices",
            action="store_true",
            help="use two separate mask indices for Transformers")

        parser.add_argument(
            "--conv-pos",
            type=int,
            metavar="N",
            help="number of filters for convolutional positional embeddings",
        )

        parser.add_argument(
            "--conv-pos-groups",
            type=int,
            metavar="N",
            help="number of groups for convolutional positional embedding",
        )

        parser.add_argument(
            "--latent-temp",
            type=str,
            metavar="D",
            help=
            "temperature for latent variable sampling. can be tuple of 3 values (start, end, decay)",
        )

        parser.add_argument("--target-glu",
                            action="store_true",
                            help="adds projection + glu to targets")

        parser.add_argument("--conv-bias",
                            action="store_true",
                            help="include bias in conv encoder")

    def __init__(self, args):
        super().__init__()
        self.args = args

        feature_enc_layers = eval(args.conv_feature_layers)
        self.embed = feature_enc_layers[-1][0]

        self.online_params = []
        self.target_params = []

        self.feature_extractor = ConvFeatureExtractionModel(
            conv_layers=feature_enc_layers,
            dropout=0.0,
            mode=args.extractor_mode,
            conv_bias=args.conv_bias,
        )

        self.feature_extractor_target = ConvFeatureExtractionModel(
            conv_layers=feature_enc_layers,
            dropout=0.0,
            mode=args.extractor_mode,
            conv_bias=args.conv_bias,
        )

        self.online_params += list(self.feature_extractor.parameters())
        self.target_params += list(self.feature_extractor_target.parameters())

        if self.embed != args.encoder_embed_dim and not args.quantize_input:
            self.post_extract_proj = (nn.Linear(self.embed,
                                                args.encoder_embed_dim))
            self.post_extract_proj_target = (nn.Linear(self.embed,
                                                       args.encoder_embed_dim))
            self.online_params += list(self.post_extract_proj.parameters())
            self.target_params += list(
                self.post_extract_proj_target.parameters())
        else:
            self.post_extract_proj = None
            self.post_extract_proj_target = None

        self.mask_prob = args.mask_prob
        self.mask_selection = args.mask_selection
        self.mask_other = args.mask_other
        self.mask_length = args.mask_length
        self.no_mask_overlap = args.no_mask_overlap
        self.mask_min_space = args.mask_min_space

        self.mask_channel_prob = args.mask_channel_prob
        self.mask_channel_selection = args.mask_channel_selection
        self.mask_channel_other = args.mask_channel_other
        self.mask_channel_length = args.mask_channel_length
        self.no_mask_channel_overlap = args.no_mask_channel_overlap
        self.mask_channel_min_space = args.mask_channel_min_space

        self.dropout_input = nn.Dropout(args.dropout_input)
        self.dropout_features = nn.Dropout(args.dropout_features)

        self.base_decay = args.base_decay
        self.step = 0
        self.total_steps = args.total_num_update

        self.feature_grad_mult = args.feature_grad_mult

        self.quantizer = None
        self.input_quantizer = None
        self.shared_quantizer = args.shared_quantizer

        final_dim = args.final_dim if args.final_dim > 0 else args.encoder_embed_dim

        if args.quantize_targets:
            vq_dim = args.latent_dim if args.latent_dim > 0 else final_dim
            self.quantizer = GumbelVectorQuantizer(
                dim=self.embed,
                num_vars=args.latent_vars,
                temp=eval(args.latent_temp),
                groups=args.latent_groups,
                combine_groups=False,
                vq_dim=vq_dim,
                time_first=True,
            )

            if not args.shared_quantizer:
                self.quantizer_target = GumbelVectorQuantizer(
                    dim=self.embed,
                    num_vars=args.latent_vars,
                    temp=eval(args.latent_temp),
                    groups=args.latent_groups,
                    combine_groups=False,
                    vq_dim=vq_dim,
                    time_first=True,
                )
                self.online_params += list(self.quantizer.parameters())
                self.target_params += list(self.quantizer_target.parameters())

            ### TODO: separate project_q?
            self.project_q = nn.Linear(vq_dim, final_dim)
        else:
            self.project_q = nn.Linear(self.embed, final_dim)

        if args.quantize_input:
            if args.same_quantizer and self.quantizer is not None:
                vq_dim = final_dim
                self.input_quantizer = self.quantizer
                if not args.shared_quantizer:
                    self.input_quantizer_target = self.quantizer_target
            else:
                vq_dim = (args.latent_dim
                          if args.latent_dim > 0 else args.encoder_embed_dim)
                self.input_quantizer = GumbelVectorQuantizer(
                    dim=self.embed,
                    num_vars=args.latent_vars,
                    temp=eval(args.latent_temp),
                    groups=args.latent_groups,
                    combine_groups=False,
                    vq_dim=vq_dim,
                    time_first=True,
                )

                if not args.shared_quantizer:
                    self.input_quantizer_target = GumbelVectorQuantizer(
                        dim=self.embed,
                        num_vars=args.latent_vars,
                        temp=eval(args.latent_temp),
                        groups=args.latent_groups,
                        combine_groups=False,
                        vq_dim=vq_dim,
                        time_first=True,
                    )
                    self.online_params += list(
                        self.input_quantizer.parameters())
                    self.target_params += list(
                        self.input_quantizer_target.parameters())

            self.project_inp = nn.Linear(vq_dim, args.encoder_embed_dim)

        self.mask_emb = nn.Parameter(
            torch.FloatTensor(args.encoder_embed_dim).uniform_())

        if not args.shared_emb:
            self.mask_emb_target = nn.Parameter(
                torch.FloatTensor(args.encoder_embed_dim).uniform_())
            #self.online_params += list(self.mask_emb)
            #self.target_params += list(self.mask_emb_target)

        if args.mlp_encoder:
            self.encoder = nn.Sequential(
                nn.Linear(args.encoder_embed_dim, args.byol_hidden_dim),
                nn.BatchNorm1d(args.byol_hidden_dim), nn.ReLU(inplace=True),
                nn.Linaer(args.byol_hidden_dim, final_dim))
            self.encoder_target = nn.Sequential(
                nn.Linear(args.encoder_embed_dim, args.byol_hidden_dim),
                nn.BatchNorm1d(args.byol_hidden_dim), nn.ReLU(inplace=True),
                nn.Linaer(args.byol_hidden_dim, final_dim))
        else:
            self.encoder = TransformerEncoder(args)
            self.encoder_target = TransformerEncoder(args)

        self.online_params += list(self.encoder.parameters())
        self.target_params += list(self.encoder_target.parameters())

        self.layer_norm = LayerNorm(self.embed)

        self.target_glu = None
        if args.target_glu:
            self.target_glu = nn.Sequential(
                nn.Linear(final_dim, final_dim * 2), nn.GLU())
            self.target_glu_target = nn.Sequential(
                nn.Linear(final_dim, final_dim * 2), nn.GLU())
            self.online_params += list(self.target_glu.parameters())
            self.target_params += list(self.target_glu_target.parameters())

        self.final_proj = nn.Linear(args.encoder_embed_dim, final_dim)
        self.final_proj_target = nn.Linear(args.encoder_embed_dim, final_dim)

        self.online_params += list(self.final_proj.parameters())
        self.target_params += list(self.final_proj_target.parameters())

        ### TODO: Transformer prediction network?
        if args.mlp_prediction:
            self.prediction = MLPPrediction(final_dim, args.byol_hidden_dim)
        else:
            self.prediction = TransformerEncoder(args, args.final_dim)

        for param in self.target_params:
            param.requires_grad = False

    #def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]:
    #    gen = self._named_members(
    #        lambda module: module._parameters.items(),
    #        prefix=prefix, recurse=recurse)
    #    for elem in gen:
    #        if elem[1] not in self.target_params:
    #            yield elem

    def update_target(self):
        decay = 1 - (1 - self.base_decay) * (
            np.cos(np.pi * self.step / self.total_steps) + 1) / 2.0
        self.step += 1

        for online_param, target_param in zip(self.online_params,
                                              self.target_params):
            target_old = target_param.data
            target_param.data = decay * target_old + (
                1 - decay) * online_param.data

    def upgrade_state_dict_named(self, state_dict, name):
        super().upgrade_state_dict_named(state_dict, name)
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        return state_dict

    @classmethod
    def build_model(cls, args, task=None):
        """Build a new model instance."""

        # make sure all arguments are present
        base_architecture(args)

        return cls(args)

    def apply_mask(self, x, padding_mask, mask_indices=None):
        B, T, C = x.shape
        if mask_indices is None:
            if self.mask_prob > 0:
                mask_indices = compute_mask_indices(
                    (B, T),
                    padding_mask,
                    self.mask_prob,
                    self.mask_length,
                    self.mask_selection,
                    self.mask_other,
                    min_masks=2,
                    no_overlap=self.no_mask_overlap,
                    min_space=self.mask_min_space,
                )
                mask_indices = torch.from_numpy(mask_indices).to(x.device)
                x[mask_indices] = self.mask_emb
            else:
                mask_indices = None
        else:
            x[mask_indices] = self.mask_emb

        if self.mask_channel_prob > 0:
            mask_channel_indices = compute_mask_indices(
                (B, C),
                None,
                self.mask_channel_prob,
                self.mask_channel_length,
                self.mask_channel_selection,
                self.mask_channel_other,
                no_overlap=self.no_mask_channel_overlap,
                min_space=self.mask_channel_min_space,
            )
            mask_channel_indices = (torch.from_numpy(mask_channel_indices).to(
                x.device).unsqueeze(1).expand(-1, T, -1))
            x[mask_channel_indices] = 0

        return x, mask_indices

    def quantize(self, x):
        assert self.quantizer is not None
        x = self.feature_extractor(x)
        x = x.transpose(1, 2)
        x = self.layer_norm(x)
        return self.quantizer.forward_idx(x)

    def extract_features(self, source, padding_mask, mask=False):
        res = self.predict(source, padding_mask, mask=mask, features_only=True)
        return res["x"], res["padding_mask"]

    def predict(self,
                source,
                padding_mask=None,
                mask=True,
                features_only=False,
                mask_indices=None):
        if self.feature_grad_mult > 0:
            features = self.feature_extractor(source)
            if self.feature_grad_mult != 1.0:
                features = GradMultiply.apply(features, self.feature_grad_mult)
        else:
            with torch.no_grad():
                features = self.feature_extractor(source)

        features_pen = features.float().pow(2).mean()

        features = features.transpose(1, 2)
        features = self.layer_norm(features)
        unmasked_features = features.clone()

        if padding_mask is not None:
            extra = padding_mask.size(1) % features.size(1)
            if extra > 0:
                padding_mask = padding_mask[:, :-extra]
            padding_mask = padding_mask.view(padding_mask.size(0),
                                             features.size(1), -1)
            padding_mask = padding_mask.all(-1)

        if self.post_extract_proj is not None:
            features = self.post_extract_proj(features)

        #print("online feature", features.size())
        features = self.dropout_input(features)
        unmasked_features = self.dropout_features(unmasked_features)

        num_vars = None
        code_ppl = None
        prob_ppl = None
        curr_temp = None

        if self.input_quantizer:
            q = self.input_quantizer(features, produce_targets=False)
            features = q["x"]
            num_vars = q["num_vars"]
            code_ppl = q["code_perplexity"]
            prob_ppl = q["prob_perplexity"]
            curr_temp = q["temp"]
            features = self.project_inp(features)

        if mask:
            x, mask_indices = self.apply_mask(features, padding_mask,
                                              mask_indices)
            if mask_indices is not None:
                y = unmasked_features[mask_indices].view(
                    unmasked_features.size(0), -1, unmasked_features.size(-1))
            else:
                y = unmasked_features
        else:
            x = features
            y = unmasked_features
            mask_indices = None

        x = self.encoder(x, padding_mask=padding_mask)
        #print("online encoded", x.size())
        if features_only:
            return {"x": x, "padding_mask": padding_mask}

        if self.quantizer:
            q = self.quantizer(y, produce_targets=False)
            y = q["x"]
            num_vars = q["num_vars"]
            code_ppl = q["code_perplexity"]
            prob_ppl = q["prob_perplexity"]
            curr_temp = q["temp"]

            y = self.project_q(y)

        else:
            y = self.project_q(y)
        #print("online before masking", x.size())
        x = x[mask_indices].view(x.size(0), -1, x.size(-1))
        #print("online after masking", x.size())
        if self.target_glu:
            y = self.target_glu(y)
            negs = self.target_glu(negs)

        x = self.final_proj(x)
        #x = self.compute_preds(x, y, negs)

        #print("online before prediction", x.size())
        if self.args.mlp_prediction:
            x = self.prediction(x)
        else:
            x = self.prediction(x, padding_mask=padding_mask)

        result = {
            "x": x,
            "padding_mask": padding_mask,
            "features_pen": features_pen
        }

        if prob_ppl is not None:
            result["prob_perplexity"] = prob_ppl
            result["code_perplexity"] = code_ppl
            result["num_vars"] = num_vars
            result["temp"] = curr_temp

        return result, mask_indices

    def target_predict(self,
                       source,
                       padding_mask=None,
                       mask=True,
                       mask_indices=None):

        with torch.no_grad():
            if self.feature_grad_mult > 0:
                features = self.feature_extractor_target(source)
                if self.feature_grad_mult != 1.0:
                    features = GradMultiply.apply(features,
                                                  self.feature_grad_mult)
            else:
                with torch.no_grad():
                    features = self.feature_extractor_target(source)

            features_pen = features.float().pow(2).mean()

            features = features.transpose(1, 2)
            features = self.layer_norm(features)
            unmasked_features = features.clone()

            if padding_mask is not None:
                extra = padding_mask.size(1) % features.size(1)
                if extra > 0:
                    padding_mask = padding_mask[:, :-extra]
                padding_mask = padding_mask.view(padding_mask.size(0),
                                                 features.size(1), -1)
                padding_mask = padding_mask.all(-1)

            if self.post_extract_proj_target is not None:
                features = self.post_extract_proj_target(features)
            #print("target feature", features.size())
            features = self.dropout_input(features)
            unmasked_features = self.dropout_features(unmasked_features)

            num_vars = None
            code_ppl = None
            prob_ppl = None
            curr_temp = None

            if self.input_quantizer:
                if self.shared_quantizer:
                    q = self.input_quantizer(features, produce_targets=False)
                else:
                    q = self.input_quantizer(features, produce_targets=False)
                features = q["x"]
                num_vars = q["num_vars"]
                code_ppl = q["code_perplexity"]
                prob_ppl = q["prob_perplexity"]
                curr_temp = q["temp"]
                if self.shared_quantizer:
                    features = self.project_inp(features)
                else:
                    features = self.project_inp_target(features)

            if mask:
                x, mask_indices = self.apply_mask(features, padding_mask,
                                                  mask_indices)
                if mask_indices is not None:
                    y = unmasked_features[mask_indices].view(
                        unmasked_features.size(0), -1,
                        unmasked_features.size(-1))
                else:
                    y = unmasked_features
            else:
                x = features
                y = unmasked_features
                mask_indices = None

            x = self.encoder_target(x, padding_mask=padding_mask)
            #print("target encoded", x.size())

            if self.quantizer:
                if self.shared_quantizer:
                    q = self.quantizer(y, produce_targets=False)
                else:
                    q = self.quantizer_target(y, produce_targets=False)
                y = q["x"]
                num_vars = q["num_vars"]
                code_ppl = q["code_perplexity"]
                prob_ppl = q["prob_perplexity"]
                curr_temp = q["temp"]

                if self.shared_quantizer:
                    y = self.project_q(y)
                else:
                    y = self.project_q_target(y)

            else:
                y = self.project_q(y)
            #print("target before masking", x.size())
            x = x[mask_indices].view(x.size(0), -1, x.size(-1))
            #print("target after masking", x.size())
            if self.target_glu:
                y = self.target_glu_target(y)
                negs = self.target_glu_target(negs)

            x = self.final_proj_target(x)
            #x = self.compute_preds(x, y, negs)
            #print("target before prediction", x.size())
            result = {
                "x": x,
                "padding_mask": padding_mask,
                "features_pen": features_pen
            }

            if prob_ppl is not None:
                result["prob_perplexity"] = prob_ppl
                result["code_perplexity"] = code_ppl
                result["num_vars"] = num_vars
                result["temp"] = curr_temp

            return result

    #def forward(self, source, padding_mask=None, mask=True, features_only=True):
    #    result = self.prediction(source[0], padding_mask, mask, features_only)
    #
    #    return result["x"], result["padding_mask"]

    def forward(self,
                source,
                padding_mask=None,
                mask=True,
                features_only=False):

        if self.step != 0:
            self.update_target()

        self.step += 1

        result_0, mask_indices0 = self.predict(
            source[0], padding_mask[0] if padding_mask is not None else None,
            mask, features_only)

        if self.args.separate_mask_indices:
            result_1, mask_indices1 = self.predict(
                source[1],
                padding_mask[1] if padding_mask is not None else None,
                mask,
                features_only,
            )
        else:
            result_1, _ = self.predict(
                source[1],
                padding_mask[1] if padding_mask is not None else None,
                mask,
                features_only,
                mask_indices=mask_indices0)
            mask_indices1 = mask_indices0

        result_target_0 = self.target_predict(
            source[0],
            padding_mask[0] if padding_mask is not None else None,
            mask,
            mask_indices=mask_indices1)
        result_target_1 = self.target_predict(
            source[1],
            padding_mask[1] if padding_mask is not None else None,
            mask,
            mask_indices=mask_indices0)
        return result_0, result_1, result_target_0, result_target_1

    def remove_pretraining_modules(self):
        self.quantizer = None
        self.project_q = None
        self.target_glu = None
        self.final_proj = None

        self.feature_extractor_target = None
        self.post_extract_proj_target = None
        self.quantizer_target = None
        self.input_quantizer_target = None
        if not self.args.shared_emb:
            self.mask_emb_target = None
        self.encoder_target = None
        self.target_glu_target = None
        self.final_proj_target = None
        self.prediction = None