def main(datadir): # datadir = sys.argv[1] if os.path.exists(os.path.join(datadir, "static_res.pk")): return pkload(os.path.join(datadir, "static_res.pk")) pattern = re.compile("reviews_(.*?)_5.json") jsondir = os.path.join(datadir, "preprocess", "transform") src, tgt = pattern.findall("\n".join(os.listdir(jsondir))) cold = os.path.join(datadir, "preprocess", "cold") def getUser(domain, overuser=None): domainpattern = glob.glob(os.path.join(cold, "*%s*.pk" % domain)) domainfile = list(domainpattern)[0] domainuser = pkload(domainfile) domainjsonpattern = glob.glob( os.path.join(jsondir, "*%s*.json" % domain)) if domain == "overlap": return domainuser domainUC = 0 domainOUC = 0 domainuser = set(domainuser) overuser = set(overuser) for record in readJson(domainjsonpattern[0]): if record["reviewerID"] in domainuser: domainUC += 1 if record["reviewerID"] in overuser: domainOUC += 1 return domainuser, domainUC, domainOUC overuser = getUser("overlap") srcuser, srcUC, srcOUC = getUser(src, overuser) tgtuser, tgtUC, tgtOUC = getUser(tgt, overuser) print datadir, "done" static_res = { "Domain": [src, tgt], "User": { "overlap": len(overuser), src: len(srcuser), tgt: len(tgtuser) }, "Record": { "srcUC": srcUC, "srcOUC": srcOUC, "tgtUC": tgtUC, "tgtOUC": tgtOUC } } pkdump(static_res, os.path.join(datadir, "static_res.pk")) return static_res
def trainBatch(self, sess, batch): try: _, loss, mae, rmse = sess.run( [self.train_op, self.loss, self.mae, self.rmse], feed_dict=self._buildDict(batch) ) except : pkdump(batch, "debug.pk") raise if not os.path.exists("debugnormal.pk"): pkdump(batch, "debugnormal.pk") return loss, mae, rmse
def main(data_dir, src_domain, tgt_domain, epoches, mode, overlap_rate): domain_name = "%s_%s" % (src_domain, tgt_domain) if "MultiCross" in data_dir: dir_name = "DSNRec/MultiCross/" else: dir_name = "DSNRec/" if overlap_rate != 1.0: dir_name = dir_name.replace("MultiCross", "ChangeOverlapRate/%s" % domain_name) domain_name += "_%1.2f" % overlap_rate runConfig = config.configs["DEBUG"](dir_name + "DSNRec_" + domain_name) runConfig.setConsoleLogLevel("DEBUG") logger = runConfig.getLogger() gpuConfig = runConfig.getGPUConfig() pre_dir = os.path.join(data_dir, "preprocess") dataset = DSNRecDataset.DSNRecDataset(os.path.join(pre_dir, "uirepresent"), os.path.join(pre_dir, "cold"), src_domain, tgt_domain, overlap_rate=overlap_rate) trainBatchGen = dataset.generateTrainBatch("user", 1000) # hmm, 这里的item_ipt_shp跟user_ipt_shp其实相等. # 因为, 不论是item还是user, 各个的向量表示都是从 # 相同维度的句子向量表述变换而来. user_ipt_shp, item_ipt_shp = dataset.getUIShp() enc_shp = [int(user_ipt_shp * r) for r in [0.7, 0.5, 0.3]] dec_shp = [int(item_ipt_shp * r) for r in [0.5, 0.7, 1]] dsn_rec = DSNRec(item_ipt_shp, enc_shp, dec_shp, user_ipt_shp, enc_shp, dec_shp, enc_shp) sess = tf.Session(config=gpuConfig) dsn_rec.initSess(sess) train_writer = tf.summary.FileWriter( "log/" + dir_name + domain_name + "/train", sess.graph) test_mses = [] for epoch in range(epoches): for i in range(100): batchData = next(trainBatchGen) batch = { "item_ipt": batchData["src"]["item"], "user_src_ipt": batchData["src"]["user"], "user_src_rating": batchData["src"]["rating"], "user_tgt_ipt": batchData["tgt"]["user"] } _ = dsn_rec.trainBatch(batch, sess) summary = dsn_rec.getSummary(sess, batch) train_writer.add_summary(summary, epoch) _, test_rmse = dsn_rec.evaluate( sess, dataset.generateTestBatch("user", 1000)) logger.info("The test mse of epoch %d is %f.", epoch, test_rmse) test_mses.append(test_rmse) pkdump(test_mses, "log/" + dir_name + domain_name + "/test_mses.pk")
def main(data_dir, src_domain, tgt_domain, epoches, mode): runConfig = config.configs["DEBUG"]("DSNRec_%s_%s" % (src_domain, tgt_domain)) runConfig.setConsoleLogLevel("DEBUG") logger = runConfig.getLogger() gpuConfig = runConfig.getGPUConfig() pre_dir = os.path.join(data_dir, "preprocess") dataset = DSNRecDataset.DSNRecDataset(os.path.join(pre_dir, "uirepresent"), os.path.join(pre_dir, "cold"), src_domain, tgt_domain) trainBatchGen = dataset.generateTrainBatch("user", 1000) item_ipt_shp, user_ipt_shp = dataset.getUIShp() enc_shp = [int(user_ipt_shp * r) for r in [0.7, 0.5, 0.3]] dec_shp = [int(item_ipt_shp * r) for r in [0.5, 0.7, 1]] dsn_rec = DSNRec(item_ipt_shp, enc_shp, dec_shp, user_ipt_shp, enc_shp, dec_shp, enc_shp) sess = tf.Session(config=gpuConfig) dsn_rec.initSess(sess) train_writer = tf.summary.FileWriter( "log/DSNRec/%s_%s/train" % (src_domain, tgt_domain), sess.graph) test_mses = [] for epoch in range(epoches): for i in range(100): batchData = next(trainBatchGen) batch = { "item_ipt": batchData["src"]["item"], "user_src_ipt": batchData["src"]["user"], "user_src_rating": batchData["src"]["rating"], "user_tgt_ipt": batchData["tgt"]["user"] } _ = dsn_rec.trainBatch(batch, sess) summary = dsn_rec.getSummary(sess, batch) train_writer.add_summary(summary, epoch) _, test_rmse = dsn_rec.evaluate( sess, dataset.generateTestBatch("user", 1000)) logger.info("The test mse of epoch %d is %f.", epoch, test_rmse) test_mses.append(test_rmse) pkdump(test_mses, "log/DSNRec/%s_%s/test_mses.pk" % (src_domain, tgt_domain))
def test_DSNRec(self): record_num = 100 item_ipt_shp = 32 item_ipt = np.random.randint(0, 9, size=record_num * item_ipt_shp).reshape( (record_num, item_ipt_shp)) item_enc_shp = [25, 9] item_dec_shp = [9, 25, item_ipt_shp] user_ipt_shp = 32 usrc_ipt = np.random.randint(0, 9, size=record_num * user_ipt_shp).reshape( (record_num, user_ipt_shp)) utgt_ipt = np.random.randint(0, 9, size=2 * record_num * user_ipt_shp).reshape( (2 * record_num, user_ipt_shp)) usrc_rating = np.random.randint(0, 6, size=record_num) user_enc_shp = [25, 16, 9] user_dec_shp = [9, 16, 25, user_ipt_shp] user_shr_shp = [25, 16, 9] dataset = DSNRecDataset.DSNRecDataset( "exam/data/preprocess/uirepresent", "exam/data/preprocess/cold", "Auto", "Musi") assert user_ipt_shp, item_ipt_shp == dataset.getUIShp() dsn_rec = DSNRec(item_ipt_shp, item_enc_shp, item_dec_shp, user_ipt_shp, user_enc_shp, user_dec_shp, user_shr_shp) sess = tf.Session(config=self.gpuConfig) batch = { "item_ipt": item_ipt, "user_src_ipt": usrc_ipt, "user_src_rating": usrc_rating, "user_tgt_ipt": utgt_ipt } dsn_rec.initSess(sess) train_writer = tf.summary.FileWriter('log/DSNRec/Musi_Auto/train', sess.graph) trainBatchGen = dataset.generateTrainBatch("user", 500) preds = [] for epoch in range(5): for i in range(100): batchData = next(trainBatchGen) batch = { "item_ipt": batchData["src"]["item"], "user_src_ipt": batchData["src"]["user"], "user_src_rating": batchData["src"]["rating"], "user_tgt_ipt": batchData["tgt"]["user"] } # for v in batch.values(): # print v.shape loss = dsn_rec.trainBatch(batch, sess) # print "loss of (i, epoch):(%d, %d) is %f" % (i, epoch, loss) # pdb.set_trace() summary = dsn_rec.getSummary(sess, batch) train_writer.add_summary(summary, epoch) pred, testloss = dsn_rec.evaluate( sess, dataset.generateTestBatch("user", 1000)) print "the loss of test dataset is", testloss preds.append((epoch, pred, testloss)) pkdump((preds, dataset.usersplit['src']['test']), "test_pred.pk")
def sentitrain(dir, domain, filter_size, filter_num, embd_size, epoches): runConfig = config.configs["DEBUG"]("sentitrain_%s_%s_%s" % (domain, filter_size, str(embd_size))) runConfig.setConsoleLogLevel("DEBUG") logger = runConfig.getLogger() gpuConfig = runConfig.getGPUConfig() session = tf.Session(config=gpuConfig) transPath = os.path.join(dir, "transform") data = [] logger.info(transPath + "/*%s*" % domain) for d in glob.glob(transPath + "/*%s*" % domain): data.append(Dataset.SentiRecDataset(d)) if data == []: logger.error("The data of %s is not in %s", domain, transPath) raise Exception data = data[0] vocabPath = os.path.join(dir, "vocab") vocab = pkload(os.path.join(vocabPath, "allDomain.pk")) vocab_size = len(vocab) + 1 filter_size = [int(i) for i in filter_size.split(",")] sentc_len = data.getSentLen() sentirec = SentiRec(sentc_len, vocab_size, embd_size, filter_size, filter_num) sentirec.initSess(session) train_writer = tf.summary.FileWriter('log/sentitrain/%s/train' % domain, session.graph) test_writer = tf.summary.FileWriter('log/sentitrain/%s/test' % domain, session.graph) minMae = 20 minRmse = 20 minEpoch = epoches batchSize = 1000 saver = tf.train.Saver(max_to_keep=1) for epoch in range(epoches): logger.info("Epoch %d" % epoch) @recordTime def senticEpoch(): loss, mae, rmse = 0, 0, 0 i = 0 for batchData in data.getTrainBatch( batchSize, itemgetter("reviewText", "overall")): sentcBatch = [d[0] for d in batchData] ratingBatch = [d[1] for d in batchData] batch = {"sentc_ipt": sentcBatch, "rating": ratingBatch} l, m, r = sentirec.trainBatch(session, batch) loss += l mae += m rmse += r i += 1 logger.info("minMae is %f, epoch mae is %f" % (minMae, mae / i)) logger.info("minRmse is %f, epoch rmse is %f" % (minRmse, rmse / i)) summary = sentirec.getSummary(session, batch) train_writer.add_summary(summary, epoch) if epoch % 50 == 0: global testEpoch for testBatch in data.getTestBatch( batchSize, itemgetter("reviewText", "overall")): testSB = [d[0] for d in testBatch] testRB = [d[1] for d in testBatch] batch = {"sentc_ipt": testSB, "rating": testRB} testSummary = sentirec.getSummary(session, batch) test_writer.add_summary(testSummary, testEpoch) testEpoch += 1 return mae / i, rmse / i return min((minMae, mae / i)), min((minRmse, rmse / i)) mae, rmse = senticEpoch() if mae < minMae: minMae = mae if rmse < minRmse: minRmse = rmse minEpoch = epoch modelSaveDir = os.path.join(dir, "sentiModel/%s/" % domain) if not os.path.exists(modelSaveDir): os.makedirs(modelSaveDir) saver.save(session, os.path.join(modelSaveDir, "%s-model" % domain), global_step=epoch) loader = tf.train.import_meta_graph( os.path.join(modelSaveDir, "%s-model-%d.meta" % (domain, minEpoch))) loader.restore(session, tf.train.latest_checkpoint(modelSaveDir)) sentiOutput = {} for batchData in data._getBatch( data.index, batchSize, itemgetter("reviewText", "reviewerID", "asin")): sentcBatch = [d[0] for d in batchData] reviewerIDAsin = [(d[1], d[2]) for d in batchData] outputVec = sentirec.outputVector(session, sentcBatch) sentiOutput.update(dict(zip(reviewerIDAsin, outputVec))) outputPath = os.path.join(dir, "sentiRecOutput", domain + ".pk") pkdump(sentiOutput, outputPath)