def train(args, train_list, val_list, u_model, net_input_shape):
    # Compile the loaded model
    model, loss_weights = compile_model(args=args, train_list=train_list, net_input_shape=net_input_shape,
                                        uncomp_model=u_model)

    # Load pre-trained weights
    if args.custom_weights_path != '':
        try:
            model.load_weights(args.custom_weights_path)
        except Exception as e:
            print(e)
            print('!!! Failed to load weights file. Training without pre-training weights. !!!')

    # Set the callbacks
    callbacks = get_callbacks(args)

    # Training the network
    history = model.fit_generator(
        generate_train_batches(root_path=args.data_root_dir, train_list=train_list, net_shape=net_input_shape,
                               mod_dirs=args.modality_dir_list, exp_name=args.exp_name, net=args.net,
                               MIP_choices=args.MIP_choices, n_class=args.num_classes, batchSize=args.batch_size,
                               numSlices=args.slices, subSampAmt=args.subsamp, stride=args.stride,
                               shuff=args.shuffle_data, aug_data=args.aug_data),
        max_queue_size=40, workers=4, use_multiprocessing=False,
        steps_per_epoch=int(np.ceil(len(train_list)/args.batch_size*12)), # 12 avg. num of loops in train generator
        validation_data=generate_val_batches(root_path=args.data_root_dir, val_list=val_list, net_shape=net_input_shape,
                                             mod_dirs=args.modality_dir_list, exp_name=args.exp_name, net=args.net,
                                             MIP_choices=args.MIP_choices, n_class=args.num_classes,
                                             batchSize=args.batch_size, numSlices=args.slices, subSampAmt=0,
                                             stride=args.stride, shuff=args.shuffle_data),
        validation_steps=int(np.ceil(len(val_list)/args.batch_size)),
        epochs=args.epochs, class_weight=loss_weights, callbacks=callbacks, verbose=args.verbose)

    # Plot the training data collected
    plot_training(history, args.net, args.num_classes, args.output_dir, args.output_name, args.time)
Example #2
0
def train(args, train_list, val_list, u_model, net_input_shape):
    # Compile the loaded model
    num_epoch = 400
    model = compile_model(args=args, net_input_shape=net_input_shape, uncomp_model=u_model)
    loss_vec = np.zeros((num_epoch, 1))
    dice_hard_vec = np.zeros((num_epoch, 1))
    val_loss_vec = np.zeros((num_epoch, 1))
    val_dice_hard_vec = np.zeros((num_epoch, 1))
    val_out_seg_loss_vec = np.zeros((num_epoch, 1))
    out_seg_loss_vec = np.zeros((num_epoch, 1))
    val_out_recon_loss_vec = np.zeros((num_epoch, 1))
    out_recon_loss_vec = np.zeros((num_epoch, 1))

    # Set the callbacks
    for i in range(num_epoch):
        print(i)
        callbacks = get_callbacks(args, i)

        # Training the network
        train_batches = generate_train_batches(args.data_root_dir, train_list, net_input_shape, net=args.net, image_shape=net_input_shape,
                                               batchSize=args.batch_size, numSlices=args.slices,
                                               subSampAmt=args.subsamp,
                                               stride=args.stride, shuff=args.shuffle_data, aug_data=args.aug_data)
        val_data = generate_val_batches(args.data_root_dir, val_list, net_input_shape, net=args.net,
                                        batchSize=args.batch_size, numSlices=args.slices, subSampAmt=0,
                                        stride=20, shuff=args.shuffle_data)
        history = model.fit_generator(train_batches.it, max_queue_size=40, workers=4, use_multiprocessing=False,
                                      steps_per_epoch=247, validation_data=val_data.it, validation_steps=200,
                                      callbacks=callbacks, verbose=1)

        # Plot the training data collected
        if args.net == 'segcapsr3':
            # Plot the training data collected
            print(history.history.keys())
            loss_vec[i] = history.history['loss'][0]
            dice_hard_vec[i] = history.history['out_seg_dice_hard'][0]
            val_loss_vec[i] = history.history['val_loss'][0]
            val_dice_hard_vec[i] = history.history['val_out_seg_dice_hard'][0]
            val_out_seg_loss_vec[i] = history.history['val_out_seg_loss'][0]
            out_seg_loss_vec[i] = history.history['out_seg_loss'][0]
            val_out_recon_loss_vec[i] = history.history['val_out_recon_loss'][0]
            out_recon_loss_vec[i] = history.history['out_recon_loss'][0]
            file2 = open(join(args.output_dir, '_errors_' + str(i) + '.txt'), "w+")
            file2.writelines(
                [str(loss_vec), str(dice_hard_vec), str(out_seg_loss_vec), str(out_recon_loss_vec), str(val_loss_vec),
                 str(val_dice_hard_vec), str(val_out_seg_loss_vec), str(val_out_recon_loss_vec)])
            file2.close()
        else:
            loss_vec[i] = history.history['loss'][0]
            dice_hard_vec[i] = history.history['dice_hard'][0]
            val_loss_vec[i] = history.history['val_loss'][0]
            val_dice_hard_vec[i] = history.history['val_dice_hard'][0]
            file2 = open(join(args.output_dir, '_errors_' + str(i) + '.txt'), "w+")
            file2.writelines([str(loss_vec), str(dice_hard_vec), str(val_loss_vec), str(val_dice_hard_vec)])
            file2.close()
    plot_training(loss_vec, dice_hard_vec, val_loss_vec, val_dice_hard_vec, args, num_epoch)
Example #3
0
def train(args, train_list, val_list, u_model, net_input_shape):
    # Compile the loaded model
    model = compile_model(args=args,
                          net_input_shape=net_input_shape,
                          uncomp_model=u_model)
    weights_path = args.weights_path
    if args.retrain == 1:
        print('\nRetrain model from weights_path=%s' % (weights_path))
        model.load_weights(weights_path)
    # Set the callbacks
    callbacks = get_callbacks(args)

    # Training the network
    history = model.fit(
        generate_train_batches(args.data_root_dir,
                               train_list,
                               net_input_shape,
                               net=args.net,
                               batchSize=args.batch_size,
                               numSlices=args.slices,
                               subSampAmt=args.subsamp,
                               stride=args.stride,
                               shuff=args.shuffle_data,
                               aug_data=args.aug_data),
        max_queue_size=40,
        workers=4,
        use_multiprocessing=False,
        steps_per_epoch=args.steps,
        validation_data=generate_val_batches(args.data_root_dir,
                                             val_list,
                                             net_input_shape,
                                             net=args.net,
                                             batchSize=args.batch_size,
                                             numSlices=args.slices,
                                             subSampAmt=0,
                                             stride=20,
                                             shuff=args.shuffle_data),
        validation_steps=
        250,  # Set validation stride larger to see more of the data.
        epochs=args.epochs,
        callbacks=callbacks,
        verbose=1)

    # Plot the training data collected
    plot_training(history, args)
Example #4
0
def train(args, train_list, val_list, u_model, net_input_shape):
    # Compile the loaded model
    model = compile_model(args=args,
                          net_input_shape=net_input_shape,
                          uncomp_model=u_model)
    # Set the callbacks
    callbacks = get_callbacks(args)
    # print(callbacks)
    train_batches = generate_train_batches(args.data_root_dir,
                                           train_list,
                                           net_input_shape,
                                           net=args.net,
                                           batchSize=args.batch_size,
                                           numSlices=args.slices,
                                           subSampAmt=args.subsamp,
                                           stride=args.stride,
                                           shuff=args.shuffle_data,
                                           aug_data=args.aug_data)

    val_batches = generate_val_batches(args.data_root_dir,
                                       val_list,
                                       net_input_shape,
                                       net=args.net,
                                       batchSize=args.batch_size,
                                       numSlices=args.slices,
                                       subSampAmt=0,
                                       stride=20,
                                       shuff=args.shuffle_data)
    print("train_batches..", train_batches)
    # Training the network
    history = model.fit_generator(
        train_batches,
        max_queue_size=40,
        workers=4,
        use_multiprocessing=False,
        steps_per_epoch=10000,
        validation_data=val_batches,
        validation_steps=
        500,  # Set validation stride larger to see more of the data.
        epochs=200,
        callbacks=callbacks,
        verbose=1)

    # Plot the training data collected
    plot_training(history, args)
Example #5
0
def train(args, train_list, val_list, u_model, net_input_shape):
    # Compile the loaded model
    model = compile_model(args=args, net_input_shape=net_input_shape, uncomp_model=u_model)
    # Set the callbacks
    callbacks = get_callbacks(args)

    # Training the network
    history = model.fit_generator(
        generate_train_batches(args.data_root_dir, train_list, net_input_shape, net=args.net,
                               batchSize=args.batch_size, numSlices=args.slices, subSampAmt=args.subsamp,
                               stride=args.stride, shuff=args.shuffle_data, aug_data=args.aug_data),
        max_queue_size=40, workers=4, use_multiprocessing=False,
        steps_per_epoch=10000,
        validation_data=generate_val_batches(args.data_root_dir, val_list, net_input_shape, net=args.net,
                                             batchSize=args.batch_size,  numSlices=args.slices, subSampAmt=0,
                                             stride=20, shuff=args.shuffle_data),
        validation_steps=500, # Set validation stride larger to see more of the data.
        epochs=200,
        callbacks=callbacks,
        verbose=1)

    # Plot the training data collected
    plot_training(history, args)
def train(args, train_list, val_list, u_model, net_input_shape):
    # Compile the loaded model
    model = compile_model(args=args,
                          net_input_shape=net_input_shape,
                          uncomp_model=u_model)
    # Set the callbacks
    callbacks = get_callbacks(args)

    # Training the network
    history = model.fit(
        generate_train_batches(args.data_root_dir,
                               train_list,
                               net_input_shape,
                               net=args.net,
                               batch_size=args.batch_size,
                               shuff=args.shuffle_data,
                               aug_data=args.aug_data),
        max_queue_size=40,
        workers=0,
        use_multiprocessing=False,
        steps_per_epoch=ceil(len(train_list) / args.batch_size),
        validation_data=generate_val_batches(args.data_root_dir,
                                             val_list,
                                             net_input_shape,
                                             net=args.net,
                                             batch_size=args.batch_size,
                                             shuff=args.shuffle_data),
        validation_steps=ceil(
            len(val_list) / args.batch_size
        ),  # Set validation stride larger to see more of the data.
        epochs=50,
        callbacks=callbacks,
        verbose=1)

    # Plot the training data collected
    plot_training(history, args)