def evaluate(valloader, args, params, mode):
    print(f'Teacher Accuracy')
    unet = models.unet.Unet('resnet34',
                            classes=params['num_classes'],
                            encoder_weights=None).to(args.gpu)
    unet.load_state_dict(
        torch.load('../saved_models/' + args.dataset +
                   '/resnet34/pretrained_0.pt',
                   map_location=args.gpu))
    current_val_iou = mean_iou(unet, valloader, args)
    print(round(current_val_iou, 5))

    print(f'Fractional data results for {mode}')
    for perc in [10, 20, 30, 40]:
        print('perc : ', perc)
        for model_name in [
                'resnet10', 'resnet14', 'resnet18', 'resnet20', 'resnet26'
        ]:
            params['model'] = model_name
            print('model : ', model_name)
            unet = models.unet.Unet(model_name,
                                    classes=params['num_classes'],
                                    encoder_weights=None).to(args.gpu)
            unet.load_state_dict(
                torch.load(
                    get_savename(params,
                                 dataset=args.dataset,
                                 mode=mode,
                                 p=perc)))
            current_val_iou = mean_iou(unet, valloader, args)
            print(round(current_val_iou, 5))

    print(f'Full data results for {mode}')
    for model_name in [
            'resnet10', 'resnet14', 'resnet18', 'resnet20', 'resnet26'
    ]:
        print('model : ', model_name)
        unet = models.unet.Unet(model_name,
                                classes=params['num_classes'],
                                encoder_weights=None).to(args.gpu)
        unet.load_state_dict(
            torch.load(get_savename(params, dataset=args.dataset, mode=mode)))
        current_val_iou = mean_iou(unet, valloader, args)
        print(round(current_val_iou, 5))
Ejemplo n.º 2
0
def train_simultaneous(hyper_params, teacher, student, sf_teacher, sf_student,
                       trainloader, valloader, args):
    if args.api_key:
        project_name = 'simultaneous-' + hyper_params[
            'dataset'] + '-' + hyper_params['model']
        experiment = Experiment(api_key=args.api_key,
                                project_name=project_name,
                                workspace=args.workspace)
        experiment.log_parameters(hyper_params)

    optimizer = torch.optim.Adam(student.parameters(),
                                 lr=hyper_params['learning_rate'])
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=1e-2,
        steps_per_epoch=len(trainloader),
        epochs=hyper_params['num_epochs'])
    if hyper_params['dataset'] == 'camvid':
        criterion = nn.CrossEntropyLoss(ignore_index=11)
    else:
        criterion = nn.CrossEntropyLoss(ignore_index=250)
        hyper_params['num_classes'] = 19
    criterion2 = nn.MSELoss()

    savename = get_savename(hyper_params,
                            dataset=args.dataset,
                            mode='simultaneous',
                            p=args.percentage)
    highest_iou = 0
    for epoch in range(hyper_params['num_epochs']):
        _, _, train_loss, val_loss, avg_iou, avg_px_acc, avg_dice_coeff = train_simult(
            model=student,
            teacher=teacher,
            sf_teacher=sf_teacher,
            sf_student=sf_student,
            train_loader=trainloader,
            val_loader=valloader,
            num_classes=hyper_params['num_classes'],
            loss_function=criterion,
            loss_function2=criterion2,
            optimiser=optimizer,
            scheduler=scheduler,
            epoch=epoch,
            num_epochs=hyper_params['num_epochs'],
            savename=savename,
            highest_iou=highest_iou,
            args=args)
        if args.api_key:
            experiment.log_metric('train_loss', train_loss)
            experiment.log_metric('val_loss', val_loss)
            experiment.log_metric('avg_iou', avg_iou)
            experiment.log_metric('avg_pixel_acc', avg_px_acc)
            experiment.log_metric('avg_dice_coeff', avg_dice_coeff)
Ejemplo n.º 3
0
def pretrain(hyper_params, unet, trainloader, valloader, args):
    if args.api_key:
        project_name = 'pretrain-' + hyper_params[
            'dataset'] + '-' + hyper_params['model']
        experiment = Experiment(api_key=args.api_key,
                                project_name=project_name,
                                workspace=args.workspace)
        experiment.log_parameters(hyper_params)

    optimizer = torch.optim.Adam(unet.parameters(),
                                 lr=hyper_params['learning_rate'])
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=1e-2,
        steps_per_epoch=len(trainloader),
        epochs=hyper_params['num_epochs'])
    if args.dataset == 'camvid':
        criterion = nn.CrossEntropyLoss(ignore_index=11)
        num_classes = 12
    elif args.dataset == 'cityscapes':
        criterion = nn.CrossEntropyLoss(ignore_index=250)
        num_classes = 19

    savename = get_savename(hyper_params,
                            dataset=args.dataset,
                            mode='pretrain',
                            p=args.percentage)
    highest_iou = 0
    losses = []
    for epoch in range(hyper_params['num_epochs']):
        unet, highest_iou, train_loss, val_loss, avg_iou, avg_pixel_acc, avg_dice_coeff = train(
            model=unet,
            train_loader=trainloader,
            val_loader=valloader,
            num_classes=num_classes,
            loss_function=criterion,
            optimiser=optimizer,
            scheduler=scheduler,
            epoch=epoch,
            num_epochs=hyper_params['num_epochs'],
            savename=savename,
            highest_iou=highest_iou,
            args=args)
        if args.api_key:
            experiment.log_metric('train_loss', train_loss)
            experiment.log_metric('val_loss', val_loss)
            experiment.log_metric('avg_iou', avg_iou)
            experiment.log_metric('avg_pixel_acc', avg_pixel_acc)
            experiment.log_metric('avg_dice_coeff', avg_dice_coeff)
def train_traditional(hyper_params, teacher, student, sf_teacher, sf_student,
                      trainloader, valloader, args):
    for stage in range(2):
        # Load previous stage model (except zeroth stage)
        if stage != 0:
            hyper_params['stage'] = stage - 1
            student.load_state_dict(
                torch.load(
                    get_savename(hyper_params,
                                 args.dataset,
                                 mode='traditional-stage',
                                 p=args.percentage)))

        # update hyperparams dictionary
        hyper_params['stage'] = stage

        # Freeze all stages except current stage
        student = unfreeze_trad(student, hyper_params['stage'])

        project_name = 'trad-kd-' + hyper_params[
            'dataset'] + '-' + hyper_params['model']
        experiment = Experiment(api_key="1jNZ1sunRoAoI2TyremCNnYLO",
                                project_name=project_name,
                                workspace="semseg_kd")
        experiment.log_parameters(hyper_params)

        optimizer = torch.optim.Adam(student.parameters(),
                                     lr=hyper_params['learning_rate'])
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=1e-2,
            steps_per_epoch=len(trainloader),
            epochs=hyper_params['num_epochs'])
        criterion = nn.MSELoss()

        savename = get_savename(hyper_params,
                                args.dataset,
                                mode='traditional-stage',
                                p=args.percentage)
        lowest_val_loss = 100
        for epoch in range(hyper_params['num_epochs']):
            student, lowest_val_loss, train_loss, val_loss = train_stage(
                model=student,
                teacher=teacher,
                stage=hyper_params['stage'],
                sf_student=sf_student,
                sf_teacher=sf_teacher,
                train_loader=trainloader,
                val_loader=valloader,
                loss_function=criterion,
                optimiser=optimizer,
                scheduler=scheduler,
                epoch=epoch,
                num_epochs=hyper_params['num_epochs'],
                savename=savename,
                lowest_val=lowest_val_loss,
                args=args)
            experiment.log_metric('train_loss', train_loss)
            experiment.log_metric('val_loss', val_loss)
            print(round(val_loss, 6))

    # Classifier training
    hyper_params['stage'] = 1
    student.load_state_dict(
        torch.load(
            get_savename(hyper_params,
                         args.dataset,
                         mode='traditional-stage',
                         p=args.percentage)))
    hyper_params['stage'] = 2

    # Freeze all stages except current stage
    student = unfreeze_trad(student, hyper_params['stage'])

    project_name = 'trad-kd-' + hyper_params['dataset'] + '-' + hyper_params[
        'model']
    experiment = Experiment(api_key="1jNZ1sunRoAoI2TyremCNnYLO",
                            project_name=project_name,
                            workspace="semseg_kd")
    experiment.log_parameters(hyper_params)

    optimizer = torch.optim.Adam(student.parameters(),
                                 lr=hyper_params['learning_rate'])
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=1e-2,
        steps_per_epoch=len(trainloader),
        epochs=hyper_params['num_epochs'])
    if hyper_params['dataset'] == 'camvid':
        criterion = nn.CrossEntropyLoss(ignore_index=11)
    else:
        criterion = nn.CrossEntropyLoss(ignore_index=250)
        hyper_params['num_classes'] = 19

    savename = get_savename(hyper_params,
                            args.dataset,
                            mode='traditional-kd',
                            p=args.percentage)
    highest_iou = 0
    for epoch in range(hyper_params['num_epochs']):
        student, highest_iou, train_loss, val_loss, avg_iou, avg_pixel_acc, avg_dice_coeff = train(
            model=student,
            train_loader=trainloader,
            val_loader=valloader,
            num_classes=12,
            loss_function=criterion,
            optimiser=optimizer,
            scheduler=scheduler,
            epoch=epoch,
            num_epochs=hyper_params['num_epochs'],
            savename=savename,
            highest_iou=highest_iou,
            args=args)
        experiment.log_metric('train_loss', train_loss)
        experiment.log_metric('val_loss', val_loss)
        experiment.log_metric('avg_iou', avg_iou)
        experiment.log_metric('avg_pixel_acc', avg_pixel_acc)
        experiment.log_metric('avg_dice_coeff', avg_dice_coeff)
Ejemplo n.º 5
0
def train_stagewise(hyper_params, teacher, student, sf_teacher, sf_student,
                    trainloader, valloader, args):
    for stage in range(10):
        # Load previous stage model (except zeroth stage)
        if stage != 0:
            # hyperparams dict for loading previous stage weights
            hyper_params['stage'] = stage - 1
            student.load_state_dict(
                torch.load(
                    get_savename(hyper_params,
                                 mode='stagewise',
                                 p=args.percentage)))

        # update hyperparams dictionary for current stage
        hyper_params['stage'] = stage
        # Freeze all stages except current stage
        student = unfreeze(student, hyper_params['stage'])

        if args.api_key:
            project_name = 'stagewise-' + hyper_params['model']
            experiment = Experiment(api_key=args.api_key,
                                    project_name=project_name,
                                    workspace=args.workspace)
            experiment.log_parameters(hyper_params)

        optimizer = torch.optim.Adam(student.parameters(),
                                     lr=hyper_params['learning_rate'])
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=1e-2,
            steps_per_epoch=len(trainloader),
            epochs=hyper_params['num_epochs'])
        criterion = nn.MSELoss()

        savename = get_savename(hyper_params,
                                dataset=args.dataset,
                                mode='stagewise',
                                p=args.percentage)
        lowest_val_loss = 100
        losses = []
        for epoch in range(hyper_params['num_epochs']):
            student, lowest_val_loss, train_loss, val_loss = train_stage(
                model=student,
                teacher=teacher,
                stage=hyper_params['stage'],
                sf_student=sf_student,
                sf_teacher=sf_teacher,
                train_loader=trainloader,
                val_loader=valloader,
                loss_function=criterion,
                optimiser=optimizer,
                scheduler=scheduler,
                epoch=epoch,
                num_epochs=hyper_params['num_epochs'],
                savename=savename,
                lowest_val=lowest_val_loss,
                args=args)
            if args.api_key:
                experiment.log_metric('train_loss', train_loss)
                experiment.log_metric('val_loss', val_loss)

    hyper_params['stage'] = 9
    student.load_state_dict(
        torch.load(
            get_savename(hyper_params, mode='stagewise', p=args.percentage)))
    hyper_params['stage'] = 10
    # Freeze all stages except current stage
    student = unfreeze(student, hyper_params['stage'])
    if args.api_key:
        project_name = 'stagewise-' + hyper_params['model']
        experiment = Experiment(api_key=args.api_key,
                                project_name=project_name,
                                workspace=args.workspace)
        experiment.log_parameters(hyper_params)

    optimizer = torch.optim.Adam(student.parameters(),
                                 lr=hyper_params['learning_rate'])
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=1e-2,
        steps_per_epoch=len(trainloader),
        epochs=hyper_params['num_epochs'])
    criterion = nn.CrossEntropyLoss(ignore_index=11)

    savename = get_savename(hyper_params,
                            dataset=args.dataset,
                            mode='classifier',
                            p=args.percentage)
    highest_iou = 0
    losses = []
    for epoch in range(hyper_params['num_epochs']):
        student, highest_iou, train_loss, val_loss, avg_iou, avg_pixel_acc, avg_dice_coeff = train(
            model=student,
            train_loader=trainloader,
            val_loader=valloader,
            num_classes=hyper_params['num_classes'],
            loss_function=criterion,
            optimiser=optimizer,
            scheduler=scheduler,
            epoch=epoch,
            num_epochs=hyper_params['num_epochs'],
            savename=savename,
            highest_iou=highest_iou,
            args=args)
        if args.api_key:
            experiment.log_metric('train_loss', train_loss)
            experiment.log_metric('val_loss', val_loss)
            experiment.log_metric('avg_iou', avg_iou)
            experiment.log_metric('avg_pixel_acc', avg_pixel_acc)
            experiment.log_metric('avg_dice_coeff', avg_dice_coeff)