def __init__(
     self,
     test_file,
     class_key_name,
     filter,
     database,
 ):
     self.__classifier = classifier.Classifier(database)
     self.__test_file = test_file
     self.__class_key_name = class_key_name
     self.__filter = filter
     self.__database = database
     self.__true_classes = preprocess.get_classes(test_file, class_key_name,
                                                  filter)
     classes = []
     with sqlite3.connect(database) as conn:
         cur = conn.cursor()
         cur.execute("SELECT class FROM class_count")
         for i in cur:
             classes.append(i[0])
     self.measurments = pd.DataFrame(
         {'true': list([self.__true_classes[x] for x in classes])},
         index=classes,
         columns=['true', 'right', 'all'])
     self.measurments.fillna(0, inplace=True)
示例#2
0
def train_from_images(model, image_dir, params, modelfilename):
    ''' trains network on prebuilt image dataset'''
    im_gen = image_iterator(os.path.join(image_dir, 'train'),
                            params['batch_size'], params['encoder'])
    val_dir = os.path.join(image_dir, 'test')
    val_steps = sum(get_classes(val_dir)[2]) // params['batch_size'] + 1
    val_gen = test_image_iterator(val_dir,
                                  params['batch_size'],
                                  params['encoder'],
                                  loop=True)

    model.fit_generator(im_gen,
                        params['steps_per_epoch'],
                        epochs=params['epochs'],
                        verbose=1,
                        validation_data=val_gen,
                        validation_steps=val_steps,
                        callbacks=([
                            ModelCheckpoint(modelfilename,
                                            save_best_only=True),
                            TensorBoard()
                        ]))
示例#3
0
def train(data_file, database_file_path, class_key_name, class_filter):
    classes = get_classes(data_file, class_key_name, class_filter)
    with sqlite3.connect(database_file_path) as conn:
        cur = conn.cursor()
        cur.execute(create_class_count_query)
        for clazz, count in classes.items():
            cur.execute(insert_class_count_query.format(clazz, count))

        create_freq_query = "CREATE TABLE IF NOT EXISTS frequency (word VARCHAR(20) PRIMARY KEY"
        for i in classes.keys():
            create_freq_query += (",`" + i + "` INTEGER DEFAULT 0 ")
        create_freq_query += ');'
        #print(create_freq_query)
        cur.execute(create_freq_query)

        insert_freq_query = "INSERT INTO frequency (word) VALUES('{}');"

        f_map = frequency_map(data_file, class_key_name, class_filter)

        k = Counter()
        for i, j in f_map.items():
            k = j + k

        for i in k.keys():
            #print(i)
            #print(insert_freq_query.format(i).format(i))
            cur.execute(insert_freq_query.format(i))
        conn.commit()
        for clazz, bag in f_map.items():
            for word, count in bag.items():
                #print("UPDATE frequency SET `{}` = {} WHERE word='{}';".format(clazz, count, word))
                cur.execute(
                    "UPDATE frequency SET `{}` = {} WHERE word='{}';".format(
                        clazz, count, word))

        conn.commit()
示例#4
0
parser.add_argument('--predict', action='store_true', help='预测模式')

args = parser.parse_args()

args.filter_sizes = list(map(int, args.filter_sizes.split(',')))

# 设置使用的设备
if args.device_num != -1 and torch.cuda.is_available():
    args.device = torch.device('cuda:'+str(args.device_num))
    args.cuda = True
    torch.cuda.set_device(torch.device(args.device))
else:
    args.device = torch.device('cpu')
    args.cuda = False

args.classes = preprocess.get_classes(args.data_path)  # 获取分类列表
args.class_num = len(args.classes)

TEXT, LABEL = preprocess.get_field(args.data_path)  # 建立文本和标签Field

# 是否进入预测模式
if args.predict:
    # 预测模式
    if os.path.exists(os.path.join(args.data_path, 'vocab.pkl')) and os.path.exists(os.path.join(args.data_path, 'saved_model.pt')):
        # 加载词向量和模型
        s = input('plase enter a sentence in classes '+str(args.classes)+'\n')
        with open(os.path.join(args.data_path, 'vocab.pkl'), 'rb') as f:
            TEXT.vocab = pickle.load(f)
        args.vocab_size = len(TEXT.vocab)
        text_cnn = module.TextCNN(args)
        # 预测