def _valid_epoch(self,
                     eval_dataloader,
                     epoch_idx,
                     batches_seen=None,
                     loss_func=None):
        """
        完成模型一个轮次的评估

        Args:
            eval_dataloader: 评估数据
            epoch_idx: 轮次数
            batches_seen: 全局batch数
            loss_func: 损失函数

        Returns:
            float: 评估数据的平均损失值
        """
        with torch.no_grad():
            self.model.eval()
            loss_func = loss_func if loss_func is not None else self.model.calculate_loss
            losses = []
            for batch in eval_dataloader:
                batch.to_tensor(self.device)
                loss = loss_func(batch)
                self._logger.debug(loss.item())
                losses.append(loss.item())
            mean_loss = np.mean(losses)
            self._writer.add_scalar('eval loss', mean_loss, batches_seen)
            return mean_loss
    def _train_epoch(self, train_dataloader, epoch_idx, loss_func=None):
        """
        完成模型一个轮次的训练

        Args:
            train_dataloader: 训练数据
            epoch_idx: 轮次数
            loss_func: 损失函数

        Returns:
            list: 每个batch的损失的数组
        """
        self.model.train()
        loss_func = loss_func if loss_func is not None else self.model.calculate_loss
        losses = []
        for batch in train_dataloader:
            self.optimizer.zero_grad()
            batch.to_tensor(self.device)
            loss = loss_func(batch)
            self._logger.debug(loss.item())
            losses.append(loss.item())
            loss.backward()
            if self.clip_grad_norm:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.max_grad_norm)
            self.optimizer.step()
        return losses
    def _train_epoch(self,
                     train_dataloader,
                     epoch_idx,
                     batches_seen=None,
                     loss_func=None):
        """
        完成模型一个轮次的训练

        Args:
            train_dataloader: 训练数据
            epoch_idx: 轮次数
            batches_seen: 全局batch数
            loss_func: 损失函数

        Returns:
            tuple: tuple contains
                losses(list): 每个batch的损失的数组 \n
                batches_seen(int): 全局batch数
        """
        self.model.train()
        loss_func = loss_func if loss_func is not None else self.model.calculate_loss
        losses = []
        for iter_, batch in enumerate(train_dataloader):
            self.optimizer.zero_grad()
            batch.to_tensor(self.device)
            if iter_ % self.step_size2 == 0:
                perm = np.random.permutation(range(self.num_nodes))
            num_sub = int(self.num_nodes / self.num_split)
            for j in range(self.num_split):
                if j != self.num_split - 1:
                    idx = perm[j * num_sub:(j + 1) * num_sub]
                else:
                    idx = perm[j * num_sub:]
                loss = loss_func(batch, idx=idx, batches_seen=batches_seen)
                self._logger.debug(loss.item())
                losses.append(loss.item())
                batches_seen += 1
                loss.backward()
                if self.clip_grad_norm:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   self.max_grad_norm)
                self.optimizer.step()
        return losses, batches_seen