Beispiel #1
0
        def create():
            p_labeled = p_unlabeled = None
            para = max(1, len(utils.get_available_gpus())) * FLAGS.para_augment

            if FLAGS.p_unlabeled:
                sequence = FLAGS.p_unlabeled.split(',')
                p_unlabeled = np.array(list(map(float, sequence)), dtype=np.float32)
                p_unlabeled /= np.max(p_unlabeled)

            train_labeled = parse_fn(dataset([root + '-label.tfrecord']))
            train_unlabeled = parse_fn(dataset([root + '-unlabel.tfrecord']).skip(valid))
            if FLAGS.whiten:
                mean, std = compute_mean_std(train_labeled.concatenate(train_unlabeled))
            else:
                mean, std = 0, 1

            return cls(name + name_suffix + fullname + '-' + str(valid),
                       train_labeled=fn(train_labeled).map(augment[0], para),
                       train_unlabeled=fn(train_unlabeled).map(augment[1], para),
                       eval_labeled=parse_fn(dataset([root + '-label.tfrecord'])),
                       eval_unlabeled=parse_fn(dataset([root + '-unlabel.tfrecord']).skip(valid)),
                       valid=parse_fn(dataset([root + '-unlabel.tfrecord']).take(valid)),
                       test=parse_fn(dataset([os.path.join(DATA_DIR, '%s-test.tfrecord' % name)])),
                       nclass=nclass, colors=colors, p_labeled=p_labeled, p_unlabeled=p_unlabeled,
                       height=height, width=width, mean=mean, std=std)
Beispiel #2
0
 def parse(self):
     if self.parse_fn:
         para = 4 * max(1, len(utils.get_available_gpus())) * FLAGS.para_parse
         if self.image_shape:
             return self.map(lambda x: self.parse_fn(x, self.image_shape), para)
         else:
             return self.map(self.parse_fn, para)
     return self
Beispiel #3
0
    def to_record_iterator(self,
                           filenames,
                           size,
                           resize,
                           repeat=False,
                           random_flip_x=False,
                           limit=0,
                           crop=(0, 0),
                           para=4):
        para *= 4 * max(1, len(utils.get_available_gpus()))
        filenames = sorted(sum([glob.glob(x) for x in filenames], []))
        if not filenames:
            raise ValueError('Empty dataset, did you mount gcsfuse bucket?')
        dataset = tf.data.TFRecordDataset(filenames)
        if limit is not None:
            if limit > 0:
                dataset = dataset.take(limit)
            elif limit < 0:
                dataset = dataset.skip(-limit)
        if repeat:
            dataset = dataset.repeat()

        def fpy(image):
            # Faster than using tf primitives.
            image = image.transpose(2, 0, 1)
            if random_flip_x and np.random.randint(2):
                image = image[:, :, ::-1]
            return image

        def f(image, *args):
            delta = [0, 0]
            if sum(crop):
                image = image[crop[0]:-crop[0], crop[1]:-crop[1]]
                delta[0] -= 2 * crop[0]
                delta[1] -= 2 * crop[1]
            if resize[0] - delta[0] != size[0] or resize[1] - delta[1] != size[
                    1]:
                image = tf.image.resize_area([image], list(resize))[0]
            image = tf.py_func(fpy, [image], tf.float32)
            image = tf.reshape(image, [size[-1]] + list(resize))
            return (image, ) + args

        dataset = dataset.map(self.record_parse_fn, num_parallel_calls=para)
        dataset = dataset.filter(
            lambda image, *_: tf.equal(tf.shape(image)[2], size[2]))
        dataset = dataset.map(f, para)
        return dataset.map(self.iterator_dict)
Beispiel #4
0
        def create():
            para = max(1, len(utils.get_available_gpus())) * FLAGS.para_augment
            # valid: sample size of a validation set
            train_labeled = parse_fn(dataset(train_files).skip(valid))
            if FLAGS.whiten:
                mean, std = compute_mean_std(train_labeled)
            else:
                mean, std = 0, 1

            return cls(
                name + name_suffix + fullname + '-' + str(valid),
                train_labeled=fn(train_labeled).map(augment, para),
                eval_labeled=train_labeled.take(
                    5000),  # No need to eval on everything.
                valid=parse_fn(dataset(train_files).take(valid)),
                test=parse_fn(dataset(test_files)),
                nclass=nclass,
                colors=colors,
                height=height,
                width=width,
                mean=mean,
                std=std)
Beispiel #5
0
    def train(self, dataset):
        batch = FLAGS.batch

        with dataset.graph.as_default():
            train_data = dataset.train.batch(batch)
            train_data = train_data.prefetch(16)
            train_data = iter(as_iterator(train_data, dataset.sess))

        with tf.Graph().as_default():
            global_step = tf.train.get_or_create_global_step()
            ops = self.model(dataset=dataset,
                             total_steps=(FLAGS.total_kimg << 10) // batch,
                             **self.params)
            self.add_summaries(dataset, ops, **self.params)
            stop_hook = tf.train.StopAtStepHook(
                last_step=1 + (FLAGS.total_kimg << 10) // batch)
            report_hook = utils.HookReport(FLAGS.report_kimg << 10, batch)
            config = tf.ConfigProto()
            if len(utils.get_available_gpus()) > 1:
                config.allow_soft_placement = True
            if FLAGS.log_device_placement:
                config.log_device_placement = True
            config.gpu_options.allow_growth = True

            with tf.train.MonitoredTrainingSession(
                    checkpoint_dir=self.checkpoint_dir,
                    config=config,
                    hooks=[stop_hook],
                    chief_only_hooks=[report_hook],
                    save_checkpoint_secs=600,
                    save_summaries_steps=(FLAGS.save_kimg << 10) //
                    batch) as sess:
                self.sess = sess
                self.nimg_cur = batch * self.tf_sess.run(global_step)
                while not sess.should_stop():
                    self.train_step(train_data, ops)
                    self.nimg_cur = batch * self.tf_sess.run(global_step)
Beispiel #6
0
 def augment(self):
     if self.augment_fn:
         para = max(1, len(utils.get_available_gpus())) * FLAGS.para_augment
         return self.map(self.augment_fn.tf, para)
     return self
Beispiel #7
0
def init_pool():
    global POOL
    if POOL is None:
        para = max(1, len(utils.get_available_gpus())) * FLAGS.para_augment
        POOL = multiprocessing.Pool(para)
Beispiel #8
0
def default_parse(dataset: tf.data.Dataset, parse_fn=record_parse) -> tf.data.Dataset:
    para = 4 * max(1, len(utils.get_available_gpus())) * FLAGS.para_parse
    return dataset.map(parse_fn, num_parallel_calls=para)
Beispiel #9
0
    def train(self, dataset, schedule):
        assert isinstance(schedule, TrainSchedule)
        batch = FLAGS.batch
        resume_step = utils.get_latest_global_step_in_subdir(
            self.checkpoint_dir)
        phase_start = schedule.phase_index(resume_step * batch)
        checkpoint_dir = lambda stage: os.path.join(self.checkpoint_dir,
                                                    'stage_%d' % stage)

        for phase in schedule.schedule[phase_start:]:
            print('Resume step %d  Phase %dK:%dK  LOD %d:%d' %
                  (resume_step, phase.nimg_start >> 10, phase.nimg_stop >> 10,
                   phase.lod_start, phase.lod_stop))
            assert isinstance(phase, TrainPhase)

            def lod_fn():
                return phase.lod(self.nimg_cur)

            with dataset.graph.as_default():
                train_data = dataset.train.batch(batch)
                train_data = train_data.prefetch(64)
                train_data = iter(as_iterator(train_data, dataset.sess))

            with tf.Graph().as_default():
                global_step = tf.train.get_or_create_global_step()
                ops = self.model(dataset=dataset,
                                 lod_start=phase.lod_start,
                                 lod_stop=phase.lod_stop,
                                 lod_max=schedule.lod_max,
                                 total_steps=schedule.total_nimg // batch,
                                 **self.params)
                self.add_summaries(dataset, ops, lod_fn, **self.params)
                stop_hook = tf.train.StopAtStepHook(
                    last_step=phase.nimg_stop // batch)
                report_hook = utils.HookReport(FLAGS.report_kimg << 10, batch)
                config = tf.ConfigProto()
                if len(utils.get_available_gpus()) > 1:
                    config.allow_soft_placement = True
                if FLAGS.log_device_placement:
                    config.log_device_placement = True
                config.gpu_options.allow_growth = True
                config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF

                # When growing the model, load the previously trained layer weights.
                stage_step_last = utils.get_latest_global_step(
                    checkpoint_dir(phase.lod_stop - 1))
                stage_step = utils.get_latest_global_step(
                    checkpoint_dir(phase.lod_stop))
                if stage_step_last and not stage_step:
                    last_checkpoint = utils.find_latest_checkpoint(
                        checkpoint_dir(phase.lod_stop - 1))
                    tf.train.init_from_checkpoint(
                        last_checkpoint,
                        {x: x
                         for x in self.stage_scopes(phase.lod_stop - 1)})

                with tf.train.MonitoredTrainingSession(
                        checkpoint_dir=checkpoint_dir(phase.lod_stop),
                        config=config,
                        hooks=[stop_hook],
                        chief_only_hooks=[report_hook],
                        save_checkpoint_secs=600,
                        save_summaries_steps=(FLAGS.save_kimg << 10) //
                        batch) as sess:
                    self.sess = sess
                    self.nimg_cur = batch * self.tf_sess.run(global_step)
                    while not sess.should_stop():
                        self.train_step(train_data, lod_fn(), ops)
                        resume_step = self.tf_sess.run(global_step)
                        self.nimg_cur = batch * resume_step