def load_checkpoint( self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None, reset_meters=False, ): extra_state = super().load_checkpoint( filename, reset_optimizer=reset_optimizer, reset_lr_scheduler=reset_lr_scheduler, optimizer_overrides=optimizer_overrides, reset_meters=reset_meters) if extra_state is not None and 'rng_tracker_states' in extra_state: get_cuda_rng_tracker().set_states( extra_state['rng_tracker_states']) return extra_state
def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" extra_state["rng_tracker_states"] = get_cuda_rng_tracker().get_states() super().save_checkpoint(filename, extra_state)
def forward( self, query, key: Optional[Tensor], value: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, static_kv: bool = False, attn_mask: Optional[Tensor] = None, **unused_kwargs, ) -> Tuple[Tensor, Optional[Tensor]]: """Input shape: Time x Batch x Channel Args: key_padding_mask (ByteTensor, optional): mask to exclude keys that are pads, of shape `(batch, src_len)`, where padding elements are indicated by 1s. attn_mask (ByteTensor, optional): typically used to implement causal attention, where the mask prevents the attention from looking forward in time (default: None). """ tgt_len, bsz, embed_dim = query.size() assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] if incremental_state is not None: saved_state = self._get_input_buffer(incremental_state) if saved_state is not None and "prev_key" in saved_state: # previous time steps are cached - no need to recompute # key and value if they are static if static_kv: assert self.encoder_decoder_attention and not self.self_attention key = value = None else: saved_state = None if self.self_attention: q = self.q_proj(query) k = self.k_proj(query) v = self.v_proj(query) elif self.encoder_decoder_attention: # encoder-decoder attention q = self.q_proj(query) if key is None: assert value is None k = v = None else: k = self.k_proj(key) v = self.v_proj(key) else: assert key is not None and value is not None q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) q *= self.scaling q = (q.contiguous().view(tgt_len, bsz * self.num_heads_partition, self.head_dim).transpose(0, 1)) if k is not None: k = (k.contiguous().view(-1, bsz * self.num_heads_partition, self.head_dim).transpose(0, 1)) if v is not None: v = (v.contiguous().view(-1, bsz * self.num_heads_partition, self.head_dim).transpose(0, 1)) if saved_state is not None: # saved states are stored with shape (bsz, num_heads_partition, seq_len, head_dim) if "prev_key" in saved_state: _prev_key = saved_state["prev_key"] assert _prev_key is not None prev_key = _prev_key.view(bsz * self.num_heads_partition, -1, self.head_dim) if static_kv: k = prev_key else: assert k is not None k = torch.cat([prev_key, k], dim=1) if "prev_value" in saved_state: _prev_value = saved_state["prev_value"] assert _prev_value is not None prev_value = _prev_value.view(bsz * self.num_heads_partition, -1, self.head_dim) if static_kv: v = prev_value else: assert v is not None v = torch.cat([prev_value, v], dim=1) prev_key_padding_mask: Optional[Tensor] = None if "prev_key_padding_mask" in saved_state: prev_key_padding_mask = saved_state["prev_key_padding_mask"] assert k is not None and v is not None key_padding_mask = ModelParallelMultiheadAttention._append_prev_key_padding_mask( key_padding_mask=key_padding_mask, prev_key_padding_mask=prev_key_padding_mask, batch_size=bsz, src_len=k.size(1), static_kv=static_kv, ) saved_state["prev_key"] = k.view(bsz, self.num_heads_partition, -1, self.head_dim) saved_state["prev_value"] = v.view(bsz, self.num_heads_partition, -1, self.head_dim) saved_state["prev_key_padding_mask"] = key_padding_mask # In this branch incremental_state is never None assert incremental_state is not None incremental_state = self._set_input_buffer(incremental_state, saved_state) assert k is not None src_len = k.size(1) # This is part of a workaround to get around fork/join parallelism # not supporting Optional types. if key_padding_mask is not None and key_padding_mask.dim() == 0: key_padding_mask = None if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len attn_weights = torch.bmm(q, k.transpose(1, 2)) assert list(attn_weights.size()) == [ bsz * self.num_heads_partition, tgt_len, src_len ] if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) attn_weights += attn_mask if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads_partition, tgt_len, src_len) if not self.tpu: attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")) else: attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.masked_fill( key_padding_mask, float('-inf')) attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.view(bsz * self.num_heads_partition, tgt_len, src_len) attn_weights_float = utils.softmax(attn_weights, dim=-1) attn_weights = attn_weights_float.type_as(attn_weights) with get_cuda_rng_tracker().fork(): attn_probs = self.dropout_module(attn_weights) assert v is not None attn = torch.bmm(attn_probs, v) assert list(attn.size()) == [ bsz * self.num_heads_partition, tgt_len, self.head_dim ] embed_dim_partition = embed_dim // self.model_parallel_size attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim_partition) attn = self.out_proj(attn) # return attn_weights None to keep the return type same as single gpu multihead attention # This will be deprecated. attn_weights: Optional[Tensor] = None return attn, attn_weights