def test(self, sess, test_data):
        '''
        the test function.
        test_data = [contexts, aspects, labels, positons]

        contexts: shape = [len(samples), None], the test samples' context, 
        the len(each sample context) is not fixed, the fixed version is mem_size. 

        aspects: shape = [len(samples), None], 
        the test samples' aspect, the len(each sample aspect) is not fixed.

        labels: shape = [len(samples)]

        positons.shape = [len(samples), 2], the aspect's positon in the sample, 
        include from and to (where the 2 means).

        the model input include the apsect lens.
        '''

        # batch data.
        preds = []
        alphas = []
        outputs = []
        bt = batcher(samples=test_data.samples,
                     class_num=self.class_num,
                     random=True,
                     pad_idx=self.pad_idx,
                     eos=self.eos)
        while bt.has_next():  # batch round.
            # get this batch data
            batch_data = bt.next_batch()
            # build the feed_dict
            feed_dict = {
                self.inputs: batch_data['text_idxes'],
                self.left_inputs: batch_data['left_ctx_idxes'],
                self.right_inputs: batch_data['right_ctx_idxes'],
                self.aspects: batch_data['aspect_idxes'],
                self.sequence_length: batch_data['text_lens'],
                self.left_length: batch_data['left_ca_lens'],
                self.right_length: batch_data['right_ca_lens'],
                self.aspect_length: batch_data['aspect_lens'],
                self.lab_input: batch_data['labels'],
                self.sent_bitmap: batch_data['text_bitmap'],
                self.left_ctx_asp: batch_data['left_ctx_asp'],
                self.right_ctx_asp: batch_data['right_ctx_asp'],
                self.asp_mask: batch_data['asp_mask']
                # self.left_sent_bitmap: batch_data['left_bitmap'],
                # self.right_sent_bitmap: batch_data['right_bitmap'],
                # self.pos_ids: batch_data['pos_ids']
            }

            if self.reverse is True:
                feed_dict[
                    self.reverse_length] = batch_data['text_reverse_lens']
        #
        # samples = test_data.samples
        # id2samples = test_data.id2sample
        # ids = range(len(samples))
        # rinids = ids
        # for id in rinids:
        #     sample = copy.deepcopy(id2samples[id])
        #     ret = get_data(sample, class_num=self.class_num, pad_idx=self.pad_idx, eos=self.eos)
        #     feed_dict = {
        #         self.inputs: ret['text_idxes'],
        #         # self.left_inputs: ret['left_ctx_idxes'],
        #         # self.right_inputs: ret['right_ctx_idxes'],
        #         self.aspects: ret['aspect_idxes'],
        #         self.sequence_length: ret['text_lens'],
        #         self.left_length: ret['left_lens'],
        #         self.right_length: ret['right_lens'],
        #         self.aspect_length: ret['aspect_lens'],
        #         self.lab_input: ret['labels'],
        #         self.sent_bitmap: ret['text_bitmap'],
        #         # self.left_sent_bitmap: ret['left_bitmap'],
        #         # self.right_sent_bitmap: ret['right_bitmap'],
        #         self.left_subs: ret['left_subs'],
        #         self.right_subs: ret['right_subs'],
        #         self.left_ctx_asp: ret['left_ctx_asp'],
        #         self.right_ctx_asp: ret['right_ctx_asp']
        #     }

        # test
            pred, om_tmpalphas, loc_tmp, out_tmp = sess.run(
                [self.pred, self.om_alpha, self.location, self.last_output],
                feed_dict=feed_dict)
            test_data.pack_preds(pred, batch_data['batch_ids'])
            # em_tmpalphas = test_data.transform_ext_matrix(em_tmpalphas)
            om_tmpalphas = test_data.transform_ext_matrix(om_tmpalphas)
            test_data.pack_ext_matrix('om', om_tmpalphas,
                                      batch_data['batch_ids'])
            test_data.pack_ext_matrix('loc', loc_tmp, batch_data['batch_ids'])
            #test_data.pack_ext_matrix('output', out_tmp, batch_data['batch_ids'])

            # test_data.pack_ext_matrix('right', em_tmpalphas, ret['batch_ids'])

            # if alphas == []:
            #     for i in xrange(len(tmpalphas)):
            #         alphas.append(tmpalphas[i].tolist())
            #         outputs.append(tmpoutputs[i].tolist())
            # else:
            #     for i in xrange(len(tmpalphas)):
            #         alphas[i].extend(tmpalphas[i].tolist())
            #         outputs[i].extend(tmpoutputs[i].tolist())
        # calculate the acc
        acc = cau_samples_acc(test_data.samples)
        return acc
    def train(self,
              sess,
              train_data,
              test_data=None,
              saver=None,
              threshold_acc=0.999):
        '''
        Train the mocel. 
        The data in the train_data, test_data:
            data = [contexts, aspects, labels, positons,
               rowtexts, rowaspects, fullsents, subpositions]
        '''
        max_acc = 0.0
        max_train_acc = 0.0
        for _ in xrange(self.nepoch):  # epoch round.
            cost = 0.0  # the cost of each epoch.
            bt = batcher(samples=train_data.samples,
                         class_num=self.class_num,
                         random=True,
                         pad_idx=self.pad_idx,
                         eos=self.eos)
            while bt.has_next():  # batch round.
                # get this batch data
                batch_data = bt.next_batch()
                # build the feed_dict
                feed_dict = {
                    self.inputs: batch_data['text_idxes'],
                    self.left_inputs: batch_data['left_ctx_idxes'],
                    self.right_inputs: batch_data['right_ctx_idxes'],
                    self.aspects: batch_data['aspect_idxes'],
                    self.sequence_length: batch_data['text_lens'],
                    self.left_length: batch_data['left_ca_lens'],
                    self.right_length: batch_data['right_ca_lens'],
                    self.aspect_length: batch_data['aspect_lens'],
                    self.lab_input: batch_data['labels'],
                    self.sent_bitmap: batch_data['text_bitmap'],
                    self.left_ctx_asp: batch_data['left_ctx_asp'],
                    self.right_ctx_asp: batch_data['right_ctx_asp'],
                    self.asp_mask: batch_data['asp_mask']
                    # self.left_sent_bitmap: batch_data['left_bitmap'],
                    # self.right_sent_bitmap: batch_data['right_bitmap'],
                    # self.pos_ids: batch_data['pos_ids']
                }
                # if self.reverse is True:
                #     feed_dict[self.reverse_length] = batch_data['text_reverse_lens']
                # for keys in feed_dict.keys():
                #     print str(keys) + str(feed_dict[keys])
                # samples = train_data.samples
                # id2samples = train_data.id2sample
                # rinids = range(len(samples))
                # random.shuffle(rinids)
                # for id in rinids:
                #     sample = copy.deepcopy(id2samples[id])
                #     ret = get_data(sample,class_num=self.class_num,pad_idx=self.pad_idx,eos=self.eos)
                #     feed_dict = {
                #         self.inputs: ret['text_idxes'],
                #         # self.left_inputs: ret['left_ctx_idxes'],
                #         # self.right_inputs: ret['right_ctx_idxes'],
                #         self.aspects: ret['aspect_idxes'],
                #         self.sequence_length: ret['text_lens'],
                #         self.left_length: ret['left_lens'],
                #         self.right_length: ret['right_lens'],
                #         self.aspect_length: ret['aspect_lens'],
                #         self.lab_input: ret['labels'],
                #         self.sent_bitmap: ret['text_bitmap'],
                #         # self.left_sent_bitmap: ret['left_bitmap'],
                #         # self.right_sent_bitmap: ret['right_bitmap'],
                #         self.left_subs: ret['left_subs'],
                #         self.right_subs: ret['right_subs'],
                #         self.left_ctx_asp: ret['left_ctx_asp'],
                #         self.right_ctx_asp: ret['right_ctx_asp']
                #             }
                # train
                crt_loss, crt_step, opt = sess.run(
                    [self.loss, self.global_step, self.optimize],
                    feed_dict=feed_dict)
                cost += np.sum(crt_loss)
            train_acc = self.test(sess, train_data)
            print "train epoch: " + str(_) + \
                "    cost: " + \
                str(cost / len(train_data.samples)) + \
                "    acc: " + \
                str(train_acc) + \
                "    crt_step:" + \
                str(crt_step / 28)
            if test_data != None:
                test_acc = self.test(sess, test_data)
                print "                  test_acc: " + str(test_acc)
                if max_acc < test_acc:
                    max_acc = test_acc
                    max_train_acc = train_acc
                    test_data.update_best()
                    if max_acc > threshold_acc:
                        self.save_model(sess, self.config, saver)
                print "                   max_acc: " + str(max_acc)

        if max_acc > threshold_acc:

            if self.print_all:
                suf = TIPrint(test_data.samples, self.config, True, {
                    'predict_accuracy': max_acc,
                    'train_accurayc': max_train_acc
                }, True)
                TIPrint(test_data.samples, self.config, False, {
                    'predict_accuracy': max_acc,
                    'train_accurayc': max_train_acc
                }, True, suf)
            else:
                TIPrint(test_data.samples, self.config, False, {
                    'predict_accuracy': max_acc,
                    'train_accurayc': max_train_acc
                }, True)

        return max_acc
Example #3
0
    def train(self,sess, train_data, test_data=None,saver = None, threshold_acc=0.99):


        max_recall = 0.0
        max_mrr = 0.0
        max_train_acc = 0.0
        for epoch in range(self.nepoch):   # epoch round.
            batch = 0
            c = []
            cost = 0.0  # the cost of each epoch.
            bt = batcher(
                samples=train_data.samples,
                class_num= self.n_items,
                random=True
            )
            while bt.has_next():    # batch round.
                # get this batch data
                batch_data = bt.next_batch()
                # build the feed_dict
                # for x,y in zip(batch_data['in_idxes'],batch_data['out_idxes']):
                batch_lenth = len(batch_data['in_idxes'])
                event = len(batch_data['in_idxes'][0])

                if batch_lenth > self.batch_size:
                    patch_len = int(batch_lenth / self.batch_size)
                    remain = int(batch_lenth % self.batch_size)
                    i = 0
                    for x in range(patch_len):
                        tmp_in_data = batch_data['in_idxes'][i:i+self.batch_size]
                        tmp_out_data = batch_data['out_idxes'][i:i+self.batch_size]
                        for s in range(len(tmp_in_data[0])):
                            batch_in = []
                            batch_out = []
                            batch_last = []
                            batch_seq_l = []
                            for tmp_in, tmp_out in zip(tmp_in_data, tmp_out_data):

                               _in = tmp_in[s]
                               _out = tmp_out[s]-1
                               batch_last.append(_in)
                               batch_in.append(tmp_in[:s + 1])
                               batch_out.append(_out)
                               batch_seq_l.append(s + 1)
                            feed_dict = {
                                self.inputs: batch_in,
                                self.last_inputs: batch_last,
                                self.lab_input: batch_out,
                                self.sequence_length: batch_seq_l

                            }
                            # train
                            crt_loss, crt_step, opt, embe_dict = sess.run(
                                [self.loss, self.global_step, self.optimize, self.embe_dict],
                                feed_dict=feed_dict
                            )

                            # cost = np.mean(crt_loss)
                            c += list(crt_loss)
                            # print("Batch:" + str(batch) + ",cost:" + str(cost))
                            batch += 1
                        i += self.batch_size
                    if remain > 0:
                        # print (i, remain)
                        tmp_in_data = batch_data['in_idxes'][i:]
                        tmp_out_data = batch_data['out_idxes'][i:]
                        for s in range(len(tmp_in_data[0])):
                            batch_in = []
                            batch_out = []
                            batch_last = []
                            batch_seq_l = []
                            for tmp_in, tmp_out in zip(tmp_in_data, tmp_out_data):
                                _in = tmp_in[s]
                                _out = tmp_out[s] - 1
                                batch_last.append(_in)
                                batch_in.append(tmp_in[:s + 1])
                                batch_out.append(_out)
                                batch_seq_l.append(s + 1)
                            feed_dict = {
                                self.inputs: batch_in,
                                self.last_inputs: batch_last,
                                self.lab_input: batch_out,
                                self.sequence_length: batch_seq_l

                            }
                            # train
                            crt_loss, crt_step, opt, embe_dict = sess.run(
                                [self.loss, self.global_step, self.optimize, self.embe_dict],
                                feed_dict=feed_dict
                            )

                            # cost = np.mean(crt_loss)
                            c += list(crt_loss)
                            # print("Batch:" + str(batch) + ",cost:" + str(cost))
                            batch += 1
                else:
                    tmp_in_data = batch_data['in_idxes']
                    tmp_out_data = batch_data['out_idxes']
                    for s in range(len(tmp_in_data[0])):
                        batch_in = []
                        batch_out = []
                        batch_last = []
                        batch_seq_l = []
                        for tmp_in, tmp_out in zip(tmp_in_data, tmp_out_data):
                            _in = tmp_in[s]
                            _out = tmp_out[s] - 1
                            batch_last.append(_in)
                            batch_in.append(tmp_in[:s + 1])
                            batch_out.append(_out)
                            batch_seq_l.append(s + 1)
                        feed_dict = {
                            self.inputs: batch_in,
                            self.last_inputs: batch_last,
                            self.lab_input: batch_out,
                            self.sequence_length: batch_seq_l

                        }
                        # train
                        crt_loss, crt_step, opt, embe_dict = sess.run(
                            [self.loss, self.global_step, self.optimize, self.embe_dict],
                            feed_dict=feed_dict
                        )

                        # cost = np.mean(crt_loss)
                        c+= list(crt_loss)
                        # print("Batch:" + str(batch) + ",cost:" + str(cost))
                        batch += 1
            # train_acc = self.test(sess,train_data)
            avgc = np.mean(c)
            if np.isnan(avgc):
                print('Epoch {}: NaN error!'.format(str(epoch)))
                self.error_during_train = True
                return
            print('Epoch{}\tloss: {:.6f}'.format(epoch, avgc))
            if test_data != None:
                recall, mrr = self.test(sess, test_data)
                print(recall, mrr)
                if max_recall < recall:
                    max_recall = recall
                    max_mrr = mrr
                    test_data.update_best()
                    if max_recall > threshold_acc:
                        self.save_model(sess, self.config, saver)
                print ("                   max_recall: " + str(max_recall)+" max_mrr: "+str(max_mrr))
                test_data.flush()
        if self.is_print:
            TIPrint(test_data.samples, self.config,
                    {'recall': max_recall, 'mrr': max_mrr}, True)
Example #4
0
    def test(self,sess,test_data):

        # calculate the acc
        print('Measuring Recall@{} and MRR@{}'.format(self.cut_off, self.cut_off))

        mrr, recall = [], []
        c_loss =[]
        batch = 0
        bt = batcher(
            samples = test_data.samples,
            class_num = self.n_items,
            random = False
        )
        while bt.has_next():    # batch round.
            # get this batch data
            batch_data = bt.next_batch()
            # build the feed_dict
            # for x,y in zip(batch_data['in_idxes'],batch_data['out_idxes']):
            batch_lenth = len(batch_data['in_idxes'])
            event = len(batch_data['in_idxes'][0])
            if batch_lenth > self.batch_size:
                patch_len = int(batch_lenth / self.batch_size)
                remain = int(batch_lenth % self.batch_size)
                i = 0
                for x in range(patch_len):
                    tmp_in_data = batch_data['in_idxes'][i:i+self.batch_size]
                    tmp_out_data = batch_data['out_idxes'][i:i+self.batch_size]
                    tmp_batch_ids = batch_data['batch_ids'][i:i+self.batch_size]
                    for s in range(len(tmp_in_data[0])):
                        batch_in = []
                        batch_out = []
                        batch_last = []
                        batch_seq_l = []
                        for tmp_in, tmp_out in zip(tmp_in_data, tmp_out_data):
                            _in = tmp_in[s]
                            _out = tmp_out[s] - 1
                            batch_last.append(_in)
                            batch_in.append(tmp_in[:s + 1])
                            batch_out.append(_out)
                            batch_seq_l.append(s + 1)
                        feed_dict = {
                            self.inputs: batch_in,
                            self.last_inputs: batch_last,
                            self.lab_input: batch_out,
                            self.sequence_length: batch_seq_l

                        }
                        # train
                        preds, loss, alpha = sess.run(
                            [self.softmax_input, self.loss, self.alph],
                            feed_dict=feed_dict
                        )
                        t_r, t_m, ranks = cau_recall_mrr_org(preds, batch_out, cutoff=self.cut_off)
                        test_data.pack_ext_matrix('alpha', alpha, tmp_batch_ids)
                        test_data.pack_preds(ranks, tmp_batch_ids)
                        c_loss += list(loss)
                        recall += t_r
                        mrr += t_m
                        batch += 1
                    i += self.batch_size
                if remain > 0:
                    # print (i, remain)
                    tmp_in_data = batch_data['in_idxes'][i:]
                    tmp_out_data = batch_data['out_idxes'][i:]
                    tmp_batch_ids = batch_data['batch_ids'][i:]
                    for s in range(len(tmp_in_data[0])):
                        batch_in = []
                        batch_out = []
                        batch_last = []
                        batch_seq_l = []
                        for tmp_in, tmp_out in zip(tmp_in_data, tmp_out_data):
                            _in = tmp_in[s]
                            _out = tmp_out[s] - 1
                            batch_last.append(_in)
                            batch_in.append(tmp_in[:s + 1])
                            batch_out.append(_out)
                            batch_seq_l.append(s + 1)
                        feed_dict = {
                            self.inputs: batch_in,
                            self.last_inputs: batch_last,
                            self.lab_input: batch_out,
                            self.sequence_length: batch_seq_l

                        }

                        # train
                        preds, loss, alpha = sess.run(
                            [self.softmax_input, self.loss, self.alph],
                            feed_dict=feed_dict
                        )
                        t_r, t_m, ranks = cau_recall_mrr_org(preds, batch_out, cutoff=self.cut_off)
                        test_data.pack_ext_matrix('alpha', alpha, tmp_batch_ids)
                        test_data.pack_preds(ranks, tmp_batch_ids)
                        c_loss += list(loss)
                        recall += t_r
                        mrr += t_m
                        batch += 1
            else:
                tmp_in_data = batch_data['in_idxes']
                tmp_out_data = batch_data['out_idxes']
                tmp_batch_ids = batch_data['batch_ids']
                for s in range(len(tmp_in_data[0])):
                    batch_in = []
                    batch_out = []
                    batch_last = []
                    batch_seq_l = []
                    for tmp_in, tmp_out in zip(tmp_in_data, tmp_out_data):
                        _in = tmp_in[s]
                        _out = tmp_out[s] - 1
                        batch_last.append(_in)
                        batch_in.append(tmp_in[:s + 1])
                        batch_out.append(_out)
                        batch_seq_l.append(s + 1)
                    feed_dict = {
                        self.inputs: batch_in,
                        self.last_inputs: batch_last,
                        self.lab_input: batch_out,
                        self.sequence_length: batch_seq_l

                    }

                    # train
                    preds, loss, alpha = sess.run(
                        [self.softmax_input, self.loss, self.alph],
                        feed_dict=feed_dict
                    )
                    t_r, t_m, ranks = cau_recall_mrr_org(preds, batch_out, cutoff=self.cut_off)
                    test_data.pack_ext_matrix('alpha', alpha, tmp_batch_ids)
                    test_data.pack_preds(ranks, tmp_batch_ids)
                    c_loss += list(loss)
                    recall += t_r
                    mrr += t_m
                    batch += 1
        r, m =cau_samples_recall_mrr(test_data.samples,self.cut_off)
        print (r,m)
        print (np.mean(c_loss))
        return  np.mean(recall), np.mean(mrr)
Example #5
0
    def test(self, sess, test_data):
        # calculate the acc
        print('Measuring Recall@{} and MRR@{}'.format(self.cut_off,
                                                      self.cut_off))

        mrr, recall = [], []
        c_loss = []
        batch = 0
        bt = batcher(samples=test_data.samples,
                     class_num=self.n_items,
                     random=False)
        while bt.has_next():  # batch round.
            # get this batch data
            batch_data = bt.next_batch()
            # build the feed_dict
            batch_lenth = len(batch_data['in_idxes'])
            if batch_lenth > self.batch_size:
                patch_len = int(batch_lenth / self.batch_size)
                remain = int(batch_lenth % self.batch_size)
                i = 0
                for x in range(patch_len):
                    tmp_in_data = batch_data['in_idxes'][i:i + self.batch_size]
                    tmp_out_data = batch_data['out_idxes'][i:i +
                                                           self.batch_size]
                    tmp_batch_ids = batch_data['batch_ids'][i:i +
                                                            self.batch_size]
                    for s in range(len(tmp_in_data[0])):
                        batch_in = []
                        batch_out = []
                        batch_last = []
                        batch_seq_l = []
                        for tmp_in, tmp_out in zip(tmp_in_data, tmp_out_data):
                            _in = tmp_in[s]
                            _out = tmp_out[s] - 1
                            batch_last.append(_in)
                            batch_in.append(tmp_in[:s + 1])
                            batch_out.append(_out)
                            batch_seq_l.append(s + 1)
                        feed_dict = {
                            self.inputs: batch_in,
                            self.last_inputs: batch_last,
                            self.lab_input: batch_out,
                            self.sequence_length: batch_seq_l
                        }
                        # train
                        preds, loss, alpha = sess.run(
                            [self.softmax_input, self.loss, self.alph],
                            feed_dict=feed_dict)
                        t_r, t_m, ranks = cau_recall_mrr_org(
                            preds, batch_out, cutoff=self.cut_off)
                        test_data.pack_ext_matrix('alpha', alpha,
                                                  tmp_batch_ids)
                        test_data.pack_preds(ranks, tmp_batch_ids)
                        c_loss += list(loss)
                        recall += t_r
                        mrr += t_m
                        batch += 1
                    i += self.batch_size
                if remain > 0:
                    tmp_in_data = batch_data['in_idxes'][i:]
                    tmp_out_data = batch_data['out_idxes'][i:]
                    tmp_batch_ids = batch_data['batch_ids'][i:]
                    for s in range(len(tmp_in_data[0])):
                        batch_in = []
                        batch_out = []
                        batch_last = []
                        batch_seq_l = []
                        for tmp_in, tmp_out in zip(tmp_in_data, tmp_out_data):
                            _in = tmp_in[s]
                            _out = tmp_out[s] - 1
                            batch_last.append(_in)
                            batch_in.append(tmp_in[:s + 1])
                            batch_out.append(_out)
                            batch_seq_l.append(s + 1)
                        feed_dict = {
                            self.inputs: batch_in,
                            self.last_inputs: batch_last,
                            self.lab_input: batch_out,
                            self.sequence_length: batch_seq_l
                        }
                        # train
                        preds, loss, alpha = sess.run(
                            [self.softmax_input, self.loss, self.alph],
                            feed_dict=feed_dict)
                        t_r, t_m, ranks = cau_recall_mrr_org(
                            preds, batch_out, cutoff=self.cut_off)
                        test_data.pack_ext_matrix('alpha', alpha,
                                                  tmp_batch_ids)
                        test_data.pack_preds(ranks, tmp_batch_ids)
                        c_loss += list(loss)
                        recall += t_r
                        mrr += t_m
                        batch += 1

                        print("This is my total_loss --->  ", c_loss)
                        nni.report_intermediate_result(c_loss)
                        logger.debug('test loss %g', c_loss)
                        logger.debug('Pipe send intermediate result done.')

            else:
                tmp_in_data = batch_data['in_idxes']
                tmp_out_data = batch_data['out_idxes']
                tmp_batch_ids = batch_data['batch_ids']
                for s in range(len(tmp_in_data[0])):
                    batch_in = []
                    batch_out = []
                    batch_last = []
                    batch_seq_l = []
                    for tmp_in, tmp_out in zip(tmp_in_data, tmp_out_data):
                        _in = tmp_in[s]
                        _out = tmp_out[s] - 1
                        batch_last.append(_in)
                        batch_in.append(tmp_in[:s + 1])
                        batch_out.append(_out)
                        batch_seq_l.append(s + 1)
                    feed_dict = {
                        self.inputs: batch_in,
                        self.last_inputs: batch_last,
                        self.lab_input: batch_out,
                        self.sequence_length: batch_seq_l
                    }
                    # train
                    preds, loss, alpha = sess.run(
                        [self.softmax_input, self.loss, self.alph],
                        feed_dict=feed_dict)
                    t_r, t_m, ranks = cau_recall_mrr_org(preds,
                                                         batch_out,
                                                         cutoff=self.cut_off)
                    test_data.pack_ext_matrix('alpha', alpha, tmp_batch_ids)
                    test_data.pack_preds(ranks, tmp_batch_ids)
                    c_loss += list(loss)
                    recall += t_r
                    mrr += t_m
                    batch += 1

                    print("This is my total_loss --->  ", c_loss)
                    nni.report_intermediate_result(c_loss)
                    logger.debug('test loss %g', c_loss)
                    logger.debug('Pipe send intermediate result done.')

        # Report final result to the tuner
        nni.report_final_result(c_loss)
        logger.debug('Final result is %g', c_loss)
        logger.debug('Send final result done.')

        r, m = cau_samples_recall_mrr(test_data.samples, self.cut_off)
        return np.mean(recall), np.mean(mrr)