Exemple #1
0
def visualize(datagen, batch_size, view_size=4):
    """
    Read the batch from 'datagen' and display 'view_size' number of
    of images and their corresponding Ground Truth
    """
    def prep_imgs(img, ann):
        cmap = plt.get_cmap('viridis')
        # cmap may randomly fails if of other types
        ann = ann.astype('float32')
        ann_chs = np.dsplit(ann, ann.shape[-1])
        for i, ch in enumerate(ann_chs):
            ch = np.squeeze(ch)
            # normalize to -1 to 1 range else
            # cmap may behave stupidly
            ch = ch / (np.max(ch) - np.min(ch) + 1.0e-16)
            # take RGB from RGBA heat map
            ann_chs[i] = cmap(ch)[..., :3]
        img = img.astype('float32') / 255.0
        prepped_img = np.concatenate([img] + ann_chs, axis=1)
        return prepped_img

    assert view_size <= batch_size, 'Number of displayed images must <= batch size'
    ds = RepeatedData(datagen, -1)
    ds.reset_state()
    for imgs, segs in ds.get_data():
        for idx in range(0, view_size):
            displayed_img = prep_imgs(imgs[idx], segs[idx])
            plt.subplot(view_size, 1, idx + 1)
            plt.imshow(displayed_img)
        plt.show()
    return
Exemple #2
0
def visualize(datagen, batch_size):
    """
    Read the batch from 'datagen' and display 'view_size' number of
    of images and their corresponding Ground Truth
    """
    cfg = Config()

    def prep_imgs(img, lab):

        # Deal with HxWx1 case
        img = np.squeeze(img)

        if cfg.model_mode == "seg_gland" or cfg.model_mode == "seg_nuc":
            cmap = plt.get_cmap("jet")
            # cmap may randomly fails if of other types
            lab = lab.astype("float32")
            lab_chs = np.dsplit(lab, lab.shape[-1])
            for i, ch in enumerate(lab_chs):
                ch = np.squeeze(ch)
                # cmap may behave stupidly
                ch = ch / (np.max(ch) - np.min(ch) + 1.0e-16)
                # take RGB from RGBA heat map
                lab_chs[i] = cmap(ch)[..., :3]
            img = img.astype("float32") / 255.0
            prepped_img = np.concatenate([img] + lab_chs, axis=1)
        else:
            prepped_img = img
        return prepped_img

    ds = RepeatedData(datagen, -1)
    ds.reset_state()
    for imgs, labs in ds.get_data():
        if cfg.model_mode == "seg_gland" or cfg.model_mode == "seg_nuc":
            for idx in range(0, 4):
                displayed_img = prep_imgs(imgs[idx], labs[idx])
                # plot the image and the label
                plt.subplot(4, 1, idx + 1)
                plt.imshow(displayed_img, vmin=-1, vmax=1)
                plt.axis("off")
            plt.show()
        else:
            for idx in range(0, 8):
                displayed_img = prep_imgs(imgs[idx], labs[idx])
                # plot the image and the label
                plt.subplot(2, 4, idx + 1)
                plt.imshow(displayed_img)
                if len(cfg.label_names) > 0:
                    lab_title = cfg.label_names[int(labs[idx])]
                else:
                    lab_tite = int(labs[idx])
                plt.title(lab_title)
                plt.axis("off")
            plt.show()
    return
Exemple #3
0
def get_train_dataflow(batch_size=2):
    print("In train dataflow")
    roidbs = list(itertools.chain.from_iterable(DatasetRegistry.get(x).training_roidbs() for x in cfg.DATA.TRAIN))
    print_class_histogram(roidbs)
    print("Done loading roidbs")

    # Filter out images that have no gt boxes, but this filter shall not be applied for testing.
    # The model does support training with empty images, but it is not useful for COCO.
    num = len(roidbs)
    roidbs = list(filter(lambda img: len(img["boxes"][img["is_crowd"] == 0]) > 0, roidbs))
    logger.info(
        "Filtered {} images which contain no non-crowd groudtruth boxes. Total #images for training: {}".format(
            num - len(roidbs), len(roidbs)
        )
    )

    aspect_grouping = [1]
    aspect_ratios = [float(x["height"]) / float(x["width"]) for x in roidbs]
    group_ids = _quantize(aspect_ratios, aspect_grouping)

    ds = AspectGroupingDataFlow(roidbs, group_ids, batch_size=batch_size, drop_uneven=True)
    preprocess = TrainingDataPreprocessor()
    buffer_size = cfg.DATA.NUM_WORKERS * 10
    # ds = MultiProcessMapData(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size)
    ds = MultiThreadMapData(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size)
    ds.reset_state()

    # to get an infinite data flow
    ds = RepeatedData(ds, num=-1)
    dataiter = ds.__iter__()

    return dataiter
Exemple #4
0
    def build_iter(self):

        ds = DataFromGenerator(self.generator)
        ds = RepeatedData(ds, -1)
        ds = BatchData(ds, self.batch_size)
        if not cfg.TRAIN.vis:
            ds = PrefetchDataZMQ(ds, self.process_num)
        ds.reset_state()
        ds = ds.get_data()
        return ds
Exemple #5
0
def visualize(datagen, batch_size, view_size=4, aug_only=False, preview=False):
    """
    Read the batch from 'datagen' and display 'view_size' number of
    of images and their corresponding Ground Truth
    """
    def prep_imgs(img, ann):
        cmap = plt.get_cmap("viridis")
        # cmap may randomly fails if of other types
        ann = ann.astype("float32")
        ann_chs = np.dsplit(ann, ann.shape[-1])
        for i, ch in enumerate(ann_chs):
            ch = np.squeeze(ch)
            # normalize to -1 to 1 range else
            # cmap may behave stupidly
            ch = ch / (np.max(ch) - np.min(ch) + 1.0e-16)
            # take RGB from RGBA heat map
            ann_chs[i] = cmap(ch)[..., :3]
        img = img.astype("float32") / 255.0
        prepped_img = np.concatenate([img] + ann_chs, axis=1)
        return prepped_img

    assert view_size <= batch_size, "Number of displayed images must <= batch size"
    ds = RepeatedData(datagen, -1)
    ds.reset_state()
    for imgs, segs in ds.get_data():
        for idx in range(0, view_size):
            displayed_img = prep_imgs(imgs[idx], segs[idx])
            plt.subplot(view_size, 1, idx + 1)
            if aug_only:
                plt.imshow(imgs[idx])  # displayed_img
            else:
                plt.imshow(displayed_img)
        plt.savefig(f"{str(tempfile.NamedTemporaryFile().name)}.png")
        plt.show()
        if preview:
            break

    return
Exemple #6
0
 def __init__(self, ds, infinite=True):
     """
     Args:
         ds (DataFlow): the input DataFlow.
         infinite (bool): When set to False, will raise StopIteration when
             ds is exhausted.
     """
     if not isinstance(ds, DataFlow):
         raise ValueError("FeedInput takes a DataFlow! Got {}".format(ds))
     self.ds = ds
     if infinite:
         self._iter_ds = RepeatedData(self.ds, -1)
     else:
         self._iter_ds = self.ds
Exemple #7
0
 def __init__(self, ds, queue=None):
     """
     Args:
         ds(DataFlow): the input DataFlow.
         queue (tf.QueueBase): A :class:`tf.QueueBase` whose type
             should match the corresponding InputDesc of the model.
             Defaults to a FIFO queue of size 50.
     """
     if not isinstance(ds, DataFlow):
         raise ValueError("QueueInput takes a DataFlow! Got {}".format(ds))
     self.queue = queue
     self.ds = ds
     self._inf_ds = RepeatedData(ds, -1)
     self._started = False
def get_val_dataflow(datadir,
                     batch_size,
                     augmentors=None,
                     parallel=None,
                     num_splits=None,
                     split_index=None,
                     dataname="val"):
    if augmentors is None:
        augmentors = fbresnet_augmentor(False)
    assert datadir is not None
    assert isinstance(augmentors, list)
    if parallel is None:
        parallel = min(40, multiprocessing.cpu_count())

    if num_splits is None:
        ds = dataset.ILSVRC12Files(datadir, dataname, shuffle=True)
    else:
        # shard validation data
        assert False
        assert split_index < num_splits
        files = dataset.ILSVRC12Files(datadir, dataname, shuffle=True)
        files.reset_state()
        files = list(files.get_data())
        logger.info("Number of validation data = {}".format(len(files)))
        split_size = len(files) // num_splits
        start, end = split_size * split_index, split_size * (split_index + 1)
        end = min(end, len(files))
        logger.info("Local validation split = {} - {}".format(start, end))
        files = files[start:end]
        ds = DataFromList(files, shuffle=True)

    aug = imgaug.AugmentorList(augmentors)

    def mapf(dp):
        fname, cls = dp
        im = cv2.imread(fname, cv2.IMREAD_COLOR)
        #from BGR to RGB
        im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
        im = aug.augment(im)
        return im, cls

    ds = MultiThreadMapData(ds,
                            parallel,
                            mapf,
                            buffer_size=min(2000, ds.size()),
                            strict=True)
    ds = BatchData(ds, batch_size, remainder=False)
    ds = RepeatedData(ds, num=-1)
    # do not fork() under MPI
    return ds
                    type=int)
args = parser.parse_args()
logger.auto_set_dir(action='d')

get_config_func = imp.load_source('config_script', args.config).get_config
config = get_config_func()
config.dataset.reset_state()

if args.output:
    mkdir_p(args.output)
    cnt = 0
    index = args.index  # TODO: as an argument?
    for dp in config.dataset.get_data():
        imgbatch = dp[index]
        if cnt > args.number:
            break
        for bi, img in enumerate(imgbatch):
            cnt += 1
            fname = os.path.join(args.output, '{:03d}-{}.png'.format(cnt, bi))
            cv2.imwrite(fname, img * args.scale)

NR_DP_TEST = args.number
logger.info("Testing dataflow speed:")
ds = RepeatedData(config.dataset, -1)
with tqdm.tqdm(total=NR_DP_TEST, leave=True, unit='data points') as pbar:
    for idx, dp in enumerate(ds.get_data()):
        del dp
        if idx > NR_DP_TEST:
            break
        pbar.update()
                    default=10, type=int)
args = parser.parse_args()
logger.auto_set_dir(action='d')

get_config_func = imp.load_source('config_script', args.config).get_config
config = get_config_func()
config.dataset.reset_state()

if args.output:
    mkdir_p(args.output)
    cnt = 0
    index = args.index   # TODO: as an argument?
    for dp in config.dataset.get_data():
        imgbatch = dp[index]
        if cnt > args.number:
            break
        for bi, img in enumerate(imgbatch):
            cnt += 1
            fname = os.path.join(args.output, '{:03d}-{}.png'.format(cnt, bi))
            cv2.imwrite(fname, img * args.scale)

NR_DP_TEST = args.number
logger.info("Testing dataflow speed:")
ds = RepeatedData(config.dataset, -1)
with tqdm.tqdm(total=NR_DP_TEST, leave=True, unit='data points') as pbar:
    for idx, dp in enumerate(ds.get_data()):
        del dp
        if idx > NR_DP_TEST:
            break
        pbar.update()
Exemple #11
0
def train(args: Namespace, data_params: MoleculeData, experiment: Experiment,
          mol_metrics: GraphMolecularMetrics) -> None:
    ds_train = create_dataflow(args.data_dir, 'train', args.batch_size)

    ds_train_repeat = PrefetchDataZMQ(ds_train, nr_proc=1)
    # times 2, because we consume 2 batches per step
    ds_train_repeat = RepeatedData(ds_train_repeat, 2 * args.epochs)

    train_input_fn = experiment.make_train_fn(ds_train_repeat, args.batch_size,
                                              args.num_latent, data_params)

    def hooks_fn(train_ops: MolGANTrainOps,
                 train_steps: tfgan.GANTrainSteps) -> EstimatorTrainHooks:
        if train_ops.valuenet_train_op is not None:
            generator_hook = FeedableTrainOpsHook(
                train_ops.generator_train_op,
                train_steps.generator_train_steps,
                train_input_fn,
                return_feed_dict=False)

            discriminator_hook = WithRewardTrainOpsHook([
                train_ops.discriminator_train_op, train_ops.valuenet_train_op
            ], train_steps.discriminator_train_steps, train_input_fn,
                                                        mol_metrics)
        else:
            generator_hook = FeedableTrainOpsHook(
                train_ops.generator_train_op,
                train_steps.generator_train_steps,
                train_input_fn,
                return_feed_dict=True)

            discriminator_hook = FeedableTrainOpsHook(
                train_ops.discriminator_train_op,
                train_steps.discriminator_train_steps, train_input_fn)
        return [generator_hook, discriminator_hook]

    model = experiment.make_model_fn(args, data_params, hooks_fn)

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    # enable XLA JIT
    # sess_config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

    config = tf.estimator.RunConfig(model_dir=str(args.model_dir),
                                    session_config=sess_config,
                                    save_summary_steps=ds_train.size(),
                                    save_checkpoints_secs=None,
                                    save_checkpoints_steps=4 * ds_train.size(),
                                    keep_checkpoint_max=2)

    estimator = tf.estimator.Estimator(model.model_fn, config=config)

    train_hooks = [PrintParameterSummary()]
    if args.restore_from_checkpoint is not None:
        train_hooks.append(
            RestoreFromCheckpointHook(str(args.restore_from_checkpoint)))

    if args.debug:
        from tensorflow.python import debug as tf_debug

        train_hooks.append(tf_debug.TensorBoardDebugHook("localhost:6064"))

    predict_fn = experiment.make_predict_fn(args.data_dir,
                                            args.num_latent,
                                            n_samples=1000,
                                            batch_size=1000)
    ckpt_listener = PredictAndEvalMolecule(estimator, predict_fn, mol_metrics,
                                           str(args.model_dir))

    hparams_setter = [
        ScheduledHyperParamSetter('generator_learning_rate:0',
                                  args.generator_learning_rate,
                                  [(80, 0.5 * args.generator_learning_rate),
                                   (150, 0.1 * args.generator_learning_rate),
                                   (200, 0.01 * args.generator_learning_rate)],
                                  steps_per_epoch=ds_train.size()),
        ScheduledHyperParamSetter(
            'discriminator_learning_rate:0',
            args.discriminator_learning_rate,
            [(80, 0.5 * args.discriminator_learning_rate),
             (150, 0.1 * args.discriminator_learning_rate),
             (200, 0.01 * args.discriminator_learning_rate)],
            steps_per_epoch=ds_train.size())
    ]
    train_hooks.extend(hparams_setter)

    if args.weight_reward_loss > 0:
        if args.weight_reward_loss_schedule == 'linear':
            lambda_setter = ScheduledHyperParamSetter(
                model.params, 'lam',
                [(args.reward_loss_delay, 1.0),
                 (args.epochs, 1.0 - args.weight_reward_loss)], True)
        elif args.weight_reward_loss_schedule == 'const':
            lambda_setter = ScheduledHyperParamSetter(
                model.params, 'lam',
                [(args.reward_loss_delay + 1, 1.0 - args.weight_reward_loss)],
                False)
        else:
            raise ValueError('unknown schedule: {!r}'.format(
                args.weight_reward_loss_schedule))

        hparams_setter.append(lambda_setter)

    train_start = time.time()
    estimator.train(train_input_fn,
                    hooks=train_hooks,
                    saving_listeners=[ckpt_listener])
    train_end = time.time()

    time_d = datetime.timedelta(seconds=int(train_end - train_start))
    LOG.info('Training for %d epochs finished in %s', args.epochs, time_d)