def __init__(self, name, num_splits, input_size, ff_size, dropout, dropout_prob, epsilon, use_default_memory_proportion, available_memory_proportion, **kwargs): scope_provider = kwargs['scope_provider'] super().__init__(params=[], scope=scope_provider.get_scope(name), **kwargs) ffwd_splits = [] self.split_size = input_size // num_splits self.name = name for i in range(num_splits): ffwd_splits.append( FeedForward(f'{name}/Split{i}', self.split_size, ff_size, dropout, dropout_prob, epsilon, residual=False, use_default_memory_proportion=use_default_memory_proportion, available_memory_proportion=available_memory_proportion, **kwargs)) self.layers = ffwd_splits self.accum_scope = scope_provider.get_scope(f'{name}/FFAccum', 'next') self.norm = Norm( scope_provider.get_scope( f'{name}/FFNorm', self.accum_scope.execution_phase), input_size, epsilon, **kwargs) if dropout: self.dropout = Dropout( scope_provider.get_scope(f'{name}/FFDropout', self.accum_scope.execution_phase), dropout_prob, **kwargs) else: self.dropout = lambda x: x self.total_execution_phases = self.total_phases()
def __init__(self, name, input_size, ff_size, dropout, dropout_prob, epsilon, residual=True, intermediate_act_func='gelu', alpha=None, increment_scope=True, serialize_matmul=False, use_default_memory_proportion=True, available_memory_proportion=None, **kwargs): scope_provider = kwargs['scope_provider'] self.apply_dropout = dropout if increment_scope: scope = scope_provider.get_scope(name, 'next') else: scope = scope_provider.get_scope(name, 'prev') super(FeedForward, self).__init__(params=[], scope=scope, **kwargs) self.residual = residual if serialize_matmul: split = Split(dim='output_channels', num_splits=ff_size // input_size) else: split = None self.dense1 = Dense( scope_provider.get_scope("1", 'prev'), input_size, ff_size, split=split, activation=intermediate_act_func, alpha=alpha, use_default_memory_proportion=use_default_memory_proportion, available_memory_proportion=available_memory_proportion, **kwargs) if serialize_matmul: split = Split(dim='reducing_dim', num_splits=ff_size // input_size) else: split = None self.dense2 = Dense( scope_provider.get_scope("2", "prev"), ff_size, input_size, split=split, activation=None, use_default_memory_proportion=use_default_memory_proportion, available_memory_proportion=available_memory_proportion, **kwargs) if residual: if dropout: self.dropout = Dropout( scope_provider.get_scope("Dropout", "prev"), dropout_prob, **kwargs) self.norm = Norm(scope_provider.get_scope("Norm", "prev"), input_size, epsilon, **kwargs) self.total_execution_phases = self.total_phases()
def __init__(self, name: str, num_splits, hidden_size, num_heads, serialize_matmul, available_memory_proportion, epsilon, dropout, dropout_prob, attn_dropout, attn_dropout_prob, batch_size, sequence_length, dtype, task, num_mask_tokens, use_default_mem_proportion, **kwargs): scope_provider = kwargs['scope_provider'] # AttentionSplitHidden splits the num_heads, keeping size_per_head same. # Since hidden_size = num_heads * size_per_head , num_heads and hiddden_size # should be multiple of num_splits. if hidden_size % num_splits: raise ValueError('Hidden size must be a multiple of num_splits.') if num_heads % num_splits: raise ValueError('Num heads must be a multiple of num_splits.') super().__init__(params=[], scope=scope_provider.get_scope(name), **kwargs) attention_splits = [] self.split_size = hidden_size // num_splits self.name = name for i in range(num_splits): attention_splits.append( Attention( f"Split{i}", hidden_size, self.split_size, num_heads // num_splits, serialize_matmul, available_memory_proportion, epsilon, dropout, dropout_prob, attn_dropout, attn_dropout_prob, batch_size, sequence_length, dtype, task, num_mask_tokens, residual=False, use_default_mem_proportion=use_default_mem_proportion, **kwargs)) self.layers = attention_splits self.accum_scope = scope_provider.get_scope(f'AttnAccum', 'next') self.norm = Norm( scope_provider.get_scope(f'AttnNorm', self.accum_scope.execution_phase), hidden_size, epsilon, dtype, **kwargs) if dropout: self.dropout = Dropout(scope_provider.get_scope( f'AttnDropout', self.accum_scope.execution_phase), dropout_prob, dtype=dtype, **kwargs) else: self.dropout = lambda x: x
def __init__(self, vocab_size, hidden_size, sequence_length, max_positional_length, num_vocab_splits, epsilon, apply_dropout, dropout_prob, mode, dtype, detach, weight_transposed, custom=True, **kwargs): scope_provider = kwargs['scope_provider'] additional_scopes = [kwargs['builder'].outlineAttributes({'outline_scope': 'Embeddings'})] scope = scope_provider.get_scope('Embeddings', additional_scopes=additional_scopes) super().__init__(scope, **kwargs) if num_vocab_splits > 1: self.token_embedding = EmbeddingSerialised( scope_provider.get_scope('Token'), input_dim=vocab_size, output_dim=hidden_size, num_splits=num_vocab_splits, custom=custom, dtype=dtype, detach=detach, weight_transposed=weight_transposed, **kwargs) else: self.token_embedding = Embedding( scope_provider.get_scope('Token', execution_phase='next'), input_dim=vocab_size, output_dim=hidden_size, custom=custom, dtype=dtype, detach=detach, weight_transposed=weight_transposed, **kwargs) num_segments = 2 self.segment_embedding = Embedding( scope_provider.get_scope( 'Segment', execution_phase='next'), num_segments, hidden_size, dtype, **kwargs) self.position_embedding = Embedding( scope_provider.get_scope('Position', execution_phase='prev'), max_positional_length, hidden_size, dtype, **kwargs) self.add = Add(scope_provider.get_scope('Sum', execution_phase='prev'), **kwargs) self.norm = Norm(scope_provider.get_scope('Norm', execution_phase='prev'), hidden_size, epsilon, dtype, **kwargs) self.apply_dropout = apply_dropout if apply_dropout: self.dropout = Dropout( scope_provider.get_scope( 'Dropout', execution_phase='prev'), dropout_prob, **kwargs) self.total_execution_phases = self.total_phases()
def __init__(self, name: str, input_size, hidden_size, num_heads, serialize_matmul, available_memory_proportion, epsilon, dropout, dropout_prob, attn_dropout, attn_dropout_prob, batch_size, sequence_length, dtype, task, num_mask_tokens, split_qkv=False, residual=True, prefetch_masks=True, use_default_mem_proportion=True, mask=None, **kwargs): if split_qkv: params = [ Parameter(name='Q', shape=[input_size, hidden_size], value=None), Parameter(name='K', shape=[input_size, hidden_size], value=None), Parameter(name='V', shape=[input_size, hidden_size], value=None), Parameter(name='Out', shape=[hidden_size, input_size], value=None) ] else: params = [ Parameter(name='QKV', shape=[input_size, 3 * hidden_size], value=None), Parameter(name='Out', shape=[hidden_size, input_size], value=None) ] scope_provider = kwargs['scope_provider'] super(Attention, self).__init__(params=params, scope=scope_provider.get_scope(name, 'next'), dtype=dtype, **kwargs) self.num_heads = num_heads self.hidden_size = hidden_size self.serialize_matmul = serialize_matmul self.available_memory_proportion = available_memory_proportion self.use_default_mem_proportion = use_default_mem_proportion self.split_qkv = split_qkv self.batch_size = batch_size self.seq_len = sequence_length if hidden_size % num_heads != 0: raise ValueError('Hidden size must be a multiple of num_heads') self.qkv_length = hidden_size // num_heads self.dtype = dtype self.residual = residual self.task = task self.num_mask_tokens = num_mask_tokens self.mask = mask self.prefetch_masks = prefetch_masks if prefetch_masks: additional_scopes = [ self.builder.recomputeOutput(popart.RecomputeType.Checkpoint), self.builder.outputTensorLocation( popart.TensorLocation(popart.TensorStorage.OnChip)) ] self.mask_execution_phase = scope_provider.get_scope( 'Mask', 'prev').execution_phase % 2 self.mask_scope = scope_provider.get_scope( 'Mask', self.mask_execution_phase, additional_scopes=additional_scopes) else: self.mask_scope = scope_provider.get_scope('Mask', 'prev') if self.residual: self.norm = Norm(scope_provider.get_scope('Norm', 'prev'), hidden_size, epsilon, dtype, **kwargs) if dropout: self.dropout = Dropout(scope_provider.get_scope('Dropout', 'prev'), dropout_prob, **kwargs) else: self.dropout = lambda x: x if attn_dropout: self.attn_dropout = Dropout( scope_provider.get_scope('AttnDropout', 'prev'), attn_dropout_prob, **kwargs) else: self.attn_dropout = lambda x: x self.total_execution_phases = self.total_phases()