Beispiel #1
0
def train(train_loader, val_loader, class_weights):
    model = ENet(num_classes)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.Adam(model.parameters(), lr=5e-4, weight_decay=2e-4)
    lr_updater = lr_scheduler.StepLR(
        optimizer, 10, 1e-7)  # Large dataset, decaying every 10 epochs..
    ignore_index = None
    metric = IoU(num_classes, ignore_index=ignore_index)

    model = model.cuda()
    criterion = criterion.cuda()

    # model, optimizer, start_epoch, best_miou = utils.load_checkpoint(
    #        model, optimizer, args.save_dir, args.name)
    # print("Resuming from model: Start epoch = {0} "
    #       "| Best mean IoU = {1:.4f}".format(start_epoch, best_miou))
    start_epoch = 0
    best_miou = 0
    train = Train(model,
                  train_loader,
                  optimizer,
                  criterion,
                  metric,
                  use_cuda=True)
    val = Test(model, val_loader, criterion, metric, use_cuda=True)
    n_epochs = 200
    for epoch in range(start_epoch, n_epochs):
        print(">>>> [Epoch: {0:d}] Training".format(epoch))

        lr_updater.step()
        epoch_loss, (iou, miou) = train.run_epoch(iteration_loss=True)

        print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
              format(epoch, epoch_loss, miou))

        if (epoch + 1) % 10 == 0 or epoch + 1 == n_epochs:
            print(">>>> [Epoch: {0:d}] Validation".format(epoch))

            loss, (iou, miou) = val.run_epoch(iteration_loss=True)

            print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
                  format(epoch, loss, miou))

            # Print per class IoU on last epoch or if best iou
            if epoch + 1 == n_epochs or miou > best_miou:
                for class_iou in iou:
                    print(class_iou)

            # Save the model if it's the best thus far
            if miou > best_miou:
                print("\nBest model thus far. Saving...\n")
                best_miou = miou
                torch.save(
                    model.state_dict(),
                    '/mnt/disks/data/d4dl/snapshots/snapshot_' + str(epoch) +
                    '.pt')
    return model
Beispiel #2
0
class Trainer(object): 
    def __init__(self, exp): 
        # IoU and pixAcc Metric calculator
        self.metric = SegmentationMetric(7)
        cfg_path = os.path.join(os.getcwd(), 'config/tusimple_config.yaml') 
        self.exp_name = exp
        self.writer = SummaryWriter('tensorboard/' + self.exp_name)
        with open(cfg_path) as file: 
            cfg = yaml.load(file, Loader=yaml.FullLoader)
        self.device = torch.device(cfg['DEVICE'])
        self.max_epochs = cfg['TRAIN']['MAX_EPOCHS']
        self.dataset_path = cfg['DATASET']['PATH']
        # TODO remove this and refactor PROPERLY
        self.input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg['DATASET']['MEAN'], cfg['DATASET']['STD']),
        ])

        mean = cfg['DATASET']['MEAN']
        std = cfg['DATASET']['STD']
        self.train_transform = Compose(Resize(size=(645,373)), RandomCrop(size=(640,368)), RandomFlip(0.5), Rotation(2), ToTensor(), Normalize(mean=mean, std=std))

        self.val_transform = Compose(Resize(size=(640,368)), ToTensor(), Normalize(mean=mean, std=std))
        data_kwargs = {
            'transform': self.input_transform, 
            'size': cfg['DATASET']['SIZE'],
        } 
        self.train_dataset = tuSimple(
                path=cfg['DATASET']['PATH'],
                image_set='train',
                transforms=self.train_transform
                ) 
        self.val_dataset = tuSimple(
                path = cfg['DATASET']['PATH'],
                image_set = 'val',
                transforms =self.val_transform,
                )
        self.train_loader = data.DataLoader(
                dataset = self.train_dataset,
                batch_size = cfg['TRAIN']['BATCH_SIZE'],
                shuffle = True,
                num_workers = 0,
                pin_memory = True,
                drop_last = True,
                )
        self.val_loader = data.DataLoader(
                dataset = self.val_dataset,
                batch_size = cfg['TRAIN']['BATCH_SIZE'],
                shuffle = False,
                num_workers = 0, 
                pin_memory = True,
                drop_last = False,
                ) 
        self.iters_per_epoch = len(self.train_dataset) // (cfg['TRAIN']['BATCH_SIZE'])
        self.max_iters = cfg['TRAIN']['MAX_EPOCHS'] * self.iters_per_epoch
        # -------- network --------
        weight = [0.4, 1, 1, 1, 1, 1, 1]
        self.model = ENet(num_classes=7).to(self.device) 
        self.optimizer = optim.SGD(
            self.model.parameters(),
            lr=cfg['OPTIM']['LR'],
            weight_decay=cfg['OPTIM']['DECAY'],
            momentum=0.9,
        )
        self.lr_scheduler = get_scheduler(self.optimizer, max_iters=self.max_iters, iters_per_epoch=self.iters_per_epoch)
        #self.optimizer = optim.Adam(
        #    self.model.parameters(),
        #    lr = cfg['OPTIM']['LR'],
        #    weight_decay=0,
        #    )
        self.criterion = nn.CrossEntropyLoss(weight=torch.tensor([0.4, 1, 1, 1, 1, 1, 1])).cuda() 
        self.bce = nn.BCELoss().cuda()
    def train(self, epoch, start_time):
        running_loss = 0.0
        is_better = True
        prev_loss = float('inf') 
        logging.info('Start training, Total Epochs: {:d}, Total Iterations: {:d}'.format(self.max_epochs, self.max_iters))
        print("Train Epoch: {}".format(epoch))
        self.model.train() 
        epoch_loss = 0
        #progressbar = tqdm(range(len(self.train_loader)))
        iteration = epoch * self.iters_per_epoch if epoch > 0 else 0
        start_time = start_time
        for batch_idx, sample in enumerate(self.train_loader): 
            iteration += 1
            img = sample['img'].to(self.device) 
            segLabel = sample['segLabel'].to(self.device) 
            exist = sample['exist'].to(self.device)
            # outputs is crossentropy, sig is binary cross entropy
            outputs, sig = self.model(img) 
            ce = self.criterion(outputs,segLabel)
            bce = self.bce(sig, exist)
            loss = ce + (0.1 * bce) 


            self.optimizer.zero_grad() 
            loss.backward() 
            self.optimizer.step()
            self.lr_scheduler.step()
            #print("LR", self.optimizer.param_groups[0]['lr'])

            epoch_loss += loss.item() 
            running_loss += loss.item() 
            eta_seconds = ((time.time() - start_time) / iteration) * (self.max_iters - iteration) 
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
            iter_idx = epoch * len(self.train_loader) + batch_idx
            #progressbar.set_description("Batch loss: {:.3f}".format(loss.item()))
            #progressbar.update(1)
            # Tensorboard
            if iteration % 10 == 0:
                logging.info(
                "Epoch: {:d}/{:d} || Iters: {:d}/{:d} || Lr: {:6f} || "
                "Loss: {:.4f} || Cost Time: {} || Estimated Time: {}".format(
                epoch, self.max_epochs, iteration % self.iters_per_epoch, self.iters_per_epoch, 
                self.optimizer.param_groups[0]['lr'], loss.item(), str(datetime.timedelta(seconds=int(time.time() - start_time))), eta_string))
            if batch_idx % 10 == 9: 
                self.writer.add_scalar('train loss',
                                running_loss / 10,
                                epoch * len(self.train_loader) + batch_idx + 1)
                running_loss = 0.0
        #progressbar.close() 
        if epoch % 1 == 0: 
            save_dict = {
                    "epoch": epoch,
                    "model": self.model.state_dict(),
                    "optim": self.optimizer.state_dict(),
                    "best_val_loss": best_val_loss,
                    }
            save_name = os.path.join(os.getcwd(), 'results', self.exp_name, 'run.pth')
            save_name_epoch = os.path.join(os.getcwd(), 'results', self.exp_name, '{}.pth'.format(epoch))
            torch.save(save_dict, save_name) 
            torch.save(save_dict, save_name_epoch) 
            print("Model is saved: {}".format(save_name))
            print("Model is saved: {}".format(save_name_epoch))
            print("+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*")
        return epoch_loss/len(self.train_loader)
    def val(self, epoch, train_loss):
        self.metric.reset()
        global best_val_loss
        global best_mIoU
        print("Val Epoch: {}".format(epoch))
        self.model.eval()
        val_loss = 0 
        #progressbar = tqdm(range(len(self.val_loader)))
        with torch.no_grad(): 
            for batch_idx, sample in enumerate(self.val_loader):
                img = sample['img'].to(self.device) 
                segLabel = sample['segLabel'].to(self.device) 
                exist = sample['exist'].to(self.device)
                outputs, sig = self.model(img) 
                ce = self.criterion(outputs, segLabel)
                bce = self.bce(sig, exist)
                loss = ce + (0.1*bce) 
                val_loss += loss.item() 
                self.metric.update(outputs, segLabel)
                pixAcc, mIoU = self.metric.get()
                logging.info("Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
                    batch_idx + 1, pixAcc * 100, mIoU * 100))
                #progressbar.set_description("Batch loss: {:3f}".format(loss.item()))
                #progressbar.update(1)
                # Tensorboard
                #if batch_idx + 1 == len(self.val_loader):
                #    self.writer.add_scalar('train - val loss',
                #                    train_loss - (val_loss / len(self.val_loader)),
                #                    epoch)
        #progressbar.close() 
        pixAcc, mIoU, category_iou = self.metric.get(return_category_iou = True)
        print(category_iou)
        logging.info('End validation pixAcc: {:.3f}, mIoU: {:.3f}'.format(
            pixAcc * 100, mIoU * 100))
        iter_idx = (epoch + 1) * len(self.train_loader)
        with open('val_out.txt', 'a') as out:
            sys.stdout = out
            print(self.exp_name, 'Epoch:', epoch, 'pixAcc: {:.3f}, mIoU: {:.3f}'.format(pixAcc*100, mIoU*100))
            sys.stdout = original_stdout
        print("Validation loss: {}".format(val_loss)) 
        print("+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*")
        if (mIoU * 100) > best_mIoU:
            best_mIoU = mIoU*100
            save_dict = {
                    "epoch": epoch,
                    "model": self.model.state_dict(),
                    "optim": self.optimizer.state_dict(),
                    "best_val_loss": best_val_loss,
                    "best_mIoU": best_mIoU,
                    }
            save_name = os.path.join(os.getcwd(), 'results', self.exp_name, 'best_mIoU.pth')
            torch.save(save_dict, save_name)
            print("mIoU is higher than best mIoU! Model saved to {}".format(save_name))
        #if val_loss < best_val_loss: 
        #    best_val_loss = val_loss
        #    save_name = os.path.join(os.getcwd(), 'results', self.exp_name, 'run.pth') 
        #    copy_name = os.path.join(os.getcwd(), 'results', self.exp_name, 'run_best.pth') 
        #    print("val loss is lower than best val loss! Model saved to {}".format(copy_name))
        #    shutil.copyfile(save_name, copy_name) 
    
    def eval(self):
        print("+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*")
        print("Evaluating.. ") 
        self.model.eval() 
        val_loss = 0 
        dump_to_json = [] 
        test_dataset = tuSimple(
                path=self.dataset_path,
                image_set='test',
                transforms=self.val_transform
                ) 
        test_loader = data.DataLoader(
                dataset = test_dataset,
                batch_size = 12, 
                shuffle = False,
                num_workers = 0, 
                pin_memory = True,
                drop_last = False,
                ) 
        progressbar = tqdm(range(len(test_loader))) 
        with torch.no_grad():
            with open('exist_out.txt','w') as f:
                for batch_idx, sample in enumerate(test_loader): 
                    img = sample['img'].to(self.device) 
                    img_name = sample['img_name']
                    #segLabel = sample['segLabel'].to(self.device) 
                    outputs, sig = self.model(img) 
                    seg_pred = F.softmax(outputs, dim=1)
                    seg_pred = seg_pred.detach().cpu().numpy()
                    exist_pred = sig.detach().cpu().numpy()
                    count = 0

                    for img_idx in range(len(seg_pred)):
                        seg = seg_pred[img_idx]
                        exist = [1 if exist_pred[img_idx ,i] > 0.5 else 0 for i in range(6)]
                        lane_coords = getLane.prob2lines_tusimple(seg, exist, resize_shape=(720,1280), y_px_gap=10, pts=56)
                        for i in range(len(lane_coords)):
                            # sort lane coords
                            lane_coords[i] = sorted(lane_coords[i], key=lambda pair:pair[1])
                        
                        #print(len(lane_coords))
                    # Visualisation 
                        savename = "{}/{}_{}_vis.png".format(os.path.join(os.getcwd(), 'vis'), batch_idx, count) 
                        count += 1
                        raw_file_name = img_name[img_idx]
                        pred_json = {}
                        pred_json['lanes'] = []
                        pred_json['h_samples'] = []
                        # truncate everything before 'clips' to be consistent with test_label.json gt
                        pred_json['raw_file'] = raw_file_name[raw_file_name.find('clips'):]
                        pred_json['run_time'] = 0

                        for l in lane_coords:
                            empty = all(lane[0] == -2 for lane in l)
                            if len(l)==0:
                                continue
                            if empty:
                                continue
                            pred_json['lanes'].append([])
                            for (x,y) in l:
                                pred_json['lanes'][-1].append(int(x))
                        for (x, y) in lane_coords[0]:
                            pred_json['h_samples'].append(int(y))
                        dump_to_json.append(json.dumps(pred_json))
                    progressbar.update(1)
                progressbar.close() 

                with open(os.path.join(os.getcwd(), "results", self.exp_name, "pred_json.json"), "w") as f:
                    for line in dump_to_json:
                        print(line, end="\n", file=f)

                print("Saved pred_json.json to {}".format(os.path.join(os.getcwd(), 'results', self.exp_name, "pred_json.json")))
           
                '''
                        raw_img = img[b].cpu().detach().numpy()
                        raw_img = raw_img.transpose(1, 2, 0)
                        # Normalize both to 0..1
                        min_val, max_val = np.min(raw_img), np.max(raw_img)
                        raw_img = (raw_img - min_val) / (max_val - min_val)
                        #rgb = rgb / 255.
                        #stack = np.hstack((raw_img, rgb))
                        background = Image.fromarray(np.uint8(raw_img*255))
                        overlay = Image.fromarray(rgb)
                        new_img = Image.blend(background, overlay, 0.4)
                        new_img.save(savename, "PNG")
                '''
                        
                '''
Beispiel #3
0
class Trainer(object):
    def __init__(self, s_exp_name, t_exp_name):
        cfg_path = os.path.join(os.getcwd(), 'config/tusimple_config.yaml')
        self.s_exp_name = s_exp_name
        self.t_exp_name = t_exp_name
        self.writer = SummaryWriter('tensorboard/' + self.s_exp_name)
        self.metric = SegmentationMetric(7)
        with open(cfg_path) as cfg:
            config = yaml.load(cfg, Loader=yaml.FullLoader)
        self.device = torch.device(config['DEVICE'])
        self.max_epochs = config['TRAIN']['MAX_EPOCHS']
        self.dataset_path = config['DATASET']['PATH']
        self.mean = config['DATASET']['MEAN']
        self.std = config['DATASET']['STD']
        '''
        self.input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(self.mean, self.std),
            ])
        '''
        self.train_transform = Compose(Resize(size=(645, 373)),
                                       RandomCrop(size=(640, 368)),
                                       RandomFlip(0.5), Rotation(2),
                                       ToTensor(),
                                       Normalize(mean=self.mean, std=self.std))
        self.val_transform = Compose(Resize(size=(640, 368)), ToTensor(),
                                     Normalize(mean=self.mean, std=self.std))
        self.train_dataset = tuSimple(path=config['DATASET']['PATH'],
                                      image_set='train',
                                      transforms=self.train_transform)
        self.val_dataset = tuSimple(
            path=config['DATASET']['PATH'],
            image_set='val',
            transforms=self.val_transform,
        )
        self.train_loader = data.DataLoader(
            dataset=self.train_dataset,
            batch_size=config['TRAIN']['BATCH_SIZE'],
            shuffle=True,
            num_workers=0,
            pin_memory=True,
            drop_last=True,
        )
        self.val_loader = data.DataLoader(
            dataset=self.val_dataset,
            batch_size=config['TRAIN']['BATCH_SIZE'],
            shuffle=False,
            num_workers=0,
            pin_memory=True,
            drop_last=False,
        )
        self.iters_per_epoch = len(
            self.train_dataset) // config['TRAIN']['BATCH_SIZE']
        self.max_iters = self.max_epochs * self.iters_per_epoch

        # ------------network------------
        self.s_model = ENet(num_classes=7).to(self.device)
        self.t_model = ENet(num_classes=7).to(self.device)
        self.optimizer = optim.SGD(
            self.s_model.parameters(),
            lr=config['OPTIM']['LR'],
            weight_decay=config['OPTIM']['DECAY'],
            momentum=0.9,
        )
        self.lr_scheduler = get_scheduler(
            self.optimizer,
            max_iters=self.max_iters,
            iters_per_epoch=self.iters_per_epoch,
        )
        self.ce = nn.CrossEntropyLoss(weight=torch.tensor(
            [0.4, 1, 1, 1, 1, 1, 1])).cuda()  #background weight 0.4
        self.bce = nn.BCELoss().cuda()
        self.kl = nn.KLDivLoss().cuda()  #reduction='batchmean' gives NaN
        self.mse = nn.MSELoss().cuda()

    def train(self, epoch, start_time):
        running_loss = 0.0
        is_better = True
        prev_loss = float('inf')
        logging.info(
            'Start training, Total Epochs: {:d}, Total Iterations: {:d}'.
            format(self.max_epochs, self.max_iters))
        print("Train Epoch: {}".format(epoch))
        self.s_model.train()
        self.t_model.eval()
        epoch_loss = 0
        iteration = epoch * self.iters_per_epoch if epoch > 0 else 0
        start_time = start_time
        for batch_idx, sample in enumerate(self.train_loader):
            iteration += 1
            img = sample['img'].to(self.device)
            segLabel = sample['segLabel'].to(self.device)
            exist = sample['exist'].to(self.device)
            with torch.no_grad():
                t_outputs, t_sig = self.t_model(img)
            s_outputs, s_sig = self.s_model(img)
            ce = self.ce(s_outputs, segLabel)
            bce = self.bce(s_sig, exist)
            kl = self.kl(
                F.log_softmax(s_outputs, dim=1),
                F.softmax(t_outputs, dim=1),
            )
            mse = self.mse(s_outputs, t_outputs)  #/ s_outputs.size(0)
            loss = ce + (0.1 * bce) + kl + (0.5 * mse)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            self.lr_scheduler.step()
            epoch_loss += loss.item()
            running_loss += loss.item()
            eta_seconds = ((time.time() - start_time) /
                           iteration) * (self.max_iters - iteration)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
            iter_idx = epoch * len(self.train_loader) + batch_idx
            if iteration % 10 == 0:
                logging.info(
                    "Epoch: {:d}/{:d} || Iters: {:d}/{:d} || Lr: {:6f} || Loss: {:.4f} || Cost Time: {} || Estimated Time: {}"
                    .format(
                        epoch,
                        self.max_epochs,
                        iteration % self.iters_per_epoch,
                        self.iters_per_epoch,
                        self.optimizer.param_groups[0]['lr'],
                        loss.item(),
                        str(
                            datetime.timedelta(seconds=int(time.time() -
                                                           start_time))),
                        eta_string,
                    ))
            if batch_idx % 10 == 9:
                self.writer.add_scalar(
                    'train_loss', running_loss / 10,
                    epoch * len(self.train_loader) + batch_idx + 1)
                running_loss = 0.0
        if epoch % 1 == 0:
            save_dict = {
                "epoch": epoch,
                "model": self.s_model.state_dict(),
                "optim": self.optimizer.state_dict(),
                "best_mIoU": best_mIoU,
                "best_val_loss": best_val_loss,
            }
            save_name = os.path.join(os.getcwd(), 'results', self.s_exp_name,
                                     'run.pth')
            torch.save(save_dict, save_name)
            print("Model is saved: {}".format(save_name))

    def val(self, epoch):
        self.metric.reset()
        global best_val_loss
        global best_mIoU
        print("Val Epoch: {}".format(epoch))
        self.s_model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch_idx, sample in enumerate(self.val_loader):
                img = sample['img'].to(self.device)
                segLabel = sample['segLabel'].to(self.device)
                exist = sample['exist'].to(self.device)
                outputs, sig = self.s_model(img)
                ce = self.ce(outputs, segLabel)
                bce = self.bce(sig, exist)
                loss = ce + (0.1 * bce)
                self.metric.update(outputs, segLabel)
                pixAcc, mIoU = self.metric.get()
                logging.info(
                    "Sample: {:d}, pixAcc: {:.3f}, mIoU: {:.3f}".format(
                        batch_idx + 1, pixAcc * 100, mIoU * 100))

        pixAcc, mIoU, category_iou = self.metric.get(return_category_iou=True)
        print(category_iou)
        logging.info("Final pixAcc: {:.3f}, mIoU: {:.3f}".format(
            pixAcc * 100,
            mIoU * 100,
        ))
        iter_idx = (epoch + 1) * len(self.train_loader)
        if (mIoU * 100) > best_mIoU:
            best_mIoU = mIoU * 100
            save_dict = {
                "epoch": epoch,
                "model": self.s_model.state_dict(),
                "optim": self.optimizer.state_dict(),
                "best_val_loss": best_val_loss,
                "best_mIoU": best_mIoU,
            }
            save_name = os.path.join(os.getcwd(), 'results', self.s_exp_name,
                                     'best_mIoU.pth')
            torch.save(save_dict, save_name)
            print("mIoU is higher than best mIoU! Model saved to {}".format(
                save_name))
Beispiel #4
0
class Trainer(object):
    def __init__(self, exp, exp2):
        cfg_path = os.path.join(os.getcwd(), 'config/tusimple_config.yaml')
        self.exp_name = exp
        self.exp_name2 = exp2

        self.writer = SummaryWriter('tensorboard/' + self.exp_name)
        with open(cfg_path) as file:
            cfg = yaml.load(file, Loader=yaml.FullLoader)
        self.device = torch.device(cfg['DEVICE'])
        self.max_epochs = cfg['TRAIN']['MAX_EPOCHS']
        self.dataset_path = cfg['DATASET']['PATH']
        # TODO remove this and refactor PROPERLY
        self.input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cfg['DATASET']['MEAN'],
                                 cfg['DATASET']['STD']),
        ])

        mean = cfg['DATASET']['MEAN']
        std = cfg['DATASET']['STD']
        self.train_transform = Compose(Resize(size=(645, 373)),
                                       RandomCrop(size=(640, 368)),
                                       RandomFlip(0.5), Rotation(2),
                                       ToTensor(), Normalize(mean=mean,
                                                             std=std))

        self.val_transform = Compose(Resize(size=(640, 368)), ToTensor(),
                                     Normalize(mean=mean, std=std))
        data_kwargs = {
            'transform': self.input_transform,
            'size': cfg['DATASET']['SIZE'],
        }
        self.train_dataset = tuSimple(path=cfg['DATASET']['PATH'],
                                      image_set='train',
                                      transforms=self.train_transform)
        self.val_dataset = tuSimple(
            path=cfg['DATASET']['PATH'],
            image_set='val',
            transforms=self.val_transform,
        )
        self.train_loader = data.DataLoader(
            dataset=self.train_dataset,
            batch_size=cfg['TRAIN']['BATCH_SIZE'],
            shuffle=True,
            num_workers=0,
            pin_memory=True,
            drop_last=True,
        )
        self.val_loader = data.DataLoader(
            dataset=self.val_dataset,
            batch_size=cfg['TRAIN']['BATCH_SIZE'],
            shuffle=False,
            num_workers=0,
            pin_memory=True,
            drop_last=False,
        )
        # -------- network --------
        weight = [0.4, 1, 1, 1, 1, 1, 1]
        self.model = ENet(num_classes=7).to(self.device)
        self.model2 = ENet(num_classes=7).to(self.device)
        self.optimizer = optim.SGD(
            self.model.parameters(),
            lr=cfg['OPTIM']['LR'],
            weight_decay=cfg['OPTIM']['DECAY'],
            momentum=0.9,
        )
        #self.optimizer = optim.Adam(
        #    self.model.parameters(),
        #    lr = cfg['OPTIM']['LR'],
        #    weight_decay=0,
        #    )
        self.criterion = nn.CrossEntropyLoss(
            weight=torch.tensor([0.4, 1, 1, 1, 1, 1, 1])).cuda()
        self.bce = nn.BCELoss().cuda()

    def train(self, epoch):
        running_loss = 0.0
        is_better = True
        prev_loss = float('inf')
        print("Train Epoch: {}".format(epoch))
        self.model.train()
        epoch_loss = 0
        progressbar = tqdm(range(len(self.train_loader)))
        for batch_idx, sample in enumerate(self.train_loader):
            img = sample['img'].to(self.device)
            segLabel = sample['segLabel'].to(self.device)
            exist = sample['exist'].to(self.device)
            # outputs is crossentropy, sig is binary cross entropy
            outputs, sig = self.model(img)
            ce = self.criterion(outputs, segLabel)
            bce = self.bce(sig, exist)
            loss = ce + (0.1 * bce)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            epoch_loss += loss.item()
            running_loss += loss.item()
            iter_idx = epoch * len(self.train_loader) + batch_idx
            progressbar.set_description("Batch loss: {:.3f}".format(
                loss.item()))
            progressbar.update(1)
            # Tensorboard
            if batch_idx % 10 == 9:
                self.writer.add_scalar(
                    'train loss', running_loss / 10,
                    epoch * len(self.train_loader) + batch_idx + 1)
                running_loss = 0.0
        progressbar.close()
        if epoch % 1 == 0:
            save_dict = {
                "epoch": epoch,
                "model": self.model.state_dict(),
                "optim": self.optimizer.state_dict(),
                "best_val_loss": best_val_loss,
            }
            os.makedirs(os.path.join(os.getcwd(), 'results', self.exp_name),
                        exist_ok=True)
            save_name = os.path.join(os.getcwd(), 'results', self.exp_name,
                                     'run.pth')
            torch.save(save_dict, save_name)
            print("Model is saved: {}".format(save_name))
            print("+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*")
        return epoch_loss / len(self.train_loader)

    def val(self, epoch, train_loss):
        global best_val_loss
        print("Val Epoch: {}".format(epoch))
        self.model.eval()
        val_loss = 0
        progressbar = tqdm(range(len(self.val_loader)))
        with torch.no_grad():
            for batch_idx, sample in enumerate(self.val_loader):
                img = sample['img'].to(self.device)
                segLabel = sample['segLabel'].to(self.device)
                exist = sample['exist'].to(self.device)
                outputs, sig = self.model(img)
                ce = self.criterion(outputs, segLabel)
                bce = self.bce(sig, exist)
                loss = ce + (0.1 * bce)
                val_loss += loss.item()
                progressbar.set_description("Batch loss: {:3f}".format(
                    loss.item()))
                progressbar.update(1)
                # Tensorboard
                if batch_idx + 1 == len(self.val_loader):
                    self.writer.add_scalar(
                        'train - val loss',
                        train_loss - (val_loss / len(self.val_loader)), epoch)
        progressbar.close()
        iter_idx = (epoch + 1) * len(self.train_loader)
        print("Validation loss: {}".format(val_loss))
        print("+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*")
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_name = os.path.join(os.getcwd(), 'results', self.exp_name,
                                     'run.pth')
            copy_name = os.path.join(os.getcwd(), 'results', self.exp_name,
                                     'run_best.pth')
            print("val loss is lower than best val loss! Model saved to {}".
                  format(copy_name))
            shutil.copyfile(save_name, copy_name)

    def eval(self):
        print("+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*")
        print("Evaluating.. ")
        self.model.eval()
        self.model2.eval()
        val_loss = 0
        dump_to_json = []
        test_dataset = tuSimple(path=self.dataset_path,
                                image_set='test',
                                transforms=self.val_transform)
        test_loader = data.DataLoader(
            dataset=test_dataset,
            batch_size=12,
            shuffle=False,
            num_workers=0,
            pin_memory=True,
            drop_last=False,
        )
        progressbar = tqdm(range(len(test_loader)))
        with torch.no_grad():
            with open('exist_out.txt', 'w') as f:
                for batch_idx, sample in enumerate(test_loader):
                    img = sample['img'].to(self.device)
                    img_name = sample['img_name']
                    #segLabel = sample['segLabel'].to(self.device)
                    outputs, sig = self.model(img)
                    outputs2, sig2 = self.model2(img)
                    #added_sig = sig2.add(sig)
                    #div_sig = torch.div(added_sig, 2.0)
                    #added_out = outputs.add(outputs2)
                    #div_out = torch.div(added_out, 2.0)
                    seg_pred1 = F.softmax(outputs, dim=1)
                    seg_pred2 = F.softmax(outputs2, dim=1)
                    seg_pred = seg_pred1.add(seg_pred2)
                    seg_pred = torch.div(seg_pred, 2.0)
                    seg_pred = seg_pred.detach().cpu().numpy()
                    sig_pred = sig.add(sig2)
                    exist_pred = sig_pred.detach().cpu().numpy()
                    count = 0

                    for img_idx in range(len(seg_pred)):
                        seg = seg_pred[img_idx]
                        exist = [
                            1 if exist_pred[img_idx, i] > 0.8 else 0
                            for i in range(6)
                        ]
                        lane_coords = getLane.prob2lines_tusimple(
                            seg,
                            exist,
                            resize_shape=(720, 1280),
                            y_px_gap=10,
                            pts=56)
                        for i in range(len(lane_coords)):
                            # sort lane coords
                            lane_coords[i] = sorted(lane_coords[i],
                                                    key=lambda pair: pair[1])

                        #print(len(lane_coords))
                    # Visualisation
                        savename = "{}/{}_{}_vis.png".format(
                            os.path.join(os.getcwd(), 'vis'), batch_idx, count)
                        count += 1
                        raw_file_name = img_name[img_idx]
                        pred_json = {}
                        pred_json['lanes'] = []
                        pred_json['h_samples'] = []
                        # truncate everything before 'clips' to be consistent with test_label.json gt
                        pred_json['raw_file'] = raw_file_name[raw_file_name.
                                                              find('clips'):]
                        pred_json['run_time'] = 0

                        for l in lane_coords:
                            empty = all(lane[0] == -2 for lane in l)
                            if len(l) == 0:
                                continue
                            if empty:
                                continue
                            pred_json['lanes'].append([])
                            for (x, y) in l:
                                pred_json['lanes'][-1].append(int(x))
                        for (x, y) in lane_coords[0]:
                            pred_json['h_samples'].append(int(y))
                        dump_to_json.append(json.dumps(pred_json))
                    progressbar.update(1)
                progressbar.close()

                with open(
                        os.path.join(os.getcwd(), "results", self.exp_name,
                                     "pred_json.json"), "w") as f:
                    for line in dump_to_json:
                        print(line, end="\n", file=f)

                print("Saved pred_json.json to {}".format(
                    os.path.join(os.getcwd(), "results", self.exp_name,
                                 "pred_json.json")))
                '''
                        raw_img = img[b].cpu().detach().numpy()
                        raw_img = raw_img.transpose(1, 2, 0)
                        # Normalize both to 0..1
                        min_val, max_val = np.min(raw_img), np.max(raw_img)
                        raw_img = (raw_img - min_val) / (max_val - min_val)
                        #rgb = rgb / 255.
                        #stack = np.hstack((raw_img, rgb))
                        background = Image.fromarray(np.uint8(raw_img*255))
                        overlay = Image.fromarray(rgb)
                        new_img = Image.blend(background, overlay, 0.4)
                        new_img.save(savename, "PNG")
                '''
                '''