class Tester(BaseTester): """ Tester class """ def __init__(self, model, criterion, metric_ftns, plot_ftns, config, data_loader): super().__init__(model, criterion, metric_ftns, plot_ftns, config) self.config = config self.data_loader = data_loader self.test_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns]) def _test(self): """ Test logic :return: A log that contains information about testing """ self.model.eval() self.test_metrics.reset() with torch.no_grad(): outputs = [] targets = [] for batch_idx, (data, target) in enumerate(tqdm(self.data_loader)): data, target = data.to(self.device, non_blocking=self.non_blocking), target.to(self.device, non_blocking=self.non_blocking) output = self.model(data) loss = self.criterion(output, target) outputs.append(output) targets.append(target) self.test_metrics.update('loss', loss.item()) outputs = torch.cat(outputs) targets = torch.cat(targets) for met in self.metric_ftns: self.test_metrics.update(met.__name__, met(outputs, targets)) for plt in self.plot_ftns: image_path = self.config.log_dir / (plt.__name__ + '.png') torchvision.utils.save_image(plt(outputs, targets).float(), image_path, normalize=True) return self.test_metrics.result()
class Eval: def __init__(self, models, criterion, metrics, device): self.criterion = criterion self.models = models self.device = device self.metrics = metrics self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metrics], writer=None) self.logger = logging.getLogger() def eval(self, valid_data_loader): for model in self.models: model.eval() self.valid_metrics.reset() outputs = [] targets = [] with torch.no_grad(): tk = tqdm(enumerate(valid_data_loader), total=len(valid_data_loader)) for batch_idx, (data, target) in tk: data, target = data.to(self.device), target.to(self.device) for model in self.models: output = model(data) output2 = model(data.flip(-1)) loss = self.criterion(output, target) outputs.append( (output.sigmoid().detach().cpu().numpy() + output2.sigmoid().detach().cpu().numpy()) / 2) targets.append(target.cpu().numpy()) self.valid_metrics.update('loss', loss.item()) tk.set_description("loss: %.6f" % loss.item()) outputs = np.concatenate(outputs) targets = np.concatenate(targets) for met in self.metrics: self.valid_metrics.update(met.__name__, met(outputs, targets)) self.logger.info(self.valid_metrics.result()) return self.valid_metrics.result()
class Trainer(BaseTrainer): """ Trainer class """ def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader, valid_data_loader=None, test_data_loader=None, lr_scheduler=None, len_epoch=None, overfit_single_batch=False): super().__init__(model, criterion, metric_ftns, optimizer, config) self.config = config self.data_loader = data_loader if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader if not overfit_single_batch else None self.test_data_loader = test_data_loader if not overfit_single_batch else None self.do_validation = self.valid_data_loader is not None self.do_test = self.test_data_loader is not None self.lr_scheduler = lr_scheduler self.log_step = int(np.sqrt(data_loader.batch_size)) self.overfit_single_batch = overfit_single_batch # ------------------------------------------------- # add flexibility to allow no metric in config.json self.log_loss = ['loss', 'nll', 'kl'] if self.metric_ftns is None: self.train_metrics = MetricTracker(*self.log_loss, writer=self.writer) self.valid_metrics = MetricTracker(*self.log_loss, writer=self.writer) # ------------------------------------------------- else: self.train_metrics = MetricTracker( *self.log_loss, *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker( *self.log_loss, *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.test_metrics = MetricTracker( *[m.__name__ for m in self.metric_ftns], writer=self.writer) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() # ---------------- # add logging grad dict_grad = {} for name, p in self.model.named_parameters(): if p.requires_grad and 'bias' not in name: dict_grad[name] = np.zeros(self.len_epoch) # ---------------- for batch_idx, batch in enumerate(self.data_loader): x, x_reversed, x_mask, x_seq_lengths = batch x = x.to(self.device) x_reversed = x_reversed.to(self.device) x_mask = x_mask.to(self.device) x_seq_lengths = x_seq_lengths.to(self.device) self.optimizer.zero_grad() x_recon, z_q_seq, z_p_seq, mu_q_seq, logvar_q_seq, mu_p_seq, logvar_p_seq = \ self.model(x, x_reversed, x_seq_lengths) kl_annealing_factor = \ determine_annealing_factor(self.config['trainer']['min_anneal_factor'], self.config['trainer']['anneal_update'], epoch - 1, self.len_epoch, batch_idx) kl_raw, nll_raw, kl_fr, nll_fr, kl_m, nll_m, loss = \ self.criterion(x, x_recon, mu_q_seq, logvar_q_seq, mu_p_seq, logvar_p_seq, kl_annealing_factor, x_mask) loss.backward() # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10) # ------------ # accumulate gradients that are to be logged later after epoch ends for name, p in self.model.named_parameters(): if p.requires_grad and 'bias' not in name: val = 0 if p.grad is None else p.grad.abs().mean() dict_grad[name][batch_idx] = val # ------------ self.optimizer.step() for l_i, l_i_val in zip(self.log_loss, [loss, nll_m, kl_m]): self.train_metrics.update(l_i, l_i_val.item()) if self.metric_ftns is not None: for met in self.metric_ftns: if met.__name__ == 'bound_eval': self.train_metrics.update( met.__name__, met([x_recon, mu_q_seq, logvar_q_seq], [x, mu_p_seq, logvar_p_seq], mask=x_mask)) if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( epoch, self._progress(batch_idx), loss.item())) if batch_idx == self.len_epoch or self.overfit_single_batch: break # --------------------------------------------------- if self.writer is not None: self.writer.set_step(epoch, 'train') # log losses for l_i in self.log_loss: self.train_metrics.write_to_logger(l_i) # log metrics if self.metric_ftns is not None: if met.__name__ == 'bound_eval': self.train_metrics.write_to_logger(met.__name__) # log gradients for name, p in dict_grad.items(): self.writer.add_histogram(name + '/grad', p, bins='auto') # log parameters for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') # log kl annealing factors self.writer.add_scalar('anneal_factor', kl_annealing_factor) # --------------------------------------------------- if epoch % 50 == 0: fig = create_reconstruction_figure(x, torch.sigmoid(x_recon)) # debug_fig = create_debug_figure(x, x_reversed, x_mask) # debug_fig_loss = create_debug_loss_figure(kl_raw, nll_raw, kl_fr, nll_fr, kl_m, nll_m, x_mask) self.writer.set_step(epoch, 'train') self.writer.add_figure('reconstruction', fig) # self.writer.add_figure('debug', debug_fig) # self.writer.add_figure('debug_loss', debug_fig_loss) log = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) if self.do_test and epoch % 50 == 0: test_log = self._test_epoch(epoch) log.update(**{'test_' + k: v for k, v in test_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() with torch.no_grad(): for batch_idx, batch in enumerate(self.valid_data_loader): x, x_reversed, x_mask, x_seq_lengths = batch x = x.to(self.device) x_reversed = x_reversed.to(self.device) x_mask = x_mask.to(self.device) x_seq_lengths = x_seq_lengths.to(self.device) x_recon, z_q_seq, z_p_seq, mu_q_seq, logvar_q_seq, mu_p_seq, logvar_p_seq = \ self.model(x, x_reversed, x_seq_lengths) kl_raw, nll_raw, kl_fr, nll_fr, kl_m, nll_m, loss = \ self.criterion(x, x_recon, mu_q_seq, logvar_q_seq, mu_p_seq, logvar_p_seq, 1, x_mask) for l_i, l_i_val in zip(self.log_loss, [loss, nll_m, kl_m]): self.valid_metrics.update(l_i, l_i_val.item()) if self.metric_ftns is not None: for met in self.metric_ftns: if met.__name__ == 'bound_eval': self.valid_metrics.update( met.__name__, met([x_recon, mu_q_seq, logvar_q_seq], [x, mu_p_seq, logvar_p_seq], mask=x_mask)) # --------------------------------------------------- if self.writer is not None: self.writer.set_step(epoch, 'valid') for l_i in self.log_loss: self.valid_metrics.write_to_logger(l_i) if self.metric_ftns is not None: for met in self.metric_ftns: if met.__name__ == 'bound_eval': self.valid_metrics.write_to_logger(met.__name__) # --------------------------------------------------- if epoch % 10 == 0: x_recon = torch.nn.functional.sigmoid( x_recon.view(x.size(0), x.size(1), -1)) fig = create_reconstruction_figure(x, x_recon) # debug_fig = create_debug_figure(x, x_reversed_unpack, x_mask) # debug_fig_loss = create_debug_loss_figure(kl_raw, nll_raw, kl_fr, nll_fr, kl_m, nll_m, x_mask) self.writer.set_step(epoch, 'valid') self.writer.add_figure('reconstruction', fig) # self.writer.add_figure('debug', debug_fig) # self.writer.add_figure('debug_loss', debug_fig_loss) return self.valid_metrics.result() def _test_epoch(self, epoch): self.model.eval() self.test_metrics.reset() with torch.no_grad(): for batch_idx, batch in enumerate(self.test_data_loader): x, x_reversed, x_mask, x_seq_lengths = batch x = x.to(self.device) x_reversed = x_reversed.to(self.device) x_mask = x_mask.to(self.device) x_seq_lengths = x_seq_lengths.to(self.device) x_recon, z_q_seq, z_p_seq, mu_q_seq, logvar_q_seq, mu_p_seq, logvar_p_seq = \ self.model(x, x_reversed, x_seq_lengths) if self.metric_ftns is not None: for met in self.metric_ftns: if met.__name__ == 'bound_eval': self.test_metrics.update( met.__name__, met([x_recon, mu_q_seq, logvar_q_seq], [x, mu_p_seq, logvar_p_seq], mask=x_mask)) if met.__name__ == 'importance_sample': self.test_metrics.update( met.__name__, met(batch_idx, self.model, x, x_reversed, x_seq_lengths, x_mask, n_sample=500)) # --------------------------------------------------- if self.writer is not None: self.writer.set_step(epoch, 'test') if self.metric_ftns is not None: for met in self.metric_ftns: self.test_metrics.write_to_logger(met.__name__) n_sample = 3 output_seq, z_p_seq, mu_p_seq, logvar_p_seq = self.model.generate( n_sample, 100) output_seq = torch.sigmoid(output_seq) plt.close() fig, ax = plt.subplots(n_sample, 1, figsize=(10, n_sample * 10)) for i in range(n_sample): ax[i].imshow(output_seq[i].T.cpu().detach().numpy(), origin='lower') self.writer.add_figure('generation', fig) # --------------------------------------------------- return self.test_metrics.result() def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total)
class KDTrainer(TrainerBase): def __init__(self, s_model, t_model, epoch, criterion, metrics, optimizer, device, data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None, checkpoint=None, sts=[]): # sts=[stop, st_empty, save_dir] super().__init__(s_model, criterion, metrics, optimizer, epoch, checkpoint, save_dir=sts[2], st_stop=sts[0]) self.scaler = GradScaler() self.device = device self.s_model = self.model self.s_model = self.s_model.to(device) self.t_model = t_model self.t_model = self.t_model.to(device) self.kd_criterion = nn.KLDivLoss(size_average=False) self.data_loader = data_loader if len_epoch is None: self.len_epoch = len(self.data_loader) else: self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler self.st_empty = sts[1] self.st_container = self.st_empty.beta_container() self.lossChart = self.st_container.line_chart() self.processBar = self.st_container.progress(0) self.epochResult = self.st_container.table() self.train_idx = 0 self.log_step = 100 self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metrics], writer=None) self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metrics], writer=None) def _train_epoch(self, epoch: int) -> dict: self.model.train() self.train_metrics.reset() outputs = [] targets = [] for batch_idx, (data, target) in enumerate(tqdm(self.data_loader)): if self.st_stop: break data, target = data.to(self.device), target.to(self.device) self.optimizer.zero_grad() output, loss = self._calculate_loss(data, target) outputs.append(output.sigmoid().detach().cpu().numpy()) targets.append(target.cpu().numpy()) self.train_metrics.update('loss', loss.item()) loss.backward() self.optimizer.step() if batch_idx % self.log_step == 0: self.lossChart.add_rows( pd.DataFrame(self.train_metrics.result(), index=[self.train_idx])) self.train_idx = self.train_idx + 1 self.processBar.progress(batch_idx / self.len_epoch) if batch_idx == self.len_epoch: break outputs = np.concatenate(outputs) targets = np.concatenate(targets) for i, met in enumerate(self.metrics): self.train_metrics.update(met.__name__, met(outputs, targets)) log = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) if self.lr_scheduler is not None: if type(self.lr_scheduler ) == torch.optim.lr_scheduler.ReduceLROnPlateau: self.lr_scheduler.step(log["val_loss"]) else: self.lr_scheduler.step() st_res = log.copy() self.epochResult.add_rows(pd.DataFrame(st_res, index=[epoch])) self.logger.info(log) return log def _valid_epoch(self, epoch: int) -> dict: self.model.eval() self.valid_metrics.reset() outputs = [] targets = [] with torch.no_grad(): for batch_idx, (data, target) in enumerate(self.valid_data_loader): data, target = data.to(self.device), target.to(self.device) output, loss = self._calculate_loss(data, target) outputs.append(output.sigmoid().detach().cpu().numpy()) targets.append(target.cpu().numpy()) self.valid_metrics.update('loss', loss.item()) outputs = np.concatenate(outputs) targets = np.concatenate(targets) for met in self.metrics: self.valid_metrics.update(met.__name__, met(outputs, targets)) return self.valid_metrics.result() def _kd_loss(self, out_s, out_t, target): alpha = 0.5 T = 4 loss = self.criterion(out_s, target) batch_size = target.shape[0] s_max = F.log_softmax(out_s / T, dim=1) t_max = F.softmax(out_t / T, dim=1) loss_kd = self.kd_criterion(s_max, t_max) / batch_size loss = (1 - alpha) * loss + alpha * T * T * loss_kd return loss def _calculate_loss(self, data, target): out_s = self.s_model(data) out_t = self.t_model(data) loss = self._kd_loss(out_s, out_t, target) return out_s, loss
class Trainer(BaseTrainer): """ Trainer class """ def __init__(self, model, loss_fn_class, loss_fn_domain, metric_ftns, optimizer, config, device, data_loader_source, valid_data_loader_source=None, data_loader_target=None, valid_data_loader_target=None, lr_scheduler=None, len_epoch=None): super().__init__(model, metric_ftns, optimizer, config) self.config = config self.device = device self.loss_fn_class = loss_fn_class self.loss_fn_domain = loss_fn_domain self.data_loader_source = data_loader_source self.valid_data_loader_source = valid_data_loader_source self.data_loader_target = data_loader_target self.valid_data_loader_target = valid_data_loader_target self.model.to(self.device) if len_epoch is None: # epoch-based training self.len_epoch = min(len(self.data_loader_source), len(self.data_loader_target)) else: # FIXME: implement source/target style training or remove this feature # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch # FIXME: handle validation round self.valid_data_loader = valid_data_loader_source self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler self.log_step = 64 self.train_metrics = MetricTracker( 'loss', 'class_loss', 'domain_loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker( 'loss', 'class_loss', 'domain_loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ # Setting model into train mode, required_grad self.model.train() # Reset all metric in metric dataframe self.train_metrics.reset() batch_idx = 0 for source, target in zip(self.data_loader_source, self.data_loader_target): # source, target = source.to(self.device), target.to(self.device) # Calculate training progress and GRL λ p = float(batch_idx + (epoch-1) * self.len_epoch) / \ (self.epochs * self.len_epoch) λ = 2. / (1. + np.exp(-10 * p)) - 1 # === Train on source domain X_source, y_source = source X_source, y_source = X_source.to(self.device), y_source.to( self.device) # generate source domain labels: 0 y_s_domain = torch.zeros(X_source.shape[0], dtype=torch.float32) y_s_domain = y_s_domain.to(self.device) class_pred_source, domain_pred_source = self.model(X_source, λ) # source classification loss loss_s_label = self.loss_fn_class(class_pred_source.squeeze(), y_source) # Compress from tensor size batch*1*1*1 => batch domain_pred_source = torch.squeeze(domain_pred_source) loss_s_domain = self.loss_fn_domain( domain_pred_source, y_s_domain) # source domain loss (via GRL) # === Train on target domain X_target, _ = target # generate source domain labels: 0 y_t_domain = torch.ones(X_target.shape[0], dtype=torch.float32) X_target = X_target.to(self.device) y_t_domain = y_t_domain.to(self.device) _, domain_pred_target = self.model(X_target, λ) domain_pred_target = torch.squeeze(domain_pred_target) loss_t_domain = self.loss_fn_domain( domain_pred_target, y_t_domain) # source domain loss (via GRL) # === Optimizer ==== self.optimizer.zero_grad() loss_s_label = torch.log(loss_s_label + 1e-9) loss = loss_t_domain + loss_s_domain + loss_s_label loss.backward() self.optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item()) self.train_metrics.update('class_loss', loss_s_label.item()) self.train_metrics.update('domain_loss', loss_s_domain.item()) for met in self.metric_ftns: self.train_metrics.update(met.__name__, met(class_pred_source, y_source)) if batch_idx % self.log_step == 0: self.logger.debug( f'Train Epoch: {epoch} {self._progress(batch_idx)} Loss: {loss.item():.4f} Source class loss: {loss_s_label.item():3f} Source domain loss {loss_s_domain.item():3f}' ) self.writer.add_image( 'input', make_grid(X_source.cpu(), nrow=4, normalize=True)) batch_idx += 1 if batch_idx == self.len_epoch: break # Average the accumulated result to log the result log = self.train_metrics.result() # update lambda value to metric tracker log["lambda"] = λ # Run validation after each epoch if validation dataloader is available. if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ # Set model to evaluation mode, required_grad = False # disables dropout and has batch norm use the entire population statistics self.model.eval() # Reset validation metrics in dataframe for a new validation round self.valid_metrics.reset() with torch.no_grad(): for batch_idx, (data, target) in enumerate(self.valid_data_loader): data, target = data.to(self.device), target.to(self.device) # ignore labmda value output, _ = self.model(data, 1) loss = self.loss_fn_class(output.squeeze(), target) self.writer.set_step( (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.valid_metrics.update(met.__name__, met(output, target)) self.writer.add_image( 'input', make_grid(data.cpu(), nrow=4, normalize=True)) # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result() def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader_source, 'n_samples'): current = batch_idx * self.data_loader_source.batch_size total = self.data_loader_source.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total)
class Trainer(BaseTrainer): """ Trainer class """ def __init__(self, model, criterion, metric_ftns, optimizer, config, device, data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None): super().__init__(model, criterion, metric_ftns, optimizer, config, device) self.config = config self.data_loader = data_loader if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler self.log_step = int(np.sqrt(data_loader.batch_size)) self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() for batch_idx, (sentences, sentences_mask, strokes, strokes_mask) in enumerate(self.data_loader): # Moving input data to device sentences, sentences_mask = sentences.to(self.device), sentences_mask.to(self.device) strokes, strokes_mask = strokes.to(self.device), strokes_mask.to(self.device) # Compute the loss and perform an optimization step self.optimizer.zero_grad() if str(self.model).startswith('Unconditional'): output_network = self.model(sentences, sentences_mask, strokes, strokes_mask) gaussian_params = self.model.compute_gaussian_parameters(output_network) loss = self.criterion(gaussian_params, strokes, strokes_mask) loss.backward() # Gradient clipping clip_grad_norm_(self.model.rnn_1.parameters(), 10) clip_grad_norm_(self.model.rnn_2.parameters(), 10) clip_grad_norm_(self.model.rnn_3.parameters(), 10) elif str(self.model).startswith('Conditional'): output_network = self.model(sentences, sentences_mask, strokes, strokes_mask) gaussian_params = self.model.compute_gaussian_parameters(output_network) loss = self.criterion(gaussian_params, strokes, strokes_mask) loss.backward() # Gradient clipping clip_grad_norm_(self.model.rnn_1_with_gaussian_attention.lstm_cell.parameters(), 10) clip_grad_norm_(self.model.rnn_2.parameters(), 10) clip_grad_norm_(self.model.rnn_3.parameters(), 10) elif str(self.model).startswith('Seq2Seq'): output_network = self.model(sentences, sentences_mask, strokes, strokes_mask) loss = self.criterion(output_network, sentences, sentences_mask) loss.backward() # Gradient clipping clip_grad_norm_(self.model.parameters(), 10) elif str(self.model).startswith('Graves'): output_network = self.model(sentences, sentences_mask, strokes, strokes_mask) gaussian_params = self.model.compute_gaussian_parameters(output_network) loss = self.criterion(gaussian_params, strokes, strokes_mask) loss.backward() # Gradient clipping clip_grad_norm_(self.model.rnn_1.parameters(), 10) clip_grad_norm_(self.model.rnn_2_with_gaussian_attention.lstm_cell.parameters(), 10) clip_grad_norm_(self.model.rnn_3.parameters(), 10) else: NotImplementedError("Not a valid model name") self.optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item()) if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( epoch, self._progress(batch_idx), loss.item())) if batch_idx == self.len_epoch: break log = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_'+k : v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() with torch.no_grad(): for batch_idx, (sentences, sentences_mask, strokes, strokes_mask) in enumerate(self.valid_data_loader): # Moving input data to device sentences, sentences_mask = sentences.to(self.device), sentences_mask.to(self.device) strokes, strokes_mask = strokes.to(self.device), strokes_mask.to(self.device) # Compute the loss if str(self.model).startswith('Unconditional'): output_network = self.model(sentences, sentences_mask, strokes, strokes_mask) gaussian_params = self.model.compute_gaussian_parameters(output_network) loss = self.criterion(gaussian_params, strokes, strokes_mask) elif str(self.model).startswith('Conditional'): output_network = self.model(sentences, sentences_mask, strokes, strokes_mask) gaussian_params = self.model.compute_gaussian_parameters(output_network) loss = self.criterion(gaussian_params, strokes, strokes_mask) elif str(self.model).startswith('Seq2Seq'): output_network = self.model(sentences, sentences_mask, strokes, strokes_mask) loss = self.criterion(output_network, sentences, sentences_mask) elif str(self.model).startswith('Graves'): output_network = self.model(sentences, sentences_mask, strokes, strokes_mask) gaussian_params = self.model.compute_gaussian_parameters(output_network) loss = self.criterion(gaussian_params, strokes, strokes_mask) else: NotImplementedError("Not a valid model name") self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss.item()) # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result() def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total)
class Trainer(BaseTrainer): """ Trainer class """ def __init__(self, model, criterion, metric_ftns, optimizer, config, i_fold, data_loader, valid_data_loader=None, test_data_loader=None, lr_scheduler=None, len_epoch=None): super().__init__(model, criterion, metric_ftns, optimizer, config) self.config = config self.i_fold = i_fold self.data_loader = data_loader if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.test_data_loader = test_data_loader self.lr_scheduler = lr_scheduler self.log_step = int(np.sqrt(data_loader.batch_size)) self.train_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.zero_grad() self.train_metrics.reset() adv_train = self.config.init_obj('adversarial_training', module_adversarial, model=self.model) K = 3 for batch_idx, data in enumerate(self.data_loader): self.model.train() ids, texts, input_ids, attention_masks, text_lengths, labels = data if 'cuda' == self.device.type: input_ids = input_ids.cuda(self.device) attention_masks = attention_masks.cuda(self.device) labels = labels.cuda(self.device) preds, cls_embedding = self.model(input_ids, attention_masks, text_lengths) loss = self.criterion[0](preds, labels) # 损失截断 loss_zeros = torch.zeros_like(loss) loss = torch.where( loss > float(self.config.config['loss']['loss_cut']), loss, loss_zeros) loss.backward() if self.config.config['trainer'][ 'is_adversarial_training'] and self.config.config[ 'adversarial_training']['type'] == 'FGM': # 对抗训练 adv_train.attack() adv_preds, adv_cls_embedding = self.model( input_ids, attention_masks, text_lengths) adv_loss = self.criterion[0](adv_preds, labels) adv_loss.backward() adv_train.restore() elif self.config.config['trainer'][ 'is_adversarial_training'] and self.config.config[ 'adversarial_training']['type'] == 'PGD': adv_train.backup_grad() # 对抗训练 for t in range(K): adv_train.attack(is_first_attack=( t == 0 )) # 在embedding上添加对抗扰动, first attack时备份param.data if t != K - 1: self.model.zero_grad() else: adv_train.restore_grad() adv_preds, adv_cls_embedding = self.model( input_ids, attention_masks, text_lengths) adv_loss = self.criterion[0](adv_preds, labels) adv_loss.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度 adv_train.restore() # 恢复embedding参数 if self.config.config['trainer']['clip_grad']: # 梯度截断 torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.config.config['trainer']['max_grad_norm']) self.optimizer.step() if self.lr_scheduler is not None: self.lr_scheduler.step() self.model.zero_grad() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.train_metrics.update(met.__name__, met(preds, labels)) if batch_idx % self.log_step == 0: self.logger.debug( 'Train Epoch: {} {} Loss: {:.3f} lr: {}'.format( epoch, self._progress(batch_idx), loss.item(), self.optimizer.param_groups[0]['lr'])) if batch_idx == self.len_epoch: break log = self.train_metrics.result() if self.valid_data_loader: val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() with torch.no_grad(): for batch_idx, data in enumerate(self.valid_data_loader): ids, texts, input_ids, attention_masks, text_lengths, labels = data if 'cuda' == self.device.type: input_ids = input_ids.cuda(self.device) attention_masks = attention_masks.cuda(self.device) labels = labels.cuda(self.device) preds, cls_embedding = self.model(input_ids, attention_masks, text_lengths) if self.add_graph: input_model = self.model.module if (len( self.config.config['device_id']) > 1) else self.model self.writer.writer.add_graph( input_model, [input_ids, attention_masks, text_lengths]) self.add_graph = False loss = self.criterion[0](preds, labels) self.writer.set_step( (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.valid_metrics.update(met.__name__, met(preds, labels)) log = self.valid_metrics.result() # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return log def _inference(self): """ Inference after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ checkpoint = torch.load(self.best_path) self.logger.info("load best mode {} ...".format(self.best_path)) self.model.load_state_dict(checkpoint['state_dict']) self.model.eval() ps = [] ls = [] with torch.no_grad(): for batch_idx, data in enumerate(self.valid_data_loader): ids, texts, input_ids, attention_masks, text_lengths, labels = data if 'cuda' == self.device.type: input_ids = input_ids.cuda(self.device) attention_masks = attention_masks.cuda(self.device) labels = labels.cuda(self.device) preds, cls_embedding = self.model(input_ids, attention_masks, text_lengths) ps.append(preds) ls.append(labels) ps = torch.cat(ps, dim=0) ls = torch.cat(ls, dim=0) acc = module_mertric.binary_accuracy(ps, ls) self.logger.info('\toverall acc :{}'.format(acc)) result_file = self.test_data_loader.dataset.data_dir.parent / 'result' / '{}-{}-{}-{}-{}.jsonl'.format( self.config.config['experiment_name'], self.test_data_loader.dataset.transformer_model, self.config.config['k_fold'], self.i_fold, acc) if not result_file.parent.exists(): result_file.parent.mkdir() result_writer = result_file.open('w') with torch.no_grad(): for batch_idx, data in enumerate(self.test_data_loader): ids, texts, input_ids, attention_masks, text_lengths, labels = data if 'cuda' == self.device.type: input_ids = input_ids.cuda(self.device) attention_masks = attention_masks.cuda(self.device) preds, cls_embedding = self.model(input_ids, attention_masks, text_lengths) preds = torch.round( torch.sigmoid(preds)).cpu().detach().numpy() for pred, item_id, text in zip(preds, ids, texts): result_writer.write( json.dumps( { "id": item_id, "text": text, "labels": int(pred) }, ensure_ascii=False) + '\n') result_writer.close() self.logger.info('result saving to {}'.format(result_file)) def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total)
class Trainer(BaseTrainer): def __init__(self, model, criterion, metric_ftns, optimizer: Optimizer, config, data_loader, valid_data_loader, len_epoch=None, log_step=2): super().__init__(model, criterion, metric_ftns, optimizer, config) self.config = config self.data_loader = data_loader self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.log_step = log_step if len_epoch is None: self.data_loader_iter = data_loader self.len_epoch = len(self.data_loader) else: self.data_loader_iter = inf_loop(self.data_loader) self.valid_loader_iter = inf_loop(self.valid_data_loader) self.len_epoch = len_epoch self.valid_len_epoch = 53 self.train_metrics = MetricTracker( 'train_loss', *['train_' + m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker( 'val_loss', *['val_' + m.__name__ for m in self.metric_ftns], writer=self.writer) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() self.logger.info(epoch) self.logger.info("Current difficulty: {}".format( self.data_loader.gen.difficulty)) for batch_idx, batch in enumerate(self.data_loader_iter): (input_variables, input_lengths, target) = batch self.optimizer.zero_grad() output, _, sequence_info = self.model.forward( input=input_variables, input_lens=input_lengths, target=target, teacher_forcing_ratio=0.5) loss = self.criterion.__call__(output, target) loss.backward() self.optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) # set train metrics self.train_metrics.update('train_loss', loss.item()) for metric in self.metric_ftns: self.train_metrics.update( 'train_' + metric.__name__, metric(output, target, sequence_info)) if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss: {:.8f}'.format( epoch, self._progress(batch_idx), loss.item())) if self.train_metrics.result().get('train_Accuracy') >= 0.93: self.optimizer.step_lr( self.train_metrics.result().get('train_loss'), epoch) if batch_idx == self.len_epoch: break history = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) if val_log.get('val_Accuracy' ) >= 0.95 and self.train_metrics.result().get( 'train_Accuracy') >= 0.80: self.logger.info("Increasing difficulty") self.data_loader.gen.increase_difficulty() self.valid_data_loader.gen.increase_difficulty() self.logger.info("Current difficulty: {}".format( self.data_loader.gen.difficulty)) history.update(**{k: v for k, v in val_log.items()}) self.optimizer.step_lr(history['train_loss'], epoch) return history def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() with torch.no_grad(): for batch_idx, batch in enumerate(self.valid_loader_iter): (input_variables, input_lengths, target) = batch output, _, sequence_info = self.model.forward( input=input_variables, input_lens=input_lengths, target=target) loss = self.criterion.__call__(output, target) # set writer step self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx, 'valid') # set val metrics self.valid_metrics.update('val_loss', loss.item()) for metric in self.metric_ftns: self.valid_metrics.update( 'val_' + metric.__name__, metric(output, target, sequence_info)) if batch_idx == self.valid_len_epoch: break for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result() def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total)
class ISBITrainer(BaseTrainer): """ Trainer class """ def __init__(self, model, loss, metric_ftns, optimizer, config, data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None): super().__init__(model, loss, metric_ftns, optimizer, config) self.config = config self.data_loader = data_loader self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) self.len_epoch_val = len( self.valid_data_loader) if self.do_validation else 0 self.lr_scheduler = lr_scheduler self.log_step = int(np.sqrt(data_loader.batch_size)) self.train_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) @abstractmethod def _process(self, epoch, data_loader, metrics, mode: Mode = Mode.TRAIN): raise NotImplementedError( 'Method _process() from ISBITrainer class has to be implemented!') def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() self._process(epoch, self.data_loader, self.train_metrics, Mode.TRAIN) log = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() with torch.no_grad(): self._process(epoch, self.valid_data_loader, self.valid_metrics, Mode.VAL) # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result() def log_scalars(self, metrics, step, output, target, loss, mode=Mode.TRAIN): self.writer.set_step(step, mode) metrics.update('loss', loss.item()) for met in self.metric_ftns: metrics.update(met.__name__, met(output, target)) @staticmethod def _progress(data_loader, batch_idx, batches): base = '[{}/{} ({:.0f}%)]' if hasattr(data_loader, 'n_samples'): current = batch_idx * data_loader.batch_size total = data_loader.n_samples else: current = batch_idx total = batches return base.format(current, total, 100.0 * current / total) @staticmethod def get_step(batch_idx, epoch, len_epoch): return (epoch - 1) * len_epoch + batch_idx
class Trainer(BaseTrainer): """ Trainer class """ def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None): super().__init__(model, criterion, metric_ftns, optimizer, config) self.config = config self.data_loader = data_loader if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler self.log_step = int(self.len_epoch / 4) # int(np.sqrt(data_loader.batch_size)) self.train_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) def get_lr(self): for param_group in self.optimizer.param_groups: return param_group['lr'] def _train_epoch(self, epoch): fp16 = False gradient_accumulation_steps = 1 self.logger.info("Current gradient_accumulation_steps: {}".format( gradient_accumulation_steps)) self.logger.info("Current learning rate: {}".format(self.get_lr())) """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() trange = tqdm(enumerate(self.data_loader), total=self.len_epoch, desc="training") for batch_idx, batch in trange: data = batch["sentence"] target = batch["label"] if not isinstance(data, list): # check if type is list data = data.to(self.device) if not isinstance(target, list): # check if type is list target = target.to(self.device) output = self.model(data) if isinstance(output, list): output = torch.cat(output, dim=0).cuda() if isinstance(target, list): target = torch.cat(target, dim=0).cuda() if fp16: print(output, target) loss = self.criterion(output, target).half() print(loss) else: loss = self.criterion(output, target) if gradient_accumulation_steps > 1: loss = loss / gradient_accumulation_steps if fp16: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if (batch_idx + 1) % gradient_accumulation_steps == 0: if fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optimizer), 1.0) else: torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) self.optimizer.step() self.optimizer.zero_grad() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item(), output.size(0)) predict = (output >= 0.5) maxclass = torch.argmax( output, dim=1 ) # make sure every sentence predicted to at least one class for i in range(len(predict)): predict[i][maxclass[i].item()] = 1 predict = predict.type(torch.LongTensor).to(self.device) for met in self.metric_ftns: self.train_metrics.update(met.__name__, met(predict, target), predict.size(0)) ''' if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( epoch, self._progress(batch_idx), loss.item())) self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) ''' if batch_idx == self.len_epoch: break trange.set_postfix(loss=loss.item()) log = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() with torch.no_grad(): for batch_idx, batch in enumerate(self.valid_data_loader): data = batch["sentence"] target = batch["label"] if not isinstance(data, list): data = data.to(self.device) if not isinstance(target, list): target = target.to(self.device) output = self.model(data) if isinstance(output, list): output = torch.cat(output, dim=0).to(self.device) if isinstance(target, list): target = torch.cat(target, dim=0).to(self.device) loss = self.criterion(output, target) self.writer.set_step( (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss.item(), output.size(0)) predict = (output >= 0.5) maxclass = torch.argmax( output, dim=1 ) # make sure every sentence predicted to at least one class for i in range(len(predict)): predict[i][maxclass[i].item()] = 1 predict = predict.type(torch.LongTensor).to(self.device) for met in self.metric_ftns: self.valid_metrics.update(met.__name__, met(predict, target), predict.size(0)) #self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result() def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total)
class Trainer(BaseTrainer): """ Trainer class """ def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None): super().__init__(model, criterion, metric_ftns, optimizer, config) self.config = config self.data_loader = data_loader if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler self.log_step = int(np.sqrt(data_loader.batch_size)) self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) if hasattr(self.data_loader, 'n_valid_samples'): validation_samples=self.data_loader.n_valid_samples else: validation_samples=self.valid_data_loader.n_samples self.heatmap_sample_indices=np.sort(np.random.randint(validation_samples, size=min(16, validation_samples))) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() for batch_idx, (data, target) in enumerate(self.data_loader): data, target = data.to(self.device), target.to(self.device) self.optimizer.zero_grad() output = self.model(data) loss = self.criterion(output, target) loss.backward() self.optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.train_metrics.update(met.__name__, met(output, target)) if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( epoch, self._progress(batch_idx), loss.item())) # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) if batch_idx == self.len_epoch: break log = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_'+k : v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() y=[] y_hat=[] with torch.no_grad(): for batch_idx, (data, target) in enumerate(self.valid_data_loader): data, target = data.to(self.device), target.to(self.device) output = self.model(data) loss = self.criterion(output, target) self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.valid_metrics.update(met.__name__, met(output, target)) # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) y.append(target.cpu().numpy()) y_hat.append(output.detach().cpu().numpy()) y=np.concatenate(y) y_hat=np.concatenate(y_hat) self._do_validation_visualizations(epoch, y, y_hat) # # add histogram of model parameters to the tensorboard # for name, p in self.model.named_parameters(): # self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result() def _do_validation_visualizations(self, epoch, y, y_hat): # show multi class AUC-ROC roc=get_auc_roc_curve(y, y_hat, len(self.data_loader.dataset.classes), labels=self.data_loader.dataset.classes) self.writer.add_figure("metric/roc", roc) # show confusion_matrix cm=get_confusion_matrix_figure(y, y_hat, len(self.data_loader.dataset.classes), labels=self.data_loader.dataset.classes) self.writer.add_figure("metric/confusion_matrix", cm) self._do_heatmaps(epoch) return def _do_heatmaps(self, epoch): images=[] targets=[] for idx in self.heatmap_sample_indices: img,label = self.valid_data_loader.dataset.__getitem__(idx) images.append(img) targets.append(label) gradcam, gradcam_pp = get_heatmap_tensors(images, self.model, self.config, self.checkpoint_dir, epoch, save_images_to_dir=True) self.writer.add_images("saliency/gradcam", gradcam) self.writer.add_images("saliency/gradcam_pp", gradcam_pp) return def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total)
class SegmentationTrainer(BaseTrainer): def __init__(self, model, criterion, metrics, optimizer, config, lr_scheduler=None): super().__init__(model, criterion, metrics, optimizer, config) self.lr_scheduler = lr_scheduler self.loss_name = 'supervised_loss' # Metrics # Train self.train_loss = MetricTracker(self.loss_name, self.writer) self.train_metrics = MetricTracker(*self.metric_names, self.writer) # Validation self.valid_loss = MetricTracker(self.loss_name, self.writer) self.valid_metrics = MetricTracker(*self.metric_names, self.writer) # Test self.test_loss = MetricTracker(self.loss_name, self.writer) self.test_metrics = MetricTracker(*self.metric_names, self.writer) if isinstance(self.model, nn.DataParallel): self.criterion = nn.DataParallel(self.criterion) # Resume checkpoint if path is available in config cp_path = self.config['trainer'].get('resume_path') if cp_path: super()._resume_checkpoint() def reset_scheduler(self): self.train_loss.reset() self.train_metrics.reset() self.valid_loss.reset() self.valid_metrics.reset() self.test_loss.reset() self.test_metrics.reset() # if isinstance(self.lr_scheduler, MyReduceLROnPlateau): # self.lr_scheduler.reset() def prepare_train_epoch(self, epoch): self.logger.info('EPOCH: {}'.format(epoch)) self.reset_scheduler() def _train_epoch(self, epoch): self.model.train() self.prepare_train_epoch(epoch) for batch_idx, (data, target, image_name) in enumerate(self.train_data_loader): data, target = data.to(self.device), target.to(self.device) output = self.model(data) loss = self.criterion(output, target) # For debug model if torch.isnan(loss): super()._save_checkpoint(epoch) self.model.zero_grad() loss.backward() self.optimizer.step() # Update train loss, metrics self.train_loss.update(self.loss_name, loss.item()) for metric in self.metrics: self.train_metrics.update(metric.__name__, metric(output, target), n=output.shape[0]) if batch_idx % self.log_step == 0: self.log_for_step(epoch, batch_idx) if self.save_for_track and (batch_idx % self.save_for_track == 0): save_output(output, image_name, epoch, self.checkpoint_dir) if batch_idx == self.len_epoch: break log = self.train_loss.result() log.update(self.train_metrics.result()) if self.do_validation and (epoch % self.do_validation_interval == 0): val_log = self._valid_epoch(epoch) log.update(val_log) # step lr scheduler if isinstance(self.lr_scheduler, MyReduceLROnPlateau): self.lr_scheduler.step(self.valid_loss.avg(self.loss_name)) return log @staticmethod def get_metric_message(metrics, metric_names): metrics_avg = [metrics.avg(name) for name in metric_names] message_metrics = ', '.join(['{}: {:.6f}'.format(x, y) for x, y in zip(metric_names, metrics_avg)]) return message_metrics def log_for_step(self, epoch, batch_idx): message_loss = 'Train Epoch: {} [{}]/[{}] Dice Loss: {:.6f}'.format(epoch, batch_idx, self.len_epoch, self.train_loss.avg(self.loss_name)) message_metrics = SegmentationTrainer.get_metric_message(self.train_metrics, self.metric_names) self.logger.info(message_loss) self.logger.info(message_metrics) def _valid_epoch(self, epoch, save_result=False, save_for_visual=False): self.model.eval() self.valid_loss.reset() self.valid_metrics.reset() self.logger.info('Validation: ') with torch.no_grad(): for batch_idx, (data, target, image_name) in enumerate(self.valid_data_loader): data, target = data.to(self.device), target.to(self.device) output = self.model(data) loss = self.criterion(output, target) self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_loss.update(self.loss_name, loss.item()) for metric in self.metrics: self.valid_metrics.update(metric.__name__, metric(output, target), n=output.shape[0]) if save_result: save_output(output, image_name, epoch, os.path.join(self.checkpoint_dir, 'tracker'), percent=1) if save_for_visual: save_mask2image(output, image_name, os.path.join(self.checkpoint_dir, 'output')) save_mask2image(target, image_name, os.path.join(self.checkpoint_dir, 'target')) if batch_idx % self.log_step == 0: self.logger.debug('{}/{}'.format(batch_idx, len(self.valid_data_loader))) self.logger.debug('{}: {}'.format(self.loss_name, self.valid_loss.avg(self.loss_name))) self.logger.debug(SegmentationTrainer.get_metric_message(self.valid_metrics, self.metric_names)) log = self.valid_loss.result() log.update(self.valid_metrics.result()) val_log = {'val_{}'.format(k): v for k, v in log.items()} return val_log def _test_epoch(self, epoch, save_result=False, save_for_visual=False): self.model.eval() self.test_loss.reset() self.test_metrics.reset() self.logger.info('Test: ') with torch.no_grad(): for batch_idx, (data, target, image_name) in enumerate(self.test_data_loader): data, target = data.to(self.device), target.to(self.device) output = self.model(data) loss = self.criterion(output, target) self.writer.set_step((epoch - 1) * len(self.test_data_loader) + batch_idx, 'test') self.test_loss.update(self.loss_name, loss.item()) for metric in self.metrics: self.test_metrics.update(metric.__name__, metric(output, target), n=output.shape[0]) if save_result: save_output(output, image_name, epoch, os.path.join(self.checkpoint_dir, 'tracker'), percent=1) if save_for_visual: save_mask2image(output, image_name, os.path.join(self.checkpoint_dir, 'output')) save_mask2image(target, image_name, os.path.join(self.checkpoint_dir, 'target')) if batch_idx % self.log_step == 0: self.logger.debug('{}/{}'.format(batch_idx, len(self.test_data_loader))) self.logger.debug('{}: {}'.format(self.loss_name, self.test_loss.avg(self.loss_name))) self.logger.debug(SegmentationTrainer.get_metric_message(self.test_metrics, self.metric_names)) log = self.test_loss.result() log.update(self.test_metrics.result()) test_log = {'test_{}'.format(k): v for k, v in log.items()} return test_log
class QuicknatLIDCTrainer(BaseTrainer): """ Trainer class """ def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None, experiment=None): super().__init__(model, criterion, metric_ftns, optimizer, config, experiment) self.config = config self.data_loader = data_loader if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler self.log_step = int(np.sqrt(data_loader.batch_size)) self.train_metrics = MetricTracker('loss', writer=self.writer) self.valid_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.metrics_sample_count = config['trainer']['metrics_sample_count'] def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.model.enable_test_dropout() self.train_metrics.reset() for batch_idx, (data, target) in enumerate(self.data_loader): # shape data: [B x 1 x H x W] # shape target: [B x 4 x H x W] data, target = data.to(self.device), target.to(self.device) rand_idx = np.random.randint(0, 4) target = target[:, rand_idx, ...] self.optimizer.zero_grad() output = self.model(data) loss = self.criterion(output, target) loss.backward() self.optimizer.step() # self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item()) if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( epoch, self._progress(batch_idx), loss.item())) if batch_idx == self.len_epoch: break log = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.model.enable_test_dropout() self.valid_metrics.reset() with torch.no_grad(): for batch_idx, (data, targets) in enumerate(self.valid_data_loader): data, targets = data.to(self.device), targets.to(self.device) rand_idx = np.random.randint(0, 4) target = targets[:, rand_idx, ...] targets = targets.unsqueeze(2) # self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') # Loss output = self.model(data) loss = self.criterion(output, target) self.valid_metrics.update('loss', loss.item()) # Sampling samples = self._sample( self.model, data) # [BATCH_SIZE x SAMPLE_SIZE x NUM_CHANNELS x H x W] for met in self.metric_ftns: self.valid_metrics.update(met.__name__, met(samples, targets)) self._visualize_batch(batch_idx, samples, targets) # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result() def _sample(self, model, data): num_samples = self.metrics_sample_count batch_size, num_channels, image_size = data.shape[0], 1, tuple( data.shape[2:]) samples = torch.zeros((batch_size, num_samples, num_channels, *image_size)).to(self.device) for i in range(num_samples): output = model(data) max_val, idx = torch.max(output, 1) sample = idx.unsqueeze(dim=1) samples[:, i, ...] = sample return samples def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total) def _visualize_batch(self, batch_idx, samples, targets): gt_titles = [f'GT_{i}' for i in range(targets.shape[1])] s_titles = [f'S_{i}' for i in range(self.metrics_sample_count)] titles = gt_titles + s_titles vis_data = torch.cat((targets, samples), dim=1) img_metric_grid = visualization.make_image_metric_grid( vis_data, enable_helper_dots=True, titles=titles) self.writer.add_image(f'segmentations_batch_idx_{batch_idx}', img_metric_grid.cpu())
class TrainerVd(BaseTrainer): """ Trainer class """ def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None): super().__init__(model, criterion, metric_ftns, optimizer, config) self.config = config self.data_loader = data_loader if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler self.log_step = int(np.sqrt(data_loader.batch_size)) self.n_batches = data_loader.n_samples / data_loader.batch_size self.n_batches_valid = valid_data_loader.n_samples / valid_data_loader.batch_size self.train_metrics = MetricTracker('loss', 'kl_cost', 'pred_cost', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker('loss', 'kl_cost', 'pred_cost', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.keys.extend(['kl_cost', 'pred_cost']) if self.do_validation: keys_val = ['val_' + k for k in self.keys] for key in self.keys + keys_val: self.log[key] = [] def _train_epoch(self, epoch, samples=10): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() for batch_idx, (data, target) in enumerate(self.data_loader): data, target = data.to(self.device), target.to(self.device) self.optimizer.zero_grad() outputs = torch.zeros(data.shape[0], self.model.output_dim, samples).to(self.device) if samples == 1: out, tkl = self.model(data) mlpdw = self._compute_loss(out, target) Edkl = tkl / self.n_batches outputs[:, :, 0] = out elif samples > 1: mlpdw_cum = 0 Edkl_cum = 0 for i in range(samples): out, tkl = self.model(data, sample=True) mlpdw_i = self._compute_loss(out, target) Edkl_i = tkl / self.n_batches mlpdw_cum = mlpdw_cum + mlpdw_i Edkl_cum = Edkl_cum + Edkl_i outputs[:, :, i] = out mlpdw = mlpdw_cum / samples Edkl = Edkl_cum / samples mean = torch.mean(outputs, dim=2) loss = Edkl + mlpdw loss.backward() self.optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item(), n=len(target)) self.train_metrics.update('kl_cost', Edkl.item(), n=len(target)) self.train_metrics.update('pred_cost', mlpdw.item(), n=len(target)) for met in self.metric_ftns: self._compute_metric(self.train_metrics, met, outputs, target) if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( epoch, self._progress(batch_idx), loss.item())) # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) if batch_idx == self.len_epoch: break log = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_'+k : v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() return log def _valid_epoch(self, epoch, samples=100): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() with torch.no_grad(): for batch_idx, (data, target) in enumerate(self.valid_data_loader): data, target = data.to(self.device), target.to(self.device) loss = 0 outputs = torch.zeros(data.shape[0], self.model.output_dim, samples).to(self.device) if samples == 1: out, tkl = self.model(data) mlpdw = self._compute_loss(out, target) Edkl = tkl / self.n_batches_valid outputs[:, :, 0] = out elif samples > 1: mlpdw_cum = 0 Edkl_cum = 0 for i in range(samples): out, tkl = self.model(data, sample=True) mlpdw_i = self._compute_loss(out, target) Edkl_i = tkl / self.n_batches_valid mlpdw_cum = mlpdw_cum + mlpdw_i Edkl_cum = Edkl_cum + Edkl_i outputs[:, :, i] = out mlpdw = mlpdw_cum / samples Edkl = Edkl_cum / samples mean = torch.mean(outputs, dim=2) loss = Edkl + mlpdw self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss.item(), n=len(target)) self.valid_metrics.update('kl_cost', Edkl.item(), n=len(target)) self.valid_metrics.update('pred_cost', mlpdw.item(), n=len(target)) for met in self.metric_ftns: self._compute_metric(self.valid_metrics, met, outputs, target) # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result() def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total) def _compute_loss(self, output, target): if self.model.regression_type == 'h**o': loss = self.criterion(output, target, self.model.log_noise.exp(), self.model.output_dim) elif self.model.regression_type == 'hetero': loss = self.criterion(output, target, self.model.output_dim/2) else: loss = self.criterion(output, target) return loss def _compute_metric(self, metrics, met, output, target, type="VD"): if self.model.regression_type == 'h**o': metrics.update(met.__name__, met([output, self.model.log_noise.exp()], target,type)) else: metrics.update(met.__name__, met(output, target, type))
class OPUSMultitaskTrainer(BaseTrainer): """ Trainer class """ def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None, experiment=None): super().__init__(model, criterion, metric_ftns, optimizer, config, experiment) self.config = config self.data_loader = data_loader if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler self.log_step = int(np.sqrt(data_loader.batch_size)) self.train_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) for param_group in optimizer.param_groups: lr = param_group['lr'] model.cross1ss = torch.nn.Parameter(data=model.cross1ss.to( self.device), requires_grad=True) model.cross1sc = torch.nn.Parameter(data=model.cross1sc.to( self.device), requires_grad=True) model.cross1cc = torch.nn.Parameter(data=model.cross1cc.to( self.device), requires_grad=True) model.cross1cs = torch.nn.Parameter(data=model.cross1cs.to( self.device), requires_grad=True) model.cross2ss = torch.nn.Parameter(data=model.cross2ss.to( self.device), requires_grad=True) model.cross2sc = torch.nn.Parameter(data=model.cross2sc.to( self.device), requires_grad=True) model.cross2cc = torch.nn.Parameter(data=model.cross2cc.to( self.device), requires_grad=True) model.cross2cs = torch.nn.Parameter(data=model.cross2cs.to( self.device), requires_grad=True) model.cross3ss = torch.nn.Parameter(data=model.cross3ss.to( self.device), requires_grad=True) model.cross3sc = torch.nn.Parameter(data=model.cross3sc.to( self.device), requires_grad=True) model.cross3cc = torch.nn.Parameter(data=model.cross3cc.to( self.device), requires_grad=True) model.cross3cs = torch.nn.Parameter(data=model.cross3cs.to( self.device), requires_grad=True) model.crossbss = torch.nn.Parameter(data=model.crossbss.to( self.device), requires_grad=True) model.crossbsc = torch.nn.Parameter(data=model.crossbsc.to( self.device), requires_grad=True) model.crossbcc = torch.nn.Parameter(data=model.crossbcc.to( self.device), requires_grad=True) model.crossbcs = torch.nn.Parameter(data=model.crossbcs.to( self.device), requires_grad=True) # Hack: Set a different learning rate for the cross-stitch parameters optimizer.add_param_group({ 'params': [ model.cross1ss, model.cross1sc, model.cross1cs, model.cross1cc, model.cross2ss, model.cross2sc, model.cross2cs, model.cross2cc, model.cross3ss, model.cross3sc, model.cross3cs, model.cross3cc, model.crossbss, model.crossbsc, model.crossbcs, model.crossbcc ], 'lr': lr * 250 }) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() for batch_idx, (data, target_seg, target_class) in enumerate(self.data_loader): data, target_seg, target_class = data.to( self.device), target_seg.to(self.device), target_class.to( self.device) self.optimizer.zero_grad() output_seg, output_class = self.model(data) loss = self.criterion((output_seg, output_class), target_seg, target_class, epoch) loss.backward() self.optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item()) for met in self.metric_ftns: if met.__name__ == "accuracy": self.train_metrics.update(met.__name__, met(output_class, target_class)) else: self.train_metrics.update(met.__name__, met(output_seg, target_seg)) if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( epoch, self._progress(batch_idx), loss.item())) self._visualize_input(data.cpu()) if batch_idx == self.len_epoch: break log = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() with torch.no_grad(): for batch_idx, (data, target_seg, target_class) in enumerate(self.valid_data_loader): data, target_seg, target_class = data.to( self.device), target_seg.to(self.device), target_class.to( self.device) output_seg, output_class = self.model(data) loss = self.criterion((output_seg, output_class), target_seg, target_class, epoch) self.writer.set_step( (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss.item()) for met in self.metric_ftns: if met.__name__ == "accuracy": self.valid_metrics.update( met.__name__, met(output_class, target_class)) else: self.valid_metrics.update(met.__name__, met(output_seg, target_seg)) data_cpu = data.cpu() self._visualize_input(data_cpu) self._visualize_prediction(data_cpu, output_seg.cpu(), target_seg.cpu()) # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result() def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total) def _visualize_input(self, input): """format and display input data on tensorboard""" self.writer.add_image( 'input', make_grid(input[0, 0, :, :], nrow=8, normalize=True)) def _visualize_prediction(self, input, output, target): """format and display output and target data on tensorboard""" out_b1 = binary(output) out_b1 = impose_labels_on_image(input[0, 0, :, :], target[0, :, :], out_b1[0, 1, :, :]) self.writer.add_image('output', make_grid(out_b1, nrow=8, normalize=False))
class TrainerDeEnsemble(BaseTrainerEnsemble): """ Trainer class """ def __init__(self, models, criterion, metric_ftns, optimizers, config, data_loader, valid_data_loader=None, lr_schedulers=None, len_epoch=None): super().__init__(models, criterion, metric_ftns, optimizers, config) self.config = config self.data_loader = data_loader if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_schedulers = lr_schedulers self.log_step = int(np.sqrt(data_loader.batch_size)) self.n_batches = data_loader.n_samples / data_loader.batch_size self.train_metrics = MetricTracker( *['loss_' + str(i) for i in range(self.n_ensembles)], *[ m.__name__ + '_' + str(i) for m in self.metric_ftns for i in range(self.n_ensembles) ], writer=self.writer) self.valid_metrics = MetricTracker( *['loss_' + str(i) for i in range(self.n_ensembles)], *[ m.__name__ + '_' + str(i) for m in self.metric_ftns for i in range(self.n_ensembles) ], writer=self.writer) if self.do_validation: keys_val = ['val_' + k for k in self.keys] for key in self.keys + keys_val: self.log[key] = [] cfg_loss = config['trainer']['loss'] self.alpha = cfg_loss['alpha'] self.epsilon = cfg_loss['epsilon'] def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.train_metrics.reset() for i, (model, optimizer, lr_scheduler) in enumerate( zip(self.models, self.optimizers, self.lr_schedulers)): model.train() for batch_idx, (data, target) in enumerate(self.data_loader): data, target = data.to(self.device), target.to(self.device) data.requires_grad = True optimizer.zero_grad() output = model(data) nll_loss = self.criterion(output, target) nll_grad = grad(self.alpha * nll_loss, data, retain_graph=True, create_graph=True)[0] x_at = data + self.epsilon * torch.sign(nll_grad) out_at = model(x_at) nll_loss_at = self.criterion(out_at, target) loss = self.alpha * nll_loss + (1 - self.alpha) * nll_loss_at loss.backward() optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss_' + str(i), loss.item()) for met in self.metric_ftns: self.train_metrics.update(met.__name__ + '_' + str(i), met(output, target, type="DE")) if batch_idx % self.log_step == 0: self.logger.debug( 'Train Epoch: {} {} {} Loss: {:.6f}'.format( 'Net_' + str(i), epoch, self._progress(batch_idx), loss.item())) # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) if batch_idx == self.len_epoch: break if lr_scheduler is not None: lr_scheduler.step() log = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.valid_metrics.reset() for i, (model, optimizer, lr_scheduler) in enumerate( zip(self.models, self.optimizers, self.lr_schedulers)): model.eval() with torch.no_grad(): for batch_idx, (data, target) in enumerate(self.valid_data_loader): data, target = data.to(self.device), target.to(self.device) output = model(data) loss = self.criterion(output, target) self.writer.set_step( (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_metrics.update('loss_' + str(i), loss.item()) for met in self.metric_ftns: self.valid_metrics.update( met.__name__ + '_' + str(i), met(output, target, type="DE")) # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) # add histogram of model parameters to the tensorboard for name, p in model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result() def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total)
class Trainer(BaseTrainer): """ Class implementation for trainers. The class is inherited from the class BaseTrainer. """ def __init__(self, model, criterion, metricFunction, optimizer, configuration, device, dataLoader, validationDataLoader=None, learningRateScheduler=None, epochLength=None): """ Method to initialize an object of type Trainer. Parameters ---------- self : Trainer Instance of the class model : torch.nn.Module Model to be trained criterion : callable Criterion to be evaluated (This is usually the loss function to be minimized) metricFunction : callable Metric functions to evaluate model performance optimizer : torch.optim Optimizer to be used during training device : torch.device Device on which the training would be performed dataLoader : torch.utils.data.DataLoader Dataset sampler to load training data for model training validationDataLoader : torch.utils.data.DataLoader Dataset sampler to load validation data for model validation (Default value: None) learningRateScheduler : torch.optim.lr_scheduler Method to adjust learning rate (Default value: None) epochLength : int Total number of epochs for training (Default value: None) Returns ------- self : Trainer Initialized object of class Trainer """ # Initialize BaseTrainer class super().__init__(model, criterion, metricFunction, optimizer, configuration) # Save trainer configuration, device, dataLoaders, learningRateScheduler and loggingStep self.configuration = configuration self.device = device self.dataLoader = dataLoader if epochLength is None: self.epochLength = len(self.dataLoader) else: self.dataLoader = infinte_loop(dataLoader) self.epochLength = epochLength self.validationDataLoader = validationDataLoader self.performValidation = (self.validationDataLoader is not None) self.learningRateScheduler = learningRateScheduler self.loggingStep = int(np.sqrt(dataLoader.batch_size)) # Set up training and validation metrics self.trainingMetrics = MetricTracker( "loss", *[ individualMetricFunction.__name__ for individualMetricFunction in self.metricFunction ], writer=self.writer) self.validationMetrics = MetricTracker( "loss", *[ individualMetricFunction.__name__ for individualMetricFunction in self.metricFunction ], writer=self.writer) def train_epoch(self, epoch): """ Method to train a single epoch. Parameters ---------- self : Trainer Instance of the class epoch : int Current epoch number Returns ------- log : dict Average of all the metrics in a dictionary """ # Set the model to training mode and start training the model self.model.train() self.trainingMetrics.reset() print(type(self.dataLoader) is data_loader.data_loaders.JaadDataLoader) for batchId, (data, target) in enumerate(self.dataLoader): print(1) data, target = data.to(self.device), target.to(self.device) print(2) self.optimizer.zero_grad() print(3) output = self.model(data) print(4) loss = self.criteria(output, target) print(5) loss.backward() print(6) self.optimizer.step() print(7) # Update training metrics self.writer.set_step((epoch - 1) * self.epochLength + batchId) print(8) self.trainingMetrics.update("loss", loss.item()) print(9) for individualMetric in self.metricFunction: self.trainingMetrics.update(individualMetric.__name__, individualMetric(output, target)) print(10) if batchId % self.loggingStep == 0: self.logger.debug("Training Epoch: {} {} Loss: {}".format( epoch, self.progress(batchId), loss.item())) self.writer.add_image( "input", make_grid(data.cpu(), nrow=8, normalize=True)) print(11) if batchId == self.epochLength: break print(12) log = self.trainingMetrics.result() print(13) if self.performValidation: validationLog = self.validate_epoch(epoch) log.update( ** {"val_" + key: value for key, value in validationLog.items()}) print(14) if self.learningRateScheduler is not None: self.learningRateScheduler.step() return log def validate_epoch(self, epoch): """ Method to validate a single epoch. Parameters ---------- self : Trainer Instance of the class epoch : int Current epoch number Returns ------- log : dict Average of all the metrics in a dictionary """ # Set the model to evaluation mode and start validating the model self.model.eval() self.validationMetrics.reset() with torch.no_grad(): for batchId, (data, target) in enumerate(self.validationDataLoader): data, target = data.to(self.device), target.to(self.device) output = self.model(data) loss = self.criterion(output, target) # Update training metrics self.writer.set_step( (epoch - 1) * len(self.validationDataLoader) + batchId, "valid") self.validationMetrics.update("loss", loss.item()) for individualMetric in self.metricFunction: self.validationMetrics.update( individualMetric.__name__, individualMetric(output, target)) self.writer.add_image( "input", make_grid(data.cpu(), nrow=8, normalize=True)) # Update TensorBoardWriter for name, parameter in self.model.named_parameters(): self.writer.add_histogram(name, parameter, bins="auto") return self.validationMetrics.result() def progress(self, batchId): """ Method to calculate progress of training or validation. Parameters ---------- self : Trainer Instance of the class batchId : int Current batch ID Returns ------- progress : str Amount of progress """ base = "[{}/{} ({:.0f}%)]" if hasattr(self.dataLoader, "numberOfSamples"): current = batchId * self.dataLoader.batch_size total = self.dataLoader.numberOfSamples else: current = batchId total = self.epochLength return base.format(current, total, 100.0 * current / total)
class Trainer(BaseTrainer): def __init__(self, model, criterion, metric_fns, optimizer, config, data_loader, feature_index, cell_neighbor_set, drug_neighbor_set, valid_data_loader=None, test_data_loader=None, lr_scheduler=None, len_epoch=None): super().__init__(model, criterion, metric_fns, optimizer, config) self.config = config # for data self.data_loader = data_loader self.cell_neighbor_set = cell_neighbor_set self.drug_neighbor_set = drug_neighbor_set self.feature_index = feature_index if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.test_data_loader = test_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler self.log_step = int(np.sqrt(data_loader.batch_size)) self.train_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_fns], writer=self.writer) self.valid_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_fns], writer=self.writer) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() for batch_idx, (data, target) in enumerate(self.data_loader): target = target.to(self.device) output, emb_loss = self.model(*self._get_feed_dict(data)) loss = self.criterion(output, target.squeeze()) + emb_loss self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item()) with torch.no_grad(): y_pred = torch.sigmoid(output) y_pred = y_pred.cpu().detach().numpy() y_true = target.cpu().detach().numpy() for met in self.metric_fns: self.train_metrics.update(met.__name__, met(y_pred, y_true)) if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( epoch, self._progress(batch_idx), loss.item())) if batch_idx == self.len_epoch: break log = self.train_metrics.result() log['train'] = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) log['validation'] = {'val_' + k: v for k, v in val_log.items()} if self.lr_scheduler is not None: self.lr_scheduler.step() return log def _valid_epoch(self, epoch): self.model.eval() self.valid_metrics.reset() with torch.no_grad(): for batch_idx, (data, target) in enumerate(self.valid_data_loader): target = target.to(self.device) output, emb_loss = self.model(*self._get_feed_dict(data)) loss = self.criterion(output, target.squeeze()) + emb_loss self.writer.set_step( (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss.item()) y_pred = torch.sigmoid(output) y_pred = y_pred.cpu().detach().numpy() y_true = target.cpu().detach().numpy() for met in self.metric_fns: self.valid_metrics.update(met.__name__, met(y_pred, y_true)) # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result() def test(self): self.model.eval() total_loss = 0.0 total_metrics = torch.zeros(len(self.metric_fns)) with torch.no_grad(): for batch_idx, (data, target) in enumerate(self.test_data_loader): target = target.to(self.device) output, emb_loss = self.model(*self._get_feed_dict(data)) loss = self.criterion(output, target.squeeze()) + emb_loss batch_size = data.shape[0] total_loss += loss.item() * batch_size y_pred = torch.sigmoid(output) y_pred = y_pred.cpu().detach().numpy() y_true = target.cpu().detach().numpy() for i, metric in enumerate(self.metric_fns): total_metrics[i] += metric(y_pred, y_true) * batch_size test_output = { 'n_samples': len(self.test_data_loader.sampler), 'total_loss': total_loss, 'total_metrics': total_metrics } return test_output def get_save(self, save_files): result = dict() for key, value in save_files.items(): if type(value) == dict: temp = dict() for k, v in value.items(): temp[k] = v.cpu().detach().numpy() else: temp = value.cpu().detach().numpy() result[key] = temp return result def _get_feed_dict(self, data): # [batch_size] cells = data[:, self.feature_index['cell']] drugs1 = data[:, self.feature_index['drug1']] drugs2 = data[:, self.feature_index['drug2']] cells_neighbors, drugs1_neighbors, drugs2_neighbors = [], [], [] for hop in range(self.model.n_hop): cells_neighbors.append(torch.LongTensor([self.cell_neighbor_set[c][hop] \ for c in cells.numpy()]).to(self.device)) drugs1_neighbors.append(torch.LongTensor([self.drug_neighbor_set[d][hop] \ for d in drugs1.numpy()]).to(self.device)) drugs2_neighbors.append(torch.LongTensor([self.drug_neighbor_set[d][hop] \ for d in drugs2.numpy()]).to(self.device)) return cells.to(self.device), drugs1.to(self.device), drugs2.to(self.device), \ cells_neighbors, drugs1_neighbors, drugs2_neighbors def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total)
class Trainer(BaseTrainer): """ Trainer class """ def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader, valid_data_loader=None, test_data_loader=None, lr_scheduler=None, len_epoch=None): super().__init__(model, criterion, metric_ftns, optimizer, config) self.config = config self.data_loader = data_loader if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.test_data_loader = test_data_loader self.do_validation = self.valid_data_loader is not None self.do_inference = self.test_data_loader is not None self.lr_scheduler = lr_scheduler self.log_step = int(np.sqrt(data_loader.batch_size)) self.train_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.test_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() for batch_idx, data in enumerate(self.data_loader): self.optimizer.zero_grad() input_ids, attention_masks, text_lengths, labels = data if 'cuda' == self.device.type: input_ids = input_ids.cuda() if attention_masks is not None: attention_masks = attention_masks.cuda() text_lengths = text_lengths.cuda() labels = labels.cuda() preds, embedding = self.model(input_ids, attention_masks, text_lengths) preds = preds.squeeze() loss = self.criterion[0](preds, labels) loss.backward() self.optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.train_metrics.update(met.__name__, met(preds, labels)) if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss: {:.3f}'.format( epoch, self._progress(batch_idx), loss.item())) if batch_idx == self.len_epoch: break log = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) if self.do_inference: test_log = self._inference_epoch(epoch) log.update(**{'test_' + k: v for k, v in test_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() with torch.no_grad(): for batch_idx, data in enumerate(self.valid_data_loader): input_ids, attention_masks, text_lengths, labels = data if 'cuda' == self.device.type: input_ids = input_ids.cuda() if attention_masks is not None: attention_masks = attention_masks.cuda() text_lengths = text_lengths.cuda() labels = labels.cuda() preds, embedding = self.model(input_ids, attention_masks, text_lengths) preds = preds.squeeze() if self.add_graph: input_model = self.model.module if (len( self.config.config['device_id']) > 1) else self.model self.writer.writer.add_graph( input_model, [input_ids, attention_masks, text_lengths]) self.add_graph = False loss = self.criterion[0](preds, labels) self.writer.set_step( (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.valid_metrics.update(met.__name__, met(preds, labels)) # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result() def _inference_epoch(self, epoch): """ Inference after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.test_metrics.reset() with torch.no_grad(): for batch_idx, data in enumerate(self.test_data_loader): input_ids, attention_masks, text_lengths, labels = data if 'cuda' == self.device.type: input_ids = input_ids.cuda() if attention_masks is not None: attention_masks = attention_masks.cuda() text_lengths = text_lengths.cuda() labels = labels.cuda() preds, embedding = self.model(input_ids, attention_masks, text_lengths) preds = preds.squeeze() loss = self.criterion[0](preds, labels) self.writer.set_step( (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'test') self.test_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.test_metrics.update(met.__name__, met(preds, labels)) # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return self.test_metrics.result() def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total)
class Trainer(BaseTrainer): """ Trainer class """ def __init__(self, model, metric_ftns, optimizer, config, data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None, jit=False, log_images=True): super().__init__(model, metric_ftns, optimizer, config) self.config = config self.data_loader = data_loader if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler self.log_step = int(np.sqrt(data_loader.batch_length)) self.jit = jit self.log_images = log_images self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker('loss', 'log_likelihood', 'log_marginal', *[m.__name__ for m in self.metric_ftns], writer=self.writer) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ elbo = TraceGraph_ELBO(vectorize_particles=False, num_particles=4) svi = SVI(self.model.model, self.model.guide, self.optimizer, loss=elbo) self.model.train() self.train_metrics.reset() current = 0 for batch_idx, (data, target) in enumerate(self.data_loader): data, target = data.to(self.device), target.to(self.device) loss = svi.step(observations=data) / data.shape[0] self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss) for met in self.metric_ftns: metric_val = met(self.model.model, self.model.guide, data, target, 4) self.train_metrics.update(met.__name__, metric_val) current += len(target) if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( epoch, self._progress(batch_idx, current=current), loss)) if self.log_images: self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) if batch_idx == self.len_epoch: break log = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_'+k : v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step(val_log['loss']) return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ elbo = TraceGraph_ELBO(vectorize_particles=False, num_particles=4) svi = SVI(self.model.model, self.model.guide, self.optimizer, loss=elbo) imps = ImportanceSampler(self.model.model, self.model.guide, num_samples=4) self.model.eval() self.valid_metrics.reset() with torch.no_grad(): for batch_idx, (data, target) in enumerate(self.valid_data_loader): data, target = data.to(self.device), target.to(self.device) loss = svi.evaluate_loss(observations=data) / data.shape[0] imps.sample(observations=data) log_likelihood = imps.get_log_likelihood().item() / data.shape[0] log_marginal = imps.get_log_normalizer().item() / data.shape[0] self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss) self.valid_metrics.update('log_likelihood', log_likelihood) self.valid_metrics.update('log_marginal', log_marginal) for met in self.metric_ftns: metric_val = met(self.model.model, self.model.guide, data, target, 4) self.valid_metrics.update(met.__name__, metric_val) if self.log_images: self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) return self.valid_metrics.result() def _progress(self, batch_idx, current=None): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): if current is None: current = batch_idx * self.data_loader.batch_length total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total)
class Trainer(BaseTrainer): """ Trainer class """ def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None): super().__init__(model, criterion, metric_ftns, optimizer, config) self.config = config self.data_loader = data_loader if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler self.log_step = int(np.sqrt(data_loader.batch_size)) self.init_lr = config['optimizer']['args']['lr'] self.warm_up = config['trainer']['warm_up'] self.train_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns]) self.valid_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns]) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() for batch_idx, (data, target) in enumerate(self.data_loader): data, target = data.to(self.device), target.to(self.device) # Linear Learning Rate Warm-up full_batch_idx = ((epoch - 1) * len(self.data_loader) + batch_idx) if epoch - 1 < self.warm_up: for params in self.optimizer.param_groups: params['lr'] = self.init_lr / ( self.warm_up * len(self.data_loader)) * full_batch_idx lr = get_lr(self.optimizer) # -------- TRAINING LOOP -------- self.optimizer.zero_grad() output = self.model(data) loss = self.criterion(output, target) loss.backward() self.optimizer.step() # ------------------------------- self.train_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.train_metrics.update(met.__name__, met(output, target)) if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( epoch, self._progress(batch_idx), loss.item())) if batch_idx == self.len_epoch: break log = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() log.update({'lr': lr}) # Add log to WandB if not self.config['debug']: wandb.log(log) return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() with torch.no_grad(): for batch_idx, (data, target) in enumerate(self.valid_data_loader): data, target = data.to(self.device), target.to(self.device) output = self.model(data) loss = self.criterion(output, target) self.valid_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.valid_metrics.update(met.__name__, met(output, target)) return self.valid_metrics.result() def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total)
class Learning(object): def __init__(self, model, criterion, optimizer, scheduler, metric_ftns, device, num_epoch, grad_clipping, grad_accumulation_steps, early_stopping, validation_frequency, tensorboard, checkpoint_dir, resume_path): self.device, device_ids = self._prepare_device(device) # self.model = model.to(self.device) self.start_epoch = 1 if resume_path is not None: self._resume_checkpoint(resume_path) if len(device_ids) > 1: # self.model = torch.nn.DataParallel(model, device_ids=device_ids) self.model = torch.nn.DataParallel(model) # cudnn.benchmark = True self.model = model.cuda() self.criterion = criterion self.metric_ftns = metric_ftns self.optimizer = optimizer self.num_epoch = num_epoch self.scheduler = scheduler self.grad_clipping = grad_clipping self.grad_accumulation_steps = grad_accumulation_steps self.early_stopping = early_stopping self.validation_frequency =validation_frequency self.checkpoint_dir = checkpoint_dir self.best_epoch = 1 self.best_score = 0 self.writer = TensorboardWriter(os.path.join(checkpoint_dir, 'tensorboard'), tensorboard) self.train_metrics = MetricTracker('loss', writer = self.writer) self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer = self.writer) def train(self, train_dataloader): score = 0 for epoch in range(self.start_epoch, self.num_epoch+1): print("{} epoch: \t start training....".format(epoch)) start = time.time() train_result = self._train_epoch(epoch, train_dataloader) train_result.update({'time': time.time()-start}) for key, value in train_result.items(): print(' {:15s}: {}'.format(str(key), value)) # if (epoch+1) % self.validation_frequency!=0: # print("skip validation....") # continue # print('{} epoch: \t start validation....'.format(epoch)) # start = time.time() # valid_result = self._valid_epoch(epoch, valid_dataloader) # valid_result.update({'time': time.time() - start}) # for key, value in valid_result.items(): # if 'score' in key: # score = value # print(' {:15s}: {}'.format(str(key), value)) score+=0.001 self.post_processing(score, epoch) if epoch - self.best_epoch > self.early_stopping: print('WARNING: EARLY STOPPING') break def _train_epoch(self, epoch, data_loader): self.model.train() self.optimizer.zero_grad() self.train_metrics.reset() for idx, (data, target) in enumerate(data_loader): data = Variable(data.cuda()) target = [ann.to(self.device) for ann in target] output = self.model(data) loss = self.criterion(output, target) loss.backward() self.writer.set_step((epoch - 1) * len(data_loader) + idx) self.train_metrics.update('loss', loss.item()) if (idx+1) % self.grad_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clipping) self.optimizer.step() self.optimizer.zero_grad() if (idx+1) % int(np.sqrt(len(data_loader))) == 0: self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) return self.train_metrics.result() def _valid_epoch(self, epoch, data_loader): self.valid_metrics.reset() self.model.eval() with torch.no_grad(): for idx, (data, target) in enumerate(data_loader): data, target = data.to(self.device), target.to(self.device) output = self.model(data) loss = self.criterion(output, target) self.writer.set_step((epoch - 1) * len(data_loader) + idx, 'valid') self.valid_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.valid_metrics.update(met.__name__, met(output, target)) self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result() def post_processing(self, score, epoch): best = False if score > self.best_score: self.best_score = score self.best_epoch = epoch best = True print("best model: {} epoch - {:.5}".format(epoch, score)) self._save_checkpoint(epoch = epoch, save_best = best) if self.scheduler.__class__.__name__ == 'ReduceLROnPlateau': self.scheduler.step(score) else: self.scheduler.step() def _save_checkpoint(self, epoch, save_best=False): """ Saving checkpoints :param epoch: current epoch number :param save_best: if True, rename the saved checkpoint to 'model_best.pth' """ arch = type(self.model).__name__ state = { 'arch': arch, 'epoch': epoch, 'state_dict': self.get_state_dict(self.model), 'best_score': self.best_score } filename = os.path.join(self.checkpoint_dir, 'checkpoint_epoch{}.pth'.format(epoch)) torch.save(state, filename) print("Saving checkpoint: {} ...".format(filename)) if save_best: best_path = os.path.join(self.checkpoint_dir, 'model_best.pth') torch.save(state, best_path) print("Saving current best: model_best.pth ...") @staticmethod def get_state_dict(model): if type(model) == torch.nn.DataParallel: state_dict = model.module.state_dict() else: state_dict = model.state_dict() return state_dict def _resume_checkpoint(self, resume_path): resume_path = str(resume_path) print("Loading checkpoint: {} ...".format(resume_path)) checkpoint = torch.load(resume_path, map_location=lambda storage, loc: storage) self.start_epoch = checkpoint['epoch'] + 1 self.best_epoch = checkpoint['epoch'] self.best_score = checkpoint['best_score'] self.model.load_state_dict(checkpoint['state_dict']) print("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) @staticmethod def _prepare_device(device): n_gpu_use = len(device) n_gpu = torch.cuda.device_count() if n_gpu_use > 0 and n_gpu == 0: print("Warning: There\'s no GPU available on this machine, training will be performed on CPU.") n_gpu_use = 0 if n_gpu_use > n_gpu: print("Warning: The number of GPU\'s configured to use is {}, but only {} are available on this machine.".format(n_gpu_use, n_gpu)) n_gpu_use = n_gpu list_ids = device device = torch.device('cuda:{}'.format(device[0]) if n_gpu_use > 0 else 'cpu') return device, list_ids
class TrainerRetrievalAux(BaseTrainerRetrieval): """ Trainer class for retrieval with classification as extra info """ def __init__(self, model, model_text, criterion, criterion_ret, metric_ftns, optimizer, config, data_loader, font_type, valid_data_loader=None, lr_scheduler=None, len_epoch=None): super().__init__(model, model_text, criterion, criterion_ret, metric_ftns, optimizer, config) self.config = config self.data_loader = data_loader self.font_type = font_type if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler self.log_step = int(np.sqrt(data_loader.batch_size)) self.no_tasks = len(_FACTORS_IN_ORDER) list_metrics = [] for m in self.metric_ftns: for i in range(0, self.no_tasks): metric_task = f"{m.__name__}_{_FACTORS_IN_ORDER[i]}" list_metrics.append(metric_task) list_losses = [] for i in range(0, self.no_tasks): list_losses.append(f"loss_{_FACTORS_IN_ORDER[i]}") self.train_metrics = MetricTracker('loss_classification', 'accuracy_retrieval', 'loss_floor_hue', 'loss_wall_hue', 'loss_object_hue', 'loss_retrieval', 'loss_tot', 'loss_scale', 'loss_shape', 'loss_orientation', 'accuracy_floor_hue', 'accuracy_wall_hue', 'accuracy_object_hue', 'accuracy_scale', 'accuracy_shape', 'accuracy_orientation', 'accuracy', writer=self.writer) self.valid_metrics = MetricTracker('loss_classification', 'accuracy_retrieval', 'loss_floor_hue', 'loss_wall_hue', 'loss_object_hue', 'loss_retrieval', 'loss_tot', 'loss_scale', 'loss_shape', 'loss_orientation', 'accuracy_floor_hue', 'accuracy_wall_hue', 'accuracy_object_hue', 'accuracy_scale', 'accuracy_shape', 'accuracy_orientation', 'accuracy', writer=self.writer) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model_text.train() self.model.train() self.train_metrics.reset() if epoch == 1: list_of_counters = [] for i in range(0, 6): list_of_counters.append(Counter()) for batch_idx, (data, target_ret, target_init) in enumerate(self.data_loader): # import pdb; pdb.set_trace() data, target_ret = data.to(self.device), target_ret.to(self.device) target_init = target_init.to(self.device) self.optimizer.zero_grad() text_output = self.model_text(target_ret.float()) output_ret, output_init = self.model(data) loss_ret = self.criterion_ret(output_ret, text_output, 20) no_tasks = len(target_init[0]) loss_classification = 0 for i in range(0, no_tasks): output_task = output_init[i] target_task = target_init[:, i] if epoch == 1: list_of_counters[i] += Counter(target_task.tolist()) new_org = add_margin(img_list=data[0:8, :, :], labels=target_task, predictions=output_task, margins=5, idx2label=self.data_loader.idx2label_init[i], font=self.font_type, ) self.writer.add_image(f"Image_train_marg_{_FACTORS_IN_ORDER[i]}_{epoch}", torchvision.utils.make_grid(new_org), epoch) loss_task = self.criterion(output_task, target_task) loss_classification += loss_task loss_title = f"loss_{_FACTORS_IN_ORDER[i]}" self.train_metrics.update(loss_title, loss_task.item()) for met in self.metric_ftns: metric_title = f"{met.__name__}_{_FACTORS_IN_ORDER[i]}" self.train_metrics.update(metric_title, met(output_task, target_task)) loss_tot = loss_ret + loss_classification loss_tot.backward() self.optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) try: self.train_metrics.update('loss_retrieval', loss_ret.item()) except AttributeError: print("Not enough data") self.train_metrics.update('accuracy_retrieval', accuracy_retrieval(output_ret, text_output)) self.train_metrics.update('loss_classification', loss_classification.item()) self.train_metrics.update('loss_tot', loss_tot.item()) if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss tot: {:.6f}'.format( epoch, self._progress(batch_idx), loss_tot.item())) self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss ret: {:.6f}'.format( epoch, self._progress(batch_idx), loss_ret.item())) self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss classification: {:.6f}'.format( epoch, self._progress(batch_idx), loss_classification.item())) self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) if batch_idx == self.len_epoch: break #add histograms for data distribution if epoch == 1: histogram_distribution(list_of_counters, 'train') log = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_'+k : v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.model_text.eval() self.valid_metrics.reset() if epoch == 1: list_of_counters = [] for i in range(0, 6): list_of_counters.append(Counter()) with torch.no_grad(): for batch_idx, (data, target_ret, target_init) in enumerate(self.valid_data_loader): data, target_ret = data.to(self.device), target_ret.to(self.device) target_init = target_init.to(self.device) text_output = self.model_text(target_ret.float()) output_ret, output_init = self.model(data) no_tasks = len(target_init[0]) loss_ret = self.criterion_ret(output_ret, text_output, 10) loss_classification = 0 self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') for i in range(0, no_tasks): output_task = output_init[i] target_task = target_init[:, i] if epoch == 1: list_of_counters[i] += Counter(target_task.tolist()) new_org = add_margin(img_list=data[0:8, :, :], labels=target_task, predictions=output_task, margins=5, idx2label=self.data_loader.idx2label_init[i], font=self.font_type, ) self.writer.add_image(f"Image_val_marg_{_FACTORS_IN_ORDER[i]}_{epoch}", torchvision.utils.make_grid(new_org), epoch) loss_task = self.criterion(output_task, target_task) loss_classification += loss_task loss_title = f"loss_{_FACTORS_IN_ORDER[i]}" self.valid_metrics.update(loss_title, loss_task.item()) for met in self.metric_ftns: metric_title = f"{met.__name__}_{_FACTORS_IN_ORDER[i]}" self.valid_metrics.update(metric_title, met(output_task, target_task)) self.valid_metrics.update('loss_classification', loss_classification.item()) loss_tot = loss_ret + loss_classification self.valid_metrics.update('loss_tot', loss_tot.item()) # for met in self.metric_ftns: # self.valid_metrics.update(met.__name__, met(output, target, no_tasks)) self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) try: self.valid_metrics.update('loss_retrieval', loss_ret.item()) except AttributeError: print("Not enough data") self.valid_metrics.update('accuracy_retrieval', accuracy_retrieval(output_ret, text_output)) self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') for name, p in self.model_text.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result() def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total)
class MEmoRTrainer(BaseTrainer): def __init__(self, model, criterion, metric_ftns, config, data_loader, valid_data_loader=None, len_epoch=None): super().__init__(model, criterion, metric_ftns, config) self.config = config self.data_loader = data_loader if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, self.optimizer) self.log_step = 200 self.train_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() for batch_idx, data in enumerate(self.data_loader): target, U_v, U_a, U_t, U_p, M_v, M_a, M_t, target_loc, umask, seg_len, n_c = [ d.to(self.device) for d in data ] self.optimizer.zero_grad() seq_lengths = [(umask[j] == 1).nonzero().tolist()[-1][0] + 1 for j in range(len(umask))] output = self.model(U_v, U_a, U_t, U_p, M_v, M_a, M_t, seq_lengths, target_loc, seg_len, n_c) assert output.shape[0] == target.shape[0] target = target.squeeze(1) loss = self.criterion(output, target) loss.backward() self.optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.train_metrics.update(met.__name__, met(output, target)) if batch_idx % self.log_step == 0: self.logger.debug( 'Train Epoch: {} {} Loss: {:.6f} Time:{}'.format( epoch, self._progress(batch_idx), loss.item(), datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) if batch_idx == self.len_epoch: break log = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() outputs, targets = [], [] with torch.no_grad(): for batch_idx, data in enumerate(self.valid_data_loader): target, U_v, U_a, U_t, U_p, M_v, M_a, M_t, target_loc, umask, seg_len, n_c = [ d.to(self.device) for d in data ] seq_lengths = [(umask[j] == 1).nonzero().tolist()[-1][0] + 1 for j in range(len(umask))] output = self.model(U_v, U_a, U_t, U_p, M_v, M_a, M_t, seq_lengths, target_loc, seg_len, n_c) target = target.squeeze(1) loss = self.criterion(output, target) outputs.append(output.detach()) targets.append(target.detach()) self.writer.set_step( (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss.item()) outputs = torch.cat(outputs, dim=0) targets = torch.cat(targets, dim=0) for met in self.metric_ftns: self.valid_metrics.update(met.__name__, met(outputs, targets)) for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result() def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total)
class Trainer(BaseTrainer): """ Trainer class """ def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None): super().__init__(model, criterion, metric_ftns, optimizer, config) self.config = config self.data_loader = data_loader if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler if self.config['log_step'] is not None: self.log_step = self.config['log_step'] else: self.log_step = int(np.sqrt(data_loader.batch_size)) self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() for batch_idx, data in enumerate(self.data_loader): data = overlap_objects_from_batch(data, self.config['n_objects']) target = data # Is data a variable? data, target = data.to(self.device), target.to(self.device) self.optimizer.zero_grad() output = self.model(data, epoch) loss, loss_particles = self.criterion(output, target, epoch_iter=(epoch, (epoch + 1)*batch_idx), lambd=self.config["trainer"]["lambd"]) loss = loss.mean() # Note: from space implementation # optimizer_fg.zero_grad() # optimizer_bg.zero_grad() # loss.backward() # if cfg.train.clip_norm: # clip_grad_norm_(model.parameters(), cfg.train.clip_norm) loss.backward() # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1000) self.optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.train_metrics.update(met.__name__, met(output, target)) if batch_idx % self.log_step == 0: loss_particles_str = " ".join([key + ': {:.6f}, '.format(loss_particles[key].item()) for key in loss_particles]) self.logger.debug('Train Epoch: {} {} '.format(epoch, self._progress(batch_idx)) + loss_particles_str + 'Loss: {:.6f}'.format( loss.item())) self._show(data, output) if batch_idx == self.len_epoch: break log = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_'+k : v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step(loss) # self.lr_scheduler.step() #Note: If it doesn't require argument. self.writer.add_scalar('LR', self.optimizer.param_groups[0]['lr']) return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() with torch.no_grad(): for batch_idx, data in enumerate(self.valid_data_loader): data = overlap_objects_from_batch(data, self.config['n_objects']) target = data # Is data a variable? data, target = data.to(self.device), target.to(self.device) output = self.model(data, epoch=epoch) loss, loss_particles = self.criterion(output, target, epoch_iter=(epoch, (epoch + 1)*batch_idx), lambd=self.config["trainer"]["lambd"]) self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.valid_metrics.update(met.__name__, met(output, target)) self._show(data, output, train=False) # add histogram of model parameters to the tensorboard # for name, p in self.model.named_parameters(): # self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result() def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total) def _show(self, data, output, train=True): g_plot = plot_representation(output[2][:,:output[0].shape[1]].cpu()) g_plot_pred = plot_representation(output[2][:,output[0].shape[1]:].cpu()) A_plot = plot_matrix(output[4]) if output[5] is not None: B_plot = plot_matrix(output[5]) self.writer.add_image('B', make_grid(B_plot, nrow=1, normalize=False)) if output[6] is not None: u_plot = plot_representation(output[6][:, :output[6].shape[1]].cpu()) self.writer.add_image('u', make_grid(to_tensor(u_plot), nrow=1, normalize=False)) # if output[10] is not None: # TODO: Ara el torno a posar # # print(output[10][0].max(), output[-1][0].min()) # shape = output[10][0].shape # self.writer.add_image('objects', make_grid(output[10][0].permute(1, 2, 0, 3, 4).reshape(*shape[1:-2], -1, shape[-1]).cpu(), nrow=output[0].shape[1], normalize=True)) self.writer.add_image('A', make_grid(A_plot, nrow=1, normalize=False)) self.writer.add_image('g_repr', make_grid(to_tensor(g_plot), nrow=1, normalize=False)) self.writer.add_image('g_repr_pred', make_grid(to_tensor(g_plot_pred), nrow=1, normalize=False)) self.writer.add_image('input', make_grid(data[0].cpu(), nrow=data.shape[1], normalize=True)) self.writer.add_image('output_0rec', make_grid(output[0][0].cpu(), nrow=output[0].shape[1], normalize=True)) self.writer.add_image('output_1pred', make_grid(output[1][0].cpu(), nrow=output[1].shape[1], normalize=True))
class Trainer(BaseTrainer): """ Trainer class """ def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None): super().__init__(model, criterion, metric_ftns, optimizer, config) self.config = config self.data_loader = data_loader if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler self.log_step = int(np.sqrt(data_loader.batch_size)) self.train_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ start_epoch = time.time() self.model.train() self.train_metrics.reset() # print("Learning rate:", self.lr_scheduler.get_lr()) for batch_idx, (inputs, labels) in enumerate(self.data_loader): # debugging # print('Classes: ', torch.unique(labels)) face, context = inputs['face'].to( self.device), inputs['context'].to(self.device) labels = labels.to(self.device) self.optimizer.zero_grad() output = self.model(face, context) loss = self.criterion(output, labels) loss.backward() self.optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.train_metrics.update(met.__name__, met(output, labels)) if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( epoch, self._progress(batch_idx), loss.item())) self.writer.add_image( 'face', make_grid(face.cpu(), nrow=4, normalize=True)) self.writer.add_image( 'context', make_grid(context.cpu(), nrow=2, normalize=True)) for name, p in self.model.named_parameters(): if p.requires_grad and p.grad is not None: self.writer.add_histogram('grad_' + name, p.grad, bins='auto') if batch_idx == self.len_epoch: break log = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() time_elapsed = time.time() - start_epoch print('Epoch completes in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() with torch.no_grad(): for batch_idx, (inputs, labels) in enumerate(self.valid_data_loader): face, context = inputs['face'].to( self.device), inputs['context'].to(self.device) labels = labels.to(self.device) output = self.model(face, context) loss = self.criterion(output, labels) self.writer.set_step( (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.valid_metrics.update(met.__name__, met(output, labels)) self.writer.add_image( 'face', make_grid(face.cpu(), nrow=4, normalize=True)) self.writer.add_image( 'context', make_grid(context.cpu(), nrow=2, normalize=True)) # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result() def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total)
class ResNetTrainer(BaseTrainer): """ Trainer class """ def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None, experiment=None): super().__init__(model, criterion, metric_ftns, optimizer, config, experiment) self.config = config self.data_loader = data_loader if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler self.log_step = int(np.sqrt(data_loader.batch_size)) self.train_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.best_val_accuracy = 0 def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() train_confusion_matrix = torch.zeros(3, 3, dtype=torch.long) print('train epoch: ', epoch) for batch_idx, (data, label, target_class, idx) in enumerate(self.data_loader): print('train batch, item: ', batch_idx, ', ', idx) data, target_class = data.to(self.device), target_class.to( self.device) self.optimizer.zero_grad() output = self.model(data) loss = self.criterion(output, target_class) loss.backward() self.optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.train_metrics.update(met.__name__, met(output, target_class)) if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( epoch, self._progress(batch_idx), loss.item())) self._visualize_input(data.cpu()) p_cls = torch.argmax(output, dim=1) for i, t_cl in enumerate(target_class): train_confusion_matrix[p_cls[i], t_cl] += 1 if batch_idx == self.len_epoch: break print('train confusion matrix:') print(train_confusion_matrix) self._visualize_prediction(train_confusion_matrix) log = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() with torch.no_grad(): val_confusion_matrix = torch.zeros(3, 3, dtype=torch.long) print('val epoch: ', epoch) for batch_idx, (data, label, target_class, idx) in enumerate(self.valid_data_loader): print('val batch, item: ', batch_idx, ', ', idx) data, target_class = data.to(self.device), target_class.to( self.device) output = self.model(data) loss = self.criterion(output, target_class) self.writer.set_step( (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.valid_metrics.update(met.__name__, met(output, target_class)) self._visualize_input(data.cpu()) #prediction = torch.argmax(output) #self.logger.debug('val class prediction, actual: {}, {}'.format(prediction, target_class)) p_cls = torch.argmax(output, dim=1) for i, t_cl in enumerate(target_class): val_confusion_matrix[p_cls[i], t_cl] += 1 print('val confusion matrix:') print(val_confusion_matrix) self._visualize_prediction(val_confusion_matrix) # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') val_log = self.valid_metrics.result() # TODO: Super hacky way to display best val dice score. Better way possible? self.writer.set_step( (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'best_valid') val_scores = {k: v for k, v in val_log.items()} current_val_accuracy = val_scores['accuracy'] if current_val_accuracy > self.best_val_accuracy: self.best_val_accuracy = current_val_accuracy self.valid_metrics.update('accuracy', self.best_val_accuracy) return val_log def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total) def _visualize_input(self, input): """format and display input data on tensorboard""" self.writer.add_image( 'input', make_grid(input[0, 0, :, :], nrow=8, normalize=True)) def _visualize_prediction(self, matrix): """format and display output and target data on tensorboard""" out = draw_confusion_matrix(matrix) self.writer.add_image('output', make_grid(out, nrow=8, normalize=False))
class LayerwiseTrainer(BaseTrainer): """ Trainer """ def __init__(self, model: DepthwiseStudent, criterions, metric_ftns, optimizer, config, train_data_loader, valid_data_loader=None, lr_scheduler=None, weight_scheduler=None): super().__init__(model, None, metric_ftns, optimizer, config) self.config = config self.train_data_loader = train_data_loader self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.do_validation_interval = self.config['trainer'][ 'do_validation_interval'] self.lr_scheduler = lr_scheduler self.weight_scheduler = weight_scheduler self.log_step = config['trainer']['log_step'] if "len_epoch" in self.config['trainer']: # iteration-based training self.train_data_loader = inf_loop(train_data_loader) self.len_epoch = self.config['trainer']['len_epoch'] else: # epoch-based training self.len_epoch = len(self.train_data_loader) # Metrics # Train self.train_metrics = MetricTracker( 'loss', 'supervised_loss', 'kd_loss', 'hint_loss', 'teacher_loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.train_iou_metrics = CityscapesMetricTracker(writer=self.writer) self.train_teacher_iou_metrics = CityscapesMetricTracker( writer=self.writer) # Valid self.valid_metrics = MetricTracker( 'loss', 'supervised_loss', 'kd_loss', 'hint_loss', 'teacher_loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_iou_metrics = CityscapesMetricTracker(writer=self.writer) # Test self.test_metrics = MetricTracker( 'loss', 'supervised_loss', 'kd_loss', 'hint_loss', 'teacher_loss', *[m.__name__ for m in self.metric_ftns], *['teacher_' + m.__name__ for m in self.metric_ftns], writer=self.writer, ) self.test_iou_metrics = CityscapesMetricTracker(writer=self.writer) # Tracker for early stop if val miou doesn't increase self.val_iou_tracker = EarlyStopTracker('best', 'max', 0.01, 'rel') # Only used list of criterions and remove the unused property self.criterions = criterions self.criterions = nn.ModuleList(self.criterions).to(self.device) if isinstance(self.model, nn.DataParallel): self.criterions = nn.DataParallel(self.criterions) del self.criterion # Resume checkpoint if path is available in config if 'resume_path' in self.config['trainer']: self.resume(self.config['trainer']['resume_path']) def prepare_train_epoch(self, epoch, config=None): """ Prepare before training an epoch i.e. prune new layer, unfreeze some layers, create new optimizer .... :param epoch: int - indicate which epoch the trainer's in :param config: a config object that contain pruning_plan, hint, unfreeze information :return: """ # if the config is not set (training normaly, then set config to current trainer config) # if the config is set (in case you're resuming a checkpoint) then use saved config to replace # layers in student so that it would have identical archecture with saved checkpoint if config is None: config = self.config # reset_scheduler self.reset_scheduler() # there isn't any layer that would be replaced or unfreeze or set as hint then unfreeze # the whole network if (epoch == 1) and ((len(config['pruning']['pruning_plan']) + len(config['pruning']['hint']) + len(config['pruning']['unfreeze'])) == 0): self.logger.debug( 'Train a student with identical architecture with teacher') # unfreeze for param in self.model.student.parameters(): param.requires_grad = True # debug self.logger.info(self.model.dump_trainable_params()) # create optimizer for the network self.create_new_optimizer() # ignore all below stuff return # Check if there is any layer that would any update in current epoch # list of epochs that would have an update on student networks epochs = list( map( lambda x: x['epoch'], config['pruning']['pruning_plan'] + config['pruning']['hint'] + config['pruning']['unfreeze'])) # if there isn't any update then simply return if epoch not in epochs: self.logger.info('EPOCH: ' + str(epoch)) self.logger.info('There is no update ...') return # layers that would be replaced by depthwise separable conv replaced_layers = list( filter(lambda x: x['epoch'] == epoch, config['pruning']['pruning_plan'])) # layers which outputs will be used as loss hint_layers = list( map( lambda x: x['name'], filter(lambda x: x['epoch'] == epoch, config['pruning']['hint']))) # layers that would be trained in this epoch unfreeze_layers = list( map( lambda x: x['name'], filter(lambda x: x['epoch'] == epoch, config['pruning']['unfreeze']))) self.logger.info('EPOCH: ' + str(epoch)) self.logger.info('Replaced layers: ' + str(replaced_layers)) self.logger.info('Hint layers: ' + str(hint_layers)) self.logger.info('Unfreeze layers: ' + str(unfreeze_layers)) # Avoid error when loading deprecate checkpoint which don't have 'args' in config.pruning if 'args' in config['pruning']: kwargs = config['pruning']['args'] else: self.logger.warning('Using deprecate checkpoint...') kwargs = config['pruning']['pruner'] self.model.replace( replaced_layers, **kwargs) # replace those layers with depthwise separable conv self.model.register_hint_layers( hint_layers ) # assign which layers output would be used as hint loss self.model.unfreeze(unfreeze_layers) # unfreeze chosen layers if epoch == 1: self.create_new_optimizer( ) # create new optimizer to remove the effect of momentum else: self.update_optimizer( list( filter(lambda x: x['epoch'] == epoch, config['pruning']['unfreeze']))) self.logger.info(self.model.dump_trainable_params()) self.logger.info(self.model.dump_student_teacher_blocks_info()) def update_optimizer(self, unfreeze_config): """ Update param groups for optimizer with unfreezed layers of this epoch :param unfreeze_config - list of arg. Each arg is the dictionary with following format: {'name': 'layer1', 'epoch':1, 'lr'(optional): 0.01} return: """ if len(unfreeze_config) > 0: self.logger.debug('Updating optimizer for new layer') for config in unfreeze_config: layer_name = config['name'] # layer that will be unfreezed self.logger.debug( 'Add parameters of layer: {} to optimizer'.format(layer_name)) layer = self.model.get_block( layer_name, self.model.student) # actual layer i.e. nn.Module obj optimizer_arg = self.config['optimizer'][ 'args'] # default args for optimizer # we can also specify layerwise learning ! if "lr" in config: optimizer_arg['lr'] = config['lr'] # add unfreezed layer's parameters to optimizer self.optimizer.add_param_group({ 'params': layer.parameters(), **optimizer_arg }) def create_new_optimizer(self): """ Create new optimizer if trainer is in epoch 1 otherwise just run update optimizer """ # Create new optimizer self.logger.debug('Creating new optimizer ...') self.optimizer = self.config.init_obj( 'optimizer', optim_module, list( filter(lambda x: x.requires_grad, self.model.student.parameters()))) self.lr_scheduler = self.config.init_obj('lr_scheduler', optim_module.lr_scheduler, self.optimizer) def reset_scheduler(self): """ reset all schedulers, metrics, trackers, etc when unfreeze new layer :return: """ self.weight_scheduler.reset() # weight between loss self.val_iou_tracker.reset() # verify val iou would increase each time self.train_metrics.reset() # metrics for loss,... in training phase self.valid_metrics.reset() # metrics for loss,... in validating phase self.train_iou_metrics.reset() # train iou of student self.valid_iou_metrics.reset() # val iou of student self.train_teacher_iou_metrics.reset() # train iou of teacher if isinstance(self.lr_scheduler, MyReduceLROnPlateau): self.lr_scheduler.reset() def _train_epoch(self, epoch): """ Training logic for 1 epoch """ # Prepare the network i.e. unfreezed new layers, replaced new layer with depthwise separable conv, ... self.prepare_train_epoch(epoch) # reset # FIXME: # as the teacher network contain batchnorm layer and our resources are limited to train with # large batch size we ALWAYS keep bn as training mode to prevent instable problem when having # small batch size # self.model.train() self.train_iou_metrics.reset() self.train_teacher_iou_metrics.reset() self._clean_cache() for batch_idx, (data, target) in enumerate(self.train_data_loader): data, target = data.to(self.device), target.to(self.device) output_st, output_tc = self.model(data) supervised_loss = self.criterions[0]( output_st, target) / self.accumulation_steps kd_loss = self.criterions[1](output_st, output_tc) / self.accumulation_steps teacher_loss = self.criterions[0](output_tc, target) # for comparision hint_loss = reduce( lambda acc, elem: acc + self.criterions[2](elem[0], elem[1]), zip(self.model.student_hidden_outputs, self.model.teacher_hidden_outputs), 0) / self.accumulation_steps # Only use hint loss loss = hint_loss loss.backward() if batch_idx % self.accumulation_steps == 0: self.optimizer.step() self.optimizer.zero_grad() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) # update metrics self.train_metrics.update('loss', loss.item() * self.accumulation_steps) self.train_metrics.update( 'supervised_loss', supervised_loss.item() * self.accumulation_steps) self.train_metrics.update('kd_loss', kd_loss.item() * self.accumulation_steps) self.train_metrics.update( 'hint_loss', hint_loss.item() * self.accumulation_steps) self.train_metrics.update('teacher_loss', teacher_loss.item()) self.train_iou_metrics.update(output_st.detach().cpu(), target.cpu()) self.train_teacher_iou_metrics.update(output_tc.cpu(), target.cpu()) for met in self.metric_ftns: self.train_metrics.update(met.__name__, met(output_st, target)) if batch_idx % self.log_step == 0: # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) # st_masks = visualize.viz_pred_cityscapes(output_st) # tc_masks = visualize.viz_pred_cityscapes(output_tc) # self.writer.add_image('st_pred', make_grid(st_masks, nrow=8, normalize=False)) # self.writer.add_image('tc_pred', make_grid(tc_masks, nrow=8, normalize=False)) self.logger.info( 'Train Epoch: {} [{}]/[{}] Loss: {:.6f} mIoU: {:.6f} Teacher mIoU: {:.6f} Supervised Loss: {:.6f} ' 'Knowledge Distillation loss: ' '{:.6f} Hint Loss: {:.6f} Teacher Loss: {:.6f}'.format( epoch, batch_idx, self.len_epoch, self.train_metrics.avg('loss'), self.train_iou_metrics.get_iou(), self.train_teacher_iou_metrics.get_iou(), self.train_metrics.avg('supervised_loss'), self.train_metrics.avg('kd_loss'), self.train_metrics.avg('hint_loss'), self.train_metrics.avg('teacher_loss'), )) if batch_idx == self.len_epoch: break log = self.train_metrics.result() log.update( {'train_teacher_mIoU': self.train_teacher_iou_metrics.get_iou()}) log.update({'train_student_mIoU': self.train_iou_metrics.get_iou()}) if self.do_validation and ( (epoch % self.config["trainer"]["do_validation_interval"]) == 0): val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) log.update(**{'val_mIoU': self.valid_iou_metrics.get_iou()}) self.val_iou_tracker.update(self.valid_iou_metrics.get_iou()) self._teacher_student_iou_gap = self.train_teacher_iou_metrics.get_iou( ) - self.train_iou_metrics.get_iou() # step lr scheduler if (self.lr_scheduler is not None) and (not isinstance( self.lr_scheduler, MyOneCycleLR)): if isinstance(self.lr_scheduler, MyReduceLROnPlateau): self.lr_scheduler.step(self.train_metrics.avg('loss')) else: self.lr_scheduler.step() self.logger.debug('stepped lr') for param_group in self.optimizer.param_groups: self.logger.debug(param_group['lr']) # anneal weight between losses self.weight_scheduler.step() return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self._clean_cache() # FIXME: # as the teacher network contain batchnorm layer and our resources are limited to train with # large batch size we ALWAYS keep bn as training mode to prevent instable problem when having # small batch size # self.model.eval() self.model.save_hidden = False # stop saving hidden output self.valid_metrics.reset() self.valid_iou_metrics.reset() with torch.no_grad(): for batch_idx, (data, target) in enumerate(self.valid_data_loader): data, target = data.to(self.device), target.to(self.device) output = self.model.inference(data) supervised_loss = self.criterions[0](output, target) self.writer.set_step( (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_metrics.update('supervised_loss', supervised_loss.item()) self.valid_iou_metrics.update(output.detach().cpu(), target) self.logger.debug( str(batch_idx) + " : " + str(self.valid_iou_metrics.get_iou())) for met in self.metric_ftns: self.valid_metrics.update(met.__name__, met(output, target)) result = self.valid_metrics.result() result['mIoU'] = self.valid_iou_metrics.get_iou() return result def _test_epoch(self, epoch): # cleaning up memory self._clean_cache() # self.model.eval() self.model.save_hidden = False self.model.cpu() self.model.student.to(self.device) # prepare before running submission self.test_metrics.reset() self.test_iou_metrics.reset() args = self.config['test']['args'] save_4_sm = self.config['submission']['save_output'] path_output = self.config['submission']['path_output'] if save_4_sm and not os.path.exists(path_output): os.mkdir(path_output) n_samples = len(self.valid_data_loader) with torch.no_grad(): for batch_idx, (img_name, data, target) in enumerate(self.valid_data_loader): self.logger.info('{}/{}'.format(batch_idx, n_samples)) data, target = data.to(self.device), target.to(self.device) output = self.model.inference_test(data, args) if save_4_sm: self.save_for_submission(output, img_name[0]) supervised_loss = self.criterions[0](output, target) self.writer.set_step( (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'test') self.test_metrics.update('supervised_loss', supervised_loss.item()) self.test_iou_metrics.update(output.detach().cpu(), target) for met in self.metric_ftns: self.test_metrics.update(met.__name__, met(output, target)) result = self.test_metrics.result() result['mIoU'] = self.test_iou_metrics.get_iou() return result def save_for_submission(self, output, image_name, img_type=np.uint8): args = self.config['submission'] path_output = args['path_output'] image_save = '{}.{}'.format(image_name, args['ext']) path_save = os.path.join(path_output, image_save) result = torch.argmax(output, dim=1) result_mapped = self.re_map_for_submission(result) if output.size()[0] == 1: result_mapped = result_mapped[0] save_image(result_mapped.cpu().numpy().astype(img_type), path_save) print('Saved output of test data: {}'.format(image_save)) def re_map_for_submission(self, output): mapping = self.valid_data_loader.dataset.id_to_trainid cp_output = torch.zeros(output.size()) for k, v in mapping.items(): cp_output[output == v] = k return cp_output def _clean_cache(self): self.model.student_hidden_outputs, self.model.teacher_hidden_outputs = list( ), list() gc.collect() torch.cuda.empty_cache() def resume(self, checkpoint_path): self.logger.info("Loading checkpoint: {} ...".format(checkpoint_path)) checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) self.start_epoch = checkpoint['epoch'] + 1 self.mnt_best = checkpoint['monitor_best'] config = checkpoint['config'] # config of checkpoint epoch = checkpoint['epoch'] # stopped epoch # load model state from checkpoint # first, align the network by replacing depthwise separable for student for i in range(1, epoch + 1): self.prepare_train_epoch(i, config) # load weight forgiving_state_restore(self.model, checkpoint['state_dict']) self.logger.info("Loaded model's state dict") # load optimizer state from checkpoint only when optimizer type is not changed. if checkpoint['config']['optimizer']['type'] != self.config[ 'optimizer']['type']: self.logger.warning( "Warning: Optimizer type given in config file is different from that of checkpoint. " "Optimizer parameters not being resumed.") else: self.optimizer.load_state_dict(checkpoint['optimizer']) self.logger.info("Loaded optimizer state dict")
class Trainer(BaseTrainer): """ Trainer class """ def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None): super().__init__(model, criterion, metric_ftns, optimizer, config) self.config = config self.data_loader = data_loader if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler self.log_step = int(np.sqrt(data_loader.batch_size)) self.train_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() for batch_idx, (data, target) in enumerate(self.data_loader): data, target = data.to(self.device), target.to(self.device) self.optimizer.zero_grad() output = self.model(data) loss = self.criterion(output, target) loss.backward() self.optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.train_metrics.update(met.__name__, met(output, target)) if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( epoch, self._progress(batch_idx), loss.item())) self.writer.add_image( 'input', make_grid(data.cpu(), nrow=8, normalize=True)) if batch_idx == self.len_epoch: break log = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() with torch.no_grad(): for batch_idx, (data, target) in enumerate(self.valid_data_loader): data, target = data.to(self.device), target.to(self.device) output = self.model(data) loss = self.criterion(output, target) self.writer.set_step( (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.valid_metrics.update(met.__name__, met(output, target)) self.writer.add_image( 'input', make_grid(data.cpu(), nrow=8, normalize=True)) # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result() def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total)
class MAMOTrainer(BaseTrainer): """ Trainer class """ def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader, trainable_params, valid_data_loader=None, lr_scheduler=None, len_epoch=None): super().__init__(model, criterion, metric_ftns, optimizer, config) self.config = config self.data_loader = data_loader if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler self.log_step = int(np.sqrt(data_loader.batch_size)) self.losses_num = len(self.criterion) self.max_empirical_losses = self._compute_max_expirical_losses() copsolver = AnalyticalSolver() self.common_descent_vector = MultiObjectiveCDV( copsolver=copsolver, max_empirical_losses=self.max_empirical_losses, normalized=True) self.trainable_params = trainable_params self.opt_losses = self.config['opt_losses'] self.train_metrics = MetricTracker( 'loss', 'weighted_loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker( 'loss', 'weighted_loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) def _compute_max_expirical_losses(self): max_losses = [0] * self.losses_num cnt = 0 for batch_idx, (data, target, price) in enumerate(self.data_loader): data, target, price = data.to(self.device), target.to( self.device), price.to(self.device) cnt += 1 output = self.model(data) for i in range(self.losses_num): l = self._cal_loss(self.criterion[i], output, target, price) max_losses[i] = (cnt - 1) / cnt * \ max_losses[i] + 1 / cnt * l.item() return max_losses def _cal_loss(self, c, output, target, price): para_nums = len(inspect.getargspec(c)[0]) if para_nums == 2: return c(output, target.float()) elif para_nums == 3: return c(output, target.float(), price) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() average_alpha = [0] * self.losses_num cnt = 0 for batch_idx, (data, target, price) in enumerate(self.data_loader): cnt += 1 data, target, price = Variable(data).to( self.device), Variable(target).to( self.device), Variable(price).to(self.device) losses_computed = [] if self.opt_losses == 0 or self.opt_losses == 1: output = self.model(data) for loss in self.criterion: losses_computed.append( self._cal_loss(loss, output, target, price)) self.optimizer.zero_grad() losses_computed[self.opt_losses].backward() self.optimizer.step() else: # calculate the gradients gradients = [] for i, loss in enumerate(self.criterion): # forward pass output = self.model(data) # calculate loss L = self._cal_loss(loss, output, target, price) # zero gradient self.optimizer.zero_grad() # backward pass L.backward() # get gradient for correctness objective gradients.append(self.optimizer.get_gradient()) # calculate the losses # forward pass output = self.model(data) for i, loss in enumerate(self.criterion): L = self._cal_loss(loss, output, target, price) losses_computed.append(L) # get the final loss to compute the common descent vector final_loss, alphas = self.common_descent_vector.get_descent_vector( losses_computed, gradients) # moving average alpha for i, alpha in enumerate(alphas): average_alpha[i] = (cnt - 1) / cnt * \ average_alpha[i] + 1 / cnt * alpha # zero gradient self.optimizer.zero_grad() # backward pass final_loss.backward() # update parameters self.optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', losses_computed[0].item()) self.train_metrics.update('weighted_loss', losses_computed[1].item()) for met in self.metric_ftns: para_nums = len(inspect.getargspec(met)[0]) if para_nums == 2: self.train_metrics.update(met.__name__, met(output, target)) elif para_nums == 3: self.train_metrics.update(met.__name__, met(output, target, price)) if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( epoch, self._progress(batch_idx), losses_computed[0].item())) if batch_idx == self.len_epoch: break if self.opt_losses == 0: print("Optimize only logloss") elif self.opt_losses == 1: print("Optimize only weighted logloss") else: print("Optimize both logloss and weighted logloss") print(average_alpha) log = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() with torch.no_grad(): for batch_idx, (data, target, price) in enumerate(self.valid_data_loader): data, target = data.to(self.device), target.to(self.device) output = self.model(data) loss = self.criterion[0](output, target.float()) self.writer.set_step( (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss.item()) w_loss = self.criterion[1](output, target.float(), price) self.valid_metrics.update('weighted_loss', w_loss.item()) for met in self.metric_ftns: para_nums = len(inspect.getargspec(met)[0]) if para_nums == 2: self.valid_metrics.update(met.__name__, met(output, target)) elif para_nums == 3: self.valid_metrics.update(met.__name__, met(output, target, price)) return self.valid_metrics.result() def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total)