def _get_transform_executable(className): #TODO: dirty hack to be able to import custom transform functions... tokens = className.split('.') if tokens[-2] == 'labels': return getattr(LabelsTransforms, tokens[-1]) elif tokens[-2] == 'points': return getattr(PointsTransforms, tokens[-1]) elif tokens[-2] == 'boundingBoxes': return getattr(BoundingBoxesTransforms, tokens[-1]) else: return get_class_executable(className)
def __init__(self, project, config, dbConnector, fileServer, options): super(GenericPyTorchModel, self).__init__(project, config, dbConnector, fileServer, options) # try to fill and substitute global definitions in JSON-enhanced options if isinstance(options, dict) and 'defs' in options: try: updatedOptions = optionsHelper.substitute_definitions( options.copy()) self.options = updatedOptions except: # something went wrong; ignore pass # retrieve executables try: self.model_class = get_class_executable( optionsHelper.get_hierarchical_value( self.options, ['options', 'model', 'class'])) except: self.model_class = None try: self.criterion_class = get_class_executable( optionsHelper.get_hierarchical_value( self.options, ['options', 'train', 'criterion', 'class'])) except: self.criterion_class = None try: self.optim_class = get_class_executable( optionsHelper.get_hierarchical_value( self.options, ['options', 'train', 'optim'])) except: self.optim_class = SGD try: self.dataset_class = get_class_executable( optionsHelper.get_hierarchical_value(self.options, ['options', 'dataset'])) except: self.dataset_class = None
def __init__(self, project, config, dbConnector, fileServer, options, defaultOptions=None): super(GenericPyTorchModel_Legacy, self).__init__(project, config, dbConnector, fileServer, options) # parse the options and compare with the provided defaults (if provided) if defaultOptions is not None: self.options = check_args(self.options, defaultOptions) else: self.options = options # retrieve executables try: self.model_class = get_class_executable( self.options['model']['class']) except: self.model_class = None try: self.criterion_class = get_class_executable( self.options['train']['criterion']['class']) except: self.criterion_class = None try: self.optim_class = get_class_executable( self.options['train']['optim']['class']) except: self.optim_class = SGD try: self.dataset_class = get_class_executable( self.options['dataset']['class']) except: self.dataset_class = None
def _import_model_state_file(self, project, modelURI, modelDefinition, stateDict=None, public=True, anonymous=False, namePolicy='skip', customName=None): ''' Receives two files: - "modelDefinition": JSON-encoded metadata file - "stateDict" (optional): BytesIO file containing model state Parses the files for completeness and integrity and attempts to launch a model with them. If all checks are successful, the new model state is inserted into the database and the UUID is returned. Parameter "namePolicy" can have one of three values: - "skip" (default): skips model import if another with the same name already exists in the Model Marketplace database - "increment": appends or increments a number at the end of the name if it is already found in the database - "custom": uses the value under "customName" ''' # check inputs if not isinstance(modelDefinition, dict): try: if isinstance(modelDefinition, bytes): modelDefinition = json.load(io.BytesIO(modelDefinition)) elif isinstance(modelDefinition, str): modelDefinition = json.loads(modelDefinition) except Exception as e: raise Exception( f'Invalid model state definition file (message: "{e}").') if stateDict is not None and not isinstance(stateDict, bytes): raise Exception( 'Invalid model state dict provided (not a binary file).') # check naming policy namePolicy = str(namePolicy).lower() assert namePolicy in ( 'skip', 'increment', 'custom'), f'Invalid naming policy "{namePolicy}".' if namePolicy == 'skip': # check if model with same name already exists modelID = self.getModelIdByName(modelDefinition['name']) if modelID is not None: return modelID elif namePolicy == 'custom': assert isinstance( customName, str) and len(customName), 'Invalid custom name provided.' # check if name is available if self.getModelIdByName(customName) is not None: raise Exception(f'Custom name "{customName}" is unavailable.') # project metadata projectMeta = self.dbConnector.execute( ''' SELECT annotationType, predictionType FROM aide_admin.project WHERE shortname = %s; ''', (project, ), 1) if projectMeta is None or not len(projectMeta): raise Exception( f'Project with shortname "{project}" not found in database.') projectMeta = projectMeta[0] # check fields for field in ('author', 'citation_info', 'license'): if field not in modelDefinition: modelDefinition[field] = None for field in self.MODEL_STATE_REQUIRED_FIELDS: if field not in modelDefinition: raise Exception(f'Missing field "{field}" in AIDE JSON file.') if field == 'aide_model_version': # check model definition version modelVersion = modelDefinition['aide_model_version'] if isinstance(modelVersion, str): if not modelVersion.isnumeric(): raise Exception( f'Invalid AIDE model version "{modelVersion}" in JSON file.' ) else: modelVersion = float(modelVersion) modelVersion = float(modelVersion) if modelVersion > self.MAX_AIDE_MODEL_VERSION: raise Exception( f'Model state contains a newer model version than supported by this installation of AIDE ({modelVersion} > {self.MAX_AIDE_MODEL_VERSION}).\nPlease update AIDE to the latest version.' ) if field == 'ai_model_library': # check if model library is installed modelLibrary = modelDefinition[field] if modelLibrary not in PREDICTION_MODELS: raise Exception( f'Model library "{modelLibrary}" is not installed in this instance of AIDE.' ) # check if annotation and prediction types match if projectMeta['annotationtype'] not in PREDICTION_MODELS[ modelLibrary]['annotationType']: raise Exception( 'Project\'s annotation type is not compatible with this model state.' ) if projectMeta['predictiontype'] not in PREDICTION_MODELS[ modelLibrary]['predictionType']: raise Exception( 'Project\'s prediction type is not compatible with this model state.' ) # check if model state URI provided if stateDict is None and hasattr(modelDefinition, 'ai_model_state_uri'): stateDictURI = modelDefinition['ai_model_state_uri'] try: if stateDictURI.lower().startswith('aide://'): # load from disk stateDictPath = stateDictURI.replace('aide://', '').strip('/') if not os.path.isfile(stateDictPath): raise Exception( f'Model state file path provided ("{stateDictPath}"), but file could not be found.' ) with open(stateDictPath, 'rb') as f: stateDict = f.read() #TODO: BytesIO instead else: # network import with request.urlopen(stateDictURI) as f: stateDict = f.read( ) #TODO: progress bar; load in chunks; etc. except Exception as e: raise Exception( f'Model state URI provided ("{stateDictURI}"), but could not be loaded (message: "{str(e)}").' ) # model name modelName = modelDefinition['name'] if namePolicy == 'increment': # check if model name needs to be incremented allNames = self.dbConnector.execute( 'SELECT name FROM "aide_admin".modelmarketplace;', None, 'all') allNames = (set([a['name'].strip() for a in allNames]) if allNames is not None else set()) if modelName.strip() in allNames: startIdx = 1 insertPos = len(modelName) trailingNumber = re.findall(' \d+$', modelName.strip()) if len(trailingNumber): startIdx = int(trailingNumber[0]) insertPos = modelName.rfind(str(startIdx)) - 1 while modelName.strip() in allNames: modelName = modelName[:insertPos] + f' {startIdx}' startIdx += 1 elif namePolicy == 'custom': modelName = customName # remaining parameters modelAuthor = modelDefinition['author'] modelDescription = (modelDefinition['description'] if 'description' in modelDefinition else None) modelTags = (';;'.join(modelDefinition['tags']) if 'tags' in modelDefinition else None) labelClasses = modelDefinition['labelclasses'] #TODO: parse? if not isinstance(labelClasses, str): labelClasses = json.dumps(labelClasses) modelOptions = (modelDefinition['ai_model_settings'] if 'ai_model_settings' in modelDefinition else None) modelLibrary = modelDefinition['ai_model_library'] alCriterion_library = (modelDefinition['alcriterion_library'] if 'alcriterion_library' in modelDefinition else None) annotationType = PREDICTION_MODELS[modelLibrary][ 'annotationType'] #TODO predictionType = PREDICTION_MODELS[modelLibrary][ 'predictionType'] #TODO citationInfo = modelDefinition['citation_info'] license = modelDefinition['license'] if not isinstance(annotationType, str): annotationType = ','.join(annotationType) if not isinstance(predictionType, str): predictionType = ','.join(predictionType) timeCreated = (modelDefinition['time_created'] if 'time_created' in modelDefinition else None) try: timeCreated = datetime.fromtimestamp(timeCreated) except: timeCreated = current_time() # try to launch model with data try: modelClass = get_class_executable(modelLibrary) modelClass(project=project, config=self.config, dbConnector=self.dbConnector, fileServer=FileServer( self.config).get_secure_instance(project), options=modelOptions) # verify options if modelOptions is not None: try: optionMeta = modelClass.verifyOptions(modelOptions) if 'options' in optionMeta: modelOptions = optionMeta['options'] #TODO: parse warnings and errors except: # model library does not support option verification pass if isinstance(modelOptions, dict): modelOptions = json.dumps(modelOptions) except Exception as e: raise Exception( f'Model from imported state could not be launched (message: "{str(e)}").' ) # import model state into Model Marketplace success = self.dbConnector.execute( ''' INSERT INTO aide_admin.modelMarketplace (name, description, tags, labelclasses, author, statedict, model_library, model_settings, alCriterion_library, annotationType, predictionType, citation_info, license, timeCreated, origin_project, origin_uuid, origin_uri, public, anonymous) VALUES %s RETURNING id; ''', [(modelName, modelDescription, modelTags, labelClasses, modelAuthor, stateDict, modelLibrary, modelOptions, alCriterion_library, annotationType, predictionType, citationInfo, license, timeCreated, project, None, modelURI, public, anonymous)], 1) if success is None or not len(success): #TODO: temporary fix to get ID: try again by re-querying DB success = self.dbConnector.execute( ''' SELECT id FROM aide_admin.modelMarketplace WHERE name = %s; ''', (modelName, ), 1) if success is None or not len(success): raise Exception( 'Model could not be imported into Model Marketplace.') # model import to Marketplace successful; now import to projet return success[0]['id']
def main(): # parse arguments parser = argparse.ArgumentParser(description='AIDE local model tester') parser.add_argument('--project', type=str, required=True, help='Project shortname to draw sample data from.') parser.add_argument( '--mode', type=str, required=True, help= 'Evaluation mode (function to call). One of {"train", "inference"}.') parser.add_argument( '--modelLibrary', type=str, required=False, help= 'Optional AI model library override. Provide a dot-separated Python import path here.' ) parser.add_argument( '--modelSettings', type=str, required=False, help= 'Optional AI model settings override (absolute or relative path to settings file, or else "none" to not use any predefined settings).' ) args = parser.parse_args() #TODO: continue assert args.mode.lower() in ( 'train', 'inference'), f'"{args.mode}" is not a known evaluation mode.' mode = args.mode.lower() # initialize required modules config = Config() dbConnector = Database(config) fileServer = FileServer(config).get_secure_instance(args.project) aiw = AIWorker(config, dbConnector, True) aicw = AIControllerWorker(config, None) # check if AIDE file server is reachable admin = AdminMiddleware(config, dbConnector) connDetails = admin.getServiceDetails(True, False) fsVersion = connDetails['FileServer']['aide_version'] if not isinstance(fsVersion, str): # no file server running raise Exception( 'ERROR: AIDE file server is not running, but required for running models. Make sure to launch it prior to running this script.' ) elif fsVersion != AIDE_VERSION: print( f'WARNING: the AIDE version of File Server instance ({fsVersion}) differs from this one ({AIDE_VERSION}).' ) # get model trainer instance and settings queryStr = ''' SELECT ai_model_library, ai_model_settings FROM aide_admin.project WHERE shortname = %s; ''' result = dbConnector.execute(queryStr, (args.project, ), 1) if result is None or not len(result): raise Exception( f'Project "{args.project}" could not be found in this installation of AIDE.' ) modelLibrary = result[0]['ai_model_library'] modelSettings = result[0]['ai_model_settings'] customSettingsSpecified = False if hasattr(args, 'modelSettings') and isinstance( args.modelSettings, str) and len(args.modelSettings): # settings override specified if args.modelSettings.lower() == 'none': modelSettings = None customSettingsSpecified = True elif not os.path.isfile(args.modelSettings): print( f'WARNING: model settings override provided, but file cannot be found ("{args.modelSettings}"). Falling back to project default ("{modelSettings}").' ) else: modelSettings = args.modelSettings customSettingsSpecified = True if hasattr(args, 'modelLibrary') and isinstance( args.modelLibrary, str) and len(args.modelLibrary): # library override specified; try to import it try: modelClass = helpers.get_class_executable(args.modelLibrary) if modelClass is None: raise modelLibrary = args.modelLibrary # re-check if current model settings are compatible; warn and set to None if not if modelLibrary != result[0][ 'ai_model_library'] and not customSettingsSpecified: # project model settings are not compatible with provided model print( 'WARNING: custom model library specified differs from the one currently set in project. Model settings will be set to None.' ) modelSettings = None except Exception as e: print( f'WARNING: model library override provided ("{args.modelLibrary}"), but could not be imported. Falling back to project default ("{modelLibrary}").' ) # initialize instance print(f'Using model library "{modelLibrary}".') modelTrainer = aiw._init_model_instance(args.project, modelLibrary, modelSettings) stateDict = None #TODO: load latest unless override is specified? # get data data = aicw.get_training_images(project=args.project, maxNumImages=512) data = __load_metadata(args.project, dbConnector, data[0], (mode == 'train')) # helper functions def updateStateFun(state, message, done=None, total=None): print(message, end='') if done is not None and total is not None: print(f': {done}/{total}') else: print('') # launch task if mode == 'train': result = modelTrainer.train(stateDict, data, updateStateFun) if result is None: raise Exception( 'Training function must return an object (i.e., trained model state) to be stored in the database.' ) elif mode == 'inference': result = modelTrainer.inference(stateDict, data, updateStateFun)
def train(self, stateDict, data, updateStateFun): ''' Initializes a model based on the given stateDict and a data loader from the provided data and trains the model, taking into account the parameters speci- fied in the 'options' given to the class. Returns a serializable state dict of the resulting model. ''' # initialize model model, labelclassMap = self.initializeModel(stateDict, data) # setup transform, data loader, dataset, optimizer, criterion transform = parse_transforms(self.options['train']['transform']) dataset = self.dataset_class(data, self.fileServer, labelclassMap, transform, self.options['train']['ignore_unsure'], **self.options['dataset']['kwargs']) collator = Collator(self.project, self.dbConnector) dataLoader = DataLoader( dataset, collate_fn=collator.collate, **self.options['train']['dataLoader']['kwargs']) optimizer_class = get_class_executable( self.options['train']['optim']['class']) optimizer = optimizer_class(params=model.parameters(), **self.options['train']['optim']['kwargs']) criterion_class = get_class_executable( self.options['train']['criterion']['class']) criterion = criterion_class( **self.options['train']['criterion']['kwargs']) # train model device = self.get_device() torch.manual_seed(self.options['general']['seed']) if 'cuda' in device: torch.cuda.manual_seed(self.options['general']['seed']) model.to(device) imgCount = 0 for (img, labels, fVec, _) in tqdm(dataLoader): img, labels = img.to(device), labels.to(device) optimizer.zero_grad() pred = model(img) loss_value = criterion(pred, labels) loss_value.backward() optimizer.step() # update worker state imgCount += img.size(0) updateStateFun(state='PROGRESS', message='training', done=imgCount, total=len(dataLoader.dataset)) # all done; return state dict as bytes return self.exportModelState(model)
def train(self, stateDict, data, updateStateFun): ''' Initializes a model based on the given stateDict and a data loader from the provided data and trains the model, taking into account the parameters speci- fied in the 'options' given to the class. Trains the model with either mode (fully or weakly supervised), depending on the individual images' annotation types. This is handled by the dataset class. Returns a serializable state dict of the resulting model. ''' # initialize model model, labelclassMap = self.initializeModel(stateDict, data) inputSize = tuple(self.options['general']['image_size']) targetSize = model.getOutputSize(inputSize) # setup transform, data loader, dataset, optimizer, criterion transform = parse_transforms(self.options['train']['transform']) dataset = self.dataset_class(data, self.fileServer, labelclassMap, transform, self.options['train']['ignore_unsure'], **self.options['dataset']['kwargs']) dataEncoder = encoder.DataEncoder(len(labelclassMap.keys())) collator = collation.Collator(self.project, self.dbConnector, targetSize, dataEncoder) dataLoader = DataLoader( dataset=dataset, collate_fn=collator.collate_fn, **self.options['train']['dataLoader']['kwargs']) # optimizer optimizer_class = get_class_executable( self.options['train']['optim']['class']) optimizer = optimizer_class(params=model.parameters(), **self.options['train']['optim']['kwargs']) # loss criterion criterion_class = get_class_executable( self.options['train']['criterion']['class']) criterion = criterion_class( **self.options['train']['criterion']['kwargs']) # train model device = self.get_device() torch.manual_seed(self.options['general']['seed']) if 'cuda' in device: torch.cuda.manual_seed(self.options['general']['seed']) model.to(device) imgCount = 0 for (img, locs_target, cls_images, fVec, _) in tqdm(dataLoader): img, locs_target, cls_images = img.to(device), \ locs_target.to(device), \ cls_images.to(device) optimizer.zero_grad() locs_pred = model(img) loss_value = criterion(locs_pred, locs_target, cls_images) loss_value.backward() optimizer.step() # update worker state imgCount += img.size(0) updateStateFun(state='PROGRESS', message='training', done=imgCount, total=len(dataLoader.dataset)) # all done; return state dict as bytes return self.exportModelState(model)
# setup print('Setup...') if not 'AIDE_CONFIG_PATH' in os.environ: os.environ['AIDE_CONFIG_PATH'] = str(args.settings_filepath) from util.configDef import Config from modules.Database.app import Database config = Config() dbConn = Database(config) # load model class function print('Load and verify state dict...') from util.helpers import get_class_executable, current_time modelClass = getattr( get_class_executable( config.getProperty('AIController', 'model_lib_path')), 'model_class') # load state dict stateDict = torch.load(open(args.modelPath, 'rb')) # verify model state model = modelClass.loadFromStateDict(stateDict) # load class definitions from database classdef_db = {} labelClasses = dbConn.execute( 'SELECT * FROM {schema}.labelclass;'.format( schema=config.getProperty('Database', 'schema')), None, 'all') for lc in labelClasses: classdef_db[lc['id']] = lc
def __init__(self, config, dbConnector, fileServer, options): # parse provided functions self.heuristics = [] for h in options['rank']['heuristics']: self.heuristics.append(get_class_executable(h))
def train(self, stateDict, data, updateStateFun): ''' Initializes a model based on the given stateDict and a data loader from the provided data and trains the model, taking into account the parameters speci- fied in the 'options' given to the class. Returns a serializable state dict of the resulting model. ''' # initialize model model, labelclassMap = self.initializeModel( stateDict, data, optionsHelper.get_hierarchical_value( self.options, ['options', 'general', 'labelClasses', 'add_missing', 'value' ]), optionsHelper.get_hierarchical_value(self.options, [ 'options', 'general', 'labelClasses', 'remove_obsolete', 'value' ])) # setup transform, data loader, dataset, optimizer, criterion inputSize = (int( optionsHelper.get_hierarchical_value( self.options, ['options', 'general', 'imageSize', 'width', 'value'])), int( optionsHelper.get_hierarchical_value( self.options, [ 'options', 'general', 'imageSize', 'height', 'value' ]))) transform = RetinaNet._init_transform_instances( optionsHelper.get_hierarchical_value( self.options, ['options', 'train', 'transform', 'value']), inputSize) dataset = BoundingBoxesDataset( data=data, fileServer=self.fileServer, labelclassMap=labelclassMap, targetFormat='xyxy', transform=transform, ignoreUnsure=optionsHelper.get_hierarchical_value( self.options, ['options', 'train', 'encoding', 'ignore_unsure', 'value'], fallback=False)) dataEncoder = encoder.DataEncoder( minIoU_pos=optionsHelper.get_hierarchical_value( self.options, ['options', 'train', 'encoding', 'minIoU_pos', 'value'], fallback=0.5), maxIoU_neg=optionsHelper.get_hierarchical_value( self.options, ['options', 'train', 'encoding', 'maxIoU_neg', 'value'], fallback=0.4)) collator = collation.Collator(self.project, self.dbConnector, ( inputSize[1], inputSize[0], ), dataEncoder) dataLoader = DataLoader( dataset=dataset, collate_fn=collator.collate_fn, shuffle=optionsHelper.get_hierarchical_value( self.options, ['options', 'train', 'dataLoader', 'shuffle', 'value'], fallback=True)) # optimizer optimArgs = optionsHelper.get_hierarchical_value( self.options, ['options', 'train', 'optim', 'value'], None) optimArgs_out = {} optimClass = get_class_executable(optimArgs['id']) for key in optimArgs.keys(): if key not in optionsHelper.RESERVED_KEYWORDS: optimArgs_out[key] = optionsHelper.get_hierarchical_value( optimArgs[key], ['value']) optimizer = optimClass(params=model.parameters(), **optimArgs_out) # loss criterion critArgs = optionsHelper.get_hierarchical_value( self.options, ['options', 'train', 'criterion'], None) critArgs_out = {} for key in critArgs.keys(): if key not in optionsHelper.RESERVED_KEYWORDS: critArgs_out[key] = optionsHelper.get_hierarchical_value( critArgs[key], ['value']) criterion = loss.FocalLoss(**critArgs_out) # train model device = self.get_device() seed = int( optionsHelper.get_hierarchical_value( self.options, ['options', 'general', 'seed', 'value'], fallback=0)) torch.manual_seed(seed) if 'cuda' in device: torch.cuda.manual_seed(seed) model.to(device) imgCount = 0 for (img, bboxes_target, labels_target, fVec, _) in tqdm(dataLoader): img, bboxes_target, labels_target = img.to(device), \ bboxes_target.to(device), \ labels_target.to(device) optimizer.zero_grad() bboxes_pred, labels_pred = model(img) loss_value = criterion(bboxes_pred, bboxes_target, labels_pred, labels_target) loss_value.backward() optimizer.step() # check for Inf and NaN values and raise exception if needed if any([ torch.any(torch.isinf(bboxes_pred)).item(), torch.any(torch.isinf(labels_pred)).item(), torch.any(torch.isnan(bboxes_pred)).item(), torch.any(torch.isnan(labels_pred)).item() ]): raise Exception( 'Model produced Inf and/or NaN values; training was aborted. Try reducing the learning rate.' ) # update worker state imgCount += img.size(0) updateStateFun(state='PROGRESS', message='training', done=imgCount, total=len(dataLoader.dataset)) # all done; return state dict as bytes return self.exportModelState(model)