class Learning(object): def __init__(self, model, criterion, optimizer, scheduler, metric_ftns, device, num_epoch, grad_clipping, grad_accumulation_steps, early_stopping, validation_frequency, tensorboard, checkpoint_dir, resume_path): self.device, device_ids = self._prepare_device(device) # self.model = model.to(self.device) self.start_epoch = 1 if resume_path is not None: self._resume_checkpoint(resume_path) if len(device_ids) > 1: # self.model = torch.nn.DataParallel(model, device_ids=device_ids) self.model = torch.nn.DataParallel(model) # cudnn.benchmark = True self.model = model.cuda() self.criterion = criterion self.metric_ftns = metric_ftns self.optimizer = optimizer self.num_epoch = num_epoch self.scheduler = scheduler self.grad_clipping = grad_clipping self.grad_accumulation_steps = grad_accumulation_steps self.early_stopping = early_stopping self.validation_frequency =validation_frequency self.checkpoint_dir = checkpoint_dir self.best_epoch = 1 self.best_score = 0 self.writer = TensorboardWriter(os.path.join(checkpoint_dir, 'tensorboard'), tensorboard) self.train_metrics = MetricTracker('loss', writer = self.writer) self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer = self.writer) def train(self, train_dataloader): score = 0 for epoch in range(self.start_epoch, self.num_epoch+1): print("{} epoch: \t start training....".format(epoch)) start = time.time() train_result = self._train_epoch(epoch, train_dataloader) train_result.update({'time': time.time()-start}) for key, value in train_result.items(): print(' {:15s}: {}'.format(str(key), value)) # if (epoch+1) % self.validation_frequency!=0: # print("skip validation....") # continue # print('{} epoch: \t start validation....'.format(epoch)) # start = time.time() # valid_result = self._valid_epoch(epoch, valid_dataloader) # valid_result.update({'time': time.time() - start}) # for key, value in valid_result.items(): # if 'score' in key: # score = value # print(' {:15s}: {}'.format(str(key), value)) score+=0.001 self.post_processing(score, epoch) if epoch - self.best_epoch > self.early_stopping: print('WARNING: EARLY STOPPING') break def _train_epoch(self, epoch, data_loader): self.model.train() self.optimizer.zero_grad() self.train_metrics.reset() for idx, (data, target) in enumerate(data_loader): data = Variable(data.cuda()) target = [ann.to(self.device) for ann in target] output = self.model(data) loss = self.criterion(output, target) loss.backward() self.writer.set_step((epoch - 1) * len(data_loader) + idx) self.train_metrics.update('loss', loss.item()) if (idx+1) % self.grad_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clipping) self.optimizer.step() self.optimizer.zero_grad() if (idx+1) % int(np.sqrt(len(data_loader))) == 0: self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) return self.train_metrics.result() def _valid_epoch(self, epoch, data_loader): self.valid_metrics.reset() self.model.eval() with torch.no_grad(): for idx, (data, target) in enumerate(data_loader): data, target = data.to(self.device), target.to(self.device) output = self.model(data) loss = self.criterion(output, target) self.writer.set_step((epoch - 1) * len(data_loader) + idx, 'valid') self.valid_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.valid_metrics.update(met.__name__, met(output, target)) self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result() def post_processing(self, score, epoch): best = False if score > self.best_score: self.best_score = score self.best_epoch = epoch best = True print("best model: {} epoch - {:.5}".format(epoch, score)) self._save_checkpoint(epoch = epoch, save_best = best) if self.scheduler.__class__.__name__ == 'ReduceLROnPlateau': self.scheduler.step(score) else: self.scheduler.step() def _save_checkpoint(self, epoch, save_best=False): """ Saving checkpoints :param epoch: current epoch number :param save_best: if True, rename the saved checkpoint to 'model_best.pth' """ arch = type(self.model).__name__ state = { 'arch': arch, 'epoch': epoch, 'state_dict': self.get_state_dict(self.model), 'best_score': self.best_score } filename = os.path.join(self.checkpoint_dir, 'checkpoint_epoch{}.pth'.format(epoch)) torch.save(state, filename) print("Saving checkpoint: {} ...".format(filename)) if save_best: best_path = os.path.join(self.checkpoint_dir, 'model_best.pth') torch.save(state, best_path) print("Saving current best: model_best.pth ...") @staticmethod def get_state_dict(model): if type(model) == torch.nn.DataParallel: state_dict = model.module.state_dict() else: state_dict = model.state_dict() return state_dict def _resume_checkpoint(self, resume_path): resume_path = str(resume_path) print("Loading checkpoint: {} ...".format(resume_path)) checkpoint = torch.load(resume_path, map_location=lambda storage, loc: storage) self.start_epoch = checkpoint['epoch'] + 1 self.best_epoch = checkpoint['epoch'] self.best_score = checkpoint['best_score'] self.model.load_state_dict(checkpoint['state_dict']) print("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) @staticmethod def _prepare_device(device): n_gpu_use = len(device) n_gpu = torch.cuda.device_count() if n_gpu_use > 0 and n_gpu == 0: print("Warning: There\'s no GPU available on this machine, training will be performed on CPU.") n_gpu_use = 0 if n_gpu_use > n_gpu: print("Warning: The number of GPU\'s configured to use is {}, but only {} are available on this machine.".format(n_gpu_use, n_gpu)) n_gpu_use = n_gpu list_ids = device device = torch.device('cuda:{}'.format(device[0]) if n_gpu_use > 0 else 'cpu') return device, list_ids
class Trainer: """ Training pipeline Parameters ---------- num_epochs : int We should train the model for __ epochs start_epoch : int We should start training the model from __th epoch train_loader : DataLoader DataLoader for training data model : nn.Module Model model_name : str Name of the model loss_function : nn.Module Loss function (cross entropy) optimizer : optim.Optimizer Optimizer (Adam) lr_decay : float A factor in interval (0, 1) to multiply the learning rate with dataset_name : str Name of the dataset word_map : Dict[str, int] Word2id map grad_clip : float, optional Gradient threshold in clip gradients print_freq : int Print training status every __ batches checkpoint_path : str, optional Path to the folder to save checkpoints checkpoint_basename : str, optional, default='checkpoint' Basename of the checkpoint tensorboard : bool, optional, default=False Enable tensorboard or not? log_dir : str, optional Path to the folder to save logs for tensorboard """ def __init__(self, num_epochs: int, start_epoch: int, train_loader: DataLoader, model: nn.Module, model_name: str, loss_function: nn.Module, optimizer, lr_decay: float, dataset_name: str, word_map: Dict[str, int], grad_clip=Optional[None], print_freq: int = 100, checkpoint_path: Optional[str] = None, checkpoint_basename: str = 'checkpoint', tensorboard: bool = False, log_dir: Optional[str] = None) -> None: self.num_epochs = num_epochs self.start_epoch = start_epoch self.train_loader = train_loader self.model = model self.model_name = model_name self.loss_function = loss_function self.optimizer = optimizer self.lr_decay = lr_decay self.dataset_name = dataset_name self.word_map = word_map self.print_freq = print_freq self.grad_clip = grad_clip self.checkpoint_path = checkpoint_path self.checkpoint_basename = checkpoint_basename # setup visualization writer instance self.writer = TensorboardWriter(log_dir, tensorboard) self.len_epoch = len(self.train_loader) def train(self, epoch: int) -> None: """ Train an epoch Parameters ---------- epoch : int Current number of epoch """ self.model.train() # training mode enables dropout batch_time = AverageMeter( ) # forward prop. + back prop. time per batch data_time = AverageMeter() # data loading time per batch losses = AverageMeter(tag='loss', writer=self.writer) # cross entropy loss accs = AverageMeter(tag='acc', writer=self.writer) # accuracies start = time.time() # batches for i, batch in enumerate(self.train_loader): data_time.update(time.time() - start) if self.model_name in ['han']: documents, sentences_per_document, words_per_sentence, labels = batch documents = documents.to( device) # (batch_size, sentence_limit, word_limit) sentences_per_document = sentences_per_document.squeeze(1).to( device) # (batch_size) words_per_sentence = words_per_sentence.to( device) # (batch_size, sentence_limit) labels = labels.squeeze(1).to(device) # (batch_size) # forward scores, _, _ = self.model( documents, sentences_per_document, words_per_sentence ) # (n_documents, n_classes), (n_documents, max_doc_len_in_batch, max_sent_len_in_batch), (n_documents, max_doc_len_in_batch) else: sentences, words_per_sentence, labels = batch sentences = sentences.to(device) # (batch_size, word_limit) words_per_sentence = words_per_sentence.squeeze(1).to( device) # (batch_size) labels = labels.squeeze(1).to(device) # (batch_size) # for torchtext # sentences = batch.text[0].to(device) # (batch_size, word_limit) # words_per_sentence = batch.text[1].to(device) # (batch_size) # labels = batch.label.to(device) # (batch_size) scores = self.model( sentences, words_per_sentence) # (batch_size, n_classes) # calc loss loss = self.loss_function(scores, labels) # scalar # backward self.optimizer.zero_grad() loss.backward() # clip gradients if self.grad_clip is not None: clip_gradient(self.optimizer, grad_clip) # update weights self.optimizer.step() # find accuracy _, predictions = scores.max(dim=1) # (n_documents) correct_predictions = torch.eq(predictions, labels).sum().item() accuracy = correct_predictions / labels.size(0) # set step for tensorboard step = (epoch - 1) * self.len_epoch + i self.writer.set_step(step=step, mode='train') # keep track of metrics batch_time.update(time.time() - start) losses.update(loss.item(), labels.size(0)) accs.update(accuracy, labels.size(0)) start = time.time() # print training status if i % self.print_freq == 0: print( 'Epoch: [{0}][{1}/{2}]\t' 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format( epoch, i, len(self.train_loader), batch_time=batch_time, data_time=data_time, loss=losses, acc=accs)) def run_train(self): start = time.time() # epochs for epoch in range(self.start_epoch, self.num_epochs): # trian an epoch self.train(epoch=epoch) # time per epoch epoch_time = time.time() - start print('Epoch: [{0}] finished, time consumed: {epoch_time:.3f}'. format(epoch, epoch_time=epoch_time)) # decay learning rate every epoch adjust_learning_rate(self.optimizer, self.lr_decay) # save checkpoint if self.checkpoint_path is not None: save_checkpoint(epoch=epoch, model=self.model, model_name=self.model_name, optimizer=self.optimizer, dataset_name=self.dataset_name, word_map=self.word_map, checkpoint_path=self.checkpoint_path, checkpoint_basename=self.checkpoint_basename) start = time.time()
class Trainer: """ Encoder-decoder pipeline. Tearcher Forcing is used during training and validation. Parameters ---------- caption_model : str Type of the caption model epochs : int We should train the model for __ epochs device : torch.device Use GPU or not word_map : Dict[str, int] Word2id map rev_word_map : Dict[int, str] Id2word map start_epoch : int We should start training the model from __th epoch epochs_since_improvement : int Number of epochs since last improvement in BLEU-4 score best_bleu4 : float Best BLEU-4 score until now train_loader : DataLoader DataLoader for training data val_loader : DataLoader DataLoader for validation data encoder : nn.Module Encoder (based on CNN) decoder : nn.Module Decoder (based on LSTM) encoder_optimizer : optim.Optimizer Optimizer for encoder (Adam) (if fine-tune) decoder_optimizer : optim.Optimizer Optimizer for decoder (Adam) loss_function : nn.Module Loss function (cross entropy) grad_clip : float Gradient threshold in clip gradients tau : float Penalty term τ for doubly stochastic attention in paper: show, attend and tell fine_tune_encoder : bool Fine-tune encoder or not tensorboard : bool, optional, default=False Enable tensorboard or not? log_dir : str, optional Path to the folder to save logs for tensorboard """ def __init__( self, caption_model: str, epochs: int, device: torch.device, word_map: Dict[str, int], rev_word_map: Dict[int, str], start_epoch: int, epochs_since_improvement: int, best_bleu4: float, train_loader: DataLoader, val_loader: DataLoader, encoder: nn.Module, decoder: nn.Module, encoder_optimizer: optim.Optimizer, decoder_optimizer: optim.Optimizer, loss_function: nn.Module, grad_clip: float, tau: float, fine_tune_encoder: bool, tensorboard: bool = False, log_dir: Optional[str] = None ) -> None: self.device = device # GPU / CPU self.caption_model = caption_model self.epochs = epochs self.word_map = word_map self.rev_word_map = rev_word_map self.start_epoch = start_epoch self.epochs_since_improvement = epochs_since_improvement self.best_bleu4 = best_bleu4 self.train_loader = train_loader self.val_loader = val_loader self.encoder = encoder self.decoder = decoder self.encoder_optimizer = encoder_optimizer self.decoder_optimizer = decoder_optimizer self.loss_function = loss_function self.tau = tau self.grad_clip = grad_clip self.fine_tune_encoder = fine_tune_encoder self.print_freq = 100 # print training/validation stats every __ batches # setup visualization writer instance self.writer = TensorboardWriter(log_dir, tensorboard) self.len_epoch = len(self.train_loader) def train(self, epoch: int) -> None: """ Train an epoch Parameters ---------- epoch : int Current number of epoch """ self.decoder.train() # train mode (dropout and batchnorm is used) self.encoder.train() batch_time = AverageMeter() # forward prop. + back prop. time data_time = AverageMeter() # data loading time losses = AverageMeter(tag='loss', writer=self.writer) # loss (per word decoded) top5accs = AverageMeter(tag='top5acc', writer=self.writer) # top5 accuracy start = time.time() # batches for i, (imgs, caps, caplens) in enumerate(self.train_loader): data_time.update(time.time() - start) # Move to GPU, if available imgs = imgs.to(self.device) caps = caps.to(self.device) caplens = caplens.to(self.device) # forward encoder imgs = self.encoder(imgs) # forward decoder if self.caption_model == 'att2all': scores, caps_sorted, decode_lengths, alphas, sort_ind = self.decoder(imgs, caps, caplens) else: scores, caps_sorted, decode_lengths, sort_ind = self.decoder(imgs, caps, caplens) # since we decoded starting with <start>, the targets are all words after <start>, up to <end> targets = caps_sorted[:, 1:] # remove timesteps that we didn't decode at, or are pads # pack_padded_sequence is an easy trick to do this scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)[0] targets = pack_padded_sequence(targets, decode_lengths, batch_first=True)[0] # calc loss loss = self.loss_function(scores, targets) # doubly stochastic attention regularization (in paper: show, attend and tell) if self.caption_model == 'att2all': loss += self.tau * ((1. - alphas.sum(dim = 1)) ** 2).mean() # clear gradient of last batch self.decoder_optimizer.zero_grad() if self.encoder_optimizer is not None: self.encoder_optimizer.zero_grad() # backward loss.backward() # clip gradients if self.grad_clip is not None: clip_gradient(self.decoder_optimizer, self.grad_clip) if self.encoder_optimizer is not None: clip_gradient(self.encoder_optimizer, self.grad_clip) # update weights self.decoder_optimizer.step() if self.encoder_optimizer is not None: self.encoder_optimizer.step() # set step for tensorboard step = (epoch - 1) * self.len_epoch + i self.writer.set_step(step=step, mode='train') # keep track of metrics top5 = accuracy(scores, targets, 5) losses.update(loss.item(), sum(decode_lengths)) top5accs.update(top5, sum(decode_lengths)) batch_time.update(time.time() - start) start = time.time() # print status if i % self.print_freq == 0: print( 'Epoch: [{0}][{1}/{2}]\t' 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format( epoch, i, len(self.train_loader), batch_time = batch_time, data_time = data_time, loss = losses, top5 = top5accs ) ) def validate(self) -> float: """ Validate an epoch. Returns ------- bleu4 : float BLEU-4 score """ self.decoder.eval() # eval mode (no dropout or batchnorm) if self.encoder is not None: self.encoder.eval() batch_time = AverageMeter() losses = AverageMeter() top5accs = AverageMeter() start = time.time() ground_truth = list() # ground_truth (true captions) for calculating BLEU-4 score prediction = list() # prediction (predicted captions) # explicitly disable gradient calculation to avoid CUDA memory error # solves the issue #57 with torch.no_grad(): # Batches for i, (imgs, caps, caplens, allcaps) in enumerate(self.val_loader): # move to device, if available imgs = imgs.to(self.device) caps = caps.to(self.device) caplens = caplens.to(self.device) # forward encoder if self.encoder is not None: imgs = self.encoder(imgs) # forward decoder if self.caption_model == 'att2all': scores, caps_sorted, decode_lengths, alphas, sort_ind = self.decoder(imgs, caps, caplens) else: scores, caps_sorted, decode_lengths, sort_ind = self.decoder(imgs, caps, caplens) # since we decoded starting with <start>, the targets are all words after <start>, up to <end> targets = caps_sorted[:, 1:] # remove timesteps that we didn't decode at, or are pads # pack_padded_sequence is an easy trick to do this scores_copy = scores.clone() scores = pack_padded_sequence(scores, decode_lengths, batch_first = True)[0] targets = pack_padded_sequence(targets, decode_lengths, batch_first = True)[0] # calc loss loss = self.loss_function(scores, targets) # doubly stochastic attention regularization (in paper: show, attend and tell) if self.caption_model == 'att2all': loss += self.tau * ((1. - alphas.sum(dim = 1)) ** 2).mean() # keep track of metrics losses.update(loss.item(), sum(decode_lengths)) top5 = accuracy(scores, targets, 5) top5accs.update(top5, sum(decode_lengths)) batch_time.update(time.time() - start) start = time.time() if i % self.print_freq == 0: print('Validation: [{0}/{1}]\t' 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(self.val_loader), batch_time = batch_time, loss = losses, top5 = top5accs) ) # store ground truth captions and predicted captions of each image # for n images, each of them has one prediction and multiple ground truths (a, b, c...): # prediction = [ [hyp1], [hyp2], ..., [hypn] ] # ground_truth = [ [ [ref1a], [ref1b], [ref1c] ], ..., [ [refna], [refnb] ] ] # ground truth allcaps = allcaps[sort_ind] # because images were sorted in the decoder for j in range(allcaps.shape[0]): img_caps = allcaps[j].tolist() img_captions = list( map( lambda c: [w for w in c if w not in {self.word_map['<start>'], self.word_map['<pad>']}], img_caps ) ) # remove <start> and pads ground_truth.append(img_captions) # prediction _, preds = torch.max(scores_copy, dim = 2) preds = preds.tolist() temp_preds = list() for j, p in enumerate(preds): temp_preds.append(preds[j][:decode_lengths[j]]) # remove pads preds = temp_preds prediction.extend(preds) assert len(ground_truth) == len(prediction) # calc BLEU-4 and CIDEr score metrics = Metrics(ground_truth, prediction, self.rev_word_map) bleu4 = metrics.belu[3] # BLEU-4 cider = metrics.cider # CIDEr print( '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}, CIDEr - {cider}\n'.format( loss = losses, top5 = top5accs, bleu = bleu4, cider = cider ) ) return bleu4 def run_train(self) -> None: # epochs for epoch in range(self.start_epoch, self.epochs): # decay learning rate if there is no improvement for 8 consecutive epochs # terminate training if there is no improvement for 20 consecutive epochs if self.epochs_since_improvement == 20: break if self.epochs_since_improvement > 0 and self.epochs_since_improvement % 8 == 0: adjust_learning_rate(self.decoder_optimizer, 0.8) if self.fine_tune_encoder: adjust_learning_rate(self.encoder_optimizer, 0.8) # train an epoch self.train(epoch = epoch) # validate an epoch recent_bleu4 = self.validate() # epochs num since last improvement is_best = recent_bleu4 > self.best_bleu4 self.best_bleu4 = max(recent_bleu4, self.best_bleu4) if not is_best: self.epochs_since_improvement += 1 print("\nEpochs since last improvement: %d\n" % (self.epochs_since_improvement,)) else: self.epochs_since_improvement = 0 # save checkpoint save_checkpoint( epoch = epoch, epochs_since_improvement = self.epochs_since_improvement, encoder = self.encoder, decoder = self.decoder, encoder_optimizer = self.encoder_optimizer, decoder_optimizer = self.decoder_optimizer, caption_model = self.caption_model, bleu4 = recent_bleu4, is_best = is_best )