Exemple #1
0
    def test(self, model: nn.Module) -> Dict[str, Union[int, Meter]]:
        model.eval()

        loss_meter = Meter()
        acc_meter = Meter()
        num_all_samples = 0

        with torch.no_grad():

            for batch_idx, (X, y) in enumerate(self.dataset_loader):
                X, y = X.to(self.device), y.to(self.device)
                pred = model(X)
                loss = self.criterion(pred, y)
                correct = self.count_correct(pred, y)
                #
                num_samples = y.size(0)
                loss_meter.update(loss.item(), n=num_samples)
                acc_meter.update(correct.item() / num_samples, n=num_samples)
                num_all_samples += num_samples

        return {
            'loss_meter': loss_meter,
            'acc_meter': acc_meter,
            'num_samples': num_all_samples
        }
Exemple #2
0
    def solve_epochs_delta(
        self,
        round_i,
        model: Module,
        global_state: STATE_TYPE,
        num_epochs,
        hide_output: bool = False
    ) -> Tuple[Dict[str, Union[int, Meter]], Dict[str, torch.Tensor]]:
        loss_meter = Meter()
        dice_meter = Meter('dice_coeff')
        num_all_samples = 0
        optimizer = self.create_optimizer(model)
        optimizer.step_lr_scheduler(round_i=round_i)
        model.train()

        with tqdm.trange(num_epochs, disable=hide_output) as t:

            for epoch in t:
                t.set_description(
                    f'Client: {self.id}, Round: {round_i}, Epoch :{epoch}')
                for batch_idx, (X, y) in enumerate(self.dataset_loader):
                    # from IPython import embed
                    X, y = X.to(self.device), y.to(self.device)

                    optimizer.zero_grad()
                    pred = model(X)

                    loss = self.criterion(pred, y)
                    #
                    activated = torch.sigmoid(pred)
                    dice_coeff = self.compute_dice_coefficient(activated, y)
                    #
                    loss.backward()
                    # torch.nn.utils.clip_grad_norm(self.model.parameters(), 60)
                    optimizer.step()
                    num_samples = y.size(0)
                    num_all_samples += num_samples
                    loss_meter.update(loss.item(), n=num_samples)
                    dice_meter.update(dice_coeff.item(), n=num_samples)
                    if (batch_idx % 10 == 0):
                        # 纯数值, 这里使用平均的损失
                        t.set_postfix(mean_loss=loss.item())
        # 返回参数
        state_dict = model.state_dict()
        result = {
            'loss_meter': loss_meter,
            'dice_coeff_meter': dice_meter,
            'num_samples': num_all_samples,
            'lr': optimizer.get_current_lr()
        }
        # 计算差值 latest  - init
        for k, v in state_dict.items():
            v.sub_(global_state[k])
        # 输出相关的参数
        return result, state_dict
Exemple #3
0
    def solve_epochs_with_global(self, round_i, model: Module, global_model: Module, num_epochs, hide_output: bool = False) -> Tuple[Dict[str, Union[int, Meter]], Dict[str, torch.Tensor]]:
        loss_meter = Meter()
        acc_meter = Meter()
        num_all_samples = 0
        optimizer = self.create_optimizer(model)
        # TODO 直接引用上一次的 global 模型, 避免复制, optimizer 的  step 中也是不记录梯度的
        optimizer.set_old_weights(old_weights=[p for p in global_model.parameters()])

        model.train()

        with tqdm.trange(num_epochs, disable=hide_output) as t:

            for epoch in t:
                t.set_description(f'Client: {self.id}, Round: {round_i}, Epoch :{epoch}')
                for batch_idx, (X, y) in enumerate(self.dataset_loader):
                    # from IPython import embed
                    X, y = X.to(self.device), y.to(self.device)

                    optimizer.zero_grad()
                    pred = model(X)

                    loss = self.criterion(pred, y)
                    loss.backward()
                    # torch.nn.utils.clip_grad_norm(self.model.parameters(), 60)
                    optimizer.step()

                    correct_sum = self.count_correct(pred, y)
                    num_samples = y.size(0)
                    num_all_samples += num_samples
                    loss_meter.update(loss.item(), n=num_samples)
                    acc_meter.update(correct_sum.item() / num_samples, n=num_samples)
                    if (batch_idx % 10 == 0):
                        # 纯数值, 这里使用平均的损失
                        t.set_postfix(mean_loss=loss.item())
        # 返回参数
        result = {
            'loss_meter': loss_meter,
            'acc_meter': acc_meter,
            'num_samples': num_all_samples
        }
        state_dict = model.state_dict()
        # 输出相关的参数
        return result, state_dict
Exemple #4
0
    def evaluate_epoch(self, epoch):
        self.logger.show_nl("Epoch: [{0}]".format(epoch))
        losses = Meter()
        len_eval = len(self.eval_loader)
        width = len(str(len_eval))
        start_pattern = "[{{:>{0}}}/{{:>{0}}}]".format(width)
        pb = tqdm(self.eval_loader)

        # Construct metrics
        metrics = (Precision(mode='accum'), Recall(mode='accum'),
                   F1Score(mode='accum'), Accuracy(mode='accum'))

        self.model.eval()

        with torch.no_grad():
            for i, (name, t1, t2, tar) in enumerate(pb):
                t1, t2, tar = self._prepare_data(t1, t2, tar)
                batch_size = tar.shape[0]

                fetch_dict = self._set_fetch_dict()
                out_dict = FeatureContainer()

                with HookHelper(self.model,
                                fetch_dict,
                                out_dict,
                                hook_type='forward_out'):
                    out = self.model(t1, t2)

                pred = self._process_model_out(out)

                loss = self.criterion(pred, tar)
                losses.update(loss.item(), n=batch_size)

                # Convert to numpy arrays
                prob = self._pred_to_prob(pred)
                prob = prob.cpu().numpy()
                cm = (prob > 0.5).astype('uint8')
                tar = tar.cpu().numpy().astype('uint8')

                for m in metrics:
                    m.update(cm, tar, n=batch_size)

                desc = (start_pattern + " Loss: {:.4f} ({:.4f})").format(
                    i + 1, len_eval, losses.val, losses.avg)
                for m in metrics:
                    desc += " {} {:.4f}".format(m.__name__, m.val)

                pb.set_description(desc)
                dump = not self.is_training or (i % max(1, len_eval // 10)
                                                == 0)
                if dump:
                    self.logger.dump(desc)

                if self.tb_on:
                    if dump:
                        for j in range(batch_size):
                            t1_, t2_ = to_array(t1[j]), to_array(t2[j])
                            t1_, t2_ = self._denorm_image(
                                t1_), self._denorm_image(t2_)
                            t1_, t2_ = self._process_input_pairs(t1_, t2_)
                            self.tb_writer.add_image("Eval/t1",
                                                     t1_,
                                                     self.eval_step,
                                                     dataformats='HWC')
                            self.tb_writer.add_image("Eval/t2",
                                                     t2_,
                                                     self.eval_step,
                                                     dataformats='HWC')
                            self.tb_writer.add_image("Eval/labels",
                                                     quantize(tar[j]),
                                                     self.eval_step,
                                                     dataformats='HW')
                            self.tb_writer.add_image("Eval/prob",
                                                     to_pseudo_color(
                                                         quantize(prob[j])),
                                                     self.eval_step,
                                                     dataformats='HWC')
                            self.tb_writer.add_image("Eval/cm",
                                                     quantize(cm[j]),
                                                     self.eval_step,
                                                     dataformats='HW')
                            for key, feats in out_dict.items():
                                for idx, feat in enumerate(feats):
                                    feat = self._process_fetched_feat(feat[j])
                                    self.tb_writer.add_image(
                                        f"Eval/{key}_{idx}",
                                        feat,
                                        self.eval_step,
                                        dataformats='HWC')
                            self.eval_step += 1
                    else:
                        self.eval_step += batch_size

                if self.save:
                    for j in range(batch_size):
                        self.save_image(name[j], quantize(cm[j]), epoch)

        if self.tb_on:
            self.tb_writer.add_scalar("Eval/loss", losses.avg, self.eval_step)
            for m in metrics:
                self.tb_writer.add_scalar(f"Eval/{m.__name__.lower()}", m.val,
                                          self.eval_step)
            self.tb_writer.flush()

        return metrics[2].val  # F1-score
Exemple #5
0
    def train_epoch(self, epoch):
        losses = Meter()
        len_train = len(self.train_loader)
        width = len(str(len_train))
        start_pattern = "[{{:>{0}}}/{{:>{0}}}]".format(width)
        pb = tqdm(self.train_loader)

        self.model.train()

        for i, (t1, t2, tar) in enumerate(pb):
            t1, t2, tar = self._prepare_data(t1, t2, tar)

            show_imgs_on_tb = self.tb_on and (i % self.tb_intvl == 0)

            fetch_dict = self._set_fetch_dict()
            out_dict = FeatureContainer()

            with HookHelper(self.model,
                            fetch_dict,
                            out_dict,
                            hook_type='forward_out'):
                out = self.model(t1, t2)

            pred = self._process_model_out(out)

            loss = self.criterion(pred, tar)
            losses.update(loss.item(), n=tar.shape[0])

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            desc = (start_pattern + " Loss: {:.4f} ({:.4f})").format(
                i + 1, len_train, losses.val, losses.avg)

            pb.set_description(desc)
            if i % max(1, len_train // 10) == 0:
                self.logger.dump(desc)

            if self.tb_on:
                # Write to tensorboard
                self.tb_writer.add_scalar("Train/running_loss", losses.val,
                                          self.train_step)
                if show_imgs_on_tb:
                    t1, t2 = to_array(t1[0]), to_array(t2[0])
                    t1, t2 = self._denorm_image(t1), self._denorm_image(t2)
                    t1, t2 = self._process_input_pairs(t1, t2)
                    self.tb_writer.add_image("Train/t1_picked",
                                             t1,
                                             self.train_step,
                                             dataformats='HWC')
                    self.tb_writer.add_image("Train/t2_picked",
                                             t2,
                                             self.train_step,
                                             dataformats='HWC')
                    self.tb_writer.add_image("Train/labels_picked",
                                             to_array(tar[0]),
                                             self.train_step,
                                             dataformats='HW')
                    for key, feats in out_dict.items():
                        for idx, feat in enumerate(feats):
                            feat = self._process_fetched_feat(feat)
                            self.tb_writer.add_image(f"Train/{key}_{idx}",
                                                     feat,
                                                     self.train_step,
                                                     dataformats='HWC')
                    self.tb_writer.flush()
                self.train_step += 1

        if self.tb_on:
            self.tb_writer.add_scalar("Train/loss", losses.avg,
                                      self.train_step)
            self.tb_writer.add_scalar("Train/lr", self.lr, self.train_step)