class ParamOptim: params: List[torch.Tensor] lr: float = 1e-3 eps: float = 1e-8 clip_grad: float = None optimizer: Optimizer = AdamW def __post_init__(self): base_opt = self.optimizer(self.params, lr=self.lr, eps=self.eps) self.optim = SWA(base_opt) def set_lr(self, lr): for pg in self.optim.param_groups: pg['lr'] = lr return lr def step(self, loss): self.optim.zero_grad() loss.backward() if self.clip_grad is not None: torch.nn.utils.clip_grad_norm_(self.params, self.clip_grad) self.optim.step() return loss
def train(opt): model = www_model_jamo_vertical.STR(opt, device) print( 'model parameters. height {}, width {}, num of fiducial {}, input channel {}, output channel {}, hidden size {}, batch max length {}' .format(opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.batch_max_length)) # weight initialization for name, param, in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initializaed') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: if 'weight' in name: param.data.fill_(1) continue # load pretrained model if opt.saved_model != '': base_path = './models' print( f'looking for pretrained model from {os.path.join(base_path, opt.saved_model)}' ) try: model.load_state_dict( torch.load(os.path.join(base_path, opt.saved_model))) print('loading complete ') except Exception as e: print(e) print('coud not find model') #data parallel for multi GPU model = torch.nn.DataParallel(model).to(device) model.train() # loss criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) #ignore [GO] token = ignore index 0 log_avg = utils.Averager() # filter that only require gradient descent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Tranable params : ', sum(params_num)) # optimizer # base_opt = optim.Adadelta(filtered_parameters, lr= opt.lr, rho = opt.rho, eps = opt.eps) base_opt = torch.optim.Adam(filtered_parameters, lr=0.001) optimizer = SWA(base_opt) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', verbose=True, patience=2, factor=0.5) # optimizer = adabound.AdaBound(filtered_parameters, lr=1e-3, final_lr=0.1) # opt log with open(f'./models/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '---------------------Options-----------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)} : {str(v)}\n' opt_log += '---------------------------------------------\n' opt_file.write(opt_log) #start training start_time = time.time() best_accuracy = -1 best_norm_ED = -1 swa_count = 0 for n_epoch, epoch in enumerate(range(opt.num_epoch)): for n_iter, data_point in enumerate(data_loader): image_tensors, top, mid, bot = data_point image = image_tensors.to(device) text_top, length_top = top_converter.encode( top, batch_max_length=opt.batch_max_length) text_mid, length_mid = middle_converter.encode( mid, batch_max_length=opt.batch_max_length) text_bot, length_bot = bottom_converter.encode( bot, batch_max_length=opt.batch_max_length) batch_size = image.size(0) pred_top, pred_mid, pred_bot = model(image, text_top[:, :-1], text_mid[:, :-1], text_bot[:, :-1]) # cost_top = criterion(pred_top.view(-1, pred_top.shape[-1]), text_top[:, 1:].contiguous().view(-1)) # cost_mid = criterion(pred_mid.view(-1, pred_mid.shape[-1]), text_mid[:, 1:].contiguous().view(-1)) # cost_bot = criterion(pred_bot.view(-1, pred_bot.shape[-1]), text_bot[:, 1:].contiguous().view(-1)) if n_iter % 2 == 0: cost_top = utils.reduced_focal_loss( pred_top.view(-1, pred_top.shape[-1]), text_top[:, 1:].contiguous().view(-1), ignore_index=0, gamma=2, alpha=0.25, threshold=0.5) cost_mid = utils.reduced_focal_loss( pred_mid.view(-1, pred_mid.shape[-1]), text_mid[:, 1:].contiguous().view(-1), ignore_index=0, gamma=2, alpha=0.25, threshold=0.5) cost_bot = utils.reduced_focal_loss( pred_bot.view(-1, pred_bot.shape[-1]), text_bot[:, 1:].contiguous().view(-1), ignore_index=0, gamma=2, alpha=0.25, threshold=0.5) else: cost_top = utils.CB_loss(text_top[:, 1:].contiguous().view(-1), pred_top.view(-1, pred_top.shape[-1]), top_per_cls, opt.top_n_cls, 'focal', 0.999, 0.5) cost_mid = utils.CB_loss(text_mid[:, 1:].contiguous().view(-1), pred_mid.view(-1, pred_mid.shape[-1]), mid_per_cls, opt.middle_n_cls, 'focal', 0.999, 0.5) cost_bot = utils.CB_loss(text_bot[:, 1:].contiguous().view(-1), pred_bot.view(-1, pred_bot.shape[-1]), bot_per_cls, opt.bottom_n_cls, 'focal', 0.999, 0.5) cost = cost_top * 0.33 + cost_mid * 0.33 + cost_bot * 0.33 loss_avg = utils.Averager() loss_avg.add(cost) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) #gradient clipping with 5 optimizer.step() print(loss_avg.val()) #validation if (n_iter % opt.valInterval == 0) & (n_iter != 0): elapsed_time = time.time() - start_time with open(f'./models/{opt.experiment_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, pred_top_str, pred_mid_str, pred_bot_str, label_top, label_mid, label_bot, infer_time, length_of_data = evaluate.validation_jamo( model, criterion, valid_loader, top_converter, middle_converter, bottom_converter, opt) scheduler.step(current_accuracy) model.train() present_time = time.localtime() loss_log = f'[epoch : {n_epoch}/{opt.num_epoch}] [iter : {n_iter*opt.batch_size} / {int(len(data) * 0.95)}]\n' + f'Train loss : {loss_avg.val():0.5f}, Valid loss : {valid_loss:0.5f}, Elapsed time : {elapsed_time:0.5f}, Present time : {present_time[1]}/{present_time[2]}, {present_time[3]+9} : {present_time[4]}' loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"current_norm_ED":17s}: {current_norm_ED:0.2f}' #keep the best if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.module.state_dict(), f'./models/{opt.experiment_name}/best_accuracy_{round(current_accuracy,2)}.pth' ) if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.module.state_dict(), f'./models/{opt.experiment_name}/best_norm_ED.pth') best_model_log = f'{"Best accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction" :25s}| T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' random_idx = np.random.choice(range(len(label_top)), size=5, replace=False) label_concat = np.concatenate([ np.asarray(label_top).reshape(1, -1), np.asarray(label_mid).reshape(1, -1), np.asarray(label_bot).reshape(1, -1) ], axis=0).reshape(3, -1) pred_concat = np.concatenate([ np.asarray(pred_top_str).reshape(1, -1), np.asarray(pred_mid_str).reshape(1, -1), np.asarray(pred_bot_str).reshape(1, -1) ], axis=0).reshape(3, -1) for i in random_idx: label_sample = label_concat[:, i] pred_sample = pred_concat[:, i] gt_str = utils.str_combine(label_sample[0], label_sample[1], label_sample[2]) pred_str = utils.str_combine(pred_sample[0], pred_sample[1], pred_sample[2]) predicted_result_log += f'{gt_str:25s} | {pred_str:25s} | \t{str(pred_str == gt_str)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # Stochastic weight averaging optimizer.update_swa() swa_count += 1 if swa_count % 5 == 0: optimizer.swap_swa_sgd() torch.save( model.module.state_dict(), f'./models/{opt.experiment_name}/swa_{swa_count}.pth') if (n_epoch) % 5 == 0: torch.save(model.module.state_dict(), f'./models/{opt.experiment_name}/{n_epoch}.pth')
def train(model_name, optim='adam'): train_dataset = PretrainDataset(output_shape=config['image_resolution']) train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=8, pin_memory=True, drop_last=True) val_dataset = IDRND_dataset_CV(fold=0, mode=config['mode'].replace('train', 'val'), double_loss_mode=True, output_shape=config['image_resolution']) val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=4, drop_last=False) if model_name == 'EF': model = DoubleLossModelTwoHead(base_model=EfficientNet.from_pretrained( 'efficientnet-b3')).to(device) model.load_state_dict( torch.load( f"../models_weights/pretrained/{model_name}_{4}_2.0090592697255896_1.0.pth" )) elif model_name == 'EFGAP': model = DoubleLossModelTwoHead( base_model=EfficientNetGAP.from_pretrained('efficientnet-b3')).to( device) model.load_state_dict( torch.load( f"../models_weights/pretrained/{model_name}_{4}_2.3281182915644134_1.0.pth" )) criterion = FocalLoss(add_weight=False).to(device) criterion4class = CrossEntropyLoss().to(device) if optim == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay']) elif optim == 'sgd': optimizer = torch.optim.SGD(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'], nesterov=False) else: optimizer = torch.optim.SGD(model.parameters(), momentum=0.9, lr=config['learning_rate'], weight_decay=config['weight_decay'], nesterov=True) steps_per_epoch = train_loader.__len__() - 15 swa = SWA(optimizer, swa_start=config['swa_start'] * steps_per_epoch, swa_freq=int(config['swa_freq'] * steps_per_epoch), swa_lr=config['learning_rate'] / 10) scheduler = ExponentialLR(swa, gamma=0.9) # scheduler = StepLR(swa, step_size=5*steps_per_epoch, gamma=0.5) global_step = 0 for epoch in trange(10): if epoch < 5: scheduler.step() continue model.train() train_bar = tqdm(train_loader) train_bar.set_description_str(desc=f"N epochs - {epoch}") for step, batch in enumerate(train_bar): global_step += 1 image = batch['image'].to(device) label4class = batch['label0'].to(device) label = batch['label1'].to(device) output4class, output = model(image) loss4class = criterion4class(output4class, label4class) loss = criterion(output.squeeze(), label) swa.zero_grad() total_loss = loss4class * 0.5 + loss * 0.5 total_loss.backward() swa.step() train_writer.add_scalar(tag="learning_rate", scalar_value=scheduler.get_lr()[0], global_step=global_step) train_writer.add_scalar(tag="BinaryLoss", scalar_value=loss.item(), global_step=global_step) train_writer.add_scalar(tag="SoftMaxLoss", scalar_value=loss4class.item(), global_step=global_step) train_bar.set_postfix_str(f"Loss = {loss.item()}") try: train_writer.add_scalar(tag="idrnd_score", scalar_value=idrnd_score_pytorch( label, output), global_step=global_step) train_writer.add_scalar(tag="far_score", scalar_value=far_score(label, output), global_step=global_step) train_writer.add_scalar(tag="frr_score", scalar_value=frr_score(label, output), global_step=global_step) train_writer.add_scalar(tag="accuracy", scalar_value=bce_accuracy( label, output), global_step=global_step) except Exception: pass if (epoch > config['swa_start'] and epoch % 2 == 0) or (epoch == config['number_epochs'] - 1): swa.swap_swa_sgd() swa.bn_update(train_loader, model, device) swa.swap_swa_sgd() scheduler.step() evaluate(model, val_loader, epoch, model_name)
def train(model, device, trainloader, testloader, optimizer, criterion, metric, epochs, learning_rate, swa=True, enable_scheduler=True, model_arch=''): ''' Function to perform model training. ''' model.to(device) steps = 0 running_loss = 0 running_metric = 0 print_every = 100 train_losses = [] test_losses = [] train_metrics = [] test_metrics = [] if swa: # initialize stochastic weight averaging opt = SWA(optimizer) else: opt = optimizer # learning rate cosine annealing if enable_scheduler: scheduler = lr_scheduler.CosineAnnealingLR(optimizer, len(trainloader), eta_min=0.0000001) for epoch in range(epochs): if enable_scheduler: scheduler.step() for inputs, labels in trainloader: steps += 1 # Move input and label tensors to the default device inputs, labels = inputs.to(device), labels.to(device) opt.zero_grad() outputs = model.forward(inputs) loss = criterion(outputs, labels.float()) loss.backward() opt.step() running_loss += loss running_metric += metric(outputs, labels.float()) if steps % print_every == 0: test_loss = 0 test_metric = 0 model.eval() with torch.no_grad(): for inputs, labels in testloader: inputs, labels = inputs.to(device), labels.to(device) outputs = model.forward(inputs) test_loss += criterion(outputs, labels.float()) test_metric += metric(outputs, labels.float()) print(f"Epoch {epoch+1}/{epochs}.. " f"Train loss: {running_loss/print_every:.3f}.. " f"Test loss: {test_loss/len(testloader):.3f}.. " f"Train metric: {running_metric/print_every:.3f}.. " f"Test metric: {test_metric/len(testloader):.3f}.. ") train_losses.append(running_loss / print_every) test_losses.append(test_loss / len(testloader)) train_metrics.append(running_metric / print_every) test_metrics.append(test_metric / len(testloader)) running_loss = 0 running_metric = 0 model.train() if swa: opt.update_swa() save_model(model, model_arch, learning_rate, epochs, train_losses, test_losses, train_metrics, test_metrics, filepath='models_checkpoints') if swa: opt.swap_swa_sgd() return model, train_losses, test_losses, train_metrics, test_metrics
class HM: def __init__(self): if args.train is not None: self.train_tuple = get_tuple(args.train, bs=args.batch_size, shuffle=True, drop_last=False) if args.valid is not None: valid_bsize = 2048 if args.multiGPU else 50 self.valid_tuple = get_tuple(args.valid, bs=valid_bsize, shuffle=False, drop_last=False) else: self.valid_tuple = None # Select Model, X is default if args.model == "X": self.model = ModelX(args) elif args.model == "V": self.model = ModelV(args) elif args.model == "U": self.model = ModelU(args) elif args.model == "D": self.model = ModelD(args) elif args.model == 'O': self.model = ModelO(args) else: print(args.model, " is not implemented.") # Load pre-trained weights from paths if args.loadpre is not None: self.model.load(args.loadpre) # GPU options if args.multiGPU: self.model.lxrt_encoder.multi_gpu() self.model = self.model.cuda() # Losses and optimizer self.logsoftmax = nn.LogSoftmax(dim=1) self.nllloss = nn.NLLLoss() if args.train is not None: batch_per_epoch = len(self.train_tuple.loader) self.t_total = int(batch_per_epoch * args.epochs // args.acc) print("Total Iters: %d" % self.t_total) def is_backbone(n): if "encoder" in n: return True elif "embeddings" in n: return True elif "pooler" in n: return True print("F: ", n) return False no_decay = ['bias', 'LayerNorm.weight'] params = list(self.model.named_parameters()) if args.reg: optimizer_grouped_parameters = [ { "params": [p for n, p in params if is_backbone(n)], "lr": args.lr }, { "params": [p for n, p in params if not is_backbone(n)], "lr": args.lr * 500 }, ] for n, p in self.model.named_parameters(): print(n) self.optim = AdamW(optimizer_grouped_parameters, lr=args.lr) else: optimizer_grouped_parameters = [{ 'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 'weight_decay': args.wd }, { 'params': [p for n, p in params if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] self.optim = AdamW(optimizer_grouped_parameters, lr=args.lr) if args.train is not None: self.scheduler = get_linear_schedule_with_warmup( self.optim, self.t_total * 0.1, self.t_total) self.output = args.output os.makedirs(self.output, exist_ok=True) # SWA Method: if args.contrib: self.optim = SWA(self.optim, swa_start=self.t_total * 0.75, swa_freq=5, swa_lr=args.lr) if args.swa: self.swa_model = AveragedModel(self.model) self.swa_start = self.t_total * 0.75 self.swa_scheduler = SWALR(self.optim, swa_lr=args.lr) def train(self, train_tuple, eval_tuple): dset, loader, evaluator = train_tuple iter_wrapper = (lambda x: tqdm(x, total=len(loader)) ) if args.tqdm else (lambda x: x) print("Batches:", len(loader)) self.optim.zero_grad() best_roc = 0. ups = 0 total_loss = 0. for epoch in range(args.epochs): if args.reg: if args.model != "X": print(self.model.model.layer_weights) id2ans = {} id2prob = {} for i, (ids, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)): if ups == args.midsave: self.save("MID") self.model.train() if args.swa: self.swa_model.train() feats, boxes, target = feats.cuda(), boxes.cuda(), target.long( ).cuda() # Model expects visual feats as tuple of feats & boxes logit = self.model(sent, (feats, boxes)) # Note: LogSoftmax does not change order, hence there should be nothing wrong with taking it as our prediction # In fact ROC AUC stays the exact same for logsoftmax / normal softmax, but logsoftmax is better for loss calculation # due to stronger penalization & decomplexifying properties (log(a/b) = log(a) - log(b)) logit = self.logsoftmax(logit) score = logit[:, 1] if i < 1: print(logit[0, :].detach()) # Note: This loss is the same as CrossEntropy (We splitted it up in logsoftmax & neg. log likelihood loss) loss = self.nllloss(logit.view(-1, 2), target.view(-1)) # Scaling loss by batch size, as we have batches with different sizes, since we do not "drop_last" & dividing by acc for accumulation # Not scaling the loss will worsen performance by ~2abs% loss = loss * logit.size(0) / args.acc loss.backward() total_loss += loss.detach().item() # Acts as argmax - extracting the higher score & the corresponding index (0 or 1) _, predict = logit.detach().max(1) # Getting labels for accuracy for qid, l in zip(ids, predict.cpu().numpy()): id2ans[qid] = l # Getting probabilities for Roc auc for qid, l in zip(ids, score.detach().cpu().numpy()): id2prob[qid] = l if (i + 1) % args.acc == 0: nn.utils.clip_grad_norm_(self.model.parameters(), args.clip) self.optim.step() if (args.swa) and (ups > self.swa_start): self.swa_model.update_parameters(self.model) self.swa_scheduler.step() else: self.scheduler.step() self.optim.zero_grad() ups += 1 # Do Validation in between if ups % 250 == 0: log_str = "\nEpoch(U) %d(%d): Train AC %0.2f RA %0.4f LOSS %0.4f\n" % ( epoch, ups, evaluator.evaluate(id2ans) * 100, evaluator.roc_auc(id2prob) * 100, total_loss) # Set loss back to 0 after printing it total_loss = 0. if self.valid_tuple is not None: # Do Validation acc, roc_auc = self.evaluate(eval_tuple) if roc_auc > best_roc: best_roc = roc_auc best_acc = acc # Only save BEST when no midsave is specified to save space #if args.midsave < 0: # self.save("BEST") log_str += "\nEpoch(U) %d(%d): DEV AC %0.2f RA %0.4f \n" % ( epoch, ups, acc * 100., roc_auc * 100) log_str += "Epoch(U) %d(%d): BEST AC %0.2f RA %0.4f \n" % ( epoch, ups, best_acc * 100., best_roc * 100.) print(log_str, end='') with open(self.output + "/log.log", 'a') as f: f.write(log_str) f.flush() if (epoch + 1) == args.epochs: if args.contrib: self.optim.swap_swa_sgd() self.save("LAST" + args.train) def predict(self, eval_tuple: DataTuple, dump=None, out_csv=True): dset, loader, evaluator = eval_tuple id2ans = {} id2prob = {} for i, datum_tuple in enumerate(loader): ids, feats, boxes, sent = datum_tuple[:4] self.model.eval() if args.swa: self.swa_model.eval() with torch.no_grad(): feats, boxes = feats.cuda(), boxes.cuda() logit = self.model(sent, (feats, boxes)) # Note: LogSoftmax does not change order, hence there should be nothing wrong with taking it as our prediction logit = self.logsoftmax(logit) score = logit[:, 1] if args.swa: logit = self.swa_model(sent, (feats, boxes)) logit = self.logsoftmax(logit) _, predict = logit.max(1) for qid, l in zip(ids, predict.cpu().numpy()): id2ans[qid] = l # Getting probas for Roc Auc for qid, l in zip(ids, score.cpu().numpy()): id2prob[qid] = l if dump is not None: if out_csv == True: evaluator.dump_csv(id2ans, id2prob, dump) else: evaluator.dump_result(id2ans, dump) return id2ans, id2prob def evaluate(self, eval_tuple: DataTuple, dump=None): """Evaluate all data in data_tuple.""" id2ans, id2prob = self.predict(eval_tuple, dump=dump) acc = eval_tuple.evaluator.evaluate(id2ans) roc_auc = eval_tuple.evaluator.roc_auc(id2prob) return acc, roc_auc def save(self, name): if args.swa: torch.save(self.swa_model.state_dict(), os.path.join(self.output, "%s.pth" % name)) else: torch.save(self.model.state_dict(), os.path.join(self.output, "%s.pth" % name)) def load(self, path): print("Load model from %s" % path) state_dict = torch.load("%s" % path) new_state_dict = {} for key, value in state_dict.items(): # N_averaged is a key in SWA models we cannot load, so we skip it if key.startswith("n_averaged"): print("n_averaged:", value) continue # SWA Models will start with module if key.startswith("module."): new_state_dict[key[len("module."):]] = value else: new_state_dict[key] = value state_dict = new_state_dict self.model.load_state_dict(state_dict)
def main(args): # Create model directory for saving trained models if not os.path.exists(args.model_path): os.makedirs(args.model_path) test_dataset = a2d_dataset.A2DDataset(train_cfg, args.dataset_path) data_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True, num_workers=12) # you can make changes # Define model, Loss, and optimizer model = net('se_resnext101') model.cuda() model.load_state_dict( torch.load(os.path.join(args.model_path, 'net_F.ckpt'))) # for i, param in model.backbone.features.named_parameters(): # if 'stage4' or 'stage3' or 'stage2' or 'stage1'in i: # print(i) # param.requires_grad = True # else: # param.requires_grad = False criterion = nn.BCEWithLogitsLoss() # print(list(filter(lambda p: p.requires_grad, model.parameters()))) base_optimizer = optim.SGD(model.parameters(), lr=0.005, momentum=0.9) # base_optimizer = optim.SGD(list(filter(lambda p: p.requires_grad, model.parameters())), lr=0.005, momentum=0.9) optimizer = SWA(base_optimizer, swa_start=1, swa_freq=5, swa_lr=0.005) # optimizer = optim.SGD(model.parameters(), lr=0.05, momentum=0.9) # optimizer = optim.Adam(model.parameters(), lr=0.005, weight_decay=0.00005) # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.num_epochs, eta_min=5e-5, last_epoch=-1) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [5, 10, 20, 40, 75], gamma=0.25) model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0) # Train the models total_step = len(data_loader) best_P, best_R, best_F = 0, 0, 0 for epoch in range(args.num_epochs): # scheduler.step() print('epoch:{}, lr:{}'.format(epoch, scheduler.get_lr()[0])) t1 = time.time() for i, data in enumerate(data_loader): optimizer.zero_grad() # mini-batch images = data[0].to(device) labels = data[1].type(torch.FloatTensor).to(device) # Forward, backward and optimize outputs = model(images) loss = criterion(outputs, labels) with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() if (i + 1) % 1 == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2) optimizer.step() optimizer.zero_grad() # optimizer.swap_swa_sgd() # scheduler.step() # Log info if i % args.log_step == 0: print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format( epoch, args.num_epochs, i, total_step, loss.item())) optimizer.swap_swa_sgd() # Save the model checkpoints # if (i + 1) % args.save_step == 0: # torch.save(model.state_dict(), os.path.join( # args.model_path, 'net.ckpt')) scheduler.step() t2 = time.time() print('Time Spend per epoch: ', t2 - t1) if epoch > -1: val_dataset = a2d_dataset.A2DDataset(val_cfg, args.dataset_path) val_data_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1) X = np.zeros((val_data_loader.__len__(), 43)) Y = np.zeros((val_data_loader.__len__(), 43)) model.eval() with torch.no_grad(): for batch_idx, data in enumerate(val_data_loader): # mini-batch images = data[0].to(device) labels = data[1].type(torch.FloatTensor).to(device) # output = model(images).cpu().detach().numpy() output = model(images) output = (torch.nn.functional.sigmoid(output) ).cpu().detach().numpy() target = labels.cpu().detach().numpy() output[output >= 0.5] = 1 output[output < 0.5] = 0 X[batch_idx, :] = output Y[batch_idx, :] = target P = Precision(X, Y) R = Recall(X, Y) F = F1(X, Y) print('Precision: {:.1f} Recall: {:.1f} F1: {:.1f}'.format( 100 * P, 100 * R, 100 * F)) if (P > best_P): torch.save(model.state_dict(), os.path.join(args.model_path, 'net_P.ckpt')) best_P = P if (R > best_R): torch.save(model.state_dict(), os.path.join(args.model_path, 'net_R.ckpt')) best_R = R if (F > best_F): torch.save(model.state_dict(), os.path.join(args.model_path, 'net_F.ckpt')) best_F = F
class Optimizer: optimizer_cls = None optimizer = None parameters = None def __init__(self, gradient_clipping, swa_start=None, swa_freq=None, swa_lr=None, **kwargs): self.gradient_clipping = gradient_clipping self.optimizer_kwargs = kwargs self.swa_start = swa_start self.swa_freq = swa_freq self.swa_lr = swa_lr def set_parameters(self, parameters): self.parameters = tuple(parameters) self.optimizer = self.optimizer_cls(self.parameters, **self.optimizer_kwargs) if self.swa_start is not None: from torchcontrib.optim import SWA assert self.swa_freq is not None, self.swa_freq assert self.swa_lr is not None, self.swa_lr self.optimizer = SWA(self.optimizer, swa_start=self.swa_start, swa_freq=self.swa_freq, swa_lr=self.swa_lr) def check_if_set(self): assert self.optimizer is not None, \ 'The optimizer is not initialized, call set_parameter before' \ ' using any of the optimizer functions' def zero_grad(self): self.check_if_set() return self.optimizer.zero_grad() def step(self): self.check_if_set() return self.optimizer.step() def swap_swa_sgd(self): self.check_if_set() from torchcontrib.optim import SWA assert isinstance(self.optimizer, SWA), self.optimizer return self.optimizer.swap_swa_sgd() def clip_grad(self): self.check_if_set() # Todo: report clipped and unclipped # Todo: allow clip=None but still report grad_norm grad_clips = self.gradient_clipping return torch.nn.utils.clip_grad_norm_(self.parameters, grad_clips) def to(self, device): if device is None: return self.check_if_set() for state in self.optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(device) def cpu(self): return self.to('cpu') def cuda(self, device=None): assert device is None or isinstance(device, int), device if device is None: device = torch.device('cuda') return self.to(device) def load_state_dict(self, state_dict): self.check_if_set() return self.optimizer.load_state_dict(state_dict) def state_dict(self): self.check_if_set() return self.optimizer.state_dict()
class Trainer: def __init__( self, net: ModuloNet, metrics=["cohen_kappa", "f1", "accuracy"], epochs=30, metric_to_maximize="accuracy", patience=None, batch_size=32, save_folder=None, loss=None, regularization=None, swa=None, optimizer=None, num_workers=0, net_methods=None ): if optimizer is None: optimizer = {'type': 'adam', 'args': {'lr': 1e-3}} if loss is None: loss = {'type': 'cross_entropy', 'args': {}} self.net_methods = net_methods if net_methods is not None else [] print('METHODS') print(self.net_methods) self.net = net print('####################') print("Device: ", net.device) print('Using:', num_workers, ' workers') print('Trainable params', sum(p.numel() for p in net.parameters() if p.requires_grad)) print('Total params', sum(p.numel() for p in net.parameters() if p.requires_grad)) print('####################') self.loss_function = loss_functions[loss['type']](**loss['args']) self.optimizer_params = optimizer self.swa_params = swa self.regularization = [] if regularization is not None: for regularizer in regularization: self.regularization += [ regularizers[regularizer['type']](self.net, **regularizer['args'])] self.reset_optimizer() self.metrics = { score: score_function for score, score_function in score_functions.items() if score in metrics + [metric_to_maximize] } self.iterations = 0 self.epochs = epochs self.metric_to_maximize = metric_to_maximize self.patience = patience if patience else epochs self.loss_values = [] self.save_folder = save_folder self.batch_size = batch_size self.num_workers = num_workers def reset_optimizer(self): self.base_optimizer = optimizers[self.optimizer_params['type']](self.net.parameters(), **self.optimizer_params[ 'args']) if self.swa_params is not None: self.optimizer = SWA(self.base_optimizer, **self.swa_params) self.swa = True self.averaged_weights = False else: self.optimizer = self.base_optimizer self.swa = False def on_batch_start(self): pass def on_epoch_end(self): pass def validate(self, validation_dataset, return_metrics_per_records=False, verbose=False): self.net.eval() if self.swa: self.optimizer.swap_swa_sgd() self.averaged_weights = not self.averaged_weights metrics_epoch = { metric: [] for metric in self.metrics.keys() } metrics_per_records = {} hypnograms = {} predictions = self.net.predict_on_dataset(validation_dataset, return_prob=False, verbose=verbose) record_weights = [] for record in validation_dataset.records: metrics_per_records[record] = {} hypnogram_target = validation_dataset.hypnogram[record] hypnogram_predicted = predictions[record] hypnograms[os.path.split(record)[-2]] = { 'predicted': hypnogram_predicted.astype(int).tolist(), 'target': hypnogram_target.astype(int).tolist()} record_weights += [np.sum(hypnogram_target >= 0)] for metric, metric_function in self.metrics.items(): metric_value = metric_function(hypnogram_target, hypnogram_predicted) metrics_per_records[record][metric] = metric_value metrics_epoch[metric].append(metric_value) record_weights = np.array(record_weights) for metric in metrics_epoch.keys(): metrics_epoch[metric] = np.array(metrics_epoch[metric]) record_weights_tp = record_weights[~np.isnan(metrics_epoch[metric])] metrics_epoch[metric] = metrics_epoch[metric][~np.isnan(metrics_epoch[metric])] try: metrics_epoch[metric] = np.average(metrics_epoch[metric], weights=record_weights_tp) except ZeroDivisionError: metrics_epoch[metric] = np.nan if self.metric_to_maximize == metric: value = metrics_epoch[metric] if self.swa: if self.averaged_weights: self.optimizer.swap_swa_sgd() self.averaged_weights = not self.averaged_weights if return_metrics_per_records: return metrics_epoch, value, metrics_per_records, hypnograms else: return metrics_epoch, value def train_on_batch(self, data, mask=-1): # 1. train network # Set network in train mode self.net.train() # Retrieve inputs args, hypnogram = self.net.get_args(data) device = self.net.device hypnogram = hypnogram.to(device) mask = [hypnogram != mask] hypnogram = hypnogram[mask] # zero the network parameters gradien self.optimizer.zero_grad() # forward + backward output = self.net.forward(*args)[0] output = output[mask] loss_train = self.loss_function(output, hypnogram) if self.regularization is not None: for regularizer in self.regularization: regularizer.regularized_all_param(loss_train) loss_train.backward() if isinstance(hypnogram, tuple): hypnogram = hypnogram[0] self.iterations += 1 return output, loss_train, hypnogram def validate_on_batch(self, data, mask=-1): # 2. Evaluate network on validation # Set network in eval mode self.net.eval() # Retrieve inputs args, hypnogram = self.net.get_args(data) device = self.net.device hypnogram = hypnogram.to(device) mask = [hypnogram != mask] hypnogram = hypnogram[mask] # forward output = self.net.forward(*args)[0] output = output[mask] loss_validation = self.loss_function(output, hypnogram) if isinstance(hypnogram, tuple): hypnogram = hypnogram[0] return output, loss_validation, hypnogram def train(self, train_dataset, validation_dataset, verbose=1, reset_optimizer=True): """ for epoch: for batch on train set: train net with optimizer SGD eval on a random batch of validation set print metrics on train set and val set every 1% of dataset Evaluate metrics BY RECORD, take mean if metric_to_maximize value > best_value: store best_net* else: patience += 1 if patience to big: return """ if reset_optimizer: self.reset_optimizer() if self.save_folder: self.save_weights('best_net') dataloader_train = DataLoader( train_dataset, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True ) metrics_final = { metric: 0 for metric in self.metrics.keys() } best_value = 0 counter_patience = 0 for epoch in range(0, self.epochs): if verbose == 0: print('EPOCH:', epoch) # init running_loss running_loss_train_epoch = 0 running_metrics = {metric: 0 for metric in self.metrics.keys()} buffer_outputs_train = ([], []) if verbose > 0: # Configurate progress bar t = tqdm(dataloader_train, 0) t.set_description("EPOCH {}".format(epoch)) update_postfix_every = max(int(len(t) * 0.05), 1) counter_update_postfix = 0 else: t = dataloader_train t_start_train = time.time() for i, data in enumerate(t): self.on_batch_start() if verbose > 0: if (i + 1) % update_postfix_every == 0 and i != 0: # compute desired metrics each update_postfix_every for metric_name, metric_function in self.metrics.items(): running_metrics[metric_name] += metric_function( buffer_outputs_train[0], buffer_outputs_train[1] ) buffer_outputs_train = ([], []) counter_update_postfix += 1 t.set_postfix( loss=running_loss_train_epoch / (i + 1), **{ k: v / counter_update_postfix for k, v in running_metrics.items() } ) self.loss_values.append((running_loss_train_epoch, i + 1)) # train output, loss_train, hypnogram = self.train_on_batch(data) # fill metrics for print running_loss_train_epoch += loss_train.item() buffer_outputs_train[0].extend(list(output.max(1)[1].cpu().numpy())) buffer_outputs_train[1].extend(list(hypnogram.cpu().numpy().flatten())) # gradient descent self.optimizer.step() t_stop_train = time.time() t_start_validation = time.time() metrics_epoch, value = self.validate(validation_dataset=validation_dataset) t_stop_validation = time.time() metrics_epoch["training_duration"] = t_stop_train - t_start_train metrics_epoch["validation_duration"] = t_stop_validation - t_start_validation if self.save_folder: self.save_weights(str(epoch) + "_net") json.dump(metrics_epoch, open(self.save_folder + str(epoch) + "_metrics_epoch.json", "w")) if value > best_value: print("New best {} !".format(self.metric_to_maximize), value) # best_net = copy.deepcopy(self.net) metrics_final = { metric: metrics_epoch[metric] for metric in self.metrics.keys() } best_value = value counter_patience = 0 if self.save_folder: self.save_weights('best_net') json.dump(metrics_epoch, open(self.save_folder + "metrics_best_epoch.json", "w")) else: counter_patience += 1 if counter_patience > self.patience: break self.on_epoch_end() return metrics_final def save_weights(self, file_name): self.net.save(self.save_folder + file_name)
class Trainer(object): def __init__(self, args, surrogate, train_data, val_data, tflogger=None, pde=None): self.args = args self.pde = pde self.surrogate = surrogate self.tflogger = tflogger self.train_data = train_data self.val_data = val_data # self.init_transformations() self.train_loader = DataLoader(self.train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True) self.val_loader = DataLoader(self.val_data, batch_size=args.batch_size, shuffle=False, pin_memory=True) self.init_optimizer() self.train_f_std = _cuda(torch.Tensor([[1.0]])) self.train_J_std = _cuda(torch.Tensor([[1.0]])) def init_transformations(self): # Init transformations corresponding to rotations, flips d = self.surrogate.fsm.vector_dim v2r = self.surrogate.fsm.vec_to_ring_map.cpu() v2r_perm = torch.argmax(v2r, dim=1).cpu().numpy() r2v_perm = torch.argmax(v2r, dim=0).cpu().numpy() # rotated = R * original p1 = np.array([i for i in range(d)]) p2 = np.mod(p1 + int(d / 4), d) # rotate 90 p3 = np.mod(p2 + int(d / 4), d) # rotate 180 p4 = np.mod(p3 + int(d / 4), d) # rotate 270 # flip about 0th loc which is two points in 2d # TODO: make this dimension-agnostic def flip(p): p = p.reshape(-1, 2) p = np.concatenate(([p[0]], p[1:][::-1]), axis=0) return p.reshape(-1) ps = [p1, p2, p3, p4] # , p2, p3, p4] # ps = ps + [flip(p) for p in ps] # ps = [p.tolist() for p in ps] self.perms = [v2r_perm[p[r2v_perm]].tolist() for p in ps] # trans_mats = [torch.eye(d)[p] for p in ps] # self.trans_mats = [torch.matmul(torch.matmul( # self.surrogate.fsm.vec_to_ring_map, # _cuda(m)), self.surrogate.fsm.vec_to_ring_map.t()) # for m in trans_mats] N = self.train_data.size() self.train_data.memory_size = self.train_data.memory_size * 8 us_orig = torch.stack([td[0] for td in self.train_data.data]) Js_orig = torch.stack([td[3] for td in self.train_data.data]) Hs_orig = torch.stack([td[4] for td in self.train_data.data]) for pn, perm in enumerate(self.perms): print("proc perm ", pn) print("perm: ", perm) us = us_orig[:, perm] Js = Js_orig[:, perm] Hs = Hs_orig[:, perm, :] Hs = Hs[:, :, perm] for i in range(len(us)): self.train_data.feed( Example( us[i].clone().detach(), self.train_data.data[i][1].clone().detach(), self.train_data.data[i][2].clone().detach(), Js[i].clone().detach(), Hs[i].clone().detach(), )) """ for n in range(N): if n % 1000 == 0: print('preproc perm {}, n {}'.format(pn, n)) u, p, f, J, H = self.train_data.data[n] u = u[perm] J = J[perm] H = H[[perm for _ in range(len(perm))], [[i for i in range(len(perm))] for _ in range(len(perm))]] H = H[[[i for i in range(len(perm))] for _ in range(len(perm))], [perm for _ in range(len(perm))]] self.train_data.feed(Example(u, p, f, J, H)) """ """ ii = [[i for _ in range(self.surrogate.fsm.vector_dim)] for i in range(len(u))] jjs = [self.perms[np.random.choice(len(self.perms))] for _ in range(len(u))] u = u[iis, jjs] J = J[iis, jjs] iis = [[ii for _ in range(self.surrogate.fsm.vector_dim)] for ii in iis] jjs = [[jj for _ in range(self.surrogate.fsm.vector_dim)] for jj in jjs] kks = [[[k for k in range(self.surrogate.fsm.vector_dim)] for _ in range(self.surrogate.fsm.vector_dim)] for _ in range(len(u))] # pdb.set_trace() H = H[iis, jjs, kks] H = H[iis, kks, jjs] """ # for tm in self.trans_mats: def init_optimizer(self): # Create optimizer if surrogate is trainable if hasattr(self.surrogate, "parameters"): if self.args.optimizer == "adam" or self.args.optimizer == "amsgrad": self.optimizer = torch.optim.AdamW( (p for p in self.surrogate.parameters() if p.requires_grad), self.args.lr, weight_decay=self.args.wd, amsgrad=(self.args.optimizer == "amsgrad"), ) elif self.args.optimizer == "sgd": self.optimizer = torch.optim.SGD( (p for p in self.surrogate.parameters() if p.requires_grad), self.args.lr, momentum=0.9, weight_decay=self.args.wd, ) elif self.args.optimizer == "radam": from ..util.radam import RAdam self.optimizer = RAdam( (p for p in self.surrogate.parameters() if p.requires_grad), self.args.lr, weight_decay=self.args.wd, betas=ast.literal_eval(self.args.adam_betas), ) else: raise Exception("Unknown optimizer") ''' if self.args.fix_batch: self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, patience=10 if self.args.fix_batch else 20 * len(self.train_data) // self.args.batch_size, verbose=self.args.verbose, factor=1.0 / np.sqrt(np.sqrt(np.sqrt(2))), ) else: """ self.scheduler = torch.optim.lr_scheduler.MultiStepLR( self.optimizer, gamma=0.1, milestones=[1e2,5e2,2e3,1e4,1e5]) """ ''' if self.args.swa: from torchcontrib.optim import SWA self.optimizer = SWA( self.optimizer, swa_start=self.args.swa_start * len(self.train_loader), swa_freq=len(self.train_loader), swa_lr=self.args.lr, ) else: self.scheduler = torch.optim.lr_scheduler.CyclicLR( self.optimizer, base_lr=self.args.lr * 1e-3, max_lr=3 * self.args.lr, step_size_up=int(math.ceil(len(self.train_loader) / 2)), step_size_down=int(math.floor(len(self.train_loader) / 2)), mode="triangular", scale_fn=None, scale_mode="cycle", cycle_momentum=False, base_momentum=0.8, max_momentum=0.9, last_epoch=-1, ) else: self.optimizer = None def cd_step(self, step, batch): """Do a single step of CD training. Log stats to tensorboard.""" self.surrogate.net.train() if self.optimizer: self.optimizer.zero_grad() u, p, f, J, H, _ = batch u, p, f, J, H = _cuda(u), _cuda(p), _cuda(f), _cuda(J), _cuda(H) u_sgld = torch.autograd.Variable(u, requires_grad=True) lam, eps, TEMP = ( self.args.cd_sgld_lambda, self.args.cd_sgld_eps, self.args.cd_sgld_temp, ) for i in range(self.args.cd_sgld_steps): temp = TEMP * self.args.cd_sgld_steps / (i + 1) u_sgld = (u_sgld - 0.5 * lam * torch.autograd.grad( torch.log(self.surrogate.f(u_sgld, p) + eps).sum() / temp, u_sgld)[0].clamp(-0.1, +0.1) + lam * _cuda(torch.randn(*u.size()))) eplus = torch.log(self.surrogate.f(u.detach(), p) + eps).mean() eminus = torch.log(self.surrogate.f(u_sgld.detach(), p) + eps).mean() cd_loss = eplus - eminus (cd_loss * self.args.cd_weight).backward() if self.args.clip_grad_norm is not None: torch.nn.utils.clip_grad_norm_(self.surrogate.net.parameters(), self.args.clip_grad_norm) self.optimizer.step() if self.tflogger is not None: self.tflogger.log_scalar("cd_E+_mean", eplus.item(), step) self.tflogger.log_scalar("cd_E-_mean", eminus.item(), step) self.tflogger.log_scalar("cd_loss", cd_loss.item(), step) def train_step(self, step, batch): """Do a single step of Sobolev training. Log stats to tensorboard.""" self.surrogate.net.train() if self.optimizer: self.optimizer.zero_grad() u, p, f, J, H, _ = batch with Timer() as timer: u, p, f, J, H = _cuda(u), _cuda(p), _cuda(f), _cuda(J), _cuda(H) if self.args.poisson: u = self.surrogate.fsm.to_ring(u) u[:, 0] = 0.0 u = self.surrogate.fsm.to_torch(u) # pdb.set_trace() # T = torch.stack([self.trans_mats[i] # for i in np.random.choice(len(self.trans_mats), size=len(u))]) # pdb.set_trace() # u, J, H = ( # torch.matmul(u.unsqueeze(1), T).squeeze(), # torch.matmul(J.unsqueeze(1), T).squeeze(), # torch.matmul(torch.matmul(T.permute(0, 2, 1), H), T) # ) """ iis = [[i for _ in range(self.surrogate.fsm.vector_dim)] for i in range(len(u))] jjs = [self.perms[np.random.choice(len(self.perms))] for _ in range(len(u))] u = u[iis, jjs] J = J[iis, jjs] iis = [[ii for _ in range(self.surrogate.fsm.vector_dim)] for ii in iis] jjs = [[jj for _ in range(self.surrogate.fsm.vector_dim)] for jj in jjs] kks = [[[k for k in range(self.surrogate.fsm.vector_dim)] for _ in range(self.surrogate.fsm.vector_dim)] for _ in range(len(u))] # pdb.set_trace() H = H[iis, jjs, kks] H = H[iis, kks, jjs] """ # pdb.set_trace() if self.tflogger is not None: self.tflogger.log_scalar("batch_cuda_time", timer.interval, step) with Timer() as timer: if self.args.hess: vectors = torch.randn(*J.size()).to(J.device) fhat, Jhat, Hvphat = self.surrogate.f_J_Hvp(u, p, vectors=vectors) Hvp = (vectors.view(*J.size(), 1) * H).sum(dim=1) else: fhat, Jhat = self.surrogate.f_J(u, p) Hvphat = torch.zeros_like(Jhat) Hvp = torch.zeros_like(Jhat) if not self.args.poisson: fhat[f.view(-1) < 0] *= 0.0 Jhat[f.view(-1) < 0] *= 0.0 Hvphat[f.view(-1) < 0] *= 0.0 Hvp[f.view(-1) < 0] *= 0.0 J[f.view(-1) < 0] *= 0.0 f[f.view(-1) < 0] *= 0.0 # pdb.set_trace() if self.tflogger is not None: self.tflogger.log_scalar("batch_forward_time", timer.interval, step) with Timer() as timer: f_loss, f_pce, J_loss, J_sim, H_loss, H_sim, total_loss = self.stats( step, u, f, J, Hvp, fhat, Jhat, Hvphat) if self.tflogger is not None: self.tflogger.log_scalar("stats_forward_time", timer.interval, step) if not np.isfinite(total_loss.data.cpu().numpy().sum()): pdb.set_trace() with Timer() as timer: if self.optimizer: total_loss.backward() if self.args.verbose: log([ getattr(p.grad, "data", torch.Tensor([0.0 ])).norm().cpu().numpy().sum() for p in self.surrogate.net.parameters() ]) if self.args.clip_grad_norm is not None: torch.nn.utils.clip_grad_norm_( self.surrogate.net.parameters(), self.args.clip_grad_norm) self.optimizer.step() if not self.args.swa: self.scheduler.step() if self.args.verbose: log("lr: {}".format(self.optimizer.param_groups[0]["lr"])) if self.tflogger is not None: self.tflogger.log_scalar("backward_time", timer.interval, step) # pdb.set_trace() return ( f_loss.item(), f_pce.item(), J_loss.item(), J_sim.item(), H_loss.item(), H_sim.item(), total_loss.item(), ) def val_step(self, step): """Do a single validation step. Log stats to tensorboard.""" self.surrogate.net.eval() for i, batch in enumerate(self.val_loader): u, p, f, J, H, _ = batch u, p, f, J, H = _cuda(u), _cuda(p), _cuda(f), _cuda(J), _cuda(H) if self.args.poisson: u = self.surrogate.fsm.to_ring(u) u[:, 0] = 0.0 u = self.surrogate.fsm.to_torch(u) if self.args.hess: vectors = torch.randn(*J.size()).to(J.device) fhat, Jhat, Hvphat = self.surrogate.f_J_Hvp(u, p, vectors=vectors) Hvp = (vectors.view(*J.size(), 1) * H).sum(dim=1) else: fhat, Jhat = self.surrogate.f_J(u, p) Hvphat = torch.zeros_like(Jhat) Hvp = torch.zeros_like(Jhat) if not self.args.poisson: fhat[f.view(-1) < 0] *= 0.0 Jhat[f.view(-1) < 0] *= 0.0 Hvphat[f.view(-1) < 0] *= 0.0 Hvp[f.view(-1) < 0] *= 0.0 J[f.view(-1) < 0] *= 0.0 f[f.view(-1) < 0] *= 0.0 u_ = torch.cat([u_, u.data], dim=0) if i > 0 else u.data f_ = torch.cat([f_, f.data], dim=0) if i > 0 else f.data J_ = torch.cat([J_, J.data], dim=0) if i > 0 else J.data Hvp_ = torch.cat([Hvp_, Hvp.data], dim=0) if i > 0 else Hvp.data fhat_ = torch.cat([fhat_, fhat.data], dim=0) if i > 0 else fhat.data Jhat_ = torch.cat([Jhat_, Jhat.data], dim=0) if i > 0 else Jhat.data Hvphat_ = torch.cat([Hvphat_, Hvphat.data], dim=0) if i > 0 else Hvphat.data return list(r.item() for r in self.stats( step, u_, f_, J_, Hvp_, fhat_, Jhat_, Hvphat_, phase="val")) def visualize(self, step, batch, dataset_name): u, p, f, J, H, _ = batch u, p, f, J = u[:16], p[:16], f[:16], J[:16] if self.args.poisson: u = self.surrogate.fsm.to_ring(u) u[:, 0] = 0.0 u = self.surrogate.fsm.to_torch(u).cpu() fhat, Jhat = self.surrogate.f_J(u, p) assert len(u) <= 16 assert len(u) == len(p) assert len(u) == len(f) assert len(u) == len(J) assert len(u) == len(fhat) assert len(u) == len(Jhat) u, f, J, fhat, Jhat = u.cpu(), f.cpu(), J.cpu(), fhat.cpu(), Jhat.cpu() cuda = self.surrogate.fsm.cuda self.surrogate.fsm.cuda = False fig, axes = plt.subplots(4, 4, figsize=(16, 16)) axes = [ax for axs in axes for ax in axs] RCs = self.surrogate.fsm.ring_coords.cpu().detach().numpy() rigid_remover = RigidRemover(self.surrogate.fsm) for i in range(len(u)): ax = axes[i] locs = RCs + self.surrogate.fsm.to_ring(u[i]).detach().numpy() plot_boundary( lambda x: (0, 0), 1000, label="reference, f={:.3e}".format(f[i].item()), ax=ax, color="k", ) plot_boundary( self.surrogate.fsm.get_query_fn(u[i]), 1000, ax=ax, label="ub, fhat={:.3e}".format(fhat[i].item()), linestyle="-", color="darkorange", ) plot_boundary( self.surrogate.fsm.get_query_fn( rigid_remover(u[i].unsqueeze(0)).squeeze(0)), 1000, ax=ax, label="rigid removed", linestyle="--", color="blue", ) if J is not None and Jhat is not None: J_ = self.surrogate.fsm.to_ring(J[i]) Jhat_ = self.surrogate.fsm.to_ring(Jhat[i]) normalizer = np.mean( np.nan_to_num([ J_.norm(dim=1).detach().numpy(), Jhat_.norm(dim=1).detach().numpy(), ])) plot_vectors( locs, J_.detach().numpy(), ax=ax, label="J", color="darkgreen", normalizer=normalizer, scale=1.0, ) plot_vectors( locs, Jhat_.detach().numpy(), ax=ax, color="darkorchid", label="Jhat", normalizer=normalizer, scale=1.0, ) plot_vectors( locs, (J_ - Jhat_).detach().numpy(), ax=ax, color="red", label="residual J-Jhat", normalizer=normalizer, scale=1.0, ) ax.legend() fig.canvas.draw() buf = io.BytesIO() plt.savefig(buf, format="png") buf.seek(0) plt.close() if self.tflogger is not None: self.tflogger.log_images("{} displacements".format(dataset_name), [buf], step) if J is not None and Jhat is not None: fig, axes = plt.subplots(4, 4, figsize=(16, 16)) axes = [ax for axs in axes for ax in axs] for i in range(len(J)): ax = axes[i] J_ = self.surrogate.fsm.to_ring(J[i]) Jhat_ = self.surrogate.fsm.to_ring(Jhat[i]) normalizer = 10 * np.mean( np.nan_to_num([ J_.norm(dim=1).detach().numpy(), Jhat_.norm(dim=1).detach().numpy(), ])) plot_boundary(lambda x: (0, 0), 1000, label="reference", ax=ax) plot_boundary( self.surrogate.fsm.get_query_fn(J[i]), 100, ax=ax, label="true_J, f={:.3e}".format(f[i].item()), linestyle="--", normalizer=normalizer, ) plot_boundary( self.surrogate.fsm.get_query_fn(Jhat[i]), 100, ax=ax, label="surrogate_J, fhat={:.3e}".format(fhat[i].item()), linestyle="-.", normalizer=normalizer, ) ax.legend() fig.canvas.draw() buf = io.BytesIO() plt.savefig(buf, format="png") buf.seek(0) plt.close() if self.tflogger is not None: self.tflogger.log_images("{} Jacobians".format(dataset_name), [buf], step) self.surrogate.fsm.cuda = cuda def stats(self, step, u, f, J, Hvp, fhat, Jhat, Hvphat, phase="train"): """Take ground truth and predictions. Log stats and return loss.""" if self.args.l1_loss: f_loss = torch.nn.functional.l1_loss( self.surrogate.scaler.scale(f + EPS, u), self.surrogate.scaler.scale(fhat + EPS, u), ) else: f_loss = torch.nn.functional.mse_loss( self.surrogate.scaler.scale(f + EPS, u), self.surrogate.scaler.scale(fhat + EPS, u), ) if self.args.angle_magnitude: J_loss = (self.args.mag_weight * torch.nn.functional.mse_loss( torch.log(J.norm(dim=1) + EPS), torch.log(Jhat.norm(dim=1) + EPS)) + 1.0 - similarity(J, Jhat)) H_loss = (self.args.mag_weight * torch.nn.functional.mse_loss( torch.log(Hvp.norm(dim=1) + EPS), torch.log(Hvphat.norm(dim=1) + EPS), ) + 1.0 - similarity(Hvp, Hvphat)) else: J_loss = torch.nn.functional.mse_loss(J, Jhat) H_loss = torch.nn.functional.mse_loss(Hvp, Hvphat) total_loss = f_loss + self.args.J_weight * J_loss + self.args.H_weight * H_loss f_pce = error_percent(f, fhat) J_pce = error_percent(J, Jhat) H_pce = error_percent(Hvp, Hvphat) J_sim = similarity(J, Jhat) H_sim = similarity(Hvp, Hvphat) if self.tflogger is not None: self.tflogger.log_scalar("train_set_size", len(self.train_data.data), step) self.tflogger.log_scalar("val_set_size", len(self.val_data.data), step) if hasattr(self.surrogate, "net") and hasattr( self.surrogate.net, "parameters"): self.tflogger.log_scalar( "param_norm_sum", sum([ p.norm().sum().item() for p in self.surrogate.net.parameters() ]), step, ) self.tflogger.log_scalar("total_loss_" + phase, total_loss.item(), step) self.tflogger.log_scalar("f_loss_" + phase, f_loss.item(), step) self.tflogger.log_scalar("f_pce_" + phase, f_pce.item(), step) self.tflogger.log_scalar("f_mean_" + phase, f.mean().item(), step) self.tflogger.log_scalar("f_std_" + phase, f.std().item(), step) self.tflogger.log_scalar("fhat_mean_" + phase, fhat.mean().item(), step) self.tflogger.log_scalar("fhat_std_" + phase, fhat.std().item(), step) self.tflogger.log_scalar("J_loss_" + phase, J_loss.item(), step) self.tflogger.log_scalar("J_pce_" + phase, J_pce.item(), step) self.tflogger.log_scalar("J_sim_" + phase, J_sim.item(), step) self.tflogger.log_scalar("H_loss_" + phase, H_loss.item(), step) self.tflogger.log_scalar("H_pce_" + phase, H_pce.item(), step) self.tflogger.log_scalar("H_sim_" + phase, H_sim.item(), step) self.tflogger.log_scalar("J_mean_" + phase, J.mean().item(), step) self.tflogger.log_scalar("J_std_mean_" + phase, J.std(dim=1).mean().item(), step) self.tflogger.log_scalar("Jhat_mean_" + phase, Jhat.mean().item(), step) self.tflogger.log_scalar("Jhat_std_mean_" + phase, Jhat.std(dim=1).mean().item(), step) return (f_loss, f_pce, J_loss, J_sim, H_loss, H_sim, total_loss)
def fit( model, train_dataset, val_dataset, optimizer_name="adam", samples_per_player=0, epochs=50, batch_size=32, val_bs=32, warmup_prop=0.1, lr=1e-3, acc_steps=1, swa_first_epoch=50, num_classes_aux=0, aux_mode="sigmoid", verbose=1, first_epoch_eval=0, device="cuda", ): """ Fitting function for the classification task. Args: model (torch model): Model to train. train_dataset (torch dataset): Dataset to train with. val_dataset (torch dataset): Dataset to validate with. optimizer_name (str, optional): Optimizer name. Defaults to 'adam'. samples_per_player (int, optional): Number of images to use per player. Defaults to 0. epochs (int, optional): Number of epochs. Defaults to 50. batch_size (int, optional): Training batch size. Defaults to 32. val_bs (int, optional): Validation batch size. Defaults to 32. warmup_prop (float, optional): Warmup proportion. Defaults to 0.1. lr (float, optional): Learning rate. Defaults to 1e-3. acc_steps (int, optional): Accumulation steps. Defaults to 1. swa_first_epoch (int, optional): Epoch to start applying SWA from. Defaults to 50. num_classes_aux (int, optional): Number of auxiliary classes. Defaults to 0. aux_mode (str, optional): Mode for auxiliary classification. Defaults to 'sigmoid'. verbose (int, optional): Period (in epochs) to display logs at. Defaults to 1. first_epoch_eval (int, optional): Epoch to start evaluating at. Defaults to 0. device (str, optional): Device for torch. Defaults to "cuda". Returns: numpy array [len(val_dataset)]: Last predictions on the validation data. numpy array [len(val_dataset) x num_classes_aux]: Last aux predictions on the val data. """ optimizer = define_optimizer(optimizer_name, model.parameters(), lr=lr) if swa_first_epoch <= epochs: optimizer = SWA(optimizer) loss_fct = nn.BCEWithLogitsLoss() loss_fct_aux = nn.BCEWithLogitsLoss( ) if aux_mode == "sigmoid" else nn.CrossEntropyLoss() aux_loss_weight = 1 if num_classes_aux else 0 if samples_per_player: sampler = PlayerSampler( RandomSampler(train_dataset), train_dataset.players, batch_size=batch_size, drop_last=True, samples_per_player=samples_per_player, ) train_loader = DataLoader( train_dataset, batch_sampler=sampler, num_workers=NUM_WORKERS, pin_memory=True, ) print( f"Using {len(train_loader)} out of {len(train_dataset) // batch_size} " f"batches by limiting to {samples_per_player} samples per player.\n" ) else: train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=NUM_WORKERS, pin_memory=True, ) val_loader = DataLoader( val_dataset, batch_size=val_bs, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, ) num_training_steps = int(epochs * len(train_loader)) num_warmup_steps = int(warmup_prop * num_training_steps) scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps) for epoch in range(epochs): model.train() start_time = time.time() optimizer.zero_grad() avg_loss = 0 if epoch + 1 > swa_first_epoch: optimizer.swap_swa_sgd() for batch in train_loader: images = batch[0].to(device) y_batch = batch[1].to(device).view(-1).float() y_batch_aux = batch[2].to(device).float() y_batch_aux = y_batch_aux.float( ) if aux_mode == "sigmoid" else y_batch_aux.long() y_pred, y_pred_aux = model(images) loss = loss_fct(y_pred.view(-1), y_batch) if aux_loss_weight: loss += aux_loss_weight * loss_fct_aux(y_pred_aux, y_batch_aux) loss.backward() avg_loss += loss.item() / len(train_loader) optimizer.step() scheduler.step() for param in model.parameters(): param.grad = None if epoch + 1 >= swa_first_epoch: optimizer.update_swa() optimizer.swap_swa_sgd() preds = np.empty(0) preds_aux = np.empty((0, num_classes_aux)) model.eval() avg_val_loss, auc, scores_aux = 0., 0., 0. if epoch + 1 >= first_epoch_eval or epoch + 1 == epochs: with torch.no_grad(): for batch in val_loader: images = batch[0].to(device) y_batch = batch[1].to(device).view(-1).float() y_aux = batch[2].to(device).float() y_batch_aux = y_aux.float( ) if aux_mode == "sigmoid" else y_aux.long() y_pred, y_pred_aux = model(images) loss = loss_fct(y_pred.detach().view(-1), y_batch) if aux_loss_weight: loss += aux_loss_weight * loss_fct_aux( y_pred_aux.detach(), y_batch_aux) avg_val_loss += loss.item() / len(val_loader) y_pred = torch.sigmoid(y_pred).view(-1) preds = np.concatenate( [preds, y_pred.detach().cpu().numpy()]) if num_classes_aux: y_pred_aux = (y_pred_aux.sigmoid() if aux_mode == "sigmoid" else y_pred_aux.softmax(-1)) preds_aux = np.concatenate( [preds_aux, y_pred_aux.detach().cpu().numpy()]) auc = roc_auc_score(val_dataset.labels, preds) if num_classes_aux: if aux_mode == "sigmoid": scores_aux = np.round( [ roc_auc_score(val_dataset.aux_labels[:, i], preds_aux[:, i]) for i in range(num_classes_aux) ], 3, ).tolist() else: scores_aux = np.round( [ roc_auc_score((val_dataset.aux_labels == i).astype(int), preds_aux[:, i]) for i in range(num_classes_aux) ], 3, ).tolist() else: scores_aux = 0 elapsed_time = time.time() - start_time if (epoch + 1) % verbose == 0: elapsed_time = elapsed_time * verbose lr = scheduler.get_last_lr()[0] print( f"Epoch {epoch + 1:02d}/{epochs:02d} \t lr={lr:.1e}\t t={elapsed_time:.0f}s \t" f"loss={avg_loss:.3f}", end="\t", ) if epoch + 1 >= first_epoch_eval: print( f"val_loss={avg_val_loss:.3f} \t auc={auc:.3f}\t aucs_aux={scores_aux}" ) else: print("") del val_loader, train_loader, y_pred torch.cuda.empty_cache() return preds, preds_aux
def train_model(model,criterion, optimizer, lr_scheduler,arc_model=None): train_dataset = gaodeDataset(opt.trainValConcat_dir, opt.train_list, phase='train', input_size=opt.input_size) trainloader = DataLoader(train_dataset, batch_size=opt.train_batch_size, shuffle=True, num_workers=opt.num_workers) total_iters=len(trainloader) logger.info('total_iters:{}'.format(total_iters)) model_name=opt.backbone train_loss = [] since = time.time() best_model_wts = model.state_dict() best_score = 0.0 model.train(True) logger.info('start training...') # if lr_scheduler is cos_lr_scheduler: return_lr_scheduler = lr_scheduler(optimizer, 4) if lr_scheduler is exp_lr_scheduler: return_lr_scheduler = lr_scheduler(optimizer) # optimizer=SWA(optimizer,swa_start=10, swa_freq=5) for epoch in range(1,opt.max_epoch+1): begin_time=time.time() #logger.info('learning rate:{}'.format(optimizer.param_groups[-1]['lr'])) logger.info('Epoch {}/{}'.format(epoch, opt.max_epoch)) logger.info('-' * 10) #optimizer = lr_scheduler(optimizer, epoch) running_loss = 0.0 running_corrects_linear = 0 running_corrects_arc=0 count=0 iters=len(trainloader) for i, data in enumerate(trainloader): count+=1 inputs, labels = data labels = labels.type(torch.LongTensor) inputs, labels = inputs.cuda(), labels.cuda() # out_linear= model(inputs) _, linear_preds = torch.max(out_linear.data, 1) loss = criterion(out_linear, labels) # optimizer.zero_grad() loss.backward() optimizer.step() if lr_scheduler is cos_lr_scheduler: return_lr_scheduler.step(epoch + i / iters) if i % opt.print_interval == 0 or out_linear.size()[0] < opt.train_batch_size: spend_time = time.time() - begin_time logger.info( ' Epoch:{}({}/{}) loss:{:.3f} lr:{:.7f} epoch_Time:{}min:'.format( epoch, count, total_iters, loss.item(),optimizer.param_groups[-1]['lr'], spend_time / count * total_iters // 60 - spend_time // 60)) train_loss.append(loss.item()) running_corrects_linear += torch.sum(linear_preds == labels.data) # if lr_scheduler is exp_lr_scheduler: return_lr_scheduler.step() weight_score = val_model(model, criterion) epoch_acc_linear = running_corrects_linear.double() / total_iters / opt.train_batch_size logger.info('Epoch:[{}/{}] train_acc={:.3f} '.format(epoch, opt.max_epoch, epoch_acc_linear)) save_dir = os.path.join(opt.checkpoints_dir, model_name) if not os.path.exists(save_dir): os.makedirs(save_dir) model_out_path = save_dir + "/" + '{}_'.format(model_name) + str(epoch) + '.pth' best_model_out_path = save_dir + "/" + '{}_'.format(model_name) + 'best' + '.pth' #save the best model if weight_score > best_score: best_score = weight_score torch.save(model.state_dict(), best_model_out_path) #save based on epoch interval if epoch % opt.save_interval == 0 and epoch>opt.min_save_epoch: torch.save(model.state_dict(), model_out_path) # optimizer.swap_swa_sgd # logger.info('Best WeightF1: {:.3f}'.format(best_score)) time_elapsed = time.time() - since logger.info('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
def main(args): # Parameters for toy data and experiments plt.style.use('dark_background') rng_seed = 4200 np.random.seed(rng_seed) wstar = torch.tensor([[0.973, 1.144]], dtype=torch.float) xs = torch.tensor(np.random.randn(50, 2), dtype=torch.float) labels = torch.mm(xs, wstar.T) p = torch.tensor(np.random.uniform(0.05, 1.0, xs.shape[0]), dtype=torch.float) ips = 1.0 / p n_iters = 500 plot_every = n_iters // 10 arrow_width = 0.012 * 0.01 legend = {} data = {} colors = {} interpsize = 30 # Figure fig = plt.figure(figsize=(8.4,5.4), dpi=150) ax = fig.add_subplot(111) ax.set_xlim([-0.3, 1.8]) ax.set_ylim([-0.3, 1.8]) # The loss function we want to optimize def loss_fn(out, y, mult): l2loss = (out - y) ** 2.0 logl2loss = torch.log(1.0 + (out - y) ** 2.0) return torch.mean(mult * l2loss) # True IPS-weighted loss over all datapoints, used for plotting contour def f(x1, x2): w = torch.tensor([[x1], [x2]], dtype=torch.float) o = torch.mm(xs, w) return float(loss_fn(o, torch.mm(xs, wstar.reshape((2, 1))), ips)) # Plot contour true_x1 = np.linspace(float(wstar[0, 0]) - 1.5, float(wstar[0, 0]) + 0.8) # - 1.5 / + 1.0 true_x2 = np.linspace(float(wstar[0, 1]) - 1.5, float(wstar[0, 1]) + 1.2) # - 1.5 / + 1.0 true_x1, true_x2 = np.meshgrid(true_x1, true_x2) true_y = np.array([ [f(true_x1[i1, i2], true_x2[i1, i2]) for i2 in range(len(true_x2))] for i1 in range(len(true_x1)) ]) ax.contour(true_x1, true_x2, true_y, levels=10, colors='white', alpha=0.45) plt.plot(0.0, 0.0, marker='o', markersize=6, color="white") plt.plot(wstar[0, 0], wstar[0, 1], marker='*', markersize=6, color="white") # IPS-weighted approach for color_index, lr in enumerate([0.1, 0.01, 0.03]): # 0.01, 0.03, 0.05, 0.1, 0.3 if lr == 0.01: color = "violet" elif lr == 0.03: color = "orange" else: color = "lightgreen" #color = "C%d" % (color_index + 2) model = torch.nn.Linear(2, 1) optimizer = torch.optim.SGD(model.parameters(), lr=lr) optimizer = SWA(optimizer, swa_start=0, swa_freq=1, swa_lr=lr) with torch.no_grad(): model.bias.zero_() model.weight.zero_() old_weights = np.copy(model.weight.data.numpy()) np.random.seed(rng_seed + color_index + 1)# + color_index + 1) label = f"IPS-SGD ($\\eta={lr}$)" data[label] = np.zeros((3, n_iters * interpsize + 1)) colors[label] = color for t in range(n_iters): i = np.random.randint(xs.shape[0]) x = xs[i, :] y = labels[i] optimizer.zero_grad() o = model(x) l = loss_fn(o, y, ips[i]) l.backward() optimizer.step() # Record current iteration performance and location optimizer.swap_swa_sgd() x, y = model.weight.data.numpy()[0] optimizer.swap_swa_sgd() old_x, old_y, old_z = data[label][:, t * interpsize] xr = np.linspace(old_x, x, num=interpsize) yr = np.linspace(old_y, y, num=interpsize) for i in range(interpsize): data[label][:, 1 + t * interpsize + i] = np.array([ xr[i], yr[i], f(xr[i], yr[i])]) #data[label][:, t] = np.array([x, y, f(x, y)]) # Sample based approach lr = 10.0 # 1.0 model = torch.nn.Linear(2, 1) optimizer = torch.optim.SGD(model.parameters(), lr=lr) optimizer = SWA(optimizer, swa_start=0, swa_freq=1, swa_lr=lr) with torch.no_grad(): model.bias.zero_() model.weight.zero_() old_weights = np.copy(model.weight.data.numpy()) sample_probs = np.array(ips / torch.sum(ips)) Mbar = float(np.mean(sample_probs)) np.random.seed(rng_seed - 1) label = f"\\textsc{{CounterSample}} ($\\eta={lr}$)" data[label] = np.zeros((3, n_iters * interpsize + 1)) for t in range(n_iters): i = np.argwhere(np.random.multinomial(1, sample_probs) == 1.0)[0, 0] x = xs[i, :] y = labels[i] optimizer.zero_grad() o = model(x) l = loss_fn(o, y, Mbar) l.backward() optimizer.step() # Record current iteration location and performance optimizer.swap_swa_sgd() x, y = model.weight.data.numpy()[0] optimizer.swap_swa_sgd() old_x, old_y, old_z = data[label][:, t * interpsize] xr = np.linspace(old_x, x, num=interpsize) yr = np.linspace(old_y, y, num=interpsize) for i in range(interpsize): data[label][:, 1 + t * interpsize + i] = np.array([ xr[i], yr[i], f(xr[i], yr[i])]) colors[label] = "deepskyblue" # Print summary to quickly find performance at convergence for label in data.keys(): print(f"{label}: {data[label][2, -1]}") # Create legend lines = {} for label in data.keys(): line = data[label] lines[label] = ax.plot(line[:, 0], line[:, 1], color=colors[label], label=label, linewidth=2.0) legend[colors[label]] = label ax.set_xlabel('$w_1$') ax.set_ylabel('$w_2$') legend_lines = [ lines[legend["deepskyblue"]][0], lines[legend["orange"]][0], lines[legend["violet"]][0], lines[legend["lightgreen"]][0]] legend_labels = [ "\\textsc{CounterSample}", "IPS-SGD (best learning rate)", "IPS-SGD (learning rate too small)", "IPS-SGD (learning rate too large)"] legend_artist = ax.legend(legend_lines, legend_labels, loc='lower right') #, bbox_to_anchor=(1.0 - 0.3, 0.25)) # Update function for animation n_frames = n_iters * 2 n_data = n_iters * interpsize from math import floor def transform_num(num): x = 3 * ((1.0 * num) / n_frames) y = (((x + 0.5)**2 - 0.5**2) / 12.0) return floor(y * n_data) def update_lines(num): print(f"frame {num:4d} / {n_frames:4d} [{num / n_frames * 100:.0f}%]", end="\r") num = transform_num(num) out = [legend_artist] for label in data.keys(): line = lines[label][0] d = data[label] line.set_data(d[0:2, :num]) out.append(line) return out # Write animation to file line_ani = animation.FuncAnimation(fig, update_lines, n_frames, interval=100, blit=False, repeat_delay=3000) writer = animation.FFMpegWriter(fps=60, codec='h264') # for keynote line_ani.save(args.out, writer=writer) print("\033[K\n")
def main(): maxIOU = 0.0 assert torch.cuda.is_available() torch.backends.cudnn.benchmark = True model_fname = '../data/model_swa_8/deeplabv3_{0}_epoch%d.pth'.format( 'crops') focal_loss = FocalLoss2d() train_dataset = CropSegmentation(train=True, crop_size=args.crop_size) # test_dataset = CropSegmentation(train=False, crop_size=args.crop_size) model = torchvision.models.segmentation.deeplabv3_resnet50( pretrained=False, progress=True, num_classes=5, aux_loss=True) if args.train: weight = np.ones(4) weight[2] = 5 weight[3] = 5 w = torch.FloatTensor(weight).cuda() criterion = nn.CrossEntropyLoss() #ignore_index=255 weight=w model = nn.DataParallel(model).cuda() for param in model.parameters(): param.requires_grad = True optimizer1 = optim.SGD(model.parameters(), lr=config.lr, momentum=0.9, weight_decay=1e-4) scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=(args.epochs // 9) + 1) optimizer = SWA(optimizer1) dataset_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=args.train, pin_memory=True, num_workers=args.workers) max_iter = args.epochs * len(dataset_loader) losses = AverageMeter() start_epoch = 0 if args.resume: if os.path.isfile(args.resume): print('=> loading checkpoint {0}'.format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print('=> loaded checkpoint {0} (epoch {1})'.format( args.resume, checkpoint['epoch'])) else: print('=> no checkpoint found at {0}'.format(args.resume)) for epoch in range(start_epoch, args.epochs): scheduler.step(epoch) model.train() for i, (inputs, target) in enumerate(dataset_loader): inputs = Variable(inputs.cuda()) target = Variable(target.cuda()) outputs = model(inputs) loss1 = focal_loss(outputs['out'], target) loss2 = focal_loss(outputs['aux'], target) loss01 = loss1 + 0.1 * loss2 loss3 = lovasz_softmax(outputs['out'], target) loss4 = lovasz_softmax(outputs['aux'], target) loss02 = loss3 + 0.1 * loss4 loss = loss01 + loss02 if np.isnan(loss.item()) or np.isinf(loss.item()): pdb.set_trace() losses.update(loss.item(), args.batch_size) loss.backward() optimizer.step() optimizer.zero_grad() if i > 10 and i % 5 == 0: optimizer.update_swa() print('epoch: {0}\t' 'iter: {1}/{2}\t' 'loss: {loss.val:.4f} ({loss.ema:.4f})'.format( epoch + 1, i + 1, len(dataset_loader), loss=losses)) if epoch > 5: torch.save( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, model_fname % (epoch + 1)) optimizer.swap_swa_sgd() torch.save( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, model_fname % (665 + 1))
class Model() : def __init__(self, configuration, pre_embed=None) : configuration = deepcopy(configuration) self.configuration = deepcopy(configuration) configuration['model']['encoder']['pre_embed'] = pre_embed encoder_copy = deepcopy(configuration['model']['encoder']) self.Pencoder = Encoder.from_params(Params(configuration['model']['encoder'])).to(device) self.Qencoder = Encoder.from_params(Params(encoder_copy)).to(device) configuration['model']['decoder']['hidden_size'] = self.Pencoder.output_size self.decoder = AttnDecoderQA.from_params(Params(configuration['model']['decoder'])).to(device) self.bsize = configuration['training']['bsize'] self.adversary_multi = AdversaryMulti(self.decoder) weight_decay = configuration['training'].get('weight_decay', 1e-5) self.params = list(self.Pencoder.parameters()) + list(self.Qencoder.parameters()) + list(self.decoder.parameters()) self.optim = torch.optim.Adam(self.params, weight_decay=weight_decay, amsgrad=True) # self.optim = torch.optim.Adagrad(self.params, lr=0.05, weight_decay=weight_decay) self.criterion = nn.CrossEntropyLoss() import time dirname = configuration['training']['exp_dirname'] basepath = configuration['training'].get('basepath', 'outputs') self.time_str = time.ctime().replace(' ', '_') self.dirname = os.path.join(basepath, dirname, self.time_str) self.swa_settings = configuration['training']['swa'] if self.swa_settings[0]: self.swa_all_optim = SWA(self.optim) self.running_norms = [] @classmethod def init_from_config(cls, dirname, **kwargs) : config = json.load(open(dirname + '/config.json', 'r')) config.update(kwargs) obj = cls(config) obj.load_values(dirname) return obj def get_param_buffer_norms(self): for p in self.swa_all_optim.param_groups[0]['params']: param_state = self.swa_all_optim.state[p] if 'swa_buffer' not in param_state: self.swa_all_optim.update_swa() norms = [] for p in np.array(self.swa_all_optim.param_groups[0]['params'])[ [1, 2, 5, 6, 10, 11, 14, 15, 18, 20, 24, 26]]: param_state = self.swa_all_optim.state[p] buf = np.squeeze( param_state['swa_buffer'].cpu().numpy()) cur_state = np.squeeze(p.data.cpu().numpy()) norm = np.linalg.norm(buf - cur_state) norms.append(norm) if self.swa_settings[3] == 2: return np.max(norms) return np.mean(norms) def total_iter_num(self): return self.swa_all_optim.param_groups[0]['step_counter'] def iter_for_swa_update(self, iter_num): return iter_num > self.swa_settings[1] \ and iter_num % self.swa_settings[2] == 0 def check_and_update_swa(self): if self.iter_for_swa_update(self.total_iter_num()): cur_step_diff_norm = self.get_param_buffer_norms() if self.swa_settings[3] == 0: self.swa_all_optim.update_swa() return if not self.running_norms: running_mean_norm = 0 else: running_mean_norm = np.mean(self.running_norms) if cur_step_diff_norm > running_mean_norm: self.swa_all_optim.update_swa() self.running_norms = [cur_step_diff_norm] elif cur_step_diff_norm > 0: self.running_norms.append(cur_step_diff_norm) def train(self, train_data, train=True) : docs_in = train_data.P question_in = train_data.Q entity_masks_in = train_data.E target_in = train_data.A sorting_idx = get_sorting_index_with_noise_from_lengths([len(x) for x in docs_in], noise_frac=0.1) docs = [docs_in[i] for i in sorting_idx] questions = [question_in[i] for i in sorting_idx] entity_masks = [entity_masks_in[i] for i in sorting_idx] target = [target_in[i] for i in sorting_idx] self.Pencoder.train() self.Qencoder.train() self.decoder.train() bsize = self.bsize N = len(questions) loss_total = 0 batches = list(range(0, N, bsize)) batches = shuffle(batches) for n in tqdm(batches) : torch.cuda.empty_cache() batch_doc = docs[n:n+bsize] batch_ques = questions[n:n+bsize] batch_entity_masks = entity_masks[n:n+bsize] batch_doc = BatchHolder(batch_doc) batch_ques = BatchHolder(batch_ques) batch_data = BatchMultiHolder(P=batch_doc, Q=batch_ques) batch_data.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device) self.Pencoder(batch_data.P) self.Qencoder(batch_data.Q) self.decoder(batch_data) batch_target = target[n:n+bsize] batch_target = torch.LongTensor(batch_target).to(device) ce_loss = self.criterion(batch_data.predict, batch_target) loss = ce_loss if hasattr(batch_data, 'reg_loss') : loss += batch_data.reg_loss if train : if self.swa_settings[0]: self.check_and_update_swa() self.swa_all_optim.zero_grad() loss.backward() self.swa_all_optim.step() else: self.optim.zero_grad() loss.backward() self.optim.step() loss_total += float(loss.data.cpu().item()) if self.swa_settings[0] and self.swa_all_optim.param_groups[0][ 'step_counter'] > self.swa_settings[1]: print("\nSWA swapping\n") # self.attn_optim.swap_swa_sgd() # self.encoder_optim.swap_swa_sgd() # self.decoder_optim.swap_swa_sgd() self.swa_all_optim.swap_swa_sgd() self.running_norms = [] return loss_total*bsize/N def evaluate(self, data) : docs = data.P questions = data.Q entity_masks = data.E self.Pencoder.train() self.Qencoder.train() self.decoder.train() bsize = self.bsize N = len(questions) outputs = [] attns = [] scores = [] for n in tqdm(range(0, N, bsize)) : torch.cuda.empty_cache() batch_doc = docs[n:n+bsize] batch_ques = questions[n:n+bsize] batch_entity_masks = entity_masks[n:n+bsize] batch_doc = BatchHolder(batch_doc) batch_ques = BatchHolder(batch_ques) batch_data = BatchMultiHolder(P=batch_doc, Q=batch_ques) batch_data.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device) self.Pencoder(batch_data.P) self.Qencoder(batch_data.Q) self.decoder(batch_data) prediction_scores = batch_data.predict.cpu().data.numpy() batch_data.predict = torch.argmax(batch_data.predict, dim=-1) if self.decoder.use_attention : attn = batch_data.attn attns.append(attn.cpu().data.numpy()) predict = batch_data.predict.cpu().data.numpy() outputs.append(predict) scores.append(prediction_scores) outputs = [x for y in outputs for x in y] attns = [x for y in attns for x in y] scores = [x for y in scores for x in y] return outputs, attns, scores def save_values(self, use_dirname=None, save_model=True) : if use_dirname is not None : dirname = use_dirname else : dirname = self.dirname os.makedirs(dirname, exist_ok=True) shutil.copy2(file_name, dirname + '/') json.dump(self.configuration, open(dirname + '/config.json', 'w')) if save_model : torch.save(self.Pencoder.state_dict(), dirname + '/encP.th') torch.save(self.Qencoder.state_dict(), dirname + '/encQ.th') torch.save(self.decoder.state_dict(), dirname + '/dec.th') return dirname def load_values(self, dirname) : self.Pencoder.load_state_dict(torch.load(dirname + '/encP.th')) self.Qencoder.load_state_dict(torch.load(dirname + '/encQ.th')) self.decoder.load_state_dict(torch.load(dirname + '/dec.th')) def permute_attn(self, data, num_perm=100) : docs = data.P questions = data.Q entity_masks = data.E self.Pencoder.train() self.Qencoder.train() self.decoder.train() bsize = self.bsize N = len(questions) permutations_predict = [] permutations_diff = [] for n in tqdm(range(0, N, bsize)) : torch.cuda.empty_cache() batch_doc = docs[n:n+bsize] batch_ques = questions[n:n+bsize] batch_entity_masks = entity_masks[n:n+bsize] batch_doc = BatchHolder(batch_doc) batch_ques = BatchHolder(batch_ques) batch_data = BatchMultiHolder(P=batch_doc, Q=batch_ques) batch_data.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device) self.Pencoder(batch_data.P) self.Qencoder(batch_data.Q) self.decoder(batch_data) predict_true = batch_data.predict.clone().detach() batch_perms_predict = np.zeros((batch_data.P.B, num_perm)) batch_perms_diff = np.zeros((batch_data.P.B, num_perm)) for i in range(num_perm) : batch_data.permute = True self.decoder(batch_data) predict = torch.argmax(batch_data.predict, dim=-1) batch_perms_predict[:, i] = predict.cpu().data.numpy() predict_difference = self.adversary_multi.output_diff(batch_data.predict, predict_true) batch_perms_diff[:, i] = predict_difference.squeeze(-1).cpu().data.numpy() permutations_predict.append(batch_perms_predict) permutations_diff.append(batch_perms_diff) permutations_predict = [x for y in permutations_predict for x in y] permutations_diff = [x for y in permutations_diff for x in y] return permutations_predict, permutations_diff def adversarial_multi(self, data) : docs = data.P questions = data.Q entity_masks = data.E self.Pencoder.eval() self.Qencoder.eval() self.decoder.eval() print(self.adversary_multi.K) self.params = list(self.Pencoder.parameters()) + list(self.Qencoder.parameters()) + list(self.decoder.parameters()) for p in self.params : p.requires_grad = False bsize = self.bsize N = len(questions) batches = list(range(0, N, bsize)) outputs, attns, diffs = [], [], [] for n in tqdm(batches) : torch.cuda.empty_cache() batch_doc = docs[n:n+bsize] batch_ques = questions[n:n+bsize] batch_entity_masks = entity_masks[n:n+bsize] batch_doc = BatchHolder(batch_doc) batch_ques = BatchHolder(batch_ques) batch_data = BatchMultiHolder(P=batch_doc, Q=batch_ques) batch_data.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device) self.Pencoder(batch_data.P) self.Qencoder(batch_data.Q) self.decoder(batch_data) self.adversary_multi(batch_data) predict_volatile = torch.argmax(batch_data.predict_volatile, dim=-1) outputs.append(predict_volatile.cpu().data.numpy()) attn = batch_data.attn_volatile attns.append(attn.cpu().data.numpy()) predict_difference = self.adversary_multi.output_diff(batch_data.predict_volatile, batch_data.predict.unsqueeze(1)) diffs.append(predict_difference.cpu().data.numpy()) outputs = [x for y in outputs for x in y] attns = [x for y in attns for x in y] diffs = [x for y in diffs for x in y] return outputs, attns, diffs def gradient_mem(self, data) : docs = data.P questions = data.Q entity_masks = data.E self.Pencoder.train() self.Qencoder.train() self.decoder.train() bsize = self.bsize N = len(questions) grads = {'XxE' : [], 'XxE[X]' : [], 'H' : []} for n in range(0, N, bsize) : torch.cuda.empty_cache() batch_doc = docs[n:n+bsize] batch_ques = questions[n:n+bsize] batch_entity_masks = entity_masks[n:n+bsize] batch_doc = BatchHolder(batch_doc) batch_ques = BatchHolder(batch_ques) batch_data = BatchMultiHolder(P=batch_doc, Q=batch_ques) batch_data.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device) batch_data.P.keep_grads = True batch_data.detach = True self.Pencoder(batch_data.P) self.Qencoder(batch_data.Q) self.decoder(batch_data) max_predict = torch.argmax(batch_data.predict, dim=-1) prob_predict = nn.Softmax(dim=-1)(batch_data.predict) max_class_prob = torch.gather(prob_predict, -1, max_predict.unsqueeze(-1)) max_class_prob.sum().backward() g = batch_data.P.embedding.grad em = batch_data.P.embedding g1 = (g * em).sum(-1) grads['XxE[X]'].append(g1.cpu().data.numpy()) g1 = (g * self.Pencoder.embedding.weight.sum(0)).sum(-1) grads['XxE'].append(g1.cpu().data.numpy()) g1 = batch_data.P.hidden.grad.sum(-1) grads['H'].append(g1.cpu().data.numpy()) for k in grads : grads[k] = [x for y in grads[k] for x in y] return grads def remove_and_run(self, data) : docs = data.P questions = data.Q entity_masks = data.E self.Pencoder.train() self.Qencoder.train() self.decoder.train() bsize = self.bsize N = len(questions) output_diffs = [] for n in tqdm(range(0, N, bsize)) : torch.cuda.empty_cache() batch_doc = docs[n:n+bsize] batch_ques = questions[n:n+bsize] batch_entity_masks = entity_masks[n:n+bsize] batch_doc = BatchHolder(batch_doc) batch_ques = BatchHolder(batch_ques) batch_data = BatchMultiHolder(P=batch_doc, Q=batch_ques) batch_data.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device) self.Pencoder(batch_data.P) self.Qencoder(batch_data.Q) self.decoder(batch_data) po = np.zeros((batch_data.P.B, batch_data.P.maxlen)) for i in range(1, batch_data.P.maxlen - 1) : batch_doc = BatchHolder(docs[n:n+bsize]) batch_doc.seq = torch.cat([batch_doc.seq[:, :i], batch_doc.seq[:, i+1:]], dim=-1) batch_doc.lengths = batch_doc.lengths - 1 batch_doc.masks = torch.cat([batch_doc.masks[:, :i], batch_doc.masks[:, i+1:]], dim=-1) batch_data_loop = BatchMultiHolder(P=batch_doc, Q=batch_ques) batch_data_loop.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device) self.Pencoder(batch_data_loop.P) self.decoder(batch_data_loop) predict_difference = self.adversary_multi.output_diff(batch_data_loop.predict, batch_data.predict) po[:, i] = predict_difference.squeeze(-1).cpu().data.numpy() output_diffs.append(po) output_diffs = [x for y in output_diffs for x in y] return output_diffs
def main(): if args.config_path: if args.config_path in CONFIG_TREATER: load_path = CONFIG_TREATER[args.config_path] elif args.config_path.endswith(".yaml"): load_path = args.config_path else: load_path = "experiments/" + CONFIG_TREATER[ args.config_path] + ".yaml" with open(load_path, 'rb') as fp: config = CfgNode.load_cfg(fp) else: config = None torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = True test_model = None max_epoch = config.TRAIN.NUM_EPOCHS print('data folder: ', args.data_folder) torch.backends.cudnn.benchmark = True # WORLD_SIZE Generated by torch.distributed.launch.py #num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 #is_distributed = num_gpus > 1 #if is_distributed: # torch.cuda.set_device(args.local_rank) # torch.distributed.init_process_group( # backend="nccl", init_method="env://", # ) model = get_model(config) model_loss = ModelLossWraper( model, config.TRAIN.CLASS_WEIGHTS, config.MODEL.IS_DISASTER_PRED, config.MODEL.IS_SPLIT_LOSS, ).cuda() #if args.local_rank == 0: #from IPython import embed; embed() #if is_distributed: # model_loss = nn.SyncBatchNorm.convert_sync_batchnorm(model_loss) # model_loss = nn.parallel.DistributedDataParallel( # model_loss#, device_ids=[args.local_rank], output_device=args.local_rank # ) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if torch.cuda.device_count() > 1: model_loss = nn.DataParallel(model_loss) model_loss.to(device) cpucount = multiprocessing.cpu_count() if config.mode.startswith("single"): trainset_loaders = {} loader_len = 0 for disaster in disaster_list[config.mode[6:]]: trainset = XView2Dataset(args.data_folder, rgb_bgr='rgb', preprocessing={ 'flip': True, 'scale': config.TRAIN.MULTI_SCALE, 'crop': config.TRAIN.CROP_SIZE, }, mode="singletrain", single_disaster=disaster) if len(trainset) > 0: train_sampler = None trainset_loader = torch.utils.data.DataLoader( trainset, batch_size=config.TRAIN.BATCH_SIZE_PER_GPU, shuffle=train_sampler is None, pin_memory=True, drop_last=True, sampler=train_sampler, num_workers=cpucount if cpucount < 16 else cpucount // 3) trainset_loaders[disaster] = trainset_loader loader_len += len(trainset_loader) print("added disaster {} with {} samples".format( disaster, len(trainset))) else: print("skipping disaster ", disaster) else: trainset = XView2Dataset(args.data_folder, rgb_bgr='rgb', preprocessing={ 'flip': True, 'scale': config.TRAIN.MULTI_SCALE, 'crop': config.TRAIN.CROP_SIZE, }, mode=config.mode) #if is_distributed: # train_sampler = DistributedSampler(trainset) #else: train_sampler = None trainset_loader = torch.utils.data.DataLoader( trainset, batch_size=config.TRAIN.BATCH_SIZE_PER_GPU, shuffle=train_sampler is None, pin_memory=True, drop_last=True, sampler=train_sampler, num_workers=multiprocessing.cpu_count()) loader_len = len(trainset_loader) model.train() lr_init = config.TRAIN.LR optimizer = torch.optim.SGD( [{ 'params': filter(lambda p: p.requires_grad, model.parameters()), 'lr': lr_init }], lr=lr_init, momentum=0.9, weight_decay=0., nesterov=False, ) num_iters = max_epoch * loader_len if config.SWA: swa_start = num_iters optimizer = SWA( optimizer, swa_start=swa_start, swa_freq=4 * loader_len, swa_lr=0.001 ) #SWA(optimizer, swa_start = None, swa_freq = None, swa_lr = None)# #scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, 0.0001, 0.05, step_size_up=1, step_size_down=2*len(trainset_loader)-1, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=False, base_momentum=0.8, max_momentum=0.9, last_epoch=-1) lr = 0.0001 #model.load_state_dict(torch.load("ckpt/dual-hrnet/hrnet_450", map_location='cpu')['state_dict']) #print("weights loaded") max_epoch = max_epoch + 40 start_epoch = 0 losses = AverageMeter() model.train() cur_iters = 0 if start_epoch == 0 else None for epoch in range(start_epoch, max_epoch): if config.mode.startswith("single"): all_batches = [] total_len = 0 for disaster in sorted(list(trainset_loaders.keys())): all_batches += [ (disaster, idx) for idx in range(len(trainset_loaders[disaster])) ] total_len += len(trainset_loaders[disaster].dataset) all_batches = random.sample(all_batches, len(all_batches)) iterators = { disaster: iter(trainset_loaders[disaster]) for disaster in trainset_loaders.keys() } if cur_iters is not None: cur_iters += len(all_batches) else: cur_iters = epoch * len(all_batches) for i, (disaster, idx) in enumerate(all_batches): lr = optimizer.param_groups[0]['lr'] if not config.SWA or epoch < swa_start: lr = adjust_learning_rate(optimizer, lr_init, num_iters, i + cur_iters) samples = next(iterators[disaster]) inputs_pre = samples['pre_img'].to(device) inputs_post = samples['post_img'].to(device) target = samples['mask_img'].to(device) #disaster_target = samples['disaster'].to(device) loss = model_loss(inputs_pre, inputs_post, target) #, disaster_target) loss_sum = torch.sum(loss).detach().cpu() if np.isnan(loss_sum) or np.isinf(loss_sum): print('check') losses.update(loss_sum, 4) # batch size loss = torch.sum(loss) loss.backward() optimizer.step() optimizer.zero_grad() if args.local_rank == 0 and i % 10 == 0: logger.info('epoch: {0}\t' 'iter: {1}/{2}\t' 'lr: {3:.6f}\t' 'loss: {loss.val:.4f} ({loss.ema:.4f})\t' 'disaster: {dis}'.format(epoch + 1, i + 1, len(all_batches), lr, loss=losses, dis=disaster)) del iterators else: cur_iters = epoch * len(trainset_loader) for i, samples in enumerate(trainset_loader): lr = optimizer.param_groups[0]['lr'] if not config.SWA or epoch < swa_start: lr = adjust_learning_rate(optimizer, lr_init, num_iters, i + cur_iters) inputs_pre = samples['pre_img'].to(device) inputs_post = samples['post_img'].to(device) target = samples['mask_img'].to(device) #disaster_target = samples['disaster'].to(device) loss = model_loss(inputs_pre, inputs_post, target) #, disaster_target) loss_sum = torch.sum(loss).detach().cpu() if np.isnan(loss_sum) or np.isinf(loss_sum): print('check') losses.update(loss_sum, 4) # batch size loss = torch.sum(loss) loss.backward() optimizer.step() optimizer.zero_grad() #if args.swa == "True": #scheduler.step() #if epoch%4 == 3 and i == len(trainset_loader)-2: # optimizer.update_swa() if args.local_rank == 0 and i % 10 == 0: logger.info('epoch: {0}\t' 'iter: {1}/{2}\t' 'lr: {3:.6f}\t' 'loss: {loss.val:.4f} ({loss.ema:.4f})'.format( epoch + 1, i + 1, len(trainset_loader), lr, loss=losses)) if args.local_rank == 0: if (epoch + 1) % 50 == 0 and test_model is None: torch.save( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, os.path.join(ckpts_save_dir, 'hrnet_%s' % (epoch + 1))) if config.SWA: torch.save( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, os.path.join(ckpts_save_dir, 'hrnet_%s' % ("preSWA"))) optimizer.swap_swa_sgd() bn_loader = torch.utils.data.DataLoader( trainset, batch_size=2, shuffle=train_sampler is None, pin_memory=True, drop_last=True, sampler=train_sampler, num_workers=multiprocessing.cpu_count()) bn_update(bn_loader, model, device='cuda') torch.save( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, os.path.join(ckpts_save_dir, 'hrnet_%s' % ("SWA")))
def train_adversarial_model(train_loader, validation_loader, protected_attr, utility_attr_list, model_name, loss_name, f_standard_list, g_standard, epoch_number, learning_rate, lambda_list, target_accuracy, g = None, W = None, EA = False, save_gradient = False, gradient_loss_name = None): ''' train_loader : 訓練資料集 validation_loader : 訓練過程中用來幫助驗證的資料集 protected_attr : privacy的label utility_attr_list : utility的label(注意要用list形式) model_name : 訓練完成的模型存檔名稱 loss_name : 訓練過程中的training loss存檔名稱 f_standard_list : utility classifier使用standard版本,不會變動(注意要用list形式) g_standard : 在validation的時候可以用standard的privacy classifier和channel一起使用查看情形 epoch_number : 訓練epoch數目 learning_rate : 模型的learning rate初始值 lambda_list : W的loss function中的lambda值(注意要用list形式, k個utility + privacy + L1 loss) target_accuracy : 如果是multiple utility的話可以動態調整lambda g : privacy classifier是否初始值要使用訓練好的模型 W : channel是否初始值要使用訓練好的模型 EA : 是否使用Extra gradient演算法 save_gradient : 是否儲存gradient gradient_loss_name : 訓練過程中的gradient存檔名稱 ''' # 確保utility數量和lambda_list長度ㄧ致 utility_number = len(utility_attr_list) if len(lambda_list) != utility_number + 2: print('The lambda_list number is wrong.') # 模型初始化 for f in f_standard_list: f.cuda() if g == None: g = Discriminator() g.apply(weights_init_normal) g_standard.cuda() g.cuda() if W == None: W = Generator() W.apply(weights_init_normal) W.cuda() g_optimizer = SWA(torch.optim.Adam(g.parameters(), lr = learning_rate, betas=(0.5, 0.999))) W_optimizer = SWA(torch.optim.Adam(W.parameters(), lr = learning_rate, betas=(0.5, 0.999))) lr_scheduler_g = torch.optim.lr_scheduler.StepLR(g_optimizer, step_size = 10, gamma = 0.8) lr_scheduler_W = torch.optim.lr_scheduler.StepLR(W_optimizer, step_size = 10, gamma = 0.8) training_f_accuracy_list = [] training_g_accuracy_list = [] training_g_standard_accuracy_list = [] L1_loss_list = [] g_gradient_list = [] W_gradient_list = [] model_list = [] each_save_number = len(train_loader) // 10 epoch_count = 0 # 開始訓練模型 for epoch in range(epoch_number): seed_number = 999 torch.manual_seed(seed_number + epoch) for i, (images, _, paths) in enumerate(train_loader): images = to_var(images) x_hat = W(images) protected_labels = file_name_dealer(paths, attrs = [protected_attr])[0] utility_labels = file_name_dealer(paths, attrs = utility_attr_list) # current step # g : privacy classifier training g_current_para = g.parameters() g.zero_grad() g_loss = loss_calculation(g, x_hat, protected_labels) g_loss.backward(retain_graph = True) g_current_para = g.parameters() g_optimizer.step() # W : channel training W_current_para = W.parameters() W.zero_grad() f_loss = 0 # 分項計算loss for index, f in enumerate(f_standard_list): f_loss += loss_calculation(f, x_hat, utility_labels[index])*lambda_list[index] g_loss = loss_calculation(g, x_hat, protected_labels)*lambda_list[-2] identity_loss = criterion_L1(images, x_hat)*lambda_list[-1] W_loss = f_loss - g_loss + identity_loss W_loss.backward(retain_graph = True) W_optimizer.step() # 使用EA演算法進行更新: if EA == True: # g g_next_para = g.parameters() g.zero_grad() g_loss = loss_calculation(g, x_hat, protected_labels) g_loss.backward(retain_graph = True) g_optimizer.step() # W W_next_para = W.parameters() W.zero_grad() f_loss = 0 # 分項計算loss for index, f in enumerate(f_standard_list): f_loss += loss_calculation(f, x_hat, utility_labels[index])*lambda_list[index] g_loss = loss_calculation(g, x_hat, protected_labels)*lambda_list[-2] identity_loss = criterion_L1(images, x_hat)*lambda_list[-1] W_loss = f_loss - g_loss + identity_loss W_loss.backward(retain_graph = True) W_optimizer.step() # 請參考EA演算法 with torch.no_grad(): for gp1, gp2, gp3 in zip(g.parameters(), g_current_para, g_next_para): gp1 -= (gp3 - gp2) for Wp1, Wp2, Wp3 in zip(W.parameters(), W_current_para, W_next_para): Wp1 -= (Wp3 - Wp2) ##### model saving & validation ##################################################### with torch.no_grad(): # 清除cuda內存節省空間 torch.cuda.empty_cache() print('===========================') print('Now epoch number is ', epoch + 1) g_gradient = gradient_norm(g) W_gradient = gradient_norm(W) print('g gradient norm:',g_gradient) print('W gradient norm:',W_gradient) print('Now learning rate :') for param_group in g_optimizer.param_groups: print(param_group['lr']) #保存模型 print('Models are saving...') # 如果要記錄gradient話執行 if save_gradient == True: g_gradient_list.append(g_gradient) W_gradient_list.append(W_gradient) gradient_list = [copy.deepcopy(g_gradient_list), copy.deepcopy(W_gradient_list)] pickle.dump(gradient_list, open(gradient_loss_name, 'wb')) # 每10個epoch存一次模型,節省空間 if epoch % 10 == 0: model_list.append([copy.deepcopy(g), copy.deepcopy(W)]) pickle.dump(model_list, open(model_name, 'wb')) lr_scheduler_g.step() lr_scheduler_W.step() # validation # 計算accuracy和L1 loss f_predict_total_list = [] utility_total_list = [] for index in range(utility_number): f_predict_total_list.append([]) utility_total_list.append([]) utility_accuracy_total_list = [] g_predict_list = [] g_standard_predict_list = [] utility_labels_list = [] protected_labels_list = [] L1_loss = [] for (images, _, paths) in validation_loader: images = to_var(images) x_hat = W(images) protected_labels = file_name_dealer(paths, attrs = [protected_attr])[0] utility_labels = file_name_dealer(paths, attrs = utility_attr_list) for index, f in enumerate(f_standard_list): f_predict_total_list[index] += list(f(x_hat).view(-1).detach().cpu().numpy()) utility_total_list[index] += list(utility_labels[index].view(-1).detach().cpu().numpy()) g_predict_list += list(g(x_hat).view(-1).detach().cpu().numpy()) g_standard_predict_list += list(g_standard(x_hat).view(-1).detach().cpu().numpy()) protected_labels_list += list(protected_labels.view(-1).detach().cpu().numpy()) L1_loss += list(criterion_L1(images, x_hat).view(-1).detach().cpu().numpy()) for index in range(utility_number): utility_accuracy_total_list.append(correct_ratio(f_predict_total_list[index], utility_total_list[index], threshold = 0.5)) protected_accuracy = correct_ratio(g_predict_list, protected_labels_list, threshold = 0.5) protected_standard_accuracy = correct_ratio(g_standard_predict_list, protected_labels_list, threshold = 0.5) L1_loss_this_epoch = np.mean(L1_loss) for index in range(utility_number): print('f_accuracy_'+str(index+1)+':', utility_accuracy_total_list[index]) print('g_accuracy:'+str(protected_accuracy)+' standard_g_accuracy:'+str(protected_standard_accuracy)+' L1 loss:'+str(L1_loss_this_epoch)) training_f_accuracy_list.append(utility_accuracy_total_list) training_g_accuracy_list.append(protected_accuracy) training_g_standard_accuracy_list.append(protected_standard_accuracy) L1_loss_list.append(L1_loss_this_epoch) g_gradient_list.append(g_gradient) W_gradient_list.append(W_gradient) # 動態調整lambda lambda_list = adjust_lambda_list(lambda_list, utility_accuracy_total_list, target_accuracy) print('lambda_list:', lambda_list) # 儲存loss loss_list = [copy.deepcopy(training_f_accuracy_list), copy.deepcopy(training_g_accuracy_list), copy.deepcopy(training_g_standard_accuracy_list), copy.deepcopy(L1_loss_list)] pickle.dump(loss_list, open(loss_name, 'wb'))
def train(train_df, CONFIG): # set-up seed_everything(CONFIG['SEED']) torch.manual_seed(CONFIG['TORCH_SEED']) mlflow.log_params(CONFIG) TRAIN_LEN = len(train_df) train_dataset = TweetDataset(train_df, CONFIG) CRITERION = define_criterion(CONFIG) folds = StratifiedKFold(n_splits=CONFIG["FOLD"], shuffle=True, random_state=CONFIG["SEED"]) for n_fold, (train_idx, valid_idx) in enumerate( folds.split(train_dataset.df['textID'], train_dataset.df['sentiment'])): if n_fold != CONFIG["FOLD_NUM"]: continue ## DataLoaderの定義 train = torch.utils.data.Subset(train_dataset, train_idx) valid = torch.utils.data.Subset(train_dataset, valid_idx) DATA_IN_EPOCH = len(train) TOTAL_DATA = DATA_IN_EPOCH * CONFIG["EPOCHS"] T_TOTAL = int(CONFIG["EPOCHS"] * DATA_IN_EPOCH / CONFIG["TRAIN_BATCH_SIZE"]) ## modelとoptimizerの初期化 model = build_model(CONFIG) model.to(DEVICE) model.train() ## From 20/05/17 param_optimizer = list(model.named_parameters()) bert_params = [n for n, p in param_optimizer if "bert" in n] no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ ## BERT param { 'params': [ p for n, p in param_optimizer if (not any(nd in n for nd in no_decay)) and (n in bert_params) ], 'weight_decay': CONFIG["WEIGHT_DECAY"], 'lr': CONFIG['LR'] * 1, }, { 'params': [ p for n, p in param_optimizer if (any(nd in n for nd in no_decay)) and (n in bert_params) ], 'weight_decay': 0.0, 'lr': CONFIG['LR'] * 1, }, ## Other param { 'params': [ p for n, p in param_optimizer if (not any(nd in n for nd in no_decay)) and (n not in bert_params) ], 'weight_decay': CONFIG["WEIGHT_DECAY"], 'lr': CONFIG['LR'] * 1, }, { 'params': [ p for n, p in param_optimizer if (any(nd in n for nd in no_decay)) and (n not in bert_params) ], 'weight_decay': 0.0, 'lr': CONFIG['LR'] * 1, }, ] optimizer = torch.optim.AdamW(optimizer_grouped_parameters, eps=4e-5) if CONFIG['SWA']: optimizer = SWA(optimizer) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=int(CONFIG["WARMUP"] * T_TOTAL), num_training_steps=T_TOTAL) train_sampler = SentimentBalanceSampler(train, CONFIG) train_loader = DataLoader(train, batch_size=CONFIG["TRAIN_BATCH_SIZE"], shuffle=False, sampler=train_sampler, collate_fn=TweetCollate(CONFIG), num_workers=1) valid_loader = DataLoader( valid, batch_size=CONFIG["VALID_BATCH_SIZE"], shuffle=False, # sampler = valid_sampler, collate_fn=TweetCollate(CONFIG), num_workers=1) n_data = 0 n_e_data = 0 best_val = 0.0 best_val_neu = 0.0 best_val_pos = 0.0 best_val_neg = 0.0 t_batch = 0 while n_data < TOTAL_DATA: print(f"Epoch : {int(n_data/DATA_IN_EPOCH)}") n_batch = 0 loss_list = [] jac_token_list = [] jac_text_list = [] jac_sentiment_list = [] jac_cl_text_list = [] output_list = [] target_list = [] for batch in tqdm(train_loader): textID = batch['textID'] text = batch['text'] sentiment = batch['sentiment'] cl_text = batch['cl_text'] selected_text = batch['selected_text'] cl_selected_text = batch['cl_selected_text'] text_idx = batch['text_idx'] offsets = batch['offsets'] tokenized_text = batch['tokenized_text'].to(DEVICE) mask = batch['mask'].to(DEVICE) mask_out = batch['mask_out'].to(DEVICE) token_type_ids = batch['token_type_ids'].to(DEVICE) weight = batch['weight'].to(DEVICE) target = batch['target'].to(DEVICE) ep = int(n_data / DATA_IN_EPOCH) n_data += len(textID) n_e_data += len(textID) n_batch += 1 t_batch += 1 model.zero_grad() # optimizer.zero_grad() output = model(input_ids=tokenized_text, attention_mask=mask, token_type_ids=token_type_ids, mask_out=mask_out) loss = CRITERION(output, target) loss = loss * weight loss.mean().backward() loss = loss.detach().cpu().numpy().tolist() optimizer.step() if t_batch < T_TOTAL * 0.50: scheduler.step() loss_list.extend(loss) output = output.detach().cpu().numpy() target = target.detach().cpu().numpy() jac = calc_jaccard(output, batch, CONFIG) jac_token_list.extend(jac['jaccard_token'].tolist()) jac_cl_text_list.extend(jac['jaccard_cl_text'].tolist()) jac_text_list.extend(jac['jaccard_text'].tolist()) jac_sentiment_list.extend(sentiment) if ((((ep > 0) & (n_batch % (int(5 * 32 / CONFIG["TRAIN_BATCH_SIZE"])) == 0)) | ((n_batch % (int(50 * 32 / CONFIG["TRAIN_BATCH_SIZE"])) == 0)) | (n_data >= TOTAL_DATA)) and (CONFIG['SWA'] and ep > 0 and t_batch >= T_TOTAL * 0.50)): optimizer.update_swa() if ( ((ep > 0) & (n_batch % (int(50 * 32 / CONFIG["TRAIN_BATCH_SIZE"])) == 0)) | ((n_batch % (int(50 * 32 / CONFIG["TRAIN_BATCH_SIZE"])) == 0)) | (n_data >= TOTAL_DATA) ): # ((n_data>=0)&(n_data<=1600)|(n_data>=21000)&(n_data<=23000))& if CONFIG['SWA'] and ep > 0 and t_batch >= T_TOTAL * 0.50: # optimizer.update_swa() optimizer.swap_swa_sgd() val = create_valid(model, valid_loader, CONFIG) trn_loss = np.array(loss_list).mean() trn_jac_token = np.array(jac_token_list).mean() trn_jac_cl_text = np.array(jac_cl_text_list).mean() trn_jac_text = np.array(jac_text_list).mean() trn_jac_text_neu = np.array(jac_text_list)[np.array( jac_sentiment_list) == 'neutral'].mean() trn_jac_text_pos = np.array(jac_text_list)[np.array( jac_sentiment_list) == 'positive'].mean() trn_jac_text_neg = np.array(jac_text_list)[np.array( jac_sentiment_list) == 'negative'].mean() val_loss = val['loss'].mean() val_jac_token = val['jaccard_token'].mean() val_jac_cl_text = val['jaccard_cl_text'].mean() val_jac_text = val['jaccard_text'].mean() val_jac_text_neu = val['jaccard_text'][val['sentiment'] == 'neutral'].mean() val_jac_text_pos = val['jaccard_text'][val['sentiment'] == 'positive'].mean() val_jac_text_neg = val['jaccard_text'][val['sentiment'] == 'negative'].mean() loss_list = [] jac_token_list = [] jac_cl_text_list = [] jac_text_list = [] jac_sentiment_list = [] # mlflow metrics = { "lr": optimizer.param_groups[0]['lr'], "trn_loss": trn_loss, "trn_jac_text_neu": trn_jac_text_neu, "trn_jac_text_pos": trn_jac_text_pos, "trn_jac_text_neg": trn_jac_text_neg, "trn_jac_text": trn_jac_text, "val_loss": val_loss, "val_jac_text_neu": val_jac_text_neu, "val_jac_text_pos": val_jac_text_pos, "val_jac_text_neg": val_jac_text_neg, "val_jac_text": val_jac_text, } mlflow.log_metrics(metrics, step=n_data) if CONFIG['SWA'] and t_batch < T_TOTAL * 0.50: pass else: if best_val < val_jac_text: best_val = val_jac_text best_model = copy.deepcopy(model) if CONFIG['SWA'] and ep > 0 and t_batch >= T_TOTAL * 0.50: optimizer.swap_swa_sgd() if n_e_data >= DATA_IN_EPOCH: n_e_data -= DATA_IN_EPOCH if n_data >= TOTAL_DATA: filepath = os.path.join(FILE_DIR, OUTPUT_DIR, "model.pth") torch.save(best_model.state_dict(), filepath) # mlflow mlflow.log_artifact(filepath) break
def train(self, train_inputs): config = self.config.fitting model = train_inputs['model'] train_data = train_inputs['train_data'] dev_data = train_inputs['dev_data'] epoch_start = train_inputs['epoch_start'] train_steps = int((len(train_data) + config.batch_size - 1) / config.batch_size) train_dataloader = DataLoader(train_data, batch_size=config.batch_size, collate_fn=self.get_collate_fn('train'), shuffle=True) params_lr = [] for key, value in model.get_params().items(): if key in config.lr: params_lr.append({"params": value, 'lr': config.lr[key]}) optimizer = torch.optim.Adam(params_lr) optimizer = SWA(optimizer) early_stopping = EarlyStopping(model, ROOT_WEIGHT, mode='max', patience=3) learning_schedual = LearningSchedual(optimizer, config.epochs, config.end_epoch, train_steps, config.lr) aux = ModelAux(self.config, train_steps) moving_log = MovingData(window=100) ending_flag = False detach_flag = False swa_flag = False fgm = FGM(model) for epoch in range(epoch_start, config.epochs): for step, (inputs, targets, others) in enumerate(train_dataloader): inputs = dict([(key, value[0].cuda() if value[1] else value[0]) for key, value in inputs.items()]) targets = dict([(key, value.cuda()) for key, value in targets.items()]) if epoch > 0 and step == 0: model.detach_ptm(False) detach_flag = False if epoch == 0 and step == 0: model.detach_ptm(True) detach_flag = True # train ================================================================================================ preds = model(inputs, en_decode=config.verbose) loss = model.cal_loss(preds, targets, inputs['mask']) loss['back'].backward() # 对抗训练 if (not detach_flag) and config.en_fgm: fgm.attack(emb_name='word_embeddings') # 在embedding上添加对抗扰动 preds_adv = model(inputs, en_decode=False) loss_adv = model.cal_loss(preds_adv, targets, inputs['mask']) loss_adv['back'].backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度 fgm.restore(emb_name='word_embeddings') # 恢复embedding参数 # torch.nn.utils.clip_grad_norm(model.parameters(), 1) optimizer.step() optimizer.zero_grad() with torch.no_grad(): logs = {} if config.verbose: pred_entity_point = model.find_entity(preds['pred'], others['raw_text']) cn, pn, tn = self.calculate_f1(pred_entity_point, others['raw_entity']) metrics_data = {'loss': loss['show'].cpu().numpy(), 'sampled_num': 1, 'correct_num': cn, 'pred_num': pn, 'true_num': tn} moving_data = moving_log(epoch * train_steps + step, metrics_data) logs['loss'] = moving_data['loss'] / moving_data['sampled_num'] logs['precise'], logs['recall'], logs['f1'] = calculate_f1(moving_data['correct_num'], moving_data['pred_num'], moving_data['true_num'], verbose=True) else: metrics_data = {'loss': loss['show'].cpu().numpy(), 'sampled_num': 1} moving_data = moving_log(epoch * train_steps + step, metrics_data) logs['loss'] = moving_data['loss'] / moving_data['sampled_num'] # update lr lr_data = learning_schedual.update_lr(epoch, step) logs.update(lr_data) if step + 1 == train_steps: model.eval() aux.new_line() # dev ========================================================================================== eval_inputs = {'model': model, 'data': dev_data, 'type_data': 'dev', 'outfile': train_inputs['dev_res_file']} dev_result = self.eval(eval_inputs) logs['dev_loss'] = dev_result['loss'] logs['dev_precise'] = dev_result['precise'] logs['dev_recall'] = dev_result['recall'] logs['dev_f1'] = dev_result['f1'] if logs['dev_f1'] > 0.80: torch.save(model.state_dict(), "{}/auto_save_{:.6f}.ckpt".format(ROOT_WEIGHT, logs['dev_f1'])) if (epoch > 3 or swa_flag) and config.en_swa: optimizer.update_swa() swa_flag = True early_stop, best_score = early_stopping(logs['dev_f1']) # test ========================================================================================= if (epoch + 1 == config.epochs and step + 1 == train_steps) or early_stop: ending_flag = True if swa_flag: optimizer.swap_swa_sgd() optimizer.bn_update(train_dataloader, model) model.train() aux.show_log(epoch, step, logs) if ending_flag: return best_score
class Trainer(object): def __init__(self, args, train_dataloader=None, validate_dataloader=None, test_dataloader=None): self.args = args self.train_dataloader = train_dataloader self.validate_dataloader = validate_dataloader self.test_dataloader = test_dataloader self.label_lst = [i for i in range(self.args.num_classes)] self.num_labels = self.args.num_classes self.config_class = AutoConfig self.model_class = BertForSequenceClassification self.config = self.config_class.from_pretrained( self.args.bert_model_name, num_labels=self.num_labels, finetuning_task='nsmc', id2label={str(i): label for i, label in enumerate(self.label_lst)}, label2id={label: i for i, label in enumerate(self.label_lst)}) self.model = self.model_class.from_pretrained( self.args.bert_model_name, config=self.config) self.optimizer = None self.scheduler = None # GPU or CPU self.device = "cuda" if torch.cuda.is_available( ) and args.cuda else "cpu" self.model.to(self.device) def train(self, alpha, gamma): train_dataloader = self.train_dataloader t_total = len(train_dataloader) * self.args.num_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }, { 'params': [ p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] if self.args.use_swa: base_opt = AdamW(optimizer_grouped_parameters, lr=self.args.lr, eps=1e-8) self.optimizer = SWA(base_opt, swa_start=4 * len(train_dataloader), swa_freq=100, swa_lr=5e-5) self.optimizer.param_groups = self.optimizer.optimizer.param_groups self.optimizer.state = self.optimizer.optimizer.state self.optimizer.defaults = self.optimizer.optimizer.defaults else: self.optimizer = optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.lr, eps=1e-8) self.scheduler = scheduler = get_linear_schedule_with_warmup( self.optimizer, num_warmup_steps=100, num_training_steps=self.args.num_epochs * len(train_dataloader)) self.criterion = FocalLoss(alpha=alpha, gamma=gamma) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(self.train_dataloader) * self.args.batch_size) logger.info(" Num Epochs = %d", self.args.num_epochs) logger.info(" Total train batch size = %d", self.args.batch_size) logger.info(" Total optimization steps = %d", t_total) global_step = 0 tr_loss = 0.0 self.model.zero_grad() self.optimizer.zero_grad() train_iterator = trange(int(self.args.num_epochs), desc="Epoch") fin_result = None f1_max = 0.0 self.model.train() for epoch in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration") for step, batch in enumerate(epoch_iterator): batch = tuple(t.to(self.device) for t in batch) # GPU or CPU inputs = { 'input_ids': batch[0], 'attention_mask': batch[1], 'labels': batch[3], 'token_type_ids': batch[2] } # outputs = self.model(**inputs) # loss = outputs[0] # # Custom Loss loss, logits = self.model(**inputs) logits = torch.sigmoid(logits) labels = torch.zeros( (len(batch[3]), self.num_labels)).to(self.device) labels[range(len(batch[3])), batch[3]] = 1 loss = self.criterion(logits, labels) loss.backward() self.optimizer.step() self.scheduler.step() # Update learning rate schedule self.model.zero_grad() self.optimizer.zero_grad() tr_loss += loss.item() global_step += 1 logger.info('train loss %f', loss.item()) logger.info('total train loss %f', tr_loss / global_step) if epoch >= 4 and self.args.use_swa: self.optimizer.swap_swa_sgd() fin_result = self.evaluate("validate") self.save_model(epoch) self.model.train() if epoch >= 4 and self.args.use_swa: self.optimizer.swap_swa_sgd() f1_max = max(fin_result['f1_macro'], f1_max) if epoch >= 4 and self.args.use_swa: self.optimizer.swap_swa_sgd() with open(os.path.join(self.args.base_dir, self.args.result_dir, self.args.train_id, 'param_seach.txt'), "a", encoding="utf-8") as f: f.write('alpha: {}, gamma: {}, f1_macro: {}\n'.format( alpha, gamma, f1_max)) return f1_max def evaluate(self, mode='test'): if mode == 'test': dataloader = self.test_dataloader elif mode == 'validate': dataloader = self.validate_dataloader else: raise Exception("Only dev and test dataset available") # Eval! logger.info("***** Running evaluation on %s dataset *****", mode) logger.info(" Num examples = %d", len(dataloader) * self.args.batch_size) logger.info(" Batch size = %d", self.args.batch_size) eval_loss = 0.0 nb_eval_steps = 0 preds = None out_label_ids = None self.model.eval() for batch in tqdm(dataloader, desc="Evaluating"): batch = tuple(t.to(self.device) for t in batch) with torch.no_grad(): inputs = { 'input_ids': batch[0], 'attention_mask': batch[1], 'labels': batch[3], 'token_type_ids': batch[2] } outputs = self.model(**inputs) tmp_eval_loss, logits = outputs[:2] eval_loss += tmp_eval_loss.mean().item() nb_eval_steps += 1 if preds is None: preds = logits.detach().cpu().numpy() out_label_ids = inputs['labels'].detach().cpu().numpy() else: preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) out_label_ids = np.append( out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0) eval_loss = eval_loss / nb_eval_steps results = {"loss": eval_loss} preds = np.argmax(preds, axis=1) result = compute_metrics(preds, out_label_ids) results.update(result) p_macro, r_macro, f_macro, support_macro \ = precision_recall_fscore_support(y_true=out_label_ids, y_pred=preds, labels=[i for i in range(self.num_labels)], average='macro') results.update({ 'precision': p_macro, 'recall': r_macro, 'f1_macro': f_macro }) with open(self.args.prediction_file, "w", encoding="utf-8") as f: for pred in preds: f.write("{}\n".format(pred)) if mode == 'validate': logger.info("***** Eval results *****") for key in sorted(results.keys()): logger.info(" %s = %s", key, str(results[key])) return results def save_model(self, num=0): state = { 'model': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict() } torch.save( state, os.path.join(self.args.base_dir, self.args.result_dir, self.args.train_id, 'epoch_' + str(num) + '.pth')) logger.info('model saved') def load_model(self, model_name): state = torch.load( os.path.join(self.args.base_dir, self.args.result_dir, self.args.train_id, model_name)) self.model.load_state_dict(state['model']) if self.optimizer is not None: self.optimizer.load_state_dict(state['optimizer']) if self.scheduler is not None: self.scheduler.load_state_dict(state['scheduler']) logger.info('model loaded')
def main(args): # Parameters for toy data and experiments plt.figure(figsize=(6.4, 3.4)) rng_seed = 42 np.random.seed(rng_seed) wstar = torch.tensor([[0.973, 1.144]], dtype=torch.float) xs = torch.tensor(np.random.randn(50, 2), dtype=torch.float) labels = torch.mm(xs, wstar.T) p = torch.tensor(np.random.uniform(0.05, 1.0, xs.shape[0]), dtype=torch.float) ips = 1.0 / p n_iters = 50 plot_every = n_iters // 10 arrow_width = 0.012 legends = {} # The loss function we want to optimize def loss_fn(out, y, mult): l2loss = (out - y)**2.0 logl2loss = torch.log(1.0 + (out - y)**2.0) return torch.mean(mult * l2loss) # IPS-weighted approach for color_index, lr in enumerate([0.01, 0.02, 0.03, 0.05, 0.1]): # 0.01, 0.03, 0.05, 0.1, 0.3 color = "C%d" % (color_index + 2) if color_index > 0 else "C1" model = torch.nn.Linear(2, 1) optimizer = torch.optim.SGD(model.parameters(), lr=lr) optimizer = SWA(optimizer, swa_start=0, swa_freq=1, swa_lr=lr) with torch.no_grad(): model.bias.zero_() model.weight.zero_() old_weights = np.copy(model.weight.data.numpy()) np.random.seed(rng_seed + color_index + 1) for t in range(n_iters): i = np.random.randint(xs.shape[0]) x = xs[i, :] y = labels[i] optimizer.zero_grad() o = model(x) l = loss_fn(o, y, ips[i]) l.backward() optimizer.step() if t % plot_every == 0: optimizer.swap_swa_sgd() x, y = model.weight.data.numpy()[0] optimizer.swap_swa_sgd() ox, oy = old_weights[0] label = f"IPS-SGD ($\\eta={lr}$)" arr = plt.arrow(ox, oy, x - ox, y - oy, width=arrow_width, length_includes_head=True, color=color, label=label) optimizer.swap_swa_sgd() old_weights = np.copy(model.weight.data.numpy()) optimizer.swap_swa_sgd() legends[label] = arr # Sample based approach for lr in [10.0]: # lr = 3.0 # 1.0 model = torch.nn.Linear(2, 1) optimizer = torch.optim.SGD(model.parameters(), lr=lr) optimizer = SWA(optimizer, swa_start=0, swa_freq=1, swa_lr=lr) with torch.no_grad(): model.bias.zero_() model.weight.zero_() old_weights = np.copy(model.weight.data.numpy()) sample_probs = np.array(ips / torch.sum(ips)) Mbar = float(np.mean(sample_probs)) np.random.seed(rng_seed - 1) for t in range(n_iters): i = np.argwhere(np.random.multinomial(1, sample_probs) == 1.0)[0, 0] x = xs[i, :] y = labels[i] optimizer.zero_grad() o = model(x) l = loss_fn(o, y, Mbar) l.backward() optimizer.step() if t % plot_every == 0: optimizer.swap_swa_sgd() x, y = model.weight.data.numpy()[0] optimizer.swap_swa_sgd() ox, oy = old_weights[0] label = f"\\textsc{{CounterSample}} ($\\eta={lr}$)" arr = plt.arrow(ox, oy, x - ox, y - oy, width=arrow_width, length_includes_head=True, color="C2", label=label) optimizer.swap_swa_sgd() old_weights = np.copy(model.weight.data.numpy()) optimizer.swap_swa_sgd() legends[label] = arr # True IPS-weighted loss over all datapoints, used for plotting contour def f(x1, x2): w = torch.tensor([[x1], [x2]], dtype=torch.float) o = torch.mm(xs, w) return float(loss_fn(o, torch.mm(xs, wstar.reshape((2, 1))), ips)) # Compute all useful combinations of weights and compute true loss for each one # This will be used to compute a contour plot true_x1 = np.linspace(float(wstar[0, 0]) - 1.5, float(wstar[0, 0]) + 0.8) # - 1.5 / + 1.0 true_x2 = np.linspace(float(wstar[0, 1]) - 1.5, float(wstar[0, 1]) + 1.2) # - 1.5 / + 1.0 true_x1, true_x2 = np.meshgrid(true_x1, true_x2) true_y = np.array( [[f(true_x1[i1, i2], true_x2[i1, i2]) for i2 in range(len(true_x2))] for i1 in range(len(true_x1))]) # Contour plot with optimum plt.plot(wstar[0, 0], wstar[0, 1], marker='o', markersize=3, color="black") plt.contour(true_x1, true_x2, true_y, 10, colors="black", alpha=0.35) # Generate legends from arrows and make figure def make_legend_arrow(legend, orig_handle, xdescent, ydescent, width, height, fontsize): p = mpatches.FancyArrow(0, 0.5 * height, width, 0, length_includes_head=True, head_width=0.75 * height) return p def sort_op(key): if "CounterSample" in key: return "" else: return key labels = [key for key in sorted(legends.keys(), key=sort_op)] arrows = [legends[key] for key in sorted(legends.keys(), key=sort_op)] plt.legend(arrows, labels, ncol=2, loc='upper center', framealpha=0.95, handler_map={ mpatches.FancyArrow: HandlerPatch(patch_func=make_legend_arrow) }, bbox_to_anchor=(0.5, 1.0 + 0.03)) plt.xlabel("$w_1$") plt.ylabel("$w_2$") plt.tight_layout() plt.savefig(args.out, format=args.format)
def train(self): # prepare data train_data = self.data('train') train_steps = int((len(train_data) + self.config.batch_size - 1) / self.config.batch_size) train_dataloader = DataLoader(train_data, batch_size=self.config.batch_size, collate_fn=self.get_collate_fn('train'), shuffle=True, num_workers=2) # prepare optimizer params_lr = [{ "params": self.model.bert_parameters, 'lr': self.config.small_lr }, { "params": self.model.other_parameters, 'lr': self.config.large_lr }] optimizer = torch.optim.Adam(params_lr) optimizer = SWA(optimizer) # prepare early stopping early_stopping = EarlyStopping(self.model, self.config.best_model_path, big_server=BIG_GPU, mode='max', patience=10, verbose=True) # prepare learning schedual learning_schedual = LearningSchedual( optimizer, self.config.epochs, train_steps, [self.config.small_lr, self.config.large_lr]) # prepare other aux = REModelAux(self.config, train_steps) moving_log = MovingData(window=500) ending_flag = False # self.model.load_state_dict(torch.load(ROOT_SAVED_MODEL + 'temp_model.ckpt')) # # with torch.no_grad(): # self.model.eval() # print(self.eval()) # return for epoch in range(0, self.config.epochs): for step, (inputs, y_trues, spo_info) in enumerate(train_dataloader): inputs = [aaa.cuda() for aaa in inputs] y_trues = [aaa.cuda() for aaa in y_trues] if epoch > 0 or step == 1000: self.model.detach_bert = False # train ================================================================================================ preds = self.model(inputs) loss = self.calculate_loss(preds, y_trues, inputs[1], inputs[2]) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm(self.model.parameters(), 1) optimizer.step() with torch.no_grad(): logs = {'lr0': 0, 'lr1': 0} if (epoch > 0 or step > 620) and False: sbj_f1, spo_f1 = self.calculate_train_f1( spo_info[0], preds, spo_info[1:3], inputs[2].cpu().numpy()) metrics_data = { 'loss': loss.cpu().numpy(), 'sampled_num': 1, 'sbj_correct_num': sbj_f1[0], 'sbj_pred_num': sbj_f1[1], 'sbj_true_num': sbj_f1[2], 'spo_correct_num': spo_f1[0], 'spo_pred_num': spo_f1[1], 'spo_true_num': spo_f1[2] } moving_data = moving_log(epoch * train_steps + step, metrics_data) logs['loss'] = moving_data['loss'] / moving_data[ 'sampled_num'] logs['sbj_precise'], logs['sbj_recall'], logs[ 'sbj_f1'] = calculate_f1( moving_data['sbj_correct_num'], moving_data['sbj_pred_num'], moving_data['sbj_true_num'], verbose=True) logs['spo_precise'], logs['spo_recall'], logs[ 'spo_f1'] = calculate_f1( moving_data['spo_correct_num'], moving_data['spo_pred_num'], moving_data['spo_true_num'], verbose=True) else: metrics_data = { 'loss': loss.cpu().numpy(), 'sampled_num': 1 } moving_data = moving_log(epoch * train_steps + step, metrics_data) logs['loss'] = moving_data['loss'] / moving_data[ 'sampled_num'] # update lr logs['lr0'], logs['lr1'] = learning_schedual.update_lr( epoch, step) if step == int(train_steps / 2) or step + 1 == train_steps: self.model.eval() torch.save(self.model.state_dict(), ROOT_SAVED_MODEL + 'temp_model.ckpt') aux.new_line() # dev ========================================================================================== dev_result = self.eval() logs['dev_loss'] = dev_result['loss'] logs['dev_sbj_precise'] = dev_result['sbj_precise'] logs['dev_sbj_recall'] = dev_result['sbj_recall'] logs['dev_sbj_f1'] = dev_result['sbj_f1'] logs['dev_spo_precise'] = dev_result['spo_precise'] logs['dev_spo_recall'] = dev_result['spo_recall'] logs['dev_spo_f1'] = dev_result['spo_f1'] logs['dev_precise'] = dev_result['precise'] logs['dev_recall'] = dev_result['recall'] logs['dev_f1'] = dev_result['f1'] # other thing early_stopping(logs['dev_f1']) if logs['dev_f1'] > 0.730: optimizer.update_swa() # test ========================================================================================= if (epoch + 1 == self.config.epochs and step + 1 == train_steps) or early_stopping.early_stop: ending_flag = True optimizer.swap_swa_sgd() optimizer.bn_update(train_dataloader, self.model) torch.save(self.model.state_dict(), ROOT_SAVED_MODEL + 'swa.ckpt') self.test(ROOT_SAVED_MODEL + 'swa.ckpt') self.model.train() aux.show_log(epoch, step, logs) if ending_flag: return
class Train(object): """Train class. """ def __init__(self, train_ds, val_ds, fold): self.fold = fold self.init_lr = cfg.TRAIN.init_lr self.warup_step = cfg.TRAIN.warmup_step self.epochs = cfg.TRAIN.epoch self.batch_size = cfg.TRAIN.batch_size self.l2_regularization = cfg.TRAIN.weight_decay_factor self.device = torch.device( "cuda" if torch.cuda.is_available() else 'cpu') self.model = Net().to(self.device) self.load_weight() param_optimizer = list(self.model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': cfg.TRAIN.weight_decay_factor }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] if 'Adamw' in cfg.TRAIN.opt: self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.init_lr, eps=1.e-5) else: self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9) if cfg.TRAIN.SWA > 0: ##use swa self.optimizer = SWA(self.optimizer) if cfg.TRAIN.mix_precision: self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O1") self.ema = EMA(self.model, 0.999) self.ema.register() ###control vars self.iter_num = 0 self.train_ds = train_ds self.val_ds = val_ds # self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,mode='max', patience=3,verbose=True) self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, self.epochs, eta_min=1.e-6) self.criterion = nn.BCEWithLogitsLoss().to(self.device) def custom_loop(self): """Custom training and testing loop. Args: train_dist_dataset: Training dataset created using strategy. test_dist_dataset: Testing dataset created using strategy. strategy: Distribution strategy. Returns: train_loss, train_accuracy, test_loss, test_accuracy """ def distributed_train_epoch(epoch_num): summary_loss = AverageMeter() acc_score = ACCMeter() self.model.train() if cfg.MODEL.freeze_bn: for m in self.model.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() if cfg.MODEL.freeze_bn_affine: m.weight.requires_grad = False m.bias.requires_grad = False for step in range(self.train_ds.size): if epoch_num < 10: ###excute warm up in the first epoch if self.warup_step > 0: if self.iter_num < self.warup_step: for param_group in self.optimizer.param_groups: param_group['lr'] = self.iter_num / float( self.warup_step) * self.init_lr lr = param_group['lr'] logger.info('warm up with learning rate: [%f]' % (lr)) start = time.time() images, data, target = self.train_ds() images = torch.from_numpy(images).to(self.device).float() data = torch.from_numpy(data).to(self.device).float() target = torch.from_numpy(target).to(self.device).float() batch_size = data.shape[0] output = self.model(images, data) current_loss = self.criterion(output, target) summary_loss.update(current_loss.detach().item(), batch_size) acc_score.update(target, output) self.optimizer.zero_grad() if cfg.TRAIN.mix_precision: with amp.scale_loss(current_loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: current_loss.backward() self.optimizer.step() if cfg.MODEL.ema: self.ema.update() self.iter_num += 1 time_cost_per_batch = time.time() - start images_per_sec = cfg.TRAIN.batch_size / time_cost_per_batch if self.iter_num % cfg.TRAIN.log_interval == 0: log_message = '[fold %d], '\ 'Train Step %d, ' \ 'summary_loss: %.6f, ' \ 'accuracy: %.6f, ' \ 'time: %.6f, '\ 'speed %d images/persec'% ( self.fold, step, summary_loss.avg, acc_score.avg, time.time() - start, images_per_sec) logger.info(log_message) if cfg.TRAIN.SWA > 0 and epoch_num >= cfg.TRAIN.SWA: self.optimizer.update_swa() return summary_loss, acc_score def distributed_test_epoch(epoch_num): summary_loss = AverageMeter() acc_score = ACCMeter() self.model.eval() t = time.time() with torch.no_grad(): for step in range(self.val_ds.size): images, data, target = self.train_ds() images = torch.from_numpy(images).to(self.device).float() data = torch.from_numpy(data).to(self.device).float() target = torch.from_numpy(target).to(self.device).float() batch_size = data.shape[0] output = self.model(images, data) loss = self.criterion(output, target) summary_loss.update(loss.detach().item(), batch_size) acc_score.update(target, output) if step % cfg.TRAIN.log_interval == 0: log_message = '[fold %d], '\ 'Val Step %d, ' \ 'summary_loss: %.6f, ' \ 'acc: %.6f, ' \ 'time: %.6f' % ( self.fold,step, summary_loss.avg, acc_score.avg, time.time() - t) logger.info(log_message) return summary_loss, acc_score for epoch in range(self.epochs): for param_group in self.optimizer.param_groups: lr = param_group['lr'] logger.info('learning rate: [%f]' % (lr)) t = time.time() summary_loss, acc_score = distributed_train_epoch(epoch) train_epoch_log_message = '[fold %d], '\ '[RESULT]: Train. Epoch: %d,' \ ' summary_loss: %.5f,' \ ' acuracy: %.5f,' \ ' time:%.5f' % ( self.fold,epoch, summary_loss.avg,acc_score.avg, (time.time() - t)) logger.info(train_epoch_log_message) if cfg.TRAIN.SWA > 0 and epoch >= cfg.TRAIN.SWA: ###switch to avg model self.optimizer.swap_swa_sgd() ##switch eam weighta if cfg.MODEL.ema: self.ema.apply_shadow() if epoch % cfg.TRAIN.test_interval == 0: summary_loss, acc_score = distributed_test_epoch(epoch) val_epoch_log_message = '[fold %d], '\ '[RESULT]: VAL. Epoch: %d,' \ ' summary_loss: %.5f,' \ ' accuracy: %.5f,' \ ' time:%.5f' % ( self.fold,epoch, summary_loss.avg,acc_score.avg, (time.time() - t)) logger.info(val_epoch_log_message) self.scheduler.step() # self.scheduler.step(final_scores.avg) #### save model if not os.access(cfg.MODEL.model_path, os.F_OK): os.mkdir(cfg.MODEL.model_path) ###save the best auc model #### save the model every end of epoch current_model_saved_name = './models/fold%d_epoch_%d_val_loss%.6f.pth' % ( self.fold, epoch, summary_loss.avg) logger.info('A model saved to %s' % current_model_saved_name) torch.save(self.model.state_dict(), current_model_saved_name) ####switch back if cfg.MODEL.ema: self.ema.restore() # save_checkpoint({ # 'state_dict': self.model.state_dict(), # },iters=epoch,tag=current_model_saved_name) if cfg.TRAIN.SWA > 0 and epoch > cfg.TRAIN.SWA: ###switch back to plain model to train next epoch self.optimizer.swap_swa_sgd() def load_weight(self): if cfg.MODEL.pretrained_model is not None: state_dict = torch.load(cfg.MODEL.pretrained_model, map_location=self.device) self.model.load_state_dict(state_dict, strict=False)
step = epoch * len(train_dataloader) + i writer.add_scalar('Loss/train', total_loss / args.log_interval, step) total_loss = 0.0 # log gradient norm if args.log_grad_norm: for p in net.parameters(): norm = p.grad.data.norm(2) total_norm += norm.item() ** 2 norm = total_norm ** 0.5 writer.add_scalar('Gradient 2-Norm/train', norm, step) total_norm = 0.0 # accumulate gradients if (i + 1) % args.grad_accumulate_batches == 0: optimizer.step() net.zero_grad() # ------------- validation ------------- pbar = tqdm(test_dataloader) pbar.set_description('Validation') total_loss = 0.0 # total_pesq = 0.0 num_test_data = len(test_dataloader) with torch.no_grad(): net.eval() for i, batch in enumerate(pbar): x, y = map(lambda x: x.to(device), batch) y_hat = net(x) loss = criterion(y_hat, y) # pesq_score = evaluate(e, c, l, fn=cal_pesq)
train_bar = tqdm(train_loader) train_bar.set_description_str(desc=f"N epochs - {epoch}") for step, batch in enumerate(train_bar): global_step += 1 image = batch['image'].to(device) label4class = batch['label0'].to(device) label = batch['label1'].to(device) output4class, output = model(image) loss4class = criterion4class(output4class, label4class) loss = criterion(output.squeeze(), label) swa.zero_grad() total_loss = loss4class*0.5 + loss*0.5 total_loss.backward() swa.step() train_writer.add_scalar(tag="learning_rate", scalar_value=scheduler.get_lr()[0], global_step=global_step) train_writer.add_scalar(tag="BinaryLoss", scalar_value=loss.item(), global_step=global_step) train_writer.add_scalar(tag="SoftMaxLoss", scalar_value=loss4class.item(), global_step=global_step) train_bar.set_postfix_str(f"Loss = {loss.item()}") try: train_writer.add_scalar(tag="idrnd_score", scalar_value=idrnd_score_pytorch(label, output), global_step=global_step) train_writer.add_scalar(tag="far_score", scalar_value=far_score(label, output), global_step=global_step) train_writer.add_scalar(tag="frr_score", scalar_value=frr_score(label, output), global_step=global_step) train_writer.add_scalar(tag="accuracy", scalar_value=bce_accuracy(label, output), global_step=global_step) except Exception: pass if (epoch > config['swa_start'] and epoch % 2 == 0) or (epoch == config['number_epochs']-1): swa.swap_swa_sgd() swa.bn_update(train_loader, model, device)
def train(self, train_loader, eval_loader, epoch): # 定义优化器 if self.args.swa: logger.info('SWA training') base_opt = torch.optim.SGD(self.model.parameters(), lr=self.args.learning_rate) optimizer = SWA(base_opt, swa_start=self.args.swa_start, swa_freq=self.args.swa_freq, swa_lr=self.args.swa_lr) scheduler = CyclicLR( optimizer, base_lr=5e-5, max_lr=7e-5, step_size_up=(self.args.epochs * len(train_loader) / self.args.batch_accumulation), cycle_momentum=False) else: logger.info('Adam training') optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.learning_rate) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=self.args.warmup, num_training_steps=(self.args.epochs * len(train_loader) / self.args.batch_accumulation)) bar = tqdm(range(self.args.train_steps), total=self.args.train_steps) train_batches = cycle(train_loader) loss_sum = 0.0 start = time.time() self.model.train() for step in bar: batch = next(train_batches) input_ids, input_mask, segment_ids, label_ids = [ t.to(self.device) for t in batch ] loss, _ = self.model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids) if self.gpu_num > 1: loss = loss.mean() optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() # optimizer.update_swa() loss_sum += loss.cpu().item() train_loss = loss_sum / (step + 1) bar.set_description("loss {}".format(train_loss)) if (step + 1) % self.args.eval_steps == 0: logger.info("***** Training result *****") logger.info(' time %.2fs ', time.time() - start) logger.info(" %s = %s", 'global_step', str(step + 1)) logger.info(" %s = %s", 'train loss', str(train_loss)) # 每eval_steps进行一次evaluate self.result = { 'epoch': epoch, 'global_step': step + 1, 'loss': train_loss } if self.args.swa: optimizer.swap_swa_sgd() self.evaluate(eval_loader, epoch) if self.args.swa: optimizer.swap_swa_sgd() if self.args.swa: optimizer.swap_swa_sgd() logging.info('The training of epoch ' + str(epoch + 1) + ' has finished.')
def train_model(self, model, criterion, lr_scheduler): self.logger.info('Using: {}'.format(self.model_name)) self.logger.info('Using the GPU: {}'.format(self.gpu_id)) self.logger.info('start training...') train_loss = [] since = time.time() best_acc = 0.0 model.train(True) base_optimizer = optim.SGD((model.parameters()), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay) optimizer = SWA(base_optimizer, swa_start=10, swa_freq=5, swa_lr=0.001) # 余弦退火策略 # if lr_scheduler is cos_lr_scheduler: # return_lr_scheduler = lr_scheduler(optimizer, 5) # if lr_scheduler is exp_lr_scheduler: # return_lr_scheduler = lr_scheduler(optimizer) for epoch in range(self.num_epochs): begin_time = time.time() data_loaders, dset_sizes = self.loaddata( train_dir=self.train_dir, batch_size=self.train_batch_size, shuffle=True, is_train=True) self.logger.info('-' * 10) self.logger.info('Epoch {}/{}'.format(epoch, self.num_epochs - 1)) self.logger.info('learning rate:{}'.format( optimizer.param_groups[-1]['lr'])) self.logger.info('-' * 10) running_loss = 0.0 running_corrects = 0 count = 0 for i, data in enumerate(data_loaders): count += 1 inputs, labels = data labels = labels.type(torch.LongTensor) inputs, labels = inputs.cuda(), labels.cuda() optimizer.zero_grad() _, outputs = model(inputs) loss = criterion(outputs, labels) _, preds = torch.max(outputs.data, 1) loss.backward() optimizer.step() # if epoch > 5 and (epoch-1) % 5 == 0: # optimizer.update_swa() if i % self.print_interval == 0 or outputs.size( )[0] < self.train_batch_size: spend_time = time.time() - begin_time self.logger.info(' Epoch:{}({}/{}) loss:{:.3f} '.format( epoch, count, dset_sizes // self.train_batch_size, loss.item())) train_loss.append(loss.item()) running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) # if lr_scheduler is exp_lr_scheduler: # return_lr_scheduler.step() # if lr_scheduler is cos_lr_scheduler: # return_lr_scheduler.step() val_acc = self.test_model(model, criterion) # self.train_infer(model, epoch) epoch_loss = running_loss / dset_sizes epoch_acc = running_corrects.double() / dset_sizes # self.logger.info('Epoch:[{}/{}]\t Loss={:.5f}\t Acc={:.3f} epoch_Time:{} min:'.format(epoch , self.num_epochs-1, epoch_loss, epoch_acc, spend_time/60)) self.logger.info( 'Epoch:[{}/{}] Loss={:.5f} Acc={:.3f} Epoch_Time:{} min: ETA: {} hours' .format(epoch, self.num_epochs - 1, epoch_loss, epoch_acc, spend_time / 60, (self.num_epochs - epoch) * spend_time / 3600)) if val_acc > best_acc and epoch > self.min_save_epoch: best_acc = val_acc best_model_wts = model.state_dict() if val_acc > 0.999: break save_dir = os.path.join(self.out_dir, self.model_name) model_out_path = save_dir + "/" + '{}_'.format( self.model_name) + str(epoch) + '.pth' torch.save(model.module.state_dict(), model_out_path) # save best model self.logger.info('Best Accuracy: {:.3f}'.format(best_acc)) model.load_state_dict(best_model_wts) model_out_path = save_dir + "/" + '{}_best.pth'.format(self.model_name) torch.save(model, model_out_path) time_elapsed = time.time() - since self.logger.info('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60))
class Trainer(): def __init__(self, config_path): self.image_config, self.model_config, self.run_config = LoadConfig( config_path=config_path).train_config() self.device = torch.device('cuda:%d' % self.run_config['device_ids'][0] if torch. cuda.is_available else 'cpu') self.model = getModel(self.model_config) os.makedirs(self.run_config['model_save_path'], exist_ok=True) self.run_config['num_workers'] = self.run_config['num_workers'] * len( self.run_config['device_ids']) self.train_set = Data(root=self.image_config['image_path'], phase='train', data_name=self.image_config['data_name'], img_mode=self.image_config['image_mode'], n_classes=self.model_config['num_classes'], size=self.image_config['image_size'], scale=self.image_config['image_scale']) self.valid_set = Data(root=self.image_config['image_path'], phase='valid', data_name=self.image_config['data_name'], img_mode=self.image_config['image_mode'], n_classes=self.model_config['num_classes'], size=self.image_config['image_size'], scale=self.image_config['image_scale']) self.className = self.valid_set.className self.train_loader = DataLoader( self.train_set, batch_size=self.run_config['batch_size'], shuffle=True, num_workers=self.run_config['num_workers'], pin_memory=True, drop_last=False) self.valid_loader = DataLoader( self.valid_set, batch_size=self.run_config['batch_size'], shuffle=True, num_workers=self.run_config['num_workers'], pin_memory=True, drop_last=False) train_params = self.model.parameters() self.optimizer = RAdam(train_params, lr=eval(self.run_config['lr']), weight_decay=eval( self.run_config['weight_decay'])) if self.run_config['swa']: self.optimizer = SWA(self.optimizer, swa_start=10, swa_freq=5, swa_lr=0.005) # 设置学习率调节策略 self.lr_scheduler = utils.adjustLR.AdjustLr(self.optimizer) if self.run_config['use_weight_balance']: weight = utils.weight_balance.getWeight( self.run_config['weights_file']) else: weight = None self.Criterion = SegmentationLosses(weight=weight, cuda=True, device=self.device, batch_average=False) self.metric = utils.metrics.MetricMeter( self.model_config['num_classes']) @logger.catch # 在日志中记录错误 def __call__(self): # 设置记录日志 self.global_name = self.model_config['model_name'] logger.add(os.path.join( self.image_config['image_path'], 'log', 'log_' + self.global_name + '/train_{time}.log'), format="{time} {level} {message}", level="INFO", encoding='utf-8') self.writer = SummaryWriter(logdir=os.path.join( self.image_config['image_path'], 'run', 'runs_' + self.global_name)) logger.info("image_config: {} \n model_config: {} \n run_config: {}", self.image_config, self.model_config, self.run_config) # 如果多余一张卡,就采用数据并行 if len(self.run_config['device_ids']) > 1: self.model = nn.DataParallel( self.model, device_ids=self.run_config['device_ids']) self.model.to(device=self.device) cnt = 0 # 如果有预训练模型就加载 if self.run_config['pretrain'] != '': logger.info("loading pretrain %s" % self.run_config['pretrain']) try: self.load_checkpoint(use_optimizer=True, use_epoch=True, use_miou=True) except: print('load model with channed!!!!!') self.load_checkpoint_with_changed(use_optimizer=False, use_epoch=False, use_miou=False) logger.info("start training") for epoch in range(self.run_config['start_epoch'], self.run_config['epoch']): lr = self.optimizer.param_groups[0]['lr'] print('epoch=%d, lr=%.8f' % (epoch, lr)) self.train_epoch(epoch, lr) valid_miou = self.valid_epoch(epoch) # 确定采用哪一种学习率调节策略 self.lr_scheduler.LambdaLR_(milestone=5, gamma=0.92).step(epoch=epoch) self.save_checkpoint(epoch, valid_miou, 'last_' + self.global_name) if valid_miou > self.run_config['best_miou']: cnt = 0 self.save_checkpoint(epoch, valid_miou, 'best_' + self.global_name) logger.info("############# %d saved ##############" % epoch) self.run_config['best_miou'] = valid_miou else: cnt += 1 if cnt == self.run_config['early_stop']: logger.info("early stop") break self.writer.close() def train_epoch(self, epoch, lr): self.metric.reset() train_loss = 0.0 train_miou = 0.0 tbar = tqdm(self.train_loader) self.model.train() for i, (image, mask, edge) in enumerate(tbar): tbar.set_description('train_miou:%.6f' % train_miou) tbar.set_postfix({"train_loss": train_loss}) image = image.to(self.device) mask = mask.to(self.device) edge = edge.to(self.device) self.optimizer.zero_grad() out = self.model(image) if isinstance(out, tuple): aux_out, final_out = out[0], out[1] else: aux_out, final_out = None, out if self.model_config['model_name'] == 'ocrnet': aux_loss = self.Criterion.build_loss(mode='rmi')(aux_out, mask) cls_loss = self.Criterion.build_loss(mode='ce')(final_out, mask) loss = 0.4 * aux_loss + cls_loss loss = loss.mean() elif self.model_config['model_name'] == 'hrnet_duc': loss_body = self.Criterion.build_loss( mode=self.run_config['loss_type'])(final_out, mask) loss_edge = self.Criterion.build_loss(mode='dice')( aux_out.squeeze(), edge) loss = loss_body + loss_edge loss = loss.mean() else: loss = self.Criterion.build_loss( mode=self.run_config['loss_type'])(final_out, mask) loss.backward() self.optimizer.step() if self.run_config['swa']: self.optimizer.swap_swa_sgd() with torch.no_grad(): train_loss = ((train_loss * i) + loss.item()) / (i + 1) _, pred = torch.max(final_out, dim=1) self.metric.add(pred.cpu().numpy(), mask.cpu().numpy()) train_miou, train_ious = self.metric.miou() train_fwiou = self.metric.fw_iou() train_accu = self.metric.pixel_accuracy() train_fwaccu = self.metric.pixel_accuracy_class() logger.info( "Epoch:%2d\t lr:%.8f\t Train loss:%.4f\t Train FWiou:%.4f\t Train Miou:%.4f\t Train accu:%.4f\t " "Train fwaccu:%.4f" % (epoch, lr, train_loss, train_fwiou, train_miou, train_accu, train_fwaccu)) cls = "" ious = list() ious_dict = OrderedDict() for i, c in enumerate(self.className): ious_dict[c] = train_ious[i] ious.append(ious_dict[c]) cls += "%s:" % c + "%.4f " ious = tuple(ious) logger.info(cls % ious) # tensorboard self.writer.add_scalar("lr", lr, epoch) self.writer.add_scalar("loss/train_loss", train_loss, epoch) self.writer.add_scalar("miou/train_miou", train_miou, epoch) self.writer.add_scalar("fwiou/train_fwiou", train_fwiou, epoch) self.writer.add_scalar("accuracy/train_accu", train_accu, epoch) self.writer.add_scalar("fwaccuracy/train_fwaccu", train_fwaccu, epoch) self.writer.add_scalars("ious/train_ious", ious_dict, epoch) def valid_epoch(self, epoch): self.metric.reset() valid_loss = 0.0 valid_miou = 0.0 tbar = tqdm(self.valid_loader) self.model.eval() with torch.no_grad(): for i, (image, mask, edge) in enumerate(tbar): tbar.set_description('valid_miou:%.6f' % valid_miou) tbar.set_postfix({"valid_loss": valid_loss}) image = image.to(self.device) mask = mask.to(self.device) edge = edge.to(self.device) out = self.model(image) if isinstance(out, tuple): aux_out, final_out = out[0], out[1] else: aux_out, final_out = None, out if self.model_config['model_name'] == 'ocrnet': aux_loss = self.Criterion.build_loss(mode='rmi')(aux_out, mask) cls_loss = self.Criterion.build_loss(mode='ce')(final_out, mask) loss = 0.4 * aux_loss + cls_loss loss = loss.mean() elif self.model_config['model_name'] == 'hrnet_duc': loss_body = self.Criterion.build_loss( mode=self.run_config['loss_type'])(final_out, mask) loss_edge = self.Criterion.build_loss(mode='dice')( aux_out.squeeze(), edge) loss = loss_body + loss_edge # loss = loss.mean() else: loss = self.Criterion.build_loss(mode='ce')(final_out, mask) valid_loss = ((valid_loss * i) + float(loss)) / (i + 1) _, pred = torch.max(final_out, dim=1) self.metric.add(pred.cpu().numpy(), mask.cpu().numpy()) valid_miou, valid_ious = self.metric.miou() valid_fwiou = self.metric.fw_iou() valid_accu = self.metric.pixel_accuracy() valid_fwaccu = self.metric.pixel_accuracy_class() logger.info( "epoch:%d\t valid loss:%.4f\t valid fwiou:%.4f\t valid miou:%.4f valid accu:%.4f\t " "valid fwaccu:%.4f\t" % (epoch, valid_loss, valid_fwiou, valid_miou, valid_accu, valid_fwaccu)) ious = list() cls = "" ious_dict = OrderedDict() for i, c in enumerate(self.className): ious_dict[c] = valid_ious[i] ious.append(ious_dict[c]) cls += "%s:" % c + "%.4f " ious = tuple(ious) logger.info(cls % ious) self.writer.add_scalar("loss/valid_loss", valid_loss, epoch) self.writer.add_scalar("miou/valid_miou", valid_miou, epoch) self.writer.add_scalar("fwiou/valid_fwiou", valid_fwiou, epoch) self.writer.add_scalar("accuracy/valid_accu", valid_accu, epoch) self.writer.add_scalar("fwaccuracy/valid_fwaccu", valid_fwaccu, epoch) self.writer.add_scalars("ious/valid_ious", ious_dict, epoch) return valid_miou def save_checkpoint(self, epoch, best_miou, flag): meta = { 'epoch': epoch, 'model': self.model.state_dict(), 'optim': self.optimizer.state_dict(), 'bmiou': best_miou } try: torch.save(meta, os.path.join(self.run_config['model_save_path'], '%s.pth' % flag), _use_new_zipfile_serialization=False) except: torch.save( meta, os.path.join(self.run_config['model_save_path'], '%s.pth' % flag)) def load_checkpoint(self, use_optimizer, use_epoch, use_miou): state_dict = torch.load(self.run_config['pretrain'], map_location=self.device) self.model.load_state_dict(state_dict['model']) if use_optimizer: self.optimizer.load_state_dict(state_dict['optim']) if use_epoch: self.run_config['start_epoch'] = state_dict['epoch'] + 1 if use_miou: self.run_config['best_miou'] = state_dict['bmiou'] def load_checkpoint_with_changed(self, use_optimizer, use_epoch, use_miou): state_dict = torch.load(self.run_config['pretrain'], map_location=self.device) pretrain_dict = state_dict['model'] model_dict = self.model.state_dict() pretrain_dict = { k: v for k, v in pretrain_dict.items() if k in model_dict and 'edge' not in k } model_dict.update(pretrain_dict) self.model.load_state_dict(model_dict) if use_optimizer: self.optimizer.load_state_dict(state_dict['optim']) if use_epoch: self.run_config['start_epoch'] = state_dict['epoch'] + 1 if use_miou: self.run_config['best_miou'] = state_dict['bmiou']
def train_standard_model(train_loader, validation_loader, attrs, model_name, loss_name, epoch_number = 50, learning_rate = 0.0003, W = None): ''' train_loader : 訓練資料集 validation_loader : 訓練過程中用來幫助驗證的資料集 attrs : label model_name : 訓練完成的模型存檔名稱 loss_name : 訓練過程中的training loss存檔名稱 epoch_number : 訓練epoch數目 learning_rate : 模型的learning rate初始值 W : 模型是否要經過distortion,要的話放入channel ''' # 初始化模型(SWA代表Stochastic Weight Averaging) f = Discriminator() f.cuda() f.apply(weights_init_normal) f_optimizer = SWA(torch.optim.Adam(f.parameters(), lr = learning_rate, betas=(0.5, 0.999))) # learning rate每25次就變0.8倍 lr_scheduler_f = torch.optim.lr_scheduler.StepLR(f_optimizer, step_size = 25, gamma = 0.8) # 模型是否有distortion if W != None: print('Distortion channel is injected.') # 建立空list儲存模型訓練結果 model_standard_list = [] training_f_roc_list = [] training_f_accuracy_list = [] criterion_BCE = nn.BCELoss() # 由於一個epch的資料量很大,我們每1/10epoch保存一次模型和training loss each_save_number = len(train_loader) // 10 # 是否停止訓練的標記,初始為False stop_flag = False for epoch in range(epoch_number): seed_number = 999 # 更新seed, 以免每次batch都ㄧ樣資料 torch.manual_seed(seed_number + epoch) print('Now epoch number is ', epoch + 1) for i, (images, _, paths) in enumerate(train_loader): # 讀取label labels = file_name_dealer(paths, attrs = [attrs])[0] # 將資料轉為Variable images = to_var(images) # 有W的話就distortion if W != None: images = W(images) else: pass ############################################################### # classifier training # 先清除gradient f.zero_grad() # 訓練standard classifier的時候loss function用binary cross entropy f_loss = criterion_BCE(f(images), labels) # 計算gradient f_loss.backward(retain_graph=True) # 利用back propagation更新模型參數 f_optimizer.step() ################################################################ # 每1/10 epoch就保存模型+做validation if i % each_save_number == 0: print((i // each_save_number)*10,'%') ## 保存模型 ################################ print('Standard Models are saving...') # 將模型丟入放模型的list裡面 model_standard_list.append(copy.deepcopy(f)) # 將這個模型list存檔, 名稱是model_name pickle.dump(model_standard_list, open(model_name, 'wb')) # learning rate更新 lr_scheduler_f.step() ## validation ################################ f_predict_list = [] f_labels_list = [] for i, (images, _, paths) in enumerate(validation_loader): labels = file_name_dealer(paths, attrs = [attrs])[0] images = to_var(images) if W != None: images = W(images) else: pass # 保存prediction和真實的label f_predict_list += list(f(images).view(-1).detach().cpu().numpy()) f_labels_list += list(labels.view(-1).detach().cpu().numpy()) # 計算validation的accuracy, roc score f_score = roc_score(f_predict_list, f_labels_list) f_accuracy = correct_ratio(f_predict_list, f_labels_list, threshold = 0.5) print('auc:',f_score , 'accuracy:',f_accuracy) # 保存這些數值 training_f_roc_list.append(f_score) training_f_accuracy_list.append(f_accuracy) pickle.dump([training_f_roc_list, training_f_accuracy_list], open(loss_name, 'wb')) # 看看是否滿足early stopping條件,是的話就停止訓練 : stop_flag = stop_running(training_f_accuracy_list, number_threshold = 5) if stop_flag: break if stop_flag: break
class Model(): def __init__(self, configuration, pre_embed=None): configuration = deepcopy(configuration) self.configuration = deepcopy(configuration) configuration['model']['encoder']['pre_embed'] = pre_embed self.encoder = Encoder.from_params( Params(configuration['model']['encoder'])).to(device) configuration['model']['decoder'][ 'hidden_size'] = self.encoder.output_size self.decoder = AttnDecoder.from_params( Params(configuration['model']['decoder'])).to(device) self.encoder_params = list(self.encoder.parameters()) self.attn_params = list([ v for k, v in self.decoder.named_parameters() if 'attention' in k ]) self.decoder_params = list([ v for k, v in self.decoder.named_parameters() if 'attention' not in k ]) self.bsize = configuration['training']['bsize'] weight_decay = configuration['training'].get('weight_decay', 1e-5) self.encoder_optim = torch.optim.Adam(self.encoder_params, lr=0.001, weight_decay=weight_decay, amsgrad=True) self.attn_optim = torch.optim.Adam(self.attn_params, lr=0.001, weight_decay=0, amsgrad=True) self.decoder_optim = torch.optim.Adam(self.decoder_params, lr=0.001, weight_decay=weight_decay, amsgrad=True) self.adversarymulti = AdversaryMulti(decoder=self.decoder) self.all_params = self.encoder_params + self.attn_params + self.decoder_params self.all_optim = torch.optim.Adam(self.all_params, lr=0.001, weight_decay=weight_decay, amsgrad=True) # self.all_optim = adagrad.Adagrad(self.all_params, weight_decay=weight_decay) pos_weight = configuration['training'].get('pos_weight', [1.0] * self.decoder.output_size) self.pos_weight = torch.Tensor(pos_weight).to(device) self.criterion = nn.BCEWithLogitsLoss(reduction='none').to(device) self.swa_settings = configuration['training']['swa'] import time dirname = configuration['training']['exp_dirname'] basepath = configuration['training'].get('basepath', 'outputs') self.time_str = time.ctime().replace(' ', '_') self.dirname = os.path.join(basepath, dirname, self.time_str) self.temperature = configuration['training']['temperature'] self.train_losses = [] if self.swa_settings[0]: # self.attn_optim = SWA(self.attn_optim, swa_start=3, swa_freq=1, swa_lr=0.05) # self.decoder_optim = SWA(self.decoder_optim, swa_start=3, swa_freq=1, swa_lr=0.05) # self.encoder_optim = SWA(self.encoder_optim, swa_start=3, swa_freq=1, swa_lr=0.05) self.swa_all_optim = SWA(self.all_optim) self.running_norms = [] @classmethod def init_from_config(cls, dirname, **kwargs): config = json.load(open(dirname + '/config.json', 'r')) config.update(kwargs) obj = cls(config) obj.load_values(dirname) return obj def get_param_buffer_norms(self): for p in self.swa_all_optim.param_groups[0]['params']: param_state = self.swa_all_optim.state[p] if 'swa_buffer' not in param_state: self.swa_all_optim.update_swa() norms = [] # for p in np.array(self.swa_all_optim.param_groups[0]['params'])[[1, 2, 5, 6, 9]]: for p in np.array(self.swa_all_optim.param_groups[0]['params'])[[6, 9]]: param_state = self.swa_all_optim.state[p] buf = np.squeeze(param_state['swa_buffer'].cpu().numpy()) cur_state = np.squeeze(p.data.cpu().numpy()) norm = np.linalg.norm(buf - cur_state) norms.append(norm) if self.swa_settings[3] == 2: return np.max(norms) return np.mean(norms) def total_iter_num(self): return self.swa_all_optim.param_groups[0]['step_counter'] def iter_for_swa_update(self, iter_num): return iter_num > self.swa_settings[1] \ and iter_num % self.swa_settings[2] == 0 def check_and_update_swa(self): if self.iter_for_swa_update(self.total_iter_num()): cur_step_diff_norm = self.get_param_buffer_norms() if self.swa_settings[3] == 0: self.swa_all_optim.update_swa() return if not self.running_norms: running_mean_norm = 0 else: running_mean_norm = np.mean(self.running_norms) if cur_step_diff_norm > running_mean_norm: self.swa_all_optim.update_swa() self.running_norms = [cur_step_diff_norm] elif cur_step_diff_norm > 0: self.running_norms.append(cur_step_diff_norm) def train(self, data_in, target_in, train=True): sorting_idx = get_sorting_index_with_noise_from_lengths( [len(x) for x in data_in], noise_frac=0.1) data = [data_in[i] for i in sorting_idx] target = [target_in[i] for i in sorting_idx] self.encoder.train() self.decoder.train() bsize = self.bsize N = len(data) loss_total = 0 batches = list(range(0, N, bsize)) batches = shuffle(batches) for n in tqdm(batches): torch.cuda.empty_cache() batch_doc = data[n:n + bsize] batch_data = BatchHolder(batch_doc) self.encoder(batch_data) self.decoder(batch_data) batch_target = target[n:n + bsize] batch_target = torch.Tensor(batch_target).to(device) if len(batch_target.shape) == 1: #(B, ) batch_target = batch_target.unsqueeze(-1) #(B, 1) bce_loss = self.criterion(batch_data.predict / self.temperature, batch_target) weight = batch_target * self.pos_weight + (1 - batch_target) bce_loss = (bce_loss * weight).mean(1).sum() loss = bce_loss self.train_losses.append(bce_loss.detach().cpu().numpy() + 0) if hasattr(batch_data, 'reg_loss'): loss += batch_data.reg_loss if train: if self.swa_settings[0]: self.check_and_update_swa() self.swa_all_optim.zero_grad() loss.backward() self.swa_all_optim.step() else: # self.encoder_optim.zero_grad() # self.decoder_optim.zero_grad() # self.attn_optim.zero_grad() self.all_optim.zero_grad() loss.backward() # self.encoder_optim.step() # self.decoder_optim.step() # self.attn_optim.step() self.all_optim.step() loss_total += float(loss.data.cpu().item()) if self.swa_settings[0] and self.swa_all_optim.param_groups[0][ 'step_counter'] > self.swa_settings[1]: print("\nSWA swapping\n") # self.attn_optim.swap_swa_sgd() # self.encoder_optim.swap_swa_sgd() # self.decoder_optim.swap_swa_sgd() self.swa_all_optim.swap_swa_sgd() self.running_norms = [] return loss_total * bsize / N def predictor(self, inp_text_permutations): text_permutations = [ dataset_vec.map2idxs(x.split()) for x in inp_text_permutations ] outputs = [] bsize = 512 N = len(text_permutations) for n in range(0, N, bsize): torch.cuda.empty_cache() batch_doc = text_permutations[n:n + bsize] batch_data = BatchHolder(batch_doc) self.encoder(batch_data) self.decoder(batch_data) batch_data.predict = torch.sigmoid(batch_data.predict) pred = batch_data.predict.cpu().data.numpy() for i in range(len(pred)): if math.isnan(pred[i][0]): pred[i][0] = 0.5 outputs.extend(pred) ret_val = [[output_i[0], 1 - output_i[0]] for output_i in outputs] ret_val = np.array(ret_val) return ret_val def evaluate(self, data): self.encoder.eval() self.decoder.eval() bsize = self.bsize N = len(data) outputs = [] attns = [] for n in tqdm(range(0, N, bsize)): torch.cuda.empty_cache() batch_doc = data[n:n + bsize] batch_data = BatchHolder(batch_doc) self.encoder(batch_data) self.decoder(batch_data) batch_data.predict = torch.sigmoid(batch_data.predict / self.temperature) if self.decoder.use_attention: attn = batch_data.attn.cpu().data.numpy() attns.append(attn) predict = batch_data.predict.cpu().data.numpy() outputs.append(predict) outputs = [x for y in outputs for x in y] if self.decoder.use_attention: attns = [x for y in attns for x in y] return outputs, attns def get_lime_explanations(self, data): explanations = [] explainer = LimeTextExplainer(class_names=["A", "B"]) for data_i in data: sentence = ' '.join(dataset_vec.map2words(data_i)) exp = explainer.explain_instance(text_instance=sentence, classifier_fn=self.predictor, num_features=len(data_i), num_samples=5000).as_list() explanations.append(exp) return explanations def gradient_mem(self, data): self.encoder.train() self.decoder.train() bsize = self.bsize N = len(data) grads = {'XxE': [], 'XxE[X]': [], 'H': []} for n in tqdm(range(0, N, bsize)): torch.cuda.empty_cache() batch_doc = data[n:n + bsize] grads_xxe = [] grads_xxex = [] grads_H = [] for i in range(self.decoder.output_size): batch_data = BatchHolder(batch_doc) batch_data.keep_grads = True batch_data.detach = True self.encoder(batch_data) self.decoder(batch_data) torch.sigmoid(batch_data.predict[:, i]).sum().backward() g = batch_data.embedding.grad em = batch_data.embedding g1 = (g * em).sum(-1) grads_xxex.append(g1.cpu().data.numpy()) g1 = (g * self.encoder.embedding.weight.sum(0)).sum(-1) grads_xxe.append(g1.cpu().data.numpy()) g1 = batch_data.hidden.grad.sum(-1) grads_H.append(g1.cpu().data.numpy()) grads_xxe = np.array(grads_xxe).swapaxes(0, 1) grads_xxex = np.array(grads_xxex).swapaxes(0, 1) grads_H = np.array(grads_H).swapaxes(0, 1) import ipdb ipdb.set_trace() grads['XxE'].append(grads_xxe) grads['XxE[X]'].append(grads_xxex) grads['H'].append(grads_H) for k in grads: grads[k] = [x for y in grads[k] for x in y] return grads def remove_and_run(self, data): self.encoder.train() self.decoder.train() bsize = self.bsize N = len(data) outputs = [] for n in tqdm(range(0, N, bsize)): batch_doc = data[n:n + bsize] batch_data = BatchHolder(batch_doc) po = np.zeros( (batch_data.B, batch_data.maxlen, self.decoder.output_size)) for i in range(1, batch_data.maxlen - 1): batch_data = BatchHolder(batch_doc) batch_data.seq = torch.cat( [batch_data.seq[:, :i], batch_data.seq[:, i + 1:]], dim=-1) batch_data.lengths = batch_data.lengths - 1 batch_data.masks = torch.cat( [batch_data.masks[:, :i], batch_data.masks[:, i + 1:]], dim=-1) self.encoder(batch_data) self.decoder(batch_data) po[:, i] = torch.sigmoid(batch_data.predict).cpu().data.numpy() outputs.append(po) outputs = [x for y in outputs for x in y] return outputs def permute_attn(self, data, num_perm=100): self.encoder.train() self.decoder.train() bsize = self.bsize N = len(data) permutations = [] for n in tqdm(range(0, N, bsize)): torch.cuda.empty_cache() batch_doc = data[n:n + bsize] batch_data = BatchHolder(batch_doc) batch_perms = np.zeros( (batch_data.B, num_perm, self.decoder.output_size)) self.encoder(batch_data) self.decoder(batch_data) for i in range(num_perm): batch_data.permute = True self.decoder(batch_data) output = torch.sigmoid(batch_data.predict) batch_perms[:, i] = output.cpu().data.numpy() permutations.append(batch_perms) permutations = [x for y in permutations for x in y] return permutations def save_values(self, use_dirname=None, save_model=True, append_to_dir_name=''): if use_dirname is not None: dirname = use_dirname else: dirname = self.dirname + append_to_dir_name self.last_epch_dirname = dirname os.makedirs(dirname, exist_ok=True) shutil.copy2(file_name, dirname + '/') json.dump(self.configuration, open(dirname + '/config.json', 'w')) if save_model: torch.save(self.encoder.state_dict(), dirname + '/enc.th') torch.save(self.decoder.state_dict(), dirname + '/dec.th') return dirname def load_values(self, dirname): self.encoder.load_state_dict( torch.load(dirname + '/enc.th', map_location={'cuda:1': 'cuda:0'})) self.decoder.load_state_dict( torch.load(dirname + '/dec.th', map_location={'cuda:1': 'cuda:0'})) def adversarial_multi(self, data): self.encoder.eval() self.decoder.eval() for p in self.encoder.parameters(): p.requires_grad = False for p in self.decoder.parameters(): p.requires_grad = False bsize = self.bsize N = len(data) adverse_attn = [] adverse_output = [] for n in tqdm(range(0, N, bsize)): torch.cuda.empty_cache() batch_doc = data[n:n + bsize] batch_data = BatchHolder(batch_doc) self.encoder(batch_data) self.decoder(batch_data) self.adversarymulti(batch_data) attn_volatile = batch_data.attn_volatile.cpu().data.numpy( ) #(B, 10, L) predict_volatile = batch_data.predict_volatile.cpu().data.numpy( ) #(B, 10, O) adverse_attn.append(attn_volatile) adverse_output.append(predict_volatile) adverse_output = [x for y in adverse_output for x in y] adverse_attn = [x for y in adverse_attn for x in y] return adverse_output, adverse_attn def logodds_attention(self, data, logodds_map: Dict): self.encoder.eval() self.decoder.eval() bsize = self.bsize N = len(data) adverse_attn = [] adverse_output = [] logodds = np.zeros((self.encoder.vocab_size, )) for k, v in logodds_map.items(): if v is not None: logodds[k] = abs(v) else: logodds[k] = float('-inf') logodds = torch.Tensor(logodds).to(device) for n in tqdm(range(0, N, bsize)): torch.cuda.empty_cache() batch_doc = data[n:n + bsize] batch_data = BatchHolder(batch_doc) self.encoder(batch_data) self.decoder(batch_data) attn = batch_data.attn #(B, L) batch_data.attn_logodds = logodds[batch_data.seq] self.decoder.get_output_from_logodds(batch_data) attn_volatile = batch_data.attn_volatile.cpu().data.numpy( ) #(B, L) predict_volatile = torch.sigmoid( batch_data.predict_volatile).cpu().data.numpy() #(B, O) adverse_attn.append(attn_volatile) adverse_output.append(predict_volatile) adverse_output = [x for y in adverse_output for x in y] adverse_attn = [x for y in adverse_attn for x in y] return adverse_output, adverse_attn def logodds_substitution(self, data, top_logodds_words: Dict): self.encoder.eval() self.decoder.eval() bsize = self.bsize N = len(data) adverse_X = [] adverse_attn = [] adverse_output = [] words_neg = torch.Tensor( top_logodds_words[0][0]).long().cuda().unsqueeze(0) words_pos = torch.Tensor( top_logodds_words[0][1]).long().cuda().unsqueeze(0) words_to_select = torch.cat([words_neg, words_pos], dim=0) #(2, 5) for n in tqdm(range(0, N, bsize)): torch.cuda.empty_cache() batch_doc = data[n:n + bsize] batch_data = BatchHolder(batch_doc) self.encoder(batch_data) self.decoder(batch_data) predict_class = (torch.sigmoid(batch_data.predict).squeeze(-1) > 0.5) * 1 #(B,) attn = batch_data.attn #(B, L) top_val, top_idx = torch.topk(attn, 5, dim=-1) subs_words = words_to_select[1 - predict_class.long()] #(B, 5) batch_data.seq.scatter_(1, top_idx, subs_words) self.encoder(batch_data) self.decoder(batch_data) attn_volatile = batch_data.attn.cpu().data.numpy() #(B, L) predict_volatile = torch.sigmoid( batch_data.predict).cpu().data.numpy() #(B, O) X_volatile = batch_data.seq.cpu().data.numpy() adverse_X.append(X_volatile) adverse_attn.append(attn_volatile) adverse_output.append(predict_volatile) adverse_X = [x for y in adverse_X for x in y] adverse_output = [x for y in adverse_output for x in y] adverse_attn = [x for y in adverse_attn for x in y] return adverse_output, adverse_attn, adverse_X def predict(self, batch_data, lengths, masks): batch_holder = BatchHolderIndentity(batch_data, lengths, masks) self.encoder(batch_holder) self.decoder(batch_holder) # batch_holder.predict = torch.sigmoid(batch_holder.predict) predict = batch_holder.predict return predict
class Train(object): """Train class. """ def __init__(self, ): trainds = AlaskaDataIter(cfg.DATA.root_path, cfg.DATA.train_txt_path, training_flag=True) self.train_ds = DataLoader(trainds, cfg.TRAIN.batch_size, num_workers=cfg.TRAIN.process_num, shuffle=True) valds = AlaskaDataIter(cfg.DATA.root_path, cfg.DATA.val_txt_path, training_flag=False) self.val_ds = DataLoader(valds, cfg.TRAIN.batch_size, num_workers=cfg.TRAIN.process_num, shuffle=False) self.init_lr = cfg.TRAIN.init_lr self.warup_step = cfg.TRAIN.warmup_step self.epochs = cfg.TRAIN.epoch self.batch_size = cfg.TRAIN.batch_size self.l2_regularization = cfg.TRAIN.weight_decay_factor self.device = torch.device( "cuda" if torch.cuda.is_available() else 'cpu') self.model = CenterNet().to(self.device) self.load_weight() if 'Adamw' in cfg.TRAIN.opt: self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=self.init_lr, eps=1.e-5, weight_decay=self.l2_regularization) else: self.optimizer = torch.optim.SGD( self.model.parameters(), lr=self.init_lr, momentum=0.9, weight_decay=self.l2_regularization) if cfg.TRAIN.SWA > 0: ##use swa self.optimizer = SWA(self.optimizer) if cfg.TRAIN.mix_precision: self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O1") self.model = nn.DataParallel(self.model) self.ema = EMA(self.model, 0.999) self.ema.register() ###control vars self.iter_num = 0 # self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,mode='max', patience=3,verbose=True) self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, self.epochs, eta_min=1.e-6) self.criterion = CenterNetLoss().to(self.device) def custom_loop(self): """Custom training and testing loop. Args: train_dist_dataset: Training dataset created using strategy. test_dist_dataset: Testing dataset created using strategy. strategy: Distribution strategy. Returns: train_loss, train_accuracy, test_loss, test_accuracy """ def train_epoch(epoch_num): summary_loss_cls = AverageMeter() summary_loss_wh = AverageMeter() self.model.train() if cfg.MODEL.freeze_bn: for m in self.model.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() if cfg.MODEL.freeze_bn_affine: m.weight.requires_grad = False m.bias.requires_grad = False for image, hm_target, wh_target, weights in self.train_ds: if epoch_num < 10: ###excute warm up in the first epoch if self.warup_step > 0: if self.iter_num < self.warup_step: for param_group in self.optimizer.param_groups: param_group['lr'] = self.iter_num / float( self.warup_step) * self.init_lr lr = param_group['lr'] logger.info('warm up with learning rate: [%f]' % (lr)) start = time.time() if cfg.TRAIN.vis: for i in range(image.shape[0]): img = image[i].numpy() img = np.transpose(img, axes=[1, 2, 0]) hm = hm_target[i].numpy() wh = wh_target[i].numpy() if cfg.DATA.use_int8_data: hm = hm[:, :, 0].astype(np.uint8) wh = wh[:, :, 0] else: hm = hm[:, :, 0].astype(np.float32) wh = wh[:, :, 0].astype(np.float32) cv2.namedWindow('s_hm', 0) cv2.imshow('s_hm', hm) cv2.namedWindow('s_wh', 0) cv2.imshow('s_wh', wh + 1) cv2.namedWindow('img', 0) cv2.imshow('img', img) cv2.waitKey(0) else: data = image.to(self.device).float() if cfg.DATA.use_int8_data: hm_target = hm_target.to( self.device).float() / cfg.DATA.use_int8_enlarge else: hm_target = hm_target.to(self.device).float() wh_target = wh_target.to(self.device).float() weights = weights.to(self.device).float() batch_size = data.shape[0] cls, wh = self.model(data) cls_loss, wh_loss = self.criterion( [cls, wh], [hm_target, wh_target, weights]) current_loss = cls_loss + wh_loss summary_loss_cls.update(cls_loss.detach().item(), batch_size) summary_loss_wh.update(wh_loss.detach().item(), batch_size) self.optimizer.zero_grad() if cfg.TRAIN.mix_precision: with amp.scale_loss(current_loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: current_loss.backward() self.optimizer.step() if cfg.TRAIN.ema: self.ema.update() self.iter_num += 1 time_cost_per_batch = time.time() - start images_per_sec = cfg.TRAIN.batch_size * cfg.TRAIN.num_gpu / time_cost_per_batch if self.iter_num % cfg.TRAIN.log_interval == 0: log_message = '[TRAIN], '\ 'Epoch %d Step %d, ' \ 'summary_loss: %.6f, ' \ 'cls_loss: %.6f, '\ 'wh_loss: %.6f, ' \ 'time: %.6f, '\ 'speed %d images/persec'% ( epoch_num, self.iter_num, summary_loss_cls.avg+summary_loss_wh.avg, summary_loss_cls.avg , summary_loss_wh.avg, time.time() - start, images_per_sec) logger.info(log_message) if cfg.TRAIN.SWA > 0 and epoch_num >= cfg.TRAIN.SWA: self.optimizer.update_swa() return summary_loss_cls, summary_loss_wh def test_epoch(epoch_num): summary_loss_cls = AverageMeter() summary_loss_wh = AverageMeter() self.model.eval() t = time.time() with torch.no_grad(): for step, (image, hm_target, wh_target, weights) in enumerate(self.val_ds): data = image.to(self.device).float() if cfg.DATA.use_int8_data: hm_target = hm_target.to( self.device).float() / cfg.DATA.use_int8_enlarge else: hm_target = hm_target.to(self.device).float() wh_target = wh_target.to(self.device).float() weights = weights.to(self.device).float() batch_size = data.shape[0] with torch.no_grad(): cls, wh = self.model(data) cls_loss, wh_loss = self.criterion( [cls, wh], [hm_target, wh_target, weights]) summary_loss_cls.update(cls_loss.detach().item(), batch_size) summary_loss_wh.update(wh_loss.detach().item(), batch_size) if step % cfg.TRAIN.log_interval == 0: log_message = '[VAL], '\ 'Epoch %d Step %d, ' \ 'summary_loss: %.6f, ' \ 'cls_loss: %.6f, '\ 'wh_loss: %.6f, ' \ 'time: %.6f' % (epoch_num, step, summary_loss_cls.avg+summary_loss_wh.avg, summary_loss_cls.avg, summary_loss_wh.avg, time.time() - t) logger.info(log_message) return summary_loss_cls, summary_loss_wh for epoch in range(self.epochs): for param_group in self.optimizer.param_groups: lr = param_group['lr'] logger.info('learning rate: [%f]' % (lr)) t = time.time() summary_loss_cls, summary_loss_wh = train_epoch(epoch) train_epoch_log_message = '[centernet], '\ '[RESULT]: Train. Epoch: %d,' \ ' summary_loss: %.5f,' \ ' cls_loss: %.6f, ' \ ' wh_loss: %.6f, ' \ ' time:%.5f' % (epoch, summary_loss_cls.avg+summary_loss_wh.avg, summary_loss_cls.avg, summary_loss_wh.avg, (time.time() - t)) logger.info(train_epoch_log_message) if cfg.TRAIN.SWA > 0 and epoch >= cfg.TRAIN.SWA: ###switch to avg model self.optimizer.swap_swa_sgd() ##switch eam weighta if cfg.TRAIN.ema: self.ema.apply_shadow() if epoch % cfg.TRAIN.test_interval == 0: summary_loss_cls, summary_loss_wh = test_epoch(epoch) val_epoch_log_message = '[centernet], '\ '[RESULT]: VAL. Epoch: %d,' \ ' summary_loss: %.5f,' \ ' cls_loss: %.6f, ' \ ' wh_loss: %.6f, ' \ ' time:%.5f' % (epoch, summary_loss_cls.avg+summary_loss_wh.avg, summary_loss_cls.avg, summary_loss_wh.avg, (time.time() - t)) logger.info(val_epoch_log_message) self.scheduler.step() # self.scheduler.step(final_scores.avg) #### save model if not os.access(cfg.MODEL.model_path, os.F_OK): os.mkdir(cfg.MODEL.model_path) #### save the model every end of epoch current_model_saved_name = './model/centernet_epoch_%d_val_loss%.6f.pth' % ( epoch, summary_loss_cls.avg + summary_loss_wh.avg) logger.info('A model saved to %s' % current_model_saved_name) torch.save(self.model.module.state_dict(), current_model_saved_name) ####switch back if cfg.TRAIN.ema: self.ema.restore() if cfg.TRAIN.SWA > 0 and epoch > cfg.TRAIN.SWA: ###switch back to plain model to train next epoch self.optimizer.swap_swa_sgd() def load_weight(self): if cfg.MODEL.pretrained_model is not None: state_dict = torch.load(cfg.MODEL.pretrained_model, map_location=self.device) self.model.load_state_dict(state_dict, strict=False)