示例#1
0
def train(tf_config, logger):
    dataset = data.Dataset3D(cfg.DATASET, cfg.RNG_SEED, training=True)
    imgs, labels = dataset.preprocessing(
        augment=True, batch_size=cfg.TRAIN.BATCH_SIZE, num_epochs=cfg.TRAIN.EPOCH)

    net, _ = model.unet_3d(imgs, bn_training=True, layers=4, features_root=32,
                           dropout_training=True, dataset=cfg.DATASET)
    with tf.variable_scope('cls'):
        net = tf.layers.conv3d(net, 1, 1, activation=tf.nn.relu)
    loss_pixel = tf.losses.mean_squared_error(
        labels * cfg.MODEL.RATIO[cfg.DATASET], net)
    loss_pixel_sum = tf.losses.absolute_difference(
        tf.reduce_sum(labels, axis=[1, 2, 3, 4]),
        tf.reduce_sum(net / cfg.MODEL.RATIO[cfg.DATASET], axis=[1, 2, 3, 4]))

    lr_decayed = tf.train.cosine_decay_restarts(
        cfg.SOLVER.BASE_LR, tf.train.get_or_create_global_step(), cfg.SOLVER.RESTART_STEP)
    wd = cfg.SOLVER.WEIGHT_DECAY * lr_decayed / cfg.SOLVER.BASE_LR
    optimizer = tf.contrib.opt.AdamWOptimizer(wd, learning_rate=lr_decayed)
    step_pixel = optimizer.minimize(
        loss_pixel, global_step=tf.train.get_or_create_global_step())

    tf.summary.scalar('per_pixel_mse', loss_pixel)
    tf.summary.scalar('sum_mae', loss_pixel_sum)
    merged = tf.summary.merge_all()
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    step = tf.group([step_pixel, update_ops])
    saver = tf.train.Saver(max_to_keep=1000)
    if not os.path.exists(cfg.OUTPUT_DIR):
        os.mkdir(cfg.OUTPUT_DIR)

    with tf.Session(config=tf_config) as sess:
        summary_writer = tf.summary.FileWriter(
            os.path.join(cfg.OUTPUT_DIR, 'train'), sess.graph)
        if tf.train.latest_checkpoint(cfg.OUTPUT_DIR) is None:
            sess.run(tf.global_variables_initializer())
            start_step = 0
            logger.info('Saving path is {}'.format(cfg.OUTPUT_DIR))
        else:
            weights_path = tf.train.latest_checkpoint(cfg.OUTPUT_DIR)
            start_step = int(weights_path.split('-')[-1])
            tf.train.Saver().restore(sess, weights_path)
            logger.info('Restoring weights from {}'.format(weights_path))
        logger.info('Training at Step {}'.format(start_step + 1))

        for i in range(start_step, cfg.TRAIN.STEP):
            if i % cfg.LOG_PERIOD == 0 or i == cfg.TRAIN.STEP - 1:
                loss_pixel_val, loss_sum_val, summary, _ = sess.run(
                    [loss_pixel, loss_pixel_sum, merged, step])
                summary_writer.add_summary(summary, i + 1)
                logger.info('Step:{}/{} per_pixel:{:6.3f}  sum:{:6.3f}'.format(
                    i + 1, cfg.TRAIN.STEP, loss_pixel_val, loss_sum_val))
            else:
                sess.run([step])
            if i == cfg.TRAIN.STEP - 1:
                weights_path = saver.save(
                    sess, os.path.join(cfg.OUTPUT_DIR, 'model'), global_step=i + 1)
                logger.info('Saving weights to {}'.format(weights_path))
    tf.reset_default_graph()
示例#2
0
 def _predict(imgs, labels, reuse=False):
     net, _ = model.unet_3d(imgs, bn_training=False, layers=4, features_root=32,
                            dropout_training=False, dataset=cfg.DATASET, reuse=reuse)
     with tf.variable_scope('cls', reuse=reuse):
         net = tf.layers.conv3d(net, 1, 1, activation=tf.nn.relu)
     pred_sum = tf.reduce_sum(
         net / cfg.MODEL.RATIO[cfg.DATASET], axis=[0, 1, 2, 3, 4])
     label_sum = tf.reduce_sum(labels, axis=[0, 1, 2, 3, 4])
     return pred_sum, label_sum, net
示例#3
0
    # We'll save the worker logs and models separately but only
    # use the logs/saved model from worker 0.
    args.saved_model = "./worker{}/3d_unet_decathlon.hdf5".format(hvd.rank())

# Optimize CPU threads for TensorFlow
CONFIG = tf.ConfigProto(inter_op_parallelism_threads=args.interop_threads,
                        intra_op_parallelism_threads=args.intraop_threads)

SESS = tf.Session(config=CONFIG)

K.backend.set_session(SESS)

model, opt = unet_3d(
    use_upsampling=args.use_upsampling,
    n_cl_in=args.number_input_channels,
    learning_rate=args.lr * hvd.size(),
    n_cl_out=1,  # single channel (greyscale)
    dropout=0.2,
    print_summary=print_summary)

opt = hvd.DistributedOptimizer(opt)

model.compile(
    optimizer=opt,
    # loss=[combined_dice_ce_loss],
    loss=[dice_coef_loss],
    metrics=[dice_coef, "accuracy", sensitivity, specificity])

if hvd.rank() == 0:
    start_time = datetime.datetime.now()
    print("Started script on {}".format(start_time))
示例#4
0
                              batch_size=args.batch_size,
                              train_test_split=args.train_test_split,
                              validate_test_split=args.validate_test_split,
                              number_output_classes=args.number_output_classes,
                              random_seed=args.random_seed,
                              shard=hvd.rank())

if (hvd.rank() == 0):
    print("{} workers".format(hvd.size()))
    brats_data.print_info()  # Print dataset information
"""
2. Create the TensorFlow model
"""
model = unet_3d(input_dim=crop_dim,
                filters=args.filters,
                number_output_classes=args.number_output_classes,
                use_upsampling=args.use_upsampling,
                concat_axis=-1,
                model_name=args.saved_model_name)

local_opt = K.optimizers.Adam()
hvd_opt = hvd.DistributedOptimizer(local_opt)

model.compile(loss=dice_loss,
              metrics=[dice_coef, soft_dice_coef],
              optimizer=hvd_opt)

checkpoint = K.callbacks.ModelCheckpoint(args.saved_model_name,
                                         verbose=1,
                                         save_best_only=True)

# TensorBoard
示例#5
0
def data_loader():
    train_list = yaml_utils.read(str(output_path / 't22seg_train.yaml'))
    train_generator = data_generator(train_list, batch_size=6)

    test_list = yaml_utils.read(str(output_path / 't22seg_test.yaml'))
    test_generator = data_generator(test_list, batch_size=12)
    return train_generator, len(train_list), test_generator, len(test_list)


if __name__ == '__main__':
    create_data_yaml(output_path)  # first deal with dataset

    train_generator, train_steps, validation_generator, validation_steps = data_loader(
    )  # second create generator

    _model = unet_3d(input_shape=(1, 64, 64,
                                  32))  # third create model (channels,x,y,z)

    Path('_').mkdir(parents=True, exist_ok=True
                    )  # create file in fold _ for finding and deleting  easily
    _model.fit_generator(
        generator=train_generator,
        steps_per_epoch=train_steps,
        epochs=200,  # final train model
        validation_data=validation_generator,
        validation_steps=validation_steps,
        callbacks=[
            ModelCheckpoint('_/tumor_segmentation_model.h5',
                            save_best_only=True),
            CSVLogger('_/training.log', append=True),
            ReduceLROnPlateau(factor=0.5, patience=50, verbose=1),
            EarlyStopping(verbose=1, patience=None)