Пример #1
0
def train(model,
          train_images,
          train_annotations,
          input_height=None,
          input_width=None,
          n_classes=None,
          verify_dataset=True,
          checkpoints_path=None,
          epochs=5,
          batch_size=2,
          validate=False,
          val_images=None,
          val_annotations=None,
          val_batch_size=2,
          auto_resume_checkpoint=False,
          load_weights=None,
          steps_per_epoch=512,
          val_steps_per_epoch=512,
          gen_use_multiprocessing=False,
          optimizer_name='adadelta',
          loss_type='categorical_crossentropy',
          metrics_used=['accuracy'],
          do_augment=False,
          augmentation_name="aug_all"):

    lib_found = importlib.util.find_spec("models")
    found = lib_found is not None
    if found:
        from models.all_models import model_from_name
    else:
        from .models.all_models import model_from_name

    # check if user gives model name instead of the model object
    if isinstance(model, six.string_types):
        # create the model from the name
        assert (n_classes is not None), "Please provide the n_classes"
        if (input_height is not None) and (input_width is not None):
            model = model_from_name[model](n_classes,
                                           input_height=input_height,
                                           input_width=input_width)
        else:
            model = model_from_name[model](n_classes)

    n_classes = model.n_classes
    input_height = model.input_height
    input_width = model.input_width
    output_height = model.output_height
    output_width = model.output_width

    if validate:
        assert val_images is not None
        assert val_annotations is not None

    # def compile_model(model):
    #      model.compile(loss=loss_type,
    #                   optimizer=optimizer_name,
    #                   metrics=metrics_used)

    # def finetune_model(model, initial_epoch, finetune_epochs):
    #     if not validate:
    #         history = model.fit_generator(train_gen, steps_per_epoch,
    #                                       epochs=finetune_epochs, callbacks=callbacks,
    #                                       initial_epoch=initial_epoch)
    #     else:
    #         history = model.fit_generator(train_gen,
    #                             steps_per_epoch,
    #                             validation_data=val_gen,
    #                             validation_steps=val_steps_per_epoch,
    #                             epochs=finetune_epochs, callbacks=callbacks,
    #                             use_multiprocessing=gen_use_multiprocessing,
    #                             initial_epoch=initial_epoch)

    if optimizer_name is not None:

        # if ignore_zero_class:
        #     loss_k = masked_categorical_crossentropy
        # else:
        #     #loss_k = 'categorical_crossentropy'
        #     loss_k = jaccard_distance

        # model.compile(loss=loss_k,
        #               optimizer=optimizer_name,
        #               #metrics=['accuracy'])
        #               metrics=['accuracy', metrics.MeanIoU(name='model_iou', num_classes=n_classes)])
        #compile_model(model)
        model.compile(loss=loss_type,
                      optimizer=optimizer_name,
                      metrics=metrics_used)

    if checkpoints_path is not None:
        with open(checkpoints_path + "_config.json", "w") as f:
            json.dump(
                {
                    "model_class": model.model_name,
                    "n_classes": n_classes,
                    "input_height": input_height,
                    "input_width": input_width,
                    "output_height": output_height,
                    "output_width": output_width
                }, f)

    if load_weights is not None and len(load_weights) > 0:
        print("Loading weights from ", load_weights)
        model.load_weights(load_weights)

    if auto_resume_checkpoint and (checkpoints_path is not None):
        latest_checkpoint = find_latest_checkpoint(checkpoints_path)
        if latest_checkpoint is not None:
            print("Loading the weights from latest checkpoint ",
                  latest_checkpoint)
            model.load_weights(latest_checkpoint)

    if verify_dataset:
        print("Verifying training dataset")
        verified = verify_segmentation_dataset(train_images, train_annotations,
                                               n_classes)
        assert verified
        if validate:
            print("Verifying validation dataset")
            verified = verify_segmentation_dataset(val_images, val_annotations,
                                                   n_classes)
            assert verified

    train_gen = image_segmentation_generator(
        train_images,
        train_annotations,
        batch_size,
        n_classes,
        input_height,
        input_width,
        output_height,
        output_width,
        do_augment=do_augment,
        augmentation_name=augmentation_name)

    if validate:
        val_gen = image_segmentation_generator(val_images, val_annotations,
                                               val_batch_size, n_classes,
                                               input_height, input_width,
                                               output_height, output_width)

    callbacks = [
        CheckpointsCallback(checkpoints_path),
        CSVLogger(checkpoints_path + model.model_name + '_training.csv')
    ]

    if not validate:
        history = model.fit_generator(train_gen,
                                      steps_per_epoch,
                                      epochs=epochs,
                                      callbacks=callbacks)
    else:
        history = model.fit_generator(
            train_gen,
            steps_per_epoch,
            validation_data=val_gen,
            validation_steps=val_steps_per_epoch,
            epochs=epochs,
            callbacks=callbacks,
            use_multiprocessing=gen_use_multiprocessing)

    # if finetune_epochs > 0:

    # pruning = ridurre.KMeansFilterPruning(0.6, compile_model,
    #                                       finetune_model, finetune_epochs,
    #                                       maximum_pruning_percent=0.4,
    #                                       maximum_prune_iterations=12)
    # model, num = pruning.run_pruning(model)
    # print(model.summary())

    # list all data in history
    if epochs > 0:
        print(history.history.keys())
        # summarize history for accuracy
        plt.plot(history.history['masked_categorical_accuracy'])
        plt.plot(history.history['val_masked_categorical_accuracy'])
        plt.title('model accuracy')
        plt.ylabel('accuracy')
        plt.xlabel('epoch')
        plt.legend(['train', 'validation'], loc='upper left')
        plt.show()
        plt.savefig('accuracy.png')
        # summarize history for loss
        plt.plot(history.history['loss'])
        plt.plot(history.history['val_loss'])
        plt.title('model loss')
        plt.ylabel('loss')
        plt.xlabel('epoch')
        plt.legend(['train', 'validation'], loc='upper left')
        plt.show()
        plt.savefig('loss.png')
        # summarize history for loss
        plt.plot(history.history['mean_iou'])
        plt.plot(history.history['val_mean_iou'])
        plt.title('model mean IOU')
        plt.ylabel('mean_iou')
        plt.xlabel('epoch')
        plt.legend(['train', 'validation'], loc='upper left')
        plt.show()
        plt.savefig('mean_iou.png')
def train(model,
		  train_images,
		  train_annotations,
		  input_height=None,
		  input_width=None,
		  n_classes=None,
		  verify_dataset=True,
		  checkpoints_path=None,
		  epochs=5,
		  batch_size=2,
		  validate=True,
		  val_images=None,
		  val_annotations=None,
		  val_batch_size=2,
		  auto_resume_checkpoint=False,
		  load_weights=None,
		  steps_per_epoch=512,
		  optimizer_name='adadelta' , do_augment=False , 
		  loss_name='categorical_crossentropy'
		  ):
	#categorical_crossentropy
	from models.all_models import model_from_name
	#from .models.all_models import model_from_name
	# check if user gives model name instead of the model object
	if isinstance(model, six.string_types):
		# create the model from the name
		assert (n_classes is not None), "Please provide the n_classes"
		if (input_height is not None) and (input_width is not None):
			model = model_from_name[model](
				n_classes, input_height=input_height, input_width=input_width)
		else:
			model = model_from_name[model](n_classes)
	
	n_classes = model.n_classes
	input_height = model.input_height
	input_width = model.input_width
	output_height = model.output_height
	output_width = model.output_width
	
	csv_logger = CSVLogger('.../Loss_Acc.csv', append=True, separator=' ')
	checkpoint = ModelCheckpoint('model-{epoch:03d}.h5', verbose=1, monitor='val_loss',save_best_only=True, mode='auto')
	reduce_lr_loss = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, verbose=1, epsilon=1e-4, mode='min')
	
	if validate:
		assert val_images is not None
		assert val_annotations is not None
		
	
	if optimizer_name is not None:
		model.compile(loss=dice_coef_loss,
					  optimizer=optimizer_name,
					  metrics=[jacard_coef, 'accuracy'])
	
	if checkpoints_path is not None:
		with open(checkpoints_path+"_config.json", "w") as f:
			json.dump({
				"model_class": model.model_name,
				"n_classes": n_classes,
				"input_height": input_height,
				"input_width": input_width,
				"output_height": output_height,
				"output_width": output_width
			}, f)

	if load_weights is not None and len(load_weights) > 0:
		print("Loading weights from ", load_weights)
		model.load_weights(load_weights)

	if auto_resume_checkpoint and (checkpoints_path is not None):
		latest_checkpoint = find_latest_checkpoint(checkpoints_path)
		if latest_checkpoint is not None:
			print("Loading the weights from latest checkpoint ",
				  latest_checkpoint)
			model.load_weights(latest_checkpoint)

	if verify_dataset:
		print("Verifying training dataset") 
		print("Verifying training dataset::")        
		verified = verify_segmentation_dataset(train_images, train_annotations, n_classes)
		assert verified
		if validate:
			print("Verifying validation dataset::")
			verified = verify_segmentation_dataset(val_images, val_annotations, n_classes)
			assert verified

	train_gen = image_segmentation_generator(
		train_images, train_annotations,  batch_size,  n_classes,
		input_height, input_width, output_height, output_width , do_augment=do_augment )

	if validate:
		val_gen = image_segmentation_generator(
			val_images, val_annotations,  val_batch_size,
			n_classes, input_height, input_width, output_height, output_width)

	if not validate:
		for ep in range(epochs):
			print("Starting Epoch # ", ep)
			model.fit_generator(train_gen, steps_per_epoch, epochs=1)
			if checkpoints_path is not None:
				model.save_weights(checkpoints_path + "." + str(ep))
				print("saved ", checkpoints_path + ".model." + str(ep))
			print("Finished Epoch #", ep)
	else:
		for ep in range(epochs):
			print("Starting Epoch # ", ep)
			model.fit_generator(train_gen, steps_per_epoch,
								validation_data=val_gen,
								validation_steps=200,  epochs=1, callbacks=[csv_logger, reduce_lr_loss])
			if checkpoints_path is not None:
				model.save_weights(checkpoints_path + "." + str(ep))
				print("saved ", checkpoints_path + ".model." + str(ep))
			print("Finished Epoch #", ep)
Пример #3
0
def train(model,
          train_images,
          train_annotations,
          input_height=None,
          input_width=None,
          n_classes=None,
          verify_dataset=True,
          checkpoints_path=None,
          epochs=5,
          batch_size=2,
          validate=False,
          val_images=None,
          val_annotations=None,
          val_batch_size=2,
          auto_resume_checkpoint=False,
          load_weights=None,
          steps_per_epoch=512,
          optimizer_name='adadelta',
          do_augment=False):

    from models.all_models import model_from_name

    if isinstance(model, six.string_types):

        assert (n_classes is not None), "Please provide the n_classes"
        if (input_height is not None) and (input_width is not None):
            model = model_from_name[model](n_classes,
                                           input_height=input_height,
                                           input_width=input_width)
        else:
            model = model_from_name[model](n_classes)

    n_classes = model.n_classes
    input_height = model.input_height
    input_width = model.input_width
    output_height = model.output_height
    output_width = model.output_width

    if validate:
        assert val_images is not None
        assert val_annotations is not None

    if optimizer_name is not None:
        model.compile(loss='categorical_crossentropy',
                      optimizer=optimizer_name,
                      metrics=['accuracy'])

    if checkpoints_path is not None:
        with open(checkpoints_path + "_config.json", "w") as f:
            json.dump(
                {
                    "model_class": model.model_name,
                    "n_classes": n_classes,
                    "input_height": input_height,
                    "input_width": input_width,
                    "output_height": output_height,
                    "output_width": output_width
                }, f)

    if load_weights is not None and len(load_weights) > 0:
        print("Loading weights from ", load_weights)
        model.load_weights(load_weights)

    if auto_resume_checkpoint and (checkpoints_path is not None):
        latest_checkpoint = find_latest_checkpoint(checkpoints_path)
        if latest_checkpoint is not None:
            print("Loading the weights from latest checkpoint ",
                  latest_checkpoint)
            model.load_weights(latest_checkpoint)

    if verify_dataset:
        print("Verifying training dataset")
        verified = verify_segmentation_dataset(train_images, train_annotations,
                                               n_classes)
        assert verified
        if validate:
            print("Verifying validation dataset")
            verified = verify_segmentation_dataset(val_images, val_annotations,
                                                   n_classes)
            assert verified

    train_gen = image_segmentation_generator(train_images,
                                             train_annotations,
                                             batch_size,
                                             n_classes,
                                             input_height,
                                             input_width,
                                             output_height,
                                             output_width,
                                             do_augment=do_augment)

    if validate:
        val_gen = image_segmentation_generator(val_images, val_annotations,
                                               val_batch_size, n_classes,
                                               input_height, input_width,
                                               output_height, output_width)

    if not validate:
        for ep in range(epochs):
            print("Starting Epoch ", ep)
            model.fit_generator(train_gen, steps_per_epoch, epochs=1)
            if checkpoints_path is not None:
                model.save_weights(checkpoints_path + "." + str(ep))
                print("saved ", checkpoints_path + ".model." + str(ep))
            print("Finished Epoch", ep)
    else:
        for ep in range(epochs):
            print("Starting Epoch ", ep)
            model.fit_generator(train_gen,
                                steps_per_epoch,
                                validation_data=val_gen,
                                validation_steps=200,
                                epochs=1)
            if checkpoints_path is not None:
                model.save_weights(checkpoints_path + "." + str(ep))
                print("saved ", checkpoints_path + ".model." + str(ep))
            print("Finished Epoch", ep)
Пример #4
0
def train(model,
          train_images,
          train_annotations,
          input_height=None,
          input_width=None,
          n_classes=None,
          verify_dataset=True,
          checkpoints_path=None,
          epochs=5,
          batch_size=2,
          validate=False,
          val_images=None,
          val_annotations=None,
          val_batch_size=2,
          auto_resume_checkpoint=False,
          load_weights=None,
          steps_per_epoch=512,
          optimizer_name='adadelta',
          callbacks=None):

    if isinstance(model, six.string_types):
        # check if user gives models name insteead of the models object
        # create the models from the name
        assert (not n_classes is None), "Please provide the n_classes"
        if (not input_height is None) and (not input_width is None):
            model = model_from_name[model](n_classes,
                                           input_height=input_height,
                                           input_width=input_width)
        else:
            model = model_from_name[model](n_classes)

    n_classes = model.n_classes
    input_height = model.input_height
    input_width = model.input_width
    output_height = model.output_height
    output_width = model.output_width

    model.summary()
    if validate:
        assert not (val_images is None)
        assert not (val_annotations is None)

    if not optimizer_name is None:
        model.compile(loss='categorical_crossentropy',
                      optimizer=optimizer_name,
                      metrics=['accuracy'])

    if not checkpoints_path is None:
        open(checkpoints_path + "_config.json", "w").write(
            json.dumps({
                "model_class": model.model_name,
                "n_classes": n_classes,
                "input_height": input_height,
                "input_width": input_width,
                "output_height": output_height,
                "output_width": output_width
            }))

    if (not (load_weights is None)) and len(load_weights) > 0:
        print("Loading weights from ", load_weights)
        model.load_weights(load_weights)

    latest_ep = -1
    if auto_resume_checkpoint and (not checkpoints_path is None):
        latest_checkpoint, latest_ep = find_latest_checkpoint(checkpoints_path)
        if not latest_checkpoint is None:
            print("Loading the weights from latest checkpoint ",
                  latest_checkpoint)
            model.load_weights(latest_checkpoint)

    if verify_dataset:
        print("Verifying train dataset")
        verify_segmentation_dataset(train_images, train_annotations, n_classes)
        if validate:
            print("Verifying val dataset")
            verify_segmentation_dataset(val_images, val_annotations, n_classes)

    train_gen = image_segmentation_generator(train_images, train_annotations,
                                             batch_size, n_classes,
                                             input_height, input_width,
                                             output_height, output_width)

    if validate:
        val_gen = image_segmentation_generator(val_images, val_annotations,
                                               val_batch_size, n_classes,
                                               input_height, input_width,
                                               output_height, output_width)

    if not validate:
        for ep in range(latest_ep + 1, latest_ep + 1 + epochs):
            print("Starting Epoch ", ep)
            model.fit_generator(train_gen,
                                steps_per_epoch,
                                epochs=1,
                                callbacks=callbacks)
            if not checkpoints_path is None:
                model.save_weights(checkpoints_path + "." + str(ep))
                print("saved ", checkpoints_path + ".models." + str(ep))
            print("Finished Epoch", ep)
    else:
        for ep in range(latest_ep + 1, latest_ep + 1 + epochs):
            print("Starting Epoch ", ep)
            model.fit_generator(train_gen,
                                steps_per_epoch,
                                validation_data=val_gen,
                                validation_steps=200,
                                callbacks=callbacks,
                                epochs=1)

            if not checkpoints_path is None:
                model.save_weights(checkpoints_path + "." + str(ep))
                print("saved ", checkpoints_path + ".models." + str(ep))
            print("Finished Epoch", ep)
Пример #5
0
 def action(args):
     verify_segmentation_dataset(args.images_path, args.segs_path,
                                 args.n_classes)