def test(args, shared_model, dataset, targets, log): start_time = time.time() log.info('Test time ' + time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - start_time)) + ', ' + 'Start testing.') local_model = SVM() local_model.load_state_dict(shared_model.state_dict()) if args.gpu: local_model = local_model.cuda() correct_cnt = 0 predictions = np.zeros([targets.shape[0]], dtype=np.int64) for idx in range(targets.shape[0]): data = dataset[idx] data = Variable(torch.from_numpy(data)) if args.gpu: data = data.cuda() target = targets[idx] output = local_model(data) if args.gpu: output = output.cpu() predict_class = output.max(0)[1].data.numpy()[0] predictions[idx] = predict_class if target == predict_class: correct_cnt += 1 # else: # print(predict_class) # if (idx + 1) % 100 == 0: # log.info('Test time ' + time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - start_time)) + ', ' + 'Accuracy: %d / %d\t%0.4f' % (correct_cnt, idx + 1, correct_cnt / (idx + 1))) log.info('Overall f1 score = %0.4f' % (f1_score(list(targets), list(predictions), average='weighted'))) log.info('Overall accuracy = %0.2f%%' % (100 * correct_cnt / targets.shape[0])) return correct_cnt / targets.shape[0]
if not os.path.exists(args.model_dir): os.mkdir(args.model_dir) if not os.path.exists(args.log_dir): os.mkdir(args.log_dir) if args.train: model = SVM() if args.model_load: try: saved_state = torch.load( os.path.join(args.model_dir, 'best_model.dat')) model.load_state_dict(saved_state) except: print('Cannot load existing model from file!') if args.gpu: model = model.cuda() dataset = torch.from_numpy(np.load("../output/data/dataset_train.npy")) targets = torch.from_numpy( np.int64(np.load("../output/data/target_train.npy"))) dataset_test = np.load(dataset_path) targets_test = np.int64(np.load(target_path)) if args.L2norm: log_test = setup_logger( 0, 'test_log_norm', os.path.join(args.log_dir, 'test_log_norm.txt')) log = setup_logger( 0, 'train_log_norm', os.path.join(args.log_dir, 'train_log_norm.txt')) optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=10) else: