def save_air(): """Save air file""" print('============= centerface start save air ==================') parser = argparse.ArgumentParser(description='Convert ckpt to air') parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load') parser.add_argument('--batch_size', type=int, default=8, help='batch size') args = parser.parse_args() network = CenterfaceMobilev2() if os.path.isfile(args.pretrained): param_dict = load_checkpoint(args.pretrained) param_dict_new = {} for key, values in param_dict.items(): if key.startswith('moments.') or key.startswith( 'moment1.') or key.startswith('moment2.'): continue elif key.startswith('centerface_network.'): param_dict_new[key[19:]] = values else: param_dict_new[key] = values load_param_into_net(network, param_dict_new) print('load model {} success'.format(args.pretrained)) input_data = np.random.uniform(low=0, high=1.0, size=(args.batch_size, 3, 832, 832)).astype(np.float32) tensor_input_data = Tensor(input_data) export(network, tensor_input_data, file_name=args.pretrained.replace( '.ckpt', '_' + str(args.batch_size) + 'b.air'), file_format='AIR') print("export model success.")
args, _ = parser.parse_known_args() if __name__ == "__main__": # logger args.outputs_dir = os.path.join( args.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) args.logger = get_logger(args.outputs_dir, args.rank) args.logger.save_args(args) if args.ckpt_name != "": args.start = 0 args.end = 1 for loop in range(args.start, args.end, 1): network = CenterfaceMobilev2() default_recurisive_init(network) if args.ckpt_name == "": ckpt_num = loop * args.device_num + args.rank + 1 ckpt_name = "0-" + str(ckpt_num) + "_" + str( args.steps_per_epoch * ckpt_num) + ".ckpt" else: ckpt_name = args.ckpt_name test_model = args.test_model + ckpt_name if not test_model: args.logger.info('load_model {} none'.format(test_model)) continue if os.path.isfile(test_model):
default='AIR', help='file format') parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend", help="device target") args = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) if __name__ == '__main__': config = ConfigCenterface() net = CenterfaceMobilev2() param_dict = load_checkpoint(args.ckpt_file) param_dict_new = {} for key, values in param_dict.items(): if key.startswith('moments.') or key.startswith( 'moment1.') or key.startswith('moment2.'): continue elif key.startswith('centerface_network.'): param_dict_new[key[19:]] = values else: param_dict_new[key] = values load_param_into_net(net, param_dict_new) net = CenterFaceWithNms(net) net.set_train(False)
def create_network(name, *args, **kwargs): if name == "centerface": return CenterfaceMobilev2(*args, **kwargs) raise NotImplementedError(f"{name} is not implemented in the repo")
def centerface(*args, **kwargs): return CenterfaceMobilev2(*args, **kwargs)