def create(): p_labeled = p_unlabeled = None para = max(1, len(utils.get_available_gpus())) * FLAGS.para_augment if FLAGS.p_unlabeled: sequence = FLAGS.p_unlabeled.split(',') p_unlabeled = np.array(list(map(float, sequence)), dtype=np.float32) p_unlabeled /= np.max(p_unlabeled) train_labeled = parse_fn(dataset([root + '-label.tfrecord'])) train_unlabeled = parse_fn(dataset([root + '-unlabel.tfrecord']).skip(valid)) if FLAGS.whiten: mean, std = compute_mean_std(train_labeled.concatenate(train_unlabeled)) else: mean, std = 0, 1 return cls(name + name_suffix + fullname + '-' + str(valid), train_labeled=fn(train_labeled).map(augment[0], para), train_unlabeled=fn(train_unlabeled).map(augment[1], para), eval_labeled=parse_fn(dataset([root + '-label.tfrecord'])), eval_unlabeled=parse_fn(dataset([root + '-unlabel.tfrecord']).skip(valid)), valid=parse_fn(dataset([root + '-unlabel.tfrecord']).take(valid)), test=parse_fn(dataset([os.path.join(DATA_DIR, '%s-test.tfrecord' % name)])), nclass=nclass, colors=colors, p_labeled=p_labeled, p_unlabeled=p_unlabeled, height=height, width=width, mean=mean, std=std)
def parse(self): if self.parse_fn: para = 4 * max(1, len(utils.get_available_gpus())) * FLAGS.para_parse if self.image_shape: return self.map(lambda x: self.parse_fn(x, self.image_shape), para) else: return self.map(self.parse_fn, para) return self
def to_record_iterator(self, filenames, size, resize, repeat=False, random_flip_x=False, limit=0, crop=(0, 0), para=4): para *= 4 * max(1, len(utils.get_available_gpus())) filenames = sorted(sum([glob.glob(x) for x in filenames], [])) if not filenames: raise ValueError('Empty dataset, did you mount gcsfuse bucket?') dataset = tf.data.TFRecordDataset(filenames) if limit is not None: if limit > 0: dataset = dataset.take(limit) elif limit < 0: dataset = dataset.skip(-limit) if repeat: dataset = dataset.repeat() def fpy(image): # Faster than using tf primitives. image = image.transpose(2, 0, 1) if random_flip_x and np.random.randint(2): image = image[:, :, ::-1] return image def f(image, *args): delta = [0, 0] if sum(crop): image = image[crop[0]:-crop[0], crop[1]:-crop[1]] delta[0] -= 2 * crop[0] delta[1] -= 2 * crop[1] if resize[0] - delta[0] != size[0] or resize[1] - delta[1] != size[ 1]: image = tf.image.resize_area([image], list(resize))[0] image = tf.py_func(fpy, [image], tf.float32) image = tf.reshape(image, [size[-1]] + list(resize)) return (image, ) + args dataset = dataset.map(self.record_parse_fn, num_parallel_calls=para) dataset = dataset.filter( lambda image, *_: tf.equal(tf.shape(image)[2], size[2])) dataset = dataset.map(f, para) return dataset.map(self.iterator_dict)
def create(): para = max(1, len(utils.get_available_gpus())) * FLAGS.para_augment # valid: sample size of a validation set train_labeled = parse_fn(dataset(train_files).skip(valid)) if FLAGS.whiten: mean, std = compute_mean_std(train_labeled) else: mean, std = 0, 1 return cls( name + name_suffix + fullname + '-' + str(valid), train_labeled=fn(train_labeled).map(augment, para), eval_labeled=train_labeled.take( 5000), # No need to eval on everything. valid=parse_fn(dataset(train_files).take(valid)), test=parse_fn(dataset(test_files)), nclass=nclass, colors=colors, height=height, width=width, mean=mean, std=std)
def train(self, dataset): batch = FLAGS.batch with dataset.graph.as_default(): train_data = dataset.train.batch(batch) train_data = train_data.prefetch(16) train_data = iter(as_iterator(train_data, dataset.sess)) with tf.Graph().as_default(): global_step = tf.train.get_or_create_global_step() ops = self.model(dataset=dataset, total_steps=(FLAGS.total_kimg << 10) // batch, **self.params) self.add_summaries(dataset, ops, **self.params) stop_hook = tf.train.StopAtStepHook( last_step=1 + (FLAGS.total_kimg << 10) // batch) report_hook = utils.HookReport(FLAGS.report_kimg << 10, batch) config = tf.ConfigProto() if len(utils.get_available_gpus()) > 1: config.allow_soft_placement = True if FLAGS.log_device_placement: config.log_device_placement = True config.gpu_options.allow_growth = True with tf.train.MonitoredTrainingSession( checkpoint_dir=self.checkpoint_dir, config=config, hooks=[stop_hook], chief_only_hooks=[report_hook], save_checkpoint_secs=600, save_summaries_steps=(FLAGS.save_kimg << 10) // batch) as sess: self.sess = sess self.nimg_cur = batch * self.tf_sess.run(global_step) while not sess.should_stop(): self.train_step(train_data, ops) self.nimg_cur = batch * self.tf_sess.run(global_step)
def augment(self): if self.augment_fn: para = max(1, len(utils.get_available_gpus())) * FLAGS.para_augment return self.map(self.augment_fn.tf, para) return self
def init_pool(): global POOL if POOL is None: para = max(1, len(utils.get_available_gpus())) * FLAGS.para_augment POOL = multiprocessing.Pool(para)
def default_parse(dataset: tf.data.Dataset, parse_fn=record_parse) -> tf.data.Dataset: para = 4 * max(1, len(utils.get_available_gpus())) * FLAGS.para_parse return dataset.map(parse_fn, num_parallel_calls=para)
def train(self, dataset, schedule): assert isinstance(schedule, TrainSchedule) batch = FLAGS.batch resume_step = utils.get_latest_global_step_in_subdir( self.checkpoint_dir) phase_start = schedule.phase_index(resume_step * batch) checkpoint_dir = lambda stage: os.path.join(self.checkpoint_dir, 'stage_%d' % stage) for phase in schedule.schedule[phase_start:]: print('Resume step %d Phase %dK:%dK LOD %d:%d' % (resume_step, phase.nimg_start >> 10, phase.nimg_stop >> 10, phase.lod_start, phase.lod_stop)) assert isinstance(phase, TrainPhase) def lod_fn(): return phase.lod(self.nimg_cur) with dataset.graph.as_default(): train_data = dataset.train.batch(batch) train_data = train_data.prefetch(64) train_data = iter(as_iterator(train_data, dataset.sess)) with tf.Graph().as_default(): global_step = tf.train.get_or_create_global_step() ops = self.model(dataset=dataset, lod_start=phase.lod_start, lod_stop=phase.lod_stop, lod_max=schedule.lod_max, total_steps=schedule.total_nimg // batch, **self.params) self.add_summaries(dataset, ops, lod_fn, **self.params) stop_hook = tf.train.StopAtStepHook( last_step=phase.nimg_stop // batch) report_hook = utils.HookReport(FLAGS.report_kimg << 10, batch) config = tf.ConfigProto() if len(utils.get_available_gpus()) > 1: config.allow_soft_placement = True if FLAGS.log_device_placement: config.log_device_placement = True config.gpu_options.allow_growth = True config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF # When growing the model, load the previously trained layer weights. stage_step_last = utils.get_latest_global_step( checkpoint_dir(phase.lod_stop - 1)) stage_step = utils.get_latest_global_step( checkpoint_dir(phase.lod_stop)) if stage_step_last and not stage_step: last_checkpoint = utils.find_latest_checkpoint( checkpoint_dir(phase.lod_stop - 1)) tf.train.init_from_checkpoint( last_checkpoint, {x: x for x in self.stage_scopes(phase.lod_stop - 1)}) with tf.train.MonitoredTrainingSession( checkpoint_dir=checkpoint_dir(phase.lod_stop), config=config, hooks=[stop_hook], chief_only_hooks=[report_hook], save_checkpoint_secs=600, save_summaries_steps=(FLAGS.save_kimg << 10) // batch) as sess: self.sess = sess self.nimg_cur = batch * self.tf_sess.run(global_step) while not sess.should_stop(): self.train_step(train_data, lod_fn(), ops) resume_step = self.tf_sess.run(global_step) self.nimg_cur = batch * resume_step