Ejemplo n.º 1
0
def test_model(model, criterion, dataload):
    '''
    if dataload is not None:
        valid_loss = 0
        valid_acc = 0
        print(int(len(dataload)))
        for x,y in dataload:
            #inputs = x.to(device)
            #labels = y.to(device)
            out = model(x)
            loss = criterion(out,y)
            print(loss.item())
            valid_loss += loss.item()
            #valid_acc += get_acc(out,y)
            
        #epoch_str = ("Valid Loss: %f, Valid Acc: %f, " % valid_loss / len(dataload),valid_acc / len(dataload))
        epoch_str = ("Valid Loss: %f " % valid_loss / int(len(dataload)))
        
    else:
        epoch_str = ("test_dataload is none")
    
    print(epoch_str)
    '''
    iou, loss = eval_net(model, criterion, dataload, True)
    #dice,loss = eval_net(model,criterion,dataload)
    print("ave_test_iou:{},ave_test_loss:{}".format(iou, loss))
    '''
    save_data = "ckp_xin/test/caoying.txt"
    with open(save_data,'a+') as fw:
        fw.write("ave_test_dice:{},ave_test_loss:{}".format(dice,loss))
        fw.write("\n")
    '''

    return iou, loss
Ejemplo n.º 2
0
def eval_one_epoch(net, eval_dataloader, device, global_step, fw_iou_avg,
                   writer):
    val_loss, pixel_acc_avg, mean_iou_avg, _fw_iou_avg = eval_net(
        net, eval_dataloader, device)
    if fw_iou_avg < _fw_iou_avg:
        fw_iou_avg = _fw_iou_avg
    logging.info('Validation cross entropy: {}'.format(val_loss))
    writer.add_scalar('Loss/test', val_loss, global_step)
    writer.add_scalar('pixel_acc_avg', pixel_acc_avg, global_step)
    writer.add_scalar('mean_iou_avg', mean_iou_avg, global_step)
    writer.add_scalar('fw_iou_avg', fw_iou_avg, global_step)
    return fw_iou_avg
        true_masks = batch['mask']
        # print(true_masks.size())
        imgs = imgs.to(device=device, dtype=torch.float32)
        mask_type = torch.float32 if net.n_classes == 1 else torch.long
        true_masks = true_masks.to(device=device, dtype=mask_type)

        masks_pred = net(imgs)
        # masks_pred = masks_pred.to("cpu", torch.double)
        # print(masks_pred.size())

        loss = criterion(masks_pred, true_masks)
        # loss = criterion_(masks_pred, true_masks)
        # print("###########")
        # print(loss.item())
        epoch_loss += loss.item()
        # writer.add_scalar('Loss/train', loss.item(), global_step)

        print("epoch : %d, batch : %5d, loss : %.5f" %
              (epoch, (global_step / batch_size), loss.item()))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        global_step += 1
        if global_step // 1000 == 0 and global_step > 1000:
            val_pre = eval_net(net, val_loader, device)
            print("val loss : %.5f" % val_pre)

    if epoch % 10 == 0 and epoch > 0:
        torch.save(net.state_dict(), dir_checkpoint + f'epoch_%d.pth' % epoch)
Ejemplo n.º 4
0
def train_net(net,
              device,
              epochs=10,
              batch_size=1,
              lr=0.001,
              val_percent=0.1,
              save_cp=True,
              img_scale=0.5,
              gamma=1):

    dataset = TUMDataset(train_seqs, root_dir)
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    torch.manual_seed(0)
    train, val = random_split(dataset, [n_train, n_val])
    torch.manual_seed(torch.initial_seed())
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True)

    writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
    global_step = 0

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_cp}
        Device:          {device.type}
        Images scaling:  {img_scale}
    ''')

    optimizer = optim.Adam(net.parameters(), lr=lr) # weight_decay=1e-8, momentum=0.9
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)
    
    if net.n_classes > 1:
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = CustomMSELoss#nn.MSELoss()

    Superpoint_model = PoseEstimation()

    for epoch in range(epochs):
        net.train()

        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                im1 = batch["gray1"]
                im2 = batch["gray2"]
                d1 = batch["depth1"]
                d2 = batch["depth2"]
                
                hm1, hm2 = Superpoint_model.forward(im1, im2)

                fx = batch["fx"]
                fy = batch["fy"]
                cx = batch["cx"]
                cy = batch["cy"]
                
                imgs = torch.cat((torch.unsqueeze(im1,1).to(device),
                                    torch.unsqueeze(im2,1).to(device),
                                    torch.unsqueeze(d1,1).to(device),
                                    torch.unsqueeze(d2,1).to(device),
                                    torch.unsqueeze(hm1,1).to(device),
                                    torch.unsqueeze(hm2,1).to(device),
                                ),1)

                assert imgs.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if net.n_classes == 1 else torch.long
                GT_depth = torch.unsqueeze(d1,1).to(device=device, dtype=mask_type)

                depth_pred = net(imgs)
                loss = criterion(depth_pred, GT_depth)

                #unproj_loss = unproject_loss(hm1, hm2, batch, device)
                #loss += gamma * unproj_loss

                epoch_loss += loss.item()
                writer.add_scalar('Loss/train', loss.item(), global_step)

                pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_value_(net.parameters(), 0.1)
                optimizer.step()

                pbar.update(imgs.shape[0])
                global_step += 1
                if global_step % (n_train // (10 * batch_size)) == 0:#######################
                    for tag, value in net.named_parameters():
                        tag = tag.replace('.', '/')
                        writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)
                        writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)
                    val_score = eval_net(net, val_loader, device, Superpoint_model)
                    scheduler.step(val_score)
                    writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)

                    if net.n_classes > 1:
                        logging.info('Validation cross entropy: {}'.format(val_score))
                        writer.add_scalar('Loss/test', val_score, global_step)
                    else:
                        logging.info('Validation Loss: {}'.format(val_score))
                        writer.add_scalar('Loss/test', val_score, global_step)

                    writer.add_images('images', torch.unsqueeze(imgs[:,0,:,:],1), global_step)
                    if net.n_classes == 1:
                        writer.add_images('depth/true', GT_depth, global_step)
                        writer.add_images('depth/pred', depth_pred, global_step)

        if save_cp:
            try:
                os.mkdir(dir_checkpoint)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            torch.save(net.state_dict(),
                       dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
            logging.info(f'Checkpoint {epoch + 1} saved !')

    writer.close()
Ejemplo n.º 5
0
def train_net(net, epochs=5, batch_size=2, lr=0.1, save_cp=True, gpu=True):
    train_img_dir = './FarmlandC_data/train/image/'
    train_msk_dir = './FarmlandC_data/train/label/'
    val_img_dir = './FarmlandC_data/val/image/'
    val_msk_dir = './FarmlandC_data/val/label/'
    dir_checkpoint = './checkpoints_C/'

    train_img_list = os.listdir(train_img_dir)
    val_img_list = os.listdir(val_img_dir)

    print('''
Starting training:
    Training size: {}
    Validation size: {}
    Epochs: {}
    Batch size: {}
    Learning rate: {}
    Checkpoints: {}
    CUDA: {}
    '''.format(len(train_img_list), len(val_img_list), epochs, batch_size, lr,
               str(save_cp), str(gpu)))

    optimizer = optim.Adam(net.parameters(),
                           lr=lr,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0,
                           amsgrad=False)

    max_val_loss = 1000.
    iter_num = 0
    for epoch in range(epochs):
        print('Starting epoch {}/{}.'.format(epoch, epochs))
        net.train()
        sub_num = 0
        epoch_loss = 0
        kits = SAR_Dataset(imgPath=train_img_dir, labPath=train_msk_dir)
        train_loader = DataLoader(kits,
                                  batch_size,
                                  shuffle=True,
                                  num_workers=4)

        for i, data in enumerate(train_loader):
            imgs = data[0]
            true_masks = data[1]
            if gpu:
                imgs = imgs.float().cuda()
                true_masks = true_masks.float().cuda()
            masks_pred = net(imgs)
            loss = binary_loss(masks_pred, true_masks)
            print('iter:' + str(sub_num) + ' ' + str('%4f' % loss.item()))
            epoch_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            sub_num += 1

        #训练时进行验证
        with torch.no_grad():
            val_loss = eval_net(net, val_img_dir, val_msk_dir, batch_size, gpu)
            print('Validation loss: {}'.format(val_loss))

        #如果验证损失变小,就保存模型
        if save_cp and val_loss < max_val_loss:
            torch.save(
                net.state_dict(), dir_checkpoint +
                'CP{}_val_{}.pth'.format(iter_num, '%4f' % val_loss))
            print('Checkpoint {} saved !'.format(iter_num))
            max_val_loss = val_loss
        iter_num += 1