def test_unique(self): N = 100 nsampled = 50 RND = 100 sampler = LogUniformSampler(N) histogram = [0 for idx in range(N)] all_values = [idx for idx in range(N)] sample_ids, expected, sample_freq = sampler.sample( nsampled, np.asarray(all_values, dtype=np.int32)) sample_set = set(sample_ids) self.assertEqual(len(sample_set), nsampled) for rnd in range(RND): sample_ids, true_freq, sample_freq = sampler.sample( nsampled, np.asarray(all_values, dtype=np.int32)) for idx in range(N): self.assertTrue( EXPECT_NEAR(expected[idx], true_freq[idx], expected[idx] * 0.5)) for idx in range(nsampled): histogram[sample_ids[idx]] += 1 for idx in range(N): average_count = histogram[idx] / RND self.assertTrue(EXPECT_NEAR(expected[idx], average_count, 0.2))
class SampledSoftmax(nn.Module): def __init__(self, ntokens, nsampled, nhid, tied_weight): super(SampledSoftmax, self).__init__() # Parameters self.ntokens = ntokens self.nsampled = nsampled self.sampler = LogUniformSampler(self.ntokens) self.params = nn.Linear(nhid, ntokens) if tied_weight is not None: self.params.weight = tied_weight else: util.initialize(self.params, self.ntokens) self.params.bias.data.fill_(0) def forward(self, inputs, labels, train=False): if train: sample_values = self.sampler.sample(self.nsampled, labels.data.cpu().numpy()) return self.sampled(inputs, labels, sample_values) else: return self.full(inputs, labels) def sampled(self, inputs, labels, sample_values): assert(inputs.data.get_device() == labels.data.get_device()) device_id = labels.data.get_device() batch_size, d = inputs.size() sample_ids, true_freq, sample_freq = sample_values # sample ids according to word distribution - Unique sample_ids = Variable(torch.LongTensor(sample_ids), requires_grad=False).cuda(device_id) true_freq = Variable(torch.FloatTensor(true_freq), requires_grad=False).cuda(device_id) sample_freq = Variable(torch.FloatTensor(sample_freq), requires_grad=False).cuda(device_id) # gather true labels - weights and frequencies true_weights = self.params.weight[labels.data, :] true_bias = self.params.bias[labels.data] # gather sample ids - weights and frequencies sample_weights = self.params.weight[sample_ids.data, :] sample_bias = self.params.bias[sample_ids.data] # calculate logits true_logits = torch.sum(torch.mul(inputs, true_weights), dim=1) + true_bias sample_logits = torch.matmul(inputs, torch.t(sample_weights)) + sample_bias # perform correction true_logits = true_logits.sub(torch.log(true_freq)) sample_logits = sample_logits.sub(torch.log(sample_freq)) # return logits and new_labels logits = torch.cat((torch.unsqueeze(true_logits, dim=1), sample_logits), dim=1) new_targets = Variable(torch.zeros(batch_size).long(), requires_grad=False).cuda(device_id) return logits, new_targets def full(self, inputs, labels): return self.params(inputs), labels
class SampledSoftmax(nn.Module): def __init__(self, ntokens, nsampled, nhid, device): super(SampledSoftmax, self).__init__() # Parameters self.ntokens = ntokens self.nsampled = nsampled self.device = device # self.sampler = LogUniformSampler(self.ntokens) # self.weight = nn.Parameter(torch.Tensor(ntokens, nhid)) self.reset_parameters() def reset_parameters(self): stdv = math.sqrt(6.0 / (self.weight.size(0) + self.weight.size(1))) self.weight.data.uniform_(-stdv, stdv) def forward(self, inputs, labels): # sample ids according to word distribution - Unique sample_values = self.sampler.sample(self.nsampled, labels.data.cpu().numpy()) return self.sampled(inputs, labels, sample_values, remove_accidental_match=True) """@Dai Quoc Nguyen: Implement the sampled softmax loss function as described in the paper On Using Very Large Target Vocabulary for Neural Machine Translation https://www.aclweb.org/anthology/P15-1001/""" def sampled(self, inputs, labels, sample_values, remove_accidental_match=False): assert (inputs.data.get_device() == labels.data.get_device()) batch_size, d = inputs.size() sample_ids, true_freq, sample_freq = sample_values sample_ids = Variable(torch.LongTensor(sample_ids)).to(self.device) # gather true labels true_weights = torch.index_select(self.weight, 0, labels) # gather sample ids sample_weights = torch.index_select(self.weight, 0, sample_ids) # calculate logits true_logits = torch.exp( torch.sum(torch.mul(inputs, true_weights), dim=1)) sample_logits = torch.exp(torch.matmul(inputs, torch.t(sample_weights))) logits = -torch.log(true_logits / torch.sum(sample_logits, dim=1)) return logits
def test_AccidentalMatch(self): np.random.seed(1000) num_classes = 5 batch_size = 3 nsampled = 4 nhid = 10 labels = np.random.randint(low=0, high=num_classes, size=batch_size) (weights, biases, hidden_acts, sampled_vals, exp_logits, exp_labels) = self._GenerateTestData(num_classes=num_classes, dim=nhid, batch_size=batch_size, num_true=1, labels=labels, sampled=[1, 0, 2, 3], subtract_log_q=True) ss = model.SampledSoftmax(num_classes, nsampled, nhid, tied_weight=None) ss.params.weight.data = torch.from_numpy(weights) ss.params.bias.data = torch.from_numpy(biases) ss.params.cuda() hidden_acts = Variable(torch.from_numpy(hidden_acts)).cuda() labels = Variable(torch.LongTensor(labels)).cuda() sampler = LogUniformSampler(nsampled) sampled_values = sampler.sample(nsampled, labels.data.cpu().numpy()) sample_ids, true_freq, sample_freq = sampled_values logits, new_targets = ss.sampled(hidden_acts, labels, sampled_values, remove_accidental_match=True) criterion = nn.CrossEntropyLoss() loss = criterion(logits.view(-1, nsampled + 1), new_targets) np_logits = logits.data.cpu().numpy() for row in range(batch_size): label = labels[row] for col in range(nsampled): if sample_ids[col] == label: self.assertTrue( EXPECT_NEAR(np.exp(np_logits[row, col + 1]), 0, 1e-4))
class SampledSoftmax(nn.Module): def __init__(self, ntokens, nsampled, nhid, tied_weight): super(SampledSoftmax, self).__init__() # Parameters self.ntokens = ntokens self.nsampled = nsampled self.sampler = LogUniformSampler(self.ntokens) self.params = nn.Linear(nhid, ntokens) if tied_weight is not None: self.params.weight = tied_weight else: util.initialize(self.params.weight) def forward(self, inputs, labels): if self.training: # sample ids according to word distribution - Unique sample_values = self.sampler.sample(self.nsampled, labels.data.cpu().numpy()) return self.sampled(inputs, labels, sample_values, remove_accidental_match=True) else: return self.full(inputs, labels) def sampled(self, inputs, labels, sample_values, remove_accidental_match=False): assert (inputs.data.get_device() == labels.data.get_device()) device_id = labels.data.get_device() batch_size, d = inputs.size() sample_ids, true_freq, sample_freq = sample_values sample_ids = Variable(torch.LongTensor(sample_ids)).cuda(device_id) true_freq = Variable(torch.FloatTensor(true_freq)).cuda(device_id) sample_freq = Variable(torch.FloatTensor(sample_freq)).cuda(device_id) # gather true labels - weights and frequencies true_weights = torch.index_select(self.params.weight, 0, labels) true_bias = torch.index_select(self.params.bias, 0, labels) # gather sample ids - weights and frequencies sample_weights = torch.index_select(self.params.weight, 0, sample_ids) sample_bias = torch.index_select(self.params.bias, 0, sample_ids) # calculate logits true_logits = torch.sum(torch.mul(inputs, true_weights), dim=1) + true_bias sample_logits = torch.matmul(inputs, torch.t(sample_weights)) + sample_bias # remove true labels from sample set if remove_accidental_match: acc_hits = self.sampler.accidental_match( labels.data.cpu().numpy(), sample_ids.data.cpu().numpy()) acc_hits = list(zip(*acc_hits)) sample_logits[acc_hits] = -1e37 # perform correction true_logits = true_logits.sub(torch.log(true_freq)) sample_logits = sample_logits.sub(torch.log(sample_freq)) # return logits and new_labels logits = torch.cat( (torch.unsqueeze(true_logits, dim=1), sample_logits), dim=1) new_targets = Variable(torch.zeros(batch_size).long()).cuda(device_id) return logits, new_targets def full(self, inputs, labels): return self.params(inputs), labels
class Trainer(object): def __init__(self, log, model, beam_size, train_data, eval_data, optim, use_cuda, loss_func, cate_loss_func, topk, input_size, args): self.m_log = log self.model = model self.train_data = train_data self.eval_data = eval_data self.optim = optim self.m_loss_func = loss_func self.topk = topk self.evaluation = Evaluation(self.m_log, self.model, self.m_loss_func, use_cuda, self.topk, warm_start=args.warm_start) self.device = torch.device('cuda' if use_cuda else 'cpu') self.args = args self.m_cate_loss_func = cate_loss_func ### early stopping self.m_patience = args.patience self.m_best_recall = 0.0 self.m_best_mrr = 0.0 self.m_early_stop = False self.m_counter = 0 self.m_best_cate_recall = 0.0 self.m_best_cate_mrr = 0.0 self.m_batch_iter = 0 #### sample negative items self.m_sampler = LogUniformSampler(input_size) self.m_nsampled = args.negative_num self.m_remove_match = True self.m_teacher_forcing_ratio = 2.0 self.m_beam_size = beam_size self.m_teacher_forcing_flag = True self.m_logsoftmax = nn.LogSoftmax(dim=1) def saveModel(self, epoch, loss, recall, mrr): checkpoint = { 'model': self.model.state_dict(), 'args': self.args, 'epoch': epoch, 'optim': self.optim, 'loss': loss, 'recall': recall, 'mrr': mrr } model_name = os.path.join(self.args.checkpoint_dir, "model_best.pt") torch.save(checkpoint, model_name) def train(self, start_epoch, end_epoch, batch_size, start_time=None): if start_time is None: self.start_time = time.time() else: self.start_time = start_time ### start training for epoch in range(start_epoch, end_epoch + 1): msg = "*" * 10 + str(epoch) + "*" * 5 self.m_log.addOutput2IO(msg) print("teaching", self.m_teacher_forcing_flag) st = time.time() ### an epoch train_mixture_loss, train_loss, train_cate_loss = self.train_epoch( epoch, batch_size) ### evaluate model on train dataset or validation dateset mixture_loss, loss, recall, mrr, cate_loss, cate_recall, cate_mrr = self.evaluation.eval( self.train_data, batch_size, "train") print("train", train_loss) print("mix", mixture_loss) print("loss", loss) print("recall", recall) print("mrr", mrr) msg = "train Epoch: {}, train loss: {:.4f}, mixture loss: {:.4f}, loss: {:.4f}, recall: {:.4f}, mrr: {:.4f}, cate_loss: {:.4f}, cate_recall: {:.4f}, cate_mrr: {:.4f}, time: {}".format( epoch, train_mixture_loss, mixture_loss, loss, recall, mrr, cate_loss, cate_recall, cate_mrr, time.time() - st) self.m_log.addOutput2IO(msg) self.m_log.addScalar2Tensorboard("train_mixture_loss", train_mixture_loss, epoch) self.m_log.addScalar2Tensorboard("mixture_loss", mixture_loss, epoch) self.m_log.addScalar2Tensorboard("train_loss_eval", loss, epoch) self.m_log.addScalar2Tensorboard("train_recall", recall, epoch) self.m_log.addScalar2Tensorboard("train_mrr", mrr, epoch) self.m_log.addScalar2Tensorboard("train_cate_loss_eval", cate_loss, epoch) self.m_log.addScalar2Tensorboard("train_cate_recall", cate_recall, epoch) self.m_log.addScalar2Tensorboard("train_cate_mrr", cate_mrr, epoch) if self.m_best_cate_recall == 0: self.m_best_cate_recall = cate_recall elif self.m_best_cate_recall >= cate_recall: self.m_teacher_forcing_flag = False # self.m_teacher_forcing_flag = True else: self.m_best_cate_recall = cate_recall ### evaluate model on test dataset mixture_loss, loss, recall, mrr, cate_loss, cate_recall, cate_mrr = self.evaluation.eval( self.eval_data, batch_size, "test") msg = "Epoch: {}, mixture loss: {:.4f}, loss: {:.4f}, recall: {:.4f}, mrr: {:.4f}, cate_loss: {:.4f}, cate_recall: {:.4f}, cate_mrr: {:.4f}, time: {}".format( epoch, mixture_loss, loss, recall, mrr, cate_loss, cate_recall, cate_mrr, time.time() - st) self.m_log.addOutput2IO(msg) self.m_log.addScalar2Tensorboard("test_mixture_loss", mixture_loss, epoch) self.m_log.addScalar2Tensorboard("test_loss", loss, epoch) self.m_log.addScalar2Tensorboard("test_recall", recall, epoch) self.m_log.addScalar2Tensorboard("test_mrr", mrr, epoch) self.m_log.addScalar2Tensorboard("test_cate_loss", cate_loss, epoch) self.m_log.addScalar2Tensorboard("test_cate_recall", cate_recall, epoch) self.m_log.addScalar2Tensorboard("test_cate_mrr", cate_mrr, epoch) if self.m_best_recall == 0: self.m_best_recall = recall self.saveModel(epoch, loss, recall, mrr) elif self.m_best_recall > recall: self.m_counter += 1 if self.m_counter > self.m_patience: break msg = "early stop counter " + str(self.m_counter) self.m_log.addOutput2IO(msg) else: self.m_best_recall = recall self.m_best_mrr = mrr self.saveModel(epoch, loss, recall, mrr) self.m_counter = 0 msg = "best recall: " + str( self.m_best_recall) + "\t best mrr: \t" + str(self.m_best_mrr) self.m_log.addOutput2IO(msg) def train_epoch(self, epoch, batch_size): self.model.train() losses = [] cate_losses = [] mixture_losses = [] def reset_hidden(hidden, mask): """Helper function that resets hidden state when some sessions terminate""" if len(mask) != 0: hidden[:, mask, :] = 0 return hidden dataloader = self.train_data for x_long_cate_action_batch, x_long_cate_batch, mask_long_cate_action_batch, mask_long_cate_batch, max_long_cate_actionNum_batch, max_long_cateNum_batch, pad_x_long_cate_actionNum_batch, x_long_cateNum_batch, x_short_action_batch, x_short_cate_batch, mask_short_action_batch, pad_x_short_actionNum_batch, y_action_batch, y_cate_batch, y_action_idx_batch in dataloader: ### negative samples sample_values = self.m_sampler.sample(self.m_nsampled, y_action_batch) sample_ids, true_freq, sample_freq = sample_values ### whether excluding current pos sample from negative samples if self.m_remove_match: acc_hits = self.m_sampler.accidental_match( y_action_batch, sample_ids) acc_hits = list(zip(*acc_hits)) ### cates of long-range actions x_long_cate_batch = x_long_cate_batch.to(self.device) ### mask of cates of long-range actions mask_long_cate_batch = mask_long_cate_batch.to(self.device) ### items of long-range actions x_long_cate_action_batch = x_long_cate_action_batch.to(self.device) ### mask of items of long-range actions mask_long_cate_action_batch = mask_long_cate_action_batch.to( self.device) ### items of short-range actions x_short_action_batch = x_short_action_batch.to(self.device) ### mask of items of short-range actions mask_short_action_batch = mask_short_action_batch.to(self.device) ### cate of short-range actions x_short_cate_batch = x_short_cate_batch.to(self.device) ### items of target action y_action_batch = y_action_batch.to(self.device) ### cates of target action y_cate_batch = y_cate_batch.to(self.device) ### idx of target action in seq y_action_idx_batch = y_action_idx_batch.to(self.device) # batch_size = x_batch.size(0) self.optim.zero_grad() seq_cate_short_input = self.model.m_cateNN( x_short_cate_batch, mask_short_action_batch, pad_x_short_actionNum_batch, "train") logit_cate_short = self.model.m_cateNN.m_cate_h2o( seq_cate_short_input) pred_item_prob = None if self.m_teacher_forcing_flag: seq_cate_input, seq_short_input = self.model.m_itemNN( x_long_cate_action_batch, x_long_cate_batch, mask_long_cate_action_batch, mask_long_cate_batch, max_long_cate_actionNum_batch, max_long_cateNum_batch, pad_x_long_cate_actionNum_batch, x_long_cateNum_batch, x_short_action_batch, mask_short_action_batch, pad_x_short_actionNum_batch, y_cate_batch, "train") mixture_output = torch.cat((seq_cate_input, seq_short_input), dim=1) fc_output = self.model.fc(mixture_output) sampled_logit_batch, sampled_target_batch = self.model.m_ss( fc_output, y_action_batch, sample_ids, true_freq, sample_freq, acc_hits, self.device, self.m_remove_match) sampled_prob_batch = self.m_logsoftmax(sampled_logit_batch) pred_item_prob = sampled_prob_batch else: log_prob_cate_short = self.m_logsoftmax(logit_cate_short) log_prob_cate_short, pred_cate_index = torch.topk( log_prob_cate_short, self.m_beam_size, dim=-1) pred_cate_index = pred_cate_index.detach() log_prob_cate_short = log_prob_cate_short for beam_index in range(self.m_beam_size): pred_cate_beam = pred_cate_index[:, beam_index] prob_cate_beam = log_prob_cate_short[:, beam_index] seq_cate_input, seq_short_input = self.model.m_itemNN( x_long_cate_action_batch, x_long_cate_batch, mask_long_cate_action_batch, mask_long_cate_batch, max_long_cate_actionNum_batch, max_long_cateNum_batch, pad_x_long_cate_actionNum_batch, x_long_cateNum_batch, x_short_action_batch, mask_short_action_batch, pad_x_short_actionNum_batch, pred_cate_beam, "train") mixture_output = torch.cat( (seq_cate_input, seq_short_input), dim=1) fc_output = self.model.fc(mixture_output) sampled_logit_batch, sampled_target_batch = self.model.m_ss( fc_output, y_action_batch, sample_ids, true_freq, sample_freq, acc_hits, self.device, self.m_remove_match) sampled_prob_batch = self.m_logsoftmax(sampled_logit_batch) if pred_item_prob is None: pred_item_prob = sampled_prob_batch + prob_cate_beam.reshape( -1, 1) pred_item_prob = pred_item_prob.unsqueeze(-1) else: pred_item_prob_beam = sampled_prob_batch + prob_cate_beam.reshape( -1, 1) pred_item_prob_beam = pred_item_prob_beam.unsqueeze(-1) pred_item_prob = torch.cat( (pred_item_prob, pred_item_prob_beam), dim=-1) pred_item_prob = torch.logsumexp(pred_item_prob, dim=-1) loss_batch = self.m_loss_func(pred_item_prob, sampled_target_batch, "prob") losses.append(loss_batch.item()) cate_loss_batch = self.m_cate_loss_func(logit_cate_short, y_cate_batch, "logit") cate_losses.append(cate_loss_batch.item()) mixture_loss_batch = loss_batch + cate_loss_batch mixture_losses.append(mixture_loss_batch.item()) mixture_loss_batch.backward() max_norm = 5.0 self.m_batch_iter += 1 torch.nn.utils.clip_grad_norm(self.model.parameters(), max_norm) self.optim.step() mean_mixture_losses = np.mean(mixture_losses) mean_losses = np.mean(losses) mean_cate_losses = np.mean(cate_losses) return mean_mixture_losses, mean_losses, mean_cate_losses
p = log_uniform_distribution(N) start_time = time.time() torch.multinomial(p, num_samples, replacement=True) end_time = time.time() print("non_unique multinomial cuda", end_time - start_time) start_time = time.time() torch.multinomial(p, 100) end_time = time.time() print("unique multinomial cuda", end_time - start_time) sampler = LogUniformSampler(N) labels = np.random.choice(N, batch_size) start_time = time.time() sample_id, true_freq, sample_freq = sampler.sample(num_samples, labels) end_time = time.time() print("unique log_uniform c++", end_time - start_time) """ sampler = LogUniformSampler() start_time = time.time() sample_id = sampler.sample(N, num_samples, unique=True, labels=labels.tolist()) end_time = time.time() print("unique no_accidental_hits log_uniform c++", end_time - start_time) label_set = set(labels.tolist()) for idx in sample_id: assert(idx not in label_set) """
class Evaluation(object): def __init__(self, log, model, loss_func, use_cuda, input_size, k=20, warm_start=5): self.model = model self.m_loss_func = loss_func self.topk = k self.warm_start = warm_start self.device = torch.device('cuda' if use_cuda else 'cpu') self.m_log = log beam_size = 5 self.m_beam_size = beam_size self.m_sampler = LogUniformSampler(input_size) self.m_nsampled = 100 self.m_remove_match = True print("evaluation is based on sampled 100", self.m_nsampled) def eval(self, eval_data, batch_size, train_test_flag): self.model.eval() mixture_losses = [] losses = [] recalls = [] mrrs = [] weights = [] cate_losses = [] cate_recalls = [] cate_mrrs = [] cate_weights = [] dataloader = eval_data with torch.no_grad(): total_test_num = [] for x_long_cate_action_batch, x_long_cate_batch, mask_long_cate_action_batch, mask_long_cate_batch, max_long_cate_actionNum_batch, max_long_cateNum_batch, pad_x_long_cate_actionNum_batch, x_long_cateNum_batch, x_short_action_batch, x_short_cate_batch, mask_short_action_batch, pad_x_short_actionNum_batch, y_action_batch, y_cate_batch, y_action_idx_batch in dataloader: ###speed evaluation for train dataset if train_test_flag == "train": eval_flag = random.randint(1, 101) if eval_flag != 10: continue sample_values = self.m_sampler.sample(self.m_nsampled, y_action_batch) sample_ids, true_freq, sample_freq = sample_values ### whether excluding current pos sample from negative samples if self.m_remove_match: acc_hits = self.m_sampler.accidental_match( y_action_batch, sample_ids) acc_hits = list(zip(*acc_hits)) x_long_cate_action_batch = x_long_cate_action_batch.to( self.device) mask_long_cate_action_batch = mask_long_cate_action_batch.to( self.device) x_long_cate_batch = x_long_cate_batch.to(self.device) mask_long_cate_batch = mask_long_cate_batch.to(self.device) x_short_action_batch = x_short_action_batch.to(self.device) mask_short_action_batch = mask_short_action_batch.to( self.device) x_short_cate_batch = x_short_cate_batch.to(self.device) y_action_batch = y_action_batch.to(self.device) y_cate_batch = y_cate_batch.to(self.device) warm_start_mask = (y_action_idx_batch >= self.warm_start) ### cateNN seq_cate_short_input = self.model.m_cateNN( x_short_cate_batch, mask_short_action_batch, pad_x_short_actionNum_batch, "test") logit_cate_short = self.model.m_cateNN.m_cate_h2o( seq_cate_short_input) ### retrieve top k predicted categories prob_cate_short = F.softmax(logit_cate_short, dim=-1) cate_prob_beam, cate_id_beam = prob_cate_short.topk( dim=1, k=self.m_beam_size) item_prob_flag = False for beam_index in range(self.m_beam_size): ### for each category, predict item, then mix predictions for rec cate_id_beam_batch = cate_id_beam[:, beam_index] cate_id_beam_batch = cate_id_beam_batch.reshape(-1, 1) seq_cate_input, seq_short_input = self.model.m_itemNN( x_long_cate_action_batch, x_long_cate_batch, mask_long_cate_action_batch, mask_long_cate_batch, max_long_cate_actionNum_batch, max_long_cateNum_batch, pad_x_long_cate_actionNum_batch, x_long_cateNum_batch, x_short_action_batch, mask_short_action_batch, pad_x_short_actionNum_batch, cate_id_beam_batch, "test") mixture_output = torch.cat( (seq_cate_input, seq_short_input), dim=1) output_batch = self.model.fc(mixture_output) ### sampled_logit_batch sampled_logit_batch, sampled_target_batch = self.model.m_ss( output_batch, y_action_batch, sample_ids, true_freq, sample_freq, acc_hits, self.device, self.m_remove_match) ### batch_size*voc_size prob_batch = F.softmax(sampled_logit_batch, dim=-1) ## batch_size*1 cate_prob_batch = cate_prob_beam[:, beam_index] item_prob_batch = prob_batch * cate_prob_batch.reshape( -1, 1) # if not item_prob: if not item_prob_flag: item_prob_flag = True item_prob = item_prob_batch else: item_prob += item_prob_batch ### evaluate cate prediction cate_loss_batch = self.m_loss_func(logit_cate_short, y_cate_batch, "logit") cate_losses.append(cate_loss_batch.item()) cate_topk = 5 cate_recall_batch, cate_mrr_batch = evaluate(logit_cate_short, y_cate_batch, warm_start_mask, k=cate_topk) cate_weights.append(int(warm_start_mask.int().sum())) cate_recalls.append(cate_recall_batch) cate_mrrs.append(cate_mrr_batch) ### evaluate item prediction recall_batch, mrr_batch = evaluate(item_prob, y_action_batch, warm_start_mask, k=self.topk) weights.append(int(warm_start_mask.int().sum())) recalls.append(recall_batch) mrrs.append(mrr_batch) total_test_num.append(y_action_batch.view(-1).size(0)) mean_mixture_losses = 0.0 mean_losses = 0.0 mean_recall = np.average(recalls, weights=weights) mean_mrr = np.average(mrrs, weights=weights) mean_cate_losses = np.mean(cate_losses) mean_cate_recall = np.average(cate_recalls, weights=cate_weights) mean_cate_mrr = np.average(cate_mrrs, weights=cate_weights) msg = "total_test_num" + str(np.sum(total_test_num)) self.m_log.addOutput2IO(msg) return mean_mixture_losses, mean_losses, mean_recall, mean_mrr, mean_cate_losses, mean_cate_recall, mean_cate_mrr