def loop(self, sess, guesser_wrapper, qgen_wrapper, oracle_wrapper): """ The GuessWhat game loop """ wait_for_img_status_published = False while not rospy.is_shutdown(): if not wait_for_img_status_published: self.status.publish('Waiting for image processing') wait_for_img_status_published = True try: seg = self.segmentations.get(timeout=1) feats = self.features.get(timeout=1) wait_for_img_status_published = False except Empty: continue self.status.publish('Starting new game') objects = list( map(self.segmented_image_to_img_obj, enumerate(seg.objects))) game = Game(id=0, object_id=0, objects=objects, qas=[], image={ 'id': 0, 'width': self.image_dim[0], 'height': self.image_dim[1], 'coco_url': '' }, status='false', which_set=None, image_builder=ImgFeaturesBuilder(feats), crop_builder=None) looper = BasicLooper(self.eval_config, guesser_wrapper=guesser_wrapper, qgen_wrapper=qgen_wrapper, oracle_wrapper=oracle_wrapper, tokenizer=self.tokenizer, batch_size=1) iterator = SingleGameIterator(self.tokenizer, game) looper.process(sess, iterator, mode='greedy', store_games=True) #beam_search, sampling or greedy self._resolve_choice(seg.header, looper)
logger=logger) # create training tools loop_sources = qgen_network.get_sources(sess) logger.info("Sources: " + ', '.join(loop_sources)) evaluator = Evaluator(loop_sources, qgen_network.scope_name, network=qgen_network, tokenizer=tokenizer) train_batchifier = LooperBatchifier(tokenizer, loop_sources, train=True) eval_batchifier = LooperBatchifier(tokenizer, loop_sources, train=False) # Initialize the looper to eval/train the game-simulation qgen_network.build_sampling_graph(qgen_config["model"], tokenizer=tokenizer, max_length=loop_config['loop']['max_depth']) looper_evaluator = BasicLooper(loop_config, oracle=oracle_network, guesser=guesser_network, qgen=qgen_network, tokenizer=tokenizer) test_iterator = Iterator(testset, pool=cpu_pool, batch_size=batch_size, batchifier=eval_batchifier, shuffle=False, use_padding=True) test_score = looper_evaluator.process(sess, test_iterator, mode="sampling") logger.info("Test success ratio (Init-Sampling): {}".format(test_score)) logs = [] # Start training final_val_score = 0. for epoch in range(no_epoch):
sources=qgen_network.get_sources(sess_loop), generate=True) qgen_wrapper = QGenWrapper(qgen_network, qgen_batchifier, tokenizer, max_length=12, k_best=20) xp_manager.configure_score_tracking("valid_accuracy", max_is_best=True) loop_config = {} # fake config loop_config['loop'] = {} loop_config['loop']['max_question'] = 5 game_engine = BasicLooper(loop_config, oracle_wrapper=oracle_wrapper, guesser_wrapper=guesser_wrapper, qgen_wrapper=qgen_wrapper, tokenizer=tokenizer, batch_size=64) logger.info(">>> New Objects <<<") compute_qgen_accuracy(sess_loop, trainset, batchifier=train_batchifier, looper=game_engine, mode=mode_to_evaluate, cpu_pool=cpu_pool, batch_size=batch_size, name="ini.new_object", save_path=xp_manager.dir_xp, store_games=True)
qgen_wrapper = QGenWrapper(qgen_network, qgen_batchifier, tokenizer, max_length=loop_config['loop']['max_depth'], k_best=loop_config['loop']['beam_k_best']) oracle_split_mode = 1 # oracle_split_mode = BatchifierSplitMode.from_string(oracle_config["model"]["question"]["input_type"]) oracle_batchifier = oracle_batchifier_cstor(tokenizer, sources=oracle_network.get_sources(sess), split_mode=oracle_split_mode) oracle_wrapper = OracleWrapper(oracle_network, oracle_batchifier, tokenizer) guesser_batchifier = guesser_batchifier_cstor(tokenizer, sources=guesser_network.get_sources(sess)) guesser_wrapper = GuesserWrapper(guesser_network, guesser_batchifier, tokenizer, guesser_listener) xp_manager.configure_score_tracking("valid_accuracy", max_is_best=True) game_engine = BasicLooper(loop_config, oracle_wrapper=oracle_wrapper, guesser_wrapper=guesser_wrapper, qgen_wrapper=qgen_wrapper, tokenizer=tokenizer, batch_size=loop_config["optimizer"]["batch_size"]) # Compute the initial scores if args.test_ini: logger.info(">>>-------------- INITIAL SCORE ---------------------<<<") # evaluator = Evaluator(loop_sources, qgen_network.scope_name, network=qgen_network, tokenizer=tokenizer) cpu_pool = create_cpu_pool(args.no_thread, use_process=False) logger.info(">>> Initial models <<<") test_models(sess, testset, cpu_pool=cpu_pool, batch_size=batch_size*2, oracle=oracle_network, oracle_batchifier=oracle_batchifier, guesser=guesser_network, guesser_batchifier=guesser_batchifier, guesser_listener=guesser_listener, qgen=qgen_network, qgen_batchifier=qgen_batchifier)
oracle_wrapper = OracleWrapper(oracle_network, tokenizer) guesser_wrapper = GuesserWrapper(guesser_network) qgen_network.build_sampling_graph( qgen_config["model"], tokenizer=tokenizer, max_length=loop_config['loop']['max_depth']) qgen_wrapper = QGenWrapper(qgen_network, tokenizer, max_length=loop_config['loop']['max_depth'], k_best=loop_config['loop']['beam_k_best']) looper_evaluator = BasicLooper( loop_config, oracle_wrapper=oracle_wrapper, guesser_wrapper=guesser_wrapper, qgen_wrapper=qgen_wrapper, tokenizer=tokenizer, batch_size=loop_config["optimizer"]["batch_size"]) # Compute the initial scores logger.info(">>>-------------- INITIAL SCORE ---------------------<<<") logger.info(">>> Initial models <<<") test_model(sess, testset, cpu_pool=cpu_pool, tokenizer=tokenizer, oracle=oracle_network, guesser=guesser_network, qgen=qgen_network,
qgen_config["model"], tokenizer=tokenizer, max_length=eval_config['loop']['max_depth']) qgen_wrapper = QGenWrapper( qgen_network, tokenizer, max_length=eval_config['loop']['max_depth'], k_best=eval_config['loop']['beam_k_best']) else: qgen_wrapper = QGenUserWrapper(tokenizer) logger.info("No QGen was registered >>> use user input") looper_evaluator = BasicLooper(eval_config, oracle_wrapper=oracle_wrapper, guesser_wrapper=guesser_wrapper, qgen_wrapper=qgen_wrapper, tokenizer=tokenizer, batch_size=1) logs = [] # Start training final_val_score = 0. batchifier = LooperBatchifier(tokenizer, generate_new_games=False) while True: # Start new game while True: id_str = input( 'Do you want to play a new game? (Yes/No) --> ').lower() if id_str == "y" or id_str == "yes": break
oracle_wrapper = OracleWrapper(oracle_network, tokenizer) guesser_wrapper = GuesserWrapper(guesser_network) qgen_network.build_sampling_graph( qgen_config["model"], tokenizer=tokenizer, max_length=loop_config['loop']['max_depth']) qgen_wrapper = QGenWrapper(qgen_network, tokenizer, max_length=loop_config['loop']['max_depth'], k_best=loop_config['loop']['beam_k_best']) looper_evaluator = BasicLooper( loop_config, oracle_wrapper=oracle_wrapper, guesser_wrapper=guesser_wrapper, qgen_wrapper=qgen_wrapper, tokenizer=tokenizer, batch_size=loop_config["optimizer"]["batch_size"]) # Compute the initial scores logger.info(">>>-------------- INITIAL SCORE ---------------------<<<") for split in ["nd_test", "nd_valid", "od_test", "od_valid"]: logger.info("Loading dataset split {}".format(split)) testset = Dataset(args.data_dir, split, "guesswhat_nocaps", image_builder, crop_builder) logger.info(">>> New Games <<<") dump_suffix = "gameplay_{}_{}".format( split, "rl" if args.rl_identifier else "sl")