예제 #1
0
def train(train_img_path, train_gt_path, pths_path, batch_size, lr, num_workers, epoch_iter, interval):
    file_num = len(os.listdir(train_img_path))
    trainset = custom_dataset(train_img_path, train_gt_path)
    train_loader = data.DataLoader(trainset, batch_size=batch_size, \
                                   shuffle=True, num_workers=num_workers, drop_last=True)

    criterion = Loss()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = EAST()
    data_parallel = False
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        data_parallel = True
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[epoch_iter // 3, epoch_iter * 2 // 3], gamma=0.1)

    train_loss = []

    for epoch in range(epoch_iter):
        model.train()
        scheduler.step()
        epoch_loss = 0
        epoch_time = time.time()
        for i, (img, gt_score, gt_geo, ignored_map) in enumerate(train_loader):
            start_time = time.time()
            img, gt_score, gt_geo, ignored_map = img.to(device), gt_score.to(device), gt_geo.to(device), ignored_map.to(
                device)
            pred_score, pred_geo = model(img)
            loss = criterion(gt_score, pred_score, gt_geo, pred_geo, ignored_map)

            epoch_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print('Epoch is [{}/{}], mini-batch is [{}/{}], time consumption is {:.8f}, batch_loss is {:.8f}'.format( \
                epoch + 1, epoch_iter, i + 1, int(file_num / batch_size), time.time() - start_time, loss.item()))

        epoch_loss_mean = epoch_loss / int(file_num / batch_size)
        train_loss.append(epoch_loss_mean)
        print('epoch_loss is {:.8f}, epoch_time is {:.8f}'.format(epoch_loss_mean,
                                                                  time.time() - epoch_time))
        print(time.asctime(time.localtime(time.time())))
        print('=' * 50)
        if (epoch + 1) % interval == 0:
            savePath = pths_path + 'lossImg' + str(epoch + 1) + '.jpg'
            drawLoss(train_loss, savePath)
            lossPath = pths_path + 'loss' + str(epoch + 1) + '.npy'
            train_loss_np = np.array(train_loss, dtype=float)
            np.save(lossPath, train_loss_np)
            state_dict = model.module.state_dict() if data_parallel else model.state_dict()
            torch.save(state_dict, os.path.join(pths_path, 'model_epoch_{}.pth'.format(epoch + 1)))
            lr_state = scheduler.state_dict()
            torch.save(lr_state, os.path.join(pths_path, 'scheduler_epoch_{}.pth'.format(epoch + 1)))
예제 #2
0
def eval_model(model_name, test_img_path, submit_path, save_flag=True):
    if os.path.exists(submit_path):
        shutil.rmtree(submit_path)
    os.mkdir(submit_path)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = EAST(pretrained=False).to(device)
    model.load_state_dict(torch.load(model_name))
    model.eval()

    start_time = time.time()
    detect_dataset(model, device, test_img_path, submit_path)
    os.chdir(submit_path)
    res = subprocess.getoutput('zip -q submit.zip *.txt')
    res = subprocess.getoutput('mv submit.zip ../')
    os.chdir('../')
    res = subprocess.getoutput(
        'python ./evaluate/script.py –g=./evaluate/gt.zip –s=./submit.zip')
    print(res)
    os.remove('./submit.zip')
    print('eval time is {}'.format(time.time() - start_time))

    if not save_flag:
        shutil.rmtree(submit_path)
        seq = []
        if boxes is not None:
            seq.extend([
                ','.join([str(int(b)) for b in box[:-1]]) + '\n'
                for box in boxes
            ])
        with open(
                os.path.join(
                    submit_path, 'res_' +
                    os.path.basename(img_file).replace('.jpg', '.txt')),
                'w') as f:
            f.writelines(seq)


if __name__ == '__main__':
    img_path = '../data/train/img/tr_img_00017.jpg'
    model_path = './pths/max.pth'
    res_img = './res.bmp'
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = EAST(pretrained=False).to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    img = Image.open(img_path)
    print(img.size)
    r = min(800 / img.width, 800 / img.height)
    print(r)
    img = img.resize((int(img.width * r), int(img.height * r)), Image.BILINEAR)
    boxes = detect(img_path, img, model, device)
    plot_img = plot_boxes(img, boxes)
    plot_img.save(res_img)
def train(train_img_path, train_gt_path, pths_path, batch_size, lr,
          num_workers, epoch_iter, interval, checkpoint, eval_interval,
          test_img_path, submit_path):
    file_num = len(os.listdir(train_img_path))
    trainset = custom_dataset(train_img_path, train_gt_path)
    train_loader = data.DataLoader(trainset, batch_size=batch_size, \
                                      shuffle = True, num_workers=num_workers, drop_last=True)

    criterion = Loss()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = EAST(pretrained=False)
    if checkpoint:
        model.load_state_dict(torch.load(checkpoint))
    data_parallel = False
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        # model = DataParallelModel(model)
        data_parallel = True
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    # optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9,weight_decay=0)
    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=[epoch_iter // 2],
                                         gamma=0.1)
    whole_number = epoch_iter * (len(trainset) / batch_size)
    print("epoch size:%d" % (epoch_iter))
    print("batch size:%d" % (batch_size))
    print("data number:%d" % (len(trainset)))
    all_loss = []
    current_i = 0
    for epoch in range(epoch_iter):

        model.train()

        epoch_loss = 0
        epoch_time = time.time()
        for i, (img, gt_score, gt_geo, ignored_map,
                _) in enumerate(train_loader):
            current_i += 1
            start_time = time.time()
            img, gt_score, gt_geo, ignored_map = img.to(device), gt_score.to(
                device), gt_geo.to(device), ignored_map.to(device)
            pred_score, pred_geo = model(img)
            loss = criterion(gt_score, pred_score, gt_geo, pred_geo,
                             ignored_map)

            epoch_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lr_now = scheduler.get_last_lr()
            progress_bar(40, loss.item(), current_i, whole_number, lr_now[0])
        scheduler.step()
        print('epoch_loss is {:.8f}, epoch_time is {:.8f}'.format(
            epoch_loss / int(file_num / batch_size),
            time.time() - epoch_time))
        all_loss.append(epoch_loss / int(file_num / batch_size))
        print(time.asctime(time.localtime(time.time())))
        plt.plot(all_loss)
        plt.savefig('loss_landscape.png')
        plt.close()
        print('=' * 50)
        if (epoch + 1) % interval == 0:
            state_dict = model.module.state_dict(
            ) if data_parallel else model.state_dict()
            torch.save(
                state_dict,
                os.path.join(pths_path,
                             'model_epoch_{}.pth'.format(epoch + 1)))
            output = open(os.path.join(pths_path, 'loss.pkl'), 'wb')
            pkl.dump(all_loss, output)
예제 #5
0
                for box in boxes
            ])
        with open(
                os.path.join(
                    submit_path, 'res_' +
                    os.path.basename(img_file).replace('.jpg', '.txt')),
                'w') as f:
            f.writelines(seq)


if __name__ == '__main__':
    img_path = '/youedata/dengjinhong/zjw/dataset/icdar2013/Challenge2_Test_Task12_Images/img_1.jpg'
    model_path = '/youedata/dengjinhong/zjw/code/EAST_Tansfer/checkpoint/baseline_k1_rlop_2/model_epoch_best.pth'
    res_img = './res.bmp'
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = EAST().to(device)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    img = Image.open(img_path)

    boxes = detect(img, model, device)
    seq = []

    if boxes is not None:
        seq.extend([
            ','.join([str(int(b))
                      for b in [box[0], box[1], box[4], box[5]]]) + '\n'
            for box in boxes
        ])
예제 #6
0
def train(source_img_path,
          source_gt_path,
          target_img_path,
          target_gt_path,
          valid_img_path,
          valid_gt_path,
          pths_path,
          batch_size,
          lr,
          num_workers,
          epoch_iter,
          interval,
          pretrain_model_path=None,
          scheduler_path=None,
          current_epoch_num=0):

    if not os.path.exists(pths_path):
        os.mkdir(pths_path)

    # source_train_set = IC13_dataset(source_img_path, source_gt_path)
    source_train_set = custom_dataset(source_img_path, source_gt_path)
    target_train_set = custom_dataset(target_img_path, target_gt_path)
    valid_train_set = valid_dataset(valid_img_path, valid_gt_path)

    source_train_loader = data.DataLoader(source_train_set,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=num_workers,
                                          drop_last=True)
    target_train_loader = data.DataLoader(target_train_set,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=num_workers,
                                          drop_last=True)
    valid_loader = data.DataLoader(valid_train_set,
                                   batch_size=batch_size,
                                   shuffle=False,
                                   num_workers=num_workers,
                                   drop_last=False)

    criterion = Loss().to(device)
    loss_domain = torch.nn.CrossEntropyLoss()

    model = EAST()
    if None != pretrain_model_path:
        model.load_state_dict(torch.load(pretrain_model_path))
    data_parallel = False
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        data_parallel = True

    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = lr_scheduler.MultiStepLR(
        optimizer,
        milestones=[epoch_iter // 3, epoch_iter * 2 // 3],
        gamma=0.1)
    if None != scheduler_path:
        scheduler.load_state_dict(torch.load(scheduler_path))
    best_loss = 1000
    best_model_wts = copy.deepcopy(model.state_dict())
    best_num = 0

    train_loss = []
    valid_loss = []

    for epoch in range(current_epoch_num, epoch_iter):
        model.train()
        target_train_iter = iter(target_train_loader)

        epoch_loss = 0
        epoch_time = time.time()
        for i, (s_img, s_gt_score, s_gt_geo,
                s_ignored_map) in enumerate(source_train_loader):
            start_time = time.time()

            try:
                t_img, t_gt_score, t_gt_geo, t_ignored_map = next(
                    target_train_iter)
            except StopIteration:
                target_train_iter = iter(source_train_loader)
                t_img, t_gt_score, t_gt_geo, t_ignored_map = next(
                    target_train_iter)

            s_img, s_gt_score, s_gt_geo, s_ignored_map = s_img.to(
                device), s_gt_score.to(device), s_gt_geo.to(
                    device), s_ignored_map.to(device)

            pred_score, pred_geo, pred_cls = model(s_img, False)

            #source label
            domain_s = Variable(torch.zeros(pred_cls.size(0)).long().cuda())
            loss_domain_s = loss_domain(pred_cls, domain_s)

            target_cls = model(t_img, True)
            # target label
            domain_t = Variable(torch.ones(pred_cls.size(0)).long().cuda())
            loss_domain_t = loss_domain(target_cls, domain_t)

            loss = criterion(s_gt_score, pred_score, s_gt_geo, pred_geo,
                             s_ignored_map) + loss_domain_s + loss_domain_t

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

            # print('Epoch is [{}/{}], mini-batch is [{}/{}], time consumption is {:.8f}, batch_loss is {:.8f}'.format( \
            #     epoch + 1, epoch_iter, i + 1, int(len(source_train_loader)), time.time() - start_time, loss.item()))

        epoch_loss_mean = epoch_loss / len(source_train_loader)
        train_loss.append(epoch_loss_mean)
        print('Epoch[{}], Train, epoch_loss is {:.8f}, epoch_time is {:.8f}'.
              format(epoch, epoch_loss_mean,
                     time.time() - epoch_time))

        val_epoch_loss = eval(model, valid_loader, criterion, epoch)
        val_loss_mean = val_epoch_loss / len(valid_loader)
        valid_loss.append(val_loss_mean)

        print(time.asctime(time.localtime(time.time())))
        print('=' * 50)

        if val_loss_mean < best_loss:
            best_num = epoch + 1
            best_loss = val_loss_mean
            best_model_wts = copy.deepcopy(model.state_dict())
            # save best model
            print('best model num:{}, best loss is {:.8f}'.format(
                best_num, best_loss))
            torch.save(best_model_wts,
                       os.path.join(pths_path, 'model_epoch_best.pth'))
        if (epoch + 1) % interval == 0:
            savePath = pths_path + 'lossImg' + str(epoch + 1) + '.jpg'
            drawLoss(train_loss, valid_loss, savePath)
            print(time.asctime(time.localtime(time.time())))
            state_dict = model.module.state_dict(
            ) if data_parallel else model.state_dict()
            lr_state = scheduler.state_dict()
            torch.save(
                state_dict,
                os.path.join(pths_path,
                             'model_epoch_{}.pth'.format(epoch + 1)))
            torch.save(
                lr_state,
                os.path.join(pths_path,
                             'scheduler_epoch_{}.pth'.format(epoch + 1)))
            print("save model")
            print('=' * 50)
예제 #7
0
def main(args):
    source_train_set = custom_dataset(args.train_data_path, args.train_gt_path)
    target_train_set = custom_dataset(args.target_data_path,
                                      args.target_gt_path)
    valid_train_set = valid_dataset(args.val_data_path, args.val_gt_path)

    source_train_loader = data.DataLoader(source_train_set,
                                          batch_size=args.batch_size,
                                          shuffle=True,
                                          num_workers=args.num_workers,
                                          drop_last=True)
    target_train_loader = data.DataLoader(target_train_set,
                                          batch_size=args.batch_size,
                                          shuffle=True,
                                          num_workers=args.num_workers,
                                          drop_last=True)
    valid_loader = data.DataLoader(valid_train_set,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   num_workers=args.num_workers,
                                   drop_last=False)

    criterion = Loss().to(device)
    # domain loss
    loss_domain = torch.nn.CrossEntropyLoss().to(device)

    best_loss = 1000
    best_num = 0

    model = EAST()
    if args.pretrained_model_path:
        model.load_state_dict(torch.load(args.pretrained_model_path))

    # resume
    if args.resume:
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        best_loss = checkpoint['best_loss']
        current_epoch_num = checkpoint['epoch']

    data_parallel = False
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        data_parallel = True

    model.to(device)

    total_epoch = args.epochs
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    # scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[total_epoch // 3, total_epoch * 2 // 3], gamma=0.1)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               mode='min',
                                               factor=0.1,
                                               patience=6,
                                               threshold=args.lr / 100)
    current_epoch_num = 0

    # resume
    if args.resume:
        checkpoint = torch.load(args.resume)
        scheduler.load_state_dict(checkpoint['scheduler'])

    for epoch in range(current_epoch_num, total_epoch):
        each_epoch_start = time.time()
        # scheduler.step(epoch)
        # add lr in tensorboardX
        writer.add_scalar('epoch/lr', get_learning_rate(optimizer), epoch)

        train(source_train_loader, target_train_loader, model, criterion,
              loss_domain, optimizer, epoch)

        val_loss = eval(model, valid_loader, criterion, loss_domain, epoch)
        scheduler.step(val_loss)

        if val_loss < best_loss:
            best_num = epoch + 1
            best_loss = val_loss
            best_model_wts = copy.deepcopy(model.module.state_dict(
            ) if data_parallel else model.state_dict())
            # save best model

            torch.save(
                {
                    'epoch': epoch + 1,
                    'state_dict': best_model_wts,
                    'best_loss': best_loss,
                    'scheduler': scheduler.state_dict(),
                }, os.path.join(save_folder, "model_epoch_best.pth"))

            log.write('best model num:{}, best loss is {:.8f}'.format(
                best_num, best_loss))
            log.write('\n')

        if (epoch + 1) % int(args.save_interval) == 0:
            state_dict = model.module.state_dict(
            ) if data_parallel else model.state_dict()
            torch.save(
                {
                    'epoch': epoch + 1,
                    'state_dict': state_dict,
                    'best_loss': best_loss,
                    'scheduler': scheduler.state_dict(),
                },
                os.path.join(save_folder,
                             'model_epoch_{}.pth'.format(epoch + 1)))
            log.write('save model')
            log.write('\n')

        log.write('=' * 50)
        log.write('\n')