def visualize(datagen, batch_size, view_size=4): """ Read the batch from 'datagen' and display 'view_size' number of of images and their corresponding Ground Truth """ def prep_imgs(img, ann): cmap = plt.get_cmap('viridis') # cmap may randomly fails if of other types ann = ann.astype('float32') ann_chs = np.dsplit(ann, ann.shape[-1]) for i, ch in enumerate(ann_chs): ch = np.squeeze(ch) # normalize to -1 to 1 range else # cmap may behave stupidly ch = ch / (np.max(ch) - np.min(ch) + 1.0e-16) # take RGB from RGBA heat map ann_chs[i] = cmap(ch)[..., :3] img = img.astype('float32') / 255.0 prepped_img = np.concatenate([img] + ann_chs, axis=1) return prepped_img assert view_size <= batch_size, 'Number of displayed images must <= batch size' ds = RepeatedData(datagen, -1) ds.reset_state() for imgs, segs in ds.get_data(): for idx in range(0, view_size): displayed_img = prep_imgs(imgs[idx], segs[idx]) plt.subplot(view_size, 1, idx + 1) plt.imshow(displayed_img) plt.show() return
def visualize(datagen, batch_size): """ Read the batch from 'datagen' and display 'view_size' number of of images and their corresponding Ground Truth """ cfg = Config() def prep_imgs(img, lab): # Deal with HxWx1 case img = np.squeeze(img) if cfg.model_mode == "seg_gland" or cfg.model_mode == "seg_nuc": cmap = plt.get_cmap("jet") # cmap may randomly fails if of other types lab = lab.astype("float32") lab_chs = np.dsplit(lab, lab.shape[-1]) for i, ch in enumerate(lab_chs): ch = np.squeeze(ch) # cmap may behave stupidly ch = ch / (np.max(ch) - np.min(ch) + 1.0e-16) # take RGB from RGBA heat map lab_chs[i] = cmap(ch)[..., :3] img = img.astype("float32") / 255.0 prepped_img = np.concatenate([img] + lab_chs, axis=1) else: prepped_img = img return prepped_img ds = RepeatedData(datagen, -1) ds.reset_state() for imgs, labs in ds.get_data(): if cfg.model_mode == "seg_gland" or cfg.model_mode == "seg_nuc": for idx in range(0, 4): displayed_img = prep_imgs(imgs[idx], labs[idx]) # plot the image and the label plt.subplot(4, 1, idx + 1) plt.imshow(displayed_img, vmin=-1, vmax=1) plt.axis("off") plt.show() else: for idx in range(0, 8): displayed_img = prep_imgs(imgs[idx], labs[idx]) # plot the image and the label plt.subplot(2, 4, idx + 1) plt.imshow(displayed_img) if len(cfg.label_names) > 0: lab_title = cfg.label_names[int(labs[idx])] else: lab_tite = int(labs[idx]) plt.title(lab_title) plt.axis("off") plt.show() return
def get_train_dataflow(batch_size=2): print("In train dataflow") roidbs = list(itertools.chain.from_iterable(DatasetRegistry.get(x).training_roidbs() for x in cfg.DATA.TRAIN)) print_class_histogram(roidbs) print("Done loading roidbs") # Filter out images that have no gt boxes, but this filter shall not be applied for testing. # The model does support training with empty images, but it is not useful for COCO. num = len(roidbs) roidbs = list(filter(lambda img: len(img["boxes"][img["is_crowd"] == 0]) > 0, roidbs)) logger.info( "Filtered {} images which contain no non-crowd groudtruth boxes. Total #images for training: {}".format( num - len(roidbs), len(roidbs) ) ) aspect_grouping = [1] aspect_ratios = [float(x["height"]) / float(x["width"]) for x in roidbs] group_ids = _quantize(aspect_ratios, aspect_grouping) ds = AspectGroupingDataFlow(roidbs, group_ids, batch_size=batch_size, drop_uneven=True) preprocess = TrainingDataPreprocessor() buffer_size = cfg.DATA.NUM_WORKERS * 10 # ds = MultiProcessMapData(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size) ds = MultiThreadMapData(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size) ds.reset_state() # to get an infinite data flow ds = RepeatedData(ds, num=-1) dataiter = ds.__iter__() return dataiter
def build_iter(self): ds = DataFromGenerator(self.generator) ds = RepeatedData(ds, -1) ds = BatchData(ds, self.batch_size) if not cfg.TRAIN.vis: ds = PrefetchDataZMQ(ds, self.process_num) ds.reset_state() ds = ds.get_data() return ds
def visualize(datagen, batch_size, view_size=4, aug_only=False, preview=False): """ Read the batch from 'datagen' and display 'view_size' number of of images and their corresponding Ground Truth """ def prep_imgs(img, ann): cmap = plt.get_cmap("viridis") # cmap may randomly fails if of other types ann = ann.astype("float32") ann_chs = np.dsplit(ann, ann.shape[-1]) for i, ch in enumerate(ann_chs): ch = np.squeeze(ch) # normalize to -1 to 1 range else # cmap may behave stupidly ch = ch / (np.max(ch) - np.min(ch) + 1.0e-16) # take RGB from RGBA heat map ann_chs[i] = cmap(ch)[..., :3] img = img.astype("float32") / 255.0 prepped_img = np.concatenate([img] + ann_chs, axis=1) return prepped_img assert view_size <= batch_size, "Number of displayed images must <= batch size" ds = RepeatedData(datagen, -1) ds.reset_state() for imgs, segs in ds.get_data(): for idx in range(0, view_size): displayed_img = prep_imgs(imgs[idx], segs[idx]) plt.subplot(view_size, 1, idx + 1) if aug_only: plt.imshow(imgs[idx]) # displayed_img else: plt.imshow(displayed_img) plt.savefig(f"{str(tempfile.NamedTemporaryFile().name)}.png") plt.show() if preview: break return
def __init__(self, ds, infinite=True): """ Args: ds (DataFlow): the input DataFlow. infinite (bool): When set to False, will raise StopIteration when ds is exhausted. """ if not isinstance(ds, DataFlow): raise ValueError("FeedInput takes a DataFlow! Got {}".format(ds)) self.ds = ds if infinite: self._iter_ds = RepeatedData(self.ds, -1) else: self._iter_ds = self.ds
def __init__(self, ds, queue=None): """ Args: ds(DataFlow): the input DataFlow. queue (tf.QueueBase): A :class:`tf.QueueBase` whose type should match the corresponding InputDesc of the model. Defaults to a FIFO queue of size 50. """ if not isinstance(ds, DataFlow): raise ValueError("QueueInput takes a DataFlow! Got {}".format(ds)) self.queue = queue self.ds = ds self._inf_ds = RepeatedData(ds, -1) self._started = False
def get_val_dataflow(datadir, batch_size, augmentors=None, parallel=None, num_splits=None, split_index=None, dataname="val"): if augmentors is None: augmentors = fbresnet_augmentor(False) assert datadir is not None assert isinstance(augmentors, list) if parallel is None: parallel = min(40, multiprocessing.cpu_count()) if num_splits is None: ds = dataset.ILSVRC12Files(datadir, dataname, shuffle=True) else: # shard validation data assert False assert split_index < num_splits files = dataset.ILSVRC12Files(datadir, dataname, shuffle=True) files.reset_state() files = list(files.get_data()) logger.info("Number of validation data = {}".format(len(files))) split_size = len(files) // num_splits start, end = split_size * split_index, split_size * (split_index + 1) end = min(end, len(files)) logger.info("Local validation split = {} - {}".format(start, end)) files = files[start:end] ds = DataFromList(files, shuffle=True) aug = imgaug.AugmentorList(augmentors) def mapf(dp): fname, cls = dp im = cv2.imread(fname, cv2.IMREAD_COLOR) #from BGR to RGB im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = aug.augment(im) return im, cls ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=min(2000, ds.size()), strict=True) ds = BatchData(ds, batch_size, remainder=False) ds = RepeatedData(ds, num=-1) # do not fork() under MPI return ds
type=int) args = parser.parse_args() logger.auto_set_dir(action='d') get_config_func = imp.load_source('config_script', args.config).get_config config = get_config_func() config.dataset.reset_state() if args.output: mkdir_p(args.output) cnt = 0 index = args.index # TODO: as an argument? for dp in config.dataset.get_data(): imgbatch = dp[index] if cnt > args.number: break for bi, img in enumerate(imgbatch): cnt += 1 fname = os.path.join(args.output, '{:03d}-{}.png'.format(cnt, bi)) cv2.imwrite(fname, img * args.scale) NR_DP_TEST = args.number logger.info("Testing dataflow speed:") ds = RepeatedData(config.dataset, -1) with tqdm.tqdm(total=NR_DP_TEST, leave=True, unit='data points') as pbar: for idx, dp in enumerate(ds.get_data()): del dp if idx > NR_DP_TEST: break pbar.update()
default=10, type=int) args = parser.parse_args() logger.auto_set_dir(action='d') get_config_func = imp.load_source('config_script', args.config).get_config config = get_config_func() config.dataset.reset_state() if args.output: mkdir_p(args.output) cnt = 0 index = args.index # TODO: as an argument? for dp in config.dataset.get_data(): imgbatch = dp[index] if cnt > args.number: break for bi, img in enumerate(imgbatch): cnt += 1 fname = os.path.join(args.output, '{:03d}-{}.png'.format(cnt, bi)) cv2.imwrite(fname, img * args.scale) NR_DP_TEST = args.number logger.info("Testing dataflow speed:") ds = RepeatedData(config.dataset, -1) with tqdm.tqdm(total=NR_DP_TEST, leave=True, unit='data points') as pbar: for idx, dp in enumerate(ds.get_data()): del dp if idx > NR_DP_TEST: break pbar.update()
def train(args: Namespace, data_params: MoleculeData, experiment: Experiment, mol_metrics: GraphMolecularMetrics) -> None: ds_train = create_dataflow(args.data_dir, 'train', args.batch_size) ds_train_repeat = PrefetchDataZMQ(ds_train, nr_proc=1) # times 2, because we consume 2 batches per step ds_train_repeat = RepeatedData(ds_train_repeat, 2 * args.epochs) train_input_fn = experiment.make_train_fn(ds_train_repeat, args.batch_size, args.num_latent, data_params) def hooks_fn(train_ops: MolGANTrainOps, train_steps: tfgan.GANTrainSteps) -> EstimatorTrainHooks: if train_ops.valuenet_train_op is not None: generator_hook = FeedableTrainOpsHook( train_ops.generator_train_op, train_steps.generator_train_steps, train_input_fn, return_feed_dict=False) discriminator_hook = WithRewardTrainOpsHook([ train_ops.discriminator_train_op, train_ops.valuenet_train_op ], train_steps.discriminator_train_steps, train_input_fn, mol_metrics) else: generator_hook = FeedableTrainOpsHook( train_ops.generator_train_op, train_steps.generator_train_steps, train_input_fn, return_feed_dict=True) discriminator_hook = FeedableTrainOpsHook( train_ops.discriminator_train_op, train_steps.discriminator_train_steps, train_input_fn) return [generator_hook, discriminator_hook] model = experiment.make_model_fn(args, data_params, hooks_fn) sess_config = tf.ConfigProto() sess_config.gpu_options.allow_growth = True # enable XLA JIT # sess_config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 config = tf.estimator.RunConfig(model_dir=str(args.model_dir), session_config=sess_config, save_summary_steps=ds_train.size(), save_checkpoints_secs=None, save_checkpoints_steps=4 * ds_train.size(), keep_checkpoint_max=2) estimator = tf.estimator.Estimator(model.model_fn, config=config) train_hooks = [PrintParameterSummary()] if args.restore_from_checkpoint is not None: train_hooks.append( RestoreFromCheckpointHook(str(args.restore_from_checkpoint))) if args.debug: from tensorflow.python import debug as tf_debug train_hooks.append(tf_debug.TensorBoardDebugHook("localhost:6064")) predict_fn = experiment.make_predict_fn(args.data_dir, args.num_latent, n_samples=1000, batch_size=1000) ckpt_listener = PredictAndEvalMolecule(estimator, predict_fn, mol_metrics, str(args.model_dir)) hparams_setter = [ ScheduledHyperParamSetter('generator_learning_rate:0', args.generator_learning_rate, [(80, 0.5 * args.generator_learning_rate), (150, 0.1 * args.generator_learning_rate), (200, 0.01 * args.generator_learning_rate)], steps_per_epoch=ds_train.size()), ScheduledHyperParamSetter( 'discriminator_learning_rate:0', args.discriminator_learning_rate, [(80, 0.5 * args.discriminator_learning_rate), (150, 0.1 * args.discriminator_learning_rate), (200, 0.01 * args.discriminator_learning_rate)], steps_per_epoch=ds_train.size()) ] train_hooks.extend(hparams_setter) if args.weight_reward_loss > 0: if args.weight_reward_loss_schedule == 'linear': lambda_setter = ScheduledHyperParamSetter( model.params, 'lam', [(args.reward_loss_delay, 1.0), (args.epochs, 1.0 - args.weight_reward_loss)], True) elif args.weight_reward_loss_schedule == 'const': lambda_setter = ScheduledHyperParamSetter( model.params, 'lam', [(args.reward_loss_delay + 1, 1.0 - args.weight_reward_loss)], False) else: raise ValueError('unknown schedule: {!r}'.format( args.weight_reward_loss_schedule)) hparams_setter.append(lambda_setter) train_start = time.time() estimator.train(train_input_fn, hooks=train_hooks, saving_listeners=[ckpt_listener]) train_end = time.time() time_d = datetime.timedelta(seconds=int(train_end - train_start)) LOG.info('Training for %d epochs finished in %s', args.epochs, time_d)