def eval(args): # parse config config = parse_config(args.config) val_config = merge_configs(config, 'valid', vars(args)) print_configs(val_config, "Valid") with fluid.dygraph.guard(): val_model = ResNet3D.Resnet('ResNet3D',[2,2,2,2], ResNet3D.get_inplanes()) label_dic = np.load('label_dir.npy', allow_pickle=True).item() label_dic = {v: k for k, v in label_dic.items()} # get infer reader val_reader = KineticsReader(args.model_name.upper(), 'valid', val_config).create_reader() # if no weight files specified, exit() if args.weights: weights = args.weights else: print("model path must be specified") exit() para_state_dict, _ = fluid.load_dygraph(weights) val_model.load_dict(para_state_dict) val_model.eval() acc_list = [] for batch_id, data in enumerate(val_reader()): dy_x_data = np.array([x[0] for x in data]).astype('float32') y_data = np.array([[x[1]] for x in data]).astype('int64') img = fluid.dygraph.to_variable(dy_x_data) label = fluid.dygraph.to_variable(y_data) label.stop_gradient = True out, acc = val_model(img, label) acc_list.append(acc.numpy()[0]) print("验证集准确率为:{}".format(np.mean(acc_list)))
def infer(args): # parse config config = parse_config(args.config) infer_config = merge_configs(config, 'infer', vars(args)) print_configs(infer_config, "Infer") with fluid.dygraph.guard(): infer_model = ResNet3D.Resnet('ResNet3D', [2, 2, 2, 2], ResNet3D.get_inplanes()) label_dic = np.load('label_dir.npy', allow_pickle=True).item() label_dic = {v: k for k, v in label_dic.items()} # get infer reader infer_reader = KineticsReader(args.model_name.upper(), 'infer', infer_config).create_reader() # if no weight files specified, exit() if args.weights: weights = args.weights else: print("model path must be specified") exit() para_state_dict, _ = fluid.load_dygraph(weights) infer_model.load_dict(para_state_dict) infer_model.eval() for batch_id, data in enumerate(infer_reader()): dy_x_data = np.array([x[0] for x in data]).astype('float32') y_data = [x[1] for x in data] img = fluid.dygraph.to_variable(dy_x_data) out = infer_model(img).numpy()[0] label_id = np.where(out == np.max(out)) print("实际标签{}, 预测结果{}".format(y_data, label_dic[label_id[0][0]]))