Example #1
0
 def __init__(self, name, num_splits, input_size, ff_size, dropout,
              dropout_prob, epsilon, use_default_memory_proportion,
              available_memory_proportion, **kwargs):
     scope_provider = kwargs['scope_provider']
     super().__init__(params=[], scope=scope_provider.get_scope(name), **kwargs)
     ffwd_splits = []
     self.split_size = input_size // num_splits
     self.name = name
     for i in range(num_splits):
         ffwd_splits.append(
             FeedForward(f'{name}/Split{i}',
                         self.split_size,
                         ff_size,
                         dropout,
                         dropout_prob,
                         epsilon,
                         residual=False,
                         use_default_memory_proportion=use_default_memory_proportion,
                         available_memory_proportion=available_memory_proportion,
                         **kwargs))
     self.layers = ffwd_splits
     self.accum_scope = scope_provider.get_scope(f'{name}/FFAccum', 'next')
     self.norm = Norm(
         scope_provider.get_scope(
             f'{name}/FFNorm', self.accum_scope.execution_phase),
         input_size, epsilon, **kwargs)
     if dropout:
         self.dropout = Dropout(
             scope_provider.get_scope(f'{name}/FFDropout',
                                      self.accum_scope.execution_phase), dropout_prob,
             **kwargs)
     else:
         self.dropout = lambda x: x
     self.total_execution_phases = self.total_phases()
Example #2
0
    def __init__(self,
                 name,
                 input_size,
                 ff_size,
                 dropout,
                 dropout_prob,
                 epsilon,
                 residual=True,
                 intermediate_act_func='gelu',
                 alpha=None,
                 increment_scope=True,
                 serialize_matmul=False,
                 use_default_memory_proportion=True,
                 available_memory_proportion=None,
                 **kwargs):
        scope_provider = kwargs['scope_provider']
        self.apply_dropout = dropout
        if increment_scope:
            scope = scope_provider.get_scope(name, 'next')
        else:
            scope = scope_provider.get_scope(name, 'prev')
        super(FeedForward, self).__init__(params=[], scope=scope, **kwargs)
        self.residual = residual

        if serialize_matmul:
            split = Split(dim='output_channels',
                          num_splits=ff_size // input_size)
        else:
            split = None
        self.dense1 = Dense(
            scope_provider.get_scope("1", 'prev'),
            input_size,
            ff_size,
            split=split,
            activation=intermediate_act_func,
            alpha=alpha,
            use_default_memory_proportion=use_default_memory_proportion,
            available_memory_proportion=available_memory_proportion,
            **kwargs)
        if serialize_matmul:
            split = Split(dim='reducing_dim', num_splits=ff_size // input_size)
        else:
            split = None
        self.dense2 = Dense(
            scope_provider.get_scope("2", "prev"),
            ff_size,
            input_size,
            split=split,
            activation=None,
            use_default_memory_proportion=use_default_memory_proportion,
            available_memory_proportion=available_memory_proportion,
            **kwargs)
        if residual:
            if dropout:
                self.dropout = Dropout(
                    scope_provider.get_scope("Dropout", "prev"), dropout_prob,
                    **kwargs)
            self.norm = Norm(scope_provider.get_scope("Norm", "prev"),
                             input_size, epsilon, **kwargs)
        self.total_execution_phases = self.total_phases()
    def __init__(self, name: str, num_splits, hidden_size, num_heads,
                 serialize_matmul, available_memory_proportion, epsilon,
                 dropout, dropout_prob, attn_dropout, attn_dropout_prob,
                 batch_size, sequence_length, dtype, task, num_mask_tokens,
                 use_default_mem_proportion, **kwargs):
        scope_provider = kwargs['scope_provider']
        # AttentionSplitHidden splits the num_heads, keeping size_per_head same.
        # Since hidden_size = num_heads * size_per_head , num_heads and hiddden_size
        # should be multiple of num_splits.

        if hidden_size % num_splits:
            raise ValueError('Hidden size must be a multiple of num_splits.')

        if num_heads % num_splits:
            raise ValueError('Num heads must be a multiple of num_splits.')

        super().__init__(params=[],
                         scope=scope_provider.get_scope(name),
                         **kwargs)
        attention_splits = []
        self.split_size = hidden_size // num_splits
        self.name = name
        for i in range(num_splits):
            attention_splits.append(
                Attention(
                    f"Split{i}",
                    hidden_size,
                    self.split_size,
                    num_heads // num_splits,
                    serialize_matmul,
                    available_memory_proportion,
                    epsilon,
                    dropout,
                    dropout_prob,
                    attn_dropout,
                    attn_dropout_prob,
                    batch_size,
                    sequence_length,
                    dtype,
                    task,
                    num_mask_tokens,
                    residual=False,
                    use_default_mem_proportion=use_default_mem_proportion,
                    **kwargs))
        self.layers = attention_splits
        self.accum_scope = scope_provider.get_scope(f'AttnAccum', 'next')
        self.norm = Norm(
            scope_provider.get_scope(f'AttnNorm',
                                     self.accum_scope.execution_phase),
            hidden_size, epsilon, dtype, **kwargs)
        if dropout:
            self.dropout = Dropout(scope_provider.get_scope(
                f'AttnDropout', self.accum_scope.execution_phase),
                                   dropout_prob,
                                   dtype=dtype,
                                   **kwargs)
        else:
            self.dropout = lambda x: x
Example #4
0
    def __init__(self, vocab_size, hidden_size, sequence_length,
                 max_positional_length, num_vocab_splits, epsilon,
                 apply_dropout, dropout_prob, mode, dtype, detach,
                 weight_transposed, custom=True, **kwargs):
        scope_provider = kwargs['scope_provider']
        additional_scopes = [kwargs['builder'].outlineAttributes({'outline_scope': 'Embeddings'})]
        scope = scope_provider.get_scope('Embeddings', additional_scopes=additional_scopes)
        super().__init__(scope, **kwargs)
        if num_vocab_splits > 1:
            self.token_embedding = EmbeddingSerialised(
                scope_provider.get_scope('Token'),
                input_dim=vocab_size,
                output_dim=hidden_size,
                num_splits=num_vocab_splits,
                custom=custom,
                dtype=dtype,
                detach=detach,
                weight_transposed=weight_transposed,
                **kwargs)
        else:
            self.token_embedding = Embedding(
                scope_provider.get_scope('Token', execution_phase='next'),
                input_dim=vocab_size,
                output_dim=hidden_size,
                custom=custom,
                dtype=dtype,
                detach=detach,
                weight_transposed=weight_transposed,
                **kwargs)
        num_segments = 2
        self.segment_embedding = Embedding(
            scope_provider.get_scope(
                'Segment', execution_phase='next'), num_segments,
            hidden_size, dtype, **kwargs)

        self.position_embedding = Embedding(
            scope_provider.get_scope('Position', execution_phase='prev'),
            max_positional_length,
            hidden_size,
            dtype,
            **kwargs)

        self.add = Add(scope_provider.get_scope('Sum', execution_phase='prev'), **kwargs)
        self.norm = Norm(scope_provider.get_scope('Norm', execution_phase='prev'),
                         hidden_size, epsilon, dtype, **kwargs)
        self.apply_dropout = apply_dropout
        if apply_dropout:
            self.dropout = Dropout(
                scope_provider.get_scope(
                    'Dropout', execution_phase='prev'), dropout_prob,
                **kwargs)
        self.total_execution_phases = self.total_phases()
Example #5
0
    def __init__(self,
                 name: str,
                 input_size,
                 hidden_size,
                 num_heads,
                 serialize_matmul,
                 available_memory_proportion,
                 epsilon,
                 dropout,
                 dropout_prob,
                 attn_dropout,
                 attn_dropout_prob,
                 batch_size,
                 sequence_length,
                 dtype,
                 task,
                 num_mask_tokens,
                 split_qkv=False,
                 residual=True,
                 prefetch_masks=True,
                 use_default_mem_proportion=True,
                 mask=None,
                 **kwargs):
        if split_qkv:
            params = [
                Parameter(name='Q',
                          shape=[input_size, hidden_size],
                          value=None),
                Parameter(name='K',
                          shape=[input_size, hidden_size],
                          value=None),
                Parameter(name='V',
                          shape=[input_size, hidden_size],
                          value=None),
                Parameter(name='Out',
                          shape=[hidden_size, input_size],
                          value=None)
            ]
        else:
            params = [
                Parameter(name='QKV',
                          shape=[input_size, 3 * hidden_size],
                          value=None),
                Parameter(name='Out',
                          shape=[hidden_size, input_size],
                          value=None)
            ]
        scope_provider = kwargs['scope_provider']
        super(Attention,
              self).__init__(params=params,
                             scope=scope_provider.get_scope(name, 'next'),
                             dtype=dtype,
                             **kwargs)
        self.num_heads = num_heads
        self.hidden_size = hidden_size
        self.serialize_matmul = serialize_matmul
        self.available_memory_proportion = available_memory_proportion
        self.use_default_mem_proportion = use_default_mem_proportion
        self.split_qkv = split_qkv
        self.batch_size = batch_size
        self.seq_len = sequence_length
        if hidden_size % num_heads != 0:
            raise ValueError('Hidden size must be a multiple of num_heads')
        self.qkv_length = hidden_size // num_heads
        self.dtype = dtype
        self.residual = residual
        self.task = task
        self.num_mask_tokens = num_mask_tokens
        self.mask = mask
        self.prefetch_masks = prefetch_masks
        if prefetch_masks:
            additional_scopes = [
                self.builder.recomputeOutput(popart.RecomputeType.Checkpoint),
                self.builder.outputTensorLocation(
                    popart.TensorLocation(popart.TensorStorage.OnChip))
            ]
            self.mask_execution_phase = scope_provider.get_scope(
                'Mask', 'prev').execution_phase % 2
            self.mask_scope = scope_provider.get_scope(
                'Mask',
                self.mask_execution_phase,
                additional_scopes=additional_scopes)
        else:
            self.mask_scope = scope_provider.get_scope('Mask', 'prev')

        if self.residual:
            self.norm = Norm(scope_provider.get_scope('Norm', 'prev'),
                             hidden_size, epsilon, dtype, **kwargs)
        if dropout:
            self.dropout = Dropout(scope_provider.get_scope('Dropout', 'prev'),
                                   dropout_prob, **kwargs)
        else:
            self.dropout = lambda x: x

        if attn_dropout:
            self.attn_dropout = Dropout(
                scope_provider.get_scope('AttnDropout', 'prev'),
                attn_dropout_prob, **kwargs)
        else:
            self.attn_dropout = lambda x: x

        self.total_execution_phases = self.total_phases()