Example #1
0
def predict(params, model, dataset, logger):
    prec = 'amp' if params.use_amp else 'fp32'
    if params.model_dir:
        if params.use_savedmodel:
            model = tf.keras.models.load_model(os.path.join(params.model_dir, f'saved_model_{prec}'))
        elif params.use_tftrt:
            model = TFTRTModel(model_dir=params.model_dir, precision=prec)
        else:
            checkpoint = tf.train.Checkpoint(model=model)
            checkpoint.restore(tf.train.latest_checkpoint(params.model_dir)).expect_partial()

    @tf.function
    def prediction_step(features):
        return tf.nn.softmax(model(features, training=False), axis=-1)

    if params.benchmark:
        assert params.max_steps > params.warmup_steps, \
            "max_steps value has to be greater than warmup_steps"
        timestamps = []
        for iteration, images in enumerate(dataset.test_fn(count=None, drop_remainder=True)):
            prediction_step(images)
            if iteration > params.warmup_steps:
                timestamps.append(time())
            if iteration >= params.max_steps:
                break

        deltas = np.array([timestamps[i + 1] - timestamps[i] for i in range(len(timestamps) - 1)])
        stats = process_performance_stats(deltas, params.batch_size, mode="test")
        logger.log(step=(), data=stats)
    else:
        predictions = np.concatenate([prediction_step(images).numpy()
                                      for images in dataset.test_fn(count=1)], axis=0)
        binary_masks = [np.argmax(p, axis=-1).astype(np.uint8) * 255 for p in predictions]
        multipage_tif = [Image.fromarray(mask).resize(size=(512, 512), resample=Image.BILINEAR)
                         for mask in binary_masks]

        output_dir = os.path.join(params.model_dir, 'predictions')
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        multipage_tif[0].save(os.path.join(output_dir, 'test-masks.tif'),
                              compression="tiff_deflate",
                              save_all=True,
                              append_images=multipage_tif[1:])

        print("Predictions saved at {}".format(output_dir))
    logger.flush()
Example #2
0
def train(params, model, dataset, logger):
    np.random.seed(params.seed)
    tf.random.set_seed(params.seed)
    max_steps = params.max_steps // hvd.size()

    optimizer = tf.keras.optimizers.Adam(learning_rate=params.learning_rate)
    if params.use_amp:
        optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(optimizer, "dynamic")

    ce_loss = tf.keras.metrics.Mean(name='ce_loss')
    f1_loss = tf.keras.metrics.Mean(name='dice_loss')
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
    if params.resume_training and params.model_dir:
        checkpoint.restore(tf.train.latest_checkpoint(params.model_dir))

    @tf.function
    def train_step(features, labels, warmup_batch=False):
        with tf.GradientTape() as tape:
            output_map = model(features)
            crossentropy_loss, dice_loss = partial_losses(output_map, labels)
            added_losses = tf.add(crossentropy_loss, dice_loss, name="total_loss_ref")
            loss = added_losses + params.weight_decay * tf.add_n(
                [tf.nn.l2_loss(v) for v in model.trainable_variables
                 if 'batch_normalization' not in v.name])

            if params.use_amp:
                loss = optimizer.get_scaled_loss(loss)
        tape = hvd.DistributedGradientTape(tape)
        gradients = tape.gradient(loss, model.trainable_variables)
        if params.use_amp:
            gradients = optimizer.get_unscaled_gradients(gradients)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        # Note: broadcast should be done after the first gradient step to ensure optimizer
        # initialization.
        if warmup_batch:
            hvd.broadcast_variables(model.variables, root_rank=0)
            hvd.broadcast_variables(optimizer.variables(), root_rank=0)

        ce_loss(crossentropy_loss)
        f1_loss(dice_loss)
        return loss

    if params.benchmark:
        assert max_steps * hvd.size() > params.warmup_steps, \
            "max_steps value has to be greater than warmup_steps"
        timestamps = []
        for iteration, (images, labels) in enumerate(dataset.train_fn(drop_remainder=True)):
            loss = train_step(images, labels, warmup_batch=iteration == 0).numpy()
            if iteration > params.warmup_steps:
                timestamps.append(time())
            if iteration >= max_steps * hvd.size():
                break

        if hvd.rank() == 0:
            deltas = np.array([timestamps[i + 1] - timestamps[i] for i in range(len(timestamps) - 1)])
            stats = process_performance_stats(deltas, hvd.size() * params.batch_size, mode="train")
            logger.log(step=(), data=stats)
    else:
        for iteration, (images, labels) in enumerate(dataset.train_fn()):
            train_step(images, labels, warmup_batch=iteration == 0)
            if hvd.rank() == 0:
                if iteration % params.log_every == 0:
                    logger.log(step=(iteration, max_steps),
                               data={"train_ce_loss": float(ce_loss.result()),
                                     "train_dice_loss": float(f1_loss.result()),
                                     "train_total_loss": float(f1_loss.result() + ce_loss.result())})

                if (params.evaluate_every > 0) and (iteration % params.evaluate_every == 0):
                    evaluate(params, model, dataset, logger, restore_checkpoint=False)

                f1_loss.reset_states()
                ce_loss.reset_states()

            if iteration >= max_steps:
                break
        if hvd.rank() == 0:
            checkpoint.save(file_prefix=os.path.join(params.model_dir, "checkpoint"))

    logger.flush()
Example #3
0
def train(params, model, dataset, logger, tb_logger=None):
    np.random.seed(params.seed)
    tf.random.set_seed(params.seed)

    num_workers = hvd_size() if horovod_enabled() else 1
    worker_id = hvd_rank() if horovod_enabled() else 0
    max_steps = params.max_steps // num_workers

    optimizer = tf.keras.optimizers.Adam(learning_rate=params.learning_rate)

    ce_loss = tf.keras.metrics.Mean(name='ce_loss')
    f1_loss = tf.keras.metrics.Mean(name='dice_loss')
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
    if params.resume_training and params.model_dir:
        checkpoint.restore(tf.train.latest_checkpoint(params.model_dir))

    if tb_logger is not None:
        write_hparams_v2(tb_logger.train_writer, vars(params))

    @tf.function
    def train_step(features, labels, warmup_batch=False):
        with tf.GradientTape() as tape:
            output_map = model(features)
            crossentropy_loss, dice_loss = partial_losses(output_map, labels)
            added_losses = tf.add(crossentropy_loss, dice_loss, name="total_loss_ref")
            loss = added_losses + params.weight_decay * tf.add_n(
                [tf.nn.l2_loss(v) for v in model.trainable_variables
                 if 'batch_normalization' not in v.name])

        if horovod_enabled():
            tape = hvd.DistributedGradientTape(tape)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        # Note: broadcast should be done after the first gradient step to ensure optimizer
        # initialization.
        if horovod_enabled() and warmup_batch:
            hvd.broadcast_variables(model.variables, root_rank=0)
            hvd.broadcast_variables(optimizer.variables(), root_rank=0)

        ce_loss(crossentropy_loss)
        f1_loss(dice_loss)
        return loss

    if params.benchmark:
        assert max_steps * num_workers > params.warmup_steps, \
        "max_steps value has to be greater than warmup_steps"
        timestamps = []
        for iteration, (images, labels) in enumerate(dataset.train_fn(drop_remainder=True)):
            loss = train_step(images, labels, warmup_batch=iteration == 0).numpy()
            if iteration > params.warmup_steps:
                timestamps.append(time())

            if iteration >= max_steps * num_workers:
                break

        if worker_id == 0:
            deltas = np.array([timestamps[i + 1] - timestamps[i] for i in range(len(timestamps) - 1)])
            stats = process_performance_stats(deltas, num_workers * params.batch_size, mode="train")
            logger.log(step=(), data=stats)
    else:
        timestamp = time()
        dataset_fn = dataset.synth_fn if params.synth_data else dataset.train_fn
        for iteration, (images, labels) in enumerate(dataset_fn()):
            # assign returned loss as a numpy object to transfer the data to host
            loss = train_step(images, labels, warmup_batch=iteration == 0).numpy()
            if worker_id == 0 or params.log_all_workers:
                if iteration % params.log_every == 0:
                    duration = float(time() - timestamp) / params.log_every
                    timestamp = time()
                    data = {
                        "train_ce_loss": float(ce_loss.result()),
                        "train_dice_loss": float(f1_loss.result()),
                        "train_total_loss": float(f1_loss.result() + ce_loss.result()),
                        "iter duration [ms]": 1000 * duration,
                        "IPS": params.batch_size / duration
                    }
                    logger.log(step=(iteration, max_steps), data=data)

                    if tb_logger is not None:
                        with tb_logger.train_writer.as_default():
                            for name, value in data.items():
                                tf.summary.scalar(name, value, step=iteration)
                            # for consistency
                            tf.summary.scalar("loss", data["train_total_loss"], step=iteration)
                            tf.summary.scalar("examples/sec", data["IPS"], step=iteration)
                            tf.summary.scalar("global_step/sec", 1. / duration, step=iteration)

                if (params.evaluate_every > 0) and (iteration % params.evaluate_every == 0):
                    evaluate(params, model, dataset, logger, tb_logger,
                             restore_checkpoint=False)

                f1_loss.reset_states()
                ce_loss.reset_states()

            if iteration >= max_steps:
                break

        if not params.disable_ckpt_saving and worker_id == 0:
            checkpoint.save(file_prefix=os.path.join(params.model_dir, "checkpoint"))

    logger.flush()