Пример #1
0
 def _get_transform_executable(className):
     #TODO: dirty hack to be able to import custom transform functions...
     tokens = className.split('.')
     if tokens[-2] == 'labels':
         return getattr(LabelsTransforms, tokens[-1])
     elif tokens[-2] == 'points':
         return getattr(PointsTransforms, tokens[-1])
     elif tokens[-2] == 'boundingBoxes':
         return getattr(BoundingBoxesTransforms, tokens[-1])
     else:
         return get_class_executable(className)
Пример #2
0
    def __init__(self, project, config, dbConnector, fileServer, options):
        super(GenericPyTorchModel, self).__init__(project, config, dbConnector,
                                                  fileServer, options)

        # try to fill and substitute global definitions in JSON-enhanced options
        if isinstance(options, dict) and 'defs' in options:
            try:
                updatedOptions = optionsHelper.substitute_definitions(
                    options.copy())
                self.options = updatedOptions
            except:
                # something went wrong; ignore
                pass

        # retrieve executables
        try:
            self.model_class = get_class_executable(
                optionsHelper.get_hierarchical_value(
                    self.options, ['options', 'model', 'class']))
        except:
            self.model_class = None
        try:
            self.criterion_class = get_class_executable(
                optionsHelper.get_hierarchical_value(
                    self.options, ['options', 'train', 'criterion', 'class']))
        except:
            self.criterion_class = None
        try:
            self.optim_class = get_class_executable(
                optionsHelper.get_hierarchical_value(
                    self.options, ['options', 'train', 'optim']))
        except:
            self.optim_class = SGD
        try:
            self.dataset_class = get_class_executable(
                optionsHelper.get_hierarchical_value(self.options,
                                                     ['options', 'dataset']))
        except:
            self.dataset_class = None
Пример #3
0
    def __init__(self,
                 project,
                 config,
                 dbConnector,
                 fileServer,
                 options,
                 defaultOptions=None):
        super(GenericPyTorchModel_Legacy,
              self).__init__(project, config, dbConnector, fileServer, options)

        # parse the options and compare with the provided defaults (if provided)
        if defaultOptions is not None:
            self.options = check_args(self.options, defaultOptions)
        else:
            self.options = options

        # retrieve executables
        try:
            self.model_class = get_class_executable(
                self.options['model']['class'])
        except:
            self.model_class = None
        try:
            self.criterion_class = get_class_executable(
                self.options['train']['criterion']['class'])
        except:
            self.criterion_class = None
        try:
            self.optim_class = get_class_executable(
                self.options['train']['optim']['class'])
        except:
            self.optim_class = SGD
        try:
            self.dataset_class = get_class_executable(
                self.options['dataset']['class'])
        except:
            self.dataset_class = None
Пример #4
0
    def _import_model_state_file(self,
                                 project,
                                 modelURI,
                                 modelDefinition,
                                 stateDict=None,
                                 public=True,
                                 anonymous=False,
                                 namePolicy='skip',
                                 customName=None):
        '''
            Receives two files:
            - "modelDefinition": JSON-encoded metadata file
            - "stateDict" (optional): BytesIO file containing model state
            
            Parses the files for completeness and integrity and attempts to launch 
            a model with them. If all checks are successful, the new model state is
            inserted into the database and the UUID is returned.
            Parameter "namePolicy" can have one of three values:
                - "skip" (default): skips model import if another with the same name already
                                    exists in the Model Marketplace database
                - "increment":      appends or increments a number at the end of the name if
                                    it is already found in the database
                - "custom":         uses the value under "customName"
        '''
        # check inputs
        if not isinstance(modelDefinition, dict):
            try:
                if isinstance(modelDefinition, bytes):
                    modelDefinition = json.load(io.BytesIO(modelDefinition))
                elif isinstance(modelDefinition, str):
                    modelDefinition = json.loads(modelDefinition)
            except Exception as e:
                raise Exception(
                    f'Invalid model state definition file (message: "{e}").')

        if stateDict is not None and not isinstance(stateDict, bytes):
            raise Exception(
                'Invalid model state dict provided (not a binary file).')

        # check naming policy
        namePolicy = str(namePolicy).lower()
        assert namePolicy in (
            'skip', 'increment',
            'custom'), f'Invalid naming policy "{namePolicy}".'

        if namePolicy == 'skip':
            # check if model with same name already exists
            modelID = self.getModelIdByName(modelDefinition['name'])
            if modelID is not None:
                return modelID

        elif namePolicy == 'custom':
            assert isinstance(
                customName,
                str) and len(customName), 'Invalid custom name provided.'

            # check if name is available
            if self.getModelIdByName(customName) is not None:
                raise Exception(f'Custom name "{customName}" is unavailable.')

        # project metadata
        projectMeta = self.dbConnector.execute(
            '''
            SELECT annotationType, predictionType
            FROM aide_admin.project
            WHERE shortname = %s;
        ''', (project, ), 1)
        if projectMeta is None or not len(projectMeta):
            raise Exception(
                f'Project with shortname "{project}" not found in database.')
        projectMeta = projectMeta[0]

        # check fields
        for field in ('author', 'citation_info', 'license'):
            if field not in modelDefinition:
                modelDefinition[field] = None

        for field in self.MODEL_STATE_REQUIRED_FIELDS:
            if field not in modelDefinition:
                raise Exception(f'Missing field "{field}" in AIDE JSON file.')
            if field == 'aide_model_version':
                # check model definition version
                modelVersion = modelDefinition['aide_model_version']
                if isinstance(modelVersion, str):
                    if not modelVersion.isnumeric():
                        raise Exception(
                            f'Invalid AIDE model version "{modelVersion}" in JSON file.'
                        )
                    else:
                        modelVersion = float(modelVersion)
                modelVersion = float(modelVersion)
                if modelVersion > self.MAX_AIDE_MODEL_VERSION:
                    raise Exception(
                        f'Model state contains a newer model version than supported by this installation of AIDE ({modelVersion} > {self.MAX_AIDE_MODEL_VERSION}).\nPlease update AIDE to the latest version.'
                    )

            if field == 'ai_model_library':
                # check if model library is installed
                modelLibrary = modelDefinition[field]
                if modelLibrary not in PREDICTION_MODELS:
                    raise Exception(
                        f'Model library "{modelLibrary}" is not installed in this instance of AIDE.'
                    )
                # check if annotation and prediction types match
                if projectMeta['annotationtype'] not in PREDICTION_MODELS[
                        modelLibrary]['annotationType']:
                    raise Exception(
                        'Project\'s annotation type is not compatible with this model state.'
                    )
                if projectMeta['predictiontype'] not in PREDICTION_MODELS[
                        modelLibrary]['predictionType']:
                    raise Exception(
                        'Project\'s prediction type is not compatible with this model state.'
                    )

        # check if model state URI provided
        if stateDict is None and hasattr(modelDefinition,
                                         'ai_model_state_uri'):
            stateDictURI = modelDefinition['ai_model_state_uri']
            try:
                if stateDictURI.lower().startswith('aide://'):
                    # load from disk
                    stateDictPath = stateDictURI.replace('aide://',
                                                         '').strip('/')
                    if not os.path.isfile(stateDictPath):
                        raise Exception(
                            f'Model state file path provided ("{stateDictPath}"), but file could not be found.'
                        )
                    with open(stateDictPath, 'rb') as f:
                        stateDict = f.read()  #TODO: BytesIO instead

                else:
                    # network import
                    with request.urlopen(stateDictURI) as f:
                        stateDict = f.read(
                        )  #TODO: progress bar; load in chunks; etc.

            except Exception as e:
                raise Exception(
                    f'Model state URI provided ("{stateDictURI}"), but could not be loaded (message: "{str(e)}").'
                )

        # model name
        modelName = modelDefinition['name']
        if namePolicy == 'increment':
            # check if model name needs to be incremented
            allNames = self.dbConnector.execute(
                'SELECT name FROM "aide_admin".modelmarketplace;', None, 'all')
            allNames = (set([a['name'].strip() for a in allNames])
                        if allNames is not None else set())
            if modelName.strip() in allNames:
                startIdx = 1
                insertPos = len(modelName)
                trailingNumber = re.findall(' \d+$', modelName.strip())
                if len(trailingNumber):
                    startIdx = int(trailingNumber[0])
                    insertPos = modelName.rfind(str(startIdx)) - 1
                while modelName.strip() in allNames:
                    modelName = modelName[:insertPos] + f' {startIdx}'
                    startIdx += 1

        elif namePolicy == 'custom':
            modelName = customName

        # remaining parameters
        modelAuthor = modelDefinition['author']
        modelDescription = (modelDefinition['description']
                            if 'description' in modelDefinition else None)
        modelTags = (';;'.join(modelDefinition['tags'])
                     if 'tags' in modelDefinition else None)
        labelClasses = modelDefinition['labelclasses']  #TODO: parse?
        if not isinstance(labelClasses, str):
            labelClasses = json.dumps(labelClasses)
        modelOptions = (modelDefinition['ai_model_settings']
                        if 'ai_model_settings' in modelDefinition else None)
        modelLibrary = modelDefinition['ai_model_library']
        alCriterion_library = (modelDefinition['alcriterion_library']
                               if 'alcriterion_library' in modelDefinition else
                               None)
        annotationType = PREDICTION_MODELS[modelLibrary][
            'annotationType']  #TODO
        predictionType = PREDICTION_MODELS[modelLibrary][
            'predictionType']  #TODO
        citationInfo = modelDefinition['citation_info']
        license = modelDefinition['license']
        if not isinstance(annotationType, str):
            annotationType = ','.join(annotationType)
        if not isinstance(predictionType, str):
            predictionType = ','.join(predictionType)
        timeCreated = (modelDefinition['time_created']
                       if 'time_created' in modelDefinition else None)
        try:
            timeCreated = datetime.fromtimestamp(timeCreated)
        except:
            timeCreated = current_time()

        # try to launch model with data
        try:
            modelClass = get_class_executable(modelLibrary)
            modelClass(project=project,
                       config=self.config,
                       dbConnector=self.dbConnector,
                       fileServer=FileServer(
                           self.config).get_secure_instance(project),
                       options=modelOptions)

            # verify options
            if modelOptions is not None:
                try:
                    optionMeta = modelClass.verifyOptions(modelOptions)
                    if 'options' in optionMeta:
                        modelOptions = optionMeta['options']
                    #TODO: parse warnings and errors
                except:
                    # model library does not support option verification
                    pass
                if isinstance(modelOptions, dict):
                    modelOptions = json.dumps(modelOptions)

        except Exception as e:
            raise Exception(
                f'Model from imported state could not be launched (message: "{str(e)}").'
            )

        # import model state into Model Marketplace
        success = self.dbConnector.execute(
            '''
            INSERT INTO aide_admin.modelMarketplace
                (name, description, tags, labelclasses, author, statedict,
                model_library, model_settings, alCriterion_library,
                annotationType, predictionType,
                citation_info, license,
                timeCreated,
                origin_project, origin_uuid, origin_uri, public, anonymous)
            VALUES %s
            RETURNING id;
        ''', [(modelName, modelDescription, modelTags, labelClasses,
               modelAuthor, stateDict, modelLibrary, modelOptions,
               alCriterion_library, annotationType, predictionType,
               citationInfo, license, timeCreated, project, None, modelURI,
               public, anonymous)], 1)
        if success is None or not len(success):

            #TODO: temporary fix to get ID: try again by re-querying DB
            success = self.dbConnector.execute(
                '''
                SELECT id FROM aide_admin.modelMarketplace
                WHERE name = %s;
            ''', (modelName, ), 1)
            if success is None or not len(success):
                raise Exception(
                    'Model could not be imported into Model Marketplace.')

        # model import to Marketplace successful; now import to projet
        return success[0]['id']
Пример #5
0
def main():

    # parse arguments
    parser = argparse.ArgumentParser(description='AIDE local model tester')
    parser.add_argument('--project',
                        type=str,
                        required=True,
                        help='Project shortname to draw sample data from.')
    parser.add_argument(
        '--mode',
        type=str,
        required=True,
        help=
        'Evaluation mode (function to call). One of {"train", "inference"}.')
    parser.add_argument(
        '--modelLibrary',
        type=str,
        required=False,
        help=
        'Optional AI model library override. Provide a dot-separated Python import path here.'
    )
    parser.add_argument(
        '--modelSettings',
        type=str,
        required=False,
        help=
        'Optional AI model settings override (absolute or relative path to settings file, or else "none" to not use any predefined settings).'
    )
    args = parser.parse_args()
    #TODO: continue

    assert args.mode.lower() in (
        'train', 'inference'), f'"{args.mode}" is not a known evaluation mode.'
    mode = args.mode.lower()

    # initialize required modules
    config = Config()
    dbConnector = Database(config)
    fileServer = FileServer(config).get_secure_instance(args.project)
    aiw = AIWorker(config, dbConnector, True)
    aicw = AIControllerWorker(config, None)

    # check if AIDE file server is reachable
    admin = AdminMiddleware(config, dbConnector)
    connDetails = admin.getServiceDetails(True, False)
    fsVersion = connDetails['FileServer']['aide_version']
    if not isinstance(fsVersion, str):
        # no file server running
        raise Exception(
            'ERROR: AIDE file server is not running, but required for running models. Make sure to launch it prior to running this script.'
        )
    elif fsVersion != AIDE_VERSION:
        print(
            f'WARNING: the AIDE version of File Server instance ({fsVersion}) differs from this one ({AIDE_VERSION}).'
        )

    # get model trainer instance and settings
    queryStr = '''
        SELECT ai_model_library, ai_model_settings FROM aide_admin.project
        WHERE shortname = %s;
    '''
    result = dbConnector.execute(queryStr, (args.project, ), 1)
    if result is None or not len(result):
        raise Exception(
            f'Project "{args.project}" could not be found in this installation of AIDE.'
        )

    modelLibrary = result[0]['ai_model_library']
    modelSettings = result[0]['ai_model_settings']

    customSettingsSpecified = False
    if hasattr(args, 'modelSettings') and isinstance(
            args.modelSettings, str) and len(args.modelSettings):
        # settings override specified
        if args.modelSettings.lower() == 'none':
            modelSettings = None
            customSettingsSpecified = True
        elif not os.path.isfile(args.modelSettings):
            print(
                f'WARNING: model settings override provided, but file cannot be found ("{args.modelSettings}"). Falling back to project default ("{modelSettings}").'
            )
        else:
            modelSettings = args.modelSettings
            customSettingsSpecified = True

    if hasattr(args, 'modelLibrary') and isinstance(
            args.modelLibrary, str) and len(args.modelLibrary):
        # library override specified; try to import it
        try:
            modelClass = helpers.get_class_executable(args.modelLibrary)
            if modelClass is None:
                raise
            modelLibrary = args.modelLibrary

            # re-check if current model settings are compatible; warn and set to None if not
            if modelLibrary != result[0][
                    'ai_model_library'] and not customSettingsSpecified:
                # project model settings are not compatible with provided model
                print(
                    'WARNING: custom model library specified differs from the one currently set in project. Model settings will be set to None.'
                )
                modelSettings = None

        except Exception as e:
            print(
                f'WARNING: model library override provided ("{args.modelLibrary}"), but could not be imported. Falling back to project default ("{modelLibrary}").'
            )

    # initialize instance
    print(f'Using model library "{modelLibrary}".')
    modelTrainer = aiw._init_model_instance(args.project, modelLibrary,
                                            modelSettings)

    stateDict = None  #TODO: load latest unless override is specified?

    # get data
    data = aicw.get_training_images(project=args.project, maxNumImages=512)
    data = __load_metadata(args.project, dbConnector, data[0],
                           (mode == 'train'))

    # helper functions
    def updateStateFun(state, message, done=None, total=None):
        print(message, end='')
        if done is not None and total is not None:
            print(f': {done}/{total}')
        else:
            print('')

    # launch task
    if mode == 'train':
        result = modelTrainer.train(stateDict, data, updateStateFun)
        if result is None:
            raise Exception(
                'Training function must return an object (i.e., trained model state) to be stored in the database.'
            )

    elif mode == 'inference':
        result = modelTrainer.inference(stateDict, data, updateStateFun)
Пример #6
0
    def train(self, stateDict, data, updateStateFun):
        '''
            Initializes a model based on the given stateDict and a data loader from the
            provided data and trains the model, taking into account the parameters speci-
            fied in the 'options' given to the class.
            Returns a serializable state dict of the resulting model.
        '''

        # initialize model
        model, labelclassMap = self.initializeModel(stateDict, data)

        # setup transform, data loader, dataset, optimizer, criterion
        transform = parse_transforms(self.options['train']['transform'])

        dataset = self.dataset_class(data, self.fileServer, labelclassMap,
                                     transform,
                                     self.options['train']['ignore_unsure'],
                                     **self.options['dataset']['kwargs'])

        collator = Collator(self.project, self.dbConnector)
        dataLoader = DataLoader(
            dataset,
            collate_fn=collator.collate,
            **self.options['train']['dataLoader']['kwargs'])

        optimizer_class = get_class_executable(
            self.options['train']['optim']['class'])
        optimizer = optimizer_class(params=model.parameters(),
                                    **self.options['train']['optim']['kwargs'])

        criterion_class = get_class_executable(
            self.options['train']['criterion']['class'])
        criterion = criterion_class(
            **self.options['train']['criterion']['kwargs'])

        # train model
        device = self.get_device()
        torch.manual_seed(self.options['general']['seed'])
        if 'cuda' in device:
            torch.cuda.manual_seed(self.options['general']['seed'])

        model.to(device)
        imgCount = 0
        for (img, labels, fVec, _) in tqdm(dataLoader):
            img, labels = img.to(device), labels.to(device)

            optimizer.zero_grad()
            pred = model(img)
            loss_value = criterion(pred, labels)
            loss_value.backward()
            optimizer.step()

            # update worker state
            imgCount += img.size(0)
            updateStateFun(state='PROGRESS',
                           message='training',
                           done=imgCount,
                           total=len(dataLoader.dataset))

        # all done; return state dict as bytes
        return self.exportModelState(model)
Пример #7
0
    def train(self, stateDict, data, updateStateFun):
        '''
            Initializes a model based on the given stateDict and a data loader from the
            provided data and trains the model, taking into account the parameters speci-
            fied in the 'options' given to the class.
            Trains the model with either mode (fully or weakly supervised), depending on
            the individual images' annotation types. This is handled by the dataset class.
            Returns a serializable state dict of the resulting model.
        '''

        # initialize model
        model, labelclassMap = self.initializeModel(stateDict, data)

        inputSize = tuple(self.options['general']['image_size'])
        targetSize = model.getOutputSize(inputSize)

        # setup transform, data loader, dataset, optimizer, criterion
        transform = parse_transforms(self.options['train']['transform'])

        dataset = self.dataset_class(data, self.fileServer, labelclassMap,
                                     transform,
                                     self.options['train']['ignore_unsure'],
                                     **self.options['dataset']['kwargs'])

        dataEncoder = encoder.DataEncoder(len(labelclassMap.keys()))
        collator = collation.Collator(self.project, self.dbConnector,
                                      targetSize, dataEncoder)
        dataLoader = DataLoader(
            dataset=dataset,
            collate_fn=collator.collate_fn,
            **self.options['train']['dataLoader']['kwargs'])

        # optimizer
        optimizer_class = get_class_executable(
            self.options['train']['optim']['class'])
        optimizer = optimizer_class(params=model.parameters(),
                                    **self.options['train']['optim']['kwargs'])

        # loss criterion
        criterion_class = get_class_executable(
            self.options['train']['criterion']['class'])
        criterion = criterion_class(
            **self.options['train']['criterion']['kwargs'])

        # train model
        device = self.get_device()
        torch.manual_seed(self.options['general']['seed'])
        if 'cuda' in device:
            torch.cuda.manual_seed(self.options['general']['seed'])

        model.to(device)
        imgCount = 0
        for (img, locs_target, cls_images, fVec, _) in tqdm(dataLoader):
            img, locs_target, cls_images = img.to(device), \
                                                locs_target.to(device), \
                                                cls_images.to(device)

            optimizer.zero_grad()
            locs_pred = model(img)
            loss_value = criterion(locs_pred, locs_target, cls_images)
            loss_value.backward()
            optimizer.step()

            # update worker state
            imgCount += img.size(0)
            updateStateFun(state='PROGRESS',
                           message='training',
                           done=imgCount,
                           total=len(dataLoader.dataset))

        # all done; return state dict as bytes
        return self.exportModelState(model)
    # setup
    print('Setup...')
    if not 'AIDE_CONFIG_PATH' in os.environ:
        os.environ['AIDE_CONFIG_PATH'] = str(args.settings_filepath)

    from util.configDef import Config
    from modules.Database.app import Database
    config = Config()
    dbConn = Database(config)

    # load model class function
    print('Load and verify state dict...')
    from util.helpers import get_class_executable, current_time
    modelClass = getattr(
        get_class_executable(
            config.getProperty('AIController', 'model_lib_path')),
        'model_class')

    # load state dict
    stateDict = torch.load(open(args.modelPath, 'rb'))

    # verify model state
    model = modelClass.loadFromStateDict(stateDict)

    # load class definitions from database
    classdef_db = {}
    labelClasses = dbConn.execute(
        'SELECT * FROM {schema}.labelclass;'.format(
            schema=config.getProperty('Database', 'schema')), None, 'all')
    for lc in labelClasses:
        classdef_db[lc['id']] = lc
Пример #9
0
    def __init__(self, config, dbConnector, fileServer, options):

        # parse provided functions
        self.heuristics = []
        for h in options['rank']['heuristics']:
            self.heuristics.append(get_class_executable(h))
Пример #10
0
    def train(self, stateDict, data, updateStateFun):
        '''
            Initializes a model based on the given stateDict and a data loader from the
            provided data and trains the model, taking into account the parameters speci-
            fied in the 'options' given to the class.
            Returns a serializable state dict of the resulting model.
        '''
        # initialize model
        model, labelclassMap = self.initializeModel(
            stateDict, data,
            optionsHelper.get_hierarchical_value(
                self.options,
                ['options', 'general', 'labelClasses', 'add_missing', 'value'
                 ]),
            optionsHelper.get_hierarchical_value(self.options, [
                'options', 'general', 'labelClasses', 'remove_obsolete',
                'value'
            ]))

        # setup transform, data loader, dataset, optimizer, criterion
        inputSize = (int(
            optionsHelper.get_hierarchical_value(
                self.options,
                ['options', 'general', 'imageSize', 'width', 'value'])),
                     int(
                         optionsHelper.get_hierarchical_value(
                             self.options, [
                                 'options', 'general', 'imageSize', 'height',
                                 'value'
                             ])))

        transform = RetinaNet._init_transform_instances(
            optionsHelper.get_hierarchical_value(
                self.options, ['options', 'train', 'transform', 'value']),
            inputSize)

        dataset = BoundingBoxesDataset(
            data=data,
            fileServer=self.fileServer,
            labelclassMap=labelclassMap,
            targetFormat='xyxy',
            transform=transform,
            ignoreUnsure=optionsHelper.get_hierarchical_value(
                self.options,
                ['options', 'train', 'encoding', 'ignore_unsure', 'value'],
                fallback=False))

        dataEncoder = encoder.DataEncoder(
            minIoU_pos=optionsHelper.get_hierarchical_value(
                self.options,
                ['options', 'train', 'encoding', 'minIoU_pos', 'value'],
                fallback=0.5),
            maxIoU_neg=optionsHelper.get_hierarchical_value(
                self.options,
                ['options', 'train', 'encoding', 'maxIoU_neg', 'value'],
                fallback=0.4))
        collator = collation.Collator(self.project, self.dbConnector, (
            inputSize[1],
            inputSize[0],
        ), dataEncoder)
        dataLoader = DataLoader(
            dataset=dataset,
            collate_fn=collator.collate_fn,
            shuffle=optionsHelper.get_hierarchical_value(
                self.options,
                ['options', 'train', 'dataLoader', 'shuffle', 'value'],
                fallback=True))

        # optimizer
        optimArgs = optionsHelper.get_hierarchical_value(
            self.options, ['options', 'train', 'optim', 'value'], None)
        optimArgs_out = {}
        optimClass = get_class_executable(optimArgs['id'])
        for key in optimArgs.keys():
            if key not in optionsHelper.RESERVED_KEYWORDS:
                optimArgs_out[key] = optionsHelper.get_hierarchical_value(
                    optimArgs[key], ['value'])
        optimizer = optimClass(params=model.parameters(), **optimArgs_out)

        # loss criterion
        critArgs = optionsHelper.get_hierarchical_value(
            self.options, ['options', 'train', 'criterion'], None)
        critArgs_out = {}
        for key in critArgs.keys():
            if key not in optionsHelper.RESERVED_KEYWORDS:
                critArgs_out[key] = optionsHelper.get_hierarchical_value(
                    critArgs[key], ['value'])
        criterion = loss.FocalLoss(**critArgs_out)

        # train model
        device = self.get_device()
        seed = int(
            optionsHelper.get_hierarchical_value(
                self.options, ['options', 'general', 'seed', 'value'],
                fallback=0))
        torch.manual_seed(seed)
        if 'cuda' in device:
            torch.cuda.manual_seed(seed)
        model.to(device)
        imgCount = 0
        for (img, bboxes_target, labels_target, fVec, _) in tqdm(dataLoader):
            img, bboxes_target, labels_target = img.to(device), \
                                                bboxes_target.to(device), \
                                                labels_target.to(device)

            optimizer.zero_grad()
            bboxes_pred, labels_pred = model(img)
            loss_value = criterion(bboxes_pred, bboxes_target, labels_pred,
                                   labels_target)
            loss_value.backward()
            optimizer.step()

            # check for Inf and NaN values and raise exception if needed
            if any([
                    torch.any(torch.isinf(bboxes_pred)).item(),
                    torch.any(torch.isinf(labels_pred)).item(),
                    torch.any(torch.isnan(bboxes_pred)).item(),
                    torch.any(torch.isnan(labels_pred)).item()
            ]):
                raise Exception(
                    'Model produced Inf and/or NaN values; training was aborted. Try reducing the learning rate.'
                )

            # update worker state
            imgCount += img.size(0)
            updateStateFun(state='PROGRESS',
                           message='training',
                           done=imgCount,
                           total=len(dataLoader.dataset))

        # all done; return state dict as bytes
        return self.exportModelState(model)