def test_unique(self):
        N = 100
        nsampled = 50
        RND = 100

        sampler = LogUniformSampler(N)
        histogram = [0 for idx in range(N)]

        all_values = [idx for idx in range(N)]
        sample_ids, expected, sample_freq = sampler.sample(
            nsampled, np.asarray(all_values, dtype=np.int32))

        sample_set = set(sample_ids)
        self.assertEqual(len(sample_set), nsampled)

        for rnd in range(RND):
            sample_ids, true_freq, sample_freq = sampler.sample(
                nsampled, np.asarray(all_values, dtype=np.int32))

            for idx in range(N):
                self.assertTrue(
                    EXPECT_NEAR(expected[idx], true_freq[idx],
                                expected[idx] * 0.5))

            for idx in range(nsampled):
                histogram[sample_ids[idx]] += 1

        for idx in range(N):
            average_count = histogram[idx] / RND
            self.assertTrue(EXPECT_NEAR(expected[idx], average_count, 0.2))
Exemple #2
0
class SampledSoftmax(nn.Module):
    def __init__(self, ntokens, nsampled, nhid, tied_weight):
        super(SampledSoftmax, self).__init__()

        # Parameters
        self.ntokens = ntokens
        self.nsampled = nsampled

        self.sampler = LogUniformSampler(self.ntokens)
        self.params = nn.Linear(nhid, ntokens)

        if tied_weight is not None:
            self.params.weight = tied_weight
        else:
            util.initialize(self.params, self.ntokens)
        self.params.bias.data.fill_(0)

    def forward(self, inputs, labels, train=False):
        if train:
            sample_values = self.sampler.sample(self.nsampled, labels.data.cpu().numpy())
            return self.sampled(inputs, labels, sample_values)
        else:
            return self.full(inputs, labels)

    def sampled(self, inputs, labels, sample_values):
        assert(inputs.data.get_device() == labels.data.get_device())
        device_id = labels.data.get_device()

        batch_size, d = inputs.size()
        sample_ids, true_freq, sample_freq = sample_values

        # sample ids according to word distribution - Unique
        sample_ids = Variable(torch.LongTensor(sample_ids), requires_grad=False).cuda(device_id)
        true_freq = Variable(torch.FloatTensor(true_freq), requires_grad=False).cuda(device_id)
        sample_freq = Variable(torch.FloatTensor(sample_freq), requires_grad=False).cuda(device_id)

        # gather true labels - weights and frequencies
        true_weights = self.params.weight[labels.data, :]
        true_bias = self.params.bias[labels.data]

        # gather sample ids - weights and frequencies
        sample_weights = self.params.weight[sample_ids.data, :]
        sample_bias = self.params.bias[sample_ids.data]

        # calculate logits
        true_logits = torch.sum(torch.mul(inputs, true_weights), dim=1) + true_bias
        sample_logits = torch.matmul(inputs, torch.t(sample_weights)) + sample_bias

        # perform correction
        true_logits = true_logits.sub(torch.log(true_freq))
        sample_logits = sample_logits.sub(torch.log(sample_freq))

        # return logits and new_labels
        logits = torch.cat((torch.unsqueeze(true_logits, dim=1), sample_logits), dim=1)
        new_targets = Variable(torch.zeros(batch_size).long(), requires_grad=False).cuda(device_id)
        return logits, new_targets

    def full(self, inputs, labels):
        return self.params(inputs), labels
Exemple #3
0
class SampledSoftmax(nn.Module):
    def __init__(self, ntokens, nsampled, nhid, device):
        super(SampledSoftmax, self).__init__()

        # Parameters
        self.ntokens = ntokens
        self.nsampled = nsampled
        self.device = device
        #
        self.sampler = LogUniformSampler(self.ntokens)
        #
        self.weight = nn.Parameter(torch.Tensor(ntokens, nhid))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = math.sqrt(6.0 / (self.weight.size(0) + self.weight.size(1)))
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, inputs, labels):
        # sample ids according to word distribution - Unique
        sample_values = self.sampler.sample(self.nsampled,
                                            labels.data.cpu().numpy())
        return self.sampled(inputs,
                            labels,
                            sample_values,
                            remove_accidental_match=True)

    """@Dai Quoc Nguyen: Implement the sampled softmax loss function as described in the paper
    On Using Very Large Target Vocabulary for Neural Machine Translation https://www.aclweb.org/anthology/P15-1001/"""

    def sampled(self,
                inputs,
                labels,
                sample_values,
                remove_accidental_match=False):
        assert (inputs.data.get_device() == labels.data.get_device())

        batch_size, d = inputs.size()
        sample_ids, true_freq, sample_freq = sample_values

        sample_ids = Variable(torch.LongTensor(sample_ids)).to(self.device)

        # gather true labels
        true_weights = torch.index_select(self.weight, 0, labels)

        # gather sample ids
        sample_weights = torch.index_select(self.weight, 0, sample_ids)

        # calculate logits
        true_logits = torch.exp(
            torch.sum(torch.mul(inputs, true_weights), dim=1))
        sample_logits = torch.exp(torch.matmul(inputs,
                                               torch.t(sample_weights)))

        logits = -torch.log(true_logits / torch.sum(sample_logits, dim=1))

        return logits
Exemple #4
0
    def test_AccidentalMatch(self):
        np.random.seed(1000)
        num_classes = 5
        batch_size = 3
        nsampled = 4
        nhid = 10
        labels = np.random.randint(low=0, high=num_classes, size=batch_size)

        (weights, biases, hidden_acts, sampled_vals, exp_logits,
         exp_labels) = self._GenerateTestData(num_classes=num_classes,
                                              dim=nhid,
                                              batch_size=batch_size,
                                              num_true=1,
                                              labels=labels,
                                              sampled=[1, 0, 2, 3],
                                              subtract_log_q=True)

        ss = model.SampledSoftmax(num_classes,
                                  nsampled,
                                  nhid,
                                  tied_weight=None)
        ss.params.weight.data = torch.from_numpy(weights)
        ss.params.bias.data = torch.from_numpy(biases)
        ss.params.cuda()

        hidden_acts = Variable(torch.from_numpy(hidden_acts)).cuda()
        labels = Variable(torch.LongTensor(labels)).cuda()

        sampler = LogUniformSampler(nsampled)
        sampled_values = sampler.sample(nsampled, labels.data.cpu().numpy())
        sample_ids, true_freq, sample_freq = sampled_values
        logits, new_targets = ss.sampled(hidden_acts,
                                         labels,
                                         sampled_values,
                                         remove_accidental_match=True)

        criterion = nn.CrossEntropyLoss()
        loss = criterion(logits.view(-1, nsampled + 1), new_targets)

        np_logits = logits.data.cpu().numpy()
        for row in range(batch_size):
            label = labels[row]
            for col in range(nsampled):
                if sample_ids[col] == label:
                    self.assertTrue(
                        EXPECT_NEAR(np.exp(np_logits[row, col + 1]), 0, 1e-4))
Exemple #5
0
class SampledSoftmax(nn.Module):
    def __init__(self, ntokens, nsampled, nhid, tied_weight):
        super(SampledSoftmax, self).__init__()

        # Parameters
        self.ntokens = ntokens
        self.nsampled = nsampled

        self.sampler = LogUniformSampler(self.ntokens)
        self.params = nn.Linear(nhid, ntokens)

        if tied_weight is not None:
            self.params.weight = tied_weight
        else:
            util.initialize(self.params.weight)

    def forward(self, inputs, labels):
        if self.training:
            # sample ids according to word distribution - Unique
            sample_values = self.sampler.sample(self.nsampled,
                                                labels.data.cpu().numpy())
            return self.sampled(inputs,
                                labels,
                                sample_values,
                                remove_accidental_match=True)
        else:
            return self.full(inputs, labels)

    def sampled(self,
                inputs,
                labels,
                sample_values,
                remove_accidental_match=False):
        assert (inputs.data.get_device() == labels.data.get_device())
        device_id = labels.data.get_device()

        batch_size, d = inputs.size()
        sample_ids, true_freq, sample_freq = sample_values

        sample_ids = Variable(torch.LongTensor(sample_ids)).cuda(device_id)
        true_freq = Variable(torch.FloatTensor(true_freq)).cuda(device_id)
        sample_freq = Variable(torch.FloatTensor(sample_freq)).cuda(device_id)

        # gather true labels - weights and frequencies
        true_weights = torch.index_select(self.params.weight, 0, labels)
        true_bias = torch.index_select(self.params.bias, 0, labels)

        # gather sample ids - weights and frequencies
        sample_weights = torch.index_select(self.params.weight, 0, sample_ids)
        sample_bias = torch.index_select(self.params.bias, 0, sample_ids)

        # calculate logits
        true_logits = torch.sum(torch.mul(inputs, true_weights),
                                dim=1) + true_bias
        sample_logits = torch.matmul(inputs,
                                     torch.t(sample_weights)) + sample_bias
        # remove true labels from sample set
        if remove_accidental_match:
            acc_hits = self.sampler.accidental_match(
                labels.data.cpu().numpy(),
                sample_ids.data.cpu().numpy())
            acc_hits = list(zip(*acc_hits))
            sample_logits[acc_hits] = -1e37

        # perform correction
        true_logits = true_logits.sub(torch.log(true_freq))
        sample_logits = sample_logits.sub(torch.log(sample_freq))

        # return logits and new_labels
        logits = torch.cat(
            (torch.unsqueeze(true_logits, dim=1), sample_logits), dim=1)
        new_targets = Variable(torch.zeros(batch_size).long()).cuda(device_id)
        return logits, new_targets

    def full(self, inputs, labels):
        return self.params(inputs), labels
Exemple #6
0
class Trainer(object):
    def __init__(self, log, model, beam_size, train_data, eval_data, optim,
                 use_cuda, loss_func, cate_loss_func, topk, input_size, args):
        self.m_log = log
        self.model = model
        self.train_data = train_data
        self.eval_data = eval_data
        self.optim = optim
        self.m_loss_func = loss_func
        self.topk = topk
        self.evaluation = Evaluation(self.m_log,
                                     self.model,
                                     self.m_loss_func,
                                     use_cuda,
                                     self.topk,
                                     warm_start=args.warm_start)
        self.device = torch.device('cuda' if use_cuda else 'cpu')
        self.args = args

        self.m_cate_loss_func = cate_loss_func

        ### early stopping
        self.m_patience = args.patience
        self.m_best_recall = 0.0
        self.m_best_mrr = 0.0
        self.m_early_stop = False
        self.m_counter = 0

        self.m_best_cate_recall = 0.0
        self.m_best_cate_mrr = 0.0

        self.m_batch_iter = 0

        #### sample negative items
        self.m_sampler = LogUniformSampler(input_size)
        self.m_nsampled = args.negative_num
        self.m_remove_match = True

        self.m_teacher_forcing_ratio = 2.0
        self.m_beam_size = beam_size
        self.m_teacher_forcing_flag = True

        self.m_logsoftmax = nn.LogSoftmax(dim=1)

    def saveModel(self, epoch, loss, recall, mrr):
        checkpoint = {
            'model': self.model.state_dict(),
            'args': self.args,
            'epoch': epoch,
            'optim': self.optim,
            'loss': loss,
            'recall': recall,
            'mrr': mrr
        }
        model_name = os.path.join(self.args.checkpoint_dir, "model_best.pt")
        torch.save(checkpoint, model_name)

    def train(self, start_epoch, end_epoch, batch_size, start_time=None):

        if start_time is None:
            self.start_time = time.time()
        else:
            self.start_time = start_time

        ### start training
        for epoch in range(start_epoch, end_epoch + 1):

            msg = "*" * 10 + str(epoch) + "*" * 5
            self.m_log.addOutput2IO(msg)
            print("teaching", self.m_teacher_forcing_flag)
            st = time.time()

            ### an epoch
            train_mixture_loss, train_loss, train_cate_loss = self.train_epoch(
                epoch, batch_size)

            ### evaluate model on train dataset or validation dateset
            mixture_loss, loss, recall, mrr, cate_loss, cate_recall, cate_mrr = self.evaluation.eval(
                self.train_data, batch_size, "train")

            print("train", train_loss)
            print("mix", mixture_loss)
            print("loss", loss)
            print("recall", recall)
            print("mrr", mrr)

            msg = "train Epoch: {}, train loss: {:.4f},  mixture loss: {:.4f}, loss: {:.4f}, recall: {:.4f}, mrr: {:.4f}, cate_loss: {:.4f}, cate_recall: {:.4f}, cate_mrr: {:.4f}, time: {}".format(
                epoch, train_mixture_loss, mixture_loss, loss, recall, mrr,
                cate_loss, cate_recall, cate_mrr,
                time.time() - st)
            self.m_log.addOutput2IO(msg)
            self.m_log.addScalar2Tensorboard("train_mixture_loss",
                                             train_mixture_loss, epoch)
            self.m_log.addScalar2Tensorboard("mixture_loss", mixture_loss,
                                             epoch)
            self.m_log.addScalar2Tensorboard("train_loss_eval", loss, epoch)
            self.m_log.addScalar2Tensorboard("train_recall", recall, epoch)
            self.m_log.addScalar2Tensorboard("train_mrr", mrr, epoch)

            self.m_log.addScalar2Tensorboard("train_cate_loss_eval", cate_loss,
                                             epoch)
            self.m_log.addScalar2Tensorboard("train_cate_recall", cate_recall,
                                             epoch)
            self.m_log.addScalar2Tensorboard("train_cate_mrr", cate_mrr, epoch)

            if self.m_best_cate_recall == 0:
                self.m_best_cate_recall = cate_recall
            elif self.m_best_cate_recall >= cate_recall:
                self.m_teacher_forcing_flag = False
                # self.m_teacher_forcing_flag = True
            else:
                self.m_best_cate_recall = cate_recall

            ### evaluate model on test dataset
            mixture_loss, loss, recall, mrr, cate_loss, cate_recall, cate_mrr = self.evaluation.eval(
                self.eval_data, batch_size, "test")
            msg = "Epoch: {}, mixture loss: {:.4f}, loss: {:.4f}, recall: {:.4f}, mrr: {:.4f}, cate_loss: {:.4f}, cate_recall: {:.4f}, cate_mrr: {:.4f}, time: {}".format(
                epoch, mixture_loss, loss, recall, mrr, cate_loss, cate_recall,
                cate_mrr,
                time.time() - st)
            self.m_log.addOutput2IO(msg)
            self.m_log.addScalar2Tensorboard("test_mixture_loss", mixture_loss,
                                             epoch)
            self.m_log.addScalar2Tensorboard("test_loss", loss, epoch)
            self.m_log.addScalar2Tensorboard("test_recall", recall, epoch)
            self.m_log.addScalar2Tensorboard("test_mrr", mrr, epoch)

            self.m_log.addScalar2Tensorboard("test_cate_loss", cate_loss,
                                             epoch)
            self.m_log.addScalar2Tensorboard("test_cate_recall", cate_recall,
                                             epoch)
            self.m_log.addScalar2Tensorboard("test_cate_mrr", cate_mrr, epoch)

            if self.m_best_recall == 0:
                self.m_best_recall = recall
                self.saveModel(epoch, loss, recall, mrr)
            elif self.m_best_recall > recall:
                self.m_counter += 1
                if self.m_counter > self.m_patience:
                    break
                msg = "early stop counter " + str(self.m_counter)
                self.m_log.addOutput2IO(msg)
            else:
                self.m_best_recall = recall
                self.m_best_mrr = mrr
                self.saveModel(epoch, loss, recall, mrr)
                self.m_counter = 0

            msg = "best recall: " + str(
                self.m_best_recall) + "\t best mrr: \t" + str(self.m_best_mrr)
            self.m_log.addOutput2IO(msg)

    def train_epoch(self, epoch, batch_size):
        self.model.train()

        losses = []
        cate_losses = []
        mixture_losses = []

        def reset_hidden(hidden, mask):
            """Helper function that resets hidden state when some sessions terminate"""
            if len(mask) != 0:
                hidden[:, mask, :] = 0
            return hidden

        dataloader = self.train_data

        for x_long_cate_action_batch, x_long_cate_batch, mask_long_cate_action_batch, mask_long_cate_batch, max_long_cate_actionNum_batch, max_long_cateNum_batch, pad_x_long_cate_actionNum_batch, x_long_cateNum_batch, x_short_action_batch, x_short_cate_batch, mask_short_action_batch, pad_x_short_actionNum_batch, y_action_batch, y_cate_batch, y_action_idx_batch in dataloader:

            ### negative samples
            sample_values = self.m_sampler.sample(self.m_nsampled,
                                                  y_action_batch)
            sample_ids, true_freq, sample_freq = sample_values

            ### whether excluding current pos sample from negative samples
            if self.m_remove_match:
                acc_hits = self.m_sampler.accidental_match(
                    y_action_batch, sample_ids)
                acc_hits = list(zip(*acc_hits))

            ### cates of long-range actions
            x_long_cate_batch = x_long_cate_batch.to(self.device)

            ### mask of cates of long-range actions
            mask_long_cate_batch = mask_long_cate_batch.to(self.device)

            ### items of long-range actions
            x_long_cate_action_batch = x_long_cate_action_batch.to(self.device)

            ### mask of items of long-range actions
            mask_long_cate_action_batch = mask_long_cate_action_batch.to(
                self.device)

            ### items of short-range actions
            x_short_action_batch = x_short_action_batch.to(self.device)

            ### mask of items of short-range actions
            mask_short_action_batch = mask_short_action_batch.to(self.device)

            ### cate of short-range actions
            x_short_cate_batch = x_short_cate_batch.to(self.device)

            ### items of target action
            y_action_batch = y_action_batch.to(self.device)

            ### cates of target action
            y_cate_batch = y_cate_batch.to(self.device)

            ### idx of target action in seq
            y_action_idx_batch = y_action_idx_batch.to(self.device)
            # batch_size = x_batch.size(0)

            self.optim.zero_grad()

            seq_cate_short_input = self.model.m_cateNN(
                x_short_cate_batch, mask_short_action_batch,
                pad_x_short_actionNum_batch, "train")
            logit_cate_short = self.model.m_cateNN.m_cate_h2o(
                seq_cate_short_input)

            pred_item_prob = None

            if self.m_teacher_forcing_flag:
                seq_cate_input, seq_short_input = self.model.m_itemNN(
                    x_long_cate_action_batch, x_long_cate_batch,
                    mask_long_cate_action_batch, mask_long_cate_batch,
                    max_long_cate_actionNum_batch, max_long_cateNum_batch,
                    pad_x_long_cate_actionNum_batch, x_long_cateNum_batch,
                    x_short_action_batch, mask_short_action_batch,
                    pad_x_short_actionNum_batch, y_cate_batch, "train")

                mixture_output = torch.cat((seq_cate_input, seq_short_input),
                                           dim=1)
                fc_output = self.model.fc(mixture_output)

                sampled_logit_batch, sampled_target_batch = self.model.m_ss(
                    fc_output, y_action_batch, sample_ids, true_freq,
                    sample_freq, acc_hits, self.device, self.m_remove_match)

                sampled_prob_batch = self.m_logsoftmax(sampled_logit_batch)
                pred_item_prob = sampled_prob_batch

            else:
                log_prob_cate_short = self.m_logsoftmax(logit_cate_short)
                log_prob_cate_short, pred_cate_index = torch.topk(
                    log_prob_cate_short, self.m_beam_size, dim=-1)

                pred_cate_index = pred_cate_index.detach()
                log_prob_cate_short = log_prob_cate_short

                for beam_index in range(self.m_beam_size):
                    pred_cate_beam = pred_cate_index[:, beam_index]
                    prob_cate_beam = log_prob_cate_short[:, beam_index]

                    seq_cate_input, seq_short_input = self.model.m_itemNN(
                        x_long_cate_action_batch, x_long_cate_batch,
                        mask_long_cate_action_batch, mask_long_cate_batch,
                        max_long_cate_actionNum_batch, max_long_cateNum_batch,
                        pad_x_long_cate_actionNum_batch, x_long_cateNum_batch,
                        x_short_action_batch, mask_short_action_batch,
                        pad_x_short_actionNum_batch, pred_cate_beam, "train")

                    mixture_output = torch.cat(
                        (seq_cate_input, seq_short_input), dim=1)
                    fc_output = self.model.fc(mixture_output)

                    sampled_logit_batch, sampled_target_batch = self.model.m_ss(
                        fc_output, y_action_batch, sample_ids, true_freq,
                        sample_freq, acc_hits, self.device,
                        self.m_remove_match)

                    sampled_prob_batch = self.m_logsoftmax(sampled_logit_batch)

                    if pred_item_prob is None:
                        pred_item_prob = sampled_prob_batch + prob_cate_beam.reshape(
                            -1, 1)
                        pred_item_prob = pred_item_prob.unsqueeze(-1)
                    else:
                        pred_item_prob_beam = sampled_prob_batch + prob_cate_beam.reshape(
                            -1, 1)
                        pred_item_prob_beam = pred_item_prob_beam.unsqueeze(-1)
                        pred_item_prob = torch.cat(
                            (pred_item_prob, pred_item_prob_beam), dim=-1)

                pred_item_prob = torch.logsumexp(pred_item_prob, dim=-1)

            loss_batch = self.m_loss_func(pred_item_prob, sampled_target_batch,
                                          "prob")
            losses.append(loss_batch.item())

            cate_loss_batch = self.m_cate_loss_func(logit_cate_short,
                                                    y_cate_batch, "logit")
            cate_losses.append(cate_loss_batch.item())

            mixture_loss_batch = loss_batch + cate_loss_batch

            mixture_losses.append(mixture_loss_batch.item())

            mixture_loss_batch.backward()

            max_norm = 5.0

            self.m_batch_iter += 1

            torch.nn.utils.clip_grad_norm(self.model.parameters(), max_norm)

            self.optim.step()

        mean_mixture_losses = np.mean(mixture_losses)

        mean_losses = np.mean(losses)

        mean_cate_losses = np.mean(cate_losses)

        return mean_mixture_losses, mean_losses, mean_cate_losses
Exemple #7
0
p = log_uniform_distribution(N)
start_time = time.time()
torch.multinomial(p, num_samples, replacement=True)
end_time = time.time()
print("non_unique multinomial cuda", end_time - start_time)

start_time = time.time()
torch.multinomial(p, 100)
end_time = time.time()
print("unique multinomial cuda", end_time - start_time)

sampler = LogUniformSampler(N)
labels = np.random.choice(N, batch_size)

start_time = time.time()
sample_id, true_freq, sample_freq = sampler.sample(num_samples, labels)
end_time = time.time()
print("unique log_uniform c++", end_time - start_time)

"""
sampler = LogUniformSampler()
start_time = time.time()
sample_id = sampler.sample(N, num_samples, unique=True, labels=labels.tolist())
end_time = time.time()
print("unique no_accidental_hits log_uniform c++", end_time - start_time)

label_set = set(labels.tolist())
for idx in sample_id:
    assert(idx not in label_set)
"""
Exemple #8
0
class Evaluation(object):
    def __init__(self,
                 log,
                 model,
                 loss_func,
                 use_cuda,
                 input_size,
                 k=20,
                 warm_start=5):
        self.model = model
        self.m_loss_func = loss_func

        self.topk = k
        self.warm_start = warm_start
        self.device = torch.device('cuda' if use_cuda else 'cpu')
        self.m_log = log
        beam_size = 5
        self.m_beam_size = beam_size

        self.m_sampler = LogUniformSampler(input_size)
        self.m_nsampled = 100
        self.m_remove_match = True
        print("evaluation is based on sampled 100", self.m_nsampled)

    def eval(self, eval_data, batch_size, train_test_flag):
        self.model.eval()

        mixture_losses = []

        losses = []
        recalls = []
        mrrs = []
        weights = []

        cate_losses = []
        cate_recalls = []
        cate_mrrs = []
        cate_weights = []

        dataloader = eval_data

        with torch.no_grad():
            total_test_num = []

            for x_long_cate_action_batch, x_long_cate_batch, mask_long_cate_action_batch, mask_long_cate_batch, max_long_cate_actionNum_batch, max_long_cateNum_batch, pad_x_long_cate_actionNum_batch, x_long_cateNum_batch, x_short_action_batch, x_short_cate_batch, mask_short_action_batch, pad_x_short_actionNum_batch, y_action_batch, y_cate_batch, y_action_idx_batch in dataloader:

                ###speed evaluation for train dataset
                if train_test_flag == "train":
                    eval_flag = random.randint(1, 101)
                    if eval_flag != 10:
                        continue

                sample_values = self.m_sampler.sample(self.m_nsampled,
                                                      y_action_batch)
                sample_ids, true_freq, sample_freq = sample_values

                ### whether excluding current pos sample from negative samples
                if self.m_remove_match:
                    acc_hits = self.m_sampler.accidental_match(
                        y_action_batch, sample_ids)
                    acc_hits = list(zip(*acc_hits))

                x_long_cate_action_batch = x_long_cate_action_batch.to(
                    self.device)

                mask_long_cate_action_batch = mask_long_cate_action_batch.to(
                    self.device)

                x_long_cate_batch = x_long_cate_batch.to(self.device)
                mask_long_cate_batch = mask_long_cate_batch.to(self.device)

                x_short_action_batch = x_short_action_batch.to(self.device)
                mask_short_action_batch = mask_short_action_batch.to(
                    self.device)

                x_short_cate_batch = x_short_cate_batch.to(self.device)

                y_action_batch = y_action_batch.to(self.device)
                y_cate_batch = y_cate_batch.to(self.device)

                warm_start_mask = (y_action_idx_batch >= self.warm_start)

                ### cateNN
                seq_cate_short_input = self.model.m_cateNN(
                    x_short_cate_batch, mask_short_action_batch,
                    pad_x_short_actionNum_batch, "test")
                logit_cate_short = self.model.m_cateNN.m_cate_h2o(
                    seq_cate_short_input)

                ### retrieve top k predicted categories
                prob_cate_short = F.softmax(logit_cate_short, dim=-1)
                cate_prob_beam, cate_id_beam = prob_cate_short.topk(
                    dim=1, k=self.m_beam_size)

                item_prob_flag = False

                for beam_index in range(self.m_beam_size):

                    ### for each category, predict item, then mix predictions for rec

                    cate_id_beam_batch = cate_id_beam[:, beam_index]
                    cate_id_beam_batch = cate_id_beam_batch.reshape(-1, 1)

                    seq_cate_input, seq_short_input = self.model.m_itemNN(
                        x_long_cate_action_batch, x_long_cate_batch,
                        mask_long_cate_action_batch, mask_long_cate_batch,
                        max_long_cate_actionNum_batch, max_long_cateNum_batch,
                        pad_x_long_cate_actionNum_batch, x_long_cateNum_batch,
                        x_short_action_batch, mask_short_action_batch,
                        pad_x_short_actionNum_batch, cate_id_beam_batch,
                        "test")

                    mixture_output = torch.cat(
                        (seq_cate_input, seq_short_input), dim=1)
                    output_batch = self.model.fc(mixture_output)

                    ### sampled_logit_batch
                    sampled_logit_batch, sampled_target_batch = self.model.m_ss(
                        output_batch, y_action_batch, sample_ids, true_freq,
                        sample_freq, acc_hits, self.device,
                        self.m_remove_match)

                    ### batch_size*voc_size
                    prob_batch = F.softmax(sampled_logit_batch, dim=-1)

                    ## batch_size*1
                    cate_prob_batch = cate_prob_beam[:, beam_index]

                    item_prob_batch = prob_batch * cate_prob_batch.reshape(
                        -1, 1)

                    # if not item_prob:
                    if not item_prob_flag:
                        item_prob_flag = True
                        item_prob = item_prob_batch
                    else:
                        item_prob += item_prob_batch

                ### evaluate cate prediction
                cate_loss_batch = self.m_loss_func(logit_cate_short,
                                                   y_cate_batch, "logit")
                cate_losses.append(cate_loss_batch.item())
                cate_topk = 5

                cate_recall_batch, cate_mrr_batch = evaluate(logit_cate_short,
                                                             y_cate_batch,
                                                             warm_start_mask,
                                                             k=cate_topk)

                cate_weights.append(int(warm_start_mask.int().sum()))
                cate_recalls.append(cate_recall_batch)
                cate_mrrs.append(cate_mrr_batch)

                ### evaluate item prediction
                recall_batch, mrr_batch = evaluate(item_prob,
                                                   y_action_batch,
                                                   warm_start_mask,
                                                   k=self.topk)

                weights.append(int(warm_start_mask.int().sum()))
                recalls.append(recall_batch)
                mrrs.append(mrr_batch)

                total_test_num.append(y_action_batch.view(-1).size(0))

        mean_mixture_losses = 0.0

        mean_losses = 0.0
        mean_recall = np.average(recalls, weights=weights)
        mean_mrr = np.average(mrrs, weights=weights)

        mean_cate_losses = np.mean(cate_losses)
        mean_cate_recall = np.average(cate_recalls, weights=cate_weights)
        mean_cate_mrr = np.average(cate_mrrs, weights=cate_weights)

        msg = "total_test_num" + str(np.sum(total_test_num))
        self.m_log.addOutput2IO(msg)

        return mean_mixture_losses, mean_losses, mean_recall, mean_mrr, mean_cate_losses, mean_cate_recall, mean_cate_mrr