Exemplo n.º 1
0
def eval(args):
    # parse config
    config = parse_config(args.config)
    val_config = merge_configs(config, 'valid', vars(args))
    print_configs(val_config, "Valid")
    with fluid.dygraph.guard():
        val_model = ResNet3D.Resnet('ResNet3D',[2,2,2,2], ResNet3D.get_inplanes())

        label_dic = np.load('label_dir.npy', allow_pickle=True).item()
        label_dic = {v: k for k, v in label_dic.items()}

        # get infer reader
        val_reader = KineticsReader(args.model_name.upper(), 'valid', val_config).create_reader()

        # if no weight files specified, exit()
        if args.weights:
            weights = args.weights
        else:
            print("model path must be specified")
            exit()
            
        para_state_dict, _ = fluid.load_dygraph(weights)
        val_model.load_dict(para_state_dict)
        val_model.eval()
        
        acc_list = []
        for batch_id, data in enumerate(val_reader()):
            dy_x_data = np.array([x[0] for x in data]).astype('float32')
            y_data = np.array([[x[1]] for x in data]).astype('int64')
            
            img = fluid.dygraph.to_variable(dy_x_data)
            label = fluid.dygraph.to_variable(y_data)
            label.stop_gradient = True
            
            out, acc = val_model(img, label)
            acc_list.append(acc.numpy()[0])

        print("验证集准确率为:{}".format(np.mean(acc_list)))
Exemplo n.º 2
0
def infer(args):
    # parse config
    config = parse_config(args.config)
    infer_config = merge_configs(config, 'infer', vars(args))
    print_configs(infer_config, "Infer")
    with fluid.dygraph.guard():
        infer_model = ResNet3D.Resnet('ResNet3D', [2, 2, 2, 2],
                                      ResNet3D.get_inplanes())

        label_dic = np.load('label_dir.npy', allow_pickle=True).item()
        label_dic = {v: k for k, v in label_dic.items()}

        # get infer reader
        infer_reader = KineticsReader(args.model_name.upper(), 'infer',
                                      infer_config).create_reader()

        # if no weight files specified, exit()
        if args.weights:
            weights = args.weights
        else:
            print("model path must be specified")
            exit()

        para_state_dict, _ = fluid.load_dygraph(weights)
        infer_model.load_dict(para_state_dict)
        infer_model.eval()

        for batch_id, data in enumerate(infer_reader()):
            dy_x_data = np.array([x[0] for x in data]).astype('float32')
            y_data = [x[1] for x in data]

            img = fluid.dygraph.to_variable(dy_x_data)

            out = infer_model(img).numpy()[0]
            label_id = np.where(out == np.max(out))
            print("实际标签{}, 预测结果{}".format(y_data, label_dic[label_id[0][0]]))