Esempio n. 1
0
                loss = loss_r + loss_a

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            #loss.backward()
            #optimizer.step()
            scheduler.step()

            loss_value += loss.item()
            if index % 2 == 0 and index > 0:
                et = time.time()
                print(
                    'epoch {}:({}/{}) batch || training time for 32 batch {} || training loss {} || learning rate {}'
                    .format(epoch, index, len(sample_train_loader), et - st,
                            loss_value / 2, optimizer.param_groups[0]['lr']))
                loss_time = 0
                loss_value = 0
                st = time.time()

            if index % 500 == 0 and index > 0:
                print('Save epoch : {} , index : {}'.format(epoch, index))
                torch.save(
                    {
                        'epoch': epoch,
                        'model_state_dict': net.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': loss_value
                    }, '/root/data/test_param/{}_{}.pth'.format(epoch, index))
Esempio n. 2
0
def train(args):
    # load net
    net = CRAFT()  # initialize

    if not os.path.exists(args.trained_model):
        args.trained_model = None

    if args.trained_model is not None:
        print('Loading weights from checkpoint (' + args.trained_model + ')')
        if args.cuda:
            net.load_state_dict(test.copyStateDict(torch.load(args.trained_model)))
        else:
            net.load_state_dict(test.copyStateDict(torch.load(args.trained_model, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    # # LinkRefiner
    # refine_net = None
    # if args.refine:
    #     from refinenet import RefineNet
    #
    #     refine_net = RefineNet()
    #     print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')
    #     if args.cuda:
    #         refine_net.load_state_dict(test.copyStateDict(torch.load(args.refiner_model)))
    #         refine_net = refine_net.cuda()
    #         refine_net = torch.nn.DataParallel(refine_net)
    #     else:
    #         refine_net.load_state_dict(test.copyStateDict(torch.load(args.refiner_model, map_location='cpu')))
    #
    #     args.poly = True

    criterion = craft_utils.CRAFTLoss()
    optimizer = optim.Adam(net.parameters(), args.learning_rate)
    train_data = CRAFTDataset(args)
    dataloader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True)
    t0 = time.time()

    for epoch in range(args.max_epoch):
        pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f'Epoch {epoch}')
        running_loss = 0.0
        for i, data in pbar:
            x, y_region, y_link, y_conf = data
            x = x.cuda()
            y_region = y_region.cuda()
            y_link = y_link.cuda()
            y_conf = y_conf.cuda()
            optimizer.zero_grad()

            y, feature = net(x)

            score_text = y[:, :, :, 0]
            score_link = y[:, :, :, 1]

            L = criterion(score_text, score_link, y_region, y_link, y_conf)

            L.backward()
            optimizer.step()

            running_loss += L.data.item()
            if i % 2000 == 1999 or i == len(dataloader) - 1:
                pbar.set_postfix_str('[%d, %5d] loss: %.3f' %
                                     (epoch + 1, i + 1, running_loss / min(i + 1, 2000)))
                running_loss = 0.0

    # Save trained model
    torch.save(net.state_dict(), args.weight)

    print(f'training finished\n {time.time() - t0} spent for {args.max_epoch} epochs')
Esempio n. 3
0
                ), confidence_mask.cuda()

            pred_scores, _ = net(images)
            pred_region_scores = pred_scores[:, :, :, 0]
            pred_affinity_scores = pred_scores[:, :, :, 1]
            if use_cuda:
                pred_region_scores, pred_affinity_scores = pred_region_scores.cuda(
                ), pred_affinity_scores.cuda()
            loss = criterion(region_scores, affinity_scores,
                             pred_region_scores, pred_affinity_scores,
                             confidence_mask)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_value += loss.item()

            if (i + 1) % 10 == 0:
                pause = time.time()
                print(
                    'Epoch {}:({}/{}) batch\nTraining time for 10 batches: {}\nTraining loss: {}'
                    .format(epoch, i + 1, len(train_loader), pause - start,
                            loss_value / 10))
                loss_value = 0

            if i + 1 == len(real_data_loader):
                print('Saving state dict, version', epoch)
                torch.save(
                    net.state_dict(),
                    './models/realweights/Epoch_{}.pth'.format(epoch + 1))
def train(train_img_path, train_gt_path, pths_path, batch_size, lr,
          num_workers, epoch_iter, save_interval):
    filenum = len(os.listdir(train_img_path))
    trainset = custom_dataset(train_img_path, train_gt_path)
    train_loader = data.DataLoader(trainset, batch_size=batch_size, \
                                   shuffle=True, num_workers=num_workers, drop_last=True)
    criterion = Maploss()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = CRAFT()
    data_parallel = False

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        data_parallel = True

    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=args.weight_decay)
    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=[epoch_iter // 2],
                                         gamma=0.1)

    step_index = 0
    for epoch in range(epoch_iter):
        if epoch % 50 == 0 and epoch != 0:
            step_index += 1
            adjust_learning_rate(optimizer, args.gamma, step_index)

        model.train()
        scheduler.step()
        epoch_loss = 0
        epoch_time = time.time()
        for i, (img, gt_score, gt_geo, ignored_map) in enumerate(train_loader):
            start_time = time.time()
            img, gt_score, gt_geo, ignored_map = img.to(device), gt_score.to(
                device), gt_geo.to(device), ignored_map.to(device)
            pred_score, pred_geo = model(img)
            loss = criterion(gt_score, pred_score, gt_geo, pred_geo,
                             ignored_map)

            epoch_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print('Epoch is [{}/{}], mini-batch is [{}/{}], time consumption is {:.8f}, batch_loss is {:.8f}'.format(\
                    epoch+1, epoch_iter, i+1, int(file_num/batch_size), time.time()-start_time, loss.item()))

        print('epoch_loss is {:.8f}, epoch_time is {:.8f}'.format(
            epoch_loss / int(file_num / batch_size),
            time.time() - epoch_time))
        print(time.asctime(time.localtime(time.time())))
        print('=' * 50)
        if (epoch + 1) % interval == 0:
            state_dict = model.module.state_dict(
            ) if data_parallel else model.state_dict()
            torch.save(
                state_dict,
                os.path.join(pths_path,
                             'model_epoch_{}.pth'.format(epoch + 1)))