def evaluate(model, config, val_loader): model.eval() opt = config['opt'] pad_label_id = config['pad_label_id'] eval_loss = 0. criterion = nn.CrossEntropyLoss(ignore_index=pad_label_id).to(opt.device) n_batches = len(val_loader) prog = Progbar(target=n_batches) preds = None ys = None with torch.no_grad(): for i, (x, y) in enumerate(val_loader): x = to_device(x, opt.device) y = to_device(y, opt.device) if opt.use_crf: logits, prediction = model(x) mask = torch.sign(torch.abs(x[0])).to(torch.uint8).to( opt.device) log_likelihood = model.crf(logits, y, mask=mask, reduction='mean') loss = -1 * log_likelihood else: logits = model(x) loss = criterion(logits.view(-1, model.label_size), y.view(-1)) if preds is None: if opt.use_crf: preds = to_numpy(prediction) else: preds = to_numpy(logits) ys = to_numpy(y) else: if opt.use_crf: preds = np.append(preds, to_numpy(prediction), axis=0) else: preds = np.append(preds, to_numpy(logits), axis=0) ys = np.append(ys, to_numpy(y), axis=0) eval_loss += loss.item() prog.update(i + 1, [('eval curr loss', loss.item())]) eval_loss = eval_loss / n_batches if not opt.use_crf: preds = np.argmax(preds, axis=2) # compute measure using seqeval labels = model.labels ys_lbs = [[] for _ in range(ys.shape[0])] preds_lbs = [[] for _ in range(ys.shape[0])] for i in range(ys.shape[0]): # foreach sentence for j in range(ys.shape[1]): # foreach token if ys[i][j] != pad_label_id: ys_lbs[i].append(labels[ys[i][j]]) preds_lbs[i].append(labels[preds[i][j]]) ret = { "loss": eval_loss, "precision": precision_score(ys_lbs, preds_lbs), "recall": recall_score(ys_lbs, preds_lbs), "f1": f1_score(ys_lbs, preds_lbs), "report": classification_report(ys_lbs, preds_lbs, digits=4), } print(ret['report']) return ret
def forward_core(self, input): initial_h = to_device(torch.zeros(input.size(1), self.hidden_size)) initial_c = to_device(torch.zeros(input.size(1), self.hidden_size)) # lstm and dense pass for prediction lstm_out = self.lstm(input, initial_h, initial_c)[0] dense_out = self.forward_dense(lstm_out) return dense_out
def evaluate(model, config, valid_loader, eval_device=None): args = config['args'] total_loss = 0. total_examples = 0 correct = 0 criterion = torch.nn.CrossEntropyLoss() preds = None ys = None with torch.no_grad(): iterator = tqdm(valid_loader, total=len(valid_loader), desc=f"Evaluate") for i, (x, y) in enumerate(iterator): if eval_device: x = to_device(x, args.device) y = to_device(y, args.device) model.eval() logits = model(x) loss = criterion(logits, y) # softmax after computing cross entropy loss logits = torch.softmax(logits, dim=-1) logits = logits.cpu().numpy() y = y.cpu().numpy() if preds is None: preds = logits ys = y else: preds = np.append(preds, logits, axis=0) ys = np.append(ys, y, axis=0) predicted = np.argmax(logits, axis=1) correct += np.sum(np.equal(predicted, y).astype(int)) cur_examples = y.size total_loss += (loss.item() * cur_examples) total_examples += cur_examples # generate report labels = config['labels'] label_names = [v for k, v in sorted(labels.items(), key=lambda x: x[0])] preds_ids = np.argmax(preds, axis=1) try: print( classification_report(ys, preds_ids, target_names=label_names, digits=4)) print(labels) print(confusion_matrix(ys, preds_ids)) except Exception as e: logger.warn(str(e)) cur_loss = total_loss / total_examples cur_acc = correct / total_examples return cur_loss, cur_acc
def forward_core(self, input): batch_size = input.size(1) initial_h = to_device(torch.zeros(batch_size, self.hidden_size)) # initialize eligibility vectors only on first run if type(self.eligibility_vectors) == type(None): self.initial_c = to_device( torch.zeros(batch_size, self.hidden_size)).requires_grad_() self.eligibility_vectors = [ to_device( torch.zeros(batch_size, 3 * self.hidden_size, self.input_size, requires_grad=False)), to_device( torch.zeros(batch_size, 3 * self.hidden_size, self.hidden_size, requires_grad=False)), to_device( torch.zeros(batch_size, 3 * self.hidden_size, 1, requires_grad=False)) ] initial_c = self.initial_c # lstm and dense pass for prediction lstm_out, final_c, self.eligibility_vectors = self.lstm( input, initial_h.detach(), initial_c, self.eligibility_vectors) lstm_out = self.forward_dense(lstm_out) # take the last output of the network to let the synth grad network predict the synthetic gradient synth_grad = self.synthetic_gradient_net(lstm_out[:, -1, :].detach()) # set the gradient of the last internal state equal to the synthetic gradient self.initial_c, lstm_out, _ = SyntheticGradient.apply( final_c, lstm_out, synth_grad.detach()) # detach initial state from the compute graph of [t_{m-1}+1, ..., t_{m}] but # build a new compute graph ... self.initial_c = self.initial_c.detach().requires_grad_() # ... and make sure to also detach eligibility vectors self.eligibility_vectors = [ self.eligibility_vectors[0].detach(), self.eligibility_vectors[1].detach(), self.eligibility_vectors[2].detach() ] # return both the output as well as the synthetic gradient return lstm_out, initial_c, synth_grad
def generate_single_lable_memory_data(num_observations, sequence_length, time_delta=None): ''' Generates num_observations sequences of length sequence_length where the entire sequence is filled with 0s, except for one singular signal at a random position inside the sequence. The label of a sequence is the singular signal value. Because the NLLLoss used to train the network expects labels in the range between 0 and num_classes - 1 instead of 1 and num_classes. ''' size = ((num_observations, sequence_length, 1)) data = np.zeros(size) labels = np.zeros((num_observations, 1)) for i, row in enumerate(data): signal = np.random.randint(1, MEM_NUM_CLASSES, 1) last_possible_signal = sequence_length if not time_delta else sequence_length - time_delta column = np.random.randint(0, last_possible_signal, 1) row[column] = signal labels[i] = signal - 1 return to_device(torch.from_numpy(data).float()), to_device(torch.from_numpy(labels).long())
def evaluate(model, config, val_loader): opt = config['opt'] model.eval() total_loss = 0. total_examples = 0 correct = 0 criterion = torch.nn.CrossEntropyLoss().to(opt.device) preds = None ys = None with torch.no_grad(): for i, (x, y) in tqdm(enumerate(val_loader), total=len(val_loader)): x = to_device(x, opt.device) y = to_device(y, opt.device) logits = model(x) loss = criterion(logits, y) if preds is None: preds = to_numpy(logits) ys = to_numpy(y) else: preds = np.append(preds, to_numpy(logits), axis=0) ys = np.append(ys, to_numpy(y), axis=0) predicted = logits.argmax(1) correct += (predicted == y).sum().item() cur_examples = y.size(0) total_loss += (loss.item() * cur_examples) total_examples += cur_examples # generate report labels = model.labels label_names = [v for k, v in sorted(labels.items(), key=lambda x: x[0])] preds_ids = np.argmax(preds, axis=1) try: print( classification_report(ys, preds_ids, target_names=label_names, digits=4)) print(labels) print(confusion_matrix(ys, preds_ids)) except Exception as e: logger.warn(str(e)) cur_loss = total_loss / total_examples cur_acc = correct / total_examples return cur_loss, cur_acc
def forward(self, input, initial_h, initial_c, eligibility_vectors=[]): # input (seq_len x batch_size x input_size) # initial_hidden (batch x hidden_size) # initial_state (batch x hidden_size) inputs = input.unbind(0) input_size = input.size(2) hidden_size = initial_h.size(1) batch_size = input.size(1) hx = initial_h cx = initial_c if len(eligibility_vectors) == 0: ev_w_ih_x = to_device( torch.zeros(batch_size, 3 * hidden_size, input_size, requires_grad=False)) ev_w_hh_x = to_device( torch.zeros(batch_size, 3 * hidden_size, hidden_size, requires_grad=False)) ev_b_x = to_device( torch.zeros(batch_size, 3 * hidden_size, 1, requires_grad=False)) else: ev_w_ih_x = eligibility_vectors[0] ev_w_hh_x = eligibility_vectors[1] ev_b_x = eligibility_vectors[2] forgetgate = to_device(torch.zeros(batch_size, hidden_size, 1)) outputs = [] for i in range(len(inputs)): hx, cx, ev_w_ih_x, ev_w_hh_x, ev_b_x, new_forgetgate = self.cell( inputs[i], hx, cx, ev_w_ih_x, ev_w_hh_x, ev_b_x, forgetgate) _, forgetgate = ForgetGate.apply(forgetgate, new_forgetgate) outputs += [hx] return torch.stack(outputs), cx, [ev_w_ih_x, ev_w_hh_x, ev_b_x]
def validate(data_loader, model, best_score, global_step, cfg): """ Loads the model weights from the state dictionary. Function will only load the weights which have matching key names and dimensions in the state dictionary. :param state_dict: Pytorch model state dictionary :param verbose: bool, If True, the function will print the weight keys of parameters that can and cannot be loaded from the checkpoint state dictionary. :return: """ model.eval() gts, predictions = [], [] log.info("Validation started...") for data in data_loader: imgs, labels = data imgs = util.to_device(imgs, gpu=cfg.gpu) with torch.no_grad(): logits = model(imgs) probs = model.module.probability(logits) preds = torch.argmax(probs, dim=1).cpu().numpy() labels = labels.cpu().detach().numpy() predictions.extend(preds) gts.extend(labels) predictions = np.array(predictions, dtype=np.int32) gts = np.array(gts, dtype=np.int32) acc, f1, prec, rec = util.clf_metrics(predictions=predictions, targets=gts, average="macro") report = classification_report(gts, predictions, output_dict=True) log.info("VALIDATION | Accuracy {:.4f} | F1 {:.4f} | Precision {:.4f} | " "Recall {:.4f}".format(acc, f1, prec, rec)) if acc > best_score: save_config = { 'name': config.name, 'save_dir': config.ckpts_dir, 'global_step': global_step, 'clf_report': report } save_model(model=model, config=save_config) best_score = acc log.info("Validation end") model.train() return best_score
def preprocess(self, data): config = self.config opt = config['opt'] logger.info("data: %s", data) text = data[0].get('data') if text is None: text = data[0].get('body') if text: text = text.decode('utf-8') logger.info("[Received text] %s", text) x = self.encode_text(text) x = to_device(x, opt.device) return x, text
def generate_store_and_recall_data(num_observations, sequence_length, recall_repetition=0, time_delta=None): ''' Generates num_observations sequences of length sequence_length. The input consists of three different signals: 1. Data signal: A stream of data points with a value between 1 and SR_NUM_CLASSES + 1 2. Store signal: Either 0 if NO storage is required, or 1 if the current data point is supposed to be stored by the network 3. Recall signal: Either 0 if NO recall is required, or 1 if the last stored data point is supposed to be recalled by the network. The label is a sequence of length num_observations that is 0 if the recall signal is 0 or has the value of the last stored data point if the corresponding recall signal is 1. ''' size = ((num_observations, sequence_length, 1)) data_stream = np.random.randint(1, SR_NUM_CLASSES + 1, size) store_signal = np.zeros(size) recall_signal = np.zeros(size) labels = np.zeros(size) data = [] time_delta = time_delta if time_delta else 1 for i, row in enumerate(data_stream): # select time step where store signal is sent store = np.random.choice(list(range(sequence_length - time_delta - recall_repetition))) store_signal[i][store] = 1 # select time step where recall signal is sent recall = np.random.choice(list(range(store + time_delta, sequence_length - recall_repetition))) for j in range(0, recall_repetition + 1): recall_signal[i][recall + j] = 1 labels[i][recall + j] = row[store] data.append(np.hstack((row, store_signal[i], recall_signal[i]))) data = np.array(data) return to_device(torch.from_numpy(data).float()), to_device(torch.from_numpy(labels).long())
def chose_task(memory_task, training_algorithm): # Chose the task and corresponding model: if memory_task == config.MEMORY: generate_data = generate_single_lable_memory_data input_size = config.MEM_INPUT_SIZE hidden_size = config.MEM_HIDEN_SIZE output_size = config.MEM_NUM_CLASSES single_output = True loss_function = nn.CrossEntropyLoss() elif memory_task == config.STORE_RECALL: generate_data = generate_store_and_recall_data input_size = config.SR_INPUT_SIZE hidden_size = config.SR_HIDEN_SIZE output_size = config.SR_NUM_CLASSES + 1 single_output = False loss_function = nn.CrossEntropyLoss() if training_algorithm == BPTT: model_constructor = BPTT_LSTM train_function = lambda model, optimizer, loss_func, batch_x, batch_y : train_bptt(model, optimizer, loss_func, batch_x, batch_y) elif training_algorithm == EPROP_1: model_constructor = EPROP1_LSTM train_function = lambda model, optimizer, loss_func, batch_x, batch_y : train_bptt(model, optimizer, loss_func, batch_x, batch_y) elif training_algorithm == EPROP_3: model_constructor = lambda in_size, h_size, o_size, single_output : EPROP3_LSTM( in_size, h_size, o_size, single_output=single_output) train_function = lambda model, optimizer, loss_func, batch_x, batch_y : train_eprop3( model, optimizer, loss_func, batch_x, batch_y, config.TRUNCATION_DELTA) model = to_device(model_constructor( input_size, hidden_size, output_size, single_output=single_output)) return generate_data, model, loss_function, train_function
def validate(data_loader, model, best_score, global_step, cfg): model.eval() gts, predictions = [], [] log.info("Validation started...") for data in data_loader: imgs, labels = data imgs = util.to_device(imgs, gpu=cfg.gpu) with torch.no_grad(): logits = model(imgs) probs = model.module.probability(logits) preds = torch.argmax(probs, dim=1).cpu().numpy() labels = labels.cpu().detach().numpy() predictions.extend(preds) gts.extend(labels) predictions = np.array(predictions, dtype=np.int32) gts = np.array(gts, dtype=np.int32) acc, f1, prec, rec = util.clf_metrics(predictions=predictions, targets=gts, average="macro") report = classification_report(gts, predictions, output_dict=True) log.info("VALIDATION | Accuracy {:.4f} | F1 {:.4f} | Precision {:.4f} | " "Recall {:.4f}".format(acc, f1, prec, rec)) if f1 > best_score: save_config = { 'name': config.name, 'save_dir': config.ckpts_dir, 'global_step': global_step, 'clf_report': report } save_model(model=model, config=save_config) best_score = f1 log.info("Validation end") model.train() return best_score
# --------------------------------------- summ_counter = 0 mean_losses = np.zeros(3) mean_metrics = np.zeros(5) for epoch in tqdm(range(epoch, epoch_total), initial=epoch, total=epoch_total, leave=False, dynamic_ncols=True): for idx, batch in enumerate( tqdm(BackgroundGenerator(dataloader), total=len(dataloader), leave=False, dynamic_ncols=True)): text_data, text_pos, text_len, text_mask, mel_data, mel_pos, mel_len, mel_mask, gate, text_data_ = to_device( batch, device) text_out, gate_out, att_heads_enc, att_heads_dec, att_heads = model( mel_data, mel_pos, mel_mask, text_data_, text_pos, text_mask) loss_text = F.cross_entropy(text_out.view(-1, nr_symbols), text_data.view(-1), weight=weights, ignore_index=0) loss_gate = F.binary_cross_entropy(gate_out, gate) loss = loss_text + loss_gate optimizer.zero_grad() loss.backward() # print(mel_mask.grad) grad_norm_enc = nn.utils.clip_grad_norm_(model.encoder.parameters(),
def train_epoch(model, config, train_loader, val_loader, epoch_i, best_eval_f1): opt = config['opt'] optimizer = config['optimizer'] scheduler = config['scheduler'] writer = config['writer'] scaler = config['scaler'] pad_label_id = config['pad_label_id'] criterion = nn.CrossEntropyLoss(ignore_index=pad_label_id).to(opt.device) n_batches = len(train_loader) # train one epoch train_loss = 0. avg_loss = 0. local_best_eval_loss = float('inf') local_best_eval_f1 = 0 st_time = time.time() optimizer.zero_grad() epoch_iterator = tqdm(train_loader, total=len(train_loader), desc=f"Epoch {epoch_i}") for local_step, (x, y) in enumerate(epoch_iterator): model.train() global_step = (len(train_loader) * epoch_i) + local_step x = to_device(x, opt.device) y = to_device(y, opt.device) if opt.use_crf: with autocast(enabled=opt.use_amp): if opt.use_profiler: with profiler.profile(profile_memory=True, record_shapes=True) as prof: logits, prediction = model(x) print(prof.key_averages().table( sort_by="self_cpu_memory_usage", row_limit=10)) else: logits, prediction = model(x) mask = torch.sign(torch.abs(x[0])).to(torch.uint8).to( opt.device) log_likelihood = model.crf(logits, y, mask=mask, reduction='mean') loss = -1 * log_likelihood if opt.gradient_accumulation_steps > 1: loss = loss / opt.gradient_accumulation_steps else: with autocast(enabled=opt.use_amp): if opt.use_profiler: with profiler.profile(profile_memory=True, record_shapes=True) as prof: logits = model(x) print(prof.key_averages().table( sort_by="self_cpu_memory_usage", row_limit=10)) else: logits = model(x) # reshape for computing loss logits_view = logits.view(-1, model.label_size) y_view = y.view(-1) loss = criterion(logits_view, y_view) if opt.gradient_accumulation_steps > 1: loss = loss / opt.gradient_accumulation_steps # back-propagation - begin scaler.scale(loss).backward() if (local_step + 1) % opt.gradient_accumulation_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), opt.max_grad_norm) scaler.step(optimizer) scaler.update() optimizer.zero_grad() scheduler.step() curr_lr = scheduler.get_last_lr( )[0] if scheduler else optimizer.param_groups[0]['lr'] epoch_iterator.set_description( f"Epoch {epoch_i}, local_step: {local_step}, loss: {loss:.3f}, curr_lr: {curr_lr:.7f}" ) if opt.eval_and_save_steps > 0 and global_step != 0 and global_step % opt.eval_and_save_steps == 0: # evaluate eval_ret = evaluate(model, config, val_loader) eval_loss = eval_ret['loss'] eval_f1 = eval_ret['f1'] if local_best_eval_loss > eval_loss: local_best_eval_loss = eval_loss if local_best_eval_f1 < eval_f1: local_best_eval_f1 = eval_f1 if writer: writer.add_scalar('Loss/valid', eval_loss, global_step) writer.add_scalar('F1/valid', eval_f1, global_step) writer.add_scalar('LearningRate/train', curr_lr, global_step) if eval_f1 > best_eval_f1: best_eval_f1 = eval_f1 if opt.save_path and not opt.hp_search_optuna: logger.info("[Best model saved] : {}, {}".format( eval_loss, eval_f1)) save_model(config, model) # save finetuned bert model/config/tokenizer if config['emb_class'] not in ['glove', 'elmo']: if not os.path.exists(opt.bert_output_dir): os.makedirs(opt.bert_output_dir) model.bert_tokenizer.save_pretrained( opt.bert_output_dir) model.bert_model.save_pretrained( opt.bert_output_dir) # back-propagation - end train_loss += loss.item() if writer: writer.add_scalar('Loss/train', loss.item(), global_step) avg_loss = train_loss / n_batches # evaluate at the end of epoch eval_ret = evaluate(model, config, val_loader) eval_loss = eval_ret['loss'] eval_f1 = eval_ret['f1'] if local_best_eval_loss > eval_loss: local_best_eval_loss = eval_loss if local_best_eval_f1 < eval_f1: local_best_eval_f1 = eval_f1 if writer: writer.add_scalar('Loss/valid', eval_loss, global_step) writer.add_scalar('F1/valid', eval_f1, global_step) writer.add_scalar('LearningRate/train', curr_lr, global_step) if eval_f1 > best_eval_f1: best_eval_f1 = eval_f1 if opt.save_path and not opt.hp_search_optuna: logger.info("[Best model saved] : {}, {}".format( eval_loss, eval_f1)) save_model(config, model) # save finetuned bert model/config/tokenizer if config['emb_class'] not in ['glove', 'elmo']: if not os.path.exists(opt.bert_output_dir): os.makedirs(opt.bert_output_dir) model.bert_tokenizer.save_pretrained(opt.bert_output_dir) model.bert_model.save_pretrained(opt.bert_output_dir) curr_time = time.time() elapsed_time = (curr_time - st_time) / 60 st_time = curr_time logs = { 'epoch': epoch_i, 'local_step': local_step + 1, 'epoch_step': len(train_loader), 'avg_loss': avg_loss, 'local_best_eval_loss': local_best_eval_loss, 'local_best_eval_f1': local_best_eval_f1, 'best_eval_f1': best_eval_f1, 'elapsed_time': elapsed_time } logger.info(json.dumps(logs, indent=4, ensure_ascii=False, sort_keys=True)) return local_best_eval_loss, local_best_eval_f1, best_eval_f1
def prune_rewire(config, model, eval_loader, use_tqdm=True): args = config['opt'] bert_model = model.bert_model # get the model ffn weights and biases inter_weights = torch.zeros(bert_model.config.num_hidden_layers, bert_model.config.intermediate_size, bert_model.config.hidden_size).to(args.device) inter_biases = torch.zeros(bert_model.config.num_hidden_layers, bert_model.config.intermediate_size).to(args.device) output_weights = torch.zeros(bert_model.config.num_hidden_layers, bert_model.config.hidden_size, bert_model.config.intermediate_size).to(args.device) layers = bert_model.base_model.encoder.layer head_importance = torch.zeros(bert_model.config.num_hidden_layers, bert_model.config.num_attention_heads).to(args.device) ffn_importance = torch.zeros(bert_model.config.num_hidden_layers, bert_model.config.intermediate_size).to(args.device) for layer_num in range(bert_model.config.num_hidden_layers): inter_weights[layer_num] = layers._modules[str(layer_num)].intermediate.dense.weight.detach().to(args.device) inter_biases[layer_num] = layers._modules[str(layer_num)].intermediate.dense.bias.detach().to(args.device) output_weights[layer_num] = layers._modules[str(layer_num)].output.dense.weight.detach().to(args.device) head_mask = torch.ones(bert_model.config.num_hidden_layers, bert_model.config.num_attention_heads, requires_grad=True).to(args.device) # Eval! logger.info(f"***** Running evaluation for pruning *****") logger.info(" Num batches = %d", len(eval_loader)) logger.info(" Batch size = %d", args.eval_batch_size) eval_loss = 0.0 nb_eval_steps = 0 criterion = torch.nn.CrossEntropyLoss().to(args.device) eval_loader = tqdm(eval_loader, desc="Evaluating") if use_tqdm else eval_loader tot_tokens = 0.0 for x, y in eval_loader: model.eval() x = to_device(x, args.device) y = to_device(y, args.device) logits, bert_outputs = model(x, return_bert_outputs=True, head_mask=head_mask) tmp_eval_loss = criterion(logits, y) eval_loss += tmp_eval_loss.mean().item() # for preventing head_mask.grad is None head_mask.retain_grad() # TODO accumulate? absolute value sum? tmp_eval_loss.backward() # collect attention confidence scores head_importance += head_mask.grad.abs().detach() # collect gradients of linear layers for layer_num in range(bert_model.config.num_hidden_layers): ffn_importance[layer_num] += torch.abs( torch.sum(layers._modules[str(layer_num)].intermediate.dense.weight.grad.detach()*inter_weights[layer_num], 1) + layers._modules[str(layer_num)].intermediate.dense.bias.grad.detach()*inter_biases[layer_num]) attention_mask = x[1] tot_tokens += attention_mask.float().detach().sum().data nb_eval_steps += 1 head_importance /= tot_tokens # Layerwise importance normalization if not args.dont_normalize_importance_by_layer: exponent = 2 norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1 / exponent) head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20 # rewire the network head_importance = head_importance.cpu() ffn_importance = ffn_importance.cpu() num_heads = bert_model.config.num_attention_heads head_size = bert_model.config.hidden_size / num_heads for layer_num in range(bert_model.config.num_hidden_layers): # load query, key, value weights query_weight = layers._modules[str(layer_num)].attention.self.query.weight query_bias = layers._modules[str(layer_num)].attention.self.query.bias key_weight = layers._modules[str(layer_num)].attention.self.key.weight key_bias = layers._modules[str(layer_num)].attention.self.key.bias value_weight = layers._modules[str(layer_num)].attention.self.value.weight value_bias = layers._modules[str(layer_num)].attention.self.value.bias # sort query, key, value based on the confidence scores query_weight, query_bias = sort_by_importance(query_weight, query_bias, head_importance[layer_num], args.target_num_heads, head_size) print('query_weight = ', query_weight.shape) print('query_bias = ', query_bias.shape) layers._modules[str(layer_num)].attention.self.query.weight = torch.nn.Parameter(query_weight) layers._modules[str(layer_num)].attention.self.query.bias = torch.nn.Parameter(query_bias) key_weight, key_bias = sort_by_importance(key_weight, key_bias, head_importance[layer_num], args.target_num_heads, head_size) print('key_weight = ', key_weight.shape) print('key_bias = ', key_bias.shape) layers._modules[str(layer_num)].attention.self.key.weight = torch.nn.Parameter(key_weight) layers._modules[str(layer_num)].attention.self.key.bias = torch.nn.Parameter(key_bias) value_weight, value_bias = sort_by_importance(value_weight, value_bias, head_importance[layer_num], args.target_num_heads, head_size) print('value_weight = ', value_weight.shape) print('value_bias = ', value_bias.shape) layers._modules[str(layer_num)].attention.self.value.weight = torch.nn.Parameter(value_weight) layers._modules[str(layer_num)].attention.self.value.bias = torch.nn.Parameter(value_bias) # output matrix weight_sorted, _ = sort_by_importance( layers._modules[str(layer_num)].attention.output.dense.weight.transpose(0, 1), None, head_importance[layer_num], args.target_num_heads, head_size) weight_sorted = weight_sorted.transpose(0, 1) print('attention.output.dense.weight = ', weight_sorted.shape) layers._modules[str(layer_num)].attention.output.dense.weight = torch.nn.Parameter(weight_sorted) weight_sorted, bias_sorted = sort_by_importance( layers._modules[str(layer_num)].intermediate.dense.weight, layers._modules[str(layer_num)].intermediate.dense.bias, ffn_importance[layer_num], args.target_ffn_dim, 1) layers._modules[str(layer_num)].intermediate.dense.weight = torch.nn.Parameter(weight_sorted) layers._modules[str(layer_num)].intermediate.dense.bias = torch.nn.Parameter(bias_sorted) # ffn output matrix input side weight_sorted, _ = sort_by_importance( layers._modules[str(layer_num)].output.dense.weight.transpose(0, 1), None, ffn_importance[layer_num], args.target_ffn_dim, 1) weight_sorted = weight_sorted.transpose(0, 1) print('output.dense.weight = ', weight_sorted.shape) layers._modules[str(layer_num)].output.dense.weight = torch.nn.Parameter(weight_sorted) # set bert model's config for pruned model bert_model.config.num_attention_heads = min([num_heads, args.target_num_heads]) bert_model.config.intermediate_size = layers._modules['0'].intermediate.dense.weight.size(0)
def train_epoch(model, config, train_loader, val_loader, epoch_i): optimizer = config['optimizer'] scheduler = config['scheduler'] writer = config['writer'] scaler = config['scaler'] opt = config['opt'] if opt.criterion == 'MSELoss': criterion = torch.nn.MSELoss(reduction='sum').to(opt.device) else: criterion = torch.nn.CrossEntropyLoss().to(opt.device) # train one epoch model.train() total_loss = 0. final_val_loss = 0. total_examples = 0 st_time = time.time() optimizer.zero_grad() for local_step, (x,y) in tqdm(enumerate(train_loader), total=len(train_loader)): global_step = (len(train_loader) * epoch_i) + local_step x = to_device(x, opt.device) y = to_device(y, opt.device) with autocast(enabled=opt.use_amp): if opt.use_profiler: with profiler.profile(profile_memory=True, record_shapes=True) as prof: output = model(x) print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10)) else: output = model(x) loss = criterion(output, y) if opt.gradient_accumulation_steps > 1: loss = loss / opt.gradient_accumulation_steps # back-propagation - begin scaler.scale(loss).backward() if (local_step + 1) % opt.gradient_accumulation_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), opt.max_grad_norm) scaler.step(optimizer) scaler.update() optimizer.zero_grad() if opt.use_transformers_optimizer: scheduler.step() # back-propagation - end cur_examples = y.size(0) total_examples += cur_examples total_loss += (loss.item() * cur_examples) if writer: writer.add_scalar('Loss/train', loss.item(), global_step) cur_loss = total_loss / total_examples # evaluate eval_loss, eval_acc = evaluate(model, config, val_loader) curr_time = time.time() elapsed_time = (curr_time - st_time) / 60 st_time = curr_time curr_lr = scheduler.get_last_lr()[0] if scheduler else optimizer.param_groups[0]['lr'] logger.info('{:3d} epoch | {:5d}/{:5d} | train loss : {:6.3f}, valid loss {:6.3f}, valid acc {:.4f}| lr :{:7.6f} | {:5.2f} min elapsed'.\ format(epoch_i, local_step+1, len(train_loader), cur_loss, eval_loss, eval_acc, curr_lr, elapsed_time)) if writer: writer.add_scalar('Loss/valid', eval_loss, global_step) writer.add_scalar('Acc/valid', eval_acc, global_step) writer.add_scalar('LearningRate/train', curr_lr, global_step) return eval_loss, eval_acc
def evaluate(args): # set config config = load_config(args) if args.num_threads > 0: torch.set_num_threads(args.num_threads) config['args'] = args logger.info("%s", config) # set path set_path(config) # prepare test dataset test_loader = prepare_datasets(config) # load pytorch model checkpoint checkpoint = load_checkpoint(args.model_path, device=args.device) # prepare model and load parameters model = load_model(config, checkpoint) model.eval() # convert to onnx format if args.convert_onnx: # FIXME not working for --use_crf batch = next(iter(test_loader)) if config['emb_class'] not in ['glove', 'elmo']: x, y, gy = batch else: x, y = batch x = to_device(x, args.device) convert_onnx(config, model, x) check_onnx(config) logger.info("[ONNX model saved] : {}".format(args.onnx_path)) # quantize onnx if args.quantize_onnx: quantize_onnx(args.onnx_path, args.quantized_onnx_path) logger.info("[Quantized ONNX model saved] : {}".format( args.quantized_onnx_path)) return # load onnx model for using onnxruntime if args.enable_ort: import onnxruntime as ort sess_options = ort.SessionOptions() sess_options.inter_op_num_threads = args.num_threads sess_options.intra_op_num_threads = args.num_threads ort_session = ort.InferenceSession(args.onnx_path, sess_options=sess_options) # enable to use dynamic quantized model (pytorch>=1.3.0) if args.enable_dqm: model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) print(model) # evaluation preds = None ys = None gpreds = None gys = None n_batches = len(test_loader) total_examples = 0 whole_st_time = time.time() first_time = time.time() first_examples = 0 total_duration_time = 0.0 with torch.no_grad(): for i, batch in enumerate(tqdm(test_loader, total=n_batches)): start_time = time.time() if config['emb_class'] not in ['glove', 'elmo']: x, y, gy = batch gy = to_device(gy, args.device) else: x, y = batch x = to_device(x, args.device) y = to_device(y, args.device) if args.enable_ort: ort_inputs = build_onnx_input(config, ort_session, x) if args.use_crf: # FIXME not working for --use_crf if args.bert_use_mtl: logits, prediction, glogits = ort_session.run( None, ort_inputs) glogits = to_device(torch.tensor(glogits), args.device) else: logits, prediction = ort_session.run(None, ort_inputs) prediction = to_device(torch.tensor(prediction), args.device) logits = to_device(torch.tensor(logits), args.device) else: if args.bert_use_mtl: logits, glogits = ort_session.run(None, ort_inputs) glogits = to_device(torch.tensor(glogits), args.device) else: logits = ort_session.run(None, ort_inputs)[0] logits = to_device(torch.tensor(logits), args.device) logits = torch.softmax(logits, dim=-1) else: if args.use_crf: if args.bert_use_mtl: logits, prediction, glogits = model(x) else: logits, prediction = model(x) else: if args.bert_use_mtl: logits, glogits = model(x) else: logits = model(x) logits = torch.softmax(logits, dim=-1) if preds is None: if args.use_crf: preds = to_numpy(prediction) else: preds = to_numpy(logits) ys = to_numpy(y) else: if args.use_crf: preds = np.append(preds, to_numpy(prediction), axis=0) else: preds = np.append(preds, to_numpy(logits), axis=0) ys = np.append(ys, to_numpy(y), axis=0) if args.bert_use_mtl: glogits = torch.softmax(glogits, dim=-1) if gpreds is None: gpreds = to_numpy(glogits) gys = to_numpy(gy) else: gpreds = np.append(gpreds, to_numpy(glogits), axis=0) gys = np.append(gys, to_numpy(gy), axis=0) cur_examples = y.size(0) total_examples += cur_examples if i == 0: # first one may take longer time, so ignore in computing duration. first_time = float((time.time() - first_time) * 1000) first_examples = cur_examples if args.num_examples != 0 and total_examples >= args.num_examples: logger.info("[Stop Evaluation] : up to the {} examples".format( total_examples)) break duration_time = float((time.time() - start_time) * 1000) if i != 0: total_duration_time += duration_time ''' logger.info("[Elapsed Time] : {}ms".format(duration_time)) ''' whole_time = float((time.time() - whole_st_time) * 1000) avg_time = (whole_time - first_time) / (total_examples - first_examples) # generate report for token classification if not args.use_crf: preds = np.argmax(preds, axis=2) # compute measure using seqeval labels = config['labels'] ys_lbs = [[] for _ in range(ys.shape[0])] preds_lbs = [[] for _ in range(ys.shape[0])] pad_label_id = config['pad_label_id'] for i in range(ys.shape[0]): # foreach sentence for j in range(ys.shape[1]): # foreach token if ys[i][j] != pad_label_id: ys_lbs[i].append(labels[ys[i][j]]) preds_lbs[i].append(labels[preds[i][j]]) ret = { "precision": precision_score(ys_lbs, preds_lbs), "recall": recall_score(ys_lbs, preds_lbs), "f1": f1_score(ys_lbs, preds_lbs), "report": classification_report(ys_lbs, preds_lbs, digits=4), } print(ret['report']) # write predicted labels to file write_prediction(config, model, ys, preds, labels) # generate report for sequence classification if args.bert_use_mtl: glabels = config['glabels'] glabel_names = [ v for k, v in sorted(glabels.items(), key=lambda x: x[0]) ] glabel_ids = [k for k in glabels.keys()] gpreds_ids = np.argmax(gpreds, axis=1) try: g_report = sequence_classification_report( gys, gpreds_ids, target_names=glabel_names, labels=glabel_ids, digits=4) g_report_dict = sequence_classification_report( gys, gpreds_ids, target_names=glabel_names, labels=glabel_ids, output_dict=True) g_matrix = confusion_matrix(gys, gpreds_ids) ret['g_report'] = g_report ret['g_report_dict'] = g_report_dict ret['g_f1'] = g_report_dict['micro avg']['f1-score'] ret['g_matrix'] = g_matrix except Exception as e: logger.warn(str(e)) print(ret['g_report']) print(ret['g_f1']) print(ret['g_matrix']) logger.info("[sequence classification F1] : {}, {}".format( ret['g_f1'], total_examples)) # write predicted glabels to file write_gprediction(args, gpreds, glabels) logger.info("[token classification F1] : {}, {}".format( ret['f1'], total_examples)) logger.info("[Elapsed Time] : {} examples, {}ms, {}ms on average".format( total_examples, whole_time, avg_time)) logger.info( "[Elapsed Time(total_duration_time, average)] : {}ms, {}ms".format( total_duration_time, total_duration_time / (total_examples - 1)))
num_workers=0, drop_last=True) writer = tensorboard.SummaryWriter(log_dir=f'test/{logs_idx}') if len(saves) != 0: saves.sort(key=os.path.getmtime) checkpoint = torch.load(saves[-1], ) text_embedding.load_state_dict(checkpoint['text_embedding']) text_embedding.eval() encoder.load_state_dict(checkpoint['encoder']) encoder.eval() decoder.load_state_dict(checkpoint['decoder']) decoder.eval() for idx, batch in enumerate( tqdm(BackgroundGenerator(dataloader), total=len(dataloader))): text_data, text_pos, text_mask, mel_data, mel_pos, mel_mask, gate = to_device( batch, device) # audio_data = F.avg_pool1d(audio_data, kernel_size=2, padding=1) with torch.no_grad(): text_emb = text_embedding(text_data) text_pos_emb = pos_embedding_(text_pos) enc_out, att_heads_enc = encoder(text_emb, text_mask, text_pos_emb) mel_pos = torch.arange(1, 512).view(1, 511).to(device) mel_pos_emb_ = pos_embedding(mel_pos) mel_mask_ = torch.triu(torch.ones(511, 511, dtype=torch.bool), 1).unsqueeze(0).to(device) # [B, T, C], [B, T, C], [B, T, 1], [B, T, T_text] mel = torch.zeros(1, 511, 80).to(device) for pos_idx in tqdm(range(511)): mel_pos_emb = mel_pos_emb_[:, :pos_idx + 1] mel_mask = mel_mask_[:, :pos_idx + 1, :pos_idx + 1]
def train_epoch(model, config, train_loader, val_loader, epoch_i): opt = config['opt'] optimizer = config['optimizer'] scheduler = config['scheduler'] writer = config['writer'] scaler = config['scaler'] pad_label_id = config['pad_label_id'] criterion = nn.CrossEntropyLoss(ignore_index=pad_label_id).to(opt.device) n_batches = len(train_loader) prog = Progbar(target=n_batches) # train one epoch model.train() train_loss = 0. st_time = time.time() optimizer.zero_grad() for local_step, (x,y) in enumerate(train_loader): global_step = (len(train_loader) * epoch_i) + local_step x = to_device(x, opt.device) y = to_device(y, opt.device) if opt.use_crf: with autocast(enabled=opt.use_amp): if opt.use_profiler: with profiler.profile(profile_memory=True, record_shapes=True) as prof: logits, prediction = model(x) print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10)) else: logits, prediction = model(x) mask = torch.sign(torch.abs(x[0])).to(torch.uint8).to(opt.device) log_likelihood = model.crf(logits, y, mask=mask, reduction='mean') loss = -1 * log_likelihood if opt.gradient_accumulation_steps > 1: loss = loss / opt.gradient_accumulation_steps else: with autocast(enabled=opt.use_amp): if opt.use_profiler: with profiler.profile(profile_memory=True, record_shapes=True) as prof: logits = model(x) print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10)) else: logits = model(x) # reshape for computing loss logits_view = logits.view(-1, model.label_size) y_view = y.view(-1) loss = criterion(logits_view, y_view) if opt.gradient_accumulation_steps > 1: loss = loss / opt.gradient_accumulation_steps # back-propagation - begin scaler.scale(loss).backward() if (local_step + 1) % opt.gradient_accumulation_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), opt.max_grad_norm) scaler.step(optimizer) scaler.update() optimizer.zero_grad() if opt.use_transformers_optimizer: scheduler.step() # back-propagation - end train_loss += loss.item() if writer: writer.add_scalar('Loss/train', loss.item(), global_step) curr_lr = scheduler.get_last_lr()[0] if scheduler else optimizer.param_groups[0]['lr'] prog.update(local_step+1, [('global step', global_step), ('train curr loss', loss.item()), ('lr', curr_lr)]) train_loss = train_loss / n_batches # evaluate eval_ret = evaluate(model, config, val_loader) eval_loss = eval_ret['loss'] eval_f1 = eval_ret['f1'] curr_time = time.time() elapsed_time = (curr_time - st_time) / 60 st_time = curr_time curr_lr = scheduler.get_last_lr()[0] if scheduler else optimizer.param_groups[0]['lr'] logger.info('{:3d} epoch | {:5d}/{:5d} | train loss : {:10.6f}, valid loss {:10.6f}, valid f1 {:.4f}| lr :{:7.6f} | {:5.2f} min elapsed'.\ format(epoch_i, local_step+1, len(train_loader), train_loss, eval_loss, eval_f1, curr_lr, elapsed_time)) if writer: writer.add_scalar('Loss/valid', eval_loss, global_step) writer.add_scalar('F1/valid', eval_f1, global_step) writer.add_scalar('LearningRate/train', curr_lr, global_step) return eval_loss, eval_f1
def train_eprop3(model, optimizer, loss_function, batch_x, batch_y, truncation_delta): seq_len = batch_x.shape[1] loss_function_synth = nn.MSELoss() # reset model model.reset() # implementation of algo on page 33 of eprop paper for i, start in enumerate( range(truncation_delta, seq_len, truncation_delta)): # reset gradient optimizer.zero_grad() # select [t_{m-1}+1, ..., t_{m}] first_batch_x = batch_x[:, start - truncation_delta:start, :].clone() first_batch_y = batch_y[:, start - truncation_delta:start, :].clone() # simulate network over [t_{m-1}+1, ..., t_{m}] and backprop using the synthetic gradient prediction, _, first_synth_grad = model(first_batch_x) pred, gt = format_pred_and_gt(prediction, first_batch_y, model.single_output) loss = loss_function(pred, gt) loss.backward() # select [t_{m}+1, ..., t_{m+1}] (the next truncated time interval) second_batch_x = batch_x[:, start:start + truncation_delta, :].clone() second_batch_y = batch_y[:, start:start + truncation_delta, :].clone() # simulate and backprop using second synthetic gradient prediction, second_initial_state, second_synth_grad = model( second_batch_x) # retain grad of the initial hidden state of the second interval ... second_initial_state.retain_grad() pred, gt = format_pred_and_gt(prediction, second_batch_y, model.single_output) loss = loss_function(pred, gt) loss.backward() # ... and store it ... real_grad_x = second_initial_state.grad.detach() # ... to optimize the synth grad network using MSE loss = loss_function_synth(first_synth_grad, real_grad_x) loss.backward() real_grad_x_shape = real_grad_x.shape # train the final synthetic gradient to be close to 0 if start + truncation_delta == seq_len: zeros = to_device( torch.zeros(real_grad_x_shape, requires_grad=False)) loss = loss_function_synth(second_synth_grad, zeros) loss.backward() optimizer.step() with torch.no_grad(): prediction, _, _ = model(batch_x) pred, gt = format_pred_and_gt(prediction, batch_y, model.single_output) loss = loss_function(pred, gt) return loss.item()
def train_epoch(model, config, train_loader, val_loader, epoch_i): optimizer = config['optimizer'] scheduler = config['scheduler'] writer = config['writer'] opt = config['opt'] local_rank = opt.local_rank use_amp = opt.use_amp if opt.criterion == 'MSELoss': criterion = torch.nn.MSELoss(reduction='sum').to(opt.device) else: criterion = torch.nn.CrossEntropyLoss().to(opt.device) # train one epoch model.train() optimizer.zero_grad() total_loss = 0. final_val_loss = 0. total_examples = 0 st_time = time.time() for local_step, (x, y) in tqdm(enumerate(train_loader), total=len(train_loader)): global_step = (len(train_loader) * epoch_i) + local_step x = to_device(x, opt.device) y = to_device(y, opt.device) output = model(x) loss = criterion(output, y) # back-propagation - begin if opt.gradient_accumulation_steps > 1: loss = loss / opt.gradient_accumulation_steps if use_amp: with amp.scale_loss(loss, optimizer) as scaled_loss: try: scaled_loss.backward() except Exception as e: print(e) else: loss.backward() if (local_step + 1) % opt.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), opt.max_grad_norm) optimizer.step() if opt.use_transformers_optimizer: scheduler.step() optimizer.zero_grad() # back-propagation - end cur_examples = y.size(0) total_examples += cur_examples total_loss += (loss.item() * cur_examples) if local_rank == 0 and writer: writer.add_scalar('Loss/train', loss.item(), global_step) cur_loss = total_loss / total_examples # evaluate eval_loss, eval_acc = evaluate(model, config, val_loader) curr_time = time.time() elapsed_time = (curr_time - st_time) / 60 st_time = curr_time curr_lr = scheduler.get_lr( )[0] if scheduler else optimizer.param_groups[0]['lr'] if local_rank == 0: logger.info('{:3d} epoch | {:5d}/{:5d} | train loss : {:6.3f}, valid loss {:6.3f}, valid acc {:.4f}| lr :{:7.6f} | {:5.2f} min elapsed'.\ format(epoch_i, local_step+1, len(train_loader), cur_loss, eval_loss, eval_acc, curr_lr, elapsed_time)) if writer: writer.add_scalar('Loss/valid', eval_loss, global_step) writer.add_scalar('Acc/valid', eval_acc, global_step) writer.add_scalar('LearningRate/train', curr_lr, global_step) return eval_loss, eval_acc
def distill( teacher_config, teacher_model, student_config, student_model, train_loader, eval_loader, best_eval_metric=None, mpl_loader=None): args = teacher_config['opt'] teacher_layer_num = teacher_model.bert_model.config.num_hidden_layers student_layer_num = student_model.bert_model.config.num_hidden_layers # create teacher optimizer with larger L2 norm teacher_optimizer, _, _, _ = prepare_osws(teacher_config, teacher_model, train_loader, lr=args.mpl_learning_rate, weight_decay=args.mpl_weight_decay) # create student optimizer, scheduler, summary writer student_optimizer, student_scheduler, writer, _ = prepare_osws(student_config, student_model, train_loader, lr=args.lr, weight_decay=args.weight_decay) # prepare loss functions def soft_cross_entropy(predicts, targets): likelihood = F.log_softmax(predicts, dim=-1) targets_prob = F.softmax(targets, dim=-1) return (- targets_prob * likelihood).sum(dim=-1).mean() loss_mse_sum = MSELoss(reduction='sum').to(args.device) loss_mse = MSELoss().to(args.device) loss_cs = CosineSimilarity(dim=2).to(args.device) loss_cs_att = CosineSimilarity(dim=3).to(args.device) logger.info("***** Running distillation training *****") logger.info(" Num Batchs = %d", len(train_loader)) logger.info(" Num Epochs = %d", args.epoch) logger.info(" batch size = %d", args.batch_size) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 tr_loss, logging_loss = 0.0, 0.0 tr_att_loss = 0. tr_rep_loss = 0. tr_cls_loss = 0. teacher_model.zero_grad() student_model.zero_grad() epoch_iterator = range(epochs_trained, int(args.epoch)) # for reproductibility set_seed(args) for epoch_n in epoch_iterator: tr_att_loss = 0. tr_rep_loss = 0. tr_cls_loss = 0. train_iterator = tqdm(train_loader, desc=f"Epoch {epoch_n}") for step, (x, y) in enumerate(train_iterator): x = to_device(x, args.device) y = to_device(y, args.device) # ------------------------------------------------------------------------------------------------------- # teacher -> student, teaching with teacher_model.eval(), student_model.train() # ------------------------------------------------------------------------------------------------------- att_loss = 0. rep_loss = 0. cls_loss = 0. # teacher model output teacher_model.eval() with torch.no_grad(): output_teacher, teacher_bert_outputs = teacher_model(x, return_bert_outputs=True) # student model output student_model.train() output_student, student_bert_outputs = student_model(x, return_bert_outputs=True) # Knowledge Distillation loss # 1) logits distillation ''' kd_loss = soft_cross_entropy(output_student, output_teacher) ''' kd_loss = loss_mse_sum(output_student, output_teacher) loss = kd_loss tr_cls_loss += loss.item() # 2) embedding and last hidden state distillation if args.state_loss_ratio > 0.0: teacher_reps = teacher_bert_outputs.hidden_states student_reps = student_bert_outputs.hidden_states new_teacher_reps = [teacher_reps[0], teacher_reps[teacher_layer_num]] new_student_reps = [student_reps[0], student_reps[student_layer_num]] for student_rep, teacher_rep in zip(new_student_reps, new_teacher_reps): # cosine similarity loss if args.state_distill_cs: tmp_loss = 1.0 - loss_cs(student_rep, teacher_rep).mean() # MSE loss else: tmp_loss = loss_mse(student_rep, teacher_rep) rep_loss += tmp_loss loss += args.state_loss_ratio * rep_loss tr_rep_loss += rep_loss.item() # 3) Attentions distillation if args.att_loss_ratio > 0.0: teacher_atts = teacher_bert_outputs.attentions student_atts = student_bert_outputs.attentions assert teacher_layer_num == len(teacher_atts) assert student_layer_num == len(student_atts) assert teacher_layer_num % student_layer_num == 0 layers_per_block = int(teacher_layer_num / student_layer_num) new_teacher_atts = [teacher_atts[i * layers_per_block + layers_per_block - 1] for i in range(student_layer_num)] for student_att, teacher_att in zip(student_atts, new_teacher_atts): student_att = torch.where(student_att <= -1e2, torch.zeros_like(student_att).to(args.device), student_att) teacher_att = torch.where(teacher_att <= -1e2, torch.zeros_like(teacher_att).to(args.device), teacher_att) tmp_loss = 1.0 - loss_cs_att(student_att, teacher_att).mean() att_loss += tmp_loss loss += args.att_loss_ratio * att_loss tr_att_loss += att_loss.item() if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps # back propagate through student model loss.backward() tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(student_model.parameters(), args.max_grad_norm) student_optimizer.step() # update student model student_scheduler.step() # Update learning rate schedule student_model.zero_grad() global_step += 1 # ------------------------------------------------------------------------------------------------------- # ------------------------------------------------------------------------------------------------------- # student -> teacher, performance feedback/update with student_model.eval(), teacher_model.train() # ------------------------------------------------------------------------------------------------------- mpl_loss = 0.0 if mpl_loader and global_step > args.mpl_warmup_steps: loss_cross_entropy = torch.nn.CrossEntropyLoss().to(args.device) mpl_iterator = iter(mpl_loader) try: (x, y) = next(mpl_iterator) # draw random sample except StopIteration as e: mpl_iterator = iter(mpl_loader) (x, y) = next(mpl_iterator) # draw random sample x = to_device(x, args.device) y = to_device(y, args.device) # teacher model output teacher_model.train() output_teacher, teacher_bert_outputs = teacher_model(x, return_bert_outputs=True) # student model output student_model.eval() # updated student model output_student, student_bert_outputs = student_model(x, return_bert_outputs=True) # the loss is the performance of the student on the labeled data. # additionaly, we add the loss of the teacher on the labeled data for avoiding overfitting. mpl_loss = loss_cross_entropy(output_student, y) / 2 + loss_cross_entropy(output_teacher, y) / 2 if args.gradient_accumulation_steps > 1: mpl_loss = mpl_loss / args.gradient_accumulation_steps # back propagate through teacher model mpl_loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(teacher_model.parameters(), args.max_grad_norm) teacher_optimizer.step() # update teacher model teacher_model.zero_grad() student_model.zero_grad() # clear gradient info which was generated during forward computation. # ------------------------------------------------------------------------------------------------------- train_iterator.set_description(f"Epoch {epoch_n} loss: {loss:.3f}, mpl loss: {mpl_loss:.3f}") if writer: writer.add_scalar('loss', loss, global_step) writer.add_scalar('mpl_loss', mpl_loss, global_step) # ------------------------------------------------------------------------------------------------------- # evaluate student, save model # ------------------------------------------------------------------------------------------------------- flag_eval = False logs = {} if args.logging_steps > 0 and global_step % args.logging_steps == 0: flag_eval = True if flag_eval: if args.log_evaluate_during_training: eval_loss, eval_acc = evaluate(student_model, student_config, eval_loader) logs['eval_loss'] = eval_loss logs['eval_acc'] = eval_acc if writer: writer.add_scalar('eval_loss', eval_loss, global_step) writer.add_scalar('eval_acc', eval_acc, global_step) cls_loss = tr_cls_loss / (step + 1) att_loss = tr_att_loss / (step + 1) rep_loss = tr_rep_loss / (step + 1) loss_scalar = (tr_loss - logging_loss) / args.logging_steps learning_rate_scalar = student_scheduler.get_last_lr()[0] logs["learning_rate"] = learning_rate_scalar logs["avg_loss_since_last_log"] = loss_scalar logs['cls_loss'] = cls_loss logs['att_loss'] = att_loss logs['rep_loss'] = rep_loss logging_loss = tr_loss logging.info(json.dumps({**logs, **{"step": global_step}})) if writer: writer.add_scalar('learning_rate', learning_rate_scalar, global_step) writer.add_scalar('avg_loss_since_last_log', loss_scalar, global_step) writer.add_scalar('cls_loss', cls_loss, global_step) writer.add_scalar('att_loss', att_loss, global_step) writer.add_scalar('rep_loss', rep_loss, global_step) flag_eval = False if step == 0 and epoch_n != 0: flag_eval = True # every epoch if args.eval_and_save_steps > 0 and global_step % args.eval_and_save_steps == 0: flag_eval = True if flag_eval: eval_loss, eval_acc = evaluate(student_model, student_config, eval_loader) logs['eval_loss'] = eval_loss logs['eval_acc'] = eval_acc logger.info(json.dumps({**logs, **{"step": global_step}})) if writer: writer.add_scalar('eval_loss', eval_loss, global_step) writer.add_scalar('eval_acc', eval_acc, global_step) # measured by accuracy curr_eval_metric = eval_acc if best_eval_metric is None or curr_eval_metric > best_eval_metric: # save model to '--save_path', '--bert_output_dir' save_model(student_config, student_model, save_path=args.save_path) student_model.bert_tokenizer.save_pretrained(args.bert_output_dir) student_model.bert_model.save_pretrained(args.bert_output_dir) best_eval_metric = curr_eval_metric logger.info("[Best student model saved] : {:10.6f}, {}, {}".format(best_eval_metric, args.bert_output_dir, args.save_path)) # ------------------------------------------------------------------------------------------------------- return global_step, tr_loss / global_step, best_eval_metric
def evaluate(opt): # set config config = load_config(opt) if opt.num_threads > 0: torch.set_num_threads(opt.num_threads) config['opt'] = opt logger.info("%s", config) # set path set_path(config) # prepare test dataset test_loader = prepare_datasets(config) # load pytorch model checkpoint checkpoint = load_checkpoint(opt.model_path, device=opt.device) # prepare model and load parameters model = load_model(config, checkpoint) model.eval() # convert to onnx if opt.convert_onnx: (x, y) = next(iter(test_loader)) x = to_device(x, opt.device) y = to_device(y, opt.device) convert_onnx(config, model, x) check_onnx(config) logger.info("[ONNX model saved] :{}".format(opt.onnx_path)) # quantize onnx if opt.quantize_onnx: quantize_onnx(opt.onnx_path, opt.quantized_onnx_path) logger.info("[Quantized ONNX model saved] : {}".format( opt.quantized_onnx_path)) return # load onnx model for using onnxruntime if opt.enable_ort: import onnxruntime as ort sess_options = ort.SessionOptions() sess_options.inter_op_num_threads = opt.num_threads sess_options.intra_op_num_threads = opt.num_threads ort_session = ort.InferenceSession(opt.onnx_path, sess_options=sess_options) # convert to tvm format if opt.convert_tvm: (x, y) = next(iter(test_loader)) x = to_device(x, opt.device) y = to_device(y, opt.device) convert_tvm(config, model, x) logger.info("[TVM model saved] : {}".format(opt.tvm_path)) return # enable to use dynamic quantized model (pytorch>=1.3.0) if opt.enable_dqm and opt.device == 'cpu': model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) print(model) # evaluation preds = None ys = None correct = 0 n_batches = len(test_loader) total_examples = 0 whole_st_time = time.time() first_time = time.time() first_examples = 0 total_duration_time = 0.0 with torch.no_grad(): for i, (x, y) in enumerate(tqdm(test_loader, total=n_batches)): start_time = time.time() x = to_device(x, opt.device) y = to_device(y, opt.device) if opt.enable_ort: x = to_numpy(x) if config['emb_class'] == 'glove': ort_inputs = {ort_session.get_inputs()[0].name: x} else: if config['emb_class'] in [ 'roberta', 'distilbert', 'bart' ]: ort_inputs = { ort_session.get_inputs()[0].name: x[0], ort_session.get_inputs()[1].name: x[1] } else: ort_inputs = { ort_session.get_inputs()[0].name: x[0], ort_session.get_inputs()[1].name: x[1], ort_session.get_inputs()[2].name: x[2] } logits = ort_session.run(None, ort_inputs)[0] logits = to_device(torch.tensor(logits), opt.device) else: logits = model(x) if preds is None: preds = to_numpy(logits) ys = to_numpy(y) else: preds = np.append(preds, to_numpy(logits), axis=0) ys = np.append(ys, to_numpy(y), axis=0) predicted = logits.argmax(1) correct += (predicted == y).sum().item() cur_examples = y.size(0) total_examples += cur_examples if i == 0: # first one may take longer time, so ignore in computing duration. first_time = float((time.time() - first_time) * 1000) first_examples = cur_examples if opt.num_examples != 0 and total_examples >= opt.num_examples: logger.info("[Stop Evaluation] : up to the {} examples".format( total_examples)) break duration_time = float((time.time() - start_time) * 1000) if i != 0: total_duration_time += duration_time ''' logger.info("[Elapsed Time] : {}ms".format(duration_time)) ''' # generate report labels = config['labels'] label_names = [v for k, v in sorted(labels.items(), key=lambda x: x[0])] preds_ids = np.argmax(preds, axis=1) try: print( classification_report(ys, preds_ids, target_names=label_names, digits=4)) print(labels) print(confusion_matrix(ys, preds_ids)) except Exception as e: logger.warn(str(e)) acc = correct / total_examples whole_time = float((time.time() - whole_st_time) * 1000) avg_time = (whole_time - first_time) / (total_examples - first_examples) # write predictions to file write_prediction(opt, preds, labels) logger.info("[Accuracy] : {:.4f}, {:5d}/{:5d}".format( acc, correct, total_examples)) logger.info("[Elapsed Time] : {}ms, {}ms on average".format( whole_time, avg_time)) logger.info( "[Elapsed Time(total_duration_time, average)] : {}ms, {}ms".format( total_duration_time, total_duration_time / (total_examples - 1)))
def main(): if config.gpu and not torch.cuda.is_available(): raise ValueError("GPU not supported or enabled on this system.") use_gpu = config.gpu log.info("Loading train dataset") train_dataset = COVIDxFolder( config.train_imgs, config.train_labels, transforms.train_transforms(config.width, config.height)) train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, drop_last=True, num_workers=config.n_threads, pin_memory=use_gpu) log.info("Number of training examples {}".format(len(train_dataset))) log.info("Loading val dataset") val_dataset = COVIDxFolder( config.val_imgs, config.val_labels, transforms.val_transforms(config.width, config.height)) val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.n_threads, pin_memory=use_gpu) log.info("Number of validation examples {}".format(len(val_dataset))) if config.weights: #state = torch.load(config.weights) state = None log.info("Loaded model weights from: {}".format(config.weights)) else: state = None state_dict = state["state_dict"] if state else None model = architecture.COVIDNext50(n_classes=config.n_classes) if state_dict: model = util.load_model_weights(model=model, state_dict=state_dict) if use_gpu: model.cuda() model = torch.nn.DataParallel(model) optim_layers = filter(lambda p: p.requires_grad, model.parameters()) # optimizer and lr scheduler optimizer = Adam(optim_layers, lr=config.lr, weight_decay=config.weight_decay) scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=config.lr_reduce_factor, patience=config.lr_reduce_patience, mode='max', min_lr=1e-7) # Load the last global_step from the checkpoint if existing global_step = 0 if state is None else state['global_step'] + 1 class_weights = util.to_device(torch.FloatTensor(config.loss_weights), gpu=use_gpu) loss_fn = CrossEntropyLoss() # Reset the best metric score best_score = -1 for epoch in range(config.epochs): log.info("Started epoch {}/{}".format(epoch + 1, config.epochs)) for data in train_loader: imgs, labels = data imgs = util.to_device(imgs, gpu=use_gpu) labels = util.to_device(labels, gpu=use_gpu) logits = model(imgs) loss = loss_fn(logits, labels) optimizer.zero_grad() loss.backward() optimizer.step() if global_step % config.log_steps == 0 and global_step > 0: probs = model.module.probability(logits) preds = torch.argmax(probs, dim=1).detach().cpu().numpy() labels = labels.cpu().detach().numpy() acc, f1, _, _ = util.clf_metrics(preds, labels) lr = util.get_learning_rate(optimizer) log.info("Step {} | TRAINING batch: Loss {:.4f} | F1 {:.4f} | " "Accuracy {:.4f} | LR {:.2e}".format( global_step, loss.item(), f1, acc, lr)) if global_step % config.eval_steps == 0 and global_step > 0: best_score = validate(val_loader, model, best_score=best_score, global_step=global_step, cfg=config) scheduler.step(best_score) global_step += 1
def evaluate(opt): # set config config = load_config(opt) if opt.num_threads > 0: torch.set_num_threads(opt.num_threads) config['opt'] = opt logger.info("%s", config) # set path set_path(config) # prepare test dataset test_loader = prepare_datasets(config) # load pytorch model checkpoint checkpoint = load_checkpoint(config) # prepare model and load parameters model = load_model(config, checkpoint) model.eval() # convert to onnx format if opt.convert_onnx: (x, y) = next(iter(test_loader)) x = to_device(x, opt.device) y = to_device(y, opt.device) convert_onnx(config, model, x) check_onnx(config) logger.info("[ONNX model saved at {}".format(opt.onnx_path)) # quantize onnx if opt.quantize_onnx: quantize_onnx(opt.onnx_path, opt.quantized_onnx_path) logger.info("[Quantized ONNX model saved at {}".format( opt.quantized_onnx_path)) return # load onnx model for using onnxruntime if opt.enable_ort: import onnxruntime as ort sess_options = ort.SessionOptions() sess_options.inter_op_num_threads = opt.num_threads sess_options.intra_op_num_threads = opt.num_threads ort_session = ort.InferenceSession(opt.onnx_path, sess_options=sess_options) # enable to use dynamic quantized model (pytorch>=1.3.0) if opt.enable_dqm and opt.device == 'cpu': model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) print(model) # evaluation preds = None ys = None n_batches = len(test_loader) total_examples = 0 whole_st_time = time.time() first_time = time.time() first_examples = 0 total_duration_time = 0.0 with torch.no_grad(): for i, (x, y) in enumerate(tqdm(test_loader, total=n_batches)): start_time = time.time() x = to_device(x, opt.device) y = to_device(y, opt.device) if opt.enable_ort: x = to_numpy(x) if config['emb_class'] == 'glove': ort_inputs = { ort_session.get_inputs()[0].name: x[0], ort_session.get_inputs()[1].name: x[1] } if opt.use_char_cnn: ort_inputs[ort_session.get_inputs()[2].name] = x[2] if config['emb_class'] in [ 'bert', 'distilbert', 'albert', 'roberta', 'bart', 'electra' ]: if config['emb_class'] in ['distilbert', 'bart']: ort_inputs = { ort_session.get_inputs()[0].name: x[0], ort_session.get_inputs()[1].name: x[1] } else: ort_inputs = { ort_session.get_inputs()[0].name: x[0], ort_session.get_inputs()[1].name: x[1], ort_session.get_inputs()[2].name: x[2] } if opt.bert_use_pos: ort_inputs[ort_session.get_inputs()[3].name] = x[3] if opt.use_crf: logits, prediction = ort_session.run(None, ort_inputs) prediction = to_device(torch.tensor(prediction), opt.device) logits = to_device(torch.tensor(logits), opt.device) else: logits = ort_session.run(None, ort_inputs)[0] logits = to_device(torch.tensor(logits), opt.device) else: if opt.use_crf: logits, prediction = model(x) else: logits = model(x) if preds is None: if opt.use_crf: preds = to_numpy(prediction) else: preds = to_numpy(logits) ys = to_numpy(y) else: if opt.use_crf: preds = np.append(preds, to_numpy(prediction), axis=0) else: preds = np.append(preds, to_numpy(logits), axis=0) ys = np.append(ys, to_numpy(y), axis=0) cur_examples = y.size(0) total_examples += cur_examples if i == 0: # first one may take longer time, so ignore in computing duration. first_time = float((time.time() - first_time) * 1000) first_examples = cur_examples if opt.num_examples != 0 and total_examples >= opt.num_examples: logger.info("[Stop Evaluation] : up to the {} examples".format( total_examples)) break duration_time = float((time.time() - start_time) * 1000) if i != 0: total_duration_time += duration_time ''' logger.info("[Elapsed Time] : {}ms".format(duration_time)) ''' whole_time = float((time.time() - whole_st_time) * 1000) avg_time = (whole_time - first_time) / (total_examples - first_examples) if not opt.use_crf: preds = np.argmax(preds, axis=2) # compute measure using seqeval labels = model.labels ys_lbs = [[] for _ in range(ys.shape[0])] preds_lbs = [[] for _ in range(ys.shape[0])] pad_label_id = config['pad_label_id'] for i in range(ys.shape[0]): # foreach sentence for j in range(ys.shape[1]): # foreach token if ys[i][j] != pad_label_id: ys_lbs[i].append(labels[ys[i][j]]) preds_lbs[i].append(labels[preds[i][j]]) ret = { "precision": precision_score(ys_lbs, preds_lbs), "recall": recall_score(ys_lbs, preds_lbs), "f1": f1_score(ys_lbs, preds_lbs), "report": classification_report(ys_lbs, preds_lbs, digits=4), } print(ret['report']) f1 = ret['f1'] # write predicted labels to file default_label = config['default_label'] write_prediction(opt, ys, preds, labels, pad_label_id, default_label) logger.info("[F1] : {}, {}".format(f1, total_examples)) logger.info("[Elapsed Time] : {} examples, {}ms, {}ms on average".format( total_examples, whole_time, avg_time)) logger.info( "[Elapsed Time(total_duration_time, average)] : {}ms, {}ms".format( total_duration_time, total_duration_time / (total_examples - 1)))
def forward(ctx, ev_w_ih_x, ev_w_hh_x, ev_b_x, forgetgate_x, input_data, hx, cx, weight_ih, weight_hh, bias_ih=None, bias_hh=None): # calculate gates ... gates = (torch.mm(input_data, weight_ih.t()) + bias_ih + torch.mm(hx, weight_hh.t()) + bias_hh) ingate, forgetgate_y, cellgate, outgate = gates.chunk(4, 1) # ... and gate activations ingate = torch.sigmoid(ingate) forgetgate_y = torch.sigmoid(forgetgate_y) cellgate = torch.tanh(cellgate) outgate = torch.sigmoid(outgate) cy = (forgetgate_y * cx) + (ingate * cellgate) hy = outgate * torch.tanh(cy) # TODO: calculate new eligibility vector and trace # There exist distinct eligibility traces and vectors for the followiug parts of the LSTM cell: # - input to hidden connections, hidden to hidden connections, bias # all for each: # - ... inputgate, forgetgate and cellgate # => overall 3 * 3 = 9 eligibility traces hidden_size = hy.size(1) batch_size = input_data.size(0) input_size = input_data.size(1) ingate = ingate.unsqueeze(2) forgetgate_y = forgetgate_y.unsqueeze(2) cellgate = cellgate.unsqueeze(2) outgate = outgate.unsqueeze(2) input_data = input_data.unsqueeze(1) hx = hx.unsqueeze(2) cx = cx.unsqueeze(2) forgetgate_x = forgetgate_x.repeat(1, 3, 1) ones = to_device(torch.ones(ingate.size())) # the new eligibility vectors ... ev_w_ih_y = ev_w_ih_x * forgetgate_x ev_w_hh_y = ev_w_hh_x * forgetgate_x ev_b_y = ev_b_x * forgetgate_x # ingate base = ingate * (ones - ingate) * cellgate ev_w_hh_y[:, :hidden_size, :] += base * hx ev_w_ih_y[:, :hidden_size, :] += base * input_data ev_b_y[:, :hidden_size, :] += base # forgetgate #base = forgetgate_y * (ones - forgetgate_y) * cellgate base = forgetgate_y * (ones - forgetgate_y) * cx ev_w_hh_y[:, hidden_size:(2 * hidden_size), :] += base * hx ev_w_ih_y[:, hidden_size:(2 * hidden_size), :] += base * input_data ev_b_y[:, hidden_size:(2 * hidden_size), :] += base # cellgate base = ingate * (ones - cellgate**2) ev_w_hh_y[:, (2 * hidden_size):(3 * hidden_size), :] += base * hx ev_w_ih_y[:, (2 * hidden_size):(3 * hidden_size), :] += base * input_data ev_b_y[:, (2 * hidden_size):(3 * hidden_size), :] += base # ... and eligibility traces et_w_ih_y = to_device( torch.zeros(batch_size, 4 * hidden_size, input_size, requires_grad=False)) et_w_hh_y = to_device( torch.zeros(batch_size, 4 * hidden_size, hidden_size, requires_grad=False)) et_b_y = to_device( torch.zeros(batch_size, 4 * hidden_size, 1, requires_grad=False)) # calculate eligibility traces by multiplying the eligibility vectors with the outgate tmp_outgate = outgate.repeat(1, 3, 1) et_w_ih_y[:, :3 * hidden_size, :] = ev_w_ih_y * tmp_outgate et_w_hh_y[:, :3 * hidden_size, :] = ev_w_hh_y * tmp_outgate et_b_y[:, :3 * hidden_size, :] = ev_b_y * tmp_outgate # The gradient of the output gate is only dependent on the observable state # => just use normal gradient calculation of dE/dh * dh/dweight # => calculate second part of that equation now for input to hidden, hidden to hidden # and bias connections and multiply in the backward pass base = outgate * (ones - outgate) * cy.unsqueeze(2) et_w_hh_y[:, (3 * hidden_size):(4 * hidden_size)] = base * hx et_w_ih_y[:, (3 * hidden_size):(4 * hidden_size)] = base * input_data et_b_y[:, (3 * hidden_size):(4 * hidden_size)] = base ctx.save_for_backward(et_w_ih_y, et_w_hh_y, et_b_y, weight_hh, cx.clone().squeeze(), cy.clone().squeeze(), outgate.squeeze(), ingate.squeeze(), cellgate.squeeze(), forgetgate_y.squeeze()) return hy, cy, ev_w_ih_y, ev_w_hh_y, ev_b_y, forgetgate_y
def inference(opt): # set config config = load_config(opt) if opt.num_threads > 0: torch.set_num_threads(opt.num_threads) config['opt'] = opt # set path: opt.embedding_path, opt.vocab_path, opt.label_path set_path(config) # load pytorch model checkpoint checkpoint = load_checkpoint(opt.model_path, device=opt.device) # prepare model and load parameters model = load_model(config, checkpoint) model.eval() # load onnx model for using onnxruntime if opt.enable_ort: import onnxruntime as ort sess_options = ort.SessionOptions() sess_options.inter_op_num_threads = opt.num_threads sess_options.intra_op_num_threads = opt.num_threads ort_session = ort.InferenceSession(opt.onnx_path, sess_options=sess_options) # enable to use dynamic quantized model (pytorch>=1.3.0) if opt.enable_dqm and opt.device == 'cpu': model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) print(model) # prepare tokenizer tokenizer = prepare_tokenizer(config, model) # prepare labels labels = config['labels'] # inference f_out = open(opt.test_path + '.inference', 'w', encoding='utf-8') total_examples = 0 total_duration_time = 0.0 with torch.no_grad(), open(opt.test_path, 'r', encoding='utf-8') as f: for i, line in enumerate(f): start_time = time.time() sent, label = line.strip().split('\t') x_raw = sent.split() y_raw = label text = ' '.join(x_raw) x = encode_text(config, tokenizer, text) x = to_device(x, opt.device) if opt.enable_ort: x = to_numpy(x) if config['emb_class'] == 'glove': ort_inputs = {ort_session.get_inputs()[0].name: x} else: if config['emb_class'] in [ 'roberta', 'distilbert', 'bart' ]: ort_inputs = { ort_session.get_inputs()[0].name: x[0], ort_session.get_inputs()[1].name: x[1] } else: ort_inputs = { ort_session.get_inputs()[0].name: x[0], ort_session.get_inputs()[1].name: x[1], ort_session.get_inputs()[2].name: x[2] } logits = ort_session.run(None, ort_inputs)[0] logits = to_device(torch.tensor(logits), opt.device) else: logits = model(x) predicted = logits.argmax(1) predicted = to_numpy(predicted)[0] predicted_raw = labels[predicted] f_out.write(text + '\t' + y_raw + '\t' + predicted_raw + '\n') total_examples += 1 if opt.num_examples != 0 and total_examples >= opt.num_examples: logger.info("[Stop Inference] : up to the {} examples".format( total_examples)) break duration_time = float((time.time() - start_time) * 1000) if i != 0: total_duration_time += duration_time logger.info("[Elapsed Time] : {}ms".format(duration_time)) f_out.close() logger.info( "[Elapsed Time(total_duration_time, average)] : {}ms, {}ms".format( total_duration_time, total_duration_time / (total_examples - 1)))
def train_epoch(model, config, train_loader, valid_loader, epoch_i, best_eval_measure): optimizer = config['optimizer'] scheduler = config['scheduler'] writer = config['writer'] scaler = config['scaler'] opt = config['opt'] if opt.criterion == 'MSELoss': criterion = torch.nn.MSELoss(reduction='sum').to(opt.device) elif opt.criterion == 'KLDivLoss': criterion = torch.nn.KLDivLoss(reduction='sum').to(opt.device) elif opt.criterion == 'LabelSmoothingCrossEntropy': criterion = LabelSmoothingCrossEntropy(reduction='sum').to(opt.device) else: criterion = torch.nn.CrossEntropyLoss().to(opt.device) # train one epoch total_loss = 0 avg_loss = 0 local_best_eval_loss = float('inf') local_best_eval_acc = 0 total_examples = 0 st_time = time.time() optimizer.zero_grad() epoch_iterator = tqdm(train_loader, total=len(train_loader), desc=f"Epoch {epoch_i}") for local_step, (x, y) in enumerate(epoch_iterator): model.train() global_step = (len(train_loader) * epoch_i) + local_step x = to_device(x, opt.device) y = to_device(y, opt.device) with autocast(enabled=opt.use_amp): if opt.use_profiler: with profiler.profile(profile_memory=True, record_shapes=True) as prof: output = model(x) print(prof.key_averages().table( sort_by="self_cpu_memory_usage", row_limit=10)) else: output = model(x) if opt.criterion == 'KLDivLoss': loss = criterion(F.log_softmax(output, dim=1), y) else: loss = criterion(output, y) if opt.gradient_accumulation_steps > 1: loss = loss / opt.gradient_accumulation_steps # back-propagation - begin if opt.device == 'cpu': loss.backward() else: scaler.scale(loss).backward() if (local_step + 1) % opt.gradient_accumulation_steps == 0: if opt.device == 'cpu': optimizer.step() else: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), opt.max_grad_norm) scaler.step(optimizer) scaler.update() optimizer.zero_grad() scheduler.step() curr_lr = scheduler.get_last_lr( )[0] if scheduler else optimizer.param_groups[0]['lr'] epoch_iterator.set_description( f"Epoch {epoch_i}, local_step: {local_step}, loss: {loss:.3f}, curr_lr: {curr_lr:.7f}" ) if opt.eval_and_save_steps > 0 and global_step != 0 and global_step % opt.eval_and_save_steps == 0: # evaluate eval_loss, eval_acc = evaluate(model, config, valid_loader) if local_best_eval_loss > eval_loss: local_best_eval_loss = eval_loss if local_best_eval_acc < eval_acc: local_best_eval_acc = eval_acc if writer: writer.add_scalar('Loss/valid', eval_loss, global_step) writer.add_scalar('Acc/valid', eval_acc, global_step) writer.add_scalar('LearningRate/train', curr_lr, global_step) if opt.measure == 'loss': eval_measure = eval_loss else: eval_measure = eval_acc if opt.measure == 'loss': is_best = eval_measure < best_eval_measure else: is_best = eval_measure > best_eval_measure if is_best: best_eval_measure = eval_measure if opt.save_path and not opt.hp_search_optuna and not opt.hp_search_nni: logger.info("[Best model saved] : {}, {}".format( eval_loss, eval_acc)) save_model(config, model, valid_loader=valid_loader) # save finetuned bert model/config/tokenizer if config['emb_class'] not in ['glove']: if not os.path.exists(opt.bert_output_dir): os.makedirs(opt.bert_output_dir) model.bert_tokenizer.save_pretrained( opt.bert_output_dir) model.bert_model.save_pretrained( opt.bert_output_dir) # back-propagation - end cur_examples = y.size(0) total_examples += cur_examples total_loss += (loss.item() * cur_examples) if writer: writer.add_scalar('Loss/train', loss.item(), global_step) avg_loss = total_loss / total_examples # evaluate at the end of epoch eval_loss, eval_acc = evaluate(model, config, valid_loader) if local_best_eval_loss > eval_loss: local_best_eval_loss = eval_loss if local_best_eval_acc < eval_acc: local_best_eval_acc = eval_acc if writer: writer.add_scalar('Loss/valid', eval_loss, global_step) writer.add_scalar('Acc/valid', eval_acc, global_step) writer.add_scalar('LearningRate/train', curr_lr, global_step) if opt.measure == 'loss': eval_measure = eval_loss else: eval_measure = eval_acc if opt.measure == 'loss': is_best = eval_measure < best_eval_measure else: is_best = eval_measure > best_eval_measure if is_best: best_eval_measure = eval_measure if opt.save_path and not opt.hp_search_optuna and not opt.hp_search_nni: logger.info("[Best model saved] : {}, {}".format( eval_loss, eval_acc)) save_model(config, model, valid_loader=valid_loader) # save finetuned bert model/config/tokenizer if config['emb_class'] not in ['glove']: if not os.path.exists(opt.bert_output_dir): os.makedirs(opt.bert_output_dir) model.bert_tokenizer.save_pretrained(opt.bert_output_dir) model.bert_model.save_pretrained(opt.bert_output_dir) curr_time = time.time() elapsed_time = (curr_time - st_time) / 60 st_time = curr_time logs = { 'epoch': epoch_i, 'local_step': local_step + 1, 'epoch_step': len(train_loader), 'avg_loss': avg_loss, 'local_best_eval_loss': local_best_eval_loss, 'local_best_eval_acc': local_best_eval_acc, 'best_eval_measure': best_eval_measure, 'elapsed_time': elapsed_time } logger.info(json.dumps(logs, indent=4, ensure_ascii=False, sort_keys=True)) return local_best_eval_loss, local_best_eval_acc, best_eval_measure