Ejemplo n.º 1
0
def run():
    try:
        with tf.variable_scope('nav_global_network'):
            network = FreeSpaceNetwork(constants.GRU_SIZE,
                                       constants.BATCH_SIZE,
                                       constants.NUM_UNROLLS)
            network.create_net()
            training_step = network.training_op

        with tf.variable_scope('loss'):
            loss_summary_op = tf.summary.merge([
                tf.summary.scalar('loss', network.loss),
            ])
        summary_full = tf.summary.merge_all()
        conv_var_list = [
            v for v in tf.trainable_variables()
            if 'conv' in v.name and 'weight' in v.name and
            (v.get_shape().as_list()[0] != 1 or v.get_shape().as_list()[1] != 1
             )
        ]
        for var in conv_var_list:
            tf_util.conv_variable_summaries(var,
                                            scope=var.name.replace('/',
                                                                   '_')[:-2])
        summary_with_images = tf.summary.merge_all()

        # prepare session
        sess = tf_util.Session()

        seq_inds = np.zeros((constants.BATCH_SIZE, 2), dtype=np.int32)

        sequence_generators = []
        for thread_index in range(constants.PARALLEL_SIZE):
            gpus = str(constants.GPU_ID).split(',')
            sequence_generator = SequenceGenerator(sess)
            sequence_generators.append(sequence_generator)

        sess.run(tf.global_variables_initializer())

        if not (constants.DEBUG or constants.DRAWING):
            from utils import py_util
            time_str = py_util.get_time_str()
            summary_writer = tf.summary.FileWriter(
                os.path.join(constants.LOG_FILE, time_str), sess.graph)
        else:
            summary_writer = None

        saver = tf.train.Saver(max_to_keep=3)

        # init or load checkpoint
        start_it = tf_util.restore_from_dir(sess, constants.CHECKPOINT_DIR)

        sess.graph.finalize()

        data_lock = threading.Lock()

        def load_new_data(thread_index):
            global data_buffer
            global data_counts

            sequence_generator = sequence_generators[thread_index]
            counter = 0
            while True:
                while not (len(data_buffer) < constants.REPLAY_BUFFER_SIZE
                           or np.max(data_counts) > 0):
                    time.sleep(1)
                counter += 1
                if constants.DEBUG:
                    print('\nThread %d' % thread_index)
                new_data, bounds, goal_pose = sequence_generator.generate_episode(
                )
                new_data = {
                    key: ([new_data[ii][key] for ii in range(len(new_data))])
                    for key in new_data[0]
                }
                new_data['goal_pose'] = goal_pose
                new_data['memory'] = np.zeros(
                    (constants.SPATIAL_MAP_HEIGHT, constants.SPATIAL_MAP_WIDTH,
                     constants.MEMORY_SIZE))
                new_data['gru_state'] = np.zeros(constants.GRU_SIZE)
                if constants.DRAWING:
                    new_data['debug_images'] = sequence_generator.debug_images
                data_lock.acquire()
                if len(data_buffer) < constants.REPLAY_BUFFER_SIZE:
                    data_counts[len(data_buffer)] = 0
                    data_buffer.append(new_data)
                    counts = data_counts[:len(data_buffer)]
                    if counter % 10 == 0:
                        print(
                            'Buffer size %d  Num used %d  Max used amount %d' %
                            (len(data_buffer), len(
                                counts[counts > 0]), np.max(counts)))
                else:
                    max_count_ind = np.argmax(data_counts)
                    data_buffer[max_count_ind] = new_data
                    data_counts[max_count_ind] = 0
                    if counter % 10 == 0:
                        print('Num used %d  Max used amount %d' %
                              (len(data_counts[data_counts > 0]),
                               np.max(data_counts)))
                data_lock.release()

        threads = []
        for i in range(constants.PARALLEL_SIZE):
            load_data_thread = threading.Thread(target=load_new_data,
                                                args=(i, ))
            load_data_thread.daemon = True
            load_data_thread.start()
            threads.append(load_data_thread)
            time.sleep(1)

        sequences = [None] * constants.BATCH_SIZE

        curr_it = 0
        dataTimeTotal = 0.00001
        solverTimeTotal = 0.00001
        summaryTimeTotal = 0.00001
        totalTimeTotal = 0.00001

        chosen_inds = set()
        loc_to_chosen_ind = {}
        for iteration in range(start_it, constants.MAX_TIME_STEP):
            if iteration == start_it or iteration % 10 == 1:
                currentTimeStart = time.time()
            tStart = time.time()
            batch_data = []
            batch_action = []
            batch_memory = []
            batch_gru_state = []
            batch_labels = []
            batch_pose = []
            batch_mask = []
            batch_goal_pose = []
            batch_pose_indicator = []
            batch_possible_label = []
            batch_debug_images = []
            for bb in range(constants.BATCH_SIZE):
                if seq_inds[bb, 0] == seq_inds[bb, 1]:
                    # Pick a new random sequence
                    pickable_inds = set(
                        np.where(data_counts < 100)[0]) - chosen_inds
                    count_size = len(pickable_inds)
                    while count_size == 0:
                        pickable_inds = set(
                            np.where(data_counts < 100)[0]) - chosen_inds
                        count_size = len(pickable_inds)
                        time.sleep(1)
                    random_ind = random.sample(pickable_inds, 1)[0]
                    data_lock.acquire()
                    sequences[bb] = data_buffer[random_ind]
                    goal_pose = sequences[bb]['goal_pose']
                    sequences[bb]['memory'] = np.zeros(
                        (constants.SPATIAL_MAP_HEIGHT,
                         constants.SPATIAL_MAP_WIDTH, constants.MEMORY_SIZE))
                    sequences[bb]['gru_state'] = np.zeros(constants.GRU_SIZE)
                    data_counts[random_ind] += 1
                    if bb in loc_to_chosen_ind:
                        chosen_inds.remove(loc_to_chosen_ind[bb])
                    loc_to_chosen_ind[bb] = random_ind
                    chosen_inds.add(random_ind)
                    data_lock.release()
                    seq_inds[bb, 0] = 0
                    seq_inds[bb, 1] = len(sequences[bb]['color'])
                data_len = min(constants.NUM_UNROLLS,
                               seq_inds[bb, 1] - seq_inds[bb, 0])
                ind0 = seq_inds[bb, 0]
                ind1 = seq_inds[bb, 0] + data_len
                data = sequences[bb]['color'][ind0:ind1]
                action = sequences[bb]['action'][ind0:ind1]
                labels = sequences[bb]['label'][ind0:ind1]
                memory = sequences[bb]['memory'].copy()
                gru_state = sequences[bb]['gru_state'].copy()
                pose = sequences[bb]['pose'][ind0:ind1]
                goal_pose = sequences[bb]['goal_pose']
                mask = sequences[bb]['weight'][ind0:ind1]
                pose_indicator = sequences[bb]['pose_indicator'][ind0:ind1]
                possible_label = sequences[bb]['possible_label'][ind0:ind1]
                if constants.DRAWING:
                    batch_debug_images.append(
                        sequences[bb]['debug_images'][ind0:ind1])
                if data_len < (constants.NUM_UNROLLS):
                    seq_inds[bb, :] = 0
                    data.extend([
                        np.zeros_like(data[0])
                        for _ in range(constants.NUM_UNROLLS - data_len)
                    ])
                    action.extend([
                        np.zeros_like(action[0])
                        for _ in range(constants.NUM_UNROLLS - data_len)
                    ])
                    labels.extend([
                        np.zeros_like(labels[0])
                        for _ in range(constants.NUM_UNROLLS - data_len)
                    ])
                    pose.extend([
                        pose[-1]
                        for _ in range(constants.NUM_UNROLLS - data_len)
                    ])
                    mask.extend([
                        np.zeros_like(mask[0])
                        for _ in range(constants.NUM_UNROLLS - data_len)
                    ])
                    pose_indicator.extend([
                        np.zeros_like(pose_indicator[0])
                        for _ in range(constants.NUM_UNROLLS - data_len)
                    ])
                    possible_label.extend([
                        np.zeros_like(possible_label[0])
                        for _ in range(constants.NUM_UNROLLS - data_len)
                    ])
                else:
                    seq_inds[bb, 0] += constants.NUM_UNROLLS
                batch_data.append(data)
                batch_action.append(action)
                batch_memory.append(memory)
                batch_gru_state.append(gru_state)
                batch_pose.append(pose)
                batch_goal_pose.append(goal_pose)
                batch_labels.append(labels)
                batch_mask.append(mask)
                batch_pose_indicator.append(pose_indicator)
                batch_possible_label.append(possible_label)

            feed_dict = {
                network.image_placeholder:
                np.ascontiguousarray(batch_data),
                network.action_placeholder:
                np.ascontiguousarray(batch_action),
                network.gru_placeholder:
                np.ascontiguousarray(batch_gru_state),
                network.pose_placeholder:
                np.ascontiguousarray(batch_pose),
                network.goal_pose_placeholder:
                np.ascontiguousarray(batch_goal_pose),
                network.labels_placeholder:
                np.ascontiguousarray(batch_labels)[..., np.newaxis],
                network.mask_placeholder:
                np.ascontiguousarray(batch_mask),
                network.pose_indicator_placeholder:
                np.ascontiguousarray(batch_pose_indicator),
                network.possible_label_placeholder:
                np.ascontiguousarray(batch_possible_label),
                network.memory_placeholders:
                np.ascontiguousarray(batch_memory),
            }
            dataTEnd = time.time()
            summaryTime = 0
            if constants.DEBUG or constants.DRAWING:
                outputs = sess.run([
                    training_step, network.loss, network.gru_state,
                    network.patch_weights_sigm, network.gru_outputs_full,
                    network.is_possible_sigm,
                    network.pose_indicator_placeholder,
                    network.terminal_patches, network.gru_outputs
                ],
                                   feed_dict=feed_dict)
            else:
                if iteration == start_it + 10:
                    run_options = tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE)
                    run_metadata = tf.RunMetadata()
                    outputs = sess.run([
                        training_step, network.loss, network.gru_state,
                        summary_with_images, network.gru_outputs
                    ],
                                       feed_dict=feed_dict,
                                       options=run_options,
                                       run_metadata=run_metadata)
                    loss_summary = outputs[3]
                    summary_writer.add_run_metadata(run_metadata,
                                                    'step_%07d' % iteration)
                    summary_writer.add_summary(loss_summary, iteration)
                    summary_writer.flush()
                elif iteration % 10 == 0:
                    if iteration % 100 == 0:
                        outputs = sess.run([
                            training_step, network.loss, network.gru_state,
                            summary_with_images, network.gru_outputs
                        ],
                                           feed_dict=feed_dict)
                    elif iteration % 10 == 0:
                        outputs = sess.run([
                            training_step, network.loss, network.gru_state,
                            loss_summary_op, network.gru_outputs
                        ],
                                           feed_dict=feed_dict)
                    loss_summary = outputs[3]
                    summaryTStart = time.time()
                    summary_writer.add_summary(loss_summary, iteration)
                    summary_writer.flush()
                    summaryTime = time.time() - summaryTStart
                else:
                    outputs = sess.run([
                        training_step, network.loss, network.gru_state,
                        network.gru_outputs
                    ],
                                       feed_dict=feed_dict)

            gru_state_out = outputs[2]
            memory_out = outputs[-1]
            for mm in range(constants.BATCH_SIZE):
                sequences[mm]['memory'] = memory_out[mm, ...]
                sequences[mm]['gru_state'] = gru_state_out[mm, ...]

            loss = outputs[1]
            solverTEnd = time.time()

            if constants.DEBUG or constants.DRAWING:
                # Look at outputs
                patch_weights = outputs[3]
                is_possible = outputs[5]
                pose_indicator = outputs[6]
                terminal_patches = outputs[7]
                data_lock.acquire()
                for bb in range(constants.BATCH_SIZE):
                    for tt in range(constants.NUM_UNROLLS):
                        if batch_mask[bb][tt] == 0:
                            break
                        if constants.DRAWING:
                            import cv2
                            import scipy.misc
                            from utils import drawing
                            curr_image = batch_data[bb][tt]
                            label = np.flipud(batch_labels[bb][tt])
                            debug_images = batch_debug_images[bb][tt]
                            color_image = debug_images['color']
                            state_image = debug_images['state_image']
                            label_memory_image = debug_images[
                                'label_memory'][:, :, 0]
                            label_memory_image_class = np.argmax(
                                debug_images['label_memory'][:, :, 1:], axis=2)
                            label_memory_image_class[0,
                                                     0] = constants.NUM_CLASSES

                            label_patch = debug_images['label']

                            print('Possible pred %.3f' % is_possible[bb, tt])
                            print('Possible label %.3f' %
                                  batch_possible_label[bb][tt])
                            patch = np.flipud(patch_weights[bb, tt, ...])
                            patch_occupancy = patch[:, :, 0]
                            print('occ', patch_occupancy)
                            print('label', label)
                            terminal_patch = np.flipud(
                                np.sum(terminal_patches[bb, tt, ...], axis=2))
                            image_list = [
                                debug_images['color'],
                                state_image,
                                debug_images['label_memory'][:, :, 0],
                                debug_images['memory_map'][:, :, 0],
                                label[:, :],
                                patch_occupancy,
                                np.flipud(pose_indicator[bb, tt]),
                                terminal_patch,
                            ]

                            image = drawing.subplot(image_list, 4, 2,
                                                    constants.SCREEN_WIDTH,
                                                    constants.SCREEN_HEIGHT)
                            cv2.imshow('image', image[:, :, ::-1])
                            cv2.waitKey(0)
                        else:
                            pdb.set_trace()
                data_lock.release()

            if not (constants.DEBUG or constants.DRAWING) and (
                    iteration % 500 == 0
                    or iteration == constants.MAX_TIME_STEP - 1):
                saverTStart = time.time()
                tf_util.save(saver, sess, constants.CHECKPOINT_DIR, iteration)
                saverTEnd = time.time()
                print('Saver:     %.3f' % (saverTEnd - saverTStart))

            curr_it += 1

            dataTimeTotal += dataTEnd - tStart
            summaryTimeTotal += summaryTime
            solverTimeTotal += solverTEnd - dataTEnd - summaryTime
            totalTimeTotal += time.time() - tStart

            if iteration == start_it or (iteration) % 10 == 0:
                print('Iteration: %d' % (iteration))
                print('Loss:      %.3f' % loss)
                print('Data:      %.3f' % (dataTimeTotal / curr_it))
                print('Solver:    %.3f' % (solverTimeTotal / curr_it))
                print('Summary:   %.3f' % (summaryTimeTotal / curr_it))
                print('Total:     %.3f' % (totalTimeTotal / curr_it))
                print('Current:   %.3f\n' %
                      ((time.time() - currentTimeStart) / min(10, curr_it)))

    except:
        import traceback
        traceback.print_exc()
    finally:
        # Save final model
        if not (constants.DEBUG or constants.DRAWING):
            tf_util.save(saver, sess, constants.CHECKPOINT_DIR, iteration)
    def create_dump():
        time_str = py_util.get_time_str()
        prefix = 'questions/'
        if not os.path.exists(prefix + dataset_type + '/data_contains'):
            os.makedirs(prefix + dataset_type + '/data_contains')

        h5 = h5py.File(prefix + dataset_type + '/data_contains/Contains_Questions_' + time_str + '.h5', 'w')
        h5.create_dataset('questions/question', (num_record, 5), dtype=np.int32)
        print('Generating %d contains questions' % num_record)

        # Generate contains questions
        data_ind = 0
        episode = Episode()
        scene_number = -1
        while data_ind < num_record:
            k = 0

            scene_number += 1
            scene_num = scene_numbers[scene_number % len(scene_numbers)]

            scene_name = 'FloorPlan%d' % scene_num
            episode.initialize_scene(scene_name)
            num_tries = 0
            while k < num_samples_per_scene and num_tries < 10 * num_samples_per_scene:
                # randomly pick a pickable object in the scene
                object_class = random.choice(all_object_classes)
                generated_no, generated_yes = False, False
                question = None
                temp_data = []
                num_tries += 1

                grid_file = 'layouts/%s-layout.npy' % scene_name
                xray_graph = graph_obj.Graph(grid_file, use_gt=True, construct_graph=False)
                scene_bounds = [xray_graph.xMin, xray_graph.yMin,
                    xray_graph.xMax - xray_graph.xMin + 1,
                    xray_graph.yMax - xray_graph.yMin + 1]

                for i in range(20):  # try 20 times
                    scene_seed = random.randint(0, 999999999)
                    episode.initialize_episode(scene_seed=scene_seed)  # randomly initialize the scene
                    if not question:
                        question = ExistenceQuestion.get_true_contain_question(episode, all_object_classes, receptacles)
                        if DEBUG:
                            print(str(question))
                        if not question:
                            continue
                        object_class = question.object_class
                        parent_class = question.parent_object_class

                    answer = question.get_answer(episode)
                    object_target = constants.OBJECT_CLASS_TO_ID[object_class]

                    if answer and not generated_yes:
                        event = episode.event
                        xray_graph.memory[:, :, 1:] = 0
                        objs = {obj['objectId']: obj for obj in event.metadata['objects']
                                if obj['objectType'] == object_class and obj['parentReceptacle'].split('|')[0] == parent_class}
                        for obj in objs.values():
                            if obj['objectType'] != object_class:
                                continue
                            obj_point = game_util.get_object_point(obj, scene_bounds)
                            xray_graph.memory[obj_point[1], obj_point[0],
                                    constants.OBJECT_CLASS_TO_ID[obj['objectType']] + 1] = 1

                        # Make sure findable
                        try:
                            graph_points = xray_graph.points.copy()
                            graph_points = graph_points[np.random.permutation(graph_points.shape[0]), :]
                            num_checked_points = 0
                            for start_point in graph_points:
                                headings = np.random.permutation(4)
                                for heading in headings:
                                    start_point = (start_point[0], start_point[1], heading)
                                    patch = xray_graph.get_graph_patch(start_point)[0]
                                    if patch[:, :, object_target + 1].max() > 0:
                                        action = {'action': 'TeleportFull',
                                                  'x': start_point[0] * constants.AGENT_STEP_SIZE,
                                                  'y': episode.agent_height,
                                                  'z': start_point[1] * constants.AGENT_STEP_SIZE,
                                                  'rotateOnTeleport': True,
                                                  'rotation': start_point[2] * 90,
                                                  'horizon': -30,
                                                  }
                                        event = episode.env.step(action)
                                        num_checked_points += 1
                                        if num_checked_points > 1000:
                                            answer = None
                                            raise AssertionError
                                        for jj in range(4):
                                            open_success = True
                                            opened_objects = set()
                                            parents = [game_util.get_object(obj['parentReceptacle'], event.metadata)
                                                       for obj in objs.values()]
                                            openable_parents = [parent for parent in parents
                                                                if parent['visible'] and parent['openable'] and not parent['isopen']]
                                            while open_success:
                                                for obj in objs.values():
                                                    if obj['objectId'] in event.instance_detections2D:
                                                        if game_util.check_object_size(event.instance_detections2D[obj['objectId']]):
                                                            raise AssertionError
                                                if len(openable_parents) > 0:
                                                    action = {'action': 'OpenObject'}
                                                    game_util.set_open_close_object(action, event)
                                                    event = episode.env.step(action)
                                                    open_success = event.metadata['lastActionSuccess']
                                                    if open_success:
                                                        opened_objects.add(episode.env.last_action['objectId'])
                                                else:
                                                    open_success = False
                                            for opened in opened_objects:
                                                event = episode.env.step({
                                                    'action': 'CloseObject',
                                                    'objectId': opened,
                                                    'forceVisible': True})
                                                if not event.metadata['lastActionSuccess']:
                                                    answer = None
                                                    raise AssertionError
                                            if jj < 3:
                                                event = episode.env.step({'action': 'LookDown'})
                            answer = None
                        except AssertionError:
                            if answer is None and DEBUG:
                                print('failed to generate')

                    print(str(question), answer)

                    if answer == False and not generated_no:
                        generated_no = True
                        temp_data.append([scene_num, scene_seed, constants.OBJECT_CLASS_TO_ID[object_class], constants.OBJECT_CLASS_TO_ID[parent_class], answer])
                    elif answer == True and not generated_yes:
                        generated_yes = True
                        temp_data.append([scene_num, scene_seed, constants.OBJECT_CLASS_TO_ID[object_class], constants.OBJECT_CLASS_TO_ID[parent_class], answer])

                    if generated_no and generated_yes:
                        h5['questions/question'][data_ind, :] = np.array(temp_data[0])
                        h5['questions/question'][data_ind + 1, :] = np.array(temp_data[1])
                        h5.flush()
                        data_ind += 2
                        k += 2
                        break
                print("# generated samples: {}".format(data_ind))

        h5.close()
        episode.env.stop_unity()
Ejemplo n.º 3
0
def run():
    global global_t
    global dataset
    global data_ind
    if constants.OBJECT_DETECTION:
        from darknet_object_detection import detector
        detector.setup_detectors(constants.PARALLEL_SIZE)

    os.environ["CUDA_VISIBLE_DEVICES"] = str(constants.GPU_ID)

    try:
        with tf.variable_scope('global_network'):
            if constants.END_TO_END_BASELINE:
                global_network = EndToEndBaselineNetwork()
            else:
                global_network = QAPlannerNetwork(constants.RL_GRU_SIZE, 1, 1)
            global_network.create_net()
        if constants.USE_NAVIGATION_AGENT:
            with tf.variable_scope('nav_global_network') as net_scope:
                free_space_network = FreeSpaceNetwork(constants.GRU_SIZE, 1, 1)
                free_space_network.create_net()
        else:
            net_scope = None

        conv_var_list = [v for v in tf.trainable_variables() if 'conv' in v.name and 'weight' in v.name and
                        (v.get_shape().as_list()[0] != 1 or v.get_shape().as_list()[1] != 1)]
        for var in conv_var_list:
            tf_util.conv_variable_summaries(var, scope=var.name.replace('/', '_')[:-2])
        conv_image_summary = tf.summary.merge_all()

        # prepare session
        sess = tf_util.Session()

        # Instantiate singletons without scope.
        if constants.PREDICT_DEPTH:
            from depth_estimation_network import depth_estimator
            with tf.variable_scope('') as depth_scope:
                depth_estimator = depth_estimator.get_depth_estimator(sess)
        else:
            depth_scope = None

        training_threads = []

        learning_rate_input = tf.placeholder(tf.float32, name='learning_rate')
        grad_applier = RMSPropApplier(learning_rate=learning_rate_input,
                decay=constants.RMSP_ALPHA,
                momentum=0.0,
                epsilon=constants.RMSP_EPSILON,
                clip_norm=constants.GRAD_NORM_CLIP)

        for i in range(constants.PARALLEL_SIZE):
            training_thread = A3CTrainingThread(
                    i, sess,
                    learning_rate_input,
                    grad_applier,
                    constants.MAX_TIME_STEP,
                    net_scope,
                    depth_scope)
            training_threads.append(training_thread)

        if constants.RUN_TEST:
            testing_thread = A3CTestingThread(constants.PARALLEL_SIZE + 1, sess, net_scope, depth_scope)

        sess.run(tf.global_variables_initializer())

        # Initialize pretrained weights after init.
        if constants.PREDICT_DEPTH:
            depth_estimator.load_weights()


        episode_reward_input = tf.placeholder(tf.float32, name='ep_reward')
        episode_length_input = tf.placeholder(tf.float32, name='ep_length')
        exist_answer_correct_input = tf.placeholder(tf.float32, name='exist_ans')
        count_answer_correct_input = tf.placeholder(tf.float32, name='count_ans')
        contains_answer_correct_input = tf.placeholder(tf.float32, name='contains_ans')
        percent_invalid_actions_input = tf.placeholder(tf.float32, name='inv')

        scalar_summaries = [
            tf.summary.scalar("Episode Reward", episode_reward_input),
            tf.summary.scalar("Episode Length", episode_length_input),
            tf.summary.scalar("Percent Invalid Actions", percent_invalid_actions_input),
        ]
        exist_summary = tf.summary.scalar("Answer Correct Existence", exist_answer_correct_input),
        count_summary = tf.summary.scalar("Answer Correct Counting", count_answer_correct_input),
        contains_summary = tf.summary.scalar("Answer Correct Containing", contains_answer_correct_input),
        exist_summary_op = tf.summary.merge(scalar_summaries + [exist_summary])
        count_summary_op = tf.summary.merge(scalar_summaries + [count_summary])
        contains_summary_op = tf.summary.merge(scalar_summaries + [contains_summary])
        summary_ops = [exist_summary_op, count_summary_op, contains_summary_op]

        summary_placeholders = {
            "episode_reward_input": episode_reward_input,
            "episode_length_input": episode_length_input,
            "exist_answer_correct_input": exist_answer_correct_input,
            "count_answer_correct_input": count_answer_correct_input,
            "contains_answer_correct_input": contains_answer_correct_input,
            "percent_invalid_actions_input" : percent_invalid_actions_input,
            }

        if constants.RUN_TEST:
            test_episode_reward_input = tf.placeholder(tf.float32, name='test_ep_reward')
            test_episode_length_input = tf.placeholder(tf.float32, name='test_ep_length')
            test_exist_answer_correct_input = tf.placeholder(tf.float32, name='test_exist_ans')
            test_count_answer_correct_input = tf.placeholder(tf.float32, name='test_count_ans')
            test_contains_answer_correct_input = tf.placeholder(tf.float32, name='test_contains_ans')
            test_percent_invalid_actions_input = tf.placeholder(tf.float32, name='test_inv')

            test_scalar_summaries = [
                tf.summary.scalar("Test Episode Reward", test_episode_reward_input),
                tf.summary.scalar("Test Episode Length", test_episode_length_input),
                tf.summary.scalar("Test Percent Invalid Actions", test_percent_invalid_actions_input),
            ]
            exist_summary = tf.summary.scalar("Test Answer Correct Existence", test_exist_answer_correct_input),
            count_summary = tf.summary.scalar("Test Answer Correct Counting", test_count_answer_correct_input),
            contains_summary = tf.summary.scalar("Test Answer Correct Containing", test_contains_answer_correct_input),
            test_exist_summary_op = tf.summary.merge(test_scalar_summaries + [exist_summary])
            test_count_summary_op = tf.summary.merge(test_scalar_summaries + [count_summary])
            test_contains_summary_op = tf.summary.merge(test_scalar_summaries + [contains_summary])
            test_summary_ops = [test_exist_summary_op, test_count_summary_op, test_contains_summary_op]

        if not constants.DEBUG:
            time_str = py_util.get_time_str()
            summary_writer = tf.summary.FileWriter(os.path.join(constants.LOG_FILE, time_str), sess.graph)
        else:
            summary_writer = None

        # init or load checkpoint with saver
        vars_to_save = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='global_network')
        saver = tf.train.Saver(vars_to_save, max_to_keep=3)
        print('-------------- Looking for checkpoints in ', constants.CHECKPOINT_DIR)
        global_t = tf_util.restore_from_dir(sess, constants.CHECKPOINT_DIR)

        if constants.USE_NAVIGATION_AGENT:
            print('now trying to restore nav model')
            tf_util.restore_from_dir(sess, 'logs/checkpoints/navigation', True)

        sess.graph.finalize()

        for i in range(constants.PARALLEL_SIZE):
            sess.run(training_threads[i].sync)
            training_threads[i].agent.reset()

        times = []

        if not constants.DEBUG and constants.RECORD_FEED_DICT:
            import h5py
            NUM_RECORD = 10000 * len(constants.USED_QUESTION_TYPES)
            time_str = py_util.get_time_str()
            if not os.path.exists('question_data_dump'):
                os.mkdir('question_data_dump')
            dataset = h5py.File('question_data_dump/maps_' + time_str + '.h5', 'w')

            dataset.create_dataset('question_data/existence_answer_placeholder', (NUM_RECORD, 1), dtype=np.int32)
            dataset.create_dataset('question_data/counting_answer_placeholder', (NUM_RECORD, 1), dtype=np.int32)

            dataset.create_dataset('question_data/question_type_placeholder', (NUM_RECORD, 1), dtype=np.int32)
            dataset.create_dataset('question_data/question_object_placeholder', (NUM_RECORD, 1), dtype=np.int32)
            dataset.create_dataset('question_data/question_container_placeholder', (NUM_RECORD, 21), dtype=np.int32)

            dataset.create_dataset('question_data/pose_placeholder', (NUM_RECORD, 3), dtype=np.int32)
            dataset.create_dataset('question_data/image_placeholder', (NUM_RECORD, 300, 300, 3), dtype=np.uint8)
            dataset.create_dataset('question_data/map_mask_placeholder',
                    (NUM_RECORD, constants.SPATIAL_MAP_HEIGHT, constants.SPATIAL_MAP_WIDTH, 27), dtype=np.uint8)
            dataset.create_dataset('question_data/meta_action_placeholder', (NUM_RECORD, 7), dtype=np.int32)
            dataset.create_dataset('question_data/possible_move_placeholder', (NUM_RECORD, 31), dtype=np.int32)

            dataset.create_dataset('question_data/answer_weight', (NUM_RECORD, 1), dtype=np.int32)
            dataset.create_dataset('question_data/taken_action', (NUM_RECORD, 32), dtype=np.int32)
            dataset.create_dataset('question_data/new_episode', (NUM_RECORD, 1), dtype=np.int32)

        time_lock = threading.Lock()

        data_ind = 0

        def train_function(parallel_index):
            global global_t
            global dataset
            global data_ind
            print('----------------------------------------thread', parallel_index, 'global_t', global_t)
            training_thread = training_threads[parallel_index]
            last_global_t = global_t
            last_global_t_image = global_t

            while global_t < constants.MAX_TIME_STEP:
                diff_global_t, ep_length, ep_reward, num_unrolls, feed_dict = training_thread.process(global_t, summary_writer,
                        summary_ops, summary_placeholders)
                time_lock.acquire()

                if not constants.DEBUG and constants.RECORD_FEED_DICT:
                    print('    NEW ENTRY: %d %s' % (data_ind, training_thread.agent.game_state.scene_name))
                    dataset['question_data/new_episode'][data_ind] = 1
                    for k, v in feed_dict.items():
                        key = 'question_data/' + k.name.split('/')[-1].split(':')[0]
                        if any([s in key for s in {
                            'gru_placeholder', 'num_unrolls',
                            'reward_placeholder', 'td_placeholder',
                            'learning_rate', 'phi_hat_prev_placeholder',
                            'next_map_mask_placeholder', 'next_pose_placeholder',
                            'episode_length_placeholder', 'question_count_placeholder',
                            'supervised_action_labels', 'supervised_action_weights_sigmoid',
                            'supervised_action_weights', 'question_direction_placeholder',
                        }]):
                            continue
                        v = np.ascontiguousarray(v)
                        if v.shape[0] == 1:
                            v = v[0, ...]
                        if len(v.shape) == 1 and num_unrolls > 1:
                            v = v[:, np.newaxis]
                        if 'map_mask_placeholder' in key:
                            v[:, :, :, :2] *= constants.STEPS_AHEAD
                            v[:, :, :, :2] += 2
                        data_loc = dataset[key][data_ind:data_ind + num_unrolls, ...]
                        data_len = data_loc.shape[0]
                        dataset[key][data_ind:data_ind + num_unrolls, ...] = v[:data_len, ...]
                    dataset.flush()

                    data_ind += num_unrolls
                    if data_ind >= NUM_RECORD:
                        # Everything is done
                        dataset.close()
                        raise Exception('Fake exception to exit the process. Don\'t worry, everything is fine.')
                if ep_length > 0:
                    times.append((ep_length, ep_reward))
                    print('Num episodes', len(times), 'Episode means', np.mean(times, axis=0))
                time_lock.release()
                global_t += diff_global_t
                # periodically save checkpoints to disk
                if not (constants.DEBUG or constants.RECORD_FEED_DICT) and parallel_index == 0:
                    if global_t - last_global_t_image > 10000:
                        print('Ran conv image summary')
                        summary_im_str = sess.run(conv_image_summary)
                        summary_writer.add_summary(summary_im_str, global_t)
                        last_global_t_image = global_t
                    if global_t - last_global_t > 10000:
                        print('Save checkpoint at timestamp %d' % global_t)
                        tf_util.save(saver, sess, constants.CHECKPOINT_DIR, global_t)
                        last_global_t = global_t

        def test_function():
            global global_t
            last_global_t_test = 0
            while global_t < constants.MAX_TIME_STEP:
                time.sleep(1)
                if global_t - last_global_t_test > 10000:
                    # RUN TEST
                    sess.run(testing_thread.sync)
                    from game_state import QuestionGameState
                    if testing_thread.agent.game_state is None:
                        testing_thread.agent.game_state = QuestionGameState(sess=sess)
                    for q_type in constants.USED_QUESTION_TYPES:
                        answers_correct = 0.0
                        ep_lengths = 0.0
                        ep_rewards = 0.0
                        invalid_percents = 0.0
                        rows = list(range(len(testing_thread.agent.game_state.test_datasets[q_type])))
                        random.shuffle(rows)
                        rows = rows[:16]
                        print('()()()()()()()rows', rows)
                        for rr,row in enumerate(rows):
                            answer_correct, answer, gt_answer, ep_length, ep_reward, invalid_percent, scene_num, seed, required_interaction = testing_thread.process((row, q_type))
                            answers_correct += int(answer_correct)
                            ep_lengths += ep_length
                            ep_rewards += ep_reward
                            invalid_percents += invalid_percent
                            print('############################### TEST ITERATION ##################################')
                            print('ep ', (rr + 1))
                            print('average correct', answers_correct / (rr + 1))
                            print('#################################################################################')
                        answers_correct /= len(rows)
                        ep_lengths /= len(rows)
                        ep_rewards /= len(rows)
                        invalid_percents /= len(rows)

                        # Write the summary
                        test_summary_str = sess.run(test_summary_ops[q_type], feed_dict={
                            test_episode_reward_input : ep_rewards,
                            test_episode_length_input : ep_lengths,
                            test_exist_answer_correct_input : answers_correct,
                            test_count_answer_correct_input : answers_correct,
                            test_contains_answer_correct_input : answers_correct,
                            test_percent_invalid_actions_input : invalid_percents,
                            })
                        summary_writer.add_summary(test_summary_str, global_t)
                        summary_writer.flush()
                        last_global_t_test = global_t
                    testing_thread.agent.game_state.env.stop()
                    testing_thread.agent.game_state = None



        train_threads = []
        for i in range(constants.PARALLEL_SIZE):
            train_threads.append(threading.Thread(target=train_function, args=(i,)))
            train_threads[i].daemon = True

        for t in train_threads:
            t.start()

        if constants.RUN_TEST:
            test_thread = threading.Thread(target=test_function)
            test_thread.daemon = True
            test_thread.start()

        for t in train_threads:
            t.join()

        if constants.RUN_TEST:
            test_thread.join()

        if not constants.DEBUG:
            if not os.path.exists(constants.CHECKPOINT_DIR):
                os.makedirs(constants.CHECKPOINT_DIR)
            saver.save(sess, constants.CHECKPOINT_DIR + '/' + 'checkpoint', global_step = global_t)
            summary_writer.close()
            print('Saved.')

    except KeyboardInterrupt:
        print('Press Ctrl+C to stop')
    except:
        import traceback
        traceback.print_exc()
    finally:
        if not constants.DEBUG:
            print('Now saving data. Please wait')
            tf_util.save(saver, sess, constants.CHECKPOINT_DIR, global_t)
            summary_writer.close()
            print('Saved.')
Ejemplo n.º 4
0
def run():
    try:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(constants.GPU_ID)
        sess = tf_util.Session()
        network = QuestionEmbeddingNetwork(sess, constants.BATCH_SIZE)
        network.create_network()
        network.create_train_ops()

        if not constants.DEBUG:
            log_writer = tf.summary.FileWriter(
                os.path.join(constants.LOG_FILE, py_util.get_time_str()),
                sess.graph)
            with tf.name_scope('train'):
                train_summary = tf.summary.merge([
                    tf.summary.scalar('total_loss', network.loss),
                    tf.summary.scalar('question_type_loss',
                                      network.question_type_loss),
                    tf.summary.scalar('object_id_loss',
                                      network.object_id_loss),
                    tf.summary.scalar('container_id_loss',
                                      network.container_id_loss),
                    tf.summary.scalar('question_type_accuracy',
                                      network.question_type_accuracy),
                    tf.summary.scalar('object_id_accuracy',
                                      network.object_id_accuracy),
                    tf.summary.scalar('container_id_accuracy',
                                      network.container_id_accuracy),
                ])

            with tf.name_scope('test'):
                test_summary_ph = tf.placeholder(tf.float32, [None],
                                                 'test_summary_placeholder')
                test_summary = tf.summary.merge([
                    tf.summary.scalar('total_loss', test_summary_ph[0]),
                    tf.summary.scalar('question_type_loss',
                                      test_summary_ph[1]),
                    tf.summary.scalar('object_id_loss', test_summary_ph[2]),
                    tf.summary.scalar('container_id_loss', test_summary_ph[3]),
                    tf.summary.scalar('question_type_accuracy',
                                      test_summary_ph[4]),
                    tf.summary.scalar('object_id_accuracy',
                                      test_summary_ph[5]),
                    tf.summary.scalar('container_id_accuracy',
                                      test_summary_ph[6]),
                ])

        data_train = sorted(
            glob.glob(os.path.join('questions', 'train', '*', '*.csv')))
        data_train = [
            parse_question.get_sequences(filename) for filename in data_train
        ]
        data_train_dict = {}
        for data in data_train:
            for key, val in data.items():
                current_data = data_train_dict.get(key, [])
                current_data.extend(val)
                data_train_dict[key] = current_data
        data_train = data_train_dict
        data_train = {key: np.array(val) for key, val in data_train.items()}
        num_data_train = data_train['questions'].shape[0]

        data_test = sorted(
            glob.glob(os.path.join('questions', 'train_test', '*', '*.csv')))
        data_test = [
            parse_question.get_sequences(filename) for filename in data_test
        ]
        data_test_dict = {}
        for data in data_test:
            for key, val in data.items():
                current_data = data_test_dict.get(key, [])
                current_data.extend(val)
                data_test_dict[key] = current_data
        data_test = data_test_dict
        data_test = {key: np.array(val) for key, val in data_test.items()}
        num_data_test = data_test['questions'].shape[0]

        sess.run(tf.global_variables_initializer())
        iteration = network.restore()
        sess.graph.finalize()
        num_its = 0

        data_train = shuffle_data(data_train)
        data_ind = 0
        total_iteration_time = 0
        while iteration < constants.MAX_TIME_STEP:
            t_start = time.time()
            if num_its == 0 or data_ind + constants.BATCH_SIZE > num_data_train:
                # end of epoch
                data_ind = 0
                data_train = shuffle_data(data_train)

                if not constants.DEBUG:
                    # run test
                    test_start = time.time()
                    test_losses = []
                    for test_data_ind in range(
                            0, num_data_test - constants.BATCH_SIZE + 1,
                            constants.BATCH_SIZE):
                        data_batch = {
                            key: val[test_data_ind:test_data_ind +
                                     constants.BATCH_SIZE, ...]
                            for key, val in data_test.items()
                        }
                        feed_dict = {
                            network.question_placeholder:
                            data_batch['questions'],
                            network.question_type_label:
                            data_batch['question_types'],
                            network.object_id_label:
                            data_batch['object_ids'],
                            network.container_id_label:
                            data_batch['container_ids'],
                        }
                        losses = sess.run([
                            network.loss, network.question_type_loss,
                            network.object_id_loss, network.container_id_loss,
                            network.question_type_accuracy,
                            network.object_id_accuracy,
                            network.container_id_accuracy
                        ],
                                          feed_dict=feed_dict)
                        test_losses.append(losses)
                    summary_losses = np.mean(test_losses, axis=0)
                    summary = sess.run(
                        test_summary,
                        feed_dict={test_summary_ph: summary_losses})
                    log_writer.add_summary(summary, iteration)
                    log_writer.flush()
                    print('test time               %.3f' %
                          (time.time() - test_start))

            data_batch = {
                key: val[data_ind:data_ind + constants.BATCH_SIZE, ...]
                for key, val in data_train.items()
            }
            feed_dict = {
                network.question_placeholder: data_batch['questions'],
                network.question_type_label: data_batch['question_types'],
                network.object_id_label: data_batch['object_ids'],
                network.container_id_label: data_batch['container_ids'],
            }
            if num_its % 10 == 0 and not constants.DEBUG:
                _, summary = sess.run([network.train_op, train_summary],
                                      feed_dict=feed_dict)
                log_writer.add_summary(summary, iteration)
                log_writer.flush()
            else:
                sess.run(network.train_op, feed_dict=feed_dict)
            data_ind += constants.BATCH_SIZE

            t_end = time.time()
            iteration_time = t_end - t_start
            total_iteration_time += iteration_time
            if num_its % 100 == 0:
                print('\niteration               %d' % iteration)
                print('time per iteration      %.3f' % (total_iteration_time /
                                                        (num_its + 1)))
                print('time for last iteration %.3f' % (iteration_time))
            num_its += 1
            iteration += 1
            if iteration % 1000 == 0:
                network.save(iteration)
    except KeyboardInterrupt:
        pass
    except:
        import traceback
        traceback.print_exc()
        print('error')
    finally:
        if not constants.DEBUG:
            print('\nsaving')
            network.save(iteration)
    def create_dump():
        time_str = py_util.get_time_str()
        prefix = 'questions/'
        if not os.path.exists(prefix + dataset_type + '/data_counting'):
            os.makedirs(prefix + dataset_type + '/data_counting')

        h5 = h5py.File(
            prefix + dataset_type + '/data_counting/Counting_Questions_' +
            time_str + '.h5', 'w')
        h5.create_dataset('questions/question', (num_record, 4),
                          dtype=np.int32)
        print('Generating %d counting questions' % num_record)

        # Generate counting questions
        data_ind = 0
        episode = Episode()
        scene_number = random.randint(0, len(scene_numbers) - 1)
        while data_ind < num_record:
            k = 0

            scene_number += 1
            scene_num = scene_numbers[scene_number % len(scene_numbers)]

            scene_name = 'FloorPlan%d' % scene_num
            episode.initialize_scene(scene_name)
            num_tries = 0
            while num_tries < num_samples_per_scene:
                # randomly pick a pickable object in the scene
                object_class = random.choice(all_object_classes)
                question = CountQuestion(
                    object_class
                )  # randomly generate a general counting question
                generated = [None] * (constants.MAX_COUNTING_ANSWER + 1)
                generated_counts = set()

                num_tries += 1

                grid_file = 'layouts/%s-layout.npy' % scene_name
                xray_graph = graph_obj.Graph(grid_file,
                                             use_gt=True,
                                             construct_graph=False)
                scene_bounds = [
                    xray_graph.xMin, xray_graph.yMin,
                    xray_graph.xMax - xray_graph.xMin + 1,
                    xray_graph.yMax - xray_graph.yMin + 1
                ]

                for i in range(100):
                    if DEBUG:
                        print('starting try ', i)
                    scene_seed = random.randint(0, 999999999)
                    episode.initialize_episode(
                        scene_seed=scene_seed,
                        max_num_repeats=constants.MAX_COUNTING_ANSWER + 1,
                        remove_prob=0.5)
                    answer = question.get_answer(episode)
                    object_target = constants.OBJECT_CLASS_TO_ID[object_class]

                    if answer > 0 and answer not in generated_counts:
                        if DEBUG:
                            print('target', str(question), object_target,
                                  answer)
                        event = episode.event

                        # Make sure findable
                        try:
                            objs = {
                                obj['objectId']: obj
                                for obj in event.metadata['objects']
                                if obj['objectType'] == object_class
                            }
                            xray_graph.memory[:, :, 1:] = 0
                            for obj in objs.values():
                                obj_point = game_util.get_object_point(
                                    obj, scene_bounds)
                                xray_graph.memory[obj_point[1], obj_point[0],
                                                  object_target + 1] = 1
                            start_graph = xray_graph.memory.copy()

                            graph_points = xray_graph.points.copy()
                            graph_points = graph_points[np.random.permutation(
                                graph_points.shape[0]), :]
                            num_checked_points = 0
                            point_ind = 0

                            # Initial check to make sure all objects are visible on the grid.
                            while point_ind < len(graph_points):
                                start_point = graph_points[point_ind]
                                headings = np.random.permutation(4)
                                for heading in headings:
                                    start_point = (start_point[0],
                                                   start_point[1], heading)
                                    patch = xray_graph.get_graph_patch(
                                        start_point)[0][:, :,
                                                        object_target + 1]
                                    if patch.max() > 0:
                                        point_ind = 0
                                        xray_graph.update_graph(
                                            (np.zeros((constants.STEPS_AHEAD,
                                                       constants.STEPS_AHEAD,
                                                       1)), 0), start_point,
                                            [object_target + 1])
                                point_ind += 1
                            if np.max(xray_graph.memory[:, :, object_target +
                                                        1]) > 0:
                                if DEBUG:
                                    print('some points could not be reached')
                                answer = None
                                raise AssertionError

                            xray_graph.memory = start_graph
                            point_ind = 0
                            seen_objs = set()
                            while point_ind < len(graph_points):
                                start_point = graph_points[point_ind]
                                headings = np.random.permutation(4)
                                for heading in headings:
                                    start_point = (start_point[0],
                                                   start_point[1], heading)
                                    patch = xray_graph.get_graph_patch(
                                        start_point)[0]
                                    if patch[:, :,
                                             object_target + 1].max() > 0:
                                        action = {
                                            'action': 'TeleportFull',
                                            'x': start_point[0] *
                                            constants.AGENT_STEP_SIZE,
                                            'y': episode.agent_height,
                                            'z': start_point[1] *
                                            constants.AGENT_STEP_SIZE,
                                            'rotateOnTeleport': True,
                                            'rotation': start_point[2] * 90,
                                            'horizon': -30,
                                        }
                                        event = episode.env.step(action)
                                        num_checked_points += 1
                                        if num_checked_points > 20:
                                            if DEBUG:
                                                print('timeout')
                                            answer = None
                                            raise AssertionError
                                        changed = False

                                        for jj in range(4):
                                            open_success = True
                                            opened_objects = set()
                                            parents = [
                                                game_util.get_object(
                                                    obj['parentReceptacle'],
                                                    event.metadata)
                                                for obj in objs.values()
                                            ]
                                            openable_parents = [
                                                parent for parent in parents
                                                if parent['visible']
                                                and parent['openable']
                                                and not parent['isopen']
                                            ]
                                            while open_success:
                                                obj_list = list(objs.values())
                                                for obj in obj_list:
                                                    if obj['objectId'] in event.instance_detections2D:
                                                        if game_util.check_object_size(
                                                                event.
                                                                instance_detections2D[
                                                                    obj['objectId']]
                                                        ):
                                                            seen_objs.add(
                                                                obj['objectId']
                                                            )
                                                            if DEBUG:
                                                                print(
                                                                    'seen',
                                                                    seen_objs)
                                                            del objs[obj[
                                                                'objectId']]
                                                            changed = True
                                                            num_checked_points = 0
                                                            if len(seen_objs
                                                                   ) == answer:
                                                                raise AssertionError
                                                if len(openable_parents) > 0:
                                                    action = {
                                                        'action': 'OpenObject'
                                                    }
                                                    game_util.set_open_close_object(
                                                        action, event)
                                                    event = episode.env.step(
                                                        action)
                                                    open_success = event.metadata[
                                                        'lastActionSuccess']
                                                    if open_success:
                                                        opened_objects.add(
                                                            episode.env.
                                                            last_action[
                                                                'objectId'])
                                                else:
                                                    open_success = False
                                            for opened in opened_objects:
                                                event = episode.env.step({
                                                    'action':
                                                    'CloseObject',
                                                    'objectId':
                                                    opened,
                                                    'forceVisible':
                                                    True
                                                })
                                                if not event.metadata[
                                                        'lastActionSuccess']:
                                                    answer = None
                                                    raise AssertionError
                                            if jj < 3:
                                                event = episode.env.step(
                                                    {'action': 'LookDown'})
                                        if changed:
                                            point_ind = 0
                                            num_checked_points = 0
                                            xray_graph.memory[:, :,
                                                              object_target +
                                                              1] = 0
                                            for obj in objs.values():
                                                obj_point = game_util.get_object_point(
                                                    obj, scene_bounds)
                                                xray_graph.memory[
                                                    obj_point[1], obj_point[0],
                                                    object_target + 1] = 1
                                point_ind += 1
                            if DEBUG:
                                print('ran out of points')
                            answer = None
                        except AssertionError:
                            if answer is not None:
                                if DEBUG:
                                    print('success')
                            pass

                    print(str(question), object_target, answer)

                    if answer is not None and answer < len(
                            generated) and answer not in generated_counts:
                        generated[answer] = [
                            scene_num, scene_seed,
                            constants.OBJECT_CLASS_TO_ID[object_class], answer
                        ]
                        generated_counts.add(answer)
                        print('\tcounts', sorted(list(generated_counts)))

                    if len(generated_counts) == len(generated):
                        for q in generated:
                            if data_ind >= h5['questions/question'].shape[0]:
                                num_tries = 2**32
                                break
                            h5['questions/question'][data_ind, :] = np.array(q)
                            data_ind += 1
                            k += 1
                        h5.flush()
                        break
                print("# generated samples: {}".format(data_ind))

        h5.close()
        episode.env.stop_unity()
Ejemplo n.º 6
0
def main():
    if constants.OBJECT_DETECTION:
        from darknet_object_detection import detector
        detector.setup_detectors(constants.PARALLEL_SIZE)

    with tf.device('/gpu:' + str(constants.GPU_ID)):
        with tf.variable_scope('global_network'):
            if constants.END_TO_END_BASELINE:
                global_network = EndToEndBaselineNetwork()
            else:
                global_network = QAPlannerNetwork(constants.RL_GRU_SIZE, 1, 1)
            global_network.create_net()
        if constants.USE_NAVIGATION_AGENT:
            with tf.variable_scope('nav_global_network') as net_scope:
                free_space_network = FreeSpaceNetwork(constants.GRU_SIZE, 1, 1)
                free_space_network.create_net()
        else:
            net_scope = None

        # prepare session
        sess = tf_util.Session()

        if constants.PREDICT_DEPTH:
            from depth_estimation_network import depth_estimator
            with tf.variable_scope('') as depth_scope:
                depth_estimator = depth_estimator.get_depth_estimator(sess)
        else:
            depth_scope = None

        sess.run(tf.global_variables_initializer())

        # Initialize pretrained weights after init.
        if constants.PREDICT_DEPTH:
            depth_estimator.load_weights()

        testing_threads = []

        for i in range(constants.PARALLEL_SIZE):
            testing_thread = A3CTestingThread(i, sess, net_scope, depth_scope)
            testing_threads.append(testing_thread)

        tf_util.restore_from_dir(sess, constants.CHECKPOINT_DIR, True)

        if constants.USE_NAVIGATION_AGENT:
            print('now trying to restore nav model')
            tf_util.restore_from_dir(
                sess, os.path.join(constants.CHECKPOINT_PREFIX, 'navigation'),
                True)

    sess.graph.finalize()

    question_types = constants.USED_QUESTION_TYPES
    rows = []
    for q_type in question_types:
        curr_rows = list(
            range(len(testing_thread.agent.game_state.test_datasets[q_type])))
        #curr_rows = list(range(8))
        rows.extend(list(zip(curr_rows, [q_type] * len(curr_rows))))

    random.shuffle(rows)

    answers_correct = []
    ep_lengths = []
    ep_rewards = []
    invalid_percents = []
    time_lock = threading.Lock()
    if not os.path.exists(constants.LOG_FILE):
        os.makedirs(constants.LOG_FILE)
    out_file = open(
        constants.LOG_FILE + '/results_' + constants.TEST_SET + '_' +
        py_util.get_time_str() + '.csv', 'w')
    out_file.write(constants.LOG_FILE + '\n')
    out_file.write(
        'question_type, answer_correct, answer, gt_answer, episode_length, invalid_action_percent, scene number, seed, required_interaction\n'
    )

    def test_function(thread_ind):
        testing_thread = testing_threads[thread_ind]
        sess.run(testing_thread.sync)
        #from game_state import QuestionGameState
        #if testing_thread.agent.game_state is None:
        #testing_thread.agent.game_state = QuestionGameState(sess=sess)
        while len(rows) > 0:
            time_lock.acquire()
            if len(rows) == 0:
                break
            row = rows.pop()
            time_lock.release()

            answer_correct, answer, gt_answer, ep_length, ep_reward, invalid_percent, scene_num, seed, required_interaction = testing_thread.process(
                row)
            question_type = row[1] + 1

            time_lock.acquire()
            output_str = (
                '%d, %d, %d, %d, %d, %f, %d, %d, %d\n' %
                (question_type, answer_correct, answer, gt_answer, ep_length,
                 invalid_percent, scene_num, seed, required_interaction))
            out_file.write(output_str)
            out_file.flush()
            answers_correct.append(int(answer_correct))
            ep_lengths.append(ep_length)
            ep_rewards.append(ep_reward)
            invalid_percents.append(invalid_percent)
            print('###############################')
            print('ep ', row)
            print('num episodes', len(answers_correct))
            print('average correct', np.mean(answers_correct))
            print('invalid percents', np.mean(invalid_percents),
                  np.median(invalid_percents))
            print('###############################')
            time_lock.release()

    test_threads = []
    for i in range(constants.PARALLEL_SIZE):
        test_threads.append(threading.Thread(target=test_function, args=(i, )))

    for t in test_threads:
        t.start()

    for t in test_threads:
        t.join()

    out_file.close()
Ejemplo n.º 7
0
def run():
    try:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(constants.GPU_ID)
        with tf.variable_scope('global_network'):
            network = QAPlannerNetwork(constants.RL_GRU_SIZE, int(constants.BATCH_SIZE), 1)
            network.create_net()
            training_step = network.training(network.rl_total_loss)

        conv_var_list = [v for v in tf.trainable_variables() if 'conv' in v.name and 'weight' in v.name and
                        (v.get_shape().as_list()[0] != 1 or v.get_shape().as_list()[1] != 1)]
        for var in conv_var_list:
            tf_util.conv_variable_summaries(var, scope=var.name.replace('/', '_')[:-2])
        summary_with_images = tf.summary.merge_all()

        with tf.variable_scope('supervised_loss'):
            loss_ph = tf.placeholder(tf.float32)
            accs_ph = [tf.placeholder(tf.float32) for _ in range(4)]
            loss_summary_op = tf.summary.merge([
                tf.summary.scalar('supervised_loss', loss_ph),
                tf.summary.scalar('acc_1_exist', accs_ph[0]),
                tf.summary.scalar('acc_2_count', accs_ph[1]),
                tf.summary.scalar('acc_3_contains', accs_ph[2]),
                ])

        # prepare session
        sess = tf_util.Session()
        sess.run(tf.global_variables_initializer())

        if not (constants.DEBUG or constants.DRAWING):
            from utils import py_util
            time_str = py_util.get_time_str()
            summary_writer = tf.summary.FileWriter(os.path.join(constants.LOG_FILE, time_str), sess.graph)
        else:
            summary_writer = None

        # init or load checkpoint with saver
        saver = tf.train.Saver(max_to_keep=3)
        start_it = tf_util.restore_from_dir(sess, constants.CHECKPOINT_DIR)

        sess.graph.finalize()

        import h5py
        h5_file = sorted(glob.glob('question_data_dump/*.h5'), key=os.path.getmtime)[-1]
        dataset = h5py.File(h5_file)
        num_entries = np.sum(np.sum(dataset['question_data/pose_placeholder'][...], axis=1) > 0)
        print('num_entries', num_entries)
        start_inds = dataset['question_data/new_episode'][:num_entries]
        start_inds = np.where(start_inds[1:] != 1)[0]

        curr_it = 0
        data_time_total = 0
        solver_time_total = 0
        total_time_total = 0

        for iteration in range(start_it, constants.MAX_TIME_STEP):
            if iteration == start_it or iteration % 10 == 1:
                current_time_start = time.time()
            t_start = time.time()

            rand_inds = np.sort(np.random.choice(start_inds, int(constants.BATCH_SIZE), replace=False))
            rand_inds = rand_inds.tolist()

            existence_answer_placeholder = dataset['question_data/existence_answer_placeholder'][rand_inds]
            counting_answer_placeholder = dataset['question_data/counting_answer_placeholder'][rand_inds]

            question_type_placeholder = dataset['question_data/question_type_placeholder'][rand_inds]
            question_object_placeholder = dataset['question_data/question_object_placeholder'][rand_inds]
            question_container_placeholder = dataset['question_data/question_container_placeholder'][rand_inds]

            pose_placeholder = dataset['question_data/pose_placeholder'][rand_inds]
            image_placeholder = dataset['question_data/image_placeholder'][rand_inds]
            map_mask_placeholder = np.ascontiguousarray(dataset['question_data/map_mask_placeholder'][rand_inds])

            meta_action_placeholder = dataset['question_data/meta_action_placeholder'][rand_inds]
            possible_move_placeholder = dataset['question_data/possible_move_placeholder'][rand_inds]

            taken_action = dataset['question_data/taken_action'][rand_inds]
            answer_weight = np.ones((constants.BATCH_SIZE, 1))

            map_mask_placeholder = np.ascontiguousarray(map_mask_placeholder, dtype=np.float32)
            map_mask_placeholder[:, :, :, :2] -= 2
            map_mask_placeholder[:, :, :, :2] /= constants.STEPS_AHEAD

            map_mask_starts = map_mask_placeholder.copy()

            for bb in range(0, constants.BATCH_SIZE):
                object_ind = int(question_object_placeholder[bb])
                question_type_ind = int(question_type_placeholder[bb])
                if question_type_ind in {2, 3}:
                    container_ind = np.argmax(question_container_placeholder[bb])

                max_map_inds = np.argmax(map_mask_placeholder[bb, ...], axis=2)
                map_range = np.where(max_map_inds > 0)
                map_range_x = (np.min(map_range[1]), np.max(map_range[1]))
                map_range_y = (np.min(map_range[0]), np.max(map_range[0]))
                for jj in range(random.randint(0, 100)):
                    tmp_patch_start = (random.randint(map_range_x[0], map_range_x[1]),
                            random.randint(map_range_y[0], map_range_y[1]))
                    tmp_patch_end = (random.randint(map_range_x[0], map_range_x[1]),
                            random.randint(map_range_y[0], map_range_y[1]))
                    patch_start = (min(tmp_patch_start[0], tmp_patch_end[0]),
                            min(tmp_patch_start[1], tmp_patch_end[1]))
                    patch_end = (max(tmp_patch_start[0], tmp_patch_end[0]),
                            max(tmp_patch_start[1], tmp_patch_end[1]))

                    patch = map_mask_placeholder[bb, patch_start[1]:patch_end[1], patch_start[0]:patch_end[0], :]
                    if question_type_ind in {2, 3}:
                        obj_mem = patch[:, :, 6 + container_ind] + patch[:, :, 6 + object_ind]
                    else:
                        obj_mem = patch[:, :, 6 + object_ind].copy()
                    obj_mem += patch[:, :, 2]  # make sure seen locations stay marked.
                    if patch.size > 0 and np.max(obj_mem) == 0:
                        map_mask_placeholder[bb, patch_start[1]:patch_end[1], patch_start[0]:patch_end[0], 6:] = 0
            feed_dict = {
                    network.existence_answer_placeholder: np.ascontiguousarray(existence_answer_placeholder),
                    network.counting_answer_placeholder: np.ascontiguousarray(counting_answer_placeholder),

                    network.question_type_placeholder: np.ascontiguousarray(question_type_placeholder),
                    network.question_object_placeholder: np.ascontiguousarray(question_object_placeholder),
                    network.question_container_placeholder: np.ascontiguousarray(question_container_placeholder),
                    network.question_direction_placeholder: np.zeros((constants.BATCH_SIZE, 4), dtype=np.float32),

                    network.pose_placeholder: np.ascontiguousarray(pose_placeholder),
                    network.image_placeholder: np.ascontiguousarray(image_placeholder),
                    network.map_mask_placeholder: map_mask_placeholder,
                    network.meta_action_placeholder: np.ascontiguousarray(meta_action_placeholder),
                    network.possible_move_placeholder: np.ascontiguousarray(possible_move_placeholder),

                    network.taken_action: np.ascontiguousarray(taken_action),
                    network.answer_weight: np.ascontiguousarray(answer_weight),
                    network.episode_length_placeholder: np.ones((constants.BATCH_SIZE)),
                    network.question_count_placeholder: np.zeros((constants.BATCH_SIZE)),
                    }
            new_feed_dict = {}
            for key,value in feed_dict.items():
                if len(value.squeeze().shape) > 1:
                    new_feed_dict[key] = np.reshape(value, [int(constants.BATCH_SIZE), 1] + list(value.squeeze().shape[1:]))
                else:
                    new_feed_dict[key] = np.reshape(value, [int(constants.BATCH_SIZE), 1])
            feed_dict = new_feed_dict
            feed_dict[network.taken_action] = np.reshape(feed_dict[network.taken_action], (constants.BATCH_SIZE, -1))
            feed_dict[network.gru_placeholder] = np.zeros((int(constants.BATCH_SIZE), constants.RL_GRU_SIZE))

            data_t_end = time.time()
            if constants.DEBUG or constants.DRAWING:
                outputs = sess.run(
                        [training_step, network.rl_total_loss, network.existence_answer, network.counting_answer,
                            network.possible_moves, network.memory_crops_rot, network.taken_action],
                        feed_dict=feed_dict)
            else:
                if iteration == start_it + 10:
                    run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                    run_metadata = tf.RunMetadata()
                    outputs = sess.run([training_step, network.rl_total_loss, summary_with_images],
                            feed_dict=feed_dict,
                            options=run_options,
                            run_metadata=run_metadata)
                    loss_summary = outputs[2]
                    summary_writer.add_run_metadata(run_metadata, 'step_%07d' % iteration)
                    summary_writer.add_summary(loss_summary, iteration)
                    summary_writer.flush()
                elif iteration % 10 == 0:
                    if iteration % 100 == 0:
                        outputs = sess.run(
                                [training_step, network.rl_total_loss, summary_with_images],
                                feed_dict=feed_dict)
                        loss_summary = outputs[2]
                    elif iteration % 10 == 0:
                        outputs = sess.run(
                                [training_step, network.rl_total_loss,
                                 network.existence_answer,
                                 network.counting_answer,
                                 ],
                                feed_dict=feed_dict)
                        outputs[2] = outputs[2].reshape(-1, 1)
                        acc_q0 = np.sum((existence_answer_placeholder == (outputs[2] > 0.5)) *
                            (question_type_placeholder == 0)) / np.maximum(1, np.sum(question_type_placeholder == 0))
                        acc_q1 = np.sum((counting_answer_placeholder == np.argmax(outputs[3], axis=1)[..., np.newaxis]) *
                            (question_type_placeholder == 1)) / np.maximum(1, np.sum(question_type_placeholder == 1))
                        acc_q2 = np.sum((existence_answer_placeholder == (outputs[2] > 0.5)) *
                            (question_type_placeholder == 2)) / np.maximum(1, np.sum(question_type_placeholder == 2))
                        acc_q3 = np.sum((existence_answer_placeholder == (outputs[2] > 0.5)) *
                            (question_type_placeholder == 3)) / np.maximum(1, np.sum(question_type_placeholder == 3))

                        curr_loss = outputs[1]
                        outputs = sess.run([loss_summary_op],
                                feed_dict={
                                    accs_ph[0]: acc_q0,
                                    accs_ph[1]: acc_q1,
                                    accs_ph[2]: acc_q2,
                                    accs_ph[3]: acc_q3,
                                    loss_ph: curr_loss,
                                    })

                        loss_summary = outputs[0]
                        outputs.append(curr_loss)
                    summary_writer.add_summary(loss_summary, iteration)
                    summary_writer.flush()
                else:
                    outputs = sess.run([training_step, network.rl_total_loss],
                            feed_dict=feed_dict)

            loss = outputs[1]
            solver_t_end = time.time()

            if constants.DEBUG or constants.DRAWING:
                # Look at outputs
                guess_bool = outputs[2].flatten()
                guess_count = outputs[3]
                possible_moves_pred = outputs[4]
                memory_crop = outputs[5]
                print('loss', loss)
                for bb in range(constants.BATCH_SIZE):
                    if constants.DRAWING:
                        import cv2
                        import scipy.misc
                        from utils import drawing
                        object_ind = int(question_object_placeholder[bb])
                        question_type_ind = question_type_placeholder[bb]
                        if question_type_ind == 1:
                            answer = counting_answer_placeholder[bb]
                            guess = guess_count[bb]
                        else:
                            answer = existence_answer_placeholder[bb]
                            guess = np.concatenate(([1 - guess_bool[bb]], [guess_bool[bb]]))
                        if question_type_ind[0] in {2, 3}:
                            container_ind = np.argmax(question_container_placeholder[bb])
                            obj_mem = np.flipud(map_mask_placeholder[bb, :, :, 6 + container_ind]).copy()
                            obj_mem += 2 * np.flipud(map_mask_placeholder[bb, :, :, 6 + object_ind])
                        else:
                            obj_mem = np.flipud(map_mask_placeholder[bb, :, :, 6 + object_ind])

                        possible = possible_move_placeholder[bb,...].flatten()
                        possible = np.concatenate((possible, np.zeros(4)))
                        possible = possible.reshape(constants.STEPS_AHEAD + 2, constants.STEPS_AHEAD)

                        possible_pred = possible_moves_pred[bb,...].flatten()
                        possible_pred = np.concatenate((possible_pred, np.zeros(4)))
                        possible_pred = possible_pred.reshape(constants.STEPS_AHEAD + 2, constants.STEPS_AHEAD)

                        mem2 = np.flipud(np.argmax(memory_crop[bb,...], axis=2))
                        mem2[0, 0] = memory_crop.shape[-1] - 2

                        # Answer histogram
                        ans = guess
                        if len(ans) == 1:
                            ans = [ans[0], 1 - ans[0]]
                        ans_size = max(len(ans), 100)
                        ans_hist = np.zeros((ans_size, ans_size))
                        for ii,ans_i in enumerate(ans):
                            ans_hist[:max(int(np.round(ans_i * ans_size)), 1),
                                    int(ii * ans_size / len(ans)):int((ii+1) * ans_size / len(ans))] = (ii + 1)
                        ans_hist = np.flipud(ans_hist)

                        image_list = [
                                image_placeholder[bb,...],
                                ans_hist,
                                np.flipud(possible),
                                np.flipud(possible_pred),
                                mem2,
                                np.flipud(np.argmax(map_mask_starts[bb, :, :, 2:], axis=2)),
                                obj_mem,
                                ]
                        if question_type_ind == 0:
                            question_str = 'Ex Q: %s A: %s' % (constants.OBJECTS[object_ind], bool(answer))
                        elif question_type_ind == 1:
                            question_str = '# Q: %s A: %d' % (constants.OBJECTS[object_ind], answer)
                        elif question_type_ind == 2:
                            question_str = 'Q: %s in %s A: %s' % (
                                    constants.OBJECTS[object_ind],
                                    constants.OBJECTS[container_ind],
                                    bool(answer))
                        image = drawing.subplot(image_list, 4, 2, constants.SCREEN_WIDTH, constants.SCREEN_HEIGHT,
                                titles=[question_str, 'A: %s' % str(np.argmax(guess))], border=2)
                        cv2.imshow('image', image[:, :, ::-1])
                        cv2.waitKey(0)
                    else:
                        pdb.set_trace()

            if not (constants.DEBUG or constants.DRAWING) and (iteration % 1000 == 0 or iteration == constants.MAX_TIME_STEP - 1):
                saver_t_start = time.time()
                tf_util.save(saver, sess, constants.CHECKPOINT_DIR, iteration)
                saver_t_end = time.time()
                print('Saver:     %.3f' % (saver_t_end - saver_t_start))

            curr_it += 1

            data_time_total += data_t_end - t_start
            solver_time_total += solver_t_end - data_t_end
            total_time_total += time.time() - t_start

            if iteration == start_it or (iteration) % 10 == 0:
                print('Iteration: %d' % (iteration))
                print('Loss:      %.3f' % loss)
                print('Data:      %.3f' % (data_time_total / curr_it))
                print('Solver:    %.3f' % (solver_time_total / curr_it))
                print('Total:     %.3f' % (total_time_total / curr_it))
                print('Current:   %.3f\n' % ((time.time() - current_time_start) / min(10, curr_it)))

    except:
        import traceback
        traceback.print_exc()
    finally:
        # Save final model
        if not (constants.DEBUG or constants.DRAWING):
            tf_util.save(saver, sess, constants.CHECKPOINT_DIR, iteration)