def test(self, model: nn.Module) -> Dict[str, Union[int, Meter]]: model.eval() loss_meter = Meter() acc_meter = Meter() num_all_samples = 0 with torch.no_grad(): for batch_idx, (X, y) in enumerate(self.dataset_loader): X, y = X.to(self.device), y.to(self.device) pred = model(X) loss = self.criterion(pred, y) correct = self.count_correct(pred, y) # num_samples = y.size(0) loss_meter.update(loss.item(), n=num_samples) acc_meter.update(correct.item() / num_samples, n=num_samples) num_all_samples += num_samples return { 'loss_meter': loss_meter, 'acc_meter': acc_meter, 'num_samples': num_all_samples }
def solve_epochs_delta( self, round_i, model: Module, global_state: STATE_TYPE, num_epochs, hide_output: bool = False ) -> Tuple[Dict[str, Union[int, Meter]], Dict[str, torch.Tensor]]: loss_meter = Meter() dice_meter = Meter('dice_coeff') num_all_samples = 0 optimizer = self.create_optimizer(model) optimizer.step_lr_scheduler(round_i=round_i) model.train() with tqdm.trange(num_epochs, disable=hide_output) as t: for epoch in t: t.set_description( f'Client: {self.id}, Round: {round_i}, Epoch :{epoch}') for batch_idx, (X, y) in enumerate(self.dataset_loader): # from IPython import embed X, y = X.to(self.device), y.to(self.device) optimizer.zero_grad() pred = model(X) loss = self.criterion(pred, y) # activated = torch.sigmoid(pred) dice_coeff = self.compute_dice_coefficient(activated, y) # loss.backward() # torch.nn.utils.clip_grad_norm(self.model.parameters(), 60) optimizer.step() num_samples = y.size(0) num_all_samples += num_samples loss_meter.update(loss.item(), n=num_samples) dice_meter.update(dice_coeff.item(), n=num_samples) if (batch_idx % 10 == 0): # 纯数值, 这里使用平均的损失 t.set_postfix(mean_loss=loss.item()) # 返回参数 state_dict = model.state_dict() result = { 'loss_meter': loss_meter, 'dice_coeff_meter': dice_meter, 'num_samples': num_all_samples, 'lr': optimizer.get_current_lr() } # 计算差值 latest - init for k, v in state_dict.items(): v.sub_(global_state[k]) # 输出相关的参数 return result, state_dict
def solve_epochs_with_global(self, round_i, model: Module, global_model: Module, num_epochs, hide_output: bool = False) -> Tuple[Dict[str, Union[int, Meter]], Dict[str, torch.Tensor]]: loss_meter = Meter() acc_meter = Meter() num_all_samples = 0 optimizer = self.create_optimizer(model) # TODO 直接引用上一次的 global 模型, 避免复制, optimizer 的 step 中也是不记录梯度的 optimizer.set_old_weights(old_weights=[p for p in global_model.parameters()]) model.train() with tqdm.trange(num_epochs, disable=hide_output) as t: for epoch in t: t.set_description(f'Client: {self.id}, Round: {round_i}, Epoch :{epoch}') for batch_idx, (X, y) in enumerate(self.dataset_loader): # from IPython import embed X, y = X.to(self.device), y.to(self.device) optimizer.zero_grad() pred = model(X) loss = self.criterion(pred, y) loss.backward() # torch.nn.utils.clip_grad_norm(self.model.parameters(), 60) optimizer.step() correct_sum = self.count_correct(pred, y) num_samples = y.size(0) num_all_samples += num_samples loss_meter.update(loss.item(), n=num_samples) acc_meter.update(correct_sum.item() / num_samples, n=num_samples) if (batch_idx % 10 == 0): # 纯数值, 这里使用平均的损失 t.set_postfix(mean_loss=loss.item()) # 返回参数 result = { 'loss_meter': loss_meter, 'acc_meter': acc_meter, 'num_samples': num_all_samples } state_dict = model.state_dict() # 输出相关的参数 return result, state_dict
def evaluate_epoch(self, epoch): self.logger.show_nl("Epoch: [{0}]".format(epoch)) losses = Meter() len_eval = len(self.eval_loader) width = len(str(len_eval)) start_pattern = "[{{:>{0}}}/{{:>{0}}}]".format(width) pb = tqdm(self.eval_loader) # Construct metrics metrics = (Precision(mode='accum'), Recall(mode='accum'), F1Score(mode='accum'), Accuracy(mode='accum')) self.model.eval() with torch.no_grad(): for i, (name, t1, t2, tar) in enumerate(pb): t1, t2, tar = self._prepare_data(t1, t2, tar) batch_size = tar.shape[0] fetch_dict = self._set_fetch_dict() out_dict = FeatureContainer() with HookHelper(self.model, fetch_dict, out_dict, hook_type='forward_out'): out = self.model(t1, t2) pred = self._process_model_out(out) loss = self.criterion(pred, tar) losses.update(loss.item(), n=batch_size) # Convert to numpy arrays prob = self._pred_to_prob(pred) prob = prob.cpu().numpy() cm = (prob > 0.5).astype('uint8') tar = tar.cpu().numpy().astype('uint8') for m in metrics: m.update(cm, tar, n=batch_size) desc = (start_pattern + " Loss: {:.4f} ({:.4f})").format( i + 1, len_eval, losses.val, losses.avg) for m in metrics: desc += " {} {:.4f}".format(m.__name__, m.val) pb.set_description(desc) dump = not self.is_training or (i % max(1, len_eval // 10) == 0) if dump: self.logger.dump(desc) if self.tb_on: if dump: for j in range(batch_size): t1_, t2_ = to_array(t1[j]), to_array(t2[j]) t1_, t2_ = self._denorm_image( t1_), self._denorm_image(t2_) t1_, t2_ = self._process_input_pairs(t1_, t2_) self.tb_writer.add_image("Eval/t1", t1_, self.eval_step, dataformats='HWC') self.tb_writer.add_image("Eval/t2", t2_, self.eval_step, dataformats='HWC') self.tb_writer.add_image("Eval/labels", quantize(tar[j]), self.eval_step, dataformats='HW') self.tb_writer.add_image("Eval/prob", to_pseudo_color( quantize(prob[j])), self.eval_step, dataformats='HWC') self.tb_writer.add_image("Eval/cm", quantize(cm[j]), self.eval_step, dataformats='HW') for key, feats in out_dict.items(): for idx, feat in enumerate(feats): feat = self._process_fetched_feat(feat[j]) self.tb_writer.add_image( f"Eval/{key}_{idx}", feat, self.eval_step, dataformats='HWC') self.eval_step += 1 else: self.eval_step += batch_size if self.save: for j in range(batch_size): self.save_image(name[j], quantize(cm[j]), epoch) if self.tb_on: self.tb_writer.add_scalar("Eval/loss", losses.avg, self.eval_step) for m in metrics: self.tb_writer.add_scalar(f"Eval/{m.__name__.lower()}", m.val, self.eval_step) self.tb_writer.flush() return metrics[2].val # F1-score
def train_epoch(self, epoch): losses = Meter() len_train = len(self.train_loader) width = len(str(len_train)) start_pattern = "[{{:>{0}}}/{{:>{0}}}]".format(width) pb = tqdm(self.train_loader) self.model.train() for i, (t1, t2, tar) in enumerate(pb): t1, t2, tar = self._prepare_data(t1, t2, tar) show_imgs_on_tb = self.tb_on and (i % self.tb_intvl == 0) fetch_dict = self._set_fetch_dict() out_dict = FeatureContainer() with HookHelper(self.model, fetch_dict, out_dict, hook_type='forward_out'): out = self.model(t1, t2) pred = self._process_model_out(out) loss = self.criterion(pred, tar) losses.update(loss.item(), n=tar.shape[0]) self.optimizer.zero_grad() loss.backward() self.optimizer.step() desc = (start_pattern + " Loss: {:.4f} ({:.4f})").format( i + 1, len_train, losses.val, losses.avg) pb.set_description(desc) if i % max(1, len_train // 10) == 0: self.logger.dump(desc) if self.tb_on: # Write to tensorboard self.tb_writer.add_scalar("Train/running_loss", losses.val, self.train_step) if show_imgs_on_tb: t1, t2 = to_array(t1[0]), to_array(t2[0]) t1, t2 = self._denorm_image(t1), self._denorm_image(t2) t1, t2 = self._process_input_pairs(t1, t2) self.tb_writer.add_image("Train/t1_picked", t1, self.train_step, dataformats='HWC') self.tb_writer.add_image("Train/t2_picked", t2, self.train_step, dataformats='HWC') self.tb_writer.add_image("Train/labels_picked", to_array(tar[0]), self.train_step, dataformats='HW') for key, feats in out_dict.items(): for idx, feat in enumerate(feats): feat = self._process_fetched_feat(feat) self.tb_writer.add_image(f"Train/{key}_{idx}", feat, self.train_step, dataformats='HWC') self.tb_writer.flush() self.train_step += 1 if self.tb_on: self.tb_writer.add_scalar("Train/loss", losses.avg, self.train_step) self.tb_writer.add_scalar("Train/lr", self.lr, self.train_step)