コード例 #1
0
    def eval_qbs(self, phocnet_bin_path, train_xml_file, test_xml_file,
                 phoc_unigram_levels, gpu_id, debug_mode, doc_img_dir,
                 deploy_proto_path, metric, annotation_delimiter, no_bigrams):
        self.logger.info('--- Query-by-String Evaluation ---')
        train_list = self._load_word_list_from_xml(train_xml_file, doc_img_dir)
        test_list = self._load_word_list_from_xml(test_xml_file, doc_img_dir)

        phoc_unigrams = unigrams_from_word_list(
            word_list=train_list, split_character=annotation_delimiter)
        phoc_size = np.sum(phoc_unigram_levels) * len(phoc_unigrams)
        if no_bigrams:
            n_bigrams = 0
            bigrams = None
            bigram_levels = None
        else:
            n_bigrams = 50
            bigrams = get_most_common_n_grams(
                words=[word.get_transcription() for word in train_list],
                num_results=n_bigrams,
                n=2)
            bigram_levels = [2]
            phoc_size += 100

        # Set CPU/GPU mode
        if gpu_id != None:
            self.logger.info('Setting Caffe to GPU mode using device %d',
                             gpu_id)
            caffe.set_mode_gpu()
            caffe.set_device(gpu_id)
        else:
            self.logger.info('Setting Caffe to CPU mode')
            caffe.set_mode_cpu()
        phocnet = self._load_pretrained_phocnet(phocnet_bin_path, gpu_id,
                                                debug_mode, deploy_proto_path,
                                                phoc_size)
        self.logger.info('Predicting PHOCs for %d test words', len(test_list))
        test_phocs = self._net_output_for_word_list(
            word_list=test_list,
            cnn=phocnet,
            suppress_caffe_output=not debug_mode)
        test_strings = [word.get_transcription() for word in test_list]
        qry_strings = list(sorted(set(test_strings)))
        qry_phocs = build_phoc(words=qry_strings,
                               phoc_unigrams=phoc_unigrams,
                               unigram_levels=phoc_unigram_levels,
                               split_character=annotation_delimiter,
                               phoc_bigrams=bigrams,
                               bigram_levels=bigram_levels)
        self.logger.info('Calculating mAP...')
        mean_ap, _ = map_from_query_test_feature_matrices(
            query_features=qry_phocs,
            test_features=test_phocs,
            query_labels=qry_strings,
            test_labels=test_strings,
            metric=metric,
            drop_first=False)
        self.logger.info('mAP: %f', mean_ap * 100)
def main():
    # IAM PHOC unigrams
    unigrams = ['-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', \
                'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', \
                'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', \
                'w', 'x', 'y', 'z']

    # get words from dictionary
    with open('./dictionary-words.txt', 'r') as handle:
        words = sorted([line.rstrip() for line in handle])

    # build PHOCs for dictionary using IAM PHOC parameters
    phocs = build_phoc(words=words,
                       phoc_unigrams=unigrams,
                       unigram_levels=[1, 2, 3, 4, 5])

    # save dictionary to file
    with open('./dictionary-phocs.pickle', 'wb') as handle:
        pickle.dump(dict(zip(words, phocs)), handle)
コード例 #3
0
ファイル: phocnet_evaluator.py プロジェクト: ssudholt/phocnet
 def eval_qbs(self, phocnet_bin_path, train_xml_file, test_xml_file, phoc_unigram_levels, 
              gpu_id, debug_mode, doc_img_dir, deploy_proto_path, metric, 
              annotation_delimiter, no_bigrams):
     self.logger.info('--- Query-by-String Evaluation ---')
     train_list = self._load_word_list_from_xml(train_xml_file, doc_img_dir)
     test_list = self._load_word_list_from_xml(test_xml_file, doc_img_dir)
     
     phoc_unigrams = unigrams_from_word_list(word_list=train_list, split_character=annotation_delimiter)
     phoc_size = np.sum(phoc_unigram_levels)*len(phoc_unigrams)
     if no_bigrams:
         n_bigrams = 0         
         bigrams = None
         bigram_levels = None
     else:
         n_bigrams = 50
         bigrams = get_most_common_n_grams(words=[word.get_transcription() 
                                                  for word in train_list], 
                                           num_results=n_bigrams, n=2)
         bigram_levels = [2]
         phoc_size += 100
     
     # Set CPU/GPU mode
     if gpu_id != None:
         self.logger.info('Setting Caffe to GPU mode using device %d', gpu_id)
         caffe.set_mode_gpu()
         caffe.set_device(gpu_id)
     else:
         self.logger.info('Setting Caffe to CPU mode')
         caffe.set_mode_cpu()
     phocnet = self._load_pretrained_phocnet(phocnet_bin_path, gpu_id, debug_mode, 
                                             deploy_proto_path, phoc_size)
     self.logger.info('Predicting PHOCs for %d test words', len(test_list))
     test_phocs = self._net_output_for_word_list(word_list=test_list, cnn=phocnet, 
                                                 suppress_caffe_output=not debug_mode)
     test_strings = [word.get_transcription() for word in test_list] 
     qry_strings = list(sorted(set(test_strings)))
     qry_phocs = build_phoc(words=qry_strings, phoc_unigrams=phoc_unigrams, unigram_levels=phoc_unigram_levels, 
                            split_character=annotation_delimiter, phoc_bigrams=bigrams, bigram_levels=bigram_levels)
     self.logger.info('Calculating mAP...')
     mean_ap, _ = map_from_query_test_feature_matrices(query_features=qry_phocs, test_features=test_phocs, query_labels=qry_strings, 
                                                       test_labels=test_strings, metric=metric, drop_first=False)
     self.logger.info('mAP: %f', mean_ap*100)
コード例 #4
0
 def train_phocnet(self):
     self.logger.info('--- Running PHOCNet Training ---')
     # --- Step 1: check if we need to create the LMDBs
     # load the word lists
     xml_reader = XMLReader(make_lower_case=self.use_lower_case_only)
     self.dataset_name, train_list, test_list = xml_reader.load_train_test_xml(train_xml_path=self.train_annotation_file, 
                                                                               test_xml_path=self.test_annotation_file, 
                                                                               img_dir=self.doc_img_dir)
     phoc_unigrams = unigrams_from_word_list(word_list=train_list, split_character=self.annotation_delimiter)
     self.logger.info('PHOC unigrams: %s', ' '.join(phoc_unigrams))
     self.test_iter = len(test_list)
     self.logger.info('Using dataset \'%s\'', self.dataset_name)
     
     # check if we need to create LMDBs
     lmdb_prefix = '%s_nti%d_pul%s' % (self.dataset_name, self.n_train_images,
                                       '-'.join([str(elem) for elem in self.phoc_unigram_levels]))
     train_word_images_lmdb_path = os.path.join(self.lmdb_dir, '%s_train_word_images_lmdb' % lmdb_prefix)
     train_phoc_lmdb_path = os.path.join(self.lmdb_dir, '%s_train_phocs_lmdb' % lmdb_prefix)
     test_word_images_lmdb_path = os.path.join(self.lmdb_dir, '%s_test_word_images_lmdb' % lmdb_prefix)
     test_phoc_lmdb_path = os.path.join(self.lmdb_dir, '%s_test_phocs_lmdb' % lmdb_prefix)
     lmdbs_exist = (os.path.exists(train_word_images_lmdb_path),
                    os.path.exists(train_phoc_lmdb_path),
                    os.path.exists(test_word_images_lmdb_path),
                    os.path.exists(test_phoc_lmdb_path))
                     
     if self.use_bigrams:
         n_bigrams = 50
         bigrams = get_most_common_n_grams(words=[word.get_transcription() 
                                                  for word in train_list], 
                                           num_results=n_bigrams, n=2)
         bigram_levels = [2]
     else:       
         n_bigrams = 0         
         bigrams = None
         bigram_levels = None        
     if not np.all(lmdbs_exist) or self.recreate_lmdbs:     
         self.logger.info('Creating LMDBs...')  
                 
             
         train_phocs = build_phoc(words=[word.get_transcription() for word in train_list], 
                                  phoc_unigrams=phoc_unigrams, unigram_levels=self.phoc_unigram_levels,
                                  phoc_bigrams=bigrams, bigram_levels=bigram_levels,
                                  split_character=self.annotation_delimiter,
                                  on_unknown_unigram='warn')
         test_phocs = build_phoc(words=[word.get_transcription() for word in test_list],
                                 phoc_unigrams=phoc_unigrams, unigram_levels=self.phoc_unigram_levels,
                                 phoc_bigrams=bigrams, bigram_levels=bigram_levels,
                                 split_character=self.annotation_delimiter,
                                 on_unknown_unigram='warn')
         self._create_train_test_phocs_lmdbs(train_list=train_list, train_phocs=train_phocs, 
                                             test_list=test_list, test_phocs=test_phocs,
                                             train_word_images_lmdb_path=train_word_images_lmdb_path,
                                             train_phoc_lmdb_path=train_phoc_lmdb_path,
                                             test_word_images_lmdb_path=test_word_images_lmdb_path,
                                             test_phoc_lmdb_path=test_phoc_lmdb_path)
     else:
         self.logger.info('Found LMDBs...')
     
     # --- Step 2: create the proto files
     self.logger.info('Saving proto files...')
     # prepare the output paths
     train_proto_path = os.path.join(self.proto_dir, 'train_phocnet_%s.prototxt' % self.dataset_name)
     test_proto_path = os.path.join(self.proto_dir, 'test_phocnet_%s.prototxt' % self.dataset_name)
     solver_proto_path = os.path.join(self.proto_dir, 'solver_phocnet_%s.prototxt' % self.dataset_name)
     
     # generate the proto files
     n_attributes = np.sum(self.phoc_unigram_levels)*len(phoc_unigrams)
     if self.use_bigrams:
         n_attributes += np.sum(bigram_levels)*n_bigrams
     mpg = ModelProtoGenerator(initialization='msra', use_cudnn_engine=self.gpu_id is not None)        
     train_proto = mpg.get_phocnet(word_image_lmdb_path=train_word_images_lmdb_path, phoc_lmdb_path=train_phoc_lmdb_path, 
                                   phoc_size=n_attributes, 
                                   generate_deploy=False)
     test_proto = mpg.get_phocnet(word_image_lmdb_path=test_word_images_lmdb_path, phoc_lmdb_path=test_phoc_lmdb_path, 
                                  phoc_size=n_attributes, generate_deploy=False)
     solver_proto = generate_solver_proto(train_net=train_proto_path, test_net=test_proto_path,
                                          base_lr=self.learning_rate, momentum=self.momentum, display=self.display,
                                          lr_policy='step', gamma=self.gamma, stepsize=self.step_size,
                                          solver_mode=self.solver_mode, iter_size=self.batch_size, max_iter=self.max_iter,
                                          average_loss=self.display, test_iter=self.test_iter, test_interval=self.test_interval,
                                          weight_decay=self.weight_decay)
     # save the proto files
     save_prototxt(file_path=train_proto_path, proto_object=train_proto, header_comment='Train PHOCNet %s' % self.dataset_name)
     save_prototxt(file_path=test_proto_path, proto_object=test_proto, header_comment='Test PHOCNet %s' % self.dataset_name)
     save_prototxt(file_path=solver_proto_path, proto_object=solver_proto, header_comment='Solver PHOCNet %s' % self.dataset_name)
     
     # --- Step 3: train the PHOCNet
     self.logger.info('Starting SGD...')
     self._run_sgd(solver_proto_path=solver_proto_path)
コード例 #5
0
    def _load_data_from_xml(self):
        # --- Step 1: check if we need to create the LMDBs
        # load the word lists
        xml_reader = XMLReader(make_lower_case=self.use_lower_case_only)
        data = xml_reader.load_train_test_xml(
            train_xml_path=self.train_annotation_file,
            test_xml_path=self.test_annotation_file,
            img_dir=self.doc_img_dir)
        self.dataset_name, train_list, test_list = data
        phoc_unigrams = unigrams_from_word_list(
            word_list=train_list, split_character=self.annotation_delimiter)
        self.logger.info('PHOC unigrams: %s', ' '.join(phoc_unigrams))
        self.test_iter = len(test_list)
        self.logger.info('Using dataset \'%s\'', self.dataset_name)

        # compute PHOC size
        if self.use_bigrams:
            n_bigrams = 50
            bigrams = get_most_common_n_grams(
                words=[word.get_transcription() for word in train_list],
                num_results=n_bigrams,
                n=2)
            bigram_levels = [2]
        else:
            n_bigrams = 0
            bigrams = None
            bigram_levels = None

        n_attributes = np.sum(self.phoc_unigram_levels) * len(phoc_unigrams)
        if self.use_bigrams:
            n_attributes += np.sum(bigram_levels) * n_bigrams

        # check if we need to create LMDBs
        lmdb_paths = self._create_lmdb_paths()

        lmdbs_exist = (os.path.exists(lmdb_paths[0]),
                       os.path.exists(lmdb_paths[1]),
                       os.path.exists(lmdb_paths[2]),
                       os.path.exists(lmdb_paths[3]))

        if not np.all(lmdbs_exist) or self.recreate_lmdbs:
            if not self.recreate_lmdbs:
                raise ValueError("LMDBs missing: {} - {}".format(
                    lmdb_paths[0], lmdb_paths[2]))
            self.logger.info('Creating LMDBs...')
            train_phocs = build_phoc(
                words=[word.get_transcription() for word in train_list],
                phoc_unigrams=phoc_unigrams,
                unigram_levels=self.phoc_unigram_levels,
                phoc_bigrams=bigrams,
                bigram_levels=bigram_levels,
                split_character=self.annotation_delimiter,
                on_unknown_unigram='warn')
            test_phocs = build_phoc(
                words=[word.get_transcription() for word in test_list],
                phoc_unigrams=phoc_unigrams,
                unigram_levels=self.phoc_unigram_levels,
                phoc_bigrams=bigrams,
                bigram_levels=bigram_levels,
                split_character=self.annotation_delimiter,
                on_unknown_unigram='warn')
            self._create_train_test_phocs_lmdbs(
                train_list=train_list,
                train_phocs=train_phocs,
                test_list=test_list,
                test_phocs=test_phocs,
                train_word_images_lmdb_path=lmdb_paths[0],
                train_phoc_lmdb_path=lmdb_paths[1],
                test_word_images_lmdb_path=lmdb_paths[2],
                test_phoc_lmdb_path=lmdb_paths[3])
        else:
            self.logger.info('Found LMDBs...')

        return lmdb_paths, n_attributes
コード例 #6
0
    def eval_qbs(self, phocnet_bin_path, train_xml_file, test_xml_file, phoc_unigram_levels, 
                 gpu_id, debug_mode, doc_img_dir, deploy_proto_path, metric, 
                 annotation_delimiter, no_bigrams, dense_net,
                 min_image_height, min_image_width, max_pixel,image_size,protocol):
        
        self.min_image_height = min_image_height
        self.min_image_width = min_image_width
        self.max_pixel = max_pixel
        
        self.logger.info('--- Query-by-String Evaluation ---')
        if protocol == 'almazan':
            train_list = self._load_word_list_from_xml(train_xml_file, doc_img_dir)
            test_list = self._load_word_list_from_xml(test_xml_file, doc_img_dir)
        else:
            self.logger.info('--- Use Botany Evaluation protocol ---')
            train_list, test_list, qry_list = DatasetLoader.load_icfhr2016_competition(dataset_name='botany',
                                                                                       train_set='Train_III',
                                                                                       path='/vol/corpora/document-image-analysis/competition_icfhr2016/')
        
        
        phoc_unigrams = unigrams_from_word_list(word_list=train_list, split_character=annotation_delimiter)
        phoc_size = np.sum(phoc_unigram_levels)*len(phoc_unigrams)
        if no_bigrams:
                n_bigrams = 0         
                bigrams = None
                bigram_levels = None
        else:
            n_bigrams = 50
            bigrams = get_most_common_n_grams(words=[word.get_transcription() for word in train_list], 
                                                  num_results=n_bigrams, n=2)
            bigram_levels = [2]
            phoc_size += 100
        
        if dense_net is None:
            phocnet = self._load_pretrained_phocnet(phocnet_bin_path, gpu_id, debug_mode, 
                                                    deploy_proto_path, phoc_size)
        else:
            self.logger.info('--- Load DenseNet ---')
            phocnet = self._load_pretrained_dense_net(phocnet_bin_path, gpu_id, dense_net, debug_mode)
        
        # Set CPU/GPU mode
        if gpu_id != None:
            self.logger.info('Setting Caffe to GPU mode using device %d', gpu_id)
            caffe.set_mode_gpu()
            caffe.set_device(gpu_id)
        else:
            self.logger.info('Setting Caffe to CPU mode')
            caffe.set_mode_cpu()

        if protocol == 'almazan':
            self.logger.info('Predicting PHOCs for %d test words', len(test_list))
            test_phocs = self._net_output_for_word_list(word_list=test_list, cnn=phocnet,
                                                        image_size = image_size, max_pixel=max_pixel, 
                                                        suppress_caffe_output=not debug_mode)
            test_strings = [word.get_transcription() for word in test_list] 
            qry_strings = list(sorted(set(test_strings)))
            qry_phocs = build_phoc(words=qry_strings, phoc_unigrams=phoc_unigrams, unigram_levels=phoc_unigram_levels, 
                                   split_character=annotation_delimiter, phoc_bigrams=bigrams, bigram_levels=bigram_levels)
            self.logger.info('Calculating mAP...')
            mean_ap, _ = map_from_query_test_feature_matrices(query_features=qry_phocs, test_features=test_phocs, query_labels=qry_strings, 
                                                              test_labels=test_strings, metric=metric, drop_first=False)
        else:
            qry_strings = [word.get_transcription() for word in qry_list]
            if not np.all([type(elem) == str or type(elem) == unicode for elem in qry_strings]):
                    raise ValueError('query-test protocol needs a query list of strings only')
            
                
            self.logger.info('Predicting PHOCs for %d test words', len(test_list))
            
            test_phocs = self._net_output_for_word_list(word_list=test_list, cnn=phocnet,
                                                        image_size = image_size, max_pixel= max_pixel,
                                                        suppress_caffe_output=not debug_mode)
            
            transcriptions = [word.get_transcription() for word in test_list]
            
            qry_phocs = build_phoc(words=qry_strings, phoc_unigrams=phoc_unigrams, unigram_levels=phoc_unigram_levels,
                                   split_character=annotation_delimiter, phoc_bigrams=bigrams, bigram_levels=bigram_levels)
            
            self.logger.info('Calculating mAP...')
            mean_ap, _ = map_from_query_test_feature_matrices(query_features=qry_phocs, test_features=test_phocs, query_labels=qry_strings,
                                                                test_labels=transcriptions, metric=metric, drop_first=False)
        
        
        self.logger.info('mAP: %f', mean_ap*100)
コード例 #7
0
    def train_phocnet(self):
        self.logger.info('--- Running PHOCNet Prune Training ---')
        # --- Step 1: check if we need to create the LMDBs
        # load the word lists
        if self.train_annotation_file is not None and self.test_annotation_file is not None:
            xml_reader = XMLReader(make_lower_case=self.use_lower_case_only)
            self.dataset_name, train_list, test_list = xml_reader.load_train_test_xml(
                train_xml_path=self.train_annotation_file,
                test_xml_path=self.test_annotation_file,
                img_dir=self.doc_img_dir)
        elif self.dataset_name is not None:
            train_list, test_list, qry_list = DatasetLoader.load_icfhr2016_competition(
                'botany',
                train_set='Train_III',
                path=
                '/vol/corpora/document-image-analysis/competition_icfhr2016/')
        else:
            self.logger.info('Annotation missing')

        phoc_unigrams = unigrams_from_word_list(
            word_list=train_list, split_character=self.annotation_delimiter)
        self.logger.info('PHOC unigrams: %s', ' '.join(phoc_unigrams))
        self.test_iter = len(test_list)
        self.logger.info('Using dataset \'%s\'', self.dataset_name)

        lmdb_prefix = '%s_nti%d_pul%s' % (
            self.dataset_name, self.n_train_images, '-'.join(
                [str(elem) for elem in self.phoc_unigram_levels]))
        train_word_images_lmdb_path = os.path.join(
            self.lmdb_dir, '%s_train_word_images_lmdb' % lmdb_prefix)
        train_phoc_lmdb_path = os.path.join(
            self.lmdb_dir, '%s_train_phocs_lmdb' % lmdb_prefix)
        test_word_images_lmdb_path = os.path.join(
            self.lmdb_dir, '%s_test_word_images_lmdb' % lmdb_prefix)
        test_phoc_lmdb_path = os.path.join(self.lmdb_dir,
                                           '%s_test_phocs_lmdb' % lmdb_prefix)

        # check if we need to create LMDBs
        lmdbs_exist = (os.path.exists(train_word_images_lmdb_path),
                       os.path.exists(train_phoc_lmdb_path),
                       os.path.exists(test_word_images_lmdb_path),
                       os.path.exists(test_phoc_lmdb_path))

        if self.use_bigrams:
            n_bigrams = 50
            bigrams = get_most_common_n_grams(
                words=[word.get_transcription() for word in train_list],
                num_results=n_bigrams,
                n=2)
            bigram_levels = [2]
        else:
            n_bigrams = 0
            bigrams = None
            bigram_levels = None
        if not np.all(lmdbs_exist) or self.recreate_lmdbs:
            self.logger.info('Creating LMDBs...')

            train_phocs = build_phoc(
                words=[word.get_transcription() for word in train_list],
                phoc_unigrams=phoc_unigrams,
                unigram_levels=self.phoc_unigram_levels,
                phoc_bigrams=bigrams,
                bigram_levels=bigram_levels,
                split_character=self.annotation_delimiter,
                on_unknown_unigram='warn')
            test_phocs = build_phoc(
                words=[word.get_transcription() for word in test_list],
                phoc_unigrams=phoc_unigrams,
                unigram_levels=self.phoc_unigram_levels,
                phoc_bigrams=bigrams,
                bigram_levels=bigram_levels,
                split_character=self.annotation_delimiter,
                on_unknown_unigram='warn')
            self._create_train_test_phocs_lmdbs(
                train_list=train_list,
                train_phocs=train_phocs,
                test_list=test_list,
                test_phocs=test_phocs,
                train_word_images_lmdb_path=train_word_images_lmdb_path,
                train_phoc_lmdb_path=train_phoc_lmdb_path,
                test_word_images_lmdb_path=test_word_images_lmdb_path,
                test_phoc_lmdb_path=test_phoc_lmdb_path,
                image_size=self.image_size)

        else:
            self.logger.info('Found LMDBs...')

        # --- Step 2: create the proto files
        self.logger.info('Saving proto files...')
        # prepare the output paths

        name = self.get_net_name()

        train_proto_path = os.path.join(
            self.proto_dir, 'train_%s_%s.prototxt' % (name, self.dataset_name))
        test_proto_path = os.path.join(
            self.proto_dir, 'test_%s_%s.prototxt' % (name, self.dataset_name))
        solver_proto_path = os.path.join(
            self.proto_dir,
            'solver_%s_%s.prototxt' % (name, self.dataset_name))

        # generate the proto files
        n_attributes = np.sum(self.phoc_unigram_levels) * len(phoc_unigrams)
        if self.use_bigrams:
            n_attributes += np.sum(bigram_levels) * n_bigrams
        mpg = ModelProtoGenerator(initialization='msra',
                                  use_cudnn_engine=self.gpu_id is not None)

        if self.architecture == 'dense':
            train_proto = mpg.get_dense_phocnet(
                word_image_lmdb_path=train_word_images_lmdb_path,
                phoc_lmdb_path=train_phoc_lmdb_path,
                phoc_size=n_attributes,
                generate_deploy=False,
                nblocks=self.nblocks,
                growth_rate=self.growth_rate,
                nlayers=self.nlayers,
                config=self.config,
                no_batch_normalization=self.no_batch_normalization,
                use_bottleneck=self.use_bottleneck,
                use_compression=self.use_compression,
                pool_init=self.pool_init,
                conv_init=self.conv_init,
                init_7=self.init_7,
                pooling=self.pooling,
                max_out=self.max_out)

            test_proto = mpg.get_dense_phocnet(
                word_image_lmdb_path=test_word_images_lmdb_path,
                phoc_lmdb_path=test_phoc_lmdb_path,
                phoc_size=n_attributes,
                generate_deploy=False,
                nblocks=self.nblocks,
                growth_rate=self.growth_rate,
                nlayers=self.nlayers,
                config=self.config,
                no_batch_normalization=self.no_batch_normalization,
                use_bottleneck=self.use_bottleneck,
                use_compression=self.use_compression,
                pool_init=self.pool_init,
                conv_init=self.conv_init,
                init_7=self.init_7,
                pooling=self.pooling,
                max_out=self.max_out)
        elif self.architecture == 'phoc':
            train_proto = mpg.get_phocnet(
                word_image_lmdb_path=train_word_images_lmdb_path,
                phoc_lmdb_path=train_phoc_lmdb_path,
                phoc_size=n_attributes,
                pooling=self.pooling,
                generate_deploy=False,
                max_out=self.max_out)
            test_proto = mpg.get_phocnet(
                word_image_lmdb_path=test_word_images_lmdb_path,
                phoc_lmdb_path=test_phoc_lmdb_path,
                phoc_size=n_attributes,
                pooling=self.pooling,
                generate_deploy=False,
                max_out=self.max_out)
        elif self.architecture == 'hybrid':
            train_proto = mpg.get_hybrid_phocnet(
                word_image_lmdb_path=train_word_images_lmdb_path,
                phoc_lmdb_path=train_phoc_lmdb_path,
                phoc_size=n_attributes,
                generate_deploy=False,
                nblocks=self.nblocks,
                growth_rate=self.growth_rate,
                nlayers=self.nlayers,
                config=self.config,
                no_batch_normalization=self.no_batch_normalization,
                use_bottleneck=self.use_bottleneck,
                use_compression=self.use_compression,
                pooling=self.pooling,
                max_out=self.max_out)

            test_proto = mpg.get_hybrid_phocnet(
                word_image_lmdb_path=test_word_images_lmdb_path,
                phoc_lmdb_path=test_phoc_lmdb_path,
                phoc_size=n_attributes,
                generate_deploy=False,
                nblocks=self.nblocks,
                growth_rate=self.growth_rate,
                nlayers=self.nlayers,
                config=self.config,
                no_batch_normalization=self.no_batch_normalization,
                use_bottleneck=self.use_bottleneck,
                use_compression=self.use_compression,
                pooling=self.pooling,
                max_out=self.max_out)

        else:
            raise ValueError('Unknown Architecture!')

        # save the proto files
        save_prototxt(file_path=train_proto_path,
                      proto_object=train_proto,
                      header_comment='Train PHOCNet %s' % self.dataset_name)
        save_prototxt(file_path=test_proto_path,
                      proto_object=test_proto,
                      header_comment='Test PHOCNet %s' % self.dataset_name)

        solver_proto = generate_solver_proto(train_net=train_proto_path,
                                             test_net=test_proto_path,
                                             base_lr=self.learning_rate,
                                             momentum=self.momentum,
                                             display=self.display,
                                             lr_policy='step',
                                             gamma=self.gamma,
                                             stepsize=self.step_size,
                                             solver_mode=self.solver_mode,
                                             iter_size=self.batch_size,
                                             max_iter=self.max_iter,
                                             average_loss=self.display,
                                             test_iter=self.test_iter,
                                             test_interval=self.test_interval,
                                             weight_decay=self.weight_decay)

        save_prototxt(file_path=solver_proto_path,
                      proto_object=solver_proto,
                      header_comment='Solver PHOCNet %s' % self.dataset_name)

        # --- Step 3: train the PHOCNet
        self.logger.info('Starting SGD')
        solver = self._run_sgd(solver_proto_path=solver_proto_path)