class Model(): def __init__(self, configuration, pre_embed=None): configuration = deepcopy(configuration) self.configuration = deepcopy(configuration) configuration['model']['encoder']['pre_embed'] = pre_embed self.encoder = Encoder.from_params( Params(configuration['model']['encoder'])).to(device) configuration['model']['decoder'][ 'hidden_size'] = self.encoder.output_size self.decoder = AttnDecoder.from_params( Params(configuration['model']['decoder'])).to(device) self.encoder_params = list(self.encoder.parameters()) self.attn_params = list([ v for k, v in self.decoder.named_parameters() if 'attention' in k ]) self.decoder_params = list([ v for k, v in self.decoder.named_parameters() if 'attention' not in k ]) self.bsize = configuration['training']['bsize'] weight_decay = configuration['training'].get('weight_decay', 1e-5) self.encoder_optim = torch.optim.Adam(self.encoder_params, lr=0.001, weight_decay=weight_decay, amsgrad=True) self.attn_optim = torch.optim.Adam(self.attn_params, lr=0.001, weight_decay=0, amsgrad=True) self.decoder_optim = torch.optim.Adam(self.decoder_params, lr=0.001, weight_decay=weight_decay, amsgrad=True) self.adversarymulti = AdversaryMulti(decoder=self.decoder) self.all_params = self.encoder_params + self.attn_params + self.decoder_params self.all_optim = torch.optim.Adam(self.all_params, lr=0.001, weight_decay=weight_decay, amsgrad=True) # self.all_optim = adagrad.Adagrad(self.all_params, weight_decay=weight_decay) pos_weight = configuration['training'].get('pos_weight', [1.0] * self.decoder.output_size) self.pos_weight = torch.Tensor(pos_weight).to(device) self.criterion = nn.BCEWithLogitsLoss(reduction='none').to(device) self.swa_settings = configuration['training']['swa'] import time dirname = configuration['training']['exp_dirname'] basepath = configuration['training'].get('basepath', 'outputs') self.time_str = time.ctime().replace(' ', '_') self.dirname = os.path.join(basepath, dirname, self.time_str) self.temperature = configuration['training']['temperature'] self.train_losses = [] if self.swa_settings[0]: # self.attn_optim = SWA(self.attn_optim, swa_start=3, swa_freq=1, swa_lr=0.05) # self.decoder_optim = SWA(self.decoder_optim, swa_start=3, swa_freq=1, swa_lr=0.05) # self.encoder_optim = SWA(self.encoder_optim, swa_start=3, swa_freq=1, swa_lr=0.05) self.swa_all_optim = SWA(self.all_optim) self.running_norms = [] @classmethod def init_from_config(cls, dirname, **kwargs): config = json.load(open(dirname + '/config.json', 'r')) config.update(kwargs) obj = cls(config) obj.load_values(dirname) return obj def get_param_buffer_norms(self): for p in self.swa_all_optim.param_groups[0]['params']: param_state = self.swa_all_optim.state[p] if 'swa_buffer' not in param_state: self.swa_all_optim.update_swa() norms = [] # for p in np.array(self.swa_all_optim.param_groups[0]['params'])[[1, 2, 5, 6, 9]]: for p in np.array(self.swa_all_optim.param_groups[0]['params'])[[6, 9]]: param_state = self.swa_all_optim.state[p] buf = np.squeeze(param_state['swa_buffer'].cpu().numpy()) cur_state = np.squeeze(p.data.cpu().numpy()) norm = np.linalg.norm(buf - cur_state) norms.append(norm) if self.swa_settings[3] == 2: return np.max(norms) return np.mean(norms) def total_iter_num(self): return self.swa_all_optim.param_groups[0]['step_counter'] def iter_for_swa_update(self, iter_num): return iter_num > self.swa_settings[1] \ and iter_num % self.swa_settings[2] == 0 def check_and_update_swa(self): if self.iter_for_swa_update(self.total_iter_num()): cur_step_diff_norm = self.get_param_buffer_norms() if self.swa_settings[3] == 0: self.swa_all_optim.update_swa() return if not self.running_norms: running_mean_norm = 0 else: running_mean_norm = np.mean(self.running_norms) if cur_step_diff_norm > running_mean_norm: self.swa_all_optim.update_swa() self.running_norms = [cur_step_diff_norm] elif cur_step_diff_norm > 0: self.running_norms.append(cur_step_diff_norm) def train(self, data_in, target_in, train=True): sorting_idx = get_sorting_index_with_noise_from_lengths( [len(x) for x in data_in], noise_frac=0.1) data = [data_in[i] for i in sorting_idx] target = [target_in[i] for i in sorting_idx] self.encoder.train() self.decoder.train() bsize = self.bsize N = len(data) loss_total = 0 batches = list(range(0, N, bsize)) batches = shuffle(batches) for n in tqdm(batches): torch.cuda.empty_cache() batch_doc = data[n:n + bsize] batch_data = BatchHolder(batch_doc) self.encoder(batch_data) self.decoder(batch_data) batch_target = target[n:n + bsize] batch_target = torch.Tensor(batch_target).to(device) if len(batch_target.shape) == 1: #(B, ) batch_target = batch_target.unsqueeze(-1) #(B, 1) bce_loss = self.criterion(batch_data.predict / self.temperature, batch_target) weight = batch_target * self.pos_weight + (1 - batch_target) bce_loss = (bce_loss * weight).mean(1).sum() loss = bce_loss self.train_losses.append(bce_loss.detach().cpu().numpy() + 0) if hasattr(batch_data, 'reg_loss'): loss += batch_data.reg_loss if train: if self.swa_settings[0]: self.check_and_update_swa() self.swa_all_optim.zero_grad() loss.backward() self.swa_all_optim.step() else: # self.encoder_optim.zero_grad() # self.decoder_optim.zero_grad() # self.attn_optim.zero_grad() self.all_optim.zero_grad() loss.backward() # self.encoder_optim.step() # self.decoder_optim.step() # self.attn_optim.step() self.all_optim.step() loss_total += float(loss.data.cpu().item()) if self.swa_settings[0] and self.swa_all_optim.param_groups[0][ 'step_counter'] > self.swa_settings[1]: print("\nSWA swapping\n") # self.attn_optim.swap_swa_sgd() # self.encoder_optim.swap_swa_sgd() # self.decoder_optim.swap_swa_sgd() self.swa_all_optim.swap_swa_sgd() self.running_norms = [] return loss_total * bsize / N def predictor(self, inp_text_permutations): text_permutations = [ dataset_vec.map2idxs(x.split()) for x in inp_text_permutations ] outputs = [] bsize = 512 N = len(text_permutations) for n in range(0, N, bsize): torch.cuda.empty_cache() batch_doc = text_permutations[n:n + bsize] batch_data = BatchHolder(batch_doc) self.encoder(batch_data) self.decoder(batch_data) batch_data.predict = torch.sigmoid(batch_data.predict) pred = batch_data.predict.cpu().data.numpy() for i in range(len(pred)): if math.isnan(pred[i][0]): pred[i][0] = 0.5 outputs.extend(pred) ret_val = [[output_i[0], 1 - output_i[0]] for output_i in outputs] ret_val = np.array(ret_val) return ret_val def evaluate(self, data): self.encoder.eval() self.decoder.eval() bsize = self.bsize N = len(data) outputs = [] attns = [] for n in tqdm(range(0, N, bsize)): torch.cuda.empty_cache() batch_doc = data[n:n + bsize] batch_data = BatchHolder(batch_doc) self.encoder(batch_data) self.decoder(batch_data) batch_data.predict = torch.sigmoid(batch_data.predict / self.temperature) if self.decoder.use_attention: attn = batch_data.attn.cpu().data.numpy() attns.append(attn) predict = batch_data.predict.cpu().data.numpy() outputs.append(predict) outputs = [x for y in outputs for x in y] if self.decoder.use_attention: attns = [x for y in attns for x in y] return outputs, attns def get_lime_explanations(self, data): explanations = [] explainer = LimeTextExplainer(class_names=["A", "B"]) for data_i in data: sentence = ' '.join(dataset_vec.map2words(data_i)) exp = explainer.explain_instance(text_instance=sentence, classifier_fn=self.predictor, num_features=len(data_i), num_samples=5000).as_list() explanations.append(exp) return explanations def gradient_mem(self, data): self.encoder.train() self.decoder.train() bsize = self.bsize N = len(data) grads = {'XxE': [], 'XxE[X]': [], 'H': []} for n in tqdm(range(0, N, bsize)): torch.cuda.empty_cache() batch_doc = data[n:n + bsize] grads_xxe = [] grads_xxex = [] grads_H = [] for i in range(self.decoder.output_size): batch_data = BatchHolder(batch_doc) batch_data.keep_grads = True batch_data.detach = True self.encoder(batch_data) self.decoder(batch_data) torch.sigmoid(batch_data.predict[:, i]).sum().backward() g = batch_data.embedding.grad em = batch_data.embedding g1 = (g * em).sum(-1) grads_xxex.append(g1.cpu().data.numpy()) g1 = (g * self.encoder.embedding.weight.sum(0)).sum(-1) grads_xxe.append(g1.cpu().data.numpy()) g1 = batch_data.hidden.grad.sum(-1) grads_H.append(g1.cpu().data.numpy()) grads_xxe = np.array(grads_xxe).swapaxes(0, 1) grads_xxex = np.array(grads_xxex).swapaxes(0, 1) grads_H = np.array(grads_H).swapaxes(0, 1) import ipdb ipdb.set_trace() grads['XxE'].append(grads_xxe) grads['XxE[X]'].append(grads_xxex) grads['H'].append(grads_H) for k in grads: grads[k] = [x for y in grads[k] for x in y] return grads def remove_and_run(self, data): self.encoder.train() self.decoder.train() bsize = self.bsize N = len(data) outputs = [] for n in tqdm(range(0, N, bsize)): batch_doc = data[n:n + bsize] batch_data = BatchHolder(batch_doc) po = np.zeros( (batch_data.B, batch_data.maxlen, self.decoder.output_size)) for i in range(1, batch_data.maxlen - 1): batch_data = BatchHolder(batch_doc) batch_data.seq = torch.cat( [batch_data.seq[:, :i], batch_data.seq[:, i + 1:]], dim=-1) batch_data.lengths = batch_data.lengths - 1 batch_data.masks = torch.cat( [batch_data.masks[:, :i], batch_data.masks[:, i + 1:]], dim=-1) self.encoder(batch_data) self.decoder(batch_data) po[:, i] = torch.sigmoid(batch_data.predict).cpu().data.numpy() outputs.append(po) outputs = [x for y in outputs for x in y] return outputs def permute_attn(self, data, num_perm=100): self.encoder.train() self.decoder.train() bsize = self.bsize N = len(data) permutations = [] for n in tqdm(range(0, N, bsize)): torch.cuda.empty_cache() batch_doc = data[n:n + bsize] batch_data = BatchHolder(batch_doc) batch_perms = np.zeros( (batch_data.B, num_perm, self.decoder.output_size)) self.encoder(batch_data) self.decoder(batch_data) for i in range(num_perm): batch_data.permute = True self.decoder(batch_data) output = torch.sigmoid(batch_data.predict) batch_perms[:, i] = output.cpu().data.numpy() permutations.append(batch_perms) permutations = [x for y in permutations for x in y] return permutations def save_values(self, use_dirname=None, save_model=True, append_to_dir_name=''): if use_dirname is not None: dirname = use_dirname else: dirname = self.dirname + append_to_dir_name self.last_epch_dirname = dirname os.makedirs(dirname, exist_ok=True) shutil.copy2(file_name, dirname + '/') json.dump(self.configuration, open(dirname + '/config.json', 'w')) if save_model: torch.save(self.encoder.state_dict(), dirname + '/enc.th') torch.save(self.decoder.state_dict(), dirname + '/dec.th') return dirname def load_values(self, dirname): self.encoder.load_state_dict( torch.load(dirname + '/enc.th', map_location={'cuda:1': 'cuda:0'})) self.decoder.load_state_dict( torch.load(dirname + '/dec.th', map_location={'cuda:1': 'cuda:0'})) def adversarial_multi(self, data): self.encoder.eval() self.decoder.eval() for p in self.encoder.parameters(): p.requires_grad = False for p in self.decoder.parameters(): p.requires_grad = False bsize = self.bsize N = len(data) adverse_attn = [] adverse_output = [] for n in tqdm(range(0, N, bsize)): torch.cuda.empty_cache() batch_doc = data[n:n + bsize] batch_data = BatchHolder(batch_doc) self.encoder(batch_data) self.decoder(batch_data) self.adversarymulti(batch_data) attn_volatile = batch_data.attn_volatile.cpu().data.numpy( ) #(B, 10, L) predict_volatile = batch_data.predict_volatile.cpu().data.numpy( ) #(B, 10, O) adverse_attn.append(attn_volatile) adverse_output.append(predict_volatile) adverse_output = [x for y in adverse_output for x in y] adverse_attn = [x for y in adverse_attn for x in y] return adverse_output, adverse_attn def logodds_attention(self, data, logodds_map: Dict): self.encoder.eval() self.decoder.eval() bsize = self.bsize N = len(data) adverse_attn = [] adverse_output = [] logodds = np.zeros((self.encoder.vocab_size, )) for k, v in logodds_map.items(): if v is not None: logodds[k] = abs(v) else: logodds[k] = float('-inf') logodds = torch.Tensor(logodds).to(device) for n in tqdm(range(0, N, bsize)): torch.cuda.empty_cache() batch_doc = data[n:n + bsize] batch_data = BatchHolder(batch_doc) self.encoder(batch_data) self.decoder(batch_data) attn = batch_data.attn #(B, L) batch_data.attn_logodds = logodds[batch_data.seq] self.decoder.get_output_from_logodds(batch_data) attn_volatile = batch_data.attn_volatile.cpu().data.numpy( ) #(B, L) predict_volatile = torch.sigmoid( batch_data.predict_volatile).cpu().data.numpy() #(B, O) adverse_attn.append(attn_volatile) adverse_output.append(predict_volatile) adverse_output = [x for y in adverse_output for x in y] adverse_attn = [x for y in adverse_attn for x in y] return adverse_output, adverse_attn def logodds_substitution(self, data, top_logodds_words: Dict): self.encoder.eval() self.decoder.eval() bsize = self.bsize N = len(data) adverse_X = [] adverse_attn = [] adverse_output = [] words_neg = torch.Tensor( top_logodds_words[0][0]).long().cuda().unsqueeze(0) words_pos = torch.Tensor( top_logodds_words[0][1]).long().cuda().unsqueeze(0) words_to_select = torch.cat([words_neg, words_pos], dim=0) #(2, 5) for n in tqdm(range(0, N, bsize)): torch.cuda.empty_cache() batch_doc = data[n:n + bsize] batch_data = BatchHolder(batch_doc) self.encoder(batch_data) self.decoder(batch_data) predict_class = (torch.sigmoid(batch_data.predict).squeeze(-1) > 0.5) * 1 #(B,) attn = batch_data.attn #(B, L) top_val, top_idx = torch.topk(attn, 5, dim=-1) subs_words = words_to_select[1 - predict_class.long()] #(B, 5) batch_data.seq.scatter_(1, top_idx, subs_words) self.encoder(batch_data) self.decoder(batch_data) attn_volatile = batch_data.attn.cpu().data.numpy() #(B, L) predict_volatile = torch.sigmoid( batch_data.predict).cpu().data.numpy() #(B, O) X_volatile = batch_data.seq.cpu().data.numpy() adverse_X.append(X_volatile) adverse_attn.append(attn_volatile) adverse_output.append(predict_volatile) adverse_X = [x for y in adverse_X for x in y] adverse_output = [x for y in adverse_output for x in y] adverse_attn = [x for y in adverse_attn for x in y] return adverse_output, adverse_attn, adverse_X def predict(self, batch_data, lengths, masks): batch_holder = BatchHolderIndentity(batch_data, lengths, masks) self.encoder(batch_holder) self.decoder(batch_holder) # batch_holder.predict = torch.sigmoid(batch_holder.predict) predict = batch_holder.predict return predict
def train(self): # prepare data train_data = self.data('train') train_steps = int((len(train_data) + self.config.batch_size - 1) / self.config.batch_size) train_dataloader = DataLoader(train_data, batch_size=self.config.batch_size, collate_fn=self.get_collate_fn('train'), shuffle=True, num_workers=2) # prepare optimizer params_lr = [{ "params": self.model.bert_parameters, 'lr': self.config.small_lr }, { "params": self.model.other_parameters, 'lr': self.config.large_lr }] optimizer = torch.optim.Adam(params_lr) optimizer = SWA(optimizer) # prepare early stopping early_stopping = EarlyStopping(self.model, self.config.best_model_path, big_server=BIG_GPU, mode='max', patience=10, verbose=True) # prepare learning schedual learning_schedual = LearningSchedual( optimizer, self.config.epochs, train_steps, [self.config.small_lr, self.config.large_lr]) # prepare other aux = REModelAux(self.config, train_steps) moving_log = MovingData(window=500) ending_flag = False # self.model.load_state_dict(torch.load(ROOT_SAVED_MODEL + 'temp_model.ckpt')) # # with torch.no_grad(): # self.model.eval() # print(self.eval()) # return for epoch in range(0, self.config.epochs): for step, (inputs, y_trues, spo_info) in enumerate(train_dataloader): inputs = [aaa.cuda() for aaa in inputs] y_trues = [aaa.cuda() for aaa in y_trues] if epoch > 0 or step == 1000: self.model.detach_bert = False # train ================================================================================================ preds = self.model(inputs) loss = self.calculate_loss(preds, y_trues, inputs[1], inputs[2]) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm(self.model.parameters(), 1) optimizer.step() with torch.no_grad(): logs = {'lr0': 0, 'lr1': 0} if (epoch > 0 or step > 620) and False: sbj_f1, spo_f1 = self.calculate_train_f1( spo_info[0], preds, spo_info[1:3], inputs[2].cpu().numpy()) metrics_data = { 'loss': loss.cpu().numpy(), 'sampled_num': 1, 'sbj_correct_num': sbj_f1[0], 'sbj_pred_num': sbj_f1[1], 'sbj_true_num': sbj_f1[2], 'spo_correct_num': spo_f1[0], 'spo_pred_num': spo_f1[1], 'spo_true_num': spo_f1[2] } moving_data = moving_log(epoch * train_steps + step, metrics_data) logs['loss'] = moving_data['loss'] / moving_data[ 'sampled_num'] logs['sbj_precise'], logs['sbj_recall'], logs[ 'sbj_f1'] = calculate_f1( moving_data['sbj_correct_num'], moving_data['sbj_pred_num'], moving_data['sbj_true_num'], verbose=True) logs['spo_precise'], logs['spo_recall'], logs[ 'spo_f1'] = calculate_f1( moving_data['spo_correct_num'], moving_data['spo_pred_num'], moving_data['spo_true_num'], verbose=True) else: metrics_data = { 'loss': loss.cpu().numpy(), 'sampled_num': 1 } moving_data = moving_log(epoch * train_steps + step, metrics_data) logs['loss'] = moving_data['loss'] / moving_data[ 'sampled_num'] # update lr logs['lr0'], logs['lr1'] = learning_schedual.update_lr( epoch, step) if step == int(train_steps / 2) or step + 1 == train_steps: self.model.eval() torch.save(self.model.state_dict(), ROOT_SAVED_MODEL + 'temp_model.ckpt') aux.new_line() # dev ========================================================================================== dev_result = self.eval() logs['dev_loss'] = dev_result['loss'] logs['dev_sbj_precise'] = dev_result['sbj_precise'] logs['dev_sbj_recall'] = dev_result['sbj_recall'] logs['dev_sbj_f1'] = dev_result['sbj_f1'] logs['dev_spo_precise'] = dev_result['spo_precise'] logs['dev_spo_recall'] = dev_result['spo_recall'] logs['dev_spo_f1'] = dev_result['spo_f1'] logs['dev_precise'] = dev_result['precise'] logs['dev_recall'] = dev_result['recall'] logs['dev_f1'] = dev_result['f1'] # other thing early_stopping(logs['dev_f1']) if logs['dev_f1'] > 0.730: optimizer.update_swa() # test ========================================================================================= if (epoch + 1 == self.config.epochs and step + 1 == train_steps) or early_stopping.early_stop: ending_flag = True optimizer.swap_swa_sgd() optimizer.bn_update(train_dataloader, self.model) torch.save(self.model.state_dict(), ROOT_SAVED_MODEL + 'swa.ckpt') self.test(ROOT_SAVED_MODEL + 'swa.ckpt') self.model.train() aux.show_log(epoch, step, logs) if ending_flag: return
def train(model, device, trainloader, testloader, optimizer, criterion, metric, epochs, learning_rate, swa=True, enable_scheduler=True, model_arch=''): ''' Function to perform model training. ''' model.to(device) steps = 0 running_loss = 0 running_metric = 0 print_every = 100 train_losses = [] test_losses = [] train_metrics = [] test_metrics = [] if swa: # initialize stochastic weight averaging opt = SWA(optimizer) else: opt = optimizer # learning rate cosine annealing if enable_scheduler: scheduler = lr_scheduler.CosineAnnealingLR(optimizer, len(trainloader), eta_min=0.0000001) for epoch in range(epochs): if enable_scheduler: scheduler.step() for inputs, labels in trainloader: steps += 1 # Move input and label tensors to the default device inputs, labels = inputs.to(device), labels.to(device) opt.zero_grad() outputs = model.forward(inputs) loss = criterion(outputs, labels.float()) loss.backward() opt.step() running_loss += loss running_metric += metric(outputs, labels.float()) if steps % print_every == 0: test_loss = 0 test_metric = 0 model.eval() with torch.no_grad(): for inputs, labels in testloader: inputs, labels = inputs.to(device), labels.to(device) outputs = model.forward(inputs) test_loss += criterion(outputs, labels.float()) test_metric += metric(outputs, labels.float()) print(f"Epoch {epoch+1}/{epochs}.. " f"Train loss: {running_loss/print_every:.3f}.. " f"Test loss: {test_loss/len(testloader):.3f}.. " f"Train metric: {running_metric/print_every:.3f}.. " f"Test metric: {test_metric/len(testloader):.3f}.. ") train_losses.append(running_loss / print_every) test_losses.append(test_loss / len(testloader)) train_metrics.append(running_metric / print_every) test_metrics.append(test_metric / len(testloader)) running_loss = 0 running_metric = 0 model.train() if swa: opt.update_swa() save_model(model, model_arch, learning_rate, epochs, train_losses, test_losses, train_metrics, test_metrics, filepath='models_checkpoints') if swa: opt.swap_swa_sgd() return model, train_losses, test_losses, train_metrics, test_metrics
class Train(object): """Train class. """ def __init__(self, train_ds, val_ds, fold): self.fold = fold self.init_lr = cfg.TRAIN.init_lr self.warup_step = cfg.TRAIN.warmup_step self.epochs = cfg.TRAIN.epoch self.batch_size = cfg.TRAIN.batch_size self.l2_regularization = cfg.TRAIN.weight_decay_factor self.device = torch.device( "cuda" if torch.cuda.is_available() else 'cpu') self.model = Net().to(self.device) self.load_weight() param_optimizer = list(self.model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': cfg.TRAIN.weight_decay_factor }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] if 'Adamw' in cfg.TRAIN.opt: self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.init_lr, eps=1.e-5) else: self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9) if cfg.TRAIN.SWA > 0: ##use swa self.optimizer = SWA(self.optimizer) if cfg.TRAIN.mix_precision: self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O1") self.ema = EMA(self.model, 0.999) self.ema.register() ###control vars self.iter_num = 0 self.train_ds = train_ds self.val_ds = val_ds # self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,mode='max', patience=3,verbose=True) self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, self.epochs, eta_min=1.e-6) self.criterion = nn.BCEWithLogitsLoss().to(self.device) def custom_loop(self): """Custom training and testing loop. Args: train_dist_dataset: Training dataset created using strategy. test_dist_dataset: Testing dataset created using strategy. strategy: Distribution strategy. Returns: train_loss, train_accuracy, test_loss, test_accuracy """ def distributed_train_epoch(epoch_num): summary_loss = AverageMeter() acc_score = ACCMeter() self.model.train() if cfg.MODEL.freeze_bn: for m in self.model.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() if cfg.MODEL.freeze_bn_affine: m.weight.requires_grad = False m.bias.requires_grad = False for step in range(self.train_ds.size): if epoch_num < 10: ###excute warm up in the first epoch if self.warup_step > 0: if self.iter_num < self.warup_step: for param_group in self.optimizer.param_groups: param_group['lr'] = self.iter_num / float( self.warup_step) * self.init_lr lr = param_group['lr'] logger.info('warm up with learning rate: [%f]' % (lr)) start = time.time() images, data, target = self.train_ds() images = torch.from_numpy(images).to(self.device).float() data = torch.from_numpy(data).to(self.device).float() target = torch.from_numpy(target).to(self.device).float() batch_size = data.shape[0] output = self.model(images, data) current_loss = self.criterion(output, target) summary_loss.update(current_loss.detach().item(), batch_size) acc_score.update(target, output) self.optimizer.zero_grad() if cfg.TRAIN.mix_precision: with amp.scale_loss(current_loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: current_loss.backward() self.optimizer.step() if cfg.MODEL.ema: self.ema.update() self.iter_num += 1 time_cost_per_batch = time.time() - start images_per_sec = cfg.TRAIN.batch_size / time_cost_per_batch if self.iter_num % cfg.TRAIN.log_interval == 0: log_message = '[fold %d], '\ 'Train Step %d, ' \ 'summary_loss: %.6f, ' \ 'accuracy: %.6f, ' \ 'time: %.6f, '\ 'speed %d images/persec'% ( self.fold, step, summary_loss.avg, acc_score.avg, time.time() - start, images_per_sec) logger.info(log_message) if cfg.TRAIN.SWA > 0 and epoch_num >= cfg.TRAIN.SWA: self.optimizer.update_swa() return summary_loss, acc_score def distributed_test_epoch(epoch_num): summary_loss = AverageMeter() acc_score = ACCMeter() self.model.eval() t = time.time() with torch.no_grad(): for step in range(self.val_ds.size): images, data, target = self.train_ds() images = torch.from_numpy(images).to(self.device).float() data = torch.from_numpy(data).to(self.device).float() target = torch.from_numpy(target).to(self.device).float() batch_size = data.shape[0] output = self.model(images, data) loss = self.criterion(output, target) summary_loss.update(loss.detach().item(), batch_size) acc_score.update(target, output) if step % cfg.TRAIN.log_interval == 0: log_message = '[fold %d], '\ 'Val Step %d, ' \ 'summary_loss: %.6f, ' \ 'acc: %.6f, ' \ 'time: %.6f' % ( self.fold,step, summary_loss.avg, acc_score.avg, time.time() - t) logger.info(log_message) return summary_loss, acc_score for epoch in range(self.epochs): for param_group in self.optimizer.param_groups: lr = param_group['lr'] logger.info('learning rate: [%f]' % (lr)) t = time.time() summary_loss, acc_score = distributed_train_epoch(epoch) train_epoch_log_message = '[fold %d], '\ '[RESULT]: Train. Epoch: %d,' \ ' summary_loss: %.5f,' \ ' acuracy: %.5f,' \ ' time:%.5f' % ( self.fold,epoch, summary_loss.avg,acc_score.avg, (time.time() - t)) logger.info(train_epoch_log_message) if cfg.TRAIN.SWA > 0 and epoch >= cfg.TRAIN.SWA: ###switch to avg model self.optimizer.swap_swa_sgd() ##switch eam weighta if cfg.MODEL.ema: self.ema.apply_shadow() if epoch % cfg.TRAIN.test_interval == 0: summary_loss, acc_score = distributed_test_epoch(epoch) val_epoch_log_message = '[fold %d], '\ '[RESULT]: VAL. Epoch: %d,' \ ' summary_loss: %.5f,' \ ' accuracy: %.5f,' \ ' time:%.5f' % ( self.fold,epoch, summary_loss.avg,acc_score.avg, (time.time() - t)) logger.info(val_epoch_log_message) self.scheduler.step() # self.scheduler.step(final_scores.avg) #### save model if not os.access(cfg.MODEL.model_path, os.F_OK): os.mkdir(cfg.MODEL.model_path) ###save the best auc model #### save the model every end of epoch current_model_saved_name = './models/fold%d_epoch_%d_val_loss%.6f.pth' % ( self.fold, epoch, summary_loss.avg) logger.info('A model saved to %s' % current_model_saved_name) torch.save(self.model.state_dict(), current_model_saved_name) ####switch back if cfg.MODEL.ema: self.ema.restore() # save_checkpoint({ # 'state_dict': self.model.state_dict(), # },iters=epoch,tag=current_model_saved_name) if cfg.TRAIN.SWA > 0 and epoch > cfg.TRAIN.SWA: ###switch back to plain model to train next epoch self.optimizer.swap_swa_sgd() def load_weight(self): if cfg.MODEL.pretrained_model is not None: state_dict = torch.load(cfg.MODEL.pretrained_model, map_location=self.device) self.model.load_state_dict(state_dict, strict=False)
def train(train_df, CONFIG): # set-up seed_everything(CONFIG['SEED']) torch.manual_seed(CONFIG['TORCH_SEED']) mlflow.log_params(CONFIG) TRAIN_LEN = len(train_df) train_dataset = TweetDataset(train_df, CONFIG) CRITERION = define_criterion(CONFIG) folds = StratifiedKFold(n_splits=CONFIG["FOLD"], shuffle=True, random_state=CONFIG["SEED"]) for n_fold, (train_idx, valid_idx) in enumerate( folds.split(train_dataset.df['textID'], train_dataset.df['sentiment'])): if n_fold != CONFIG["FOLD_NUM"]: continue ## DataLoaderの定義 train = torch.utils.data.Subset(train_dataset, train_idx) valid = torch.utils.data.Subset(train_dataset, valid_idx) DATA_IN_EPOCH = len(train) TOTAL_DATA = DATA_IN_EPOCH * CONFIG["EPOCHS"] T_TOTAL = int(CONFIG["EPOCHS"] * DATA_IN_EPOCH / CONFIG["TRAIN_BATCH_SIZE"]) ## modelとoptimizerの初期化 model = build_model(CONFIG) model.to(DEVICE) model.train() ## From 20/05/17 param_optimizer = list(model.named_parameters()) bert_params = [n for n, p in param_optimizer if "bert" in n] no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ ## BERT param { 'params': [ p for n, p in param_optimizer if (not any(nd in n for nd in no_decay)) and (n in bert_params) ], 'weight_decay': CONFIG["WEIGHT_DECAY"], 'lr': CONFIG['LR'] * 1, }, { 'params': [ p for n, p in param_optimizer if (any(nd in n for nd in no_decay)) and (n in bert_params) ], 'weight_decay': 0.0, 'lr': CONFIG['LR'] * 1, }, ## Other param { 'params': [ p for n, p in param_optimizer if (not any(nd in n for nd in no_decay)) and (n not in bert_params) ], 'weight_decay': CONFIG["WEIGHT_DECAY"], 'lr': CONFIG['LR'] * 1, }, { 'params': [ p for n, p in param_optimizer if (any(nd in n for nd in no_decay)) and (n not in bert_params) ], 'weight_decay': 0.0, 'lr': CONFIG['LR'] * 1, }, ] optimizer = torch.optim.AdamW(optimizer_grouped_parameters, eps=4e-5) if CONFIG['SWA']: optimizer = SWA(optimizer) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=int(CONFIG["WARMUP"] * T_TOTAL), num_training_steps=T_TOTAL) train_sampler = SentimentBalanceSampler(train, CONFIG) train_loader = DataLoader(train, batch_size=CONFIG["TRAIN_BATCH_SIZE"], shuffle=False, sampler=train_sampler, collate_fn=TweetCollate(CONFIG), num_workers=1) valid_loader = DataLoader( valid, batch_size=CONFIG["VALID_BATCH_SIZE"], shuffle=False, # sampler = valid_sampler, collate_fn=TweetCollate(CONFIG), num_workers=1) n_data = 0 n_e_data = 0 best_val = 0.0 best_val_neu = 0.0 best_val_pos = 0.0 best_val_neg = 0.0 t_batch = 0 while n_data < TOTAL_DATA: print(f"Epoch : {int(n_data/DATA_IN_EPOCH)}") n_batch = 0 loss_list = [] jac_token_list = [] jac_text_list = [] jac_sentiment_list = [] jac_cl_text_list = [] output_list = [] target_list = [] for batch in tqdm(train_loader): textID = batch['textID'] text = batch['text'] sentiment = batch['sentiment'] cl_text = batch['cl_text'] selected_text = batch['selected_text'] cl_selected_text = batch['cl_selected_text'] text_idx = batch['text_idx'] offsets = batch['offsets'] tokenized_text = batch['tokenized_text'].to(DEVICE) mask = batch['mask'].to(DEVICE) mask_out = batch['mask_out'].to(DEVICE) token_type_ids = batch['token_type_ids'].to(DEVICE) weight = batch['weight'].to(DEVICE) target = batch['target'].to(DEVICE) ep = int(n_data / DATA_IN_EPOCH) n_data += len(textID) n_e_data += len(textID) n_batch += 1 t_batch += 1 model.zero_grad() # optimizer.zero_grad() output = model(input_ids=tokenized_text, attention_mask=mask, token_type_ids=token_type_ids, mask_out=mask_out) loss = CRITERION(output, target) loss = loss * weight loss.mean().backward() loss = loss.detach().cpu().numpy().tolist() optimizer.step() if t_batch < T_TOTAL * 0.50: scheduler.step() loss_list.extend(loss) output = output.detach().cpu().numpy() target = target.detach().cpu().numpy() jac = calc_jaccard(output, batch, CONFIG) jac_token_list.extend(jac['jaccard_token'].tolist()) jac_cl_text_list.extend(jac['jaccard_cl_text'].tolist()) jac_text_list.extend(jac['jaccard_text'].tolist()) jac_sentiment_list.extend(sentiment) if ((((ep > 0) & (n_batch % (int(5 * 32 / CONFIG["TRAIN_BATCH_SIZE"])) == 0)) | ((n_batch % (int(50 * 32 / CONFIG["TRAIN_BATCH_SIZE"])) == 0)) | (n_data >= TOTAL_DATA)) and (CONFIG['SWA'] and ep > 0 and t_batch >= T_TOTAL * 0.50)): optimizer.update_swa() if ( ((ep > 0) & (n_batch % (int(50 * 32 / CONFIG["TRAIN_BATCH_SIZE"])) == 0)) | ((n_batch % (int(50 * 32 / CONFIG["TRAIN_BATCH_SIZE"])) == 0)) | (n_data >= TOTAL_DATA) ): # ((n_data>=0)&(n_data<=1600)|(n_data>=21000)&(n_data<=23000))& if CONFIG['SWA'] and ep > 0 and t_batch >= T_TOTAL * 0.50: # optimizer.update_swa() optimizer.swap_swa_sgd() val = create_valid(model, valid_loader, CONFIG) trn_loss = np.array(loss_list).mean() trn_jac_token = np.array(jac_token_list).mean() trn_jac_cl_text = np.array(jac_cl_text_list).mean() trn_jac_text = np.array(jac_text_list).mean() trn_jac_text_neu = np.array(jac_text_list)[np.array( jac_sentiment_list) == 'neutral'].mean() trn_jac_text_pos = np.array(jac_text_list)[np.array( jac_sentiment_list) == 'positive'].mean() trn_jac_text_neg = np.array(jac_text_list)[np.array( jac_sentiment_list) == 'negative'].mean() val_loss = val['loss'].mean() val_jac_token = val['jaccard_token'].mean() val_jac_cl_text = val['jaccard_cl_text'].mean() val_jac_text = val['jaccard_text'].mean() val_jac_text_neu = val['jaccard_text'][val['sentiment'] == 'neutral'].mean() val_jac_text_pos = val['jaccard_text'][val['sentiment'] == 'positive'].mean() val_jac_text_neg = val['jaccard_text'][val['sentiment'] == 'negative'].mean() loss_list = [] jac_token_list = [] jac_cl_text_list = [] jac_text_list = [] jac_sentiment_list = [] # mlflow metrics = { "lr": optimizer.param_groups[0]['lr'], "trn_loss": trn_loss, "trn_jac_text_neu": trn_jac_text_neu, "trn_jac_text_pos": trn_jac_text_pos, "trn_jac_text_neg": trn_jac_text_neg, "trn_jac_text": trn_jac_text, "val_loss": val_loss, "val_jac_text_neu": val_jac_text_neu, "val_jac_text_pos": val_jac_text_pos, "val_jac_text_neg": val_jac_text_neg, "val_jac_text": val_jac_text, } mlflow.log_metrics(metrics, step=n_data) if CONFIG['SWA'] and t_batch < T_TOTAL * 0.50: pass else: if best_val < val_jac_text: best_val = val_jac_text best_model = copy.deepcopy(model) if CONFIG['SWA'] and ep > 0 and t_batch >= T_TOTAL * 0.50: optimizer.swap_swa_sgd() if n_e_data >= DATA_IN_EPOCH: n_e_data -= DATA_IN_EPOCH if n_data >= TOTAL_DATA: filepath = os.path.join(FILE_DIR, OUTPUT_DIR, "model.pth") torch.save(best_model.state_dict(), filepath) # mlflow mlflow.log_artifact(filepath) break
def fit( model, train_dataset, val_dataset, optimizer_name="adam", samples_per_player=0, epochs=50, batch_size=32, val_bs=32, warmup_prop=0.1, lr=1e-3, acc_steps=1, swa_first_epoch=50, num_classes_aux=0, aux_mode="sigmoid", verbose=1, first_epoch_eval=0, device="cuda", ): """ Fitting function for the classification task. Args: model (torch model): Model to train. train_dataset (torch dataset): Dataset to train with. val_dataset (torch dataset): Dataset to validate with. optimizer_name (str, optional): Optimizer name. Defaults to 'adam'. samples_per_player (int, optional): Number of images to use per player. Defaults to 0. epochs (int, optional): Number of epochs. Defaults to 50. batch_size (int, optional): Training batch size. Defaults to 32. val_bs (int, optional): Validation batch size. Defaults to 32. warmup_prop (float, optional): Warmup proportion. Defaults to 0.1. lr (float, optional): Learning rate. Defaults to 1e-3. acc_steps (int, optional): Accumulation steps. Defaults to 1. swa_first_epoch (int, optional): Epoch to start applying SWA from. Defaults to 50. num_classes_aux (int, optional): Number of auxiliary classes. Defaults to 0. aux_mode (str, optional): Mode for auxiliary classification. Defaults to 'sigmoid'. verbose (int, optional): Period (in epochs) to display logs at. Defaults to 1. first_epoch_eval (int, optional): Epoch to start evaluating at. Defaults to 0. device (str, optional): Device for torch. Defaults to "cuda". Returns: numpy array [len(val_dataset)]: Last predictions on the validation data. numpy array [len(val_dataset) x num_classes_aux]: Last aux predictions on the val data. """ optimizer = define_optimizer(optimizer_name, model.parameters(), lr=lr) if swa_first_epoch <= epochs: optimizer = SWA(optimizer) loss_fct = nn.BCEWithLogitsLoss() loss_fct_aux = nn.BCEWithLogitsLoss( ) if aux_mode == "sigmoid" else nn.CrossEntropyLoss() aux_loss_weight = 1 if num_classes_aux else 0 if samples_per_player: sampler = PlayerSampler( RandomSampler(train_dataset), train_dataset.players, batch_size=batch_size, drop_last=True, samples_per_player=samples_per_player, ) train_loader = DataLoader( train_dataset, batch_sampler=sampler, num_workers=NUM_WORKERS, pin_memory=True, ) print( f"Using {len(train_loader)} out of {len(train_dataset) // batch_size} " f"batches by limiting to {samples_per_player} samples per player.\n" ) else: train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=NUM_WORKERS, pin_memory=True, ) val_loader = DataLoader( val_dataset, batch_size=val_bs, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, ) num_training_steps = int(epochs * len(train_loader)) num_warmup_steps = int(warmup_prop * num_training_steps) scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps) for epoch in range(epochs): model.train() start_time = time.time() optimizer.zero_grad() avg_loss = 0 if epoch + 1 > swa_first_epoch: optimizer.swap_swa_sgd() for batch in train_loader: images = batch[0].to(device) y_batch = batch[1].to(device).view(-1).float() y_batch_aux = batch[2].to(device).float() y_batch_aux = y_batch_aux.float( ) if aux_mode == "sigmoid" else y_batch_aux.long() y_pred, y_pred_aux = model(images) loss = loss_fct(y_pred.view(-1), y_batch) if aux_loss_weight: loss += aux_loss_weight * loss_fct_aux(y_pred_aux, y_batch_aux) loss.backward() avg_loss += loss.item() / len(train_loader) optimizer.step() scheduler.step() for param in model.parameters(): param.grad = None if epoch + 1 >= swa_first_epoch: optimizer.update_swa() optimizer.swap_swa_sgd() preds = np.empty(0) preds_aux = np.empty((0, num_classes_aux)) model.eval() avg_val_loss, auc, scores_aux = 0., 0., 0. if epoch + 1 >= first_epoch_eval or epoch + 1 == epochs: with torch.no_grad(): for batch in val_loader: images = batch[0].to(device) y_batch = batch[1].to(device).view(-1).float() y_aux = batch[2].to(device).float() y_batch_aux = y_aux.float( ) if aux_mode == "sigmoid" else y_aux.long() y_pred, y_pred_aux = model(images) loss = loss_fct(y_pred.detach().view(-1), y_batch) if aux_loss_weight: loss += aux_loss_weight * loss_fct_aux( y_pred_aux.detach(), y_batch_aux) avg_val_loss += loss.item() / len(val_loader) y_pred = torch.sigmoid(y_pred).view(-1) preds = np.concatenate( [preds, y_pred.detach().cpu().numpy()]) if num_classes_aux: y_pred_aux = (y_pred_aux.sigmoid() if aux_mode == "sigmoid" else y_pred_aux.softmax(-1)) preds_aux = np.concatenate( [preds_aux, y_pred_aux.detach().cpu().numpy()]) auc = roc_auc_score(val_dataset.labels, preds) if num_classes_aux: if aux_mode == "sigmoid": scores_aux = np.round( [ roc_auc_score(val_dataset.aux_labels[:, i], preds_aux[:, i]) for i in range(num_classes_aux) ], 3, ).tolist() else: scores_aux = np.round( [ roc_auc_score((val_dataset.aux_labels == i).astype(int), preds_aux[:, i]) for i in range(num_classes_aux) ], 3, ).tolist() else: scores_aux = 0 elapsed_time = time.time() - start_time if (epoch + 1) % verbose == 0: elapsed_time = elapsed_time * verbose lr = scheduler.get_last_lr()[0] print( f"Epoch {epoch + 1:02d}/{epochs:02d} \t lr={lr:.1e}\t t={elapsed_time:.0f}s \t" f"loss={avg_loss:.3f}", end="\t", ) if epoch + 1 >= first_epoch_eval: print( f"val_loss={avg_val_loss:.3f} \t auc={auc:.3f}\t aucs_aux={scores_aux}" ) else: print("") del val_loader, train_loader, y_pred torch.cuda.empty_cache() return preds, preds_aux
def train(self, train_inputs): config = self.config.fitting model = train_inputs['model'] train_data = train_inputs['train_data'] dev_data = train_inputs['dev_data'] epoch_start = train_inputs['epoch_start'] train_steps = int((len(train_data) + config.batch_size - 1) / config.batch_size) train_dataloader = DataLoader(train_data, batch_size=config.batch_size, collate_fn=self.get_collate_fn('train'), shuffle=True) params_lr = [] for key, value in model.get_params().items(): if key in config.lr: params_lr.append({"params": value, 'lr': config.lr[key]}) optimizer = torch.optim.Adam(params_lr) optimizer = SWA(optimizer) early_stopping = EarlyStopping(model, ROOT_WEIGHT, mode='max', patience=3) learning_schedual = LearningSchedual(optimizer, config.epochs, config.end_epoch, train_steps, config.lr) aux = ModelAux(self.config, train_steps) moving_log = MovingData(window=100) ending_flag = False detach_flag = False swa_flag = False fgm = FGM(model) for epoch in range(epoch_start, config.epochs): for step, (inputs, targets, others) in enumerate(train_dataloader): inputs = dict([(key, value[0].cuda() if value[1] else value[0]) for key, value in inputs.items()]) targets = dict([(key, value.cuda()) for key, value in targets.items()]) if epoch > 0 and step == 0: model.detach_ptm(False) detach_flag = False if epoch == 0 and step == 0: model.detach_ptm(True) detach_flag = True # train ================================================================================================ preds = model(inputs, en_decode=config.verbose) loss = model.cal_loss(preds, targets, inputs['mask']) loss['back'].backward() # 对抗训练 if (not detach_flag) and config.en_fgm: fgm.attack(emb_name='word_embeddings') # 在embedding上添加对抗扰动 preds_adv = model(inputs, en_decode=False) loss_adv = model.cal_loss(preds_adv, targets, inputs['mask']) loss_adv['back'].backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度 fgm.restore(emb_name='word_embeddings') # 恢复embedding参数 # torch.nn.utils.clip_grad_norm(model.parameters(), 1) optimizer.step() optimizer.zero_grad() with torch.no_grad(): logs = {} if config.verbose: pred_entity_point = model.find_entity(preds['pred'], others['raw_text']) cn, pn, tn = self.calculate_f1(pred_entity_point, others['raw_entity']) metrics_data = {'loss': loss['show'].cpu().numpy(), 'sampled_num': 1, 'correct_num': cn, 'pred_num': pn, 'true_num': tn} moving_data = moving_log(epoch * train_steps + step, metrics_data) logs['loss'] = moving_data['loss'] / moving_data['sampled_num'] logs['precise'], logs['recall'], logs['f1'] = calculate_f1(moving_data['correct_num'], moving_data['pred_num'], moving_data['true_num'], verbose=True) else: metrics_data = {'loss': loss['show'].cpu().numpy(), 'sampled_num': 1} moving_data = moving_log(epoch * train_steps + step, metrics_data) logs['loss'] = moving_data['loss'] / moving_data['sampled_num'] # update lr lr_data = learning_schedual.update_lr(epoch, step) logs.update(lr_data) if step + 1 == train_steps: model.eval() aux.new_line() # dev ========================================================================================== eval_inputs = {'model': model, 'data': dev_data, 'type_data': 'dev', 'outfile': train_inputs['dev_res_file']} dev_result = self.eval(eval_inputs) logs['dev_loss'] = dev_result['loss'] logs['dev_precise'] = dev_result['precise'] logs['dev_recall'] = dev_result['recall'] logs['dev_f1'] = dev_result['f1'] if logs['dev_f1'] > 0.80: torch.save(model.state_dict(), "{}/auto_save_{:.6f}.ckpt".format(ROOT_WEIGHT, logs['dev_f1'])) if (epoch > 3 or swa_flag) and config.en_swa: optimizer.update_swa() swa_flag = True early_stop, best_score = early_stopping(logs['dev_f1']) # test ========================================================================================= if (epoch + 1 == config.epochs and step + 1 == train_steps) or early_stop: ending_flag = True if swa_flag: optimizer.swap_swa_sgd() optimizer.bn_update(train_dataloader, model) model.train() aux.show_log(epoch, step, logs) if ending_flag: return best_score
def main(): maxIOU = 0.0 assert torch.cuda.is_available() torch.backends.cudnn.benchmark = True model_fname = '../data/model_swa_8/deeplabv3_{0}_epoch%d.pth'.format( 'crops') focal_loss = FocalLoss2d() train_dataset = CropSegmentation(train=True, crop_size=args.crop_size) # test_dataset = CropSegmentation(train=False, crop_size=args.crop_size) model = torchvision.models.segmentation.deeplabv3_resnet50( pretrained=False, progress=True, num_classes=5, aux_loss=True) if args.train: weight = np.ones(4) weight[2] = 5 weight[3] = 5 w = torch.FloatTensor(weight).cuda() criterion = nn.CrossEntropyLoss() #ignore_index=255 weight=w model = nn.DataParallel(model).cuda() for param in model.parameters(): param.requires_grad = True optimizer1 = optim.SGD(model.parameters(), lr=config.lr, momentum=0.9, weight_decay=1e-4) scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=(args.epochs // 9) + 1) optimizer = SWA(optimizer1) dataset_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=args.train, pin_memory=True, num_workers=args.workers) max_iter = args.epochs * len(dataset_loader) losses = AverageMeter() start_epoch = 0 if args.resume: if os.path.isfile(args.resume): print('=> loading checkpoint {0}'.format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print('=> loaded checkpoint {0} (epoch {1})'.format( args.resume, checkpoint['epoch'])) else: print('=> no checkpoint found at {0}'.format(args.resume)) for epoch in range(start_epoch, args.epochs): scheduler.step(epoch) model.train() for i, (inputs, target) in enumerate(dataset_loader): inputs = Variable(inputs.cuda()) target = Variable(target.cuda()) outputs = model(inputs) loss1 = focal_loss(outputs['out'], target) loss2 = focal_loss(outputs['aux'], target) loss01 = loss1 + 0.1 * loss2 loss3 = lovasz_softmax(outputs['out'], target) loss4 = lovasz_softmax(outputs['aux'], target) loss02 = loss3 + 0.1 * loss4 loss = loss01 + loss02 if np.isnan(loss.item()) or np.isinf(loss.item()): pdb.set_trace() losses.update(loss.item(), args.batch_size) loss.backward() optimizer.step() optimizer.zero_grad() if i > 10 and i % 5 == 0: optimizer.update_swa() print('epoch: {0}\t' 'iter: {1}/{2}\t' 'loss: {loss.val:.4f} ({loss.ema:.4f})'.format( epoch + 1, i + 1, len(dataset_loader), loss=losses)) if epoch > 5: torch.save( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, model_fname % (epoch + 1)) optimizer.swap_swa_sgd() torch.save( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, model_fname % (665 + 1))
class Model() : def __init__(self, configuration, pre_embed=None) : configuration = deepcopy(configuration) self.configuration = deepcopy(configuration) configuration['model']['encoder']['pre_embed'] = pre_embed encoder_copy = deepcopy(configuration['model']['encoder']) self.Pencoder = Encoder.from_params(Params(configuration['model']['encoder'])).to(device) self.Qencoder = Encoder.from_params(Params(encoder_copy)).to(device) configuration['model']['decoder']['hidden_size'] = self.Pencoder.output_size self.decoder = AttnDecoderQA.from_params(Params(configuration['model']['decoder'])).to(device) self.bsize = configuration['training']['bsize'] self.adversary_multi = AdversaryMulti(self.decoder) weight_decay = configuration['training'].get('weight_decay', 1e-5) self.params = list(self.Pencoder.parameters()) + list(self.Qencoder.parameters()) + list(self.decoder.parameters()) self.optim = torch.optim.Adam(self.params, weight_decay=weight_decay, amsgrad=True) # self.optim = torch.optim.Adagrad(self.params, lr=0.05, weight_decay=weight_decay) self.criterion = nn.CrossEntropyLoss() import time dirname = configuration['training']['exp_dirname'] basepath = configuration['training'].get('basepath', 'outputs') self.time_str = time.ctime().replace(' ', '_') self.dirname = os.path.join(basepath, dirname, self.time_str) self.swa_settings = configuration['training']['swa'] if self.swa_settings[0]: self.swa_all_optim = SWA(self.optim) self.running_norms = [] @classmethod def init_from_config(cls, dirname, **kwargs) : config = json.load(open(dirname + '/config.json', 'r')) config.update(kwargs) obj = cls(config) obj.load_values(dirname) return obj def get_param_buffer_norms(self): for p in self.swa_all_optim.param_groups[0]['params']: param_state = self.swa_all_optim.state[p] if 'swa_buffer' not in param_state: self.swa_all_optim.update_swa() norms = [] for p in np.array(self.swa_all_optim.param_groups[0]['params'])[ [1, 2, 5, 6, 10, 11, 14, 15, 18, 20, 24, 26]]: param_state = self.swa_all_optim.state[p] buf = np.squeeze( param_state['swa_buffer'].cpu().numpy()) cur_state = np.squeeze(p.data.cpu().numpy()) norm = np.linalg.norm(buf - cur_state) norms.append(norm) if self.swa_settings[3] == 2: return np.max(norms) return np.mean(norms) def total_iter_num(self): return self.swa_all_optim.param_groups[0]['step_counter'] def iter_for_swa_update(self, iter_num): return iter_num > self.swa_settings[1] \ and iter_num % self.swa_settings[2] == 0 def check_and_update_swa(self): if self.iter_for_swa_update(self.total_iter_num()): cur_step_diff_norm = self.get_param_buffer_norms() if self.swa_settings[3] == 0: self.swa_all_optim.update_swa() return if not self.running_norms: running_mean_norm = 0 else: running_mean_norm = np.mean(self.running_norms) if cur_step_diff_norm > running_mean_norm: self.swa_all_optim.update_swa() self.running_norms = [cur_step_diff_norm] elif cur_step_diff_norm > 0: self.running_norms.append(cur_step_diff_norm) def train(self, train_data, train=True) : docs_in = train_data.P question_in = train_data.Q entity_masks_in = train_data.E target_in = train_data.A sorting_idx = get_sorting_index_with_noise_from_lengths([len(x) for x in docs_in], noise_frac=0.1) docs = [docs_in[i] for i in sorting_idx] questions = [question_in[i] for i in sorting_idx] entity_masks = [entity_masks_in[i] for i in sorting_idx] target = [target_in[i] for i in sorting_idx] self.Pencoder.train() self.Qencoder.train() self.decoder.train() bsize = self.bsize N = len(questions) loss_total = 0 batches = list(range(0, N, bsize)) batches = shuffle(batches) for n in tqdm(batches) : torch.cuda.empty_cache() batch_doc = docs[n:n+bsize] batch_ques = questions[n:n+bsize] batch_entity_masks = entity_masks[n:n+bsize] batch_doc = BatchHolder(batch_doc) batch_ques = BatchHolder(batch_ques) batch_data = BatchMultiHolder(P=batch_doc, Q=batch_ques) batch_data.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device) self.Pencoder(batch_data.P) self.Qencoder(batch_data.Q) self.decoder(batch_data) batch_target = target[n:n+bsize] batch_target = torch.LongTensor(batch_target).to(device) ce_loss = self.criterion(batch_data.predict, batch_target) loss = ce_loss if hasattr(batch_data, 'reg_loss') : loss += batch_data.reg_loss if train : if self.swa_settings[0]: self.check_and_update_swa() self.swa_all_optim.zero_grad() loss.backward() self.swa_all_optim.step() else: self.optim.zero_grad() loss.backward() self.optim.step() loss_total += float(loss.data.cpu().item()) if self.swa_settings[0] and self.swa_all_optim.param_groups[0][ 'step_counter'] > self.swa_settings[1]: print("\nSWA swapping\n") # self.attn_optim.swap_swa_sgd() # self.encoder_optim.swap_swa_sgd() # self.decoder_optim.swap_swa_sgd() self.swa_all_optim.swap_swa_sgd() self.running_norms = [] return loss_total*bsize/N def evaluate(self, data) : docs = data.P questions = data.Q entity_masks = data.E self.Pencoder.train() self.Qencoder.train() self.decoder.train() bsize = self.bsize N = len(questions) outputs = [] attns = [] scores = [] for n in tqdm(range(0, N, bsize)) : torch.cuda.empty_cache() batch_doc = docs[n:n+bsize] batch_ques = questions[n:n+bsize] batch_entity_masks = entity_masks[n:n+bsize] batch_doc = BatchHolder(batch_doc) batch_ques = BatchHolder(batch_ques) batch_data = BatchMultiHolder(P=batch_doc, Q=batch_ques) batch_data.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device) self.Pencoder(batch_data.P) self.Qencoder(batch_data.Q) self.decoder(batch_data) prediction_scores = batch_data.predict.cpu().data.numpy() batch_data.predict = torch.argmax(batch_data.predict, dim=-1) if self.decoder.use_attention : attn = batch_data.attn attns.append(attn.cpu().data.numpy()) predict = batch_data.predict.cpu().data.numpy() outputs.append(predict) scores.append(prediction_scores) outputs = [x for y in outputs for x in y] attns = [x for y in attns for x in y] scores = [x for y in scores for x in y] return outputs, attns, scores def save_values(self, use_dirname=None, save_model=True) : if use_dirname is not None : dirname = use_dirname else : dirname = self.dirname os.makedirs(dirname, exist_ok=True) shutil.copy2(file_name, dirname + '/') json.dump(self.configuration, open(dirname + '/config.json', 'w')) if save_model : torch.save(self.Pencoder.state_dict(), dirname + '/encP.th') torch.save(self.Qencoder.state_dict(), dirname + '/encQ.th') torch.save(self.decoder.state_dict(), dirname + '/dec.th') return dirname def load_values(self, dirname) : self.Pencoder.load_state_dict(torch.load(dirname + '/encP.th')) self.Qencoder.load_state_dict(torch.load(dirname + '/encQ.th')) self.decoder.load_state_dict(torch.load(dirname + '/dec.th')) def permute_attn(self, data, num_perm=100) : docs = data.P questions = data.Q entity_masks = data.E self.Pencoder.train() self.Qencoder.train() self.decoder.train() bsize = self.bsize N = len(questions) permutations_predict = [] permutations_diff = [] for n in tqdm(range(0, N, bsize)) : torch.cuda.empty_cache() batch_doc = docs[n:n+bsize] batch_ques = questions[n:n+bsize] batch_entity_masks = entity_masks[n:n+bsize] batch_doc = BatchHolder(batch_doc) batch_ques = BatchHolder(batch_ques) batch_data = BatchMultiHolder(P=batch_doc, Q=batch_ques) batch_data.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device) self.Pencoder(batch_data.P) self.Qencoder(batch_data.Q) self.decoder(batch_data) predict_true = batch_data.predict.clone().detach() batch_perms_predict = np.zeros((batch_data.P.B, num_perm)) batch_perms_diff = np.zeros((batch_data.P.B, num_perm)) for i in range(num_perm) : batch_data.permute = True self.decoder(batch_data) predict = torch.argmax(batch_data.predict, dim=-1) batch_perms_predict[:, i] = predict.cpu().data.numpy() predict_difference = self.adversary_multi.output_diff(batch_data.predict, predict_true) batch_perms_diff[:, i] = predict_difference.squeeze(-1).cpu().data.numpy() permutations_predict.append(batch_perms_predict) permutations_diff.append(batch_perms_diff) permutations_predict = [x for y in permutations_predict for x in y] permutations_diff = [x for y in permutations_diff for x in y] return permutations_predict, permutations_diff def adversarial_multi(self, data) : docs = data.P questions = data.Q entity_masks = data.E self.Pencoder.eval() self.Qencoder.eval() self.decoder.eval() print(self.adversary_multi.K) self.params = list(self.Pencoder.parameters()) + list(self.Qencoder.parameters()) + list(self.decoder.parameters()) for p in self.params : p.requires_grad = False bsize = self.bsize N = len(questions) batches = list(range(0, N, bsize)) outputs, attns, diffs = [], [], [] for n in tqdm(batches) : torch.cuda.empty_cache() batch_doc = docs[n:n+bsize] batch_ques = questions[n:n+bsize] batch_entity_masks = entity_masks[n:n+bsize] batch_doc = BatchHolder(batch_doc) batch_ques = BatchHolder(batch_ques) batch_data = BatchMultiHolder(P=batch_doc, Q=batch_ques) batch_data.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device) self.Pencoder(batch_data.P) self.Qencoder(batch_data.Q) self.decoder(batch_data) self.adversary_multi(batch_data) predict_volatile = torch.argmax(batch_data.predict_volatile, dim=-1) outputs.append(predict_volatile.cpu().data.numpy()) attn = batch_data.attn_volatile attns.append(attn.cpu().data.numpy()) predict_difference = self.adversary_multi.output_diff(batch_data.predict_volatile, batch_data.predict.unsqueeze(1)) diffs.append(predict_difference.cpu().data.numpy()) outputs = [x for y in outputs for x in y] attns = [x for y in attns for x in y] diffs = [x for y in diffs for x in y] return outputs, attns, diffs def gradient_mem(self, data) : docs = data.P questions = data.Q entity_masks = data.E self.Pencoder.train() self.Qencoder.train() self.decoder.train() bsize = self.bsize N = len(questions) grads = {'XxE' : [], 'XxE[X]' : [], 'H' : []} for n in range(0, N, bsize) : torch.cuda.empty_cache() batch_doc = docs[n:n+bsize] batch_ques = questions[n:n+bsize] batch_entity_masks = entity_masks[n:n+bsize] batch_doc = BatchHolder(batch_doc) batch_ques = BatchHolder(batch_ques) batch_data = BatchMultiHolder(P=batch_doc, Q=batch_ques) batch_data.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device) batch_data.P.keep_grads = True batch_data.detach = True self.Pencoder(batch_data.P) self.Qencoder(batch_data.Q) self.decoder(batch_data) max_predict = torch.argmax(batch_data.predict, dim=-1) prob_predict = nn.Softmax(dim=-1)(batch_data.predict) max_class_prob = torch.gather(prob_predict, -1, max_predict.unsqueeze(-1)) max_class_prob.sum().backward() g = batch_data.P.embedding.grad em = batch_data.P.embedding g1 = (g * em).sum(-1) grads['XxE[X]'].append(g1.cpu().data.numpy()) g1 = (g * self.Pencoder.embedding.weight.sum(0)).sum(-1) grads['XxE'].append(g1.cpu().data.numpy()) g1 = batch_data.P.hidden.grad.sum(-1) grads['H'].append(g1.cpu().data.numpy()) for k in grads : grads[k] = [x for y in grads[k] for x in y] return grads def remove_and_run(self, data) : docs = data.P questions = data.Q entity_masks = data.E self.Pencoder.train() self.Qencoder.train() self.decoder.train() bsize = self.bsize N = len(questions) output_diffs = [] for n in tqdm(range(0, N, bsize)) : torch.cuda.empty_cache() batch_doc = docs[n:n+bsize] batch_ques = questions[n:n+bsize] batch_entity_masks = entity_masks[n:n+bsize] batch_doc = BatchHolder(batch_doc) batch_ques = BatchHolder(batch_ques) batch_data = BatchMultiHolder(P=batch_doc, Q=batch_ques) batch_data.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device) self.Pencoder(batch_data.P) self.Qencoder(batch_data.Q) self.decoder(batch_data) po = np.zeros((batch_data.P.B, batch_data.P.maxlen)) for i in range(1, batch_data.P.maxlen - 1) : batch_doc = BatchHolder(docs[n:n+bsize]) batch_doc.seq = torch.cat([batch_doc.seq[:, :i], batch_doc.seq[:, i+1:]], dim=-1) batch_doc.lengths = batch_doc.lengths - 1 batch_doc.masks = torch.cat([batch_doc.masks[:, :i], batch_doc.masks[:, i+1:]], dim=-1) batch_data_loop = BatchMultiHolder(P=batch_doc, Q=batch_ques) batch_data_loop.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device) self.Pencoder(batch_data_loop.P) self.decoder(batch_data_loop) predict_difference = self.adversary_multi.output_diff(batch_data_loop.predict, batch_data.predict) po[:, i] = predict_difference.squeeze(-1).cpu().data.numpy() output_diffs.append(po) output_diffs = [x for y in output_diffs for x in y] return output_diffs
def train(opt): model = www_model_jamo_vertical.STR(opt, device) print( 'model parameters. height {}, width {}, num of fiducial {}, input channel {}, output channel {}, hidden size {}, batch max length {}' .format(opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.batch_max_length)) # weight initialization for name, param, in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initializaed') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: if 'weight' in name: param.data.fill_(1) continue # load pretrained model if opt.saved_model != '': base_path = './models' print( f'looking for pretrained model from {os.path.join(base_path, opt.saved_model)}' ) try: model.load_state_dict( torch.load(os.path.join(base_path, opt.saved_model))) print('loading complete ') except Exception as e: print(e) print('coud not find model') #data parallel for multi GPU model = torch.nn.DataParallel(model).to(device) model.train() # loss criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) #ignore [GO] token = ignore index 0 log_avg = utils.Averager() # filter that only require gradient descent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Tranable params : ', sum(params_num)) # optimizer # base_opt = optim.Adadelta(filtered_parameters, lr= opt.lr, rho = opt.rho, eps = opt.eps) base_opt = torch.optim.Adam(filtered_parameters, lr=0.001) optimizer = SWA(base_opt) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', verbose=True, patience=2, factor=0.5) # optimizer = adabound.AdaBound(filtered_parameters, lr=1e-3, final_lr=0.1) # opt log with open(f'./models/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '---------------------Options-----------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)} : {str(v)}\n' opt_log += '---------------------------------------------\n' opt_file.write(opt_log) #start training start_time = time.time() best_accuracy = -1 best_norm_ED = -1 swa_count = 0 for n_epoch, epoch in enumerate(range(opt.num_epoch)): for n_iter, data_point in enumerate(data_loader): image_tensors, top, mid, bot = data_point image = image_tensors.to(device) text_top, length_top = top_converter.encode( top, batch_max_length=opt.batch_max_length) text_mid, length_mid = middle_converter.encode( mid, batch_max_length=opt.batch_max_length) text_bot, length_bot = bottom_converter.encode( bot, batch_max_length=opt.batch_max_length) batch_size = image.size(0) pred_top, pred_mid, pred_bot = model(image, text_top[:, :-1], text_mid[:, :-1], text_bot[:, :-1]) # cost_top = criterion(pred_top.view(-1, pred_top.shape[-1]), text_top[:, 1:].contiguous().view(-1)) # cost_mid = criterion(pred_mid.view(-1, pred_mid.shape[-1]), text_mid[:, 1:].contiguous().view(-1)) # cost_bot = criterion(pred_bot.view(-1, pred_bot.shape[-1]), text_bot[:, 1:].contiguous().view(-1)) if n_iter % 2 == 0: cost_top = utils.reduced_focal_loss( pred_top.view(-1, pred_top.shape[-1]), text_top[:, 1:].contiguous().view(-1), ignore_index=0, gamma=2, alpha=0.25, threshold=0.5) cost_mid = utils.reduced_focal_loss( pred_mid.view(-1, pred_mid.shape[-1]), text_mid[:, 1:].contiguous().view(-1), ignore_index=0, gamma=2, alpha=0.25, threshold=0.5) cost_bot = utils.reduced_focal_loss( pred_bot.view(-1, pred_bot.shape[-1]), text_bot[:, 1:].contiguous().view(-1), ignore_index=0, gamma=2, alpha=0.25, threshold=0.5) else: cost_top = utils.CB_loss(text_top[:, 1:].contiguous().view(-1), pred_top.view(-1, pred_top.shape[-1]), top_per_cls, opt.top_n_cls, 'focal', 0.999, 0.5) cost_mid = utils.CB_loss(text_mid[:, 1:].contiguous().view(-1), pred_mid.view(-1, pred_mid.shape[-1]), mid_per_cls, opt.middle_n_cls, 'focal', 0.999, 0.5) cost_bot = utils.CB_loss(text_bot[:, 1:].contiguous().view(-1), pred_bot.view(-1, pred_bot.shape[-1]), bot_per_cls, opt.bottom_n_cls, 'focal', 0.999, 0.5) cost = cost_top * 0.33 + cost_mid * 0.33 + cost_bot * 0.33 loss_avg = utils.Averager() loss_avg.add(cost) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) #gradient clipping with 5 optimizer.step() print(loss_avg.val()) #validation if (n_iter % opt.valInterval == 0) & (n_iter != 0): elapsed_time = time.time() - start_time with open(f'./models/{opt.experiment_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, pred_top_str, pred_mid_str, pred_bot_str, label_top, label_mid, label_bot, infer_time, length_of_data = evaluate.validation_jamo( model, criterion, valid_loader, top_converter, middle_converter, bottom_converter, opt) scheduler.step(current_accuracy) model.train() present_time = time.localtime() loss_log = f'[epoch : {n_epoch}/{opt.num_epoch}] [iter : {n_iter*opt.batch_size} / {int(len(data) * 0.95)}]\n' + f'Train loss : {loss_avg.val():0.5f}, Valid loss : {valid_loss:0.5f}, Elapsed time : {elapsed_time:0.5f}, Present time : {present_time[1]}/{present_time[2]}, {present_time[3]+9} : {present_time[4]}' loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"current_norm_ED":17s}: {current_norm_ED:0.2f}' #keep the best if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.module.state_dict(), f'./models/{opt.experiment_name}/best_accuracy_{round(current_accuracy,2)}.pth' ) if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.module.state_dict(), f'./models/{opt.experiment_name}/best_norm_ED.pth') best_model_log = f'{"Best accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction" :25s}| T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' random_idx = np.random.choice(range(len(label_top)), size=5, replace=False) label_concat = np.concatenate([ np.asarray(label_top).reshape(1, -1), np.asarray(label_mid).reshape(1, -1), np.asarray(label_bot).reshape(1, -1) ], axis=0).reshape(3, -1) pred_concat = np.concatenate([ np.asarray(pred_top_str).reshape(1, -1), np.asarray(pred_mid_str).reshape(1, -1), np.asarray(pred_bot_str).reshape(1, -1) ], axis=0).reshape(3, -1) for i in random_idx: label_sample = label_concat[:, i] pred_sample = pred_concat[:, i] gt_str = utils.str_combine(label_sample[0], label_sample[1], label_sample[2]) pred_str = utils.str_combine(pred_sample[0], pred_sample[1], pred_sample[2]) predicted_result_log += f'{gt_str:25s} | {pred_str:25s} | \t{str(pred_str == gt_str)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # Stochastic weight averaging optimizer.update_swa() swa_count += 1 if swa_count % 5 == 0: optimizer.swap_swa_sgd() torch.save( model.module.state_dict(), f'./models/{opt.experiment_name}/swa_{swa_count}.pth') if (n_epoch) % 5 == 0: torch.save(model.module.state_dict(), f'./models/{opt.experiment_name}/{n_epoch}.pth')
class TPUFitter: def __init__(self, model, device, config, base_model_path='/', model_name='unnamed', model_prefix='roberta', model_version='v1', out_path='/', log_path='/'): self.log_path = Path(log_path, 'log').with_suffix('.txt') self.log(f'TPUFitter started to initilized.', direct_out=True) self.config = config self.epoch = 0 self.base_model_path = base_model_path self.model_name = model_name self.model_version = model_version self.model_path = Path(self.base_model_path, self.model_name, self.model_version) self.out_path = out_path self.node_path = Path(self.out_path, 'node_submissions') self.create_dir_structure() self.model = model self.device = device # whether use stochastic weight avaraging self.use_SWA = config.use_SWA # whether use different lr for backbone and classifier head self.use_diff_lr = config.use_diff_lr self._set_optimizer_scheduler() self.criterion = config.criterion self.best_score = -1.0 self.log(f'Fitter prepared. Device is {self.device}', direct_out=True) def create_dir_structure(self): self.node_path.mkdir(parents=True, exist_ok=True) self.log(f'**** Directory structure created ****', direct_out=True) def _set_optimizer_scheduler(self): self.log(f'Optimizer and scheduler started to initilized.', direct_out=True) def is_backbone(n): return 'backbone' in n param_optimizer = list(self.model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] # use different learning rate for backbone transformer and classifier head if self.use_diff_lr: backbone_lr, head_lr = self.config.lr*xm.xrt_world_size(), self.config.lr*xm.xrt_world_size()*500 optimizer_grouped_parameters = [ # {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001}, # {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, {"params": [p for n, p in param_optimizer if is_backbone(n)], "lr": backbone_lr}, {"params": [p for n, p in param_optimizer if not is_backbone(n)], "lr": head_lr} ] self.log(f'Different Learning rate for backbone: {backbone_lr} head:{head_lr}') else: optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, ] try: self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.config.lr*xm.xrt_world_size()) # self.optimizer = SGD(optimizer_grouped_parameters, lr=self.config.lr*xm.xrt_world_size(), momentum=0.9) except: param_g_1 = [p for n, p in param_optimizer if is_backbone(n)] param_g_2 = [p for n, p in param_optimizer if not is_backbone(n)] param_intersect = list(set(param_g_1) & set(param_g_2)) self.log(f'intersect: {param_intersect}', direct_out=True) if self.use_SWA: self.optimizer = SWA(self.optimizer) if 'num_training_steps' in self.config.scheduler_params: num_training_steps = int(self.config.train_lenght / self.config.batch_size / xm.xrt_world_size() * self.config.n_epochs) self.log(f'Number of training steps: {num_training_steps}', direct_out=True) self.config.scheduler_params['num_training_steps'] = num_training_steps self.scheduler = self.config.SchedulerClass(self.optimizer, **self.config.scheduler_params) def fit(self, train_loader, validation_loader, n_epochs=None): self.log(f'**** Fitting process has been started ****', direct_out=True) if n_epochs is None: n_epochs = self.config.n_epochs for e in range(n_epochs): if self.config.verbose: lr = self.optimizer.param_groups[0]['lr'] timestamp = datetime.utcnow().isoformat() self.log(f'\n{timestamp}\nLR: {lr} \nEpoch:{e}') t = time.time() para_loader = pl.ParallelLoader(train_loader, [self.device]) losses, final_scores = self.train_one_epoch(para_loader.per_device_loader(self.device), e) self.log(f'[RESULT]: Train. Epoch: {self.epoch}, loss: {losses.avg:.5f}, final_score: {final_scores.avg:.5f}, time: {(time.time() - t):.5f}') t = time.time() para_loader = pl.ParallelLoader(validation_loader, [self.device]) # swap SWA weights for validation if self.use_SWA: self.log('Swapping SWA weights for validation', direct_out=True) self.optimizer.swap_swa_sgd() losses, final_scores, threshold = self.validation(para_loader.per_device_loader(self.device)) self.log(f'[RESULT]: Validation. Epoch: {self.epoch}, loss: {losses.avg:.5f}, final_score: {final_scores.avg:.5f}, best_th: {threshold.find:.3f}, time: {(time.time() - t):.5f}') # swap back to normal weights to continue training if self.use_SWA: self.log('Swapping back to original weights for validation', direct_out=True) self.optimizer.swap_swa_sgd() if final_scores.avg > self.best_score: self.best_score = final_scores.avg self.save('best_model') self.log('Best model has been updated', direct_out=True) # after one epoch, update SWA model if validation score is increased if self.use_SWA: self.optimizer.update_swa() self.log('SWA model weights have been updated', direct_out=True) if self.config.validation_scheduler: # self.scheduler.step(metrics=final_scores.avg) self.scheduler.step() self.epoch += 1 def run_tuning_and_inference(self, test_loader, validation_loader, validation_tune_loader, n_epochs): self.log('******Validation tuning and inference is started*****', direct_out=True) self.run_validation_tuning(validation_loader, validation_tune_loader, n_epochs) para_loader = pl.ParallelLoader(test_loader, [self.device]) self.run_inference(para_loader.per_device_loader(self.device)) def run_validation_tuning(self, validation_loader, validation_tune_loader, n_epochs): self.log('******Validation tuning is started*****', direct_out=True) # self.optimizer.param_groups[0]['lr'] = self.config.lr*xm.xrt_world_size() / (epoch + 1) self.fit(validation_tune_loader, validation_loader, n_epochs) def validation(self, val_loader): self.log(f'**** Validation process has been started ****', direct_out=True) self.model.eval() losses = AverageMeter() final_scores = RocAucMeter() threshold = ThresholdMeter() t = time.time() for step, (targets, inputs, attention_masks) in enumerate(val_loader): self.log( f'Valid Step {step}, loss: ' + \ f'{losses.avg:.5f}, final_score: {final_scores.avg:.5f}, ' + \ f'time: {(time.time() - t):.5f}', step=step ) with torch.no_grad(): inputs = inputs.to(self.device, dtype=torch.long) attention_masks = attention_masks.to(self.device, dtype=torch.long) targets = targets.to(self.device, dtype=torch.float) outputs = self.model(inputs, attention_masks) loss = self.criterion(outputs, targets) batch_size = inputs.size(0) final_scores.update(targets, outputs) losses.update(loss.detach().item(), batch_size) threshold.update(targets, outputs) return losses, final_scores, threshold def train_one_epoch(self, train_loader, epoch): self.log(f'**** Epoch training has started: {epoch} ****', direct_out=True) self.model.train() losses = AverageMeter() final_scores = RocAucMeter() t = time.time() for step, (targets, inputs, attention_masks) in enumerate(train_loader): self.log( f'Train Step {step}, loss: ' + \ f'{losses.avg:.5f}, final_score: {final_scores.avg:.5f}, ' + \ f'time: {(time.time() - t):.5f}', step=step ) inputs = inputs.to(self.device, dtype=torch.long) attention_masks = attention_masks.to(self.device, dtype=torch.long) targets = targets.to(self.device, dtype=torch.float) self.optimizer.zero_grad() outputs = self.model(inputs, attention_masks) loss = self.criterion(outputs, targets) batch_size = inputs.size(0) final_scores.update(targets, outputs) losses.update(loss.detach().item(), batch_size) loss.backward() xm.optimizer_step(self.optimizer) if self.config.step_scheduler: self.scheduler.step() return losses, final_scores def run_inference(self, test_loader): self.log(f'**** Inference process has been started ****', direct_out=True) self.model.eval() result = {'id': [], 'toxic': []} t = time.time() for step, (ids, inputs, attention_masks) in enumerate(test_loader): self.log(f'Prediction Step {step}, time: {(time.time() - t):.5f}', step=step) with torch.no_grad(): inputs = inputs.to(self.device, dtype=torch.long) attention_masks = attention_masks.to(self.device, dtype=torch.long) outputs = self.model(inputs, attention_masks) toxics = torch.nn.functional.softmax(outputs, dim=1).data.cpu().numpy()[:,1] result['id'].extend(ids.cpu().numpy()) result['toxic'].extend(toxics) result = pd.DataFrame(result) print(f'Node path is: {self.node_path}') node_count = len(list(self.node_path.glob('*.csv'))) result.to_csv(self.node_path/f'submission_{node_count}_{datetime.utcnow().microsecond}_{random.random()}.csv', index=False) def run_pseudolabeling(self, test_loader, epoch): losses = AverageMeter() final_scores = RocAucMeter() self.model.eval() t = time.time() for step, (ids, inputs, attention_masks) in enumerate(test_loader): inputs = inputs.to(self.device, dtype=torch.long) attention_masks = attention_masks.to(self.device, dtype=torch.long) outputs = self.model(inputs, attention_masks) # print(f'Inputs: {inputs} size: {inputs.size()}') # print(f'outputs: {outputs} size: {outputs.size()}') toxics = torch.nn.functional.softmax(outputs, dim=1)[:,1] toxic_mask = (toxics<=0.4) | (toxics>=0.8) # print(attention_masks.size()) toxics = toxics[toxic_mask] inputs = inputs[toxic_mask] attention_masks = attention_masks[toxic_mask] # print(f'toxics: {toxics.size()}') # print(f'inputs: {inputs.size()}') if toxics.nelement() != 0: targets_int = (toxics>self.config.pseudolabeling_threshold).int() targets = torch.stack([onehot(2, target) for target in targets_int]) # print(targets_int) self.model.train() self.log( f'Pseudolabeling Step {step}, loss: ' + \ f'{losses.avg:.5f}, final_score: {final_scores.avg:.5f}, ' + \ f'time: {(time.time() - t):.5f}', step=step ) targets = targets.to(self.device, dtype=torch.float) self.optimizer.zero_grad() outputs = self.model(inputs, attention_masks) loss = self.criterion(outputs, targets) batch_size = inputs.size(0) final_scores.update(targets, outputs) losses.update(loss.detach().item(), batch_size) loss.backward() xm.optimizer_step(self.optimizer) if self.config.step_scheduler: self.scheduler.step() self.log(f'[RESULT]: Pseudolabeling. Epoch: {epoch}, loss: {losses.avg:.5f}, final_score: {final_scores.avg:.5f}, time: {(time.time() - t):.5f}') def get_submission(self, out_dir): submission = pd.concat([pd.read_csv(path) for path in (out_dir/'node_submissions').glob('*.csv')]).groupby('id').mean() return submission def save(self, name): self.model_path.mkdir(parents=True, exist_ok=True) path = (self.model_path/name).with_suffix('.bin') if self.use_SWA: self.optimizer.swap_swa_sgd() xm.save(self.model.state_dict(), path) self.log(f'Model has been saved') def log(self, message, step=None, direct_out=False): if direct_out or self.config.verbose: if direct_out or step is None or (step is not None and step % self.config.verbose_step == 0): xm.master_print(message) with open(self.log_path, 'a+') as logger: xm.master_print(f'{message}', logger)
class Train(object): """Train class. """ def __init__(self, ): trainds = AlaskaDataIter(cfg.DATA.root_path, cfg.DATA.train_txt_path, training_flag=True) self.train_ds = DataLoader(trainds, cfg.TRAIN.batch_size, num_workers=cfg.TRAIN.process_num, shuffle=True) valds = AlaskaDataIter(cfg.DATA.root_path, cfg.DATA.val_txt_path, training_flag=False) self.val_ds = DataLoader(valds, cfg.TRAIN.batch_size, num_workers=cfg.TRAIN.process_num, shuffle=False) self.init_lr = cfg.TRAIN.init_lr self.warup_step = cfg.TRAIN.warmup_step self.epochs = cfg.TRAIN.epoch self.batch_size = cfg.TRAIN.batch_size self.l2_regularization = cfg.TRAIN.weight_decay_factor self.device = torch.device( "cuda" if torch.cuda.is_available() else 'cpu') self.model = CenterNet().to(self.device) self.load_weight() if 'Adamw' in cfg.TRAIN.opt: self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=self.init_lr, eps=1.e-5, weight_decay=self.l2_regularization) else: self.optimizer = torch.optim.SGD( self.model.parameters(), lr=self.init_lr, momentum=0.9, weight_decay=self.l2_regularization) if cfg.TRAIN.SWA > 0: ##use swa self.optimizer = SWA(self.optimizer) if cfg.TRAIN.mix_precision: self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O1") self.model = nn.DataParallel(self.model) self.ema = EMA(self.model, 0.999) self.ema.register() ###control vars self.iter_num = 0 # self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,mode='max', patience=3,verbose=True) self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, self.epochs, eta_min=1.e-6) self.criterion = CenterNetLoss().to(self.device) def custom_loop(self): """Custom training and testing loop. Args: train_dist_dataset: Training dataset created using strategy. test_dist_dataset: Testing dataset created using strategy. strategy: Distribution strategy. Returns: train_loss, train_accuracy, test_loss, test_accuracy """ def train_epoch(epoch_num): summary_loss_cls = AverageMeter() summary_loss_wh = AverageMeter() self.model.train() if cfg.MODEL.freeze_bn: for m in self.model.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() if cfg.MODEL.freeze_bn_affine: m.weight.requires_grad = False m.bias.requires_grad = False for image, hm_target, wh_target, weights in self.train_ds: if epoch_num < 10: ###excute warm up in the first epoch if self.warup_step > 0: if self.iter_num < self.warup_step: for param_group in self.optimizer.param_groups: param_group['lr'] = self.iter_num / float( self.warup_step) * self.init_lr lr = param_group['lr'] logger.info('warm up with learning rate: [%f]' % (lr)) start = time.time() if cfg.TRAIN.vis: for i in range(image.shape[0]): img = image[i].numpy() img = np.transpose(img, axes=[1, 2, 0]) hm = hm_target[i].numpy() wh = wh_target[i].numpy() if cfg.DATA.use_int8_data: hm = hm[:, :, 0].astype(np.uint8) wh = wh[:, :, 0] else: hm = hm[:, :, 0].astype(np.float32) wh = wh[:, :, 0].astype(np.float32) cv2.namedWindow('s_hm', 0) cv2.imshow('s_hm', hm) cv2.namedWindow('s_wh', 0) cv2.imshow('s_wh', wh + 1) cv2.namedWindow('img', 0) cv2.imshow('img', img) cv2.waitKey(0) else: data = image.to(self.device).float() if cfg.DATA.use_int8_data: hm_target = hm_target.to( self.device).float() / cfg.DATA.use_int8_enlarge else: hm_target = hm_target.to(self.device).float() wh_target = wh_target.to(self.device).float() weights = weights.to(self.device).float() batch_size = data.shape[0] cls, wh = self.model(data) cls_loss, wh_loss = self.criterion( [cls, wh], [hm_target, wh_target, weights]) current_loss = cls_loss + wh_loss summary_loss_cls.update(cls_loss.detach().item(), batch_size) summary_loss_wh.update(wh_loss.detach().item(), batch_size) self.optimizer.zero_grad() if cfg.TRAIN.mix_precision: with amp.scale_loss(current_loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: current_loss.backward() self.optimizer.step() if cfg.TRAIN.ema: self.ema.update() self.iter_num += 1 time_cost_per_batch = time.time() - start images_per_sec = cfg.TRAIN.batch_size * cfg.TRAIN.num_gpu / time_cost_per_batch if self.iter_num % cfg.TRAIN.log_interval == 0: log_message = '[TRAIN], '\ 'Epoch %d Step %d, ' \ 'summary_loss: %.6f, ' \ 'cls_loss: %.6f, '\ 'wh_loss: %.6f, ' \ 'time: %.6f, '\ 'speed %d images/persec'% ( epoch_num, self.iter_num, summary_loss_cls.avg+summary_loss_wh.avg, summary_loss_cls.avg , summary_loss_wh.avg, time.time() - start, images_per_sec) logger.info(log_message) if cfg.TRAIN.SWA > 0 and epoch_num >= cfg.TRAIN.SWA: self.optimizer.update_swa() return summary_loss_cls, summary_loss_wh def test_epoch(epoch_num): summary_loss_cls = AverageMeter() summary_loss_wh = AverageMeter() self.model.eval() t = time.time() with torch.no_grad(): for step, (image, hm_target, wh_target, weights) in enumerate(self.val_ds): data = image.to(self.device).float() if cfg.DATA.use_int8_data: hm_target = hm_target.to( self.device).float() / cfg.DATA.use_int8_enlarge else: hm_target = hm_target.to(self.device).float() wh_target = wh_target.to(self.device).float() weights = weights.to(self.device).float() batch_size = data.shape[0] with torch.no_grad(): cls, wh = self.model(data) cls_loss, wh_loss = self.criterion( [cls, wh], [hm_target, wh_target, weights]) summary_loss_cls.update(cls_loss.detach().item(), batch_size) summary_loss_wh.update(wh_loss.detach().item(), batch_size) if step % cfg.TRAIN.log_interval == 0: log_message = '[VAL], '\ 'Epoch %d Step %d, ' \ 'summary_loss: %.6f, ' \ 'cls_loss: %.6f, '\ 'wh_loss: %.6f, ' \ 'time: %.6f' % (epoch_num, step, summary_loss_cls.avg+summary_loss_wh.avg, summary_loss_cls.avg, summary_loss_wh.avg, time.time() - t) logger.info(log_message) return summary_loss_cls, summary_loss_wh for epoch in range(self.epochs): for param_group in self.optimizer.param_groups: lr = param_group['lr'] logger.info('learning rate: [%f]' % (lr)) t = time.time() summary_loss_cls, summary_loss_wh = train_epoch(epoch) train_epoch_log_message = '[centernet], '\ '[RESULT]: Train. Epoch: %d,' \ ' summary_loss: %.5f,' \ ' cls_loss: %.6f, ' \ ' wh_loss: %.6f, ' \ ' time:%.5f' % (epoch, summary_loss_cls.avg+summary_loss_wh.avg, summary_loss_cls.avg, summary_loss_wh.avg, (time.time() - t)) logger.info(train_epoch_log_message) if cfg.TRAIN.SWA > 0 and epoch >= cfg.TRAIN.SWA: ###switch to avg model self.optimizer.swap_swa_sgd() ##switch eam weighta if cfg.TRAIN.ema: self.ema.apply_shadow() if epoch % cfg.TRAIN.test_interval == 0: summary_loss_cls, summary_loss_wh = test_epoch(epoch) val_epoch_log_message = '[centernet], '\ '[RESULT]: VAL. Epoch: %d,' \ ' summary_loss: %.5f,' \ ' cls_loss: %.6f, ' \ ' wh_loss: %.6f, ' \ ' time:%.5f' % (epoch, summary_loss_cls.avg+summary_loss_wh.avg, summary_loss_cls.avg, summary_loss_wh.avg, (time.time() - t)) logger.info(val_epoch_log_message) self.scheduler.step() # self.scheduler.step(final_scores.avg) #### save model if not os.access(cfg.MODEL.model_path, os.F_OK): os.mkdir(cfg.MODEL.model_path) #### save the model every end of epoch current_model_saved_name = './model/centernet_epoch_%d_val_loss%.6f.pth' % ( epoch, summary_loss_cls.avg + summary_loss_wh.avg) logger.info('A model saved to %s' % current_model_saved_name) torch.save(self.model.module.state_dict(), current_model_saved_name) ####switch back if cfg.TRAIN.ema: self.ema.restore() if cfg.TRAIN.SWA > 0 and epoch > cfg.TRAIN.SWA: ###switch back to plain model to train next epoch self.optimizer.swap_swa_sgd() def load_weight(self): if cfg.MODEL.pretrained_model is not None: state_dict = torch.load(cfg.MODEL.pretrained_model, map_location=self.device) self.model.load_state_dict(state_dict, strict=False)