Esempio n. 1
0
    def __init__(self, configs: Configs):
        # Get the device
        self.device = torch.device('cpu')
        if torch.cuda.is_available():
            self.device = torch.device('cuda:0')
        # Initialize the dataset
        self.dataset = TinyShakespeareDataset(configs.seq_len)
        # Initialize the dataloader
        self.dataloader = DataLoader(self.dataset,
                                     batch_size=configs.batch_size,
                                     collate_fn=transpose_batch,
                                     shuffle=True)

        # FFN with Gated Linear Unit
        # $$FFN_{GLU}(x)(x, W_1, V, W_2) = (\sigma(x W_1) \otimes x V) W_2$$
        if configs.glu_variant == 'GLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.Sigmoid(), True, False, False, False)
        # FFN with Bilinear hidden layer
        # $$FFN_{Bilinear}(x)(x, W_1, V, W_2) = (x W_1 \otimes x V) W_2$$
        elif configs.glu_variant == 'Bilinear':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.Identity(), True, False, False, False)
        # FFN with ReLU gate
        # $$FFN_{ReGLU}(x)(x, W_1, V, W_2) = (\max(0, x W_1) \otimes x V) W_2$$
        elif configs.glu_variant == 'ReGLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.ReLU(), True, False, False, False)
        # FFN with GELU gate
        # $$FFN_{GEGLU}(x)(x, W_1, V, W_2) = (\text{GELU}(x W_1) \otimes x V) W_2$$
        elif configs.glu_variant == 'GEGLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.GELU(), True, False, False, False)
        # FFN with Swish gate
        # $$FFN_{SwiGLU}(x)(x, W_1, V, W_2) = (\text{Swish}_1(x W_1) \otimes x V) W_2$$
        # where $\text{Swish}_\beta(x) = x \sigma(\beta x)$
        elif configs.glu_variant == 'SwiGLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.SiLU(), True, False, False, False)
        # FFN with ReLU activation
        # $$FFN_{ReLU}(x)(x, W_1, W_2, b_1, b_2) = \text{ReLU}_1(x W_1 + b_1) W_2 + b_2$$
        elif configs.glu_variant == 'ReLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.ReLU())
        # FFN with ReLU activation
        # $$FFN_{GELU}(x)(x, W_1, W_2, b_1, b_2) = \text{GELU}_1(x W_1 + b_1) W_2 + b_2$$
        elif configs.glu_variant == 'GELU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.GELU())
        else:
            raise ValueError(f'Unknown variant {configs.glu_variant}')

        # Number of different characters
        n_chars = len(self.dataset.stoi)

        # Initialize [Multi-Head Attention module](../mha.html)
        mha = MultiHeadAttention(configs.n_heads, configs.d_model,
                                 configs.dropout)
        # Initialize the [Transformer Block](../models.html#TransformerLayer)
        transformer_layer = TransformerLayer(d_model=configs.d_model,
                                             self_attn=mha,
                                             src_attn=None,
                                             feed_forward=ffn,
                                             dropout_prob=configs.dropout)
        # Initialize the model with an
        # [embedding layer](../models.html#EmbeddingsWithPositionalEncoding)
        # (with fixed positional encoding)
        # [transformer encoder](../models.html#Encoder) and
        # a linear layer to generate logits.
        self.model = AutoregressiveModel(
            EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
            Encoder(transformer_layer, configs.n_layers),
            nn.Linear(configs.d_model, n_chars))

        # Move the model to the current device
        self.model.to(self.device)

        # Initialize [Noam optimizer](../../optimizers/noam.html)
        self.optimizer = Noam(self.model.parameters(),
                              lr=1.0,
                              warmup=2_000,
                              d_model=configs.d_model)

        # Cross-entropy loss
        self.loss_func = nn.CrossEntropyLoss()
        # Number of training epochs;
        # *note that our dataset definition repeats the data `seq_len` times in a single epoch
        self.epochs = configs.epochs
        # Gradient clipping norm
        self.grad_norm_clip = configs.grad_norm_clip

        # Set tracker configurations
        tracker.set_scalar("loss.*", True)
Esempio n. 2
0
    def __init__(self,
                 img_size=224,
                 patch_size=16,
                 in_chans=3,
                 num_classes=80,
                 embed_dim=768,
                 depth=12,
                 num_heads=12,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 hybrid_backbone=None,
                 norm_layer=None,
                 init_values=None,
                 use_checkpoint=False,
                 use_abs_pos_emb=True,
                 use_rel_pos_bias=False,
                 use_shared_rel_pos_bias=False,
                 out_indices=[3, 5, 7, 11]):
        super().__init__()
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models

        if hybrid_backbone is not None:
            self.patch_embed = HybridEmbed(hybrid_backbone,
                                           img_size=img_size,
                                           in_chans=in_chans,
                                           embed_dim=embed_dim)
        else:
            self.patch_embed = PatchEmbed(img_size=img_size,
                                          patch_size=patch_size,
                                          in_chans=in_chans,
                                          embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
        self.out_indices = out_indices

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        if use_abs_pos_emb:
            self.pos_embed = nn.Parameter(
                torch.zeros(1, num_patches + 1, embed_dim))
        else:
            self.pos_embed = None
        self.pos_drop = nn.Dropout(p=drop_rate)

        if use_shared_rel_pos_bias:
            self.rel_pos_bias = RelativePositionBias(
                window_size=self.patch_embed.patch_shape, num_heads=num_heads)
        else:
            self.rel_pos_bias = None

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
               ]  # stochastic depth decay rule
        self.use_rel_pos_bias = use_rel_pos_bias
        self.use_checkpoint = use_checkpoint
        self.blocks = nn.ModuleList([
            Block(dim=embed_dim,
                  num_heads=num_heads,
                  mlp_ratio=mlp_ratio,
                  qkv_bias=qkv_bias,
                  qk_scale=qk_scale,
                  drop=drop_rate,
                  attn_drop=attn_drop_rate,
                  drop_path=dpr[i],
                  norm_layer=norm_layer,
                  init_values=init_values,
                  window_size=self.patch_embed.patch_shape
                  if use_rel_pos_bias else None) for i in range(depth)
        ])

        if self.pos_embed is not None:
            trunc_normal_(self.pos_embed, std=.02)
        trunc_normal_(self.cls_token, std=.02)
        # trunc_normal_(self.mask_token, std=.02)
        self.out_indices = out_indices

        if patch_size == 16:
            self.fpn1 = nn.Sequential(
                nn.ConvTranspose2d(embed_dim,
                                   embed_dim,
                                   kernel_size=2,
                                   stride=2),
                nn.SyncBatchNorm(embed_dim),
                nn.GELU(),
                nn.ConvTranspose2d(embed_dim,
                                   embed_dim,
                                   kernel_size=2,
                                   stride=2),
            )

            self.fpn2 = nn.Sequential(
                nn.ConvTranspose2d(embed_dim,
                                   embed_dim,
                                   kernel_size=2,
                                   stride=2), )

            self.fpn3 = nn.Identity()

            self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
        elif patch_size == 8:
            self.fpn1 = nn.Sequential(
                nn.ConvTranspose2d(embed_dim,
                                   embed_dim,
                                   kernel_size=2,
                                   stride=2), )

            self.fpn2 = nn.Identity()

            self.fpn3 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2), )

            self.fpn4 = nn.Sequential(nn.MaxPool2d(kernel_size=4, stride=4), )
        self.apply(self._init_weights)
        self.fix_init_weight()
Esempio n. 3
0
 def __setstate__(self, state):
     if 'activation' not in state:
         warnings.warn(message="'state' does not contain 'activation'. 'nn.GELU()' is used as default.")
         state['activation'] = nn.GELU()
     super(CustomTransformerEncoderLayer, self).__setstate__(state)
Esempio n. 4
0
    def __init__(self, in_features, start_index=1):
        super(ProjectReadout, self).__init__()
        self.start_index = start_index

        self.project = nn.Sequential(nn.Linear(2 * in_features, in_features),
                                     nn.GELU())
Esempio n. 5
0
 def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
     super().__init__()
     self.ffn = FeedForward(d_model, d_ff, dropout, nn.GELU())
Esempio n. 6
0
 def _get_activation_fn(self, activation):
     if callable(activation): return activation()
     elif activation.lower() == "relu": return nn.ReLU()
     elif activation.lower() == "gelu": return nn.GELU()
     raise ValueError(f'{activation} is not available. You can use "relu", "gelu", or a callable')
Esempio n. 7
0
def conv_3x3_bn(inp, oup, image_size, downsample=False):
    stride = 1 if downsample == False else 2
    return nn.Sequential(nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
                         nn.BatchNorm2d(oup), nn.GELU())
    def __init__(self, dim, kernel_size, groups, window_size, num_head,
                 reduction_ratio):
        super(CoT_Mixer, self).__init__()

        self.num_head = num_head
        hidden_dim = self.num_head * window_size**2

        self.key_embed = nn.Sequential(
            nn.Conv2d(
                in_channels=dim,
                out_channels=dim,
                kernel_size=kernel_size,
                stride=1,
                padding=kernel_size // 2,
                dilation=1,
                groups=groups,
                bias=False,
            ), nn.BatchNorm2d(num_features=dim), nn.ReLU(inplace=True))

        self.attn_map = nn.Sequential(
            nn.Conv2d(
                in_channels=dim * 2,
                out_channels=dim,
                kernel_size=1,
                stride=1,
                padding=0,
                dilation=1,
                groups=1,
                bias=False,
            ), nn.BatchNorm2d(num_features=dim), nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=dim,
                out_channels=hidden_dim,
                kernel_size=1,
                stride=1,
                padding=0,
                dilation=1,
                groups=1,
                bias=False,
            ), nn.BatchNorm2d(num_features=hidden_dim))

        self.value_embed = nn.Sequential(
            nn.Conv2d(
                in_channels=dim,
                out_channels=dim,
                kernel_size=1,
                stride=1,
                padding=0,
                dilation=1,
                groups=1,
                bias=False,
            ), nn.BatchNorm2d(num_features=dim))

        self.pool = nn.AdaptiveAvgPool2d(output_size=window_size)

        self.norm = nn.LayerNorm(normalized_shape=dim)

        self.mlp = nn.Sequential(
            nn.Linear(in_features=dim,
                      out_features=dim // reduction_ratio,
                      bias=False), nn.GELU(),
            nn.Linear(in_features=dim // reduction_ratio,
                      out_features=dim,
                      bias=False))
Esempio n. 9
0
    def __init__(self,
                 model_size,
                 inner_size,
                 dropout=0.,
                 variational=False,
                 activation='relu',
                 glu=False,
                 weight_drop=0.0,
                 dropout_residual=False,
                 res_dropout=0.0):
        super().__init__()
        self.model_size = model_size
        self.inner_size = inner_size
        self.dropout = dropout
        self.bias = True
        self.variational = variational
        self.activation = activation
        self.glu = glu
        self.weight_drop = weight_drop
        self.autograd = False
        self.fused_dropout_add = False
        self.dropout_residual = dropout_residual
        self.res_dropout = res_dropout

        if self.activation == 'relu':
            if self.glu:
                self.act = nn.ReLU(inplace=True)
            else:
                self.act = ReLUDropout(p=self.dropout,
                                       variational=self.variational,
                                       batch_first=False)
        elif self.activation == 'gelu':
            self.act = nn.GELU()
        elif self.activation == 'agelu':
            self.act = AGELU()
        elif self.activation in ['silu', 'swish']:
            self.act = SiLU()
        elif self.activation in ['sigmoid']:
            if self.glu:
                self.act = nn.functional.glu
            else:
                print(
                    "Sigmoid activation function is recommended to be used with -glu"
                )
                raise NotImplementedError

        self.in_proj_weight = Parameter(
            torch.Tensor(inner_size * (2 if glu else 1), model_size))
        self.out_proj_weight = Parameter(torch.Tensor(model_size, inner_size))

        self.in_proj_bias = Parameter(
            torch.Tensor(inner_size * (2 if glu else 1)))
        self.out_proj_bias = Parameter(torch.Tensor(model_size))

        self.reset_parameters()

        self.fused = False

        # At the moment fused mlp is supported for RELU, SiLU, Swish, GELU and AGELU (approximated GELU)
        if not self.glu and \
                self.activation in ['relu', 'silu', 'swish', 'gelu', 'agelu'] and not self.variational:
            if self.activation == 'relu':
                from onmt.modules.mlp.mlp import mlp_relu_function
                if mlp_relu_function is not None:
                    self.fused_function = mlp_relu_function
                    self.fused = True
            elif self.activation in ['silu', 'swish']:
                from onmt.modules.mlp.mlp import mlp_silu_function
                if mlp_silu_function is not None:
                    self.fused_function = mlp_silu_function
                    self.fused = True
            elif self.activation == 'gelu':
                if self.dropout_residual:
                    from onmt.modules.mlp.mlp import mlp_gelu_dropout_add_function
                    if mlp_gelu_dropout_add_function is not None:
                        self.fused_function = mlp_gelu_dropout_add_function
                        self.fused = True
                        self.fused_dropout_add = True
                if not self.fused:
                    from onmt.modules.mlp.mlp import mlp_gelu_function
                    if mlp_gelu_function is not None:
                        self.fused_function = mlp_gelu_function
                        self.fused = True
            elif self.activation == 'agelu':
                from onmt.modules.mlp.mlp import mlp_agelu_function
                if mlp_agelu_function is not None:
                    self.fused_function = mlp_agelu_function
                    self.fused = True
Esempio n. 10
0
    def __init__(self):
        super(Model, self).__init__()

        self.act_0 = nn.GELU()
Esempio n. 11
0
 def __init__(self, in_features, hidden_features, out_features, p=0.):
     super().__init__()
     self.fc1 = nn.Linear(in_features, hidden_features)
     self.act = nn.GELU()
     self.fc2 = nn.Linear(hidden_features, out_features)
     self.drop = nn.Dropout(p)
 def __init__(self, d_model, d_ff, dropout=0.1):
     super(PositionwiseFeedForward, self).__init__()
     self.w_1 = nn.Linear(d_model, d_ff)
     self.w_2 = nn.Linear(d_ff, d_model)
     self.dropout = nn.Dropout(dropout)
     self.activation = nn.GELU()
Esempio n. 13
0
 def __init__(self, eps=1e-8):
     super().__init__()
     self.eps = eps
     self.fn = nn.Sequential(nn.LayerNorm(1), nn.GELU())
Esempio n. 14
0
def gelu():
    return nn.GELU()
Esempio n. 15
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 g_spectral_norm,
                 activation_fn,
                 conditional_bn,
                 z_dims_after_concat,
                 synchronized_bn,
                 upsample,
                 channel_ratio=4):
        super(GenBlock, self).__init__()
        self.conditional_bn = conditional_bn
        self.in_channels, self.out_channels = in_channels, out_channels
        self.upsample = upsample
        self.hidden_channels = self.in_channels // channel_ratio

        if self.conditional_bn:
            self.bn1 = ConditionalBatchNorm2d_for_skip_and_shared(
                num_features=in_channels,
                z_dims_after_concat=z_dims_after_concat,
                spectral_norm=g_spectral_norm,
                synchronized_bn=synchronized_bn)
            self.bn2 = ConditionalBatchNorm2d_for_skip_and_shared(
                num_features=self.hidden_channels,
                z_dims_after_concat=z_dims_after_concat,
                spectral_norm=g_spectral_norm,
                synchronized_bn=synchronized_bn)
            self.bn3 = ConditionalBatchNorm2d_for_skip_and_shared(
                num_features=self.hidden_channels,
                z_dims_after_concat=z_dims_after_concat,
                spectral_norm=g_spectral_norm,
                synchronized_bn=synchronized_bn)
            self.bn4 = ConditionalBatchNorm2d_for_skip_and_shared(
                num_features=self.hidden_channels,
                z_dims_after_concat=z_dims_after_concat,
                spectral_norm=g_spectral_norm,
                synchronized_bn=synchronized_bn)
        else:
            if synchronized_bn:
                self.bn1 = sync_batchnorm_2d(in_features=in_channels)
                self.bn2 = sync_batchnorm_2d(in_features=self.hidden_channels)
                self.bn3 = sync_batchnorm_2d(in_features=self.hidden_channels)
                self.bn4 = sync_batchnorm_2d(in_features=self.hidden_channels)
            else:
                self.bn1 = batchnorm_2d(in_features=in_channels)
                self.bn2 = batchnorm_2d(in_features=self.hidden_channels)
                self.bn3 = batchnorm_2d(in_features=self.hidden_channels)
                self.bn4 = batchnorm_2d(in_features=self.hidden_channels)

        if activation_fn == "ReLU":
            self.activation = nn.ReLU(inplace=True)
        elif activation_fn == "Leaky_ReLU":
            self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        elif activation_fn == "ELU":
            self.activation = nn.ELU(alpha=1.0, inplace=True)
        elif activation_fn == "GELU":
            self.activation = nn.GELU()
        else:
            raise NotImplementedError

        if g_spectral_norm:
            self.conv2d1 = snconv2d(in_channels=in_channels,
                                    out_channels=self.hidden_channels,
                                    kernel_size=1,
                                    stride=1,
                                    padding=0)
            self.conv2d2 = snconv2d(in_channels=self.hidden_channels,
                                    out_channels=self.hidden_channels,
                                    kernel_size=3,
                                    stride=1,
                                    padding=1)
            self.conv2d3 = snconv2d(in_channels=self.hidden_channels,
                                    out_channels=self.hidden_channels,
                                    kernel_size=3,
                                    stride=1,
                                    padding=1)
            self.conv2d4 = snconv2d(in_channels=self.hidden_channels,
                                    out_channels=out_channels,
                                    kernel_size=1,
                                    stride=1,
                                    padding=0)
        else:
            self.conv2d1 = conv2d(in_channels=in_channels,
                                  out_channels=self.hidden_channels,
                                  kernel_size=1,
                                  stride=1,
                                  padding=0)
            self.conv2d2 = conv2d(in_channels=self.hidden_channels,
                                  out_channels=self.hidden_channels,
                                  kernel_size=3,
                                  stride=1,
                                  padding=1)
            self.conv2d3 = conv2d(in_channels=self.hidden_channels,
                                  out_channels=self.hidden_channels,
                                  kernel_size=3,
                                  stride=1,
                                  padding=1)
            self.conv2d4 = conv2d(in_channels=self.hidden_channels,
                                  out_channels=out_channels,
                                  kernel_size=1,
                                  stride=1,
                                  padding=0)
Esempio n. 16
0
class _GELU(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))


_act_map = {
    "none": None,
    "relu": nn.ReLU(),
    "tanh": nn.Tanh(),
    "softmax": nn.Softmax(dim=-1),
    "sigmoid": nn.Sigmoid(),
    "leaky_relu": nn.LeakyReLU(1 / 5.5),
    "prelu": nn.PReLU(),
    "gelu": nn.GELU() if torch.__version__ >= "1.4.0" else _GELU()
}


def map_activation_str_to_layer(act_str):
    try:
        return _act_map[act_str]
    except:
        raise NotImplementedError(
            "Error: %s activation fuction is not supported now." % (act_str))


def anneal_fn(fn, t, T, lambda0=0.0, lambda1=1.0):
    if not fn or fn == "none":
        return lambda1
    elif fn == "logistic":
Esempio n. 17
0
    def __init__(self, img_size, d_conv_dim, d_spectral_norm, attention,
                 attention_after_nth_dis_block, activation_fn,
                 conditional_strategy, hypersphere_dim, num_classes,
                 nonlinear_embed, normalize_embed, synchronized_bn, initialize,
                 D_depth):
        super(Discriminator, self).__init__()
        d_in_dims_collection = {
            "32": [3] + [d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2],
            "64":
            [3] + [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8],
            "96": [3] + [
                d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8,
                d_conv_dim * 16
            ],
            "128": [3] + [
                d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8,
                d_conv_dim * 16
            ]
        }

        d_out_dims_collection = {
            "32":
            [d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2],
            "64": [
                d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8,
                d_conv_dim * 16
            ],
            "96": [
                d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8,
                d_conv_dim * 16, d_conv_dim * 16
            ],
            "128": [
                d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8,
                d_conv_dim * 16, d_conv_dim * 16
            ]
        }

        d_down = {
            "32": [False, True, True, False],
            "64": [False, True, True, True, True],
            "96": [False, True, True, True, True, True],
            "128": [False, True, True, True, True, True]
        }

        self.nonlinear_embed = nonlinear_embed
        self.normalize_embed = normalize_embed
        self.conditional_strategy = conditional_strategy

        self.in_dims = d_in_dims_collection[str(img_size)]
        self.out_dims = d_out_dims_collection[str(img_size)]
        down = d_down[str(img_size)]

        if d_spectral_norm:
            self.input_conv = snconv2d(in_channels=self.in_dims[0],
                                       out_channels=self.out_dims[0],
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)
        else:
            self.input_conv = conv2d(in_channels=self.in_dims[0],
                                     out_channels=self.out_dims[0],
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)

        self.blocks = []
        for index in range(len(self.in_dims)):
            if index == 0:
                self.blocks += [[self.input_conv]]
            else:
                self.blocks += [[
                    DiscBlock(in_channels=self.in_dims[index]
                              if d_index == 0 else self.out_dims[index],
                              out_channels=self.out_dims[index],
                              d_spectral_norm=d_spectral_norm,
                              activation_fn=activation_fn,
                              synchronized_bn=synchronized_bn,
                              downsample=True
                              if down[index] and d_index == 0 else False)
                ] for d_index in range(D_depth)]

            if index == attention_after_nth_dis_block and attention is True:
                self.blocks += [[
                    Self_Attn(self.out_dims[index], d_spectral_norm)
                ]]

        self.blocks = nn.ModuleList(
            [nn.ModuleList(block) for block in self.blocks])

        if activation_fn == "ReLU":
            self.activation = nn.ReLU(inplace=True)
        elif activation_fn == "Leaky_ReLU":
            self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        elif activation_fn == "ELU":
            self.activation = nn.ELU(alpha=1.0, inplace=True)
        elif activation_fn == "GELU":
            self.activation = nn.GELU()
        else:
            raise NotImplementedError

        if d_spectral_norm:
            self.linear1 = snlinear(in_features=self.out_dims[-1],
                                    out_features=1)
            if self.conditional_strategy in [
                    'ContraGAN', 'Proxy_NCA_GAN', 'XT_Xent_GAN'
            ]:
                self.linear2 = snlinear(in_features=self.out_dims[-1],
                                        out_features=hypersphere_dim)
                if self.nonlinear_embed:
                    self.linear3 = snlinear(in_features=hypersphere_dim,
                                            out_features=hypersphere_dim)
                self.embedding = sn_embedding(num_classes, hypersphere_dim)
            elif self.conditional_strategy == 'cGAN':
                self.embedding = sn_embedding(num_classes, self.out_dims[-1])
            elif self.conditional_strategy == 'ACGAN':
                self.linear4 = snlinear(in_features=self.out_dims[-1],
                                        out_features=num_classes)
            else:
                pass
        else:
            self.linear1 = linear(in_features=self.out_dims[-1],
                                  out_features=1)
            if self.conditional_strategy in [
                    'ContraGAN', 'Proxy_NCA_GAN', 'XT_Xent_GAN'
            ]:
                self.linear2 = linear(in_features=self.out_dims[-1],
                                      out_features=hypersphere_dim)
                if self.nonlinear_embed:
                    self.linear3 = linear(in_features=hypersphere_dim,
                                          out_features=hypersphere_dim)
                self.embedding = embedding(num_classes, hypersphere_dim)
            elif self.conditional_strategy == 'cGAN':
                self.embedding = embedding(num_classes, self.out_dims[-1])
            elif self.conditional_strategy == 'ACGAN':
                self.linear4 = linear(in_features=self.out_dims[-1],
                                      out_features=num_classes)
            else:
                pass

        # Weight init
        if initialize is not False:
            init_weights(self.modules, initialize)
    def __init__(self,
                 block,
                 layers,
                 num_classes=1000,
                 zero_init_residual=False,
                 groups=1,
                 width_per_group=64,
                 replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = groups * width_per_group
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(
                                 replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(groups * self.base_width,
                               self.inplanes,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.gelu = nn.GELU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.layer1 = self._make_layer(block,
                                       self.inplanes,
                                       layers[0],
                                       stride=2)
        self.layer2 = self._make_layer(block,
                                       self.inplanes,
                                       layers[1],
                                       stride=4,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block,
                                       self.inplanes,
                                       layers[2],
                                       stride=4,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block,
                                       self.inplanes,
                                       layers[3],
                                       stride=4,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(self.inplanes, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight,
                                        mode='fan_out',
                                        nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)
Esempio n. 19
0
 def __init__(self, inp, oup, expansion=0.25):
     super().__init__()
     self.avg_pool = nn.AdaptiveAvgPool2d(1)
     self.fc = nn.Sequential(
         nn.Linear(oup, int(inp * expansion), bias=False), nn.GELU(),
         nn.Linear(int(inp * expansion), oup, bias=False), nn.Sigmoid())
Esempio n. 20
0
    def __init__(self, cfg: Wav2VecConfig):
        super().__init__()

        self.prediction_steps = cfg.prediction_steps
        offset = cfg.offset

        if cfg.activation == "relu":
            activation = nn.ReLU()
        elif cfg.activation == "gelu":
            activation = nn.GELU()
        else:
            raise Exception("unknown activation " + cfg.activation)

        feature_enc_layers = eval(cfg.conv_feature_layers)
        self.feature_extractor = ConvFeatureExtractionModel(
            conv_layers=feature_enc_layers,
            dropout=0.0,
            log_compression=cfg.log_compression,
            skip_connections=cfg.skip_connections_feat,
            residual_scale=cfg.residual_scale,
            non_affine_group_norm=cfg.non_affine_group_norm,
            activation=activation,
        )
        embed = feature_enc_layers[-1][0]

        self.vector_quantizer = None
        if cfg.vq_type == "gumbel":
            self.vector_quantizer = GumbelVectorQuantizer(
                dim=embed,
                num_vars=cfg.vq_vars,
                temp=cfg.vq_temp,
                groups=cfg.vq_groups,
                combine_groups=cfg.combine_groups,
                vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed,
                time_first=False,
                activation=activation,
                weight_proj_depth=cfg.vq_depth,
                weight_proj_factor=2,
            )
        elif cfg.vq_type == "kmeans":
            self.vector_quantizer = KmeansVectorQuantizer(
                dim=embed,
                num_vars=cfg.vq_vars,
                groups=cfg.vq_groups,
                combine_groups=cfg.combine_groups,
                vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed,
                time_first=False,
                gamma=cfg.vq_gamma,
            )
        else:
            assert (
                cfg.vq_type == "none" or cfg.vq_type is None
            ), "Unknown quantizer type"

        if cfg.offset == "auto":
            jin = 0
            rin = 0
            for _, k, stride in feature_enc_layers:
                if rin == 0:
                    rin = k
                rin = rin + (k - 1) * jin
                if jin == 0:
                    jin = stride
                else:
                    jin *= stride
            offset = math.ceil(rin / jin)

        offset = int(offset)

        def make_aggregator():
            if cfg.aggregator == "cnn":
                agg_layers = eval(cfg.conv_aggregator_layers)
                agg_dim = agg_layers[-1][0]
                feature_aggregator = ConvAggegator(
                    conv_layers=agg_layers,
                    embed=embed,
                    dropout=cfg.dropout,
                    skip_connections=cfg.skip_connections_agg,
                    residual_scale=cfg.residual_scale,
                    non_affine_group_norm=cfg.non_affine_group_norm,
                    conv_bias=not cfg.no_conv_bias,
                    zero_pad=cfg.agg_zero_pad,
                    activation=activation,
                )
            elif cfg.aggregator == "gru":
                agg_dim = cfg.gru_dim
                feature_aggregator = nn.Sequential(
                    TransposeLast(),
                    nn.GRU(
                        input_size=embed,
                        hidden_size=agg_dim,
                        num_layers=1,
                        dropout=cfg.dropout,
                    ),
                    TransposeLast(deconstruct_idx=0),
                )
            else:
                raise Exception("unknown aggregator type " + cfg.aggregator)

            return feature_aggregator, agg_dim

        self.feature_aggregator, agg_dim = make_aggregator()

        self.wav2vec_predictions = Wav2VecPredictionsModel(
            in_dim=agg_dim,
            out_dim=embed,
            prediction_steps=cfg.prediction_steps,
            n_negatives=cfg.num_negatives,
            cross_sample_negatives=cfg.cross_sample_negatives,
            sample_distance=cfg.sample_distance,
            dropout=cfg.dropout,
            offset=offset,
            balanced_classes=cfg.balanced_classes,
            infonce=cfg.infonce,
        )

        self.dropout_feats = nn.Dropout(p=cfg.dropout_features)
        self.dropout_agg = nn.Dropout(p=cfg.dropout_agg)

        if cfg.project_features == "none":
            self.project_features = None
        elif cfg.project_features == "same":
            self.project_features = self.feature_aggregator
        elif cfg.project_features == "new":
            self.project_features, _ = make_aggregator()
Esempio n. 21
0
    def __init__(self,
                 C,
                 steps=3,
                 reduction=8,
                 se=False,
                 genotype=None,
                 drop_prob=0.,
                 mlp_ratio=4):
        super(RFConvNeXtAttention, self).__init__()
        assert genotype is not None
        self._ops = nn.ModuleList()
        self._C = C
        self._steps = steps
        self._stride = 1
        self._se = se
        self.C_in = C
        self.conv3x3 = False
        self.reduction = reduction
        # self.norm1 = nn.BatchNorm2d(C)
        self.norm1 = nn.LayerNorm(C, eps=1e-6)

        self.genotype = genotype
        op_names, indices = zip(*self.genotype.normal)
        concat = genotype.normal_concat

        self.bottle = nn.Conv2d(C,
                                C // self.reduction,
                                kernel_size=1,
                                stride=1,
                                padding=0,
                                bias=False)

        self.drop_path = DropPath(
            drop_prob) if drop_prob > 0. else nn.Identity()

        self.norm2 = nn.BatchNorm2d(C)

        mlp_hidden_dim = int(C // mlp_ratio)

        # pointwise/1x1 convs, implemented with linear layers
        self.pwconv1 = nn.Linear(C, mlp_hidden_dim)
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(mlp_hidden_dim, C)

        if self._se:
            self.se = SE(self.C_in, reduction=self.reduction)

        if self.conv3x3:
            self.conv3x3 = nn.Conv2d(C // self.reduction * self._steps,
                                     C,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1,
                                     bias=False)
        else:
            self.conv1x1 = nn.Conv2d(C // self.reduction * self._steps,
                                     C,
                                     kernel_size=1,
                                     stride=1,
                                     padding=0,
                                     bias=False)

        self._compile(C, op_names, indices, concat)
Esempio n. 22
0
    plt.legend(loc='best')
    plt.show()


# Funkcja wyznaczająca dokładność predykcji:
def get_accuracy(model, data_loader):
    correct, total = 0, 0  # ile ok, ile wszystkich
    for tweets, labels in data_loader:  # przechodzi dane
        output = model(tweets)  # jak dziala model
        pred = output.max(1, keepdim=True)[1]  # ktora kategoria
        correct += pred.eq(labels.view_as(pred)).sum().item()
        total += labels.shape[0]
    return correct / total


mymodel = nn.Sequential(nn.Linear(200, 100), nn.GELU(), nn.Linear(100, 2))

train_network(mymodel,
              train_loader,
              valid_loader,
              num_epochs=1000,
              learning_rate=0.0000001)

print("Final test accuracy:",
      get_accuracy(mymodel, test_loader))  # dokladnosc na zbiorze testowym


def test_model(model, glove_vector, re):
    emb = sum(glove_vector[w]
              for w in re)  # przerabiam tweet na sume embieddingpw
    out = mymodel(emb.unsqueeze(0))  # co powie model
 def __init__(self, dim, mult=4, dropout=0.):
     super().__init__()
     self.net = nn.Sequential(nn.Linear(dim, dim * mult), nn.GELU(),
                              nn.Dropout(dropout),
                              nn.Linear(dim * mult, dim))
Esempio n. 24
0
 def __init__(self, d_model):
     super(position_wise_FFNN, self).__init__()
     self.linear1 = nn.Linear(d_model, d_model *
                              4)  # transformer에서 4배를 해줌으로 4배로 지정해 보았습니다 :)
     self.GELU = nn.GELU()
     self.linear2 = nn.Linear(d_model * 4, d_model)
Esempio n. 25
0
    def __init__(self, img_size, d_conv_dim, d_spectral_norm, attention,
                 attention_after_nth_dis_block, activation_fn,
                 conditional_strategy, hypersphere_dim, num_classes,
                 nonlinear_embed, normalize_embed, initialize, D_depth,
                 mixed_precision):
        super(Discriminator, self).__init__()
        self.in_dims = [3] + [64, 128]
        self.out_dims = [64, 128, 256]

        self.d_spectral_norm = d_spectral_norm
        self.conditional_strategy = conditional_strategy
        self.num_classes = num_classes
        self.nonlinear_embed = nonlinear_embed
        self.normalize_embed = normalize_embed
        self.mixed_precision = mixed_precision

        self.blocks = []
        for index in range(len(self.in_dims)):
            self.blocks += [[
                DiscBlock(in_channels=self.in_dims[index],
                          out_channels=self.out_dims[index],
                          d_spectral_norm=d_spectral_norm,
                          activation_fn=activation_fn)
            ]]

            if index + 1 == attention_after_nth_dis_block and attention is True:
                self.blocks += [[
                    Self_Attn(self.out_dims[index], d_spectral_norm)
                ]]

        self.blocks = nn.ModuleList(
            [nn.ModuleList(block) for block in self.blocks])

        if self.d_spectral_norm:
            self.conv = snconv2d(in_channels=256,
                                 out_channels=512,
                                 kernel_size=3,
                                 stride=1,
                                 padding=1)
        else:
            self.conv = conv2d(in_channels=256,
                               out_channels=512,
                               kernel_size=3,
                               stride=1,
                               padding=1)
            self.bn = batchnorm_2d(in_features=512)

        if activation_fn == "ReLU":
            self.activation = nn.ReLU(inplace=True)
        elif activation_fn == "Leaky_ReLU":
            self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        elif activation_fn == "ELU":
            self.activation = nn.ELU(alpha=1.0, inplace=True)
        elif activation_fn == "GELU":
            self.activation = nn.GELU()
        else:
            raise NotImplementedError

        if d_spectral_norm:
            self.linear1 = snlinear(in_features=512, out_features=1)
            if self.conditional_strategy in [
                    'ContraGAN', 'Proxy_NCA_GAN', 'NT_Xent_GAN'
            ]:
                self.linear2 = snlinear(in_features=512,
                                        out_features=hypersphere_dim)
                if self.nonlinear_embed:
                    self.linear3 = snlinear(in_features=hypersphere_dim,
                                            out_features=hypersphere_dim)
                self.embedding = sn_embedding(num_classes, hypersphere_dim)
            elif self.conditional_strategy == 'ProjGAN':
                self.embedding = sn_embedding(num_classes, 512)
            elif self.conditional_strategy == 'ACGAN':
                self.linear4 = snlinear(in_features=512,
                                        out_features=num_classes)
            else:
                pass
        else:
            self.linear1 = linear(in_features=512, out_features=1)
            if self.conditional_strategy in [
                    'ContraGAN', 'Proxy_NCA_GAN', 'NT_Xent_GAN'
            ]:
                self.linear2 = linear(in_features=512,
                                      out_features=hypersphere_dim)
                if self.nonlinear_embed:
                    self.linear3 = linear(in_features=hypersphere_dim,
                                          out_features=hypersphere_dim)
                self.embedding = embedding(num_classes, hypersphere_dim)
            elif self.conditional_strategy == 'ProjGAN':
                self.embedding = embedding(num_classes, 512)
            elif self.conditional_strategy == 'ACGAN':
                self.linear4 = linear(in_features=512,
                                      out_features=num_classes)
            else:
                pass

        # Weight init
        if initialize is not False:
            init_weights(self.modules, initialize)
    def __init__(self,
                 block,
                 num_blocks,
                 num_classes=10,
                 normalize=False,
                 normalize_only_FN=False,
                 scale=15,
                 activation='ReLU',
                 softplus_beta=1):
        super(PreActResNet, self).__init__()
        self.in_planes = 64

        self.normalize = normalize
        self.normalize_only_FN = normalize_only_FN
        self.scale = scale

        self.activation = activation
        self.softplus_beta = softplus_beta

        self.conv1 = nn.Conv2d(3,
                               64,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=False)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.bn = normal_func(512 * block.expansion,
                              track_running_stats=track_running_stats,
                              affine=affine)

        if self.normalize:
            self.linear = nn.Linear(512 * block.expansion,
                                    num_classes,
                                    bias=False)
        else:
            self.linear = nn.Linear(512 * block.expansion, num_classes)

        if activation == 'ReLU':
            self.relu = nn.ReLU(inplace=True)
            print('ReLU')
        elif activation == 'Softplus':
            self.relu = nn.Softplus(beta=softplus_beta, threshold=20)
            print('Softplus')
        elif activation == 'GELU':
            self.relu = nn.GELU()
            print('GELU')
        elif activation == 'ELU':
            self.relu = nn.ELU(alpha=1.0, inplace=True)
            print('ELU')
        elif activation == 'LeakyReLU':
            self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
            print('LeakyReLU')
        elif activation == 'SELU':
            self.relu = nn.SELU(inplace=True)
            print('SELU')
        elif activation == 'CELU':
            self.relu = nn.CELU(alpha=1.2, inplace=True)
            print('CELU')
        elif activation == 'Tanh':
            self.relu = nn.Tanh()
            print('Tanh')
        print('Use activation of ' + activation)
Esempio n. 27
0
 def __init__(self, dim, hidden_dim, dropout=0.):
     super().__init__()
     self.net = nn.Sequential(nn.Linear(dim, hidden_dim), nn.GELU(),
                              nn.Dropout(dropout),
                              nn.Linear(hidden_dim, dim))
Esempio n. 28
0
    def __init__(self, z_dim, shared_dim, img_size, g_conv_dim,
                 g_spectral_norm, attention, attention_after_nth_gen_block,
                 activation_fn, conditional_strategy, num_classes,
                 synchronized_bn, initialize, G_depth):
        super(Generator, self).__init__()
        g_in_dims_collection = {
            "32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4],
            "64":
            [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],
            "96": [
                g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8,
                g_conv_dim * 4, g_conv_dim * 2
            ],
            "128": [
                g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8,
                g_conv_dim * 4, g_conv_dim * 2
            ]
        }

        g_out_dims_collection = {
            "32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4],
            "64": [g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],
            "96": [
                g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4,
                g_conv_dim * 2, g_conv_dim
            ],
            "128": [
                g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4,
                g_conv_dim * 2, g_conv_dim
            ]
        }

        bottom_collection = {"32": 4, "64": 4, "96": 3, "128": 4}
        self.z_dim = z_dim
        self.shared_dim = shared_dim
        self.num_classes = num_classes
        conditional_bn = True if conditional_strategy in [
            "ACGAN", "cGAN", "ContraGAN", "Proxy_NCA_GAN", "XT_Xent_GAN"
        ] else False

        self.in_dims = g_in_dims_collection[str(img_size)]
        self.out_dims = g_out_dims_collection[str(img_size)]
        self.bottom = bottom_collection[str(img_size)]
        self.n_blocks = len(self.in_dims)
        self.z_dims_after_concat = self.z_dim + self.shared_dim

        if g_spectral_norm:
            self.linear0 = snlinear(in_features=self.z_dims_after_concat,
                                    out_features=self.in_dims[0] *
                                    self.bottom * self.bottom)
        else:
            self.linear0 = linear(in_features=self.z_dims_after_concat,
                                  out_features=self.in_dims[0] * self.bottom *
                                  self.bottom)

        self.shared = embedding(self.num_classes, self.shared_dim)

        self.blocks = []
        for index in range(self.n_blocks):
            self.blocks += [[
                GenBlock(in_channels=self.in_dims[index],
                         out_channels=self.in_dims[index]
                         if g_index == 0 else self.out_dims[index],
                         g_spectral_norm=g_spectral_norm,
                         activation_fn=activation_fn,
                         conditional_bn=conditional_bn,
                         z_dims_after_concat=self.z_dims_after_concat,
                         synchronized_bn=synchronized_bn,
                         upsample=True if g_index == (G_depth - 1) else False)
            ] for g_index in range(G_depth)]

            if index + 1 == attention_after_nth_gen_block and attention is True:
                self.blocks += [[
                    Self_Attn(self.out_dims[index], g_spectral_norm)
                ]]

        self.blocks = nn.ModuleList(
            [nn.ModuleList(block) for block in self.blocks])

        if synchronized_bn:
            self.bn4 = sync_batchnorm_2d(in_features=self.out_dims[-1])
        else:
            self.bn4 = batchnorm_2d(in_features=self.out_dims[-1])

        if activation_fn == "ReLU":
            self.activation = nn.ReLU(inplace=True)
        elif activation_fn == "Leaky_ReLU":
            self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        elif activation_fn == "ELU":
            self.activation = nn.ELU(alpha=1.0, inplace=True)
        elif activation_fn == "GELU":
            self.activation = nn.GELU()
        else:
            raise NotImplementedError

        if g_spectral_norm:
            self.conv2d5 = snconv2d(in_channels=self.out_dims[-1],
                                    out_channels=3,
                                    kernel_size=3,
                                    stride=1,
                                    padding=1)
        else:
            self.conv2d5 = conv2d(in_channels=self.out_dims[-1],
                                  out_channels=3,
                                  kernel_size=3,
                                  stride=1,
                                  padding=1)

        self.tanh = nn.Tanh()

        # Weight init
        if initialize is not False:
            init_weights(self.modules, initialize)
Esempio n. 29
0
    def __init__(self, config):
        super().__init__()
        self.save_hyperparameters()

        bert_config = BertConfig(
            vocab_size=config["vocab_size"],
            hidden_size=config["hidden_size"],
            num_hidden_layers=config["num_layers"],
            num_attention_heads=config["num_heads"],
            intermediate_size=config["hidden_size"] * config["mlp_ratio"],
            max_position_embeddings=config["max_text_len"],
            hidden_dropout_prob=config["drop_rate"],
            attention_probs_dropout_prob=config["drop_rate"],
        )
        self.tempeture_max_OT = config['tempeture_max_OT']
        self.text_embeddings = BertEmbeddings(bert_config)
        self.text_embeddings.apply(objectives.init_weights)

        self.token_type_embeddings = nn.Embedding(2, config["hidden_size"])
        self.token_type_embeddings.apply(objectives.init_weights)

        import vilt.modules.vision_transformer as vit

        if self.hparams.config["load_path"] == "":
            self.transformer = getattr(vit, self.hparams.config["vit"])(
                pretrained=config["pretrained_flag"], config=self.hparams.config)
        else:
            self.transformer = getattr(vit, self.hparams.config["vit"])(
                pretrained=False, config=self.hparams.config
            )

        self.pooler = heads.Pooler(config["hidden_size"])
        self.pooler.apply(objectives.init_weights)

        if config["loss_names"]["mlm"] > 0:
            self.mlm_score = heads.MLMHead(bert_config)
            self.mlm_score.apply(objectives.init_weights)

        if config["loss_names"]["itm"] > 0:
            self.itm_score = heads.ITMHead(config["hidden_size"])
            self.itm_score.apply(objectives.init_weights)

        if config["loss_names"]["mpp"] > 0:
            self.mpp_score = heads.MPPHead(bert_config)
            self.mpp_score.apply(objectives.init_weights)

        # ===================== Downstream ===================== #
        if (
            self.hparams.config["load_path"] != ""
            and not self.hparams.config["test_only"]
        ):
            ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu")
            state_dict = ckpt["state_dict"]
            self.load_state_dict(state_dict, strict=False)
            print(f'Loading checkpoint from {self.hparams.config["load_path"]}')

        hs = self.hparams.config["hidden_size"]

        if self.hparams.config["loss_names"]["vqa"] > 0:
            vs = self.hparams.config["vqav2_label_size"]
            self.vqa_classifier = nn.Sequential(
                nn.Linear(hs, hs * 2),
                nn.LayerNorm(hs * 2),
                nn.GELU(),
                nn.Linear(hs * 2, vs),
            )
            self.vqa_classifier.apply(objectives.init_weights)

        if self.hparams.config["loss_names"]["nlvr2"] > 0:
            self.nlvr2_classifier = nn.Sequential(
                nn.Linear(hs * 2, hs * 2),
                nn.LayerNorm(hs * 2),
                nn.GELU(),
                nn.Linear(hs * 2, 2),
            )
            self.nlvr2_classifier.apply(objectives.init_weights)
            emb_data = self.token_type_embeddings.weight.data
            self.token_type_embeddings = nn.Embedding(3, hs)
            self.token_type_embeddings.apply(objectives.init_weights)
            self.token_type_embeddings.weight.data[0, :] = emb_data[0, :]
            self.token_type_embeddings.weight.data[1, :] = emb_data[1, :]
            self.token_type_embeddings.weight.data[2, :] = emb_data[1, :]

        if self.hparams.config["loss_names"]["irtr"] > 0:
            self.rank_output = nn.Linear(hs, 1)
            self.rank_output.weight.data = self.itm_score.fc.weight.data[1:, :]
            self.rank_output.bias.data = self.itm_score.fc.bias.data[1:]
            self.margin = 0.2
            for p in self.itm_score.parameters():
                p.requires_grad = False

        vilt_utils.set_metrics(self)
        self.current_tasks = list()

        # ===================== load downstream (test_only) ======================

        if self.hparams.config["load_path"] != "" and self.hparams.config["test_only"]:
            ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu")
            state_dict = ckpt["state_dict"]
            self.load_state_dict(state_dict, strict=False)
            print(f'Loading checkpoint from {self.hparams.config["load_path"]}')
    def __init__(
        self,
        dim,
        num_vars,
        temp,
        groups,
        combine_groups,
        vq_dim,
        time_first,
        activation=nn.GELU(),
        weight_proj_depth=1,
        weight_proj_factor=1,
    ):
        """Vector quantization using gumbel softmax

        Args:
            dim: input dimension (channels)
            num_vars: number of quantized vectors per group
            temp: temperature for training. this should be a tuple of 3 elements: (start, stop, decay factor)
            groups: number of groups for vector quantization
            combine_groups: whether to use the vectors for all groups
            vq_dim: dimensionality of the resulting quantized vector
            time_first: if true, expect input in BxTxC format, otherwise in BxCxT
            activation: what activation to use (should be a module). this is only used if weight_proj_depth is > 1
            weight_proj_depth: number of layers (with activation in between) to project input before computing logits
            weight_proj_factor: this is used only if weight_proj_depth is > 1. scales the inner dimensionality of
                                projections by this factor
        """
        super().__init__()

        self.groups = groups
        self.combine_groups = combine_groups
        self.input_dim = dim
        self.num_vars = num_vars
        self.time_first = time_first

        assert (
            vq_dim % groups == 0
        ), f"dim {vq_dim} must be divisible by groups {groups} for concatenation"

        var_dim = vq_dim // groups
        num_groups = groups if not combine_groups else 1

        self.vars = nn.Parameter(torch.FloatTensor(1, num_groups * num_vars, var_dim))
        nn.init.uniform_(self.vars)

        if weight_proj_depth > 1:

            def block(input_dim, output_dim):
                return nn.Sequential(nn.Linear(input_dim, output_dim), activation)

            inner_dim = self.input_dim * weight_proj_factor
            self.weight_proj = nn.Sequential(
                *[
                    block(self.input_dim if i == 0 else inner_dim, inner_dim)
                    for i in range(weight_proj_depth - 1)
                ],
                nn.Linear(inner_dim, groups * num_vars),
            )
        else:
            self.weight_proj = nn.Linear(self.input_dim, groups * num_vars)
            nn.init.normal_(self.weight_proj.weight, mean=0, std=1)
            nn.init.zeros_(self.weight_proj.bias)

        if isinstance(temp, str):
            import ast
            temp = ast.literal_eval(temp)
        assert len(temp) == 3, f"{temp}, {len(temp)}"

        self.max_temp, self.min_temp, self.temp_decay = temp
        self.curr_temp = self.max_temp
        self.codebook_indices = None