示例#1
0
def train(train_loader, val_loader, class_weights):
    model = ENet(num_classes)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.Adam(model.parameters(), lr=5e-4, weight_decay=2e-4)
    lr_updater = lr_scheduler.StepLR(
        optimizer, 10, 1e-7)  # Large dataset, decaying every 10 epochs..
    ignore_index = None
    metric = IoU(num_classes, ignore_index=ignore_index)

    model = model.cuda()
    criterion = criterion.cuda()

    # model, optimizer, start_epoch, best_miou = utils.load_checkpoint(
    #        model, optimizer, args.save_dir, args.name)
    # print("Resuming from model: Start epoch = {0} "
    #       "| Best mean IoU = {1:.4f}".format(start_epoch, best_miou))
    start_epoch = 0
    best_miou = 0
    train = Train(model,
                  train_loader,
                  optimizer,
                  criterion,
                  metric,
                  use_cuda=True)
    val = Test(model, val_loader, criterion, metric, use_cuda=True)
    n_epochs = 200
    for epoch in range(start_epoch, n_epochs):
        print(">>>> [Epoch: {0:d}] Training".format(epoch))

        lr_updater.step()
        epoch_loss, (iou, miou) = train.run_epoch(iteration_loss=True)

        print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
              format(epoch, epoch_loss, miou))

        if (epoch + 1) % 10 == 0 or epoch + 1 == n_epochs:
            print(">>>> [Epoch: {0:d}] Validation".format(epoch))

            loss, (iou, miou) = val.run_epoch(iteration_loss=True)

            print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
                  format(epoch, loss, miou))

            # Print per class IoU on last epoch or if best iou
            if epoch + 1 == n_epochs or miou > best_miou:
                for class_iou in iou:
                    print(class_iou)

            # Save the model if it's the best thus far
            if miou > best_miou:
                print("\nBest model thus far. Saving...\n")
                best_miou = miou
                torch.save(
                    model.state_dict(),
                    '/mnt/disks/data/d4dl/snapshots/snapshot_' + str(epoch) +
                    '.pt')
    return model
示例#2
0
def main_script(args):

    # Fail fast if the dataset directory doesn't exist
    assert os.path.isdir(
        args.dataset_dir), "The directory \"{0}\" doesn't exist.".format(
            args.dataset_dir)

    # Fail fast if the saving directory doesn't exist
    assert os.path.isdir(
        args.save_dir), "The directory \"{0}\" doesn't exist.".format(
            args.save_dir)

    # Import the requested dataset
    if args.dataset.lower() == 'camvid':
        from data import CamVid as dataset
    elif args.dataset.lower() == 'cityscapes':
        from data import Cityscapes as dataset
    else:
        # Should never happen...but just in case it does
        raise RuntimeError("\"{0}\" is not a supported dataset.".format(
            args.dataset))

    loaders, w_class, class_encoding = load_dataset(dataset, args.color_space,
                                                    args.hue_value)
    train_loader, val_loader, test_loader = loaders

    if args.mode.lower() in {'train', 'full'}:
        model = train(train_loader, val_loader, w_class, class_encoding)
        if args.mode.lower() == 'full':
            test(model, test_loader, w_class, class_encoding)
    elif args.mode.lower() == 'test':
        # Intialize a new ENet model
        num_classes = len(class_encoding)
        model = ENet(num_classes)
        if use_cuda:
            model = model.cuda()

        # Here we register forward hooks for each layer.
        # model.initial_block.register_forward_hook(save_activations)
        # model.downsample1_0.register_forward_hook(save_activations)
        # model.regular1_1.register_forward_hook(save_activations)
        # model.downsample2_0.register_forward_hook(save_activations)
        # Initialize a optimizer just so we can retrieve the model from the
        # checkpoint
        optimizer = optim.Adam(model.parameters())

        # Load the previoulsy saved model state to the ENet model
        model = utils.load_checkpoint(model, optimizer, args.save_dir,
                                      args.name)[0]
        test(model, test_loader, w_class, class_encoding)
    else:
        # Should never happen...but just in case it does
        raise RuntimeError(
            "\"{0}\" is not a valid choice for execution mode.".format(
                args.mode))
示例#3
0
def predict():
    image_transform = transforms.Compose(
        [transforms.Resize(target_size),
         transforms.ToTensor()])

    label_transform = transforms.Compose(
        [transforms.Resize(target_size),
         ext_transforms.PILToLongTensor()])

    # Get selected dataset
    # Load the training set as tensors
    train_set = Cityscapes(data_dir,
                           mode='test',
                           transform=image_transform,
                           label_transform=label_transform)

    class_encoding = train_set.color_encoding

    num_classes = len(class_encoding)
    model = ENet(num_classes).to(device)

    # Initialize a optimizer just so we can retrieve the model from the
    # checkpoint
    optimizer = optim.Adam(model.parameters())

    # Load the previoulsy saved model state to the ENet model
    model = utils.load_checkpoint(model, optimizer, 'save',
                                  'ENet_cityscapes_mine.pth')[0]
    # print(model)

    image = Image.open('images/mainz_000000_008001_leftImg8bit.png')
    images = Variable(image_transform(image).to(device).unsqueeze(0))
    image = np.array(image)

    # Make predictions!
    predictions = model(images)
    _, predictions = torch.max(predictions.data, 1)
    # 0~18
    prediction = predictions.cpu().numpy()[0] - 1

    mask_color = np.asarray(label_to_color_image(prediction, 'cityscapes'),
                            dtype=np.uint8)
    mask_color = cv2.resize(mask_color, (image.shape[1], image.shape[0]))
    print(image.shape)
    print(mask_color.shape)
    res = cv2.addWeighted(image, 0.3, mask_color, 0.7, 0.6)
    # cv2.imshow('rr', mask_color)
    cv2.imshow('combined', res)
    cv2.waitKey(0)
示例#4
0
文件: main.py 项目: iamstg/vegans
def main():
    """Main function."""

    loaders, class_weights, class_encoding = load_dataset(dataset)
    train_loader, val_loader, test_loader = loaders

    num_classes = len(class_encoding)

    critic = DiscriminativeNet()
    generator = ENet(num_classes)

    dataloader = load_real_data(real_dataset)

    optimizer_D = optim.Adam(critic.parameters(),
                             lr=0.0001,
                             betas=(0.5, 0.999))
    optimizer_G = optim.Adam(generator.parameters(),
                             lr=0.0001,
                             betas=(0.5, 0.999))

    gan = WGANGP(generator,
                 critic,
                 dataloader,
                 train_loader,
                 test_loader,
                 class_weights,
                 class_encoding,
                 ngpu=ngpu,
                 device=device,
                 nr_epochs=500,
                 print_every=10,
                 save_every=400,
                 optimizer_D=optimizer_D,
                 optimizer_G=optimizer_G)

    gan.train()
    samples_l, D_losses, G_losses = gan.get_training_results()
示例#5
0
def train(train_loader, val_loader, class_weights, class_encoding):
    print("\nTraining...\n")

    num_classes = len(class_encoding)

    # Intialize ENet
    model = ENet(num_classes).to(device)
    # Check if the network architecture is correct
    print(model)

    # We are going to use the CrossEntropyLoss loss function as it's most
    # frequentely used in classification problems with multiple classes which
    # fits the problem. This criterion  combines LogSoftMax and NLLLoss.
    criterion = nn.CrossEntropyLoss(weight=class_weights)

    # ENet authors used Adam as the optimizer
    optimizer = optim.Adam(
        model.parameters(),
        lr=args.learning_rate,
        weight_decay=args.weight_decay)

    # Learning rate decay scheduler
    lr_updater = lr_scheduler.StepLR(optimizer, args.lr_decay_epochs,
                                     args.lr_decay)

    # Evaluation metric
    if args.ignore_unlabeled:
        ignore_index = list(class_encoding).index('unlabeled')
    else:
        ignore_index = None
    metric = IoU(num_classes, ignore_index=ignore_index)

    # Optionally resume from a checkpoint
    if args.resume:
        model, optimizer, start_epoch, best_miou = utils.load_checkpoint(
            model, optimizer, args.save_dir, args.name)
        print("Resuming from model: Start epoch = {0} "
              "| Best mean IoU = {1:.4f}".format(start_epoch, best_miou))
    else:
        start_epoch = 0
        best_miou = 0

    # Start Training
    print()
    train = Train(model, train_loader, optimizer, criterion, metric, device)
    val = Test(model, val_loader, criterion, metric, device)
    for epoch in range(start_epoch, args.epochs):
        print(">>>> [Epoch: {0:d}] Training".format(epoch))

        epoch_loss, (iou, miou) = train.run_epoch(args.print_step)
        lr_updater.step()

        print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
              format(epoch, epoch_loss, miou))

        if (epoch + 1) % 10 == 0 or epoch + 1 == args.epochs:
            print(">>>> [Epoch: {0:d}] Validation".format(epoch))

            loss, (iou, miou) = val.run_epoch(args.print_step)

            print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
                  format(epoch, loss, miou))

            # Print per class IoU on last epoch or if best iou
            if epoch + 1 == args.epochs or miou > best_miou:
                for key, class_iou in zip(class_encoding.keys(), iou):
                    print("{0}: {1:.4f}".format(key, class_iou))

            # Save the model if it's the best thus far
            if miou > best_miou:
                print("\nBest model thus far. Saving...\n")
                best_miou = miou
                utils.save_checkpoint(model, optimizer, epoch + 1, best_miou,
                                      args)

    return model
示例#6
0
        from data import Cityscapes as dataset
    else:
        # Should never happen...but just in case it does
        raise RuntimeError("\"{0}\" is not a supported dataset.".format(
            args.dataset))

    loaders, w_class, class_encoding = load_dataset(dataset)
    train_loader, val_loader, test_loader = loaders

    if args.mode.lower() in {'train', 'full'}:
        model = train(train_loader, val_loader, w_class, class_encoding)

    if args.mode.lower() in {'test', 'full'}:
        if args.mode.lower() == 'test':
            # Intialize a new ENet model
            num_classes = len(class_encoding)
            model = ENet(num_classes).to(device)

        # Initialize a optimizer just so we can retrieve the model from the
        # checkpoint
        optimizer = optim.Adam(model.parameters())

        # Load the previoulsy saved model state to the ENet model
        model = utils.load_checkpoint(model, optimizer, args.save_dir,
                                      args.name)[0]

        if args.mode.lower() == 'test':
            print(model)

        test(model, test_loader, w_class, class_encoding)
示例#7
0
class Trainer(object): 
    def __init__(self, exp): 
        # IoU and pixAcc Metric calculator
        self.metric = SegmentationMetric(7)
        cfg_path = os.path.join(os.getcwd(), 'config/tusimple_config.yaml') 
        self.exp_name = exp
        self.writer = SummaryWriter('tensorboard/' + self.exp_name)
        with open(cfg_path) as file: 
            cfg = yaml.load(file, Loader=yaml.FullLoader)
        self.device = torch.device(cfg['DEVICE'])
        self.max_epochs = cfg['TRAIN']['MAX_EPOCHS']
        self.dataset_path = cfg['DATASET']['PATH']
        # TODO remove this and refactor PROPERLY
        self.input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg['DATASET']['MEAN'], cfg['DATASET']['STD']),
        ])

        mean = cfg['DATASET']['MEAN']
        std = cfg['DATASET']['STD']
        self.train_transform = Compose(Resize(size=(645,373)), RandomCrop(size=(640,368)), RandomFlip(0.5), Rotation(2), ToTensor(), Normalize(mean=mean, std=std))

        self.val_transform = Compose(Resize(size=(640,368)), ToTensor(), Normalize(mean=mean, std=std))
        data_kwargs = {
            'transform': self.input_transform, 
            'size': cfg['DATASET']['SIZE'],
        } 
        self.train_dataset = tuSimple(
                path=cfg['DATASET']['PATH'],
                image_set='train',
                transforms=self.train_transform
                ) 
        self.val_dataset = tuSimple(
                path = cfg['DATASET']['PATH'],
                image_set = 'val',
                transforms =self.val_transform,
                )
        self.train_loader = data.DataLoader(
                dataset = self.train_dataset,
                batch_size = cfg['TRAIN']['BATCH_SIZE'],
                shuffle = True,
                num_workers = 0,
                pin_memory = True,
                drop_last = True,
                )
        self.val_loader = data.DataLoader(
                dataset = self.val_dataset,
                batch_size = cfg['TRAIN']['BATCH_SIZE'],
                shuffle = False,
                num_workers = 0, 
                pin_memory = True,
                drop_last = False,
                ) 
        self.iters_per_epoch = len(self.train_dataset) // (cfg['TRAIN']['BATCH_SIZE'])
        self.max_iters = cfg['TRAIN']['MAX_EPOCHS'] * self.iters_per_epoch
        # -------- network --------
        weight = [0.4, 1, 1, 1, 1, 1, 1]
        self.model = ENet(num_classes=7).to(self.device) 
        self.optimizer = optim.SGD(
            self.model.parameters(),
            lr=cfg['OPTIM']['LR'],
            weight_decay=cfg['OPTIM']['DECAY'],
            momentum=0.9,
        )
        self.lr_scheduler = get_scheduler(self.optimizer, max_iters=self.max_iters, iters_per_epoch=self.iters_per_epoch)
        #self.optimizer = optim.Adam(
        #    self.model.parameters(),
        #    lr = cfg['OPTIM']['LR'],
        #    weight_decay=0,
        #    )
        self.criterion = nn.CrossEntropyLoss(weight=torch.tensor([0.4, 1, 1, 1, 1, 1, 1])).cuda() 
        self.bce = nn.BCELoss().cuda()
    def train(self, epoch, start_time):
        running_loss = 0.0
        is_better = True
        prev_loss = float('inf') 
        logging.info('Start training, Total Epochs: {:d}, Total Iterations: {:d}'.format(self.max_epochs, self.max_iters))
        print("Train Epoch: {}".format(epoch))
        self.model.train() 
        epoch_loss = 0
        #progressbar = tqdm(range(len(self.train_loader)))
        iteration = epoch * self.iters_per_epoch if epoch > 0 else 0
        start_time = start_time
        for batch_idx, sample in enumerate(self.train_loader): 
            iteration += 1
            img = sample['img'].to(self.device) 
            segLabel = sample['segLabel'].to(self.device) 
            exist = sample['exist'].to(self.device)
            # outputs is crossentropy, sig is binary cross entropy
            outputs, sig = self.model(img) 
            ce = self.criterion(outputs,segLabel)
            bce = self.bce(sig, exist)
            loss = ce + (0.1 * bce) 


            self.optimizer.zero_grad() 
            loss.backward() 
            self.optimizer.step()
            self.lr_scheduler.step()
            #print("LR", self.optimizer.param_groups[0]['lr'])

            epoch_loss += loss.item() 
            running_loss += loss.item() 
            eta_seconds = ((time.time() - start_time) / iteration) * (self.max_iters - iteration) 
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
            iter_idx = epoch * len(self.train_loader) + batch_idx
            #progressbar.set_description("Batch loss: {:.3f}".format(loss.item()))
            #progressbar.update(1)
            # Tensorboard
            if iteration % 10 == 0:
                logging.info(
                "Epoch: {:d}/{:d} || Iters: {:d}/{:d} || Lr: {:6f} || "
                "Loss: {:.4f} || Cost Time: {} || Estimated Time: {}".format(
                epoch, self.max_epochs, iteration % self.iters_per_epoch, self.iters_per_epoch, 
                self.optimizer.param_groups[0]['lr'], loss.item(), str(datetime.timedelta(seconds=int(time.time() - start_time))), eta_string))
            if batch_idx % 10 == 9: 
                self.writer.add_scalar('train loss',
                                running_loss / 10,
                                epoch * len(self.train_loader) + batch_idx + 1)
                running_loss = 0.0
        #progressbar.close() 
        if epoch % 1 == 0: 
            save_dict = {
                    "epoch": epoch,
                    "model": self.model.state_dict(),
                    "optim": self.optimizer.state_dict(),
                    "best_val_loss": best_val_loss,
                    }
            save_name = os.path.join(os.getcwd(), 'results', self.exp_name, 'run.pth')
            save_name_epoch = os.path.join(os.getcwd(), 'results', self.exp_name, '{}.pth'.format(epoch))
            torch.save(save_dict, save_name) 
            torch.save(save_dict, save_name_epoch) 
            print("Model is saved: {}".format(save_name))
            print("Model is saved: {}".format(save_name_epoch))
            print("+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*")
        return epoch_loss/len(self.train_loader)
    def val(self, epoch, train_loss):
        self.metric.reset()
        global best_val_loss
        global best_mIoU
        print("Val Epoch: {}".format(epoch))
        self.model.eval()
        val_loss = 0 
        #progressbar = tqdm(range(len(self.val_loader)))
        with torch.no_grad(): 
            for batch_idx, sample in enumerate(self.val_loader):
                img = sample['img'].to(self.device) 
                segLabel = sample['segLabel'].to(self.device) 
                exist = sample['exist'].to(self.device)
                outputs, sig = self.model(img) 
                ce = self.criterion(outputs, segLabel)
                bce = self.bce(sig, exist)
                loss = ce + (0.1*bce) 
                val_loss += loss.item() 
                self.metric.update(outputs, segLabel)
                pixAcc, mIoU = self.metric.get()
                logging.info("Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
                    batch_idx + 1, pixAcc * 100, mIoU * 100))
                #progressbar.set_description("Batch loss: {:3f}".format(loss.item()))
                #progressbar.update(1)
                # Tensorboard
                #if batch_idx + 1 == len(self.val_loader):
                #    self.writer.add_scalar('train - val loss',
                #                    train_loss - (val_loss / len(self.val_loader)),
                #                    epoch)
        #progressbar.close() 
        pixAcc, mIoU, category_iou = self.metric.get(return_category_iou = True)
        print(category_iou)
        logging.info('End validation pixAcc: {:.3f}, mIoU: {:.3f}'.format(
            pixAcc * 100, mIoU * 100))
        iter_idx = (epoch + 1) * len(self.train_loader)
        with open('val_out.txt', 'a') as out:
            sys.stdout = out
            print(self.exp_name, 'Epoch:', epoch, 'pixAcc: {:.3f}, mIoU: {:.3f}'.format(pixAcc*100, mIoU*100))
            sys.stdout = original_stdout
        print("Validation loss: {}".format(val_loss)) 
        print("+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*")
        if (mIoU * 100) > best_mIoU:
            best_mIoU = mIoU*100
            save_dict = {
                    "epoch": epoch,
                    "model": self.model.state_dict(),
                    "optim": self.optimizer.state_dict(),
                    "best_val_loss": best_val_loss,
                    "best_mIoU": best_mIoU,
                    }
            save_name = os.path.join(os.getcwd(), 'results', self.exp_name, 'best_mIoU.pth')
            torch.save(save_dict, save_name)
            print("mIoU is higher than best mIoU! Model saved to {}".format(save_name))
        #if val_loss < best_val_loss: 
        #    best_val_loss = val_loss
        #    save_name = os.path.join(os.getcwd(), 'results', self.exp_name, 'run.pth') 
        #    copy_name = os.path.join(os.getcwd(), 'results', self.exp_name, 'run_best.pth') 
        #    print("val loss is lower than best val loss! Model saved to {}".format(copy_name))
        #    shutil.copyfile(save_name, copy_name) 
    
    def eval(self):
        print("+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*")
        print("Evaluating.. ") 
        self.model.eval() 
        val_loss = 0 
        dump_to_json = [] 
        test_dataset = tuSimple(
                path=self.dataset_path,
                image_set='test',
                transforms=self.val_transform
                ) 
        test_loader = data.DataLoader(
                dataset = test_dataset,
                batch_size = 12, 
                shuffle = False,
                num_workers = 0, 
                pin_memory = True,
                drop_last = False,
                ) 
        progressbar = tqdm(range(len(test_loader))) 
        with torch.no_grad():
            with open('exist_out.txt','w') as f:
                for batch_idx, sample in enumerate(test_loader): 
                    img = sample['img'].to(self.device) 
                    img_name = sample['img_name']
                    #segLabel = sample['segLabel'].to(self.device) 
                    outputs, sig = self.model(img) 
                    seg_pred = F.softmax(outputs, dim=1)
                    seg_pred = seg_pred.detach().cpu().numpy()
                    exist_pred = sig.detach().cpu().numpy()
                    count = 0

                    for img_idx in range(len(seg_pred)):
                        seg = seg_pred[img_idx]
                        exist = [1 if exist_pred[img_idx ,i] > 0.5 else 0 for i in range(6)]
                        lane_coords = getLane.prob2lines_tusimple(seg, exist, resize_shape=(720,1280), y_px_gap=10, pts=56)
                        for i in range(len(lane_coords)):
                            # sort lane coords
                            lane_coords[i] = sorted(lane_coords[i], key=lambda pair:pair[1])
                        
                        #print(len(lane_coords))
                    # Visualisation 
                        savename = "{}/{}_{}_vis.png".format(os.path.join(os.getcwd(), 'vis'), batch_idx, count) 
                        count += 1
                        raw_file_name = img_name[img_idx]
                        pred_json = {}
                        pred_json['lanes'] = []
                        pred_json['h_samples'] = []
                        # truncate everything before 'clips' to be consistent with test_label.json gt
                        pred_json['raw_file'] = raw_file_name[raw_file_name.find('clips'):]
                        pred_json['run_time'] = 0

                        for l in lane_coords:
                            empty = all(lane[0] == -2 for lane in l)
                            if len(l)==0:
                                continue
                            if empty:
                                continue
                            pred_json['lanes'].append([])
                            for (x,y) in l:
                                pred_json['lanes'][-1].append(int(x))
                        for (x, y) in lane_coords[0]:
                            pred_json['h_samples'].append(int(y))
                        dump_to_json.append(json.dumps(pred_json))
                    progressbar.update(1)
                progressbar.close() 

                with open(os.path.join(os.getcwd(), "results", self.exp_name, "pred_json.json"), "w") as f:
                    for line in dump_to_json:
                        print(line, end="\n", file=f)

                print("Saved pred_json.json to {}".format(os.path.join(os.getcwd(), 'results', self.exp_name, "pred_json.json")))
           
                '''
                        raw_img = img[b].cpu().detach().numpy()
                        raw_img = raw_img.transpose(1, 2, 0)
                        # Normalize both to 0..1
                        min_val, max_val = np.min(raw_img), np.max(raw_img)
                        raw_img = (raw_img - min_val) / (max_val - min_val)
                        #rgb = rgb / 255.
                        #stack = np.hstack((raw_img, rgb))
                        background = Image.fromarray(np.uint8(raw_img*255))
                        overlay = Image.fromarray(rgb)
                        new_img = Image.blend(background, overlay, 0.4)
                        new_img.save(savename, "PNG")
                '''
                        
                '''
示例#8
0
class Trainer(object):
    def __init__(self, s_exp_name, t_exp_name):
        cfg_path = os.path.join(os.getcwd(), 'config/tusimple_config.yaml')
        self.s_exp_name = s_exp_name
        self.t_exp_name = t_exp_name
        self.writer = SummaryWriter('tensorboard/' + self.s_exp_name)
        self.metric = SegmentationMetric(7)
        with open(cfg_path) as cfg:
            config = yaml.load(cfg, Loader=yaml.FullLoader)
        self.device = torch.device(config['DEVICE'])
        self.max_epochs = config['TRAIN']['MAX_EPOCHS']
        self.dataset_path = config['DATASET']['PATH']
        self.mean = config['DATASET']['MEAN']
        self.std = config['DATASET']['STD']
        '''
        self.input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(self.mean, self.std),
            ])
        '''
        self.train_transform = Compose(Resize(size=(645, 373)),
                                       RandomCrop(size=(640, 368)),
                                       RandomFlip(0.5), Rotation(2),
                                       ToTensor(),
                                       Normalize(mean=self.mean, std=self.std))
        self.val_transform = Compose(Resize(size=(640, 368)), ToTensor(),
                                     Normalize(mean=self.mean, std=self.std))
        self.train_dataset = tuSimple(path=config['DATASET']['PATH'],
                                      image_set='train',
                                      transforms=self.train_transform)
        self.val_dataset = tuSimple(
            path=config['DATASET']['PATH'],
            image_set='val',
            transforms=self.val_transform,
        )
        self.train_loader = data.DataLoader(
            dataset=self.train_dataset,
            batch_size=config['TRAIN']['BATCH_SIZE'],
            shuffle=True,
            num_workers=0,
            pin_memory=True,
            drop_last=True,
        )
        self.val_loader = data.DataLoader(
            dataset=self.val_dataset,
            batch_size=config['TRAIN']['BATCH_SIZE'],
            shuffle=False,
            num_workers=0,
            pin_memory=True,
            drop_last=False,
        )
        self.iters_per_epoch = len(
            self.train_dataset) // config['TRAIN']['BATCH_SIZE']
        self.max_iters = self.max_epochs * self.iters_per_epoch

        # ------------network------------
        self.s_model = ENet(num_classes=7).to(self.device)
        self.t_model = ENet(num_classes=7).to(self.device)
        self.optimizer = optim.SGD(
            self.s_model.parameters(),
            lr=config['OPTIM']['LR'],
            weight_decay=config['OPTIM']['DECAY'],
            momentum=0.9,
        )
        self.lr_scheduler = get_scheduler(
            self.optimizer,
            max_iters=self.max_iters,
            iters_per_epoch=self.iters_per_epoch,
        )
        self.ce = nn.CrossEntropyLoss(weight=torch.tensor(
            [0.4, 1, 1, 1, 1, 1, 1])).cuda()  #background weight 0.4
        self.bce = nn.BCELoss().cuda()
        self.kl = nn.KLDivLoss().cuda()  #reduction='batchmean' gives NaN
        self.mse = nn.MSELoss().cuda()

    def train(self, epoch, start_time):
        running_loss = 0.0
        is_better = True
        prev_loss = float('inf')
        logging.info(
            'Start training, Total Epochs: {:d}, Total Iterations: {:d}'.
            format(self.max_epochs, self.max_iters))
        print("Train Epoch: {}".format(epoch))
        self.s_model.train()
        self.t_model.eval()
        epoch_loss = 0
        iteration = epoch * self.iters_per_epoch if epoch > 0 else 0
        start_time = start_time
        for batch_idx, sample in enumerate(self.train_loader):
            iteration += 1
            img = sample['img'].to(self.device)
            segLabel = sample['segLabel'].to(self.device)
            exist = sample['exist'].to(self.device)
            with torch.no_grad():
                t_outputs, t_sig = self.t_model(img)
            s_outputs, s_sig = self.s_model(img)
            ce = self.ce(s_outputs, segLabel)
            bce = self.bce(s_sig, exist)
            kl = self.kl(
                F.log_softmax(s_outputs, dim=1),
                F.softmax(t_outputs, dim=1),
            )
            mse = self.mse(s_outputs, t_outputs)  #/ s_outputs.size(0)
            loss = ce + (0.1 * bce) + kl + (0.5 * mse)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            self.lr_scheduler.step()
            epoch_loss += loss.item()
            running_loss += loss.item()
            eta_seconds = ((time.time() - start_time) /
                           iteration) * (self.max_iters - iteration)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
            iter_idx = epoch * len(self.train_loader) + batch_idx
            if iteration % 10 == 0:
                logging.info(
                    "Epoch: {:d}/{:d} || Iters: {:d}/{:d} || Lr: {:6f} || Loss: {:.4f} || Cost Time: {} || Estimated Time: {}"
                    .format(
                        epoch,
                        self.max_epochs,
                        iteration % self.iters_per_epoch,
                        self.iters_per_epoch,
                        self.optimizer.param_groups[0]['lr'],
                        loss.item(),
                        str(
                            datetime.timedelta(seconds=int(time.time() -
                                                           start_time))),
                        eta_string,
                    ))
            if batch_idx % 10 == 9:
                self.writer.add_scalar(
                    'train_loss', running_loss / 10,
                    epoch * len(self.train_loader) + batch_idx + 1)
                running_loss = 0.0
        if epoch % 1 == 0:
            save_dict = {
                "epoch": epoch,
                "model": self.s_model.state_dict(),
                "optim": self.optimizer.state_dict(),
                "best_mIoU": best_mIoU,
                "best_val_loss": best_val_loss,
            }
            save_name = os.path.join(os.getcwd(), 'results', self.s_exp_name,
                                     'run.pth')
            torch.save(save_dict, save_name)
            print("Model is saved: {}".format(save_name))

    def val(self, epoch):
        self.metric.reset()
        global best_val_loss
        global best_mIoU
        print("Val Epoch: {}".format(epoch))
        self.s_model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch_idx, sample in enumerate(self.val_loader):
                img = sample['img'].to(self.device)
                segLabel = sample['segLabel'].to(self.device)
                exist = sample['exist'].to(self.device)
                outputs, sig = self.s_model(img)
                ce = self.ce(outputs, segLabel)
                bce = self.bce(sig, exist)
                loss = ce + (0.1 * bce)
                self.metric.update(outputs, segLabel)
                pixAcc, mIoU = self.metric.get()
                logging.info(
                    "Sample: {:d}, pixAcc: {:.3f}, mIoU: {:.3f}".format(
                        batch_idx + 1, pixAcc * 100, mIoU * 100))

        pixAcc, mIoU, category_iou = self.metric.get(return_category_iou=True)
        print(category_iou)
        logging.info("Final pixAcc: {:.3f}, mIoU: {:.3f}".format(
            pixAcc * 100,
            mIoU * 100,
        ))
        iter_idx = (epoch + 1) * len(self.train_loader)
        if (mIoU * 100) > best_mIoU:
            best_mIoU = mIoU * 100
            save_dict = {
                "epoch": epoch,
                "model": self.s_model.state_dict(),
                "optim": self.optimizer.state_dict(),
                "best_val_loss": best_val_loss,
                "best_mIoU": best_mIoU,
            }
            save_name = os.path.join(os.getcwd(), 'results', self.s_exp_name,
                                     'best_mIoU.pth')
            torch.save(save_dict, save_name)
            print("mIoU is higher than best mIoU! Model saved to {}".format(
                save_name))
示例#9
0
def train(train_loader, val_loader, class_weights, class_encoding):
    print("\nTraining...\n")
    vis_calling_times = 0

    num_classes = len(class_encoding)

    # Intialize ENet
    model = ENet(num_classes).to(device)
    # Check if the network architecture is correct
    if torch.cuda.device_count() > 1:
        print(">>>Use mult GPU for trainning>>>")
        gpu_num = torch.cuda.device_count()
        gpu_list = list(range(gpu_num))
        model = nn.DataParallel(model, device_ids=gpu_list)
    print(model)

    # We are going to use the CrossEntropyLoss loss function as it's most
    # frequentely used in classification problems with multiple classes which
    # fits the problem. This criterion  combines LogSoftMax and NLLLoss.
    criterion = nn.CrossEntropyLoss(weight=class_weights)

    # ENet authors used Adam as the optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=args.learning_rate,
                           weight_decay=args.weight_decay)

    # Learning rate decay scheduler
    lr_updater = lr_scheduler.StepLR(optimizer, args.lr_decay_epochs,
                                     args.lr_decay)

    # Evaluation metric
    if args.ignore_unlabeled:
        ignore_index = list(class_encoding).index('unlabeled')
    else:
        ignore_index = None
    metric = IoU(num_classes, ignore_index=ignore_index)

    # Optionally resume from a checkpoint
    if args.resume:
        model, optimizer, start_epoch, best_miou = utils.load_checkpoint(
            model, optimizer, args.save_dir, args.name)
        print("Resuming from model: Start epoch = {0} "
              "| Best mean IoU = {1:.4f}".format(start_epoch, best_miou))
    else:
        start_epoch = 0
        best_miou = 0

    # Start Training
    print()
    train = Train(model, train_loader, optimizer, criterion, metric, device)
    val = Test(model, val_loader, criterion, metric, device)
    for epoch in range(start_epoch, args.epochs):
        print(">>>> [Epoch: {0:d}] Training".format(epoch))

        lr_updater.step()
        epoch_loss, (iou, miou) = train.run_epoch(args.print_step)

        print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
              format(epoch, epoch_loss, miou))

        print(">>>> [Epoch: {0:d}] Validation".format(epoch))

        loss, (iou, miou) = val.run_epoch(args.print_step)

        print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
              format(epoch, loss, miou))

        if (epoch + 1) % 10 == 0 or epoch + 1 == args.epochs:
            # Print per class IoU on last epoch or if best iou
            if epoch + 1 == args.epochs or miou > best_miou:
                for key, class_iou in zip(class_encoding.keys(), iou):
                    print("{0}: {1:.4f}".format(key, class_iou))

            # Save the model if it's the best thus far
            if miou > best_miou:
                print("\nBest model thus far. Saving...\n")
                best_miou = miou
                utils.save_checkpoint(model, optimizer, epoch + 1, best_miou,
                                      args)

        if vis_calling_times == 0:
            # set to false
            vis_calling_times = 1
            win = viz.line(X=np.column_stack(
                (np.array(epoch), np.array(epoch))),
                           Y=np.column_stack(
                               (np.array(epoch_loss), np.array(loss))),
                           opts=dict(legend=['training loss', 'eval loss'],
                                     title='loss'))
        else:
            viz.line(
                X=np.column_stack((np.array(epoch), np.array(epoch))),
                Y=np.column_stack((np.array(epoch_loss), np.array(loss))),
                win=win,  #win要保持一致
                update='append')

        # if vis_first_create:
        #     vis_first_create = false

        #     win = viz.line( X=np.column_stack((np.array(epoch),np.array(epoch))),
        #                     Y=np.column_stack((np.array(epoch_loss),np.array(loss))),
        #                     name=
        #                     opts=dict(title='loss'))
        # else:
        #     viz.line(   X=np.column_stack((np.array(epoch),np.array(epoch))),
        #                 Y=np.column_stack((np.array(epoch_loss),np.array(loss))),
        #                 win=win,#win要保持一致
        #                 update='append')

    return model
示例#10
0
class Trainer(object):
    def __init__(self, exp, exp2):
        cfg_path = os.path.join(os.getcwd(), 'config/tusimple_config.yaml')
        self.exp_name = exp
        self.exp_name2 = exp2

        self.writer = SummaryWriter('tensorboard/' + self.exp_name)
        with open(cfg_path) as file:
            cfg = yaml.load(file, Loader=yaml.FullLoader)
        self.device = torch.device(cfg['DEVICE'])
        self.max_epochs = cfg['TRAIN']['MAX_EPOCHS']
        self.dataset_path = cfg['DATASET']['PATH']
        # TODO remove this and refactor PROPERLY
        self.input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg['DATASET']['MEAN'],
                                 cfg['DATASET']['STD']),
        ])

        mean = cfg['DATASET']['MEAN']
        std = cfg['DATASET']['STD']
        self.train_transform = Compose(Resize(size=(645, 373)),
                                       RandomCrop(size=(640, 368)),
                                       RandomFlip(0.5), Rotation(2),
                                       ToTensor(), Normalize(mean=mean,
                                                             std=std))

        self.val_transform = Compose(Resize(size=(640, 368)), ToTensor(),
                                     Normalize(mean=mean, std=std))
        data_kwargs = {
            'transform': self.input_transform,
            'size': cfg['DATASET']['SIZE'],
        }
        self.train_dataset = tuSimple(path=cfg['DATASET']['PATH'],
                                      image_set='train',
                                      transforms=self.train_transform)
        self.val_dataset = tuSimple(
            path=cfg['DATASET']['PATH'],
            image_set='val',
            transforms=self.val_transform,
        )
        self.train_loader = data.DataLoader(
            dataset=self.train_dataset,
            batch_size=cfg['TRAIN']['BATCH_SIZE'],
            shuffle=True,
            num_workers=0,
            pin_memory=True,
            drop_last=True,
        )
        self.val_loader = data.DataLoader(
            dataset=self.val_dataset,
            batch_size=cfg['TRAIN']['BATCH_SIZE'],
            shuffle=False,
            num_workers=0,
            pin_memory=True,
            drop_last=False,
        )
        # -------- network --------
        weight = [0.4, 1, 1, 1, 1, 1, 1]
        self.model = ENet(num_classes=7).to(self.device)
        self.model2 = ENet(num_classes=7).to(self.device)
        self.optimizer = optim.SGD(
            self.model.parameters(),
            lr=cfg['OPTIM']['LR'],
            weight_decay=cfg['OPTIM']['DECAY'],
            momentum=0.9,
        )
        #self.optimizer = optim.Adam(
        #    self.model.parameters(),
        #    lr = cfg['OPTIM']['LR'],
        #    weight_decay=0,
        #    )
        self.criterion = nn.CrossEntropyLoss(
            weight=torch.tensor([0.4, 1, 1, 1, 1, 1, 1])).cuda()
        self.bce = nn.BCELoss().cuda()

    def train(self, epoch):
        running_loss = 0.0
        is_better = True
        prev_loss = float('inf')
        print("Train Epoch: {}".format(epoch))
        self.model.train()
        epoch_loss = 0
        progressbar = tqdm(range(len(self.train_loader)))
        for batch_idx, sample in enumerate(self.train_loader):
            img = sample['img'].to(self.device)
            segLabel = sample['segLabel'].to(self.device)
            exist = sample['exist'].to(self.device)
            # outputs is crossentropy, sig is binary cross entropy
            outputs, sig = self.model(img)
            ce = self.criterion(outputs, segLabel)
            bce = self.bce(sig, exist)
            loss = ce + (0.1 * bce)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            epoch_loss += loss.item()
            running_loss += loss.item()
            iter_idx = epoch * len(self.train_loader) + batch_idx
            progressbar.set_description("Batch loss: {:.3f}".format(
                loss.item()))
            progressbar.update(1)
            # Tensorboard
            if batch_idx % 10 == 9:
                self.writer.add_scalar(
                    'train loss', running_loss / 10,
                    epoch * len(self.train_loader) + batch_idx + 1)
                running_loss = 0.0
        progressbar.close()
        if epoch % 1 == 0:
            save_dict = {
                "epoch": epoch,
                "model": self.model.state_dict(),
                "optim": self.optimizer.state_dict(),
                "best_val_loss": best_val_loss,
            }
            os.makedirs(os.path.join(os.getcwd(), 'results', self.exp_name),
                        exist_ok=True)
            save_name = os.path.join(os.getcwd(), 'results', self.exp_name,
                                     'run.pth')
            torch.save(save_dict, save_name)
            print("Model is saved: {}".format(save_name))
            print("+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*")
        return epoch_loss / len(self.train_loader)

    def val(self, epoch, train_loss):
        global best_val_loss
        print("Val Epoch: {}".format(epoch))
        self.model.eval()
        val_loss = 0
        progressbar = tqdm(range(len(self.val_loader)))
        with torch.no_grad():
            for batch_idx, sample in enumerate(self.val_loader):
                img = sample['img'].to(self.device)
                segLabel = sample['segLabel'].to(self.device)
                exist = sample['exist'].to(self.device)
                outputs, sig = self.model(img)
                ce = self.criterion(outputs, segLabel)
                bce = self.bce(sig, exist)
                loss = ce + (0.1 * bce)
                val_loss += loss.item()
                progressbar.set_description("Batch loss: {:3f}".format(
                    loss.item()))
                progressbar.update(1)
                # Tensorboard
                if batch_idx + 1 == len(self.val_loader):
                    self.writer.add_scalar(
                        'train - val loss',
                        train_loss - (val_loss / len(self.val_loader)), epoch)
        progressbar.close()
        iter_idx = (epoch + 1) * len(self.train_loader)
        print("Validation loss: {}".format(val_loss))
        print("+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*")
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_name = os.path.join(os.getcwd(), 'results', self.exp_name,
                                     'run.pth')
            copy_name = os.path.join(os.getcwd(), 'results', self.exp_name,
                                     'run_best.pth')
            print("val loss is lower than best val loss! Model saved to {}".
                  format(copy_name))
            shutil.copyfile(save_name, copy_name)

    def eval(self):
        print("+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*")
        print("Evaluating.. ")
        self.model.eval()
        self.model2.eval()
        val_loss = 0
        dump_to_json = []
        test_dataset = tuSimple(path=self.dataset_path,
                                image_set='test',
                                transforms=self.val_transform)
        test_loader = data.DataLoader(
            dataset=test_dataset,
            batch_size=12,
            shuffle=False,
            num_workers=0,
            pin_memory=True,
            drop_last=False,
        )
        progressbar = tqdm(range(len(test_loader)))
        with torch.no_grad():
            with open('exist_out.txt', 'w') as f:
                for batch_idx, sample in enumerate(test_loader):
                    img = sample['img'].to(self.device)
                    img_name = sample['img_name']
                    #segLabel = sample['segLabel'].to(self.device)
                    outputs, sig = self.model(img)
                    outputs2, sig2 = self.model2(img)
                    #added_sig = sig2.add(sig)
                    #div_sig = torch.div(added_sig, 2.0)
                    #added_out = outputs.add(outputs2)
                    #div_out = torch.div(added_out, 2.0)
                    seg_pred1 = F.softmax(outputs, dim=1)
                    seg_pred2 = F.softmax(outputs2, dim=1)
                    seg_pred = seg_pred1.add(seg_pred2)
                    seg_pred = torch.div(seg_pred, 2.0)
                    seg_pred = seg_pred.detach().cpu().numpy()
                    sig_pred = sig.add(sig2)
                    exist_pred = sig_pred.detach().cpu().numpy()
                    count = 0

                    for img_idx in range(len(seg_pred)):
                        seg = seg_pred[img_idx]
                        exist = [
                            1 if exist_pred[img_idx, i] > 0.8 else 0
                            for i in range(6)
                        ]
                        lane_coords = getLane.prob2lines_tusimple(
                            seg,
                            exist,
                            resize_shape=(720, 1280),
                            y_px_gap=10,
                            pts=56)
                        for i in range(len(lane_coords)):
                            # sort lane coords
                            lane_coords[i] = sorted(lane_coords[i],
                                                    key=lambda pair: pair[1])

                        #print(len(lane_coords))
                    # Visualisation
                        savename = "{}/{}_{}_vis.png".format(
                            os.path.join(os.getcwd(), 'vis'), batch_idx, count)
                        count += 1
                        raw_file_name = img_name[img_idx]
                        pred_json = {}
                        pred_json['lanes'] = []
                        pred_json['h_samples'] = []
                        # truncate everything before 'clips' to be consistent with test_label.json gt
                        pred_json['raw_file'] = raw_file_name[raw_file_name.
                                                              find('clips'):]
                        pred_json['run_time'] = 0

                        for l in lane_coords:
                            empty = all(lane[0] == -2 for lane in l)
                            if len(l) == 0:
                                continue
                            if empty:
                                continue
                            pred_json['lanes'].append([])
                            for (x, y) in l:
                                pred_json['lanes'][-1].append(int(x))
                        for (x, y) in lane_coords[0]:
                            pred_json['h_samples'].append(int(y))
                        dump_to_json.append(json.dumps(pred_json))
                    progressbar.update(1)
                progressbar.close()

                with open(
                        os.path.join(os.getcwd(), "results", self.exp_name,
                                     "pred_json.json"), "w") as f:
                    for line in dump_to_json:
                        print(line, end="\n", file=f)

                print("Saved pred_json.json to {}".format(
                    os.path.join(os.getcwd(), "results", self.exp_name,
                                 "pred_json.json")))
                '''
                        raw_img = img[b].cpu().detach().numpy()
                        raw_img = raw_img.transpose(1, 2, 0)
                        # Normalize both to 0..1
                        min_val, max_val = np.min(raw_img), np.max(raw_img)
                        raw_img = (raw_img - min_val) / (max_val - min_val)
                        #rgb = rgb / 255.
                        #stack = np.hstack((raw_img, rgb))
                        background = Image.fromarray(np.uint8(raw_img*255))
                        overlay = Image.fromarray(rgb)
                        new_img = Image.blend(background, overlay, 0.4)
                        new_img.save(savename, "PNG")
                '''
                '''
def train(train_loader, val_loader, class_weights, class_encoding):
    print("\nTraining...\n")

    num_classes = len(class_encoding)

    # Intialize ENet
    model = ENet(num_classes).to(device)
    # Check if the network architecture is correct
    print(model)

    # We are going to use the CrossEntropyLoss loss function as it's most
    # frequentely used in classification problems with multiple classes which
    # fits the problem. This criterion  combines LogSoftMax and NLLLoss.
    criterion = nn.CrossEntropyLoss(weight=class_weights)

    # ENet authors used Adam as the optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=args.learning_rate,
                           weight_decay=args.weight_decay)

    # Learning rate decay scheduler
    lr_updater = lr_scheduler.StepLR(optimizer, args.lr_decay_epochs,
                                     args.lr_decay)

    # Evaluation metric
    if args.ignore_unlabeled:
        ignore_index = list(class_encoding).index('unlabeled')
    else:
        ignore_index = None
    metric = IoU(num_classes, ignore_index=ignore_index)

    # Optionally resume from a checkpoint
    if args.resume:
        model, optimizer, start_epoch, best_miou = utils.load_checkpoint(
            model, optimizer, args.save_dir, args.name)
        print("Resuming from model: Start epoch = {0} "
              "| Best mean IoU = {1:.4f}".format(start_epoch, best_miou))
    else:
        start_epoch = 0
        best_miou = 0

    # Start Training
    print()
    train = Train(model, train_loader, optimizer, criterion, metric, device)
    val = Test(model, val_loader, criterion, metric, device)
    for epoch in range(start_epoch, args.epochs):
        print(">>>> [Epoch: {0:d}] Training".format(epoch))

        lr_updater.step()
        epoch_loss, (iou, miou) = train.run_epoch(args.print_step)

        # Visualization by TensorBoardX
        writer.add_scalar('data/train/loss', epoch_loss, epoch)
        writer.add_scalar('data/train/mean_IoU', miou, epoch)

        print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
              format(epoch, epoch_loss, miou))

        if (epoch + 1) % 1 == 0 or epoch + 1 == args.epochs:
            print(">>>> [Epoch: {0:d}] Validation".format(epoch))

            loss, (iou, miou) = val.run_epoch(args.print_step)

            # Visualization by TensorBoardX
            writer.add_scalar('data/val/loss', loss, epoch)
            writer.add_scalar('data/val/mean_IoU', miou, epoch)

            print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
                  format(epoch, loss, miou))

            # Print per class IoU on last epoch or if best iou
            if epoch + 1 == args.epochs or miou > best_miou:
                for key, class_iou in zip(class_encoding.keys(), iou):
                    print("{0}: {1:.4f}".format(key, class_iou))

            # Save the model if it's the best thus far
            if miou > best_miou:
                print("\nBest model thus far. Saving...\n")
                best_miou = miou
                utils.save_checkpoint(model, optimizer, epoch + 1, best_miou,
                                      args)

            # Visualization of the predicted batch in TensorBoard
            for i, batch in enumerate(val_loader):
                if i == 1:
                    break

                # Get the inputs and labels
                inputs = batch[0].to(device)
                labels = batch[1].to(device)

                # Forward propagation
                with torch.no_grad():
                    predictions = model(inputs)

                # Predictions is one-hot encoded with "num_classes" channels.
                # Convert it to a single int using the indices where the maximum (1) occurs
                _, predictions = torch.max(predictions.data, 1)

                label_to_rgb = transforms.Compose([
                    ext_transforms.LongTensorToRGBPIL(class_encoding),
                    transforms.ToTensor()
                ])
                color_predictions = utils.batch_transform(
                    predictions.cpu(), label_to_rgb)

                in_training_visualization(model, inputs, labels,
                                          class_encoding, writer, epoch, 'val')

    return model