Esempio n. 1
0
def cond_with_context(inputs, cond_layer, context, cond_type, cond_act):
    cond_act_func = base_utils.act_to_func(cond_act)
    cond_out = cond_layer(context)
    if cond_type == 'shift':
        inputs += cond_out
    elif cond_type == 'affine':
        shift, scale = tf.split(cond_out, num_or_size_splits=2, axis=-1)
        inputs *= cond_act_func(scale)
        inputs += cond_act_func(shift)
    return inputs
Esempio n. 2
0
    def __init__(self,
                 hidden_size,
                 num_heads=1,
                 num_channels_per_head=None,
                 mask=None,
                 kernel_initializer='glorot_uniform',
                 nd_block_size=None,
                 resolution=None,
                 cond_init='glorot_uniform',
                 cond_k=False,
                 cond_q=False,
                 cond_v=False,
                 cond_scale=False,
                 cond_act='identity',
                 **kwargs):
        super(SelfAttentionND, self).__init__(**kwargs)
        if nd_block_size:
            nd_block_size = list(nd_block_size)
        num_channels_per_head = num_channels_per_head or hidden_size // num_heads
        self.num_filters = [num_heads, num_channels_per_head]
        self.kernel_initializer = kernel_initializer
        self.hidden_size = hidden_size
        self.cond_k = cond_k
        self.cond_q = cond_q
        self.cond_v = cond_v
        self.cond_scale = cond_scale
        self.cond_init = cond_init
        self.cond_act_func = base_utils.act_to_func(cond_act)
        self.project_cond_q, self.project_cond_k, self.project_cond_v = None, None, None
        self.cond_filters = self.num_filters
        if cond_scale:
            self.cond_filters = [num_heads, 2 * num_channels_per_head]

        self.nd_block_size = nd_block_size
        self.resolution = resolution
        self.mask = mask
        self.num_channels_per_head = num_channels_per_head
        self.num_heads = num_heads
        self.hidden_size = hidden_size

        # By default, apply attention in third last dimension.
        # Last 2 dimensions are heads, channels.
        self.attention_dim_q = self.attention_dim_k = -3

        # Self attention type.
        self.is_block_attention = True if self.nd_block_size else False
Esempio n. 3
0
 def __init__(self,
              spatial_average='learnable',
              sequence='sc',
              out_init='glorot_uniform',
              out_act='identity',
              **kwargs):
     super(ConditionalLayerNorm, self).__init__(**kwargs)
     self.spatial_average = spatial_average
     self.sequence = sequence
     self.out_init = out_init
     self.out_act = out_act
     self.out_act_func = base_utils.act_to_func(out_act)
     if self.spatial_average not in ['mean', 'learnable']:
         raise ValueError(
             'Expected spatial average to be "mean" or "learnable" ,'
             'got %s' % self.spatial_average)
     if self.sequence not in ['sc', 'cs']:
         raise ValueError('Expected sequence to be "sc" or "cs" ,'
                          'got %s' % self.sequence)