示例#1
0
def train_and_predict():

    print('-' * 30)
    print('Creating and compiling model...')
    print('-' * 30)

    model = DenseUNet(reduction=0.5)
    import ipdb
    ipdb.set_trace()
    model.load_weights(args.model_weight, by_name=True)

    if args.n_gpus > 1:
        print("Using %d GPUs" % (args.n_gpus))
        model = make_parallel(model,
                              args.n_gpus,
                              mini_batch=max(args.batch_size / args.n_gpus, 1))
    sgd = SGD(lr=1e-3, momentum=0.9, nesterov=True)
    model.compile(optimizer=sgd, loss=[weighted_crossentropy_2ddense])

    trainidx, img_list, tumor_list, tumorlines, liverlines, tumoridx, liveridx, minindex_list, maxindex_list = load_fast_files(
        args)

    print('-' * 30)
    print('Fitting model......')
    print('-' * 30)

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    if not os.path.exists(args.save_path + "/model"):
        os.mkdir(args.save_path + '/model')
        os.mkdir(args.save_path + '/history')
    else:
        if os.path.exists(args.save_path + "/history/lossbatch.txt"):
            os.remove(args.save_path + '/history/lossbatch.txt')
        if os.path.exists(args.save_path + "/history/lossepoch.txt"):
            os.remove(args.save_path + '/history/lossepoch.txt')

    model_checkpoint = ModelCheckpoint(
        args.save_path + '/model/weights.{epoch:02d}-{loss:.2f}.hdf5',
        monitor='loss',
        verbose=1,
        save_best_only=False,
        save_weights_only=False,
        mode='min',
        period=1)

    model.fit_generator(
        generate_arrays_from_file(args.batch_size, trainidx, img_list,
                                  tumor_list, tumorlines, liverlines, tumoridx,
                                  liveridx, minindex_list, maxindex_list),
        steps_per_epoch=20,  #lizx: lizx changed to 1
        epochs=50000,
        verbose=2,
        callbacks=[model_checkpoint],
        max_queue_size=10,
        workers=1,
        use_multiprocessing=False)

    print('Finised Training .......')
def train_and_predict():

    print('-'*30)
    print('Creating and compiling model...')
    print('-'*30)

    model = DenseUNet(reduction=0.5, args=args)
    model.load_weights(args.model_weight, by_name=True)
    model = make_parallel(model, int(args.b/6), mini_batch=6)
    sgd = SGD(lr=1e-3, momentum=0.9, nesterov=True)
    model.compile(optimizer=sgd, loss=[weighted_crossentropy_2ddense], metrics=[dice_liver, dice_lesion])
    #model.compile(optimizer='adam', loss=[weighted_crossentropy_2ddense], metrics=[dice_liver, dice_lesion])
    trainidx, img_list, tumor_list, tumorlines, liverlines, tumoridx, liveridx, minindex_list, maxindex_list = load_fast_files(args)

    print('-'*30)
    print('Fitting model......')
    print('-'*30)
   
    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    if not os.path.exists(args.save_path + "/model"):
        os.mkdir(args.save_path + '/model')
        os.mkdir(args.save_path + '/history')
    else:
        if os.path.exists(args.save_path+ "/history/lossbatch.txt"):
            os.remove(args.save_path + '/history/lossbatch.txt')
        if os.path.exists(args.save_path + "/history/lossepoch.txt"):
            os.remove(args.save_path + '/history/lossepoch.txt')

    model_checkpoint = ModelCheckpoint(args.save_path + '/model/weights.{epoch:02d}-{loss:.2f}.hdf5', monitor='loss', verbose = 1, save_best_only=False,save_weights_only=False,mode = 'min', period = 1)
    steps = 27386/args.b
    model.fit_generator(generate_arrays_from_file(args.b, trainidx, img_list, tumor_list, tumorlines, liverlines, tumoridx, liveridx, minindex_list, maxindex_list),steps_per_epoch=steps, epochs= 600, verbose=1, callbacks=[model_checkpoint], max_queue_size=10, workers=3, use_multiprocessing=True)

    print ('Finised Training .......')
示例#3
0
def train_and_predict():

    print('-'*30)
    print('Creating and compiling model...')
    print('-'*30)

    model = DenseUNet(reduction=0.5, args=args)
    model.load_weights(args.model_weight, by_name=True)
    model = make_parallel(model, args.b / 10, mini_batch=10)
    sgd = SGD(lr=1e-3, momentum=0.9, nesterov=True)
    model.compile(optimizer=sgd, loss=[weighted_crossentropy_2ddense], metrics=['accuracy'])

    #Load Training Data
    trainidx, img_list, tumor_list, tumorlines, liverlines, tumoridx, liveridx, minindex_list, maxindex_list = load_fast_files(args)
    
    #Load Validation Data
    vtrainidx, vimg_list, vtumor_list, vtumorlines, vliverlines, vtumoridx, vliveridx, vminindex_list, vmaxindex_list = val_load_fast_files(args)


    print('-'*30)
    print('Fitting model......')
    print('-'*30)

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    if not os.path.exists(args.save_path + "/model"):
        os.mkdir(args.save_path + '/model')
        os.mkdir(args.save_path + '/history')
    else:
        if os.path.exists(args.save_path+ "/history/lossbatch.txt"):
            os.remove(args.save_path + '/history/lossbatch.txt')
        if os.path.exists(args.save_path + "/history/lossepoch.txt"):
            os.remove(args.save_path + '/history/lossepoch.txt')

    tensor_board = TensorBoard(logdir='./logs', histogram_freq=0, write_graph=True, write_images=False)


    model_checkpoint = ModelCheckpoint(args.save_path + '/model/weights.{epoch:02d}-{loss:.2f}.hdf5', monitor='loss', verbose = 1,
                                       save_best_only=False,save_weights_only=False,mode = 'min', period = 1)


    steps = 27386 / args.b
    model.fit_generator(generate_arrays_from_file(args.b, trainidx, img_list, tumor_list, tumorlines, liverlines, tumoridx,
                                                  liveridx, minindex_list, maxindex_list), steps_per_epoch=steps,
                                                    epochs= 6000, verbose = 1, callbacks = [model_checkpoint, tensor_board], 
                                                    validation_data = generate_arrays_from_file(args.b / 2, vtrainidx, vimg_list, vtumor_list, vtumorlines, vliverlines, vtumoridx,
                                                 vliveridx, vminindex_list, vmaxindex_list), validation_steps=30, validation_freq=5, 
                                                    max_queue_size=10, workers=3, use_multiprocessing=True)

    print ('Finised Training .......')