示例#1
0
文件: fm.py 项目: misads/ctr
    def train(self, train_data, dev_data=None):
        hparams = self.hparams
        sess = self.sess
        assert len(train_data[0]) == len(
            train_data[1]), "Size of features data must be equal to label"
        for epoch in range(hparams.epoch):
            info = {}
            info['loss'] = []
            info['norm'] = []
            start_time = time.time()
            for idx in range(len(train_data[0]) // hparams.batch_size + 3):
                try:
                    if hparams.steps <= idx:
                        T = (time.time() - start_time)
                        self.eval(T, dev_data, hparams, sess)
                        break
                except:
                    pass
                if idx * hparams.batch_size >= len(train_data[0]):
                    T = (time.time() - start_time)
                    self.eval(T, dev_data, hparams, sess)
                    break

                batch=train_data[0][idx*hparams.batch_size:\
                                    min((idx+1)*hparams.batch_size,len(train_data[0]))]
                batch = utils.hash_batch(batch, hparams)
                label=train_data[1][idx*hparams.batch_size:\
                                    min((idx+1)*hparams.batch_size,len(train_data[1]))]
                loss,_,norm=sess.run([self.loss,self.update,self.grad_norm],\
                                     feed_dict={self.features:batch,self.label:label})
                info['loss'].append(loss)
                info['norm'].append(norm)
                if (idx + 1) % hparams.num_display_steps == 0:
                    info['learning_rate'] = hparams.learning_rate
                    info["train_ppl"] = np.mean(info['loss'])
                    info["avg_grad_norm"] = np.mean(info['norm'])
                    utils.print_step_info("  ", epoch, idx + 1, info)
                    del info
                    info = {}
                    info['loss'] = []
                    info['norm'] = []
                if (idx + 1) % hparams.num_eval_steps == 0 and dev_data:
                    T = (time.time() - start_time)
                    self.eval(T, dev_data, hparams, sess)

        self.saver.restore(sess, 'model_tmp/model')
        T = (time.time() - start_time)
        self.eval(T, dev_data, hparams, sess)
        os.system("rm -r model_tmp")
示例#2
0
文件: ffm.py 项目: ncoll/ctrNet-tool
 def infer(self, dev_data):
     hparams = self.hparams
     sess = self.sess
     assert len(dev_data[0]) == len(
         dev_data[1]), "Size of features data must be equal to label"
     preds = []
     total_loss = []
     for idx in range(len(dev_data[0]) // hparams.batch_size + 1):
         batch=dev_data[0][idx*hparams.batch_size:\
                           min((idx+1)*hparams.batch_size,len(dev_data[0]))]
         batch = utils.hash_batch(batch, hparams)
         label=dev_data[1][idx*hparams.batch_size:\
                           min((idx+1)*hparams.batch_size,len(dev_data[1]))]
         pred=sess.run(self.prob,\
                       feed_dict={self.features:batch,self.label:label})
         preds.append(pred)
     preds = np.concatenate(preds)
     return preds
示例#3
0
文件: xdeepfm.py 项目: misads/ctr
 def get_embedding(self, dev_data):
     hparams = self.hparams
     sess = self.sess
     assert len(dev_data[0]) == len(
         dev_data[1]), "Size of features data must be equal to label"
     embedding = []
     total_loss = []
     for idx in range(len(dev_data[0]) // hparams.batch_size + 1):
         batch=dev_data[0][idx*hparams.batch_size:\
                           min((idx+1)*hparams.batch_size,len(dev_data[0]))]
         if len(batch) == 0:
             break
         batch = utils.hash_batch(batch, hparams)
         label=dev_data[1][idx*hparams.batch_size:\
                           min((idx+1)*hparams.batch_size,len(dev_data[1]))]
         temp=sess.run(self.emb_inp_v2,\
                       feed_dict={self.features:batch,self.label:label})
         embedding.append(temp)
     embedding = np.concatenate(embedding, 0)
     return embedding