def train(args, model, train_dataset, eval_dataset): train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8) eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size, num_workers=8) loss_fct = BCELoss() optimizer = AdamW(model.parameters(), lr=args.lr) print("***** Running training *****") print(" Num examples = %d" % (len(train_dataset))) print(" Num Val examples = %d" % (len(eval_dataset))) print(" Num Epochs = %d" % (args.epochs)) print(" Batch Size = %d" % (args.batch_size)) output_dir = join(args.out, args.save) if not os.path.exists(output_dir): os.makedirs(output_dir) log_file = open(join(output_dir, 'log'),'w') global_step = 0 best_val_auc = 0.0 running_loss = 0.0 model.zero_grad() for epoch in range(args.epochs): for step, batch in enumerate(train_dataloader): model.train() # start_time = time.time() xarray, position, token_type_list, mask, ylabel = batch xarray = xarray.to(args.device) position = position.to(args.device) token_type_list = token_type_list.to(args.device) mask = mask.to(args.device) ylabel = ylabel.to(args.device) # batch_end_time = time.time() output = model(xarray, position, token_type_list, mask) # output_time = time.time() loss = loss_fct(output.view(-1).to(torch.float32), ylabel.view(-1).to(torch.float32)) loss.backward() optimizer.step() model.zero_grad() # loss_time = time.time() running_loss += loss.item() global_step += 1 # print("Batch time",batch_end_time - start_time, "output_time", output_time - batch_end_time, "loss_time", loss_time - output_time) # print every logging_step steps if global_step % args.logging_step == 0 and global_step != 0: eval_result = eval(args, model, eval_dataloader) print('Epoch: %d, Global Step: %d, Loss: %.3f, Eval Loss: %.3f, Eval F1score: %.3f, Eval AUC: %.3f' % (epoch + 1, global_step, (running_loss / args.logging_step) , eval_result['loss'], eval_result['f1'], eval_result['auc'])) log_file.write('Epoch: %d, Global Step: %d, Loss: %.3f, Eval Loss: %.3f, Eval F1score: %.3f, Eval AUC: %.3f \n' % (epoch + 1, global_step, (running_loss / args.logging_step) , eval_result['loss'], eval_result['f1'], eval_result['auc'])) running_loss = 0.0 #If eval accuracy increases, save the model if eval_result['auc'] > best_val_auc: best_val_auc = eval_result['auc'] torch.save(model.state_dict(),os.path.join(output_dir, "model_state_dict.pt"),) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
def run_training(args, ls): ls.print('Training started: ' + datetime.now().strftime("%Y-%m-%d %H:%M:%S")) # Misc setup os.makedirs(args.model_dir, exist_ok=True) assert len(args.cnn_filters)%2 == 0 args.cnn_filters = list(zip(args.cnn_filters[:-1:2], args.cnn_filters[1::2])) # Load the vocabs vocabs = get_vocabs(os.path.join(args.model_dir, args.vocab_dir)) bert_tokenizer = None if args.with_bert: bert_tokenizer = BertEncoderTokenizer.from_pretrained(args.bert_path, do_lower_case=False) vocabs['bert_tokenizer'] = bert_tokenizer for name in vocabs: if name == 'bert_tokenizer': continue ls.print('Vocab %-20s size %5d coverage %.3f' % (name, vocabs[name].size, vocabs[name].coverage)) # Setup BERT encoder bert_encoder = None if args.with_bert: bert_encoder = BertEncoder.from_pretrained(args.bert_path) for p in bert_encoder.parameters(): p.requires_grad = False # Device and random setup torch.manual_seed(19940117) torch.cuda.manual_seed_all(19940117) random.seed(19940117) device = torch.device(args.device) # Create the model ls.print('Setting up the model') model = Parser(vocabs, args.word_char_dim, args.word_dim, args.pos_dim, args.ner_dim, args.concept_char_dim, args.concept_dim, args.cnn_filters, args.char2word_dim, args.char2concept_dim, args.embed_dim, args.ff_embed_dim, args.num_heads, args.dropout, args.snt_layers, args.graph_layers, args.inference_layers, args.rel_dim, device, args.pretrained_file, bert_encoder,) model = model.to(device) # Optimizer and weight decay params weight_decay_params = [] no_weight_decay_params = [] for name, param in model.named_parameters(): if name.endswith('bias') or 'layer_norm' in name: no_weight_decay_params.append(param) else: weight_decay_params.append(param) grouped_params = [{'params':weight_decay_params, 'weight_decay':1e-4}, {'params':no_weight_decay_params, 'weight_decay':0.}] optimizer = AdamW(grouped_params, 1., betas=(0.9, 0.999), eps=1e-6) # Re-load an existing model if requested used_batches = 0 batches_acm = 0 if args.resume_ckpt: ls.print('Resuming from checkpoint', args.resume_ckpt) ckpt = torch.load(args.resume_ckpt) model.load_state_dict(ckpt['model']) if ckpt.get('optimizer', {}): optimizer.load_state_dict(ckpt['optimizer']) else: ls.print('No optimizer state saved in checkpoint, using default initial optimizer') batches_acm = ckpt['batches_acm'] start_epoch = ckpt['epoch'] + 1 del ckpt else: start_epoch = 1 # don't start at 0 # Load data ls.print('Loading training data') train_data = DataLoader(vocabs, args.train_data, args.train_batch_size, for_train=True) train_data.set_unk_rate(args.unk_rate) # Train ls.print('Training') epoch, loss_avg, concept_loss_avg, arc_loss_avg, rel_loss_avg = 0, 0, 0, 0, 0 for epoch in range(start_epoch, args.epochs+1): st = time.time() for batch in train_data: model.train() batch = move_to_device(batch, model.device) concept_loss, arc_loss, rel_loss, graph_arc_loss = model(batch) loss = (concept_loss + arc_loss + rel_loss) / args.batches_per_update loss_value = loss.item() concept_loss_value = concept_loss.item() arc_loss_value = arc_loss.item() rel_loss_value = rel_loss.item() loss_avg = loss_avg * args.batches_per_update * 0.8 + 0.2 * loss_value concept_loss_avg = concept_loss_avg * 0.8 + 0.2 * concept_loss_value arc_loss_avg = arc_loss_avg * 0.8 + 0.2 * arc_loss_value rel_loss_avg = rel_loss_avg * 0.8 + 0.2 * rel_loss_value loss.backward() used_batches += 1 if not (used_batches % args.batches_per_update == -1 % args.batches_per_update): continue batches_acm += 1 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) lr = update_lr(optimizer, args.lr_scale, args.embed_dim, batches_acm, args.warmup_steps) optimizer.step() optimizer.zero_grad() # Summary at the end of the epoch dur = time.time() - st ls.print('Epoch %4d, Batch %5d, LR %.6f, conc_loss %.3f, arc_loss %.3f, rel_loss %.3f, duration %.1f seconds' % (epoch, batches_acm, lr, concept_loss_avg, arc_loss_avg, rel_loss_avg, dur)) # Evaluate and save the data every so often if (epoch>args.skip_evals or args.resume_ckpt is not None) and epoch % args.eval_every == 0: model.eval() ls.print('Evaluating and saving the model') fname = '%s/epoch%d.pt'%(args.model_dir, epoch) optim = optimizer.state_dict() if args.save_optimizer else {} torch.save({'args':vars(args), 'model':model.state_dict(), 'batches_acm': batches_acm, 'optimizer': optim, 'epoch':epoch}, fname) try: out_fn = 'epoch%d.pt.dev_generated' % (epoch) inference = Inference.build_from_model(model, vocabs) f_score, ctr = inference.reparse_annotated_file('.', args.dev_data, args.model_dir, out_fn, print_summary=False) ls.print('Smatch F: %.3f. Wrote %d AMR graphs to %s' % \ (f_score, ctr, os.path.join(args.model_dir, out_fn))) except: ls.print('Exception during generation') traceback.print_exc() model.train() # End time-stamp ls.print('Training finished: ' + datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
class Trainer(object): def __init__( self, model: nn.Module, learning_rate: float, device: torch.device, train_nodes: torch.LongTensor, val_nodes: torch.LongTensor, test_nodes: torch.LongTensor, vocab_size: int, results_dir: str, validate_every_n_epochs: int, save_after_n_epochs: int, checkpoint_every_n_epochs: int, use_early_stopping: bool, early_stopping_epochs: int, autodelete_checkpoints: bool, ): self.device = device self.model = model self.model.to(self.device) self.optimiser = AdamW( params=model.parameters(), lr=learning_rate, ) self.loss_fn = nn.CrossEntropyLoss() assert ( len(set(train_nodes).intersection(set(val_nodes))) == 0 ), f'There are overlapping nodes: {len(set(train_nodes).intersection(set(val_nodes)))}' assert ( len(set(train_nodes).intersection(set(test_nodes))) == 0 ), f'There are overlapping nodes: {len(set(train_nodes).intersection(set(test_nodes)))}' self.train_nodes = train_nodes self.val_nodes = val_nodes self.test_nodes = test_nodes self.vocab_size = vocab_size print(f'Vocabulary offset: {vocab_size}') self.results_dir = results_dir self.validate_every_n_epochs = validate_every_n_epochs self.save_after_n_epochs = save_after_n_epochs self.checkpoint_every_n_epochs = checkpoint_every_n_epochs self.use_early_stopping = use_early_stopping self.early_stopping_epochs = early_stopping_epochs self.has_saved_metric = False self._setup_dirs() self.metric_of_interest = 'val loss' self.best_metric = math.inf self.last_epoch_with_improvement = 1 self.autodelete_checkpoints = autodelete_checkpoints def _setup_dirs(self): self.ckpt_dir = os.path.join(self.results_dir, 'ckpt') self.best_model_dir = os.path.join(self.results_dir, 'best', 'models') self.best_preds_dir = os.path.join(self.results_dir, 'best', 'predictions') os.makedirs(self.ckpt_dir, exist_ok=True) os.makedirs(self.best_model_dir, exist_ok=True) os.makedirs(self.best_preds_dir, exist_ok=True) def __call__( self, input_features: torch.FloatTensor, adjacency: torch.sparse.FloatTensor, labels: torch.LongTensor, num_epochs: int, ): with trange(num_epochs, desc='Training progress: ') as t: for epoch_num in range(1, num_epochs + 1): train_metrics = self._train_epoch(input_features, adjacency, labels) if (epoch_num % self.validate_every_n_epochs) == 0 or epoch_num == 1: # Validate and save metrics val_metrics = self._val_epoch(input_features, adjacency, labels) save_metrics( file_path=os.path.join(self.results_dir, 'train-log.jsonl'), epoch_num=epoch_num, train_metrics=train_metrics, val_metrics=val_metrics, is_first_metric_save=not self.has_saved_metric, ) self.has_saved_metric = True if epoch_num > self.save_after_n_epochs and ( epoch_num % self.checkpoint_every_n_epochs) == 0: # Save model self._checkpoint_model(epoch_num) if self._is_best(val_metrics): self._save_best_model(epoch_num) self._save_test_predictions( input_features, adjacency, labels, epoch_num) if self.use_early_stopping: if self._is_best(val_metrics): self.last_epoch_with_improvement = epoch_num if epoch_num > self.last_epoch_with_improvement + self.early_stopping_epochs: note = f'Breaking on epoch {epoch_num} after no improvement since epoch \ {self.last_epoch_with_improvement}' print(note) save_training_notes( file_path=os.path.join(self.results_dir, 'training-notes.jsonl'), epoch_num=epoch_num, note=note, ) break else: # if we haven't validated, create an empt val metric dict val_metrics = {'val loss': None} t.set_postfix(train_loss=train_metrics['train loss'], val_loss=val_metrics['val loss']) t.update() return None def _train_epoch( self, input_features: torch.FloatTensor, adjacency: torch.sparse.FloatTensor, labels: torch.LongTensor, ) -> Dict[str, Any]: """ NOTE: Although we pass in all input features and labels, we only evaluate the loss on training set node indicies. """ start_time = time.time() self.model.train() self.optimiser.zero_grad() logits = self.model(input_features, adjacency) train_loss = self.loss_fn(logits[self.train_nodes + self.vocab_size], labels[self.train_nodes]) train_loss.backward() self.optimiser.step() # print(f'train loss: {train_loss}') duration = time.time() - start_time return { 'train epoch duration': duration, 'train loss': train_loss.item() } def _val_epoch( self, input_features: torch.FloatTensor, adjacency: torch.sparse.FloatTensor, labels: torch.LongTensor, ) -> Dict[str, Any]: self.model.eval() logits = self.model(input_features, adjacency) val_loss = self.loss_fn(logits[self.val_nodes + self.vocab_size], labels[self.val_nodes]) # print(f'val loss: {val_loss}') val_accuracy = accuracy(logits[self.val_nodes + self.vocab_size], labels[self.val_nodes], is_logit_output=True) return { 'val loss': val_loss.item(), 'F-score': None, 'Accuracy': val_accuracy } def _checkpoint_model(self, epoch: int) -> None: """ Checkpoint to resume training """ torch.save( { 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimiser_state_dict': self.optimiser.state_dict(), 'loss': self.loss_fn, }, os.path.join(self.ckpt_dir, f'model-{epoch}.pt'), ) # delete exists checkpoints (except for the one we just saved) if self.autodelete_checkpoints: checkpoints = glob(os.path.join(self.ckpt_dir, f'model-*.pt')) old_checkpoints = [ checkpoint for checkpoint in checkpoints if f'model-{epoch}' not in checkpoint ] for old_checkpoint in old_checkpoints: os.remove(old_checkpoint) def _save_best_model(self, epoch: int) -> None: """ Save best model for inference """ torch.save(self.model.state_dict(), os.path.join(self.best_model_dir, f'model-{epoch}.pt')) remove_previous_best_model(self.best_model_dir, epoch) def _save_test_predictions( self, input_features: torch.FloatTensor, adjacency: torch.sparse.FloatTensor, labels: torch.LongTensor, epoch: int, ) -> None: """ Save test set predictions for the best model """ self.model.eval() logits = self.model(input_features, adjacency) predictions = get_predictions(logits[self.test_nodes + self.vocab_size], labels[self.test_nodes], is_logit_output=True) torch.save( predictions, os.path.join(self.best_preds_dir, f'predictions-{epoch}.pt')) remove_previous_best_predictions(self.best_preds_dir, epoch) def _is_best(self, val_metrics: Dict[str, float]) -> bool: if 'loss' in self.metric_of_interest: if val_metrics[self.metric_of_interest] <= self.best_metric: self.best_metric = val_metrics[self.metric_of_interest] return True else: return False else: # Assume we want to maximise it if it is not a loss if val_metrics[self.metric_of_interest] > self.best_metric: self.best_metric = val_metrics[self.metric_of_interest] return True else: return False def save_test_metrics( self, input_features: torch.FloatTensor, adjacency: torch.sparse.FloatTensor, labels: torch.LongTensor, ) -> None: files_in_dir = os.listdir(self.best_preds_dir) assert len( files_in_dir ) == 1, f'Found more than one prediction file in:\n{files_in_dir}' test_predictions = torch.load( os.path.join(self.best_preds_dir, files_in_dir[0])) test_labels = labels[self.test_nodes] num_correct = float( torch.sum( torch.eq(test_predictions.type_as(test_labels), test_labels))) test_accuracy = num_correct / len(test_labels) test_macro_f1 = f1_score(labels[self.test_nodes], test_predictions, average='macro') save_dict_to_json( { 'test-accuracy': test_accuracy, 'test_macro_f1': test_macro_f1 }, os.path.join(self.results_dir, 'test-log.jsonl'), )
class Trainer(): def __init__(self, train_dataloader, test_dataloader, lr, betas, weight_decay, log_freq, with_cuda, model=None): cuda_condition = torch.cuda.is_available() and with_cuda self.device = torch.device("cuda" if cuda_condition else "cpu") print("Use:", "cuda:0" if cuda_condition else "cpu") self.model = Classifier_M3().to(self.device) self.optim = AdamW(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) self.scheduler = lr_scheduler.CosineAnnealingLR(self.optim, 5) self.criterion = nn.BCEWithLogitsLoss() if model != None: checkpoint = torch.load(model) self.model.load_state_dict(checkpoint['model_state_dict']) self.optim.load_state_dict(checkpoint['optimizer_state_dict']) self.epoch = checkpoint['epoch'] self.criterion = checkpoint['loss'] if torch.cuda.device_count() > 1: self.model = nn.DataParallel(self.model) print("Using %d GPUS for Converter" % torch.cuda.device_count()) self.train_data = train_dataloader self.test_data = test_dataloader self.log_freq = log_freq print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()])) self.test_loss = [] self.train_loss = [] self.train_f1_score = [] self.test_f1_score = [] def train(self, epoch): self.iteration(epoch, self.train_data) def test(self, epoch): self.iteration(epoch, self.test_data, train=False) def iteration(self, epoch, data_loader, train=True): """ :param epoch: 現在のepoch :param data_loader: torch.utils.data.DataLoader :param train: trainかtestかのbool値 """ str_code = "train" if train else "test" data_iter = tqdm(enumerate(data_loader), desc="EP_%s:%d" % (str_code, epoch), total=len(data_loader), bar_format="{l_bar}{r_bar}") total_element = 0 loss_store = 0.0 f1_score_store = 0.0 total_correct = 0 for i, data in data_iter: specgram = data[0].to(self.device) label = data[2].to(self.device) one_hot_label = data[1].to(self.device) predict_label = self.model(specgram, train) # predict_f1_score = get_F1_score( label.cpu().detach().numpy(), convert_label(predict_label.cpu().detach().numpy()), average='micro') loss = self.criterion(predict_label, one_hot_label) # if train: self.optim.zero_grad() loss.backward() self.optim.step() self.scheduler.step() loss_store += loss.item() f1_score_store += predict_f1_score self.avg_loss = loss_store / (i + 1) self.avg_f1_score = f1_score_store / (i + 1) post_fix = { "epoch": epoch, "iter": i, "avg_loss": round(self.avg_loss, 5), "loss": round(loss.item(), 5), "avg_f1_score": round(self.avg_f1_score, 5) } data_iter.write(str(post_fix)) self.train_loss.append( self.avg_loss) if train else self.test_loss.append(self.avg_loss) self.train_f1_score.append( self.avg_f1_score) if train else self.test_f1_score.append( self.avg_f1_score) def save(self, epoch, file_path="../models/2k/"): """ """ output_path = file_path + f"crnn_ep{epoch}.model" torch.save( { 'epoch': epoch, 'model_state_dict': self.model.cpu().state_dict(), 'optimizer_state_dict': self.optim.state_dict(), 'criterion': self.criterion }, output_path) self.model.to(self.device) print("EP:%d Model Saved on:" % epoch, output_path) return output_path def export_log(self, epoch, file_path="../../logs/2k/"): df = pd.DataFrame({ "train_loss": self.train_loss, "test_loss": self.test_loss, "train_F1_score": self.train_f1_score, "test_F1_score": self.test_f1_score }) output_path = file_path + f"loss_timestrech.log" print("EP:%d logs Saved on:" % epoch, output_path) df.to_csv(output_path)
def train(args, train_dataset, model, tokenizer, writer): args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = RandomSampler(train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate_fn) train_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay, }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=train_total) if os.path.isfile(os.path.join( args.pretrain_model_path, "optimizer.pt")) and os.path.isfile( os.path.join(args.pretrain_model_path, "scheduler.pt")): optimizer.load_state_dict( torch.load(os.path.join(args.pretrain_model_path, "optimizer.pt"))) scheduler.load_state_dict( torch.load(os.path.join(args.pretrain_model_path, "scheduler.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) print("***** Running training *****") global_step = 0 steps_trained_in_current_epoch = 0 if os.path.exists(args.pretrain_model_path ) and "checkpoint" in args.pretrain_model_path: global_step = int( args.pretrain_model_path.split("-")[-1].split("/")[0]) epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( len(train_dataloader) // args.gradient_accumulation_steps) train_loss, logging_loss = 0.0, 0.0 model.zero_grad() for _ in range(int(args.num_train_epochs)): pbar = ProgressBar(n_total=len(train_dataloader), desc='Training') for step, batch in enumerate(train_dataloader): if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue model.train() batch = tuple(t.to(args.device) for t in batch) inputs = { "input_ids": batch[0], "attention_mask": batch[1], "start_positions": batch[3], "end_positions": batch[4] } inputs["token_type_ids"] = (batch[2] if args.model_type in ["bert"] else None) outputs = model(**inputs) loss = outputs[0] writer.add_scalar("Train_loss", loss.item(), step) if args.n_gpu > 1: loss = loss.mean() if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() pbar(step, {'loss': loss.item()}) train_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) scheduler.step() optimizer.step() model.zero_grad() global_step += 1 if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: if args.local_rank == -1: evaluate(args, model, tokenizer, writer) if args.local_rank in [ -1, 0 ] and args.save_steps > 0 and global_step % args.save_steps == 0: output_dir = os.path.join( args.output_dir, "checkpoint-{}".format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = (model.module if hasattr(model, "module") else model) model_to_save.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) tokenizer.save_vocabulary(output_dir) print("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) print(" ") if 'cuda' in str(args.device): torch.cuda.empty_cache() return global_step, train_loss / global_step
def train(): """ Train the model using the parameters defined in the config file """ print('Initialising ...') cfg = TrainConfig() checkpoint_folder = 'checkpoints/{}/'.format(cfg.experiment_name) if not os.path.exists(checkpoint_folder): os.makedirs(checkpoint_folder) tb_folder = 'tb/{}/'.format(cfg.experiment_name) if not os.path.exists(tb_folder): os.makedirs(tb_folder) writer = SummaryWriter(logdir=tb_folder, flush_secs=30) model = ParrotModel().cuda().train() optimiser = AdamW(model.parameters(), lr=cfg.initial_lr, weight_decay=cfg.weight_decay) train_dataset = ParrotDataset(cfg.train_labels, cfg.mp3_folder) train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, num_workers=cfg.workers, collate_fn=parrot_collate_function, pin_memory=True) val_dataset = ParrotDataset(cfg.val_labels, cfg.mp3_folder) val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, num_workers=cfg.workers, collate_fn=parrot_collate_function, shuffle=False, pin_memory=True) epochs = cfg.epochs init_loss, step = 0., 0 avg_loss = AverageMeter() print('Starting training') for epoch in range(epochs): loader_length = len(train_loader) epoch_start = time.time() for batch_idx, batch in enumerate(train_loader): optimiser.zero_grad() # VRAM control by skipping long examples if batch['spectrograms'].shape[-1] > cfg.max_time: continue # inference target = batch['targets'].cuda() model_input = batch['spectrograms'].cuda() model_output = model(model_input) # loss input_lengths = batch['input_lengths'].cuda() target_lengths = batch['target_lengths'].cuda() loss = ctc_loss(model_output, target, input_lengths, target_lengths) loss.backward() if epoch == 0 and batch_idx == 0: init_loss = loss # logging elapsed = time.time() - epoch_start progress = batch_idx / loader_length est = datetime.timedelta( seconds=int(elapsed / progress)) if progress > 0.001 else '-' avg_loss.update(loss) suffix = '\tloss {:.4f}/{:.4f}\tETA [{}/{}]'.format( avg_loss.avg, init_loss, datetime.timedelta(seconds=int(elapsed)), est) printProgressBar(batch_idx, loader_length, suffix=suffix, prefix='Epoch [{}/{}]\tStep [{}/{}]'.format( epoch, epochs, batch_idx, loader_length)) writer.add_scalar('Steps/train_loss', loss, step) # saving the model if step % cfg.checkpoint_every == 0: test_name = '{}/test_epoch{}.mp3'.format( checkpoint_folder, epoch) test_mp3_file(cfg.test_mp3, model, test_name) checkpoint_name = '{}/epoch_{}.pth'.format( checkpoint_folder, epoch) torch.save( { 'model': model.state_dict(), 'epoch': epoch, 'batch_idx': loader_length, 'step': step, 'optimiser': optimiser.state_dict() }, checkpoint_name) # validating if step % cfg.val_every == 0: val(model, val_loader, writer, step) model = model.train() step += 1 optimiser.step() # end of epoch print('') writer.add_scalar('Epochs/train_loss', avg_loss.avg, epoch) avg_loss.reset() test_name = '{}/test_epoch{}.mp3'.format(checkpoint_folder, epoch) test_mp3_file(cfg.test_mp3, model, test_name) checkpoint_name = '{}/epoch_{}.pth'.format(checkpoint_folder, epoch) torch.save( { 'model': model.state_dict(), 'epoch': epoch, 'batch_idx': loader_length, 'step': step, 'optimiser': optimiser.state_dict() }, checkpoint_name) # finished training writer.close() print('Training finished :)')
def train(args, train_dataset, model, tokenizer): """ Train the model """ args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = RandomSampler( train_dataset) if args.local_rank == -1 else DistributedSampler( train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay, }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) # Check if saved optimizer or scheduler states exist if os.path.isfile(os.path.join( args.model_name_or_path, "optimizer.pt")) and os.path.isfile( os.path.join(args.model_name_or_path, "scheduler.pt")): # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) scheduler.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if os.path.exists(args.model_name_or_path): # set global_step to global_step of last saved checkpoint from model path try: global_step = int( args.model_name_or_path.split("-")[-1].split("/")[0]) except ValueError: global_step = 0 epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( len(train_dataloader) // args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) tr_loss, logging_loss = 0.0, 0.0 model.zero_grad() train_iterator = trange( epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0], ) set_seed(args) # Added here for reproductibility for _ in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) for step, batch in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue model.train() batch = tuple(t.to(args.device) for t in batch) inputs = { "input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3] } inputs["token_type_ids"] = ( batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids outputs = model(**inputs) loss = outputs[ 0] # model outputs are always tuple in transformers (see doc) if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if step % 10 == 0: print(step, loss.item()) tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps len(epoch_iterator) <= args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator)): if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: logs = {} if ( args.local_rank == -1 and args.evaluate_during_training ): # Only evaluate when single GPU otherwise metrics may not average well results = evaluate(args, model, tokenizer) for key, value in results.items(): eval_key = "eval_{}".format(key) logs[eval_key] = value loss_scalar = (tr_loss - logging_loss) / args.logging_steps learning_rate_scalar = scheduler.get_lr()[0] logs["learning_rate"] = learning_rate_scalar logs["loss"] = loss_scalar logging_loss = tr_loss print(json.dumps({**logs, **{"step": global_step}})) if args.local_rank in [ -1, 0 ] and args.save_steps > 0 and global_step % args.save_steps == 0: # Save model checkpoint output_dir = os.path.join( args.output_dir, "checkpoint-{}".format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break return global_step, tr_loss / global_step
def main(): args = parseArguments() os.makedirs(args.modelDir, exist_ok=True) checkpointDir = os.path.join(args.modelDir, 'checkpoints') os.makedirs(checkpointDir, exist_ok=True) os.makedirs(args.ensembleDir, exist_ok=True) with EventTimer('Preparing for dataset / dataloader'): trainDataset = ProductDataset(os.path.join(args.dataDir, 'train'), os.path.join(args.trainImages), transform=trainingPreprocessing) validDataset = ProductDataset(os.path.join(args.dataDir, 'train'), os.path.join(args.validImages), transform=inferencePreprocessing) trainDataloader = DataLoader(trainDataset, batch_size=args.batchSize, num_workers=args.numWorkers, shuffle=True) validDataloader = DataLoader(validDataset, batch_size=args.batchSize, num_workers=args.numWorkers, shuffle=False) print(f'> Training dataset:\t{len(trainDataset)}') print(f'> Validation dataset:\t{len(validDataset)}') with EventTimer(f'Load pretrained model - {args.pretrainModel}'): model = models.GetPretrainedModel(args.pretrainModel, fcDims=args.fcDims + [42]) print(model) #torchsummary will crash under densenet, skip the summary. #torchsummary.summary(model, (3, 224, 224), device='cpu') with EventTimer(f'Train model'): model.cuda() criterion = CrossEntropyLoss() optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.l2) scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6) history = [] if args.retrain != 0: checkpoint = torch.load( os.path.join(checkpointDir, f'checkpoint-{args.retrain:03d}.pt')) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler.load_state_dict(checkpoint['scheduler_state_dict']) history = checkpoint['history'] def runEpoch(dataloader, train=False, name=''): # For empty validation dataloader if len(dataloader) == 0: return 0, 0 # Enable grad with (torch.enable_grad() if train else torch.no_grad()): if train: model.train() else: model.eval() losses = [] for img, label, imgPath in tqdm(dataloader, desc=name, ncols=80): if train: optimizer.zero_grad() output = model(img.cuda()).cpu() loss = criterion(output, label) if train: loss.backward() optimizer.step() accu = accuracy(output.data.numpy(), label.numpy()) losses.append((loss.item(), accu)) return map(np.mean, zip(*losses)) def cleanUp(): model.eval() train_pred = np.zeros((trainDataloader.__len__()) * args.batchSize) cnt = 0 for i, (data, label, path) in enumerate(trainDataloader): test_pred = model(data.cuda()) pred = np.max(test_pred.cpu().data.numpy(), axis=1) train_pred[cnt:cnt + len(pred)] = pred cnt += len(pred) sorted_pred = train_pred sorted_pred.sort() threshold = sorted_pred[(len(sorted_pred) // 20)] data_set = [[], []] for i, (data, label, path) in enumerate(trainDataloader): test_pred = model(data.cuda()) pred = np.max(test_pred.cpu().data.numpy(), axis=1) for j in range(len(pred)): if pred[j] >= threshold: data_set[0].append(path[j]) data_set[1].append(label[j]) newDataset = ProductDataset(os.path.join(args.dataDir, 'train'), os.path.join(args.trainImages), transform=trainingPreprocessing, data=data_set) newDataloader = DataLoader(newDataset, batch_size=args.batchSize, num_workers=args.numWorkers, shuffle=True) print( f"{newDataloader.__len__() * args.batchSize} images remain after cleanup" ) return newDataloader for epoch in range(args.retrain + 1, args.epochs + 1): with EventTimer(verbose=False) as et: print(f'====== Epoch {epoch:3d} / {args.epochs:3d} ======') trainLoss, trainAccu = runEpoch(trainDataloader, train=True, name='training ') validLoss, validAccu = runEpoch(validDataloader, name='validation') history.append( ((trainLoss, trainAccu), (validLoss, validAccu))) scheduler.step() print( f'[{et.gettime():.4f}s] Training: {trainLoss:.6f} / {trainAccu:.4f} ; Validation {validLoss:.6f} / {validAccu:.4f}' ) if args.cleanup and epoch % args.cleanup_epoch == 0: with EventTimer('Cleaning Training Set'): trainDataloader = cleanUp() if epoch % 5 == 0: torch.save( { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'history': history, }, os.path.join(checkpointDir, f'checkpoint-{epoch:03d}.pt')) # save model as its coressponding name torch.save(model.state_dict(), os.path.join(args.modelDir, 'model-weights.pt')) utils.pickleSave(history, os.path.join(args.modelDir, 'history.pkl'))
class Trainer(): def __init__(self, alphabets_, list_ngram): self.vocab = Vocab(alphabets_) self.synthesizer = SynthesizeData(vocab_path="") self.list_ngrams_train, self.list_ngrams_valid = self.train_test_split( list_ngram, test_size=0.1) print("Loaded data!!!") print("Total training samples: ", len(self.list_ngrams_train)) print("Total valid samples: ", len(self.list_ngrams_valid)) INPUT_DIM = self.vocab.__len__() OUTPUT_DIM = self.vocab.__len__() self.device = DEVICE self.num_iters = NUM_ITERS self.beamsearch = BEAM_SEARCH self.batch_size = BATCH_SIZE self.print_every = PRINT_PER_ITER self.valid_every = VALID_PER_ITER self.checkpoint = CHECKPOINT self.export_weights = EXPORT self.metrics = MAX_SAMPLE_VALID logger = LOG if logger: self.logger = Logger(logger) self.iter = 0 self.model = Seq2Seq(input_dim=INPUT_DIM, output_dim=OUTPUT_DIM, encoder_embbeded=ENC_EMB_DIM, decoder_embedded=DEC_EMB_DIM, encoder_hidden=ENC_HID_DIM, decoder_hidden=DEC_HID_DIM, encoder_dropout=ENC_DROPOUT, decoder_dropout=DEC_DROPOUT) self.optimizer = AdamW(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09) self.scheduler = OneCycleLR(self.optimizer, total_steps=self.num_iters, pct_start=PCT_START, max_lr=MAX_LR) self.criterion = LabelSmoothingLoss(len(self.vocab), padding_idx=self.vocab.pad, smoothing=0.1) self.train_gen = self.data_gen(self.list_ngrams_train, self.synthesizer, self.vocab, is_train=True) self.valid_gen = self.data_gen(self.list_ngrams_valid, self.synthesizer, self.vocab, is_train=False) self.train_losses = [] # to device self.model.to(self.device) self.criterion.to(self.device) def train_test_split(self, list_phrases, test_size=0.1): list_phrases = list_phrases train_idx = int(len(list_phrases) * (1 - test_size)) list_phrases_train = list_phrases[:train_idx] list_phrases_valid = list_phrases[train_idx:] return list_phrases_train, list_phrases_valid def data_gen(self, list_ngrams_np, synthesizer, vocab, is_train=True): dataset = AutoCorrectDataset(list_ngrams_np, transform_noise=synthesizer, vocab=vocab, maxlen=MAXLEN) shuffle = True if is_train else False gen = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=shuffle, drop_last=False) return gen def step(self, batch): self.model.train() batch = self.batch_to_device(batch) src, tgt = batch['src'], batch['tgt'] src, tgt = src.transpose(1, 0), tgt.transpose( 1, 0) # batch x src_len -> src_len x batch outputs = self.model( src, tgt) # src : src_len x B, outpus : B x tgt_len x vocab # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)')) outputs = outputs.view(-1, outputs.size(2)) # flatten(0, 1) tgt_output = tgt.transpose(0, 1).reshape( -1) # flatten() # tgt: tgt_len xB , need convert to B x tgt_len loss = self.criterion(outputs, tgt_output) self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) self.optimizer.step() self.scheduler.step() loss_item = loss.item() return loss_item def train(self): print("Begin training from iter: ", self.iter) total_loss = 0 total_loader_time = 0 total_gpu_time = 0 best_acc = -1 data_iter = iter(self.train_gen) for i in range(self.num_iters): self.iter += 1 start = time.time() try: batch = next(data_iter) except StopIteration: data_iter = iter(self.train_gen) batch = next(data_iter) total_loader_time += time.time() - start start = time.time() loss = self.step(batch) total_gpu_time += time.time() - start total_loss += loss self.train_losses.append((self.iter, loss)) if self.iter % self.print_every == 0: info = 'iter: {:06d} - train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format( self.iter, total_loss / self.print_every, self.optimizer.param_groups[0]['lr'], total_loader_time, total_gpu_time) total_loss = 0 total_loader_time = 0 total_gpu_time = 0 print(info) self.logger.log(info) if self.iter % self.valid_every == 0: val_loss, preds, actuals, inp_sents = self.validate() acc_full_seq, acc_per_char, cer = self.precision(self.metrics) info = 'iter: {:06d} - valid loss: {:.3f} - acc full seq: {:.4f} - acc per char: {:.4f} - CER: {:.4f} '.format( self.iter, val_loss, acc_full_seq, acc_per_char, cer) print(info) print("--- Sentence predict ---") for pred, inp, label in zip(preds, inp_sents, actuals): infor_predict = 'Pred: {} - Inp: {} - Label: {}'.format( pred, inp, label) print(infor_predict) self.logger.log(infor_predict) self.logger.log(info) if acc_full_seq > best_acc: self.save_weights(self.export_weights) best_acc = acc_full_seq self.save_checkpoint(self.checkpoint) def validate(self): self.model.eval() total_loss = [] max_step = self.metrics / self.batch_size with torch.no_grad(): for step, batch in enumerate(self.valid_gen): batch = self.batch_to_device(batch) src, tgt = batch['src'], batch['tgt'] src, tgt = src.transpose(1, 0), tgt.transpose(1, 0) outputs = self.model(src, tgt, 0) # turn off teaching force outputs = outputs.flatten(0, 1) tgt_output = tgt.flatten() loss = self.criterion(outputs, tgt_output) total_loss.append(loss.item()) preds, actuals, inp_sents, probs = self.predict(5) del outputs del loss if step > max_step: break total_loss = np.mean(total_loss) self.model.train() return total_loss, preds[:3], actuals[:3], inp_sents[:3] def predict(self, sample=None): pred_sents = [] actual_sents = [] inp_sents = [] for batch in self.valid_gen: batch = self.batch_to_device(batch) if self.beamsearch: translated_sentence = batch_translate_beam_search( batch['src'], self.model) prob = None else: translated_sentence, prob = translate(batch['src'], self.model) pred_sent = self.vocab.batch_decode(translated_sentence.tolist()) actual_sent = self.vocab.batch_decode(batch['tgt'].tolist()) inp_sent = self.vocab.batch_decode(batch['src'].tolist()) pred_sents.extend(pred_sent) actual_sents.extend(actual_sent) inp_sents.extend(inp_sent) if sample is not None and len(pred_sents) > sample: break return pred_sents, actual_sents, inp_sents, prob def precision(self, sample=None): pred_sents, actual_sents, _, _ = self.predict(sample=sample) acc_full_seq = compute_accuracy(actual_sents, pred_sents, mode='full_sequence') acc_per_char = compute_accuracy(actual_sents, pred_sents, mode='per_char') cer = compute_accuracy(actual_sents, pred_sents, mode='CER') return acc_full_seq, acc_per_char, cer def visualize_prediction(self, sample=16, errorcase=False, fontname='serif', fontsize=16): pred_sents, actual_sents, img_files, probs = self.predict(sample) if errorcase: wrongs = [] for i in range(len(img_files)): if pred_sents[i] != actual_sents[i]: wrongs.append(i) pred_sents = [pred_sents[i] for i in wrongs] actual_sents = [actual_sents[i] for i in wrongs] img_files = [img_files[i] for i in wrongs] probs = [probs[i] for i in wrongs] img_files = img_files[:sample] fontdict = {'family': fontname, 'size': fontsize} def visualize_dataset(self, sample=16, fontname='serif'): n = 0 for batch in self.train_gen: for i in range(self.batch_size): img = batch['img'][i].numpy().transpose(1, 2, 0) sent = self.vocab.decode(batch['tgt_input'].T[i].tolist()) n += 1 if n >= sample: return def load_checkpoint(self, filename): checkpoint = torch.load(filename) self.optimizer.load_state_dict(checkpoint['optimizer']) self.scheduler.load_state_dict(checkpoint['scheduler']) self.model.load_state_dict(checkpoint['state_dict']) self.iter = checkpoint['iter'] self.train_losses = checkpoint['train_losses'] def save_checkpoint(self, filename): state = { 'iter': self.iter, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'train_losses': self.train_losses, 'scheduler': self.scheduler.state_dict() } path, _ = os.path.split(filename) os.makedirs(path, exist_ok=True) torch.save(state, filename) def load_weights(self, filename): state_dict = torch.load(filename, map_location=torch.device(self.device)) for name, param in self.model.named_parameters(): if name not in state_dict: print('{} not found'.format(name)) elif state_dict[name].shape != param.shape: print('{} missmatching shape, required {} but found {}'.format( name, param.shape, state_dict[name].shape)) del state_dict[name] self.model.load_state_dict(state_dict, strict=False) def save_weights(self, filename): path, _ = os.path.split(filename) os.makedirs(path, exist_ok=True) torch.save(self.model.state_dict(), filename) def batch_to_device(self, batch): src = batch['src'].to(self.device, non_blocking=True) tgt = batch['tgt'].to(self.device, non_blocking=True) batch = {'src': src, 'tgt': tgt} return batch
class Trainer: """ Handles model training and evaluation. Arguments: ---------- config: A dictionary of training parameters, likely from a .yaml file model: A pytorch segmentation model (e.g. DeepLabV3) trn_data: A pytorch dataloader object that will return pairs of images and segmentation masks from a training dataset val_data: A pytorch dataloader object that will return pairs of images and segmentation masks from a validation dataset. """ def __init__(self, config, model, trn_data, val_data=None): self.config = config self.model = model.cuda() self.trn_data = DataFetcher(trn_data) self.val_data = val_data #create the optimizer if config['optim'] == 'SGD': self.optimizer = SGD(model.parameters(), lr=config['lr'], momentum=config['momentum'], weight_decay=config['wd']) elif config['optim'] == 'AdamW': self.optimizer = AdamW( model.parameters(), lr=config['lr'], weight_decay=config['wd']) #momentum is default else: optim = config['optim'] raise Exception( f'Optimizer {optim} is not supported! Must be SGD or AdamW') #create the learning rate scheduler schedule = config['lr_policy'] if schedule == 'OneCycle': self.scheduler = OneCycleLR(self.optimizer, config['lr'], total_steps=config['iters']) elif schedule == 'MultiStep': self.scheduler = MultiStepLR(self.optimizer, milestones=config['lr_decay_epochs']) elif schedule == 'Poly': func = lambda iteration: (1 - (iteration / config['iters']) )**config['power'] self.scheduler = LambdaLR(self.optimizer, func) else: lr_policy = config['lr_policy'] raise Exception( f'Policy {lr_policy} is not supported! Must be OneCycle, MultiStep or Poly' ) #create the loss criterion if config['num_classes'] > 1: #load class weights if they were given in the config file if 'class_weights' in config: weight = torch.Tensor(config['class_weights']).float().cuda() else: weight = None self.criterion = nn.CrossEntropyLoss(weight=weight).cuda() else: self.criterion = nn.BCEWithLogitsLoss().cuda() #define train and validation metrics and class names class_names = config['class_names'] #make training metrics using the EMAMeter. this meter gives extra #weight to the most recent metric values calculated during training #this gives a better reflection of how well the model is performing #when the metrics are printed trn_md = { name: metric_lookup[name](EMAMeter()) for name in config['metrics'] } self.trn_metrics = ComposeMetrics(trn_md, class_names) self.trn_loss_meter = EMAMeter() #the only difference between train and validation metrics #is that we use the AverageMeter. this is because there are #no weight updates during evaluation, so all batches should #count equally val_md = { name: metric_lookup[name](AverageMeter()) for name in config['metrics'] } self.val_metrics = ComposeMetrics(val_md, class_names) self.val_loss_meter = AverageMeter() self.logging = config['logging'] #now, if we're resuming from a previous run we need to load #the state for the model, optimizer, and schedule and resume #the mlflow run (if there is one and we're using logging) if config['resume']: self.resume(config['resume']) elif self.logging: #if we're not resuming, but are logging, then we #need to setup mlflow with a new experiment #everytime that Trainer is instantiated we want to #end the current active run and let a new one begin mlflow.end_run() #extract the experiment name from config so that #we know where to save our files, if experiment name #already exists, we'll use it, otherwise we create a #new experiment mlflow.set_experiment(self.config['experiment_name']) #add the config file as an artifact mlflow.log_artifact(config['config_file']) #we don't want to add everything in the config #to mlflow parameters, we'll just add the most #likely to change parameters mlflow.log_param('lr_policy', config['lr_policy']) mlflow.log_param('optim', config['optim']) mlflow.log_param('lr', config['lr']) mlflow.log_param('wd', config['wd']) mlflow.log_param('bsz', config['bsz']) mlflow.log_param('momentum', config['momentum']) mlflow.log_param('iters', config['iters']) mlflow.log_param('epochs', config['epochs']) mlflow.log_param('encoder', config['encoder']) mlflow.log_param('finetune_layer', config['finetune_layer']) mlflow.log_param('pretraining', config['pretraining']) def resume(self, checkpoint_fpath): """ Sets model parameters, scheduler and optimizer states to the last recorded values in the given checkpoint file. """ checkpoint = torch.load(checkpoint_fpath, map_location='cpu') self.model.load_state_dict(checkpoint['state_dict']) if not self.config['restart_training']: self.scheduler.load_state_dict(checkpoint['scheduler']) self.optimizer.load_state_dict(checkpoint['optimizer']) if self.logging and 'run_id' in checkpoint: mlflow.start_run(run_id=checkpoint['run_id']) print(f'Loaded state from {checkpoint_fpath}') print(f'Resuming from epoch {self.scheduler.last_epoch}...') def log_metrics(self, step, dataset): #get the corresponding losses and metrics dict for #either train or validation sets if dataset == 'train': losses = self.trn_loss_meter metric_dict = self.trn_metrics.metrics_dict elif dataset == 'valid': losses = self.val_loss_meter metric_dict = self.val_metrics.metrics_dict #log the last loss, using the dataset name as a prefix mlflow.log_metric(dataset + '_loss', losses.avg, step=step) #log all the metrics in our dict, using dataset as a prefix metrics = {} for k, v in metric_dict.items(): values = v.meter.avg for class_name, val in zip(self.trn_metrics.class_names, values): metrics[dataset + '_' + class_name + '_' + k] = float( val.item()) mlflow.log_metrics(metrics, step=step) def train(self): """ Defines a pytorch style training loop for the model withtqdm progress bar for each epoch and handles printing loss/metrics at the end of each epoch. epochs: Number of epochs to train model train_iters_per_epoch: Number of training iterations is each epoch. Reducing this number will give more frequent updates but result in slower training time. Results: ---------- After train_iters_per_epoch iterations are completed, it will evaluate the model on val_data if there is any, then prints loss and metrics for train and validation datasets. """ #set the inner and outer training loop as either #iterations or epochs depending on our scheduler if self.config['lr_policy'] != 'MultiStep': last_epoch = self.scheduler.last_epoch + 1 total_epochs = self.config['iters'] iters_per_epoch = 1 outer_loop = tqdm(range(last_epoch, total_epochs + 1), file=sys.stdout, initial=last_epoch, total=total_epochs) inner_loop = range(iters_per_epoch) else: last_epoch = self.scheduler.last_epoch + 1 total_epochs = self.config['epochs'] iters_per_epoch = len(self.trn_data) outer_loop = range(last_epoch, total_epochs + 1) inner_loop = tqdm(range(iters_per_epoch), file=sys.stdout) #determine the epochs at which to print results eval_epochs = total_epochs // self.config['num_prints'] save_epochs = total_epochs // self.config['num_save_checkpoints'] #the cudnn.benchmark flag speeds up performance #when the model input size is constant. See: #https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936 cudnn.benchmark = True #perform training over the outer and inner loops for epoch in outer_loop: for iteration in inner_loop: #load the next batch of training data images, masks = self.trn_data.load() #run the training iteration loss, output = self._train_1_iteration(images, masks) #record the loss and evaluate metrics self.trn_loss_meter.update(loss) self.trn_metrics.evaluate(output, masks) #when we're at an eval_epoch we want to print #the training results so far and then evaluate #the model on the validation data if epoch % eval_epochs == 0: #before printing results let's record everything in mlflow #(if we're using logging) if self.logging: self.log_metrics(epoch, dataset='train') print('\n') #print a new line to give space from progess bar print(f'train_loss: {self.trn_loss_meter.avg:.3f}') self.trn_loss_meter.reset() #prints and automatically resets the metric averages to 0 self.trn_metrics.print() #run evaluation if we have validation data if self.val_data is not None: #before evaluation we want to turn off cudnn #benchmark because the input sizes of validation #images are not necessarily constant cudnn.benchmark = False self.evaluate() if self.logging: self.log_metrics(epoch, dataset='valid') print( '\n') #print a new line to give space from progess bar print(f'valid_loss: {self.val_loss_meter.avg:.3f}') self.val_loss_meter.reset() #prints and automatically resets the metric averages to 0 self.val_metrics.print() #turn cudnn.benchmark back on before returning to training cudnn.benchmark = True #update the optimizer schedule self.scheduler.step() #the last step is to save the training state if #at a checkpoint if epoch % save_epochs == 0: self.save_state(epoch) def _train_1_iteration(self, images, masks): #run a training step self.model.train() self.optimizer.zero_grad() #forward pass output = self.model(images) loss = self.criterion(output, masks) #backward pass loss.backward() self.optimizer.step() #return the loss value and the output return loss.item(), output.detach() def evaluate(self): """ Evaluation method used at the end of each epoch. Not intended to generate predictions for validation dataset, it only returns average loss and stores metrics for validaiton dataset. Use Validator class for generating masks on a dataset. """ #set the model into eval mode self.model.eval() val_iter = DataFetcher(self.val_data) for _ in range(len(val_iter)): with torch.no_grad(): #load batch of data images, masks = val_iter.load() output = self.model.eval()(images) loss = self.criterion(output, masks) self.val_loss_meter.update(loss.item()) self.val_metrics.evaluate(output.detach(), masks) #loss and metrics are updated inplace, so there's nothing to return return None def save_state(self, epoch): """ Saves the self.model state dict Arguments: ------------ save_path: Path of .pt file for saving Example: ---------- trainer = Trainer(...) trainer.save_model(model_path + 'new_model.pt') """ #save the state together with the norms that we're using state = { 'state_dict': self.model.state_dict(), 'scheduler': self.scheduler.state_dict(), 'optimizer': self.optimizer.state_dict(), 'norms': self.config['training_norms'] } if self.logging: state['run_id'] = mlflow.active_run().info.run_id #the last step is to create the name of the file to save #the format is: name-of-experiment_pretraining_epoch.pth model_dir = self.config['model_dir'] exp_name = self.config['experiment_name'] pretraining = self.config['pretraining'] ft_layer = self.config['finetune_layer'] if self.config['lr_policy'] != 'MultiStep': total_epochs = self.config['iters'] else: total_epochs = self.config['epochs'] if os.path.isfile(pretraining): #this is slightly clunky, but it handles the case #of using custom pretrained weights from a file #usually there aren't any '.'s other than the file #extension pretraining = pretraining.split('/')[-2] #.split('.')[0] save_path = os.path.join( model_dir, f'{exp_name}-{pretraining}_ft_{ft_layer}_epoch{epoch}_of_{total_epochs}.pth' ) torch.save(state, save_path)
class Detector(object): def __init__(self, cfg): self.device = cfg["device"] self.model = Models().get_model(cfg["network"]) # cfg.network self.model.to(self.device) params = [p for p in self.model.parameters() if p.requires_grad] self.optimizer = AdamW(params, lr=0.00001) self.lr_scheduler = OneCycleLR(self.optimizer, max_lr=1e-4, epochs=cfg["nepochs"], steps_per_epoch=169, # len(dataloader)/accumulations div_factor=25, # for initial lr, default: 25 final_div_factor=1e3, # for final lr, default: 1e4 ) def fit(self, data_loader, accumulation_steps=4, wandb=None): self.model.train() # metric_logger = utils.MetricLogger(delimiter=" ") # metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) avg_loss = MetricLogger('scalar') total_loss = MetricLogger('dict') lr_log = MetricLogger('list') self.optimizer.zero_grad() device = self.device for i, (images, targets) in enumerate(data_loader): images = list(image.to(device) for image in images) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] loss_dict = self.model(images, targets) losses = sum(loss for loss in loss_dict.values()) loss_value = losses.detach().item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) sys.exit(1) losses.backward() if (i+1) % accumulation_steps == 0: self.optimizer.step() self.optimizer.zero_grad() if self.lr_scheduler is not None: self.lr_scheduler.step() lr_log.update(self.lr_scheduler.get_last_lr()) print(f"\rTrain iteration: [{i+1}/{len(data_loader)}]", end="") avg_loss.update(loss_value) total_loss.update(loss_dict) # metric_logger.update(loss=losses_reduced, **loss_dict_reduced) # metric_logger.update(lr=optimizer.param_groups[0]["lr"]) print() #print(loss_dict) return {"train_avg_loss": avg_loss.avg}, total_loss.avg def mixup_fit(self, data_loader, accumulation_steps=4, wandb=None): self.model.train() torch.cuda.empty_cache() # metric_logger = utils.MetricLogger(delimiter=" ") # metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) avg_loss = MetricLogger('scalar') total_loss = MetricLogger('dict') #lr_log = MetricLogger('list') self.optimizer.zero_grad() device = self.device for i, (batch1, batch2) in enumerate(data_loader): images1, targets1 = batch1 images2, targets2 = batch2 images = mixup_images(images1, images2) targets = merge_targets(targets1, targets2) del images1, images2, targets1, targets2, batch1, batch2 images = list(image.to(device) for image in images) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] loss_dict = self.model(images, targets) losses = sum(loss for loss in loss_dict.values()) loss_value = losses.detach().item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) sys.exit(1) losses.backward() if (i+1) % accumulation_steps == 0: self.optimizer.step() self.optimizer.zero_grad() if self.lr_scheduler is not None: self.lr_scheduler.step() #lr_log.update(self.lr_scheduler.get_last_lr()) print(f"Train iteration: [{i+1}/{674}]\r", end="") avg_loss.update(loss_value) total_loss.update(loss_dict) # metric_logger.update(loss=losses_reduced, **loss_dict_reduced) # metric_logger.update(lr=optimizer.param_groups[0]["lr"]) print() #print(loss_dict) return {"train_avg_loss": avg_loss.avg}, total_loss.avg def evaluate(self, val_dataloader): device = self.device torch.cuda.empty_cache() # self.model.to(device) self.model.eval() mAp_logger = MetricLogger('list') with torch.no_grad(): for (j, batch) in enumerate(val_dataloader): print(f"\rValidation: [{j+1}/{len(val_dataloader)}]", end="") images, targets = batch del batch images = [img.to(device) for img in images] # targets = [{k: v.to(device) for k, v in t.items()} for t in targets] predictions = self.model(images)#, targets) for i, pred in enumerate(predictions): probas = pred["scores"].detach().cpu().numpy() mask = probas > 0.6 preds = pred["boxes"].detach().cpu().numpy()[mask] gts = targets[i]["boxes"].detach().cpu().numpy() score, scores = map_score(gts, preds, thresholds=[.5, .55, .6, .65, .7, .75]) mAp_logger.update(scores) print() return {"validation_mAP_score": mAp_logger.avg} def get_checkpoint(self): self.model.eval() model_state = self.model.state_dict() optimizer_state = self.optimizer.state_dict() checkpoint = {'model_state_dict': model_state, 'optimizer_state_dict': optimizer_state } # if self.lr_scheduler: # scheduler_state = self.lr_scheduler.state_dict() # checkpoint['lr_scheduler_state_dict'] = scheduler_state return checkpoint def load_checkpoint(self, checkpoint): self.model.eval() self.model.load_state_dict(checkpoint["model_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
def main() -> None: global best_loss args = parser.parse_args() if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') start_epoch = 0 vcf_reader = VCFReader(args.train_data, args.classification_map, args.chromosome, args.class_hierarchy) vcf_writer = vcf_reader.get_vcf_writer() train_dataset, validation_dataset = vcf_reader.get_datasets( args.validation_split) train_sampler = BatchByLabelRandomSampler(args.batch_size, train_dataset.labels) train_loader = DataLoader(train_dataset, batch_sampler=train_sampler) if args.validation_split != 0: validation_sampler = BatchByLabelRandomSampler( args.batch_size, validation_dataset.labels) validation_loader = DataLoader(validation_dataset, batch_sampler=validation_sampler) kwargs = { 'total_size': vcf_reader.positions.shape[0], 'window_size': args.window_size, 'num_layers': args.layers, 'num_classes': len(vcf_reader.label_encoder.classes_), 'num_super_classes': len(vcf_reader.super_label_encoder.classes_) } model = WindowedMLP(**kwargs) model.to(get_device(args)) optimizer = AdamW(model.parameters(), lr=args.learning_rate) ####### if args.resume_path is not None: if os.path.isfile(args.resume_path): print("=> loading checkpoint '{}'".format(args.resume_path)) checkpoint = torch.load(args.resume_path) if kwargs != checkpoint['model_kwargs']: raise ValueError( 'The checkpoint\'s kwargs don\'t match the ones used to initialize the model' ) if vcf_reader.snps.shape[0] != checkpoint['vcf_writer'].snps.shape[ 0]: raise ValueError( 'The data on which the checkpoint was trained had a different number of snp positions' ) start_epoch = checkpoint['epoch'] best_loss = checkpoint['best_loss'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume_path, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) ############# if args.validate: validate(validation_loader, model, nn.functional.binary_cross_entropy_with_logits, len(vcf_reader.label_encoder.classes_), len(vcf_reader.super_label_encoder.classes_), vcf_reader.maf, args) return for epoch in range(start_epoch, args.epochs + start_epoch): loss = train(train_loader, model, nn.functional.binary_cross_entropy_with_logits, optimizer, len(vcf_reader.label_encoder.classes_), len(vcf_reader.super_label_encoder.classes_), vcf_reader.maf, epoch, args) if epoch % args.save_freq == 0 or epoch == args.epochs + start_epoch - 1: if args.validation_split != 0: validation_loss = validate( validation_loader, model, nn.functional.binary_cross_entropy_with_logits, len(vcf_reader.label_encoder.classes_), len(vcf_reader.super_label_encoder.classes_), vcf_reader.maf, args) is_best = validation_loss < best_loss best_loss = min(validation_loss, best_loss) else: is_best = loss < best_loss best_loss = min(loss, best_loss) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'model_kwargs': kwargs, 'best_loss': best_loss, 'optimizer': optimizer.state_dict(), 'vcf_writer': vcf_writer, 'label_encoder': vcf_reader.label_encoder, 'super_label_encoder': vcf_reader.super_label_encoder, 'maf': vcf_reader.maf }, is_best, args.chromosome, args.model_name, args.model_dir)
class SAJEM(): ''' Self-Attention based Joint Embedding Model Consist of 2 branches to encode image and text ''' def __init__(self, image_encoder, text_encoder, image_mha, bert_model, optimizer='adam', lr=1e-3, l2_regularization=1e-2, margin_loss=1e-2, max_violation=True, cost_style='mean', use_lr_scheduler=False, grad_clip=0, num_training_steps=30000, device='cuda'): self.image_mha = image_mha self.image_encoder = image_encoder self.text_encoder = text_encoder self.bert_model = bert_model self.device = device self.use_lr_scheduler = use_lr_scheduler self.params = [] self.params = list(self.image_mha.parameters()) self.params += list(self.text_encoder.parameters()) self.params += list(self.image_encoder.parameters()) self.params += list(self.bert_model.parameters()) self.grad_clip = grad_clip self.frozen = False if optimizer == 'adamW': self.optimizer = AdamW([{ 'params': list(self.bert_model.parameters()), 'lr': 3e-5 }, { 'params': list(self.image_encoder.parameters()) + list(self.text_encoder.parameters()) + list(self.image_mha.parameters()), 'lr': 1e-4 }]) elif optimizer == 'adam': self.optimizer = torch.optim.Adam([{ 'params': list(self.bert_model.parameters()), 'lr': 3e-5 }, { 'params': list(self.image_encoder.parameters()) + list(self.text_encoder.parameters()) + list(self.image_mha.parameters()), 'lr': 1e-4 }]) # self.optimizer = torch.optim.Adam([{'params':list(self.bert_model.parameters()),'lr':3e-5}, # {'params':list(self.text_encoder.parameters()) + list(self.image_mha.parameters()),'lr':1e-4}]) if self.use_lr_scheduler: self.lr_scheduler = get_linear_schedule_with_warmup( self.optimizer, num_warmup_steps=100, num_training_steps=num_training_steps) self.lr_scheduler_0 = get_constant_schedule(self.optimizer) # loss self.mrl_loss = MarginRankingLoss(margin=margin_loss, max_violation=max_violation, cost_style=cost_style, direction='bidir') def forward(self, image_feature, image_attention_mask, input_ids, attention_mask, epoch): if epoch > 1 and self.frozen: self.frozen = False del self.lr_scheduler_0 torch.cuda.empty_cache() image_feature = l2norm(image_feature).detach() final_image_features = l2norm( self.image_mha(image_feature, image_attention_mask)) text_feature = self.bert_model(input_ids, attention_mask=attention_mask) text_feature = l2norm(text_feature) if epoch == 1: text_feature = text_feature.detach() self.frozen = True image_to_common = self.image_encoder(final_image_features) # image_to_common = final_image_features text_to_common = self.text_encoder(text_feature) return image_to_common, text_to_common def save_network(self, folder): torch.save(self.image_mha.state_dict(), os.path.join(folder, 'image_mha.pth')) torch.save(self.text_encoder.state_dict(), os.path.join(folder, 'text_encoder.pth')) torch.save(self.image_encoder.state_dict(), os.path.join(folder, 'image_encoder.pth')) torch.save(self.bert_model.state_dict(), os.path.join(folder, 'bert_model.pth')) torch.save(self.optimizer.state_dict(), os.path.join(folder, 'optimizer.pth')) if self.use_lr_scheduler: torch.save(self.lr_scheduler.state_dict(), os.path.join(folder, 'scheduler.pth')) def switch_to_train(self): self.image_mha.train() self.text_encoder.train() self.image_encoder.train() self.bert_model.train() def switch_to_eval(self): self.image_mha.eval() self.text_encoder.eval() self.image_encoder.eval() self.bert_model.eval() def train(self, image_features, image_attention_mask, input_ids, attention_mask, epoch): self.switch_to_train() image_to_common, text_to_common = self.forward(image_features, image_attention_mask, input_ids, attention_mask, epoch) self.optimizer.zero_grad() # Compute loss loss = self.mrl_loss(text_to_common, image_to_common) loss.backward() if self.grad_clip > 0: torch.nn.utils.clip_grad.clip_grad_norm_(self.params, self.grad_clip) self.optimizer.step() return loss.item() def step_scheduler(self): if self.use_lr_scheduler and not self.frozen: self.lr_scheduler.step() else: self.lr_scheduler_0.step() def evaluate(self, val_image_dataloader, val_text_dataloader, k): self.switch_to_eval() # Load image features with torch.no_grad(): image_features = [] image_ids = [] for ids, features, image_attention_mask in val_image_dataloader: image_ids.append(torch.stack(ids)) features = torch.stack(features).to(self.device) image_attention_mask = torch.stack(image_attention_mask).to( self.device) features = l2norm(features).detach() mha_features = l2norm( self.image_mha(features, image_attention_mask)) image_features.append(self.image_encoder(mha_features)) # image_features.append(mha_features) image_features = torch.cat(image_features, dim=0) image_ids = torch.cat(image_ids, dim=0).to(self.device) # Evaluate recall = 0 total_query = 0 pbar = tqdm(enumerate(val_text_dataloader), total=len(val_text_dataloader), leave=False, position=0, file=sys.stdout) for i, (image_files, input_ids, attention_mask) in pbar: input_ids = input_ids.to(self.device) attention_mask = attention_mask.to(self.device) text_features = self.bert_model(input_ids, attention_mask=attention_mask) text_features = l2norm(text_features) text_features = self.text_encoder(text_features) image_files = torch.tensor( list( map(lambda x: int(re.findall(r'\d{12}', x)[0]), image_files))).to(device) top_k = get_top_k_eval(text_features, image_features, k) for idx, indices in enumerate(top_k): total_query += 1 true_image_id = image_files[idx] top_k_images = torch.gather(image_ids, 0, indices) if (top_k_images == true_image_id).nonzero().numel() > 0: recall += 1 recall = recall / total_query return recall
def main(args): workdir = os.path.expanduser(args.training_directory) if os.path.exists(workdir) and not args.force: print("[error] %s exists, use -f to force continue training." % workdir) exit(1) init(args.seed, args.device) device = torch.device(args.device) print("[loading data]") chunks, targets, lengths = load_data(limit=args.chunks, shuffle=True, directory=args.directory) split = np.floor(chunks.shape[0] * args.validation_split).astype(np.int32) train_dataset = ChunkDataSet(chunks[:split], targets[:split], lengths[:split]) test_dataset = ChunkDataSet(chunks[split:], targets[split:], lengths[split:]) train_loader = DataLoader(train_dataset, batch_size=args.batch, shuffle=True, num_workers=4, pin_memory=True) test_loader = DataLoader(test_dataset, batch_size=args.batch, num_workers=4, pin_memory=True) config = toml.load(args.config) argsdict = dict(training=vars(args)) chunk_config = {} chunk_config_file = os.path.join(args.directory, 'config.toml') if os.path.isfile(chunk_config_file): chunk_config = toml.load(os.path.join(chunk_config_file)) os.makedirs(workdir, exist_ok=True) toml.dump({**config, **argsdict, **chunk_config}, open(os.path.join(workdir, 'config.toml'), 'w')) print("[loading model]") model = load_symbol(config, 'Model')(config) optimizer = AdamW(model.parameters(), amsgrad=False, lr=args.lr) last_epoch = load_state(workdir, args.device, model, optimizer, use_amp=args.amp) lr_scheduler = func_scheduler( optimizer, cosine_decay_schedule(1.0, 0.1), args.epochs * len(train_loader), warmup_steps=500, start_step=last_epoch*len(train_loader) ) if args.multi_gpu: from torch.nn import DataParallel model = DataParallel(model) model.decode = model.module.decode model.alphabet = model.module.alphabet if hasattr(model, 'seqdist'): criterion = model.seqdist.ctc_loss else: criterion = None for epoch in range(1 + last_epoch, args.epochs + 1 + last_epoch): try: train_loss, duration = train( model, device, train_loader, optimizer, criterion=criterion, use_amp=args.amp, lr_scheduler=lr_scheduler ) val_loss, val_mean, val_median = test( model, device, test_loader, criterion=criterion ) except KeyboardInterrupt: break print("[epoch {}] directory={} loss={:.4f} mean_acc={:.3f}% median_acc={:.3f}%".format( epoch, workdir, val_loss, val_mean, val_median )) model_state = model.state_dict() if not args.multi_gpu else model.module.state_dict() torch.save(model_state, os.path.join(workdir, "weights_%s.tar" % epoch)) torch.save(optimizer.state_dict(), os.path.join(workdir, "optim_%s.tar" % epoch)) with open(os.path.join(workdir, 'training.csv'), 'a', newline='') as csvfile: csvw = csv.writer(csvfile, delimiter=',') if epoch == 1: csvw.writerow([ 'time', 'duration', 'epoch', 'train_loss', 'validation_loss', 'validation_mean', 'validation_median' ]) csvw.writerow([ datetime.today(), int(duration), epoch, train_loss, val_loss, val_mean, val_median, ])
global_step) logger.add_images("test/2_output_outline", unnormalize(output) * (1 - outline), global_step) # Log these only once if first_run: logger.add_images("test/1_target", unnormalize(target), global_step) logger.add_images("test/3_target_outline", unnormalize(target) * (1 - outline), global_step) logger.add_images("test/4_input_morphed", unnormalize(GMM_morph), global_step) logger.add_images("test/5_input_outline", outline, global_step) first_run = False if global_step % save_interval == 0: output_path = Path( f"./training/checkpoints/{run_id}/E{e}_L{loss_G.item()}.pth" ) output_path.parent.mkdir(parents=True, exist_ok=True) torch.save( { "G": G.state_dict(), "e": e, "i": i, "run_id": run_id, "optimizer_G": optimizer_G.state_dict() }, output_path) print(f"Saved {output_path.stem}.")
class Seq2seqKpGen(object): """High level model that handles intializing the underlying network architecture, saving, updating examples, and predicting examples. """ # -------------------------------------------------------------------------- # Initialization # -------------------------------------------------------------------------- def __init__(self, args, word_dict, state_dict=None): # Book-keeping. self.args = args self.word_dict = word_dict self.args.vocab_size = len(word_dict) self.updates = 0 self.network = Sequence2Sequence(self.args, self.word_dict) if state_dict: self.network.load_state_dict(state_dict) def activate_fp16(self): if not hasattr(self, 'optimizer'): self.network.half() # for testing only return try: global amp from apex import amp except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") # https://github.com/NVIDIA/apex/issues/227 assert self.optimizer is not None self.network, self.optimizer = amp.initialize(self.network, self.optimizer, opt_level=self.args.fp16_opt_level) def init_optimizer(self, optim_state=None, sched_state=None): def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1): def lr_lambda(current_step: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1.0, num_warmup_steps)) return 1.0 return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in self.network.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": self.args.weight_decay, }, {"params": [p for n, p in self.network.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, ] self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate) self.scheduler = get_constant_schedule_with_warmup(self.optimizer, self.args.warmup_steps) if optim_state: self.optimizer.load_state_dict(optim_state) if self.args.device.type == 'cuda': for state in self.optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(self.args.device) if sched_state: self.scheduler.load_state_dict(sched_state) # -------------------------------------------------------------------------- # Learning # -------------------------------------------------------------------------- def update(self, ex): """Forward a batch of examples; step the optimizer to update weights.""" if not self.optimizer: raise RuntimeError('No optimizer set.') # Train mode self.network.train() source_map, alignment = None, None if self.args.copy_attn: source_map = make_src_map(ex['src_map']).to(self.args.device) alignment = align(ex['alignment']).to(self.args.device) source_rep = ex['source_rep'].to(self.args.device) source_len = ex['source_len'].to(self.args.device) target_rep = ex['target_rep'].to(self.args.device) target_len = ex['target_len'].to(self.args.device) # Run forward ml_loss, loss_per_token = self.network(source=source_rep, source_len=source_len, target=target_rep, target_len=target_len, src_map=source_map, alignment=alignment) loss = ml_loss.mean() if self.args.n_gpu > 1 else ml_loss if self.args.fp16: global amp with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() clip_grad_norm_(amp.master_params(self.optimizer), self.args.grad_clipping) else: loss.backward() clip_grad_norm_(self.network.parameters(), self.args.grad_clipping) self.updates += 1 self.optimizer.step() self.scheduler.step() # Update learning rate schedule self.optimizer.zero_grad() loss_per_token = loss_per_token.mean() if self.args.n_gpu > 1 else loss_per_token loss_per_token = loss_per_token.item() loss_per_token = 10 if loss_per_token > 10 else loss_per_token perplexity = math.exp(loss_per_token) return { 'ml_loss': loss.item(), 'perplexity': perplexity } # -------------------------------------------------------------------------- # Prediction # -------------------------------------------------------------------------- def predict(self, ex, replace_unk=False): """Forward a batch of examples only to get predictions. Args: ex: the batch examples replace_unk: replace `unk` tokens while generating predictions src_raw: raw source (passage); required to replace `unk` term Output: predictions: #batch predicted sequences """ def convert_text_to_string(text): """ Converts a sequence of tokens (string) in a single string. """ out_string = text.replace(" ##", "").strip() return out_string self.network.eval() source_map, alignment = None, None blank, fill = None, None if self.args.copy_attn: source_map = make_src_map(ex['src_map']).to(self.args.device) alignment = align(ex['alignment']).to(self.args.device) blank, fill = collapse_copy_scores(self.word_dict, ex['src_vocab']) source_rep = ex['source_rep'].to(self.args.device) source_len = ex['source_len'].to(self.args.device) decoder_out = self.network(source=source_rep, source_len=source_len, target=None, target_len=None, src_map=source_map, alignment=alignment, max_len=self.args.max_tgt_len, tgt_dict=self.word_dict, blank=blank, fill=fill, source_vocab=ex['src_vocab']) dec_probs = torch.exp(decoder_out['dec_log_probs']) predictions, scores = tens2sen_score(decoder_out['predictions'], dec_probs, self.word_dict, ex['src_vocab']) if replace_unk: for i in range(len(predictions)): enc_dec_attn = decoder_out['attentions'][i] if self.args.model_type == 'transformer': # tgt_len x num_heads x src_len assert enc_dec_attn.dim() == 3 enc_dec_attn = enc_dec_attn.mean(1) predictions[i] = replace_unknown(predictions[i], enc_dec_attn, src_raw=ex['source'][i].tokens) for bidx in range(ex['batch_size']): for i in range(len(predictions[bidx])): if predictions[bidx][i] == constants.KP_SEP: scores[bidx][i] = constants.KP_SEP elif predictions[bidx][i] == constants.PRESENT_EOS: scores[bidx][i] = constants.PRESENT_EOS else: assert isinstance(scores[bidx][i], float) scores[bidx][i] = str(scores[bidx][i]) predictions = [' '.join(item) for item in predictions] scores = [' '.join(item) for item in scores] present_kps = [] absent_kps = [] present_kp_scores = [] absent_kp_scores = [] for bidx in range(ex['batch_size']): keyphrases = predictions[bidx].split(constants.PRESENT_EOS) kp_scores = scores[bidx].split(constants.PRESENT_EOS) pkps = (' %s ' % constants.KP_SEP).join(keyphrases[:-1]) pkp_scores = (' %s ' % constants.KP_SEP).join(kp_scores[:-1]) akps = keyphrases[-1] akp_scores = kp_scores[-1] pre_kps = [] pre_kp_scores = [] for pkp, pkp_s in zip(pkps.split(constants.KP_SEP), pkp_scores.split(constants.KP_SEP)): pkp = pkp.strip() if pkp: pre_kps.append(convert_text_to_string(pkp)) t_scores = [float(i) for i in pkp_s.strip().split()] _score = np.prod(t_scores) / len(t_scores) pre_kp_scores.append(_score) present_kps.append(pre_kps) present_kp_scores.append(pre_kp_scores) abs_kps = [] abs_kp_scores = [] for akp, akp_s in zip(akps.split(constants.KP_SEP), akp_scores.split(constants.KP_SEP)): akp = akp.strip() if akp: abs_kps.append(convert_text_to_string(akp)) t_scores = [float(i) for i in akp_s.strip().split()] _score = np.prod(t_scores) / len(t_scores) abs_kp_scores.append(_score) absent_kps.append(abs_kps) absent_kp_scores.append(abs_kp_scores) return { 'present_kps': present_kps, 'absent_kps': absent_kps, 'present_kp_scores': present_kp_scores, 'absent_kp_scores': absent_kp_scores } # -------------------------------------------------------------------------- # Saving and loading # -------------------------------------------------------------------------- def save(self, filename): network = self.network.module if hasattr(self.network, "module") \ else self.network state_dict = copy.copy(network.state_dict()) params = { 'state_dict': state_dict, 'word_dict': self.word_dict, 'args': self.args, } try: torch.save(params, filename) except BaseException: logger.warning('WARN: Saving failed... continuing anyway.') def checkpoint(self, filename, epoch): network = self.network.module if hasattr(self.network, "module") \ else self.network params = { 'state_dict': network.state_dict(), 'word_dict': self.word_dict, 'args': self.args, 'epoch': epoch, 'updates': self.updates, 'optim_dict': self.optimizer.state_dict(), 'sched_dict': self.scheduler.state_dict(), } try: torch.save(params, filename) except BaseException: logger.warning('WARN: Saving failed... continuing anyway.') @staticmethod def load(filename, new_args=None): logger.info('Loading model %s' % filename) saved_params = torch.load( filename, map_location=lambda storage, loc: storage ) word_dict = saved_params['word_dict'] state_dict = saved_params['state_dict'] args = saved_params['args'] if new_args: args = override_model_args(args, new_args) return Seq2seqKpGen(args, word_dict, state_dict) @staticmethod def load_checkpoint(filename): logger.info('Loading model %s' % filename) saved_params = torch.load( filename, map_location=lambda storage, loc: storage ) word_dict = saved_params['word_dict'] state_dict = saved_params['state_dict'] epoch = saved_params['epoch'] updates = saved_params['updates'] optim_dict = saved_params['optim_dict'] sched_dict = saved_params['sched_dict'] args = saved_params['args'] model = Seq2seqKpGen(args, word_dict, state_dict) model.updates = updates model.init_optimizer(optim_dict, sched_dict) return model, epoch # -------------------------------------------------------------------------- # Runtime # -------------------------------------------------------------------------- def to(self, device): self.network = self.network.to(device) def parallelize(self): self.network = torch.nn.DataParallel(self.network)
class Trainer(): def __init__(self, config, pretrained=True, augmentor=ImgAugTransform()): self.config = config self.model, self.vocab = build_model(config) self.device = config['device'] self.num_iters = config['trainer']['iters'] self.beamsearch = config['predictor']['beamsearch'] self.data_root = config['dataset']['data_root'] self.train_annotation = config['dataset']['train_annotation'] self.valid_annotation = config['dataset']['valid_annotation'] self.train_lmdb = config['dataset']['train_lmdb'] self.valid_lmdb = config['dataset']['valid_lmdb'] self.dataset_name = config['dataset']['name'] self.batch_size = config['trainer']['batch_size'] self.print_every = config['trainer']['print_every'] self.valid_every = config['trainer']['valid_every'] self.image_aug = config['aug']['image_aug'] self.masked_language_model = config['aug']['masked_language_model'] self.metrics = config['trainer']['metrics'] self.is_padding = config['dataset']['is_padding'] self.tensorboard_dir = config['monitor']['log_dir'] if not os.path.exists(self.tensorboard_dir): os.makedirs(self.tensorboard_dir, exist_ok=True) self.writer = SummaryWriter(self.tensorboard_dir) # LOGGER self.logger = Logger(config['monitor']['log_dir']) self.logger.info(config) self.iter = 0 self.best_acc = 0 self.scheduler = None self.is_finetuning = config['trainer']['is_finetuning'] if self.is_finetuning: self.logger.info("Finetuning model ---->") if self.model.seq_modeling == 'crnn': self.optimizer = Adam(lr=0.0001, params=self.model.parameters(), betas=(0.5, 0.999)) else: self.optimizer = AdamW(lr=0.0001, params=self.model.parameters(), betas=(0.9, 0.98), eps=1e-09) else: self.optimizer = AdamW(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09) self.scheduler = OneCycleLR(self.optimizer, total_steps=self.num_iters, **config['optimizer']) if self.model.seq_modeling == 'crnn': self.criterion = torch.nn.CTCLoss(self.vocab.pad, zero_infinity=True) else: self.criterion = LabelSmoothingLoss(len(self.vocab), padding_idx=self.vocab.pad, smoothing=0.1) # Pretrained model if config['trainer']['pretrained']: self.load_weights(config['trainer']['pretrained']) self.logger.info("Loaded trained model from: {}".format( config['trainer']['pretrained'])) # Resume elif config['trainer']['resume_from']: self.load_checkpoint(config['trainer']['resume_from']) for state in self.optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(torch.device(self.device)) self.logger.info("Resume training from {}".format( config['trainer']['resume_from'])) # DATASET transforms = None if self.image_aug: transforms = augmentor train_lmdb_paths = [ os.path.join(self.data_root, lmdb_path) for lmdb_path in self.train_lmdb ] self.train_gen = self.data_gen( lmdb_paths=train_lmdb_paths, data_root=self.data_root, annotation=self.train_annotation, masked_language_model=self.masked_language_model, transform=transforms, is_train=True) if self.valid_annotation: self.valid_gen = self.data_gen( lmdb_paths=[os.path.join(self.data_root, self.valid_lmdb)], data_root=self.data_root, annotation=self.valid_annotation, masked_language_model=False) self.train_losses = [] self.logger.info("Number batch samples of training: %d" % len(self.train_gen)) self.logger.info("Number batch samples of valid: %d" % len(self.valid_gen)) config_savepath = os.path.join(self.tensorboard_dir, "config.yml") if not os.path.exists(config_savepath): self.logger.info("Saving config file at: %s" % config_savepath) Cfg(config).save(config_savepath) def train(self): total_loss = 0 total_loader_time = 0 total_gpu_time = 0 data_iter = iter(self.train_gen) for i in range(self.num_iters): self.iter += 1 start = time.time() try: batch = next(data_iter) except StopIteration: data_iter = iter(self.train_gen) batch = next(data_iter) total_loader_time += time.time() - start start = time.time() # LOSS loss = self.step(batch) total_loss += loss self.train_losses.append((self.iter, loss)) total_gpu_time += time.time() - start if self.iter % self.print_every == 0: info = 'Iter: {:06d} - Train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format( self.iter, total_loss / self.print_every, self.optimizer.param_groups[0]['lr'], total_loader_time, total_gpu_time) lastest_loss = total_loss / self.print_every total_loss = 0 total_loader_time = 0 total_gpu_time = 0 self.logger.info(info) if self.valid_annotation and self.iter % self.valid_every == 0: val_time = time.time() val_loss = self.validate() acc_full_seq, acc_per_char, wer = self.precision(self.metrics) self.logger.info("Iter: {:06d}, start validating".format( self.iter)) info = 'Iter: {:06d} - Valid loss: {:.3f} - Acc full seq: {:.4f} - Acc per char: {:.4f} - WER: {:.4f} - Time: {:.4f}'.format( self.iter, val_loss, acc_full_seq, acc_per_char, wer, time.time() - val_time) self.logger.info(info) if acc_full_seq > self.best_acc: self.save_weights(self.tensorboard_dir + "/best.pt") self.best_acc = acc_full_seq self.logger.info("Iter: {:06d} - Best acc: {:.4f}".format( self.iter, self.best_acc)) filename = 'last.pt' filepath = os.path.join(self.tensorboard_dir, filename) self.logger.info("Save checkpoint %s" % filename) self.save_checkpoint(filepath) log_loss = {'train loss': lastest_loss, 'val loss': val_loss} self.writer.add_scalars('Loss', log_loss, self.iter) self.writer.add_scalar('WER', wer, self.iter) def validate(self): self.model.eval() total_loss = [] with torch.no_grad(): for step, batch in enumerate(self.valid_gen): batch = self.batch_to_device(batch) img, tgt_input, tgt_output, tgt_padding_mask = batch[ 'img'], batch['tgt_input'], batch['tgt_output'], batch[ 'tgt_padding_mask'] outputs = self.model(img, tgt_input, tgt_padding_mask) # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)')) if self.model.seq_modeling == 'crnn': length = batch['labels_len'] preds_size = torch.autograd.Variable( torch.IntTensor([outputs.size(0)] * self.batch_size)) loss = self.criterion(outputs, tgt_output, preds_size, length) else: outputs = outputs.flatten(0, 1) tgt_output = tgt_output.flatten() loss = self.criterion(outputs, tgt_output) total_loss.append(loss.item()) del outputs del loss total_loss = np.mean(total_loss) self.model.train() return total_loss def predict(self, sample=None): pred_sents = [] actual_sents = [] img_files = [] probs_sents = [] imgs_sents = [] for idx, batch in enumerate(tqdm.tqdm(self.valid_gen)): batch = self.batch_to_device(batch) if self.model.seq_modeling != 'crnn': if self.beamsearch: translated_sentence = batch_translate_beam_search( batch['img'], self.model) prob = None else: translated_sentence, prob = translate( batch['img'], self.model) pred_sent = self.vocab.batch_decode( translated_sentence.tolist()) else: translated_sentence, prob = translate_crnn( batch['img'], self.model) pred_sent = self.vocab.batch_decode( translated_sentence.tolist(), crnn=True) actual_sent = self.vocab.batch_decode(batch['tgt_output'].tolist()) pred_sents.extend(pred_sent) actual_sents.extend(actual_sent) imgs_sents.extend(batch['img']) img_files.extend(batch['filenames']) probs_sents.extend(prob) # Visualize in tensorboard if idx == 0: try: num_samples = self.config['monitor']['num_samples'] fig = plt.figure(figsize=(12, 15)) imgs_samples = imgs_sents[:num_samples] preds_samples = pred_sents[:num_samples] actuals_samples = actual_sents[:num_samples] probs_samples = probs_sents[:num_samples] for id_img in range(len(imgs_samples)): img = imgs_samples[id_img] img = img.permute(1, 2, 0) img = img.cpu().detach().numpy() ax = fig.add_subplot(num_samples, 1, id_img + 1, xticks=[], yticks=[]) plt.imshow(img) ax.set_title( "LB: {} \n Pred: {:.4f}-{}".format( actuals_samples[id_img], probs_samples[id_img], preds_samples[id_img]), color=('green' if actuals_samples[id_img] == preds_samples[id_img] else 'red'), fontdict={ 'fontsize': 18, 'fontweight': 'medium' }) self.writer.add_figure('predictions vs. actuals', fig, global_step=self.iter) except Exception as error: print(error) continue if sample != None and len(pred_sents) > sample: break return pred_sents, actual_sents, img_files, probs_sents, imgs_sents def precision(self, sample=None, measure_time=True): t1 = time.time() pred_sents, actual_sents, _, _, _ = self.predict(sample=sample) time_predict = time.time() - t1 sensitive_case = self.config['predictor']['sensitive_case'] acc_full_seq = compute_accuracy(actual_sents, pred_sents, sensitive_case, mode='full_sequence') acc_per_char = compute_accuracy(actual_sents, pred_sents, sensitive_case, mode='per_char') wer = compute_accuracy(actual_sents, pred_sents, sensitive_case, mode='wer') if measure_time: print("Time: {:.4f}".format(time_predict / len(actual_sents))) return acc_full_seq, acc_per_char, wer def visualize_prediction(self, sample=16, errorcase=False, fontname='serif', fontsize=16, save_fig=False): pred_sents, actual_sents, img_files, probs, imgs = self.predict(sample) if errorcase: wrongs = [] for i in range(len(img_files)): if pred_sents[i] != actual_sents[i]: wrongs.append(i) pred_sents = [pred_sents[i] for i in wrongs] actual_sents = [actual_sents[i] for i in wrongs] img_files = [img_files[i] for i in wrongs] probs = [probs[i] for i in wrongs] imgs = [imgs[i] for i in wrongs] img_files = img_files[:sample] fontdict = {'family': fontname, 'size': fontsize} ncols = 5 nrows = int(math.ceil(len(img_files) / ncols)) fig, ax = plt.subplots(nrows, ncols, figsize=(12, 15)) for vis_idx in range(0, len(img_files)): row = vis_idx // ncols col = vis_idx % ncols pred_sent = pred_sents[vis_idx] actual_sent = actual_sents[vis_idx] prob = probs[vis_idx] img = imgs[vis_idx].permute(1, 2, 0).cpu().detach().numpy() ax[row, col].imshow(img) ax[row, col].set_title( "Pred: {: <2} \n Actual: {} \n prob: {:.2f}".format( pred_sent, actual_sent, prob), fontname=fontname, color='r' if pred_sent != actual_sent else 'g') ax[row, col].get_xaxis().set_ticks([]) ax[row, col].get_yaxis().set_ticks([]) plt.subplots_adjust() if save_fig: fig.savefig('vis_prediction.png') plt.show() def log_prediction(self, sample=16, csv_file='model.csv'): pred_sents, actual_sents, img_files, probs, imgs = self.predict(sample) save_predictions(csv_file, pred_sents, actual_sents, img_files) def vis_data(self, sample=20): ncols = 5 nrows = int(math.ceil(sample / ncols)) fig, ax = plt.subplots(nrows, ncols, figsize=(12, 12)) num_plots = 0 for idx, batch in enumerate(self.train_gen): for vis_idx in range(self.batch_size): row = num_plots // ncols col = num_plots % ncols img = batch['img'][vis_idx].numpy().transpose(1, 2, 0) sent = self.vocab.decode( batch['tgt_input'].T[vis_idx].tolist()) ax[row, col].imshow(img) ax[row, col].set_title("Label: {: <2}".format(sent), fontsize=16, color='g') ax[row, col].get_xaxis().set_ticks([]) ax[row, col].get_yaxis().set_ticks([]) num_plots += 1 if num_plots >= sample: plt.subplots_adjust() fig.savefig('vis_dataset.png') return def load_checkpoint(self, filename): checkpoint = torch.load(filename) self.optimizer.load_state_dict(checkpoint['optimizer']) self.model.load_state_dict(checkpoint['state_dict']) self.iter = checkpoint['iter'] self.train_losses = checkpoint['train_losses'] if self.scheduler is not None: self.scheduler.load_state_dict(checkpoint['scheduler']) self.best_acc = checkpoint['best_acc'] def save_checkpoint(self, filename): state = { 'iter': self.iter, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'train_losses': self.train_losses, 'scheduler': None if self.scheduler is None else self.scheduler.state_dict(), 'best_acc': self.best_acc } path, _ = os.path.split(filename) os.makedirs(path, exist_ok=True) torch.save(state, filename) def load_weights(self, filename): state_dict = torch.load(filename, map_location=torch.device(self.device)) if self.is_checkpoint(state_dict): self.model.load_state_dict(state_dict['state_dict']) else: for name, param in self.model.named_parameters(): if name not in state_dict: print('{} not found'.format(name)) elif state_dict[name].shape != param.shape: print('{} missmatching shape, required {} but found {}'. format(name, param.shape, state_dict[name].shape)) del state_dict[name] self.model.load_state_dict(state_dict, strict=False) def save_weights(self, filename): path, _ = os.path.split(filename) os.makedirs(path, exist_ok=True) torch.save(self.model.state_dict(), filename) def is_checkpoint(self, checkpoint): try: checkpoint['state_dict'] except: return False else: return True def batch_to_device(self, batch): img = batch['img'].to(self.device, non_blocking=True) tgt_input = batch['tgt_input'].to(self.device, non_blocking=True) tgt_output = batch['tgt_output'].to(self.device, non_blocking=True) tgt_padding_mask = batch['tgt_padding_mask'].to(self.device, non_blocking=True) batch = { 'img': img, 'tgt_input': tgt_input, 'tgt_output': tgt_output, 'tgt_padding_mask': tgt_padding_mask, 'filenames': batch['filenames'], 'labels_len': batch['labels_len'] } return batch def data_gen(self, lmdb_paths, data_root, annotation, masked_language_model=True, transform=None, is_train=False): datasets = [] for lmdb_path in lmdb_paths: dataset = OCRDataset( lmdb_path=lmdb_path, root_dir=data_root, annotation_path=annotation, vocab=self.vocab, transform=transform, image_height=self.config['dataset']['image_height'], image_min_width=self.config['dataset']['image_min_width'], image_max_width=self.config['dataset']['image_max_width'], separate=self.config['dataset']['separate'], batch_size=self.batch_size, is_padding=self.is_padding) datasets.append(dataset) if len(self.train_lmdb) > 1: dataset = torch.utils.data.ConcatDataset(datasets) if self.is_padding: sampler = None else: sampler = ClusterRandomSampler(dataset, self.batch_size, True) collate_fn = Collator(masked_language_model) gen = DataLoader(dataset, batch_size=self.batch_size, sampler=sampler, collate_fn=collate_fn, shuffle=is_train, drop_last=self.model.seq_modeling == 'crnn', **self.config['dataloader']) return gen def step(self, batch): self.model.train() batch = self.batch_to_device(batch) img, tgt_input, tgt_output, tgt_padding_mask = batch['img'], batch[ 'tgt_input'], batch['tgt_output'], batch['tgt_padding_mask'] outputs = self.model(img, tgt_input, tgt_key_padding_mask=tgt_padding_mask) # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)')) if self.model.seq_modeling == 'crnn': length = batch['labels_len'] preds_size = torch.autograd.Variable( torch.IntTensor([outputs.size(0)] * self.batch_size)) loss = self.criterion(outputs, tgt_output, preds_size, length) else: outputs = outputs.view( -1, outputs.size(2)) # flatten(0, 1) # B*S x N_class tgt_output = tgt_output.view(-1) # flatten() # B*S loss = self.criterion(outputs, tgt_output) self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) self.optimizer.step() if not self.is_finetuning: self.scheduler.step() loss_item = loss.item() return loss_item def count_parameters(self, model): return sum(p.numel() for p in model.parameters() if p.requires_grad) def gen_pseudo_labels(self, outfile=None): pred_sents = [] img_files = [] probs_sents = [] for idx, batch in enumerate(tqdm.tqdm(self.valid_gen)): batch = self.batch_to_device(batch) if self.model.seq_modeling != 'crnn': if self.beamsearch: translated_sentence = batch_translate_beam_search( batch['img'], self.model) prob = None else: translated_sentence, prob = translate( batch['img'], self.model) pred_sent = self.vocab.batch_decode( translated_sentence.tolist()) else: translated_sentence, prob = translate_crnn( batch['img'], self.model) pred_sent = self.vocab.batch_decode( translated_sentence.tolist(), crnn=True) pred_sents.extend(pred_sent) img_files.extend(batch['filenames']) probs_sents.extend(prob) assert len(pred_sents) == len(img_files) and len(img_files) == len( probs_sents) with open(outfile, 'w', encoding='utf-8') as f: for anno in zip(img_files, pred_sents, probs_sents): f.write('||||'.join([anno[0], anno[1], str(float(anno[2]))]) + '\n')
class Trainer(): def __init__(self, config, pretrained=True): self.config = config self.model, self.vocab = build_model(config) self.device = config['device'] self.num_iters = config['trainer']['iters'] self.beamsearch = config['predictor']['beamsearch'] self.data_root = config['dataset']['data_root'] self.train_annotation = config['dataset']['train_annotation'] self.valid_annotation = config['dataset']['valid_annotation'] self.dataset_name = config['dataset']['name'] self.batch_size = config['trainer']['batch_size'] self.print_every = config['trainer']['print_every'] self.valid_every = config['trainer']['valid_every'] self.checkpoint = config['trainer']['checkpoint'] self.export_weights = config['trainer']['export'] self.metrics = config['trainer']['metrics'] logger = config['trainer']['log'] if logger: self.logger = Logger(logger) if pretrained: weight_file = download_weights(**config['pretrain'], quiet=config['quiet']) self.load_weights(weight_file) self.iter = 0 self.optimizer = AdamW(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09) self.scheduler = OneCycleLR(self.optimizer, **config['optimizer']) # self.optimizer = ScheduledOptim( # Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09), # #config['transformer']['d_model'], # 512, # **config['optimizer']) self.criterion = LabelSmoothingLoss(len(self.vocab), padding_idx=self.vocab.pad, smoothing=0.1) transforms = ImgAugTransform() self.train_gen = self.data_gen('train_{}'.format(self.dataset_name), self.data_root, self.train_annotation, transform=transforms) if self.valid_annotation: self.valid_gen = self.data_gen( 'valid_{}'.format(self.dataset_name), self.data_root, self.valid_annotation) self.train_losses = [] def train(self): total_loss = 0 total_loader_time = 0 total_gpu_time = 0 best_acc = 0 data_iter = iter(self.train_gen) for i in range(self.num_iters): self.iter += 1 start = time.time() try: batch = next(data_iter) except StopIteration: data_iter = iter(self.train_gen) batch = next(data_iter) total_loader_time += time.time() - start start = time.time() loss = self.step(batch) total_gpu_time += time.time() - start total_loss += loss self.train_losses.append((self.iter, loss)) if self.iter % self.print_every == 0: info = 'iter: {:06d} - train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format( self.iter, total_loss / self.print_every, self.optimizer.param_groups[0]['lr'], total_loader_time, total_gpu_time) total_loss = 0 total_loader_time = 0 total_gpu_time = 0 print(info) self.logger.log(info) if self.valid_annotation and self.iter % self.valid_every == 0: val_loss = self.validate() acc_full_seq, acc_per_char = self.precision(self.metrics) info = 'iter: {:06d} - valid loss: {:.3f} - acc full seq: {:.4f} - acc per char: {:.4f}'.format( self.iter, val_loss, acc_full_seq, acc_per_char) print(info) self.logger.log(info) if acc_full_seq > best_acc: self.save_weights(self.export_weights) best_acc = acc_full_seq def validate(self): self.model.eval() total_loss = [] with torch.no_grad(): for step, batch in enumerate(self.valid_gen): batch = self.batch_to_device(batch) img, tgt_input, tgt_output, tgt_padding_mask = batch[ 'img'], batch['tgt_input'], batch['tgt_output'], batch[ 'tgt_padding_mask'] outputs = self.model(img, tgt_input, tgt_padding_mask) # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)')) outputs = outputs.flatten(0, 1) tgt_output = tgt_output.flatten() loss = self.criterion(outputs, tgt_output) total_loss.append(loss.item()) del outputs del loss total_loss = np.mean(total_loss) self.model.train() return total_loss def predict(self, sample=None): pred_sents = [] actual_sents = [] img_files = [] for batch in self.valid_gen: batch = self.batch_to_device(batch) if self.beamsearch: translated_sentence = batch_translate_beam_search( batch['img'], self.model) else: translated_sentence = translate(batch['img'], self.model) pred_sent = self.vocab.batch_decode(translated_sentence.tolist()) actual_sent = self.vocab.batch_decode(batch['tgt_output'].tolist()) img_files.extend(batch['filenames']) pred_sents.extend(pred_sent) actual_sents.extend(actual_sent) if sample != None and len(pred_sents) > sample: break return pred_sents, actual_sents, img_files def precision(self, sample=None): pred_sents, actual_sents, _ = self.predict(sample=sample) acc_full_seq = compute_accuracy(actual_sents, pred_sents, mode='full_sequence') acc_per_char = compute_accuracy(actual_sents, pred_sents, mode='per_char') return acc_full_seq, acc_per_char def visualize_prediction(self, sample=16, errorcase=False, fontname='serif', fontsize=16): pred_sents, actual_sents, img_files = self.predict(sample) if errorcase: wrongs = [] for i in range(len(img_files)): if pred_sents[i] != actual_sents[i]: wrongs.append(i) pred_sents = [pred_sents[i] for i in wrongs] actual_sents = [actual_sents[i] for i in wrongs] img_files = [img_files[i] for i in wrongs] img_files = img_files[:sample] fontdict = {'family': fontname, 'size': fontsize} for vis_idx in range(0, len(img_files)): img_path = img_files[vis_idx] pred_sent = pred_sents[vis_idx] actual_sent = actual_sents[vis_idx] img = Image.open(open(img_path, 'rb')) plt.figure() plt.imshow(img) plt.title('pred: {} - actual: {}'.format(pred_sent, actual_sent), loc='left', fontdict=fontdict) plt.axis('off') plt.show() def visualize_dataset(self, sample=16, fontname='serif'): n = 0 for batch in self.train_gen: for i in range(self.batch_size): img = batch['img'][i].numpy().transpose(1, 2, 0) sent = self.vocab.decode(batch['tgt_input'].T[i].tolist()) plt.figure() plt.title('sent: {}'.format(sent), loc='center', fontname=fontname) plt.imshow(img) plt.axis('off') n += 1 if n >= sample: plt.show() return def load_checkpoint(self, filename): checkpoint = torch.load(filename) optim = ScheduledOptim( Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09), self.config['transformer']['d_model'], **self.config['optimizer']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.model.load_state_dict(checkpoint['state_dict']) self.iter = checkpoint['iter'] self.train_losses = checkpoint['train_losses'] def save_checkpoint(self, filename): state = { 'iter': self.iter, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'train_losses': self.train_losses } path, _ = os.path.split(filename) os.makedirs(path, exist_ok=True) torch.save(state, filename) def load_weights(self, filename): state_dict = torch.load(filename, map_location=torch.device(self.device)) for name, param in self.model.named_parameters(): if name not in state_dict: print('{} not found'.format(name)) elif state_dict[name].shape != param.shape: print('{} missmatching shape'.format(name)) del state_dict[name] self.model.load_state_dict(state_dict, strict=False) def save_weights(self, filename): path, _ = os.path.split(filename) os.makedirs(path, exist_ok=True) torch.save(self.model.state_dict(), filename) def batch_to_device(self, batch): img = batch['img'].to(self.device, non_blocking=True) tgt_input = batch['tgt_input'].to(self.device, non_blocking=True) tgt_output = batch['tgt_output'].to(self.device, non_blocking=True) tgt_padding_mask = batch['tgt_padding_mask'].to(self.device, non_blocking=True) batch = { 'img': img, 'tgt_input': tgt_input, 'tgt_output': tgt_output, 'tgt_padding_mask': tgt_padding_mask, 'filenames': batch['filenames'] } return batch def data_gen(self, lmdb_path, data_root, annotation, transform=None): dataset = OCRDataset( lmdb_path=lmdb_path, root_dir=data_root, annotation_path=annotation, vocab=self.vocab, transform=transform, image_height=self.config['dataset']['image_height'], image_min_width=self.config['dataset']['image_min_width'], image_max_width=self.config['dataset']['image_max_width']) sampler = ClusterRandomSampler(dataset, self.batch_size, True) gen = DataLoader(dataset, batch_size=self.batch_size, sampler=sampler, collate_fn=collate_fn, shuffle=False, drop_last=False, **self.config['dataloader']) return gen def data_gen_v1(self, lmdb_path, data_root, annotation): data_gen = DataGen( data_root, annotation, self.vocab, 'cpu', image_height=self.config['dataset']['image_height'], image_min_width=self.config['dataset']['image_min_width'], image_max_width=self.config['dataset']['image_max_width']) return data_gen def step(self, batch): self.model.train() batch = self.batch_to_device(batch) img, tgt_input, tgt_output, tgt_padding_mask = batch['img'], batch[ 'tgt_input'], batch['tgt_output'], batch['tgt_padding_mask'] outputs = self.model(img, tgt_input, tgt_key_padding_mask=tgt_padding_mask) # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)')) outputs = outputs.view(-1, outputs.size(2)) #flatten(0, 1) tgt_output = tgt_output.view(-1) #flatten() loss = self.criterion(outputs, tgt_output) self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) self.optimizer.step() self.scheduler.step() loss_item = loss.item() return loss_item
class Training: def __init__(self, model, device, config, name, fold_num, imsize): self.config = config self.epoch = 0 self.base_dir = './models/' os.makedirs('./models', exist_ok=True) self.model = model self.best_loss = 10**5 self.device = device self.name = name self.fold_num = fold_num self.imsize = imsize # optimize param_optimizer = list(self.model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.001 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.00 }] self.optimizer = AdamW(self.model.parameters(), lr=config.lr) self.scheduler = config.SchedulerClass(self.optimizer, **config.scheduler_params) # Earlystopping self.patience = config.patience # GradScaler self.scaler = GradScaler() def train_one_epoch(self, train_loader): self.model.train() showloss = Showloss() for step, (images, targets) in tqdm(enumerate(train_loader), total=len(train_loader)): self.optimizer.zero_grad() with autocast(): images = torch.stack( images) # 이미지들을 합쳐 Batch 생성 (default: dim=0) [B,C,H,W] images = images.to(self.device).float() batch_size = images.shape[0] boxes = [ target['bbox'].to(self.device).float() for target in targets ] labels = [ target['cls'].to(self.device).float() for target in targets ] img_scale = torch.tensor([ target['img_scale'].to(self.device).float() for target in targets ]) img_size = torch.tensor([ (self.imsize, self.imsize) for target in targets ]).to(self.device).float() # update 후로 forward는 image와 target_dict를 인자로 받음 target_res = {} target_res['bbox'] = boxes target_res['cls'] = labels target_res['img_scale'] = img_scale target_res['img_size'] = img_size # pred output = self.model(images, target_res) loss = output['loss'] showloss.update(loss.detach().item(), batch_size) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() return showloss def val_one_epoch(self, val_loader): self.model.eval() showloss = Showloss() for step, (images, targets) in tqdm(enumerate(val_loader), total=len(val_loader)): with torch.no_grad(): images = torch.stack(images) batch_size = images.shape[0] images = images.to(self.device).float() boxes = [ target['bbox'].to(self.device).float() for target in targets ] labels = [ target['cls'].to(self.device).float() for target in targets ] img_scale = torch.tensor([ target['img_scale'].to(self.device).float() for target in targets ]) img_size = torch.tensor([ (self.imsize, self.imsize) for target in targets ]).to(self.device).float() target_res = {} target_res['bbox'] = boxes target_res['cls'] = labels target_res['img_scale'] = img_scale target_res['img_size'] = img_size # loss, _, _ = self.model(images, boxes, labels) output = self.model(images, target_res) loss = output['loss'] showloss.update(loss.detach().item(), batch_size) return showloss def save(self, path): # 모델 및 파라미터 저장 self.model.eval() torch.save( { 'model_state_dict': self.model.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'loss': self.best_loss, # val 'epoch': self.epoch, }, path) def load(self, path): checkpoint = torch.load(path) self.model.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.best_loss = checkpoint['best_loss'] # val self.epoch = checkpoint['epoch'] + 1 def fit(self, train_loader, val_loader): early_stopping = EarlyStopping(self.patience) for epoch in range(self.config.n_epochs): print('{} / {} Epoch'.format(epoch, self.config.n_epochs)) train_loss = self.train_one_epoch(train_loader) print('[Train] loss: {}'.format(train_loss.avg)) self.save(self.base_dir + '{}_{}_last.pt'.format(self.name, self.fold_num)) val_loss = self.val_one_epoch(val_loader) print('[Valid] loss: {}'.format(val_loss.avg)) if val_loss.avg < self.best_loss: self.best_loss = val_loss.avg self.save(self.base_dir + '{}_{}_best.pt'.format(self.name, self.fold_num)) # Early stopping early_stopping(val_loss.avg, self.best_loss) if early_stopping.early_stop: break if self.config.val_scheduler: self.scheduler.step(metrics=val_loss.avg) self.epoch += 1
class TrainLoop: def __init__( self, *, model, diffusion, data, batch_size, microbatch, lr, ema_rate, log_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=1e-3, schedule_sampler=None, weight_decay=0.0, lr_anneal_steps=0, ): self.model = model self.diffusion = diffusion self.data = data self.batch_size = batch_size self.microbatch = microbatch if microbatch > 0 else batch_size self.lr = lr self.ema_rate = ( [ema_rate] if isinstance(ema_rate, float) else [float(x) for x in ema_rate.split(",")] ) self.log_interval = log_interval self.save_interval = save_interval self.resume_checkpoint = resume_checkpoint self.use_fp16 = use_fp16 self.fp16_scale_growth = fp16_scale_growth self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) self.weight_decay = weight_decay self.lr_anneal_steps = lr_anneal_steps self.step = 0 self.resume_step = 0 self.global_batch = self.batch_size * dist.get_world_size() self.model_params = list(self.model.parameters()) self.master_params = self.model_params self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE self.sync_cuda = th.cuda.is_available() self._load_and_sync_parameters() if self.use_fp16: self._setup_fp16() self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay) if self.resume_step: self._load_optimizer_state() # Model was resumed, either due to a restart or a checkpoint # being specified at the command line. self.ema_params = [ self._load_ema_parameters(rate) for rate in self.ema_rate ] else: self.ema_params = [ copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate)) ] if th.cuda.is_available(): self.use_ddp = True self.ddp_model = DDP( self.model, device_ids=[dist_util.dev()], output_device=dist_util.dev(), broadcast_buffers=False, bucket_cap_mb=128, find_unused_parameters=False, ) else: if dist.get_world_size() > 1: logger.warn( "Distributed training requires CUDA. " "Gradients will not be synchronized properly!" ) self.use_ddp = False self.ddp_model = self.model def _load_and_sync_parameters(self): resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint if resume_checkpoint: self.resume_step = parse_resume_step_from_filename(resume_checkpoint) if dist.get_rank() == 0: logger.log(f"loading model from checkpoint: {resume_checkpoint}...") self.model.load_state_dict( dist_util.load_state_dict( resume_checkpoint, map_location=dist_util.dev() ) ) dist_util.sync_params(self.model.parameters()) def _load_ema_parameters(self, rate): ema_params = copy.deepcopy(self.master_params) main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) if ema_checkpoint: if dist.get_rank() == 0: logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") state_dict = dist_util.load_state_dict( ema_checkpoint, map_location=dist_util.dev() ) ema_params = self._state_dict_to_master_params(state_dict) dist_util.sync_params(ema_params) return ema_params def _load_optimizer_state(self): main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint opt_checkpoint = bf.join( bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" ) if bf.exists(opt_checkpoint): logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") state_dict = dist_util.load_state_dict( opt_checkpoint, map_location=dist_util.dev() ) self.opt.load_state_dict(state_dict) def _setup_fp16(self): self.master_params = make_master_params(self.model_params) self.model.convert_to_fp16() def run_loop(self): while ( not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps ): batch, cond = next(self.data) self.run_step(batch, cond) if self.step % self.log_interval == 0: logger.dumpkvs() if self.step % self.save_interval == 0: self.save() # Run for a finite amount of time in integration tests. if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: return self.step += 1 # Save the last checkpoint if it wasn't already saved. if (self.step - 1) % self.save_interval != 0: self.save() def run_step(self, batch, cond): self.forward_backward(batch, cond) if self.use_fp16: self.optimize_fp16() else: self.optimize_normal() self.log_step() def forward_backward(self, batch, cond): zero_grad(self.model_params) for i in range(0, batch.shape[0], self.microbatch): micro = batch[i : i + self.microbatch].to(dist_util.dev()) micro_cond = { k: v[i : i + self.microbatch].to(dist_util.dev()) for k, v in cond.items() } last_batch = (i + self.microbatch) >= batch.shape[0] t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) compute_losses = functools.partial( self.diffusion.training_losses, self.ddp_model, micro, t, model_kwargs=micro_cond, ) if last_batch or not self.use_ddp: losses = compute_losses() else: with self.ddp_model.no_sync(): losses = compute_losses() if isinstance(self.schedule_sampler, LossAwareSampler): self.schedule_sampler.update_with_local_losses( t, losses["loss"].detach() ) loss = (losses["loss"] * weights).mean() log_loss_dict( self.diffusion, t, {k: v * weights for k, v in losses.items()} ) if self.use_fp16: loss_scale = 2 ** self.lg_loss_scale (loss * loss_scale).backward() else: loss.backward() def optimize_fp16(self): if any(not th.isfinite(p.grad).all() for p in self.model_params): self.lg_loss_scale -= 1 logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") return model_grads_to_master_grads(self.model_params, self.master_params) self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) self._log_grad_norm() self._anneal_lr() self.opt.step() for rate, params in zip(self.ema_rate, self.ema_params): update_ema(params, self.master_params, rate=rate) master_params_to_model_params(self.model_params, self.master_params) self.lg_loss_scale += self.fp16_scale_growth def optimize_normal(self): self._log_grad_norm() self._anneal_lr() self.opt.step() for rate, params in zip(self.ema_rate, self.ema_params): update_ema(params, self.master_params, rate=rate) def _log_grad_norm(self): sqsum = 0.0 for p in self.master_params: sqsum += (p.grad ** 2).sum().item() logger.logkv_mean("grad_norm", np.sqrt(sqsum)) def _anneal_lr(self): if not self.lr_anneal_steps: return frac_done = (self.step + self.resume_step) / self.lr_anneal_steps lr = self.lr * (1 - frac_done) for param_group in self.opt.param_groups: param_group["lr"] = lr def log_step(self): logger.logkv("step", self.step + self.resume_step) logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) if self.use_fp16: logger.logkv("lg_loss_scale", self.lg_loss_scale) def save(self): def save_checkpoint(rate, params): state_dict = self._master_params_to_state_dict(params) if dist.get_rank() == 0: logger.log(f"saving model {rate}...") if not rate: filename = f"model{(self.step+self.resume_step):06d}.pt" else: filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: th.save(state_dict, f) save_checkpoint(0, self.master_params) for rate, params in zip(self.ema_rate, self.ema_params): save_checkpoint(rate, params) if dist.get_rank() == 0: with bf.BlobFile( bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), "wb", ) as f: th.save(self.opt.state_dict(), f) dist.barrier() def _master_params_to_state_dict(self, master_params): if self.use_fp16: master_params = unflatten_master_params( self.model.parameters(), master_params ) state_dict = self.model.state_dict() for i, (name, _value) in enumerate(self.model.named_parameters()): assert name in state_dict state_dict[name] = master_params[i] return state_dict def _state_dict_to_master_params(self, state_dict): params = [state_dict[name] for name, _ in self.model.named_parameters()] if self.use_fp16: return make_master_params(params) else: return params
def train_runner(model: nn.Module, model_name: str, results_dir: str, experiment: str = '', debug: bool = False, img_size: int = IMG_SIZE, learning_rate: float = 1e-2, fold: int = 0, checkpoint: str = '', epochs: int = 15, batch_size: int = 8, num_workers: int = 4, start_epoch: int = 0, save_oof: bool = False, save_train_oof: bool = False, gpu_number: int = 1): """ Model training runner Args: model : PyTorch model model_name : string name for model for checkpoints saving results_dir : directory to save results experiment : string name for naming experiments debug : if True, runs the debugging on few images img_size : size of images for training learning_rate: initial learning rate (default = 1e-2) fold : training fold (default = 0) epochs : number of the last epochs to train batch_size : number of images in batch num_workers : number of workers available from_epoch : number of epoch to continue training save_oof : saves oof validation predictions. Default = False """ device = torch.device( f'cuda:{gpu_number}' if torch.cuda.is_available() else 'cpu') print(device) # load model weights to continue training if checkpoint != '': model, ckpt = load_model(model, checkpoint) moiu = ckpt['valid_miou'] loss = ckpt['valid_loss'] start_epoch = ckpt['epoch'] + 1 print('Loaded model from {}, epoch {}'.format(checkpoint, start_epoch)) model.to(device) # creates directories for checkpoints, tensorboard and predicitons checkpoints_dir = f'{results_dir}rgb/checkpoints/{model_name}{experiment}' predictions_dir = f'{results_dir}rgb/oof/{model_name}{experiment}' validations_dir = f'{results_dir}rgb/oof_val/{model_name}{experiment}' os.makedirs(checkpoints_dir, exist_ok=True) os.makedirs(predictions_dir, exist_ok=True) os.makedirs(validations_dir, exist_ok=True) print('\n', model_name, '\n') # datasets for train and validation df = pd.read_csv(f'{TRAIN_DIR}folds.csv') df_train = df[df.fold != fold] df_val = df[df.fold == fold] print( f'Train images: {len(df_train.ImageId.values)}, valid images {len(df_val.ImageId.values)}' ) train_dataset = RGBDataset( images_dir=TRAIN_RGB, masks_dir=TRAIN_MASKS, labels_df=df_train, img_size=img_size, transforms=TRANSFORMS["medium"], normalise=True, debug=debug, ) valid_dataset = RGBDataset( images_dir=TRAIN_RGB, masks_dir=TRAIN_MASKS, labels_df=df_val, img_size=img_size, transforms=TRANSFORMS["d4"], normalise=True, debug=debug, ) # dataloaders for train and validation dataloader_train = DataLoader(train_dataset, num_workers=num_workers, batch_size=batch_size, shuffle=True) dataloader_valid = DataLoader(valid_dataset, num_workers=num_workers, batch_size=batch_size, shuffle=False, drop_last=True) print('{} training images, {} validation images'.format( len(train_dataset), len(valid_dataset))) # optimizers and schedulers optimizer = AdamW(model.parameters(), lr=learning_rate) #optimizer = RAdam(model.parameters(), lr=learning_rate) scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=2, verbose=True, factor=0.2, min_lr=1e-6) # load optimizer state continue training #if checkpoint != '': # optimizer = load_optim(optimizer, checkpoint, device) # criteria criterion1 = nn.BCEWithLogitsLoss() criterion = BCEJaccardLoss(bce_weight=2, jaccard_weight=0.5, log_loss=False, log_sigmoid=True) #criterion = JaccardLoss(log_sigmoid=True, log_loss=False) # basic logging report_batch = 200 report_epoch = 20 log_file = os.path.join(checkpoints_dir, f'fold_{fold}.log') logging.basicConfig(filename=log_file, filemode="w", level=logging.DEBUG) logging.info( f'Parameters:\n model_name: {model_name}\n, results_dir: {results_dir}\n, experiment: {experiment}\n, img_size: {img_size}\n, \ learning_rate: {learning_rate}\n, fold: {fold}\n, epochs: {epochs}\n, batch_size: {batch_size}\n, num_workers: {num_workers}\n, \ start_epoch: {start_epoch}\n, save_oof: {save_oof}\n, optimizer: {optimizer}\n, scheduler: {scheduler} \n, checkpoint: {start_epoch} \n' ) train_losses, val_losses = [], [] best_val_loss = 1e+5 best_val_metric = 0 # training cycle print("Start training") for epoch in range(start_epoch, start_epoch + epochs + 1): print("Epoch", epoch) epoch_losses = [] progress_bar = tqdm(dataloader_train, total=len(dataloader_train)) progress_bar.set_description('Epoch {}'.format(epoch)) with torch.set_grad_enabled( True): # --> sometimes people write it, idk for batch_num, (img, target, _) in enumerate(progress_bar): img = img.to(device) target = target.float().to(device) prediction = model(img) loss = criterion(prediction, target) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 3) optimizer.step() epoch_losses.append(loss.detach().cpu().numpy()) if batch_num and batch_num % report_batch == 0: neptune.log_metric('Train loss', np.mean(epoch_losses)) # log loss history print("Epoch {}, Train Loss: {}".format(epoch, np.mean(epoch_losses))) train_losses.append(np.mean(epoch_losses)) neptune.log_metric('Train loss', np.mean(epoch_losses)) logging.info( f'epoch: {epoch}; step: {batch_num}; loss: {np.mean(epoch_losses)} \n' ) # validate model val_loss = validate_loss(model, dataloader_valid, criterion1, epoch, validations_dir, device) valid_metrics = validate(model, dataloader_valid, criterion, epoch, validations_dir, save_oof, device) # logging metrics neptune.log_metric('bce_loss_valid', val_loss) neptune.log_metric('loss_valid', valid_metrics['val_loss']) neptune.log_metric('miou_valid', valid_metrics['miou']) # get current learning rate for param_group in optimizer.param_groups: lr = param_group['lr'] print(f'learning_rate: {lr}') logging.info(f'learning_rate: {lr}\n') neptune.log_metric('lr', lr) scheduler.step(valid_metrics['miou']) # save the best metric if valid_metrics['miou'] > best_val_metric: best_val_metric = valid_metrics['miou'] # save model, optimizer and losses after every epoch print( f"Saving model with the best val metric {valid_metrics['miou']}, epoch {epoch}" ) checkpoint_filename = f"{model_name}_best_val_miou.pth" checkpoint_filepath = os.path.join(checkpoints_dir, checkpoint_filename) torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'loss': np.mean(epoch_losses), 'valid_loss': valid_metrics['val_loss'], 'valid_miou': valid_metrics['miou'], }, checkpoint_filepath) # save the best loss if valid_metrics['val_loss'] < best_val_loss: best_val_loss = valid_metrics['val_loss'] # save model, optimizer and losses after every epoch print( f"Saving model with the best val loss {valid_metrics['val_loss']}, epoch {epoch}" ) checkpoint_filename = "{}_best_val_loss.pth".format(model_name) checkpoint_filepath = os.path.join(checkpoints_dir, checkpoint_filename) torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'loss': np.mean(epoch_losses), 'valid_loss': valid_metrics['val_loss'], 'valid_miou': valid_metrics['miou'], }, checkpoint_filepath) # save model, optimizer and losses after every n epoch elif epoch % report_epoch == 0: print( f"Saving model at epoch {epoch}, val loss {valid_metrics['val_loss']}" ) checkpoint_filename = "{}_epoch_{}.pth".format(model_name, epoch) checkpoint_filepath = os.path.join(checkpoints_dir, checkpoint_filename) torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'loss': np.mean(epoch_losses), 'valid_loss': valid_metrics['val_loss'], 'valid_miou': valid_metrics['miou'], }, checkpoint_filepath)
def run_training(opt): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') work_dir, epochs, train_batch, valid_batch, weights = \ opt.work_dir, opt.epochs, opt.train_bs, opt.valid_bs, opt.weights # Directories last = os.path.join(work_dir, 'last.pt') best = os.path.join(work_dir, 'best.pt') # -------------------------------------- # Setup train and validation set # -------------------------------------- data = pd.read_csv(opt.train_csv) images_path = opt.data_dir n_classes = 6 # fixed coding :V data['class'] = data.apply(lambda row: categ[row["class"]], axis=1) train_loader, val_loader = prepare_dataloader(data, opt.fold, train_batch, valid_batch, opt.img_size, opt.num_workers, data_root=images_path) # if not opt.ovr_val: # handwritten_data = pd.read_csv(opt.handwritten_csv) # printed_data = pd.read_csv(opt.printed_csv) # handwritten_data['class'] = handwritten_data.apply(lambda row: categ[row["class"]], axis =1) # printed_data['class'] = printed_data.apply(lambda row: categ[row["class"]], axis =1) # _, handwritten_val_loader = prepare_dataloader( # handwritten_data, opt.fold, train_batch, valid_batch, opt.img_size, opt.num_workers, data_root=images_path) # _, printed_val_loader = prepare_dataloader( # printed_data, opt.fold, train_batch, valid_batch, opt.img_size, opt.num_workers, data_root=images_path) # -------------------------------------- # Models # -------------------------------------- model = Classifier(model_name=opt.model_name, n_classes=n_classes, pretrained=True).to(device) if opt.weights is not None: cp = torch.load(opt.weights) model.load_state_dict(cp['model']) # ------------------------------------------- # Setup optimizer, scheduler, criterion loss # ------------------------------------------- optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-6) scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1) scaler = GradScaler() loss_tr = nn.CrossEntropyLoss().to(device) loss_fn = nn.CrossEntropyLoss().to(device) # -------------------------------------- # Setup training # -------------------------------------- if os.path.exists(work_dir) == False: os.mkdir(work_dir) best_loss = 1e5 start_epoch = 0 best_epoch = 0 # for early stopping if opt.resume == True: checkpoint = torch.load(last) start_epoch = checkpoint["epoch"] model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint["scheduler"]) best_loss = checkpoint["best_loss"] # -------------------------------------- # Start training # -------------------------------------- print("[INFO] Start training...") for epoch in range(start_epoch, epochs): train_one_epoch(epoch, model, loss_tr, optimizer, train_loader, device, scheduler=scheduler, scaler=scaler) with torch.no_grad(): if opt.ovr_val: val_loss = valid_one_epoch_overall(epoch, model, loss_fn, val_loader, device, scheduler=None) else: val_loss = valid_one_epoch(epoch, model, loss_fn, handwritten_val_loader, printed_val_loader, device, scheduler=None) if val_loss < best_loss: best_loss = val_loss best_epoch = epoch torch.save( { 'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'best_loss': best_loss }, os.path.join(best)) print('best model found for epoch {}'.format(epoch + 1)) torch.save( { 'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'best_loss': best_loss }, os.path.join(last)) if epoch - best_epoch > opt.patience: print("Early stop achieved at", epoch + 1) break del model, optimizer, train_loader, val_loader, scheduler, scaler torch.cuda.empty_cache()
avg_train_loss = total_train_loss / len(validation_dataloader) print(" Validation Loss: {0:.2f}".format(avg_val_loss)) print(" Validation took: {:}".format(validation_time)) training_stats.append({ 'Avg Accuracy': avg_val_accuracy, 'Bleu Score': avg_bleuscore, 'Training Loss': avg_train_loss, 'Valid. Loss': avg_val_loss, 'Validation Time': validation_time }) torch.save( { 'epoch': epoch_i + 4, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'total_train_loss': total_train_loss, 'step': len(train_dataloader), 'training_stats': training_stats }, "/global/cscratch1/sd/ajaybati/model_ckptDS" + str(epoch_i + 1) + ".pickle") print(training_stats) print("") print("Total training took {:} (h:mm:ss)".format(format_time(time.time() - t0))) print("done completely")