Exemplo n.º 1
0
    def do_print(stream=None):
        print(util.log.get_date_str(), file=stream)
        print('\n# =========================================================================== #', file=stream)
        print('# Training flags:', file=stream)
        print('# =========================================================================== #', file=stream)

        def print_ckpt(path):
            ckpt = util.tf.get_latest_ckpt(path)
            if ckpt is not None:
                print('Resume Training from : %s' % (ckpt), file=stream)
                return True
            return False


        print('\n# =========================================================================== #', file=stream)
        print('# pixel_link net parameters:', file=stream)
        print('# =========================================================================== #', file=stream)
        vars = globals()
        for key in vars:
            var = vars[key]
            if util.dtype.is_number(var) or util.dtype.is_str(var) or util.dtype.is_list(var) or util.dtype.is_tuple(
                    var):
                pprint('%s=%s' % (key, str(var)), stream=stream)

        print('\n# =========================================================================== #', file=stream)
        print('# Training | Evaluation dataset files:', file=stream)
        print('# =========================================================================== #', file=stream)
        data_files = parallel_reader.get_data_files(dataset.data_sources)
        pprint(sorted(data_files), stream=stream)
        print('', file=stream)
    def print_config(stream=None):
        print(
            '\n# =========================================================================== #',
            file=stream)
        print('# Training | Evaluation flags:', file=stream)
        print(
            '# =========================================================================== #',
            file=stream)
        #pprint(flags, stream=stream)
        for k in flags:
            v = flags[k].value
            print("param [ {} ]: {}".format(k, v), file=stream)
        print(
            '\n# =========================================================================== #',
            file=stream)
        print('# SSD net parameters:', file=stream)
        print(
            '# =========================================================================== #',
            file=stream)
        pprint(dict(ssd_params._asdict()), stream=stream)

        print(
            '\n# =========================================================================== #',
            file=stream)
        print('# Training | Evaluation dataset files:', file=stream)
        print(
            '# =========================================================================== #',
            file=stream)
        data_files = parallel_reader.get_data_files(data_sources)
        pprint(sorted(data_files), stream=stream)
        print('', file=stream)
Exemplo n.º 3
0
    def print_config(stream=None):
        print(
            '\n# =========================================================================== #',
            file=stream)
        print('# Training | Evaluation flags:', file=stream)
        print(
            '# =========================================================================== #',
            file=stream)
        pprint(flags, stream=stream)

        print(
            '\n# =========================================================================== #',
            file=stream)
        print('# SSD net parameters:', file=stream)
        print(
            '# =========================================================================== #',
            file=stream)
        pprint(dict(ssd_params._asdict()), stream=stream)

        print(
            '\n# =========================================================================== #',
            file=stream)
        print('# Training | Evaluation dataset files:', file=stream)
        print(
            '# =========================================================================== #',
            file=stream)
        data_files = parallel_reader.get_data_files(data_sources)
        pprint(sorted(data_files), stream=stream)
        print('', file=stream)
Exemplo n.º 4
0
    def do_print(stream=None):
        print(logger.get_date_str(), file = stream)
        print('\n# =========================================================================== #', file=stream)
        print('# Training flags:', file=stream)
        print('# =========================================================================== #', file=stream)
        
        def print_ckpt(path):
            ckpt = ckpt_util.get_latest_ckpt(path)
            if ckpt is not None:
                print('Resume Training from : %s'%(ckpt), file = stream)
                return True
            return False
        
        if not print_ckpt(flags.train_dir):
            print_ckpt(flags.checkpoint_path)                
            
        pprint(flags.__flags, stream=stream)

        print('\n# =========================================================================== #', file=stream)
        print('# pixel_link net parameters:', file=stream)
        print('# =========================================================================== #', file=stream)
        vars = globals()
        for key in vars:
            var = vars[key]
            if is_number(var) or type(var)==str or type(var)==list or type(var)==tuple:
                pprint('%s=%s'%(key, str(var)), stream = stream)
            
        print('\n# =========================================================================== #', file=stream)
        print('# Training | Evaluation dataset files:', file=stream)
        print('# =========================================================================== #', file=stream)
        data_files = parallel_reader.get_data_files(dataset.data_sources)
        pprint(sorted(data_files), stream=stream)
        print('', file=stream)
Exemplo n.º 5
0
    def do_print(stream=None):
        print(
            '\n# =========================================================================== #',
            file=stream)
        print('# Training flags:', file=stream)
        print(
            '# =========================================================================== #',
            file=stream)
        pprint(flags.__flags, stream=stream)

        print(
            '\n# =========================================================================== #',
            file=stream)
        print('# seglink net parameters:', file=stream)
        print(
            '# =========================================================================== #',
            file=stream)
        vars = globals()
        for key in vars:
            var = vars[key]
            if util.dtype.is_number(var) or util.dtype.is_str(
                    var) or util.dtype.is_list(var) or util.dtype.is_tuple(
                        var):
                pprint('%s=%s' % (key, str(var)), stream=stream)

        print(
            '\n# =========================================================================== #',
            file=stream)
        print('# Training | Evaluation dataset files:', file=stream)
        print(
            '# =========================================================================== #',
            file=stream)
        data_files = parallel_reader.get_data_files(dataset.data_sources)
        pprint(sorted(data_files), stream=stream)
        print('', file=stream)
Exemplo n.º 6
0
    def do_print(stream=None):
        print(util.log.get_date_str(), file = stream)
        print('\n# =========================================================================== #', file=stream)
        print('# Training flags:', file=stream)
        print('# =========================================================================== #', file=stream)
        
        def print_ckpt(path):
            ckpt = util.tf.get_latest_ckpt(path)
            if ckpt is not None:
                print('Resume Training from : %s'%(ckpt), file = stream)
                return True
            return False
        
        if not print_ckpt(flags.train_dir):
            print_ckpt(flags.checkpoint_path)                
            
        pprint(flags.__flags, stream=stream)

        print('\n# =========================================================================== #', file=stream)
        print('# pixel_link net parameters:', file=stream)
        print('# =========================================================================== #', file=stream)
        vars = globals()
        for key in vars:
            var = vars[key]
            if util.dtype.is_number(var) or util.dtype.is_str(var) or util.dtype.is_list(var) or util.dtype.is_tuple(var):
                pprint('%s=%s'%(key, str(var)), stream = stream)
            
        print('\n# =========================================================================== #', file=stream)
        print('# Training | Evaluation dataset files:', file=stream)
        print('# =========================================================================== #', file=stream)
        data_files = parallel_reader.get_data_files(dataset.data_sources)
        pprint(sorted(data_files), stream=stream)
        print('', file=stream)
Exemplo n.º 7
0
    def print_config(stream=None):
        print('\n# =========================================================================== #', file=stream)
        print('# SSD 网络参数:', file=stream)
        print('# =========================================================================== #', file=stream)
        pprint(dict(ssd_params._asdict()), stream=stream)

        print('\n# =========================================================================== #', file=stream)
        print('# 训练数据dataset files:', file=stream)
        print('# =========================================================================== #', file=stream)
        data_files = parallel_reader.get_data_files(data_sources)
        pprint(sorted(data_files), stream=stream)
        print('', file=stream)
  def _verify_read_up_to_out(self, shared_queue):
    with self.test_session():
      num_files = 3
      num_records_per_file = 7
      tfrecord_paths = test_utils.create_tfrecord_files(
          self.get_temp_dir(),
          num_files=num_files,
          num_records_per_file=num_records_per_file)

    p_reader = parallel_reader.ParallelReader(
        io_ops.TFRecordReader, shared_queue, num_readers=5)

    data_files = parallel_reader.get_data_files(tfrecord_paths)
    filename_queue = input_lib.string_input_producer(data_files, num_epochs=1)
    key, value = p_reader.read_up_to(filename_queue, 4)

    count0 = 0
    count1 = 0
    count2 = 0
    all_keys_count = 0
    all_values_count = 0

    sv = supervisor.Supervisor(logdir=self.get_temp_dir())
    with sv.prepare_or_wait_for_session() as sess:
      sv.start_queue_runners(sess)
      while True:
        try:
          current_keys, current_values = sess.run([key, value])
          self.assertEquals(len(current_keys), len(current_values))
          all_keys_count += len(current_keys)
          all_values_count += len(current_values)
          for current_key in current_keys:
            if '0-of-3' in str(current_key):
              count0 += 1
            if '1-of-3' in str(current_key):
              count1 += 1
            if '2-of-3' in str(current_key):
              count2 += 1
        except errors_impl.OutOfRangeError:
          break

    self.assertEquals(count0, num_records_per_file)
    self.assertEquals(count1, num_records_per_file)
    self.assertEquals(count2, num_records_per_file)
    self.assertEquals(
        all_keys_count,
        num_files * num_records_per_file)
    self.assertEquals(all_values_count, all_keys_count)
    self.assertEquals(
        count0 + count1 + count2,
        all_keys_count)
Exemplo n.º 9
0
    def _verify_read_up_to_out(self, shared_queue):
        with self.test_session():
            num_files = 3
            num_records_per_file = 7
            tfrecord_paths = test_utils.create_tfrecord_files(
                self.get_temp_dir(),
                num_files=num_files,
                num_records_per_file=num_records_per_file)

        p_reader = parallel_reader.ParallelReader(io_ops.TFRecordReader,
                                                  shared_queue,
                                                  num_readers=5)

        data_files = parallel_reader.get_data_files(tfrecord_paths)
        filename_queue = input_lib.string_input_producer(data_files,
                                                         num_epochs=1)
        key, value = p_reader.read_up_to(filename_queue, 4)

        count0 = 0
        count1 = 0
        count2 = 0
        all_keys_count = 0
        all_values_count = 0

        sv = supervisor.Supervisor(logdir=self.get_temp_dir())
        with sv.prepare_or_wait_for_session() as sess:
            sv.start_queue_runners(sess)
            while True:
                try:
                    current_keys, current_values = sess.run([key, value])
                    self.assertEquals(len(current_keys), len(current_values))
                    all_keys_count += len(current_keys)
                    all_values_count += len(current_values)
                    for current_key in current_keys:
                        if '0-of-3' in str(current_key):
                            count0 += 1
                        if '1-of-3' in str(current_key):
                            count1 += 1
                        if '2-of-3' in str(current_key):
                            count2 += 1
                except errors_impl.OutOfRangeError:
                    break

        self.assertEquals(count0, num_records_per_file)
        self.assertEquals(count1, num_records_per_file)
        self.assertEquals(count2, num_records_per_file)
        self.assertEquals(all_keys_count, num_files * num_records_per_file)
        self.assertEquals(all_values_count, all_keys_count)
        self.assertEquals(count0 + count1 + count2, all_keys_count)
Exemplo n.º 10
0
    def print_config(stream=None):
        print('\n# =========================================================================== #', file=stream)
        print('# Training | Evaluation flags:', file=stream)
        print('# =========================================================================== #', file=stream)
        pprint(flags, stream=stream)

        print('\n# =========================================================================== #', file=stream)
        print('# SSD net parameters:', file=stream)
        print('# =========================================================================== #', file=stream)
        pprint(dict(ssd_params._asdict()), stream=stream)

        print('\n# =========================================================================== #', file=stream)
        print('# Training | Evaluation dataset files:', file=stream)
        print('# =========================================================================== #', file=stream)
        data_files = parallel_reader.get_data_files(data_sources)
        pprint(sorted(data_files), stream=stream)
        print('', file=stream)
    def _verify_all_data_sources_read(self, shared_queue):
        with self.cached_session():
            tfrecord_paths = test_utils.create_tfrecord_files(
                self.get_temp_dir(), num_files=3)

        num_readers = len(tfrecord_paths)
        p_reader = parallel_reader.ParallelReader(io_ops.TFRecordReader,
                                                  shared_queue,
                                                  num_readers=num_readers)

        data_files = parallel_reader.get_data_files(tfrecord_paths)
        filename_queue = input_lib.string_input_producer(data_files)
        key, value = p_reader.read(filename_queue)

        count0 = 0
        count1 = 0
        count2 = 0

        num_reads = 50

        sv = supervisor.Supervisor(logdir=self.get_temp_dir())
        with sv.prepare_or_wait_for_session() as sess:
            sv.start_queue_runners(sess)

            for _ in range(num_reads):
                current_key, _ = sess.run([key, value])
                if '0-of-3' in str(current_key):
                    count0 += 1
                if '1-of-3' in str(current_key):
                    count1 += 1
                if '2-of-3' in str(current_key):
                    count2 += 1

        self.assertGreater(count0, 0)
        self.assertGreater(count1, 0)
        self.assertGreater(count2, 0)
        self.assertEquals(count0 + count1 + count2, num_reads)
Exemplo n.º 12
0
  def _verify_all_data_sources_read(self, shared_queue):
    with self.test_session():
      tfrecord_paths = test_utils.create_tfrecord_files(
          self.get_temp_dir(), num_files=3)

    num_readers = len(tfrecord_paths)
    p_reader = parallel_reader.ParallelReader(
        io_ops.TFRecordReader, shared_queue, num_readers=num_readers)

    data_files = parallel_reader.get_data_files(tfrecord_paths)
    filename_queue = input_lib.string_input_producer(data_files)
    key, value = p_reader.read(filename_queue)

    count0 = 0
    count1 = 0
    count2 = 0

    num_reads = 50

    sv = supervisor.Supervisor(logdir=self.get_temp_dir())
    with sv.prepare_or_wait_for_session() as sess:
      sv.start_queue_runners(sess)

      for _ in range(num_reads):
        current_key, _ = sess.run([key, value])
        if '0-of-3' in str(current_key):
          count0 += 1
        if '1-of-3' in str(current_key):
          count1 += 1
        if '2-of-3' in str(current_key):
          count2 += 1

    self.assertGreater(count0, 0)
    self.assertGreater(count1, 0)
    self.assertGreater(count2, 0)
    self.assertEquals(count0 + count1 + count2, num_reads)