def forward(self, feat_emb_value): # square_of_sum = torch.sum(feat_emb_value, dim=1) # N * emb_dim # square_of_sum = torch.mul(square_of_sum, square_of_sum) # N * emb_dim # sum_of_square = torch.mul(feat_emb_value, feat_emb_value) # N * num_fields * emb_dim # sum_of_square = torch.sum(sum_of_square, dim=1) # N * emb_dim # bi_out = square_of_sum - sum_of_square bi_out = bi_interaction(feat_emb_value) return bi_out
def forward(self, feat_index): feat_emb = self.emb_layer(feat_index) # N * num_categories * emb_dim field_wise_emb_list = [ feat_emb[:, field_range] # N * num_categories_in_field * emb_dim for field_range in self.field_ranges ] field_emb_list = [ torch.sum(field_wise_emb, dim=1).unsqueeze(dim=1) # N * emb_dim for field_wise_emb in field_wise_emb_list ] field_emb = torch.cat(field_emb_list, dim=1) # N * num_fields * emb_dim # S part y_S = self.first_order_weights(feat_index) # N * num_categories * 1 y_S = y_S.squeeze() # N * num_categories y_S = torch.sum(y_S, dim=1) # N y_S = torch.add(y_S, self.first_order_bias) # N y_S = y_S.unsqueeze(dim=1) # N * 1 # MF part -> N * emb_dim p, q = build_cross(self.num_fields, field_emb) # N * num_pairs * emb_dim y_MF = torch.mul(p, q) # N * num_pairs * emb_dim y_MF = torch.mul(y_MF, self.r_mf) # N * num_pairs * emb_dim y_MF = torch.sum(y_MF, dim=1) # N * emb_dim # FM part field_wise_fm = [ bi_interaction(field_wise_emb).unsqueeze(dim=1) # N * 1 * emb_dim for field_wise_emb in field_wise_emb_list ] field_wise_fm = torch.cat(field_wise_fm, dim=1) # N * num_fields * emb_dim y_FM = torch.mul(field_wise_fm, self.r_fm) # N * num_fields * emb_dim y_FM = torch.sum(y_FM, dim=1) # N * emb_dim # dnn fc_in = field_emb.reshape((-1, self.num_fields * self.emb_dim)) y_dnn = self.fc_layers(fc_in) # output fwBI = y_MF + y_FM fwBI = torch.cat([y_S, fwBI], dim=1) # N * (emb_dim + 1) y = torch.cat([fwBI, y_dnn], dim=1) # N * (fc_dims[-1] + emb_dim + 1) y = self.output_layer(y) return y
def forward(self, feat_index, feat_value): # With single sample, it should be expanded into 1 * F * K # Batch_size: N # feat_index_dim&feat_value_dim: F # embedding_dim: K # feat_index: N * F # feat_value: N * F # compute first order feat_value = torch.unsqueeze(feat_value, dim=2) # N * F * 1 first_order_weights = self.first_order_weights(feat_index) # N * F * 1 first_order = torch.mul(feat_value, first_order_weights) # N * F * 1 first_order = torch.squeeze(first_order, dim=2) # N * F y_first_order = torch.sum(first_order, dim=1) # N # compute second order # look up embedding table feat_emb = self.emb_layer(feat_index) # N * F * K feat_emb_value = torch.mul(feat_emb, feat_value) # N * F * K element-wise mul # compute sum of square # squared_feat_emb = torch.pow(feat_emb_value, 2) # N * K # sum_of_square = torch.sum(squared_feat_emb, dim=1) # N * K # # # compute square of sum # summed_feat_emb = torch.sum(feat_emb_value, dim=1) # N * K # square_of_sum = torch.pow(summed_feat_emb, 2) # N * K BI = bi_interaction(feat_emb_value) y_second_order = 0.5 * BI # N * K y_second_order = torch.sum(y_second_order, dim=1) # N # compute y y = self.bias + y_first_order + y_second_order # N y = self.output_layer(y) return y