def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) patches_resolution = [ img_size[0] // patch_size[0], img_size[1] // patch_size[1] ] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution self.num_patches = patches_resolution[0] * patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, isBin=True): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ f"img_size {img_size} should be divided by patch_size {patch_size}." self.H, self.W = img_size[0] // patch_size[0], img_size[ 1] // patch_size[1] self.num_patches = self.H * self.W self.norm = nn.LayerNorm(embed_dim) self.isBin = isBin if self.isBin: self.binary_act = BinaryActivation() self.proj = HardBinaryConv(in_chans, embed_dim, kernel_size=patch_size[0], stride=patch_size) else: self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): super().__init__() assert isinstance(backbone, nn.Module) img_size = to_2tuple(img_size) self.img_size = img_size self.backbone = backbone if feature_size is None: with torch.no_grad(): training = backbone.training if training: backbone.eval() o = self.backbone( torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] feature_size = o.shape[-2:] feature_dim = o.shape[1] backbone.train(training) else: feature_size = to_2tuple(feature_size) feature_dim = self.backbone.feature_info.channels()[-1] self.num_patches = feature_size[0] * feature_size[1] self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1)
def __init__(self, img_size=224, patch_size=16, in_chans=3, outer_dim=768, inner_dim=24, inner_stride=4): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.inner_dim = inner_dim self.num_words = math.ceil(patch_size[0] / inner_stride) * math.ceil( patch_size[1] / inner_stride) self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size) self.proj = nn.Conv2d(in_chans, inner_dim, kernel_size=7, padding=3, stride=inner_stride)
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): super().__init__() assert isinstance(backbone, nn.Module) img_size = to_2tuple(img_size) self.img_size = img_size self.backbone = backbone if feature_size is None: with torch.no_grad(): # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature # map for all networks, the feature metadata has reliable channel and stride info, but using # stride to calc feature dim requires info about padding of each stage that isn't captured. training = backbone.training if training: backbone.eval() o = self.backbone( torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] feature_size = o.shape[-2:] feature_dim = o.shape[1] backbone.train(training) else: feature_size = to_2tuple(feature_size) feature_dim = self.backbone.feature_info.channels()[-1] self.num_patches = feature_size[0] * feature_size[1] self.proj = nn.Linear(feature_dim, embed_dim)
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution self.num_patches = patches_resolution[0] * patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim
def __init__(self, img_size=256, patch_size = 16, in_channels=3, embed_dim=768): super(PatchEmbed, self).__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, f"img_size {img_size} should be divided by patch_size {patch_size}" self.H,self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] self.num_patches = self.H * self.W self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = nn.LayerNorm(embed_dim)
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ f"img_size {img_size} should be divided by patch_size {patch_size}." self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] # Note: self.H, self.W and self.num_patches are not used self.num_patches = self.H * self.W # since the image size may change on the fly. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = nn.LayerNorm(embed_dim)
def __init__(self, img_size=224, patch_size=4, embed_dim=96): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) patches_resolution = [ img_size[0] // patch_size[0], img_size[1] // patch_size[1] ] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution self.num_patches = patches_resolution[0] * patches_resolution[1] self.embed_dim = embed_dim self.proj = nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=4, stride=4)
def __init__(self, patch_size, nx, ny, in_chans=3, embed_dim=768, nglo=1, norm_layer=nn.LayerNorm, norm_embed=True, drop_rate=0.0, ape=True): # maximal global/x-direction/y-direction tokens: nglo, nx, ny super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm_embed = norm_layer(embed_dim) if norm_embed else None self.nx = nx self.ny = ny self.Nglo = nglo if nglo >= 1: self.cls_token = nn.Parameter(torch.zeros(1, nglo, embed_dim)) trunc_normal_(self.cls_token, std=.02) else: self.cls_token = None self.ape = ape if ape: self.cls_pos_embed = nn.Parameter(torch.zeros(1, nglo, embed_dim)) self.x_pos_embed = nn.Parameter(torch.zeros(1, nx, embed_dim // 2)) self.y_pos_embed = nn.Parameter(torch.zeros(1, ny, embed_dim // 2)) trunc_normal_(self.cls_pos_embed, std=.02) trunc_normal_(self.x_pos_embed, std=.02) trunc_normal_(self.y_pos_embed, std=.02) self.pos_drop = nn.Dropout(p=drop_rate)
def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, gate_layer=None, drop=0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features drop_probs = to_2tuple(drop) self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) if gate_layer is not None: assert hidden_features % 2 == 0 self.gate = gate_layer(hidden_features) hidden_features = (hidden_features // 2 ) # FIXME base reduction on gate property? else: self.gate = nn.Identity() self.fc2 = nn.Linear(hidden_features, out_features) self.drop2 = nn.Dropout(drop_probs[1])
def __init__(self, patch_size=16, embed_dim=768): super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size if patch_size[0] == 16: self.proj = torch.nn.Sequential( conv3x3(3, embed_dim // 8, 2), nn.GELU(), conv3x3(embed_dim // 8, embed_dim // 4, 2), nn.GELU(), conv3x3(embed_dim // 4, embed_dim // 2, 2), nn.GELU(), conv3x3(embed_dim // 2, embed_dim, 2), ) elif patch_size[0] == 8: self.proj = torch.nn.Sequential( conv3x3(3, embed_dim // 4, 2), nn.GELU(), conv3x3(embed_dim // 4, embed_dim // 2, 2), nn.GELU(), conv3x3(embed_dim // 2, embed_dim, 2), ) else: raise ( "For convolutional projection, patch size has to be in [8, 16]" )
def __init__( self, dim, seq_len, mlp_ratio=(0.5, 4.0), mlp_layer=Mlp, norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0.0, drop_path=0.0, ): super().__init__() tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)] self.norm1 = norm_layer(dim) self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop) self.drop_path = DropPath( drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop)
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio if min(self.input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if self.shift_size > 0: attn_mask = self.calculate_mask(self.input_resolution) else: attn_mask = None self.register_buffer("attn_mask", attn_mask)
def __init__(self, img_size=224, in_chans=3, outer_dim=768, inner_dim=24): super().__init__() img_size = to_2tuple(img_size) self.img_size = img_size self.inner_dim = inner_dim self.num_patches = img_size[0] // 8 * img_size[1] // 8 self.num_words = 16 self.common_conv = nn.Sequential( nn.Conv2d(in_chans, inner_dim*2, 3, stride=2, padding=1), nn.BatchNorm2d(inner_dim*2), nn.ReLU(inplace=True), ) self.inner_convs = nn.Sequential( nn.Conv2d(inner_dim*2, inner_dim, 3, stride=1, padding=1), nn.BatchNorm2d(inner_dim), nn.ReLU(inplace=False), ) self.outer_convs = nn.Sequential( nn.Conv2d(inner_dim*2, inner_dim*4, 3, stride=2, padding=1), nn.BatchNorm2d(inner_dim*4), nn.ReLU(inplace=True), nn.Conv2d(inner_dim*4, inner_dim*8, 3, stride=2, padding=1), nn.BatchNorm2d(inner_dim*8), nn.ReLU(inplace=True), nn.Conv2d(inner_dim*8, outer_dim, 3, stride=1, padding=1), nn.BatchNorm2d(outer_dim), nn.ReLU(inplace=False), ) self.unfold = nn.Unfold(kernel_size=4, padding=0, stride=4)
def __init__(self, patch_size=16, in_chans=3, embed_dim=768): super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = nn.LayerNorm(embed_dim)
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) assert max(patch_size) > stride, "Set larger patch_size than stride" self.img_size = img_size self.patch_size = patch_size self.H, self.W = img_size[0] // stride, img_size[1] // stride self.num_patches = self.H * self.W self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2)) self.norm = nn.LayerNorm(embed_dim) self.apply(self._init_weights)
def __init__(self, patch_size=16, in_ch=3, out_ch=768, with_pos=False): super().__init__() self.patch_size = to_2tuple(patch_size) self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=patch_size+1, stride=patch_size, padding=patch_size // 2) self.norm = nn.BatchNorm2d(out_ch) self.with_pos = with_pos if self.with_pos: self.pos = PA(out_ch)
def __init__(self, img_size: int = 224, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Sequential( conv3x3(in_chans, embed_dim // 8, 2), nn.GELU(), conv3x3(embed_dim // 8, embed_dim // 4, 2), nn.GELU(), conv3x3(embed_dim // 4, embed_dim // 2, 2), nn.GELU(), conv3x3(embed_dim // 2, embed_dim, 2), )
def __init__( self, dim, input_resolution, num_heads, group_size=7, lsda_flag=0, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, num_patch_size=1, ): super().__init__() self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads self.group_size = group_size self.lsda_flag = lsda_flag self.mlp_ratio = mlp_ratio self.num_patch_size = num_patch_size if min(self.input_resolution) <= self.group_size: # if group size is larger than input resolution, we don't partition groups self.lsda_flag = 0 self.group_size = min(self.input_resolution) self.norm1 = norm_layer(dim) self.attn = Attention( dim, group_size=to_2tuple(self.group_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, position_bias=True, ) self.drop_path = DropPath( drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, ) attn_mask = None self.register_buffer("attn_mask", attn_mask)
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.patch_grid = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.patch_grid[0] * self.patch_grid[1] self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() new_patch_size = to_2tuple(patch_size // 2) img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.embed_dim = embed_dim self.conv1 = nn.Conv2d(in_chans, 128, kernel_size=7, stride=2, padding=3, bias=False) # 112x112 self.bn1 = nn.BatchNorm2d(128) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False) # 112x112 self.bn2 = nn.BatchNorm2d(128) self.conv3 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False) self.bn3 = nn.BatchNorm2d(128) self.proj = nn.Conv2d(128, embed_dim, kernel_size=new_patch_size, stride=new_patch_size)
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, attn_cfg=None): super().__init__(**attn_cfg) img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ f"img_size {img_size} should be divided by patch_size {patch_size}." self.H, self.W = img_size[0] // patch_size[0], img_size[ 1] // patch_size[1] self.num_patches = self.H * self.W self.norm = nn.LayerNorm(embed_dim) self.fc = nn.Linear(in_chans, embed_dim) self.proj = nn.Linear(embed_dim, embed_dim)
def __init__(self, backbone, img_size=224, patch_size=16, feature_size=None, in_chans=3, embed_dim=768): super().__init__() assert isinstance(backbone, nn.Module) img_size = to_2tuple(img_size) self.img_size = img_size self.backbone = backbone if feature_size is None: with torch.no_grad(): # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature # map for all networks, the feature metadata has reliable channel and stride info, but using # stride to calc feature dim requires info about padding of each stage that isn't captured. training = backbone.training if training: backbone.eval() o = self.backbone( torch.zeros(1, in_chans, img_size[0], img_size[1])) if isinstance(o, (list, tuple)): o = o[ -1] # last feature if backbone outputs list/tuple of features feature_size = o.shape[-2:] feature_dim = o.shape[1] backbone.train(training) else: feature_size = to_2tuple(feature_size) feature_dim = self.backbone.feature_info.channels()[-1] print('feature_size is {}, feature_dim is {}, patch_size is {}'.format( feature_size, feature_dim, patch_size)) self.num_patches = (feature_size[0] // patch_size) * (feature_size[1] // patch_size) self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size)
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio if min(self.input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if self.shift_size > 0: # calculate attention mask for SW-MSA H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None self.register_buffer("attn_mask", attn_mask)
def __init__(self, in_chans, embed_dim, resolution, activation): super().__init__() img_size: Tuple[int, int] = to_2tuple(resolution) self.patches_resolution = (img_size[0] // 4, img_size[1] // 4) self.num_patches = self.patches_resolution[0] * \ self.patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim n = embed_dim self.seq = nn.Sequential( Conv2d_BN(in_chans, n // 2, 3, 2, 1), activation(), Conv2d_BN(n // 2, n, 3, 2, 1), )
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=nn.SyncBatchNorm): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = torch.nn.Sequential(*[ nn.Conv2d(in_chans, embed_dim // 4, kernel_size=4, stride=4), norm_layer(embed_dim // 4), nn.GELU(), nn.Conv2d(embed_dim // 4, embed_dim // 4, kernel_size=2, stride=2), norm_layer(embed_dim // 4), nn.GELU(), nn.Conv2d(embed_dim // 4, embed_dim, kernel_size=2, stride=2), norm_layer(embed_dim), ])
def forward(self, x): x = self.forward_features(x) x_rot = self.pre_logits_rot(x[:, 0]) x_rot = self.rot_head(x_rot) x_contrastive = self.pre_logits_contrastive(x[:, 1]) x_contrastive = self.contrastive_head(x_contrastive) if self.training_mode == 'finetune': return x_rot, x_contrastive x_rec = x[:, 2:].transpose(1, 2) x_rec = self.convTrans(x_rec.unflatten(2, to_2tuple(int(math.sqrt(x_rec.size()[2]))))) return x_rot, x_contrastive, x_rec, self.rot_w, self.contrastive_w, self.recons_w