示例#1
0
        logger.info("Testing loss: {}".format(test_loss))

        # Save the test scores
        xp_manager.update_user_data(user_data={
            "test_loss": test_loss,
        })

    tf.reset_default_graph()
    with tf.Session(config=config_gpu) as sess_loop:
        # compute the loop accuracy
        logger.info("==================loop===================")
        mode_to_evaluate = ["sampling", "greedy", "beam"]
        cpu_pool = create_cpu_pool(args.no_thread, use_process=False)

        train_batchifier = LooperBatchifier(tokenizer, generate_new_games=True)
        eval_batchifier = LooperBatchifier(tokenizer, generate_new_games=False)
        oracle_dir = "out/oracle/"
        oracle_checkpoint = "6e0ec7f150b27f46296406853f498af6"
        guesser_dir = "out/guesser/"
        guesser_checkpoint = "c48036b430ebca1c44a25188edb05034"
        oracle_config = get_config_from_xp(oracle_dir, oracle_checkpoint)
        guesser_config = get_config_from_xp(guesser_dir, guesser_checkpoint)

        qgen_network, qgen_batchifier_cstor = create_qgen(
            config["model"], num_words=tokenizer.no_words)
        qgen_var = [v for v in tf.global_variables()
                    if "qgen" in v.name]  # and 'rl_baseline' not in v.name
        for v in qgen_var:
            print(v.name)
        qgen_saver = tf.train.Saver(var_list=qgen_var)
        # check that models are correctly loaded
        test_model(sess, testset, cpu_pool=cpu_pool, tokenizer=tokenizer,
                   oracle=oracle_network,
                   guesser=guesser_network,
                   qgen=qgen_network,
                   batch_size=100,
                   logger=logger)

        # create training tools
        loop_sources = qgen_network.get_sources(sess)
        logger.info("Sources: " + ', '.join(loop_sources))

        evaluator = Evaluator(loop_sources, qgen_network.scope_name, network=qgen_network, tokenizer=tokenizer)

        train_batchifier = LooperBatchifier(tokenizer, loop_sources, train=True)
        eval_batchifier = LooperBatchifier(tokenizer, loop_sources, train=False)

        # Initialize the looper to eval/train the game-simulation
        qgen_network.build_sampling_graph(qgen_config["model"], tokenizer=tokenizer, max_length=loop_config['loop']['max_depth'])
        looper_evaluator = BasicLooper(loop_config,
                                       oracle=oracle_network,
                                       guesser=guesser_network,
                                       qgen=qgen_network,
                                       tokenizer=tokenizer)

        test_iterator = Iterator(testset, pool=cpu_pool,
                                 batch_size=batch_size,
                                 batchifier=eval_batchifier,
                                 shuffle=False,
                                 use_padding=True)
示例#3
0
        else:
            qgen_wrapper = QGenUserWrapper(tokenizer)
            logger.info("No QGen was registered >>> use user input")

        looper_evaluator = BasicLooper(eval_config,
                                       oracle_wrapper=oracle_wrapper,
                                       guesser_wrapper=guesser_wrapper,
                                       qgen_wrapper=qgen_wrapper,
                                       tokenizer=tokenizer,
                                       batch_size=1)

        logs = []
        # Start training
        final_val_score = 0.

        batchifier = LooperBatchifier(tokenizer, generate_new_games=False)
        while True:

            # Start new game
            while True:
                id_str = input(
                    'Do you want to play a new game? (Yes/No) -->  ').lower()
                if id_str == "y" or id_str == "yes": break
                elif id_str == "n" or id_str == "no": exit(0)

            # Pick id image
            image_id = 0
            while True:
                id_str = int(
                    input(
                        'What is the image id you want to select? (-1 for random id) -->  '