def evaluate_one_epoch(self, epoch_count):
        self.current_loss = None
        self.net.eval()  # set model to eval mode (for bn and dp)
        iou_calc = IoUCalculator(self.config)
        for batch_idx, batch_data in enumerate(self.test_dataloader):
            t_start = time.time()
            for key in batch_data:
                if type(batch_data[key]) is list:
                    for i in range(len(batch_data[key])):
                        batch_data[key][i] = batch_data[key][i].cuda()
                else:
                    batch_data[key] = batch_data[key].cuda()

            xyz = batch_data['xyz']  # (batch,N,3)
            neigh_idx = batch_data['neigh_idx']  # (batch,N,16)
            sub_idx = batch_data['sub_idx']  # (batch,N/4,16)
            interp_idx = batch_data['interp_idx']  # (batch,N,1)
            features = batch_data['features']  # (batch, 3, N)
            labels = batch_data['labels']  # (batch, N)
            input_inds = batch_data['input_inds']  # (batch, N)
            cloud_inds = batch_data['cloud_inds']  # (batch, 1)

            # Forward pass
            with torch.no_grad():
                self.out = self.net(xyz, neigh_idx, sub_idx, interp_idx,
                                    features, labels, input_inds, cloud_inds)

            self.loss, self.end_points['valid_logits'], self.end_points[
                'valid_labels'] = compute_loss(self.out, labels, self.config)
            self.end_points['loss'] = self.loss
            # self.writer.add_scalar('eval loss', self.loss, (epoch_count* len(self.test_dataloader) + batch_idx))
            self.acc = compute_acc(self.end_points['valid_logits'],
                                   self.end_points['valid_labels'])
            self.end_points['acc'] = self.acc
            # self.writer.add_scalar('eval acc', self.acc, (epoch_count* len(self.test_dataloader) + batch_idx))
            iou_calc.add_data(self.end_points['valid_logits'],
                              self.end_points['valid_labels'])

            # Accumulate statistics and print out
            for key in self.end_points:
                if 'loss' in key or 'acc' in key or 'iou' in key:
                    if key not in self.stat_dict:
                        self.stat_dict[key] = 0
                    self.stat_dict[key] += self.end_points[key].item()

            t_end = time.time()

            batch_interval = 10
            if (batch_idx + 1) % batch_interval == 0:
                log_out(
                    ' ----step %08d batch: %08d ----' %
                    (epoch_count * len(self.test_dataloader) + batch_idx + 1,
                     (batch_idx + 1)), self.f_out)

        for key in sorted(self.stat_dict.keys()):
            log_out(
                'mean %s: %f---%f ms' %
                (key, self.stat_dict[key] / batch_interval, 1000 *
                 (t_end - t_start)), self.f_out)
            self.writer.add_scalar(
                'eval mean {}'.format(key),
                self.stat_dict[key] / (float(batch_idx + 1)),
                (epoch_count * len(self.test_dataloader)))
        mean_iou, iou_list = iou_calc.compute_iou()
        self.writer.add_scalar('eval mean iou', mean_iou,
                               (epoch_count * len(self.test_dataloader)))
        log_out('eval mean IoU:{:.1f}'.format(mean_iou * 100), self.f_out)
        s = 'eval IoU:'
        for iou_tmp in iou_list:
            s += '{:5.2f} '.format(100 * iou_tmp)
        log_out(s, self.f_out)
        self.writer.close()

        current_loss = self.stat_dict['loss'] / (float(batch_idx + 1))
        return current_loss
    def train_one_epoch(self, epoch_count):
        self.stat_dict = {}  # collect statistics
        self.adjust_learning_rate(epoch_count)
        self.net.train()  # set model to training mode
        iou_calc = IoUCalculator(self.config)
        for batch_idx, batch_data in enumerate(self.train_dataloader):
            t_start = time.time()
            for key in batch_data:
                if type(batch_data[key]) is list:
                    for i in range(len(batch_data[key])):
                        batch_data[key][i] = batch_data[key][i].cuda()
                else:
                    batch_data[key] = batch_data[key].cuda()

            xyz = batch_data['xyz']  # (batch,N,3)
            neigh_idx = batch_data['neigh_idx']  # (batch,N,16)
            sub_idx = batch_data['sub_idx']  # (batch,N/4,16)
            interp_idx = batch_data['interp_idx']  # (batch,N,1)
            features = batch_data['features']  # (batch, 3, N)
            labels = batch_data['labels']  # (batch, N)
            input_inds = batch_data['input_inds']  # (batch, N)
            cloud_inds = batch_data['cloud_inds']  # (batch, 1)

            # Forward pass
            self.optimizer.zero_grad()
            self.out = self.net(xyz, neigh_idx, sub_idx, interp_idx, features,
                                labels, input_inds, cloud_inds)

            self.loss, self.end_points['valid_logits'], self.end_points[
                'valid_labels'] = compute_loss(self.out, labels, self.config)
            self.end_points['loss'] = self.loss
            # self.writer.add_graph(self.net, input_to_model=[xyz, neigh_idx, sub_idx, interp_idx, features, labels, input_inds, cloud_inds])
            self.writer.add_scalar(
                'training loss', self.loss,
                (epoch_count * len(self.train_dataloader) + batch_idx))

            self.loss.backward()
            self.optimizer.step()

            self.acc = compute_acc(self.end_points['valid_logits'],
                                   self.end_points['valid_labels'])
            self.end_points['acc'] = self.acc
            self.writer.add_scalar(
                'training accuracy', self.acc,
                (epoch_count * len(self.train_dataloader) + batch_idx))
            iou_calc.add_data(self.end_points['valid_logits'],
                              self.end_points['valid_labels'])

            for key in self.end_points:
                if 'loss' in key or 'acc' in key or 'iou' in key:
                    if key not in self.stat_dict:
                        self.stat_dict[key] = 0
                    self.stat_dict[key] += self.end_points[key].item()
            t_end = time.time()

            batch_interval = 10
            if (batch_idx + 1) % batch_interval == 0:
                log_out(
                    ' ----step %08d batch: %08d ----' %
                    (epoch_count * len(self.train_dataloader) + batch_idx + 1,
                     (batch_idx + 1)), self.f_out)
                for key in sorted(self.stat_dict.keys()):
                    log_out(
                        'mean %s: %f---%f ms' %
                        (key, self.stat_dict[key] / batch_interval, 1000 *
                         (t_end - t_start)), self.f_out)
                    self.writer.add_scalar(
                        'training mean {}'.format(key),
                        self.stat_dict[key] / batch_interval,
                        (epoch_count * len(self.train_dataloader) + batch_idx))
                    self.stat_dict[key] = 0

            for name, param in self.net.named_parameters():
                self.writer.add_histogram(
                    name + '_grad', param.grad,
                    (epoch_count * len(self.train_dataloader) + batch_idx))
                self.writer.add_histogram(
                    name + '_data', param,
                    (epoch_count * len(self.train_dataloader) + batch_idx))
        mean_iou, iou_list = iou_calc.compute_iou()
        self.writer.add_scalar('training mean iou', mean_iou,
                               (epoch_count * len(self.train_dataloader)))
        log_out('training mean IoU:{:.1f}'.format(mean_iou * 100), self.f_out)
        s = 'training IoU:'
        for iou_tmp in iou_list:
            s += '{:5.2f} '.format(100 * iou_tmp)
        log_out(s, self.f_out)
        self.writer.close()