Example #1
0
    def predict(self, image, point_cloud, frame_calib):
        with tf.Graph().as_default():
            model = AvodModel(self.model_config,
                              train_val_test='test',
                              dataset=self.dataset)

            # The model should return a dictionary of predictions
            prediction_dict = model.build()
            self._saver = tf.train.Saver()
            feed_dict = model.create_pred_feed_dict(image, point_cloud,
                                                    frame_calib)
            sess = self._create_sess()
            restored_checkpoint = self._restore_chkpt()
            self._saver.restore(sess, restored_checkpoint)
            predictions = sess.run(prediction_dict, feed_dict=feed_dict)

            box_rep = self.model_config.avod_config.avod_box_representation
            proposals_and_scores = \
                self.get_rpn_proposals_and_scores(predictions)
            predictions_and_scores = \
                self.get_avod_predicted_boxes_3d_and_scores(predictions,
                                                            box_rep)

            print("Attempting to draw predictions")

            drawn_image = self._draw_predictions(image, predictions_and_scores,
                                                 frame_calib)
Example #2
0
def train(model_config, train_config, dataset_config):

    dataset = DatasetBuilder.build_kitti_dataset(dataset_config,
                                                 use_defaults=False)

    train_val_test = 'train'
    model_name = model_config.model_name

    with tf.Graph().as_default():
        if model_name == 'rpn_model':
            model = RpnModel(model_config,
                             train_val_test=train_val_test,
                             dataset=dataset)
        elif model_name == 'bev_only_rpn_model':
            model = BevOnlyRpnModel(model_config,
                                    train_val_test=train_val_test,
                                    dataset=dataset)
        elif model_name == 'avod_model':
            model = AvodModel(model_config,
                              train_val_test=train_val_test,
                              dataset=dataset)
        elif model_name == 'bev_only_avod_model':
            model = BevOnlyAvodModel(model_config,
                                     train_val_test=train_val_test,
                                     dataset=dataset)
        else:
            raise ValueError('Invalid model_name')

        trainer.train(model, train_config)
Example #3
0
def train(model_config, train_config, dataset_config):
    # 读取config文件里面的详细内容,包括:
    # model_config:模型参数
    # train_config:训练参数
    # dataset_config:数据集参数
    dataset = DatasetBuilder.build_kitti_dataset(dataset_config,
                                                 use_defaults=False)

    train_val_test = 'train'
    # 包括avod_model 和 rpn_model
    model_name = model_config.model_name

    with tf.Graph().as_default():
        if model_name == 'rpn_model':
            model = RpnModel(model_config,
                             train_val_test=train_val_test,
                             dataset=dataset)
        elif model_name == 'avod_model':
            model = AvodModel(model_config,
                              train_val_test=train_val_test,
                              dataset=dataset)
        else:
            raise ValueError('Invalid model_name')

        trainer.train(model, train_config)
Example #4
0
def set_up_model(pipeline_config, data_split):

    model_config, train_config, _, dataset_config = \
        config_builder.get_configs_from_pipeline_file(
            pipeline_config, is_training=False)

    dataset_config = config_builder.proto_to_obj(dataset_config)

    train_val_test = data_split
    # Always run in test mode
    dataset_config.data_split = 'test'
    dataset_config.data_split_dir = 'testing'
    dataset_config.has_labels = False
    dataset_config.aug_list = []

    dataset = DatasetBuilder.build_kitti_dataset(dataset_config,
                                                 use_defaults=False)

    model_name = model_config.model_name
    if model_name == 'rpn_model':
        model = RpnModel(model_config,
                         train_val_test=train_val_test,
                         dataset=dataset)
    elif model_name == 'avod_model':
        model = AvodModel(model_config,
                          train_val_test=train_val_test,
                          dataset=dataset)
    elif model_name == 'avod_ssd_model':
        model = AvodSSDModel(model_config,
                             train_val_test=train_val_test,
                             dataset=dataset)
    else:
        raise ValueError('Invalid model_name')

    return model
Example #5
0
def get_proposal_network(model_config, dataset, model_path, GPU_INDEX=0):
    with tf.Graph().as_default():
        with tf.device('/gpu:'+str(GPU_INDEX)):
            rpn_model = AvodModel(model_config,
                         train_val_test='test',
                         dataset=dataset)
            rpn_pred = rpn_model.build()
            saver = tf.train.Saver()

        # Create a session
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        sess = tf.Session(config=config)
        saver.restore(sess, model_path)
        return rpn_pred, sess, rpn_model
Example #6
0
def train(model_config, train_config, dataset_config):

    dataset = DatasetBuilder.build_kitti_dataset(dataset_config,
                                                 use_defaults=False)

    train_val_test = 'train'
    model_name = model_config.model_name

    with tf.Graph().as_default():
        if model_name == 'rpn_model':
            model = RpnModel(model_config,
                             train_val_test=train_val_test,
                             dataset=dataset)
        elif model_name == 'avod_model':
            model = AvodModel(model_config,
                              train_val_test=train_val_test,
                              dataset=dataset)
        elif model_name == 'retinanet_model':
            model = RetinanetModel(model_config,
                                   train_val_test=train_val_test,
                                   dataset=dataset)
        else:
            raise ValueError('Invalid model_name')

        #import pdb
        #pdb.set_trace()

        trainer.train(model, train_config)
Example #7
0
    def test_avod_loss(self):
        # tests the set up for the model and the loss
        # Use "val" so that the first sample is loaded each time
        avod_model = AvodModel(self.model_config,
                               train_val_test="val",
                               dataset=self.dataset)

        predictions = avod_model.build()
        loss, total_loss = avod_model.loss(predictions)
        feed_dict = avod_model.create_feed_dict()

        with self.test_session() as sess:
            init = tf.global_variables_initializer()
            sess.run(init)
            loss_dict_out = sess.run(loss, feed_dict=feed_dict)
            print('Losses ', loss_dict_out)
Example #8
0
def inference(model_config, eval_config, dataset_config, data_split,
              ckpt_indices):

    # Overwrite the defaults
    dataset_config = config_builder.proto_to_obj(dataset_config)

    dataset_config.data_split = data_split
    dataset_config.data_split_dir = 'training'
    if data_split == 'test':
        dataset_config.data_split_dir = 'testing'

    eval_config.eval_mode = 'test'
    eval_config.evaluate_repeatedly = False

    dataset_config.has_labels = False
    # Enable this to see the actually memory being used
    eval_config.allow_gpu_mem_growth = True

    eval_config = config_builder.proto_to_obj(eval_config)
    # Grab the checkpoint indices to evaluate
    eval_config.ckpt_indices = ckpt_indices

    # Remove augmentation during evaluation in test mode
    dataset_config.aug_list = []

    # Build the dataset object
    dataset = DatasetBuilder.build_kitti_dataset(dataset_config,
                                                 use_defaults=False)

    # Setup the model
    model_name = model_config.model_name
    # Overwrite repeated field
    model_config = config_builder.proto_to_obj(model_config)
    # Switch path drop off during evaluation
    model_config.path_drop_probabilities = [1.0, 1.0]

    with tf.Graph().as_default():
        if model_name == 'avod_model':
            model = AvodModel(model_config,
                              train_val_test=eval_config.eval_mode,
                              dataset=dataset)
        elif model_name == 'rpn_model':
            model = RpnModel(model_config,
                             train_val_test=eval_config.eval_mode,
                             dataset=dataset)
        elif model_name == 'bev_only_rpn_model':
            model = BevOnlyRpnModel(model_config,
                                    train_val_test=eval_config.eval_mode,
                                    dataset=dataset)
        elif model_name == 'bev_only_avod_model':
            model = BevOnlyAvodModel(model_config,
                                     train_val_test=eval_config.eval_mode,
                                     dataset=dataset)
        else:
            raise ValueError('Invalid model name {}'.format(model_name))

        model_evaluator = Evaluator(model, dataset_config, eval_config)
        model_evaluator.run_latest_checkpoints()
Example #9
0
def inferPerspective(model_config, eval_config, dataset_config,
                     additional_cls):
    model_name = model_config.model_name

    entity_perspect_dir = dataset_config.dataset_dir + dataset_config.data_split_dir + '/'

    logging.debug("Inferring perspective: %s\n %s\n %s",
                  dataset_config.data_split, entity_perspect_dir,
                  dataset_config.dataset_dir)

    files_in_range = create_split.create_split(dataset_config.dataset_dir,
                                               entity_perspect_dir,
                                               dataset_config.data_split)

    # If there are no files within the range cfg.MIN_IDX, cfg.MAX_IDX
    # then skip this perspective
    if not files_in_range:
        logging.debug(
            "No files within the range cfg.MIN_IDX, cfg.MAX_IDX, skipping perspective"
        )
        return

    if not additional_cls:
        estimate_ground_planes.estimate_ground_planes(entity_perspect_dir,
                                                      dataset_config, 0)

    # Build the dataset object
    dataset = DatasetBuilder.build_kitti_dataset(dataset_config,
                                                 use_defaults=False)

    #Switch inference output directory
    model_config.paths_config.pred_dir = entity_perspect_dir + '/{}/'.format(
        cfg.AVOD_OUTPUT_DIR)
    logging.debug("Prediction directory: %s",
                  model_config.paths_config.pred_dir)

    with tf.Graph().as_default():
        if model_name == 'avod_model':
            model = AvodModel(model_config,
                              train_val_test=eval_config.eval_mode,
                              dataset=dataset)
        elif model_name == 'rpn_model':
            model = RpnModel(model_config,
                             train_val_test=eval_config.eval_mode,
                             dataset=dataset)
        else:
            raise ValueError('Invalid model name {}'.format(model_name))

        model_evaluator = Evaluator(model, dataset_config, eval_config)
        model_evaluator.run_latest_checkpoints()

    save_kitti_predictions.convertPredictionsToKitti(
        dataset, model_config.paths_config.pred_dir, additional_cls)
Example #10
0
def train(model_config, train_config, dataset_config):

    dataset = DatasetBuilder.build_kitti_dataset(dataset_config,
                                                 use_defaults=False)

    train_val_test = 'train'
    model_name = model_config.model_name

    with tf.Graph().as_default():
        if model_name == 'rpn_model':
            model = RpnModel(model_config,
                             train_val_test=train_val_test,
                             dataset=dataset)
            trainer.train(model, train_config)
        elif model_name == 'avod_model':
            model = AvodModel(model_config,
                              train_val_test=train_val_test,
                              dataset=dataset)
            trainer.train(model, train_config)
        elif model_name == 'avod_moe_model':
            model = AvodMoeModel(model_config,
                                 train_val_test=train_val_test,
                                 dataset=dataset)
            trainer_moe.train(model, train_config)
        elif model_name == 'epbrm':
            model = epBRM(model_config, dataset=dataset)
            epbrm_trainer.train(model, train_config)

        elif model_name == 'avod_model_new_bev':
            model = AvodModelBEV(model_config,
                                 train_val_test=train_val_test,
                                 dataset=dataset)
        elif model_name == 'avod_model_double_fusion_new_bev':
            model = AvodModelDoubleFusionBEV(model_config,
                                             train_val_test=train_val_test,
                                             dataset=dataset)
        else:
            raise ValueError('Invalid model_name')

        if model_name == 'avod_model_new_bev' or model_name == 'avod_model_double_fusion_new_bev':
            trainer_new_bev.train(model, train_config)
        else:
            trainer.train(model, train_config)
Example #11
0
def set_up_model_train_mode(pipeline_config_path, data_split):
    """Returns the model and its train_op."""

    model_config, train_config, _,  dataset_config = \
        config_builder.get_configs_from_pipeline_file(
            pipeline_config_path, is_training=True)

    dataset = DatasetBuilder.build_kitti_dataset(dataset_config,
                                                 use_defaults=False)

    model_name = model_config.model_name
    if model_name == 'rpn_model':
        model = RpnModel(model_config,
                         train_val_test=data_split,
                         dataset=dataset)
    elif model_name == 'avod_model':
        model = AvodModel(model_config,
                          train_val_test=data_split,
                          dataset=dataset)
    elif model_name == 'avod_ssd_model':
        model = AvodSSDModel(model_config,
                             train_val_test=data_split,
                             dataset=dataset)
    else:
        raise ValueError('Invalid model_name')

    prediction_dict = model.build()
    losses_dict, total_loss = model.loss(prediction_dict)

    # These parameters are required to set up the optimizer
    global_summaries = set([])
    global_step_tensor = tf.Variable(0, trainable=False)
    training_optimizer = optimizer_builder.build(train_config.optimizer,
                                                 global_summaries,
                                                 global_step_tensor)

    # Set up the train op
    train_op = slim.learning.create_train_op(total_loss, training_optimizer)

    return model, train_op
Example #12
0
def set_up_model_test_mode(pipeline_config_path, data_split):
    """Returns the model and its config in test mode."""

    model_config, _, _,  dataset_config = \
        config_builder.get_configs_from_pipeline_file(
            pipeline_config_path, is_training=False)

    dataset = DatasetBuilder.build_kitti_dataset(dataset_config,
                                                 use_defaults=False)

    # Overwrite the defaults
    dataset_config = config_builder.proto_to_obj(dataset_config)

    # Use the validation set
    dataset_config.data_split = data_split
    dataset_config.data_split_dir = 'training'
    if data_split == 'test':
        dataset_config.data_split_dir = 'testing'

    # Remove augmentation when in test mode
    dataset_config.aug_list = []

    # Build the dataset object
    dataset = DatasetBuilder.build_kitti_dataset(dataset_config,
                                                 use_defaults=False)

    model_name = model_config.model_name
    if model_name == 'rpn_model':
        model = RpnModel(model_config, train_val_test='test', dataset=dataset)
    elif model_name == 'avod_model':
        model = AvodModel(model_config, train_val_test='test', dataset=dataset)
    elif model_name == 'avod_ssd_model':
        model = AvodSSDModel(model_config,
                             train_val_test='test',
                             dataset=dataset)
    else:
        raise ValueError('Invalid model_name')

    return model, model_config
Example #13
0
def evaluate(model_config, eval_config, dataset_config):

    # Parse eval config
    eval_mode = eval_config.eval_mode
    if eval_mode not in ['val', 'test']:
        raise ValueError('Evaluation mode can only be set to `val` or `test`')
    evaluate_repeatedly = eval_config.evaluate_repeatedly

    # Parse dataset config
    data_split = dataset_config.data_split
    if data_split == 'train':
        dataset_config.data_split_dir = 'training'
        dataset_config.has_labels = True

    elif data_split.startswith('val'):
        dataset_config.data_split_dir = 'training'

        # Don't load labels for val split when running in test mode
        if eval_mode == 'val':
            dataset_config.has_labels = True
        elif eval_mode == 'test':
            dataset_config.has_labels = False

    elif data_split == 'test':
        dataset_config.data_split_dir = 'testing'
        dataset_config.has_labels = False

    else:
        raise ValueError('Invalid data split', data_split)

    # Convert to object to overwrite repeated fields
    dataset_config = config_builder.proto_to_obj(dataset_config)

    # Remove augmentation during evaluation
    dataset_config.aug_list = []

    # Build the dataset object
    dataset = DatasetBuilder.build_kitti_dataset(dataset_config,
                                                 use_defaults=False)

    # Setup the model
    model_name = model_config.model_name

    # Convert to object to overwrite repeated fields
    model_config = config_builder.proto_to_obj(model_config)

    # Switch path drop off during evaluation
    model_config.path_drop_probabilities = [1.0, 1.0]

    with tf.Graph().as_default():
        if model_name == 'avod_model':
            model = AvodModel(model_config,
                              train_val_test=eval_mode,
                              dataset=dataset)
        elif model_name == 'rpn_model':
            model = RpnModel(model_config,
                             train_val_test=eval_mode,
                             dataset=dataset)
        else:
            raise ValueError('Invalid model name {}'.format(model_name))

        model_evaluator = Evaluator(model, dataset_config, eval_config)

        if evaluate_repeatedly:
            model_evaluator.repeated_checkpoint_run()
        else:
            model_evaluator.run_latest_checkpoints()
Example #14
0
def test(model_config, eval_config,
              dataset_config, data_split,
              ckpt_indices):

    # Overwrite the defaults
    dataset_config = config_builder.proto_to_obj(dataset_config)

    dataset_config.data_split = data_split
    dataset_config.data_split_dir = 'training'
    if data_split == 'test':
        dataset_config.data_split_dir = 'testing'

    eval_config.eval_mode = 'test'
    eval_config.evaluate_repeatedly = False

    dataset_config.has_labels = False
    # Enable this to see the actually memory being used
    eval_config.allow_gpu_mem_growth = True

    eval_config = config_builder.proto_to_obj(eval_config)
    # Grab the checkpoint indices to evaluate
    eval_config.ckpt_indices = ckpt_indices

    # Remove augmentation during evaluation in test mode
    dataset_config.aug_list = []

    # Build the dataset object
    dataset = DatasetBuilder.build_kitti_dataset(dataset_config,
                                                 use_defaults=False)

    # Setup the model
    model_name = model_config.model_name
    # Overwrite repeated field
    model_config = config_builder.proto_to_obj(model_config)
    # Switch path drop off during evaluation
    model_config.path_drop_probabilities = [1.0, 1.0]

    with tf.Graph().as_default():
        if model_name == 'avod_model':
            model = AvodModel(model_config,
                              train_val_test=eval_config.eval_mode,
                              dataset=dataset)
        elif model_name == 'rpn_model':
            model = RpnModel(model_config,
                             train_val_test=eval_config.eval_mode,
                             dataset=dataset)
        else:
            raise ValueError('Invalid model name {}'.format(model_name))

        #model_evaluator = Evaluator(model, dataset_config, eval_config)
        #model_evaluator.run_latest_checkpoints()

        # Create a variable tensor to hold the global step
        global_step_tensor = tf.Variable(0, trainable=False, name='global_step')

        allow_gpu_mem_growth = eval_config.allow_gpu_mem_growth
        if allow_gpu_mem_growth:
            # GPU memory config
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = allow_gpu_mem_growth
            _sess = tf.Session(config=config)
        else:
            _sess = tf.Session()

        _prediction_dict = model.build()
        _saver = tf.train.Saver()

        trainer_utils.load_checkpoints(model_config.paths_config.checkpoint_dir,
                                       _saver)
        num_checkpoints = len(_saver.last_checkpoints)
        print("test:",num_checkpoints)
        checkpoint_to_restore = _saver.last_checkpoints[num_checkpoints-1]

        _saver.restore(_sess, checkpoint_to_restore)

        num_samples = model.dataset.num_samples
        num_valid_samples = 0

        current_epoch = model.dataset.epochs_completed
        while current_epoch == model.dataset.epochs_completed:
            # Keep track of feed_dict speed
            start_time = time.time()
            feed_dict = model.create_feed_dict()
            feed_dict_time = time.time() - start_time

            # Get sample name from model
            sample_name = model.sample_info['sample_name']

            num_valid_samples += 1
            print("Step: {} / {}, Inference on sample {}".format(
                num_valid_samples, num_samples,
                sample_name))

            print("test mode")
            inference_start_time = time.time()
            # Don't calculate loss or run summaries for test
            predictions = _sess.run(_prediction_dict,
                                         feed_dict=feed_dict)
            inference_time = time.time() - inference_start_time

            print("inference time:", inference_time)

            predictions_and_scores = get_avod_predicted_boxes_3d_and_scores(predictions)

            #print(predictions_and_scores)
            #im_path = os.path.join(dataset_dir, 'training/image_2/{:06d}.png'.format(img_idx))
            #im = cv2.imread(im_path)
            #cv2.imshow('result',im)
            #cv2.waitKey(30)

            prediction_boxes_3d = predictions_and_scores[:, 0:7]
            prediction_scores = predictions_and_scores[:, 7]
            prediction_class_indices = predictions_and_scores[:, 8]
            gt_classes = ['Car']
            fig_size = (10, 6.1)

            avod_score_threshold = 0.1
            if len(prediction_boxes_3d) > 0:

                # Apply score mask
                avod_score_mask = prediction_scores >= avod_score_threshold
                prediction_boxes_3d = prediction_boxes_3d[avod_score_mask]
                prediction_scores = prediction_scores[avod_score_mask]
                prediction_class_indices = \
                    prediction_class_indices[avod_score_mask]

            if len(prediction_boxes_3d) > 0:

                dataset_dir = model.dataset.dataset_dir
                sample_name = (model.dataset.sample_names[model.dataset._index_in_epoch - 1])
                img_idx = int(sample_name)
                print("frame_index",img_idx)
                image_path = model.dataset.get_rgb_image_path(sample_name)
                image = Image.open(image_path)
                image_size = image.size

                if model.dataset.has_labels:
                    gt_objects = obj_utils.read_labels(dataset.label_dir, img_idx)
                else:
                    gt_objects = []
                filtered_gt_objs = model.dataset.kitti_utils.filter_labels(
                    gt_objects, classes=gt_classes)

                stereo_calib = calib_utils.read_calibration(dataset.calib_dir,
                                                            img_idx)
                calib_p2 = stereo_calib.p2
                # Project the 3D box predictions to image space
                image_filter = []
                final_boxes_2d = []
                for i in range(len(prediction_boxes_3d)):
                    box_3d = prediction_boxes_3d[i, 0:7]
                    img_box = box_3d_projector.project_to_image_space(
                        box_3d, calib_p2,
                        truncate=True, image_size=image_size,
                        discard_before_truncation=False)
                    if img_box is not None:
                        image_filter.append(True)
                        final_boxes_2d.append(img_box)
                    else:
                        image_filter.append(False)
                final_boxes_2d = np.asarray(final_boxes_2d)
                final_prediction_boxes_3d = prediction_boxes_3d[image_filter]
                final_scores = prediction_scores[image_filter]
                final_class_indices = prediction_class_indices[image_filter]

                num_of_predictions = final_boxes_2d.shape[0]

                # Convert to objs
                final_prediction_objs = \
                    [box_3d_encoder.box_3d_to_object_label(
                        prediction, obj_type='Prediction')
                        for prediction in final_prediction_boxes_3d]
                for (obj, score) in zip(final_prediction_objs, final_scores):
                    obj.score = score

                pred_fig, pred_2d_axes, pred_3d_axes = \
                    vis_utils.visualization(dataset.rgb_image_dir,
                                            img_idx,
                                            display=False,
                                            fig_size=fig_size)

                draw_predictions(filtered_gt_objs,
                                 calib_p2,
                                 num_of_predictions,
                                 final_prediction_objs,
                                 final_class_indices,
                                 final_boxes_2d,
                                 pred_2d_axes,
                                 pred_3d_axes,
                                 True,
                                 True,
                                 gt_classes,
                                 False)

                #cv2.imshow('result',pred_fig)
                print(type(pred_fig))
                pred_fig.canvas.draw()
                img = np.fromstring(pred_fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
                img  = img.reshape(pred_fig.canvas.get_width_height()[::-1] + (3,))
                cv2.imshow('result',img)

                #draw bird view
                kitti_utils = model.dataset.kitti_utils
                print(img.shape[0:2])
                point_cloud = kitti_utils.get_point_cloud(
                    'lidar', img_idx, (370, 1242))
                ground_plane = kitti_utils.get_ground_plane(sample_name)
                bev_images = kitti_utils.create_bev_maps(point_cloud, ground_plane)

                density_map = np.array(bev_images.get("density_map"))
                _, box_points_norm = box_3d_projector.project_to_bev(
                    final_prediction_boxes_3d, [[-40, 40], [0, 70]])
                density_map = draw_boxes(density_map, box_points_norm)
                cv2.imshow('lidar',density_map)
                cv2.waitKey(-1)
Example #15
0
    def test_load_model_weights(self):
        # Tests loading weights

        train_val_test = 'train'

        # Overwrite the training iterations
        self.train_config.max_iterations = 1
        self.train_config.overwrite_checkpoints = True

        with tf.Graph().as_default():
            model = RpnModel(self.model_config,
                             train_val_test=train_val_test,
                             dataset=self.dataset)
            trainer.train(model, self.train_config)

            paths_config = self.model_config.paths_config
            rpn_checkpoint_dir = paths_config.checkpoint_dir

            # load the weights back in
            init_op = tf.global_variables_initializer()

            saver = tf.train.Saver()
            with tf.Session() as sess:
                sess.run(init_op)

                trainer_utils.load_checkpoints(rpn_checkpoint_dir, saver)
                checkpoint_to_restore = saver.last_checkpoints[-1]
                trainer_utils.load_model_weights(sess, checkpoint_to_restore)

                rpn_vars = slim.get_model_variables()
                rpn_weights = sess.run(rpn_vars)
                self.assertGreater(len(rpn_weights), 0,
                                   msg='Loaded RPN weights are empty')

        with tf.Graph().as_default():
            model = AvodModel(self.model_config,
                              train_val_test=train_val_test,
                              dataset=self.dataset)
            model.build()

            # load the weights back in
            init_op = tf.global_variables_initializer()

            saver = tf.train.Saver()
            with tf.Session() as sess:
                sess.run(init_op)

                trainer_utils.load_checkpoints(rpn_checkpoint_dir, saver)
                checkpoint_to_restore = saver.last_checkpoints[-1]
                trainer_utils.load_model_weights(sess, checkpoint_to_restore)

                avod_vars = slim.get_model_variables()
                avod_weights = sess.run(avod_vars)

                # AVOD weights should include both RPN + AVOD weights
                self.assertGreater(len(avod_weights),
                                   len(rpn_weights),
                                   msg='Expected more weights for AVOD')

                # grab weights corresponding to RPN by index
                # since the model variables are ordered
                rpn_len = len(rpn_weights)
                loaded_rpn_vars = avod_vars[0:rpn_len]
                rpn_weights_reload = sess.run(loaded_rpn_vars)

                # Make sure the reloaded weights match the originally
                # loaded weights
                for i in range(rpn_len):
                    np.testing.assert_array_equal(rpn_weights_reload[i],
                                                  rpn_weights[i])