class TRANSFORMER(TransformerBaseWrapper): """ Use this class to extract features from the Transformer model, or to finetune the pre-trained Transformer with any downstream tasks. Also, this class is `pytorch-kaldi` ready, hence we need to use `str` instead of `bool` in the options dict, as pytorch-kaldi scripts will pass in str. Params: `options`: a python dictionary containing the following keys: ckpt_file: str, a path specifying the pre-trained ckpt file load_pretrain: str, ['True', 'False'], whether to load pre-trained weights no_grad: str, ['True', 'False'], whether to have gradient flow over this class dropout: float/str, use float to modify dropout value during downstream finetune, or use the str `default` for pre-train default values spec_aug: str, ['True', 'False'], whether to apply SpecAugment on inputs (used for ASR training) spec_aug_prev: str, ['True', 'False'], apply spec augment on input acoustic features if True, else apply on output representations (used for ASR training) weighted_sum: str, ['True', 'False'], whether to use a learnable weighted sum to integrate hidden representations from all layers, if False then use the last select_layer: int, select from all hidden representations, set to -1 to select the last (will only be used when weighted_sum is False) permute_input: str, ['True', 'False'], this attribute is for the forward method. If Ture then input ouput is in the shape of (T, B, D), if False then in (B, T, D) `intput_dim`: int, input dimension of model `config`: optional, reads the given yaml config and not use the config stored in `ckpt_file` An example `options` dictionary: options = { 'ckpt_file' : './result/result_transformer/libri_sd1337_fmllrBase960-F-N-K-RA/states-1000000.ckpt', 'load_pretrain' : 'True', 'no_grad' : 'True', 'dropout' : 'default', 'spec_aug' : 'False', 'spec_aug_prev' : 'True', 'weighted_sum' : 'False', 'select_layer' : -1, 'permute_input' : 'False', } """ def __init__(self, options, inp_dim, config=None, online_config=None): super(TRANSFORMER, self).__init__(options, inp_dim, config, online_config) # Build model self.model = TransformerModel(self.model_config, self.inp_dim).to(self.device) self.model.eval() if self.no_grad else self.model.train() self.out_dim = self.hidden_size # This attribute is necessary, for pytorch-kaldi and run_downstream.py # Load from a PyTorch state_dict if self.load: self.model = self.load_model(self.model, self.all_states['Transformer']) print('[Transformer] - Number of parameters: ' + str(sum(p.numel() for p in self.model.parameters() if p.requires_grad))) def forward(self, x): if hasattr(self, 'preprocessor'): x = self.preprocessor(x.transpose(1, 2).contiguous())[0] if self.no_grad: with torch.no_grad(): x = self._forward(x) else: x = self._forward(x) return x
class TRANSFORMER(nn.Module): def __init__(self, options, inp_dim, config=None): super(TRANSFORMER, self).__init__() if config is not None: self.config = yaml.load(open(config, 'r'), Loader=yaml.FullLoader) else: all_states = torch.load(options["ckpt_file"], map_location='cpu') self.config = all_states['Settings']['Config'] self.no_grad = bool(strtobool(options['no_grad'])) self.spec_aug = bool(strtobool(options['spec_aug'])) self.spec_aug_prev = bool(strtobool(options['spec_aug_prev'])) self.weighted_sum = bool(strtobool(options['weighted_sum'])) self.select_layer = int(options['select_layer']) if (not self.no_grad) and (not self.spec_aug_prev): raise RuntimeError('Only one of them can be set False!') # increase dropout if str(options['dropout']) != 'default': self.config['transformer']['hidden_dropout_prob'] = float(options['dropout']) self.config['transformer']['attention_probs_dropout_prob'] = float(options['dropout']) # Model Config self.model_config = TransformerConfig(self.config) self.dr = self.model_config.downsample_rate self.hidden_size = self.model_config.hidden_size self.num_layers = self.model_config.num_hidden_layers self.max_input_length = self.config['transformer']['max_input_length'] if 'max_input_length' in self.config['transformer'] else 0 if self.max_input_length > 0: print('[Transformer] - Maximum input length: ', self.max_input_length) if not (self.select_layer in list(range(-1, self.num_layers))): raise RuntimeError('Out of range int for \'select_layer\'!') # use weighted sum from all layers if self.weighted_sum: self.weight = nn.Parameter(torch.ones(self.num_layers) / self.num_layers) # Build model self.inp_dim = inp_dim if inp_dim > 0 else self.config['transformer']['input_dim'] self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') self.model = TransformerModel(self.model_config, self.inp_dim).to(self.device) self.model.eval() if self.no_grad else self.model.train() # Load from a PyTorch state_dict load = bool(strtobool(options["load_pretrain"])) if load: self.load_model(all_states['Transformer']) print('[Transformer] - Number of parameters: ' + str(sum(p.numel() for p in self.model.parameters() if p.requires_grad))) self.out_dim = self.hidden_size # 768, This attribute is for pytorch-kaldi and downstream runner self.permute_input = True # This attribute is for the forward method. If Ture then input ouput is in the shape of (T, B, D), if False then in (B, T, D) def load_model(self, state_dict): try: old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None if 'gamma' in key: new_key = key.replace('gamma', 'weight') if 'beta' in key: new_key = key.replace('beta', 'bias') if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) missing_keys = [] unexpected_keys = [] error_msgs = [] # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) module._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') load(self.model) if len(missing_keys) > 0: print('Weights of {} not initialized from pretrained model: {}'.format( self.model.__class__.__name__, missing_keys)) if len(unexpected_keys) > 0: print('Weights from pretrained model not used in {}: {}'.format( self.model.__class__.__name__, unexpected_keys)) if len(error_msgs) > 0: raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( self.model.__class__.__name__, '\n\t'.join(error_msgs))) print('[Transformer] - Pre-trained weights loaded!') except: raise RuntimeError('[Transformer] - Pre-trained weights NOT loaded!') def down_sample_frames(self, spec): spec = spec.contiguous() left_over = spec.shape[1] % self.dr if left_over != 0: spec = spec[:, :-left_over, :] spec_stacked = spec.view(spec.shape[0], spec.shape[1]//self.dr, spec.shape[2]*self.dr) return spec_stacked def process_input_data(self, spec): """Process input data for the model""" # add arbitary batch axis B if input `spec` has shape of TxD if len(spec.shape) == 2: spec = spec.unsqueeze(0) # input `spec` should have shape BxTxD elif len(spec.shape) != 3: raise ValueError('Input argument `spec` has invalid shape: {}'.format(spec.shape)) # Down sample if self.dr > 1: spec_stacked = self.down_sample_frames(spec) # (batch_size, seq_len, feature_dim * dr) else: spec_stacked = spec # Record length for each uttr spec_len = np.sum(np.sum(spec_stacked.cpu().data.numpy(), axis=-1) != 0, axis=-1) spec_len = [int(sl) for sl in spec_len] batch_size = spec_stacked.shape[0] seq_len = spec_stacked.shape[1] pos_enc = position_encoding(seq_len, self.hidden_size) # (seq_len, hidden_size) attn_mask = np.ones((batch_size, seq_len)) # (batch_size, seq_len) # zero vectors for padding dimension for idx in range(len(spec_stacked)): attn_mask[idx][spec_len[idx]:] = 0 if self.spec_aug and self.spec_aug_prev and self.model.training: spec_stacked = spec_augment(spec_stacked, mask_T=70, mask_F=4, num_T=2, num_F=2, p=1.0) # (batch_size, seq_len, feature_dim * dr) spec_stacked = spec_stacked.to(device=self.device, dtype=torch.float32) # (batch_size, seq_len, feature_dim * dr) pos_enc = torch.FloatTensor(pos_enc).to(device=self.device, dtype=torch.float32).expand(spec_stacked.size(0), *pos_enc.size()) # (batch_size, seq_len, hidden_size) attn_mask = torch.FloatTensor(attn_mask).to(device=self.device, dtype=torch.float32) # (batch_size, seq_len) return spec_stacked, pos_enc, attn_mask # (x, pos_enc, attention_mask) def tile_representations(self, reps): """ Tile up the speech representations to match the amount of input frames. Input - encoded_layers shape: (batch_size, sequence_length, hidden_size) Output - tiled_encoded_layers shape: (batch_size, sequence_length * downsample_rate, hidden_size) """ if len(reps.shape) != 3: raise ValueError('Input argument `reps` has invalid shape: {}'.format(reps.shape)) tiled_reps = reps.repeat(1, 1, self.dr) tiled_reps = tiled_reps.reshape(reps.size(0), reps.size(1)*self.dr, reps.size(2)) return tiled_reps # (batch_size, sequence_length * downsample_rate, hidden_size) def upsample(self, x, input_len): # Compute padding to compromise the downsample loss left_over = input_len % self.dr if left_over % 2 == 0: left_pad = left_over // 2 right_pad = left_pad else: left_pad = left_over // 2 right_pad = left_over // 2 + 1 x = self.tile_representations(x) # padding x = x.permute(0, 2, 1).contiguous() # (B, T, D) -> (B, D, T) padding = nn.ReplicationPad1d((left_pad, right_pad)) x = padding(x) x = x.permute(0, 2, 1).contiguous() # (B, D, T) -> (B, T, D) return x def _forward(self, x): if self.permute_input: x = x.permute(1, 0, 2).contiguous() # (T, B, D) -> (B, T, D) input_len = x.shape[1] # forward the whole sequence at once if self.max_input_length == 0 or input_len <= self.max_input_length: spec_stacked, pos_enc, attn_mask = self.process_input_data(x) # x shape: (B, T, D) x = self.model(spec_stacked, pos_enc, attn_mask, output_all_encoded_layers=self.weighted_sum or self.select_layer != -1) # (B, T, D) or # (N, B, T, D) # forward the sequence in chunks then concat else: chunks = torch.chunk(x, chunks=math.ceil(input_len / self.max_input_length), dim=1) x_ = [] for chunk in chunks: spec_stacked, pos_enc, attn_mask = self.process_input_data(chunk) # x shape: (B, T, D) chunk = self.model(spec_stacked, pos_enc, attn_mask, output_all_encoded_layers=self.weighted_sum or self.select_layer != -1) # (B, T, D) or # (N, B, T, D) x_.append(torch.stack(chunk) if type(chunk) is list else chunk) x = torch.cat(x_, dim=2 if (self.weighted_sum or self.select_layer != -1) else 1) # Apply weighted sum if self.weighted_sum: if type(x) is list: x = torch.stack(x) softmax_weight = nn.functional.softmax(self.weight, dim=-1) B, T, D = x.shape[1], x.shape[2], x.shape[3] x = x.reshape(self.num_layers, -1) x = torch.matmul(softmax_weight, x).reshape(B, T, D) # Select a specific layer elif self.select_layer != -1: x = x[self.select_layer] if self.spec_aug and not self.spec_aug_prev and self.model.training: x = spec_augment(x, mask_T=70, mask_F=86, num_T=2, num_F=2, p=1.0) # (B, T, D) # If using a downsampling model, apply tile and padding if self.dr > 1: x = self.upsample(x, input_len) # (B, T, D) # permute to output if self.permute_input: x = x.permute(1, 0, 2).contiguous() # (B, T, D) -> (T, B, D) return x # (B, T, D) or (T, B, D) def forward(self, x): if self.no_grad: with torch.no_grad(): self.model.eval() x = self._forward(x) else: x = self._forward(x) return x