예제 #1
0
def main(args):
    if args.mode == 'prepare':  # python3 run.py  --mode prepare --pointer-gen
        prepare(args)
    elif args.mode == 'train':  # python3 run.py  --mode train -b 100 -o output --gpu 0  --restore
        train(args)
    elif args.mode == 'eval':
        # python3 run.py --mode eval --eval-model
        evaluate(args)
    elif args.mode == 'decode':  #
        # python3 run.py --mode decode --beam-size 10 --decode-model output_big_data/model/model-250000 --decode-dir output_big_data/result --gpu 1
        args.batch_size = args.beam_size
        vocab_encoder = Vocab(args, "encoder_vocab")
        vocab_decoder = Vocab(args, "decoder_vocab")
        vocab_user = User_Vocab(args, name="user_vocab")
        test_file = "./test.data"
        #test_file = os.path.join(args.data, 'chat_data/tmp.data')
        # test_file = os.path.join(args.data, 'news_train_span_50.data')
        batcher = TestBatcher(args, vocab_encoder, vocab_decoder, vocab_user,
                              test_file).batcher()
        if args.cpu:
            with tf.device('/cpu:0'):
                model = CommentModel(args, vocab_decoder)
        else:
            model = CommentModel(args, vocab_decoder)

        decoder = BeamSearchDecoder(args, model, batcher, vocab_decoder)
        decoder.decode()
    elif args.mode == 'debug':
        debug(args)
    else:
        raise RuntimeError(f'mode {args.mode} is invalid.')
예제 #2
0
def evaluate(args):
    with open(os.path.join(args.records_dir, 'dev_meta.json'),
              'r',
              encoding='utf8') as p:
        dev_total = json.load(p)['size']

    dev_records_file = os.path.join(args.records_dir, 'dev.tfrecords')
    parser = get_record_parser(args)
    dev_dataset = get_batch_dataset(dev_records_file, parser, args)
    dev_iterator = dev_dataset.make_one_shot_iterator()

    session_config = tf.ConfigProto(allow_soft_placement=True)
    session_config.gpu_options.allow_growth = True
    sess = tf.Session(config=session_config)
    model = CommentModel(args, dev_iterator)
    saver = tf.train.Saver(max_to_keep=10000)
    saver.restore(sess, args.eval_model)

    total_loss = 0
    batches_num = int(np.ceil(dev_total / args.batch_size))
    for i in tqdm(range(batches_num), desc='eval'):
        loss = sess.run(model.loss)
        total_loss += loss
    dev_loss = total_loss / batches_num
    dev_ppl = np.exp(dev_loss)
    print(
        f'{time.asctime()} - Evaluation result -> dev_loss:{dev_loss:.3f}  dev_ppl:{dev_ppl:.3f}'
    )
예제 #3
0
    def comment_sync_down(self, request):
        """
			metoda koja ce vratiti sve komentare koji su dodati/menjani nakon proteglov
			vremena od poslednje sinhronizacije:
			
			Args:
				StringMessage (type): messages.Message 
				vreme poslednje sinhronizacije u %Y-%m-%dT%H:%M:%S formatu
				
			Returns:
				CommentMessageCollection (type): messages.Message 
				Izlazna poruka koja se salje klijentima
		"""

        query = CommentModel.query(CommentModel.last_modified > string_to_datetime(request.date))

        my_items = []

        for comment in query:
            my_items.append(
                CommentMessage(
                    content=comment.content,
                    creator=comment.creator,
                    review_uuid=comment.review_uuid,
                    uuid=comment.uuid,
                    last_modified=comment.last_modified,
                )
            )

        return CommentMessageCollection(items=my_items)
예제 #4
0
    def mutate(cls, _, info, post_id, comment):
        post = PostModel.objects(id=post_id).first()
        new_comment = CommentModel(text=comment, author=AccountModel.objects(id=get_jwt_identity()).first())

        if post is None:
            return CommentLeaveMutation(ResponseMessageField(is_success=False, message="Unknown post id"))

        post.update_one(push_comment=new_comment)

        return CommentLeaveMutation(ResponseMessageField(is_success=True, message="Comment successfully uploaded"))
예제 #5
0
    def comment_list(self, unused_request):
        qry = CommentModel.query()
        my_items = []

        for comment in qry:
            my_items.append(
                CommentMessage(
                    content=comment.content,
                    last_modified=comment.last_modified,
                    creator=comment.creator,
                    uuid=comment.uuid,
                    review_uuid=comment.review_uuid,
                )
            )

        return CommentMessageCollection(items=my_items)
예제 #6
0
    def major(self):
        categories = ['news_society', 'news_entertainment', 'news_tech', 'news_military', 'news_sports', 'news_car',
                      'news_finance', 'news_world', 'news_fashion', 'news_travel', 'news_discovery', 'news_baby',
                      'news_regimen', 'news_story', 'news_essay', 'news_game', 'news_history', 'news_food']
        # categories = ['news_society']

        for category in categories:
            print("当前类别: %s" % category)
            logging.info("当前类别: %s" % category)

            """ 处理 art 
            """
            try:
                arts_brief_json = self.__art_pro.get_arts_brief_json_by_category(category)
                logging.info('%s arts_brief_json 获取 成功' % category)
            except:
                print('%s arts_brief_json 获取 失败' % category)
                logging.exception('%s arts_brief_json 获取 失败' % category)
                continue

            for art_i, art_brief_json in enumerate(arts_brief_json):
                print("当前新闻: %d/%d %s" % (art_i, len(arts_brief_json), category))
                logging.info("当前新闻: %d/%d %s" % (art_i, len(arts_brief_json), category))
                """ 新闻作者
                """
                art_cus_mod = CusMod.CustomerModel()
                try:
                    self.__cus_pro.set_art_cus(art_brief_json, art_cus_mod)
                    self.__cus_dao.insert_then_get_cus(art_cus_mod)
                    self.__cus_dao.update_cus_feature(category, art_cus_mod.cus_id, flag=True)
                    logging.info("%s-%d art_cus 处理 成功" % (category, art_i))
                except:
                    print("%s-%d art_cus 处理 失败" % (category, art_i))
                    logging.exception("%s-%d art_cus 处理 失败" % (category, art_i))
                    continue

                """ 新闻
                """
                art_mod = ArtMod.ArticleModel()
                try:
                    self.__art_pro.set_art(art_brief_json, category, art_cus_mod.cus_id, art_mod)
                    if not self.__art_dao.is_art_exist(art_mod.art_spider):
                        # 新闻不存在的情况
                        self.__art_dao.insert_art(art_mod)
                    else:
                        print("art 已存在")
                        continue
                    art_mod.art_id = self.__art_dao.search_art_id_by_spider(art_mod.art_spider)
                    # art_mod.art_time = self.__art_dao.search_art_time_by_spider(art_mod.art_spider)
                    logging.info("%s-%d art 操作 成功" % (category, art_i))
                except:
                    print("%s-%d art 操作 失败" % (category, art_i))
                    logging.exception("%s-%d art 操作 失败" % (category, art_i))
                    continue

                """ 新闻 用户 行为
                """
                try:
                    if self.__art_dao.check_art_cus_relationship(art_mod.art_id, art_cus_mod.cus_id):
                        self.__cus_dao.insert_cus_behavior(
                            art_cus_mod.cus_id, art_cus_mod.cus_id, 1, art_mod.art_id, 1,
                            art_mod.art_id, cbr_time=art_mod.art_time
                        )
                        self.__cus_dao.insert_cus_behavior(
                            art_cus_mod.cus_id, art_cus_mod.cus_id, 2, art_mod.art_id, 1,
                            art_mod.art_id
                        )
                        self.__cus_dao.update_cus_feature(category, art_cus_mod.cus_id)
                        self.__art_dao.update_art_feature(1, art_mod.art_id, art_mod.art_time)
                    else:
                        pass
                    logging.info("%s-%d rt-cus 行为 1 数据库操作 成功" % (category, art_i))
                except:
                    print("%s-%d rt-cus 行为 1 数据库操作 失败" % (category, art_i))
                    logging.exception("%s-%d rt-cus 行为 1 数据库操作 失败" % (category, art_i))
                    continue

                """ 评论与回复处理
                """
                try:
                    coms_json = self.__com_pro.get_coms_json(art_brief_json)
                    if coms_json is None:
                        continue
                    logging.info("%s-%d coms_json 获取 成功" % (category, art_i))
                except:
                    print("\t%s-%d coms_json 获取 失败" % (category, art_i))
                    logging.exception("%s-%d coms_json 获取 失败" % (category, art_i))
                    continue

                for com_i, com_json in enumerate(coms_json):
                    print("\t当前评论: %d/%d" % (com_i, len(coms_json)))
                    logging.info("当前评论: %d/%d" % (com_i, len(coms_json)))
                    """ 评论用户
                    """
                    com_cus_mod = CusMod.CustomerModel()
                    try:
                        self.__cus_pro.set_com_cus(com_json, com_cus_mod)
                        self.__cus_dao.insert_then_get_cus(com_cus_mod)
                        self.__cus_dao.update_cus_feature(category, com_cus_mod.cus_id, flag=True)
                        # self.__cus_dao.cus_watch_other_same_category_art(com_cus_mod.cus_id, art_mod.art_id, category)
                        logging.info("%s-%d-%d com_cus 处理 错误" % (category, art_i, com_i))
                    except:
                        print("\t%s-%d-%d com_cus 处理 错误" % (category, art_i, com_i))
                        logging.exception("%s-%d-%d com_cus 处理 错误" % (category, art_i, com_i))
                        continue

                    """ 评论
                    """
                    com_mod = ComMod.CommentModel()
                    try:
                        self.__com_pro.set_com(com_json, art_mod.art_id, com_cus_mod.cus_id, com_mod)
                        if not self.__com_dao.is_com_exist(com_mod.com_spider):
                            # 如果评论不存在
                            self.__com_dao.insert_com(com_mod)
                        else:
                            print("com 已存在")
                            continue
                        com_mod.com_id = self.__com_dao.search_com_id_by_spider(com_mod.com_spider)
                        logging.info("%s-%d-%d com 处理 失败" % (category, art_i, com_i))
                    except:
                        print("\t%s-%d-%d com 处理 失败" % (category, art_i, com_i))
                        logging.exception("%s-%d-%d com 处理 失败" % (category, art_i, com_i))
                        continue

                    """ 评论 用户 行为
                    """
                    try:
                        if self.__com_dao.check_com_cus_relationship(art_mod.art_id, com_mod.com_id, com_cus_mod.cus_id):
                            self.__cus_dao.insert_cus_behavior(
                                com_cus_mod.cus_id, art_cus_mod.cus_id, 5, art_mod.art_id, 2,
                                com_mod.com_id, cbr_time=com_mod.com_time
                            )
                            self.__cus_dao.insert_cus_behavior(
                                com_cus_mod.cus_id, art_cus_mod.cus_id, 2, art_mod.art_id, 1,
                                art_mod.art_id
                            )
                            self.__cus_dao.update_cus_feature(category, com_cus_mod.cus_id)
                            self.__art_dao.update_art_feature(4, art_mod.art_id, art_mod.art_time)
                        else:
                            pass
                        logging.info("%s-%d-%d art-cus 行为 4 数据库操作 成功" % (category, art_i, com_i))
                    except:
                        print("\t%s-%d-%d art-cus 行为 4 数据库操作 失败" % (category, art_i, com_i))
                        logging.exception("%s-%d-%d art-cus 行为 4 数据库操作 失败" % (category, art_i, com_i))
                        continue

                    """ 评论用户 模拟浏览
                    """
                    try:
                        result_list = None
                        rand_category_num = random.randint(1, 2)
                        rand_cates = random.sample(categories, rand_category_num)
                        for rand_cate in rand_cates:
                            result_list = self.__art_dao.get_same_category_art(art_mod.art_id, rand_cate)
                            if result_list is not None:
                                for back_art in result_list:
                                    try:
                                        self.__cus_dao.insert_cus_behavior(
                                            com_cus_mod.cus_id, back_art[1], 2, back_art[0], 1, back_art[0]
                                        )
                                        self.__cus_dao.update_cus_feature(rand_cate, com_cus_mod.cus_id, update_num=1)
                                        self.__art_dao.update_art_feature(6, back_art[0], art_mod.art_time)
                                    except:
                                        continue
                                print("\t%d 用户模拟浏览操作 数量 %d 完成" % (com_cus_mod.cus_id, len(result_list)))
                                logging.info("%d 模拟浏览操作 数量 %d 完成" % (com_cus_mod.cus_id, len(result_list)))
                    except:
                        print("\t%d 用户模拟浏览操作 失败" % com_cus_mod.cus_id)
                        logging.exception("%d 用户模拟浏览操作 失败" % com_cus_mod.cus_id)

                    """ 回复处理
                    """
                    try:
                        reps_json = self.__rep_pro.get_reps_json(com_json)
                        if reps_json is None:
                            continue
                        logging.info("%s-%d-%d reps_json 获取 成功" % (category, art_i, com_i))
                    except:
                        print("\t\t%s-%d-%d reps_json 获取 失败" % (category, art_i, com_i))
                        logging.exception("%s-%d-%d reps_json 获取 失败" % (category, art_i, com_i))
                        continue

                    for rep_i, rep_json in enumerate(reps_json):
                        """ 回复用户
                        """
                        rep_cus_mod = CusMod.CustomerModel()
                        try:
                            self.__cus_pro.set_rep_cus(rep_json, rep_cus_mod)
                            self.__cus_dao.insert_then_get_cus(rep_cus_mod)
                            self.__cus_dao.update_cus_feature(category, rep_cus_mod.cus_id, flag=True)
                            logging.info("%s-%d-%d-%d rep_cus 处理 成功" % (category, art_i, com_i, rep_i))
                        except:
                            print("\t\t%s-%d-%d-%d rep_cus 处理 失败" % (category, art_i, com_i, rep_i))
                            logging.exception("%s-%d-%d-%d rep_cus 处理 失败" % (category, art_i, com_i, rep_i))
                            continue

                        """ 回复
                        """
                        rep_mod = RepMod.ReplyModel()
                        try:
                            self.__rep_pro.set_rep(rep_json, art_mod.art_id,
                                                   com_mod.com_id, rep_cus_mod.cus_id, rep_mod)
                            if not self.__rep_dao.is_rep_exist(rep_mod.rep_spider):
                                self.__rep_dao.search_rep_rep_by_spyder(rep_json, rep_mod)
                                self.__rep_dao.insert_rep(rep_mod)
                            else:
                                print("rep 已存在")
                                continue
                            rep_mod.rep_id = self.__rep_dao.search_rep_id_by_spider(rep_mod.rep_spider)
                            logging.info("%s-%d-%d-%d rep 处理 成功" % (category, art_i, com_i, rep_i))
                        except:
                            print("\t\t%s-%d-%d-%d rep 处理 失败" % (category, art_i, com_i, rep_i))
                            logging.exception("%s-%d-%d-%d rep 处理 失败" % (category, art_i, com_i, rep_i))
                            continue

                        """ 回复 用户 行为
                        """
                        try:
                            if self.__rep_dao.check_rep_cus_relationship(art_mod.art_id, rep_mod.rep_id,
                                                                         rep_cus_mod.cus_id):
                                self.__cus_dao.insert_cus_behavior(
                                    rep_cus_mod.cus_id, art_cus_mod.cus_id, 8, art_mod.art_id, 3,
                                    rep_mod.rep_id, cbr_time=rep_mod.rep_time
                                )
                                self.__cus_dao.insert_cus_behavior(
                                    rep_cus_mod.cus_id, art_cus_mod.cus_id, 2, art_mod.art_id, 1,
                                    art_mod.art_id
                                )
                                self.__cus_dao.update_cus_feature(category, rep_cus_mod.cus_id)
                                self.__art_dao.update_art_feature(5, art_mod.art_id, art_mod.art_time)
                            else:
                                pass
                            logging.info("%s-%d-%d-%d art-cus 行为 5 数据库操作 成功" % (category, art_i, com_i, rep_i))
                        except:
                            print("\t\t%s-%d-%d-%d art-cus 行为 5 数据库操作 失败" % (category, art_i, com_i, rep_i))
                            logging.exception("%s-%d-%d-%d art-cus 行为 5 数据库操作 失败" % (category, art_i, com_i, rep_i))
                            continue

                        """ 回复用户 模拟浏览
                        """
                        try:
                            result_list = None
                            rand_category_num = random.randint(1, 2)
                            rand_cates = random.sample(categories, rand_category_num)
                            for rand_cate in rand_cates:
                                result_list = self.__art_dao.get_same_category_art(art_mod.art_id, rand_cate)
                                if result_list is not None:
                                    for back_art in result_list:
                                        try:
                                            self.__cus_dao.insert_cus_behavior(
                                                rep_cus_mod.cus_id, back_art[1], 2, back_art[0], 1, back_art[0]
                                            )
                                            self.__cus_dao.update_cus_feature(rand_cate, rep_cus_mod.cus_id, update_num=1)
                                            self.__art_dao.update_art_feature(6, back_art[0], art_mod.art_time)
                                        except:
                                            continue
                                    print("\t\t%d 用户模拟浏览操作 数量 %d 完成" % (rep_cus_mod.cus_id, len(result_list)))
                                    logging.info("%d 用户模拟浏览操作 数量 %d 完成" % (rep_cus_mod.cus_id, len(result_list)))
                        except:
                            print("\t\t%d 用户模拟浏览操作 失败" % rep_cus_mod.cus_id)
                            logging.exception("%d 用户模拟浏览操作 失败" % rep_cus_mod.cus_id)
예제 #7
0
def train(args):
    output_dir = args.output
    log_dir = args.log_dir if args.log_dir else os.path.join(output_dir, 'log')
    model_dir = os.path.join(output_dir, 'model')
    records_dir = args.records_dir if not args.data_dir else os.path.join(
        args.data_dir, args.records_dir)
    result_dir = os.path.join(output_dir, 'result')
    for dir in [output_dir, log_dir, model_dir, result_dir]:
        if not os.path.exists(dir):
            os.makedirs(dir)

    # save the args info to ouptut dir.
    with open(os.path.join(output_dir, 'args.json'), 'w') as p:
        json.dump(vars(args), p, indent=2)

    # load meta info
    with open(os.path.join(records_dir, 'train_meta.json'),
              'r',
              encoding='utf8') as p:
        train_total = json.load(p)['size']
        batch_num_per_epoch = int(np.ceil(train_total / args.batch_size))
        print(f'{time.asctime()} - batch num per epoch: {batch_num_per_epoch}')

    with open(os.path.join(records_dir, 'dev_meta.json'), 'r',
              encoding='utf8') as p:
        dev_total = json.load(p)['size']

    train_records_file = os.path.join(records_dir, 'train.tfrecords')
    dev_records_file = os.path.join(records_dir, 'dev.tfrecords')

    with tf.Graph().as_default() as graph, tf.device('/gpu:0'):

        parser = get_record_parser(args)
        train_dataset = get_batch_dataset(train_records_file, parser, args)
        dev_dataset = get_batch_dataset(dev_records_file, parser, args)

        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(
            handle, train_dataset.output_types, train_dataset.output_shapes)
        train_iterator = train_dataset.make_one_shot_iterator()
        dev_iterator = dev_dataset.make_one_shot_iterator()

        model = CommentModel(args, iterator)

        session_config = tf.ConfigProto(allow_soft_placement=True)
        session_config.gpu_options.allow_growth = True
        sess = tf.Session(config=session_config)

        # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
        # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)

        writer = tf.summary.FileWriter(log_dir)
        best_ppl = tf.Variable(300,
                               trainable=False,
                               name='best_ppl',
                               dtype=tf.float32)

        saver = tf.train.Saver(max_to_keep=10000)
        if args.restore:
            model_file = args.restore_model or tf.train.latest_checkpoint(
                model_dir)
            print(f'{time.asctime()} - Restore model from {model_file}..')
            var_list = [
                _[0] for _ in checkpoint_utils.list_variables(model_file)
            ]
            saved_vars = [
                _ for _ in tf.global_variables()
                if _.name.split(':')[0] in var_list
            ]
            res_saver = tf.train.Saver(saved_vars)
            res_saver.restore(sess, model_file)

            left_vars = [
                _ for _ in tf.global_variables()
                if _.name.split(':')[0] not in var_list
            ]
            sess.run(tf.initialize_variables(left_vars))
            print(
                f'{time.asctime()} - Restore {len(var_list)} vars and initialize {len(left_vars)} vars.'
            )
            print(left_vars)
        else:
            print(f'{time.asctime()} - Initialize model..')
            sess.run(tf.global_variables_initializer())
            # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)

        train_handle = sess.run(train_iterator.string_handle())
        dev_handle = sess.run(dev_iterator.string_handle())

        sess.run(tf.assign(model.is_train, tf.constant(True,
                                                       dtype=tf.bool)))  #tmp

        patience = 0

        lr = sess.run(model.lr)
        b_ppl = sess.run(best_ppl)
        print(f'{time.asctime()} - lr: {lr:.3f}  best_ppl:{b_ppl:.3f}')

        t0 = datetime.now()

        while True:
            global_step = sess.run(model.global_step) + 1
            epoch = int(np.ceil(global_step / batch_num_per_epoch))

            loss, loss_gen, ppl, train_op, merge_sum, target, check_1 = sess.run(
                [
                    model.loss, model.loss_gen, model.ppl, model._train_op,
                    model._summaries, model.target, model.check_dec_outputs
                ],
                feed_dict={handle: train_handle})

            ela_time = str(datetime.now() - t0).split('.')[0]

            print(
                (f'{time.asctime()} - step/epoch:{global_step}/{epoch:<3d}   '
                 f'gen_loss:{loss_gen:<3.3f}  '
                 f'ppl:{ppl:<4.3f}  '
                 f'elapsed:{ela_time}\r'),
                end='')

            if global_step % args.period == 0:
                writer.add_summary(merge_sum, global_step)
                writer.flush()

            if global_step % args.checkpoint == 0:
                model_file = os.path.join(model_dir, 'model')
                saver.save(sess, model_file, global_step=global_step)

            # if  global_step % batch_num_per_epoch== 0:
            if global_step % args.checkpoint == 0 and not args.no_eval:
                sess.run(
                    tf.assign(model.is_train, tf.constant(False,
                                                          dtype=tf.bool)))
                metrics, summ = evaluate_batch(model,
                                               dev_total // args.batch_size,
                                               sess, handle, dev_handle,
                                               iterator)
                sess.run(
                    tf.assign(model.is_train, tf.constant(True,
                                                          dtype=tf.bool)))

                for s in summ:
                    writer.add_summary(s, global_step)

                dev_ppl = metrics['ppl']
                dev_gen_loss = metrics['gen_loss']

                tqdm.write(
                    f'{time.asctime()} - Evaluate after steps:{global_step}, '
                    f' gen_loss:{dev_gen_loss:.4f},  ppl:{dev_ppl:.3f}')

                if dev_ppl < b_ppl:
                    sess.run(tf.assign(best_ppl, dev_ppl))
                    saver.save(sess, save_path=os.path.join(model_dir, 'best'))
                    tqdm.write(
                        f'{time.asctime()} - the ppl is lower than current best ppl so saved the model.'
                    )
                    patience = 0
                else:
                    patience += 1

                if patience >= args.patience:
                    lr = lr / 2
                    sess.run(
                        tf.assign(model.lr, tf.constant(lr, dtype=tf.float32)))
                    patience = 0
                    tqdm.write(
                        f'{time.asctime()} - The lr is decayed form {lr*2} to {lr}.'
                    )
예제 #8
0
def debug(args):
    from utils import get_record_parser, get_batch_dataset
    import tensorflow as tf
    # parser = get_record_parser(args)
    # dataset = get_batch_dataset('data/records/dev.tfrecords', parser, args)
    # iterator = dataset.make_one_shot_iterator()
    # sess = tf.Session()
    # while True:
    #     print(sess.run(iterator.get_next()))
    # break
    # vocab = Vocab(args.vocab, args.vocab_size)
    # test_file = os.path.join(args.data, 'news_test.data')
    # batcher = TestBatcher(args, vocab, test_file).batcher()
    # for b in batcher:
    #     pass

    parser = get_record_parser(args)
    dataset = get_batch_dataset('data/records/dev.tfrecords', parser, args)

    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(handle,
                                                   dataset.output_types,
                                                   dataset.output_shapes)
    train_iterator = dataset.make_one_shot_iterator()

    model = CommentModel(args, iterator)

    session_config = tf.ConfigProto(allow_soft_placement=True)
    session_config.gpu_options.allow_growth = True

    sess = tf.Session(config=session_config)
    sess.run(tf.global_variables_initializer())

    train_handle = sess.run(train_iterator.string_handle())

    sess.run(tf.assign(model.is_train, tf.constant(True, dtype=tf.bool)))
    get_results = {
        'description': model.description,
        'description_sen': model.description_sen,
        'description_off': model.description_off,
        'description_len': model.description_len,
        'description_mask': model.description_mask,
        'query': model.query,
        'query_len': model.query_len,
        'query_mask': model.query_mask,
        'response': model.response,
        'response_len': model.response_len,
        'response_mask': model.response_mask,
        'target': model.target,
        'target_len': model.target_len,
        'target_mask': model.target_mask,
        # 'x': model.x,
        # 'y': model.y,
        # 'spans': model.span_seq,
        # 'span_num': model.span_num,
        # 'span_mask': model.span_mask
    }

    while True:
        results = sess.run(get_results, feed_dict={handle: train_handle})
        results = {k: v.tolist() for k, v in results.items()}
        from pprint import pprint
        if results['loss'] > 100000:
            pprint(results['loss'], width=1000)
            pprint(results['target'], width=1000)
            pprint(results['target_mask'], width=1000)
예제 #9
0
    def major(self):
        categories = [
            'news_society', 'news_entertainment', 'news_tech', 'news_military',
            'news_sports', 'news_car', 'news_finance', 'news_world',
            'news_fashion', 'news_travel', 'news_discovery', 'news_baby',
            'news_regimen', 'news_story', 'news_essay', 'news_game',
            'news_history', 'news_food'
        ]
        for category in categories:
            print("\n当前类别: %s" % category)
            """ 处理 art """
            try:
                arts_brief_json = self.__art_pro.get_arts_brief_json_by_category(
                    category)
                if len(arts_brief_json) != 0:
                    print("新闻总长度: %d" % len(arts_brief_json))
                # print('arts_brief_json 获取 成功')
            except:
                print('arts_brief_json 获取 失败')
                continue

            for art_brief_json in arts_brief_json:
                # art_cus
                art_cus_mod = CusMod.CustomerModel()
                try:
                    self.__cus_pro.set_art_cus(art_brief_json, art_cus_mod)
                    self.__cus_dao.group_check_insert_cus_then_search_id(
                        art_cus_mod)
                    # print("art_cus 处理 成功")
                except:
                    print("art_cus 处理 失败")
                    continue
                # art
                art_mod = ArtMod.ArticleModel()
                try:
                    self.__art_pro.set_art(art_brief_json, category,
                                           art_cus_mod.cus_id, art_mod)
                    if not self.__art_dao.is_art_exist(art_mod.art_spider):
                        # 新闻不存在的情况
                        self.__art_dao.insert_art(art_mod)
                    else:
                        print("art 已存在")
                        continue
                    art_mod.art_id = self.__art_dao.search_art_id_by_spider(
                        art_mod.art_spider)
                    # print("art 操作 成功")
                except:
                    print("art 操作 失败")
                    continue
                # art cus behavior
                try:
                    if self.__art_dao.check_art_cus_relationship(
                            art_mod.art_id, art_cus_mod.cus_id):
                        self.__cus_dao.insert_cus_behavior(
                            1, art_mod.art_id, art_cus_mod.cus_id,
                            art_mod.art_time)
                    else:
                        pass
                    # print("art-cus 行为 1 数据库操作 成功")
                except:
                    print("art-cus 行为 1 数据库操作 失败")
                    continue
                """ handel the coms """
                try:
                    coms_json = self.__com_pro.get_coms_json(art_brief_json)
                    if len(coms_json) != 0:
                        print("回复总长 %d" % len(coms_json))
                except:
                    print("coms_json 获取 失败")
                    continue

                for com_json in coms_json:
                    # com_cus
                    com_cus_mod = CusMod.CustomerModel()
                    try:
                        self.__cus_pro.set_com_cus(com_json, com_cus_mod)
                        self.__cus_dao.group_check_insert_cus_then_search_id(
                            com_cus_mod)
                        # print("com_cus 处理 成功")
                    except:
                        print("com_cus 处理 错误")
                        continue
                    # com
                    com_mod = ComMod.CommentModel()
                    try:
                        self.__com_pro.set_com(com_json, art_mod.art_id,
                                               com_cus_mod.cus_id, com_mod)
                        if not self.__com_dao.is_com_exist(com_mod.com_spider):
                            # if the com is not exist
                            self.__com_dao.insert_com(com_mod)
                        else:
                            print("com 已存在")
                            continue
                        com_mod.com_id = self.__com_dao.search_com_id_by_spider(
                            com_mod.com_spider)
                        self.__art_dao.update_art_com_number(art_mod.art_id)
                        # print("com 处理 成功")
                    except:
                        print("com 处理 失败")
                        continue
                    # com cus behavior
                    try:
                        if self.__com_dao.check_com_cus_relationship(
                                art_mod.art_id, com_mod.com_id,
                                com_cus_mod.cus_id):
                            self.__cus_dao.insert_cus_behavior(
                                4, art_mod.art_id, com_cus_mod.cus_id,
                                com_mod.com_time)
                        else:
                            pass
                        # print("art-cus 行为 4 数据库操作 成功")
                    except:
                        print("art-cus 行为 4 数据库操作 失败")
                        continue
                    """ handel the reps """
                    try:
                        reps_json = self.__rep_pro.get_reps_json(com_json)
                        if len(reps_json) != 0:
                            print("回复总长 %d" % len(reps_json))
                    except:
                        print("reps_json 获取 失败")
                        continue

                    for rep_json in reps_json:
                        # rep_cus
                        rep_cus_mod = CusMod.CustomerModel()
                        try:
                            self.__cus_pro.set_rep_cus(rep_json, rep_cus_mod)
                            self.__cus_dao.group_check_insert_cus_then_search_id(
                                rep_cus_mod)
                            # print("rep_cus 处理 成功")
                        except:
                            print("rep_cus 处理 失败")
                            continue
                        # rep
                        rep_mod = RepMod.ReplyModel()
                        try:
                            self.__rep_pro.set_rep(rep_json, art_mod.art_id,
                                                   com_mod.com_id,
                                                   rep_cus_mod.cus_id, rep_mod)
                            if not self.__rep_dao.is_rep_exist(
                                    rep_mod.rep_spider):
                                self.__rep_dao.search_rep_rep_by_spyder(
                                    rep_json, rep_mod)
                                self.__rep_dao.insert_rep(rep_mod)
                            else:
                                print("rep 已存在")
                                continue
                            rep_mod.rep_id = self.__rep_dao.search_rep_id_by_spider(
                                rep_mod.rep_spider)
                            # print("rep 处理 成功")
                        except:
                            print("rep 处理 失败")
                            continue

                        # rep cus behavior
                        try:
                            if self.__rep_dao.check_rep_cus_relationship(
                                    art_mod.art_id, rep_mod.rep_id,
                                    rep_cus_mod.cus_id):
                                self.__cus_dao.insert_cus_behavior(
                                    5, art_mod.art_id, rep_cus_mod.cus_id,
                                    rep_mod.rep_time)
                            else:
                                pass
                            # print("art-cus 行为 5 数据库操作 成功")
                        except:
                            print("art-cus 行为 5 数据库操作 失败")
                            continue
예제 #10
0
from model import CommentModel
from model import create

create()

model = CommentModel()
model.load()
res = model.predict(
    ["amazing hot✨✨✨❤️❤️", "nice", "Summer is here in Colorado now"])
print(res)
# [1,1,0]
예제 #11
0
파일: service.py 프로젝트: 1q84/2057
 def __init__(self):
     self.user=UserModel.instance()
     self.note=NoteModel.instance()
     self.comment=CommentModel.instance()
     self.relation=RelationModel.instance()
     self.notification = NotificationModel.instance()