def __init__(self, configuration, pre_embed=None) : configuration = deepcopy(configuration) self.configuration = deepcopy(configuration) configuration['model']['encoder']['pre_embed'] = pre_embed encoder_copy = deepcopy(configuration['model']['encoder']) self.Pencoder = Encoder.from_params(Params(configuration['model']['encoder'])).to(device) self.Qencoder = Encoder.from_params(Params(encoder_copy)).to(device) configuration['model']['decoder']['hidden_size'] = self.Pencoder.output_size self.decoder = AttnDecoderQA.from_params(Params(configuration['model']['decoder'])).to(device) self.bsize = configuration['training']['bsize'] self.adversary_multi = AdversaryMulti(self.decoder) weight_decay = configuration['training'].get('weight_decay', 1e-5) self.params = list(self.Pencoder.parameters()) + list(self.Qencoder.parameters()) + list(self.decoder.parameters()) self.optim = torch.optim.Adam(self.params, weight_decay=weight_decay, amsgrad=True) # self.optim = torch.optim.Adagrad(self.params, lr=0.05, weight_decay=weight_decay) self.criterion = nn.CrossEntropyLoss() import time dirname = configuration['training']['exp_dirname'] basepath = configuration['training'].get('basepath', 'outputs') self.time_str = time.ctime().replace(' ', '_') self.dirname = os.path.join(basepath, dirname, self.time_str) self.swa_settings = configuration['training']['swa'] if self.swa_settings[0]: self.swa_all_optim = SWA(self.optim) self.running_norms = []
def main(args): np.random.seed(432) torch.random.manual_seed(432) try: os.makedirs(args.outpath) except OSError: pass experiment_path = utils.get_new_model_path(args.outpath) print(experiment_path) train_writer = SummaryWriter(os.path.join(experiment_path, 'train_logs')) val_writer = SummaryWriter(os.path.join(experiment_path, 'val_logs')) scheduler = cyclical_lr(5, 1e-5, 2e-3) trainer = train.Trainer(train_writer, val_writer, scheduler=scheduler) train_transform = data.build_preprocessing() eval_transform = data.build_preprocessing() trainds, evalds = data.build_dataset(args.datadir, None) trainds.transform = train_transform evalds.transform = eval_transform model = models.resnet34() base_opt = torch.optim.Adam(model.parameters()) opt = SWA(base_opt, swa_start=30, swa_freq=10) trainloader = DataLoader(trainds, batch_size=args.batch_size, shuffle=True, num_workers=8, pin_memory=True) evalloader = DataLoader(evalds, batch_size=args.batch_size, shuffle=False, num_workers=16, pin_memory=True) export_path = os.path.join(experiment_path, 'last.pth') best_lwlrap = 0 for epoch in range(args.epochs): print('Epoch {} - lr {:.6f}'.format(epoch, scheduler(epoch))) trainer.train_epoch(model, opt, trainloader, scheduler(epoch)) metrics = trainer.eval_epoch(model, evalloader) print('Epoch: {} - lwlrap: {:.4f}'.format(epoch, metrics['lwlrap'])) # save best model if metrics['lwlrap'] > best_lwlrap: best_lwlrap = metrics['lwlrap'] torch.save(model.state_dict(), export_path) print('Best metrics {:.4f}'.format(best_lwlrap)) opt.swap_swa_sgd()
def set_parameters(self, parameters): self.parameters = tuple(parameters) self.optimizer = self.optimizer_cls(self.parameters, **self.optimizer_kwargs) if self.swa_start is not None: from torchcontrib.optim import SWA assert self.swa_freq is not None, self.swa_freq assert self.swa_lr is not None, self.swa_lr self.optimizer = SWA(self.optimizer, swa_start=self.swa_start, swa_freq=self.swa_freq, swa_lr=self.swa_lr)
def reset_optimizer(self): self.base_optimizer = optimizers[self.optimizer_params['type']](self.net.parameters(), **self.optimizer_params[ 'args']) if self.swa_params is not None: self.optimizer = SWA(self.base_optimizer, **self.swa_params) self.swa = True self.averaged_weights = False else: self.optimizer = self.base_optimizer self.swa = False
def configure_optimizers(self): no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] param_optimizer = self.model.named_parameters() optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.001 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = AdamW(optimizer_grouped_parameters, lr=self.global_config.lr) lr_scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=self.global_config.warmup_steps, num_training_steps=self.total_steps(), ) if self.global_config.swa: optimizer = SWA(optimizer, self.global_config.swa_start, self.global_config.swa_freq, self.global_config.swa_lr) return [optimizer], [{"scheduler": lr_scheduler, "interval": "step"}]
def build_optimizer(self, name: str, model: torch.nn.Module) -> torch.optim.Optimizer: """No bias decay: Bag of Tricks for Image Classification with Convolutional Neural Networks (https://arxiv.org/pdf/1812.01187.pdf)""" weight_p, bias_p = [], [] for p_name, p in model.named_parameters(): if 'bias' in p_name: bias_p += [p] else: weight_p += [p] parameters = [{ 'params': weight_p, 'weight_decay': self.weight_decay }, { 'params': bias_p, 'weight_decay': 0 }] if name == 'Adam': return torch.optim.Adam(model.parameters(), lr=self.base_lr) if name == 'SGD': return torch.optim.SGD(model.parameters(), lr=self.base_lr) if name == 'SWA': """Stochastic Weight Averaging: Averaging Weights Leads to Wider Optima and Better Generalization (https://arxiv.org/pdf/1803.05407.pdf)""" base_opt = torch.optim.SGD(parameters, lr=self.base_lr) return SWA(base_opt, swa_start=10, swa_freq=5, swa_lr=self.base_lr)
def set_model_optimizer(self): """ Set model optimizer based on user parameter selection 1) Set SGD or Adam optimizer 2) Set SWA if set (check you have downloaded the library using: pip install torchcontrib) 3) Print if: Use ZCA preprocessing (sometimes useful for CIFAR10) or debug mode is on or off (to check the model on the test set without taking decisions based on it -- all decisions are taken based on the validation set) """ if self.args.optimizer == 'sgd': prRed('... SGD ...') optimizer = torch.optim.SGD(self.model.parameters(), self.args.lr, momentum=self.args.momentum, weight_decay=self.args.weight_decay, nesterov=self.args.nesterov) else: prRed('... Adam optimizer ...') optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr) if self.args.swa: prRed('Using SWA!') from torchcontrib.optim import SWA optimizer = SWA(optimizer) self.model_optimizer = optimizer if self.args.use_zca: prPurple('*Use ZCA preprocessing*') if self.args.debug: prPurple('*Debug mode on*')
def init_SWA(self, optimizer): print("Using SWA") opt = SWA( optimizer, swa_start=self.config_dict["iters_per_epoch"] * 5, swa_freq=self.config_dict["iters_per_epoch"] * 2, swa_lr=self.config_dict["lr"] * 1e-1, ) return opt
def _set_optimizer_scheduler(self): self.log(f'Optimizer and scheduler started to initilized.', direct_out=True) def is_backbone(n): return 'backbone' in n param_optimizer = list(self.model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] # use different learning rate for backbone transformer and classifier head if self.use_diff_lr: backbone_lr, head_lr = self.config.lr*xm.xrt_world_size(), self.config.lr*xm.xrt_world_size()*500 optimizer_grouped_parameters = [ # {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001}, # {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, {"params": [p for n, p in param_optimizer if is_backbone(n)], "lr": backbone_lr}, {"params": [p for n, p in param_optimizer if not is_backbone(n)], "lr": head_lr} ] self.log(f'Different Learning rate for backbone: {backbone_lr} head:{head_lr}') else: optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, ] try: self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.config.lr*xm.xrt_world_size()) # self.optimizer = SGD(optimizer_grouped_parameters, lr=self.config.lr*xm.xrt_world_size(), momentum=0.9) except: param_g_1 = [p for n, p in param_optimizer if is_backbone(n)] param_g_2 = [p for n, p in param_optimizer if not is_backbone(n)] param_intersect = list(set(param_g_1) & set(param_g_2)) self.log(f'intersect: {param_intersect}', direct_out=True) if self.use_SWA: self.optimizer = SWA(self.optimizer) if 'num_training_steps' in self.config.scheduler_params: num_training_steps = int(self.config.train_lenght / self.config.batch_size / xm.xrt_world_size() * self.config.n_epochs) self.log(f'Number of training steps: {num_training_steps}', direct_out=True) self.config.scheduler_params['num_training_steps'] = num_training_steps self.scheduler = self.config.SchedulerClass(self.optimizer, **self.config.scheduler_params)
def make_optimizer(self, max_steps): optimizer = OPTIMIZERS[self.config.train.optimizer] optimizer = optimizer(self.parameters(), self.config.train.learning_rate, weight_decay=self.config.train.weight_decay) self.optimizer = SWA(optimizer, swa_start=int(0.8 * max_steps), swa_freq=100) self.scheduler = make_scheduler(self.config.train.scheduler, max_steps=max_steps)(optimizer)
def configure_optimizers(self): optim = next(o for o in dir(torch.optim) if o.lower() == FLAGS.optim.lower()) # "Adam" optimizer=getattr(torch.optim, optim)(self.parameters(), lr=FLAGS.learning_rate) # optimizer object #optimizer=torch.optim.SGD(self.parameters(), lr=FLAGS.learning_rate,weight_decay=0.00001) optimizer=torch.optim.AdamW(self.parameters(),lr=FLAGS.learning_rate,weight_decay=0.00001) if FLAGS.SWA: iterations_per_epoch=int(len(train_ds)/FLAGS.batch_size) optimizer = SWA(optimizer, swa_start=int(FLAGS.swa_start*iterations_per_epoch), swa_freq=50, swa_lr=FLAGS.learning_rate/10) scheduler=torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=FLAGS.lr_milestones, gamma=0.5) #scheduler=torch.optim.lr_scheduler.CyclicLR(optimizer,base_lr=FLAGS.learning_rate/2,max_lr=2*FLAGS.learning_rate,step_size_up=2,step_size_down=2,cycle_momentum=False) #step_size_up is in epochs! I don't know why the hell return [optimizer], [scheduler]
class ParamOptim: params: List[torch.Tensor] lr: float = 1e-3 eps: float = 1e-8 clip_grad: float = None optimizer: Optimizer = AdamW def __post_init__(self): base_opt = self.optimizer(self.params, lr=self.lr, eps=self.eps) self.optim = SWA(base_opt) def set_lr(self, lr): for pg in self.optim.param_groups: pg['lr'] = lr return lr def step(self, loss): self.optim.zero_grad() loss.backward() if self.clip_grad is not None: torch.nn.utils.clip_grad_norm_(self.params, self.clip_grad) self.optim.step() return loss
def configure_optimizers(self): if self.hparams['optimizer_name'] == 'adam': opt = torch.optim.Adam(self.parameters(), lr=self.lr) return opt elif self.hparams['optimizer_name'] == 'rmsprop': opt = torch.optim.RMSprop(self.parameters(), lr=self.hparams['lr'], momentum=.001) elif self.hparams['optimizer_name'] == 'swa': opt = torch.optim.Adam(self.parameters(), lr=self.hparams['lr']) return SWA(opt, swa_start=100, swa_freq=50, swa_lr=self.hparams['lr'])
def _create_optimizer(self, sgd): optimizer = AdamW( self._model.parameters(), lr=getattr(sgd, "pytt_lr", sgd.alpha), eps=sgd.eps, betas=(sgd.b1, sgd.b2), weight_decay=getattr(sgd, "pytt_weight_decay", 0.0), ) if getattr(sgd, "pytt_use_swa", False): optimizer = SWA(optimizer, swa_start=1, swa_freq=10, swa_lr=sgd.alpha) optimizer.zero_grad() return optimizer
def make_optimizer(cfg, model): params = [] for key, value in model.named_parameters(): if not value.requires_grad: continue lr = cfg.SOLVER.BASE_LR weight_decay = cfg.SOLVER.WEIGHT_DECAY if "bias" in key: lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)( params, momentum=cfg.SOLVER.MOMENTUM) # training loop optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05) else: optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) return optimizer
def _dist_train(model, dataset, cfg, validate=False): # prepare data loaders data_loaders = [ build_dataloader(dataset, cfg.data.imgs_per_gpu, cfg.data.workers_per_gpu, dist=True) ] # put model on gpus model = MMDistributedDataParallel(model.cuda()) # build runner optimizer = build_optimizer(model, cfg.optimizer) if cfg.swa is not None: optimizer = SWA(optimizer, cfg.swa['swa_start'], cfg.swa['swa_freq'], cfg.swa['swa_lr']) runner = Runner(model, batch_processor, optimizer, cfg.work_dir, cfg.log_level) # register hooks optimizer_config = DistOptimizerHook(**cfg.optimizer_config) runner.register_training_hooks(cfg.lr_config, optimizer_config, cfg.checkpoint_config, cfg.log_config) runner.register_hook(DistSamplerSeedHook()) # register eval hooks if validate: val_dataset_cfg = cfg.data.val if isinstance(model.module, RPN): # TODO: implement recall hooks for other datasets runner.register_hook(CocoDistEvalRecallHook(val_dataset_cfg)) else: dataset_type = getattr(datasets, val_dataset_cfg.type) if issubclass(dataset_type, datasets.CocoDataset): runner.register_hook(CocoDistEvalmAPHook(val_dataset_cfg)) else: runner.register_hook(DistEvalmAPHook(val_dataset_cfg)) if cfg.resume_from: runner.resume(cfg.resume_from) elif cfg.load_from: runner.load_checkpoint(cfg.load_from) runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
# Fetch the loss from model.loss_functions.mmd_loss import loss_function loss_fn = loss_function # Load the VAE model model = VariationalAutoencoder( params).cuda() if params.cuda else VariationalAutoencoder(params) if args.swa: # Use the Adam optimizer base_optimizer = optim.Adam(model.parameters(), lr=1e-3, eps=params.eps, betas=(params.betas[0], params.betas[1]), weight_decay=params.weight_decay) optimizer = SWA(base_optimizer, swa_start=10, swa_freq=5, swa_lr=1e-3) else: # Use the Adam optimizer optimizer = optim.Adam(model.parameters(), lr=params.learning_rate, eps=params.eps, betas=(params.betas[0], params.betas[1]), weight_decay=params.weight_decay) # Train the model train(model, train_dl, args.dataloader, optimizer, loss_fn, params, model_dir, args.swa, args.restore_file)
def __init__(self, configuration, pre_embed=None): configuration = deepcopy(configuration) self.configuration = deepcopy(configuration) configuration['model']['encoder']['pre_embed'] = pre_embed self.encoder = Encoder.from_params( Params(configuration['model']['encoder'])).to(device) configuration['model']['decoder'][ 'hidden_size'] = self.encoder.output_size self.decoder = AttnDecoder.from_params( Params(configuration['model']['decoder'])).to(device) self.encoder_params = list(self.encoder.parameters()) self.attn_params = list([ v for k, v in self.decoder.named_parameters() if 'attention' in k ]) self.decoder_params = list([ v for k, v in self.decoder.named_parameters() if 'attention' not in k ]) self.bsize = configuration['training']['bsize'] weight_decay = configuration['training'].get('weight_decay', 1e-5) self.encoder_optim = torch.optim.Adam(self.encoder_params, lr=0.001, weight_decay=weight_decay, amsgrad=True) self.attn_optim = torch.optim.Adam(self.attn_params, lr=0.001, weight_decay=0, amsgrad=True) self.decoder_optim = torch.optim.Adam(self.decoder_params, lr=0.001, weight_decay=weight_decay, amsgrad=True) self.adversarymulti = AdversaryMulti(decoder=self.decoder) self.all_params = self.encoder_params + self.attn_params + self.decoder_params self.all_optim = torch.optim.Adam(self.all_params, lr=0.001, weight_decay=weight_decay, amsgrad=True) # self.all_optim = adagrad.Adagrad(self.all_params, weight_decay=weight_decay) pos_weight = configuration['training'].get('pos_weight', [1.0] * self.decoder.output_size) self.pos_weight = torch.Tensor(pos_weight).to(device) self.criterion = nn.BCEWithLogitsLoss(reduction='none').to(device) self.swa_settings = configuration['training']['swa'] import time dirname = configuration['training']['exp_dirname'] basepath = configuration['training'].get('basepath', 'outputs') self.time_str = time.ctime().replace(' ', '_') self.dirname = os.path.join(basepath, dirname, self.time_str) self.temperature = configuration['training']['temperature'] self.train_losses = [] if self.swa_settings[0]: # self.attn_optim = SWA(self.attn_optim, swa_start=3, swa_freq=1, swa_lr=0.05) # self.decoder_optim = SWA(self.decoder_optim, swa_start=3, swa_freq=1, swa_lr=0.05) # self.encoder_optim = SWA(self.encoder_optim, swa_start=3, swa_freq=1, swa_lr=0.05) self.swa_all_optim = SWA(self.all_optim) self.running_norms = []
def main(args, logger): # trn_df = pd.read_csv(f'{MNT_DIR}/inputs/origin/train.csv') trn_df = pd.read_pickle(f'{MNT_DIR}/inputs/nes_info/trn_df.pkl') trn_df['is_original'] = 1 gkf = GroupKFold(n_splits=5).split( X=trn_df.question_body, groups=trn_df.question_body_le, ) histories = { 'trn_loss': {}, 'val_loss': {}, 'val_metric': {}, 'val_metric_raws': {}, } loaded_fold = -1 loaded_epoch = -1 if args.checkpoint: histories, loaded_fold, loaded_epoch = load_checkpoint(args.checkpoint) fold_best_metrics = [] fold_best_metrics_raws = [] for fold, (trn_idx, val_idx) in enumerate(gkf): if fold < loaded_fold: fold_best_metrics.append(np.max(histories["val_metric"][fold])) fold_best_metrics_raws.append( histories["val_metric_raws"][fold][np.argmax( histories["val_metric"][fold])]) continue sel_log( f' --------------------------- start fold {fold} --------------------------- ', logger) fold_trn_df = trn_df.iloc[trn_idx] # .query('is_original == 1') fold_trn_df = fold_trn_df.drop(['is_original', 'question_body_le'], axis=1) # use only original row fold_val_df = trn_df.iloc[val_idx].query('is_original == 1') fold_val_df = fold_val_df.drop(['is_original', 'question_body_le'], axis=1) if args.debug: fold_trn_df = fold_trn_df.sample(100, random_state=71) fold_val_df = fold_val_df.sample(100, random_state=71) temp = pd.Series( list( itertools.chain.from_iterable( fold_trn_df.question_title.apply(lambda x: x.split(' ')) + fold_trn_df.question_body.apply(lambda x: x.split(' ')) + fold_trn_df.answer.apply(lambda x: x.split(' ')))) ).value_counts() tokens = temp[temp >= 10].index.tolist() # tokens = [] tokens = [ 'CAT_TECHNOLOGY'.casefold(), 'CAT_STACKOVERFLOW'.casefold(), 'CAT_CULTURE'.casefold(), 'CAT_SCIENCE'.casefold(), 'CAT_LIFE_ARTS'.casefold(), ] trn_dataset = QUESTDataset( df=fold_trn_df, mode='train', tokens=tokens, augment=[], tokenizer_type=TOKENIZER_TYPE, pretrained_model_name_or_path=TOKENIZER_PRETRAIN, do_lower_case=True, LABEL_COL=LABEL_COL, t_max_len=30, q_max_len=239 * 2, a_max_len=239 * 0, tqa_mode=TQA_MODE, TBSEP='[TBSEP]', pos_id_type='arange', MAX_SEQUENCE_LENGTH=MAX_SEQ_LEN, rm_zero=RM_ZERO, ) # update token trn_sampler = RandomSampler(data_source=trn_dataset) trn_loader = DataLoader(trn_dataset, batch_size=BATCH_SIZE, sampler=trn_sampler, num_workers=os.cpu_count(), worker_init_fn=lambda x: np.random.seed(), drop_last=True, pin_memory=True) val_dataset = QUESTDataset( df=fold_val_df, mode='valid', tokens=tokens, augment=[], tokenizer_type=TOKENIZER_TYPE, pretrained_model_name_or_path=TOKENIZER_PRETRAIN, do_lower_case=True, LABEL_COL=LABEL_COL, t_max_len=30, q_max_len=239 * 2, a_max_len=239 * 0, tqa_mode=TQA_MODE, TBSEP='[TBSEP]', pos_id_type='arange', MAX_SEQUENCE_LENGTH=MAX_SEQ_LEN, rm_zero=RM_ZERO, ) val_sampler = RandomSampler(data_source=val_dataset) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, sampler=val_sampler, num_workers=os.cpu_count(), worker_init_fn=lambda x: np.random.seed(), drop_last=False, pin_memory=True) fobj = BCEWithLogitsLoss() state_dict = BertModel.from_pretrained(MODEL_PRETRAIN).state_dict() model = BertModelForBinaryMultiLabelClassifier( num_labels=len(LABEL_COL), config_path=MODEL_CONFIG_PATH, state_dict=state_dict, token_size=len(trn_dataset.tokenizer), MAX_SEQUENCE_LENGTH=MAX_SEQ_LEN, ) # optimizer = optim.Adam(model.parameters(), lr=3e-5) optimizer = optim.SGD(model.parameters(), lr=1e-1) optimizer = SWA(optimizer, swa_start=2, swa_freq=5, swa_lr=1e-1) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=MAX_EPOCH, eta_min=1e-2) # load checkpoint model, optim, scheduler if args.checkpoint and fold == loaded_fold: load_checkpoint(args.checkpoint, model, optimizer, scheduler) for epoch in tqdm(list(range(MAX_EPOCH))): if fold <= loaded_fold and epoch <= loaded_epoch: continue if epoch < 1: model.freeze_unfreeze_bert(freeze=True, logger=logger) else: model.freeze_unfreeze_bert(freeze=False, logger=logger) model = DataParallel(model) model = model.to(DEVICE) trn_loss = train_one_epoch(model, fobj, optimizer, trn_loader, DEVICE) if epoch > 2: optimizer.swap_swa_sgd() optimizer.bn_update(trn_loader, model) val_loss, val_metric, val_metric_raws, val_y_preds, val_y_trues, val_qa_ids = test( model, fobj, val_loader, DEVICE, mode='valid') if epoch > 2: optimizer.swap_swa_sgd() scheduler.step() if fold in histories['trn_loss']: histories['trn_loss'][fold].append(trn_loss) else: histories['trn_loss'][fold] = [ trn_loss, ] if fold in histories['val_loss']: histories['val_loss'][fold].append(val_loss) else: histories['val_loss'][fold] = [ val_loss, ] if fold in histories['val_metric']: histories['val_metric'][fold].append(val_metric) else: histories['val_metric'][fold] = [ val_metric, ] if fold in histories['val_metric_raws']: histories['val_metric_raws'][fold].append(val_metric_raws) else: histories['val_metric_raws'][fold] = [ val_metric_raws, ] logging_val_metric_raws = '' for val_metric_raw in val_metric_raws: logging_val_metric_raws += f'{float(val_metric_raw):.4f}, ' sel_log( f'fold : {fold} -- epoch : {epoch} -- ' f'trn_loss : {float(trn_loss.detach().to("cpu").numpy()):.4f} -- ' f'val_loss : {float(val_loss.detach().to("cpu").numpy()):.4f} -- ' f'val_metric : {float(val_metric):.4f} -- ' f'val_metric_raws : {logging_val_metric_raws}', logger) model = model.to('cpu') model = model.module save_checkpoint(f'{MNT_DIR}/checkpoints/{EXP_ID}/{fold}', model, optimizer, scheduler, histories, val_y_preds, val_y_trues, val_qa_ids, fold, epoch, val_loss, val_metric) fold_best_metrics.append(np.max(histories["val_metric"][fold])) fold_best_metrics_raws.append( histories["val_metric_raws"][fold][np.argmax( histories["val_metric"][fold])]) save_and_clean_for_prediction(f'{MNT_DIR}/checkpoints/{EXP_ID}/{fold}', trn_dataset.tokenizer, clean=False) del model # calc training stats fold_best_metric_mean = np.mean(fold_best_metrics) fold_best_metric_std = np.std(fold_best_metrics) fold_stats = f'{EXP_ID} : {fold_best_metric_mean:.4f} +- {fold_best_metric_std:.4f}' sel_log(fold_stats, logger) send_line_notification(fold_stats) fold_best_metrics_raws_mean = np.mean(fold_best_metrics_raws, axis=0) fold_raw_stats = '' for metric_stats_raw in fold_best_metrics_raws_mean: fold_raw_stats += f'{float(metric_stats_raw):.4f},' sel_log(fold_raw_stats, logger) send_line_notification(fold_raw_stats) sel_log('now saving best checkpoints...', logger)
def get_default_optimizer(): base_opt = torch.optim.SGD(model.parameters(), momentum=.9, lr=1e-1) optimizer = SWA(base_opt, swa_start=10, swa_freq=5, swa_lr=0.05) return optimizer
class Model(): def __init__(self, configuration, pre_embed=None): configuration = deepcopy(configuration) self.configuration = deepcopy(configuration) configuration['model']['encoder']['pre_embed'] = pre_embed self.encoder = Encoder.from_params( Params(configuration['model']['encoder'])).to(device) configuration['model']['decoder'][ 'hidden_size'] = self.encoder.output_size self.decoder = AttnDecoder.from_params( Params(configuration['model']['decoder'])).to(device) self.encoder_params = list(self.encoder.parameters()) self.attn_params = list([ v for k, v in self.decoder.named_parameters() if 'attention' in k ]) self.decoder_params = list([ v for k, v in self.decoder.named_parameters() if 'attention' not in k ]) self.bsize = configuration['training']['bsize'] weight_decay = configuration['training'].get('weight_decay', 1e-5) self.encoder_optim = torch.optim.Adam(self.encoder_params, lr=0.001, weight_decay=weight_decay, amsgrad=True) self.attn_optim = torch.optim.Adam(self.attn_params, lr=0.001, weight_decay=0, amsgrad=True) self.decoder_optim = torch.optim.Adam(self.decoder_params, lr=0.001, weight_decay=weight_decay, amsgrad=True) self.adversarymulti = AdversaryMulti(decoder=self.decoder) self.all_params = self.encoder_params + self.attn_params + self.decoder_params self.all_optim = torch.optim.Adam(self.all_params, lr=0.001, weight_decay=weight_decay, amsgrad=True) # self.all_optim = adagrad.Adagrad(self.all_params, weight_decay=weight_decay) pos_weight = configuration['training'].get('pos_weight', [1.0] * self.decoder.output_size) self.pos_weight = torch.Tensor(pos_weight).to(device) self.criterion = nn.BCEWithLogitsLoss(reduction='none').to(device) self.swa_settings = configuration['training']['swa'] import time dirname = configuration['training']['exp_dirname'] basepath = configuration['training'].get('basepath', 'outputs') self.time_str = time.ctime().replace(' ', '_') self.dirname = os.path.join(basepath, dirname, self.time_str) self.temperature = configuration['training']['temperature'] self.train_losses = [] if self.swa_settings[0]: # self.attn_optim = SWA(self.attn_optim, swa_start=3, swa_freq=1, swa_lr=0.05) # self.decoder_optim = SWA(self.decoder_optim, swa_start=3, swa_freq=1, swa_lr=0.05) # self.encoder_optim = SWA(self.encoder_optim, swa_start=3, swa_freq=1, swa_lr=0.05) self.swa_all_optim = SWA(self.all_optim) self.running_norms = [] @classmethod def init_from_config(cls, dirname, **kwargs): config = json.load(open(dirname + '/config.json', 'r')) config.update(kwargs) obj = cls(config) obj.load_values(dirname) return obj def get_param_buffer_norms(self): for p in self.swa_all_optim.param_groups[0]['params']: param_state = self.swa_all_optim.state[p] if 'swa_buffer' not in param_state: self.swa_all_optim.update_swa() norms = [] # for p in np.array(self.swa_all_optim.param_groups[0]['params'])[[1, 2, 5, 6, 9]]: for p in np.array(self.swa_all_optim.param_groups[0]['params'])[[6, 9]]: param_state = self.swa_all_optim.state[p] buf = np.squeeze(param_state['swa_buffer'].cpu().numpy()) cur_state = np.squeeze(p.data.cpu().numpy()) norm = np.linalg.norm(buf - cur_state) norms.append(norm) if self.swa_settings[3] == 2: return np.max(norms) return np.mean(norms) def total_iter_num(self): return self.swa_all_optim.param_groups[0]['step_counter'] def iter_for_swa_update(self, iter_num): return iter_num > self.swa_settings[1] \ and iter_num % self.swa_settings[2] == 0 def check_and_update_swa(self): if self.iter_for_swa_update(self.total_iter_num()): cur_step_diff_norm = self.get_param_buffer_norms() if self.swa_settings[3] == 0: self.swa_all_optim.update_swa() return if not self.running_norms: running_mean_norm = 0 else: running_mean_norm = np.mean(self.running_norms) if cur_step_diff_norm > running_mean_norm: self.swa_all_optim.update_swa() self.running_norms = [cur_step_diff_norm] elif cur_step_diff_norm > 0: self.running_norms.append(cur_step_diff_norm) def train(self, data_in, target_in, train=True): sorting_idx = get_sorting_index_with_noise_from_lengths( [len(x) for x in data_in], noise_frac=0.1) data = [data_in[i] for i in sorting_idx] target = [target_in[i] for i in sorting_idx] self.encoder.train() self.decoder.train() bsize = self.bsize N = len(data) loss_total = 0 batches = list(range(0, N, bsize)) batches = shuffle(batches) for n in tqdm(batches): torch.cuda.empty_cache() batch_doc = data[n:n + bsize] batch_data = BatchHolder(batch_doc) self.encoder(batch_data) self.decoder(batch_data) batch_target = target[n:n + bsize] batch_target = torch.Tensor(batch_target).to(device) if len(batch_target.shape) == 1: #(B, ) batch_target = batch_target.unsqueeze(-1) #(B, 1) bce_loss = self.criterion(batch_data.predict / self.temperature, batch_target) weight = batch_target * self.pos_weight + (1 - batch_target) bce_loss = (bce_loss * weight).mean(1).sum() loss = bce_loss self.train_losses.append(bce_loss.detach().cpu().numpy() + 0) if hasattr(batch_data, 'reg_loss'): loss += batch_data.reg_loss if train: if self.swa_settings[0]: self.check_and_update_swa() self.swa_all_optim.zero_grad() loss.backward() self.swa_all_optim.step() else: # self.encoder_optim.zero_grad() # self.decoder_optim.zero_grad() # self.attn_optim.zero_grad() self.all_optim.zero_grad() loss.backward() # self.encoder_optim.step() # self.decoder_optim.step() # self.attn_optim.step() self.all_optim.step() loss_total += float(loss.data.cpu().item()) if self.swa_settings[0] and self.swa_all_optim.param_groups[0][ 'step_counter'] > self.swa_settings[1]: print("\nSWA swapping\n") # self.attn_optim.swap_swa_sgd() # self.encoder_optim.swap_swa_sgd() # self.decoder_optim.swap_swa_sgd() self.swa_all_optim.swap_swa_sgd() self.running_norms = [] return loss_total * bsize / N def predictor(self, inp_text_permutations): text_permutations = [ dataset_vec.map2idxs(x.split()) for x in inp_text_permutations ] outputs = [] bsize = 512 N = len(text_permutations) for n in range(0, N, bsize): torch.cuda.empty_cache() batch_doc = text_permutations[n:n + bsize] batch_data = BatchHolder(batch_doc) self.encoder(batch_data) self.decoder(batch_data) batch_data.predict = torch.sigmoid(batch_data.predict) pred = batch_data.predict.cpu().data.numpy() for i in range(len(pred)): if math.isnan(pred[i][0]): pred[i][0] = 0.5 outputs.extend(pred) ret_val = [[output_i[0], 1 - output_i[0]] for output_i in outputs] ret_val = np.array(ret_val) return ret_val def evaluate(self, data): self.encoder.eval() self.decoder.eval() bsize = self.bsize N = len(data) outputs = [] attns = [] for n in tqdm(range(0, N, bsize)): torch.cuda.empty_cache() batch_doc = data[n:n + bsize] batch_data = BatchHolder(batch_doc) self.encoder(batch_data) self.decoder(batch_data) batch_data.predict = torch.sigmoid(batch_data.predict / self.temperature) if self.decoder.use_attention: attn = batch_data.attn.cpu().data.numpy() attns.append(attn) predict = batch_data.predict.cpu().data.numpy() outputs.append(predict) outputs = [x for y in outputs for x in y] if self.decoder.use_attention: attns = [x for y in attns for x in y] return outputs, attns def get_lime_explanations(self, data): explanations = [] explainer = LimeTextExplainer(class_names=["A", "B"]) for data_i in data: sentence = ' '.join(dataset_vec.map2words(data_i)) exp = explainer.explain_instance(text_instance=sentence, classifier_fn=self.predictor, num_features=len(data_i), num_samples=5000).as_list() explanations.append(exp) return explanations def gradient_mem(self, data): self.encoder.train() self.decoder.train() bsize = self.bsize N = len(data) grads = {'XxE': [], 'XxE[X]': [], 'H': []} for n in tqdm(range(0, N, bsize)): torch.cuda.empty_cache() batch_doc = data[n:n + bsize] grads_xxe = [] grads_xxex = [] grads_H = [] for i in range(self.decoder.output_size): batch_data = BatchHolder(batch_doc) batch_data.keep_grads = True batch_data.detach = True self.encoder(batch_data) self.decoder(batch_data) torch.sigmoid(batch_data.predict[:, i]).sum().backward() g = batch_data.embedding.grad em = batch_data.embedding g1 = (g * em).sum(-1) grads_xxex.append(g1.cpu().data.numpy()) g1 = (g * self.encoder.embedding.weight.sum(0)).sum(-1) grads_xxe.append(g1.cpu().data.numpy()) g1 = batch_data.hidden.grad.sum(-1) grads_H.append(g1.cpu().data.numpy()) grads_xxe = np.array(grads_xxe).swapaxes(0, 1) grads_xxex = np.array(grads_xxex).swapaxes(0, 1) grads_H = np.array(grads_H).swapaxes(0, 1) import ipdb ipdb.set_trace() grads['XxE'].append(grads_xxe) grads['XxE[X]'].append(grads_xxex) grads['H'].append(grads_H) for k in grads: grads[k] = [x for y in grads[k] for x in y] return grads def remove_and_run(self, data): self.encoder.train() self.decoder.train() bsize = self.bsize N = len(data) outputs = [] for n in tqdm(range(0, N, bsize)): batch_doc = data[n:n + bsize] batch_data = BatchHolder(batch_doc) po = np.zeros( (batch_data.B, batch_data.maxlen, self.decoder.output_size)) for i in range(1, batch_data.maxlen - 1): batch_data = BatchHolder(batch_doc) batch_data.seq = torch.cat( [batch_data.seq[:, :i], batch_data.seq[:, i + 1:]], dim=-1) batch_data.lengths = batch_data.lengths - 1 batch_data.masks = torch.cat( [batch_data.masks[:, :i], batch_data.masks[:, i + 1:]], dim=-1) self.encoder(batch_data) self.decoder(batch_data) po[:, i] = torch.sigmoid(batch_data.predict).cpu().data.numpy() outputs.append(po) outputs = [x for y in outputs for x in y] return outputs def permute_attn(self, data, num_perm=100): self.encoder.train() self.decoder.train() bsize = self.bsize N = len(data) permutations = [] for n in tqdm(range(0, N, bsize)): torch.cuda.empty_cache() batch_doc = data[n:n + bsize] batch_data = BatchHolder(batch_doc) batch_perms = np.zeros( (batch_data.B, num_perm, self.decoder.output_size)) self.encoder(batch_data) self.decoder(batch_data) for i in range(num_perm): batch_data.permute = True self.decoder(batch_data) output = torch.sigmoid(batch_data.predict) batch_perms[:, i] = output.cpu().data.numpy() permutations.append(batch_perms) permutations = [x for y in permutations for x in y] return permutations def save_values(self, use_dirname=None, save_model=True, append_to_dir_name=''): if use_dirname is not None: dirname = use_dirname else: dirname = self.dirname + append_to_dir_name self.last_epch_dirname = dirname os.makedirs(dirname, exist_ok=True) shutil.copy2(file_name, dirname + '/') json.dump(self.configuration, open(dirname + '/config.json', 'w')) if save_model: torch.save(self.encoder.state_dict(), dirname + '/enc.th') torch.save(self.decoder.state_dict(), dirname + '/dec.th') return dirname def load_values(self, dirname): self.encoder.load_state_dict( torch.load(dirname + '/enc.th', map_location={'cuda:1': 'cuda:0'})) self.decoder.load_state_dict( torch.load(dirname + '/dec.th', map_location={'cuda:1': 'cuda:0'})) def adversarial_multi(self, data): self.encoder.eval() self.decoder.eval() for p in self.encoder.parameters(): p.requires_grad = False for p in self.decoder.parameters(): p.requires_grad = False bsize = self.bsize N = len(data) adverse_attn = [] adverse_output = [] for n in tqdm(range(0, N, bsize)): torch.cuda.empty_cache() batch_doc = data[n:n + bsize] batch_data = BatchHolder(batch_doc) self.encoder(batch_data) self.decoder(batch_data) self.adversarymulti(batch_data) attn_volatile = batch_data.attn_volatile.cpu().data.numpy( ) #(B, 10, L) predict_volatile = batch_data.predict_volatile.cpu().data.numpy( ) #(B, 10, O) adverse_attn.append(attn_volatile) adverse_output.append(predict_volatile) adverse_output = [x for y in adverse_output for x in y] adverse_attn = [x for y in adverse_attn for x in y] return adverse_output, adverse_attn def logodds_attention(self, data, logodds_map: Dict): self.encoder.eval() self.decoder.eval() bsize = self.bsize N = len(data) adverse_attn = [] adverse_output = [] logodds = np.zeros((self.encoder.vocab_size, )) for k, v in logodds_map.items(): if v is not None: logodds[k] = abs(v) else: logodds[k] = float('-inf') logodds = torch.Tensor(logodds).to(device) for n in tqdm(range(0, N, bsize)): torch.cuda.empty_cache() batch_doc = data[n:n + bsize] batch_data = BatchHolder(batch_doc) self.encoder(batch_data) self.decoder(batch_data) attn = batch_data.attn #(B, L) batch_data.attn_logodds = logodds[batch_data.seq] self.decoder.get_output_from_logodds(batch_data) attn_volatile = batch_data.attn_volatile.cpu().data.numpy( ) #(B, L) predict_volatile = torch.sigmoid( batch_data.predict_volatile).cpu().data.numpy() #(B, O) adverse_attn.append(attn_volatile) adverse_output.append(predict_volatile) adverse_output = [x for y in adverse_output for x in y] adverse_attn = [x for y in adverse_attn for x in y] return adverse_output, adverse_attn def logodds_substitution(self, data, top_logodds_words: Dict): self.encoder.eval() self.decoder.eval() bsize = self.bsize N = len(data) adverse_X = [] adverse_attn = [] adverse_output = [] words_neg = torch.Tensor( top_logodds_words[0][0]).long().cuda().unsqueeze(0) words_pos = torch.Tensor( top_logodds_words[0][1]).long().cuda().unsqueeze(0) words_to_select = torch.cat([words_neg, words_pos], dim=0) #(2, 5) for n in tqdm(range(0, N, bsize)): torch.cuda.empty_cache() batch_doc = data[n:n + bsize] batch_data = BatchHolder(batch_doc) self.encoder(batch_data) self.decoder(batch_data) predict_class = (torch.sigmoid(batch_data.predict).squeeze(-1) > 0.5) * 1 #(B,) attn = batch_data.attn #(B, L) top_val, top_idx = torch.topk(attn, 5, dim=-1) subs_words = words_to_select[1 - predict_class.long()] #(B, 5) batch_data.seq.scatter_(1, top_idx, subs_words) self.encoder(batch_data) self.decoder(batch_data) attn_volatile = batch_data.attn.cpu().data.numpy() #(B, L) predict_volatile = torch.sigmoid( batch_data.predict).cpu().data.numpy() #(B, O) X_volatile = batch_data.seq.cpu().data.numpy() adverse_X.append(X_volatile) adverse_attn.append(attn_volatile) adverse_output.append(predict_volatile) adverse_X = [x for y in adverse_X for x in y] adverse_output = [x for y in adverse_output for x in y] adverse_attn = [x for y in adverse_attn for x in y] return adverse_output, adverse_attn, adverse_X def predict(self, batch_data, lengths, masks): batch_holder = BatchHolderIndentity(batch_data, lengths, masks) self.encoder(batch_holder) self.decoder(batch_holder) # batch_holder.predict = torch.sigmoid(batch_holder.predict) predict = batch_holder.predict return predict
class HM: def __init__(self): if args.train is not None: self.train_tuple = get_tuple(args.train, bs=args.batch_size, shuffle=True, drop_last=False) if args.valid is not None: valid_bsize = 2048 if args.multiGPU else 50 self.valid_tuple = get_tuple(args.valid, bs=valid_bsize, shuffle=False, drop_last=False) else: self.valid_tuple = None # Select Model, X is default if args.model == "X": self.model = ModelX(args) elif args.model == "V": self.model = ModelV(args) elif args.model == "U": self.model = ModelU(args) elif args.model == "D": self.model = ModelD(args) elif args.model == 'O': self.model = ModelO(args) else: print(args.model, " is not implemented.") # Load pre-trained weights from paths if args.loadpre is not None: self.model.load(args.loadpre) # GPU options if args.multiGPU: self.model.lxrt_encoder.multi_gpu() self.model = self.model.cuda() # Losses and optimizer self.logsoftmax = nn.LogSoftmax(dim=1) self.nllloss = nn.NLLLoss() if args.train is not None: batch_per_epoch = len(self.train_tuple.loader) self.t_total = int(batch_per_epoch * args.epochs // args.acc) print("Total Iters: %d" % self.t_total) def is_backbone(n): if "encoder" in n: return True elif "embeddings" in n: return True elif "pooler" in n: return True print("F: ", n) return False no_decay = ['bias', 'LayerNorm.weight'] params = list(self.model.named_parameters()) if args.reg: optimizer_grouped_parameters = [ { "params": [p for n, p in params if is_backbone(n)], "lr": args.lr }, { "params": [p for n, p in params if not is_backbone(n)], "lr": args.lr * 500 }, ] for n, p in self.model.named_parameters(): print(n) self.optim = AdamW(optimizer_grouped_parameters, lr=args.lr) else: optimizer_grouped_parameters = [{ 'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 'weight_decay': args.wd }, { 'params': [p for n, p in params if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] self.optim = AdamW(optimizer_grouped_parameters, lr=args.lr) if args.train is not None: self.scheduler = get_linear_schedule_with_warmup( self.optim, self.t_total * 0.1, self.t_total) self.output = args.output os.makedirs(self.output, exist_ok=True) # SWA Method: if args.contrib: self.optim = SWA(self.optim, swa_start=self.t_total * 0.75, swa_freq=5, swa_lr=args.lr) if args.swa: self.swa_model = AveragedModel(self.model) self.swa_start = self.t_total * 0.75 self.swa_scheduler = SWALR(self.optim, swa_lr=args.lr) def train(self, train_tuple, eval_tuple): dset, loader, evaluator = train_tuple iter_wrapper = (lambda x: tqdm(x, total=len(loader)) ) if args.tqdm else (lambda x: x) print("Batches:", len(loader)) self.optim.zero_grad() best_roc = 0. ups = 0 total_loss = 0. for epoch in range(args.epochs): if args.reg: if args.model != "X": print(self.model.model.layer_weights) id2ans = {} id2prob = {} for i, (ids, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)): if ups == args.midsave: self.save("MID") self.model.train() if args.swa: self.swa_model.train() feats, boxes, target = feats.cuda(), boxes.cuda(), target.long( ).cuda() # Model expects visual feats as tuple of feats & boxes logit = self.model(sent, (feats, boxes)) # Note: LogSoftmax does not change order, hence there should be nothing wrong with taking it as our prediction # In fact ROC AUC stays the exact same for logsoftmax / normal softmax, but logsoftmax is better for loss calculation # due to stronger penalization & decomplexifying properties (log(a/b) = log(a) - log(b)) logit = self.logsoftmax(logit) score = logit[:, 1] if i < 1: print(logit[0, :].detach()) # Note: This loss is the same as CrossEntropy (We splitted it up in logsoftmax & neg. log likelihood loss) loss = self.nllloss(logit.view(-1, 2), target.view(-1)) # Scaling loss by batch size, as we have batches with different sizes, since we do not "drop_last" & dividing by acc for accumulation # Not scaling the loss will worsen performance by ~2abs% loss = loss * logit.size(0) / args.acc loss.backward() total_loss += loss.detach().item() # Acts as argmax - extracting the higher score & the corresponding index (0 or 1) _, predict = logit.detach().max(1) # Getting labels for accuracy for qid, l in zip(ids, predict.cpu().numpy()): id2ans[qid] = l # Getting probabilities for Roc auc for qid, l in zip(ids, score.detach().cpu().numpy()): id2prob[qid] = l if (i + 1) % args.acc == 0: nn.utils.clip_grad_norm_(self.model.parameters(), args.clip) self.optim.step() if (args.swa) and (ups > self.swa_start): self.swa_model.update_parameters(self.model) self.swa_scheduler.step() else: self.scheduler.step() self.optim.zero_grad() ups += 1 # Do Validation in between if ups % 250 == 0: log_str = "\nEpoch(U) %d(%d): Train AC %0.2f RA %0.4f LOSS %0.4f\n" % ( epoch, ups, evaluator.evaluate(id2ans) * 100, evaluator.roc_auc(id2prob) * 100, total_loss) # Set loss back to 0 after printing it total_loss = 0. if self.valid_tuple is not None: # Do Validation acc, roc_auc = self.evaluate(eval_tuple) if roc_auc > best_roc: best_roc = roc_auc best_acc = acc # Only save BEST when no midsave is specified to save space #if args.midsave < 0: # self.save("BEST") log_str += "\nEpoch(U) %d(%d): DEV AC %0.2f RA %0.4f \n" % ( epoch, ups, acc * 100., roc_auc * 100) log_str += "Epoch(U) %d(%d): BEST AC %0.2f RA %0.4f \n" % ( epoch, ups, best_acc * 100., best_roc * 100.) print(log_str, end='') with open(self.output + "/log.log", 'a') as f: f.write(log_str) f.flush() if (epoch + 1) == args.epochs: if args.contrib: self.optim.swap_swa_sgd() self.save("LAST" + args.train) def predict(self, eval_tuple: DataTuple, dump=None, out_csv=True): dset, loader, evaluator = eval_tuple id2ans = {} id2prob = {} for i, datum_tuple in enumerate(loader): ids, feats, boxes, sent = datum_tuple[:4] self.model.eval() if args.swa: self.swa_model.eval() with torch.no_grad(): feats, boxes = feats.cuda(), boxes.cuda() logit = self.model(sent, (feats, boxes)) # Note: LogSoftmax does not change order, hence there should be nothing wrong with taking it as our prediction logit = self.logsoftmax(logit) score = logit[:, 1] if args.swa: logit = self.swa_model(sent, (feats, boxes)) logit = self.logsoftmax(logit) _, predict = logit.max(1) for qid, l in zip(ids, predict.cpu().numpy()): id2ans[qid] = l # Getting probas for Roc Auc for qid, l in zip(ids, score.cpu().numpy()): id2prob[qid] = l if dump is not None: if out_csv == True: evaluator.dump_csv(id2ans, id2prob, dump) else: evaluator.dump_result(id2ans, dump) return id2ans, id2prob def evaluate(self, eval_tuple: DataTuple, dump=None): """Evaluate all data in data_tuple.""" id2ans, id2prob = self.predict(eval_tuple, dump=dump) acc = eval_tuple.evaluator.evaluate(id2ans) roc_auc = eval_tuple.evaluator.roc_auc(id2prob) return acc, roc_auc def save(self, name): if args.swa: torch.save(self.swa_model.state_dict(), os.path.join(self.output, "%s.pth" % name)) else: torch.save(self.model.state_dict(), os.path.join(self.output, "%s.pth" % name)) def load(self, path): print("Load model from %s" % path) state_dict = torch.load("%s" % path) new_state_dict = {} for key, value in state_dict.items(): # N_averaged is a key in SWA models we cannot load, so we skip it if key.startswith("n_averaged"): print("n_averaged:", value) continue # SWA Models will start with module if key.startswith("module."): new_state_dict[key[len("module."):]] = value else: new_state_dict[key] = value state_dict = new_state_dict self.model.load_state_dict(state_dict)
def main(args, dst_folder): # best_ac only record the best top1_ac for validation set. best_ac = 0.0 # os.environ['CUDA_VISIBLE_DEVICES'] = '0' if args.cuda_dev == 1: torch.cuda.set_device(1) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.backends.cudnn.deterministic = True # fix the GPU to deterministic mode torch.manual_seed(args.seed) # CPU seed if device == "cuda": torch.cuda.manual_seed_all(args.seed) # GPU seed random.seed(args.seed) # python seed for image transformation np.random.seed(args.seed) if args.dataset == 'svhn': mean = [x/255 for x in[127.5,127.5,127.5]] std = [x/255 for x in[127.5,127.5,127.5]] elif args.dataset == 'cifar100': mean = [0.5071, 0.4867, 0.4408] std = [0.2675, 0.2565, 0.2761] if args.DA == "standard": transform_train = transforms.Compose([ transforms.Pad(2, padding_mode='reflect'), transforms.RandomCrop(32), #transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std), ]) elif args.DA == "jitter": transform_train = transforms.Compose([ transforms.Pad(2, padding_mode='reflect'), transforms.ColorJitter(brightness= 0.4, contrast= 0.4, saturation= 0.4, hue= 0.1), transforms.RandomCrop(32), #SVHNPolicy(), #AutoAugment(), #transforms.RandomHorizontalFlip(), transforms.ToTensor(), #Cutout(n_holes=1,length=20), transforms.Normalize(mean, std), ]) else: print("Wrong value for --DA argument.") transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean, std), ]) # data loader train_loader, test_loader, train_noisy_indexes = data_config(args, transform_train, transform_test, dst_folder) if args.network == "MT_Net": print("Loading MT_Net...") model = MT_Net(num_classes = args.num_classes, dropRatio = args.dropout).to(device) elif args.network == "WRN28_2_wn": print("Loading WRN28_2...") model = WRN28_2_wn(num_classes = args.num_classes, dropout = args.dropout).to(device) elif args.network == "PreactResNet18_WNdrop": print("Loading preActResNet18_WNdrop...") model = PreactResNet18_WNdrop(drop_val = args.dropout, num_classes = args.num_classes).to(device) print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) milestones = args.M if args.swa == 'True': # to install it: # pip3 install torchcontrib # git clone https://github.com/pytorch/contrib.git # cd contrib # sudo python3 setup.py install from torchcontrib.optim import SWA #base_optimizer = RAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=1e-4) base_optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=1e-4) optimizer = SWA(base_optimizer, swa_lr=args.swa_lr) else: #optimizer = RAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=1e-4) optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=1e-4) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1) loss_train_epoch = [] loss_val_epoch = [] acc_train_per_epoch = [] acc_val_per_epoch = [] new_labels = [] exp_path = os.path.join('./', 'noise_models_{0}'.format(args.experiment_name), str(args.labeled_samples)) res_path = os.path.join('./', 'metrics_{0}'.format(args.experiment_name), str(args.labeled_samples)) if not os.path.isdir(res_path): os.makedirs(res_path) if not os.path.isdir(exp_path): os.makedirs(exp_path) cont = 0 load = False save = True if args.initial_epoch != 0: initial_epoch = args.initial_epoch load = True save = False if args.dataset_type == 'sym_noise_warmUp': load = False save = True if load: if args.loss_term == 'Reg_ep': train_type = 'C' if args.loss_term == 'MixUp_ep': train_type = 'M' if args.dropout > 0.0: train_type = train_type + 'drop' + str(int(10*args.dropout)) if args.beta == 0.0: train_type = train_type + 'noReg' path = './checkpoints/warmUp_{6}_{5}_{0}_{1}_{2}_{3}_S{4}.hdf5'.format(initial_epoch, \ args.dataset, \ args.labeled_samples, \ args.network, \ args.seed, \ args.Mixup_Alpha, \ train_type) checkpoint = torch.load(path) print("Load model in epoch " + str(checkpoint['epoch'])) print("Path loaded: ", path) model.load_state_dict(checkpoint['state_dict']) print("Relabeling the unlabeled samples...") model.eval() initial_rand_relab = args.label_noise results = np.zeros((len(train_loader.dataset), 10), dtype=np.float32) for images, images_pslab, labels, soft_labels, index in train_loader: images = images.to(device) labels = labels.to(device) soft_labels = soft_labels.to(device) outputs = model(images) prob, loss = loss_soft_reg_ep(outputs, labels, soft_labels, device, args) results[index.detach().numpy().tolist()] = prob.cpu().detach().numpy().tolist() train_loader.dataset.update_labels_randRelab(results, train_noisy_indexes, initial_rand_relab) print("Start training...") for epoch in range(1, args.epoch + 1): st = time.time() scheduler.step() # train for one epoch print(args.experiment_name, args.labeled_samples) loss_per_epoch, top_5_train_ac, top1_train_acc_original_labels, \ top1_train_ac, train_time = train_CrossEntropy_partialRelab(\ args, model, device, \ train_loader, optimizer, \ epoch, train_noisy_indexes) loss_train_epoch += [loss_per_epoch] # test if args.validation_exp == "True": loss_per_epoch, acc_val_per_epoch_i = validating(args, model, device, test_loader) else: loss_per_epoch, acc_val_per_epoch_i = testing(args, model, device, test_loader) loss_val_epoch += loss_per_epoch acc_train_per_epoch += [top1_train_ac] acc_val_per_epoch += acc_val_per_epoch_i #################################################################################################### ############################# SAVING MODELS ########################### #################################################################################################### if not os.path.exists('./checkpoints'): os.mkdir('./checkpoints') if epoch == 1: best_acc_val = acc_val_per_epoch_i[-1] snapBest = 'best_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestAccVal_%.5f' % ( epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val) torch.save(model.state_dict(), os.path.join(exp_path, snapBest + '.pth')) torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapBest + '.pth')) else: if acc_val_per_epoch_i[-1] > best_acc_val: best_acc_val = acc_val_per_epoch_i[-1] if cont > 0: try: os.remove(os.path.join(exp_path, 'opt_' + snapBest + '.pth')) os.remove(os.path.join(exp_path, snapBest + '.pth')) except OSError: pass snapBest = 'best_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestAccVal_%.5f' % ( epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val) torch.save(model.state_dict(), os.path.join(exp_path, snapBest + '.pth')) torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapBest + '.pth')) cont += 1 if epoch == args.epoch: snapLast = 'last_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestValLoss_%.5f' % ( epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val) torch.save(model.state_dict(), os.path.join(exp_path, snapLast + '.pth')) torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapLast + '.pth')) #### Save models for ensembles: if (epoch >= 150) and (epoch%2 == 0) and (args.save_checkpoint == "True"): print("Saving model ...") out_path = './checkpoints/ENS_{0}_{1}'.format(args.experiment_name, args.labeled_samples) if not os.path.exists(out_path): os.makedirs(out_path) torch.save(model.state_dict(), out_path + "/epoch_{0}.pth".format(epoch)) ### Saving model to load it again # cond = epoch%1 == 0 if args.dataset_type == 'sym_noise_warmUp': if args.loss_term == 'Reg_ep': train_type = 'C' if args.loss_term == 'MixUp_ep': train_type = 'M' if args.dropout > 0.0: train_type = train_type + 'drop' + str(int(10*args.dropout)) if args.beta == 0.0: train_type = train_type + 'noReg' cond = (epoch==args.epoch) name = 'warmUp_{1}_{0}'.format(args.Mixup_Alpha, train_type) save = True else: cond = (epoch==args.epoch) name = 'warmUp_{1}_{0}'.format(args.Mixup_Alpha, train_type) save = True if cond and save: print("Saving models...") path = './checkpoints/{0}_{1}_{2}_{3}_{4}_S{5}.hdf5'.format(name, epoch, args.dataset, args.labeled_samples, args.network, args.seed) save_checkpoint({ 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict(), 'loss_train_epoch' : np.asarray(loss_train_epoch), 'loss_val_epoch' : np.asarray(loss_val_epoch), 'acc_train_per_epoch' : np.asarray(acc_train_per_epoch), 'acc_val_per_epoch' : np.asarray(acc_val_per_epoch), 'labels': np.asarray(train_loader.dataset.soft_labels) }, filename = path) #################################################################################################### ############################ SAVING METRICS ########################### #################################################################################################### # Save losses: np.save(res_path + '/' + str(args.labeled_samples) + '_LOSS_epoch_train.npy', np.asarray(loss_train_epoch)) np.save(res_path + '/' + str(args.labeled_samples) + '_LOSS_epoch_val.npy', np.asarray(loss_val_epoch)) # save accuracies: np.save(res_path + '/' + str(args.labeled_samples) + '_accuracy_per_epoch_train.npy', np.asarray(acc_train_per_epoch)) np.save(res_path + '/' + str(args.labeled_samples) + '_accuracy_per_epoch_val.npy', np.asarray(acc_val_per_epoch)) # save the new labels new_labels.append(train_loader.dataset.labels) np.save(res_path + '/' + str(args.labeled_samples) + '_new_labels.npy', np.asarray(new_labels)) #logging.info('Epoch: [{}|{}], train_loss: {:.3f}, top1_train_ac: {:.3f}, top1_val_ac: {:.3f}, train_time: {:.3f}'.format(epoch, args.epoch, loss_per_epoch[-1], top1_train_ac, acc_val_per_epoch_i[-1], time.time() - st)) # applying swa if args.swa == 'True': optimizer.swap_swa_sgd() optimizer.bn_update(train_loader, model, device) if args.validation_exp == "True": loss_swa, acc_val_swa = validating(args, model, device, test_loader) else: loss_swa, acc_val_swa = testing(args, model, device, test_loader) snapLast = 'last_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestValLoss_%.5f_swaAcc_%.5f' % ( epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val, acc_val_swa[0]) torch.save(model.state_dict(), os.path.join(exp_path, snapLast + '.pth')) torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapLast + '.pth')) # save_fig(dst_folder) print('Best ac:%f' % best_acc_val) record_result(dst_folder, best_ac)
def train(model_name, optim='adam'): train_dataset = PretrainDataset(output_shape=config['image_resolution']) train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=8, pin_memory=True, drop_last=True) val_dataset = IDRND_dataset_CV(fold=0, mode=config['mode'].replace('train', 'val'), double_loss_mode=True, output_shape=config['image_resolution']) val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=4, drop_last=False) if model_name == 'EF': model = DoubleLossModelTwoHead(base_model=EfficientNet.from_pretrained( 'efficientnet-b3')).to(device) model.load_state_dict( torch.load( f"../models_weights/pretrained/{model_name}_{4}_2.0090592697255896_1.0.pth" )) elif model_name == 'EFGAP': model = DoubleLossModelTwoHead( base_model=EfficientNetGAP.from_pretrained('efficientnet-b3')).to( device) model.load_state_dict( torch.load( f"../models_weights/pretrained/{model_name}_{4}_2.3281182915644134_1.0.pth" )) criterion = FocalLoss(add_weight=False).to(device) criterion4class = CrossEntropyLoss().to(device) if optim == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay']) elif optim == 'sgd': optimizer = torch.optim.SGD(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'], nesterov=False) else: optimizer = torch.optim.SGD(model.parameters(), momentum=0.9, lr=config['learning_rate'], weight_decay=config['weight_decay'], nesterov=True) steps_per_epoch = train_loader.__len__() - 15 swa = SWA(optimizer, swa_start=config['swa_start'] * steps_per_epoch, swa_freq=int(config['swa_freq'] * steps_per_epoch), swa_lr=config['learning_rate'] / 10) scheduler = ExponentialLR(swa, gamma=0.9) # scheduler = StepLR(swa, step_size=5*steps_per_epoch, gamma=0.5) global_step = 0 for epoch in trange(10): if epoch < 5: scheduler.step() continue model.train() train_bar = tqdm(train_loader) train_bar.set_description_str(desc=f"N epochs - {epoch}") for step, batch in enumerate(train_bar): global_step += 1 image = batch['image'].to(device) label4class = batch['label0'].to(device) label = batch['label1'].to(device) output4class, output = model(image) loss4class = criterion4class(output4class, label4class) loss = criterion(output.squeeze(), label) swa.zero_grad() total_loss = loss4class * 0.5 + loss * 0.5 total_loss.backward() swa.step() train_writer.add_scalar(tag="learning_rate", scalar_value=scheduler.get_lr()[0], global_step=global_step) train_writer.add_scalar(tag="BinaryLoss", scalar_value=loss.item(), global_step=global_step) train_writer.add_scalar(tag="SoftMaxLoss", scalar_value=loss4class.item(), global_step=global_step) train_bar.set_postfix_str(f"Loss = {loss.item()}") try: train_writer.add_scalar(tag="idrnd_score", scalar_value=idrnd_score_pytorch( label, output), global_step=global_step) train_writer.add_scalar(tag="far_score", scalar_value=far_score(label, output), global_step=global_step) train_writer.add_scalar(tag="frr_score", scalar_value=frr_score(label, output), global_step=global_step) train_writer.add_scalar(tag="accuracy", scalar_value=bce_accuracy( label, output), global_step=global_step) except Exception: pass if (epoch > config['swa_start'] and epoch % 2 == 0) or (epoch == config['number_epochs'] - 1): swa.swap_swa_sgd() swa.bn_update(train_loader, model, device) swa.swap_swa_sgd() scheduler.step() evaluate(model, val_loader, epoch, model_name)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--seed', type=int, default=42, help='Random seed') parser.add_argument('--fast', action='store_true') parser.add_argument('--mixup', action='store_true') parser.add_argument('--balance', action='store_true') parser.add_argument('--balance-datasets', action='store_true') parser.add_argument('--swa', action='store_true') parser.add_argument('--show', action='store_true') parser.add_argument('--use-idrid', action='store_true') parser.add_argument('--use-messidor', action='store_true') parser.add_argument('--use-aptos2015', action='store_true') parser.add_argument('--use-aptos2019', action='store_true') parser.add_argument('-v', '--verbose', action='store_true') parser.add_argument('--coarse', action='store_true') parser.add_argument('-acc', '--accumulation-steps', type=int, default=1, help='Number of batches to process') parser.add_argument('-dd', '--data-dir', type=str, default='data', help='Data directory') parser.add_argument('-m', '--model', type=str, default='resnet18_gap', help='') parser.add_argument('-b', '--batch-size', type=int, default=8, help='Batch Size during training, e.g. -b 64') parser.add_argument('-e', '--epochs', type=int, default=100, help='Epoch to run') parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement') parser.add_argument('-f', '--fold', action='append', type=int, default=None) parser.add_argument('-ft', '--fine-tune', default=0, type=int) parser.add_argument('-lr', '--learning-rate', type=float, default=1e-4, help='Initial learning rate') parser.add_argument('--criterion-reg', type=str, default=None, nargs='+', help='Criterion') parser.add_argument('--criterion-ord', type=str, default=None, nargs='+', help='Criterion') parser.add_argument('--criterion-cls', type=str, default=['ce'], nargs='+', help='Criterion') parser.add_argument('-l1', type=float, default=0, help='L1 regularization loss') parser.add_argument('-l2', type=float, default=0, help='L2 regularization loss') parser.add_argument('-o', '--optimizer', default='Adam', help='Name of the optimizer') parser.add_argument('-p', '--preprocessing', default=None, help='Preprocessing method') parser.add_argument( '-c', '--checkpoint', type=str, default=None, help='Checkpoint filename to use as initial model weights') parser.add_argument('-w', '--workers', default=multiprocessing.cpu_count(), type=int, help='Num workers') parser.add_argument('-a', '--augmentations', default='medium', type=str, help='') parser.add_argument('-tta', '--tta', default=None, type=str, help='Type of TTA to use [fliplr, d4]') parser.add_argument('-t', '--transfer', default=None, type=str, help='') parser.add_argument('--fp16', action='store_true') parser.add_argument('-s', '--scheduler', default='multistep', type=str, help='') parser.add_argument('--size', default=512, type=int, help='Image size for training & inference') parser.add_argument('-wd', '--weight-decay', default=0, type=float, help='L2 weight decay') parser.add_argument('-wds', '--weight-decay-step', default=None, type=float, help='L2 weight decay step to add after each epoch') parser.add_argument('-d', '--dropout', default=0.0, type=float, help='Dropout before head layer') parser.add_argument( '--warmup', default=0, type=int, help= 'Number of warmup epochs with 0.1 of the initial LR and frozed encoder' ) parser.add_argument('-x', '--experiment', default=None, type=str, help='Dropout before head layer') args = parser.parse_args() data_dir = args.data_dir num_workers = args.workers num_epochs = args.epochs batch_size = args.batch_size learning_rate = args.learning_rate l1 = args.l1 l2 = args.l2 early_stopping = args.early_stopping model_name = args.model optimizer_name = args.optimizer image_size = (args.size, args.size) fast = args.fast augmentations = args.augmentations fp16 = args.fp16 fine_tune = args.fine_tune criterion_reg_name = args.criterion_reg criterion_cls_name = args.criterion_cls criterion_ord_name = args.criterion_ord folds = args.fold mixup = args.mixup balance = args.balance balance_datasets = args.balance_datasets use_swa = args.swa show_batches = args.show scheduler_name = args.scheduler verbose = args.verbose weight_decay = args.weight_decay use_idrid = args.use_idrid use_messidor = args.use_messidor use_aptos2015 = args.use_aptos2015 use_aptos2019 = args.use_aptos2019 warmup = args.warmup dropout = args.dropout use_unsupervised = False experiment = args.experiment preprocessing = args.preprocessing weight_decay_step = args.weight_decay_step coarse_grading = args.coarse class_names = get_class_names(coarse_grading) assert use_aptos2015 or use_aptos2019 or use_idrid or use_messidor current_time = datetime.now().strftime('%b%d_%H_%M') random_name = get_random_name() if folds is None or len(folds) == 0: folds = [None] for fold in folds: torch.cuda.empty_cache() checkpoint_prefix = f'{model_name}_{args.size}_{augmentations}' if preprocessing is not None: checkpoint_prefix += f'_{preprocessing}' if use_aptos2019: checkpoint_prefix += '_aptos2019' if use_aptos2015: checkpoint_prefix += '_aptos2015' if use_messidor: checkpoint_prefix += '_messidor' if use_idrid: checkpoint_prefix += '_idrid' if coarse_grading: checkpoint_prefix += '_coarse' if fold is not None: checkpoint_prefix += f'_fold{fold}' checkpoint_prefix += f'_{random_name}' if experiment is not None: checkpoint_prefix = experiment directory_prefix = f'{current_time}/{checkpoint_prefix}' log_dir = os.path.join('runs', directory_prefix) os.makedirs(log_dir, exist_ok=False) config_fname = os.path.join(log_dir, f'{checkpoint_prefix}.json') with open(config_fname, 'w') as f: train_session_args = vars(args) f.write(json.dumps(train_session_args, indent=2)) set_manual_seed(args.seed) num_classes = len(class_names) model = get_model(model_name, num_classes=num_classes, dropout=dropout).cuda() if args.transfer: transfer_checkpoint = fs.auto_file(args.transfer) print("Transfering weights from model checkpoint", transfer_checkpoint) checkpoint = load_checkpoint(transfer_checkpoint) pretrained_dict = checkpoint['model_state_dict'] for name, value in pretrained_dict.items(): try: model.load_state_dict(collections.OrderedDict([(name, value)]), strict=False) except Exception as e: print(e) report_checkpoint(checkpoint) if args.checkpoint: checkpoint = load_checkpoint(fs.auto_file(args.checkpoint)) unpack_checkpoint(checkpoint, model=model) report_checkpoint(checkpoint) train_ds, valid_ds, train_sizes = get_datasets( data_dir=data_dir, use_aptos2019=use_aptos2019, use_aptos2015=use_aptos2015, use_idrid=use_idrid, use_messidor=use_messidor, use_unsupervised=False, coarse_grading=coarse_grading, image_size=image_size, augmentation=augmentations, preprocessing=preprocessing, target_dtype=int, fold=fold, folds=4) train_loader, valid_loader = get_dataloaders( train_ds, valid_ds, batch_size=batch_size, num_workers=num_workers, train_sizes=train_sizes, balance=balance, balance_datasets=balance_datasets, balance_unlabeled=False) loaders = collections.OrderedDict() loaders["train"] = train_loader loaders["valid"] = valid_loader print('Datasets :', data_dir) print(' Train size :', len(train_loader), len(train_loader.dataset)) print(' Valid size :', len(valid_loader), len(valid_loader.dataset)) print(' Aptos 2019 :', use_aptos2019) print(' Aptos 2015 :', use_aptos2015) print(' IDRID :', use_idrid) print(' Messidor :', use_messidor) print('Train session :', directory_prefix) print(' FP16 mode :', fp16) print(' Fast mode :', fast) print(' Mixup :', mixup) print(' Balance cls. :', balance) print(' Balance ds. :', balance_datasets) print(' Warmup epoch :', warmup) print(' Train epochs :', num_epochs) print(' Fine-tune ephs :', fine_tune) print(' Workers :', num_workers) print(' Fold :', fold) print(' Log dir :', log_dir) print(' Augmentations :', augmentations) print('Model :', model_name) print(' Parameters :', count_parameters(model)) print(' Image size :', image_size) print(' Dropout :', dropout) print(' Classes :', class_names, num_classes) print('Optimizer :', optimizer_name) print(' Learning rate :', learning_rate) print(' Batch size :', batch_size) print(' Criterion (cls):', criterion_cls_name) print(' Criterion (reg):', criterion_reg_name) print(' Criterion (ord):', criterion_ord_name) print(' Scheduler :', scheduler_name) print(' Weight decay :', weight_decay, weight_decay_step) print(' L1 reg. :', l1) print(' L2 reg. :', l2) print(' Early stopping :', early_stopping) # model training callbacks = [] criterions = {} main_metric = 'cls/kappa' if criterion_reg_name is not None: cb, crits = get_reg_callbacks(criterion_reg_name, class_names=class_names, show=show_batches) callbacks += cb criterions.update(crits) if criterion_ord_name is not None: cb, crits = get_ord_callbacks(criterion_ord_name, class_names=class_names, show=show_batches) callbacks += cb criterions.update(crits) if criterion_cls_name is not None: cb, crits = get_cls_callbacks(criterion_cls_name, num_classes=num_classes, num_epochs=num_epochs, class_names=class_names, show=show_batches) callbacks += cb criterions.update(crits) if l1 > 0: callbacks += [ LPRegularizationCallback(start_wd=l1, end_wd=l1, schedule=None, prefix='l1', p=1) ] if l2 > 0: callbacks += [ LPRegularizationCallback(start_wd=l2, end_wd=l2, schedule=None, prefix='l2', p=2) ] callbacks += [CustomOptimizerCallback()] runner = SupervisedRunner(input_key='image') # Pretrain/warmup if warmup: set_trainable(model.encoder, False, False) optimizer = get_optimizer('Adam', get_optimizable_parameters(model), learning_rate=learning_rate * 0.1) runner.train(fp16=fp16, model=model, criterion=criterions, optimizer=optimizer, scheduler=None, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, 'warmup'), num_epochs=warmup, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": vars(args)}) del optimizer # Main train if num_epochs: set_trainable(model.encoder, True, False) optimizer = get_optimizer(optimizer_name, get_optimizable_parameters(model), learning_rate=learning_rate, weight_decay=weight_decay) if use_swa: from torchcontrib.optim import SWA optimizer = SWA(optimizer, swa_start=len(train_loader), swa_freq=512) scheduler = get_scheduler(scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(train_loader)) # Additional callbacks that specific to main stage only added here to copy of callbacks main_stage_callbacks = callbacks if early_stopping: es_callback = EarlyStoppingCallback(early_stopping, min_delta=1e-4, metric=main_metric, minimize=False) main_stage_callbacks = callbacks + [es_callback] runner.train(fp16=fp16, model=model, criterion=criterions, optimizer=optimizer, scheduler=scheduler, callbacks=main_stage_callbacks, loaders=loaders, logdir=os.path.join(log_dir, 'main'), num_epochs=num_epochs, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": vars(args)}) del optimizer, scheduler best_checkpoint = os.path.join(log_dir, 'main', 'checkpoints', 'best.pth') model_checkpoint = os.path.join(log_dir, 'main', 'checkpoints', f'{checkpoint_prefix}.pth') clean_checkpoint(best_checkpoint, model_checkpoint) # Restoring best model from checkpoint checkpoint = load_checkpoint(best_checkpoint) unpack_checkpoint(checkpoint, model=model) report_checkpoint(checkpoint) # Stage 3 - Fine tuning if fine_tune: set_trainable(model.encoder, False, False) optimizer = get_optimizer(optimizer_name, get_optimizable_parameters(model), learning_rate=learning_rate) scheduler = get_scheduler('multistep', optimizer, lr=learning_rate, num_epochs=fine_tune, batches_in_epoch=len(train_loader)) runner.train(fp16=fp16, model=model, criterion=criterions, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, 'finetune'), num_epochs=fine_tune, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": vars(args)}) best_checkpoint = os.path.join(log_dir, 'finetune', 'checkpoints', 'best.pth') model_checkpoint = os.path.join(log_dir, 'finetune', 'checkpoints', f'{checkpoint_prefix}.pth') clean_checkpoint(best_checkpoint, model_checkpoint)
def __init__(self, config_path): self.image_config, self.model_config, self.run_config = LoadConfig( config_path=config_path).train_config() self.device = torch.device('cuda:%d' % self.run_config['device_ids'][0] if torch. cuda.is_available else 'cpu') self.model = getModel(self.model_config) os.makedirs(self.run_config['model_save_path'], exist_ok=True) self.run_config['num_workers'] = self.run_config['num_workers'] * len( self.run_config['device_ids']) self.train_set = Data(root=self.image_config['image_path'], phase='train', data_name=self.image_config['data_name'], img_mode=self.image_config['image_mode'], n_classes=self.model_config['num_classes'], size=self.image_config['image_size'], scale=self.image_config['image_scale']) self.valid_set = Data(root=self.image_config['image_path'], phase='valid', data_name=self.image_config['data_name'], img_mode=self.image_config['image_mode'], n_classes=self.model_config['num_classes'], size=self.image_config['image_size'], scale=self.image_config['image_scale']) self.className = self.valid_set.className self.train_loader = DataLoader( self.train_set, batch_size=self.run_config['batch_size'], shuffle=True, num_workers=self.run_config['num_workers'], pin_memory=True, drop_last=False) self.valid_loader = DataLoader( self.valid_set, batch_size=self.run_config['batch_size'], shuffle=True, num_workers=self.run_config['num_workers'], pin_memory=True, drop_last=False) train_params = self.model.parameters() self.optimizer = RAdam(train_params, lr=eval(self.run_config['lr']), weight_decay=eval( self.run_config['weight_decay'])) if self.run_config['swa']: self.optimizer = SWA(self.optimizer, swa_start=10, swa_freq=5, swa_lr=0.005) # 设置学习率调节策略 self.lr_scheduler = utils.adjustLR.AdjustLr(self.optimizer) if self.run_config['use_weight_balance']: weight = utils.weight_balance.getWeight( self.run_config['weights_file']) else: weight = None self.Criterion = SegmentationLosses(weight=weight, cuda=True, device=self.device, batch_average=False) self.metric = utils.metrics.MetricMeter( self.model_config['num_classes'])
class Trainer(): def __init__(self, config_path): self.image_config, self.model_config, self.run_config = LoadConfig( config_path=config_path).train_config() self.device = torch.device('cuda:%d' % self.run_config['device_ids'][0] if torch. cuda.is_available else 'cpu') self.model = getModel(self.model_config) os.makedirs(self.run_config['model_save_path'], exist_ok=True) self.run_config['num_workers'] = self.run_config['num_workers'] * len( self.run_config['device_ids']) self.train_set = Data(root=self.image_config['image_path'], phase='train', data_name=self.image_config['data_name'], img_mode=self.image_config['image_mode'], n_classes=self.model_config['num_classes'], size=self.image_config['image_size'], scale=self.image_config['image_scale']) self.valid_set = Data(root=self.image_config['image_path'], phase='valid', data_name=self.image_config['data_name'], img_mode=self.image_config['image_mode'], n_classes=self.model_config['num_classes'], size=self.image_config['image_size'], scale=self.image_config['image_scale']) self.className = self.valid_set.className self.train_loader = DataLoader( self.train_set, batch_size=self.run_config['batch_size'], shuffle=True, num_workers=self.run_config['num_workers'], pin_memory=True, drop_last=False) self.valid_loader = DataLoader( self.valid_set, batch_size=self.run_config['batch_size'], shuffle=True, num_workers=self.run_config['num_workers'], pin_memory=True, drop_last=False) train_params = self.model.parameters() self.optimizer = RAdam(train_params, lr=eval(self.run_config['lr']), weight_decay=eval( self.run_config['weight_decay'])) if self.run_config['swa']: self.optimizer = SWA(self.optimizer, swa_start=10, swa_freq=5, swa_lr=0.005) # 设置学习率调节策略 self.lr_scheduler = utils.adjustLR.AdjustLr(self.optimizer) if self.run_config['use_weight_balance']: weight = utils.weight_balance.getWeight( self.run_config['weights_file']) else: weight = None self.Criterion = SegmentationLosses(weight=weight, cuda=True, device=self.device, batch_average=False) self.metric = utils.metrics.MetricMeter( self.model_config['num_classes']) @logger.catch # 在日志中记录错误 def __call__(self): # 设置记录日志 self.global_name = self.model_config['model_name'] logger.add(os.path.join( self.image_config['image_path'], 'log', 'log_' + self.global_name + '/train_{time}.log'), format="{time} {level} {message}", level="INFO", encoding='utf-8') self.writer = SummaryWriter(logdir=os.path.join( self.image_config['image_path'], 'run', 'runs_' + self.global_name)) logger.info("image_config: {} \n model_config: {} \n run_config: {}", self.image_config, self.model_config, self.run_config) # 如果多余一张卡,就采用数据并行 if len(self.run_config['device_ids']) > 1: self.model = nn.DataParallel( self.model, device_ids=self.run_config['device_ids']) self.model.to(device=self.device) cnt = 0 # 如果有预训练模型就加载 if self.run_config['pretrain'] != '': logger.info("loading pretrain %s" % self.run_config['pretrain']) try: self.load_checkpoint(use_optimizer=True, use_epoch=True, use_miou=True) except: print('load model with channed!!!!!') self.load_checkpoint_with_changed(use_optimizer=False, use_epoch=False, use_miou=False) logger.info("start training") for epoch in range(self.run_config['start_epoch'], self.run_config['epoch']): lr = self.optimizer.param_groups[0]['lr'] print('epoch=%d, lr=%.8f' % (epoch, lr)) self.train_epoch(epoch, lr) valid_miou = self.valid_epoch(epoch) # 确定采用哪一种学习率调节策略 self.lr_scheduler.LambdaLR_(milestone=5, gamma=0.92).step(epoch=epoch) self.save_checkpoint(epoch, valid_miou, 'last_' + self.global_name) if valid_miou > self.run_config['best_miou']: cnt = 0 self.save_checkpoint(epoch, valid_miou, 'best_' + self.global_name) logger.info("############# %d saved ##############" % epoch) self.run_config['best_miou'] = valid_miou else: cnt += 1 if cnt == self.run_config['early_stop']: logger.info("early stop") break self.writer.close() def train_epoch(self, epoch, lr): self.metric.reset() train_loss = 0.0 train_miou = 0.0 tbar = tqdm(self.train_loader) self.model.train() for i, (image, mask, edge) in enumerate(tbar): tbar.set_description('train_miou:%.6f' % train_miou) tbar.set_postfix({"train_loss": train_loss}) image = image.to(self.device) mask = mask.to(self.device) edge = edge.to(self.device) self.optimizer.zero_grad() out = self.model(image) if isinstance(out, tuple): aux_out, final_out = out[0], out[1] else: aux_out, final_out = None, out if self.model_config['model_name'] == 'ocrnet': aux_loss = self.Criterion.build_loss(mode='rmi')(aux_out, mask) cls_loss = self.Criterion.build_loss(mode='ce')(final_out, mask) loss = 0.4 * aux_loss + cls_loss loss = loss.mean() elif self.model_config['model_name'] == 'hrnet_duc': loss_body = self.Criterion.build_loss( mode=self.run_config['loss_type'])(final_out, mask) loss_edge = self.Criterion.build_loss(mode='dice')( aux_out.squeeze(), edge) loss = loss_body + loss_edge loss = loss.mean() else: loss = self.Criterion.build_loss( mode=self.run_config['loss_type'])(final_out, mask) loss.backward() self.optimizer.step() if self.run_config['swa']: self.optimizer.swap_swa_sgd() with torch.no_grad(): train_loss = ((train_loss * i) + loss.item()) / (i + 1) _, pred = torch.max(final_out, dim=1) self.metric.add(pred.cpu().numpy(), mask.cpu().numpy()) train_miou, train_ious = self.metric.miou() train_fwiou = self.metric.fw_iou() train_accu = self.metric.pixel_accuracy() train_fwaccu = self.metric.pixel_accuracy_class() logger.info( "Epoch:%2d\t lr:%.8f\t Train loss:%.4f\t Train FWiou:%.4f\t Train Miou:%.4f\t Train accu:%.4f\t " "Train fwaccu:%.4f" % (epoch, lr, train_loss, train_fwiou, train_miou, train_accu, train_fwaccu)) cls = "" ious = list() ious_dict = OrderedDict() for i, c in enumerate(self.className): ious_dict[c] = train_ious[i] ious.append(ious_dict[c]) cls += "%s:" % c + "%.4f " ious = tuple(ious) logger.info(cls % ious) # tensorboard self.writer.add_scalar("lr", lr, epoch) self.writer.add_scalar("loss/train_loss", train_loss, epoch) self.writer.add_scalar("miou/train_miou", train_miou, epoch) self.writer.add_scalar("fwiou/train_fwiou", train_fwiou, epoch) self.writer.add_scalar("accuracy/train_accu", train_accu, epoch) self.writer.add_scalar("fwaccuracy/train_fwaccu", train_fwaccu, epoch) self.writer.add_scalars("ious/train_ious", ious_dict, epoch) def valid_epoch(self, epoch): self.metric.reset() valid_loss = 0.0 valid_miou = 0.0 tbar = tqdm(self.valid_loader) self.model.eval() with torch.no_grad(): for i, (image, mask, edge) in enumerate(tbar): tbar.set_description('valid_miou:%.6f' % valid_miou) tbar.set_postfix({"valid_loss": valid_loss}) image = image.to(self.device) mask = mask.to(self.device) edge = edge.to(self.device) out = self.model(image) if isinstance(out, tuple): aux_out, final_out = out[0], out[1] else: aux_out, final_out = None, out if self.model_config['model_name'] == 'ocrnet': aux_loss = self.Criterion.build_loss(mode='rmi')(aux_out, mask) cls_loss = self.Criterion.build_loss(mode='ce')(final_out, mask) loss = 0.4 * aux_loss + cls_loss loss = loss.mean() elif self.model_config['model_name'] == 'hrnet_duc': loss_body = self.Criterion.build_loss( mode=self.run_config['loss_type'])(final_out, mask) loss_edge = self.Criterion.build_loss(mode='dice')( aux_out.squeeze(), edge) loss = loss_body + loss_edge # loss = loss.mean() else: loss = self.Criterion.build_loss(mode='ce')(final_out, mask) valid_loss = ((valid_loss * i) + float(loss)) / (i + 1) _, pred = torch.max(final_out, dim=1) self.metric.add(pred.cpu().numpy(), mask.cpu().numpy()) valid_miou, valid_ious = self.metric.miou() valid_fwiou = self.metric.fw_iou() valid_accu = self.metric.pixel_accuracy() valid_fwaccu = self.metric.pixel_accuracy_class() logger.info( "epoch:%d\t valid loss:%.4f\t valid fwiou:%.4f\t valid miou:%.4f valid accu:%.4f\t " "valid fwaccu:%.4f\t" % (epoch, valid_loss, valid_fwiou, valid_miou, valid_accu, valid_fwaccu)) ious = list() cls = "" ious_dict = OrderedDict() for i, c in enumerate(self.className): ious_dict[c] = valid_ious[i] ious.append(ious_dict[c]) cls += "%s:" % c + "%.4f " ious = tuple(ious) logger.info(cls % ious) self.writer.add_scalar("loss/valid_loss", valid_loss, epoch) self.writer.add_scalar("miou/valid_miou", valid_miou, epoch) self.writer.add_scalar("fwiou/valid_fwiou", valid_fwiou, epoch) self.writer.add_scalar("accuracy/valid_accu", valid_accu, epoch) self.writer.add_scalar("fwaccuracy/valid_fwaccu", valid_fwaccu, epoch) self.writer.add_scalars("ious/valid_ious", ious_dict, epoch) return valid_miou def save_checkpoint(self, epoch, best_miou, flag): meta = { 'epoch': epoch, 'model': self.model.state_dict(), 'optim': self.optimizer.state_dict(), 'bmiou': best_miou } try: torch.save(meta, os.path.join(self.run_config['model_save_path'], '%s.pth' % flag), _use_new_zipfile_serialization=False) except: torch.save( meta, os.path.join(self.run_config['model_save_path'], '%s.pth' % flag)) def load_checkpoint(self, use_optimizer, use_epoch, use_miou): state_dict = torch.load(self.run_config['pretrain'], map_location=self.device) self.model.load_state_dict(state_dict['model']) if use_optimizer: self.optimizer.load_state_dict(state_dict['optim']) if use_epoch: self.run_config['start_epoch'] = state_dict['epoch'] + 1 if use_miou: self.run_config['best_miou'] = state_dict['bmiou'] def load_checkpoint_with_changed(self, use_optimizer, use_epoch, use_miou): state_dict = torch.load(self.run_config['pretrain'], map_location=self.device) pretrain_dict = state_dict['model'] model_dict = self.model.state_dict() pretrain_dict = { k: v for k, v in pretrain_dict.items() if k in model_dict and 'edge' not in k } model_dict.update(pretrain_dict) self.model.load_state_dict(model_dict) if use_optimizer: self.optimizer.load_state_dict(state_dict['optim']) if use_epoch: self.run_config['start_epoch'] = state_dict['epoch'] + 1 if use_miou: self.run_config['best_miou'] = state_dict['bmiou']
def train(model, device, trainloader, testloader, optimizer, criterion, metric, epochs, learning_rate, swa=True, enable_scheduler=True, model_arch=''): ''' Function to perform model training. ''' model.to(device) steps = 0 running_loss = 0 running_metric = 0 print_every = 100 train_losses = [] test_losses = [] train_metrics = [] test_metrics = [] if swa: # initialize stochastic weight averaging opt = SWA(optimizer) else: opt = optimizer # learning rate cosine annealing if enable_scheduler: scheduler = lr_scheduler.CosineAnnealingLR(optimizer, len(trainloader), eta_min=0.0000001) for epoch in range(epochs): if enable_scheduler: scheduler.step() for inputs, labels in trainloader: steps += 1 # Move input and label tensors to the default device inputs, labels = inputs.to(device), labels.to(device) opt.zero_grad() outputs = model.forward(inputs) loss = criterion(outputs, labels.float()) loss.backward() opt.step() running_loss += loss running_metric += metric(outputs, labels.float()) if steps % print_every == 0: test_loss = 0 test_metric = 0 model.eval() with torch.no_grad(): for inputs, labels in testloader: inputs, labels = inputs.to(device), labels.to(device) outputs = model.forward(inputs) test_loss += criterion(outputs, labels.float()) test_metric += metric(outputs, labels.float()) print(f"Epoch {epoch+1}/{epochs}.. " f"Train loss: {running_loss/print_every:.3f}.. " f"Test loss: {test_loss/len(testloader):.3f}.. " f"Train metric: {running_metric/print_every:.3f}.. " f"Test metric: {test_metric/len(testloader):.3f}.. ") train_losses.append(running_loss / print_every) test_losses.append(test_loss / len(testloader)) train_metrics.append(running_metric / print_every) test_metrics.append(test_metric / len(testloader)) running_loss = 0 running_metric = 0 model.train() if swa: opt.update_swa() save_model(model, model_arch, learning_rate, epochs, train_losses, test_losses, train_metrics, test_metrics, filepath='models_checkpoints') if swa: opt.swap_swa_sgd() return model, train_losses, test_losses, train_metrics, test_metrics
def train(self, train_loader, eval_loader, epoch): # 定义优化器 if self.args.swa: logger.info('SWA training') base_opt = torch.optim.SGD(self.model.parameters(), lr=self.args.learning_rate) optimizer = SWA(base_opt, swa_start=self.args.swa_start, swa_freq=self.args.swa_freq, swa_lr=self.args.swa_lr) scheduler = CyclicLR( optimizer, base_lr=5e-5, max_lr=7e-5, step_size_up=(self.args.epochs * len(train_loader) / self.args.batch_accumulation), cycle_momentum=False) else: logger.info('Adam training') optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.learning_rate) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=self.args.warmup, num_training_steps=(self.args.epochs * len(train_loader) / self.args.batch_accumulation)) bar = tqdm(range(self.args.train_steps), total=self.args.train_steps) train_batches = cycle(train_loader) loss_sum = 0.0 start = time.time() self.model.train() for step in bar: batch = next(train_batches) input_ids, input_mask, segment_ids, label_ids = [ t.to(self.device) for t in batch ] loss, _ = self.model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids) if self.gpu_num > 1: loss = loss.mean() optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() # optimizer.update_swa() loss_sum += loss.cpu().item() train_loss = loss_sum / (step + 1) bar.set_description("loss {}".format(train_loss)) if (step + 1) % self.args.eval_steps == 0: logger.info("***** Training result *****") logger.info(' time %.2fs ', time.time() - start) logger.info(" %s = %s", 'global_step', str(step + 1)) logger.info(" %s = %s", 'train loss', str(train_loss)) # 每eval_steps进行一次evaluate self.result = { 'epoch': epoch, 'global_step': step + 1, 'loss': train_loss } if self.args.swa: optimizer.swap_swa_sgd() self.evaluate(eval_loader, epoch) if self.args.swa: optimizer.swap_swa_sgd() if self.args.swa: optimizer.swap_swa_sgd() logging.info('The training of epoch ' + str(epoch + 1) + ' has finished.')
def __init__(self): if args.train is not None: self.train_tuple = get_tuple(args.train, bs=args.batch_size, shuffle=True, drop_last=False) if args.valid is not None: valid_bsize = 2048 if args.multiGPU else 50 self.valid_tuple = get_tuple(args.valid, bs=valid_bsize, shuffle=False, drop_last=False) else: self.valid_tuple = None # Select Model, X is default if args.model == "X": self.model = ModelX(args) elif args.model == "V": self.model = ModelV(args) elif args.model == "U": self.model = ModelU(args) elif args.model == "D": self.model = ModelD(args) elif args.model == 'O': self.model = ModelO(args) else: print(args.model, " is not implemented.") # Load pre-trained weights from paths if args.loadpre is not None: self.model.load(args.loadpre) # GPU options if args.multiGPU: self.model.lxrt_encoder.multi_gpu() self.model = self.model.cuda() # Losses and optimizer self.logsoftmax = nn.LogSoftmax(dim=1) self.nllloss = nn.NLLLoss() if args.train is not None: batch_per_epoch = len(self.train_tuple.loader) self.t_total = int(batch_per_epoch * args.epochs // args.acc) print("Total Iters: %d" % self.t_total) def is_backbone(n): if "encoder" in n: return True elif "embeddings" in n: return True elif "pooler" in n: return True print("F: ", n) return False no_decay = ['bias', 'LayerNorm.weight'] params = list(self.model.named_parameters()) if args.reg: optimizer_grouped_parameters = [ { "params": [p for n, p in params if is_backbone(n)], "lr": args.lr }, { "params": [p for n, p in params if not is_backbone(n)], "lr": args.lr * 500 }, ] for n, p in self.model.named_parameters(): print(n) self.optim = AdamW(optimizer_grouped_parameters, lr=args.lr) else: optimizer_grouped_parameters = [{ 'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 'weight_decay': args.wd }, { 'params': [p for n, p in params if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] self.optim = AdamW(optimizer_grouped_parameters, lr=args.lr) if args.train is not None: self.scheduler = get_linear_schedule_with_warmup( self.optim, self.t_total * 0.1, self.t_total) self.output = args.output os.makedirs(self.output, exist_ok=True) # SWA Method: if args.contrib: self.optim = SWA(self.optim, swa_start=self.t_total * 0.75, swa_freq=5, swa_lr=args.lr) if args.swa: self.swa_model = AveragedModel(self.model) self.swa_start = self.t_total * 0.75 self.swa_scheduler = SWALR(self.optim, swa_lr=args.lr)