예제 #1
0
def SegmentDist(File1, File2, OutDr, name):

    cm.mkdir(OutDr + '/Temp')
    cmd1 = subprocess.getoutput(ITKToolsBinDir +
                                '/pxsegmentationdistance -in ' + File1 + ' ' +
                                File2 + ' -car true -out ' + OutDr + '/Temp' +
                                '/out.mhd')
    cmd2 = subprocess.getoutput(ITKToolsBinDir + '/pxunaryimageoperator -in ' +
                                OutDr + '/Temp' +
                                '/outDIST.mhd -ops ABS -out ' + OutDr +
                                '/Temp' + '/outDISTabs.mhd')
    cmd3 = subprocess.getoutput(ITKToolsBinDir + '/pxstatisticsonimage -in ' +
                                OutDr + '/Temp' +
                                '/outDISTabs.mhd -s arithmetic -mask ' +
                                OutDr + '/Temp' + '/outEDGE.mhd')
    shutil.rmtree(OutDr + '/Temp')

    # text_file = open(OutDr + '/surface_distance_' + name + '.txt', 'w')
    # text_file.write(str(cmd1))
    # text_file.close()
    #
    # text_file = open(OutDr + '/surface_distance_' + name + '.txt', 'a')
    # text_file.write(str(cmd2))
    # text_file.close()

    text_file = open(OutDr + '/surface_distance_' + name + '.txt', 'w')
    text_file.write(str(cmd3))
    text_file.close()

    return [cmd1, cmd2, cmd3]
  def __init__(self, train_set_x, savePaths, model, perSample):
    super(callbacks.Callback, self).__init__()

    self.train_set_x = train_set_x
    self.savePath = savePaths
    self.model = model
    self.perSample = perSample
    cm.mkdir(self.savePath + 'gradientsPerEpoch/')
def SegmentDist(File1, File2, OutDr):

    cm.mkdir(OutDr + '/Temp')
    cmd1 = subprocess.getoutput(ITKToolsBinDir +
                                '/pxsegmentationdistance -in ' + File1 + ' ' +
                                File2 + ' -car true -out ' + OutDr + '/Temp' +
                                '/out.mhd')
    cmd2 = subprocess.getoutput(ITKToolsBinDir + '/pxunaryimageoperator -in ' +
                                OutDr + '/Temp' +
                                '/outDIST.mhd -ops ABS -out ' + OutDr +
                                '/Temp' + '/outDISTabs.mhd')
    cmd3 = subprocess.getoutput(ITKToolsBinDir + '/pxstatisticsonimage -in ' +
                                OutDr + '/Temp' +
                                '/outDISTabs.mhd -s arithmetic -mask ' +
                                OutDr + '/Temp' + '/outEDGE.mhd')
    shutil.rmtree(OutDr + '/Temp')
    return [cmd1, cmd2, cmd3]
  def on_epoch_end(self, epoch, logs={}):

    absGrad, layer_names, gradPerSample = self.compute_gradients(self.model, self.train_set_x)

    # save overall gradient
    path_overall = self.savePath + 'gradients.csv'
    if epoch == 0:
      self.writeNamesCSV(layer_names, path_overall)
    self.writeCSV(absGrad, path_overall)

    # save gradients of the current epoch
    if self.perSample:
      path_epoch = self.savePath + 'gradientsPerEpoch' + str(epoch) + '/'
      path_epoch_csv = path_epoch + 'gradients.csv'
      cm.mkdir(path_epoch)
      self.writeNamesCSV(layer_names, path_epoch_csv)
      for i in range(len(self.train_set_x)):
        self.writeCSV(gradPerSample[i], path_epoch_csv)
예제 #5
0
    filelist.write('\n')
filelist.close()

# Load file:

axis_process = "Axial"  ## axis you want to process
# axis_process = "Sagittal"
# axis_process = "Coronal"

# Training set:
print('-' * 30)
print('Loading files...')
print('-' * 30)

vol_slices = []
cm.mkdir(cm.workingPath.trainingPatchesSet_path)

for nb_file in range(len(originFile_list)):

    out_images = []
    out_masks = []

    # Read information from dcm file:
    originVol, originVol_num, originVolwidth, originVolheight = dp.loadFile(
        originFile_list[nb_file])
    maskAortaVol, maskAortaVol_num, maskAortaVolwidth, maskAortaVolheight = dp.loadFile(
        maskAortaFile_list[nb_file])
    maskPulVol, maskPulVol_num, maskPulVolwidth, maskPulVolheight = dp.loadFile(
        maskPulFile_list[nb_file])

    maskVol = maskAortaVol
def train_and_predict(use_existing):

    cm.mkdir(cm.workingPath.model_path)
    cm.mkdir(cm.workingPath.best_model_path)
    cm.mkdir(cm.workingPath.visual_path)

    # lrate = callbacks.LearningRateScheduler(cb.step_decay)

    print('-' * 30)
    print('Loading and preprocessing train data...')
    print('-' * 30)

    # Scanning training data list:
    originFile_list = sorted(
        glob(cm.workingPath.trainingPatchesSet_path + 'img_*.npy'))
    mask_list = sorted(
        glob(cm.workingPath.trainingPatchesSet_path + 'mask_*.npy'))

    # imgs_train = np.load(cm.workingPath.home_path + 'trainImages3D16.npy')
    # imgs_mask_train = np.load(cm.workingPath.home_path + 'trainMasks3D16.npy')
    # originFile_list = sorted(glob(cm.workingPath.training3DSet_path + 'trainImages_1.npy'))
    # mask_list = sorted(glob(cm.workingPath.training3DSet_path + 'trainMasks_1.npy'))
    #
    full_list = list(zip(originFile_list, mask_list))
    #
    shuffle(full_list)
    originFile_list, mask_list = zip(*full_list)

    # Scanning validation data list:
    originValFile_list = sorted(
        glob(cm.workingPath.validationSet_path + 'valImages.npy'))
    maskVal_list = sorted(
        glob(cm.workingPath.validationSet_path + 'valMasks.npy'))

    x_val = np.load(originValFile_list[0])
    y_val = np.load(maskVal_list[0])

    # Calculate the total amount of training sets:
    # nb_file = int(len(originFile_list))
    # nb_val_file = int(len(originValFile_list))

    print('_' * 30)
    print('Creating and compiling model...')
    print('_' * 30)

    # Select the model you want to train:
    # model = nw.get_3D_unet()
    # model = nw.get_3D_Eunet()
    # model = DenseUNet_3D.get_3d_denseunet()
    model = UNet_3D.get_3d_unet()
    # model = UNet_3D.get_3d_wnet(opti)
    # model = RSUNet_3D.get_3d_rsunet(opti)
    # model = RSUNet_3D_Gerda.get_3d_rsunet_Gerdafeature(opti)

    # Plot the model:
    modelname = 'model.png'
    plot_model(model,
               show_shapes=True,
               to_file=cm.workingPath.model_path + modelname)
    model.summary()

    # Should we load existing weights?
    if use_existing:
        model.load_weights(cm.workingPath.model_path + './unet.hdf5')

    print('-' * 30)
    print('Fitting model...')
    print('-' * 30)

    # Callbacks:
    filepath = cm.workingPath.model_path + 'weights.{epoch:02d}-{loss:.5f}-{val_loss:.5f}.hdf5'
    bestfilepath = cm.workingPath.model_path + 'Best_weights.{epoch:02d}-{loss:.5f}-{val_loss:.5f}.hdf5'

    model_checkpoint = callbacks.ModelCheckpoint(filepath,
                                                 monitor='val_loss',
                                                 verbose=0,
                                                 save_best_only=False)
    model_best_checkpoint = callbacks.ModelCheckpoint(bestfilepath,
                                                      monitor='val_loss',
                                                      verbose=0,
                                                      save_best_only=True)
    # history = cb.LossHistory_lr()
    record_history = cb.RecordLossHistory()
    # model_history = callbacks.TensorBoard(log_dir='./logs', histogram_freq=1, write_graph=True, write_images=True,
    #  							embeddings_freq=1, embeddings_layer_names=None, embeddings_metadata= None)
    callbacks_list = [record_history, model_best_checkpoint]

    #model.fit(imgs_train, imgs_mask_train, batch_size=1, epochs=4000, verbose=1, shuffle=True,
    #          validation_split=0.1, callbacks=callbacks_list)

    model_info = model.fit_generator(
        ManyFilesBatchGenerator(originFile_list, mask_list,
                                batch_size=1),  # BATCH_SIZE
        nb_epoch=4000,
        verbose=1,
        shuffle=True,
        validation_data=(x_val, y_val),
        callbacks=callbacks_list)

    print('training finished')
예제 #7
0
    filelist.write('\n')
filelist.close()

# Load file:

axis_process = "Axial"  ## axis you want to process
# axis_process = "Sagittal"
# axis_process = "Coronal"

# Training set:
print('-' * 30)
print('Loading files...')
print('-' * 30)

vol_slices = []
cm.mkdir(cm.workingPath.validationPatchesSet_path)

for nb_file in range(len(originFile_list)):

    out_images = []
    out_masks = []

    # Read information from dcm file:
    originVol, originVol_num, originVolwidth, originVolheight = dp.loadFile(
        originFile_list[nb_file])
    maskAortaVol, maskAortaVol_num, maskAortaVolwidth, maskAortaVolheight = dp.loadFile(
        maskAortaFile_list[nb_file])
    maskPulVol, maskPulVol_num, maskPulVolwidth, maskPulVolheight = dp.loadFile(
        maskPulFile_list[nb_file])
    maskVol = maskAortaVol
def train_and_predict(use_existing):

    cm.mkdir(cm.workingPath.model_path)
    cm.mkdir(cm.workingPath.best_model_path)
    cm.mkdir(cm.workingPath.tensorboard_path)

    # class LossHistory(callbacks.Callback):
    #   def on_train_begin(self, logs={}):
    #     self.losses = []
    #     self.val_losses = []
    #     self.sd = []
    #
    #   def on_epoch_end(self, epoch, logs={}):
    #     self.losses.append(logs.get('loss'))
    #
    #     self.val_losses.append(logs.get('val_loss'))
    #
    #     self.sd.append(step_decay(len(self.losses)))
    #     print('\nlr:', step_decay(len(self.losses)))
    #     lrate_file = list(self.sd)
    #     np.savetxt(cm.workingPath.model_path + 'lrate.txt', lrate_file, newline='\r\n')
    #
    # learning_rate = 0.00001
    #
    # adam = Adam(lr=learning_rate)
    #
    # opti = adam
    #
    # def step_decay(losses):
    #   if len(history.losses)==0:
    #     lrate = 0.00001
    #     return lrate
    #   elif float(2 * np.sqrt(np.array(history.losses[-1]))) < 1.0:
    #     lrate = 0.00001 * 1.0 / (1.0 + 0.1 * len(history.losses))
    #     return lrate
    #   else:
    #     lrate = 0.00001
    #     return lrate
    #
    # history = LossHistory()
    # lrate = callbacks.LearningRateScheduler(step_decay)

    print('-' * 30)
    print('Loading and preprocessing train data...')
    print('-' * 30)

    # Choose which subset you would like to use:

    imgs_train = np.load(cm.workingPath.home_path + 'trainImages3D16.npy')
    imgs_mask_train = np.load(cm.workingPath.home_path + 'trainMasks3D16.npy')
    # imgs_train = np.load(cm.workingPath.home_path + 'trainImages3Dtest.npy')
    # imgs_mask_train = np.load(cm.workingPath.home_path + 'trainMasks3Dtest.npy')

    # x_val = np.load(cm.workingPath.validationSet_path + 'valImages.npy')
    # y_val = np.load(cm.workingPath.validationSet_path + 'valMasks.npy')

    # imgs_train = np.load(cm.workingPath.home_path + 'vesselImages.npy')
    # imgs_mask_train = np.load(cm.workingPath.home_path + 'vesselMasks.npy')
    # x_val = np.load(cm.workingPath.home_path + 'vesselValImages.npy')
    # y_val = np.load(cm.workingPath.home_path + 'vesselValMasks.npy')

    # imgs_train = np.load(cm.workingPath.trainingSet_path + 'trainImages_0000.npy')
    # imgs_mask_train = np.load(cm.workingPath.trainingSet_path + 'trainMasks_0000.npy')

    print('_' * 30)
    print('Creating and compiling model...')
    print('_' * 30)

    # model = DenseUNet_3D.get_3d_denseunet()
    # model = CRFRNN.get_3d_crfrnn_model_def()
    model = UNet_3D.get_3d_unet_bn()
    # model = UNet_2D.get_2d_unet_crf()
    # model = UNet_3D.get_3d_wnet1()
    # model = RSUNet_3D.get_3d_rsunet()
    # model = crfrnn_model.get_3d_crfrnn_model_def()
    # model = CNN.get_3d_cnn()

    modelname = 'model.png'
    plot_model(model,
               show_shapes=True,
               to_file=cm.workingPath.model_path + modelname)
    model.summary()

    # Callbacks:
    filepath = cm.workingPath.model_path + 'weights.{epoch:02d}-{loss:.5f}-{val_loss:.5f}.hdf5'
    bestfilepath = cm.workingPath.model_path + 'Best_weights.{epoch:02d}-{loss:.5f}-{val_loss:.5f}.hdf5'

    model_checkpoint = callbacks.ModelCheckpoint(filepath,
                                                 monitor='val_loss',
                                                 verbose=0,
                                                 save_best_only=False)
    model_best_checkpoint = callbacks.ModelCheckpoint(bestfilepath,
                                                      monitor='val_loss',
                                                      verbose=0,
                                                      save_best_only=True)

    record_history = cb.RecordLossHistory()
    # record_gradients = cb.recordGradients_Florian(x_val, cm.workingPath.gradient_path, model, False)

    # tbCallBack = callbacks.TensorBoard(log_dir=cm.workingPath.tensorboard_path, histogram_freq=1, write_graph=False,
    #                                    write_images=False, write_grads=False, batch_size=1)

    callbacks_list = [record_history, model_best_checkpoint]

    # Should we load existing weights?
    # Set argument for call to train_and_predict to true at end of script
    if use_existing:
        model.load_weights('./unet.hdf5')

    print('-' * 30)
    print('Fitting model...')
    print('-' * 30)

    print(imgs_train.shape)
    print(imgs_mask_train.shape)

    model.fit(imgs_train,
              imgs_mask_train,
              batch_size=1,
              epochs=400,
              verbose=1,
              shuffle=True,
              validation_split=0.1,
              callbacks=callbacks_list)

    print('training finished')
def train_and_predict(use_existing):
    cm.mkdir(cm.workingPath.model_path)
    cm.mkdir(cm.workingPath.best_model_path)
    print('-' * 30)
    print('Loading and preprocessing train data...')
    print('-' * 30)

    # Choose which subset you would like to use:

    imgs_train = np.load(cm.workingPath.home_path + 'trainImages3D16.npy')
    imgs_mask_train = np.load(cm.workingPath.home_path + 'trainMasks3D16.npy')
    # imgs_train = np.load(cm.workingPath.home_path + 'trainImages3Dtest.npy')
    # imgs_mask_train = np.load(cm.workingPath.home_path + 'trainMasks3Dtest.npy')
    # imgs_train = np.load(cm.workingPath.trainingSet_path + 'trainImages_0000.npy')
    # imgs_mask_train = np.load(cm.workingPath.trainingSet_path + 'trainMasks_0000.npy')

    # imgs_val = np.load(cm.workingPath.validationSet_path + 'valImages3D.npy')
    # imgs_mask_val = np.load(cm.workingPath.validationSet_path + 'valMasks3D.npy')

    # imgs_train = np.load(cm.workingPath.training3DSet_path + 'trainImages_0000.npy')
    # imgs_mask_train = np.load(cm.workingPath.training3DSet_path + 'trainMasks_0000.npy')

    print('_' * 30)
    print('Creating and compiling model...')
    print('_' * 30)

    # model = DenseUNet_3D.get_3d_denseunet()
    # model = LSTM_UNet_3D.time_GRU_unet_1_level()
    model = UNet_3D.get_3d_unet()
    # model = ResNet_3D.get_3d_resnet_34()

    modelname = 'model.png'
    plot_model(model,
               show_shapes=True,
               to_file=cm.workingPath.model_path + modelname)
    model.summary()

    # Callbacks:
    filepath = cm.workingPath.model_path + 'weights.{epoch:02d}-{loss:.5f}.hdf5'
    bestfilepath = cm.workingPath.model_path + 'Best_weights.{epoch:02d}-{loss:.5f}.hdf5'

    model_checkpoint = callbacks.ModelCheckpoint(filepath,
                                                 monitor='val_loss',
                                                 verbose=0,
                                                 save_best_only=False)
    model_best_checkpoint = callbacks.ModelCheckpoint(bestfilepath,
                                                      monitor='val_loss',
                                                      verbose=0,
                                                      save_best_only=True)

    # history = cm.LossHistory_Gerda(cm.workingPath.working_path)
    history = cb.LossHistory()
    # model_history = callbacks.TensorBoard(log_dir='./logs', histogram_freq=1, write_graph=True, write_images=True,
    #  							embeddings_freq=1, embeddings_layer_names=None, embeddings_metadata= None)

    callbacks_list = [history, model_best_checkpoint]

    # Should we load existing weights?
    # Set argument for call to train_and_predict to true at end of script
    if use_existing:
        model.load_weights('./unet.hdf5')

    print('-' * 30)
    print('Fitting model...')
    print('-' * 30)

    model.fit(imgs_train,
              imgs_mask_train,
              batch_size=1,
              epochs=4000,
              verbose=1,
              shuffle=True,
              validation_split=0.1,
              callbacks=callbacks_list)

    print('training finished')
def train_and_predict(use_existing, x_train, x_val, y_train, y_val, cross, pre_lrate):
  cm.mkdir(cm.workingPath.model_path)
  cm.mkdir(cm.workingPath.best_model_path)

  class LossHistory(callbacks.Callback):
    def on_train_begin(self, logs={}):
      self.losses = [1]
      self.val_losses = []
      self.lr = []

    def on_epoch_end(self, epoch, logs={}):
      self.losses.append(logs.get('loss'))
      loss_file = (list(self.losses))
      np.savetxt(cm.workingPath.model_path + 'loss_' + 'Val.%02d.txt' % (cross), loss_file[1:], newline='\r\n')

      self.val_losses.append(logs.get('val_loss'))
      val_loss_file = (list(self.val_losses))
      np.savetxt(cm.workingPath.model_path + 'val_loss_' + 'Val.%02d.txt' % (cross), val_loss_file, newline='\r\n')

      self.lr.append(step_decay(len(self.losses)))
      print('\nLearning rate:', step_decay(len(self.losses)))
      lrate_file = (list(self.lr))
      np.savetxt(cm.workingPath.model_path + 'lrate_' + 'Val.%02d.txt' % (cross), lrate_file, newline='\r\n')

  if cross == 0:
    learning_rate = 0.0001
    decay_rate = 5e-6
    momentum = 0.9
  else:
    learning_rate = pre_lrate
    decay_rate = 5e-6
    momentum = 0.9

  # sgd = SGD(lr=learning_rate, momentum=momentum, decay=decay_rate, nesterov=False)
  adam = Adam(lr=learning_rate, decay=decay_rate)

  opti = adam

  def step_decay(losses):
    if float(2 * np.sqrt(np.array(history.losses[-1]))) < 100:
      if cross == 0:
        lrate = 0.0001 * 1 / (1 + 0.1 * len(history.losses))
        # lrate = 0.0001
        momentum = 0.8
        decay_rate = 2e-6
      else:
        lrate = pre_lrate * 1 / (1 + 0.1 * len(history.losses))
        # lrate = 0.0001
        momentum = 0.8
        decay_rate = 2e-6
      return lrate
    else:
      if cross == 0:
        lrate = 0.0001
      else:
        lrate = pre_lrate
      return lrate

  history = LossHistory()
  lrate = LearningRateScheduler(step_decay)

  print('_' * 30)
  print('Creating and compiling model...')
  print('_' * 30)

  # model = nw.get_simple_unet(opti)
  # model = nw.get_shallow_unet(sgd)
  model = nw.get_unet(opti)
  # model = nw.get_dropout_unet()
  # model = nw.get_unet_less_feature()
  # model = nw.get_unet_dilated_conv_4()
  # model = nw.get_unet_dilated_conv_7()
  # model = nw.get_2D_Deeply_supervised_network()

  modelname = 'model.png'
  # plot_model(model, show_shapes=True, to_file=cm.workingPath.model_path + modelname)
  model.summary()

  # Callbacks:

  filepath = cm.workingPath.model_path + 'Val.%02d_' % (cross) + 'weights.{epoch:02d}-{loss:.5f}.hdf5'
  bestfilepath = cm.workingPath.best_model_path + 'Val.%02d_' % (cross) + 'Best_weights.{epoch:02d}-{loss:.5f}.hdf5'
  unet_hdf5_path = cm.workingPath.working_path + 'unet.hdf5'

  model_checkpoint = callbacks.ModelCheckpoint(filepath, monitor='loss', verbose=0, save_best_only=True)
  model_best_checkpoint = callbacks.ModelCheckpoint(bestfilepath, monitor='val_loss', verbose=0, save_best_only=True)
  model_best_unet_hdf5 = callbacks.ModelCheckpoint(unet_hdf5_path, monitor='val_loss', verbose=0, save_best_only=True)

  history = LossHistory()

  callbacks_list = [history, lrate, model_checkpoint, model_best_checkpoint, model_best_unet_hdf5]

  # Should we load existing weights?
  # Set argument for call to train_and_predict to true at end of script
  if use_existing:
    model.load_weights('./unet.hdf5')

  print('-' * 30)
  print('Fitting model...')
  print('-' * 30)

  model.fit(x_train, y_train, batch_size=4, epochs=50, verbose=1, shuffle=True,
            validation_data=(x_val, y_val), callbacks=callbacks_list)

  print('training finished')
def train_and_predict(use_existing):

    cm.mkdir(cm.workingPath.model_path)
    cm.mkdir(cm.workingPath.best_model_path)

    class LossHistory(callbacks.Callback):
        def on_train_begin(self, logs={}):
            self.losses = []
            self.val_losses = []
            self.sd = []

        def on_epoch_end(self, epoch, logs={}):
            self.losses.append(logs.get('loss'))

            self.val_losses.append(logs.get('val_loss'))

            self.sd.append(step_decay(len(self.losses)))
            print('\nlr:', step_decay(len(self.losses)))
            lrate_file = list(self.sd)
            np.savetxt(cm.workingPath.model_path + 'lrate.txt',
                       lrate_file,
                       newline='\r\n')

    learning_rate = 0.00001

    adam = Adam(lr=learning_rate)

    opti = adam

    def step_decay(losses):
        if len(history.losses) == 0:
            lrate = 0.00001
            return lrate
        elif float(2 * np.sqrt(np.array(history.losses[-1]))) < 1.0:
            lrate = 0.00001 * 1.0 / (1.0 + 0.1 * len(history.losses))
            return lrate
        else:
            lrate = 0.00001
            return lrate

    history = LossHistory()
    lrate = callbacks.LearningRateScheduler(step_decay)

    print('-' * 30)
    print('Loading and preprocessing train data...')
    print('-' * 30)

    # Choose which subset you would like to use:

    imgs_train = np.load(cm.workingPath.home_path + 'trainImages3D16.npy')
    imgs_mask_train = np.load(cm.workingPath.home_path + 'trainMasks3D16.npy')

    # imgs_train = np.load(cm.workingPath.home_path + 'vesselImages.npy')
    # imgs_mask_train = np.load(cm.workingPath.home_path + 'vesselMasks.npy')

    # imgs_train = np.load(cm.workingPath.trainingPatchesSet_path + 'img_0000.npy')
    # imgs_mask_train = np.load(cm.workingPath.trainingPatchesSet_path + 'mask_0000.npy')
    # x_val = np.load(cm.workingPath.home_path + 'vesselValImages.npy')
    # y_val = np.load(cm.workingPath.home_path + 'vesselValMasks.npy')

    # x_val = np.load(cm.workingPath.validationSet_path + 'valImages.npy')
    # y_val = np.load(cm.workingPath.validationSet_path + 'valMasks.npy')

    print('_' * 30)
    print('Creating and compiling model...')
    print('_' * 30)

    # model = nw.get_3D_unet()
    # model = nw.get_3D_Eunet()
    # model = DenseUNet_3D.get_3d_denseunet()
    model = UNet_3D.get_3d_unet()
    # model = UNet_3D.get_3d_wnet(opti)
    # model = RSUNet_3D.get_3d_rsunet(opti)

    modelname = 'model.png'
    plot_model(model,
               show_shapes=True,
               to_file=cm.workingPath.model_path + modelname)
    model.summary()

    # Callbacks:
    filepath = cm.workingPath.model_path + 'weights.{epoch:02d}-{loss:.5f}-{val_loss:.5f}.hdf5'
    bestfilepath = cm.workingPath.model_path + 'Best_weights.{epoch:02d}-{loss:.5f}-{val_loss:.5f}.hdf5'

    model_checkpoint = callbacks.ModelCheckpoint(filepath,
                                                 monitor='val_loss',
                                                 verbose=0,
                                                 save_best_only=False)
    model_best_checkpoint = callbacks.ModelCheckpoint(bestfilepath,
                                                      monitor='val_loss',
                                                      verbose=0,
                                                      save_best_only=True)

    # history = cm.LossHistory_Gerda(cm.workingPath.working_path)
    record_history = cb.RecordLossHistory()
    # model_history = callbacks.TensorBoard(log_dir='./logs', histogram_freq=1, write_graph=True, write_images=True,
    #  							embeddings_freq=1, embeddings_layer_names=None, embeddings_metadata= None)

    callbacks_list = [record_history, model_best_checkpoint]

    # Should we load existing weights?
    # Set argument for call to train_and_predict to true at end of script
    if use_existing:
        model.load_weights('./unet.hdf5')

    print('-' * 30)
    print('Fitting model...')
    print('-' * 30)

    print(imgs_train.shape)
    print(imgs_mask_train.shape)

    model.fit(imgs_train,
              imgs_mask_train,
              batch_size=1,
              epochs=4000,
              verbose=1,
              shuffle=True,
              validation_split=0.1,
              callbacks=callbacks_list)

    print('training finished')
def train_and_predict(use_existing):

    cm.mkdir(cm.workingPath.model_path)
    cm.mkdir(cm.workingPath.best_model_path)
    cm.mkdir(cm.workingPath.visual_path)

    class LossHistory(callbacks.Callback):
        def on_train_begin(self, logs={}):
            self.losses = [1, 1]
            self.val_losses = []
            self.sd = []

        def on_epoch_end(self, epoch, logs={}):
            self.losses.append(logs.get('loss'))
            loss_file = list(self.losses)
            np.savetxt(cm.workingPath.model_path + 'loss.txt',
                       loss_file,
                       newline='\r\n')

            self.val_losses.append(logs.get('val_loss'))
            val_loss_file = list(self.val_losses)
            np.savetxt(cm.workingPath.model_path + 'val_loss.txt',
                       val_loss_file,
                       newline='\r\n')

            self.sd.append(step_decay(len(self.losses)))
            print('\nlr:', step_decay(len(self.losses)))
            lrate_file = list(self.sd)
            np.savetxt(cm.workingPath.model_path + 'lrate.txt',
                       lrate_file,
                       newline='\r\n')

    learning_rate = 0.0001
    decay_rate = 5e-6
    momentum = 0.9

    # sgd = SGD(lr=learning_rate, momentum=momentum, decay=decay_rate, nesterov=False)
    adam = Adam(lr=learning_rate, decay=decay_rate)

    opti = adam

    def step_decay(losses):
        if float(2 * np.sqrt(np.array(history.losses[-1]))) < 1.0:
            lrate = 0.0001 * 1 / (1 + 0.1 * len(history.losses))
            # lrate = 0.0001
            momentum = 0.8
            decay_rate = 2e-6
            return lrate
        else:
            lrate = 0.0001
            return lrate

    history = LossHistory()
    lrate = LearningRateScheduler(step_decay)

    print('-' * 30)
    print('Loading and preprocessing train data...')
    print('-' * 30)

    # Choose which subset you would like to use:
    i = 0

    imgs_train = np.load(cm.workingPath.trainingSet_path +
                         'trainImages_%04d.npy' % (i)).astype(np.float32)
    imgs_mask_train = np.load(cm.workingPath.trainingSet_path +
                              'trainMasks_%04d.npy' % (i)).astype(np.float32)

    # imgs_test = np.load(cm.workingPath.trainingSet_path + 'testImages_%04d.npy'%(i)).astype(np.float32)
    # imgs_mask_test_true = np.load(cm.workingPath.trainingSet_path + 'testMasks_%04d.npy'%(i)).astype(np.float32)

    # Mean for data centering:
    # mean= np.mean(imgs_train)
    # Std for data normalization:
    # std = np.std(imgs_train)

    # imgs_train -= mean
    # imgs_train /= std

    print('_' * 30)
    print('Creating and compiling model...')
    print('_' * 30)

    model = UNet_2D.get_unet(opti)
    # model = nw.get_shallow_unet(sgd)
    # model = nw.get_unet(opti)
    # model = nw.get_dropout_unet()
    # model = nw.get_unet_less_feature()
    # model = nw.get_unet_dilated_conv_4()
    # model = nw.get_unet_dilated_conv_7()
    # model = nw.get_2D_Deeply_supervised_network()

    modelname = 'model.png'
    plot_model(model,
               show_shapes=True,
               to_file=cm.workingPath.model_path + modelname)
    model.summary()
    # config = model.get_config()
    # print(config)

    # Callbacks:

    filepath = cm.workingPath.model_path + 'weights.{epoch:02d}-{loss:.5f}.hdf5'
    bestfilepath = cm.workingPath.model_path + 'Best_weights.{epoch:02d}-{loss:.5f}.hdf5'

    model_checkpoint = callbacks.ModelCheckpoint(filepath,
                                                 monitor='loss',
                                                 verbose=0,
                                                 save_best_only=True)
    model_best_checkpoint = callbacks.ModelCheckpoint(bestfilepath,
                                                      monitor='val_loss',
                                                      verbose=0,
                                                      save_best_only=True)

    # history = cm.LossHistory_Gerda(cm.workingPath.working_path)
    history = LossHistory()
    # model_history = callbacks.TensorBoard(log_dir='./logs', histogram_freq=1, write_graph=True, write_images=True,
    #  							embeddings_freq=1, embeddings_layer_names=None, embeddings_metadata= None)
    # gradients = cb.recordGradients_Florian(imgs_train, cm.workingPath.model_path, model, True)
    visual = vs.visualize_activation_in_layer(model, imgs_train[100])

    callbacks_list = [
        history, lrate, visual, model_checkpoint, model_best_checkpoint
    ]

    # Should we load existing weights?
    # Set argument for call to train_and_predict to true at end of script
    if use_existing:
        model.load_weights('./unet.hdf5')

    print('-' * 30)
    print('Fitting model...')
    print('-' * 30)

    temp_weights = model.get_weights()
    vs.plot_conv_weights(temp_weights[2], cm.workingPath.visual_path, 'conv_1')

    model.fit(imgs_train,
              imgs_mask_train,
              batch_size=8,
              epochs=1,
              verbose=1,
              shuffle=True,
              validation_split=0.1,
              callbacks=callbacks_list)

    print('training finished')
예제 #13
0
def model_test(use_existing):

    cm.mkdir(cm.workingPath.testingResults_path)
    cm.mkdir(cm.workingPath.testingNPY_path)

    # Loading test data:

    filename = cm.filename
    modelname = cm.modellist[0]

    # Single CT:
    originFile_list = sorted(
        glob(cm.workingPath.originTestingSet_path + filename))
    maskAortaFile_list = sorted(
        glob(cm.workingPath.aortaTestingSet_path + filename))
    maskPulFile_list = sorted(
        glob(cm.workingPath.pulTestingSet_path + filename))

    # Zahra CTs:
    # originFile_list = sorted(glob(cm.workingPath.originTestingSet_path + "vol*.dcm"))
    # maskAortaFile_list = sorted(glob(cm.workingPath.aortaTestingSet_path + "vol*.dcm"))
    # maskPulFile_list = sorted(glob(cm.workingPath.pulTestingSet_path + "vol*.dcm"))

    # Lidia CTs:
    # originFile_list = sorted(glob(cm.workingPath.originLidiaTestingSet_path + "vol*.dcm"))[61:]
    # maskAortaFile_list = sorted(glob(cm.workingPath.originLidiaTestingSet_path + "vol*.dcm"))[61:]
    # maskPulFile_list = sorted(glob(cm.workingPath.originLidiaTestingSet_path + "vol*.dcm"))[61:]

    # Abnormal CTs:
    # originFile_list = sorted(glob(cm.workingPath.originAbnormalTestingSet_path + "vol126*.dcm"))
    # maskAortaFile_list = sorted(glob(cm.workingPath.originAbnormalTestingSet_path + "vol126*.dcm"))
    # maskPulFile_list = sorted(glob(cm.workingPath.originAbnormalTestingSet_path + "vol126*.dcm"))

    for i in range(len(originFile_list)):

        # Show runtime:
        starttime = datetime.datetime.now()

        vol_slices = []
        out_test_images = []
        out_test_masks = []

        current_file = originFile_list[i].split('/')[-1]
        current_dir = cm.workingPath.testingResults_path + str(
            current_file[:-17])
        cm.mkdir(current_dir)
        cm.mkdir(current_dir + '/Plots/')
        cm.mkdir(current_dir + '/Surface_Distance/Aorta/')
        cm.mkdir(current_dir + '/Surface_Distance/Pul/')
        cm.mkdir(current_dir + '/DICOM/')
        cm.mkdir(current_dir + '/mhd/')

        stdout_backup = sys.stdout
        log_file = open(current_dir + "/logs.txt", "w")
        sys.stdout = log_file

        print('-' * 30)
        print('Loading test data %04d/%04d...' % (i + 1, len(originFile_list)))

        originVol, originVol_num, originVolwidth, originVolheight = dp.loadFile(
            originFile_list[i])
        maskAortaVol, maskAortaVol_num, maskAortaVolwidth, maskAortaVolheight = dp.loadFile(
            maskAortaFile_list[i])
        maskPulVol, maskPulVol_num, maskPulVolwidth, maskPulVolheight = dp.loadFile(
            maskPulFile_list[i])
        maskVol = maskAortaVol

        for j in range(len(maskAortaVol)):
            maskAortaVol[j] = np.where(maskAortaVol[j] != 0, 1, 0)
        for j in range(len(maskPulVol)):
            maskPulVol[j] = np.where(maskPulVol[j] != 0, 2, 0)

        maskVol = maskVol + maskPulVol

        for j in range(len(maskVol)):
            maskVol[j] = np.where(maskVol[j] > 2, 0, maskVol[j])
            # maskVol[j] = np.where(maskVol[j] != 0, 0, maskVol[j])

        # Make the Vessel class
        for j in range(len(maskVol)):
            maskVol[j] = np.where(maskVol[j] != 0, 1, 0)

        for i in range(originVol.shape[0]):
            img = originVol[i, :, :]

            out_test_images.append(img)
        for i in range(maskVol.shape[0]):
            img = maskVol[i, :, :]

            out_test_masks.append(img)

        vol_slices.append(originVol.shape[0])

        maskAortaVol = None
        maskPulVol = None
        maskVol = None
        originVol = None

        nb_class = 2
        outmasks_onehot = to_categorical(out_test_masks, num_classes=nb_class)
        final_test_images = np.ndarray([sum(vol_slices), 512, 512, 1],
                                       dtype=np.int16)
        final_test_masks = np.ndarray([sum(vol_slices), 512, 512, nb_class],
                                      dtype=np.int8)

        for i in range(len(out_test_images)):
            final_test_images[i, :, :, 0] = out_test_images[i]
            final_test_masks[i, :, :, :] = outmasks_onehot[i]

        outmasks_onehot = None
        out_test_masks = None
        out_test_images = None

        row = cm.img_rows_3d
        col = cm.img_cols_3d
        row_1 = int((512 - row) / 2)
        row_2 = int(512 - (512 - row) / 2)
        col_1 = int((512 - col) / 2)
        col_2 = int(512 - (512 - col) / 2)
        slices = cm.slices_3d
        gaps = cm.gaps_3d

        final_images_crop = final_test_images[:, row_1:row_2, col_1:col_2, :]
        final_masks_crop = final_test_masks[:, row_1:row_2, col_1:col_2, :]

        sitk.WriteImage(
            sitk.GetImageFromArray(np.uint16(final_test_masks[:, :, :, 1])),
            current_dir + '/DICOM/masksAortaGroundTruth.dcm')

        sitk.WriteImage(
            sitk.GetImageFromArray(np.uint16(final_test_masks[:, :, :, 1])),
            current_dir + '/mhd/masksAortaGroundTruth.mhd')

        # clear the masks for the final step:
        final_test_masks = np.where(final_test_masks == 0, 0, 0)

        num_patches = int((sum(vol_slices) - slices) / gaps)

        test_image = np.ndarray([1, slices, row, col, 1], dtype=np.int16)

        predicted_mask_volume = np.ndarray(
            [sum(vol_slices), row, col, nb_class], dtype=np.float32)

        # model = DenseUNet_3D.get_3d_denseunet()
        # model = UNet_3D.get_3d_unet_bn()
        # model = RSUNet_3D.get_3d_rsunet(opti)
        # model = UNet_3D.get_3d_wnet(opti)
        # model = UNet_3D.get_3d_unet()
        # model = UNet_3D.get_3d_unet()
        model = CNN_3D.get_3d_cnn()
        # model = RSUNet_3D_Gerda.get_3d_rsunet_Gerdafeature(opti)

        using_start_end = 1
        start_slice = cm.start_slice
        end_slice = -1

        if use_existing:
            model.load_weights(modelname)

        for i in range(num_patches):
            count1 = i * gaps
            count2 = i * gaps + slices
            test_image[0] = final_images_crop[count1:count2]

            predicted_mask = model.predict(test_image)

            if i == int(num_patches * 0.63):
                vs.visualize_activation_in_layer_one_plot_add_weights(
                    model, test_image, current_dir)
            else:
                pass

            predicted_mask_volume[count1:count2] += predicted_mask[
                0, :, :, :, :]

        t = len(predicted_mask_volume)
        for i in range(0, slices, gaps):
            predicted_mask_volume[i:(
                i +
                gaps)] = predicted_mask_volume[i:(i + gaps)] / (i / gaps + 1)

        for i in range(0, slices, gaps):
            predicted_mask_volume[(t - i -
                                   gaps):(t - i)] = predicted_mask_volume[
                                       (t - i - gaps):(t - i)] / (i / gaps + 1)

        for i in range(slices, (len(predicted_mask_volume) - slices)):
            predicted_mask_volume[i] = predicted_mask_volume[i] / (slices /
                                                                   gaps)

        np.save(cm.workingPath.testingNPY_path + 'testImages.npy',
                final_images_crop)
        np.save(cm.workingPath.testingNPY_path + 'testMasks.npy',
                final_masks_crop)
        np.save(cm.workingPath.testingNPY_path + 'masksTestPredicted.npy',
                predicted_mask_volume)

        final_images_crop = None
        final_masks_crop = None
        predicted_mask_volume = None

        imgs_origin = np.load(cm.workingPath.testingNPY_path +
                              'testImages.npy').astype(np.int16)
        imgs_true = np.load(cm.workingPath.testingNPY_path +
                            'testMasks.npy').astype(np.int8)
        imgs_predict = np.load(cm.workingPath.testingNPY_path +
                               'masksTestPredicted.npy').astype(np.float32)
        imgs_predict_threshold = np.load(cm.workingPath.testingNPY_path +
                                         'masksTestPredicted.npy').astype(
                                             np.float32)

        ########## ROC curve aorta

        actual = imgs_true[:, :, :, 1].reshape(-1)
        predictions = imgs_predict[:, :, :, 1].reshape(-1)
        # predictions = np.where(predictions < (0.7), 0, 1)

        false_positive_rate_aorta, true_positive_rate_aorta, thresholds_aorta = roc_curve(
            actual, predictions, pos_label=1)
        roc_auc_aorta = auc(false_positive_rate_aorta,
                            true_positive_rate_aorta)
        plt.figure(1, figsize=(6, 6))
        plt.figure(1)
        plt.title('ROC of Aorta')
        plt.plot(false_positive_rate_aorta, true_positive_rate_aorta, 'b')
        label = 'AUC = %0.2f' % roc_auc_aorta
        plt.legend(loc='lower right')
        plt.plot([0, 1], [0, 1], 'r--')
        plt.xlim([-0.0, 1.0])
        plt.ylim([-0.0, 1.0])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        # plt.show()
        saveName = '/Plots/ROC_Aorta_curve.png'
        plt.savefig(current_dir + saveName)
        plt.close()

        false_positive_rate_aorta = None
        true_positive_rate_aorta = None

        imgs_predict_threshold = np.where(imgs_predict_threshold < 0.3, 0, 1)

        if using_start_end == 1:
            aortaMean = lf.dice_coef_np(
                imgs_predict_threshold[start_slice:end_slice, :, :, 1],
                imgs_true[start_slice:end_slice, :, :, 1])
        else:
            aortaMean = lf.dice_coef_np(imgs_predict_threshold[:, :, :, 1],
                                        imgs_true[:, :, :, 1])

        np.savetxt(current_dir + '/Plots/AortaDicemean.txt',
                   np.array(aortaMean).reshape(1, ),
                   fmt='%.5f')

        print('Model file:', modelname)
        print('-' * 30)
        print('Aorta Dice Coeff', aortaMean)
        print('-' * 30)

        # Draw the subplots of figures:

        color1 = 'gray'  # ***
        color2 = 'viridis'  # ******

        transparent1 = 1.0
        transparent2 = 0.5

        # Slice parameters:

        #################################### Aorta
        # Automatically:

        steps = 40
        slice = range(0, len(imgs_origin), steps)
        plt_row = 3
        plt_col = int(len(imgs_origin) / steps)

        plt.figure(3, figsize=(25, 12))
        plt.figure(3)

        for i in slice:
            if i == 0:
                plt_num = int(i / steps) + 1
            else:
                plt_num = int(i / steps)

            if plt_num <= plt_col:

                ax1 = plt.subplot(plt_row, plt_col, plt_num)
                title = 'slice=' + str(i)
                plt.title(title)
                ax1.imshow(imgs_origin[i, :, :, 0],
                           cmap=color1,
                           alpha=transparent1)
                ax1.imshow(imgs_true[i, :, :, 1],
                           cmap=color2,
                           alpha=transparent2)

                ax2 = plt.subplot(plt_row, plt_col, plt_num + plt_col)
                title = 'slice=' + str(i)
                plt.title(title)
                ax2.imshow(imgs_origin[i, :, :, 0],
                           cmap=color1,
                           alpha=transparent1)
                ax2.imshow(imgs_predict[i, :, :, 1],
                           cmap=color2,
                           alpha=transparent2)

                ax3 = plt.subplot(plt_row, plt_col, plt_num + 2 * plt_col)
                title = 'slice=' + str(i)
                plt.title(title)
                ax3.imshow(imgs_origin[i, :, :, 0],
                           cmap=color1,
                           alpha=transparent1)
                ax3.imshow(imgs_predict_threshold[i, :, :, 1],
                           cmap=color2,
                           alpha=transparent2)
            else:
                pass

        modelname = cm.modellist[0]

        imageName = re.findall(r'\d+\.?\d*', modelname)
        epoch_num = int(imageName[0]) + 1
        accuracy = float(
            np.loadtxt(current_dir + '/Plots/AortaDicemean.txt', float))

        # saveName = 'epoch_' + str(epoch_num) + '_dice_' +str(accuracy) + '.png'
        saveName = '/Plots/epoch_Aorta_%02d_dice_%.3f.png' % (epoch_num - 1,
                                                              accuracy)

        plt.subplots_adjust(left=0.0,
                            bottom=0.05,
                            right=1.0,
                            top=0.95,
                            hspace=0.3,
                            wspace=0.3)
        plt.savefig(current_dir + saveName)
        plt.close()
        # plt.show()

        print('Images saved')
        # Save npy as dcm files:

        final_test_aorta_predicted_threshold = final_test_masks[:, :, :, 1]

        final_test_aorta_predicted_threshold[:, row_1:row_2, col_1:
                                             col_2] = imgs_predict_threshold[:, :, :,
                                                                             1]

        new_imgs_dcm = sitk.GetImageFromArray(
            np.uint16(final_test_images + 4000))
        new_imgs_aorta_predict_dcm = sitk.GetImageFromArray(
            np.uint16(final_test_aorta_predicted_threshold))

        sitk.WriteImage(new_imgs_dcm,
                        current_dir + '/DICOM/imagesPredicted.dcm')
        sitk.WriteImage(new_imgs_aorta_predict_dcm,
                        current_dir + '/DICOM/masksAortaPredicted.dcm')

        sitk.WriteImage(new_imgs_dcm, current_dir + '/mhd/imagesPredicted.mhd')
        sitk.WriteImage(new_imgs_aorta_predict_dcm,
                        current_dir + '/mhd/masksAortaPredicted.mhd')

        # mt.SegmentDist(current_dir + '/mhd/masksAortaPredicted.mhd',current_dir + '/mhd/masksAortaGroundTruth.mhd', current_dir + '/Surface_Distance/Aorta')
        # mt.SegmentDist(current_dir + '/mhd/masksPulPredicted.mhd',current_dir + '/mhd/masksPulGroundTruth.mhd', current_dir + '/Surface_Distance/Pul')

        print('DICOM saved')

        # Clear memory for the next testing sample:

        final_test_aorta_predicted_threshold = None
        final_test_pul_predicted_threshold = None
        imgs_predict_threshold = None
        new_imgs_dcm = None
        new_imgs_aorta_predict_dcm = None
        new_imgs_pul_predict_dcm = None
        final_test_images = None
        final_test_masks = None
        imgs_origin = None
        imgs_predict = None
        imgs_true = None
        predicted_mask = None
        predictions = None

        endtime = datetime.datetime.now()
        print('-' * 30)
        print('running time:', endtime - starttime)

        log_file.close()
        sys.stdout = stdout_backup
예제 #14
0
from keras.applications.inception_v3 import InceptionV3
from keras.preprocessing import image
from keras import backend as K
from tqdm import tqdm

np.random.seed(111)

print(os.listdir("/hdd2/PythonCodes/Git_clone/seaser/data/"))

input_path = '/hdd2/PythonCodes/Git_clone/seaser/data/'
cats = os.listdir(input_path)
print("Total number of sub-directories found: ", len(cats))

# Store the meta-data in a dataframe for convinience

cm.mkdir(cm.workingPath.model_path)
cm.mkdir(cm.workingPath.best_model_path)

x_train = np.load(cm.workingPath.home_path + 'xtrain.npy')
y_train = np.load(cm.workingPath.home_path + 'ytrain.npy')

# create the base pre-trained model
base_model = InceptionV3(weights='imagenet', include_top=False)
# add a global spatial average pooling layer
x = base_model.output
x = GlobalAveragePooling2D()(x)
# let's add a fully-connected layer
x = Dense(1024, activation='relu')(x)
x = Dropout(0.2)(x)
# and a logistic layer -- let's say we have 18 points
predictions = Dense(18, activation='relu')(x)
def train_and_predict(use_existing):

    cm.mkdir(cm.workingPath.model_path)
    cm.mkdir(cm.workingPath.best_model_path)
    cm.mkdir(cm.workingPath.visual_path)

    # learning_rate = 0.00001
    #
    # adam = Adam(lr=learning_rate)
    #
    # opti = adam
    #
    # lrate = callbacks.LearningRateScheduler(cb.step_decay)

    print('-' * 30)
    print('Loading and preprocessing train data...')
    print('-' * 30)

    # Scanning training data list:
    originFile_list = sorted(
        glob(cm.workingPath.trainingPatchesSet_path + 'img_*.npy'))
    mask_list = sorted(
        glob(cm.workingPath.trainingPatchesSet_path + 'mask_*.npy'))

    # Scanning validation data list:
    originValFile_list = sorted(
        glob(cm.workingPath.validationSet_path + 'valImages.npy'))
    maskVal_list = sorted(
        glob(cm.workingPath.validationSet_path + 'valMasks.npy'))

    x_val = np.load(originValFile_list[0])
    y_val = np.load(maskVal_list[0])

    # Calculate the total amount of training sets:
    nb_file = int(len(originFile_list))
    nb_val_file = int(len(originValFile_list))

    # Make a random list (shuffle the training data):
    random_scale = nb_file
    rand_i = np.random.choice(range(random_scale),
                              size=random_scale,
                              replace=False)

    # train_num, val_num, train_list, val_list = dp.train_split(nb_file, rand_i)
    train_num, train_list = dp.train_val_split(nb_file, rand_i)

    print('_' * 30)
    print('Creating and compiling model...')
    print('_' * 30)

    # Select the model you want to train:
    # model = nw.get_3D_unet()
    # model = nw.get_3D_Eunet()
    # model = DenseUNet_3D.get_3d_denseunet()
    model = UNet_3D.get_3d_unet()
    # model = RSUNet_3D.get_3d_rsunet(opti)
    # model = RSUNet_3D_Gerda.get_3d_rsunet_Gerdafeature(opti)

    # Plot the model:
    modelname = 'model.png'
    plot_model(model,
               show_shapes=True,
               to_file=cm.workingPath.model_path + modelname)
    model.summary()

    # Should we load existing weights?
    if use_existing:
        model.load_weights(cm.workingPath.model_path + './unet.hdf5')

    print('-' * 30)
    print('Fitting model...')
    print('-' * 30)

    nb_epoch = 4000

    temp_weights = model.get_weights()

    for e in range(nb_epoch):

        # Set callbacks:
        filepath = cm.workingPath.model_path + 'weights.epoch_%02d-{loss:.5f}-{val_loss:.5f}.hdf5' % (
            e + 1)
        model_checkpoint = callbacks.ModelCheckpoint(filepath,
                                                     monitor='loss',
                                                     verbose=0,
                                                     save_best_only=False)
        record_history = cb.RecordLossHistory()

        for i in range(train_num):

            print("epoch %04d, batch %04d" % (e + 1, i + 1))
            x_train, y_train = dp.BatchGenerator(i, originFile_list, mask_list,
                                                 train_list)

            # gradients = cb.recordGradients_Florian(x_train, cm.workingPath.model_path, model, True)
            callbacks_list = [record_history, model_checkpoint]

            if i == (train_num - 1):

                model.set_weights(temp_weights)
                model.fit(x_train,
                          y_train,
                          batch_size=1,
                          epochs=1,
                          verbose=1,
                          validation_data=(x_val, y_val),
                          callbacks=callbacks_list)
                temp_weights = model.get_weights()

            else:

                model.set_weights(temp_weights)
                model.fit(x_train, y_train, batch_size=1, epochs=1, verbose=1)
                temp_weights = model.get_weights()

    print('training finished')