Пример #1
0
def save_air():
    """Save air file"""
    print('============= centerface start save air ==================')

    parser = argparse.ArgumentParser(description='Convert ckpt to air')
    parser.add_argument('--pretrained',
                        type=str,
                        default='',
                        help='pretrained model to load')
    parser.add_argument('--batch_size', type=int, default=8, help='batch size')

    args = parser.parse_args()
    network = CenterfaceMobilev2()

    if os.path.isfile(args.pretrained):
        param_dict = load_checkpoint(args.pretrained)
        param_dict_new = {}
        for key, values in param_dict.items():
            if key.startswith('moments.') or key.startswith(
                    'moment1.') or key.startswith('moment2.'):
                continue
            elif key.startswith('centerface_network.'):
                param_dict_new[key[19:]] = values
            else:
                param_dict_new[key] = values
        load_param_into_net(network, param_dict_new)
        print('load model {} success'.format(args.pretrained))

        input_data = np.random.uniform(low=0,
                                       high=1.0,
                                       size=(args.batch_size, 3, 832,
                                             832)).astype(np.float32)

        tensor_input_data = Tensor(input_data)
        export(network,
               tensor_input_data,
               file_name=args.pretrained.replace(
                   '.ckpt', '_' + str(args.batch_size) + 'b.air'),
               file_format='AIR')

        print("export model success.")
Пример #2
0
args, _ = parser.parse_known_args()

if __name__ == "__main__":
    # logger
    args.outputs_dir = os.path.join(
        args.ckpt_path,
        datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
    args.logger = get_logger(args.outputs_dir, args.rank)
    args.logger.save_args(args)

    if args.ckpt_name != "":
        args.start = 0
        args.end = 1

    for loop in range(args.start, args.end, 1):
        network = CenterfaceMobilev2()
        default_recurisive_init(network)

        if args.ckpt_name == "":
            ckpt_num = loop * args.device_num + args.rank + 1
            ckpt_name = "0-" + str(ckpt_num) + "_" + str(
                args.steps_per_epoch * ckpt_num) + ".ckpt"
        else:
            ckpt_name = args.ckpt_name

        test_model = args.test_model + ckpt_name
        if not test_model:
            args.logger.info('load_model {} none'.format(test_model))
            continue

        if os.path.isfile(test_model):
Пример #3
0
                    default='AIR',
                    help='file format')
parser.add_argument("--device_target",
                    type=str,
                    choices=["Ascend", "GPU", "CPU"],
                    default="Ascend",
                    help="device target")
args = parser.parse_args()

context.set_context(mode=context.GRAPH_MODE,
                    device_target=args.device_target,
                    device_id=args.device_id)

if __name__ == '__main__':
    config = ConfigCenterface()
    net = CenterfaceMobilev2()

    param_dict = load_checkpoint(args.ckpt_file)
    param_dict_new = {}
    for key, values in param_dict.items():
        if key.startswith('moments.') or key.startswith(
                'moment1.') or key.startswith('moment2.'):
            continue
        elif key.startswith('centerface_network.'):
            param_dict_new[key[19:]] = values
        else:
            param_dict_new[key] = values

    load_param_into_net(net, param_dict_new)
    net = CenterFaceWithNms(net)
    net.set_train(False)
Пример #4
0
def create_network(name, *args, **kwargs):
    if name == "centerface":
        return CenterfaceMobilev2(*args, **kwargs)
    raise NotImplementedError(f"{name} is not implemented in the repo")
Пример #5
0
def centerface(*args, **kwargs):
    return CenterfaceMobilev2(*args, **kwargs)