def test(model, test_loader, criterion, to_log=None): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for i, conc_data in enumerate(test_loader): # prepare input and target to device h_data, p_data = conc_data azi, ele, target = h_data azi = torch.permute(azi, (0, 2, 1, 3, 4)) ele = torch.permute(ele, (0, 2, 1, 3, 4)) phase, _ = p_data azi = azi.to(device, dtype=torch.float) ele = ele.to(device, dtype=torch.float) phase = phase.to(device, dtype=torch.float) target = target.to(device, dtype=torch.long) output = model(azi, ele, phase) loss = criterion(output, target) test_loss += loss pred = output.argmax( dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.sampler) test_loss *= test_loader.batch_size acc = 100. * correct / len(test_loader.sampler) format_str = 'Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( test_loss, correct, len(test_loader.sampler), acc) print(format_str) if to_log is not None: write_log(format_str, to_log) return test_loss.item(), acc
def __extract_h_whisker(self, feature_map, h_map): """Extract features from feature map using bilinear interpolation of whisker points. Parameters ========== feature_map : torch.Tensor Has shape (B, C, H, W). x_tml_grid : torch.Tensor Has shape (B, K, A, D=2). Returns ======= torch.Tensor Has shape (B, (n whiskers)*(n points per whisker)*C). """ B, _, _, _ = h_map.shape # whisker_points has shape (B, n whiskers, n points per whisker, 2) whisker_points = util.torchu.expand_and_repeat(self.whisker_points, 0, B) # h_whisker has shape (B, C, n whiskers, n points per whisker) h_whisker = F.grid_sample(feature_map, whisker_points) # h_whisker has shape (B, (n whiskers)*(n points per whisker)*C) h_whisker = torch.permute(h_whisker, (0, 2, 3, 1)).reshape(B, -1) return h_whisker
def forward(self, inputs: torch.Tensor, training=None, mask=None): # inputs.shape = [batch_size, seq_len, embedding_dim] batch_size = inputs.shape[0] query = self.query_dense(inputs) # (batch_size, seq_len, h) key = self.key_dense(inputs) # (batch_size, seq_len, h) value = self.value_dense(inputs) # (batch_size, seq_len, h) query = self.separate_heads( query, batch_size) # (batch_size, num_heads, seq_len, projection_dim) key = self.separate_heads( key, batch_size) # (batch_size, num_heads, seq_len, projection_dim) value = self.separate_heads( value, batch_size) # (batch_size, num_heads, seq_len, projection_dim) outputs, weights = self.attention(query, key, value, mask=mask) outputs = torch.permute( outputs, (0, 2, 1, 3)) # (batch_size, seq_len, num_heads, projection_dim) concat_outputs = torch.reshape( outputs, (batch_size, -1, self.embedding_size)) # (batch_size, seq_len, h) projected_outputs = self.combine_heads( concat_outputs) # (batch_size, seq_len, h) return projected_outputs
def forward(self, input): # Prepare attributes input_shape = list(map(int, list(input.shape))) block_shape = self.block_shape crop = self.crop # number of spatial dimensions m = len(block_shape) # rest of dimensions n = len(input.shape) - m # output batch size batch_size = input_shape[0] // np.product(block_shape) unfolded_shape = list(block_shape) + [batch_size] + input_shape[1:] fold_shape = [batch_size] + input_shape[1:n] + [ input_shape[i + n] * block_shape[i] for i in range(m) ] permute_dims = list(range( m, m + n)) + [i + mod for i in range(m) for mod in [n + m, 0]] # Actual model starts here unfolded_input = input.reshape(unfolded_shape) permuted = torch.permute(unfolded_input, permute_dims) full_output = permuted.reshape(fold_shape) # crop output tensor crop_output = full_output for i in range(m): crop_size = sum(crop[i]) crop_output = crop_output.narrow(i + n, crop[i][0], fold_shape[i + n] - crop_size) return crop_output
def get_predictions(input_sents, model, tokenizer, k=5, bert=True): token_preds = [] tok_probs = [] for tokens_tensor, mi, tokensized_text in prep_input(input_sents, tokenizer, bert=bert): tokens_tensor = tokens_tensor.to(device) with torch.no_grad(): predictions = model(tokens_tensor).logits predictions = torch.permute(predictions, (1, 0, 2)) predicted_tokens = [] predicted_token_probs = [] if bert: softpred = torch.softmax(predictions[0, mi], 0) else: softpred = torch.softmax(predictions[0, mi, :], 0) top_inds = torch.argsort(softpred, descending=True)[:k].cpu().numpy() top_probs = [softpred[tgt_ind].item() for tgt_ind in top_inds] top_tok_preds = tokenizer.convert_ids_to_tokens(top_inds) if not bert: top_tok_preds = [re.sub('\<\/w\>', '', e) for e in top_tok_preds] token_preds.append(top_tok_preds) tok_probs.append(top_probs) return token_preds, tok_probs
def forward(self, input): # Prepare attributes input_shape = list(map(int, list(input.shape))) block_shape = self.block_shape pad = self.pad # number of spatial dimensions m = len(block_shape) # rest of dimensions n = len(input.shape) - m # output batch size batch_size = input_shape[0] out_spatial_dim = [ (input_shape[i + n] + pad[i * 2] + pad[i * 2 + 1]) // block_shape[i] for i in range(m) ] unfolded_shape = [batch_size] + input_shape[1:n] + [ dim for i in range(m) for dim in [out_spatial_dim[i], block_shape[i]] ] fold_shape = [batch_size * np.prod(block_shape) ] + input_shape[1:n] + out_spatial_dim permute_dims = list(range(n + 1, n + 2 * m, 2)) + list( range(n)) + list(range(n, n + 2 * m, 2)) # Actual model starts here padded_input = torch.nn.functional.pad(input, pad) unfolded_input = padded_input.reshape(unfolded_shape) permuted = torch.permute(unfolded_input, permute_dims) output = permuted.reshape(fold_shape) return output
def test(model, test_loader, criterion, to_log=None): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for (data, target) in test_loader: target = target.long() data = torch.permute(data, (0, 2, 1, 3, 4)) data, target = data.to(device), target.to(device) output = model(data) loss = criterion(output, target) test_loss += loss pred = output.argmax( dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.sampler) test_loss *= test_loader.batch_size acc = 100. * correct / len(test_loader.sampler) format_str = 'Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( test_loss, correct, len(test_loader.sampler), acc) print(format_str) if to_log is not None: write_log(format_str, to_log) return test_loss.item(), acc
def forward(self, inputs: Dict) -> Dict: # encoder outputs encoder_outputs = [inputs[k]["encoder_output"] for k in inputs] # ================ Flatten ================ batch_size = encoder_outputs[0].shape[0] encoder_outputs = [ torch.reshape(eo, [batch_size, -1]) for eo in encoder_outputs ] # ================ Project ================ projected = [ self.projectors[i](eo) for i, eo in enumerate(encoder_outputs) ] hidden = torch.stack(projected) hidden = torch.permute(hidden, (1, 0, 2)) # shape [bs, num_eo, h] # ================ Aggregate ================ hidden = torch.mean(hidden, dim=1) # ================ Fully Connected ================ if self.fc_stack is not None: hidden = self.fc_stack(hidden) return_data = {"combiner_output": hidden} if len(inputs) == 1: # Workaround for including additional tensors from output of input encoders for # potential use in decoders, e.g. LSTM state for seq2seq. # TODO(Justin): Think about how to make this communication work for multi-sequence # features. Other combiners. for key, value in [d for d in inputs.values()][0].items(): if key != "encoder_output": return_data[key] = value return return_data
def m_probs(self) -> torch.Tensor: r""" Posterior spot presence probability :math:`q(m=1, z=z_\mathsf{MAP})`. """ return Vindex( torch.permute(pyro.param("m_probs").data, (1, 2, 3, 4, 0)))[..., self.z_map.long()]
def tensor_indexing_ops(self): x = torch.randn(2, 4) y = torch.randn(2, 4, 2) t = torch.tensor([[0, 0], [1, 0]]) mask = x.ge(0.5) i = [0, 1] return ( torch.cat((x, x, x), 0), torch.concat((x, x, x), 0), torch.conj(x), torch.chunk(x, 2), torch.dsplit(y, i), torch.column_stack((x, x)), torch.dstack((x, x)), torch.gather(x, 0, t), torch.hsplit(x, i), torch.hstack((x, x)), torch.index_select(x, 0, torch.tensor([0, 1])), torch.masked_select(x, mask), torch.movedim(x, 1, 0), torch.moveaxis(x, 1, 0), torch.narrow(x, 0, 0, 2), torch.nonzero(x), torch.permute(x, (0, 1)), torch.reshape(x, (-1, )), )
def forward(self, x): [batch, channel, length, width, height] = x.shape x = self.physnet_lstm['cnn_blocks'](x) x,_ = self.physnet_lstm['cov_lstm'](x) x = torch.permute(x[0], (0, 2, 1, 3, 4)) x = self.physnet_lstm['cnn_flatten'](x) return x.view(-1, length)
def forward(self, x): #bs, c_in, ts, n_vertex = x.shape x = torch.permute(x, (0, 2, 3, 1)) if self.Ks - 1 < 0: raise ValueError(f'ERROR: the graph convolution kernel size Ks has to be a positive integer, but received {self.Ks}.') elif self.Ks - 1 == 0: x_0 = x x_list = [x_0] elif self.Ks - 1 == 1: x_0 = x x_1 = torch.einsum('hi,btij->bthj', self.gso, x) x_list = [x_0, x_1] elif self.Ks - 1 >= 2: x_0 = x x_1 = torch.einsum('hi,btij->bthj', self.gso, x) x_list = [x_0, x_1] for k in range(2, self.Ks): x_list.append(torch.einsum('hi,btij->bthj', 2 * self.gso, x_list[k - 1]) - x_list[k - 2]) x = torch.stack(x_list, dim=2) cheb_graph_conv = torch.einsum('btkhi,kij->bthj', x, self.weight) if self.bias is not None: cheb_graph_conv = torch.add(cheb_graph_conv, self.bias) else: cheb_graph_conv = cheb_graph_conv return cheb_graph_conv
def forward(self, pred, pred_semseg, label_info, label_i=0, sparse=True, shape=(30, 30)): targe = label_info['disp_label_0'] if sparse: targe = F.adaptive_max_pool2d(targe, shape) #calucate disp output bs, num_class, maxdisp, h, w = pred.shape pred = torch.permute(0, 3, 4, 1, 2) pred = pred.view(-1, num_class, maxdisp) disp_scans = pred_semseg.view(-1) disp_scans.requires_grad = False pred_disp = pred[torch.arange(bs), disp_scans] disp_scans = label_info['disp_scans'][0].view(1, maxdisp) disp_pred = torch.sum(pred_disp * disp_scans, dim=1) disp_pred = disp_pred.view(bs, h, w) #loss and pred EPE_map = self.SmoothL1Loss(disp_pred, targe) epe_pred = torch.abs(disp_pred - targe) #ignore false disp values positive = targe.ge(0) EPE_map = torch.masked_select(EPE_map, positive) epe_pred = torch.masked_select(epe_pred, positive) #normlization loss = EPE_map.mean() epe_pred = epe_pred.mean() return loss, epe_pred
def forward( self, inputs, # encoder outputs ) -> Dict: encoder_outputs = [inputs[k]["encoder_output"] for k in inputs] # ================ Flatten ================ batch_size = encoder_outputs[0].shape[0] encoder_outputs = [torch.reshape(eo, [batch_size, -1]) for eo in encoder_outputs] # ================ Project & Concat ================ projected = [self.projectors[i](eo) for i, eo in enumerate(encoder_outputs)] hidden = torch.stack(projected) # shape [num_eo, bs, h] hidden = torch.permute(hidden, (1, 0, 2)) # shape [bs, num_eo, h] # ================ Transformer Layers ================ hidden = self.transformer_stack(hidden) # ================ Sequence Reduction ================ if self.reduce_output is not None: hidden = self.reduce_sequence(hidden) # ================ FC Layers ================ hidden = self.fc_stack(hidden) return_data = {"combiner_output": hidden} if len(inputs) == 1: for key, value in [d for d in inputs.values()][0].items(): if key != "encoder_output": return_data[key] = value return return_data
def evaluate(model, resdir, testloader): # load weights cmdir = os.path.join(resdir, 'cm.pdf') logdir = os.path.join(resdir, 'cm_log.txt') model_path = os.path.join(resdir, 'best.pth.tar') # load weights model.load_state_dict(torch.load(model_path)) model = model.to(device) # test all_pred = [] all_target = [] test_loss = 0 with torch.no_grad(): for (azi, ele, target) in testloader: azi = torch.permute(azi, (0, 2, 1, 3, 4)) ele = torch.permute(ele, (0, 2, 1, 3, 4)) # prepare input and target to device azi = azi.to(device, dtype=torch.float) ele = ele.to(device, dtype=torch.float) target = target.to(device, dtype=torch.long) output = model(azi, ele) pred = output.argmax( dim=1, keepdim=True) # get the index of the max log-probability pred = pred.cpu().numpy().flatten() target = target.cpu().numpy().flatten() all_pred = np.concatenate((all_pred, pred), axis=0) all_target = np.concatenate((all_target, target), axis=0) # print write_log( classification_report(all_target, all_pred, target_names=emotion_list), logdir) cm = confusion_matrix(all_target, all_pred) ax = sns.heatmap(cm, annot=True, cmap='Blues') ax.set_title('Confusion Matrix\n\n') ax.set_xlabel('\nPredicted Classes') ax.set_ylabel('Actual Classes') ## Ticket labels - List must be in alphabetical order ax.xaxis.set_ticklabels(emotion_list) ax.yaxis.set_ticklabels(emotion_list) plt.savefig(cmdir)
def pointcloud_project(cfg, point_cloud, transform, sigma): tr_pc = pc_perspective_transform(cfg, point_cloud, transform) voxels = pointcloud2voxels(cfg, tr_pc, sigma) voxels = torch.permute(voxels, [0, 2, 1, 3, 4]) proj, probs = util.drc_pytorch.drc_projection(voxels, cfg) proj = torch.flip(proj, [1]) return proj, voxels
def __call__(self, pred, device='cpu'): pred = torch.rand(size=(3, 5, 4)) # agent, states, actions idx = torch.randint(low=0, high=pred.shape[0], size=pred.shape[1:]).to(pred.device) idx_ohe = one_hot_encoding(idx, n_categories=pred.shape[0], unsqueeze=True) return (pred * torch.permute(idx_ohe, [2, 0, 1])).sum(0)
def tensor_indexing_ops(self): x = torch.randn(2, 4) y = torch.randn(4, 4) t = torch.tensor([[0, 0], [1, 0]]) mask = x.ge(0.5) i = [0, 1] return len( torch.cat((x, x, x), 0), torch.concat((x, x, x), 0), torch.conj(x), torch.chunk(x, 2), torch.dsplit(torch.randn(2, 2, 4), i), torch.column_stack((x, x)), torch.dstack((x, x)), torch.gather(x, 0, t), torch.hsplit(x, i), torch.hstack((x, x)), torch.index_select(x, 0, torch.tensor([0, 1])), x.index(t), torch.masked_select(x, mask), torch.movedim(x, 1, 0), torch.moveaxis(x, 1, 0), torch.narrow(x, 0, 0, 2), torch.nonzero(x), torch.permute(x, (0, 1)), torch.reshape(x, (-1, )), torch.row_stack((x, x)), torch.select(x, 0, 0), torch.scatter(x, 0, t, x), x.scatter(0, t, x.clone()), torch.diagonal_scatter(y, torch.ones(4)), torch.select_scatter(y, torch.ones(4), 0, 0), torch.slice_scatter(x, x), torch.scatter_add(x, 0, t, x), x.scatter_(0, t, y), x.scatter_add_(0, t, y), # torch.scatter_reduce(x, 0, t, reduce="sum"), torch.split(x, 1), torch.squeeze(x, 0), torch.stack([x, x]), torch.swapaxes(x, 0, 1), torch.swapdims(x, 0, 1), torch.t(x), torch.take(x, t), torch.take_along_dim(x, torch.argmax(x)), torch.tensor_split(x, 1), torch.tensor_split(x, [0, 1]), torch.tile(x, (2, 2)), torch.transpose(x, 0, 1), torch.unbind(x), torch.unsqueeze(x, -1), torch.vsplit(x, i), torch.vstack((x, x)), torch.where(x), torch.where(t > 0, t, 0), torch.where(t > 0, t, t), )
def transpose_int(self: Tensor, dim0: int, dim1: int) -> Tensor: dim0, dim1 = utils.canonicalize_dims(self.dim(), (dim0, dim1)) # type: ignore[misc] if self.dim() <= 1: return self if dim0 == dim1: return self perm = list(range(self.dim())) perm[dim0], perm[dim1] = perm[dim1], perm[dim0] return torch.permute(self, perm)
def forward(self, x): #bs, c_in, ts, n_vertex = x.shape x = torch.permute(x, (0, 2, 3, 1)) first_mul = torch.einsum('hi,btij->bthj', self.gso, x) second_mul = torch.einsum('bthi,ij->bthj', first_mul, self.weight) if self.bias is not None: graph_conv = torch.add(second_mul, self.bias) else: graph_conv = second_mul return graph_conv
def forward(self, input, length): """ DO NOT MODIFY FUNCTION SIGNATURE TODO: Create the forward pass through the network. """ input = torch.permute(0, 2, 1) x = self.max_pool(self.rlu(self.conv(input))) x = self.max_pool(self.rlu(self.conv(x))) x = self.rlu(self.conv(x)) x = torch.max(x, 2) #L = (((length+3)/4+3)/4+3)*50 x = tnn.Linear(50, 1) x = x.view(-1) return x
def transpose_int(self: Tensor, dim0: int, dim1: int) -> Tensor: dim0, dim1 = utils.canonicalize_dims(self.dim(), (dim0, dim1)) # type: ignore[misc] # NB: these no-op views force this operator to return a # fresh TensorImpl, which is important for autograd to # work correctly (assert will fail if you don't do it) if self.dim() <= 1: return self.view(self.shape) if dim0 == dim1: return self.view(self.shape) perm = list(range(self.dim())) perm[dim0], perm[dim1] = perm[dim1], perm[dim0] return torch.permute(self, perm)
def contractive_loss(self): W_1 = [] W_2 = [] for layer in self.encoder1.children(): # convolution filters if layer is nn.Conv2d: param = next(layer.parameters()) if len(param.size()) == 4: W_1.append((param**2).sum(dim=[1, 2, 3])) for layer in self.encoder2.children(): # convolution filters if layer is nn.Conv2d: param = next(layer.parameters()) if len(param.size()) == 4: W_2.append((param**2).sum(dim=[1, 2, 3])) contractive_loss_1 = 0.0 for w, latent in zip(W_1, self.latent1): dlatent = latent * (1 - latent) c = dlatent.size(1) contractive_loss_1 += torch.mm( torch.permute(dlatent**2, [0, 2, 3, 1]).view(-1, c), w) contractive_loss_2 = 0.0 for w, latent in zip(W_2, self.latent2): dlatent = latent * (1 - latent) c = dlatent.size(1) contractive_loss_2 += torch.mm( torch.permute(dlatent**2, [0, 2, 3, 1]).view(-1, c), w) contractive_loss = contractive_loss_1 + contractive_loss_2 return contractive_loss
def forward(self, x, y, z): if torch.__version__ < '1.9': x = x.permute(1, 0, 2) x = x.permute(0, 1, 2) y = y.permute(2, 3, 1, 0) y = y.permute(3, 1, 0, 2) z = z.permute(1, 3, 0, 4, 2) z = z.permute(0, 2, 4, 3, 1) else: x = torch.permute(x, (1, 0, 2)) x = torch.permute(x, (0, 1, 2)) y = torch.permute(y, (2, 3, 1, 0)) y = torch.permute(y, (3, 1, 0, 2)) z = torch.permute(z, (1, 3, 0, 4, 2)) z = torch.permute(z, (0, 2, 4, 3, 1)) return x, y, z
def __extract_map_features(self, feature_map, x_tml_grid): """Extract features from feature map using bilinear interpolation of car points. Parameters ========== feature_map : torch.Tensor Has shape (B, C, H, W). x_tml_grid : torch.Tensor Has shape (B, K, A, D=2). Returns ======= torch.Tensor Has shape (B, K, A, C). """ # x_tml_grid = x_tml_grid.reshape(-1, D) # h_map has shape (B, C, K, A) h_map = F.grid_sample(feature_map, x_tml_grid) # h_map has shape (B, K, A, C) h_map = torch.permute(h_map, (0, 2, 3, 1)) return h_map
def get_probabilities(input_sents, tgtlist, model, tokenizer, bert=True): token_probs = [] for i, (tokens_tensor, mi, _) in enumerate(prep_input(input_sents, tokenizer, bert=bert)): tokens_tensor = tokens_tensor.to(device) with torch.no_grad(): predictions = model(tokens_tensor).logits predictions = torch.permute(predictions, (1, 0, 2)) tgt = tgtlist[i] if bert: softpred = torch.softmax(predictions[0, mi], 0) else: softpred = torch.softmax(predictions[0, mi, :], 0) try: tgt_ind = tokenizer.convert_tokens_to_ids([tgt])[0] except: this_tgt_prob = np.nan else: this_tgt_prob = softpred[tgt_ind].item() token_probs.append(this_tgt_prob) return token_probs
def __getitem__(self, idx: int): input_ID = self.inputs[idx] label_ID = self.labels[idx] x, y = imread(input_ID), imread(label_ID) if self.transform is not None: x, y = self.transform(x, y) # y = y[np.newaxis, :] y_tmp = np.zeros([11, 360, 480]) for i in range(11): if i == 0: continue y_tmp[i - 1, :, :] = y == i y = y_tmp x, y = torch.from_numpy(x).type( self.inputs_dtype), torch.from_numpy(y).type(self.labels_dtype) x = torch.permute(x, [2, 0, 1]) return x, y
def forward(self, instance_prob, x=None): """ :param instance_prob: :param x: should specify x if self.pooling='att', x.shape=(B, C, 1, T) :return: """ if self.pooling == 'max': bag_prob, _ = instance_prob.max(dim=1) return bag_prob, instance_prob elif self.pooling == 'ave': bag_prob = instance_prob.mean(dim=1) return bag_prob, instance_prob elif self.pooling == 'lin': bag_prob = (instance_prob * instance_prob).sum(dim=1) / instance_prob.sum(dim=1) return bag_prob, instance_prob elif self.pooling == 'exp': bag_prob = (instance_prob * instance_prob.exp()).sum(dim=1) / instance_prob.exp().sum(dim=1) return bag_prob, instance_prob elif self.pooling == 'att': x = x.view(x.size(0), x.size(1), int(x.size(2) * x.size(3))) # (B, C, F, T) -> (B, C, F*T) x = torch.permute(0, 2, 1) # (B, C, F*T) -> (B, F*T/nb_ins, C) instance_att = F.softmax(self.fc_att(x), dim=1) # (Batch, nb_ins, feature_dim) -> (B, nb_ins, nb_class) bag_prob = (instance_prob * instance_att).sum(dim=1) return bag_prob, instance_prob, instance_att
def forward(self, x, y, z): if torch.__version__ < '1.9': x = x.permute(1, 0) x = x.permute(0, 1) y = y.permute(2, 1, 0) y = y.permute(1, 0, 2) z = z.permute(1, 3, 0, 2) z = z.permute(2, 0, 3, 1) else: x = torch.permute(x, (1, 0)) x = torch.permute(x, (0, 1)) y = torch.permute(y, (2, 1, 0)) y = torch.permute(y, (1, 0, 2)) z = torch.permute(z, (1, 3, 0, 2)) z = torch.permute(z, (2, 0, 3, 1)) x = F.relu(x) y = F.relu(y) z = F.relu(z) return x, y, z
def forward( self, inputs: Dict, # encoder outputs ) -> Dict: unembeddable_encoder_outputs = [inputs[k]["encoder_output"] for k in inputs if k in self.unembeddable_features] embeddable_encoder_outputs = [inputs[k]["encoder_output"] for k in inputs if k in self.embeddable_features] batch_size = ( embeddable_encoder_outputs[0].shape[0] if len(embeddable_encoder_outputs) > 0 else unembeddable_encoder_outputs[0].shape[0] ) # ================ Project & Concat embeddables ================ if len(embeddable_encoder_outputs) > 0: # ============== Flatten ================= embeddable_encoder_outputs = [torch.reshape(eo, [batch_size, -1]) for eo in embeddable_encoder_outputs] projected = [self.projectors[i](eo) for i, eo in enumerate(embeddable_encoder_outputs)] hidden = torch.stack(projected) # num_eo, bs, h hidden = torch.permute(hidden, (1, 0, 2)) # bs, num_eo, h if self.embed_input_feature_name: i_f_names_idcs = torch.reshape(torch.arange(0, len(embeddable_encoder_outputs)), [-1, 1]) embedded_i_f_names = self.embed_i_f_name_layer(i_f_names_idcs) embedded_i_f_names = torch.unsqueeze(embedded_i_f_names, dim=0) embedded_i_f_names = torch.tile(embedded_i_f_names, [batch_size, 1, 1]) if self.embed_input_feature_name == "add": hidden = hidden + embedded_i_f_names else: hidden = torch.cat([hidden, embedded_i_f_names], -1) # ================ Transformer Layers ================ hidden = self.transformer_stack(hidden) # ================ Sequence Reduction ================ hidden = self.reduce_sequence(hidden) else: # create empty tensor because there are no category features hidden = torch.empty([batch_size, 0]) # ================ Concat Skipped ================ if len(unembeddable_encoder_outputs) > 0: unembeddable_encoder_outputs = [torch.reshape(eo, [batch_size, -1]) for eo in unembeddable_encoder_outputs] # ================ Flatten ================ if len(unembeddable_encoder_outputs) > 1: unembeddable_hidden = torch.cat(unembeddable_encoder_outputs, -1) # tf.keras.layers.concatenate else: unembeddable_hidden = list(unembeddable_encoder_outputs)[0] unembeddable_hidden = self.layer_norm(unembeddable_hidden) else: # create empty tensor because there are not numeric/binary features unembeddable_hidden = torch.tile(self.empty_hidden, [batch_size, 0]) # ================ Concat Skipped and Others ================ hidden = torch.cat([hidden, unembeddable_hidden], -1) # ================ FC Layers ================ hidden = self.fc_stack(hidden) return_data = {"combiner_output": hidden} if len(inputs) == 1: for key, value in [d for d in inputs.values()][0].items(): if key != "encoder_output": return_data[key] = value return return_data