コード例 #1
0
    def __init__(self, i_net):

        # Create network
        self.netname = i_net
        print('Train: cfg.FLAGS.network', net)
        if i_net == 'vgg16':
            self.net = vgg16(batch_size=cfg.FLAGS.ims_per_batch)
            self.pre_model = cfg.FLAGS.pretrained_model
        elif i_net == 'res50':
            self.net = resnetv1(num_layers=50)
            self.pre_model = cfg.FLAGS.pretrained_resnet50_model
        elif i_net == 'res101':
            self.net = resnetv1(num_layers=101)
            self.pre_model = cfg.FLAGS.pretrained_resnet101_model
        elif i_net == 'res152':
            self.net = resnetv1(num_layers=152)
            self.pre_model = cfg.FLAGS.pretrained_resnet152_model
        else:
            raise NotImplementedError

        print('self.net', self.net)
        self.imdb, self.roidb = combined_roidb("voc_2007_trainval")

        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.output_dir = cfg.get_output_dir(self.imdb, 'default')
コード例 #2
0
ファイル: train.py プロジェクト: IAMLabUMD/tpami2020
    def __init__(self, dataset):

        # Create network
        if cfg.FLAGS.network == 'vgg16':
            self.net = vgg16(batch_size=cfg.FLAGS.ims_per_batch)
        else:
            raise NotImplementedError

        if dataset == "tego":
            self.imdb, self.roidb = load_db("tego_train")
        elif dataset == "tego_wholeBB":
            self.imdb, self.roidb = load_db("tego_train-wholeBB")
        elif dataset == "tego_blind":
            self.imdb, self.roidb = load_db("tego_train-blind")
        elif dataset == "tego_sighted":
            self.imdb, self.roidb = load_db("tego_train-sighted")
        elif dataset == "tego_blind_wholeBB":
            self.imdb, self.roidb = load_db("tego_train-blind-wholeBB")
        elif dataset == "tego_sighted_wholeBB":
            self.imdb, self.roidb = load_db("tego_train-sighted-wholeBB")
        else:
            self.imdb, self.roidb = combined_roidb("voc_2007_trainval")

        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.output_dir = cfg.get_output_dir(self.imdb, dataset)
コード例 #3
0
ファイル: train.py プロジェクト: Joekma/work_notes
    def __init__(self):

        # Create network
        if cfg.FLAGS.network == 'vgg16':
            self.net = vgg16(batch_size=cfg.FLAGS.ims_per_batch)
        else:
            raise NotImplementedError

        #self.imdb, self.roidb = combined_roidb("voc_2007_trainval")
        self.imdb, self.roidb = combined_roidb("geetcodechinese_2019_train")
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.output_dir = cfg.get_output_dir(self.imdb, 'default')
コード例 #4
0
    def __init__(self):

        # Create network
        if cfg.FLAGS.network == 'resnet_v1':
            self.net = resnetv1(batch_size=cfg.FLAGS.ims_per_batch,
                                num_layers=50)
        else:
            raise NotImplementedError

        self.imdb, self.roidb = combined_roidb("voc_2007_trainval")

        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.output_dir = cfg.get_output_dir(self.imdb, 'default')
コード例 #5
0
    def __init__(self):

        # Create network
        if cfg.FLAGS.net == 'vgg16':
            self.net = vgg16(batch_size=cfg.FLAGS.ims_per_batch)
        elif cfg.FLAGS.net == 'resnetv1':
            self.net = resnetv1(batch_size=cfg.FLAGS.ims_per_batch)
        else:
            raise NotImplementedError

        self.imdb, self.roidb = combined_roidb("Columbia")

        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.output_dir = cfg.get_output_dir(self.imdb, 'default')
コード例 #6
0
    def __init__(self):

        # Create network
        if cfg.FLAGS.network == 'vgg16':
            self.net = vgg16(batch_size=cfg.FLAGS.ims_per_batch)
        else:
            raise NotImplementedError
        # 获取数据集, 并对roidb做数据增强和数据补充
        self.imdb, self.roidb = combined_roidb('jinnan2_round1_train_20190222')
        print(len(self.imdb.roidb))
        # self.roidb是合并了了水平翻转后的roi集合,数量为原先的2倍
        print(len(self.roidb))
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.output_dir = cfg.get_output_dir(self.imdb, 'default')
コード例 #7
0
    def __init__(self):

        # Create network
        if cfg.FLAGS.network == 'vgg16':
            # 创建vgg(16)网络
            self.net = vgg16(batch_size=cfg.FLAGS.ims_per_batch)
        else:
            raise NotImplementedError

        # 加载数据
        self.imdb, self.roidb = combined_roidb("voc_2007_trainval")

        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        # 模型保存的位置
        self.output_dir = cfg.get_output_dir(self.imdb, 'default')
コード例 #8
0
    def __init__(self):

        # Create network
        if cfg.FLAGS.network == 'vgg16':
            self.net = vgg16(batch_size=cfg.FLAGS.ims_per_batch)
        else:
            raise NotImplementedError
        '''
        这里应用factory.py中的get_imdb(name)函数,
        然后值为 __sets[name](),是一个字典, __sets[name] = (lambda split=split, year=year: pascal_voc(split, year)),值为一个函数,也即实例化一个pascal_voc(split, year)对象
        也就确定了imbd的name,这里只是跟imbd的name有关,可以不用管
        '''
        self.imdb, self.roidb = combined_roidb("voc_2007_trainval")     
        
        # 对原始image和gt_boxes进行平移缩放等处理,得到network input
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        # 模型保存的路径
        self.output_dir = cfg.get_output_dir(self.imdb, 'default')
コード例 #9
0
    def __init__(self, dataset):

        # Create network
        if cfg.FLAGS.network == 'vgg16':
            self.net = vgg16(batch_size=cfg.FLAGS.ims_per_batch)
        elif cfg.FLAGS.network == 'RESNET_v1_50':
            self.net = resnetv1(batch_size=cfg.FLAGS.ims_per_batch)
        else:
            raise NotImplementedError

        #The database
        #self.imdb, self.roidb = combined_roidb("voc_2007_trainval+test+Isabel")

        self.imdb, self.roidb = combined_roidb(dataset)
        #self.imdb, self.roidb = combined_roidb("Isabel")

        print(self.imdb.name)
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.output_dir = cfg.get_output_dir(self.imdb, 'default')

        print(self.output_dir)
コード例 #10
0
        print('im_detect: {:d}/{:d} {:.3f}s {:.3f}s' \
              .format(i + 1, num_images, _t['im_detect'].average_time,
                      _t['misc'].average_time))

    det_file = os.path.join(output_dir, 'detections.pkl')
    with open(det_file, 'wb') as f:
        pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)

    print('Evaluating detections')
    imdb.evaluate_detections(all_boxes, output_dir)


if __name__ == '__main__':
    imdb, roidb = combined_roidb("voc_2007_trainval")
    data_layer = RoIDataLayer(roidb, imdb.num_classes)
    output_dir = cfg.get_output_dir(imdb, 'default')

    args = parse_args()

    # model path
    demonet = args.demo_net
    dataset = args.dataset
    # tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default', NETS[demonet][0])
    tfmodel = os.path.join('default', DATASETS[dataset][0], 'default',
                           NETS[demonet][0])

    if not os.path.isfile(tfmodel + '.meta'):
        print(tfmodel)
        raise IOError(
            ('{:s} not found.\nDid you download the proper networks from '