def main(args): # # Device configuration device = torch.device( 'cuda:{}'.format(args.gpu) if torch.cuda.is_available() else 'cpu') num_epochs = 80 num_classes = 8 learning_rate = 0.08 num_views = 3 num_layers = 4 data_path = args.dir file_list = [ './data/train_web_content.npy', './data/train_web_links.npy', './data/train_web_title.npy', './data/test_web_content.npy', './data/test_web_links.npy', './data/test_web_title.npy', './data/train_label.npy', './data/test_label.npy' ] aaa = list(map(os.path.exists, file_list)) if sum(aaa) != len(aaa): print( 'Raw data has not been pre-processed! Start pre-processing the raw data.' ) data_loader.preprocess(data_path) else: print('Loading the existing data set...') # train_dataset = data_loader.Load_datasets('train', num_classes) train_dataset = data_loader.Load_datasets('train', 8) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4) input_dims = np.array(train_dataset.data[0]).shape model = CNN_Text(input_dims, [64, 32, 32, 32], [1, 2, 3, 4], num_classes, 0.5, num_layers, num_views).to(device) model = model.double() model.device = device model.learning_rate = learning_rate model.epoch = 0 if args.model != None: model.load_state_dict(torch.load(args.mpodel)) print('Successfully load pre-trained model!') # train the model until the model is fully trained train_model(model, train_loader, num_epochs) print('Finish training process!') evaluation(model)
text_field = data.Field(lower=True) label_field = data.Field(sequential=False) train_data, dev_data = MR.splits(text_field, label_field) text_field.build_vocab(train_data, dev_data) label_field.build_vocab(train_data, dev_data) args = Args() args.dropout = 0.5 args.max_norm = 3.0 args.embed_dim = 128 args.kernel_num = 100 args.kernel_sizes = '3,4,5' args.static = False args.snapshot = 'snapshot/best.pt' args.embed_num = len(text_field.vocab) args.class_num = len(label_field.vocab) - 1 args.kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')] model = CNN_Text(args) model.load_state_dict(torch.load(args.snapshot, map_location='cpu')) model = model.to(device) @app.route('/cls/<text>') def classify_text(text): app.logger.warning(text) result, conf = predict(text, model, text_field, label_field, device) app.logger.warning(conf) return result