Exemple #1
0
def vald_net(net, loader, device):
    """Evaluation using MAE"""
    net.eval()
    n_val = len(loader) + 1
    mae = 0

    with tqdm(total=n_val, desc='Validation round', unit='batch',
              leave=False) as pbar:
        for batch in loader:
            imgs = batch['image']
            xyz_gt = batch['gt_xyz']
            assert imgs.shape[1] == 3, (
                f'Network has been defined with 3 input channels, but loaded '
                f'training images have {imgs.shape[1]} channels. Please check '
                f'that the images are loaded correctly.')

            assert xyz_gt.shape[1] == 3, (
                f'Network has been defined with 3 input channels, but loaded '
                f'XYZ images have {xyz_gt.shape[1]} channels. Please check '
                f'that the images are loaded correctly.')

            imgs = imgs.to(device=device, dtype=torch.float32)
            xyz_gt = xyz_gt.to(device=device, dtype=torch.float32)

            with torch.no_grad():
                rec_imgs, rendered_imgs = net(imgs)
                loss = utls.compute_loss(imgs, xyz_gt, rec_imgs, rendered_imgs)
                mae = mae + loss

            pbar.update(np.ceil(imgs.shape[0]))

    net.train()
    return mae / n_val
Exemple #2
0
def vald_net_normalization(net, loader, device):
    net.eval()
    n_val = len(loader) + 1
    score = 0
    with tqdm(total=n_val, desc='Validation round', unit='batch',
              leave=False) as pbar:
        for batch in loader:
            imgs = batch['input']
            gts = batch['gt']
            assert imgs.shape[1] == 3, \
                f'Network has been defined with 3 input channels, ' \
                f'but loaded training images have {imgs.shape[1]} channels. Please check that ' \
                'the images are loaded correctly.'

            assert gts.shape[1] == 3, \
                f'Network has been defined with 3 input channels, ' \
                f'but loaded AWB GT images have {gts.shape[1]} channels. Please check that ' \
                'the images are loaded correctly.'

            imgs = imgs.to(device=device, dtype=torch.float32)
            gts = gts.to(device=device, dtype=torch.float32)

            with torch.no_grad():
                results = net(imgs)
                loss = utils.compute_loss(results, gts)
                score = score + loss

            pbar.update(np.ceil(imgs.shape[0]))

        net.train()
        return score / n_val
Exemple #3
0
    def train_iteration(self, data, label, debug=1):
        """
        one iteration of learning
        """
        # forawrd feed, get y and x_hidden
        output, input_hidden = self.compute_output(data)
        # measure loss and gradient on y
        loss, g_output = compute_loss(output, label)
        # backprop, get gradient on weight
        g_w_hidden, g_b_hidden = self.compute_gradient(g_output, input_hidden)
        if debug:
            debugStr = 'w={0} \n xhidden={1} \n y={2} \n dy={3} \n dw={4}'.format(
                self.w, input_hidden, output, g_output, g_w_hidden)
            print(debugStr)

        return output, loss, g_w_hidden, g_b_hidden
Exemple #4
0
def Train():
    model_save_path = f'models/model_{datetime.datetime.now().strftime("%Y_%m_%d__%H%M%S")}.pth'
    val_loss_best = np.inf

    dataloader_train, dataloader_val, _, service_count = LoadDataset()
    net, optimizer, loss_function, device = LoadNetwork(service_count)

    # Prepare loss history
    hist_loss = np.zeros(EPOCHS)
    hist_loss_val = np.zeros(EPOCHS)
    for idx_epoch in range(EPOCHS):
        running_loss = 0
        for idx_batch, (x, y) in enumerate(dataloader_train):
            optimizer.zero_grad()

            # Propagate input
            netout = net(x.to(device))

            # Comupte loss
            loss = loss_function(y.to(device), netout)

            # Backpropage loss
            loss.backward()

            # Update weights
            optimizer.step()

            running_loss += loss.item()

        train_loss = running_loss / len(dataloader_train)
        val_loss = compute_loss(net, dataloader_val, loss_function,
                                device).item()
        print(train_loss)
        hist_loss[idx_epoch] = train_loss
        hist_loss_val[idx_epoch] = val_loss

        if val_loss < val_loss_best:
            val_loss_best = val_loss
            torch.save(net.state_dict(), model_save_path)

    plt.plot(hist_loss, 'o-', label='train')
    plt.plot(hist_loss_val, 'o-', label='val')
    plt.legend()
    plt.show()
    print(f"model exported to {model_save_path} with loss {val_loss_best:5f}")
Exemple #5
0
def train_net(net, device, dir_img, dir_gt, val_dir, val_dir_gt, epochs=300,
              batch_size=4, lr=0.0001, lrdf=0.5, lrdp=75, l2reg=0.001,
              chkpointperiod=1, patchsz=256, validationFrequency=10,
              save_cp=True):

    dir_checkpoint = 'checkpoints/'

    train = BasicDataset(dir_img, dir_gt, patch_size=patchsz)
    val = BasicDataset(val_dir, val_dir_gt, patch_size=patchsz)
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True,
                              num_workers=8, pin_memory=True)
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=False,
                            num_workers=8, pin_memory=True, drop_last=True)
    if use_tb:
        writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}')
    global_step = 0

    logging.info(f'''Starting training:
        Epochs:          {epochs} epochs
        Batch size:      {batch_size}
        Patch size:      {patchsz} x {patchsz}
        Learning rate:   {lr}
        Training size:   {len(train)}
        Validation size: {len(val)}
        Validation Frq.: {validationFrequency}
        Checkpoints:     {save_cp}
        Device:          {device.type}
        TensorBoard:     {use_tb}
    ''')

    optimizer = optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.999),
                           eps=1e-08, weight_decay=l2reg)
    scheduler = optim.lr_scheduler.StepLR(optimizer, lrdp, gamma=lrdf,
                                          last_epoch=-1)

    for epoch in range(epochs):
        net.train()

        epoch_loss = 0
        with tqdm(total=len(train), desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                xyz_gt = batch['gt_xyz']
                assert imgs.shape[1] == 3, (
                    f'Network has been defined with 3 input channels, '
                    f'but loaded training images have {imgs.shape[1]} channels.'
                    f' Please check that the images are loaded correctly.')

                assert xyz_gt.shape[1] == 3, (
                    f'Network has been defined with 3 input channels, '
                    f'but loaded XYZ images have {xyz_gt.shape[1]} channels. '
                    f'Please check that the images are loaded correctly.')

                imgs = imgs.to(device=device, dtype=torch.float32)
                xyz_gt = xyz_gt.to(device=device, dtype=torch.float32)

                rec_imgs, rendered_imgs = net(imgs)
                loss = utls.compute_loss(imgs, xyz_gt, rec_imgs, rendered_imgs)

                epoch_loss += loss.item()

                if use_tb:
                    writer.add_scalar('Loss/train', loss.item(), global_step)

                pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                pbar.update(np.ceil(imgs.shape[0]))
                global_step += 1

        if (epoch + 1) % validationFrequency == 0:
            val_score = vald_net(net, val_loader, device)
            logging.info('Validation loss: {}'.format(val_score))
            if use_tb:
                writer.add_scalar('learning_rate',
                                  optimizer.param_groups[0]['lr'], global_step)
                writer.add_scalar('Loss/test', val_score, global_step)
                writer.add_images('images', imgs, global_step)
                writer.add_images('rendered-imgs', rendered_imgs, global_step)
                writer.add_images('rec-xyz', rec_imgs, global_step)
                writer.add_images('gt-xyz', xyz_gt, global_step)

        scheduler.step()

        if save_cp and (epoch + 1) % chkpointperiod == 0:
            if not os.path.exists(dir_checkpoint):
                os.mkdir(dir_checkpoint)
                logging.info('Created checkpoint directory')

            torch.save(net.state_dict(), dir_checkpoint +
                       f'ciexyznet{epoch + 1}.pth')
            logging.info(f'Checkpoint {epoch + 1} saved!')

    if not os.path.exists('models'):
        os.mkdir('models')
        logging.info('Created trained models directory')
    torch.save(net.state_dict(), 'models/' + 'model_sRGB-XYZ-sRGB.pth')
    logging.info('Saved trained model!')
    if use_tb:
        writer.close()
    logging.info('End of training')
                    print('Epoch:', '%04d' % (idx_epoch + 1), 'loss =', '{:.6f}'.format(loss))
                    loss_list.append(loss.item())

                    # Backpropage loss
                    loss.backward()

                    # Update weights
                    optimizer.step()

                if ((idx_epoch+1)%test_interval) == 0:
                    test()
                    if max(correct_list) > best_correct:
                        best_correct = max(correct_list)
                        best_model = d_model
                        best_dropout = dropout
                val_loss = compute_loss(net, dataloader_train, loss_function, device).item()

                if val_loss < val_loss_best:
                    val_loss_best = val_loss

                if pbar is not None:
                    pbar.update()



    print("The best d_model is ",best_model)
    print("The best dropout is",best_dropout)
    print("The best correct is",best_correct)


    # print('\r\n', loss_list)
Exemple #7
0
def train_relighting_one_to_any(net, device, epochs, batch_size, lr,
                                val_percent, lrdf, lrdp, chkpointperiod,
                                patchsz, validationFrequency, in_dir_img,
                                gt_dir_img, tr_dir_img, resizing, save_cp):
    dir_checkpoint = 'checkpoints_relighting/'
    dataset = DataLoading(in_dir_img,
                          gt_dir=gt_dir_img,
                          patch_size=patchsz,
                          target_dir=tr_dir_img,
                          resizing=resizing)
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True)
    val_loader = DataLoader(val,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True,
                            drop_last=True)
    if use_tb:
        writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}')
    global_step = 0

    logging.info(f'''Starting training relighting net:
        Epochs:          {epochs} epochs
        Batch size:      {batch_size}
        Patch size:      {patchsz} x {patchsz}
        Learning rate:   {lr}
        Training size:   {n_train}
        Validation size: {n_val}
        Validation Frq.: {validationFrequency}
        Checkpoints:     {save_cp}
        Device:          {device.type}
        TensorBoard:     {use_tb}
    ''')

    optimizer = optim.Adam(net.parameters(),
                           lr=lr,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0.00001)
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          lrdp,
                                          gamma=lrdf,
                                          last_epoch=-1)

    for epoch in range(epochs):
        net.train()

        epoch_loss = 0
        with tqdm(total=n_train,
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['input']
                gts = batch['gt']
                targets = batch['target']
                assert imgs.shape[1] == 3, \
                    f'Network has been defined with 3 input channels, ' \
                    f'but loaded training images have {imgs.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                assert gts.shape[1] == 3, \
                    f'Network has been defined with 3 input channels, ' \
                    f'but loaded GT images have {gts.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                assert targets.shape[1] == 3, \
                    f'Network has been defined with 3 input channels, ' \
                    f'but loaded guide images have {targets.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                imgs = imgs.to(device=device, dtype=torch.float32)
                gts = gts.to(device=device, dtype=torch.float32)
                targets = targets.to(device=device, dtype=torch.float32)

                results = net(imgs, t=targets)
                loss = utils.compute_loss(results, gts)
                epoch_loss += loss.item()
                if use_tb:
                    writer.add_scalar('Loss/train', loss.item(), global_step)
                pbar.set_postfix(**{'loss (batch)': loss.item()})
                optimizer.zero_grad()
                loss.backward()
                # nn.utils.clip_grad_value_(net.parameters(), 0.1)
                optimizer.step()
                pbar.update(np.ceil(imgs.shape[0]))
                global_step += 1

        if (epoch + 1) % validationFrequency == 0:
            val_score = vald_net(net, val_loader, 'relighting_one_to_any',
                                 device)
            logging.info('Validation MAE: {}'.format(val_score))
            if use_tb:
                writer.add_scalar('learning_rate',
                                  optimizer.param_groups[0]['lr'], global_step)
                writer.add_scalar('Loss/test', val_score, global_step)
                writer.add_images('images', imgs, global_step)
                writer.add_images('result', results, global_step)
                writer.add_images('GT', gts, global_step)

        scheduler.step()

        if save_cp and (epoch + 1) % chkpointperiod == 0:
            if not os.path.exists(dir_checkpoint):
                os.mkdir(dir_checkpoint)
                logging.info('Created checkpoint directory')

            torch.save(
                net.state_dict(),
                dir_checkpoint + f'relighting_net_one_to_any_{epoch + 1}.pth')
            logging.info(f'Checkpoint {epoch + 1} saved!')

    if not os.path.exists('models'):
        os.mkdir('models')
        logging.info('Created trained models directory')
    torch.save(net.state_dict(), 'models/' + 'relighting_net_one_to_any.pth')
    logging.info('Saved trained model!')
    if use_tb:
        writer.close()
    logging.info('End of training relighting net')
    return net