def gtrobustlosstrain(parse, config: ConfigParser): dataset_name = config['name'].split('_')[0] lr_scheduler_name = config['lr_scheduler']['type'] loss_fn_name = config['train_loss']['type'] wandb_run_name_list = [] if parse.distillation: if parse.distill_mode == 'eigen': wandb_run_name_list.append('distil') else: wandb_run_name_list.append('kmeans') else: wandb_run_name_list.append('baseline') wandb_run_name_list.append(dataset_name) wandb_run_name_list.append(lr_scheduler_name) wandb_run_name_list.append(loss_fn_name) wandb_run_name_list.append(str(config['trainer']['asym'])) wandb_run_name_list.append(str(config['trainer']['percent'])) wandb_run_name = '_'.join(wandb_run_name_list) if parse.no_wandb: wandb.init(config=config, project='noisylabel', entity='goguryeo', name=wandb_run_name) # By default, pytorch utilizes multi-threaded cpu # Set to handle whole procedures on a single core torch.set_num_threads(1) logger = config.get_logger('train') # Set seed for reproducibility random.seed(config['seed']) torch.manual_seed(config['seed']) torch.cuda.manual_seed_all(config['seed']) torch.backends.cudnn.deterministic = True np.random.seed(config['seed']) data_loader = getattr(module_data, config['data_loader']['type'])( config['data_loader']['args']['data_dir'], batch_size=config['data_loader']['args']['batch_size'], shuffle=False if parse.distillation else config['data_loader']['args']['shuffle'], # validation_split=config['data_loader']['args']['validation_split'], validation_split=0.0, num_batches=config['data_loader']['args']['num_batches'], training=True, num_workers=config['data_loader']['args']['num_workers'], pin_memory=config['data_loader']['args']['pin_memory']) # valid_data_loader = data_loader.split_validation() valid_data_loader = None # test_data_loader = None test_data_loader = getattr(module_data, config['data_loader']['type'])( config['data_loader']['args']['data_dir'], batch_size=128, shuffle=False, validation_split=0.0, training=False, num_workers=2).split_validation() # build model architecture, then print to console model = config.initialize('arch', module_arch) if parse.no_wandb: wandb.watch(model) if parse.distillation: teacher = config.initialize('arch', module_arch) teacher.load_state_dict( torch.load('./checkpoint/' + parse.load_name)['state_dict']) if not parse.reinit: model.load_state_dict( torch.load('./checkpoint/' + parse.load_name)['state_dict']) for params in teacher.parameters(): params.requires_grad = False if parse.distill_mode == 'eigen': tea_label_list, tea_out_list = get_out_list(teacher, data_loader) singular_dict, v_ortho_dict = get_singular_value_vector( tea_label_list, tea_out_list) for key in v_ortho_dict.keys(): v_ortho_dict[key] = v_ortho_dict[key].cuda() teacher_idx = singular_label(v_ortho_dict, tea_out_list, tea_label_list) else: teacher_idx = get_out_list(teacher, data_loader) data_loader = getattr(module_data, config['data_loader']['type'])( config['data_loader']['args']['data_dir'], batch_size=config['data_loader']['args']['batch_size'], shuffle=config['data_loader']['args']['shuffle'], # validation_split=config['data_loader']['args']['validation_split'], validation_split=0.0, num_batches=config['data_loader']['args']['num_batches'], training=True, num_workers=config['data_loader']['args']['num_workers'], pin_memory=config['data_loader']['args']['pin_memory'], teacher_idx=teacher_idx) else: teacher = None # get function handles of loss and metrics logger.info(config.config) if hasattr(data_loader.dataset, 'num_raw_example'): num_examp = data_loader.dataset.num_raw_example else: num_examp = len(data_loader.dataset) if config['train_loss']['type'] == 'ELR_GTLoss': train_loss = getattr(module_loss, 'ELR_GTLoss')( num_examp=num_examp, num_classes=config['num_classes'], beta=config['train_loss']['args']['beta']) elif config['train_loss']['type'] == 'SCE_GTLoss': train_loss = getattr(module_loss, 'SCE_GTLoss')( alpha=config['train_loss']['args']['alpha'], beta=config['train_loss']['args']['beta'], num_classes=config['num_classes']) elif config['train_loss']['type'] == 'GCE_GTLoss': train_loss = getattr(module_loss, 'GCE_GTLoss')( q=config['train_loss']['args']['q'], k=config['train_loss']['args']['k'], trainset_size=num_examp, truncated=config['train_loss']['args']['truncated']) elif config['train_loss']['type'] == 'CCE_GTLoss': train_loss = getattr(module_loss, 'CCE_GTLoss')() val_loss = getattr(module_loss, config['val_loss']) metrics = [getattr(module_metric, met) for met in config['metrics']] # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler trainable_params = filter(lambda p: p.requires_grad, model.parameters()) optimizer = config.initialize('optimizer', torch.optim, [{ 'params': trainable_params }]) lr_scheduler = config.initialize('lr_scheduler', torch.optim.lr_scheduler, optimizer) if config['train_loss']['type'] == 'ELR_GTLoss': trainer = GroundTruthTrainer(model, train_loss, metrics, optimizer, config=config, data_loader=data_loader, teacher=teacher, valid_data_loader=valid_data_loader, test_data_loader=test_data_loader, lr_scheduler=lr_scheduler, val_criterion=val_loss, mode=parse.mode, entropy=parse.entropy, threshold=parse.threshold) elif config['train_loss']['type'] == 'SCE_GTLoss': trainer = GroundTruthTrainer(model, train_loss, metrics, optimizer, config=config, data_loader=data_loader, teacher=teacher, valid_data_loader=valid_data_loader, test_data_loader=test_data_loader, lr_scheduler=lr_scheduler, val_criterion=val_loss, mode=parse.mode, entropy=parse.entropy, threshold=parse.threshold) elif config['train_loss']['type'] == 'GCE_GTLoss': trainer = GroundTruthTrainer(model, train_loss, metrics, optimizer, config=config, data_loader=data_loader, teacher=teacher, valid_data_loader=valid_data_loader, test_data_loader=test_data_loader, lr_scheduler=lr_scheduler, val_criterion=val_loss, mode=parse.mode, entropy=parse.entropy, threshold=parse.threshold) elif config['train_loss']['type'] == 'CCE_GTLoss': trainer = GroundTruthTrainer(model, train_loss, metrics, optimizer, config=config, data_loader=data_loader, teacher=teacher, valid_data_loader=valid_data_loader, test_data_loader=test_data_loader, lr_scheduler=lr_scheduler, val_criterion=val_loss, mode=parse.mode, entropy=parse.entropy, threshold=parse.threshold) trainer.train() logger = config.get_logger('trainer', config['trainer']['verbosity']) cfg_trainer = config['trainer']
def trainClothing1m(parse, config: ConfigParser): # implementation for WandB wandb_run_name_list = wandbRunlist(config, parse) if parse.no_wandb: wandb.init(config=config, project='noisylabel', entity='goguryeo', name=wandb_run_name) # By default, pytorch utilizes multi-threaded cpu # Set to handle whole procedures on a single core numthread = torch.get_num_threads() torch.set_num_threads(numthread) logger = config.get_logger('train') # Set seed for reproducibility fix_seed(config['seed']) data_loader = getattr(module_data, config['data_loader']['type'])( config['data_loader']['args']['data_dir'], batch_size= config['data_loader']['args']['batch_size'], shuffle=False if parse.distillation else config['data_loader']['args']['shuffle'] , validation_split=0.0, num_batches=config['data_loader']['args']['num_batches'], training=True, num_workers=config['data_loader']['args']['num_workers'], pin_memory=config['data_loader']['args']['pin_memory'] ) # valid_data_loader = data_loader.split_validation() valid_data_loader = None # test_data_loader = None test_data_loader = getattr(module_data, config['data_loader']['type'])( config['data_loader']['args']['data_dir'], batch_size=128, shuffle=False, validation_split=0.0, training=False, num_workers=0 ).split_validation() print('---------') # build model architecture, then print to console # model = config.initialize('arch', module_arch) model = getattr(module_arch, 'resnet50')(pretrained=True, num_classes=config["num_classes"]) if parse.no_wandb: wandb.watch(model) if parse.distillation: teacher = config.initialize('arch', module_arch) data_loader = getattr(module_data, config['data_loader']['type'])( config['data_loader']['args']['data_dir'], batch_size= config['data_loader']['args']['batch_size'], shuffle=config['data_loader']['args']['shuffle'], # validation_split=config['data_loader']['args']['validation_split'], validation_split=0.0, num_batches=config['data_loader']['args']['num_batches'], training=True, num_workers=config['data_loader']['args']['num_workers'], pin_memory=config['data_loader']['args']['pin_memory'], teacher_idx = extract_cleanidx(teacher, data_loader, parse)) else: teacher = None # get function handles of loss and metrics logger.info(config.config) if hasattr(data_loader.dataset, 'num_raw_example'): num_examp = data_loader.dataset.num_raw_example else: num_examp = len(data_loader.dataset) if config['train_loss']['type'] == 'ELRLoss': train_loss = getattr(module_loss, 'ELRLoss')(num_examp=num_examp, num_classes=config['num_classes'], beta=config['train_loss']['args']['beta']) elif config['train_loss']['type'] == 'SCELoss': train_loss = getattr(module_loss, 'SCELoss')(alpha=config['train_loss']['args']['alpha'], beta=config['train_loss']['args']['beta'], num_classes=config['num_classes']) elif config['train_loss']['type'] == 'GCELoss': train_loss = getattr(module_loss, 'GCELoss')(q=config['train_loss']['args']['q'], k=config['train_loss']['args']['k'], trainset_size=num_examp, truncated=config['train_loss']['args']['truncated']) elif config['train_loss']['type'] == 'GTLoss': train_loss = getattr(module_loss, 'GTLoss')() else: train_loss = getattr(module_loss, 'CCELoss')() print (train_loss) val_loss = getattr(module_loss, config['val_loss']) metrics = [getattr(module_metric, met) for met in config['metrics']] # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler trainable_params = filter(lambda p: p.requires_grad, model.parameters()) optimizer = config.initialize('optimizer', torch.optim, [{'params': trainable_params}]) lr_scheduler = config.initialize('lr_scheduler', torch.optim.lr_scheduler, optimizer) if config['train_loss']['type'] == 'ELRLoss': trainer = RealDatasetTrainer(model, train_loss, metrics, optimizer, config=config, data_loader=data_loader, parse=parse, teacher=teacher, valid_data_loader=valid_data_loader, test_data_loader=test_data_loader, lr_scheduler=lr_scheduler, val_criterion=val_loss, mode = parse.mode, entropy = parse.entropy, threshold = parse.threshold ) elif config['train_loss']['type'] == 'SCELoss': trainer = RealDatasetTrainer(model, train_loss, metrics, optimizer, config=config, data_loader=data_loader, parse=parse, teacher=teacher, valid_data_loader=valid_data_loader, test_data_loader=test_data_loader, lr_scheduler=lr_scheduler, val_criterion=val_loss, mode = parse.mode, entropy = parse.entropy, threshold = parse.threshold ) elif config['train_loss']['type'] == 'GCELoss': if config['train_loss']['args']['truncated'] == False: trainer = RealDatasetTrainer(model, train_loss, metrics, optimizer, config=config, data_loader=data_loader, parse=parse, teacher=teacher, valid_data_loader=valid_data_loader, test_data_loader=test_data_loader, lr_scheduler=lr_scheduler, val_criterion=val_loss, mode = parse.mode, entropy = parse.entropy, threshold = parse.threshold ) elif config['train_loss']['args']['truncated'] == True: trainer= RealDatasetTrainer(model, train_loss, metrics, optimizer, config=config, data_loader=data_loader, parse=parse, teacher=teacher, valid_data_loader=valid_data_loader, test_data_loader=test_data_loader, lr_scheduler=lr_scheduler, val_criterion=val_loss, mode = parse.mode, entropy = parse.entropy, threshold = parse.threshold ) else: trainer = RealDatasetTrainer(model, train_loss, metrics, optimizer, config=config, data_loader=data_loader, parse=parse, teacher=teacher, valid_data_loader=valid_data_loader, test_data_loader=test_data_loader, lr_scheduler=lr_scheduler, val_criterion=val_loss, mode = parse.mode, entropy = parse.entropy, threshold = parse.threshold ) trainer.train() logger = config.get_logger('trainer', config['trainer']['verbosity']) cfg_trainer = config['trainer']
def coteachingtrain(parse, config: ConfigParser): # implementation for WandB wandb_run_name_list = wandbRunlist(config, parse) if parse.no_wandb: wandb.init(config=config, project='noisylabel', entity='goguryeo', name=wandb_run_name) # By default, pytorch utilizes multi-threaded cpu numthread = torch.get_num_threads() torch.set_num_threads(numthread) logger = config.get_logger('train') # Set seed for reproducibility fix_seed(config['seed']) data_loader = getattr(module_data, config['data_loader']['type'])( config['data_loader']['args']['data_dir'], batch_size=config['data_loader']['args']['batch_size'], shuffle=False if parse.distillation else config['data_loader']['args']['shuffle'], validation_split=0.0, num_batches=config['data_loader']['args']['num_batches'], training=True, num_workers=config['data_loader']['args']['num_workers'], pin_memory=config['data_loader']['args']['pin_memory'], seed=parse.dataseed # parse.seed ) valid_data_loader = None test_data_loader = getattr(module_data, config['data_loader']['type'])( config['data_loader']['args']['data_dir'], batch_size=128, shuffle=False, validation_split=0.0, training=False, num_workers=2).split_validation() # build model architecture, then print to console model = config.initialize('arch', module_arch) if parse.no_wandb: wandb.watch(model) if parse.distillation: teacher = config.initialize('teacher_arch', module_arch) data_loader = getattr(module_data, config['data_loader']['type'])( config['data_loader']['args']['data_dir'], batch_size=config['data_loader']['args']['batch_size'], shuffle=config['data_loader']['args']['shuffle'], validation_split=0.0, num_batches=config['data_loader']['args']['num_batches'], training=True, num_workers=config['data_loader']['args']['num_workers'], pin_memory=config['data_loader']['args']['pin_memory'], seed=parse.dataseed, teacher_idx=extract_cleanidx(teacher, data_loader, parse)) else: teacher = None # get function handles of loss and metrics logger.info(config.config) if hasattr(data_loader.dataset, 'num_raw_example'): num_examp = data_loader.dataset.num_raw_example else: num_examp = len(data_loader.dataset) # F-coteaching if config['train_loss']['type'] == 'CCELoss': train_loss = getattr(module_loss, 'CCELoss')() # coteaching elif config['train_loss']['type'] == 'CoteachingLoss': train_loss = getattr(module_loss, 'CoteachingLoss')( forget_rate=config['trainer']['percent'], num_gradual=int(config['train_loss']['args']['num_gradual']), n_epoch=config['trainer']['epochs']) # coteaching_plus elif config['train_loss']['type'] == 'CoteachingPlusLoss': train_loss = getattr(module_loss, 'CoteachingPlusLoss')( forget_rate=config['trainer']['percent'], num_gradual=int(config['train_loss']['args']['num_gradual']), n_epoch=config['trainer']['epochs']) val_loss = getattr(module_loss, config['val_loss']) metrics = [getattr(module_metric, met) for met in config['metrics']] # F-coteaching if config['train_loss']['type'] == 'CCELoss': model = config.initialize('arch', module_arch) trainer = FCoteachingTrainer(model, train_loss, metrics, None, config=config, data_loader=data_loader, parse=parse, teacher=teacher, valid_data_loader=valid_data_loader, test_data_loader=test_data_loader, lr_scheduler=None, val_criterion=val_loss, mode=parse.mode, entropy=parse.entropy, threshold=parse.threshold) # coteaching elif config['train_loss']['type'] == 'CoteachingLoss': model1, model2 = config.initialize('arch', module_arch), config.initialize( 'arch', module_arch) trainable_params1 = filter(lambda p: p.requires_grad, model1.parameters()) trainable_params2 = filter(lambda p: p.requires_grad, model2.parameters()) optimizer1 = config.initialize('optimizer', torch.optim, [{ 'params': trainable_params1 }]) optimizer2 = config.initialize('optimizer', torch.optim, [{ 'params': trainable_params2 }]) if isinstance(optimizer1, torch.optim.Adam): lr_scheduler = None else: lr_scheduler1 = config.initialize('lr_scheduler', torch.optim.lr_scheduler, optimizer1) lr_scheduler2 = config.initialize('lr_scheduler', torch.optim.lr_scheduler, optimizer2) lr_scheduler = [lr_scheduler1, lr_scheduler2] # print ('$$$$$$$$$$$$$$$') # print (config['optimizer']) trainer = CoteachingTrainer( [model1, model2], train_loss, metrics, [optimizer1, optimizer2], config=config, data_loader=data_loader, parse=parse, teacher=teacher, valid_data_loader=valid_data_loader, test_data_loader=test_data_loader, lr_scheduler=lr_scheduler, val_criterion=val_loss, mode=parse.mode, entropy=parse.entropy, threshold=parse.threshold, epoch_decay_start=config['trainer']['epoch_decay_start'], n_epoch=config['trainer']['epochs'], learning_rate=config['optimizer']['args']['lr']) elif config['train_loss']['type'] == 'CoteachingPlusLoss': model1, model2 = config.initialize('arch', module_arch), config.initialize( 'arch', module_arch) trainable_params1 = filter(lambda p: p.requires_grad, model1.parameters()) trainable_params2 = filter(lambda p: p.requires_grad, model2.parameters()) optimizer1 = config.initialize('optimizer', torch.optim, [{ 'params': trainable_params1 }]) optimizer2 = config.initialize('optimizer', torch.optim, [{ 'params': trainable_params2 }]) if isinstance(optimizer1, torch.optim.Adam): lr_scheduler = None else: lr_scheduler1 = config.initialize('lr_scheduler', torch.optim.lr_scheduler, optimizer1) lr_scheduler2 = config.initialize('lr_scheduler', torch.optim.lr_scheduler, optimizer2) lr_scheduler = [lr_scheduler1, lr_scheduler2] trainer = CoteachingTrainer( [model1, model2], train_loss, metrics, [optimizer1, optimizer2], config=config, data_loader=data_loader, parse=parse, teacher=teacher, valid_data_loader=valid_data_loader, test_data_loader=test_data_loader, lr_scheduler=lr_scheduler, val_criterion=val_loss, mode=parse.mode, entropy=parse.entropy, threshold=parse.threshold, epoch_decay_start=config['trainer']['epoch_decay_start'], n_epoch=config['trainer']['epochs'], learning_rate=config['optimizer']['args']['lr']) elif config['train_loss']['type'] == 'CoteachingDistillLoss': model1, model2 = config.initialize('arch', module_arch), config.initialize( 'arch', module_arch) trainable_params1 = filter(lambda p: p.requires_grad, model1.parameters()) trainable_params2 = filter(lambda p: p.requires_grad, model2.parameters()) optimizer1 = config.initialize('optimizer', torch.optim, [{ 'params': trainable_params1 }]) optimizer2 = config.initialize('optimizer', torch.optim, [{ 'params': trainable_params2 }]) if isinstance(optimizer1, torch.optim.Adam): lr_scheduler = None else: lr_scheduler1 = config.initialize('lr_scheduler', torch.optim.lr_scheduler, optimizer1) lr_scheduler2 = config.initialize('lr_scheduler', torch.optim.lr_scheduler, optimizer2) lr_scheduler = [lr_scheduler1, lr_scheduler2] trainer = CoteachingTrainer( [model1, model2], train_loss, metrics, [optimizer1, optimizer2], config=config, data_loader=data_loader, parse=parse, teacher=teacher, valid_data_loader=valid_data_loader, test_data_loader=test_data_loader, lr_scheduler=lr_scheduler, val_criterion=val_loss, mode=parse.mode, entropy=parse.entropy, threshold=parse.threshold, epoch_decay_start=config['trainer']['epoch_decay_start'], n_epoch=config['trainer']['epochs'], learning_rate=config['optimizer']['args']['lr']) elif config['train_loss']['type'] == 'CoteachingPlusDistillLoss': model1, model2 = config.initialize('arch', module_arch), config.initialize( 'arch', module_arch) trainable_params1 = filter(lambda p: p.requires_grad, model1.parameters()) trainable_params2 = filter(lambda p: p.requires_grad, model2.parameters()) optimizer1 = config.initialize('optimizer', torch.optim, [{ 'params': trainable_params1 }]) optimizer2 = config.initialize('optimizer', torch.optim, [{ 'params': trainable_params2 }]) if isinstance(optimizer1, torch.optim.Adam): lr_scheduler = None else: lr_scheduler1 = config.initialize('lr_scheduler', torch.optim.lr_scheduler, optimizer1) lr_scheduler2 = config.initialize('lr_scheduler', torch.optim.lr_scheduler, optimizer2) lr_scheduler = [lr_scheduler1, lr_scheduler2] trainer = CoteachingTrainer( [model1, model2], train_loss, metrics, [optimizer1, optimizer2], config=config, data_loader=data_loader, parse=parse, teacher=teacher, valid_data_loader=valid_data_loader, test_data_loader=test_data_loader, lr_scheduler=lr_scheduler, val_criterion=val_loss, mode=parse.mode, entropy=parse.entropy, threshold=parse.threshold, epoch_decay_start=config['trainer']['epoch_decay_start'], n_epoch=config['trainer']['epochs'], learning_rate=config['optimizer']['args']['lr']) trainer.train() logger = config.get_logger('trainer', config['trainer']['verbosity']) cfg_trainer = config['trainer']