Exemple #1
0
def main(args):
    # # Device configuration
    device = torch.device(
        'cuda:{}'.format(args.gpu) if torch.cuda.is_available() else 'cpu')
    num_epochs = 80
    num_classes = 8
    learning_rate = 0.08
    num_views = 3
    num_layers = 4
    data_path = args.dir
    file_list = [
        './data/train_web_content.npy', './data/train_web_links.npy',
        './data/train_web_title.npy', './data/test_web_content.npy',
        './data/test_web_links.npy', './data/test_web_title.npy',
        './data/train_label.npy', './data/test_label.npy'
    ]
    aaa = list(map(os.path.exists, file_list))
    if sum(aaa) != len(aaa):
        print(
            'Raw data has not been pre-processed! Start pre-processing the raw data.'
        )
        data_loader.preprocess(data_path)
    else:
        print('Loading the existing data set...')
    # train_dataset = data_loader.Load_datasets('train', num_classes)
    train_dataset = data_loader.Load_datasets('train', 8)
    train_loader = DataLoader(train_dataset,
                              batch_size=32,
                              shuffle=True,
                              num_workers=4)
    input_dims = np.array(train_dataset.data[0]).shape
    model = CNN_Text(input_dims, [64, 32, 32, 32], [1, 2, 3, 4], num_classes,
                     0.5, num_layers, num_views).to(device)
    model = model.double()
    model.device = device
    model.learning_rate = learning_rate
    model.epoch = 0
    if args.model != None:
        model.load_state_dict(torch.load(args.mpodel))
        print('Successfully load pre-trained model!')
    # train the model until the model is fully trained
    train_model(model, train_loader, num_epochs)
    print('Finish training process!')
    evaluation(model)
Exemple #2
0
text_field = data.Field(lower=True)
label_field = data.Field(sequential=False)
train_data, dev_data = MR.splits(text_field, label_field)
text_field.build_vocab(train_data, dev_data)
label_field.build_vocab(train_data, dev_data)

args = Args()
args.dropout = 0.5
args.max_norm = 3.0

args.embed_dim = 128
args.kernel_num = 100
args.kernel_sizes = '3,4,5'
args.static = False
args.snapshot = 'snapshot/best.pt'
args.embed_num = len(text_field.vocab)
args.class_num = len(label_field.vocab) - 1
args.kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]

model = CNN_Text(args)
model.load_state_dict(torch.load(args.snapshot, map_location='cpu'))
model = model.to(device)


@app.route('/cls/<text>')
def classify_text(text):
    app.logger.warning(text)
    result, conf = predict(text, model, text_field, label_field, device)
    app.logger.warning(conf)
    return result