def forward(self, inp, edge_info, rel_dist=None, basis=None): neighbor_indices, neighbor_masks, edges = edge_info rel_dist = rearrange(rel_dist, 'b m n -> b m n ()') kernels = {} outputs = {} for (di, mi), (do, mo) in (self.fiber_in * self.fiber_out): etype = f'({di},{do})' kernel_fn = self.kernel_unary[etype] edge_features = torch.cat( (rel_dist, edges), dim=-1) if exists(edges) else rel_dist kernels[etype] = kernel_fn(edge_features, basis=basis) for degree_out in self.fiber_out.degrees: output = 0 degree_out_key = str(degree_out) for degree_in, m_in in self.fiber_in: x = inp[str(degree_in)] x = batched_index_select(x, neighbor_indices, dim=1) x = x.view(*x.shape[:3], to_order(degree_in) * m_in, 1) etype = f'({degree_in},{degree_out})' kernel = kernels[etype] output = output + einsum('... o i, ... i c -> ... o c', kernel, x) if self.pool: output = masked_mean( output, neighbor_masks, dim=2) if exists(neighbor_masks) else output.mean(dim=2) leading_shape = x.shape[:2] if self.pool else x.shape[:3] output = output.view(*leading_shape, -1, to_order(degree_out)) outputs[degree_out_key] = output if self.self_interaction: self_interact_out = self.self_interact(inp) outputs = self.self_interact_sum(outputs, self_interact_out) return outputs
def forward(self, feats, coors, mask=None, adj_mat=None, edges=None, return_type=None, return_pooled=False): _mask = mask if self.output_degrees == 1: return_type = 0 if exists(self.token_emb): feats = self.token_emb(feats) assert not ( self.attend_sparse_neighbors and not exists(adj_mat) ), 'adjacency matrix (adjacency_mat) or edges (edges) must be passed in' assert not ( exists(edges) and not exists(self.edge_emb) ), 'edge embedding (num_edge_tokens & edge_dim) must be supplied if one were to train on edge types' if torch.is_tensor(feats): feats = {'0': feats[..., None]} b, n, d, *_, device = *feats['0'].shape, feats['0'].device assert d == self.dim, f'feature dimension {d} must be equal to dimension given at init {self.dim}' assert set(map(int, feats.keys())) == set( range(self.input_degrees )), f'input must have {self.input_degrees} degree' num_degrees, neighbors, max_sparse_neighbors, valid_radius = self.num_degrees, self.num_neighbors, self.max_sparse_neighbors, self.valid_radius assert self.attend_sparse_neighbors or neighbors > 0, 'you must either attend to sparsely bonded neighbors, or set number of locally attended neighbors to be greater than 0' # create N-degrees adjacent matrix from 1st degree connections if exists(self.num_adj_degrees): if len(adj_mat.shape) == 2: adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b=b) adj_indices = adj_mat.clone().long() for ind in range(self.num_adj_degrees - 1): degree = ind + 2 next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0 next_degree_mask = (next_degree_adj_mat.float() - adj_mat.float()).bool() adj_indices.masked_fill_(next_degree_mask, degree) adj_mat = next_degree_adj_mat.clone() # se3 transformer by default cannot have a node attend to itself exclude_self_mask = rearrange( ~torch.eye(n, dtype=torch.bool, device=device), 'i j -> () i j') # calculate sparsely connected neighbors sparse_neighbor_mask = None num_sparse_neighbors = 0 if self.attend_sparse_neighbors: assert exists( adj_mat ), 'adjacency matrix must be passed in (keyword argument adj_mat)' if exists(adj_mat): if len(adj_mat) == 2: adj_mat = repeat(adj_mat, 'i j -> b i j', b=b) adj_mat = adj_mat.masked_select(exclude_self_mask).reshape( b, n, n - 1) adj_mat_values = adj_mat.float() adj_mat_max_neighbors = adj_mat_values.sum(dim=-1).max().item() if max_sparse_neighbors < adj_mat_max_neighbors: noise = torch.empty_like(adj_mat_values).uniform_(-0.01, 0.01) adj_mat_values += noise num_sparse_neighbors = int( min(max_sparse_neighbors, adj_mat_max_neighbors)) values, indices = adj_mat_values.topk(num_sparse_neighbors, dim=-1) sparse_neighbor_mask = torch.zeros_like(adj_mat_values).scatter_( -1, indices, values) sparse_neighbor_mask = sparse_neighbor_mask > 0.5 # exclude edge of token to itself indices = repeat(torch.arange(n, device=device), 'i -> b i j', b=b, j=n) rel_pos = rearrange(coors, 'b n d -> b n () d') - rearrange( coors, 'b n d -> b () n d') indices = indices.masked_select(exclude_self_mask).reshape(b, n, n - 1) rel_pos = rel_pos.masked_select(exclude_self_mask[..., None]).reshape( b, n, n - 1, 3) if exists(mask): mask = rearrange(mask, 'b i -> b i ()') * rearrange( mask, 'b j -> b () j') mask = mask.masked_select(exclude_self_mask).reshape(b, n, n - 1) if exists(edges): edges = self.edge_emb(edges) edges = edges.masked_select(exclude_self_mask[..., None]).reshape( b, n, n - 1, -1) if exists(self.adj_emb): adj_emb = self.adj_emb(adj_indices) edges = torch.cat( (edges, adj_emb), dim=-1) if exists(edges) else adj_emb rel_dist = rel_pos.norm(dim=-1) # use sparse neighbor mask to assign priority of bonded modified_rel_dist = rel_dist if exists(sparse_neighbor_mask): modified_rel_dist.masked_fill_(sparse_neighbor_mask, 0.) # if number of local neighbors by distance is set to 0, then only fetch the sparse neighbors defined by adjacency matrix if neighbors == 0: valid_radius = 0 # get neighbors and neighbor mask, excluding self neighbors = int(min(neighbors, n - 1)) total_neighbors = int(neighbors + num_sparse_neighbors) assert total_neighbors > 0, 'you must be fetching at least 1 neighbor' total_neighbors = int( min(total_neighbors, n - 1) ) # make sure total neighbors does not exceed the length of the sequence itself dist_values, nearest_indices = modified_rel_dist.topk(total_neighbors, dim=-1, largest=False) neighbor_mask = dist_values <= valid_radius neighbor_rel_dist = batched_index_select(rel_dist, nearest_indices, dim=2) neighbor_rel_pos = batched_index_select(rel_pos, nearest_indices, dim=2) neighbor_indices = batched_index_select(indices, nearest_indices, dim=2) if exists(mask): neighbor_mask = neighbor_mask & batched_index_select( mask, nearest_indices, dim=2) if exists(edges): edges = batched_index_select(edges, nearest_indices, dim=2) # calculate basis basis = get_basis(neighbor_rel_pos, num_degrees - 1, differentiable=self.differentiable_coors) # main logic edge_info = (neighbor_indices, neighbor_mask, edges) x = feats # project in x = self.conv_in(x, edge_info, rel_dist=neighbor_rel_dist, basis=basis) # transformer layers x = self.net(x, edge_info=edge_info, rel_dist=neighbor_rel_dist, basis=basis) # project out x = self.conv_out(x, edge_info, rel_dist=neighbor_rel_dist, basis=basis) # norm x = self.norm(x) # reduce dim if specified if exists(self.linear_out): x = self.linear_out(x) x = map_values(lambda t: t.squeeze(dim=2), x) if return_pooled: mask_fn = (lambda t: masked_mean(t, _mask, dim=1) ) if exists(_mask) else (lambda t: t.mean(dim=1)) x = map_values(mask_fn, x) if '0' in x: x['0'] = x['0'].squeeze(dim=-1) if exists(return_type): return x[str(return_type)] return x
def forward( self, feats, coors, mask = None, adj_mat = None, edges = None, return_type = None, return_pooled = False, neighbor_mask = None, global_feats = None ): assert not (self.accept_global_feats ^ exists(global_feats)), 'you cannot pass in global features unless you init the class correctly' _mask = mask if self.output_degrees == 1: return_type = 0 if exists(self.token_emb): feats = self.token_emb(feats) assert not (self.attend_sparse_neighbors and not exists(adj_mat)), 'adjacency matrix (adjacency_mat) or edges (edges) must be passed in' assert not (self.has_edges and not exists(edges)), 'edge embedding (num_edge_tokens & edge_dim) must be supplied if one were to train on edge types' if torch.is_tensor(feats): feats = {'0': feats[..., None]} if torch.is_tensor(global_feats): global_feats = {'0': global_feats[..., None]} b, n, d, *_, device = *feats['0'].shape, feats['0'].device assert d == self.dim_in[0], f'feature dimension {d} must be equal to dimension given at init {self.dim_in[0]}' assert set(map(int, feats.keys())) == set(range(self.input_degrees)), f'input must have {self.input_degrees} degree' num_degrees, neighbors, max_sparse_neighbors, valid_radius = self.num_degrees, self.num_neighbors, self.max_sparse_neighbors, self.valid_radius assert self.attend_sparse_neighbors or neighbors > 0, 'you must either attend to sparsely bonded neighbors, or set number of locally attended neighbors to be greater than 0' # se3 transformer by default cannot have a node attend to itself exclude_self_mask = rearrange(~torch.eye(n, dtype = torch.bool, device = device), 'i j -> () i j') remove_self = lambda t: t.masked_select(exclude_self_mask).reshape(b, n, n - 1) get_max_value = lambda t: torch.finfo(t.dtype).max # create N-degrees adjacent matrix from 1st degree connections if exists(self.num_adj_degrees): if len(adj_mat.shape) == 2: adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b = b) adj_indices = adj_mat.clone().long() for ind in range(self.num_adj_degrees - 1): degree = ind + 2 next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0 next_degree_mask = (next_degree_adj_mat.float() - adj_mat.float()).bool() adj_indices.masked_fill_(next_degree_mask, degree) adj_mat = next_degree_adj_mat.clone() adj_indices = adj_indices.masked_select(exclude_self_mask).reshape(b, n, n - 1) # calculate sparsely connected neighbors sparse_neighbor_mask = None num_sparse_neighbors = 0 if self.attend_sparse_neighbors: assert exists(adj_mat), 'adjacency matrix must be passed in (keyword argument adj_mat)' if exists(adj_mat): if len(adj_mat.shape) == 2: adj_mat = repeat(adj_mat, 'i j -> b i j', b = b) adj_mat = remove_self(adj_mat) adj_mat_values = adj_mat.float() adj_mat_max_neighbors = adj_mat_values.sum(dim = -1).max().item() if max_sparse_neighbors < adj_mat_max_neighbors: noise = torch.empty_like(adj_mat_values).uniform_(-0.01, 0.01) adj_mat_values += noise num_sparse_neighbors = int(min(max_sparse_neighbors, adj_mat_max_neighbors)) values, indices = adj_mat_values.topk(num_sparse_neighbors, dim = -1) sparse_neighbor_mask = torch.zeros_like(adj_mat_values).scatter_(-1, indices, values) sparse_neighbor_mask = sparse_neighbor_mask > 0.5 # exclude edge of token to itself indices = repeat(torch.arange(n, device = device), 'i -> b i j', b = b, j = n) rel_pos = rearrange(coors, 'b n d -> b n () d') - rearrange(coors, 'b n d -> b () n d') indices = indices.masked_select(exclude_self_mask).reshape(b, n, n - 1) rel_pos = rel_pos.masked_select(exclude_self_mask[..., None]).reshape(b, n, n - 1, 3) if exists(mask): mask = rearrange(mask, 'b i -> b i ()') * rearrange(mask, 'b j -> b () j') mask = mask.masked_select(exclude_self_mask).reshape(b, n, n - 1) if exists(edges): if exists(self.edge_emb): edges = self.edge_emb(edges) edges = edges.masked_select(exclude_self_mask[..., None]).reshape(b, n, n - 1, -1) if exists(self.adj_emb): adj_emb = self.adj_emb(adj_indices) edges = torch.cat((edges, adj_emb), dim = -1) if exists(edges) else adj_emb rel_dist = rel_pos.norm(dim = -1) # rel_dist gets modified using adjacency or neighbor mask modified_rel_dist = rel_dist.clone() max_value = get_max_value(modified_rel_dist) # for masking out nodes from being considered as neighbors # neighbors if exists(neighbor_mask): neighbor_mask = remove_self(neighbor_mask) max_neighbors = neighbor_mask.sum(dim = -1).max().item() if max_neighbors > neighbors: print(f'neighbor_mask shows maximum number of neighbors as {max_neighbors} but specified number of neighbors is {neighbors}') modified_rel_dist.masked_fill_(~neighbor_mask, max_value) # use sparse neighbor mask to assign priority of bonded if exists(sparse_neighbor_mask): modified_rel_dist.masked_fill_(sparse_neighbor_mask, 0.) # mask out future nodes to high distance if causal turned on if self.causal: causal_mask = torch.ones(n, n - 1, device = device).triu().bool() modified_rel_dist.masked_fill_(causal_mask[None, ...], max_value) # if number of local neighbors by distance is set to 0, then only fetch the sparse neighbors defined by adjacency matrix if neighbors == 0: valid_radius = 0 # get neighbors and neighbor mask, excluding self neighbors = int(min(neighbors, n - 1)) total_neighbors = int(neighbors + num_sparse_neighbors) assert total_neighbors > 0, 'you must be fetching at least 1 neighbor' total_neighbors = int(min(total_neighbors, n - 1)) # make sure total neighbors does not exceed the length of the sequence itself dist_values, nearest_indices = modified_rel_dist.topk(total_neighbors, dim = -1, largest = False) neighbor_mask = dist_values <= valid_radius neighbor_rel_dist = batched_index_select(rel_dist, nearest_indices, dim = 2) neighbor_rel_pos = batched_index_select(rel_pos, nearest_indices, dim = 2) neighbor_indices = batched_index_select(indices, nearest_indices, dim = 2) if exists(mask): neighbor_mask = neighbor_mask & batched_index_select(mask, nearest_indices, dim = 2) if exists(edges): edges = batched_index_select(edges, nearest_indices, dim = 2) # calculate rotary pos emb rotary_pos_emb = None rotary_query_pos_emb = None rotary_key_pos_emb = None if self.rotary_position: seq = torch.arange(n, device = device) seq_pos_emb = self.rotary_pos_emb(seq) self_indices = torch.arange(neighbor_indices.shape[1], device = device) self_indices = repeat(self_indices, 'i -> b i ()', b = b) neighbor_indices_with_self = torch.cat((self_indices, neighbor_indices), dim = 2) pos_emb = batched_index_select(seq_pos_emb, neighbor_indices_with_self, dim = 0) rotary_key_pos_emb = pos_emb rotary_query_pos_emb = repeat(seq_pos_emb, 'n d -> b n d', b = b) if self.rotary_rel_dist: neighbor_rel_dist_with_self = F.pad(neighbor_rel_dist, (1, 0), value = 0) * 1e2 rel_dist_pos_emb = self.rotary_pos_emb(neighbor_rel_dist_with_self) rotary_key_pos_emb = safe_cat(rotary_key_pos_emb, rel_dist_pos_emb, dim = -1) query_dist = torch.zeros(n, device = device) query_pos_emb = self.rotary_pos_emb(query_dist) query_pos_emb = repeat(query_pos_emb, 'n d -> b n d', b = b) rotary_query_pos_emb = safe_cat(rotary_query_pos_emb, query_pos_emb, dim = -1) if exists(rotary_query_pos_emb) and exists(rotary_key_pos_emb): rotary_pos_emb = (rotary_query_pos_emb, rotary_key_pos_emb) # calculate basis basis = get_basis(neighbor_rel_pos, num_degrees - 1, differentiable = self.differentiable_coors) # main logic edge_info = (neighbor_indices, neighbor_mask, edges) x = feats # project in x = self.conv_in(x, edge_info, rel_dist = neighbor_rel_dist, basis = basis) # preconvolution layers for conv, nonlin in self.convs: x = nonlin(x) x = conv(x, edge_info, rel_dist = neighbor_rel_dist, basis = basis) # transformer layers x = self.net(x, edge_info = edge_info, rel_dist = neighbor_rel_dist, basis = basis, global_feats = global_feats, pos_emb = rotary_pos_emb) # project out x = self.conv_out(x, edge_info, rel_dist = neighbor_rel_dist, basis = basis) # norm x = self.norm(x) # reduce dim if specified if exists(self.linear_out): x = self.linear_out(x) x = map_values(lambda t: t.squeeze(dim = 2), x) if return_pooled: mask_fn = (lambda t: masked_mean(t, _mask, dim = 1)) if exists(_mask) else (lambda t: t.mean(dim = 1)) x = map_values(mask_fn, x) if '0' in x: x['0'] = x['0'].squeeze(dim = -1) if exists(return_type): return x[str(return_type)] return x
def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos_emb = None): h, attend_self = self.heads, self.attend_self device, dtype = get_tensor_device_and_dtype(features) neighbor_indices, neighbor_mask, edges = edge_info max_neg_value = -torch.finfo().max if exists(neighbor_mask): neighbor_mask = rearrange(neighbor_mask, 'b i j -> b () i j') queries = self.to_q(features) values = self.to_v(features, edge_info, rel_dist, basis) if self.linear_proj_keys: keys = self.to_k(features) keys = map_values(lambda val: batched_index_select(val, neighbor_indices, dim = 1), keys) elif not exists(self.to_k): keys = values else: keys = self.to_k(features, edge_info, rel_dist, basis) if attend_self: self_keys, self_values = self.to_self_k(features), self.to_self_v(features) if exists(global_feats): global_keys, global_values = self.to_global_k(global_feats), self.to_global_v(global_feats) outputs = {} for degree in features.keys(): q, k, v = map(lambda t: t[degree], (queries, keys, values)) q = rearrange(q, 'b i (h d) m -> b h i d m', h = h) if attend_self: self_k, self_v = map(lambda t: t[degree], (self_keys, self_values)) self_k, self_v = map(lambda t: rearrange(t, 'b n d m -> b n () d m'), (self_k, self_v)) k = torch.cat((self_k, k), dim = 2) v = torch.cat((self_v, v), dim = 2) if exists(pos_emb) and degree == '0': query_pos_emb, key_pos_emb = pos_emb query_pos_emb = rearrange(query_pos_emb, 'b i d -> b () i d ()') key_pos_emb = rearrange(key_pos_emb, 'b i j d -> b i j d ()') q = apply_rotary_pos_emb(q, query_pos_emb) k = apply_rotary_pos_emb(k, key_pos_emb) v = apply_rotary_pos_emb(v, key_pos_emb) if self.use_null_kv: null_k, null_v = map(lambda t: t[degree], (self.null_keys, self.null_values)) null_k, null_v = map(lambda t: repeat(t, 'd m -> b i () d m', b = q.shape[0], i = q.shape[2]), (null_k, null_v)) k = torch.cat((null_k, k), dim = 2) v = torch.cat((null_v, v), dim = 2) if exists(global_feats) and degree == '0': global_k, global_v = map(lambda t: t[degree], (global_keys, global_values)) global_k, global_v = map(lambda t: repeat(t, 'b j d m -> b i j d m', i = k.shape[1]), (global_k, global_v)) k = torch.cat((global_k, k), dim = 2) v = torch.cat((global_v, v), dim = 2) sim = einsum('b h i d m, b i j d m -> b h i j', q, k) * self.scale if exists(neighbor_mask): num_left_pad = sim.shape[-1] - neighbor_mask.shape[-1] mask = F.pad(neighbor_mask, (num_left_pad, 0), value = True) sim.masked_fill_(~mask, max_neg_value) attn = sim.softmax(dim = -1) out = einsum('b h i j, b i j d m -> b h i d m', attn, v) outputs[degree] = rearrange(out, 'b h n d m -> b n (h d) m') return self.to_out(outputs)
def forward( self, inp, edge_info, rel_dist = None, basis = None ): splits = self.splits neighbor_indices, neighbor_masks, edges = edge_info rel_dist = rearrange(rel_dist, 'b m n -> b m n ()') kernels = {} outputs = {} if self.fourier_encode_dist: rel_dist = fourier_encode(rel_dist[..., None], num_encodings = self.num_fourier_features) # split basis basis_keys = basis.keys() split_basis_values = list(zip(*list(map(lambda t: fast_split(t, splits, dim = 1), basis.values())))) split_basis = list(map(lambda v: dict(zip(basis_keys, v)), split_basis_values)) # go through every permutation of input degree type to output degree type for degree_out in self.fiber_out.degrees: output = 0 degree_out_key = str(degree_out) for degree_in, m_in in self.fiber_in: etype = f'({degree_in},{degree_out})' x = inp[str(degree_in)] x = batched_index_select(x, neighbor_indices, dim = 1) x = x.view(*x.shape[:3], to_order(degree_in) * m_in, 1) kernel_fn = self.kernel_unary[etype] edge_features = torch.cat((rel_dist, edges), dim = -1) if exists(edges) else rel_dist output_chunk = None split_x = fast_split(x, splits, dim = 1) split_edge_features = fast_split(edge_features, splits, dim = 1) # process input, edges, and basis in chunks along the sequence dimension for x_chunk, edge_features, basis in zip(split_x, split_edge_features, split_basis): kernel = kernel_fn(edge_features, basis = basis) chunk = einsum('... o i, ... i c -> ... o c', kernel, x_chunk) output_chunk = safe_cat(output_chunk, chunk, dim = 1) output = output + output_chunk if self.pool: output = masked_mean(output, neighbor_masks, dim = 2) if exists(neighbor_masks) else output.mean(dim = 2) leading_shape = x.shape[:2] if self.pool else x.shape[:3] output = output.view(*leading_shape, -1, to_order(degree_out)) outputs[degree_out_key] = output if self.self_interaction: self_interact_out = self.self_interact(inp) outputs = self.self_interact_sum(outputs, self_interact_out) return outputs
def forward(self, feats, coors, mask=None, edges=None, return_type=None): if exists(self.token_emb): feats = self.token_emb(feats) assert not ( exists(edges) and not exists(self.edge_emb) ), 'edge embedding (num_edge_tokens & edge_dim) must be supplied if one were to train on edge types' if exists(edges): edges = self.edge_emb(edges) if torch.is_tensor(feats): feats = {'0': feats[..., None]} b, n, d, *_, device = *feats['0'].shape, feats['0'].device assert d == self.dim, f'feature dimension {d} must be equal to dimension given at init {self.dim}' assert set(map(int, feats.keys())) == set( range(self.input_degrees )), f'input must have {self.input_degrees} degree' num_degrees, neighbors = self.num_degrees, self.num_neighbors neighbors = min(neighbors, n - 1) # exclude edge of token to itself exclude_self_mask = rearrange( ~torch.eye(n, dtype=torch.bool, device=device), 'i j -> () i j') indices = repeat(torch.arange(n, device=device), 'i -> b i j', b=b, j=n) rel_pos = rearrange(coors, 'b n d -> b n () d') - rearrange( coors, 'b n d -> b () n d') indices = indices.masked_select(exclude_self_mask).reshape(b, n, n - 1) rel_pos = rel_pos.masked_select(exclude_self_mask[..., None]).reshape( b, n, n - 1, 3) if exists(mask): mask = rearrange(mask, 'b i -> b i ()') * rearrange( mask, 'b j -> b () j') mask = mask.masked_select(exclude_self_mask).reshape(b, n, n - 1) if exists(edges): edges = edges.masked_select(exclude_self_mask[..., None]).reshape( b, n, n - 1, -1) rel_dist = rel_pos.norm(dim=-1) # get neighbors and neighbor mask, excluding self neighbor_rel_dist, nearest_indices = rel_dist.topk(neighbors, dim=-1, largest=False) neighbor_rel_pos = batched_index_select(rel_pos, nearest_indices, dim=2) neighbor_indices = batched_index_select(indices, nearest_indices, dim=2) basis = get_basis(neighbor_rel_pos, num_degrees - 1, differentiable=self.differentiable_coors) neighbor_mask = neighbor_rel_dist <= self.valid_radius if exists(mask): neighbor_mask = neighbor_mask & batched_index_select( mask, nearest_indices, dim=2) if exists(edges): edges = batched_index_select(edges, nearest_indices, dim=2) # main logic edge_info = (neighbor_indices, neighbor_mask, edges) x = feats # project in x = self.conv_in(x, edge_info, rel_dist=neighbor_rel_dist, basis=basis) # transformer layers x = self.net(x, edge_info=edge_info, rel_dist=neighbor_rel_dist, basis=basis) # project out x = self.conv_out(x, edge_info, rel_dist=neighbor_rel_dist, basis=basis) # norm x = self.norm(x) # reduce dim if specified if exists(self.linear_out): x = self.linear_out(x) x = {k: v.squeeze(dim=2) for k, v in x.items()} if exists(return_type): return x[str(return_type)] return x