コード例 #1
0
def run():
    global global_t
    global dataset
    global data_ind
    if constants.OBJECT_DETECTION:
        from darknet_object_detection import detector
        detector.setup_detectors(constants.PARALLEL_SIZE)

    os.environ["CUDA_VISIBLE_DEVICES"] = str(constants.GPU_ID)

    try:
        with tf.variable_scope('global_network'):
            if constants.END_TO_END_BASELINE:
                global_network = EndToEndBaselineNetwork()
            else:
                global_network = QAPlannerNetwork(constants.RL_GRU_SIZE, 1, 1)
            global_network.create_net()
        if constants.USE_NAVIGATION_AGENT:
            with tf.variable_scope('nav_global_network') as net_scope:
                free_space_network = FreeSpaceNetwork(constants.GRU_SIZE, 1, 1)
                free_space_network.create_net()
        else:
            net_scope = None

        conv_var_list = [v for v in tf.trainable_variables() if 'conv' in v.name and 'weight' in v.name and
                        (v.get_shape().as_list()[0] != 1 or v.get_shape().as_list()[1] != 1)]
        for var in conv_var_list:
            tf_util.conv_variable_summaries(var, scope=var.name.replace('/', '_')[:-2])
        conv_image_summary = tf.summary.merge_all()

        # prepare session
        sess = tf_util.Session()

        # Instantiate singletons without scope.
        if constants.PREDICT_DEPTH:
            from depth_estimation_network import depth_estimator
            with tf.variable_scope('') as depth_scope:
                depth_estimator = depth_estimator.get_depth_estimator(sess)
        else:
            depth_scope = None

        training_threads = []

        learning_rate_input = tf.placeholder(tf.float32, name='learning_rate')
        grad_applier = RMSPropApplier(learning_rate=learning_rate_input,
                decay=constants.RMSP_ALPHA,
                momentum=0.0,
                epsilon=constants.RMSP_EPSILON,
                clip_norm=constants.GRAD_NORM_CLIP)

        for i in range(constants.PARALLEL_SIZE):
            training_thread = A3CTrainingThread(
                    i, sess,
                    learning_rate_input,
                    grad_applier,
                    constants.MAX_TIME_STEP,
                    net_scope,
                    depth_scope)
            training_threads.append(training_thread)

        if constants.RUN_TEST:
            testing_thread = A3CTestingThread(constants.PARALLEL_SIZE + 1, sess, net_scope, depth_scope)

        sess.run(tf.global_variables_initializer())

        # Initialize pretrained weights after init.
        if constants.PREDICT_DEPTH:
            depth_estimator.load_weights()


        episode_reward_input = tf.placeholder(tf.float32, name='ep_reward')
        episode_length_input = tf.placeholder(tf.float32, name='ep_length')
        exist_answer_correct_input = tf.placeholder(tf.float32, name='exist_ans')
        count_answer_correct_input = tf.placeholder(tf.float32, name='count_ans')
        contains_answer_correct_input = tf.placeholder(tf.float32, name='contains_ans')
        percent_invalid_actions_input = tf.placeholder(tf.float32, name='inv')

        scalar_summaries = [
            tf.summary.scalar("Episode Reward", episode_reward_input),
            tf.summary.scalar("Episode Length", episode_length_input),
            tf.summary.scalar("Percent Invalid Actions", percent_invalid_actions_input),
        ]
        exist_summary = tf.summary.scalar("Answer Correct Existence", exist_answer_correct_input),
        count_summary = tf.summary.scalar("Answer Correct Counting", count_answer_correct_input),
        contains_summary = tf.summary.scalar("Answer Correct Containing", contains_answer_correct_input),
        exist_summary_op = tf.summary.merge(scalar_summaries + [exist_summary])
        count_summary_op = tf.summary.merge(scalar_summaries + [count_summary])
        contains_summary_op = tf.summary.merge(scalar_summaries + [contains_summary])
        summary_ops = [exist_summary_op, count_summary_op, contains_summary_op]

        summary_placeholders = {
            "episode_reward_input": episode_reward_input,
            "episode_length_input": episode_length_input,
            "exist_answer_correct_input": exist_answer_correct_input,
            "count_answer_correct_input": count_answer_correct_input,
            "contains_answer_correct_input": contains_answer_correct_input,
            "percent_invalid_actions_input" : percent_invalid_actions_input,
            }

        if constants.RUN_TEST:
            test_episode_reward_input = tf.placeholder(tf.float32, name='test_ep_reward')
            test_episode_length_input = tf.placeholder(tf.float32, name='test_ep_length')
            test_exist_answer_correct_input = tf.placeholder(tf.float32, name='test_exist_ans')
            test_count_answer_correct_input = tf.placeholder(tf.float32, name='test_count_ans')
            test_contains_answer_correct_input = tf.placeholder(tf.float32, name='test_contains_ans')
            test_percent_invalid_actions_input = tf.placeholder(tf.float32, name='test_inv')

            test_scalar_summaries = [
                tf.summary.scalar("Test Episode Reward", test_episode_reward_input),
                tf.summary.scalar("Test Episode Length", test_episode_length_input),
                tf.summary.scalar("Test Percent Invalid Actions", test_percent_invalid_actions_input),
            ]
            exist_summary = tf.summary.scalar("Test Answer Correct Existence", test_exist_answer_correct_input),
            count_summary = tf.summary.scalar("Test Answer Correct Counting", test_count_answer_correct_input),
            contains_summary = tf.summary.scalar("Test Answer Correct Containing", test_contains_answer_correct_input),
            test_exist_summary_op = tf.summary.merge(test_scalar_summaries + [exist_summary])
            test_count_summary_op = tf.summary.merge(test_scalar_summaries + [count_summary])
            test_contains_summary_op = tf.summary.merge(test_scalar_summaries + [contains_summary])
            test_summary_ops = [test_exist_summary_op, test_count_summary_op, test_contains_summary_op]

        if not constants.DEBUG:
            time_str = py_util.get_time_str()
            summary_writer = tf.summary.FileWriter(os.path.join(constants.LOG_FILE, time_str), sess.graph)
        else:
            summary_writer = None

        # init or load checkpoint with saver
        vars_to_save = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='global_network')
        saver = tf.train.Saver(vars_to_save, max_to_keep=3)
        print('-------------- Looking for checkpoints in ', constants.CHECKPOINT_DIR)
        global_t = tf_util.restore_from_dir(sess, constants.CHECKPOINT_DIR)

        if constants.USE_NAVIGATION_AGENT:
            print('now trying to restore nav model')
            tf_util.restore_from_dir(sess, 'logs/checkpoints/navigation', True)

        sess.graph.finalize()

        for i in range(constants.PARALLEL_SIZE):
            sess.run(training_threads[i].sync)
            training_threads[i].agent.reset()

        times = []

        if not constants.DEBUG and constants.RECORD_FEED_DICT:
            import h5py
            NUM_RECORD = 10000 * len(constants.USED_QUESTION_TYPES)
            time_str = py_util.get_time_str()
            if not os.path.exists('question_data_dump'):
                os.mkdir('question_data_dump')
            dataset = h5py.File('question_data_dump/maps_' + time_str + '.h5', 'w')

            dataset.create_dataset('question_data/existence_answer_placeholder', (NUM_RECORD, 1), dtype=np.int32)
            dataset.create_dataset('question_data/counting_answer_placeholder', (NUM_RECORD, 1), dtype=np.int32)

            dataset.create_dataset('question_data/question_type_placeholder', (NUM_RECORD, 1), dtype=np.int32)
            dataset.create_dataset('question_data/question_object_placeholder', (NUM_RECORD, 1), dtype=np.int32)
            dataset.create_dataset('question_data/question_container_placeholder', (NUM_RECORD, 21), dtype=np.int32)

            dataset.create_dataset('question_data/pose_placeholder', (NUM_RECORD, 3), dtype=np.int32)
            dataset.create_dataset('question_data/image_placeholder', (NUM_RECORD, 300, 300, 3), dtype=np.uint8)
            dataset.create_dataset('question_data/map_mask_placeholder',
                    (NUM_RECORD, constants.SPATIAL_MAP_HEIGHT, constants.SPATIAL_MAP_WIDTH, 27), dtype=np.uint8)
            dataset.create_dataset('question_data/meta_action_placeholder', (NUM_RECORD, 7), dtype=np.int32)
            dataset.create_dataset('question_data/possible_move_placeholder', (NUM_RECORD, 31), dtype=np.int32)

            dataset.create_dataset('question_data/answer_weight', (NUM_RECORD, 1), dtype=np.int32)
            dataset.create_dataset('question_data/taken_action', (NUM_RECORD, 32), dtype=np.int32)
            dataset.create_dataset('question_data/new_episode', (NUM_RECORD, 1), dtype=np.int32)

        time_lock = threading.Lock()

        data_ind = 0

        def train_function(parallel_index):
            global global_t
            global dataset
            global data_ind
            print('----------------------------------------thread', parallel_index, 'global_t', global_t)
            training_thread = training_threads[parallel_index]
            last_global_t = global_t
            last_global_t_image = global_t

            while global_t < constants.MAX_TIME_STEP:
                diff_global_t, ep_length, ep_reward, num_unrolls, feed_dict = training_thread.process(global_t, summary_writer,
                        summary_ops, summary_placeholders)
                time_lock.acquire()

                if not constants.DEBUG and constants.RECORD_FEED_DICT:
                    print('    NEW ENTRY: %d %s' % (data_ind, training_thread.agent.game_state.scene_name))
                    dataset['question_data/new_episode'][data_ind] = 1
                    for k, v in feed_dict.items():
                        key = 'question_data/' + k.name.split('/')[-1].split(':')[0]
                        if any([s in key for s in {
                            'gru_placeholder', 'num_unrolls',
                            'reward_placeholder', 'td_placeholder',
                            'learning_rate', 'phi_hat_prev_placeholder',
                            'next_map_mask_placeholder', 'next_pose_placeholder',
                            'episode_length_placeholder', 'question_count_placeholder',
                            'supervised_action_labels', 'supervised_action_weights_sigmoid',
                            'supervised_action_weights', 'question_direction_placeholder',
                        }]):
                            continue
                        v = np.ascontiguousarray(v)
                        if v.shape[0] == 1:
                            v = v[0, ...]
                        if len(v.shape) == 1 and num_unrolls > 1:
                            v = v[:, np.newaxis]
                        if 'map_mask_placeholder' in key:
                            v[:, :, :, :2] *= constants.STEPS_AHEAD
                            v[:, :, :, :2] += 2
                        data_loc = dataset[key][data_ind:data_ind + num_unrolls, ...]
                        data_len = data_loc.shape[0]
                        dataset[key][data_ind:data_ind + num_unrolls, ...] = v[:data_len, ...]
                    dataset.flush()

                    data_ind += num_unrolls
                    if data_ind >= NUM_RECORD:
                        # Everything is done
                        dataset.close()
                        raise Exception('Fake exception to exit the process. Don\'t worry, everything is fine.')
                if ep_length > 0:
                    times.append((ep_length, ep_reward))
                    print('Num episodes', len(times), 'Episode means', np.mean(times, axis=0))
                time_lock.release()
                global_t += diff_global_t
                # periodically save checkpoints to disk
                if not (constants.DEBUG or constants.RECORD_FEED_DICT) and parallel_index == 0:
                    if global_t - last_global_t_image > 10000:
                        print('Ran conv image summary')
                        summary_im_str = sess.run(conv_image_summary)
                        summary_writer.add_summary(summary_im_str, global_t)
                        last_global_t_image = global_t
                    if global_t - last_global_t > 10000:
                        print('Save checkpoint at timestamp %d' % global_t)
                        tf_util.save(saver, sess, constants.CHECKPOINT_DIR, global_t)
                        last_global_t = global_t

        def test_function():
            global global_t
            last_global_t_test = 0
            while global_t < constants.MAX_TIME_STEP:
                time.sleep(1)
                if global_t - last_global_t_test > 10000:
                    # RUN TEST
                    sess.run(testing_thread.sync)
                    from game_state import QuestionGameState
                    if testing_thread.agent.game_state is None:
                        testing_thread.agent.game_state = QuestionGameState(sess=sess)
                    for q_type in constants.USED_QUESTION_TYPES:
                        answers_correct = 0.0
                        ep_lengths = 0.0
                        ep_rewards = 0.0
                        invalid_percents = 0.0
                        rows = list(range(len(testing_thread.agent.game_state.test_datasets[q_type])))
                        random.shuffle(rows)
                        rows = rows[:16]
                        print('()()()()()()()rows', rows)
                        for rr,row in enumerate(rows):
                            answer_correct, answer, gt_answer, ep_length, ep_reward, invalid_percent, scene_num, seed, required_interaction = testing_thread.process((row, q_type))
                            answers_correct += int(answer_correct)
                            ep_lengths += ep_length
                            ep_rewards += ep_reward
                            invalid_percents += invalid_percent
                            print('############################### TEST ITERATION ##################################')
                            print('ep ', (rr + 1))
                            print('average correct', answers_correct / (rr + 1))
                            print('#################################################################################')
                        answers_correct /= len(rows)
                        ep_lengths /= len(rows)
                        ep_rewards /= len(rows)
                        invalid_percents /= len(rows)

                        # Write the summary
                        test_summary_str = sess.run(test_summary_ops[q_type], feed_dict={
                            test_episode_reward_input : ep_rewards,
                            test_episode_length_input : ep_lengths,
                            test_exist_answer_correct_input : answers_correct,
                            test_count_answer_correct_input : answers_correct,
                            test_contains_answer_correct_input : answers_correct,
                            test_percent_invalid_actions_input : invalid_percents,
                            })
                        summary_writer.add_summary(test_summary_str, global_t)
                        summary_writer.flush()
                        last_global_t_test = global_t
                    testing_thread.agent.game_state.env.stop()
                    testing_thread.agent.game_state = None



        train_threads = []
        for i in range(constants.PARALLEL_SIZE):
            train_threads.append(threading.Thread(target=train_function, args=(i,)))
            train_threads[i].daemon = True

        for t in train_threads:
            t.start()

        if constants.RUN_TEST:
            test_thread = threading.Thread(target=test_function)
            test_thread.daemon = True
            test_thread.start()

        for t in train_threads:
            t.join()

        if constants.RUN_TEST:
            test_thread.join()

        if not constants.DEBUG:
            if not os.path.exists(constants.CHECKPOINT_DIR):
                os.makedirs(constants.CHECKPOINT_DIR)
            saver.save(sess, constants.CHECKPOINT_DIR + '/' + 'checkpoint', global_step = global_t)
            summary_writer.close()
            print('Saved.')

    except KeyboardInterrupt:
        print('Press Ctrl+C to stop')
    except:
        import traceback
        traceback.print_exc()
    finally:
        if not constants.DEBUG:
            print('Now saving data. Please wait')
            tf_util.save(saver, sess, constants.CHECKPOINT_DIR, global_t)
            summary_writer.close()
            print('Saved.')
コード例 #2
0
def run():
    try:
        with tf.variable_scope('nav_global_network'):
            network = FreeSpaceNetwork(constants.GRU_SIZE,
                                       constants.BATCH_SIZE,
                                       constants.NUM_UNROLLS)
            network.create_net()
            training_step = network.training_op

        with tf.variable_scope('loss'):
            loss_summary_op = tf.summary.merge([
                tf.summary.scalar('loss', network.loss),
            ])
        summary_full = tf.summary.merge_all()
        conv_var_list = [
            v for v in tf.trainable_variables()
            if 'conv' in v.name and 'weight' in v.name and
            (v.get_shape().as_list()[0] != 1 or v.get_shape().as_list()[1] != 1
             )
        ]
        for var in conv_var_list:
            tf_util.conv_variable_summaries(var,
                                            scope=var.name.replace('/',
                                                                   '_')[:-2])
        summary_with_images = tf.summary.merge_all()

        # prepare session
        sess = tf_util.Session()

        seq_inds = np.zeros((constants.BATCH_SIZE, 2), dtype=np.int32)

        sequence_generators = []
        for thread_index in range(constants.PARALLEL_SIZE):
            gpus = str(constants.GPU_ID).split(',')
            sequence_generator = SequenceGenerator(sess)
            sequence_generators.append(sequence_generator)

        sess.run(tf.global_variables_initializer())

        if not (constants.DEBUG or constants.DRAWING):
            from utils import py_util
            time_str = py_util.get_time_str()
            summary_writer = tf.summary.FileWriter(
                os.path.join(constants.LOG_FILE, time_str), sess.graph)
        else:
            summary_writer = None

        saver = tf.train.Saver(max_to_keep=3)

        # init or load checkpoint
        start_it = tf_util.restore_from_dir(sess, constants.CHECKPOINT_DIR)

        sess.graph.finalize()

        data_lock = threading.Lock()

        def load_new_data(thread_index):
            global data_buffer
            global data_counts

            sequence_generator = sequence_generators[thread_index]
            counter = 0
            while True:
                while not (len(data_buffer) < constants.REPLAY_BUFFER_SIZE
                           or np.max(data_counts) > 0):
                    time.sleep(1)
                counter += 1
                if constants.DEBUG:
                    print('\nThread %d' % thread_index)
                new_data, bounds, goal_pose = sequence_generator.generate_episode(
                )
                new_data = {
                    key: ([new_data[ii][key] for ii in range(len(new_data))])
                    for key in new_data[0]
                }
                new_data['goal_pose'] = goal_pose
                new_data['memory'] = np.zeros(
                    (constants.SPATIAL_MAP_HEIGHT, constants.SPATIAL_MAP_WIDTH,
                     constants.MEMORY_SIZE))
                new_data['gru_state'] = np.zeros(constants.GRU_SIZE)
                if constants.DRAWING:
                    new_data['debug_images'] = sequence_generator.debug_images
                data_lock.acquire()
                if len(data_buffer) < constants.REPLAY_BUFFER_SIZE:
                    data_counts[len(data_buffer)] = 0
                    data_buffer.append(new_data)
                    counts = data_counts[:len(data_buffer)]
                    if counter % 10 == 0:
                        print(
                            'Buffer size %d  Num used %d  Max used amount %d' %
                            (len(data_buffer), len(
                                counts[counts > 0]), np.max(counts)))
                else:
                    max_count_ind = np.argmax(data_counts)
                    data_buffer[max_count_ind] = new_data
                    data_counts[max_count_ind] = 0
                    if counter % 10 == 0:
                        print('Num used %d  Max used amount %d' %
                              (len(data_counts[data_counts > 0]),
                               np.max(data_counts)))
                data_lock.release()

        threads = []
        for i in range(constants.PARALLEL_SIZE):
            load_data_thread = threading.Thread(target=load_new_data,
                                                args=(i, ))
            load_data_thread.daemon = True
            load_data_thread.start()
            threads.append(load_data_thread)
            time.sleep(1)

        sequences = [None] * constants.BATCH_SIZE

        curr_it = 0
        dataTimeTotal = 0.00001
        solverTimeTotal = 0.00001
        summaryTimeTotal = 0.00001
        totalTimeTotal = 0.00001

        chosen_inds = set()
        loc_to_chosen_ind = {}
        for iteration in range(start_it, constants.MAX_TIME_STEP):
            if iteration == start_it or iteration % 10 == 1:
                currentTimeStart = time.time()
            tStart = time.time()
            batch_data = []
            batch_action = []
            batch_memory = []
            batch_gru_state = []
            batch_labels = []
            batch_pose = []
            batch_mask = []
            batch_goal_pose = []
            batch_pose_indicator = []
            batch_possible_label = []
            batch_debug_images = []
            for bb in range(constants.BATCH_SIZE):
                if seq_inds[bb, 0] == seq_inds[bb, 1]:
                    # Pick a new random sequence
                    pickable_inds = set(
                        np.where(data_counts < 100)[0]) - chosen_inds
                    count_size = len(pickable_inds)
                    while count_size == 0:
                        pickable_inds = set(
                            np.where(data_counts < 100)[0]) - chosen_inds
                        count_size = len(pickable_inds)
                        time.sleep(1)
                    random_ind = random.sample(pickable_inds, 1)[0]
                    data_lock.acquire()
                    sequences[bb] = data_buffer[random_ind]
                    goal_pose = sequences[bb]['goal_pose']
                    sequences[bb]['memory'] = np.zeros(
                        (constants.SPATIAL_MAP_HEIGHT,
                         constants.SPATIAL_MAP_WIDTH, constants.MEMORY_SIZE))
                    sequences[bb]['gru_state'] = np.zeros(constants.GRU_SIZE)
                    data_counts[random_ind] += 1
                    if bb in loc_to_chosen_ind:
                        chosen_inds.remove(loc_to_chosen_ind[bb])
                    loc_to_chosen_ind[bb] = random_ind
                    chosen_inds.add(random_ind)
                    data_lock.release()
                    seq_inds[bb, 0] = 0
                    seq_inds[bb, 1] = len(sequences[bb]['color'])
                data_len = min(constants.NUM_UNROLLS,
                               seq_inds[bb, 1] - seq_inds[bb, 0])
                ind0 = seq_inds[bb, 0]
                ind1 = seq_inds[bb, 0] + data_len
                data = sequences[bb]['color'][ind0:ind1]
                action = sequences[bb]['action'][ind0:ind1]
                labels = sequences[bb]['label'][ind0:ind1]
                memory = sequences[bb]['memory'].copy()
                gru_state = sequences[bb]['gru_state'].copy()
                pose = sequences[bb]['pose'][ind0:ind1]
                goal_pose = sequences[bb]['goal_pose']
                mask = sequences[bb]['weight'][ind0:ind1]
                pose_indicator = sequences[bb]['pose_indicator'][ind0:ind1]
                possible_label = sequences[bb]['possible_label'][ind0:ind1]
                if constants.DRAWING:
                    batch_debug_images.append(
                        sequences[bb]['debug_images'][ind0:ind1])
                if data_len < (constants.NUM_UNROLLS):
                    seq_inds[bb, :] = 0
                    data.extend([
                        np.zeros_like(data[0])
                        for _ in range(constants.NUM_UNROLLS - data_len)
                    ])
                    action.extend([
                        np.zeros_like(action[0])
                        for _ in range(constants.NUM_UNROLLS - data_len)
                    ])
                    labels.extend([
                        np.zeros_like(labels[0])
                        for _ in range(constants.NUM_UNROLLS - data_len)
                    ])
                    pose.extend([
                        pose[-1]
                        for _ in range(constants.NUM_UNROLLS - data_len)
                    ])
                    mask.extend([
                        np.zeros_like(mask[0])
                        for _ in range(constants.NUM_UNROLLS - data_len)
                    ])
                    pose_indicator.extend([
                        np.zeros_like(pose_indicator[0])
                        for _ in range(constants.NUM_UNROLLS - data_len)
                    ])
                    possible_label.extend([
                        np.zeros_like(possible_label[0])
                        for _ in range(constants.NUM_UNROLLS - data_len)
                    ])
                else:
                    seq_inds[bb, 0] += constants.NUM_UNROLLS
                batch_data.append(data)
                batch_action.append(action)
                batch_memory.append(memory)
                batch_gru_state.append(gru_state)
                batch_pose.append(pose)
                batch_goal_pose.append(goal_pose)
                batch_labels.append(labels)
                batch_mask.append(mask)
                batch_pose_indicator.append(pose_indicator)
                batch_possible_label.append(possible_label)

            feed_dict = {
                network.image_placeholder:
                np.ascontiguousarray(batch_data),
                network.action_placeholder:
                np.ascontiguousarray(batch_action),
                network.gru_placeholder:
                np.ascontiguousarray(batch_gru_state),
                network.pose_placeholder:
                np.ascontiguousarray(batch_pose),
                network.goal_pose_placeholder:
                np.ascontiguousarray(batch_goal_pose),
                network.labels_placeholder:
                np.ascontiguousarray(batch_labels)[..., np.newaxis],
                network.mask_placeholder:
                np.ascontiguousarray(batch_mask),
                network.pose_indicator_placeholder:
                np.ascontiguousarray(batch_pose_indicator),
                network.possible_label_placeholder:
                np.ascontiguousarray(batch_possible_label),
                network.memory_placeholders:
                np.ascontiguousarray(batch_memory),
            }
            dataTEnd = time.time()
            summaryTime = 0
            if constants.DEBUG or constants.DRAWING:
                outputs = sess.run([
                    training_step, network.loss, network.gru_state,
                    network.patch_weights_sigm, network.gru_outputs_full,
                    network.is_possible_sigm,
                    network.pose_indicator_placeholder,
                    network.terminal_patches, network.gru_outputs
                ],
                                   feed_dict=feed_dict)
            else:
                if iteration == start_it + 10:
                    run_options = tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE)
                    run_metadata = tf.RunMetadata()
                    outputs = sess.run([
                        training_step, network.loss, network.gru_state,
                        summary_with_images, network.gru_outputs
                    ],
                                       feed_dict=feed_dict,
                                       options=run_options,
                                       run_metadata=run_metadata)
                    loss_summary = outputs[3]
                    summary_writer.add_run_metadata(run_metadata,
                                                    'step_%07d' % iteration)
                    summary_writer.add_summary(loss_summary, iteration)
                    summary_writer.flush()
                elif iteration % 10 == 0:
                    if iteration % 100 == 0:
                        outputs = sess.run([
                            training_step, network.loss, network.gru_state,
                            summary_with_images, network.gru_outputs
                        ],
                                           feed_dict=feed_dict)
                    elif iteration % 10 == 0:
                        outputs = sess.run([
                            training_step, network.loss, network.gru_state,
                            loss_summary_op, network.gru_outputs
                        ],
                                           feed_dict=feed_dict)
                    loss_summary = outputs[3]
                    summaryTStart = time.time()
                    summary_writer.add_summary(loss_summary, iteration)
                    summary_writer.flush()
                    summaryTime = time.time() - summaryTStart
                else:
                    outputs = sess.run([
                        training_step, network.loss, network.gru_state,
                        network.gru_outputs
                    ],
                                       feed_dict=feed_dict)

            gru_state_out = outputs[2]
            memory_out = outputs[-1]
            for mm in range(constants.BATCH_SIZE):
                sequences[mm]['memory'] = memory_out[mm, ...]
                sequences[mm]['gru_state'] = gru_state_out[mm, ...]

            loss = outputs[1]
            solverTEnd = time.time()

            if constants.DEBUG or constants.DRAWING:
                # Look at outputs
                patch_weights = outputs[3]
                is_possible = outputs[5]
                pose_indicator = outputs[6]
                terminal_patches = outputs[7]
                data_lock.acquire()
                for bb in range(constants.BATCH_SIZE):
                    for tt in range(constants.NUM_UNROLLS):
                        if batch_mask[bb][tt] == 0:
                            break
                        if constants.DRAWING:
                            import cv2
                            import scipy.misc
                            from utils import drawing
                            curr_image = batch_data[bb][tt]
                            label = np.flipud(batch_labels[bb][tt])
                            debug_images = batch_debug_images[bb][tt]
                            color_image = debug_images['color']
                            state_image = debug_images['state_image']
                            label_memory_image = debug_images[
                                'label_memory'][:, :, 0]
                            label_memory_image_class = np.argmax(
                                debug_images['label_memory'][:, :, 1:], axis=2)
                            label_memory_image_class[0,
                                                     0] = constants.NUM_CLASSES

                            label_patch = debug_images['label']

                            print('Possible pred %.3f' % is_possible[bb, tt])
                            print('Possible label %.3f' %
                                  batch_possible_label[bb][tt])
                            patch = np.flipud(patch_weights[bb, tt, ...])
                            patch_occupancy = patch[:, :, 0]
                            print('occ', patch_occupancy)
                            print('label', label)
                            terminal_patch = np.flipud(
                                np.sum(terminal_patches[bb, tt, ...], axis=2))
                            image_list = [
                                debug_images['color'],
                                state_image,
                                debug_images['label_memory'][:, :, 0],
                                debug_images['memory_map'][:, :, 0],
                                label[:, :],
                                patch_occupancy,
                                np.flipud(pose_indicator[bb, tt]),
                                terminal_patch,
                            ]

                            image = drawing.subplot(image_list, 4, 2,
                                                    constants.SCREEN_WIDTH,
                                                    constants.SCREEN_HEIGHT)
                            cv2.imshow('image', image[:, :, ::-1])
                            cv2.waitKey(0)
                        else:
                            pdb.set_trace()
                data_lock.release()

            if not (constants.DEBUG or constants.DRAWING) and (
                    iteration % 500 == 0
                    or iteration == constants.MAX_TIME_STEP - 1):
                saverTStart = time.time()
                tf_util.save(saver, sess, constants.CHECKPOINT_DIR, iteration)
                saverTEnd = time.time()
                print('Saver:     %.3f' % (saverTEnd - saverTStart))

            curr_it += 1

            dataTimeTotal += dataTEnd - tStart
            summaryTimeTotal += summaryTime
            solverTimeTotal += solverTEnd - dataTEnd - summaryTime
            totalTimeTotal += time.time() - tStart

            if iteration == start_it or (iteration) % 10 == 0:
                print('Iteration: %d' % (iteration))
                print('Loss:      %.3f' % loss)
                print('Data:      %.3f' % (dataTimeTotal / curr_it))
                print('Solver:    %.3f' % (solverTimeTotal / curr_it))
                print('Summary:   %.3f' % (summaryTimeTotal / curr_it))
                print('Total:     %.3f' % (totalTimeTotal / curr_it))
                print('Current:   %.3f\n' %
                      ((time.time() - currentTimeStart) / min(10, curr_it)))

    except:
        import traceback
        traceback.print_exc()
    finally:
        # Save final model
        if not (constants.DEBUG or constants.DRAWING):
            tf_util.save(saver, sess, constants.CHECKPOINT_DIR, iteration)
コード例 #3
0
def main():
    if constants.OBJECT_DETECTION:
        from darknet_object_detection import detector
        detector.setup_detectors(constants.PARALLEL_SIZE)

    with tf.device('/gpu:' + str(constants.GPU_ID)):
        with tf.variable_scope('global_network'):
            if constants.END_TO_END_BASELINE:
                global_network = EndToEndBaselineNetwork()
            else:
                global_network = QAPlannerNetwork(constants.RL_GRU_SIZE, 1, 1)
            global_network.create_net()
        if constants.USE_NAVIGATION_AGENT:
            with tf.variable_scope('nav_global_network') as net_scope:
                free_space_network = FreeSpaceNetwork(constants.GRU_SIZE, 1, 1)
                free_space_network.create_net()
        else:
            net_scope = None

        # prepare session
        sess = tf_util.Session()

        if constants.PREDICT_DEPTH:
            from depth_estimation_network import depth_estimator
            with tf.variable_scope('') as depth_scope:
                depth_estimator = depth_estimator.get_depth_estimator(sess)
        else:
            depth_scope = None

        sess.run(tf.global_variables_initializer())

        # Initialize pretrained weights after init.
        if constants.PREDICT_DEPTH:
            depth_estimator.load_weights()

        testing_threads = []

        for i in range(constants.PARALLEL_SIZE):
            testing_thread = A3CTestingThread(i, sess, net_scope, depth_scope)
            testing_threads.append(testing_thread)

        tf_util.restore_from_dir(sess, constants.CHECKPOINT_DIR, True)

        if constants.USE_NAVIGATION_AGENT:
            print('now trying to restore nav model')
            tf_util.restore_from_dir(
                sess, os.path.join(constants.CHECKPOINT_PREFIX, 'navigation'),
                True)

    sess.graph.finalize()

    question_types = constants.USED_QUESTION_TYPES
    rows = []
    for q_type in question_types:
        curr_rows = list(
            range(len(testing_thread.agent.game_state.test_datasets[q_type])))
        #curr_rows = list(range(8))
        rows.extend(list(zip(curr_rows, [q_type] * len(curr_rows))))

    random.shuffle(rows)

    answers_correct = []
    ep_lengths = []
    ep_rewards = []
    invalid_percents = []
    time_lock = threading.Lock()
    if not os.path.exists(constants.LOG_FILE):
        os.makedirs(constants.LOG_FILE)
    out_file = open(
        constants.LOG_FILE + '/results_' + constants.TEST_SET + '_' +
        py_util.get_time_str() + '.csv', 'w')
    out_file.write(constants.LOG_FILE + '\n')
    out_file.write(
        'question_type, answer_correct, answer, gt_answer, episode_length, invalid_action_percent, scene number, seed, required_interaction\n'
    )

    def test_function(thread_ind):
        testing_thread = testing_threads[thread_ind]
        sess.run(testing_thread.sync)
        #from game_state import QuestionGameState
        #if testing_thread.agent.game_state is None:
        #testing_thread.agent.game_state = QuestionGameState(sess=sess)
        while len(rows) > 0:
            time_lock.acquire()
            if len(rows) == 0:
                break
            row = rows.pop()
            time_lock.release()

            answer_correct, answer, gt_answer, ep_length, ep_reward, invalid_percent, scene_num, seed, required_interaction = testing_thread.process(
                row)
            question_type = row[1] + 1

            time_lock.acquire()
            output_str = (
                '%d, %d, %d, %d, %d, %f, %d, %d, %d\n' %
                (question_type, answer_correct, answer, gt_answer, ep_length,
                 invalid_percent, scene_num, seed, required_interaction))
            out_file.write(output_str)
            out_file.flush()
            answers_correct.append(int(answer_correct))
            ep_lengths.append(ep_length)
            ep_rewards.append(ep_reward)
            invalid_percents.append(invalid_percent)
            print('###############################')
            print('ep ', row)
            print('num episodes', len(answers_correct))
            print('average correct', np.mean(answers_correct))
            print('invalid percents', np.mean(invalid_percents),
                  np.median(invalid_percents))
            print('###############################')
            time_lock.release()

    test_threads = []
    for i in range(constants.PARALLEL_SIZE):
        test_threads.append(threading.Thread(target=test_function, args=(i, )))

    for t in test_threads:
        t.start()

    for t in test_threads:
        t.join()

    out_file.close()
コード例 #4
0
        ],
                             dtype=np.int32)[:2]
        return (self.states, self.bounds, goal_pose)


if __name__ == '__main__':
    from networks.free_space_network import FreeSpaceNetwork
    from utils import tf_util
    import tensorflow as tf
    sess = tf_util.Session()

    with tf.variable_scope('nav_global_network'):
        network = FreeSpaceNetwork(constants.GRU_SIZE, 1, 1)
        network.create_net()
    sess.run(tf.global_variables_initializer())
    start_it = tf_util.restore_from_dir(sess, constants.CHECKPOINT_DIR)

    import cv2

    sequence_generator = SequenceGenerator(sess)
    sequence_generator.planner_prob = 1
    counter = 0
    while True:
        states, bounds, goal_pose = sequence_generator.generate_episode()
        images = sequence_generator.debug_images
        for im_dict in images:
            counter += 1

            gt_map = (2 - im_dict['label_memory'][:, :, 0])

            image_list = [
コード例 #5
0
def run():
    try:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(constants.GPU_ID)
        with tf.variable_scope('global_network'):
            network = QAPlannerNetwork(constants.RL_GRU_SIZE, int(constants.BATCH_SIZE), 1)
            network.create_net()
            training_step = network.training(network.rl_total_loss)

        conv_var_list = [v for v in tf.trainable_variables() if 'conv' in v.name and 'weight' in v.name and
                        (v.get_shape().as_list()[0] != 1 or v.get_shape().as_list()[1] != 1)]
        for var in conv_var_list:
            tf_util.conv_variable_summaries(var, scope=var.name.replace('/', '_')[:-2])
        summary_with_images = tf.summary.merge_all()

        with tf.variable_scope('supervised_loss'):
            loss_ph = tf.placeholder(tf.float32)
            accs_ph = [tf.placeholder(tf.float32) for _ in range(4)]
            loss_summary_op = tf.summary.merge([
                tf.summary.scalar('supervised_loss', loss_ph),
                tf.summary.scalar('acc_1_exist', accs_ph[0]),
                tf.summary.scalar('acc_2_count', accs_ph[1]),
                tf.summary.scalar('acc_3_contains', accs_ph[2]),
                ])

        # prepare session
        sess = tf_util.Session()
        sess.run(tf.global_variables_initializer())

        if not (constants.DEBUG or constants.DRAWING):
            from utils import py_util
            time_str = py_util.get_time_str()
            summary_writer = tf.summary.FileWriter(os.path.join(constants.LOG_FILE, time_str), sess.graph)
        else:
            summary_writer = None

        # init or load checkpoint with saver
        saver = tf.train.Saver(max_to_keep=3)
        start_it = tf_util.restore_from_dir(sess, constants.CHECKPOINT_DIR)

        sess.graph.finalize()

        import h5py
        h5_file = sorted(glob.glob('question_data_dump/*.h5'), key=os.path.getmtime)[-1]
        dataset = h5py.File(h5_file)
        num_entries = np.sum(np.sum(dataset['question_data/pose_placeholder'][...], axis=1) > 0)
        print('num_entries', num_entries)
        start_inds = dataset['question_data/new_episode'][:num_entries]
        start_inds = np.where(start_inds[1:] != 1)[0]

        curr_it = 0
        data_time_total = 0
        solver_time_total = 0
        total_time_total = 0

        for iteration in range(start_it, constants.MAX_TIME_STEP):
            if iteration == start_it or iteration % 10 == 1:
                current_time_start = time.time()
            t_start = time.time()

            rand_inds = np.sort(np.random.choice(start_inds, int(constants.BATCH_SIZE), replace=False))
            rand_inds = rand_inds.tolist()

            existence_answer_placeholder = dataset['question_data/existence_answer_placeholder'][rand_inds]
            counting_answer_placeholder = dataset['question_data/counting_answer_placeholder'][rand_inds]

            question_type_placeholder = dataset['question_data/question_type_placeholder'][rand_inds]
            question_object_placeholder = dataset['question_data/question_object_placeholder'][rand_inds]
            question_container_placeholder = dataset['question_data/question_container_placeholder'][rand_inds]

            pose_placeholder = dataset['question_data/pose_placeholder'][rand_inds]
            image_placeholder = dataset['question_data/image_placeholder'][rand_inds]
            map_mask_placeholder = np.ascontiguousarray(dataset['question_data/map_mask_placeholder'][rand_inds])

            meta_action_placeholder = dataset['question_data/meta_action_placeholder'][rand_inds]
            possible_move_placeholder = dataset['question_data/possible_move_placeholder'][rand_inds]

            taken_action = dataset['question_data/taken_action'][rand_inds]
            answer_weight = np.ones((constants.BATCH_SIZE, 1))

            map_mask_placeholder = np.ascontiguousarray(map_mask_placeholder, dtype=np.float32)
            map_mask_placeholder[:, :, :, :2] -= 2
            map_mask_placeholder[:, :, :, :2] /= constants.STEPS_AHEAD

            map_mask_starts = map_mask_placeholder.copy()

            for bb in range(0, constants.BATCH_SIZE):
                object_ind = int(question_object_placeholder[bb])
                question_type_ind = int(question_type_placeholder[bb])
                if question_type_ind in {2, 3}:
                    container_ind = np.argmax(question_container_placeholder[bb])

                max_map_inds = np.argmax(map_mask_placeholder[bb, ...], axis=2)
                map_range = np.where(max_map_inds > 0)
                map_range_x = (np.min(map_range[1]), np.max(map_range[1]))
                map_range_y = (np.min(map_range[0]), np.max(map_range[0]))
                for jj in range(random.randint(0, 100)):
                    tmp_patch_start = (random.randint(map_range_x[0], map_range_x[1]),
                            random.randint(map_range_y[0], map_range_y[1]))
                    tmp_patch_end = (random.randint(map_range_x[0], map_range_x[1]),
                            random.randint(map_range_y[0], map_range_y[1]))
                    patch_start = (min(tmp_patch_start[0], tmp_patch_end[0]),
                            min(tmp_patch_start[1], tmp_patch_end[1]))
                    patch_end = (max(tmp_patch_start[0], tmp_patch_end[0]),
                            max(tmp_patch_start[1], tmp_patch_end[1]))

                    patch = map_mask_placeholder[bb, patch_start[1]:patch_end[1], patch_start[0]:patch_end[0], :]
                    if question_type_ind in {2, 3}:
                        obj_mem = patch[:, :, 6 + container_ind] + patch[:, :, 6 + object_ind]
                    else:
                        obj_mem = patch[:, :, 6 + object_ind].copy()
                    obj_mem += patch[:, :, 2]  # make sure seen locations stay marked.
                    if patch.size > 0 and np.max(obj_mem) == 0:
                        map_mask_placeholder[bb, patch_start[1]:patch_end[1], patch_start[0]:patch_end[0], 6:] = 0
            feed_dict = {
                    network.existence_answer_placeholder: np.ascontiguousarray(existence_answer_placeholder),
                    network.counting_answer_placeholder: np.ascontiguousarray(counting_answer_placeholder),

                    network.question_type_placeholder: np.ascontiguousarray(question_type_placeholder),
                    network.question_object_placeholder: np.ascontiguousarray(question_object_placeholder),
                    network.question_container_placeholder: np.ascontiguousarray(question_container_placeholder),
                    network.question_direction_placeholder: np.zeros((constants.BATCH_SIZE, 4), dtype=np.float32),

                    network.pose_placeholder: np.ascontiguousarray(pose_placeholder),
                    network.image_placeholder: np.ascontiguousarray(image_placeholder),
                    network.map_mask_placeholder: map_mask_placeholder,
                    network.meta_action_placeholder: np.ascontiguousarray(meta_action_placeholder),
                    network.possible_move_placeholder: np.ascontiguousarray(possible_move_placeholder),

                    network.taken_action: np.ascontiguousarray(taken_action),
                    network.answer_weight: np.ascontiguousarray(answer_weight),
                    network.episode_length_placeholder: np.ones((constants.BATCH_SIZE)),
                    network.question_count_placeholder: np.zeros((constants.BATCH_SIZE)),
                    }
            new_feed_dict = {}
            for key,value in feed_dict.items():
                if len(value.squeeze().shape) > 1:
                    new_feed_dict[key] = np.reshape(value, [int(constants.BATCH_SIZE), 1] + list(value.squeeze().shape[1:]))
                else:
                    new_feed_dict[key] = np.reshape(value, [int(constants.BATCH_SIZE), 1])
            feed_dict = new_feed_dict
            feed_dict[network.taken_action] = np.reshape(feed_dict[network.taken_action], (constants.BATCH_SIZE, -1))
            feed_dict[network.gru_placeholder] = np.zeros((int(constants.BATCH_SIZE), constants.RL_GRU_SIZE))

            data_t_end = time.time()
            if constants.DEBUG or constants.DRAWING:
                outputs = sess.run(
                        [training_step, network.rl_total_loss, network.existence_answer, network.counting_answer,
                            network.possible_moves, network.memory_crops_rot, network.taken_action],
                        feed_dict=feed_dict)
            else:
                if iteration == start_it + 10:
                    run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                    run_metadata = tf.RunMetadata()
                    outputs = sess.run([training_step, network.rl_total_loss, summary_with_images],
                            feed_dict=feed_dict,
                            options=run_options,
                            run_metadata=run_metadata)
                    loss_summary = outputs[2]
                    summary_writer.add_run_metadata(run_metadata, 'step_%07d' % iteration)
                    summary_writer.add_summary(loss_summary, iteration)
                    summary_writer.flush()
                elif iteration % 10 == 0:
                    if iteration % 100 == 0:
                        outputs = sess.run(
                                [training_step, network.rl_total_loss, summary_with_images],
                                feed_dict=feed_dict)
                        loss_summary = outputs[2]
                    elif iteration % 10 == 0:
                        outputs = sess.run(
                                [training_step, network.rl_total_loss,
                                 network.existence_answer,
                                 network.counting_answer,
                                 ],
                                feed_dict=feed_dict)
                        outputs[2] = outputs[2].reshape(-1, 1)
                        acc_q0 = np.sum((existence_answer_placeholder == (outputs[2] > 0.5)) *
                            (question_type_placeholder == 0)) / np.maximum(1, np.sum(question_type_placeholder == 0))
                        acc_q1 = np.sum((counting_answer_placeholder == np.argmax(outputs[3], axis=1)[..., np.newaxis]) *
                            (question_type_placeholder == 1)) / np.maximum(1, np.sum(question_type_placeholder == 1))
                        acc_q2 = np.sum((existence_answer_placeholder == (outputs[2] > 0.5)) *
                            (question_type_placeholder == 2)) / np.maximum(1, np.sum(question_type_placeholder == 2))
                        acc_q3 = np.sum((existence_answer_placeholder == (outputs[2] > 0.5)) *
                            (question_type_placeholder == 3)) / np.maximum(1, np.sum(question_type_placeholder == 3))

                        curr_loss = outputs[1]
                        outputs = sess.run([loss_summary_op],
                                feed_dict={
                                    accs_ph[0]: acc_q0,
                                    accs_ph[1]: acc_q1,
                                    accs_ph[2]: acc_q2,
                                    accs_ph[3]: acc_q3,
                                    loss_ph: curr_loss,
                                    })

                        loss_summary = outputs[0]
                        outputs.append(curr_loss)
                    summary_writer.add_summary(loss_summary, iteration)
                    summary_writer.flush()
                else:
                    outputs = sess.run([training_step, network.rl_total_loss],
                            feed_dict=feed_dict)

            loss = outputs[1]
            solver_t_end = time.time()

            if constants.DEBUG or constants.DRAWING:
                # Look at outputs
                guess_bool = outputs[2].flatten()
                guess_count = outputs[3]
                possible_moves_pred = outputs[4]
                memory_crop = outputs[5]
                print('loss', loss)
                for bb in range(constants.BATCH_SIZE):
                    if constants.DRAWING:
                        import cv2
                        import scipy.misc
                        from utils import drawing
                        object_ind = int(question_object_placeholder[bb])
                        question_type_ind = question_type_placeholder[bb]
                        if question_type_ind == 1:
                            answer = counting_answer_placeholder[bb]
                            guess = guess_count[bb]
                        else:
                            answer = existence_answer_placeholder[bb]
                            guess = np.concatenate(([1 - guess_bool[bb]], [guess_bool[bb]]))
                        if question_type_ind[0] in {2, 3}:
                            container_ind = np.argmax(question_container_placeholder[bb])
                            obj_mem = np.flipud(map_mask_placeholder[bb, :, :, 6 + container_ind]).copy()
                            obj_mem += 2 * np.flipud(map_mask_placeholder[bb, :, :, 6 + object_ind])
                        else:
                            obj_mem = np.flipud(map_mask_placeholder[bb, :, :, 6 + object_ind])

                        possible = possible_move_placeholder[bb,...].flatten()
                        possible = np.concatenate((possible, np.zeros(4)))
                        possible = possible.reshape(constants.STEPS_AHEAD + 2, constants.STEPS_AHEAD)

                        possible_pred = possible_moves_pred[bb,...].flatten()
                        possible_pred = np.concatenate((possible_pred, np.zeros(4)))
                        possible_pred = possible_pred.reshape(constants.STEPS_AHEAD + 2, constants.STEPS_AHEAD)

                        mem2 = np.flipud(np.argmax(memory_crop[bb,...], axis=2))
                        mem2[0, 0] = memory_crop.shape[-1] - 2

                        # Answer histogram
                        ans = guess
                        if len(ans) == 1:
                            ans = [ans[0], 1 - ans[0]]
                        ans_size = max(len(ans), 100)
                        ans_hist = np.zeros((ans_size, ans_size))
                        for ii,ans_i in enumerate(ans):
                            ans_hist[:max(int(np.round(ans_i * ans_size)), 1),
                                    int(ii * ans_size / len(ans)):int((ii+1) * ans_size / len(ans))] = (ii + 1)
                        ans_hist = np.flipud(ans_hist)

                        image_list = [
                                image_placeholder[bb,...],
                                ans_hist,
                                np.flipud(possible),
                                np.flipud(possible_pred),
                                mem2,
                                np.flipud(np.argmax(map_mask_starts[bb, :, :, 2:], axis=2)),
                                obj_mem,
                                ]
                        if question_type_ind == 0:
                            question_str = 'Ex Q: %s A: %s' % (constants.OBJECTS[object_ind], bool(answer))
                        elif question_type_ind == 1:
                            question_str = '# Q: %s A: %d' % (constants.OBJECTS[object_ind], answer)
                        elif question_type_ind == 2:
                            question_str = 'Q: %s in %s A: %s' % (
                                    constants.OBJECTS[object_ind],
                                    constants.OBJECTS[container_ind],
                                    bool(answer))
                        image = drawing.subplot(image_list, 4, 2, constants.SCREEN_WIDTH, constants.SCREEN_HEIGHT,
                                titles=[question_str, 'A: %s' % str(np.argmax(guess))], border=2)
                        cv2.imshow('image', image[:, :, ::-1])
                        cv2.waitKey(0)
                    else:
                        pdb.set_trace()

            if not (constants.DEBUG or constants.DRAWING) and (iteration % 1000 == 0 or iteration == constants.MAX_TIME_STEP - 1):
                saver_t_start = time.time()
                tf_util.save(saver, sess, constants.CHECKPOINT_DIR, iteration)
                saver_t_end = time.time()
                print('Saver:     %.3f' % (saver_t_end - saver_t_start))

            curr_it += 1

            data_time_total += data_t_end - t_start
            solver_time_total += solver_t_end - data_t_end
            total_time_total += time.time() - t_start

            if iteration == start_it or (iteration) % 10 == 0:
                print('Iteration: %d' % (iteration))
                print('Loss:      %.3f' % loss)
                print('Data:      %.3f' % (data_time_total / curr_it))
                print('Solver:    %.3f' % (solver_time_total / curr_it))
                print('Total:     %.3f' % (total_time_total / curr_it))
                print('Current:   %.3f\n' % ((time.time() - current_time_start) / min(10, curr_it)))

    except:
        import traceback
        traceback.print_exc()
    finally:
        # Save final model
        if not (constants.DEBUG or constants.DRAWING):
            tf_util.save(saver, sess, constants.CHECKPOINT_DIR, iteration)
コード例 #6
0
 def restore(self):
     path = os.path.join(constants.CHECKPOINT_DIR)
     print('loading ', path)
     iteration = tf_util.restore_from_dir(self.sess, path)
     return iteration
コード例 #7
0
tf.Graph().as_default()
sess = tf.Session(config=config)
keras.backend.set_session(sess)

# load tracker
tracker = Tracker(sess,
                  MEM_SIZE,
                  IMG_SIZE,
                  FEATURE_SIZE,
                  ori_height=height,
                  ori_width=width,
                  iou_threshold=0.3,
                  kl_threshold=0.6)
sess.run(tf.global_variables_initializer())
log_dir = '/home/msis_dasol/master_thesis/RAN/for_paper/VGG16_skip_connection/memsize_5'
tf_util.restore_from_dir(sess, os.path.join(log_dir, 'checkpoints'))

# load detector
yolov3 = YOLOv3(sess)
total_tracking_obejct = 0

var_sizes = [
    np.product(list(map(int, v.shape))) * v.dtype.size
    for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
]
print('total parameter size :', sum(var_sizes) / (1024**2), 'MB')

print('Start tracking..')
for image_name in image_list:
    # load image
    image_path = DATA_PATH + image_name