def evaluate(valloader, args, params, mode): print(f'Teacher Accuracy') unet = models.unet.Unet('resnet34', classes=params['num_classes'], encoder_weights=None).to(args.gpu) unet.load_state_dict( torch.load('../saved_models/' + args.dataset + '/resnet34/pretrained_0.pt', map_location=args.gpu)) current_val_iou = mean_iou(unet, valloader, args) print(round(current_val_iou, 5)) print(f'Fractional data results for {mode}') for perc in [10, 20, 30, 40]: print('perc : ', perc) for model_name in [ 'resnet10', 'resnet14', 'resnet18', 'resnet20', 'resnet26' ]: params['model'] = model_name print('model : ', model_name) unet = models.unet.Unet(model_name, classes=params['num_classes'], encoder_weights=None).to(args.gpu) unet.load_state_dict( torch.load( get_savename(params, dataset=args.dataset, mode=mode, p=perc))) current_val_iou = mean_iou(unet, valloader, args) print(round(current_val_iou, 5)) print(f'Full data results for {mode}') for model_name in [ 'resnet10', 'resnet14', 'resnet18', 'resnet20', 'resnet26' ]: print('model : ', model_name) unet = models.unet.Unet(model_name, classes=params['num_classes'], encoder_weights=None).to(args.gpu) unet.load_state_dict( torch.load(get_savename(params, dataset=args.dataset, mode=mode))) current_val_iou = mean_iou(unet, valloader, args) print(round(current_val_iou, 5))
def train_simultaneous(hyper_params, teacher, student, sf_teacher, sf_student, trainloader, valloader, args): if args.api_key: project_name = 'simultaneous-' + hyper_params[ 'dataset'] + '-' + hyper_params['model'] experiment = Experiment(api_key=args.api_key, project_name=project_name, workspace=args.workspace) experiment.log_parameters(hyper_params) optimizer = torch.optim.Adam(student.parameters(), lr=hyper_params['learning_rate']) scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=1e-2, steps_per_epoch=len(trainloader), epochs=hyper_params['num_epochs']) if hyper_params['dataset'] == 'camvid': criterion = nn.CrossEntropyLoss(ignore_index=11) else: criterion = nn.CrossEntropyLoss(ignore_index=250) hyper_params['num_classes'] = 19 criterion2 = nn.MSELoss() savename = get_savename(hyper_params, dataset=args.dataset, mode='simultaneous', p=args.percentage) highest_iou = 0 for epoch in range(hyper_params['num_epochs']): _, _, train_loss, val_loss, avg_iou, avg_px_acc, avg_dice_coeff = train_simult( model=student, teacher=teacher, sf_teacher=sf_teacher, sf_student=sf_student, train_loader=trainloader, val_loader=valloader, num_classes=hyper_params['num_classes'], loss_function=criterion, loss_function2=criterion2, optimiser=optimizer, scheduler=scheduler, epoch=epoch, num_epochs=hyper_params['num_epochs'], savename=savename, highest_iou=highest_iou, args=args) if args.api_key: experiment.log_metric('train_loss', train_loss) experiment.log_metric('val_loss', val_loss) experiment.log_metric('avg_iou', avg_iou) experiment.log_metric('avg_pixel_acc', avg_px_acc) experiment.log_metric('avg_dice_coeff', avg_dice_coeff)
def pretrain(hyper_params, unet, trainloader, valloader, args): if args.api_key: project_name = 'pretrain-' + hyper_params[ 'dataset'] + '-' + hyper_params['model'] experiment = Experiment(api_key=args.api_key, project_name=project_name, workspace=args.workspace) experiment.log_parameters(hyper_params) optimizer = torch.optim.Adam(unet.parameters(), lr=hyper_params['learning_rate']) scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=1e-2, steps_per_epoch=len(trainloader), epochs=hyper_params['num_epochs']) if args.dataset == 'camvid': criterion = nn.CrossEntropyLoss(ignore_index=11) num_classes = 12 elif args.dataset == 'cityscapes': criterion = nn.CrossEntropyLoss(ignore_index=250) num_classes = 19 savename = get_savename(hyper_params, dataset=args.dataset, mode='pretrain', p=args.percentage) highest_iou = 0 losses = [] for epoch in range(hyper_params['num_epochs']): unet, highest_iou, train_loss, val_loss, avg_iou, avg_pixel_acc, avg_dice_coeff = train( model=unet, train_loader=trainloader, val_loader=valloader, num_classes=num_classes, loss_function=criterion, optimiser=optimizer, scheduler=scheduler, epoch=epoch, num_epochs=hyper_params['num_epochs'], savename=savename, highest_iou=highest_iou, args=args) if args.api_key: experiment.log_metric('train_loss', train_loss) experiment.log_metric('val_loss', val_loss) experiment.log_metric('avg_iou', avg_iou) experiment.log_metric('avg_pixel_acc', avg_pixel_acc) experiment.log_metric('avg_dice_coeff', avg_dice_coeff)
def train_traditional(hyper_params, teacher, student, sf_teacher, sf_student, trainloader, valloader, args): for stage in range(2): # Load previous stage model (except zeroth stage) if stage != 0: hyper_params['stage'] = stage - 1 student.load_state_dict( torch.load( get_savename(hyper_params, args.dataset, mode='traditional-stage', p=args.percentage))) # update hyperparams dictionary hyper_params['stage'] = stage # Freeze all stages except current stage student = unfreeze_trad(student, hyper_params['stage']) project_name = 'trad-kd-' + hyper_params[ 'dataset'] + '-' + hyper_params['model'] experiment = Experiment(api_key="1jNZ1sunRoAoI2TyremCNnYLO", project_name=project_name, workspace="semseg_kd") experiment.log_parameters(hyper_params) optimizer = torch.optim.Adam(student.parameters(), lr=hyper_params['learning_rate']) scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=1e-2, steps_per_epoch=len(trainloader), epochs=hyper_params['num_epochs']) criterion = nn.MSELoss() savename = get_savename(hyper_params, args.dataset, mode='traditional-stage', p=args.percentage) lowest_val_loss = 100 for epoch in range(hyper_params['num_epochs']): student, lowest_val_loss, train_loss, val_loss = train_stage( model=student, teacher=teacher, stage=hyper_params['stage'], sf_student=sf_student, sf_teacher=sf_teacher, train_loader=trainloader, val_loader=valloader, loss_function=criterion, optimiser=optimizer, scheduler=scheduler, epoch=epoch, num_epochs=hyper_params['num_epochs'], savename=savename, lowest_val=lowest_val_loss, args=args) experiment.log_metric('train_loss', train_loss) experiment.log_metric('val_loss', val_loss) print(round(val_loss, 6)) # Classifier training hyper_params['stage'] = 1 student.load_state_dict( torch.load( get_savename(hyper_params, args.dataset, mode='traditional-stage', p=args.percentage))) hyper_params['stage'] = 2 # Freeze all stages except current stage student = unfreeze_trad(student, hyper_params['stage']) project_name = 'trad-kd-' + hyper_params['dataset'] + '-' + hyper_params[ 'model'] experiment = Experiment(api_key="1jNZ1sunRoAoI2TyremCNnYLO", project_name=project_name, workspace="semseg_kd") experiment.log_parameters(hyper_params) optimizer = torch.optim.Adam(student.parameters(), lr=hyper_params['learning_rate']) scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=1e-2, steps_per_epoch=len(trainloader), epochs=hyper_params['num_epochs']) if hyper_params['dataset'] == 'camvid': criterion = nn.CrossEntropyLoss(ignore_index=11) else: criterion = nn.CrossEntropyLoss(ignore_index=250) hyper_params['num_classes'] = 19 savename = get_savename(hyper_params, args.dataset, mode='traditional-kd', p=args.percentage) highest_iou = 0 for epoch in range(hyper_params['num_epochs']): student, highest_iou, train_loss, val_loss, avg_iou, avg_pixel_acc, avg_dice_coeff = train( model=student, train_loader=trainloader, val_loader=valloader, num_classes=12, loss_function=criterion, optimiser=optimizer, scheduler=scheduler, epoch=epoch, num_epochs=hyper_params['num_epochs'], savename=savename, highest_iou=highest_iou, args=args) experiment.log_metric('train_loss', train_loss) experiment.log_metric('val_loss', val_loss) experiment.log_metric('avg_iou', avg_iou) experiment.log_metric('avg_pixel_acc', avg_pixel_acc) experiment.log_metric('avg_dice_coeff', avg_dice_coeff)
def train_stagewise(hyper_params, teacher, student, sf_teacher, sf_student, trainloader, valloader, args): for stage in range(10): # Load previous stage model (except zeroth stage) if stage != 0: # hyperparams dict for loading previous stage weights hyper_params['stage'] = stage - 1 student.load_state_dict( torch.load( get_savename(hyper_params, mode='stagewise', p=args.percentage))) # update hyperparams dictionary for current stage hyper_params['stage'] = stage # Freeze all stages except current stage student = unfreeze(student, hyper_params['stage']) if args.api_key: project_name = 'stagewise-' + hyper_params['model'] experiment = Experiment(api_key=args.api_key, project_name=project_name, workspace=args.workspace) experiment.log_parameters(hyper_params) optimizer = torch.optim.Adam(student.parameters(), lr=hyper_params['learning_rate']) scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=1e-2, steps_per_epoch=len(trainloader), epochs=hyper_params['num_epochs']) criterion = nn.MSELoss() savename = get_savename(hyper_params, dataset=args.dataset, mode='stagewise', p=args.percentage) lowest_val_loss = 100 losses = [] for epoch in range(hyper_params['num_epochs']): student, lowest_val_loss, train_loss, val_loss = train_stage( model=student, teacher=teacher, stage=hyper_params['stage'], sf_student=sf_student, sf_teacher=sf_teacher, train_loader=trainloader, val_loader=valloader, loss_function=criterion, optimiser=optimizer, scheduler=scheduler, epoch=epoch, num_epochs=hyper_params['num_epochs'], savename=savename, lowest_val=lowest_val_loss, args=args) if args.api_key: experiment.log_metric('train_loss', train_loss) experiment.log_metric('val_loss', val_loss) hyper_params['stage'] = 9 student.load_state_dict( torch.load( get_savename(hyper_params, mode='stagewise', p=args.percentage))) hyper_params['stage'] = 10 # Freeze all stages except current stage student = unfreeze(student, hyper_params['stage']) if args.api_key: project_name = 'stagewise-' + hyper_params['model'] experiment = Experiment(api_key=args.api_key, project_name=project_name, workspace=args.workspace) experiment.log_parameters(hyper_params) optimizer = torch.optim.Adam(student.parameters(), lr=hyper_params['learning_rate']) scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=1e-2, steps_per_epoch=len(trainloader), epochs=hyper_params['num_epochs']) criterion = nn.CrossEntropyLoss(ignore_index=11) savename = get_savename(hyper_params, dataset=args.dataset, mode='classifier', p=args.percentage) highest_iou = 0 losses = [] for epoch in range(hyper_params['num_epochs']): student, highest_iou, train_loss, val_loss, avg_iou, avg_pixel_acc, avg_dice_coeff = train( model=student, train_loader=trainloader, val_loader=valloader, num_classes=hyper_params['num_classes'], loss_function=criterion, optimiser=optimizer, scheduler=scheduler, epoch=epoch, num_epochs=hyper_params['num_epochs'], savename=savename, highest_iou=highest_iou, args=args) if args.api_key: experiment.log_metric('train_loss', train_loss) experiment.log_metric('val_loss', val_loss) experiment.log_metric('avg_iou', avg_iou) experiment.log_metric('avg_pixel_acc', avg_pixel_acc) experiment.log_metric('avg_dice_coeff', avg_dice_coeff)