コード例 #1
0
 def train_model(foldername,Model_Type,num_objects=2, num_experiments=1, enhance_data=False, batch_size=1, show_network_summary=True):
     model_trainer = ModelTraining()
     if Model_Type in "ResNet":
         model_trainer.setModelTypeAsResNet()
     elif Model_Type in "SqueezeNet":
         model_trainer.setModelTypeAsSqueezeNet()
     elif Model_Type in "InceptionV3":
         model_trainer.setModelTypeAsInceptionV3()
     elif Model_Type in "DenseNet":
         model_trainer.setModelTypeAsDenseNet()
     model_trainer.setDataDirectory(foldername)
     model_trainer.trainModel(num_objects=num_objects, num_experiments=num_experiments, enhance_data=enhance_data, batch_size=batch_size, show_network_summary=show_network_summary)
コード例 #2
0
ファイル: test_model_training.py プロジェクト: AarC10/ImageAI
def test_squeezenet_training():

    trainer = ModelTraining()
    trainer.setModelTypeAsSqueezeNet()
    trainer.setDataDirectory(data_directory=sample_dataset)
    trainer.trainModel(num_objects=10,
                       num_experiments=1,
                       enhance_data=True,
                       batch_size=16,
                       show_network_summary=True)

    assert os.path.isdir(sample_dataset_json_folder)
    assert os.path.isdir(sample_dataset_models_folder)
    assert os.path.isfile(
        os.path.join(sample_dataset_json_folder, "model_class.json"))
    assert (len(os.listdir(sample_dataset_models_folder)) > 0)
    shutil.rmtree(os.path.join(sample_dataset_json_folder))
    shutil.rmtree(os.path.join(sample_dataset_models_folder))
コード例 #3
0
from imageai.Prediction.Custom import ModelTraining
from imageai.Detection import ObjectDetection
model_trainer = ModelTraining()
model_trainer.setModelTypeAsSqueezeNet()
model_trainer.setDataDirectory("ImageAI_Custom_CNN/vehicles")
model_trainer.trainModel(num_objects=2, num_experiments=100, enhance_data=True, batch_size=16, show_network_summary=True)