class Trainer(BaseTrainer): """ Trainer class """ def __init__(self, model, criterion, metric_ftns, optimizer, config, device, data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None): super().__init__(model, criterion, metric_ftns, optimizer, config) self.config = config self.device = device self.data_loader = data_loader if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler self.log_step = int(np.sqrt(data_loader.batch_size)) self.best_valid = 0 self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.file_metrics = str(self.checkpoint_dir / 'metrics.csv') def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() for batch_idx, (data, target) in enumerate(self.data_loader): data, target = data.to(self.device), target.to(self.device) self.optimizer.zero_grad() output = self.model(data) loss = self.criterion(output, target) loss.backward() self.optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.train_metrics.update(met.__name__, met(output, target)) if batch_idx % self.log_step == 0: """self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( epoch, self._progress(batch_idx), loss.item())) """ self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) if batch_idx == self.len_epoch: break log = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_'+k : v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() output = self.model(data) loss = self.criterion(output, target) self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.writer.add_scalar('Loss', loss) self._save_csv(epoch, log) return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() with torch.no_grad(): for batch_idx, (data, target) in enumerate(self.valid_data_loader): data, target = data.to(self.device), target.to(self.device) output = self.model(data) loss = self.criterion(output, target) self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.valid_metrics.update(met.__name__, met(output, target)) self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') # we added spùe custom here output = self.model(data) loss = self.criterion(output, target) self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.writer.add_scalar('Loss', loss) val_log = self.valid_metrics.result() actual_accu = val_log['accuracy'] if(actual_accu - self.best_valid > 0.0025 and self.save): self.best_valid = actual_accu if self.tensorboard: # is true you can use tensorboard self._save_checkpoint(epoch, save_best=True) filename = str(self.checkpoint_dir / 'checkpoint-best-epoch.pth') torch.save(self.model.state_dict(), filename) self.logger.info("Saving checkpoint: {} ...".format(filename)) return val_log def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total) def _save_csv(self, epoch ,log): """ Saving checkpoints :param epoch: current epoch number :param log: logging information of the epoch """ fichier = open(self.file_metrics, "a") if epoch == 1: fichier.write("epoch,") for key in log: fichier.write(str(key) +",") fichier.write("\n") fichier.write(str(epoch) +",") for key in log: fichier.write(str(log[key]) + ",") fichier.write("\n") fichier.close()
class Seq2SeqSimpleTrainer(BaseTrainer): """ Trainer for a simple seq2seq mode. """ def __init__(self, model, criterion, train_metric_ftns, eval_metric_fns, optimizer, config, device, data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None, validate_only=False): """ :param model: The model to train. :param criterion: we ignore this value and overwrite it :param train_metric_ftns: The metric function names to use for training. :param eval_metric_fns: The metric function names to use for evaluating. :param optimizer: The optimizer to use. :param config: The configuration file for the run. :param device: The device to train on. :param data_loader: The training data loader to use. :param valid_data_loader: The validation data loader to use. :param lr_scheduler: scheduler for the learning rate. :param len_epoch: The amount of examples in an epoch. :param validate_only: use if resumed, only run validation on the last resumed checkpoint. """ self.vocab = model.vocab self.pad_idx = self.vocab['<pad>'] self.criterion = criterion super().__init__(model, self.criterion, train_metric_ftns, eval_metric_fns, optimizer, config, device, data_loader, valid_data_loader, lr_scheduler) self.question_pad_length = config['data_loader']['question_pad_length'] self.qdmr_pad_length = config['data_loader']['qdmr_pad_length'] self.lexicon_pad_length = config['data_loader']['lexicon_pad_length'] if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.do_validation = self.valid_data_loader is not None self.log_step = int(np.sqrt(data_loader.batch_size)) self.train_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.train_metric_ftns], writer=self.writer) # Define evaluator. self.evaluator = Seq2SeqSimpleTester(self.model, self.criterion, self.eval_metric_ftns, self.config, self.device, self.valid_data_loader, True) # Run validation and exit. if validate_only: val_log = self.evaluator.test() log = {'val_' + k: round(v, 5) for k, v in val_log.items()} print(log) exit() def _train_epoch(self, epoch): """ Training logic for an epoch. :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ # Sets the model to training mode. self.model.train() self.train_metrics.reset() convert_to_program = self.data_loader.gold_type_is_qdmr() with tqdm(total=len(self.data_loader)) as progbar: for batch_idx, (_, data, target, lexicon_str) in enumerate(self.data_loader): data, mask_data = batch_to_tensor(self.vocab, data, self.question_pad_length, self.device) target, mask_target = batch_to_tensor(self.vocab, target, self.qdmr_pad_length, self.device) lexicon_ids, mask_lexicon = tokenize_lexicon_str( self.vocab, lexicon_str, self.qdmr_pad_length, self.device) # Run the model on the batch self.optimizer.zero_grad() # out shape is (batch_size, seq_len, output_size) output, mask_output = self.model(data, target, lexicon_ids) # CEloss expects (minibatch, classes, seq_len) # out after transpose is (batch_size, output_size, seq_len) # output = torch.transpose(output, 1, 2) # Calculate the loss and perform optimization step. # TODO test properly use of masks # output dims should be (batch_size, num_decoding_steps, num_classes) loss = self.criterion(output, mask_output, target, mask_target) loss.backward() self.optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) with torch.no_grad(): pred = torch.argmax(output, dim=1) # data_str = batch_to_str(self.vocab, data, mask_data, convert_to_program=False) # target_str = batch_to_str(self.vocab, target, mask_target, convert_to_program=convert_to_program) # pred_str = pred_batch_to_str(self.vocab, pred, convert_to_program=convert_to_program) # Update metrics self.train_metrics.update('loss', loss.item()) # for met in self.metric_ftns: # self.train_metrics.update(met.__name__, met(pred_str, target_str, data_str)) # Log progress if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( epoch, self._progress(batch_idx), loss.item())) # TODO set this to write the text examples or remove # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) if batch_idx == self.len_epoch: break # Update the progress bar. progbar.update(1) epoch_part = str(epoch) + '/' + str(self.epochs) progbar.set_postfix( epoch=epoch_part, LOSS=loss.item(), batch_size=self.data_loader.init_kwargs['batch_size'], samples=self.data_loader.n_samples) # Save the calculated metrics for that epoch. log = self.train_metrics.result() # If validation split exists, evaluate on validation set as well. if self.do_validation: # TODO print epoch stuff and add epoch to writer # TODO self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') val_log = self.evaluator.test() log.update(**{'val_' + k: round(v, 5) for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() return log def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total)
class MNISTTrainer(BaseTrainer): """ Trainer class """ def __init__(self, model, criterion, metric_fns, optimizer, config, device, data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None): super().__init__(model, criterion, metric_fns, optimizer, config, device, data_loader, valid_data_loader, lr_scheduler) if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.do_validation = self.valid_data_loader is not None self.log_step = int(np.sqrt(data_loader.batch_size)) self.train_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) # Define evaluator self.evaluator = MNISTTester(self.model, self.criterion, self.metric_ftns, self.config, self.device, self.valid_data_loader, True) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() with tqdm(total=self.data_loader.n_samples) as progbar: for batch_idx, (data, target) in enumerate(self.data_loader): data, target = data.to(self.device), target.to(self.device) self.optimizer.zero_grad() output = self.model(data) loss = self.criterion(output, target) loss.backward() self.optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.train_metrics.update(met.__name__, met(output, target)) if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( epoch, self._progress(batch_idx), loss.item())) self.writer.add_image( 'input', make_grid(data.cpu(), nrow=8, normalize=True)) if batch_idx == self.len_epoch: break progbar.update(self.data_loader.init_kwargs['batch_size']) epoch_part = str(epoch) + '/' + str(self.epochs) progbar.set_postfix(epoch=epoch_part, NLL=loss.item()) log = self.train_metrics.result() if self.do_validation: val_log = self.evaluator.test() log.update(**{'val_' + k: round(v, 5) for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() return log def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total)
class Trainer(): def __init__(self, model, criterion, metrics_name, optimizer, train_loader, logger, log_dir, nb_epochs, save_dir, device="cuda:0", log_step=10, start_epoch=0, enable_tensorboard=True, valid_loader=None, lr_scheduler=None, monitor="min val_loss", early_stop=10, save_epoch_period=1, resume=""): self.model = model self.criterion = criterion self.metrics_name = metrics_name self.optimizer = optimizer self.train_loader = train_loader self.valid_loader = valid_loader self.len_epoch = len(self.train_loader) self.do_validation = (self.valid_loader is not None) self.lr_scheduler = lr_scheduler self.log_step = log_step self.epochs = nb_epochs self.start_epoch = start_epoch + 1 self.logger = logger self.device = device self.save_period = save_epoch_period self.writer = TensorboardWriter(log_dir, self.logger, enable_tensorboard) self.train_metrics = MetricTracker('loss', *self.metrics_name, writer=self.writer) self.valid_metrics = MetricTracker('loss', *self.metrics_name, writer=self.writer) self.checkpoint_dir = save_dir if monitor == 'off': self.mnt_mode = 'off' self.mnt_best = 0 else: self.mnt_mode, self.mnt_metric = monitor.split() assert self.mnt_mode in ['min', 'max'] self.mnt_best = inf if self.mnt_mode == 'min' else -inf self.early_stop = early_stop if resume != "": self._resume_checkpoint(resume_path=resume) self.model.to(self.device) def train(self): not_improved_count = 0 for epoch in range(self.start_epoch, self.epochs + 1): result = self._train_epoch(epoch) log = {'epoch': epoch} log.update(result) self.logger.info(' {:15s}: {}'.format(str("mnt best"), self.mnt_best)) for key, value in log.items(): self.logger.info(' {:15s}: {}'.format(str(key), value)) best = False if self.mnt_mode != 'off': try: # check whether model performance improved or not, according to specified metric(mnt_metric) improved = (self.mnt_mode == 'min' and log[self.mnt_metric] < self.mnt_best) or \ (self.mnt_mode == 'max' and log[self.mnt_metric] > self.mnt_best) except KeyError: self.logger.warning( "Warning: Metric '{}' is not found. " "Model performance monitoring is disabled.".format( self.mnt_metric)) self.mnt_mode = 'off' improved = False if improved: self.mnt_best = log[self.mnt_metric] not_improved_count = 0 best = True else: not_improved_count += 1 if (not_improved_count > self.early_stop) and (self.early_stop > 0): self.logger.info( "Validation performance didn\'t improve for {} epochs. " "Training stops.".format(self.early_stop)) break if epoch % self.save_period == 0: self._save_checkpoint(epoch, best) def _train_epoch(self, epoch): self.model.train() self.train_metrics.reset() start_time = time.time() for batch_idx, sample in enumerate(self.train_loader): data = sample['image'] target = sample['mask'] data, target = data.to(self.device), target.to(self.device) current_lr = self.lr_scheduler(self.optimizer, batch_idx, epoch) self.optimizer.zero_grad() output = self.model(data) loss = self.criterion(output, target) loss.backward() self.optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item()) for met_name in self.metrics_name: self.train_metrics.update( met_name, getattr(metrics, met_name)(output, target)) if batch_idx % self.log_step == 0: time_to_run = time.time() - start_time start_time = time.time() speed = self.log_step / time_to_run self.logger.debug('Train Epoch: {} {} Loss: {:.6f} LR: {:.6f} Speed: {:.4f}iters/s' \ .format(epoch, self._progress(batch_idx), loss.item(), current_lr, speed)) for met_name in self.metrics_name: self.writer.add_scalar(met_name, self.train_metrics.avg(met_name)) self.writer.add_scalar('loss', self.train_metrics.avg('loss')) self.writer.add_scalar("lr", current_lr) # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) assert batch_idx <= self.len_epoch log = self.train_metrics.result() if self.do_validation: print("Start validation") val_log, iou_classes = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) for key, value in iou_classes.items(): log.update({key: value}) return log def _valid_epoch(self, epoch): self.model.eval() self.valid_metrics.reset() iou_tracker = metrics.IoU(2) with torch.no_grad(): for batch_idx, sample in enumerate(self.valid_loader): data = sample['image'] target = sample['mask'] data, target = data.to(self.device), target.to(self.device) output = self.model(data) loss = self.criterion(output, target) self.writer.set_step( (epoch - 1) * len(self.valid_loader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss.item()) # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) target = target.cpu().numpy() output = output[:, 0] output = output.data.cpu().numpy() pred = np.zeros_like(output) pred[output > 0.5] = 1 pred = pred.astype(np.int64) for i in range(len(target)): iou_tracker.add_batch(target[i], pred[i]) iou_classes = iou_tracker.get_iou() for key, value in iou_classes.items(): self.writer.add_scalar(key, value) self.writer.add_scalar('val_loss', self.valid_metrics.avg('loss')) for met_name in self.metrics_name: self.writer.add_scalar(met_name, self.valid_metrics.avg(met_name)) # for name, p in self.model.named_parameters(): # print(name, p) # self.writer.add_histogram(name, p.cpu().data.numpy(), bins='auto') # return self.valid_metrics.result(), iou_classes def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total) def _save_checkpoint(self, epoch, save_best=False): """ Saving checkpoints :param epoch: current epoch number :param log: logging information of the epoch :param save_best: if True, rename the saved checkpoint to 'model_best.pth' """ arch = type(self.model).__name__ state = { 'arch': arch, 'epoch': epoch, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'monitor_best': self.mnt_best, # 'config': self.config } filename = str(self.checkpoint_dir / 'checkpoint-epoch{:06d}.pth'.format(epoch)) torch.save(state, filename) self.delete_checkpoint() self.logger.info("Saving checkpoint: {} ...".format(filename)) if save_best: best_path = str(self.checkpoint_dir / 'model_best.pth') torch.save(state, best_path) self.logger.info("Saving current best: model_best.pth ...") def delete_checkpoint(self): checkpoints_file = list( self.checkpoint_dir.glob("checkpoint-epoch*.pth")) checkpoints_file.sort() for checkpoint_file in checkpoints_file[:-5]: os.remove(str(checkpoint_file.absolute())) def _resume_checkpoint(self, resume_path): self.logger.info("Loading checkpoint: {} ...".format(resume_path)) checkpoint = torch.load(resume_path) self.start_epoch = checkpoint['epoch'] + 1 self.mnt_best = checkpoint['monitor_best'] self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.logger.info( "Checkpoint loaded. Resume training from epoch {}".format( self.start_epoch))
def main(config, use_transformers): logger = config.get_logger('test') # 1 # device = torch.device('cpu') # 2 device = torch.device('cuda:{}'.format(config.config['device_id']) if config.config['n_gpu'] > 0 else 'cpu') # 测试集语料 test_dataset = config.init_obj('test_dataset', module_data_process, device=device) if use_transformers: test_dataloader = config.init_obj( 'test_data_loader', module_dataloader, dataset=test_dataset.data_set, collate_fn=test_dataset.bert_collate_fn_4_inference) model = config.init_obj('model_arch', module_arch, word_embedding=None) else: # 原始语料,只需要dataset,不需要dataloader,拿到dataset.word_embedding,普通神经网网络才需要 dataset = config.init_obj('dataset', module_data_process, device=device) test_dataloader = config.init_obj( 'test_data_loader', module_dataloader, dataset=test_dataset.data_set, collate_fn=test_dataset.collate_fn_4_inference) model = config.init_obj('model_arch', module_arch, word_embedding=dataset.word_embedding) if config['n_gpu'] > 1: device_ids = list( map(lambda x: int(x), config.config['device_id'].split(','))) model = torch.nn.DataParallel(model, device_ids=device_ids) logger.info('Loading checkpoint: {} ...'.format(config.resume)) # checkpoint = torch.load(pathlib2.PureWindowsPath(config.resume)) # checkpoint = torch.load(config.resume.replace('\\', '/')) # checkpoint = torch.load("\\saved\\text_cnn_1d\\models\\0706_122111\\checkpoint-epoch15.pth") # checkpoint = torch.load(pathlib2.PureWindowsPath(str(config.resume))) # checkpoint = torch.load(pathlib.PurePath(config.resume)) # checkpoint = torch.load(pathlib.PureWindowsPath(config.resume)) # checkpoint = torch.load(str(pathlib.PureWindowsPath(config.resume))) # checkpoint = torch.load(pathlib.PureWindowsPath(os.path.join(str(config.resume)))) # checkpoint = torch.load(os.path.join(str(config.resume))) # checkpoint = torch.load(open(os.path.join(str(config.resume)), 'rb')) # checkpoint = torch.load(open(pathlib.joinpath(str(config.resume)), 'rb')) checkpoint = torch.load(config.resume) state_dict = checkpoint['state_dict'] model.load_state_dict(state_dict) # 2 model = model.cuda() model.eval() metric_ftns = [getattr(module_metric, met) for met in config['metrics']] test_metrics = MetricTracker(*[m.__name__ for m in metric_ftns]) with torch.no_grad(): for i, batch_data in enumerate(test_dataloader): # 一个batch,128条评论 input_token_ids, _, seq_lens, class_labels, texts = batch_data # 输出值 output = model(input_token_ids, _, seq_lens).squeeze(1) # 真实类别 class_labels = class_labels # bert时候,到时候再写个布尔吧,这样不再多做一点处理(6222%128=78个结尾不去算了) if (i + 1) % 8 == 1: output_one = output.clone() class_labels_one = class_labels.clone() elif (i + 1) % 8 == 2: output_two = output.clone() class_labels_two = class_labels.clone() elif (i + 1) % 8 == 3: output_three = output.clone() class_labels_three = class_labels.clone() elif (i + 1) % 8 == 4: output_four = output.clone() class_labels_four = class_labels.clone() elif (i + 1) % 8 == 5: output_five = output.clone() class_labels_five = class_labels.clone() elif (i + 1) % 8 == 6: output_six = output.clone() class_labels_six = class_labels.clone() elif (i + 1) % 8 == 7: output_seven = output.clone() class_labels_seven = class_labels.clone() else: pred_tensor = torch.cat( (output_one, output_two, output_three, output_four, output_five, output_six, output_seven, output), 0) label_tensor = torch.cat( (class_labels_one, class_labels_two, class_labels_three, class_labels_four, class_labels_five, class_labels_six, class_labels_seven, class_labels), 0) for met in metric_ftns: test_metrics.update(met.__name__, met(pred_tensor, label_tensor)) # # 普通时候 # for met in metric_ftns: # test_metrics.update(met.__name__, met(output, class_labels)) test_log = test_metrics.result() for k, v in test_log.items(): logger.info(' {:25s}: {}'.format(str(k), v)) print(test_log['binary_auc']) return test_log['binary_auc']
class MNISTTester(BaseTester): """ Trainer for a simple seq2seq mode. """ def __init__(self, model, criterion, metric_fns, config, device, data_loader, evaluation=True): """ :param model: :param criterion: we ignore this value and overwrite it :param metric_fns: :param optimizer: :param config: :param device: :param data_loader: :param valid_data_loader: :param lr_scheduler: :param len_epoch: """ super().__init__(model, criterion, metric_fns, config, device, data_loader, evaluation) self.valid_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) def _evaluate(self): """ Validate after training an epoch. Used with gold target :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ # Sets the model to evaluation mode. self.valid_metrics.reset() total_loss = 0.0 total_metrics = torch.zeros(len(self.metric_ftns)) for i, (data, target) in enumerate(tqdm(self.data_loader)): data, target = data.to(self.device), target.to(self.device) output = self.model(data) # computing loss, metrics on test set loss = self.criterion(output, target) batch_size = data.shape[0] total_loss += loss.item() * batch_size for i, metric in enumerate(self.metric_ftns): total_metrics[i] += metric(output, target) * batch_size self.valid_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.valid_metrics.update(met.__name__, met(output, target)) self.writer.add_image( 'input', make_grid(data.cpu(), nrow=8, normalize=True)) # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result() def _predict_without_target(self): return self._evaluate()
def main(config): # logger = config.get_logger('test') # setup data_loader instances data_loader = getattr(module_data, config['data_loader']['type'])( mode="test", data_root="/root/userfolder/Dataset/ImagesAnnotations_aug/", fold=0, num_workers=4, batch_size=96) # build model architecture model = config.init_obj('arch', module_arch) params = compute_params(model) print(model) print('the params of model is: ', params) # logger.info(model) # get function handles of loss and metrics # loss_fn = getattr(module_loss, config['loss']) loss_fn = nn.BCEWithLogitsLoss() metric_fns = [getattr(module_metric, met) for met in config['metrics']] resume_path = os.path.join(config['project_root'], config['trainer']['resume_path']) checkpoint = torch.load(resume_path, map_location=torch.device('cpu')) # logger.info('Loading checkpoint: {} ...'.format(resume_path)) print('Loading checkpoint: {} ...'.format(resume_path)) state_dict = checkpoint['state_dict'] gpus = config['gpu_device'] if len(gpus) > 1: model = torch.nn.DataParallel(model) model.load_state_dict(state_dict) # prepare model for testing # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device( 'cuda:{}'.format(gpus[0]) if torch.cuda.is_available() else 'cpu') model = model.to(device) model.eval() total_loss = 0.0 total_metrics = torch.zeros(len(metric_fns)) outputs = [] targets = [] test_metrics = MetricTracker('loss', 'time', *[m.__name__ for m in metric_fns], writer=None) image_shape = None f_dir, f_name = os.path.split(resume_path) csv_path = os.path.join(f_dir, 'prediction.csv') f = open(csv_path, 'w') csv_writer = csv.writer(f) keys = ['label', 'pred'] values = [] csv_writer.writerow(keys) with torch.no_grad(): for i, (data, target) in enumerate(tqdm(data_loader.test_loader)): data, target = data.to(device), target.to(device).float() # data, target = data.cuda(), target.cuda().float() image_shape = [data.shape[2], data.shape[3]] torch.cuda.synchronize(device) start = time.time() # with torch.autograd.profiler.profile(use_cuda=True) as prof: output = model(data) torch.cuda.synchronize(device) end = time.time() # print('time:',end-start) pred = output.clone() # [batch, c] pred_list = torch.sigmoid(pred).squeeze().tolist() label = target.clone() # [batch] label_list = label.squeeze().tolist() _ = [ values.append([label_list[index], pred_list[index]]) for index in range(len(pred_list)) ] output = output.unsqueeze(dim=2).unsqueeze(dim=3) target = target.unsqueeze(dim=2) outputs.append(output.clone()) targets.append(target.clone()) loss = loss_fn(output.squeeze(dim=1), target) total_loss += loss.item() test_metrics.update('time', end - start) # for i, metric in enumerate(metric_fns): # total_metrics[i] += metric(output, target, apply_nonlin=True) # print(prof) csv_writer.writerows(values) f.close() outputs = torch.cat(outputs, dim=0) # [steps*batch, 1, 1, 1] targets = torch.cat(targets, dim=0) for met in metric_fns: test_metrics.update(met.__name__, met(outputs, targets)) log = test_metrics.result() print(log) # summary(model, (1,496, 384)) time_results = compute_precise_time(model, [496, 384], 96, loss_fn, device) print(time_results) reset_bn_stats(model) return
class Seq2SeqSimpleTester(BaseTester): """ Trainer for a simple seq2seq mode. """ def __init__(self, model, criterion, metric_ftns, config, device, data_loader, evaluation=True): """ :param model: A model to test. :param criterion: we ignore this value and overwrite it :param metric_ftns: The names of the metric functions to use. :param config: The configuration. :param device: The device to use for the testing. :param data_loader: The dataloader to use for loading the testing data. """ self.vocab = model.vocab self.question_pad_length = config['data_loader']['question_pad_length'] self.qdmr_pad_length = config['data_loader']['qdmr_pad_length'] self.lexicon_pad_length = config['data_loader']['lexicon_pad_length'] self.pad_idx = self.vocab['<pad>'] # Overriding the criterion. # self.criterion = CrossEntropyLoss(ignore_index=self.pad_idx) self.criterion = criterion super().__init__(model, self.criterion, metric_ftns, config, device, data_loader, evaluation) self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) def _evaluate(self): """ Validate after training an epoch. Used with gold target :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ # Choose 2 random examples from the dev set and print their prediction. batch_index1 = random.randint(0, len(self.data_loader) - 1) - 1 example_index1 = random.randint(0, self.data_loader.batch_size - 1) batch_index2 = random.randint(0, len(self.data_loader) - 1) - 1 example_index2 = random.randint(0, self.data_loader.batch_size - 1) questions = [] decompositions = [] targets = [] convert_to_program = self.data_loader.gold_type_is_qdmr() # Sets the model to evaluation mode. self.valid_metrics.reset() with tqdm(total=len(self.data_loader)) as progbar: for batch_idx, (_, data, target, lexicon_str) in enumerate(self.data_loader): data, mask_data = batch_to_tensor(self.vocab, data, self.question_pad_length, self.device) target, mask_target = batch_to_tensor(self.vocab, target, self.qdmr_pad_length, self.device) lexicon_ids, mask_lexicon = tokenize_lexicon_str(self.vocab, lexicon_str, self.qdmr_pad_length, self.device) start = time.time() # Run the model on the batch and calculate the loss output, mask_output = self.model(data, target, lexicon_ids, evaluation_mode=True) loss = self.criterion(output, mask_output, target, mask_target) output = torch.transpose(output, 1, 2) pred = torch.argmax(output, dim=1) start = time.time() # Convert the predictions/ targets/questions from tensor of token_ids to list of strings. # TODO do we need to convert here or can we use the originals? (for data and target) data_str = batch_to_str(self.vocab, data, mask_data, convert_to_program=False) target_str = batch_to_str(self.vocab, target, mask_target, convert_to_program=convert_to_program) pred_str = pred_batch_to_str(self.vocab, pred, convert_to_program=convert_to_program) self.valid_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.valid_metrics.update(met.__name__, met(pred_str, target_str, data_str)) # Print example for predictions. if batch_idx == batch_index1: questions.append(data_str[example_index1]) decompositions.append(pred_str[example_index1]) targets.append(target_str[example_index1]) if batch_idx == batch_index2: questions.append(data_str[example_index2]) decompositions.append(pred_str[example_index2]) targets.append(target_str[example_index2]) # Update the progress bar. progbar.update(1) progbar.set_postfix(LOSS=loss.item(), batch_size=self.data_loader.init_kwargs['batch_size'], samples=self.data_loader.n_samples) # Print example predictions. for question, decomposition, target in zip(questions, decompositions, targets): print('\ndecomposition example:') print('question:\t\t', question) print('decomposition:\t', decomposition) print('target:\t\t\t', target) print() # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result() def _predict_without_target(self): """ get model predictions for testing. Used without targets :return: A log that contains information about predictions """ qid_col = [] pred_col = [] question_col = [] convert_to_program = self.data_loader.gold_type_is_qdmr() # Sets the model to evaluation mode. self.valid_metrics.reset() with tqdm(total=len(self.data_loader)) as progbar: for batch_idx, (question_ids, data, target, lexicon_str) in enumerate(self.data_loader): data, mask_data = batch_to_tensor(self.vocab, data, self.question_pad_length, self.device) target, mask_target = batch_to_tensor(self.vocab, target, self.qdmr_pad_length, self.device) lexicon_ids, mask_lexicon = tokenize_lexicon_str(self.vocab, lexicon_str, self.qdmr_pad_length, self.device) start = time.time() # Run the model on the batch and calculate the loss output, mask_output = self.model(data, target, lexicon_ids, evaluation_mode=True) loss = self.criterion(output, mask_output, target, mask_target) output = torch.transpose(output, 1, 2) pred = torch.argmax(output, dim=1) start = time.time() # Convert the predictions/ targets/questions from tensor of token_ids to list of strings. # TODO do we need to convert here or can we use the originals? (for data and target) data_str = batch_to_str(self.vocab, data, mask_data, convert_to_program=False) target_str = batch_to_str(self.vocab, target, mask_target, convert_to_program=convert_to_program) pred_str = pred_batch_to_str(self.vocab, pred, convert_to_program=convert_to_program) for i, question_id in enumerate(question_ids): self.logger.info('{}:{}'.format(question_id, data_str[i])) qid_col.extend(question_ids) pred_col.extend(pred_str) question_col.extend(data_str) self.valid_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.valid_metrics.update(met.__name__, met(pred_str, target_str, data_str)) # Update the progress bar. progbar.update(1) progbar.set_postfix(LOSS=loss.item(), batch_size=self.data_loader.init_kwargs['batch_size'], samples=self.data_loader.n_samples) d = {'question_id': qid_col, 'question_text': question_col, 'decomposition': pred_col} programs_df = pd.DataFrame(data=d) programs_df.to_csv(self.predictions_file_name, index=False, encoding='utf-8') # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result()