def predict(params, model, dataset, logger): prec = 'amp' if params.use_amp else 'fp32' if params.model_dir: if params.use_savedmodel: model = tf.keras.models.load_model(os.path.join(params.model_dir, f'saved_model_{prec}')) elif params.use_tftrt: model = TFTRTModel(model_dir=params.model_dir, precision=prec) else: checkpoint = tf.train.Checkpoint(model=model) checkpoint.restore(tf.train.latest_checkpoint(params.model_dir)).expect_partial() @tf.function def prediction_step(features): return tf.nn.softmax(model(features, training=False), axis=-1) if params.benchmark: assert params.max_steps > params.warmup_steps, \ "max_steps value has to be greater than warmup_steps" timestamps = [] for iteration, images in enumerate(dataset.test_fn(count=None, drop_remainder=True)): prediction_step(images) if iteration > params.warmup_steps: timestamps.append(time()) if iteration >= params.max_steps: break deltas = np.array([timestamps[i + 1] - timestamps[i] for i in range(len(timestamps) - 1)]) stats = process_performance_stats(deltas, params.batch_size, mode="test") logger.log(step=(), data=stats) else: predictions = np.concatenate([prediction_step(images).numpy() for images in dataset.test_fn(count=1)], axis=0) binary_masks = [np.argmax(p, axis=-1).astype(np.uint8) * 255 for p in predictions] multipage_tif = [Image.fromarray(mask).resize(size=(512, 512), resample=Image.BILINEAR) for mask in binary_masks] output_dir = os.path.join(params.model_dir, 'predictions') if not os.path.exists(output_dir): os.makedirs(output_dir) multipage_tif[0].save(os.path.join(output_dir, 'test-masks.tif'), compression="tiff_deflate", save_all=True, append_images=multipage_tif[1:]) print("Predictions saved at {}".format(output_dir)) logger.flush()
def train(params, model, dataset, logger): np.random.seed(params.seed) tf.random.set_seed(params.seed) max_steps = params.max_steps // hvd.size() optimizer = tf.keras.optimizers.Adam(learning_rate=params.learning_rate) if params.use_amp: optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(optimizer, "dynamic") ce_loss = tf.keras.metrics.Mean(name='ce_loss') f1_loss = tf.keras.metrics.Mean(name='dice_loss') checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) if params.resume_training and params.model_dir: checkpoint.restore(tf.train.latest_checkpoint(params.model_dir)) @tf.function def train_step(features, labels, warmup_batch=False): with tf.GradientTape() as tape: output_map = model(features) crossentropy_loss, dice_loss = partial_losses(output_map, labels) added_losses = tf.add(crossentropy_loss, dice_loss, name="total_loss_ref") loss = added_losses + params.weight_decay * tf.add_n( [tf.nn.l2_loss(v) for v in model.trainable_variables if 'batch_normalization' not in v.name]) if params.use_amp: loss = optimizer.get_scaled_loss(loss) tape = hvd.DistributedGradientTape(tape) gradients = tape.gradient(loss, model.trainable_variables) if params.use_amp: gradients = optimizer.get_unscaled_gradients(gradients) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) # Note: broadcast should be done after the first gradient step to ensure optimizer # initialization. if warmup_batch: hvd.broadcast_variables(model.variables, root_rank=0) hvd.broadcast_variables(optimizer.variables(), root_rank=0) ce_loss(crossentropy_loss) f1_loss(dice_loss) return loss if params.benchmark: assert max_steps * hvd.size() > params.warmup_steps, \ "max_steps value has to be greater than warmup_steps" timestamps = [] for iteration, (images, labels) in enumerate(dataset.train_fn(drop_remainder=True)): loss = train_step(images, labels, warmup_batch=iteration == 0).numpy() if iteration > params.warmup_steps: timestamps.append(time()) if iteration >= max_steps * hvd.size(): break if hvd.rank() == 0: deltas = np.array([timestamps[i + 1] - timestamps[i] for i in range(len(timestamps) - 1)]) stats = process_performance_stats(deltas, hvd.size() * params.batch_size, mode="train") logger.log(step=(), data=stats) else: for iteration, (images, labels) in enumerate(dataset.train_fn()): train_step(images, labels, warmup_batch=iteration == 0) if hvd.rank() == 0: if iteration % params.log_every == 0: logger.log(step=(iteration, max_steps), data={"train_ce_loss": float(ce_loss.result()), "train_dice_loss": float(f1_loss.result()), "train_total_loss": float(f1_loss.result() + ce_loss.result())}) if (params.evaluate_every > 0) and (iteration % params.evaluate_every == 0): evaluate(params, model, dataset, logger, restore_checkpoint=False) f1_loss.reset_states() ce_loss.reset_states() if iteration >= max_steps: break if hvd.rank() == 0: checkpoint.save(file_prefix=os.path.join(params.model_dir, "checkpoint")) logger.flush()
def train(params, model, dataset, logger, tb_logger=None): np.random.seed(params.seed) tf.random.set_seed(params.seed) num_workers = hvd_size() if horovod_enabled() else 1 worker_id = hvd_rank() if horovod_enabled() else 0 max_steps = params.max_steps // num_workers optimizer = tf.keras.optimizers.Adam(learning_rate=params.learning_rate) ce_loss = tf.keras.metrics.Mean(name='ce_loss') f1_loss = tf.keras.metrics.Mean(name='dice_loss') checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) if params.resume_training and params.model_dir: checkpoint.restore(tf.train.latest_checkpoint(params.model_dir)) if tb_logger is not None: write_hparams_v2(tb_logger.train_writer, vars(params)) @tf.function def train_step(features, labels, warmup_batch=False): with tf.GradientTape() as tape: output_map = model(features) crossentropy_loss, dice_loss = partial_losses(output_map, labels) added_losses = tf.add(crossentropy_loss, dice_loss, name="total_loss_ref") loss = added_losses + params.weight_decay * tf.add_n( [tf.nn.l2_loss(v) for v in model.trainable_variables if 'batch_normalization' not in v.name]) if horovod_enabled(): tape = hvd.DistributedGradientTape(tape) gradients = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) # Note: broadcast should be done after the first gradient step to ensure optimizer # initialization. if horovod_enabled() and warmup_batch: hvd.broadcast_variables(model.variables, root_rank=0) hvd.broadcast_variables(optimizer.variables(), root_rank=0) ce_loss(crossentropy_loss) f1_loss(dice_loss) return loss if params.benchmark: assert max_steps * num_workers > params.warmup_steps, \ "max_steps value has to be greater than warmup_steps" timestamps = [] for iteration, (images, labels) in enumerate(dataset.train_fn(drop_remainder=True)): loss = train_step(images, labels, warmup_batch=iteration == 0).numpy() if iteration > params.warmup_steps: timestamps.append(time()) if iteration >= max_steps * num_workers: break if worker_id == 0: deltas = np.array([timestamps[i + 1] - timestamps[i] for i in range(len(timestamps) - 1)]) stats = process_performance_stats(deltas, num_workers * params.batch_size, mode="train") logger.log(step=(), data=stats) else: timestamp = time() dataset_fn = dataset.synth_fn if params.synth_data else dataset.train_fn for iteration, (images, labels) in enumerate(dataset_fn()): # assign returned loss as a numpy object to transfer the data to host loss = train_step(images, labels, warmup_batch=iteration == 0).numpy() if worker_id == 0 or params.log_all_workers: if iteration % params.log_every == 0: duration = float(time() - timestamp) / params.log_every timestamp = time() data = { "train_ce_loss": float(ce_loss.result()), "train_dice_loss": float(f1_loss.result()), "train_total_loss": float(f1_loss.result() + ce_loss.result()), "iter duration [ms]": 1000 * duration, "IPS": params.batch_size / duration } logger.log(step=(iteration, max_steps), data=data) if tb_logger is not None: with tb_logger.train_writer.as_default(): for name, value in data.items(): tf.summary.scalar(name, value, step=iteration) # for consistency tf.summary.scalar("loss", data["train_total_loss"], step=iteration) tf.summary.scalar("examples/sec", data["IPS"], step=iteration) tf.summary.scalar("global_step/sec", 1. / duration, step=iteration) if (params.evaluate_every > 0) and (iteration % params.evaluate_every == 0): evaluate(params, model, dataset, logger, tb_logger, restore_checkpoint=False) f1_loss.reset_states() ce_loss.reset_states() if iteration >= max_steps: break if not params.disable_ckpt_saving and worker_id == 0: checkpoint.save(file_prefix=os.path.join(params.model_dir, "checkpoint")) logger.flush()