def load_modelinfo(cfg, api_model_key): log.info("----------------------------->") log.info("api_model_key: {}".format(api_model_key)) modelinfo = {} modelinfo['API_MODEL_KEY'] = None modelinfo['MODEL'] = None modelinfo['MODELCFG'] = None modelinfo['DETECT'] = None modelinfo['DETECT_WITH_JSON'] = None modelinfo['DETECT_BATCH'] = None modelinfo['DNNARCH'] = None try: modelcfg = get_modelcfg(cfg, api_model_key) log.debug("modelcfg: {}".format(modelcfg)) if modelcfg: mode = modelcfg['mode'] dnnarch = modelcfg['dnnarch'] dnnmod = apputil.get_module(dnnarch) load_model_and_weights = apputil.get_module_fn( dnnmod, "load_model_and_weights") model = load_model_and_weights(mode, modelcfg, cfg) log.info("model: {}".format(model)) detect = apputil.get_module_fn(dnnmod, "detect") detect_with_json = apputil.get_module_fn(dnnmod, "detect_with_json") detect_batch = apputil.get_module_fn(dnnmod, "detect_batch") modelinfo['API_MODEL_KEY'] = api_model_key modelinfo['MODEL'] = model modelinfo['MODELCFG'] = modelcfg modelinfo['DETECT'] = detect modelinfo['DETECT_WITH_JSON'] = detect_with_json modelinfo['DETECT_BATCH'] = detect_batch modelinfo['DNNARCH'] = dnnarch else: # log.error("No modelinfo found for the criteria!") raise Exception("No modelinfo found for the criteria!") except Exception as e: msg = "'Not a Valid Model or error in loading model and weights'" log.error("Exception occurred: {}".format(msg), exc_info=True) return modelinfo
def detect_from_images(appcfg, dnnmod, images, path, model, class_names, cmdcfg, api_model_key, show_bbox=False): """detections from the images Convention: image - image filename filepath - the absolute path of the image input file location im - binary data after reading the image file TODO: 1. Prediction details log: - model details (path), copy of configuration, arch used, all class_names used in predictions, execution time etc. 2. Verify that masks are properly scaled to the original image dimensions 3. Impact on prediction of replacing skimage.io.imread with imread wrapper 4. call response providing the pointer to the saved files 5. viz from jsonres 6. memory leak in reading image as read time increases 7. async file and DB operation. MongoDB limit of 16 MB datasize """ ## always create abs filepaths and respective directories timestamp = "{:%d%m%y_%H%M%S}".format(datetime.datetime.now()) filepath = os.path.join(path, "predict-"+timestamp) common.mkdir_p(filepath) for d in ['splash', 'mask', 'annotations', 'viz', 'mmask', 'oframe']: common.mkdir_p(os.path.join(filepath,d)) detect = apputil.get_module_fn(dnnmod, "detect") DBCFG = appcfg['APP']['DBCFG'] CBIRCFG = DBCFG['CBIRCFG'] # mclient = motor.motor_asyncio.AsyncIOMotorClient('mongodb://'+CBIRCFG['host']+':'+str(CBIRCFG['port'])) # dbname = CBIRCFG['dbname'] # db = mclient[dbname] # collection = db['IMAGES'] # _create_res(detect, filepath, images, path, model, class_names, cmdcfg, api_model_key) loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: loop.run_until_complete(_create_res(detect, filepath, images, path, model, class_names, cmdcfg, api_model_key, show_bbox=show_bbox)) finally: # shutting down and closing fil descriptors after interupt loop.run_until_complete(loop.shutdown_asyncgens()) loop.close() file_names = [] res = [] return file_names,res
def inspect_annon(args, mode, appcfg): """inspection of data from command line for quick verification of data sanity """ log.debug("---------------------------->") log.debug("Inspecting annotations...") subset = args.eval_on log.debug("subset: {}".format(subset)) datacfg = apputil.get_datacfg(appcfg) dbcfg = apputil.get_dbcfg(appcfg) dataset, num_classes, num_images, class_names, total_stats, total_verify = apputil.get_dataset_instance(appcfg, dbcfg, datacfg, subset) colors = viz.random_colors(len(class_names)) log.debug("class_names: {}".format(class_names)) log.debug("len(class_names): {}".format(len(class_names))) log.debug("len(colors), colors: {},{}".format(len(colors), colors)) log.debug("num_classes: {}".format(num_classes)) log.debug("num_images: {}".format(num_images)) name = dataset.name datacfg.name = name datacfg.classes = class_names datacfg.num_classes = num_classes # log.debug("dataset: {}".format(vars(dataset))) log.debug("len(dataset.image_info): {}".format(len(dataset.image_info))) log.debug("len(dataset.image_ids): {}".format(len(dataset.image_ids))) mod = apputil.get_module('inspect_annon') archcfg = apputil.get_archcfg(appcfg) log.debug("archcfg: {}".format(archcfg)) cmdcfg = archcfg cmdcfg.name = name cmdcfg.config.NAME = name cmdcfg.config.NUM_CLASSES = num_classes dnnmod = apputil.get_module(cmdcfg.dnnarch) get_dnncfg = apputil.get_module_fn(dnnmod, "get_dnncfg") dnncfg = get_dnncfg(cmdcfg.config) log.debug("config.MINI_MASK_SHAPE: {}".format(dnncfg.MINI_MASK_SHAPE)) log.debug("type(dnncfg.MINI_MASK_SHAPE): {}".format(type(dnncfg.MINI_MASK_SHAPE))) mod.all_steps(dataset, datacfg, dnncfg) return
def detect_from_videos(appcfg, dnnmod, videos, path, model, class_names, cmdcfg, api_model_key, show_bbox=False): """detect_from_videos Code adopted from: Copyright (c) 2018 Matterport, Inc. Licensed under the MIT License (see LICENSE for details) Originally, Written by Waleed Abdulla --- Key contribution: * saving the annotated results directly * saving the annotated mask only * annotation results as json response for consumption in API, VGG VIA compatible results Copyright (c) 2020 mangalbhaskar Licensed under [see LICENSE for details] Written by mangalbhaskar --- Conventions: video - video filename filepath - the absolute path of the video input file location vid - binary data after reading the video file """ import cv2 save_viz_and_json = cmdcfg.save_viz_and_json if 'save_viz_and_json' in cmdcfg else False if save_viz_and_json: timestamp = "{:%d%m%y_%H%M%S}".format(datetime.datetime.now()) filepath = os.path.join(path,"predict-"+timestamp) log.debug("filepath: {}".format(filepath)) common.mkdir_p(filepath) file_names = [] res = [] detect = apputil.get_module_fn(dnnmod, "detect") colors = viz.random_colors(len(class_names)) log.debug("class_names: {}".format(class_names)) log.debug("len(class_names), class_names: {},{}".format(len(class_names), class_names)) log.debug("len(colors), colors: {},{}".format(len(colors), colors)) cc = dict(zip(class_names,colors)) for video in videos: ## Run model detection and save the outputs log.debug("Running on {}".format(video)) ## Read Video ##--------------------------------------------- filepath_video = os.path.join(path, video) log.debug("Processing video with filepath_video: {}".format(filepath_video)) vid = cv2.VideoCapture(filepath_video) width = int(vid.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = vid.get(cv2.CAP_PROP_FPS) vname, vext = os.path.splitext(video) file_name = video if save_viz_and_json: ## oframe - original image frame from the video ## pframe or viz - annotations visualization frame from the video ## annotations - annotations json per frame video_viz_basepath = os.path.join(filepath,vname) path_oframe = os.path.join(video_viz_basepath,"oframe") path_pframe = os.path.join(video_viz_basepath,"pframe") path_sframe = os.path.join(video_viz_basepath,"splash") path_mframe = os.path.join(video_viz_basepath,"mask") path_mmframe = os.path.join(video_viz_basepath,"mmask") path_viz = os.path.join(video_viz_basepath,"viz") path_annotations = os.path.join(video_viz_basepath,"annotations") for d in [path_oframe, path_pframe, path_annotations, path_sframe, path_mframe, path_mmframe, path_viz]: log.debug("videos dirs: {}".format(d)) common.mkdir_p(d) ## Define codec and create video writer ##--------------------------------------------- # file_name = "{:%d%m%y_%H%M%S}.avi".format(datetime.datetime.now()) fext = ".avi" file_name = vname+fext filepath_pvideo = os.path.join(filepath, vname, file_name) log.debug("filepath_pvideo: {}".format(filepath_pvideo)) count = 0 success = True frame_cutoff = 0 from_frame = 0 while success: log.debug("-------") log.debug("frame: {}".format(count)) if frame_cutoff and count >= frame_cutoff: break ## start predictions specific 'from the specific frame number' if from_frame and count < from_frame: count += 1 continue ## Read next image success, oframe_im = vid.read() if success: oframe_name = str(count)+"_"+video+".png" ## OpenCV returns images as BGR, convert to RGB oframe_im_rgb = oframe_im[..., ::-1] ## Detect objects t1 = time.time() # r = detect(model, im=oframe_im_rgb, verbose=0) r = detect(model, im=oframe_im_rgb, verbose=1)[0] t2 = time.time() time_taken = (t2 - t1) log.debug('Total time taken in detect: %f seconds' %(time_taken)) ## Convert Json response to VIA Json response ##--------------------------------------------- t1 = time.time() if save_viz_and_json: # pframe_im, jsonres = viz.get_display_instances(oframe_im_rgb, r['rois'], r['masks'], r['class_ids'], class_names, r['scores'], colors=cc, show_bbox=False) jsonres = viz.get_display_instances(oframe_im_rgb, r['rois'], r['masks'], r['class_ids'], class_names, r['scores'], colors=cc, show_bbox=False, auto_show=False, filepath=video_viz_basepath, filename=oframe_name) else: jsonres = viz.get_detections(oframe_im_rgb, r['rois'], r['masks'], r['class_ids'], class_names, r['scores'], colors=cc) t2 = time.time() time_taken = (t2 - t1) log.debug('Total time taken in detections: %f seconds' %(time_taken)) ## Convert Json response to VIA Json response ##--------------------------------------------- t1 = time.time() size_oframe = 0 jsonres["filename"] = oframe_name jsonres["size"] = size_oframe via_jsonres = {} via_jsonres[oframe_name+str(size_oframe)] = jsonres json_str = common.numpy_to_json(via_jsonres) # log.debug("json_str:\n{}".format(json_str)) t2 = time.time() time_taken = (t2 - t1) log.debug('Total time taken in json_str: %f seconds' %(time_taken)) ## Create Visualisations & Save output ##--------------------------------------------- if save_viz_and_json: t1 = time.time() ## Color Splash Effect ## Save vframe and video buffer ##--------------------------------------------- # splash = viz.color_splash(oframe_im_rgb, r['masks']) # # RGB -> BGR to save image to video # splash = splash[..., ::-1] # # Add image to video writer # vwriter_splash.write(splash) ## Color Mask Effect ## Save vframe and video buffer ##--------------------------------------------- # mframe_im = viz.color_mask(oframe_im_rgb, r['masks']) # ## RGB -> BGR to save image to video # ## mframe_im = mframe_im[..., ::-1] # filepath_mframe = os.path.join(path_mframe, oframe_name) # viz.imsave(filepath_mframe, mframe_im) ## Annotation Visualisation ## Save vframe and video buffer ##--------------------------------------------- # filepath_pframe = os.path.join(path_pframe, oframe_name) # viz.imsave(filepath_pframe, pframe_im) # filepath_oframe = os.path.join(path_oframe, oframe_name) # viz.imsave(filepath_oframe, oframe_im_rgb) # # size_oframe = os.path.getsize(filepath_oframe) filepath_jsonres = os.path.join(path_annotations, oframe_name+".json") log.debug("filepath_jsonres: {}".format(filepath_jsonres)) with open(filepath_jsonres,'w') as fw: fw.write(json_str) ## TODO: using the opencv itself created visualisation video from individual frames # pframe_im_bgr = pframe_im[..., ::-1] # height, width = pframe_im_bgr.shape[:2] # ## int(vid.get(cv2.CAP_PROP_FRAME_WIDTH)) # ## height = int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT)) # ## vwriter_splash = cv2.VideoWriter(os.path.join(filepath, 'splash_'+file_name), cv2.VideoWriter_fourcc(*'MJPG'), fps, (width, height)) # vwriter_viz = cv2.VideoWriter(filepath_pvideo, cv2.VideoWriter_fourcc(*'MJPG'), fps, (width, height)) # vwriter_viz.write(pframe_im_bgr) # ## Add image to video writer # ## vwriter_mask.write(mframe_im) res.append(json_str) count += 1 # if save_viz_and_json: # ## vwriter_splash.release() # vwriter_viz.release() file_names.append(file_name) ## https://stackoverflow.com/questions/36643139/python-and-opencv-cannot-write-readable-avi-video-files ## ffmpeg -framerate 29 -i MAH04240.mp4-%d.png -c:v libx264 -r 30 MAH04240-maskrcnn-viz.mp4 ## ffmpeg -framerate 29 -i %d_MAH04240.mp4.png -c:v libx264 -r 30 MAH04240-maskrcnn-viz.mp4 return file_names,res
def evaluate(args, mode, appcfg): """prepare the report configuration like paths, report names etc. and calls the report generation function """ log.debug("evaluate---------------------------->") subset = args.eval_on iou_threshold = args.iou log.debug("subset: {}".format(subset)) log.debug("iou_threshold: {}".format(iou_threshold)) get_mask = True auto_show = False datacfg = apputil.get_datacfg(appcfg) dbcfg = apputil.get_dbcfg(appcfg) log.debug("appcfg: {}".format(appcfg)) log.debug("datacfg: {}".format(datacfg)) dataset, num_classes, num_images, class_names, total_stats, total_verify = apputil.get_dataset_instance(appcfg, dbcfg, datacfg, subset) colors = viz.random_colors(len(class_names)) log.debug("-------") log.debug("len(colors), colors: {},{}".format(len(colors), colors)) log.debug("class_names: {}".format(class_names)) log.debug("len(class_names): {}".format(len(class_names))) log.debug("num_classes: {}".format(num_classes)) log.debug("num_images: {}".format(num_images)) log.debug("len(dataset.image_info): {}".format(len(dataset.image_info))) log.debug("len(dataset.image_ids): {}".format(len(dataset.image_ids))) # log.debug("dataset: {}".format(vars(dataset))) log.debug("-------") # log.debug("TODO: color: cc") # cc = dict(zip(class_names,colors)) name = dataset.name datacfg.name = name datacfg.classes = class_names datacfg.num_classes = num_classes archcfg = apputil.get_archcfg(appcfg) log.debug("archcfg: {}".format(archcfg)) cmdcfg = archcfg if 'save_viz_and_json' not in cmdcfg: cmdcfg.save_viz_and_json = False save_viz = args.save_viz log.debug("save_viz: {}".format(save_viz)) cmdcfg.save_viz_and_json = save_viz modelcfg_path = os.path.join(appcfg.PATHS.AI_MODEL_CFG_PATH, cmdcfg.model_info) log.info("modelcfg_path: {}".format(modelcfg_path)) modelcfg = apputil.get_modelcfg(modelcfg_path) ## for prediction, get the label information from the model information class_names_model = apputil.get_class_names(modelcfg) log.debug("class_names_model: {}".format(class_names_model)) cmdcfg.name = name cmdcfg.config.NAME = modelcfg.name cmdcfg.config.NUM_CLASSES = len(class_names_model) # class_names = apputil.get_class_names(datacfg) # log.debug("class_names: {}".format(class_names)) weights_path = apputil.get_abs_path(appcfg, modelcfg, 'AI_WEIGHTS_PATH') cmdcfg['weights_path'] = weights_path ## Prepare directory structure and filenames for reporting the evluation results now = datetime.datetime.now() ## create log directory based on timestamp for evaluation reporting timestamp = "{:%d%m%y_%H%M%S}".format(now) datacfg_ts = datacfg.timestamp if 'TIMESTAMP' in datacfg else timestamp save_viz_and_json = cmdcfg.save_viz_and_json # iou_threshold = cmdcfg.iou_threshold if 'evaluate_no_of_result' not in cmdcfg: evaluate_no_of_result = -1 else: evaluate_no_of_result = cmdcfg.evaluate_no_of_result def clean_iou(iou): return str("{:f}".format(iou)).replace('.','')[:3] path = appcfg['PATHS']['AI_LOGS'] # evaluate_dir = datacfg_ts+"-evaluate_"+clean_iou(iou_threshold)+"-"+name+"-"+subset+"-"+timestamp evaluate_dir = "evaluate_"+clean_iou(iou_threshold)+"-"+name+"-"+subset+"-"+timestamp filepath = os.path.join(path, cmdcfg.dnnarch, evaluate_dir) log.debug("filepath: {}".format(filepath)) common.mkdir_p(filepath) for d in ['splash', 'mask', 'annotations', 'viz']: common.mkdir_p(os.path.join(filepath,d)) ## gt - ground truth ## pr/pred - prediction def get_cfgfilename(cfg_filepath): return cfg_filepath.split(os.path.sep)[-1] ## generate the summary on the evaluation run evaluate_run_summary = defaultdict(list) evaluate_run_summary['name'] =name evaluate_run_summary['execution_start_time'] = timestamp evaluate_run_summary['subset'] = subset evaluate_run_summary['total_labels'] = num_classes evaluate_run_summary['total_images'] = num_images evaluate_run_summary['evaluate_no_of_result'] = evaluate_no_of_result evaluate_run_summary['evaluate_dir'] = evaluate_dir evaluate_run_summary['dataset'] = get_cfgfilename(appcfg.DATASET[appcfg.ACTIVE.DATASET].cfg_file) evaluate_run_summary['arch'] = get_cfgfilename(appcfg.ARCH[appcfg.ACTIVE.ARCH].cfg_file) evaluate_run_summary['model'] = cmdcfg['model_info'] ## classification report and confusion matrix - json and csv ## generate the filenames for what reports to be generated reportcfg = { 'filepath':filepath ,'evaluate_run_summary_reportfile':os.path.join(filepath, "evaluate_run_summary_rpt-"+subset) ,'classification_reportfile':os.path.join(filepath, "classification_rpt-"+subset) ,'confusionmatrix_reportfile':os.path.join(filepath, "confusionmatrix_rpt-"+subset) ,'iou_threshold':iou_threshold ,'evaluate_run_summary':evaluate_run_summary ,'save_viz_and_json':save_viz_and_json ,'evaluate_no_of_result':evaluate_no_of_result } log.debug("reportcfg: {}".format(reportcfg)) dnnmod = apputil.get_module(cmdcfg.dnnarch) fn_evaluate = apputil.get_module_fn(dnnmod, "evaluate") evaluate_run_summary = fn_evaluate(mode, cmdcfg, appcfg, modelcfg, dataset, datacfg, class_names, reportcfg, get_mask) return evaluate_run_summary
def predict(args, mode, appcfg): """Executes the prediction and stores the generated results TODO: 1. create the prediction configuration 2. PDB specification """ log.debug("predict---------------------------->") archcfg = apputil.get_archcfg(appcfg) log.debug("cmdcfg/archcfg: {}".format(archcfg)) cmdcfg = archcfg if 'save_viz_and_json' not in cmdcfg: cmdcfg.save_viz_and_json = False save_viz = args.save_viz show_bbox = args.show_bbox log.debug("save_viz: {}".format(save_viz)) cmdcfg.save_viz_and_json = save_viz modelcfg_path = os.path.join(appcfg.PATHS.AI_MODEL_CFG_PATH, cmdcfg.model_info) log.info("modelcfg_path: {}".format(modelcfg_path)) modelcfg = apputil.get_modelcfg(modelcfg_path) log.debug("modelcfg: {}".format(modelcfg)) api_model_key = apputil.get_api_model_key(modelcfg) log.debug("api_model_key: {}".format(api_model_key)) ## for prediction, get the label information from the model information class_names = apputil.get_class_names(modelcfg) log.debug("class_names: {}".format(class_names)) num_classes = len(class_names) name = modelcfg.name cmdcfg.name = name cmdcfg.config.NAME = name cmdcfg.config.NUM_CLASSES = num_classes dnnmod = apputil.get_module(cmdcfg.dnnarch) ## todo: hard-coding clear up cmdcfg['log_dir'] = 'predict' log_dir_path = apputil.get_abs_path(appcfg, cmdcfg, 'AI_LOGS') cmdcfg['log_dir_path'] = log_dir_path weights_path = apputil.get_abs_path(appcfg, modelcfg, 'AI_WEIGHTS_PATH') cmdcfg['weights_path'] = weights_path load_model_and_weights = apputil.get_module_fn(dnnmod, "load_model_and_weights") model = load_model_and_weights(mode, cmdcfg, appcfg) path_dtls = apputil.get_path_dtls(args, appcfg) log.debug("path_dtls: {}".format(path_dtls)) for t in ["images", "videos"]: if path_dtls[t] and len(path_dtls[t]) > 0: fname = "detect_from_"+t log.info("fname: {}".format(fname)) fn = getattr(this, fname) if fn: file_names, res = fn(appcfg, dnnmod, path_dtls[t], path_dtls['path'], model, class_names, cmdcfg, api_model_key, show_bbox) # log.debug("len(file_names), file_names: {}, {}".format(len(file_names), file_names)) else: log.error("Unkown fn: {}".format(fname)) # return file_names, res return
def train(args, mode, appcfg): log.debug("train---------------------------->") datacfg = apputil.get_datacfg(appcfg) ## Training dataset. subset = "train" log.info("subset: {}".format(subset)) dbcfg = apputil.get_dbcfg(appcfg) dataset_train, num_classes_train, num_images_train, class_names_train, total_stats_train, total_verify_train = apputil.get_dataset_instance(appcfg, dbcfg, datacfg, subset) colors = viz.random_colors(len(class_names_train)) log.info("-------") log.info("len(colors), colors: {},{}".format(len(colors), colors)) log.info("subset, class_names_train: {}, {}".format(subset, class_names_train)) log.info("subset, len(class_names_train): {}, {}".format(subset, len(class_names_train))) log.info("subset, num_classes_train: {}, {}".format(subset, num_classes_train)) log.info("subset, num_images_train: {}, {}".format(subset, num_images_train)) log.info("subset, len(dataset_train.image_info): {}, {}".format(subset, len(dataset_train.image_info))) log.info("subset, len(dataset_train.image_ids): {}, {}".format(subset, len(dataset_train.image_ids))) ## Validation dataset subset = "val" log.info("subset: {}".format(subset)) dataset_val, num_classes_val, num_images_val, class_names_val, total_stats_val, total_verify_val = apputil.get_dataset_instance(appcfg, dbcfg, datacfg, subset) log.info("-------") log.info("subset, class_names_val: {}, {}".format(subset, class_names_val)) log.info("subset, len(class_names_val): {}, {}".format(subset, len(class_names_val))) log.info("subset, num_classes_val: {}, {}".format(subset, num_classes_val)) log.info("subset, num_images_val: {}, {}".format(subset, num_images_val)) log.info("subset, len(dataset_val.image_info): {}, {}".format(subset, len(dataset_val.image_info))) log.info("subset, len(dataset_val.image_ids): {}, {}".format(subset, len(dataset_val.image_ids))) log.info("-------") ## Ensure label sequence and class_names of train and val dataset are excatly same, if not abort training assert class_names_train == class_names_val archcfg = apputil.get_archcfg(appcfg) log.debug("archcfg: {}".format(archcfg)) cmdcfg = archcfg name = dataset_train.name ## generate the modelinfo template to be used for evaluate and prediction modelinfocfg = { 'classes': class_names_train.copy() ,'classinfo': None ,'config': cmdcfg.config.copy() ,'dataset': cmdcfg.dbname ,'dbname': cmdcfg.dbname ,'dnnarch': cmdcfg.dnnarch ,'framework_type': cmdcfg.framework_type ,'id': None ,'load_weights': cmdcfg.load_weights.copy() ,'name': name ,'num_classes': num_classes_train ,'problem_id': None ,'rel_num': None ,'weights': None ,'weights_path': None ,'log_dir': None ,'checkpoint_path': None ,'model_info': None ,'timestamp': None ,'creator': None } datacfg.name = name datacfg.classes = class_names_train datacfg.num_classes = num_classes_train cmdcfg.name = name cmdcfg.config.NAME = name cmdcfg.config.NUM_CLASSES = num_classes_train modelcfg_path = os.path.join(appcfg.PATHS.AI_MODEL_CFG_PATH, cmdcfg.model_info) log.info("modelcfg_path: {}".format(modelcfg_path)) modelcfg = apputil.get_modelcfg(modelcfg_path) log_dir_path = apputil.get_abs_path(appcfg, cmdcfg, 'AI_LOGS') cmdcfg['log_dir_path'] = log_dir_path weights_path = apputil.get_abs_path(appcfg, modelcfg, 'AI_WEIGHTS_PATH') cmdcfg['weights_path'] = weights_path dnnmod = apputil.get_module(cmdcfg.dnnarch) load_model_and_weights = apputil.get_module_fn(dnnmod, "load_model_and_weights") model = load_model_and_weights(mode, cmdcfg, appcfg) modelinfocfg['log_dir'] = model.log_dir modelinfocfg['checkpoint_path'] = model.checkpoint_path if 'creator' in cmdcfg: modelinfocfg['creator'] = cmdcfg['creator'] log.info("modelinfocfg: {}".format(modelinfocfg)) fn_create_modelinfo = apputil.get_module_fn(dnnmod, "create_modelinfo") modelinfo = fn_create_modelinfo(modelinfocfg) create_modelinfo = args.create_modelinfo try: if not create_modelinfo: log.info("Training...") fn_train = apputil.get_module_fn(dnnmod, "train") fn_train(model, dataset_train, dataset_val, cmdcfg) log.info("Training Completed!!!") finally: ## save modelinfo ## popolate the relative weights_path of the last model from the training if any model is generated otherwise None logs_path = appcfg['PATHS']['AI_LOGS'] dnn = cmdcfg.dnnarch ##TODO list_of_files = glob.glob(os.path.join(model.log_dir,dnn+'*')) # * means all if need specific format then *.h5 latest_file = max(list_of_files, key=os.path.getctime) new_weights_path = re.sub('\{}'.format(logs_path+'/'), '', latest_file) modelinfo['weights_path'] = new_weights_path modelinfo_filepath = apputil.get_abs_path(appcfg, modelinfo, 'AI_MODEL_CFG_PATH') common.yaml_safe_dump(modelinfo_filepath, modelinfo) log.info("TRAIN:MODELINFO_FILEPATH: {}".format(modelinfo_filepath)) log.info("---x--x--x---") return modelinfo_filepath