示例#1
0
    def _one_hotting(self):
        '''One hotting is needed only when pickle files doesn't 
        exist at the first training period.'''
        path = self._look_up_path()
        pickles = [_file for _file in os.listdir(path) if '.pk' in _file]

        if (not self.inference and hasattr(self, 'datadir')
                and len(pickles) == 0):
            if task_relates_to(self.task, 'classification'):
                ONE_HOTTOR['classification'](self.datadir, path)
            if task_relates_to(self.task, 'segmentation'):
                ONE_HOTTOR['segmentation'](self.seglabels, path)
        if self.inference or hasattr(self, 'tfrecordir'):
            assert len(pickles) > 0, "not found needed pickle files."
示例#2
0
    def _generate_records(self, save_path, crxval_fold):
        '''Generate tfrecord files for a crxval fold.

        For segmentation related task, `instance` is `(image, mask)`,
        and `label` is needed in classification-segmentation task, or
        can be omitted in segmentation task.

        For classification task, `instance` is just a str, `mask` can 
        be omitted.

        param: save_path: str
            path for saving tfrecord files.
        param: crxval_fold: tuple
            (crxval_fold_dict, crxval_fold_index).
        yield: tfrecord files for this crxval fold.
        '''
        save_path = os.path.join(
            save_path,
            "%.5d-of-%.5d.tfrecords" % (crxval_fold[1], self.num_crxval))
        with tf.python_io.TFRecordWriter(save_path) as writer:
            for label in crxval_fold[0].keys():
                for instance in crxval_fold[0][label]:
                    if task_relates_to(self.task, 'segmentation'):
                        assert isinstance(instance, (tuple, list))
                        image, mask = instance[0], instance[1]
                        examples = self._convert_to_record(image, label, mask)
                    if self.task == 'classification':
                        assert isinstance(instance, str)
                        examples = self._convert_to_record(instance, label)
                    for example in examples:
                        writer.write(example.SerializeToString())
示例#3
0
    def _load_pickles(self, pickles):
        '''In training stage, whatever `datadir` or `tfrecordir`
        attribute exists, the path for storing pickle files can 
        always be found by calling `ImageProducer.look_up_path`.

        In inference stage, `datadir` may be changed. For instance, 
        in training stage, `datadir` is `app` in `datalibs` then 
        pickle files are in `datarecords.app.pickles`; however, in 
        inference stage, `datadir` is `app-infer`, then `pickles` should 
        be provided to find pickle files in `datarecords.app.pickles`.

        param: pickles: str or None
            path for storing pickle files.
        '''
        self.feature_keys = ['image']
        pickles = pickles if self.inference else self._look_up_path()
        self.one_hotting_codes, self.label_codes = dict(), dict()
        for task in SUPPORTED_BASE_TASKS:
            if task_relates_to(self.task, task):
                _pickle = os.path.join(pickles, '{}.pk'.format(task))
                if not os.path.exists(_pickle):
                    raise IOError("{}.pk doesn't be provided.".format(task))
                with open(_pickle, 'rb') as fp:
                    _loaded = load(fp)
                    self.one_hotting_codes[task] = \
                        _loaded["{}-one-hotting".format(task)]
                    self.label_codes[task] = _loaded['{}-label'.format(task)]
                self.feature_keys.append(task)
示例#4
0
 def _update_network(self, batches, sess):
     '''Optimize network.'''
     feed_dict = {
         self.network.input_images: batches['image'],
         self.network.eval_training: True,
         self.network.training: True
     }
     for task in SUPPORTED_BASE_TASKS:
         if task_relates_to(self.network.task, task):
             feed_dict[self.network.gts[task]] = batches[task]
     sess.run(self.network.optimizations, feed_dict=feed_dict)
示例#5
0
    def _check(self, datadir):
        '''Check reasonability of values of attributes.
        
        param: datadir: str
            image data directory or tfrecord files path.
        '''
        # check `datadir` or `tfrecordir` attribute
        # if `datadir` is tfrecords path, then the existence of tfrecod
        # files and pickle files will be checked.
        if datadir.rsplit(os.sep, 1)[-1] == 'records':  # tfrecords path
            if len([
                    _file
                    for _file in os.listdir(datadir) if '.tfrecords' in _file
            ]) <= 0:
                raise IOError("no threcord files exist.")
            self.tfrecordir = datadir

            _pickles = self._look_up_path()
            if (not os.path.exists(_pickles)
                    or len(os.listdir(_pickles)) <= 0):
                raise IOError("pickles must exist when using tfrecords.")

            self.using_records = True
            assert hasattr(self, 'num_crxval')
        else:  # image data directory
            if not os.path.isdir(datadir) or len(os.listdir(datadir)) <= 0:
                raise IOError("not found data in '{}'.".format(datadir))
            self.datadir = datadir

        # check `task``
        if self.task not in SUPPORTED_TASKS:
            raise ValueError("unrecognized task {}.".format(self.task))

        # check `seglabels`
        if task_relates_to(self.task, 'segmentation'):
            if (not isinstance(self.seglabels, (tuple, list))
                    or len(self.seglabels) <= 0):
                raise ValueError("no labels provided for segmentation.")

        # check `with_crop` and `with_pad`:
        # these two things are not consistent
        if all([self.with_crop, self.with_pad]):
            raise ValueError("'with_crop' and 'with_pad' should not "
                             "take True at the same time.")

        # check `input_size` and `input_channel`
        assert isinstance(self.input_channel, int)
        assert isinstance(self.input_size, (int, tuple, list))
        if isinstance(self.input_size, int):
            self.input_size = [self.input_size] * 2
        else:
            self.input_size = list(self.input_size)
示例#6
0
文件: loader.py 项目: CourantNg/dango
    def _load_optimizer(self, _sys, _train):
        '''Load optimizer.'''
        self.optimizers = dict()
        self.simplex = _train['simplex']

        if self.simplex:
            optimizer = _train['optimizer'][0]
            init_lr = _train['lr'][0]
            
            if _train['decay_for'] is None:
                lr = load_lr('naive')(init_lr, name=self.task)
            else:
                lr = load_lr(_train['decay_type'][0])(init_lr, 
                    _train['decay_step'][0], _train['decay_rate'][0],
                    name=self.task)
            self.optimizers[self.task] = load_optimizer(optimizer)(lr)
            msg = "Optimizer for [{}] has been loaded".format(self.task)
            tf.logging.info(msg)
        else:
            optimizers = Loader._convert_to_dict(
                _train['optimizer'], _sys['task'])
            init_lrs = Loader._convert_to_dict(
                _train['lr'], _sys['task'])
            
            decay_for = dict()
            for task in _sys['task']:
                decay_for[task] = False
                if _train['decay_for'] is None: continue
                if task in _train['decay_for']: decay_for[task] = True
            if _train['decay_for']:
                decay_types = Loader._convert_to_dict(
                    _train['decay_type'], _train['decay_for'])
                decay_steps = Loader._convert_to_dict(
                    _train['decay_step'], _train['decay_for'])
                decay_rates = Loader._convert_to_dict(
                    _train['decay_rate'], _train['decay_for'])

            for task in SUPPORTED_BASE_TASKS:
                if task_relates_to(self.task, task):
                    if decay_for[task]:
                        lr = load_lr(decay_types[task])(
                            init_lrs[task], decay_steps[task],
                            decay_rates[task], name=task)
                    else: lr = load_lr('naive')(init_lrs[task], name=task)
                    self.optimizers[task] = load_optimizer(
                        optimizers[task])(lr)
                    msg = "Optimizer for [{}] has been loaded.".format(task)
                    tf.logging.info(msg)
示例#7
0
    def _evaluate(self, phase):
        '''Evaluate training or validating performance.
        
        param: phase: str
            one of 'train' or 'validate'.
        return: merged:
            merged summary.
        '''
        all_collections = tf.get_collection(tf.GraphKeys.SUMMARIES)
        for task in SUPPORTED_BASE_TASKS:
            if task_relates_to(self.task, task):
                evaluation_name = "{}-{}-evaluation".format(task, phase)
                tf.summary.scalar(evaluation_name,
                                  self.evaluation[task],
                                  collections=[evaluation_name])
                all_collections += tf.get_collection(evaluation_name)

        return tf.summary.merge(all_collections)
示例#8
0
    def __init__(self,
                 input_size,
                 input_channel,
                 num_label,
                 batch=None,
                 loss=None,
                 task=None,
                 regularizer=None):
        '''
        param: input_size: int, list or tuple
            input size.
        param: input_channel: int
            input channel.
        param: num_label: int, tuple or dict
            number of labels for task(s) in SUPPORTED_BASE_TASKS.
        param: batch: int or None, optional
            batch size, will be provided generally.
        param: loss: str, tuple or dict
            type of data loss.
        param: task: str or tuple, optional
            task-type.
        param: regularizer:
            regularier function.
        '''
        # check `task`
        if not isinstance(task, (str, tuple)):
            raise TypeError("input 'task' should be str or str of tuple.")
        self.task = task if isinstance(task, str) else '-'.join(task)
        assert self.task in SUPPORTED_TASKS, "unrecognized task."

        # attribute `num_label`` and `loss`` are both dicts whose
        # keys are extracted from `task`.
        self.num_label = BaseNet._convert_to_dict(num_label, task)
        self.loss = BaseNet._convert_to_dict(loss, task)

        # input definition
        with tf.name_scope('inputs'):
            if task_relates_to(self.task, 'segmentation'):
                if batch is None:
                    raise ValueError(
                        "'batch' shouldn't be None in "
                        "segmentation related task, otherwise something "
                        "wrong may happen when using deconvolution.")
            self.input_shapes = BaseNet._generate_input_shapes(
                input_size, batch)

            self.input_images = tf.placeholder(
                shape=self.input_shapes['images'] + [input_channel],
                dtype=tf.float32,
                name='input-images')
            tf.summary.image('input-images', self.input_images, 3)
            tf.summary.histogram('input-images', self.input_images)

            self.gts = dict()
            for _task in SUPPORTED_BASE_TASKS:
                if task_relates_to(self.task, _task):
                    shape = self.input_shapes[_task] + [self.num_label[_task]]
                    self.gts[_task] = tf.placeholder(
                        shape=shape,
                        dtype=GTS_DTYPES[_task],
                        name="{}-labels".format(_task))

            if len(self.gts.keys()) == 0:
                raise IOError("no ground truth has been loaded yet.")

        self.training = tf.placeholder(tf.bool)  # for bn and/or dropout
        self.eval_training = tf.placeholder(tf.bool)  # for eavluation
        self.regularizer = regularizer
        self.logits = self.forward()

        tf.logging.info('trainabe variables'.center(100, '*'))
        [tf.logging.info(item) for item in tf.trainable_variables()]
        if self.logits.get('segmentation', None) is not None:
            BaseNet._segmentation_show(self.gts['segmentation'],
                                       'segmentation-inputs')
            BaseNet._segmentation_show(self.logits['segmentation'],
                                       'segmentation-outputs')
示例#9
0
    def __init__(self,
                 datadir,
                 task,
                 input_size=224,
                 input_channel=1,
                 inference=False,
                 pickles=None,
                 num_instance=None,
                 seglabels=None,
                 using_records=False,
                 val_fraction=0.1,
                 num_crxval=5,
                 crxval_index=0,
                 sizes=[224, 256],
                 side=None,
                 with_crop=False,
                 crop_num=6,
                 with_pad=True,
                 min_pad=40,
                 with_rotation=False,
                 rotation_range=15,
                 with_histogram_equalisation=False,
                 with_zero_centralisation=False,
                 with_normalisation=False,
                 capacity=1000,
                 min_after_dequeue=750,
                 num_threads=3):
        '''Refer to doc of `ImageProducer` defined in 
        `dango.dataio.image.image_producer`.

        param: datadir: str
            directory of image dataset. 
            in inference stage, this directory may be not in `datalibs`. 
        param: pickles: str or None, optional
            path of pickle files(should be provided in inference stage).
        param: crxval_index: int or None, optional
            indicating which tfrecord file is used for validation.
            if None, then no validation.
        param: with_rotation: bool
            indicating whether or not images will be rotated.
        param: rotation_range: float
            range of rotation: [-rotation_range, rotation_range].
        param: with_histogram_equalisation: bool
            indicating whether using histogram equalisation.
        param: with_zero_centralisation: bool
            indicating whether using zero centralisation.
        param: with_normalisation: bool
            indicating whether using normalisation.
        param: capacity: int
            capacity of input queue.
        param: min_after_dequeue: int
            minimal examples remained after dequeueing.
        param: num_threads: int
            number of threads for queueing input queue.
        '''
        super(ImageProvider, self).__init__(datadir=datadir,
                                            task=task,
                                            input_size=input_size,
                                            input_channel=input_channel,
                                            inference=inference,
                                            num_instance=num_instance,
                                            seglabels=seglabels,
                                            using_records=using_records,
                                            num_crxval=num_crxval,
                                            val_fraction=val_fraction,
                                            sizes=sizes,
                                            side=side,
                                            with_crop=with_crop,
                                            crop_num=crop_num,
                                            with_pad=with_pad,
                                            min_pad=min_pad)

        if task_relates_to(self.task, 'segmentation'):
            self.num_seglabels = len(self.seglabels)

        self.with_rotation = with_rotation
        self.rotation_range = rotation_range
        self.with_histogram_equalisation = with_histogram_equalisation
        self.with_zero_centralisation = with_zero_centralisation
        self.with_normalisation = with_normalisation

        self.capacity = capacity
        self.min_after_dequeue = min_after_dequeue
        self.num_threads = num_threads

        self._load_pickles(pickles)
        self._load_records(crxval_index)