Exemple #1
0
    net = net.cuda()
    #net = CRAFT_net

    # if args.cdua:
    net = torch.nn.DataParallel(net, device_ids=[0, 1, 2, 3]).cuda()
    cudnn.benchmark = True
    # realdata = ICDAR2015(net, '/data/CRAFT-pytorch/icdar2015', target_size=768)
    # real_data_loader = torch.utils.data.DataLoader(
    #     realdata,
    #     batch_size=10,
    #     shuffle=True,
    #     num_workers=0,
    #     drop_last=True,
    #     pin_memory=True)

    optimizer = optim.Adam(net.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    criterion = Maploss()
    #criterion = torch.nn.MSELoss(reduce=True, size_average=True)
    net.train()

    step_index = 0

    loss_time = 0
    loss_value = 0
    compare_loss = 1
    for epoch in range(1000):
        loss_value = 0
        # if epoch % 50 == 0 and epoch != 0:
        #     step_index += 1
    #     realdata,
    #     batch_size=10,
    #     shuffle=True,
    #     num_workers=0,
    #     drop_last=True,
    #     pin_memory=True)
    realdata = ICDAR2015(net, './data/icdar15', target_size=768)
    real_data_loader = torch.utils.data.DataLoader(
            realdata,
            batch_size=args.batch_size*5,
            shuffle=True,
            num_workers=0,
            drop_last=True,
            pin_memory=True)

    optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    criterion = Maploss()
    #criterion = torch.nn.MSELoss(reduce=True, size_average=True)
    net.train()


    step_index = 0


    loss_time = 0
    loss_value = 0
    compare_loss = 1

    batch_time = AverageMeter(100)
    iter_time = AverageMeter(100)
Exemple #3
0
        net = nn.DataParallel(net)
        data_parallel = True
    cudnn.benchmark = False

    print('Load the real data')
    real_data = ICDAR2013(net, 'D:/Datasets/ICDAR_2013')
    real_data_loader = torch.utils.data.DataLoader(real_data,
                                                   batch_size=5,
                                                   shuffle=True,
                                                   num_workers=0,
                                                   drop_last=True,
                                                   pin_memory=True)

    lr = .0001
    epochs = 15
    optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=5e-4)
    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=[epochs // 3],
                                         gamma=0.1)
    criterion = Loss()

    print('Begin training ...')
    net.train()
    for epoch in range(epochs):
        scheduler.step()
        loss_value = 0
        start = time.time()
        for i, (real_images, real_region_scores, real_affinity_scores,
                real_confidence_mask) in enumerate(real_data_loader):
            syn_images, syn_region_scores, syn_affinity_scores, syn_confidence_mask = next(
                batch_syn)
Exemple #4
0
dataset = ImageLoader_synthtext(args)
assert dataset
data_loader = torch.utils.data.DataLoader(dataset, args.batch_size, num_workers=4, shuffle=True, collate_fn=collate)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
criterion = torch.nn.MSELoss(reduction='mean')
criterion = criterion.to(device)
craft = CRAFT(pretrained=True)

if args.go_on != '':
    print('loading pretrained model from %s' % args.pre_model)
    craft.load_state_dict(torch.load(args.pre_model), strict=False)
craft = craft.to(device)

loss_avg = averager()
optimizer = optim.Adam(craft.parameters(), lr=args.lr)

def train_batch(data):
    div = 10
    craft.train()
    img, char_label, interval_label = data
    img = img.to(device)
    char_label = char_label.to(device)
    interval_label = interval_label.to(device)

    img.requires_grad_()
    optimizer.zero_grad()
    preds, _ = craft(img)
    cost_char = criterion(preds[:,:,:,0], char_label).sum()/div
    cost_interval = criterion(preds[:,:,:,1], interval_label).sum()/div
    cost = cost_char + cost_interval
Exemple #5
0
def train(args):
    # load net
    net = CRAFT()  # initialize

    if not os.path.exists(args.trained_model):
        args.trained_model = None

    if args.trained_model is not None:
        print('Loading weights from checkpoint (' + args.trained_model + ')')
        if args.cuda:
            net.load_state_dict(test.copyStateDict(torch.load(args.trained_model)))
        else:
            net.load_state_dict(test.copyStateDict(torch.load(args.trained_model, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    # # LinkRefiner
    # refine_net = None
    # if args.refine:
    #     from refinenet import RefineNet
    #
    #     refine_net = RefineNet()
    #     print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')
    #     if args.cuda:
    #         refine_net.load_state_dict(test.copyStateDict(torch.load(args.refiner_model)))
    #         refine_net = refine_net.cuda()
    #         refine_net = torch.nn.DataParallel(refine_net)
    #     else:
    #         refine_net.load_state_dict(test.copyStateDict(torch.load(args.refiner_model, map_location='cpu')))
    #
    #     args.poly = True

    criterion = craft_utils.CRAFTLoss()
    optimizer = optim.Adam(net.parameters(), args.learning_rate)
    train_data = CRAFTDataset(args)
    dataloader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True)
    t0 = time.time()

    for epoch in range(args.max_epoch):
        pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f'Epoch {epoch}')
        running_loss = 0.0
        for i, data in pbar:
            x, y_region, y_link, y_conf = data
            x = x.cuda()
            y_region = y_region.cuda()
            y_link = y_link.cuda()
            y_conf = y_conf.cuda()
            optimizer.zero_grad()

            y, feature = net(x)

            score_text = y[:, :, :, 0]
            score_link = y[:, :, :, 1]

            L = criterion(score_text, score_link, y_region, y_link, y_conf)

            L.backward()
            optimizer.step()

            running_loss += L.data.item()
            if i % 2000 == 1999 or i == len(dataloader) - 1:
                pbar.set_postfix_str('[%d, %5d] loss: %.3f' %
                                     (epoch + 1, i + 1, running_loss / min(i + 1, 2000)))
                running_loss = 0.0

    # Save trained model
    torch.save(net.state_dict(), args.weight)

    print(f'training finished\n {time.time() - t0} spent for {args.max_epoch} epochs')
def train(train_img_path, train_gt_path, pths_path, batch_size, lr,
          num_workers, epoch_iter, save_interval):
    filenum = len(os.listdir(train_img_path))
    trainset = custom_dataset(train_img_path, train_gt_path)
    train_loader = data.DataLoader(trainset, batch_size=batch_size, \
                                   shuffle=True, num_workers=num_workers, drop_last=True)
    criterion = Maploss()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = CRAFT()
    data_parallel = False

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        data_parallel = True

    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=args.weight_decay)
    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=[epoch_iter // 2],
                                         gamma=0.1)

    step_index = 0
    for epoch in range(epoch_iter):
        if epoch % 50 == 0 and epoch != 0:
            step_index += 1
            adjust_learning_rate(optimizer, args.gamma, step_index)

        model.train()
        scheduler.step()
        epoch_loss = 0
        epoch_time = time.time()
        for i, (img, gt_score, gt_geo, ignored_map) in enumerate(train_loader):
            start_time = time.time()
            img, gt_score, gt_geo, ignored_map = img.to(device), gt_score.to(
                device), gt_geo.to(device), ignored_map.to(device)
            pred_score, pred_geo = model(img)
            loss = criterion(gt_score, pred_score, gt_geo, pred_geo,
                             ignored_map)

            epoch_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print('Epoch is [{}/{}], mini-batch is [{}/{}], time consumption is {:.8f}, batch_loss is {:.8f}'.format(\
                    epoch+1, epoch_iter, i+1, int(file_num/batch_size), time.time()-start_time, loss.item()))

        print('epoch_loss is {:.8f}, epoch_time is {:.8f}'.format(
            epoch_loss / int(file_num / batch_size),
            time.time() - epoch_time))
        print(time.asctime(time.localtime(time.time())))
        print('=' * 50)
        if (epoch + 1) % interval == 0:
            state_dict = model.module.state_dict(
            ) if data_parallel else model.state_dict()
            torch.save(
                state_dict,
                os.path.join(pths_path,
                             'model_epoch_{}.pth'.format(epoch + 1)))