def main(): global args args = parser.parse_args() init_log('global', logging.INFO) logger = logging.getLogger('global') train_data = custom_dset(args.data_img, args.data_txt) train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=args.workers) logger.info("==============Build Dataset Done==============") model = East(args.pretrained) logger.info("==============Build Model Done================") logger.info(model) model = torch.nn.DataParallel(model).cuda() if args.resume: if os.path.isfile(args.resume): logger.info("=> loading checkpoint '{}'".format(args.resume)) pretrained_dict = torch.load(args.resume) model.load_state_dict(pretrained_dict, strict=True) logger.info("=> loaded checkpoint '{}'".format(args.resume)) else: logger.info("=> no checkpoint found at '{}'".format(args.resume)) crit = LossFunc() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.94) train(epochs=args.epochs, model=model, train_loader=train_loader, crit=crit, optimizer=optimizer,scheduler=scheduler, save_step=args.save_freq, weight_decay=args.weight_decay)
def main(): root_path = './dataset/' train_img = root_path + 'train2015/' train_txt = root_path + 'train_label/' trainset = custom_dset(train_img, train_txt) print(trainset) trainloader = DataLoader(trainset, batch_size=16, shuffle=True, collate_fn=collate_fn, num_workers=4) model = East() model = model.cuda() crit = LossFunc() weight_decay = 0 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) # weight_decay=1) scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.94) train(epochs=1500, model=model, trainloader=trainloader, crit=crit, optimizer=optimizer, scheduler=scheduler, save_step=20, weight_decay=weight_decay) write.close()
def __init__(self, lr, weight_path, output_path): self.output_path = output_path self.model = East() self.model = nn.DataParallel(self.model, device_ids=[0]) self.model = self.model.cuda() init_weights(self.model, init_type='xavier') cudnn.benchmark = True self.criterion = LossFunc() self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) self.scheduler = lr_scheduler.StepLR(self.optimizer, step_size=10000, gamma=0.94) self.weightpath = os.path.abspath(weight_path) logging.debug( "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Begin".format( self.weightpath)) checkpoint = torch.load(self.weightpath) self.start_epoch = checkpoint['epoch'] self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) logging.debug( "EAST <==> Prepare <==> Loading checkpoint '{}', epoch={} <==> Done" .format(self.weightpath, self.start_epoch)) self.model.eval()
def main(): # Prepare for dataset print('EAST <==> Prepare <==> DataLoader <==> Begin') trainset = custom_dset(transform=transforms.ToTensor()) train_loader = DataLoader(trainset, batch_size=cfg.train_batch_size_per_gpu * cfg.gpu, shuffle=True, num_workers=cfg.num_workers) print('EAST <==> Prepare <==> Batch_size:{} <==> Begin'.format(cfg.train_batch_size_per_gpu * cfg.gpu)) print('EAST <==> Prepare <==> DataLoader <==> Done') # test datalodaer # import numpy as np # import matplotlib.pyplot as plt # for batch_idx, (img, img_path, score_map, geo_map, training_mask) in enumerate(train_loader): # print("batch index:", batch_idx, ",img batch shape", np.shape(geo_map.numpy())) # h1 = img.numpy()[0].transpose(1, 2, 0).astype(np.int64) # h2 = score_map.numpy()[0].transpose(1, 2, 0).astype(np.float32)[:, :, 0] # plt.figure() # plt.subplot(1, 2, 1) # plt.imshow(h1) # plt.subplot(1, 2, 2) # plt.imshow(h2, cmap='gray') # plt.show() # Model print('EAST <==> Prepare <==> Network <==> Begin') model = East() model = torch.nn.DataParallel(model, device_ids=cfg.gpu_ids) criterion = loss.LossFunc().cuda() weight_loss = utils.Regularization(model, cfg.l2_weight_decay, p=2).cuda() pre_params = list(map(id, model.module.mobilenet.parameters())) post_params = filter(lambda p: id(p) not in pre_params, model.module.parameters()) optimizer = torch.optim.Adam([{'params': model.module.mobilenet.parameters(), 'lr': cfg.pre_lr}, {'params': post_params, 'lr': cfg.lr}]) # 计算方式 decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps) scheduler = lr_scheduler.StepLR(optimizer, step_size=cfg.decay_steps, gamma=cfg.decay_rate) model.cuda() # init or resume,恢复模型 if cfg.resume and os.path.isfile(cfg.checkpoint): start_epoch = utils.Loading_checkpoint(model, optimizer, scheduler) else: start_epoch = 0 print('EAST <==> Prepare <==> Network <==> Done') tensorboard_writer = init_tensorboard_writer('tensorboards/{}'.format(str(int(time.time())))) # train Model for epoch in range(start_epoch, cfg.max_epochs): scheduler.step() fit(train_loader, model, criterion, optimizer, epoch, weight_loss,tensorboard_writer) # 保存模型 if epoch % cfg.save_eval_iteration == 0: utils.save_checkpoint(epoch, model, optimizer, scheduler)
def score(dev, latency, batch_size, num_batches): sym1, sym2 = East(isTrain=False) sym = mx.sym.Group([sym1, sym2]) if 'cpu' in str(dev): sym = sym.get_backend_symbol('MKLDNN') # sym, arg, aux = onnx_mxnet.import_model("test.onnx") data_shape = [('data', (batch_size, 3, 1024, 1024))] mod = mx.mod.Module(symbol=sym, context=dev) mod.bind(for_training = False, inputs_need_grad=False, data_shapes=data_shape) mod.init_params(initializer=mx.init.Xavier(magnitude=2.)) # get data data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=dev) for _, shape in mod.data_shapes] batch = mx.io.DataBatch(data, []) # empty label # run dry_run = 5 # use 5 iterations to warm up for i in range(dry_run + num_batches): if i == dry_run: tic = time.time() mod.forward(batch, is_train=False) for output in mod.get_outputs(): output.wait_to_read() if latency: logging.info('latency: %f ms', (time.time() - tic) / num_batches * 1000) # return num images per second return num_batches * batch_size / (time.time() - tic)
def addEast(locationName, locationInfo, pictureLink, top, left): session = createSession() east_object = East(top=top, left=left, locationName=locationName, locationInfo=locationInfo, pictureLink=pictureLink) session.add(east_object) session.commit() session.close()
def main(): root_path = '/home/mathu/Documents/express_recognition/data/telephone_txt/result/' train_img = root_path + 'print_pic' train_txt = root_path + 'print_txt' # root_path = '/home/mathu/Documents/express_recognition/data/icdar2015/' # train_img = root_path + 'train2015' # train_txt = root_path + 'train_label' trainset = custom_dset(train_img, train_txt) trainloader = DataLoader(trainset, batch_size=16, shuffle=True, collate_fn=collate_fn, num_workers=4) model = East() model = model.cuda() model.load_state_dict(torch.load('./checkpoints_total/model_1440.pth')) crit = LossFunc() weight_decay = 0 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) # weight_decay=1) scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.94) train(epochs=1500, model=model, trainloader=trainloader, crit=crit, optimizer=optimizer, scheduler=scheduler, save_step=20, weight_decay=weight_decay) write.close()
def main(): # prepare output directory # global epoch print('EAST <==> TEST <==> Create Res_file and Img_with_box <==> Begin') result_root = os.path.abspath(cfg.res_img_path) if not os.path.exists(result_root): os.mkdir(result_root) print('EAST <==> Prepare <==> Network <==> Begin') model = East() model = torch.nn.DataParallel(model, device_ids=cfg.gpu_ids) model.cuda() # 载入模型 if os.path.isfile(cfg.checkpoint): print( "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Begin".format( cfg.checkpoint)) checkpoint = torch.load(cfg.checkpoint) epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) print( "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Done".format( cfg.checkpoint)) else: print('Can not find checkpoint !!!') exit(1) predict(model, epoch)
def model_init(config): train_root_path = os.path.abspath(os.path.join(config["dataroot"], 'train')) train_img = os.path.join(train_root_path, 'img') train_gt = os.path.join(train_root_path, 'gt') trainset = custom_dset(train_img, train_gt) train_loader = DataLoader(trainset, batch_size=config["train_batch_size_per_gpu"] * config["gpu"], shuffle=True, collate_fn=collate_fn, num_workers=config["num_workers"]) logging.debug('Data loader created: Batch_size:{}, GPU {}:({})'.format( config["train_batch_size_per_gpu"] * config["gpu"], config["gpu"], config["gpu_ids"])) # Model model = East() model = nn.DataParallel(model, device_ids=config["gpu_ids"]) model = model.cuda() init_weights(model, init_type=config["init_type"]) logging.debug("Model initiated, init type: {}".format(config["init_type"])) cudnn.benchmark = True criterion = LossFunc() optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"]) scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.94) # init or resume if config["resume"] and os.path.isfile(config["checkpoint"]): start_epoch = load_checkpoint(config, model, optimizer) else: start_epoch = 0 logging.debug("Model is running...") return model, criterion, optimizer, scheduler, train_loader, start_epoch
def main(): # prepare output directory # global epoch print('EAST <==> TEST <==> Create Res_file and Img_with_box <==> Begin') result_root = os.path.abspath(cfg.res_img_path) if not os.path.exists(result_root): os.mkdir(result_root) print('EAST <==> Prepare <==> Network <==> Begin') model = East() model = torch.nn.DataParallel(model, device_ids=cfg.gpu_ids) model #.cuda() if os.path.isfile(cfg.checkpoint): print( "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Begin".format( cfg.checkpoint)) checkpoint = torch.load(cfg.checkpoint, map_location='cpu') epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) print( "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Done".format( cfg.checkpoint)) else: print('Can not find checkpoint !!!') exit(1) print() print('###############') print() print('Original Size:') print_size_of_model(model) ############### print() print('Pruned model size') import torch.nn.utils.prune as prune for name, module in model.named_modules(): # prune 40% of connections in all 2D-conv layers if isinstance(module, torch.nn.Conv2d): prune.l1_unstructured(module, name='weight', amount=0.4) prune.l1_unstructured(module, name='bias', amount=0.3) prune.remove(module, 'weight') prune.remove(module, 'bias') # prune 40% of connections in all linear layers elif isinstance(module, torch.nn.Linear): prune.l1_unstructured(module, name='weight', amount=0.4) prune.l1_unstructured(module, name='bias', amount=0.4) prune.remove(module, 'weight') prune.remove(module, 'bias') model = model.to_sparse() #print(dict(model.named_buffers()).keys()) print_size_of_model(model)
def main(): hmean = .0 is_best = False warnings.simplefilter('ignore', np.RankWarning) # Prepare for dataset print('EAST <==> Prepare <==> DataLoader <==> Begin') # train_root_path = os.path.abspath(os.path.join('./dataset/', 'train')) train_root_path = cfg.dataroot train_img = os.path.join(train_root_path, 'img') train_gt = os.path.join(train_root_path, 'gt') trainset = custom_dset(train_img, train_gt) train_loader = DataLoader(trainset, batch_size=cfg.train_batch_size_per_gpu * cfg.gpu, shuffle=True, collate_fn=collate_fn, num_workers=cfg.num_workers) print('EAST <==> Prepare <==> Batch_size:{} <==> Begin'.format( cfg.train_batch_size_per_gpu * cfg.gpu)) print('EAST <==> Prepare <==> DataLoader <==> Done') # test datalodaer """ for i in range(100000): for j, (a,b,c,d) in enumerate(train_loader): print(i, j,'/',len(train_loader)) """ # Model print('EAST <==> Prepare <==> Network <==> Begin') model = East() model = nn.DataParallel(model, device_ids=cfg.gpu_ids) model = model.cuda() init_weights(model, init_type=cfg.init_type) cudnn.benchmark = True criterion = LossFunc() optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr) scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.94) # init or resume if cfg.resume and os.path.isfile(cfg.checkpoint): weightpath = os.path.abspath(cfg.checkpoint) print( "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Begin".format( weightpath)) checkpoint = torch.load(weightpath) start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print( "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Done".format( weightpath)) else: start_epoch = 0 print('EAST <==> Prepare <==> Network <==> Done') for epoch in range(start_epoch, cfg.max_epochs): train(train_loader, model, criterion, scheduler, optimizer, epoch) if epoch % cfg.eval_iteration == 0: # create res_file and img_with_box output_txt_dir_path = predict(model, criterion, epoch) # Zip file submit_path = MyZip(output_txt_dir_path, epoch) # submit and compute Hmean hmean_ = compute_hmean(submit_path) if hmean_ > hmean: is_best = True state = { 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'is_best': is_best, } save_checkpoint(state, epoch)
BATCH_SIZE = 8 X_train, X_val, y_train, y_val = train_test_split(image_paths, boxes, test_size=0.35, shuffle=True, random_state=2021) train_dataset = ReceiptDataset(X_train, y_train) val_dataset = ReceiptDataset(X_val, y_val) train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE) val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE) # Model EPOCHS = 50 model = East().to(device) model.load_state_dict(torch.load('east1.pt')) lr = 1e-4 loss_fn = Loss().to(device) optimizer = Adam(model.parameters(), lr=lr) best_val_loss = 0.455 train_loss = list() val_loss = list() for epoch in range(EPOCHS): print('Epoch {}'.format(epoch + 1)) train_batch_loss = list() for X_batch_train, gt_score, gt_geo in tqdm(train_dataloader): X_batch_train = X_batch_train.to(device)
def main(): # prepare output directory # global epoch print('EAST <==> TEST <==> Create Res_file and Img_with_box <==> Begin') result_root = os.path.abspath(cfg.res_img_path) if not os.path.exists(result_root): os.mkdir(result_root) print('EAST <==> Prepare <==> Network <==> Begin') model = East() model = torch.nn.DataParallel(model, device_ids=cfg.gpu_ids) model #.cuda() # 载入模型 if os.path.isfile(cfg.checkpoint): print( "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Begin".format( cfg.checkpoint)) checkpoint = torch.load(cfg.checkpoint, map_location='cpu') epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) print( "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Done".format( cfg.checkpoint)) else: print('Can not find checkpoint !!!') exit(1) print() print('###############') print() example = torch.rand(1, 3, 224, 224) #traced_script_module = torch.jit.trace(model, example) uninplace(model) l1 = [['module.conv1', 'module.bn1', 'module.relu1'], ['module.conv2', 'module.bn2', 'module.relu2'], ['module.conv3', 'module.bn3', 'module.relu3'], ['module.conv4', 'module.bn4', 'module.relu4'], ['module.conv5', 'module.bn5', 'module.relu5'], ['module.conv6', 'module.bn6', 'module.relu6'], ['module.conv7', 'module.bn7', 'module.relu7']] #s4 l2 = [['module.s4.0.conv.0', 'module.s4.0.conv.1', 'module.s4.0.conv.2'], ['module.s4.0.conv.3', 'module.s4.0.conv.4', 'module.s4.0.conv.5'], ['module.s4.0.conv.6', 'module.s4.0.conv.7'], ['module.s4.1.conv.0', 'module.s4.1.conv.1', 'module.s4.1.conv.2'], ['module.s4.1.conv.3', 'module.s4.1.conv.4', 'module.s4.1.conv.5'], ['module.s4.1.conv.6', 'module.s4.1.conv.7'], ['module.s4.2.conv.0', 'module.s4.2.conv.1', 'module.s4.2.conv.2'], ['module.s4.2.conv.3', 'module.s4.2.conv.4', 'module.s4.2.conv.5'], ['module.s4.2.conv.6', 'module.s4.2.conv.7']] #s3 l3 = [['module.s3.0.conv.0', 'module.s3.0.conv.1', 'module.s3.0.conv.2'], ['module.s3.0.conv.3', 'module.s3.0.conv.4', 'module.s3.0.conv.5'], ['module.s3.0.conv.6', 'module.s3.0.conv.7'], ['module.s3.1.conv.0', 'module.s3.1.conv.1', 'module.s3.1.conv.2'], ['module.s3.1.conv.3', 'module.s3.1.conv.4', 'module.s3.1.conv.5'], ['module.s3.1.conv.6', 'module.s3.1.conv.7'], ['module.s3.2.conv.0', 'module.s3.2.conv.1', 'module.s3.2.conv.2'], ['module.s3.2.conv.3', 'module.s3.2.conv.4', 'module.s3.2.conv.5'], ['module.s3.2.conv.6', 'module.s3.2.conv.7'], ['module.s3.3.conv.0', 'module.s3.3.conv.1', 'module.s3.3.conv.2'], ['module.s3.3.conv.3', 'module.s3.3.conv.4', 'module.s3.3.conv.5'], ['module.s3.3.conv.6', 'module.s3.3.conv.7'], ['module.s3.4.conv.0', 'module.s3.4.conv.1', 'module.s3.4.conv.2'], ['module.s3.4.conv.3', 'module.s3.4.conv.4', 'module.s3.4.conv.5'], ['module.s3.4.conv.6', 'module.s3.4.conv.7'], ['module.s3.5.conv.0', 'module.s3.5.conv.1', 'module.s3.5.conv.2'], ['module.s3.5.conv.3', 'module.s3.5.conv.4', 'module.s3.5.conv.5'], ['module.s3.5.conv.6', 'module.s3.5.conv.7'], ['module.s3.6.conv.0', 'module.s3.6.conv.1', 'module.s3.6.conv.2'], ['module.s3.6.conv.3', 'module.s3.6.conv.4', 'module.s3.6.conv.5'], ['module.s3.6.conv.6', 'module.s3.6.conv.7']] #s2 l4 = [['module.s2.0.conv.0', 'module.s2.0.conv.1', 'module.s2.0.conv.2'], ['module.s2.0.conv.3', 'module.s2.0.conv.4', 'module.s2.0.conv.5'], ['module.s2.0.conv.6', 'module.s2.0.conv.7'], ['module.s2.1.conv.0', 'module.s2.1.conv.1', 'module.s2.1.conv.2'], ['module.s2.1.conv.3', 'module.s2.1.conv.4', 'module.s2.1.conv.5'], ['module.s2.1.conv.6', 'module.s2.1.conv.7'], ['module.s2.2.conv.0', 'module.s2.2.conv.1', 'module.s2.2.conv.2'], ['module.s2.2.conv.3', 'module.s2.2.conv.4', 'module.s2.2.conv.5'], ['module.s2.2.conv.6', 'module.s2.2.conv.7']] #s1 l5 = [['module.s1.0.0', 'module.s1.0.1', 'module.s1.0.2'], ['module.s1.1.conv.0', 'module.s1.1.conv.1', 'module.s1.1.conv.2'], ['module.s1.1.conv.3', 'module.s1.1.conv.4'], ['module.s1.2.conv.0', 'module.s1.2.conv.1', 'module.s1.2.conv.2'], ['module.s1.2.conv.3', 'module.s1.2.conv.4', 'module.s1.2.conv.5'], ['module.s1.2.conv.6', 'module.s1.2.conv.7']] #( s1 - 1 and 2) l6 = [[ 'module.mobilenet.features.0.0', 'module.mobilenet.features.0.1', 'module.mobilenet.features.0.2' ], [ 'module.mobilenet.features.1.conv.0', 'module.mobilenet.features.1.conv.1', 'module.mobilenet.features.1.conv.2' ], [ 'module.mobilenet.features.1.conv.3', 'module.mobilenet.features.1.conv.4' ]] l7 = [[ 'module.mobilenet.features.2.conv.0', 'module.mobilenet.features.2.conv.1', 'module.mobilenet.features.2.conv.2' ], [ 'module.mobilenet.features.2.conv.3', 'module.mobilenet.features.2.conv.4', 'module.mobilenet.features.2.conv.5' ], [ 'module.mobilenet.features.2.conv.6', 'module.mobilenet.features.2.conv.7' ], [ 'module.mobilenet.features.3.conv.0', 'module.mobilenet.features.3.conv.1', 'module.mobilenet.features.3.conv.2' ], [ 'module.mobilenet.features.3.conv.3', 'module.mobilenet.features.3.conv.4', 'module.mobilenet.features.3.conv.5' ], [ 'module.mobilenet.features.3.conv.6', 'module.mobilenet.features.3.conv.7' ], [ 'module.mobilenet.features.4.conv.0', 'module.mobilenet.features.4.conv.1', 'module.mobilenet.features.4.conv.2' ], [ 'module.mobilenet.features.4.conv.3', 'module.mobilenet.features.4.conv.4', 'module.mobilenet.features.4.conv.5' ], [ 'module.mobilenet.features.4.conv.6', 'module.mobilenet.features.4.conv.7' ], [ 'module.mobilenet.features.5.conv.0', 'module.mobilenet.features.5.conv.1', 'module.mobilenet.features.5.conv.2' ], [ 'module.mobilenet.features.5.conv.3', 'module.mobilenet.features.5.conv.4', 'module.mobilenet.features.5.conv.5' ], [ 'module.mobilenet.features.5.conv.6', 'module.mobilenet.features.5.conv.7' ], [ 'module.mobilenet.features.6.conv.0', 'module.mobilenet.features.6.conv.1', 'module.mobilenet.features.6.conv.2' ], [ 'module.mobilenet.features.6.conv.3', 'module.mobilenet.features.6.conv.4', 'module.mobilenet.features.6.conv.5' ], [ 'module.mobilenet.features.6.conv.6', 'module.mobilenet.features.6.conv.7' ], [ 'module.mobilenet.features.7.conv.0', 'module.mobilenet.features.7.conv.1', 'module.mobilenet.features.7.conv.2' ], [ 'module.mobilenet.features.7.conv.3', 'module.mobilenet.features.7.conv.4', 'module.mobilenet.features.7.conv.5' ], [ 'module.mobilenet.features.7.conv.6', 'module.mobilenet.features.7.conv.7' ], [ 'module.mobilenet.features.8.conv.0', 'module.mobilenet.features.8.conv.1', 'module.mobilenet.features.8.conv.2' ], [ 'module.mobilenet.features.8.conv.3', 'module.mobilenet.features.8.conv.4', 'module.mobilenet.features.8.conv.5' ], [ 'module.mobilenet.features.8.conv.6', 'module.mobilenet.features.8.conv.7' ], [ 'module.mobilenet.features.9.conv.0', 'module.mobilenet.features.9.conv.1', 'module.mobilenet.features.9.conv.2' ], [ 'module.mobilenet.features.9.conv.3', 'module.mobilenet.features.9.conv.4', 'module.mobilenet.features.9.conv.5' ], [ 'module.mobilenet.features.9.conv.6', 'module.mobilenet.features.9.conv.7' ], [ 'module.mobilenet.features.10.conv.0', 'module.mobilenet.features.10.conv.1', 'module.mobilenet.features.10.conv.2' ], [ 'module.mobilenet.features.10.conv.3', 'module.mobilenet.features.10.conv.4', 'module.mobilenet.features.10.conv.5' ], [ 'module.mobilenet.features.10.conv.6', 'module.mobilenet.features.10.conv.7' ], [ 'module.mobilenet.features.11.conv.0', 'module.mobilenet.features.11.conv.1', 'module.mobilenet.features.11.conv.2' ], [ 'module.mobilenet.features.11.conv.3', 'module.mobilenet.features.11.conv.4', 'module.mobilenet.features.11.conv.5' ], [ 'module.mobilenet.features.11.conv.6', 'module.mobilenet.features.11.conv.7' ], [ 'module.mobilenet.features.12.conv.0', 'module.mobilenet.features.12.conv.1', 'module.mobilenet.features.12.conv.2' ], [ 'module.mobilenet.features.12.conv.3', 'module.mobilenet.features.12.conv.4', 'module.mobilenet.features.12.conv.5' ], [ 'module.mobilenet.features.12.conv.6', 'module.mobilenet.features.12.conv.7' ], [ 'module.mobilenet.features.13.conv.0', 'module.mobilenet.features.13.conv.1', 'module.mobilenet.features.13.conv.2' ], [ 'module.mobilenet.features.13.conv.3', 'module.mobilenet.features.13.conv.4', 'module.mobilenet.features.13.conv.5' ], [ 'module.mobilenet.features.13.conv.6', 'module.mobilenet.features.13.conv.7' ], [ 'module.mobilenet.features.14.conv.0', 'module.mobilenet.features.14.conv.1', 'module.mobilenet.features.14.conv.2' ], [ 'module.mobilenet.features.14.conv.3', 'module.mobilenet.features.14.conv.4', 'module.mobilenet.features.14.conv.5' ], [ 'module.mobilenet.features.14.conv.6', 'module.mobilenet.features.14.conv.7' ], [ 'module.mobilenet.features.15.conv.0', 'module.mobilenet.features.15.conv.1', 'module.mobilenet.features.15.conv.2' ], [ 'module.mobilenet.features.15.conv.3', 'module.mobilenet.features.15.conv.4', 'module.mobilenet.features.15.conv.5' ], [ 'module.mobilenet.features.15.conv.6', 'module.mobilenet.features.15.conv.7' ], [ 'module.mobilenet.features.16.conv.0', 'module.mobilenet.features.16.conv.1', 'module.mobilenet.features.16.conv.2' ], [ 'module.mobilenet.features.16.conv.3', 'module.mobilenet.features.16.conv.4', 'module.mobilenet.features.16.conv.5' ], [ 'module.mobilenet.features.16.conv.6', 'module.mobilenet.features.16.conv.7' ]] #modules_to_fuse=l1+l2+l3+l4+l5 #print(model) print('Original Size:') print_size_of_model(model) print() fused_model = torch.quantization.fuse_modules(model, l7) fused_model = torch.quantization.fuse_modules(model, l6) fused_model = torch.quantization.fuse_modules(model, l1 + l2 + l3 + l4 + l5) print('Fused model Size:') print_size_of_model(fused_model) print() #print(fused_model) #fused_model.qconfig = torch.quantization.QConfig(activation=torch.quantization.default_histogram_observer,weight=torch.quantization.default_per_channel_weight_observer) fused_model.qconfig = torch.quantization.default_qconfig torch.quantization.prepare(fused_model, inplace=True) from data_loader import custom_dset from torchvision import transforms from torch.utils.data import DataLoader trainset = custom_dset(transform=transforms.ToTensor()) train_loader = DataLoader(trainset, batch_size=cfg.train_batch_size_per_gpu * cfg.gpu, shuffle=True, num_workers=0) for i, (img, img_path, score_map, geo_map, training_mask) in enumerate(train_loader): f_score, f_geometry = fused_model(img) quantized = torch.quantization.convert(fused_model, inplace=False) print('Quantized model Size:') print_size_of_model(quantized) print('Done')
class EASTPredictor(object): def __init__(self, lr, weight_path, output_path): self.output_path = output_path self.model = East() self.model = nn.DataParallel(self.model, device_ids=[0]) self.model = self.model.cuda() init_weights(self.model, init_type='xavier') cudnn.benchmark = True self.criterion = LossFunc() self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) self.scheduler = lr_scheduler.StepLR(self.optimizer, step_size=10000, gamma=0.94) self.weightpath = os.path.abspath(weight_path) logging.debug("EAST <==> Prepare <==> Loading checkpoint '{}' <==> Begin".format(self.weightpath)) checkpoint = torch.load(self.weightpath) self.start_epoch = checkpoint['epoch'] self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) logging.debug("EAST <==> Prepare <==> Loading checkpoint '{}', epoch={} <==> Done".format(self.weightpath, self.start_epoch)) self.model.eval() def resize_image(self, im, max_side_len=2400): ''' resize image to a size multiple of 32 which is required by the network :param im: the resized image :param max_side_len: limit of max image size to avoid out of memory in gpu :return: the resized image and the resize ratio ''' h, w, _ = im.shape resize_w = w resize_h = h resize_h = resize_h if resize_h % 32 == 0 else (resize_h // 32 - 1) * 32 resize_w = resize_w if resize_w % 32 == 0 else (resize_w // 32 - 1) * 32 #resize_h, resize_w = 512, 512 im = cv2.resize(im, (int(resize_w), int(resize_h))) ratio_h = resize_h / float(h) ratio_w = resize_w / float(w) return im, (ratio_h, ratio_w) def detect(self, score_map, geo_map, score_map_thresh=1e-5, box_thresh=1e-8, nms_thres=0.1): ''' restore text boxes from score map and geo map :param score_map: :param geo_map: :param score_map_thresh: threshhold for score map :param box_thresh: threshhold for boxes :param nms_thres: threshold for nms :return: ''' if len(score_map.shape) == 4: score_map = score_map[0, :, :, 0] geo_map = geo_map[0, :, :, ] # filter the score map xy_text = np.argwhere(score_map > score_map_thresh) # sort the text boxes via the y axis xy_text = xy_text[np.argsort(xy_text[:, 0])] # restore start = time.time() text_box_restored = restore_rectangle(xy_text[:, ::-1]*4, geo_map[xy_text[:, 0], xy_text[:, 1], :]) # N*4*2 logging.debug('{} text boxes before nms'.format(text_box_restored.shape[0])) boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32) boxes[:, :8] = text_box_restored.reshape((-1, 8)) boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]] # nms part start = time.time() # boxes = nms_locality.nms_locality(boxes.astype(np.float64), nms_thres) logging.debug('{} boxes before merging'.format(boxes.shape[0])) boxes = lanms.merge_quadrangle_n9(boxes.astype('float32'), nms_thres) if boxes.shape[0] == 0: return None logging.debug('{} boxes before checking scores'.format(boxes.shape[0])) # here we filter some low score boxes by the average score map, this is different from the orginal paper for i, box in enumerate(boxes): mask = np.zeros_like(score_map, dtype=np.uint8) cv2.fillPoly(mask, box[:8].reshape((-1, 4, 2)).astype(np.int32) // 4, 1) boxes[i, 8] = cv2.mean(score_map, mask)[0] boxes = boxes[boxes[:, 8] > box_thresh] return boxes def sort_poly(self, p): min_axis = np.argmin(np.sum(p, axis=1)) p = p[[min_axis, (min_axis+1)%4, (min_axis+2)%4, (min_axis+3)%4]] if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]): return p else: return p[[0, 3, 2, 1]] def predict_one_file(self, img_file): im = cv2.imread(img_file)[:, :, ::-1] return predict_one_image(im. img_file) def predict_one_image(self, im, img_file): im_resized, (ratio_h, ratio_w) = self.resize_image(im) im_resized = im_resized.astype(np.float32) im_resized = im_resized.transpose(2, 0, 1) im_resized = torch.from_numpy(im_resized) im_resized = im_resized.cuda() im_resized = im_resized.unsqueeze(0) score, geometry = self.model(im_resized) score = score.permute(0, 2, 3, 1) geometry = geometry.permute(0, 2, 3, 1) score = score.data.cpu().numpy() geometry = geometry.data.cpu().numpy() boxes = self.detect(score_map=score, geo_map=geometry) letters = None if boxes is not None: boxes = boxes[:, :8].reshape((-1, 4, 2)) boxes[:, :, 0] /= ratio_w boxes[:, :, 1] /= ratio_h logging.debug("found {} boxes".format(len(boxes))) fstem = pathlib.Path(img_file).stem letters = self.save_boxes(os.path.join(self.output_path, fstem + "_boxes.txt"), boxes) cv2.imwrite(os.path.join(self.output_path, fstem + "_with_box.jpg"), im[:, :, ::-1]) else: logging.debug("Did not find boxes") return letters def save_boxes(self, filename, boxes): letters = [] with open(filename, 'w+') as f: for box in boxes: box = self.sort_poly(box.astype(np.int32)) if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3] - box[0]) < 5: logging.debug('wrong direction') continue #if box[0, 0] < 0 or box[0, 1] < 0 or box[1,0] < 0 or box[1,1] < 0 or box[2,0]<0 or box[2,1]<0 or box[3,0] < 0 or box[3,1]<0: # logging.debug("wrong box, {}".format(box)) # continue for x in range(4): for y in [0, 1]: if (box[x, y] < 0): box[x, y] = 0 poly = np.array([[box[0, 0], box[0, 1]], [box[1, 0], box[1, 1]], [box[2, 0], box[2, 1]], [box[3, 0], box[3, 1]]]) p_area = polygon_area(poly) if p_area > 0: poly = poly[(0, 3, 2, 1), :] f.write('{},{},{},{},{},{},{},{}\r\n' .format(poly[0, 0], poly[0, 1], poly[1, 0], poly[1, 1], poly[2, 0], poly[2, 1], poly[3, 0], poly[3, 1],)) # cv2.polylines(im[:, :, ::-1], [box.astype(np.int32).reshape((-1, 1, 2))], True, color=(255, 0, 0), thickness=1) letters.append(('', poly[0, 0], poly[0, 1], poly[2, 0], poly[2, 1])) return letters
def main(): # prepare output directory # global epoch print('EAST <==> TEST <==> Create Res_file and Img_with_box <==> Begin') result_root = os.path.abspath(cfg.res_img_path) if not os.path.exists(result_root): os.mkdir(result_root) print('EAST <==> Prepare <==> Network <==> Begin') model = East() model = torch.nn.DataParallel(model, device_ids=cfg.gpu_ids) model #.cuda() # 载入模型 if os.path.isfile(cfg.checkpoint): print( "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Begin".format( cfg.checkpoint)) checkpoint = torch.load(cfg.checkpoint, map_location='cpu') epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) print( "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Done".format( cfg.checkpoint)) else: print('Can not find checkpoint !!!') exit(1) print() print('###############') print() example = torch.rand(1, 3, 224, 224) #traced_script_module = torch.jit.trace(model, example) uninplace(model) l1 = [['module.conv1', 'module.bn1', 'module.relu1'], ['module.conv2', 'module.bn2', 'module.relu2'], ['module.conv3', 'module.bn3', 'module.relu3'], ['module.conv4', 'module.bn4', 'module.relu4'], ['module.conv5', 'module.bn5', 'module.relu5'], ['module.conv6', 'module.bn6', 'module.relu6'], ['module.conv7', 'module.bn7', 'module.relu7']] #s4 l2 = [['module.s4.0.conv.0', 'module.s4.0.conv.1', 'module.s4.0.conv.2'], ['module.s4.0.conv.3', 'module.s4.0.conv.4', 'module.s4.0.conv.5'], ['module.s4.0.conv.6', 'module.s4.0.conv.7'], ['module.s4.1.conv.0', 'module.s4.1.conv.1', 'module.s4.1.conv.2'], ['module.s4.1.conv.3', 'module.s4.1.conv.4', 'module.s4.1.conv.5'], ['module.s4.1.conv.6', 'module.s4.1.conv.7'], ['module.s4.2.conv.0', 'module.s4.2.conv.1', 'module.s4.2.conv.2'], ['module.s4.2.conv.3', 'module.s4.2.conv.4', 'module.s4.2.conv.5'], ['module.s4.2.conv.6', 'module.s4.2.conv.7']] #s3 l3 = [['module.s3.0.conv.0', 'module.s3.0.conv.1', 'module.s3.0.conv.2'], ['module.s3.0.conv.3', 'module.s3.0.conv.4', 'module.s3.0.conv.5'], ['module.s3.0.conv.6', 'module.s3.0.conv.7'], ['module.s3.1.conv.0', 'module.s3.1.conv.1', 'module.s3.1.conv.2'], ['module.s3.1.conv.3', 'module.s3.1.conv.4', 'module.s3.1.conv.5'], ['module.s3.1.conv.6', 'module.s3.1.conv.7'], ['module.s3.2.conv.0', 'module.s3.2.conv.1', 'module.s3.2.conv.2'], ['module.s3.2.conv.3', 'module.s3.2.conv.4', 'module.s3.2.conv.5'], ['module.s3.2.conv.6', 'module.s3.2.conv.7'], ['module.s3.3.conv.0', 'module.s3.3.conv.1', 'module.s3.3.conv.2'], ['module.s3.3.conv.3', 'module.s3.3.conv.4', 'module.s3.3.conv.5'], ['module.s3.3.conv.6', 'module.s3.3.conv.7'], ['module.s3.4.conv.0', 'module.s3.4.conv.1', 'module.s3.4.conv.2'], ['module.s3.4.conv.3', 'module.s3.4.conv.4', 'module.s3.4.conv.5'], ['module.s3.4.conv.6', 'module.s3.4.conv.7'], ['module.s3.5.conv.0', 'module.s3.5.conv.1', 'module.s3.5.conv.2'], ['module.s3.5.conv.3', 'module.s3.5.conv.4', 'module.s3.5.conv.5'], ['module.s3.5.conv.6', 'module.s3.5.conv.7'], ['module.s3.6.conv.0', 'module.s3.6.conv.1', 'module.s3.6.conv.2'], ['module.s3.6.conv.3', 'module.s3.6.conv.4', 'module.s3.6.conv.5'], ['module.s3.6.conv.6', 'module.s3.6.conv.7']] #s2 l4 = [['module.s2.0.conv.0', 'module.s2.0.conv.1', 'module.s2.0.conv.2'], ['module.s2.0.conv.3', 'module.s2.0.conv.4', 'module.s2.0.conv.5'], ['module.s2.0.conv.6', 'module.s2.0.conv.7'], ['module.s2.1.conv.0', 'module.s2.1.conv.1', 'module.s2.1.conv.2'], ['module.s2.1.conv.3', 'module.s2.1.conv.4', 'module.s2.1.conv.5'], ['module.s2.1.conv.6', 'module.s2.1.conv.7'], ['module.s2.2.conv.0', 'module.s2.2.conv.1', 'module.s2.2.conv.2'], ['module.s2.2.conv.3', 'module.s2.2.conv.4', 'module.s2.2.conv.5'], ['module.s2.2.conv.6', 'module.s2.2.conv.7']] #s1 l5 = [['module.s1.0.0', 'module.s1.0.1', 'module.s1.0.2'], ['module.s1.1.conv.0', 'module.s1.1.conv.1', 'module.s1.1.conv.2'], ['module.s1.1.conv.3', 'module.s1.1.conv.4'], ['module.s1.2.conv.0', 'module.s1.2.conv.1', 'module.s1.2.conv.2'], ['module.s1.2.conv.3', 'module.s1.2.conv.4', 'module.s1.2.conv.5'], ['module.s1.2.conv.6', 'module.s1.2.conv.7']] #( s1 - 1 and 2) l6 = [[ 'module.mobilenet.features.0.0', 'module.mobilenet.features.0.1', 'module.mobilenet.features.0.2' ], [ 'module.mobilenet.features.1.conv.0', 'module.mobilenet.features.1.conv.1', 'module.mobilenet.features.1.conv.2' ], [ 'module.mobilenet.features.1.conv.3', 'module.mobilenet.features.1.conv.4' ]] l7 = [[ 'module.mobilenet.features.2.conv.0', 'module.mobilenet.features.2.conv.1', 'module.mobilenet.features.2.conv.2' ], [ 'module.mobilenet.features.2.conv.3', 'module.mobilenet.features.2.conv.4', 'module.mobilenet.features.2.conv.5' ], [ 'module.mobilenet.features.2.conv.6', 'module.mobilenet.features.2.conv.7' ], [ 'module.mobilenet.features.3.conv.0', 'module.mobilenet.features.3.conv.1', 'module.mobilenet.features.3.conv.2' ], [ 'module.mobilenet.features.3.conv.3', 'module.mobilenet.features.3.conv.4', 'module.mobilenet.features.3.conv.5' ], [ 'module.mobilenet.features.3.conv.6', 'module.mobilenet.features.3.conv.7' ], [ 'module.mobilenet.features.4.conv.0', 'module.mobilenet.features.4.conv.1', 'module.mobilenet.features.4.conv.2' ], [ 'module.mobilenet.features.4.conv.3', 'module.mobilenet.features.4.conv.4', 'module.mobilenet.features.4.conv.5' ], [ 'module.mobilenet.features.4.conv.6', 'module.mobilenet.features.4.conv.7' ], [ 'module.mobilenet.features.5.conv.0', 'module.mobilenet.features.5.conv.1', 'module.mobilenet.features.5.conv.2' ], [ 'module.mobilenet.features.5.conv.3', 'module.mobilenet.features.5.conv.4', 'module.mobilenet.features.5.conv.5' ], [ 'module.mobilenet.features.5.conv.6', 'module.mobilenet.features.5.conv.7' ], [ 'module.mobilenet.features.6.conv.0', 'module.mobilenet.features.6.conv.1', 'module.mobilenet.features.6.conv.2' ], [ 'module.mobilenet.features.6.conv.3', 'module.mobilenet.features.6.conv.4', 'module.mobilenet.features.6.conv.5' ], [ 'module.mobilenet.features.6.conv.6', 'module.mobilenet.features.6.conv.7' ], [ 'module.mobilenet.features.7.conv.0', 'module.mobilenet.features.7.conv.1', 'module.mobilenet.features.7.conv.2' ], [ 'module.mobilenet.features.7.conv.3', 'module.mobilenet.features.7.conv.4', 'module.mobilenet.features.7.conv.5' ], [ 'module.mobilenet.features.7.conv.6', 'module.mobilenet.features.7.conv.7' ], [ 'module.mobilenet.features.8.conv.0', 'module.mobilenet.features.8.conv.1', 'module.mobilenet.features.8.conv.2' ], [ 'module.mobilenet.features.8.conv.3', 'module.mobilenet.features.8.conv.4', 'module.mobilenet.features.8.conv.5' ], [ 'module.mobilenet.features.8.conv.6', 'module.mobilenet.features.8.conv.7' ], [ 'module.mobilenet.features.9.conv.0', 'module.mobilenet.features.9.conv.1', 'module.mobilenet.features.9.conv.2' ], [ 'module.mobilenet.features.9.conv.3', 'module.mobilenet.features.9.conv.4', 'module.mobilenet.features.9.conv.5' ], [ 'module.mobilenet.features.9.conv.6', 'module.mobilenet.features.9.conv.7' ], [ 'module.mobilenet.features.10.conv.0', 'module.mobilenet.features.10.conv.1', 'module.mobilenet.features.10.conv.2' ], [ 'module.mobilenet.features.10.conv.3', 'module.mobilenet.features.10.conv.4', 'module.mobilenet.features.10.conv.5' ], [ 'module.mobilenet.features.10.conv.6', 'module.mobilenet.features.10.conv.7' ], [ 'module.mobilenet.features.11.conv.0', 'module.mobilenet.features.11.conv.1', 'module.mobilenet.features.11.conv.2' ], [ 'module.mobilenet.features.11.conv.3', 'module.mobilenet.features.11.conv.4', 'module.mobilenet.features.11.conv.5' ], [ 'module.mobilenet.features.11.conv.6', 'module.mobilenet.features.11.conv.7' ], [ 'module.mobilenet.features.12.conv.0', 'module.mobilenet.features.12.conv.1', 'module.mobilenet.features.12.conv.2' ], [ 'module.mobilenet.features.12.conv.3', 'module.mobilenet.features.12.conv.4', 'module.mobilenet.features.12.conv.5' ], [ 'module.mobilenet.features.12.conv.6', 'module.mobilenet.features.12.conv.7' ], [ 'module.mobilenet.features.13.conv.0', 'module.mobilenet.features.13.conv.1', 'module.mobilenet.features.13.conv.2' ], [ 'module.mobilenet.features.13.conv.3', 'module.mobilenet.features.13.conv.4', 'module.mobilenet.features.13.conv.5' ], [ 'module.mobilenet.features.13.conv.6', 'module.mobilenet.features.13.conv.7' ], [ 'module.mobilenet.features.14.conv.0', 'module.mobilenet.features.14.conv.1', 'module.mobilenet.features.14.conv.2' ], [ 'module.mobilenet.features.14.conv.3', 'module.mobilenet.features.14.conv.4', 'module.mobilenet.features.14.conv.5' ], [ 'module.mobilenet.features.14.conv.6', 'module.mobilenet.features.14.conv.7' ], [ 'module.mobilenet.features.15.conv.0', 'module.mobilenet.features.15.conv.1', 'module.mobilenet.features.15.conv.2' ], [ 'module.mobilenet.features.15.conv.3', 'module.mobilenet.features.15.conv.4', 'module.mobilenet.features.15.conv.5' ], [ 'module.mobilenet.features.15.conv.6', 'module.mobilenet.features.15.conv.7' ], [ 'module.mobilenet.features.16.conv.0', 'module.mobilenet.features.16.conv.1', 'module.mobilenet.features.16.conv.2' ], [ 'module.mobilenet.features.16.conv.3', 'module.mobilenet.features.16.conv.4', 'module.mobilenet.features.16.conv.5' ], [ 'module.mobilenet.features.16.conv.6', 'module.mobilenet.features.16.conv.7' ]] #modules_to_fuse=l1+l2+l3+l4+l5 #print(model) print('Original Size:') print_size_of_model(model) print() fused_model = torch.quantization.fuse_modules(model, l7) fused_model = torch.quantization.fuse_modules(model, l6) fused_model = torch.quantization.fuse_modules(model, l1 + l2 + l3 + l4 + l5) print('Fused model Size:') print_size_of_model(fused_model) print() #print(fused_model) #fused_model.qconfig = torch.quantization.QConfig(activation=torch.quantization.default_histogram_observer,weight=torch.quantization.default_per_channel_weight_observer) fused_model.qconfig = torch.quantization.default_qconfig torch.quantization.prepare(fused_model, inplace=True) from data_loader import custom_dset from torchvision import transforms from torch.utils.data import DataLoader trainset = custom_dset(transform=transforms.ToTensor()) device = torch.device('cpu') train_loader = DataLoader(trainset, batch_size=cfg.train_batch_size_per_gpu * cfg.gpu, shuffle=True, num_workers=0) for i, (img, img_path, score_map, geo_map, training_mask) in enumerate(train_loader): img, score_map, geo_map, training_mask = img.to(device), score_map.to( device), geo_map.to(device), training_mask.to(device) f_score, f_geometry = fused_model(img) quantized = torch.quantization.convert(fused_model, inplace=False) print('Quantized model Size:') print_size_of_model(quantized) num_train_batches = 20 print('***QAT***') print() criterion = loss.LossFunc() pre_params = list(map(id, model.module.mobilenet.parameters())) post_params = filter(lambda p: id(p) not in pre_params, model.module.parameters()) optimizer = torch.optim.Adam([{ 'params': model.module.mobilenet.parameters(), 'lr': cfg.pre_lr }, { 'params': post_params, 'lr': cfg.lr }]) fused_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') # Train and check accuracy after each epoch for nepoch in range(8): train_one_epoch(fused_model, criterion, optimizer, train_loader, torch.device('cpu'), num_train_batches) if nepoch > 3: # Freeze quantizer parameters fused_model.apply(torch.quantization.disable_observer) if nepoch > 2: # Freeze batch norm mean and variance estimates fused_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) quantized_model = torch.quantization.convert(fused_model.eval(), inplace=False) print('QAT model Size:') print_size_of_model(quantized_model) print('Done') print(quantized)
def main(): warnings.simplefilter('ignore', np.RankWarning) #Model video_root_path = os.path.abspath('./dataset/train/') video_name_list = sorted( [p for p in os.listdir(video_root_path) if p.split('_')[0] == 'Video']) #print('video_name_list', video_name_list) print('EAST <==> Prepare <==> Network <==> Begin') model = East() AGD_model = AGD() model = nn.DataParallel(model, device_ids=cfg.gpu_ids) #AGD_model = nn.DataParallel(AGD_model, device_ids=cfg.gpu_ids) model = model.cuda() AGD_model = AGD_model.cuda() init_weights(model, init_type=cfg.init_type) cudnn.benchmark = True criterion1 = LossFunc() # criterion2 = Ass_loss() optimizer1 = torch.optim.Adam(model.parameters(), lr=cfg.lr) optimizer2 = torch.optim.Adam(AGD_model.parameters(), lr=cfg.lr) scheduler = lr_scheduler.StepLR(optimizer1, step_size=10000, gamma=0.94) # init or resume if cfg.resume and os.path.isfile(cfg.checkpoint): weightpath = os.path.abspath(cfg.checkpoint) print( "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Begin".format( weightpath)) checkpoint = torch.load(weightpath) start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) #AGD_model.load_state_dict(checkpoint['model2.state_dict']) optimizer1.load_state_dict(checkpoint['optimizer']) #optimizer2.load_state_dict(checkpoint['optimizer2']) print( "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Done".format( weightpath)) else: start_epoch = 0 print('EAST <==> Prepare <==> Network <==> Done') for epoch in range(start_epoch + 1, cfg.max_epochs): for video_name in video_name_list: print( 'EAST <==> epoch:{} <==> Prepare <==> DataLoader <==>{} Begin'. format(epoch, video_name)) trainset = custom_dset(os.path.join(video_root_path, video_name)) #sampler = sampler_for_video_clip(len(trainset)) train_loader = DataLoader(trainset, batch_size=cfg.train_batch_size_per_gpu * cfg.gpu, shuffle=False, collate_fn=collate_fn, num_workers=cfg.num_workers, drop_last=True) print('EAST <==> Prepare <==> Batch_size:{} <==> Begin'.format( cfg.train_batch_size_per_gpu * cfg.gpu)) print( 'EAST <==> epoch:{} <==> Prepare <==> DataLoader <==>{} Done'. format(epoch, video_name)) train(train_loader, model, AGD_model, criterion1, criterion2, scheduler, optimizer1, optimizer2, epoch) ''' for i, (img, score_map, geo_map, training_mask, coord_ids) in enumerate(train_loader): print('i{} img.shape:{} geo_map.shape{} training_mask.shape{} coord_ids.len{}'.format(i, score_map.shape, geo_map.shape, training_mask.shape, len(coord_ids))) ''' if epoch % cfg.eval_iteration == 0: state = { 'epoch': epoch, 'model1.state_dict': model.state_dict(), 'model2.state_dict': AGD_model.state_dict(), 'optimizer1': optimizer1.state_dict(), 'optimizer2': optimizer2.state_dict() } save_checkpoint(state, epoch)