예제 #1
0
파일: cnn_model.py 프로젝트: b2220333/SLAM
 def __init__(self, batch_size, rgbd_input_batch, output_dim, normalization_epsilon):
     self.network_input = rgbd_input_batch
     self.output_dim = output_dim
     self.logger = get_logger()
     self.batch_size = batch_size
     self.total_weights = 0
     self.initial_params = np.load('resources/VGG_16_4ch.npy').item()
     self.initial_params = {key.encode('utf-8'):self.initial_params[key] for key in self.initial_params}
     self.logger.info('Weight keys:{}'.format(self.initial_params.keys()))
     self.epsilon = normalization_epsilon
예제 #2
0
파일: associate.py 프로젝트: b2220333/SLAM
def create_association_data(base_dir):
    logger = get_logger()
    filenames = [f for f in os.listdir(base_dir)]
    for filename in filenames:
        logger.info('Associating data set :{}'.format(filename))
        rgb_file = os.path.join(base_dir, filename, 'rgb.txt')
        depth_file = os.path.join(base_dir, filename, 'depth.txt')
        association_file = os.path.join(base_dir, filename, 'associate.txt')
        association_data = get_association(rgb_file, depth_file)
        with open(association_file, 'w') as fw:
            fw.write(association_data)
        logger.info('Wrote association data to :{}'.format(association_file))
예제 #3
0
import os

import tensorflow as tf
from slam.network.model_input import PoseNetInputProvider
from slam.utils.logging_utils import get_logger
from slam.network.google_net_noLRN import GoogleNet
"""
Posenet impl.
"""

if __name__ == '__main__':
    img_h = 224
    img_w = 224
    input_provider = PoseNetInputProvider()
    logger = get_logger()

    base_dir = '/usr/prakt/s085/posenet/'
    LOG_DIR = os.path.join(base_dir, 'logs/')
    LEARNED_WEIGHTS_FILENAME = os.path.join(
        base_dir, 'checkpoints/learned_weights.ckpt')

    epoch = 1000
    batch_size = 1

    rgb_input_batch = tf.placeholder(tf.float32, [batch_size, img_h, img_w, 3],
                                     name='rgbd_input')
    groundtruth_batch = tf.placeholder(tf.float32, [batch_size, 7],
                                       name='groundtruth')

    google_net = GoogleNet({'data': rgb_input_batch}, 7)
    loss = google_net.add_l2_loss(groundtruth_batch)
예제 #4
0
 def __init__(self, sequence_info, batch_size):
     self.sequence_info = sequence_info
     self.index = 0
     self.batch_size = batch_size
     self.logger = get_logger()
예제 #5
0
 def __init__(self, input_provider, seqdir_vs_offset, sequence_length):
     self.counter = 0
     self.seqdir_vs_offset = seqdir_vs_offset
     self.input_provider = input_provider
     self.sequence_length = sequence_length
     self.logger = get_logger()
예제 #6
0
 def __init__(self):
     self.config_provider = get_config_provider()
     training_filenames = self.config_provider.training_filenames()
     self.training_filenames = [os.path.join(self.BASE_DATA_DIR, filename) for filename in training_filenames]
     self.logger = get_logger()
     self.batch_size = len(self.training_filenames)