Ejemplo n.º 1
0
    def test_comet(self):
        """Test with a comet hook."""
        from comet_ml import Experiment
        comet = Experiment(project_name="Testing",
                           auto_output_logging="native")
        comet.log_dataset_info(name="Karcher", path="shonan")
        comet.add_tag("GaussNewton")
        comet.log_parameter("method", "GaussNewton")
        time = datetime.now()
        comet.set_name("GaussNewton-" + str(time.month) + "/" + str(time.day) +
                       " " + str(time.hour) + ":" + str(time.minute) + ":" +
                       str(time.second))

        # I want to do some comet thing here
        def hook(optimizer, error):
            comet.log_metric("Karcher error", error, optimizer.iterations())

        gtsam_optimize(self.optimizer, self.params, hook)
        comet.end()

        actual = self.optimizer.values()
        self.gtsamAssertEquals(actual.atRot3(KEY), self.expected)
Ejemplo n.º 2
0
def run(args, train, sparse_evidences, claims_dict):
    BATCH_SIZE = args.batch_size
    LEARNING_RATE = args.learning_rate
    DATA_SAMPLING = args.data_sampling
    NUM_EPOCHS = args.epochs
    MODEL = args.model
    RANDOMIZE = args.no_randomize
    PRINT = args.print

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")

    logger = Logger('./logs/{}'.format(time.localtime()))

    if MODEL:
        print("Loading pretrained model...")
        model = torch.load(MODEL)
        model.load_state_dict(torch.load(MODEL).state_dict())
    else:
        model = cdssm.CDSSM()
        model = model.cuda()
        model = model.to(device)

    # model = cdssm.CDSSM()
    # model = model.cuda()
    # model = model.to(device)

    if torch.cuda.device_count() > 0:
        print("Let's use", torch.cuda.device_count(), "GPU(s)!")
        model = nn.DataParallel(model)

    print("Created model with {:,} parameters.".format(
        putils.count_parameters(model)))

    # if MODEL:
    # print("TEMPORARY change to loading!")
    # model.load_state_dict(torch.load(MODEL).state_dict())

    print("Created dataset...")

    # use an 80/20 train/validate split!
    train_size = int(len(train) * 0.80)
    #test = int(len(train) * 0.5)
    train_dataset = pytorch_data_loader.WikiDataset(
        train[:train_size],
        claims_dict,
        data_sampling=DATA_SAMPLING,
        sparse_evidences=sparse_evidences,
        randomize=RANDOMIZE)
    val_dataset = pytorch_data_loader.WikiDataset(
        train[train_size:],
        claims_dict,
        data_sampling=DATA_SAMPLING,
        sparse_evidences=sparse_evidences,
        randomize=RANDOMIZE)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=BATCH_SIZE,
                                  num_workers=0,
                                  shuffle=True,
                                  collate_fn=pytorch_data_loader.PadCollate())
    val_dataloader = DataLoader(val_dataset,
                                batch_size=BATCH_SIZE,
                                num_workers=0,
                                shuffle=True,
                                collate_fn=pytorch_data_loader.PadCollate())

    # Loss and optimizer
    criterion = torch.nn.NLLLoss()
    # criterion = torch.nn.SoftMarginLoss()
    # if torch.cuda.device_count() > 0:
    # print("Let's parallelize the backward pass...")
    # criterion = DataParallelCriterion(criterion)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=LEARNING_RATE,
                                 weight_decay=1e-3)

    OUTPUT_FREQ = max(int((len(train_dataset) / BATCH_SIZE) * 0.02), 20)
    parameters = {
        "batch size": BATCH_SIZE,
        "epochs": NUM_EPOCHS,
        "learning rate": LEARNING_RATE,
        "optimizer": optimizer.__class__.__name__,
        "loss": criterion.__class__.__name__,
        "training size": train_size,
        "data sampling rate": DATA_SAMPLING,
        "data": args.data,
        "sparse_evidences": args.sparse_evidences,
        "randomize": RANDOMIZE,
        "model": MODEL
    }
    experiment = Experiment(api_key="YLsW4AvRTYGxzdDqlWRGCOhee",
                            project_name="clsm",
                            workspace="moinnadeem")
    experiment.add_tag("train")
    experiment.log_asset("cdssm.py")
    experiment.log_dataset_info(name=args.data)
    experiment.log_parameters(parameters)

    model_checkpoint_dir = "models/saved_model"
    for key, value in parameters.items():
        if type(value) == str:
            value = value.replace("/", "-")
        if key != "model":
            model_checkpoint_dir += "_{}-{}".format(key.replace(" ", "_"),
                                                    value)

    print("Training...")
    beginning_time = time.time()
    best_loss = torch.tensor(float("inf"),
                             dtype=torch.float)  # begin loss at infinity

    for epoch in range(NUM_EPOCHS):
        beginning_time = time.time()
        mean_train_acc = 0.0
        train_running_loss = 0.0
        train_running_accuracy = 0.0
        model.train()
        experiment.log_current_epoch(epoch)

        with experiment.train():
            for train_batch_num, inputs in enumerate(train_dataloader):
                claims_tensors, claims_text, evidences_tensors, evidences_text, labels = inputs

                claims_tensors = claims_tensors.cuda()
                evidences_tensors = evidences_tensors.cuda()
                labels = labels.cuda()
                #claims = claims.to(device).float()
                #evidences = evidences.to(device).float()
                #labels = labels.to(device)

                y_pred = model(claims_tensors, evidences_tensors)

                y = (labels)
                # y = y.unsqueeze(0)
                # y = y.unsqueeze(0)
                # y_pred = parallel.gather(y_pred, 0)

                y_pred = y_pred.squeeze()
                # y = y.squeeze()

                loss = criterion(y_pred, torch.max(y, 1)[1])
                # loss = criterion(y_pred, y)

                y = y.float()
                binary_y = torch.max(y, 1)[1]
                binary_pred = torch.max(y_pred, 1)[1]
                accuracy = (binary_y == binary_pred).to("cuda")
                accuracy = accuracy.float()
                accuracy = accuracy.mean()
                train_running_accuracy += accuracy.item()
                mean_train_acc += accuracy.item()
                train_running_loss += loss.item()

                if PRINT:
                    for idx in range(len(y)):
                        print(
                            "Claim: {}, Evidence: {}, Prediction: {}, Label: {}"
                            .format(claims_text[0], evidences_text[idx],
                                    torch.exp(y_pred[idx]), y[idx]))

                if (train_batch_num %
                        OUTPUT_FREQ) == 0 and train_batch_num > 0:
                    elapsed_time = time.time() - beginning_time
                    binary_y = torch.max(y, 1)[1]
                    binary_pred = torch.max(y_pred, 1)[1]
                    print(
                        "[{}:{}:{:3f}s] training loss: {}, training accuracy: {}, training recall: {}"
                        .format(
                            epoch, train_batch_num /
                            (len(train_dataset) / BATCH_SIZE), elapsed_time,
                            train_running_loss / OUTPUT_FREQ,
                            train_running_accuracy / OUTPUT_FREQ,
                            recall_score(binary_y.cpu().detach().numpy(),
                                         binary_pred.cpu().detach().numpy())))

                    # 1. Log scalar values (scalar summary)
                    info = {
                        'train_loss': train_running_loss / OUTPUT_FREQ,
                        'train_accuracy': train_running_accuracy / OUTPUT_FREQ
                    }

                    for tag, value in info.items():
                        experiment.log_metric(tag,
                                              value,
                                              step=train_batch_num *
                                              (epoch + 1))
                        logger.scalar_summary(tag, value, train_batch_num + 1)

                    ## 2. Log values and gradients of the parameters (histogram summary)
                    for tag, value in model.named_parameters():
                        tag = tag.replace('.', '/')
                        logger.histo_summary(tag,
                                             value.detach().cpu().numpy(),
                                             train_batch_num + 1)
                        logger.histo_summary(tag + '/grad',
                                             value.grad.detach().cpu().numpy(),
                                             train_batch_num + 1)

                    train_running_loss = 0.0
                    beginning_time = time.time()
                    train_running_accuracy = 0.0
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        # del loss
        # del accuracy
        # del claims_tensors
        # del claims_text
        # del evidences_tensors
        # del evidences_text
        # del labels
        # del y
        # del y_pred
        # torch.cuda.empty_cache()

        print("Running validation...")
        model.eval()
        pred = []
        true = []
        avg_loss = 0.0
        val_running_accuracy = 0.0
        val_running_loss = 0.0
        beginning_time = time.time()
        with experiment.validate():
            for val_batch_num, val_inputs in enumerate(val_dataloader):
                claims_tensors, claims_text, evidences_tensors, evidences_text, labels = val_inputs

                claims_tensors = claims_tensors.cuda()
                evidences_tensors = evidences_tensors.cuda()
                labels = labels.cuda()

                y_pred = model(claims_tensors, evidences_tensors)

                y = (labels)
                # y_pred = parallel.gather(y_pred, 0)

                y_pred = y_pred.squeeze()

                loss = criterion(y_pred, torch.max(y, 1)[1])

                y = y.float()

                binary_y = torch.max(y, 1)[1]
                binary_pred = torch.max(y_pred, 1)[1]
                true.extend(binary_y.tolist())
                pred.extend(binary_pred.tolist())

                accuracy = (binary_y == binary_pred).to("cuda")

                accuracy = accuracy.float().mean()
                val_running_accuracy += accuracy.item()
                val_running_loss += loss.item()
                avg_loss += loss.item()

                if (val_batch_num % OUTPUT_FREQ) == 0 and val_batch_num > 0:
                    elapsed_time = time.time() - beginning_time
                    print(
                        "[{}:{}:{:3f}s] validation loss: {}, accuracy: {}, recall: {}"
                        .format(
                            epoch,
                            val_batch_num / (len(val_dataset) / BATCH_SIZE),
                            elapsed_time, val_running_loss / OUTPUT_FREQ,
                            val_running_accuracy / OUTPUT_FREQ,
                            recall_score(binary_y.cpu().detach().numpy(),
                                         binary_pred.cpu().detach().numpy())))

                    # 1. Log scalar values (scalar summary)
                    info = {'val_accuracy': val_running_accuracy / OUTPUT_FREQ}

                    for tag, value in info.items():
                        experiment.log_metric(tag,
                                              value,
                                              step=val_batch_num * (epoch + 1))
                        logger.scalar_summary(tag, value, val_batch_num + 1)

                    ## 2. Log values and gradients of the parameters (histogram summary)
                    for tag, value in model.named_parameters():
                        tag = tag.replace('.', '/')
                        logger.histo_summary(tag,
                                             value.detach().cpu().numpy(),
                                             val_batch_num + 1)
                        logger.histo_summary(tag + '/grad',
                                             value.grad.detach().cpu().numpy(),
                                             val_batch_num + 1)

                    val_running_accuracy = 0.0
                    val_running_loss = 0.0
                    beginning_time = time.time()

        # del loss
        # del accuracy
        # del claims_tensors
        # del claims_text
        # del evidences_tensors
        # del evidences_text
        # del labels
        # del y
        # del y_pred
        # torch.cuda.empty_cache()

        accuracy = accuracy_score(true, pred)
        print("[{}] mean accuracy: {}, mean loss: {}".format(
            epoch, accuracy, avg_loss / len(val_dataloader)))

        true = np.array(true).astype("int")
        pred = np.array(pred).astype("int")
        print(classification_report(true, pred))

        best_loss = torch.tensor(
            min(avg_loss / len(val_dataloader),
                best_loss.cpu().numpy()))
        is_best = bool((avg_loss / len(val_dataloader)) <= best_loss)

        putils.save_checkpoint(
            {
                "epoch": epoch,
                "model": model,
                "best_loss": best_loss
            },
            is_best,
            filename="{}_loss_{}".format(model_checkpoint_dir,
                                         best_loss.cpu().numpy()))
Ejemplo n.º 3
0
def comet_lgbm(save_path):
    from comet_ml import Experiment
    exp = Experiment(api_key="sqMrI9jc8kzJYobRXRuptF5Tj",
                            project_name="baseline", workspace="gdreiman1")
    exp.log_code = True
    
    import pickle
    import pandas as pd
    import lightgbm as lgb
    import numpy as np
    import sklearn
    import matplotlib.pyplot as plt
    from sklearn.metrics import precision_recall_fscore_support as prf
    #%%
    def single_roc(y_preds,y_true):
        
        from sklearn.metrics import roc_curve, auc,precision_recall_curve
        fpr, tpr, _ = roc_curve(y_true, y_preds)
        roc_auc = auc(fpr, tpr)
        plt.figure()
        lw = 2
        plt.plot(fpr, tpr, color='darkorange',
                 lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
        plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver operating characteristic example')
        
        precision, recall, thresholds = precision_recall_curve(y_true, y_preds)
        plt.plot(recall, precision, color='blue',
                 lw=lw, label='Precision vs Recall')
        # show the plot
        plt.legend(loc="lower right")
        plt.show()
    def multi_roc(y_preds,y_true,name,n_classes):
        import collections
        nested_dict = lambda: collections.defaultdict(nested_dict)
        data_store = nested_dict()
        from sklearn.metrics import roc_curve, auc
        from scipy import interp
        from itertools import cycle
        lw = 2
        name_store = ['Active', 'Inactive', 'Inconclusive']
        fpr = dict()
        tpr = dict()
        roc_auc = dict()
        for i in range(n_classes):
            fpr[i], tpr[i], _ = roc_curve(y_true[:, i], y_preds[:, i])
            roc_auc[i] = auc(fpr[i], tpr[i])
        
        # Compute micro-average ROC curve and ROC area
        fpr["micro"], tpr["micro"], _ = roc_curve(y_true[:, i].ravel(), y_preds[:, i].ravel())
        roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
        # Compute macro-average ROC curve and ROC area
        
        # First aggregate all false positive rates
        all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
        
        # Then interpolate all ROC curves at this points
        mean_tpr = np.zeros_like(all_fpr)
        for i in range(n_classes):
            mean_tpr += interp(all_fpr, fpr[i], tpr[i])
        
        # Finally average it and compute AUC
        mean_tpr /= n_classes
        
        fpr["macro"] = all_fpr
        tpr["macro"] = mean_tpr
        roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
        
        # Plot all ROC curves
        plt.figure()
        plt.plot(fpr["micro"], tpr["micro"],
                 label='micro-average ROC curve (area = {0:0.2f})'
                       ''.format(roc_auc["micro"]),
                 color='deeppink', linestyle=':', linewidth=4)
        
        plt.plot(fpr["macro"], tpr["macro"],
                 label='macro-average ROC curve (area = {0:0.2f})'
                       ''.format(roc_auc["macro"]),
                 color='navy', linestyle=':', linewidth=4)
        
        colors = cycle(['aqua', 'darkorange', 'cornflowerblue','green'])
        for i, color in zip(range(n_classes), colors):
            plt.plot(fpr[i], tpr[i], color=color, lw=lw,
                     label='ROC curve of '+ name_store[i]+'(area = {1:0.2f})'
                     ''.format(i, roc_auc[i]))
        
        plt.plot([0, 1], [0, 1], 'k--', lw=lw)
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        #plt.title('Multi-class ROC for '+name+' Split= '+str(count+1))
        plt.title('Multi-class ROC for '+name)
    
        plt.legend(loc="lower right")
        #plt.show()
    #%%
    #save_path = r'C:\Users\gdrei\Dropbox\UCL\Thesis\May_13\AID_1345083_processed.pkl'
    model_type = 'lgbm'
    #get data cleaned
    pickle_off = open(save_path,'rb')
    activity_table=pickle.load(pickle_off)
    pickle_off.close()
    #get length of MFP
    fp_length = len(activity_table.iloc[5]['MFP'])
    
    
    from sklearn.preprocessing import StandardScaler, LabelEncoder
    scaler = StandardScaler(copy = False)
    le = LabelEncoder()
    labels = le.fit_transform(activity_table['PUBCHEM_ACTIVITY_OUTCOME'])
    #split data:
    from sklearn.model_selection import StratifiedShuffleSplit
    splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.5, train_size=None, random_state=2562)
    X_mfp = np.concatenate(np.array(activity_table['MFP'])).ravel()
    X_mfp = X_mfp.reshape((-1,fp_length))
    for train_ind, test_ind in splitter.split(X_mfp,labels):
        # standardize data
        X_train_molchars_std = scaler.fit_transform(np.array(activity_table.iloc[train_ind,4:]))
        X_test_molchars_std = scaler.transform(np.array(activity_table.iloc[test_ind,4:]))
        X_train = np.concatenate((X_mfp[train_ind,:],X_train_molchars_std),axis = 1)
        X_test = np.concatenate((X_mfp[test_ind,:],X_test_molchars_std),axis = 1)
        y_train = labels[train_ind]
        y_test = labels[test_ind]
        #X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(X,labels,test_size = .5, shuffle = True, stratify = labels, random_state = 2562)
        bin_y_train, bin_y_test = [1 if x ==2 else x for x in y_train],[1 if x ==2 else x for x in y_test]
        
    #do light gbm
        
    #need to make a lib svm file
    train_data = lgb.Dataset(X_train,label=y_train)
    test_data = lgb.Dataset(X_test,label=y_test)
    #make model class
    lgbm_model = lgb.LGBMClassifier(boosting_type='gbdt', num_leaves=31, max_depth=-1, learning_rate=0.1, n_estimators=500, subsample_for_bin=200000, 
                                    objective='binary', is_unbalance=True, min_split_gain=0.0, min_child_weight=0.001, min_child_samples=20, subsample=1.0, 
                                    subsample_freq=0, colsample_bytree=1.0, reg_alpha=0.0, reg_lambda=0.0, random_state=None, n_jobs=-1, silent=True, 
                                    importance_type='split')
    #train model
    trained_mod = lgbm_model.fit(X_train,y_train)
    #predict classes and class_probs
    test_class_preds = lgbm_model.predict(X_test)
    test_prob_preds = lgbm_model.predict_proba(X_test)
    #calculate Class report
    class_rep = sklearn.metrics.classification_report(y_test,test_class_preds)
    
    print(class_rep)
    if len(set(y_test)) == 2:
        single_roc(test_prob_preds[:,1],y_test)
        prec,rec,f_1,supp = prf(y_test, test_class_preds, average=None)
    else:
        from tensorflow.keras.utils import to_categorical
        multi_roc(test_prob_preds,to_categorical(y_test),'',3)
        prec,rec,f_1,supp = prf(y_test, test_class_preds, average=None)
    
    
     #%% 
    '''Comet Saving Zone'''
    #get AID number
    import ntpath
    #get base file name
    folder,base = ntpath.split(save_path)
    #split file name at second _ assumes file save in AID_xxx_endinfo.pkl
    AID, _,end_info = base.rpartition('_')
    #save data location, AID info, and version info
    exp.log_dataset_info(name = AID, version = end_info, path = save_path)
    #save model params
    exp.log_parameters(trained_mod.get_params())
    #save metrics report to comet
    if len(f_1) == 2:
        for i,name in enumerate(['Active','Inactive']):
            exp.log_metric('f1 class '+name, f_1[i])
            exp.log_metric('Recall class'+name,rec[i])
            exp.log_metric('Precision class'+name, prec[i])
    else:
        for i,name in enumerate(['Active','Inconclusive','Inactive']):
            exp.log_metric('f1 class '+str(i), f_1[i])
            exp.log_metric('Recall class'+str(i),rec[i])
            exp.log_metric('Precision class'+str(i), prec[i])
        #exp.log_metric('f1 class '+str(i), f_1[i])
        #exp.log_metric('Recall class'+str(i),rec[i])
        #exp.log_metric('Precision class'+str(i), prec[i])
    exp.log_other('Classification Report',class_rep)
     #save model in data_folder with comet experiement number associated
    exp_num = exp.get_key()
    model_save = folder+'\\'+model_type+'_'+exp_num+'.pkl'
    pickle_on = open(model_save,'wb')
    pickle.dump(trained_mod,pickle_on)
    pickle_on.close()
    #log trained model location
    exp.log_other('Trained Model Path',model_save)
    #save some informatvie tags:
    tags = [AID,end_info,model_type]
    exp.add_tags(tags)
    #save ROC curve
    exp.log_figure(figure_name = 'ROC-Pres/Recall',figure=plt)
    plt.show()

    #tell comet that the experiement is over
    exp.end()
Ejemplo n.º 4
0
def training(dataset: str, log_comet_ml=False):
    print('Train PINet')

    # ------------------------------------------------------------
    # 1. get parameters
    network_params = NetworkParameters()
    training_params = TrainingParameters()

    # ------------------------------------------------------------
    # 2. build dataset -> output sample dict
    dataset_params = get_parameters(dataset)
    train_dataset, val_dataset = build_dataset(dataset_params)
    validate_fn = build_validate_fn(dataset_params)

    # ------------------------------------------------------------
    # 3. build dataloader -> output batch tensor
    train_generator = None
    if train_dataset is not None:
        train_generator = build_dataloader(train_dataset,
                                           training_params.batch_size)

    # ------------------------------------------------------------
    # 4. define network
    print('Get model, train from sketch')
    model = PINet()
    device = torch.device(
        "cuda:0" if torch.cuda.is_available() else "cpu")  # start from gpu_0
    model.to(device)

    # ------------------------------------------------------------
    # 5. associate network model to a training instance
    # ToDo: resume training
    lane_detector = TrainerLaneDetector(model, network_params, training_params)
    # lane_detector.load_weights(p.model_path)
    lane_detector.training_mode()

    # ------------------------------------------------------------
    # 6. define logging enviroment
    if log_comet_ml:
        print('Logging with comet_ml')
        experiment = Experiment(api_key="7XBXyqF4ctwO6HnnHnBGuv7U3",
                                project_name="lanedetection",
                                workspace="masszhou")
        experiment.log_parameters(vars(network_params))
        experiment.log_parameters(vars(training_params))
        experiment.log_parameters(vars(dataset_params))
        experiment.log_dataset_info(name=dataset)
    else:
        experiment = None

    # ------------------------------------------------------------
    # 7. start training phase
    step = 0
    timestr = time.strftime("%Y%m%d-%H%M%S")
    for epoch in range(training_params.num_epochs):
        pbar = tqdm(total=len(train_dataset) / training_params.batch_size)
        loss_p = -1.0
        for batch in train_generator:
            # inputs -> ndarray[#batch, 3, 256, 512]
            # target_lanes -> List[ndarray] e.g. [[4, 48],..., ]
            # target_h -> List[ndarray] e.g. [[4, 48],..., ]
            # test_image -> ndarray [3, 256, 512]
            loss_p, metrics, outputs = lane_detector.train(batch,
                                                           epoch=epoch,
                                                           step=step)

            # parse confidence map from result
            outputs_last_block = outputs[-1]
            confidance, _, _ = outputs_last_block  # [8, 1, 32, 64]
            confidance = confidance[0].cpu().data.numpy()
            confidance = confidance.transpose((1, 2, 0))

            if log_comet_ml:
                experiment.log_metric("train total loss", loss_p)
                experiment.log_metric("learning rate", lane_detector.get_lr())
                experiment.log_metrics(metrics)
                if step % 500 == 0:
                    # log confidence output from first image in batch
                    experiment.log_image(confidance,
                                         name="image id:{}".format(
                                             batch["image_id"][0]))

            pbar.set_description(f'epoch {epoch}')
            pbar.set_postfix(total_loss=loss_p)
            pbar.update()
            step += 1
        pbar.close()

        # save model per epoch
        file_name = f"./tmp/{timestr}_epoch-{epoch}_totalstep-{step}_loss-{loss_p:.2f}.pth"
        lane_detector.save_model_v2(file_name)

        # if epoch % 10 == 0:
        if val_dataset is not None:
            file_name = f"./tmp/{timestr}_epoch-{epoch}_validation.json"
            scores = validate_fn(dataset=val_dataset,
                                 net=lane_detector,
                                 validate_file_name=file_name,
                                 logger=experiment)