logger.new_epoch() # train classifier.train() epoch_trn_loss = [] epoch_vld_loss = [] epoch_vld_recall_g, epoch_vld_recall_v, epoch_vld_recall_c, epoch_vld_recall_all = [], [], [], [] for j, (trn_imgs_batch, trn_lbls_batch) in enumerate(training_loader): # move to device trn_imgs_batch_device = trn_imgs_batch.cuda() trn_lbls_batch_device = trn_lbls_batch.cuda() # lr scheduler step lr_scheduler.step(i + j / nsteps) cur_lr = lr_scheduler.get_lr() # mixup trn_imgs_batch_device_mixup, trn_lbls_batch_device_shfl, gamma = mixup( trn_imgs_batch_device, trn_lbls_batch_device, 1.) # forward pass logits_g, logits_v, logits_c = classifier(trn_imgs_batch_device_mixup) loss_g = mixup_loss(logits_g, trn_lbls_batch_device[:, 0], trn_lbls_batch_device_shfl[:, 0], gamma) loss_v = mixup_loss(logits_v, trn_lbls_batch_device[:, 1], trn_lbls_batch_device_shfl[:, 1], gamma) loss_c = mixup_loss(logits_c, trn_lbls_batch_device[:, 2], trn_lbls_batch_device_shfl[:, 2], gamma)
class Trainer: _classifier: Classifier _train_set: TinyImagenetDataset _test_set: TinyImagenetDataset _results_path: Path _device: torch.device _batch_size: int _num_workers: int _num_visual: int _aug_degree: Dict _lr: float _lr_min: float _stopper: Stopper _labels_num2txt: Dict _freeze: Dict _weight_decay: float _label_smooth: float _period_cosine: int _net_path: Path _tensorboard_path: Path _writer: SummaryWriter _optimizer: Optimizer _scheduler: CosineAnnealingWarmRestarts _loss_func_smoothed: Callable _loss_func_one_hot: nn.Module _curr_epoch: int _vis_per_batch: int def __init__(self, classifier: Classifier, train_set: TinyImagenetDataset, test_set: TinyImagenetDataset, results_path: Path, device: torch.device, batch_size: int, num_workers: int, num_visual: int, aug_degree: Dict, lr: float, lr_min: float, stopper: Stopper, labels_num2txt: Dict, freeze: Dict, weight_decay: float, label_smooth: float, period_cosine: int): self._classifier = classifier self._train_set = train_set self._test_set = test_set self._results_path = results_path self._device = device self._batch_size = batch_size self._num_workers = num_workers self._num_visual = num_visual self._aug_degree = aug_degree self._lr = lr self._lr_min = lr_min self._stopper = stopper self._labels_num2txt = labels_num2txt self._freeze = freeze self._weight_decay = weight_decay self._label_smooth = label_smooth self._period_cosine = period_cosine self._classifier.to(self._device) self._net_path, self._tensorboard_path = self._results_path / 'net', self._results_path / 'tensorboard' for folder in [self._net_path, self._tensorboard_path]: folder.mkdir(exist_ok=True, parents=True) self._writer = SummaryWriter(log_dir=str(self._tensorboard_path)) self._loss_func_smoothed = cross_entropy_with_probs self._loss_func_one_hot = nn.CrossEntropyLoss() self._optimizer = torch.optim.SGD(self._classifier.parameters(), lr=self._lr, momentum=0.9, weight_decay=self._weight_decay) self._scheduler = CosineAnnealingWarmRestarts(self._optimizer, T_0=self._period_cosine, eta_min=self._lr_min, T_mult=2) self._curr_epoch = 0 self._vis_per_batch = self._calc_vis_per_batch() def _calc_vis_per_batch(self) -> int: num_batch = ceil(len(self._test_set) / self._batch_size) return ceil(self._num_visual / num_batch) def train(self, num_epoch: int) -> Tuple[float, float, int]: accuracy_max = 0 accuracy_test = 0 loss_test = 0 for epoch in range(num_epoch): self._curr_epoch = epoch self._set_freezing() accuracy_train, loss_train = self._train_epoch() accuracy_test, loss_test = self.test() self._writer.flush() self._stopper.update(accuracy_test) meta = { 'epoch': self._curr_epoch, 'accuracy_test': accuracy_test, 'accuracy_train': accuracy_train, 'loss_test': loss_test, 'loss_train': loss_train, 'lr': self._scheduler.get_lr()[0] } self._classifier.save( self._net_path / f'net_epoch_{self._curr_epoch}.pth', meta) if accuracy_test > accuracy_max: accuracy_max = accuracy_test self._classifier.save(self._net_path / 'best.pth', meta) if self._stopper.is_need_stop(): break self._writer.close() return accuracy_test, loss_test, self._curr_epoch def _train_epoch(self) -> Tuple[float, float]: self._classifier.train() self._set_aug_train() train_loader = DataLoader(self._train_set, batch_size=self._batch_size, num_workers=self._num_workers, shuffle=True) num_batches = len(train_loader) correct = 0 train_tqdm = tqdm(train_loader, desc=f'train_{self._curr_epoch}') loss_avg = AvgMoving() for curr_batch, (images, labels, _) in enumerate(train_tqdm): self._optimizer.zero_grad() images = images.to(self._device) labels = labels.to(self._device) output = self._classifier(images) if self._label_smooth > 0: smoothed_labels = smooth_one_hot(labels=labels, num_classes=200, smoothing=self._label_smooth) loss = self._loss_func_smoothed(output, smoothed_labels) else: loss = self._loss_func_one_hot(output, labels) loss.backward() self._scheduler.step(self._curr_epoch + (curr_batch + 1) / num_batches) self._optimizer.step() loss = loss.item() loss_avg.add(loss) pred = output.argmax(dim=1) correct += torch.eq(pred, labels).sum().item() train_tqdm.set_postfix({'Avg train loss': round(loss_avg.avg, 4)}) accuracy = correct / len(train_loader.dataset) self._add_writer_metrics(loss_avg.avg, accuracy, 'train') return accuracy, loss_avg.avg def test(self) -> Tuple[float, float]: self._classifier.eval() with torch.no_grad(): test_loader = DataLoader(self._test_set, batch_size=self._batch_size, num_workers=self._num_workers, shuffle=True) correct = 0 test_tqdm = tqdm(test_loader, desc=f'test_{self._curr_epoch}', leave=False) loss_avg = AvgMoving() labels_pred_list = [] worst_pred_list = [] best_pred_list = [] some_pred_list = [] for images, labels, paths in test_tqdm: images = images.to(self._device) labels = labels.to(self._device) output = self._classifier(images) output = softmax(output, dim=1) pred = output.argmax(dim=1) correct += torch.eq(pred, labels).sum().item() loss_avg.add(self._loss_func_one_hot(output, labels).item()) labels = labels.detach().cpu().numpy() output = output.detach().cpu().numpy() pred = pred.detach().cpu().numpy() prob = output[np.arange(np.size(labels)), pred] # prepare data for cunfusion matrix and hists labels_pred = np.hstack( (labels[:, np.newaxis], pred[:, np.newaxis])) labels_pred_list.append(np.copy(labels_pred)) # prepare data for visualiztion prob_in_labels = output[np.arange(np.size(labels)), labels] worst_pred, best_pred, some_pred = self._prepare_data_for_vis( np.copy(pred), np.copy(prob), np.copy(labels), np.copy(prob_in_labels), paths) worst_pred_list.append(worst_pred) best_pred_list.append(best_pred) some_pred_list.append(some_pred) accuracy = correct / len(test_loader.dataset) self._add_writer_metrics(loss_avg.avg, accuracy, 'test') self._visual_confusion_and_hists(labels_pred_list) self._visual_gt_and_pred(worst_pred_list, 'Worst_pred_pic') self._visual_gt_and_pred(best_pred_list, 'Best_pred_pic') self._visual_gt_and_pred(some_pred_list, 'Some_pred_pic') self._classifier.train() return accuracy, loss_avg.avg def _freeze_except_k_last(self, k_last: int) -> None: num_layers = 0 for _ in self._classifier._net.children(): num_layers += 1 for i, layer in enumerate(self._classifier._net.children()): if i + k_last < num_layers: for param in layer.parameters(): param.requires_grad = False def _unfreeze_all(self) -> None: for param in self._classifier.parameters(): param.requires_grad = True def _set_freezing(self) -> None: curr_epoch = str(self._curr_epoch) if curr_epoch in tuple(self._freeze.keys()): self._unfreeze_all() self._freeze_except_k_last(self._freeze[curr_epoch]) self._optimizer = torch.optim.SGD(filter( lambda p: p.requires_grad, self._classifier.parameters()), lr=self._lr, momentum=0.9, weight_decay=self._weight_decay) self._scheduler = CosineAnnealingWarmRestarts( self._optimizer, T_0=self._period_cosine, eta_min=self._lr_min, T_mult=2) def _set_aug_train(self) -> None: curr_epoch = str(self._curr_epoch) if curr_epoch in tuple(self._aug_degree.keys()): self._train_set.set_transforms(self._aug_degree[curr_epoch]) def _add_writer_metrics(self, loss: float, accuracy: float, mode: str) -> None: self._writer.add_scalars('loss', {f'loss_{mode}': loss}, self._curr_epoch) self._writer.add_scalars('accuracy', {f'accuracy_{mode}': accuracy}, self._curr_epoch) # VISUALIZATON def _prepare_data_for_vis( self, pred: np.ndarray, prob: np.ndarray, labels: np.ndarray, prob_in_labels: np.ndarray, paths: Path) -> List[Tuple[List[Path], np.ndarray, np.ndarray]]: prob_idx = np.argsort(prob_in_labels) prob_idx_list = list() prob_idx_list.append(prob_idx[0:self._vis_per_batch]) # worst preds prob_idx_list.append(prob_idx[-1:-self._vis_per_batch - 1:-1]) # best preds prob_idx_list.append(np.random.choice( prob_idx, self._vis_per_batch)) # some preds data_list = [] for idx in prob_idx_list: pred_and_prob = np.hstack( (pred[idx, np.newaxis], prob[idx, np.newaxis])) labels_and_prob_in_labels = np.hstack( (labels[idx, np.newaxis], prob_in_labels[idx, np.newaxis])) data = ([paths[k] for k in idx], pred_and_prob, labels_and_prob_in_labels) data_list.append(data) return data_list def _visual_gt_and_pred(self, data_list: List[Tuple[List[Path], np.ndarray, np.ndarray]], txt: str) -> None: num_batch = ceil(len(self._test_set) / self._batch_size) num_elem = num_batch * self._vis_per_batch images_array = np.zeros((num_elem, 64, 64, 3), int) pred_and_prob_array = np.zeros((num_elem, 2)) labels_and_prob_in_labels_array = np.zeros((num_elem, 2)) filenames = [] k_im = 0 for i, (paths, pred_and_prob, labels_and_prob_in_labels) in enumerate(data_list): filenames += paths for path in paths: images_array[k_im, :] = np.array( Image.open(path).convert('RGB')) k_im += 1 pred_and_prob_array[i * self._vis_per_batch: (i + 1) * self._vis_per_batch, :] = \ pred_and_prob labels_and_prob_in_labels_array[i * self._vis_per_batch: (i + 1) * self._vis_per_batch, :] = \ labels_and_prob_in_labels idx = np.random.choice(range(num_elem), self._num_visual, replace=False) images_array = images_array[idx, :] pred_and_prob_array = pred_and_prob_array[idx, :] labels_and_prob_in_labels_array = labels_and_prob_in_labels_array[ idx, :] filenames = [Path(filenames[idx_curr]).stem for idx_curr in idx] height_fig = self._num_visual width_fig = 3 height_cell = 0.95 * (height_fig / self._num_visual) / height_fig width_im_cell = 1 / width_fig left_im_cell = 0 / width_fig width_txt_cell = 1.8 / width_fig left_txt_cell = 1.1 / width_fig bottom_cell = [x / height_fig for x in range(height_fig)] fig = plt.figure(figsize=(width_fig, height_fig), tight_layout=False) for k in range(self._num_visual): fig.add_axes( (left_im_cell, bottom_cell[k], width_im_cell, height_cell)) plt.axis('off') plt.imshow(images_array[k, :], aspect='auto') fig.add_axes( (left_txt_cell, bottom_cell[k], width_txt_cell, height_cell)) str_pic = f'{filenames[k]} \n' \ f'gt: {round(labels_and_prob_in_labels_array[k, 1], 2)}\n' \ f'({int(labels_and_prob_in_labels_array[k, 0])}) ' \ f'{self._labels_num2txt[labels_and_prob_in_labels_array[k, 0]]}\n' \ f'pred: {round(pred_and_prob_array[k, 1], 2)}\n' \ f'({int(pred_and_prob_array[k, 0])}) ' \ f'{self._labels_num2txt[pred_and_prob_array[k, 0]]}\n' \ plt.text(0, 0.5, str_pic, verticalalignment='center') plt.axis('off') self._writer.add_figure(txt, fig, self._curr_epoch) plt.close(fig) def _visual_confusion_and_hists( self, labels_pred_list: List[np.ndarray]) -> None: labels = np.zeros(len(self._test_set)) pred = np.zeros(len(self._test_set)) k_batch = 0 for labels_pred in labels_pred_list: diap = min((self._batch_size, np.shape(labels_pred)[0])) labels[k_batch * self._batch_size:k_batch * self._batch_size + diap] = labels_pred[:, 0] pred[k_batch * self._batch_size:k_batch * self._batch_size + diap] = labels_pred[:, 1] k_batch += 1 confusion_matrix_array = confusion_matrix(y_pred=pred, y_true=labels).astype(float) confusion_matrix_array /= 50 fig = plt.figure(figsize=(12, 12)) conf_map = plt.imshow(confusion_matrix_array, cmap="gist_heat", interpolation="nearest") plt.colorbar(mappable=conf_map) self._writer.add_figure('Confusion_matrix', fig, self._curr_epoch) plt.close(fig) self._visual_hists(confusion_matrix_array) def _visual_hists(self, confusion_matrix_array: np.ndarray) -> None: num_col = 20 correct = np.diag(confusion_matrix_array) * 100 idx_correct = np.argsort(correct) idx_best = idx_correct[-1:-num_col - 1:-1] idx_worst = idx_correct[0:num_col] self._visual_hist(np.copy(correct[idx_best]), idx_best, ylabel='Correct predicts, %', title='Best predicts', tag='Best_predicts_hist') self._visual_hist(np.copy(correct[idx_worst]), idx_worst, ylabel='Correct predicts, %', title='Worst predicts', tag='Worst_predicts_hist') def _visual_hist(self, data: np.ndarray, labels: np.ndarray, ylabel: str, title: str, tag: str) -> None: num_col = np.size(data) len_txt = 20 xlim = (0, num_col + 1) ylim = (0, max(np.max(data) * 1.1, 1e-5)) fig, ax = plt.subplots(figsize=(7, 7), facecolor='white') ax.set(ylabel=ylabel, ylim=ylim, xlim=xlim) labels_on_graph = [] for k in range(np.size(labels)): str_tick = self._labels_num2txt[labels[k]] len_str_tick = len(str_tick) if len_str_tick < len_txt: str_tick = str_tick + ' ' * (len_str_tick - len_txt) elif len_str_tick > len_txt: str_tick = str_tick[0:len_txt] labels_on_graph.append(str_tick) for i in range(num_col): val = data[i] ax.text(i + 1, val + ylim[1] * 0.01, np.round(val).astype(int), horizontalalignment='center') ax.vlines(x=i + 1, ymin=0, ymax=val, color='firebrick', alpha=0.7, linewidth=20) plt.xticks(range(1, num_col + 1), labels_on_graph, rotation=-90) plt.title(title) fig.tight_layout() self._writer.add_figure(tag, fig, self._curr_epoch) plt.close(fig)
net = model() optimizer = torch.optim.Adam(net.parameters(), lr=initial_lr) # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) # scheduler = StepLR(optimizer, initial_lr, total_epoch) scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5) # scheduler = LambdaLR(optimizer, lambda step : (1.0-step/total_epoch), last_epoch=-1) # scheduler = CosineAnnealingWarmRestarts(optimizer,T_0=5,T_mult=2) # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, # lambda step: (1.0-step/total_epoch) if step <= total_epoch else 0, last_epoch=-1) print("初始化的学习率:", optimizer.defaults['lr']) lr_list = [] # 把使用过的lr都保存下来,之后画出它的变化 for epoch in range(1, total_epoch): optimizer.zero_grad() optimizer.step() print("第%d个epoch的学习率:%f" % (epoch, optimizer.param_groups[0]['lr'])) print(scheduler.get_lr()) lr_list.append(optimizer.param_groups[0]['lr']) # lr_list.append(scheduler.get_last_lr()[0]) scheduler.step() # scheduler(epoch) # 画出lr的变化 plt.plot(list(range(1, total_epoch)), lr_list) plt.xlabel("epoch") plt.ylabel("lr") plt.title("learning rate's curve changes as epoch goes on!") plt.show()