Ejemplo n.º 1
0
def main():
    global args
    args = parser.parse_args()
    init_log('global', logging.INFO)
    logger = logging.getLogger('global')

    train_data = custom_dset(args.data_img, args.data_txt)
    train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                              collate_fn=collate_fn, num_workers=args.workers)
    logger.info("==============Build Dataset Done==============")

    model = East(args.pretrained)
    logger.info("==============Build Model Done================")
    logger.info(model)

    model = torch.nn.DataParallel(model).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            logger.info("=> loading checkpoint '{}'".format(args.resume))
            pretrained_dict = torch.load(args.resume)
            model.load_state_dict(pretrained_dict, strict=True)
            logger.info("=> loaded checkpoint '{}'".format(args.resume))
        else:
            logger.info("=> no checkpoint found at '{}'".format(args.resume))

    crit = LossFunc()

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, 
                                    gamma=0.94)   
    
    train(epochs=args.epochs, model=model, train_loader=train_loader,
          crit=crit, optimizer=optimizer,scheduler=scheduler, 
          save_step=args.save_freq, weight_decay=args.weight_decay)
Ejemplo n.º 2
0
def main():
    root_path = '/home/mathu/Documents/express_recognition/data/telephone_txt/result/'
    train_img = root_path + 'print_pic'
    train_txt = root_path + 'print_txt'
    # root_path = '/home/mathu/Documents/express_recognition/data/icdar2015/'
    # train_img = root_path + 'train2015'
    # train_txt = root_path + 'train_label'

    trainset = custom_dset(train_img, train_txt)
    trainloader = DataLoader(trainset,
                             batch_size=16,
                             shuffle=True,
                             collate_fn=collate_fn,
                             num_workers=4)
    model = East()
    model = model.cuda()
    model.load_state_dict(torch.load('./checkpoints_total/model_1440.pth'))

    crit = LossFunc()
    weight_decay = 0
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    #  weight_decay=1)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.94)

    train(epochs=1500,
          model=model,
          trainloader=trainloader,
          crit=crit,
          optimizer=optimizer,
          scheduler=scheduler,
          save_step=20,
          weight_decay=weight_decay)

    write.close()
Ejemplo n.º 3
0
def main():
    # prepare output directory
    # global epoch
    print('EAST <==> TEST <==> Create Res_file and Img_with_box <==> Begin')
    result_root = os.path.abspath(cfg.res_img_path)
    if not os.path.exists(result_root):
        os.mkdir(result_root)

    print('EAST <==> Prepare <==> Network <==> Begin')
    model = East()
    model = torch.nn.DataParallel(model, device_ids=cfg.gpu_ids)
    model.cuda()
    # 载入模型
    if os.path.isfile(cfg.checkpoint):
        print(
            "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Begin".format(
                cfg.checkpoint))
        checkpoint = torch.load(cfg.checkpoint)
        epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        print(
            "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Done".format(
                cfg.checkpoint))
    else:
        print('Can not find checkpoint !!!')
        exit(1)

    predict(model, epoch)
Ejemplo n.º 4
0
def score(dev, latency, batch_size, num_batches):
    sym1, sym2 = East(isTrain=False)
    sym = mx.sym.Group([sym1, sym2])

    if 'cpu' in str(dev):
       sym = sym.get_backend_symbol('MKLDNN')
    # sym, arg, aux = onnx_mxnet.import_model("test.onnx")

    data_shape = [('data', (batch_size, 3, 1024, 1024))]
    mod = mx.mod.Module(symbol=sym, context=dev)
    mod.bind(for_training     = False,
             inputs_need_grad=False,
             data_shapes=data_shape)
    mod.init_params(initializer=mx.init.Xavier(magnitude=2.))

    # get data
    data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=dev) for _, shape in mod.data_shapes]
    batch = mx.io.DataBatch(data, []) # empty label

    # run
    dry_run = 5                 # use 5 iterations to warm up
    for i in range(dry_run + num_batches):
        if i == dry_run:
            tic = time.time()
        mod.forward(batch, is_train=False)
        for output in mod.get_outputs():
            output.wait_to_read()

    if latency:
        logging.info('latency: %f ms', (time.time() - tic) / num_batches * 1000)
    # return num images per second
    return num_batches * batch_size / (time.time() - tic)
Ejemplo n.º 5
0
def main():
    root_path = './dataset/'
    train_img = root_path + 'train2015/'
    train_txt = root_path + 'train_label/'

    trainset = custom_dset(train_img, train_txt)
    print(trainset)
    trainloader = DataLoader(trainset,
                             batch_size=16,
                             shuffle=True,
                             collate_fn=collate_fn,
                             num_workers=4)
    model = East()
    model = model.cuda()

    crit = LossFunc()
    weight_decay = 0
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    #  weight_decay=1)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.94)

    train(epochs=1500,
          model=model,
          trainloader=trainloader,
          crit=crit,
          optimizer=optimizer,
          scheduler=scheduler,
          save_step=20,
          weight_decay=weight_decay)

    write.close()
Ejemplo n.º 6
0
    def __init__(self, lr, weight_path, output_path):
        self.output_path = output_path
        self.model = East()
        self.model = nn.DataParallel(self.model, device_ids=[0])
        self.model = self.model.cuda()
        init_weights(self.model, init_type='xavier')
        cudnn.benchmark = True

        self.criterion = LossFunc()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        self.scheduler = lr_scheduler.StepLR(self.optimizer,
                                             step_size=10000,
                                             gamma=0.94)
        self.weightpath = os.path.abspath(weight_path)
        logging.debug(
            "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Begin".format(
                self.weightpath))
        checkpoint = torch.load(self.weightpath)

        self.start_epoch = checkpoint['epoch']
        self.model.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        logging.debug(
            "EAST <==> Prepare <==> Loading checkpoint '{}', epoch={} <==> Done"
            .format(self.weightpath, self.start_epoch))
        self.model.eval()
Ejemplo n.º 7
0
def main():

    # Prepare for dataset
    print('EAST <==> Prepare <==> DataLoader <==> Begin')
    trainset = custom_dset(transform=transforms.ToTensor())
    train_loader = DataLoader(trainset, batch_size=cfg.train_batch_size_per_gpu * cfg.gpu,
                              shuffle=True, num_workers=cfg.num_workers)
    print('EAST <==> Prepare <==> Batch_size:{} <==> Begin'.format(cfg.train_batch_size_per_gpu * cfg.gpu))
    print('EAST <==> Prepare <==> DataLoader <==> Done')

    # test datalodaer
    # import numpy as np
    # import matplotlib.pyplot as plt
    # for batch_idx, (img, img_path, score_map, geo_map, training_mask) in enumerate(train_loader):
    #     print("batch index:", batch_idx, ",img batch shape", np.shape(geo_map.numpy()))
    #     h1 = img.numpy()[0].transpose(1, 2, 0).astype(np.int64)
    #     h2 = score_map.numpy()[0].transpose(1, 2, 0).astype(np.float32)[:, :, 0]
    #     plt.figure()
    #     plt.subplot(1, 2, 1)
    #     plt.imshow(h1)
    #     plt.subplot(1, 2, 2)
    #     plt.imshow(h2, cmap='gray')
    #     plt.show()

    # Model
    print('EAST <==> Prepare <==> Network <==> Begin')
    model = East()
    model = torch.nn.DataParallel(model, device_ids=cfg.gpu_ids)
    criterion = loss.LossFunc().cuda()
    weight_loss = utils.Regularization(model, cfg.l2_weight_decay, p=2).cuda()

    pre_params = list(map(id, model.module.mobilenet.parameters()))
    post_params = filter(lambda p: id(p) not in pre_params, model.module.parameters())
    optimizer = torch.optim.Adam([{'params': model.module.mobilenet.parameters(), 'lr': cfg.pre_lr},
                                  {'params': post_params, 'lr': cfg.lr}])
    # 计算方式 decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=cfg.decay_steps, gamma=cfg.decay_rate)
    model.cuda()

    # init or resume,恢复模型
    if cfg.resume and os.path.isfile(cfg.checkpoint):
        start_epoch = utils.Loading_checkpoint(model, optimizer, scheduler)
    else:
        start_epoch = 0

    print('EAST <==> Prepare <==> Network <==> Done')

    tensorboard_writer = init_tensorboard_writer('tensorboards/{}'.format(str(int(time.time()))))

    # train Model
    for epoch in range(start_epoch, cfg.max_epochs):

        scheduler.step()
        fit(train_loader, model, criterion, optimizer, epoch, weight_loss,tensorboard_writer)

        # 保存模型
        if epoch % cfg.save_eval_iteration == 0:
            utils.save_checkpoint(epoch, model, optimizer, scheduler)
Ejemplo n.º 8
0
def addEast(locationName, locationInfo, pictureLink, top, left):
    session = createSession()
    east_object = East(top=top,
                       left=left,
                       locationName=locationName,
                       locationInfo=locationInfo,
                       pictureLink=pictureLink)
    session.add(east_object)
    session.commit()
    session.close()
def main():
    # prepare output directory
    # global epoch
    print('EAST <==> TEST <==> Create Res_file and Img_with_box <==> Begin')
    result_root = os.path.abspath(cfg.res_img_path)
    if not os.path.exists(result_root):
        os.mkdir(result_root)

    print('EAST <==> Prepare <==> Network <==> Begin')
    model = East()
    model = torch.nn.DataParallel(model, device_ids=cfg.gpu_ids)
    model  #.cuda()
    if os.path.isfile(cfg.checkpoint):
        print(
            "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Begin".format(
                cfg.checkpoint))
        checkpoint = torch.load(cfg.checkpoint, map_location='cpu')
        epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        print(
            "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Done".format(
                cfg.checkpoint))
    else:
        print('Can not find checkpoint !!!')
        exit(1)
    print()
    print('###############')
    print()

    print('Original Size:')
    print_size_of_model(model)

    ###############
    print()
    print('Pruned model size')
    import torch.nn.utils.prune as prune
    for name, module in model.named_modules():
        # prune 40% of connections in all 2D-conv layers
        if isinstance(module, torch.nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=0.4)
            prune.l1_unstructured(module, name='bias', amount=0.3)
            prune.remove(module, 'weight')
            prune.remove(module, 'bias')
    # prune 40% of connections in all linear layers
        elif isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=0.4)
            prune.l1_unstructured(module, name='bias', amount=0.4)
            prune.remove(module, 'weight')
            prune.remove(module, 'bias')
    model = model.to_sparse()
    #print(dict(model.named_buffers()).keys())
    print_size_of_model(model)
Ejemplo n.º 10
0
def model_init(config):
    train_root_path = os.path.abspath(os.path.join(config["dataroot"],
                                                   'train'))
    train_img = os.path.join(train_root_path, 'img')
    train_gt = os.path.join(train_root_path, 'gt')

    trainset = custom_dset(train_img, train_gt)
    train_loader = DataLoader(trainset,
                              batch_size=config["train_batch_size_per_gpu"] *
                              config["gpu"],
                              shuffle=True,
                              collate_fn=collate_fn,
                              num_workers=config["num_workers"])

    logging.debug('Data loader created: Batch_size:{}, GPU {}:({})'.format(
        config["train_batch_size_per_gpu"] * config["gpu"], config["gpu"],
        config["gpu_ids"]))

    # Model
    model = East()
    model = nn.DataParallel(model, device_ids=config["gpu_ids"])
    model = model.cuda()
    init_weights(model, init_type=config["init_type"])
    logging.debug("Model initiated, init type: {}".format(config["init_type"]))

    cudnn.benchmark = True
    criterion = LossFunc()
    optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
    scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.94)

    # init or resume
    if config["resume"] and os.path.isfile(config["checkpoint"]):
        start_epoch = load_checkpoint(config, model, optimizer)
    else:
        start_epoch = 0
    logging.debug("Model is running...")
    return model, criterion, optimizer, scheduler, train_loader, start_epoch
Ejemplo n.º 11
0
def main():
    hmean = .0
    is_best = False

    warnings.simplefilter('ignore', np.RankWarning)
    # Prepare for dataset
    print('EAST <==> Prepare <==> DataLoader <==> Begin')
    # train_root_path = os.path.abspath(os.path.join('./dataset/', 'train'))
    train_root_path = cfg.dataroot
    train_img = os.path.join(train_root_path, 'img')
    train_gt = os.path.join(train_root_path, 'gt')

    trainset = custom_dset(train_img, train_gt)
    train_loader = DataLoader(trainset,
                              batch_size=cfg.train_batch_size_per_gpu *
                              cfg.gpu,
                              shuffle=True,
                              collate_fn=collate_fn,
                              num_workers=cfg.num_workers)
    print('EAST <==> Prepare <==> Batch_size:{} <==> Begin'.format(
        cfg.train_batch_size_per_gpu * cfg.gpu))
    print('EAST <==> Prepare <==> DataLoader <==> Done')

    # test datalodaer
    """
    for i in range(100000):
        for j, (a,b,c,d) in enumerate(train_loader):
            print(i, j,'/',len(train_loader))
    """

    # Model
    print('EAST <==> Prepare <==> Network <==> Begin')
    model = East()
    model = nn.DataParallel(model, device_ids=cfg.gpu_ids)
    model = model.cuda()
    init_weights(model, init_type=cfg.init_type)
    cudnn.benchmark = True

    criterion = LossFunc()
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.94)

    # init or resume
    if cfg.resume and os.path.isfile(cfg.checkpoint):
        weightpath = os.path.abspath(cfg.checkpoint)
        print(
            "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Begin".format(
                weightpath))
        checkpoint = torch.load(weightpath)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print(
            "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Done".format(
                weightpath))
    else:
        start_epoch = 0
    print('EAST <==> Prepare <==> Network <==> Done')

    for epoch in range(start_epoch, cfg.max_epochs):

        train(train_loader, model, criterion, scheduler, optimizer, epoch)

        if epoch % cfg.eval_iteration == 0:

            # create res_file and img_with_box
            output_txt_dir_path = predict(model, criterion, epoch)

            # Zip file
            submit_path = MyZip(output_txt_dir_path, epoch)

            # submit and compute Hmean
            hmean_ = compute_hmean(submit_path)

            if hmean_ > hmean:
                is_best = True

            state = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'is_best': is_best,
            }
            save_checkpoint(state, epoch)
Ejemplo n.º 12
0
BATCH_SIZE = 8

X_train, X_val, y_train, y_val = train_test_split(image_paths,
                                                  boxes,
                                                  test_size=0.35,
                                                  shuffle=True,
                                                  random_state=2021)
train_dataset = ReceiptDataset(X_train, y_train)
val_dataset = ReceiptDataset(X_val, y_val)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

# Model
EPOCHS = 50
model = East().to(device)
model.load_state_dict(torch.load('east1.pt'))
lr = 1e-4
loss_fn = Loss().to(device)
optimizer = Adam(model.parameters(), lr=lr)
best_val_loss = 0.455

train_loss = list()
val_loss = list()

for epoch in range(EPOCHS):
    print('Epoch {}'.format(epoch + 1))

    train_batch_loss = list()
    for X_batch_train, gt_score, gt_geo in tqdm(train_dataloader):
        X_batch_train = X_batch_train.to(device)
def main():
    # prepare output directory
    # global epoch
    print('EAST <==> TEST <==> Create Res_file and Img_with_box <==> Begin')
    result_root = os.path.abspath(cfg.res_img_path)
    if not os.path.exists(result_root):
        os.mkdir(result_root)

    print('EAST <==> Prepare <==> Network <==> Begin')
    model = East()
    model = torch.nn.DataParallel(model, device_ids=cfg.gpu_ids)
    model  #.cuda()
    # 载入模型
    if os.path.isfile(cfg.checkpoint):
        print(
            "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Begin".format(
                cfg.checkpoint))
        checkpoint = torch.load(cfg.checkpoint, map_location='cpu')
        epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        print(
            "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Done".format(
                cfg.checkpoint))
    else:
        print('Can not find checkpoint !!!')
        exit(1)
    print()
    print('###############')
    print()
    example = torch.rand(1, 3, 224, 224)

    #traced_script_module = torch.jit.trace(model, example)
    uninplace(model)
    l1 = [['module.conv1', 'module.bn1', 'module.relu1'],
          ['module.conv2', 'module.bn2', 'module.relu2'],
          ['module.conv3', 'module.bn3', 'module.relu3'],
          ['module.conv4', 'module.bn4', 'module.relu4'],
          ['module.conv5', 'module.bn5', 'module.relu5'],
          ['module.conv6', 'module.bn6', 'module.relu6'],
          ['module.conv7', 'module.bn7', 'module.relu7']]
    #s4
    l2 = [['module.s4.0.conv.0', 'module.s4.0.conv.1', 'module.s4.0.conv.2'],
          ['module.s4.0.conv.3', 'module.s4.0.conv.4', 'module.s4.0.conv.5'],
          ['module.s4.0.conv.6', 'module.s4.0.conv.7'],
          ['module.s4.1.conv.0', 'module.s4.1.conv.1', 'module.s4.1.conv.2'],
          ['module.s4.1.conv.3', 'module.s4.1.conv.4', 'module.s4.1.conv.5'],
          ['module.s4.1.conv.6', 'module.s4.1.conv.7'],
          ['module.s4.2.conv.0', 'module.s4.2.conv.1', 'module.s4.2.conv.2'],
          ['module.s4.2.conv.3', 'module.s4.2.conv.4', 'module.s4.2.conv.5'],
          ['module.s4.2.conv.6', 'module.s4.2.conv.7']]
    #s3
    l3 = [['module.s3.0.conv.0', 'module.s3.0.conv.1', 'module.s3.0.conv.2'],
          ['module.s3.0.conv.3', 'module.s3.0.conv.4', 'module.s3.0.conv.5'],
          ['module.s3.0.conv.6', 'module.s3.0.conv.7'],
          ['module.s3.1.conv.0', 'module.s3.1.conv.1', 'module.s3.1.conv.2'],
          ['module.s3.1.conv.3', 'module.s3.1.conv.4', 'module.s3.1.conv.5'],
          ['module.s3.1.conv.6', 'module.s3.1.conv.7'],
          ['module.s3.2.conv.0', 'module.s3.2.conv.1', 'module.s3.2.conv.2'],
          ['module.s3.2.conv.3', 'module.s3.2.conv.4', 'module.s3.2.conv.5'],
          ['module.s3.2.conv.6', 'module.s3.2.conv.7'],
          ['module.s3.3.conv.0', 'module.s3.3.conv.1', 'module.s3.3.conv.2'],
          ['module.s3.3.conv.3', 'module.s3.3.conv.4', 'module.s3.3.conv.5'],
          ['module.s3.3.conv.6', 'module.s3.3.conv.7'],
          ['module.s3.4.conv.0', 'module.s3.4.conv.1', 'module.s3.4.conv.2'],
          ['module.s3.4.conv.3', 'module.s3.4.conv.4', 'module.s3.4.conv.5'],
          ['module.s3.4.conv.6', 'module.s3.4.conv.7'],
          ['module.s3.5.conv.0', 'module.s3.5.conv.1', 'module.s3.5.conv.2'],
          ['module.s3.5.conv.3', 'module.s3.5.conv.4', 'module.s3.5.conv.5'],
          ['module.s3.5.conv.6', 'module.s3.5.conv.7'],
          ['module.s3.6.conv.0', 'module.s3.6.conv.1', 'module.s3.6.conv.2'],
          ['module.s3.6.conv.3', 'module.s3.6.conv.4', 'module.s3.6.conv.5'],
          ['module.s3.6.conv.6', 'module.s3.6.conv.7']]
    #s2
    l4 = [['module.s2.0.conv.0', 'module.s2.0.conv.1', 'module.s2.0.conv.2'],
          ['module.s2.0.conv.3', 'module.s2.0.conv.4', 'module.s2.0.conv.5'],
          ['module.s2.0.conv.6', 'module.s2.0.conv.7'],
          ['module.s2.1.conv.0', 'module.s2.1.conv.1', 'module.s2.1.conv.2'],
          ['module.s2.1.conv.3', 'module.s2.1.conv.4', 'module.s2.1.conv.5'],
          ['module.s2.1.conv.6', 'module.s2.1.conv.7'],
          ['module.s2.2.conv.0', 'module.s2.2.conv.1', 'module.s2.2.conv.2'],
          ['module.s2.2.conv.3', 'module.s2.2.conv.4', 'module.s2.2.conv.5'],
          ['module.s2.2.conv.6', 'module.s2.2.conv.7']]
    #s1
    l5 = [['module.s1.0.0', 'module.s1.0.1', 'module.s1.0.2'],
          ['module.s1.1.conv.0', 'module.s1.1.conv.1', 'module.s1.1.conv.2'],
          ['module.s1.1.conv.3', 'module.s1.1.conv.4'],
          ['module.s1.2.conv.0', 'module.s1.2.conv.1', 'module.s1.2.conv.2'],
          ['module.s1.2.conv.3', 'module.s1.2.conv.4', 'module.s1.2.conv.5'],
          ['module.s1.2.conv.6', 'module.s1.2.conv.7']]
    #( s1 - 1 and 2)
    l6 = [[
        'module.mobilenet.features.0.0', 'module.mobilenet.features.0.1',
        'module.mobilenet.features.0.2'
    ],
          [
              'module.mobilenet.features.1.conv.0',
              'module.mobilenet.features.1.conv.1',
              'module.mobilenet.features.1.conv.2'
          ],
          [
              'module.mobilenet.features.1.conv.3',
              'module.mobilenet.features.1.conv.4'
          ]]

    l7 = [[
        'module.mobilenet.features.2.conv.0',
        'module.mobilenet.features.2.conv.1',
        'module.mobilenet.features.2.conv.2'
    ],
          [
              'module.mobilenet.features.2.conv.3',
              'module.mobilenet.features.2.conv.4',
              'module.mobilenet.features.2.conv.5'
          ],
          [
              'module.mobilenet.features.2.conv.6',
              'module.mobilenet.features.2.conv.7'
          ],
          [
              'module.mobilenet.features.3.conv.0',
              'module.mobilenet.features.3.conv.1',
              'module.mobilenet.features.3.conv.2'
          ],
          [
              'module.mobilenet.features.3.conv.3',
              'module.mobilenet.features.3.conv.4',
              'module.mobilenet.features.3.conv.5'
          ],
          [
              'module.mobilenet.features.3.conv.6',
              'module.mobilenet.features.3.conv.7'
          ],
          [
              'module.mobilenet.features.4.conv.0',
              'module.mobilenet.features.4.conv.1',
              'module.mobilenet.features.4.conv.2'
          ],
          [
              'module.mobilenet.features.4.conv.3',
              'module.mobilenet.features.4.conv.4',
              'module.mobilenet.features.4.conv.5'
          ],
          [
              'module.mobilenet.features.4.conv.6',
              'module.mobilenet.features.4.conv.7'
          ],
          [
              'module.mobilenet.features.5.conv.0',
              'module.mobilenet.features.5.conv.1',
              'module.mobilenet.features.5.conv.2'
          ],
          [
              'module.mobilenet.features.5.conv.3',
              'module.mobilenet.features.5.conv.4',
              'module.mobilenet.features.5.conv.5'
          ],
          [
              'module.mobilenet.features.5.conv.6',
              'module.mobilenet.features.5.conv.7'
          ],
          [
              'module.mobilenet.features.6.conv.0',
              'module.mobilenet.features.6.conv.1',
              'module.mobilenet.features.6.conv.2'
          ],
          [
              'module.mobilenet.features.6.conv.3',
              'module.mobilenet.features.6.conv.4',
              'module.mobilenet.features.6.conv.5'
          ],
          [
              'module.mobilenet.features.6.conv.6',
              'module.mobilenet.features.6.conv.7'
          ],
          [
              'module.mobilenet.features.7.conv.0',
              'module.mobilenet.features.7.conv.1',
              'module.mobilenet.features.7.conv.2'
          ],
          [
              'module.mobilenet.features.7.conv.3',
              'module.mobilenet.features.7.conv.4',
              'module.mobilenet.features.7.conv.5'
          ],
          [
              'module.mobilenet.features.7.conv.6',
              'module.mobilenet.features.7.conv.7'
          ],
          [
              'module.mobilenet.features.8.conv.0',
              'module.mobilenet.features.8.conv.1',
              'module.mobilenet.features.8.conv.2'
          ],
          [
              'module.mobilenet.features.8.conv.3',
              'module.mobilenet.features.8.conv.4',
              'module.mobilenet.features.8.conv.5'
          ],
          [
              'module.mobilenet.features.8.conv.6',
              'module.mobilenet.features.8.conv.7'
          ],
          [
              'module.mobilenet.features.9.conv.0',
              'module.mobilenet.features.9.conv.1',
              'module.mobilenet.features.9.conv.2'
          ],
          [
              'module.mobilenet.features.9.conv.3',
              'module.mobilenet.features.9.conv.4',
              'module.mobilenet.features.9.conv.5'
          ],
          [
              'module.mobilenet.features.9.conv.6',
              'module.mobilenet.features.9.conv.7'
          ],
          [
              'module.mobilenet.features.10.conv.0',
              'module.mobilenet.features.10.conv.1',
              'module.mobilenet.features.10.conv.2'
          ],
          [
              'module.mobilenet.features.10.conv.3',
              'module.mobilenet.features.10.conv.4',
              'module.mobilenet.features.10.conv.5'
          ],
          [
              'module.mobilenet.features.10.conv.6',
              'module.mobilenet.features.10.conv.7'
          ],
          [
              'module.mobilenet.features.11.conv.0',
              'module.mobilenet.features.11.conv.1',
              'module.mobilenet.features.11.conv.2'
          ],
          [
              'module.mobilenet.features.11.conv.3',
              'module.mobilenet.features.11.conv.4',
              'module.mobilenet.features.11.conv.5'
          ],
          [
              'module.mobilenet.features.11.conv.6',
              'module.mobilenet.features.11.conv.7'
          ],
          [
              'module.mobilenet.features.12.conv.0',
              'module.mobilenet.features.12.conv.1',
              'module.mobilenet.features.12.conv.2'
          ],
          [
              'module.mobilenet.features.12.conv.3',
              'module.mobilenet.features.12.conv.4',
              'module.mobilenet.features.12.conv.5'
          ],
          [
              'module.mobilenet.features.12.conv.6',
              'module.mobilenet.features.12.conv.7'
          ],
          [
              'module.mobilenet.features.13.conv.0',
              'module.mobilenet.features.13.conv.1',
              'module.mobilenet.features.13.conv.2'
          ],
          [
              'module.mobilenet.features.13.conv.3',
              'module.mobilenet.features.13.conv.4',
              'module.mobilenet.features.13.conv.5'
          ],
          [
              'module.mobilenet.features.13.conv.6',
              'module.mobilenet.features.13.conv.7'
          ],
          [
              'module.mobilenet.features.14.conv.0',
              'module.mobilenet.features.14.conv.1',
              'module.mobilenet.features.14.conv.2'
          ],
          [
              'module.mobilenet.features.14.conv.3',
              'module.mobilenet.features.14.conv.4',
              'module.mobilenet.features.14.conv.5'
          ],
          [
              'module.mobilenet.features.14.conv.6',
              'module.mobilenet.features.14.conv.7'
          ],
          [
              'module.mobilenet.features.15.conv.0',
              'module.mobilenet.features.15.conv.1',
              'module.mobilenet.features.15.conv.2'
          ],
          [
              'module.mobilenet.features.15.conv.3',
              'module.mobilenet.features.15.conv.4',
              'module.mobilenet.features.15.conv.5'
          ],
          [
              'module.mobilenet.features.15.conv.6',
              'module.mobilenet.features.15.conv.7'
          ],
          [
              'module.mobilenet.features.16.conv.0',
              'module.mobilenet.features.16.conv.1',
              'module.mobilenet.features.16.conv.2'
          ],
          [
              'module.mobilenet.features.16.conv.3',
              'module.mobilenet.features.16.conv.4',
              'module.mobilenet.features.16.conv.5'
          ],
          [
              'module.mobilenet.features.16.conv.6',
              'module.mobilenet.features.16.conv.7'
          ]]

    #modules_to_fuse=l1+l2+l3+l4+l5
    #print(model)
    print('Original Size:')
    print_size_of_model(model)
    print()
    fused_model = torch.quantization.fuse_modules(model, l7)
    fused_model = torch.quantization.fuse_modules(model, l6)
    fused_model = torch.quantization.fuse_modules(model,
                                                  l1 + l2 + l3 + l4 + l5)

    print('Fused model Size:')
    print_size_of_model(fused_model)
    print()
    #print(fused_model)

    #fused_model.qconfig = torch.quantization.QConfig(activation=torch.quantization.default_histogram_observer,weight=torch.quantization.default_per_channel_weight_observer)
    fused_model.qconfig = torch.quantization.default_qconfig
    torch.quantization.prepare(fused_model, inplace=True)

    from data_loader import custom_dset
    from torchvision import transforms
    from torch.utils.data import DataLoader
    trainset = custom_dset(transform=transforms.ToTensor())
    train_loader = DataLoader(trainset,
                              batch_size=cfg.train_batch_size_per_gpu *
                              cfg.gpu,
                              shuffle=True,
                              num_workers=0)
    for i, (img, img_path, score_map, geo_map,
            training_mask) in enumerate(train_loader):
        f_score, f_geometry = fused_model(img)

    quantized = torch.quantization.convert(fused_model, inplace=False)
    print('Quantized model Size:')
    print_size_of_model(quantized)

    print('Done')
def main():
    # prepare output directory
    # global epoch
    print('EAST <==> TEST <==> Create Res_file and Img_with_box <==> Begin')
    result_root = os.path.abspath(cfg.res_img_path)
    if not os.path.exists(result_root):
        os.mkdir(result_root)

    print('EAST <==> Prepare <==> Network <==> Begin')
    model = East()
    model = torch.nn.DataParallel(model, device_ids=cfg.gpu_ids)
    model  #.cuda()
    # 载入模型
    if os.path.isfile(cfg.checkpoint):
        print(
            "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Begin".format(
                cfg.checkpoint))
        checkpoint = torch.load(cfg.checkpoint, map_location='cpu')
        epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        print(
            "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Done".format(
                cfg.checkpoint))
    else:
        print('Can not find checkpoint !!!')
        exit(1)
    print()
    print('###############')
    print()
    example = torch.rand(1, 3, 224, 224)

    #traced_script_module = torch.jit.trace(model, example)
    uninplace(model)
    l1 = [['module.conv1', 'module.bn1', 'module.relu1'],
          ['module.conv2', 'module.bn2', 'module.relu2'],
          ['module.conv3', 'module.bn3', 'module.relu3'],
          ['module.conv4', 'module.bn4', 'module.relu4'],
          ['module.conv5', 'module.bn5', 'module.relu5'],
          ['module.conv6', 'module.bn6', 'module.relu6'],
          ['module.conv7', 'module.bn7', 'module.relu7']]
    #s4
    l2 = [['module.s4.0.conv.0', 'module.s4.0.conv.1', 'module.s4.0.conv.2'],
          ['module.s4.0.conv.3', 'module.s4.0.conv.4', 'module.s4.0.conv.5'],
          ['module.s4.0.conv.6', 'module.s4.0.conv.7'],
          ['module.s4.1.conv.0', 'module.s4.1.conv.1', 'module.s4.1.conv.2'],
          ['module.s4.1.conv.3', 'module.s4.1.conv.4', 'module.s4.1.conv.5'],
          ['module.s4.1.conv.6', 'module.s4.1.conv.7'],
          ['module.s4.2.conv.0', 'module.s4.2.conv.1', 'module.s4.2.conv.2'],
          ['module.s4.2.conv.3', 'module.s4.2.conv.4', 'module.s4.2.conv.5'],
          ['module.s4.2.conv.6', 'module.s4.2.conv.7']]
    #s3
    l3 = [['module.s3.0.conv.0', 'module.s3.0.conv.1', 'module.s3.0.conv.2'],
          ['module.s3.0.conv.3', 'module.s3.0.conv.4', 'module.s3.0.conv.5'],
          ['module.s3.0.conv.6', 'module.s3.0.conv.7'],
          ['module.s3.1.conv.0', 'module.s3.1.conv.1', 'module.s3.1.conv.2'],
          ['module.s3.1.conv.3', 'module.s3.1.conv.4', 'module.s3.1.conv.5'],
          ['module.s3.1.conv.6', 'module.s3.1.conv.7'],
          ['module.s3.2.conv.0', 'module.s3.2.conv.1', 'module.s3.2.conv.2'],
          ['module.s3.2.conv.3', 'module.s3.2.conv.4', 'module.s3.2.conv.5'],
          ['module.s3.2.conv.6', 'module.s3.2.conv.7'],
          ['module.s3.3.conv.0', 'module.s3.3.conv.1', 'module.s3.3.conv.2'],
          ['module.s3.3.conv.3', 'module.s3.3.conv.4', 'module.s3.3.conv.5'],
          ['module.s3.3.conv.6', 'module.s3.3.conv.7'],
          ['module.s3.4.conv.0', 'module.s3.4.conv.1', 'module.s3.4.conv.2'],
          ['module.s3.4.conv.3', 'module.s3.4.conv.4', 'module.s3.4.conv.5'],
          ['module.s3.4.conv.6', 'module.s3.4.conv.7'],
          ['module.s3.5.conv.0', 'module.s3.5.conv.1', 'module.s3.5.conv.2'],
          ['module.s3.5.conv.3', 'module.s3.5.conv.4', 'module.s3.5.conv.5'],
          ['module.s3.5.conv.6', 'module.s3.5.conv.7'],
          ['module.s3.6.conv.0', 'module.s3.6.conv.1', 'module.s3.6.conv.2'],
          ['module.s3.6.conv.3', 'module.s3.6.conv.4', 'module.s3.6.conv.5'],
          ['module.s3.6.conv.6', 'module.s3.6.conv.7']]
    #s2
    l4 = [['module.s2.0.conv.0', 'module.s2.0.conv.1', 'module.s2.0.conv.2'],
          ['module.s2.0.conv.3', 'module.s2.0.conv.4', 'module.s2.0.conv.5'],
          ['module.s2.0.conv.6', 'module.s2.0.conv.7'],
          ['module.s2.1.conv.0', 'module.s2.1.conv.1', 'module.s2.1.conv.2'],
          ['module.s2.1.conv.3', 'module.s2.1.conv.4', 'module.s2.1.conv.5'],
          ['module.s2.1.conv.6', 'module.s2.1.conv.7'],
          ['module.s2.2.conv.0', 'module.s2.2.conv.1', 'module.s2.2.conv.2'],
          ['module.s2.2.conv.3', 'module.s2.2.conv.4', 'module.s2.2.conv.5'],
          ['module.s2.2.conv.6', 'module.s2.2.conv.7']]
    #s1
    l5 = [['module.s1.0.0', 'module.s1.0.1', 'module.s1.0.2'],
          ['module.s1.1.conv.0', 'module.s1.1.conv.1', 'module.s1.1.conv.2'],
          ['module.s1.1.conv.3', 'module.s1.1.conv.4'],
          ['module.s1.2.conv.0', 'module.s1.2.conv.1', 'module.s1.2.conv.2'],
          ['module.s1.2.conv.3', 'module.s1.2.conv.4', 'module.s1.2.conv.5'],
          ['module.s1.2.conv.6', 'module.s1.2.conv.7']]
    #( s1 - 1 and 2)
    l6 = [[
        'module.mobilenet.features.0.0', 'module.mobilenet.features.0.1',
        'module.mobilenet.features.0.2'
    ],
          [
              'module.mobilenet.features.1.conv.0',
              'module.mobilenet.features.1.conv.1',
              'module.mobilenet.features.1.conv.2'
          ],
          [
              'module.mobilenet.features.1.conv.3',
              'module.mobilenet.features.1.conv.4'
          ]]

    l7 = [[
        'module.mobilenet.features.2.conv.0',
        'module.mobilenet.features.2.conv.1',
        'module.mobilenet.features.2.conv.2'
    ],
          [
              'module.mobilenet.features.2.conv.3',
              'module.mobilenet.features.2.conv.4',
              'module.mobilenet.features.2.conv.5'
          ],
          [
              'module.mobilenet.features.2.conv.6',
              'module.mobilenet.features.2.conv.7'
          ],
          [
              'module.mobilenet.features.3.conv.0',
              'module.mobilenet.features.3.conv.1',
              'module.mobilenet.features.3.conv.2'
          ],
          [
              'module.mobilenet.features.3.conv.3',
              'module.mobilenet.features.3.conv.4',
              'module.mobilenet.features.3.conv.5'
          ],
          [
              'module.mobilenet.features.3.conv.6',
              'module.mobilenet.features.3.conv.7'
          ],
          [
              'module.mobilenet.features.4.conv.0',
              'module.mobilenet.features.4.conv.1',
              'module.mobilenet.features.4.conv.2'
          ],
          [
              'module.mobilenet.features.4.conv.3',
              'module.mobilenet.features.4.conv.4',
              'module.mobilenet.features.4.conv.5'
          ],
          [
              'module.mobilenet.features.4.conv.6',
              'module.mobilenet.features.4.conv.7'
          ],
          [
              'module.mobilenet.features.5.conv.0',
              'module.mobilenet.features.5.conv.1',
              'module.mobilenet.features.5.conv.2'
          ],
          [
              'module.mobilenet.features.5.conv.3',
              'module.mobilenet.features.5.conv.4',
              'module.mobilenet.features.5.conv.5'
          ],
          [
              'module.mobilenet.features.5.conv.6',
              'module.mobilenet.features.5.conv.7'
          ],
          [
              'module.mobilenet.features.6.conv.0',
              'module.mobilenet.features.6.conv.1',
              'module.mobilenet.features.6.conv.2'
          ],
          [
              'module.mobilenet.features.6.conv.3',
              'module.mobilenet.features.6.conv.4',
              'module.mobilenet.features.6.conv.5'
          ],
          [
              'module.mobilenet.features.6.conv.6',
              'module.mobilenet.features.6.conv.7'
          ],
          [
              'module.mobilenet.features.7.conv.0',
              'module.mobilenet.features.7.conv.1',
              'module.mobilenet.features.7.conv.2'
          ],
          [
              'module.mobilenet.features.7.conv.3',
              'module.mobilenet.features.7.conv.4',
              'module.mobilenet.features.7.conv.5'
          ],
          [
              'module.mobilenet.features.7.conv.6',
              'module.mobilenet.features.7.conv.7'
          ],
          [
              'module.mobilenet.features.8.conv.0',
              'module.mobilenet.features.8.conv.1',
              'module.mobilenet.features.8.conv.2'
          ],
          [
              'module.mobilenet.features.8.conv.3',
              'module.mobilenet.features.8.conv.4',
              'module.mobilenet.features.8.conv.5'
          ],
          [
              'module.mobilenet.features.8.conv.6',
              'module.mobilenet.features.8.conv.7'
          ],
          [
              'module.mobilenet.features.9.conv.0',
              'module.mobilenet.features.9.conv.1',
              'module.mobilenet.features.9.conv.2'
          ],
          [
              'module.mobilenet.features.9.conv.3',
              'module.mobilenet.features.9.conv.4',
              'module.mobilenet.features.9.conv.5'
          ],
          [
              'module.mobilenet.features.9.conv.6',
              'module.mobilenet.features.9.conv.7'
          ],
          [
              'module.mobilenet.features.10.conv.0',
              'module.mobilenet.features.10.conv.1',
              'module.mobilenet.features.10.conv.2'
          ],
          [
              'module.mobilenet.features.10.conv.3',
              'module.mobilenet.features.10.conv.4',
              'module.mobilenet.features.10.conv.5'
          ],
          [
              'module.mobilenet.features.10.conv.6',
              'module.mobilenet.features.10.conv.7'
          ],
          [
              'module.mobilenet.features.11.conv.0',
              'module.mobilenet.features.11.conv.1',
              'module.mobilenet.features.11.conv.2'
          ],
          [
              'module.mobilenet.features.11.conv.3',
              'module.mobilenet.features.11.conv.4',
              'module.mobilenet.features.11.conv.5'
          ],
          [
              'module.mobilenet.features.11.conv.6',
              'module.mobilenet.features.11.conv.7'
          ],
          [
              'module.mobilenet.features.12.conv.0',
              'module.mobilenet.features.12.conv.1',
              'module.mobilenet.features.12.conv.2'
          ],
          [
              'module.mobilenet.features.12.conv.3',
              'module.mobilenet.features.12.conv.4',
              'module.mobilenet.features.12.conv.5'
          ],
          [
              'module.mobilenet.features.12.conv.6',
              'module.mobilenet.features.12.conv.7'
          ],
          [
              'module.mobilenet.features.13.conv.0',
              'module.mobilenet.features.13.conv.1',
              'module.mobilenet.features.13.conv.2'
          ],
          [
              'module.mobilenet.features.13.conv.3',
              'module.mobilenet.features.13.conv.4',
              'module.mobilenet.features.13.conv.5'
          ],
          [
              'module.mobilenet.features.13.conv.6',
              'module.mobilenet.features.13.conv.7'
          ],
          [
              'module.mobilenet.features.14.conv.0',
              'module.mobilenet.features.14.conv.1',
              'module.mobilenet.features.14.conv.2'
          ],
          [
              'module.mobilenet.features.14.conv.3',
              'module.mobilenet.features.14.conv.4',
              'module.mobilenet.features.14.conv.5'
          ],
          [
              'module.mobilenet.features.14.conv.6',
              'module.mobilenet.features.14.conv.7'
          ],
          [
              'module.mobilenet.features.15.conv.0',
              'module.mobilenet.features.15.conv.1',
              'module.mobilenet.features.15.conv.2'
          ],
          [
              'module.mobilenet.features.15.conv.3',
              'module.mobilenet.features.15.conv.4',
              'module.mobilenet.features.15.conv.5'
          ],
          [
              'module.mobilenet.features.15.conv.6',
              'module.mobilenet.features.15.conv.7'
          ],
          [
              'module.mobilenet.features.16.conv.0',
              'module.mobilenet.features.16.conv.1',
              'module.mobilenet.features.16.conv.2'
          ],
          [
              'module.mobilenet.features.16.conv.3',
              'module.mobilenet.features.16.conv.4',
              'module.mobilenet.features.16.conv.5'
          ],
          [
              'module.mobilenet.features.16.conv.6',
              'module.mobilenet.features.16.conv.7'
          ]]

    #modules_to_fuse=l1+l2+l3+l4+l5
    #print(model)
    print('Original Size:')
    print_size_of_model(model)
    print()
    fused_model = torch.quantization.fuse_modules(model, l7)
    fused_model = torch.quantization.fuse_modules(model, l6)
    fused_model = torch.quantization.fuse_modules(model,
                                                  l1 + l2 + l3 + l4 + l5)

    print('Fused model Size:')
    print_size_of_model(fused_model)
    print()
    #print(fused_model)

    #fused_model.qconfig = torch.quantization.QConfig(activation=torch.quantization.default_histogram_observer,weight=torch.quantization.default_per_channel_weight_observer)
    fused_model.qconfig = torch.quantization.default_qconfig
    torch.quantization.prepare(fused_model, inplace=True)

    from data_loader import custom_dset
    from torchvision import transforms
    from torch.utils.data import DataLoader
    trainset = custom_dset(transform=transforms.ToTensor())
    device = torch.device('cpu')
    train_loader = DataLoader(trainset,
                              batch_size=cfg.train_batch_size_per_gpu *
                              cfg.gpu,
                              shuffle=True,
                              num_workers=0)
    for i, (img, img_path, score_map, geo_map,
            training_mask) in enumerate(train_loader):
        img, score_map, geo_map, training_mask = img.to(device), score_map.to(
            device), geo_map.to(device), training_mask.to(device)
        f_score, f_geometry = fused_model(img)

    quantized = torch.quantization.convert(fused_model, inplace=False)
    print('Quantized model Size:')
    print_size_of_model(quantized)

    num_train_batches = 20
    print('***QAT***')
    print()
    criterion = loss.LossFunc()
    pre_params = list(map(id, model.module.mobilenet.parameters()))
    post_params = filter(lambda p: id(p) not in pre_params,
                         model.module.parameters())

    optimizer = torch.optim.Adam([{
        'params': model.module.mobilenet.parameters(),
        'lr': cfg.pre_lr
    }, {
        'params': post_params,
        'lr': cfg.lr
    }])
    fused_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
    # Train and check accuracy after each epoch
    for nepoch in range(8):
        train_one_epoch(fused_model, criterion, optimizer, train_loader,
                        torch.device('cpu'), num_train_batches)
        if nepoch > 3:
            # Freeze quantizer parameters
            fused_model.apply(torch.quantization.disable_observer)
        if nepoch > 2:
            # Freeze batch norm mean and variance estimates
            fused_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
    quantized_model = torch.quantization.convert(fused_model.eval(),
                                                 inplace=False)

    print('QAT model Size:')
    print_size_of_model(quantized_model)
    print('Done')
    print(quantized)
Ejemplo n.º 15
0
def main():
    warnings.simplefilter('ignore', np.RankWarning)
    #Model
    video_root_path = os.path.abspath('./dataset/train/')
    video_name_list = sorted(
        [p for p in os.listdir(video_root_path) if p.split('_')[0] == 'Video'])
    #print('video_name_list', video_name_list)
    print('EAST <==> Prepare <==> Network <==> Begin')
    model = East()
    AGD_model = AGD()
    model = nn.DataParallel(model, device_ids=cfg.gpu_ids)
    #AGD_model = nn.DataParallel(AGD_model, device_ids=cfg.gpu_ids)
    model = model.cuda()
    AGD_model = AGD_model.cuda()
    init_weights(model, init_type=cfg.init_type)
    cudnn.benchmark = True

    criterion1 = LossFunc()
    #
    criterion2 = Ass_loss()

    optimizer1 = torch.optim.Adam(model.parameters(), lr=cfg.lr)
    optimizer2 = torch.optim.Adam(AGD_model.parameters(), lr=cfg.lr)
    scheduler = lr_scheduler.StepLR(optimizer1, step_size=10000, gamma=0.94)

    # init or resume
    if cfg.resume and os.path.isfile(cfg.checkpoint):
        weightpath = os.path.abspath(cfg.checkpoint)
        print(
            "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Begin".format(
                weightpath))
        checkpoint = torch.load(weightpath)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        #AGD_model.load_state_dict(checkpoint['model2.state_dict'])
        optimizer1.load_state_dict(checkpoint['optimizer'])
        #optimizer2.load_state_dict(checkpoint['optimizer2'])
        print(
            "EAST <==> Prepare <==> Loading checkpoint '{}' <==> Done".format(
                weightpath))
    else:
        start_epoch = 0
    print('EAST <==> Prepare <==> Network <==> Done')

    for epoch in range(start_epoch + 1, cfg.max_epochs):
        for video_name in video_name_list:
            print(
                'EAST <==> epoch:{} <==> Prepare <==> DataLoader <==>{} Begin'.
                format(epoch, video_name))
            trainset = custom_dset(os.path.join(video_root_path, video_name))
            #sampler = sampler_for_video_clip(len(trainset))
            train_loader = DataLoader(trainset,
                                      batch_size=cfg.train_batch_size_per_gpu *
                                      cfg.gpu,
                                      shuffle=False,
                                      collate_fn=collate_fn,
                                      num_workers=cfg.num_workers,
                                      drop_last=True)
            print('EAST <==> Prepare <==> Batch_size:{} <==> Begin'.format(
                cfg.train_batch_size_per_gpu * cfg.gpu))
            print(
                'EAST <==> epoch:{} <==> Prepare <==> DataLoader <==>{} Done'.
                format(epoch, video_name))

            train(train_loader, model, AGD_model, criterion1, criterion2,
                  scheduler, optimizer1, optimizer2, epoch)
            '''
            for i, (img, score_map, geo_map, training_mask, coord_ids) in enumerate(train_loader):
                print('i{} img.shape:{} geo_map.shape{} training_mask.shape{} coord_ids.len{}'.format(i, score_map.shape, geo_map.shape, training_mask.shape, len(coord_ids)))
            '''

        if epoch % cfg.eval_iteration == 0:
            state = {
                'epoch': epoch,
                'model1.state_dict': model.state_dict(),
                'model2.state_dict': AGD_model.state_dict(),
                'optimizer1': optimizer1.state_dict(),
                'optimizer2': optimizer2.state_dict()
            }
            save_checkpoint(state, epoch)