Ejemplo n.º 1
0
    def __init__(self,
                 sess,
                 num_unrolls=1,
                 free_space_network_scope=None,
                 depth_scope=None):
        super(RLGraphAgent, self).__init__(sess, depth_scope)
        self.network = QAPlannerNetwork(constants.RL_GRU_SIZE, 1, num_unrolls)
        self.network.create_net()

        self.num_steps = 0
        self.global_step_id = 0
        self.global_num_steps = 0
        self.num_invalid_actions = 0
        self.global_num_invalid_actions = 0
        self.num_unrolls = num_unrolls
        self.coord_box = np.mgrid[int(-constants.STEPS_AHEAD * 1.0 /
                                      2):np.ceil(constants.STEPS_AHEAD * 1.0 /
                                                 2),
                                  1:1 + constants.STEPS_AHEAD].transpose(
                                      1, 2, 0) / constants.STEPS_AHEAD

        if constants.USE_NAVIGATION_AGENT:
            self.nav_agent = GraphAgent(sess, True, 1, self.game_state,
                                        free_space_network_scope)
Ejemplo n.º 2
0
def run():
    global global_t
    global dataset
    global data_ind
    if constants.OBJECT_DETECTION:
        from darknet_object_detection import detector
        detector.setup_detectors(constants.PARALLEL_SIZE)

    os.environ["CUDA_VISIBLE_DEVICES"] = str(constants.GPU_ID)

    try:
        with tf.variable_scope('global_network'):
            if constants.END_TO_END_BASELINE:
                global_network = EndToEndBaselineNetwork()
            else:
                global_network = QAPlannerNetwork(constants.RL_GRU_SIZE, 1, 1)
            global_network.create_net()
        if constants.USE_NAVIGATION_AGENT:
            with tf.variable_scope('nav_global_network') as net_scope:
                free_space_network = FreeSpaceNetwork(constants.GRU_SIZE, 1, 1)
                free_space_network.create_net()
        else:
            net_scope = None

        conv_var_list = [v for v in tf.trainable_variables() if 'conv' in v.name and 'weight' in v.name and
                        (v.get_shape().as_list()[0] != 1 or v.get_shape().as_list()[1] != 1)]
        for var in conv_var_list:
            tf_util.conv_variable_summaries(var, scope=var.name.replace('/', '_')[:-2])
        conv_image_summary = tf.summary.merge_all()

        # prepare session
        sess = tf_util.Session()

        # Instantiate singletons without scope.
        if constants.PREDICT_DEPTH:
            from depth_estimation_network import depth_estimator
            with tf.variable_scope('') as depth_scope:
                depth_estimator = depth_estimator.get_depth_estimator(sess)
        else:
            depth_scope = None

        training_threads = []

        learning_rate_input = tf.placeholder(tf.float32, name='learning_rate')
        grad_applier = RMSPropApplier(learning_rate=learning_rate_input,
                decay=constants.RMSP_ALPHA,
                momentum=0.0,
                epsilon=constants.RMSP_EPSILON,
                clip_norm=constants.GRAD_NORM_CLIP)

        for i in range(constants.PARALLEL_SIZE):
            training_thread = A3CTrainingThread(
                    i, sess,
                    learning_rate_input,
                    grad_applier,
                    constants.MAX_TIME_STEP,
                    net_scope,
                    depth_scope)
            training_threads.append(training_thread)

        if constants.RUN_TEST:
            testing_thread = A3CTestingThread(constants.PARALLEL_SIZE + 1, sess, net_scope, depth_scope)

        sess.run(tf.global_variables_initializer())

        # Initialize pretrained weights after init.
        if constants.PREDICT_DEPTH:
            depth_estimator.load_weights()


        episode_reward_input = tf.placeholder(tf.float32, name='ep_reward')
        episode_length_input = tf.placeholder(tf.float32, name='ep_length')
        exist_answer_correct_input = tf.placeholder(tf.float32, name='exist_ans')
        count_answer_correct_input = tf.placeholder(tf.float32, name='count_ans')
        contains_answer_correct_input = tf.placeholder(tf.float32, name='contains_ans')
        percent_invalid_actions_input = tf.placeholder(tf.float32, name='inv')

        scalar_summaries = [
            tf.summary.scalar("Episode Reward", episode_reward_input),
            tf.summary.scalar("Episode Length", episode_length_input),
            tf.summary.scalar("Percent Invalid Actions", percent_invalid_actions_input),
        ]
        exist_summary = tf.summary.scalar("Answer Correct Existence", exist_answer_correct_input),
        count_summary = tf.summary.scalar("Answer Correct Counting", count_answer_correct_input),
        contains_summary = tf.summary.scalar("Answer Correct Containing", contains_answer_correct_input),
        exist_summary_op = tf.summary.merge(scalar_summaries + [exist_summary])
        count_summary_op = tf.summary.merge(scalar_summaries + [count_summary])
        contains_summary_op = tf.summary.merge(scalar_summaries + [contains_summary])
        summary_ops = [exist_summary_op, count_summary_op, contains_summary_op]

        summary_placeholders = {
            "episode_reward_input": episode_reward_input,
            "episode_length_input": episode_length_input,
            "exist_answer_correct_input": exist_answer_correct_input,
            "count_answer_correct_input": count_answer_correct_input,
            "contains_answer_correct_input": contains_answer_correct_input,
            "percent_invalid_actions_input" : percent_invalid_actions_input,
            }

        if constants.RUN_TEST:
            test_episode_reward_input = tf.placeholder(tf.float32, name='test_ep_reward')
            test_episode_length_input = tf.placeholder(tf.float32, name='test_ep_length')
            test_exist_answer_correct_input = tf.placeholder(tf.float32, name='test_exist_ans')
            test_count_answer_correct_input = tf.placeholder(tf.float32, name='test_count_ans')
            test_contains_answer_correct_input = tf.placeholder(tf.float32, name='test_contains_ans')
            test_percent_invalid_actions_input = tf.placeholder(tf.float32, name='test_inv')

            test_scalar_summaries = [
                tf.summary.scalar("Test Episode Reward", test_episode_reward_input),
                tf.summary.scalar("Test Episode Length", test_episode_length_input),
                tf.summary.scalar("Test Percent Invalid Actions", test_percent_invalid_actions_input),
            ]
            exist_summary = tf.summary.scalar("Test Answer Correct Existence", test_exist_answer_correct_input),
            count_summary = tf.summary.scalar("Test Answer Correct Counting", test_count_answer_correct_input),
            contains_summary = tf.summary.scalar("Test Answer Correct Containing", test_contains_answer_correct_input),
            test_exist_summary_op = tf.summary.merge(test_scalar_summaries + [exist_summary])
            test_count_summary_op = tf.summary.merge(test_scalar_summaries + [count_summary])
            test_contains_summary_op = tf.summary.merge(test_scalar_summaries + [contains_summary])
            test_summary_ops = [test_exist_summary_op, test_count_summary_op, test_contains_summary_op]

        if not constants.DEBUG:
            time_str = py_util.get_time_str()
            summary_writer = tf.summary.FileWriter(os.path.join(constants.LOG_FILE, time_str), sess.graph)
        else:
            summary_writer = None

        # init or load checkpoint with saver
        vars_to_save = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='global_network')
        saver = tf.train.Saver(vars_to_save, max_to_keep=3)
        print('-------------- Looking for checkpoints in ', constants.CHECKPOINT_DIR)
        global_t = tf_util.restore_from_dir(sess, constants.CHECKPOINT_DIR)

        if constants.USE_NAVIGATION_AGENT:
            print('now trying to restore nav model')
            tf_util.restore_from_dir(sess, 'logs/checkpoints/navigation', True)

        sess.graph.finalize()

        for i in range(constants.PARALLEL_SIZE):
            sess.run(training_threads[i].sync)
            training_threads[i].agent.reset()

        times = []

        if not constants.DEBUG and constants.RECORD_FEED_DICT:
            import h5py
            NUM_RECORD = 10000 * len(constants.USED_QUESTION_TYPES)
            time_str = py_util.get_time_str()
            if not os.path.exists('question_data_dump'):
                os.mkdir('question_data_dump')
            dataset = h5py.File('question_data_dump/maps_' + time_str + '.h5', 'w')

            dataset.create_dataset('question_data/existence_answer_placeholder', (NUM_RECORD, 1), dtype=np.int32)
            dataset.create_dataset('question_data/counting_answer_placeholder', (NUM_RECORD, 1), dtype=np.int32)

            dataset.create_dataset('question_data/question_type_placeholder', (NUM_RECORD, 1), dtype=np.int32)
            dataset.create_dataset('question_data/question_object_placeholder', (NUM_RECORD, 1), dtype=np.int32)
            dataset.create_dataset('question_data/question_container_placeholder', (NUM_RECORD, 21), dtype=np.int32)

            dataset.create_dataset('question_data/pose_placeholder', (NUM_RECORD, 3), dtype=np.int32)
            dataset.create_dataset('question_data/image_placeholder', (NUM_RECORD, 300, 300, 3), dtype=np.uint8)
            dataset.create_dataset('question_data/map_mask_placeholder',
                    (NUM_RECORD, constants.SPATIAL_MAP_HEIGHT, constants.SPATIAL_MAP_WIDTH, 27), dtype=np.uint8)
            dataset.create_dataset('question_data/meta_action_placeholder', (NUM_RECORD, 7), dtype=np.int32)
            dataset.create_dataset('question_data/possible_move_placeholder', (NUM_RECORD, 31), dtype=np.int32)

            dataset.create_dataset('question_data/answer_weight', (NUM_RECORD, 1), dtype=np.int32)
            dataset.create_dataset('question_data/taken_action', (NUM_RECORD, 32), dtype=np.int32)
            dataset.create_dataset('question_data/new_episode', (NUM_RECORD, 1), dtype=np.int32)

        time_lock = threading.Lock()

        data_ind = 0

        def train_function(parallel_index):
            global global_t
            global dataset
            global data_ind
            print('----------------------------------------thread', parallel_index, 'global_t', global_t)
            training_thread = training_threads[parallel_index]
            last_global_t = global_t
            last_global_t_image = global_t

            while global_t < constants.MAX_TIME_STEP:
                diff_global_t, ep_length, ep_reward, num_unrolls, feed_dict = training_thread.process(global_t, summary_writer,
                        summary_ops, summary_placeholders)
                time_lock.acquire()

                if not constants.DEBUG and constants.RECORD_FEED_DICT:
                    print('    NEW ENTRY: %d %s' % (data_ind, training_thread.agent.game_state.scene_name))
                    dataset['question_data/new_episode'][data_ind] = 1
                    for k, v in feed_dict.items():
                        key = 'question_data/' + k.name.split('/')[-1].split(':')[0]
                        if any([s in key for s in {
                            'gru_placeholder', 'num_unrolls',
                            'reward_placeholder', 'td_placeholder',
                            'learning_rate', 'phi_hat_prev_placeholder',
                            'next_map_mask_placeholder', 'next_pose_placeholder',
                            'episode_length_placeholder', 'question_count_placeholder',
                            'supervised_action_labels', 'supervised_action_weights_sigmoid',
                            'supervised_action_weights', 'question_direction_placeholder',
                        }]):
                            continue
                        v = np.ascontiguousarray(v)
                        if v.shape[0] == 1:
                            v = v[0, ...]
                        if len(v.shape) == 1 and num_unrolls > 1:
                            v = v[:, np.newaxis]
                        if 'map_mask_placeholder' in key:
                            v[:, :, :, :2] *= constants.STEPS_AHEAD
                            v[:, :, :, :2] += 2
                        data_loc = dataset[key][data_ind:data_ind + num_unrolls, ...]
                        data_len = data_loc.shape[0]
                        dataset[key][data_ind:data_ind + num_unrolls, ...] = v[:data_len, ...]
                    dataset.flush()

                    data_ind += num_unrolls
                    if data_ind >= NUM_RECORD:
                        # Everything is done
                        dataset.close()
                        raise Exception('Fake exception to exit the process. Don\'t worry, everything is fine.')
                if ep_length > 0:
                    times.append((ep_length, ep_reward))
                    print('Num episodes', len(times), 'Episode means', np.mean(times, axis=0))
                time_lock.release()
                global_t += diff_global_t
                # periodically save checkpoints to disk
                if not (constants.DEBUG or constants.RECORD_FEED_DICT) and parallel_index == 0:
                    if global_t - last_global_t_image > 10000:
                        print('Ran conv image summary')
                        summary_im_str = sess.run(conv_image_summary)
                        summary_writer.add_summary(summary_im_str, global_t)
                        last_global_t_image = global_t
                    if global_t - last_global_t > 10000:
                        print('Save checkpoint at timestamp %d' % global_t)
                        tf_util.save(saver, sess, constants.CHECKPOINT_DIR, global_t)
                        last_global_t = global_t

        def test_function():
            global global_t
            last_global_t_test = 0
            while global_t < constants.MAX_TIME_STEP:
                time.sleep(1)
                if global_t - last_global_t_test > 10000:
                    # RUN TEST
                    sess.run(testing_thread.sync)
                    from game_state import QuestionGameState
                    if testing_thread.agent.game_state is None:
                        testing_thread.agent.game_state = QuestionGameState(sess=sess)
                    for q_type in constants.USED_QUESTION_TYPES:
                        answers_correct = 0.0
                        ep_lengths = 0.0
                        ep_rewards = 0.0
                        invalid_percents = 0.0
                        rows = list(range(len(testing_thread.agent.game_state.test_datasets[q_type])))
                        random.shuffle(rows)
                        rows = rows[:16]
                        print('()()()()()()()rows', rows)
                        for rr,row in enumerate(rows):
                            answer_correct, answer, gt_answer, ep_length, ep_reward, invalid_percent, scene_num, seed, required_interaction = testing_thread.process((row, q_type))
                            answers_correct += int(answer_correct)
                            ep_lengths += ep_length
                            ep_rewards += ep_reward
                            invalid_percents += invalid_percent
                            print('############################### TEST ITERATION ##################################')
                            print('ep ', (rr + 1))
                            print('average correct', answers_correct / (rr + 1))
                            print('#################################################################################')
                        answers_correct /= len(rows)
                        ep_lengths /= len(rows)
                        ep_rewards /= len(rows)
                        invalid_percents /= len(rows)

                        # Write the summary
                        test_summary_str = sess.run(test_summary_ops[q_type], feed_dict={
                            test_episode_reward_input : ep_rewards,
                            test_episode_length_input : ep_lengths,
                            test_exist_answer_correct_input : answers_correct,
                            test_count_answer_correct_input : answers_correct,
                            test_contains_answer_correct_input : answers_correct,
                            test_percent_invalid_actions_input : invalid_percents,
                            })
                        summary_writer.add_summary(test_summary_str, global_t)
                        summary_writer.flush()
                        last_global_t_test = global_t
                    testing_thread.agent.game_state.env.stop()
                    testing_thread.agent.game_state = None



        train_threads = []
        for i in range(constants.PARALLEL_SIZE):
            train_threads.append(threading.Thread(target=train_function, args=(i,)))
            train_threads[i].daemon = True

        for t in train_threads:
            t.start()

        if constants.RUN_TEST:
            test_thread = threading.Thread(target=test_function)
            test_thread.daemon = True
            test_thread.start()

        for t in train_threads:
            t.join()

        if constants.RUN_TEST:
            test_thread.join()

        if not constants.DEBUG:
            if not os.path.exists(constants.CHECKPOINT_DIR):
                os.makedirs(constants.CHECKPOINT_DIR)
            saver.save(sess, constants.CHECKPOINT_DIR + '/' + 'checkpoint', global_step = global_t)
            summary_writer.close()
            print('Saved.')

    except KeyboardInterrupt:
        print('Press Ctrl+C to stop')
    except:
        import traceback
        traceback.print_exc()
    finally:
        if not constants.DEBUG:
            print('Now saving data. Please wait')
            tf_util.save(saver, sess, constants.CHECKPOINT_DIR, global_t)
            summary_writer.close()
            print('Saved.')
Ejemplo n.º 3
0
def main():
    if constants.OBJECT_DETECTION:
        from darknet_object_detection import detector
        detector.setup_detectors(constants.PARALLEL_SIZE)

    with tf.device('/gpu:' + str(constants.GPU_ID)):
        with tf.variable_scope('global_network'):
            if constants.END_TO_END_BASELINE:
                global_network = EndToEndBaselineNetwork()
            else:
                global_network = QAPlannerNetwork(constants.RL_GRU_SIZE, 1, 1)
            global_network.create_net()
        if constants.USE_NAVIGATION_AGENT:
            with tf.variable_scope('nav_global_network') as net_scope:
                free_space_network = FreeSpaceNetwork(constants.GRU_SIZE, 1, 1)
                free_space_network.create_net()
        else:
            net_scope = None

        # prepare session
        sess = tf_util.Session()

        if constants.PREDICT_DEPTH:
            from depth_estimation_network import depth_estimator
            with tf.variable_scope('') as depth_scope:
                depth_estimator = depth_estimator.get_depth_estimator(sess)
        else:
            depth_scope = None

        sess.run(tf.global_variables_initializer())

        # Initialize pretrained weights after init.
        if constants.PREDICT_DEPTH:
            depth_estimator.load_weights()

        testing_threads = []

        for i in range(constants.PARALLEL_SIZE):
            testing_thread = A3CTestingThread(i, sess, net_scope, depth_scope)
            testing_threads.append(testing_thread)

        tf_util.restore_from_dir(sess, constants.CHECKPOINT_DIR, True)

        if constants.USE_NAVIGATION_AGENT:
            print('now trying to restore nav model')
            tf_util.restore_from_dir(
                sess, os.path.join(constants.CHECKPOINT_PREFIX, 'navigation'),
                True)

    sess.graph.finalize()

    question_types = constants.USED_QUESTION_TYPES
    rows = []
    for q_type in question_types:
        curr_rows = list(
            range(len(testing_thread.agent.game_state.test_datasets[q_type])))
        #curr_rows = list(range(8))
        rows.extend(list(zip(curr_rows, [q_type] * len(curr_rows))))

    random.shuffle(rows)

    answers_correct = []
    ep_lengths = []
    ep_rewards = []
    invalid_percents = []
    time_lock = threading.Lock()
    if not os.path.exists(constants.LOG_FILE):
        os.makedirs(constants.LOG_FILE)
    out_file = open(
        constants.LOG_FILE + '/results_' + constants.TEST_SET + '_' +
        py_util.get_time_str() + '.csv', 'w')
    out_file.write(constants.LOG_FILE + '\n')
    out_file.write(
        'question_type, answer_correct, answer, gt_answer, episode_length, invalid_action_percent, scene number, seed, required_interaction\n'
    )

    def test_function(thread_ind):
        testing_thread = testing_threads[thread_ind]
        sess.run(testing_thread.sync)
        #from game_state import QuestionGameState
        #if testing_thread.agent.game_state is None:
        #testing_thread.agent.game_state = QuestionGameState(sess=sess)
        while len(rows) > 0:
            time_lock.acquire()
            if len(rows) == 0:
                break
            row = rows.pop()
            time_lock.release()

            answer_correct, answer, gt_answer, ep_length, ep_reward, invalid_percent, scene_num, seed, required_interaction = testing_thread.process(
                row)
            question_type = row[1] + 1

            time_lock.acquire()
            output_str = (
                '%d, %d, %d, %d, %d, %f, %d, %d, %d\n' %
                (question_type, answer_correct, answer, gt_answer, ep_length,
                 invalid_percent, scene_num, seed, required_interaction))
            out_file.write(output_str)
            out_file.flush()
            answers_correct.append(int(answer_correct))
            ep_lengths.append(ep_length)
            ep_rewards.append(ep_reward)
            invalid_percents.append(invalid_percent)
            print('###############################')
            print('ep ', row)
            print('num episodes', len(answers_correct))
            print('average correct', np.mean(answers_correct))
            print('invalid percents', np.mean(invalid_percents),
                  np.median(invalid_percents))
            print('###############################')
            time_lock.release()

    test_threads = []
    for i in range(constants.PARALLEL_SIZE):
        test_threads.append(threading.Thread(target=test_function, args=(i, )))

    for t in test_threads:
        t.start()

    for t in test_threads:
        t.join()

    out_file.close()
Ejemplo n.º 4
0
class RLGraphAgent(QAAgent):
    def __init__(self,
                 sess,
                 num_unrolls=1,
                 free_space_network_scope=None,
                 depth_scope=None):
        super(RLGraphAgent, self).__init__(sess, depth_scope)
        self.network = QAPlannerNetwork(constants.RL_GRU_SIZE, 1, num_unrolls)
        self.network.create_net()

        self.num_steps = 0
        self.global_step_id = 0
        self.global_num_steps = 0
        self.num_invalid_actions = 0
        self.global_num_invalid_actions = 0
        self.num_unrolls = num_unrolls
        self.coord_box = np.mgrid[int(-constants.STEPS_AHEAD * 1.0 /
                                      2):np.ceil(constants.STEPS_AHEAD * 1.0 /
                                                 2),
                                  1:1 + constants.STEPS_AHEAD].transpose(
                                      1, 2, 0) / constants.STEPS_AHEAD

        if constants.USE_NAVIGATION_AGENT:
            self.nav_agent = GraphAgent(sess, True, 1, self.game_state,
                                        free_space_network_scope)

    def inference(self):
        outputs = self.get_next_output()
        self.gru_state = outputs[0]
        self.v = outputs[1][0, ...]
        self.pi = outputs[2][0, ...]
        self.possible_moves_pred = outputs[3][0, ...]
        if self.game_state.question_type_ind == 1:
            self.answer = outputs[5][0, ...]
        else:
            self.answer = outputs[4][0, ...]
        self.question = outputs[6]
        self.memory_crops = outputs[7]
        self.possible_moves_weights = outputs[8]
        self.actions = outputs[9]
        self.teleport_input_crops = outputs[10]
        self.pi_logits = outputs[-1]

    def get_next_output(self):
        image = self.game_state.s_t[np.newaxis, np.newaxis, ...]

        pose_shifted = np.array(self.pose[:3])
        pose_shifted[0] -= self.bounds[0]
        pose_shifted[1] -= self.bounds[1]

        self.map_mask_padded = np.pad(
            self.spatial_map.memory[:, :, 1:],
            ((0, constants.SPATIAL_MAP_HEIGHT - self.bounds[3]),
             (0, constants.SPATIAL_MAP_WIDTH - self.bounds[2]), (0, 0)),
            'constant',
            constant_values=0).copy()

        self.feed_dict = {
            self.network.image_placeholder:
            image,
            self.network.pose_placeholder:
            pose_shifted[np.newaxis, np.newaxis, :],
            self.network.map_mask_placeholder:
            self.map_mask_padded[np.newaxis, np.newaxis, ...],
            self.network.gru_placeholder:
            self.gru_state,
            self.network.question_object_placeholder:
            np.array(self.game_state.object_target)[np.newaxis, np.newaxis],
            self.network.question_container_placeholder:
            self.game_state.container_target[np.newaxis, np.newaxis, :],
            self.network.question_direction_placeholder:
            self.game_state.direction_target[np.newaxis, np.newaxis, :],
            self.network.question_type_placeholder:
            np.array(self.game_state.question_type_ind)[np.newaxis,
                                                        np.newaxis],
            self.network.existence_answer_placeholder:
            np.array(self.game_state.answer)[np.newaxis, np.newaxis],
            self.network.counting_answer_placeholder:
            np.array(self.game_state.answer)[np.newaxis, np.newaxis],
            self.network.answer_weight:
            np.array(int(self.game_state.can_end))[np.newaxis, np.newaxis],
            self.network.possible_move_placeholder:
            np.array(self.possible_moves)[np.newaxis, np.newaxis, :],
            self.network.meta_action_placeholder:
            self.last_meta_action[np.newaxis, np.newaxis, :],
            self.network.taken_action:
            self.last_action_one_hot[np.newaxis, :],
            self.network.episode_length_placeholder:
            np.array([self.num_steps / constants.MAX_EPISODE_LENGTH
                      ])[np.newaxis, :],
            self.network.question_count_placeholder:
            np.array([self.question_count])[np.newaxis, :],
        }
        if self.num_unrolls is None:
            self.feed_dict[self.network.num_unrolls] = 1

        self.prev_map_mask = self.spatial_map.memory.copy()

        outputs = self.sess.run([
            self.network.gru_state,
            self.network.v,
            self.network.pi,
            self.network.possible_moves,
            self.network.existence_answer,
            self.network.counting_answer,
            self.network.question_object_one_hot,
            self.network.memory_crops_rot,
            self.network.possible_moves_weights,
            self.network.actions,
            self.network.teleport_input_crops,
            self.network.pi_logits,
        ],
                                feed_dict=self.feed_dict)
        return outputs

    def get_reward(self):
        self.reward = self.game_state.reward
        self.reward += self.new_coverage * 1.0 / (constants.STEPS_AHEAD**2 + 1)
        if constants.DEBUG:
            print('coverage %.2f - (%.3f%%)  reward %.3f' % (float(
                self.coverage), float(
                    self.coverage * 100.0 / self.max_coverage), self.reward))
        if self.game_state.can_end and not self.prev_can_end:
            self.reward = 10
        if self.game_state.can_end and self.prev_can_end:
            if self.game_state.question_type_ind == 0:
                self.reward = -1
            elif self.game_state.question_type_ind == 1:
                pass
            elif self.game_state.question_type_ind == 2:
                self.reward = -1
            elif self.game_state.question_type_ind == 3:
                pass
        if self.terminal:
            if constants.DEBUG:
                print('answering', self.answer)
            if self.game_state.question_type_ind != 1:
                answer = self.answer[0] > 0.5
            else:
                answer = np.argmax(self.answer)

            if answer == self.game_state.answer and self.game_state.can_end:
                self.reward = 10
                print('Answer correct!!!!!')
            else:
                self.reward = -30  # Average is -10 for 50% chance, -20 for 25% chance
                print('Answer incorrect :( :( :( :( :(')
        if self.num_steps >= constants.MAX_EPISODE_LENGTH:
            self.reward = -30
            self.terminal = True
        self.prev_can_end = self.game_state.can_end
        return np.clip(self.reward / 10, -3, 1), self.terminal

    def reset(self, seed=None, test_ind=None):
        if self.game_state.env is not None:
            self.game_state.reset(seed=seed, test_ind=test_ind)
            self.bounds = self.game_state.bounds
        self.end_points = np.zeros_like(self.game_state.graph.memory[:, :, 0])
        for end_point in self.game_state.end_point:
            self.end_points[end_point[1] - self.game_state.graph.yMin,
                            end_point[0] - self.game_state.graph.xMin] = 1

        self.gru_state = np.zeros((1, constants.RL_GRU_SIZE))
        self.pose = self.game_state.pose
        self.prev_pose = self.pose
        self.visited_poses = set()
        self.reward = 0
        self.num_steps = 0
        self.question_count = 0
        self.num_invalid_actions = 0
        self.prev_can_end = False
        dilation_kernel = np.ones(
            (constants.SCENE_PADDING * 2 + 1, constants.SCENE_PADDING))
        free_space = self.game_state.xray_graph.memory.copy().squeeze()
        if len(free_space.shape) == 3:
            free_space = free_space[:, :, 0]
        free_space[free_space == graph_obj.MAX_WEIGHT] = 0
        free_space[free_space > 1] = 1
        self.dilation = np.zeros_like(free_space)
        for _ in range(2):
            self.dilation += scipy.ndimage.morphology.binary_dilation(
                free_space, structure=dilation_kernel).astype(np.int)
            dilation_kernel = np.rot90(dilation_kernel)
        self.dilation[self.dilation > 1] = 1
        self.max_coverage = np.sum(self.dilation)
        # Rows are:
        # 0 - Map weights (not fed to decision network)
        # 1 and 2 - meshgrid
        # 3 - coverage
        # 4 - teleport locations
        # 5 - free space map
        # 6 - visited locations
        # 7+ - object location
        if constants.USE_NAVIGATION_AGENT:
            if self.nav_agent is not None:
                self.nav_agent.reset(self.game_state.scene_name)
        self.spatial_map = graph_obj.Graph('layouts/%s-layout.npy' %
                                           self.game_state.scene_name,
                                           use_gt=True,
                                           construct_graph=False)
        self.spatial_map.memory = np.concatenate(
            (np.zeros((self.bounds[3], self.bounds[2], 7)),
             self.game_state.graph.memory[:, :, 1:].copy()),
            axis=2)

        self.coverage = 0
        self.terminal = False
        self.new_coverage = 0
        self.last_meta_action = np.zeros(7)
        self.last_action_one_hot = np.zeros(
            self.network.pi.get_shape().as_list()[-1])
        self.global_step_id += 1
        self.update_spatial_map({'action': 'Initialize'})

        # For drawing
        self.forward_pred = np.zeros((3, 3, 3))
        self.next_memory_crops_rot = np.zeros((3, 3, 3))

    def update_spatial_map(self, action):
        if action['action'] == 'Teleport':
            self.spatial_map.memory[action['z'] - self.bounds[1],
                                    action['x'] - self.bounds[0], 4] = 1

        self.pose = self.game_state.pose
        self.spatial_map.memory[:, :, 1:3] = 0
        self.spatial_map.update_graph((self.coord_box, 0),
                                      self.pose,
                                      rows=[1, 2])

        if not constants.USE_NAVIGATION_AGENT:
            path = self.game_state.xray_graph.get_shortest_path(
                self.prev_pose, self.pose)[1]
            if action['action'] == 'Teleport':
                self.num_steps += max(1, len(path) - 1)

            self.new_coverage = -1
            full_new_coverage = 0
            for pose in path[::-1]:
                patch = self.spatial_map.get_graph_patch(pose)
                patch_coverage = 1 - patch[0][:, :, 3]
                patch_coverage = np.sum(patch_coverage) + (1 - patch[1][3])
                if self.new_coverage == -1:
                    self.new_coverage = patch_coverage
                full_new_coverage += patch_coverage
                self.spatial_map.update_graph((np.ones(
                    (constants.STEPS_AHEAD, constants.STEPS_AHEAD, 1)), 1),
                                              pose,
                                              rows=[3])
                self.spatial_map.memory[pose[1] - self.bounds[1],
                                        pose[0] - self.bounds[0], 6] = 1
            self.new_coverage = max(0, self.new_coverage)
            if full_new_coverage > 0:
                # Make sure that new_coverage is positive if coverage increased at all.
                self.new_coverage += 0.00001
            self.coverage += full_new_coverage

            # Do GT occupancy map stuff.
            free_space = self.game_state.xray_graph.memory.copy().squeeze()
            if len(free_space.shape) == 3:
                free_space = free_space[:, :, 0]
            free_space[free_space == graph_obj.MAX_WEIGHT] = 0
            free_space[free_space > 1] = 1
            self.spatial_map.memory[:, :,
                                    5] = free_space * self.spatial_map.memory[:, :,
                                                                              3]
            #self.game_state.graph.memory[:, :, 0] = (1 + self.spatial_map.memory[:, :, 3] * (graph_obj.EPSILON +
            #(1 - self.spatial_map.memory[:, :, 5]) * (graph_obj.MAX_WEIGHT - graph_obj.EPSILON - 1)))

        else:
            self.spatial_map.memory[:, :,
                                    3] = self.game_state.graph.memory[:, :,
                                                                      0] > 1
            new_coverage = np.sum(self.spatial_map.memory[:, :, 3])
            self.new_coverage = new_coverage - self.coverage
            self.coverage = new_coverage
            self.spatial_map.memory[:, :, 5] = 1 - self.nav_agent.occupancy
            for pose in self.nav_agent.visited_spots:
                self.spatial_map.memory[pose[1] - self.bounds[1],
                                        pose[0] - self.bounds[0], 6] = 1

        if constants.RECORD_FEED_DICT:
            self.spatial_map.memory[:, :,
                                    7:] = self.game_state.xray_graph.memory[:, :,
                                                                            1:].copy(
                                                                            )
        else:
            self.spatial_map.memory[:, :,
                                    7:] = self.game_state.graph.memory[:, :,
                                                                       1:].copy(
                                                                       )

        patch = self.game_state.xray_graph.get_graph_patch(
            self.pose)[0].reshape((constants.STEPS_AHEAD**2, -1))
        patch = patch[:, 0]
        patch[patch == graph_obj.MAX_WEIGHT] = 0
        patch[patch > 1] = 1
        open_success = not self.game_state.get_action({'action': 'OpenObject'
                                                       })[2]
        close_success = not self.game_state.get_action(
            {'action': 'CloseObject'})[2]
        self.possible_moves = np.concatenate(
            (patch, [1], [1], [self.pose[3] != 330], [self.pose[3] != 60],
             [open_success], [close_success]))

    def step(self, action):
        self.prev_pose = self.pose
        self.visited_poses.add(self.pose)

        self.last_meta_action = np.zeros(7)
        if action['action'] == 'Teleport':
            self.last_meta_action[0] = 1
        elif action['action'] == 'RotateLeft':
            self.last_meta_action[1] = 1
        elif action['action'] == 'RotateRight':
            self.last_meta_action[2] = 1
        elif action['action'] == 'LookUp':
            self.last_meta_action[3] = 1
        elif action['action'] == 'LookDown':
            self.last_meta_action[4] = 1
        elif action['action'] == 'OpenObject':
            self.last_meta_action[5] = 1
        elif action['action'] == 'CloseObject':
            self.last_meta_action[6] = 1

        if action['action'] == 'Answer':
            self.terminal = True
        else:
            if not constants.USE_NAVIGATION_AGENT or action[
                    'action'] != 'Teleport':
                self.game_state.step(action)
                if not self.game_state.event.metadata['lastActionSuccess']:
                    self.num_invalid_actions += 1
                    self.global_num_invalid_actions += 1
                if action['action'] != 'Teleport':
                    self.num_steps += 1
                    self.global_num_steps += 1
            else:
                num_steps, num_invalid_actions = self.nav_agent.goto(
                    action, self.global_step_id)
                # Still need to step to get reward etc.
                self.global_num_steps += num_steps
                self.global_step_id += num_steps
                self.num_steps += num_steps
                self.game_state.step(action)
                self.num_invalid_actions += num_invalid_actions
                self.global_num_invalid_actions += num_invalid_actions
            if constants.USE_NAVIGATION_AGENT and 'Rotate' in action['action']:
                self.nav_agent.inference()

            self.update_spatial_map(action)

        self.global_step_id += 1

    def get_action(self, action_ind):
        if action_ind < constants.STEPS_AHEAD**2:
            # Teleport
            action_x = action_ind % constants.STEPS_AHEAD - int(
                constants.STEPS_AHEAD / 2)
            action_z = int(action_ind / constants.STEPS_AHEAD) + 1
            x_shift = 0
            z_shift = 0
            if self.pose[2] == 0:
                x_shift = action_x
                z_shift = action_z
            elif self.pose[2] == 1:
                x_shift = action_z
                z_shift = -action_x
            elif self.pose[2] == 2:
                x_shift = -action_x
                z_shift = -action_z
            elif self.pose[2] == 3:
                x_shift = -action_z
                z_shift = action_x
            action_x = self.pose[0] + x_shift
            action_z = self.pose[1] + z_shift
            action = {
                'action': 'Teleport',
                'x': action_x,
                'z': action_z,
                'rotation': self.pose[2] * 90,
            }
        else:
            # Rotate/Look/Open/Close/Answer
            action_ind -= constants.STEPS_AHEAD**2
            if action_ind == 0:
                action = {'action': 'RotateLeft'}
            elif action_ind == 1:
                action = {'action': 'RotateRight'}
            elif action_ind == 2:
                action = {'action': 'LookUp'}
            elif action_ind == 3:
                action = {'action': 'LookDown'}
            elif action_ind == 4:
                action = {'action': 'OpenObject'}
            elif action_ind == 5:
                action = {'action': 'CloseObject'}
            elif action_ind == 6:
                action = {'action': 'Answer', 'value': self.answer}
            else:
                raise Exception('something very wrong happened')
        return action

    def draw_state(self, return_list=False, action=None):
        if not constants.DRAWING:
            return
        # Rows are:
        # 0 - Map weights (not fed to decision network)
        # 1 and 2 - meshgrid
        # 3 - coverage
        # 4 - teleport locations
        # 5 - free space map
        # 6 - visited locations
        # 7+ - object location
        from utils import drawing
        curr_image = self.game_state.detection_image.copy()
        state_image = self.game_state.draw_state()

        action_hist = np.zeros((3, 3, 3))
        pi = self.pi.copy()
        if constants.STEPS_AHEAD == 5:
            action_hist = np.concatenate((pi, np.zeros(3)))
            action_hist = action_hist.reshape(7, 5)
        elif constants.STEPS_AHEAD == 1:
            action_hist = np.concatenate((pi, np.zeros(1)))
            action_hist = action_hist.reshape(3, 3)

        flat_action_size = max(len(pi), 100)
        flat_action_hist = np.zeros((flat_action_size, flat_action_size))
        for ii, flat_action_i in enumerate(pi):
            flat_action_hist[:max(
                int(np.round(flat_action_i * flat_action_size)), 1),
                             int(ii * flat_action_size /
                                 len(pi)):int((ii + 1) * flat_action_size /
                                              len(pi))] = (ii + 1)
        flat_action_hist = np.flipud(flat_action_hist)

        # Answer histogram
        ans = self.answer
        if len(ans) == 1:
            ans = [1 - ans[0], ans[0]]
        ans_size = max(len(ans), 100)
        ans_hist = np.zeros((ans_size, ans_size))
        for ii, ans_i in enumerate(ans):
            ans_hist[:max(int(np.round(ans_i * ans_size)), 1),
                     int(ii * ans_size / len(ans)):int((ii + 1) * ans_size /
                                                       len(ans))] = (ii + 1)
        ans_hist = np.flipud(ans_hist)

        dil = np.flipud(self.dilation)
        dil[0, 0] = 4
        coverage = int(self.coverage * 100 / self.max_coverage)

        possible = np.zeros((3, 3, 3))
        possible_pred = np.zeros((3, 3, 3))
        if constants.STEPS_AHEAD == 5:
            possible = self.possible_moves.copy()
            possible = np.concatenate((possible, np.zeros(4)))
            possible = possible.reshape(constants.STEPS_AHEAD + 2,
                                        constants.STEPS_AHEAD)

            possible_pred = self.possible_moves_pred.copy()
            possible_pred = np.concatenate((possible_pred, np.zeros(4)))
            possible_pred = possible_pred.reshape(constants.STEPS_AHEAD + 2,
                                                  constants.STEPS_AHEAD)

        elif constants.STEPS_AHEAD == 1:
            possible = self.possible_moves.copy()
            possible = np.concatenate((possible, np.zeros(2)))
            possible = possible.reshape(3, 3)

            possible_pred = self.possible_moves_pred.copy()
            possible_pred = np.concatenate((possible_pred, np.zeros(2)))
            possible_pred = possible_pred.reshape(3, 3)

        if self.game_state.question_type_ind in {2, 3}:
            obj_mem = self.spatial_map.memory[:, :, 7 + self.game_state.
                                              question_target[1]].copy()
            obj_mem += self.spatial_map.memory[:, :, 7 + self.game_state.
                                               object_target] * 2
        else:
            obj_mem = self.spatial_map.memory[:, :, 7 + self.game_state.
                                              object_target].copy()
        obj_mem[0, 0] = 2

        memory_map = np.flipud(self.spatial_map.memory[:, :, 7:].copy())
        curr_objs = np.argmax(memory_map, axis=2)

        gt_objs = np.flipud(
            np.argmax(self.game_state.xray_graph.memory[:, :, 1:], 2))
        curr_objs[0, 0] = np.max(gt_objs)
        memory_crop = self.memory_crops[0, ...].copy()
        memory_crop_cov = np.argmax(np.flipud(memory_crop), axis=2)

        gt_semantic_crop = np.flipud(
            np.argmax(self.next_memory_crops_rot, axis=2))

        images = [
            curr_image,
            state_image,
            dil + np.max(np.flipud(self.spatial_map.memory[:, :, 3:5]) *
                         np.array([1, 3]),
                         axis=2),
            memory_crop_cov,
            ans_hist,
            flat_action_hist,
            np.flipud(action_hist),
            np.flipud(possible),
            np.flipud(possible_pred),
            gt_objs,
            curr_objs,
            np.flipud(obj_mem),
        ]
        if type(action) == int:
            action = self.game_state.get_action(action)[0]
        action_str = game_util.get_action_str(action)
        if action_str == 'Answer':
            if self.game_state.question_type_ind != 1:
                action_str += ' ' + str(self.answer > 0.5)
            elif self.game_state.question_type_ind == 1:
                action_str += ' ' + str(np.argmax(self.answer))
        if self.game_state.question_type_ind == 0:
            question_str = '%03d S %s Ex Q: %s A: %s' % (
                self.num_steps, self.game_state.scene_name[9:],
                constants.OBJECTS[self.game_state.question_target],
                bool(self.game_state.answer))
        elif self.game_state.question_type_ind == 1:
            question_str = '%03d S %s # Q: %s A: %d' % (
                self.num_steps, self.game_state.scene_name[9:],
                constants.OBJECTS[self.game_state.question_target],
                self.game_state.answer)
        elif self.game_state.question_type_ind == 2:
            question_str = '%03d S %s Q: %s in %s A: %s' % (
                self.num_steps, self.game_state.scene_name[9:],
                constants.OBJECTS[self.game_state.question_target[0]],
                constants.OBJECTS[self.game_state.question_target[1]],
                bool(self.game_state.answer))
        else:
            raise Exception('No matching question number')
        titles = [
            question_str,
            str(self.answer),
            action_str,
            'coverage %d%% can end %s' %
            (coverage, bool(self.game_state.can_end)),
            'reward %.3f, value %.3f' % (self.reward, self.v),
        ]
        if return_list:
            return action_hist
        image = drawing.subplot(images,
                                4,
                                3,
                                curr_image.shape[1],
                                curr_image.shape[0],
                                titles=titles,
                                border=3)
        if not os.path.exists('visualizations/images'):
            os.makedirs('visualizations/images')
        cv2.imwrite(
            'visualizations/images/state_%05d.jpg' % self.global_step_id,
            image[:, :, ::-1])

        return image
Ejemplo n.º 5
0
def run():
    try:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(constants.GPU_ID)
        with tf.variable_scope('global_network'):
            network = QAPlannerNetwork(constants.RL_GRU_SIZE, int(constants.BATCH_SIZE), 1)
            network.create_net()
            training_step = network.training(network.rl_total_loss)

        conv_var_list = [v for v in tf.trainable_variables() if 'conv' in v.name and 'weight' in v.name and
                        (v.get_shape().as_list()[0] != 1 or v.get_shape().as_list()[1] != 1)]
        for var in conv_var_list:
            tf_util.conv_variable_summaries(var, scope=var.name.replace('/', '_')[:-2])
        summary_with_images = tf.summary.merge_all()

        with tf.variable_scope('supervised_loss'):
            loss_ph = tf.placeholder(tf.float32)
            accs_ph = [tf.placeholder(tf.float32) for _ in range(4)]
            loss_summary_op = tf.summary.merge([
                tf.summary.scalar('supervised_loss', loss_ph),
                tf.summary.scalar('acc_1_exist', accs_ph[0]),
                tf.summary.scalar('acc_2_count', accs_ph[1]),
                tf.summary.scalar('acc_3_contains', accs_ph[2]),
                ])

        # prepare session
        sess = tf_util.Session()
        sess.run(tf.global_variables_initializer())

        if not (constants.DEBUG or constants.DRAWING):
            from utils import py_util
            time_str = py_util.get_time_str()
            summary_writer = tf.summary.FileWriter(os.path.join(constants.LOG_FILE, time_str), sess.graph)
        else:
            summary_writer = None

        # init or load checkpoint with saver
        saver = tf.train.Saver(max_to_keep=3)
        start_it = tf_util.restore_from_dir(sess, constants.CHECKPOINT_DIR)

        sess.graph.finalize()

        import h5py
        h5_file = sorted(glob.glob('question_data_dump/*.h5'), key=os.path.getmtime)[-1]
        dataset = h5py.File(h5_file)
        num_entries = np.sum(np.sum(dataset['question_data/pose_placeholder'][...], axis=1) > 0)
        print('num_entries', num_entries)
        start_inds = dataset['question_data/new_episode'][:num_entries]
        start_inds = np.where(start_inds[1:] != 1)[0]

        curr_it = 0
        data_time_total = 0
        solver_time_total = 0
        total_time_total = 0

        for iteration in range(start_it, constants.MAX_TIME_STEP):
            if iteration == start_it or iteration % 10 == 1:
                current_time_start = time.time()
            t_start = time.time()

            rand_inds = np.sort(np.random.choice(start_inds, int(constants.BATCH_SIZE), replace=False))
            rand_inds = rand_inds.tolist()

            existence_answer_placeholder = dataset['question_data/existence_answer_placeholder'][rand_inds]
            counting_answer_placeholder = dataset['question_data/counting_answer_placeholder'][rand_inds]

            question_type_placeholder = dataset['question_data/question_type_placeholder'][rand_inds]
            question_object_placeholder = dataset['question_data/question_object_placeholder'][rand_inds]
            question_container_placeholder = dataset['question_data/question_container_placeholder'][rand_inds]

            pose_placeholder = dataset['question_data/pose_placeholder'][rand_inds]
            image_placeholder = dataset['question_data/image_placeholder'][rand_inds]
            map_mask_placeholder = np.ascontiguousarray(dataset['question_data/map_mask_placeholder'][rand_inds])

            meta_action_placeholder = dataset['question_data/meta_action_placeholder'][rand_inds]
            possible_move_placeholder = dataset['question_data/possible_move_placeholder'][rand_inds]

            taken_action = dataset['question_data/taken_action'][rand_inds]
            answer_weight = np.ones((constants.BATCH_SIZE, 1))

            map_mask_placeholder = np.ascontiguousarray(map_mask_placeholder, dtype=np.float32)
            map_mask_placeholder[:, :, :, :2] -= 2
            map_mask_placeholder[:, :, :, :2] /= constants.STEPS_AHEAD

            map_mask_starts = map_mask_placeholder.copy()

            for bb in range(0, constants.BATCH_SIZE):
                object_ind = int(question_object_placeholder[bb])
                question_type_ind = int(question_type_placeholder[bb])
                if question_type_ind in {2, 3}:
                    container_ind = np.argmax(question_container_placeholder[bb])

                max_map_inds = np.argmax(map_mask_placeholder[bb, ...], axis=2)
                map_range = np.where(max_map_inds > 0)
                map_range_x = (np.min(map_range[1]), np.max(map_range[1]))
                map_range_y = (np.min(map_range[0]), np.max(map_range[0]))
                for jj in range(random.randint(0, 100)):
                    tmp_patch_start = (random.randint(map_range_x[0], map_range_x[1]),
                            random.randint(map_range_y[0], map_range_y[1]))
                    tmp_patch_end = (random.randint(map_range_x[0], map_range_x[1]),
                            random.randint(map_range_y[0], map_range_y[1]))
                    patch_start = (min(tmp_patch_start[0], tmp_patch_end[0]),
                            min(tmp_patch_start[1], tmp_patch_end[1]))
                    patch_end = (max(tmp_patch_start[0], tmp_patch_end[0]),
                            max(tmp_patch_start[1], tmp_patch_end[1]))

                    patch = map_mask_placeholder[bb, patch_start[1]:patch_end[1], patch_start[0]:patch_end[0], :]
                    if question_type_ind in {2, 3}:
                        obj_mem = patch[:, :, 6 + container_ind] + patch[:, :, 6 + object_ind]
                    else:
                        obj_mem = patch[:, :, 6 + object_ind].copy()
                    obj_mem += patch[:, :, 2]  # make sure seen locations stay marked.
                    if patch.size > 0 and np.max(obj_mem) == 0:
                        map_mask_placeholder[bb, patch_start[1]:patch_end[1], patch_start[0]:patch_end[0], 6:] = 0
            feed_dict = {
                    network.existence_answer_placeholder: np.ascontiguousarray(existence_answer_placeholder),
                    network.counting_answer_placeholder: np.ascontiguousarray(counting_answer_placeholder),

                    network.question_type_placeholder: np.ascontiguousarray(question_type_placeholder),
                    network.question_object_placeholder: np.ascontiguousarray(question_object_placeholder),
                    network.question_container_placeholder: np.ascontiguousarray(question_container_placeholder),
                    network.question_direction_placeholder: np.zeros((constants.BATCH_SIZE, 4), dtype=np.float32),

                    network.pose_placeholder: np.ascontiguousarray(pose_placeholder),
                    network.image_placeholder: np.ascontiguousarray(image_placeholder),
                    network.map_mask_placeholder: map_mask_placeholder,
                    network.meta_action_placeholder: np.ascontiguousarray(meta_action_placeholder),
                    network.possible_move_placeholder: np.ascontiguousarray(possible_move_placeholder),

                    network.taken_action: np.ascontiguousarray(taken_action),
                    network.answer_weight: np.ascontiguousarray(answer_weight),
                    network.episode_length_placeholder: np.ones((constants.BATCH_SIZE)),
                    network.question_count_placeholder: np.zeros((constants.BATCH_SIZE)),
                    }
            new_feed_dict = {}
            for key,value in feed_dict.items():
                if len(value.squeeze().shape) > 1:
                    new_feed_dict[key] = np.reshape(value, [int(constants.BATCH_SIZE), 1] + list(value.squeeze().shape[1:]))
                else:
                    new_feed_dict[key] = np.reshape(value, [int(constants.BATCH_SIZE), 1])
            feed_dict = new_feed_dict
            feed_dict[network.taken_action] = np.reshape(feed_dict[network.taken_action], (constants.BATCH_SIZE, -1))
            feed_dict[network.gru_placeholder] = np.zeros((int(constants.BATCH_SIZE), constants.RL_GRU_SIZE))

            data_t_end = time.time()
            if constants.DEBUG or constants.DRAWING:
                outputs = sess.run(
                        [training_step, network.rl_total_loss, network.existence_answer, network.counting_answer,
                            network.possible_moves, network.memory_crops_rot, network.taken_action],
                        feed_dict=feed_dict)
            else:
                if iteration == start_it + 10:
                    run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                    run_metadata = tf.RunMetadata()
                    outputs = sess.run([training_step, network.rl_total_loss, summary_with_images],
                            feed_dict=feed_dict,
                            options=run_options,
                            run_metadata=run_metadata)
                    loss_summary = outputs[2]
                    summary_writer.add_run_metadata(run_metadata, 'step_%07d' % iteration)
                    summary_writer.add_summary(loss_summary, iteration)
                    summary_writer.flush()
                elif iteration % 10 == 0:
                    if iteration % 100 == 0:
                        outputs = sess.run(
                                [training_step, network.rl_total_loss, summary_with_images],
                                feed_dict=feed_dict)
                        loss_summary = outputs[2]
                    elif iteration % 10 == 0:
                        outputs = sess.run(
                                [training_step, network.rl_total_loss,
                                 network.existence_answer,
                                 network.counting_answer,
                                 ],
                                feed_dict=feed_dict)
                        outputs[2] = outputs[2].reshape(-1, 1)
                        acc_q0 = np.sum((existence_answer_placeholder == (outputs[2] > 0.5)) *
                            (question_type_placeholder == 0)) / np.maximum(1, np.sum(question_type_placeholder == 0))
                        acc_q1 = np.sum((counting_answer_placeholder == np.argmax(outputs[3], axis=1)[..., np.newaxis]) *
                            (question_type_placeholder == 1)) / np.maximum(1, np.sum(question_type_placeholder == 1))
                        acc_q2 = np.sum((existence_answer_placeholder == (outputs[2] > 0.5)) *
                            (question_type_placeholder == 2)) / np.maximum(1, np.sum(question_type_placeholder == 2))
                        acc_q3 = np.sum((existence_answer_placeholder == (outputs[2] > 0.5)) *
                            (question_type_placeholder == 3)) / np.maximum(1, np.sum(question_type_placeholder == 3))

                        curr_loss = outputs[1]
                        outputs = sess.run([loss_summary_op],
                                feed_dict={
                                    accs_ph[0]: acc_q0,
                                    accs_ph[1]: acc_q1,
                                    accs_ph[2]: acc_q2,
                                    accs_ph[3]: acc_q3,
                                    loss_ph: curr_loss,
                                    })

                        loss_summary = outputs[0]
                        outputs.append(curr_loss)
                    summary_writer.add_summary(loss_summary, iteration)
                    summary_writer.flush()
                else:
                    outputs = sess.run([training_step, network.rl_total_loss],
                            feed_dict=feed_dict)

            loss = outputs[1]
            solver_t_end = time.time()

            if constants.DEBUG or constants.DRAWING:
                # Look at outputs
                guess_bool = outputs[2].flatten()
                guess_count = outputs[3]
                possible_moves_pred = outputs[4]
                memory_crop = outputs[5]
                print('loss', loss)
                for bb in range(constants.BATCH_SIZE):
                    if constants.DRAWING:
                        import cv2
                        import scipy.misc
                        from utils import drawing
                        object_ind = int(question_object_placeholder[bb])
                        question_type_ind = question_type_placeholder[bb]
                        if question_type_ind == 1:
                            answer = counting_answer_placeholder[bb]
                            guess = guess_count[bb]
                        else:
                            answer = existence_answer_placeholder[bb]
                            guess = np.concatenate(([1 - guess_bool[bb]], [guess_bool[bb]]))
                        if question_type_ind[0] in {2, 3}:
                            container_ind = np.argmax(question_container_placeholder[bb])
                            obj_mem = np.flipud(map_mask_placeholder[bb, :, :, 6 + container_ind]).copy()
                            obj_mem += 2 * np.flipud(map_mask_placeholder[bb, :, :, 6 + object_ind])
                        else:
                            obj_mem = np.flipud(map_mask_placeholder[bb, :, :, 6 + object_ind])

                        possible = possible_move_placeholder[bb,...].flatten()
                        possible = np.concatenate((possible, np.zeros(4)))
                        possible = possible.reshape(constants.STEPS_AHEAD + 2, constants.STEPS_AHEAD)

                        possible_pred = possible_moves_pred[bb,...].flatten()
                        possible_pred = np.concatenate((possible_pred, np.zeros(4)))
                        possible_pred = possible_pred.reshape(constants.STEPS_AHEAD + 2, constants.STEPS_AHEAD)

                        mem2 = np.flipud(np.argmax(memory_crop[bb,...], axis=2))
                        mem2[0, 0] = memory_crop.shape[-1] - 2

                        # Answer histogram
                        ans = guess
                        if len(ans) == 1:
                            ans = [ans[0], 1 - ans[0]]
                        ans_size = max(len(ans), 100)
                        ans_hist = np.zeros((ans_size, ans_size))
                        for ii,ans_i in enumerate(ans):
                            ans_hist[:max(int(np.round(ans_i * ans_size)), 1),
                                    int(ii * ans_size / len(ans)):int((ii+1) * ans_size / len(ans))] = (ii + 1)
                        ans_hist = np.flipud(ans_hist)

                        image_list = [
                                image_placeholder[bb,...],
                                ans_hist,
                                np.flipud(possible),
                                np.flipud(possible_pred),
                                mem2,
                                np.flipud(np.argmax(map_mask_starts[bb, :, :, 2:], axis=2)),
                                obj_mem,
                                ]
                        if question_type_ind == 0:
                            question_str = 'Ex Q: %s A: %s' % (constants.OBJECTS[object_ind], bool(answer))
                        elif question_type_ind == 1:
                            question_str = '# Q: %s A: %d' % (constants.OBJECTS[object_ind], answer)
                        elif question_type_ind == 2:
                            question_str = 'Q: %s in %s A: %s' % (
                                    constants.OBJECTS[object_ind],
                                    constants.OBJECTS[container_ind],
                                    bool(answer))
                        image = drawing.subplot(image_list, 4, 2, constants.SCREEN_WIDTH, constants.SCREEN_HEIGHT,
                                titles=[question_str, 'A: %s' % str(np.argmax(guess))], border=2)
                        cv2.imshow('image', image[:, :, ::-1])
                        cv2.waitKey(0)
                    else:
                        pdb.set_trace()

            if not (constants.DEBUG or constants.DRAWING) and (iteration % 1000 == 0 or iteration == constants.MAX_TIME_STEP - 1):
                saver_t_start = time.time()
                tf_util.save(saver, sess, constants.CHECKPOINT_DIR, iteration)
                saver_t_end = time.time()
                print('Saver:     %.3f' % (saver_t_end - saver_t_start))

            curr_it += 1

            data_time_total += data_t_end - t_start
            solver_time_total += solver_t_end - data_t_end
            total_time_total += time.time() - t_start

            if iteration == start_it or (iteration) % 10 == 0:
                print('Iteration: %d' % (iteration))
                print('Loss:      %.3f' % loss)
                print('Data:      %.3f' % (data_time_total / curr_it))
                print('Solver:    %.3f' % (solver_time_total / curr_it))
                print('Total:     %.3f' % (total_time_total / curr_it))
                print('Current:   %.3f\n' % ((time.time() - current_time_start) / min(10, curr_it)))

    except:
        import traceback
        traceback.print_exc()
    finally:
        # Save final model
        if not (constants.DEBUG or constants.DRAWING):
            tf_util.save(saver, sess, constants.CHECKPOINT_DIR, iteration)