def validate(model,
             criterion,
             loader,
             device,
             writer,
             cur_epoch,
             calc_roc=False) -> Tuple[float, float]:
    model.eval()
    running_loss = 0.0
    scores = Scores()

    for i, batch in tqdm(enumerate(loader),
                         total=len(loader),
                         desc='Validation'):
        inputs = batch['frames'].to(device, dtype=torch.float)
        labels = batch['label'].to(device)
        video_name = batch['name']

        with torch.no_grad():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * inputs.size(0)

        scores.add(preds, labels, tags=video_name)

    val_scores = scores.calc_scores(as_dict=True)
    val_scores['loss'] = running_loss / len(loader.dataset)
    write_scores(writer, 'val', val_scores, cur_epoch)

    video_scores = scores.calc_scores_eye(as_dict=True, ratio=0.55)
    write_scores(writer, 'eval', video_scores, cur_epoch)
    scores.data.to_csv(f'training_video_{video_scores["f1"]}.csv', index=False)

    return running_loss / len(loader.dataset), val_scores['f1']
def validate(model, criterion, loader, device, writer, hp, cur_epoch, calc_roc=False) -> Tuple[float, float]:
    model.eval()
    running_loss = 0.0
    scores = Scores()

    for i, batch in tqdm(enumerate(loader), total=len(loader), desc='Validation'):
        inputs = batch['frames'].to(device, dtype=torch.float)
        labels = batch['label'].to(device)
        eye_ids = batch['name']

        with torch.no_grad():
            loss, attention_weights, prob = model.calculate_objective(inputs, labels)
            error, preds = model.calculate_classification_error(inputs, labels) 
            running_loss += loss.item()

        scores.add(preds, labels, probs=prob, tags=eye_ids, attention=attention_weights, files=batch['frame_names'])
        #print(len(attention_weights), len(batch['frame_names']))
        #print(attention_weights)

    val_scores = scores.calc_scores(as_dict=True)
    val_scores['loss'] = running_loss / len(loader.dataset)
    if not calc_roc:
        write_scores(writer, 'val', val_scores, cur_epoch)
        eye_scores = scores.calc_scores_eye(as_dict=True)
        write_scores(writer, 'eval', eye_scores, cur_epoch)
        if eye_scores['f1'] > 0.1: scores.data.to_csv(os.path.join(RES_PATH, f'training_mil_avg_{val_scores["f1"]}_{eye_scores["f1"]}.csv'), index=False)
    else:
        eye_scores = scores.calc_scores_eye(as_dict=True)
        # writer.add_hparams(hparam_dict=hp, metric_dict=eye_scores)
        scores.data.to_csv(os.path.join(RES_PATH, f'{time.strftime("%Y%m%d")}_best_mil_model_{val_scores["f1"]:0.2}.csv'), index=False)
    
    return running_loss / len(loader.dataset), eye_scores['f1']
def train_model(model, criterion, optimizer, scheduler, loaders, device, writer, hp, num_epochs=50,
                description='Vanilla'):
    print('Training model...')
    since = time.time()
    best_f1_val, val_f1 = -1, -1
    best_model = None

    for epoch in range(num_epochs):
        print(f'{time.strftime("%H:%M:%S")}> Epoch {epoch}/{num_epochs}')
        print('-' * 10)

        running_loss = 0.0
        scores = Scores()

        for i, batch in tqdm(enumerate(loaders[0]), total=len(loaders[0]), desc=f'Epoch {epoch}'):
            inputs = batch['frames'].to(device, dtype=torch.float)
            label = batch['label'].to(device)

            model.train()
            optimizer.zero_grad()

            loss, _, _ = model.calculate_objective(inputs, label)
            error, pred = model.calculate_classification_error(inputs, label)
            
            loss.backward()
            optimizer.step()

            scores.add(pred, label)
            running_loss += loss.item()

        train_scores = scores.calc_scores(as_dict=True)
        train_scores['loss'] = running_loss / len(loaders[0].dataset)
        write_scores(writer, 'train', train_scores, epoch)
        val_loss, val_f1 = validate(model, criterion, loaders[1], device, writer, hp, epoch)

        if val_f1 > best_f1_val:
            best_f1_val = val_f1
            best_model = copy.deepcopy(model)

        scheduler.step(val_loss)

    time_elapsed = time.time() - since
    print(f'{time.strftime("%H:%M:%S")}> Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best f1 score: {best_f1_val}, model saved...')
    
    torch.save(model.state_dict(), os.path.join(RES_PATH, f'{time.strftime("%Y%m%d")}_last_model_score{val_f1}.pth'))
    torch.save(best_model.state_dict(), os.path.join(RES_PATH, f'{time.strftime("%Y%m%d")}_best_model_score{best_f1_val}.pth'))
    torch.save(model.stump.state_dict(), os.path.join(RES_PATH, f'{time.strftime("%Y%m%d")}_stump_score{val_f1}.pth'))
    return best_model
예제 #4
0
def validate(model,
             criterion,
             loader,
             device,
             writer,
             cur_epoch,
             calc_roc=False) -> Tuple[float, float]:
    model.eval()
    running_loss = 0.0
    perf_metrics = nn_utils.Scores()

    for i, batch in tqdm(enumerate(loader),
                         total=len(loader),
                         desc='Validation'):
        inputs = batch['image'].to(device, dtype=torch.float)
        labels = batch['label'].to(device)
        # crop_idx = batch['image_idx']
        crop_idx = batch['eye_id']

        with torch.no_grad():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * inputs.size(0)

        perf_metrics.add(preds, labels, tags=crop_idx)

    scores = perf_metrics.calc_scores(as_dict=True)
    scores['loss'] = running_loss / len(loader.dataset)
    nn_utils.write_scores(writer, 'val', scores, cur_epoch, full_report=True)

    # print(majority_dict)
    # print(labels, preds)
    #f1_video, recall_video, precision_video = f1_score(labels, preds), recall_score(labels, preds), precision_score(labels, preds)
    #print(f'Validation scores (all 5 crops):\n F1: {f1_video},\n Precision: {precision_video},\n Recall: {recall_video}')
    #writer.add_scalar('val/crof1', f1_video, cur_epoch)

    return running_loss / len(loader.dataset), scores['f1']
def validate(model,
             criterion,
             loader,
             device,
             writer,
             hp,
             cur_epoch,
             is_test=False) -> Tuple[float, float]:
    model.eval()
    running_loss = 0.0
    sm = torch.nn.Softmax(dim=1)
    scores = Scores()

    for i, batch in enumerate(loader):
        inputs = batch['image'].to(device, dtype=torch.float)
        labels = batch['label'].to(device)
        eye_ids = batch['eye_id']

        with torch.no_grad():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)
            probs = sm(outputs)
            running_loss += loss.item() * inputs.size(0)

        scores.add(preds, labels, tags=eye_ids, probs=probs)

    val_scores = scores.calc_scores(as_dict=True)
    val_scores['loss'] = running_loss / len(loader.dataset)

    eye_scores = scores.calc_scores_eye(as_dict=True,
                                        top_percent=hp['voting_percentage'])
    if not is_test:
        write_scores(writer, 'val', val_scores, cur_epoch)
        write_scores(writer, 'eval', eye_scores, cur_epoch)
    else:
        write_scores(writer, 'test', val_scores, cur_epoch)
        write_scores(writer, 'etest', eye_scores, cur_epoch)

    return running_loss / len(loader.dataset), eye_scores['f1']
def train_model(model,
                criterion,
                optimizer,
                scheduler,
                loaders,
                device,
                writer,
                hp,
                num_epochs=50,
                description='Vanilla'):
    since = time.time()
    best_f1_val = -1
    best_model_path = None

    for epoch in range(num_epochs):
        print(f'{time.strftime("%H:%M:%S")}> Epoch {epoch}/{num_epochs}')
        print('-' * 10)

        running_loss = 0.0
        scores = Scores()

        # Iterate over data.
        for i, batch in tqdm(enumerate(loaders[0]),
                             total=len(loaders[0]),
                             desc=f'Epoch {epoch}'):
            inputs = batch['image'].to(device, dtype=torch.float)
            labels = batch['label'].to(device)

            model.train()
            optimizer.zero_grad()
            outputs = model(inputs)
            _, pred = torch.max(outputs, 1)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            scores.add(pred, labels)
            running_loss += loss.item() * inputs.size(0)

        train_scores = scores.calc_scores(as_dict=True)
        train_scores['loss'] = running_loss / len(loaders[0].dataset)
        write_scores(writer, 'train', train_scores, epoch)
        val_loss, val_f1 = validate(model, criterion, loaders[1], device,
                                    writer, hp, epoch)
        if hp['validation'] == 'tvt':
            validate(model,
                     criterion,
                     loaders[2],
                     device,
                     writer,
                     hp,
                     epoch,
                     is_test=True)

        if val_f1 > best_f1_val:
            best_f1_val = val_f1
            best_model_path = f'best_frames_model_f1_{val_f1:0.2}_epoch_{epoch}.pth'
            torch.save(model.state_dict(), best_model_path)

        scheduler.step(val_loss)

    time_elapsed = time.time() - since
    print(
        f'{time.strftime("%H:%M:%S")}> Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s with best f1 score of {best_f1_val}'
    )

    if hp['validation'] == 'tt' or hp[
            'validation'] == 'tv':  # For test set evaluation, the last model should be used (specified by number of epochs) therefore overwrite the best_model_path
        best_model_path = f'frames_model_evalf1_{val_f1:0.2}_epoch_{epoch}.pth'
        torch.save(model.state_dict(), best_model_path)
    return best_model_path, val_f1
def train_model(model,
                criterion,
                optimizer,
                scheduler,
                loaders,
                device,
                writer,
                num_epochs=50,
                description='Vanilla'):
    since = time.time()
    best_f1_val = -1
    model.to(device)

    for epoch in range(num_epochs):
        print(f'{time.strftime("%H:%M:%S")}> Epoch {epoch}/{num_epochs}')
        print('-' * 10)

        running_loss = 0.0
        scores = Scores()

        # Iterate over data.
        for i, batch in tqdm(enumerate(loaders[0]),
                             total=len(loaders[0]),
                             desc=f'Epoch {epoch}'):
            inputs = batch['frames'].to(device, dtype=torch.float)
            labels = batch['label'].to(device)

            model.train()
            optimizer.zero_grad()
            outputs = model(inputs)
            _, pred = torch.max(outputs, 1)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # statistics
            running_loss += loss.item() * inputs.size(0)
            scores.add(pred, labels)

        train_scores = scores.calc_scores(as_dict=True)
        train_scores['loss'] = running_loss / len(loaders[0].dataset)
        write_scores(writer, 'train', train_scores, epoch)
        val_loss, val_f1 = validate(model, criterion, loaders[1], device,
                                    writer, epoch)

        best_f1_val = val_f1 if val_f1 > best_f1_val else best_f1_val

        scheduler.step(val_loss)

    time_elapsed = time.time() - since
    print(
        f'{time.strftime("%H:%M:%S")}> Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s with best f1 score of {best_f1_val}'
    )

    validate(model,
             criterion,
             loaders[1],
             device,
             writer,
             num_epochs,
             calc_roc=True)
    torch.save(model.state_dict(), f'model{description}')
    return model