def __init__(self, settings):
        logger.info('Will init all required to inference.')

        self.source_gpu_device = settings['device_id']
        self._load_train_config()
        self._construct_and_fill_model()
        logger.info('Model is ready to inference.')
示例#2
0
    def _construct_and_fill_model(self):
        self.device_ids = sly.remap_gpu_devices(self.config['gpu_devices'])
        self.model = create_model(n_cls=len(self.train_classes), device_ids=self.device_ids)

        self.model = WeightsRW(self.helper.paths.model_dir).load_strictly(self.model)
        self.model.eval()
        logger.info('Weights are loaded.')
示例#3
0
    def __init__(self,
                 single_image_inference_initializer,
                 num_processes,
                 default_inference_mode_config: dict,
                 config_validator=None):
        self._config_validator = config_validator or AlwaysPassingConfigValidator(
        )
        self._inference_mode_config = determine_task_inference_mode_config(
            deepcopy(default_inference_mode_config))

        self._in_project = read_single_project(TaskPaths.DATA_DIR)
        logger.info('Project structure has been read. Samples: {}.'.format(
            self._in_project.total_items))

        self._inference_request_queue = mp.JoinableQueue(maxsize=2 *
                                                         num_processes)
        self._inference_result_queue = mp.JoinableQueue(maxsize=2 *
                                                        num_processes)
        self._inference_processes = [
            mp.Process(target=single_inference_process_fn,
                       args=(single_image_inference_initializer,
                             self._inference_mode_config,
                             self._in_project.meta.to_json(),
                             self._inference_request_queue,
                             self._inference_result_queue),
                       daemon=True) for _ in range(num_processes)
        ]
        logger.info('Dataset inference preparation done.')
        for p in self._inference_processes:
            p.start()
示例#4
0
    def _construct_data_loaders(self):
        self.data_dicts = {}
        self.iters_cnt = {}
        for the_name, the_tag in self.name_to_tag.items():
            samples_lst = self._deprecated_samples_by_tag[the_tag]
            supervisely_lib.nn.dataset.ensure_samples_nonempty(
                samples_lst, the_tag, self.project.meta)

            img_paths, labels, num_boxes = load_dataset(
                samples_lst, self.class_title_to_idx, self.project.meta)
            dataset_dict = {
                'img_paths': img_paths,
                'labels': labels,
                'num_boxes': num_boxes,
                'sample_cnt': len(samples_lst)
            }
            self.data_dicts[the_name] = dataset_dict

            samples_per_iter = self.config['batch_size'][the_name] * len(
                self.config['gpu_devices'])
            self.iters_cnt[the_name] = math.ceil(
                float(len(samples_lst)) / samples_per_iter)
            logger.info('Prepared dataset.',
                        extra={
                            'dataset_purpose': the_name,
                            'dataset_tag': the_tag,
                            'sample_cnt': len(samples_lst)
                        })
示例#5
0
def serve():
    settings = {
        'device_id': 0,
        'cache_limit': 500,
        'connection': {
            'server_address': None,
            'token': None,
            'task_id': None,
        },
    }

    new_settings = sly.json_load(sly.TaskPaths(determine_in_project=False).settings_path)
    logger.info('Input settings', extra={'settings': new_settings})
    sly.update_recursively(settings, new_settings)
    logger.info('Full settings', extra={'settings': settings})

    def model_creator():
        res = UnetV2FastApplier(settings={
            'device_id': settings['device_id']
        })
        return res

    image_cache = SimpleCache(settings['cache_limit'])
    serv_instance = AgentRPCServicer(logger=logger,
                                     model_creator=model_creator,
                                     apply_cback=single_img_pipeline,
                                     conn_settings=settings['connection'],
                                     cache=image_cache)
    serv_instance.run_inf_loop()
示例#6
0
 def _determine_model_classes(self):
     self.class_title_to_idx, self.out_classes = create_output_classes(
         in_project_classes=self.helper.in_project_meta.classes)
     logger.info('Determined model internal class mapping',
                 extra={'class_mapping': self.class_title_to_idx})
     logger.info('Determined model out classes',
                 extra={'classes': self.out_classes.py_container})
    def _construct_and_fill_model(self):
        model_dir = sly.TaskPaths(determine_in_project=False).model_dir
        self.device_ids = sly.remap_gpu_devices([self.source_gpu_device])

        src_train_cfg_path = join(model_dir, 'model.cfg')
        with open(src_train_cfg_path) as f:
            src_config = f.readlines()

        def repl_batch(row):
            if 'batch=' in row:
                return 'batch=1\n'
            if 'subdivisions=' in row:
                return 'subdivisions=1\n'
            return row

        changed_config = [repl_batch(x) for x in src_config]

        inf_cfg_path = join(model_dir, 'inf_model.cfg')
        if not os.path.exists(inf_cfg_path):
            with open(inf_cfg_path, 'w') as f:
                f.writelines(changed_config)

        self.net = load_net(inf_cfg_path.encode('utf-8'),
                            join(model_dir, 'model.weights').encode('utf-8'),
                            0)
        logger.info('Weights are loaded.')
示例#8
0
def main(args):
    logger.info('ENVS', extra={**args, 'DOCKER_PASSWORD': '******'})

    constants.HOST_DIR = args['AGENT_HOST_DIR']
    constants.AGENT_TASKS_DIR_HOST = os.path.join(constants.HOST_DIR, 'tasks')

    constants.SERVER_ADDRESS = args['SERVER_ADDRESS']
    constants.TOKEN = args['ACCESS_TOKEN']
    constants.TASKS_DOCKER_LABEL = 'supervisely_{}'.format(constants.TOKEN)

    constants.DOCKER_LOGIN = args['DOCKER_LOGIN']
    constants.DOCKER_PASSWORD = args['DOCKER_PASSWORD']
    constants.DOCKER_REGISTRY = args['DOCKER_REGISTRY']

    constants.WITH_LOCAL_STORAGE = flag_from_env(args['WITH_LOCAL_STORAGE'])
    constants.UPLOAD_RESULT_IMAGES = flag_from_env(
        args['UPLOAD_RESULT_IMAGES'])
    constants.PULL_ALWAYS = flag_from_env(args['PULL_ALWAYS'])

    constants.DELETE_TASK_DIR_ON_FINISH = flag_from_env(
        args['DELETE_TASK_DIR_ON_FINISH'])

    agent = Agent()
    agent.inf_loop()
    agent.wait_all()
示例#9
0
 def __init__(self):
     super().__init__(
         default_config=ObjectDetectionTrainer.get_default_config())
     logger.info('Model is ready to train.')
     # To be filled in by dump_model() callback inside train().
     self.saver = None
     self.sess = None
示例#10
0
    def _determine_model_classes(self):
        if 'classes' not in self.config:
            # Key-value tags are ignored as a source of class labels.
            img_tags = set(tag_meta.name for tag_meta in self.project.meta.img_tag_metas if
                           tag_meta.value_type == sly.TagValueType.NONE)
            img_tags -= set(self.config['dataset_tags'].values())
            train_classes = sorted(img_tags)
        else:
            train_classes = self.config['classes']

        if 'ignore_tags' in self.config:
            for tag in self.config['ignore_tags']:
                if tag in train_classes:
                    train_classes.remove(tag)

        if len(train_classes) < 2:
            raise RuntimeError('Training requires at least two input classes.')

        in_classification_tags_to_idx, self.classification_tags_sorted = create_classes(train_classes)
        self.classification_tags_to_idx = infer_training_class_to_idx_map(self.config['weights_init_type'],
                                                                          in_classification_tags_to_idx,
                                                                          sly.TaskPaths.MODEL_CONFIG_PATH,
                                                                          class_to_idx_config_key=self.classification_tags_to_idx_key)

        self.class_title_to_idx = {}
        self.out_classes = sly.ObjClassCollection()
        logger.info('Determined model internal class mapping', extra={'class_mapping': self.class_title_to_idx})
        logger.info('Determined model out classes', extra={'classes': self.classification_tags_sorted})
示例#11
0
def infer_training_class_to_idx_map(weights_init_type, in_project_class_to_idx, model_config_fpath,
                                    class_to_idx_config_key, special_class_ids=None):
    if weights_init_type == TRANSFER_LEARNING:
        logger.info('Transfer learning mode, using a class mapping created from scratch.')
        class_title_to_idx = in_project_class_to_idx
    elif weights_init_type == CONTINUE_TRAINING:
        logger.info('Continued training mode, reusing the existing class mapping from the model.')
        class_title_to_idx = read_validate_model_class_to_idx_map(
            model_config_fpath=model_config_fpath,
            in_project_classes_set=set(in_project_class_to_idx.keys()),
            class_to_idx_config_key=class_to_idx_config_key)
    else:
        raise RuntimeError('Unknown weights init type: {}'.format(weights_init_type))

    if special_class_ids is not None:
        for class_title, requested_class_id in special_class_ids.items():
            effective_class_id = class_title_to_idx[class_title]
            if requested_class_id != effective_class_id:
                error_msg = ('Unable to start training. Effective integer id for class {} does not match the ' +
                             'requested value in the training config ({} vs {}).'.format(
                                 class_title, effective_class_id, requested_class_id))
                logger.critical(error_msg, extra={'class_title_to_idx': class_title_to_idx,
                                                  'special_class_ids': special_class_ids})
                raise RuntimeError(error_msg)
    return class_title_to_idx
示例#12
0
    def _construct_and_fill_model(self):
        self.device_ids = sly.remap_gpu_devices(self.config['gpu_devices'])

        src_train_cfg_path = join(self.helper.paths.model_dir, 'model.cfg')
        with open(src_train_cfg_path) as f:
            src_config = f.readlines()

        def repl_batch(row):
            if 'batch=' in row:
                return 'batch=1\n'
            if 'subdivisions=' in row:
                return 'subdivisions=1\n'
            return row

        changed_config = [repl_batch(x) for x in src_config]

        inf_cfg_path = join(self.helper.paths.model_dir, 'inf_model.cfg')
        if not os.path.exists(inf_cfg_path):
            with open(inf_cfg_path, 'w') as f:
                f.writelines(changed_config)

        self.net = load_net(
            inf_cfg_path.encode('utf-8'),
            join(self.helper.paths.model_dir, 'model.weights').encode('utf-8'),
            0)
        # self.meta = load_meta(join(self.helper.paths.model_dir, 'model.names').encode('utf-8'))
        logger.info('Weights are loaded.')
示例#13
0
def main():
    sly.task_verification(check_in_graph)

    logger.info('DTL started')
    helper = sly.DtlHelper()
    net = Net(helper.graph, helper.in_project_metas, helper.paths.results_dir)
    helper.save_res_meta(net.get_result_project_meta())

    # is_archive = net.is_archive()
    results_counter = 0
    for pr_name, pr_dir in helper.in_project_dirs.items():
        root_path, project_name = sly.ProjectFS.split_dir_project(pr_dir)
        project_fs = sly.ProjectFS.from_disk(root_path, project_name, by_annotations=True)
        progress = sly.progress_counter_dtl(pr_name, project_fs.image_cnt)
        for sample in project_fs:
            try:
                img_desc = sly.ImageDescriptor(sample)
                ann = sly.json_load(sample.ann_path)
                data_el = (img_desc, ann)
                export_output_generator = net.start(data_el)
                for res_export in export_output_generator:
                    logger.trace("image processed", extra={'img_name': res_export[0][0].get_img_name()})
                    results_counter += 1
            except Exception:
                ex = {
                    'project_name': sample.project_name,
                    'ds_name': sample.ds_name,
                    'image_name': sample.image_name
                }
                logger.warn('Image was skipped because some error occured', exc_info=True, extra=ex)
            progress.iter_done_report()

    logger.info('DTL finished', extra={'event_type': EventType.DTL_APPLIED, 'new_proj_size': results_counter})
示例#14
0
def create_segmentation_classes(in_project_classes, special_classes_config, bkg_input_idx,
                                weights_init_type, model_config_fpath, class_to_idx_config_key, start_class_id=1):
    extra_classes = {}
    special_class_ids = {}
    bkg_title = special_classes_config.get(BACKGROUND, None)
    if bkg_title is not None:
        extra_classes = {bkg_title: [34, 34, 34]} # Default background color
        special_class_ids[bkg_title] = bkg_input_idx

    exclude_titles = []
    neutral_title = special_classes_config.get(NEUTRAL, None)
    if neutral_title is not None:
        exclude_titles.append(neutral_title)
    out_classes = make_out_classes(in_project_classes, geometry_type=Bitmap, exclude_titles=exclude_titles,
                                   extra_classes=extra_classes)

    logger.info('Determined model out classes', extra={'out_classes': list(out_classes)})

    in_project_class_to_idx = make_new_class_to_idx_map(in_project_classes, start_class_id=start_class_id,
                                                        preset_class_ids=special_class_ids,
                                                        exclude_titles=exclude_titles)
    class_title_to_idx = infer_training_class_to_idx_map(weights_init_type,
                                                         in_project_class_to_idx,
                                                         model_config_fpath,
                                                         class_to_idx_config_key,
                                                         special_class_ids=special_class_ids)
    logger.info('Determined class mapping.', extra={'class_mapping': class_title_to_idx})
    return out_classes, class_title_to_idx
示例#15
0
    def run_inference(self):
        out_project_fs = copy(self.in_project_fs)
        out_project_fs.root_path = self.helper.paths.results_dir
        out_project_fs.make_dirs()

        inf_feeder = sly.InferenceFeederFactory.create(self.config, self.helper.in_project_meta, self.train_classes)
        out_pr_meta = inf_feeder.out_meta
        out_pr_meta.to_dir(out_project_fs.project_path)

        ia_cnt = out_project_fs.pr_structure.image_cnt
        progress = sly.progress_counter_inference(cnt_imgs=ia_cnt)

        for sample in self.in_project_fs:
            logger.info('Will process image',
                        extra={'dataset_name': sample.ds_name, 'image_name': sample.image_name})
            ann_packed = sly.json_load(sample.ann_path)
            ann = sly.Annotation.from_packed(ann_packed, self.helper.in_project_meta)

            img = cv2.imread(sample.img_path)[:, :, ::-1]
            res_ann = inf_feeder.feed(img, ann, self._infer_on_img)

            out_ann_fpath = out_project_fs.ann_path(sample.ds_name, sample.image_name)
            res_ann_packed = res_ann.pack()
            sly.json_dump(res_ann_packed, out_ann_fpath)

            if self.debug_copy_images:
                out_img_fpath = out_project_fs.img_path(sample.ds_name, sample.image_name)
                sly.ensure_base_path(out_img_fpath)
                shutil.copy(sample.img_path, out_img_fpath)

            progress.iter_done_report()

        sly.report_inference_finished()
示例#16
0
    def _construct_data_loaders(self):
        self.tf_data_dicts = {}
        self.iters_cnt = {}
        for the_name, the_tag in self.name_to_tag.items():
            samples_lst = self._deprecated_samples_by_tag[the_tag]
            supervisely_lib.nn.dataset.ensure_samples_nonempty(
                samples_lst, the_tag, self.project.meta)
            dataset_dict = {
                "samples": samples_lst,
                "classes_mapping": self.class_title_to_idx,
                "project_meta": self.project.meta,
                "sample_cnt": len(samples_lst)
            }
            self.tf_data_dicts[the_name] = dataset_dict
            num_gpu_devices = len(self.config['gpu_devices'])
            single_gpu_batch_size = self.config['batch_size'][the_name]
            effective_batch_size = single_gpu_batch_size * num_gpu_devices
            if len(samples_lst) < effective_batch_size:
                raise RuntimeError(
                    f'Not enough items in the {the_name!r} fold (tagged {the_tag!r}). There are only '
                    f'{len(samples_lst)} items, but the effective batch size is {effective_batch_size} '
                    f'({num_gpu_devices} GPU devices X {single_gpu_batch_size} single GPU vatch size).'
                )

            self.iters_cnt[the_name] = len(samples_lst) // effective_batch_size
            logger.info('Prepared dataset.',
                        extra={
                            'dataset_purpose': the_name,
                            'dataset_tag': the_tag,
                            'sample_cnt': len(samples_lst)
                        })
示例#17
0
def determine_task_inference_mode_config(default_inference_mode_config):
    raw_task_config = load_json_file(TaskPaths.TASK_CONFIG_PATH)
    task_config = maybe_convert_from_v1_inference_task_config(raw_task_config)
    logger.info('Input task config', extra={'config': task_config})
    result_config = get_effective_inference_mode_config(
        task_config.get(MODE, {}), default_inference_mode_config)
    logger.info('Full inference mode config', extra={'config': result_config})
    return result_config
示例#18
0
def construct_model(model_dir):
    if 'model.pb' not in os.listdir(model_dir):
        logger.info('Freezing training checkpoint!')
        freeze_graph('image_tensor', model_dir + '/model.config',
                     model_dir + '/model_weights/model.ckpt', model_dir)
    detection_graph = create_detection_graph(model_dir)
    session = tf.Session(graph=detection_graph)
    return detection_graph, session
示例#19
0
 def _determine_settings(self):
     input_config = self.helper.task_settings
     logger.info('Input config', extra={'config': input_config})
     config = deepcopy(self.default_settings)
     sly.update_recursively(config, input_config)
     logger.info('Full config', extra={'config': config})
     SettingsValidator.validate_train_cfg(config)
     self.config = config
示例#20
0
    def __init__(self):
        logger.info('Will init all required to train.')
        self.helper = sly.TaskHelperTrain()

        self._determine_settings()
        self._determine_model_classes()
        self._determine_out_config()
        self._construct_data_dicts()
示例#21
0
    def _determine_input_data(self):
        project_fs = sly.ProjectFS.from_disk_dir_project(self.helper.paths.project_dir)
        logger.info('Project structure has been read. Samples: {}.'.format(project_fs.pr_structure.image_cnt))
        self.in_project_fs = project_fs

        self.inf_feeder = sly.InferenceFeederFactory.create(
            self.config, self.helper.in_project_meta, self.train_classes
        )
示例#22
0
    def _construct_and_fill_model(self):
        model_dir = sly.TaskPaths(determine_in_project=False).model_dir
        self.device_ids = sly.remap_gpu_devices([self.source_gpu_device])
        self.model = create_model(n_cls=len(self.train_classes),
                                  device_ids=self.device_ids)

        self.model = WeightsRW(model_dir).load_strictly(self.model)
        self.model.eval()
        logger.info('Weights are loaded.')
示例#23
0
 def _construct_samples_dct(self):
     logger.info('Will collect samples (image/annotation pairs).')
     self.name_to_tag = self.config[DATASET_TAGS]
     self._deprecated_samples_by_tag = samples_by_tags(required_tags=list(
         self.name_to_tag.values()),
                                                       project=self.project)
     self._samples_by_data_purpose = {
         purpose: self._deprecated_samples_by_tag[tag]
         for purpose, tag in self.config[DATASET_TAGS].items()
     }
 def _determine_model_input_size(self):
     src_size = self.train_config[SETTINGS][INPUT_SIZE]
     self.input_size = (src_size[HEIGHT], src_size[WIDTH])
     logger.info('Model input size is read (for auto-rescale).',
                 extra={
                     INPUT_SIZE: {
                         WIDTH: self.input_size[1],
                         HEIGHT: self.input_size[0]
                     }
                 })
示例#25
0
 def _construct_and_fill_model(self):
     model_dir = sly.TaskPaths(determine_in_project=False).model_dir
     self.device_ids = sly.remap_gpu_devices([self.source_gpu_device])
     if 'model.pb' not in os.listdir(model_dir):
         logger.info('Freezing training checkpoint!')
         freeze_graph('image_tensor', model_dir + '/model.config',
                      model_dir + '/model_weights/model.ckpt', model_dir)
     self.detection_graph = create_detection_graph(model_dir)
     self.session = tf.Session(graph=self.detection_graph)
     logger.info('Weights are loaded.')
示例#26
0
    def _validation(self):
        # Compute validation metrics.

        # Switch the model to evaluation model to stop batchnorm runnning average updates.
        self._model.eval()
        # Initialize the totals counters.
        validated_samples = 0
        total_val_metrics = {name: 0.0 for name in self._metrics_with_loss}
        total_loss = 0.0

        # Iterate over validation dataset batches.
        for val_it, (inputs, targets) in enumerate(self._data_loaders[VAL]):
            _check_all_pixels_have_segmentation_class(targets)

            # Move the data to the GPU and run inference.
            with torch.no_grad():
                inputs_cuda, targets_cuda = Variable(inputs).cuda(), Variable(
                    targets).cuda()

            outputs_cuda = self._model(inputs_cuda)

            # The last betch may be smaller than the rest if the dataset does not have a whole number of full batches,
            # so read the batch size from the input.
            batch_size = inputs_cuda.size(0)

            # Compute the metrics and grab the values from GPU.
            batch_metrics = {
                name: metric_fn(outputs_cuda, targets_cuda).item()
                for name, metric_fn in self._metrics_with_loss.items()
            }
            for name, metric_value in batch_metrics.items():
                total_val_metrics[name] += metric_value * batch_size

            # Add up the totals.
            validated_samples += batch_size

            # Report progress.
            logger.info("Validation in progress",
                        extra={
                            'epoch': self.epoch_flt,
                            'val_iter': val_it,
                            'val_iters': self._val_iters
                        })

        # Compute the average loss from the accumulated totals.
        avg_metrics_values = {
            name: total_value / validated_samples
            for name, total_value in total_val_metrics.items()
        }

        # Report progress and metric values to be plotted in the training chart and return.
        report_metrics_validation(self.epoch_flt, avg_metrics_values)
        logger.info("Validation has been finished",
                    extra={'epoch': self.epoch_flt})
        return avg_metrics_values
示例#27
0
    def _construct_data_loaders(self):
        self.device_ids = sly.env.remap_gpu_devices(self.config['gpu_devices'])

        src_size = self.config['input_size']
        input_size = (src_size['height'], src_size['width'])

        self.pytorch_datasets = {}
        self.data_loaders = {}

        shuffle_drop_last = {
            'train': True,
            'val': False
        }

        for the_name, the_tag in self.name_to_tag.items():
            samples_lst = self._deprecated_samples_by_tag[the_tag]
            samples_count = len(samples_lst)
            # note that now batch_size from config determines batch for single device
            batch_sz = self.config['batch_size'][the_name]
            batch_sz_full = batch_sz * len(self.device_ids)

            if samples_count < batch_sz_full:
                raise RuntimeError('Project should contain at least '
                                   '{}(batch size) * {}(gpu devices) = {} samples tagged by "{}", '
                                   'but found {} samples.'
                                   .format(batch_sz, len(self.device_ids), batch_sz_full, the_name, samples_count))

            the_ds = ResnetDataset(
                project_meta=self.project.meta,
                samples=samples_lst,
                out_size=input_size,
                class_mapping=self.classification_tags_to_idx,
                out_classes=self.classification_tags_sorted,
                allow_corrupted_cnt=self.config['allow_corrupted_samples'][the_name],
                spec_tags=list(self.config['dataset_tags'].values())
            )
            self.pytorch_datasets[the_name] = the_ds
            logger.info('Prepared dataset.', extra={
                'dataset_purpose': the_name, 'dataset_tag': the_tag, 'sample_cnt': len(samples_lst)
            })

            n_workers = self.config['data_workers'][the_name]
            self.data_loaders[the_name] = DataLoader(
                dataset=self.pytorch_datasets[the_name],
                batch_size=batch_sz_full,  # it looks like multi-gpu validation works
                num_workers=n_workers,
                shuffle=shuffle_drop_last,
                drop_last=shuffle_drop_last[the_name]
            )
        logger.info('DataLoaders are constructed.')

        self.train_iters = len(self.data_loaders['train'])
        self.val_iters = len(self.data_loaders['val'])
        self.epochs = self.config['epochs']
        self.eval_planner = EvalPlanner(epochs=self.epochs, val_every=self.config['val_every'])
    def __init__(self):
        logger.info('Starting base single image inference applier init.')
        task_model_config = self._load_task_model_config()
        self._config = update_recursively(self.get_default_config(),
                                          task_model_config)
        # Only validate after merging task config with the defaults.
        self._validate_model_config(self._config)

        self._load_train_config()
        self._construct_and_fill_model()
        logger.info('Base single image inference applier init done.')
示例#29
0
 def _determine_model_classes(self):
     spec_cls = self.config['special_classes']
     self.class_title_to_idx, self.out_classes = sly.create_segmentation_classes(
         in_project_classes=self.helper.in_project_meta.classes,
         bkg_title=spec_cls['background'],
         neutral_title=spec_cls['neutral'],
         bkg_color=0,
         neutral_color=self.neutral_input_idx,
     )
     logger.info('Determined model internal class mapping', extra={'class_mapping': self.class_title_to_idx})
     logger.info('Determined model out classes', extra={'classes': self.out_classes.py_container})
示例#30
0
def create_detection_graph(model_dirpath):
    fpath = osp.join(model_dirpath, 'model.pb')
    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(fpath, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
    logger.info('Restored model weights from training.')
    return detection_graph