コード例 #1
0
    def loop(self, sess, guesser_wrapper, qgen_wrapper, oracle_wrapper):
        """ The GuessWhat game loop """
        wait_for_img_status_published = False
        while not rospy.is_shutdown():
            if not wait_for_img_status_published:
                self.status.publish('Waiting for image processing')
                wait_for_img_status_published = True
            try:
                seg = self.segmentations.get(timeout=1)
                feats = self.features.get(timeout=1)
                wait_for_img_status_published = False
            except Empty:
                continue

            self.status.publish('Starting new game')

            objects = list(
                map(self.segmented_image_to_img_obj, enumerate(seg.objects)))
            game = Game(id=0,
                        object_id=0,
                        objects=objects,
                        qas=[],
                        image={
                            'id': 0,
                            'width': self.image_dim[0],
                            'height': self.image_dim[1],
                            'coco_url': ''
                        },
                        status='false',
                        which_set=None,
                        image_builder=ImgFeaturesBuilder(feats),
                        crop_builder=None)

            looper = BasicLooper(self.eval_config,
                                 guesser_wrapper=guesser_wrapper,
                                 qgen_wrapper=qgen_wrapper,
                                 oracle_wrapper=oracle_wrapper,
                                 tokenizer=self.tokenizer,
                                 batch_size=1)

            iterator = SingleGameIterator(self.tokenizer, game)
            looper.process(sess, iterator, mode='greedy',
                           store_games=True)  #beam_search, sampling or greedy

            self._resolve_choice(seg.header, looper)
コード例 #2
0
                   logger=logger)

        # create training tools
        loop_sources = qgen_network.get_sources(sess)
        logger.info("Sources: " + ', '.join(loop_sources))

        evaluator = Evaluator(loop_sources, qgen_network.scope_name, network=qgen_network, tokenizer=tokenizer)

        train_batchifier = LooperBatchifier(tokenizer, loop_sources, train=True)
        eval_batchifier = LooperBatchifier(tokenizer, loop_sources, train=False)

        # Initialize the looper to eval/train the game-simulation
        qgen_network.build_sampling_graph(qgen_config["model"], tokenizer=tokenizer, max_length=loop_config['loop']['max_depth'])
        looper_evaluator = BasicLooper(loop_config,
                                       oracle=oracle_network,
                                       guesser=guesser_network,
                                       qgen=qgen_network,
                                       tokenizer=tokenizer)

        test_iterator = Iterator(testset, pool=cpu_pool,
                                 batch_size=batch_size,
                                 batchifier=eval_batchifier,
                                 shuffle=False,
                                 use_padding=True)
        test_score = looper_evaluator.process(sess, test_iterator, mode="sampling")
        logger.info("Test success ratio (Init-Sampling): {}".format(test_score))

        logs = []
        # Start training
        final_val_score = 0.
        for epoch in range(no_epoch):
コード例 #3
0
            sources=qgen_network.get_sources(sess_loop),
            generate=True)
        qgen_wrapper = QGenWrapper(qgen_network,
                                   qgen_batchifier,
                                   tokenizer,
                                   max_length=12,
                                   k_best=20)

        xp_manager.configure_score_tracking("valid_accuracy", max_is_best=True)

        loop_config = {}  # fake config
        loop_config['loop'] = {}
        loop_config['loop']['max_question'] = 5
        game_engine = BasicLooper(loop_config,
                                  oracle_wrapper=oracle_wrapper,
                                  guesser_wrapper=guesser_wrapper,
                                  qgen_wrapper=qgen_wrapper,
                                  tokenizer=tokenizer,
                                  batch_size=64)

        logger.info(">>>  New Objects  <<<")
        compute_qgen_accuracy(sess_loop,
                              trainset,
                              batchifier=train_batchifier,
                              looper=game_engine,
                              mode=mode_to_evaluate,
                              cpu_pool=cpu_pool,
                              batch_size=batch_size,
                              name="ini.new_object",
                              save_path=xp_manager.dir_xp,
                              store_games=True)
コード例 #4
0
        qgen_wrapper = QGenWrapper(qgen_network, qgen_batchifier, tokenizer,
                                   max_length=loop_config['loop']['max_depth'],
                                   k_best=loop_config['loop']['beam_k_best'])

        oracle_split_mode = 1
        # oracle_split_mode = BatchifierSplitMode.from_string(oracle_config["model"]["question"]["input_type"])
        oracle_batchifier = oracle_batchifier_cstor(tokenizer, sources=oracle_network.get_sources(sess), split_mode=oracle_split_mode)
        oracle_wrapper = OracleWrapper(oracle_network, oracle_batchifier, tokenizer)

        guesser_batchifier = guesser_batchifier_cstor(tokenizer, sources=guesser_network.get_sources(sess))
        guesser_wrapper = GuesserWrapper(guesser_network, guesser_batchifier, tokenizer, guesser_listener)

        xp_manager.configure_score_tracking("valid_accuracy", max_is_best=True)
        game_engine = BasicLooper(loop_config,
                                  oracle_wrapper=oracle_wrapper,
                                  guesser_wrapper=guesser_wrapper,
                                  qgen_wrapper=qgen_wrapper,
                                  tokenizer=tokenizer,
                                  batch_size=loop_config["optimizer"]["batch_size"])

        # Compute the initial scores
        if args.test_ini:
            logger.info(">>>-------------- INITIAL SCORE ---------------------<<<")
            # evaluator = Evaluator(loop_sources, qgen_network.scope_name, network=qgen_network, tokenizer=tokenizer)
            cpu_pool = create_cpu_pool(args.no_thread, use_process=False)

            logger.info(">>>  Initial models  <<<")
            test_models(sess, testset, cpu_pool=cpu_pool, batch_size=batch_size*2,
                        oracle=oracle_network, oracle_batchifier=oracle_batchifier,
                        guesser=guesser_network, guesser_batchifier=guesser_batchifier, guesser_listener=guesser_listener,
                        qgen=qgen_network, qgen_batchifier=qgen_batchifier)
コード例 #5
0
        oracle_wrapper = OracleWrapper(oracle_network, tokenizer)
        guesser_wrapper = GuesserWrapper(guesser_network)
        qgen_network.build_sampling_graph(
            qgen_config["model"],
            tokenizer=tokenizer,
            max_length=loop_config['loop']['max_depth'])
        qgen_wrapper = QGenWrapper(qgen_network,
                                   tokenizer,
                                   max_length=loop_config['loop']['max_depth'],
                                   k_best=loop_config['loop']['beam_k_best'])

        looper_evaluator = BasicLooper(
            loop_config,
            oracle_wrapper=oracle_wrapper,
            guesser_wrapper=guesser_wrapper,
            qgen_wrapper=qgen_wrapper,
            tokenizer=tokenizer,
            batch_size=loop_config["optimizer"]["batch_size"])

        # Compute the initial scores
        logger.info(">>>-------------- INITIAL SCORE ---------------------<<<")

        logger.info(">>>  Initial models  <<<")
        test_model(sess,
                   testset,
                   cpu_pool=cpu_pool,
                   tokenizer=tokenizer,
                   oracle=oracle_network,
                   guesser=guesser_network,
                   qgen=qgen_network,
コード例 #6
0
                qgen_config["model"],
                tokenizer=tokenizer,
                max_length=eval_config['loop']['max_depth'])
            qgen_wrapper = QGenWrapper(
                qgen_network,
                tokenizer,
                max_length=eval_config['loop']['max_depth'],
                k_best=eval_config['loop']['beam_k_best'])

        else:
            qgen_wrapper = QGenUserWrapper(tokenizer)
            logger.info("No QGen was registered >>> use user input")

        looper_evaluator = BasicLooper(eval_config,
                                       oracle_wrapper=oracle_wrapper,
                                       guesser_wrapper=guesser_wrapper,
                                       qgen_wrapper=qgen_wrapper,
                                       tokenizer=tokenizer,
                                       batch_size=1)

        logs = []
        # Start training
        final_val_score = 0.

        batchifier = LooperBatchifier(tokenizer, generate_new_games=False)
        while True:

            # Start new game
            while True:
                id_str = input(
                    'Do you want to play a new game? (Yes/No) -->  ').lower()
                if id_str == "y" or id_str == "yes": break
コード例 #7
0
        oracle_wrapper = OracleWrapper(oracle_network, tokenizer)
        guesser_wrapper = GuesserWrapper(guesser_network)
        qgen_network.build_sampling_graph(
            qgen_config["model"],
            tokenizer=tokenizer,
            max_length=loop_config['loop']['max_depth'])
        qgen_wrapper = QGenWrapper(qgen_network,
                                   tokenizer,
                                   max_length=loop_config['loop']['max_depth'],
                                   k_best=loop_config['loop']['beam_k_best'])

        looper_evaluator = BasicLooper(
            loop_config,
            oracle_wrapper=oracle_wrapper,
            guesser_wrapper=guesser_wrapper,
            qgen_wrapper=qgen_wrapper,
            tokenizer=tokenizer,
            batch_size=loop_config["optimizer"]["batch_size"])

        # Compute the initial scores
        logger.info(">>>-------------- INITIAL SCORE ---------------------<<<")

        for split in ["nd_test", "nd_valid", "od_test", "od_valid"]:
            logger.info("Loading dataset split {}".format(split))
            testset = Dataset(args.data_dir, split, "guesswhat_nocaps",
                              image_builder, crop_builder)

            logger.info(">>>  New Games  <<<")
            dump_suffix = "gameplay_{}_{}".format(
                split, "rl" if args.rl_identifier else "sl")