コード例 #1
0
ファイル: run.py プロジェクト: NMADALI97/Pytorch-PointCNN
def evaluate(data_loader,
             net: nn.Module,
             calc_confusion_matrix=False,
             rtn_features=False,
             html_path="training_output"):

    if config.task == "seg":
        #train_true_cls = []
        #train_pred_cls = []
        #train_true_seg = []
        #train_pred_seg = []
        #train_label_seg = []
        Iou_meter = meter.AverageValueMeter()
        avg_acc_meter = meter.AverageValueMeter()
        N = len(data_loader.dataset)
        n_sample = int(0.05 * len(data_loader.dataset))
        idx_samples = set(
            np.random.choice(np.arange(N), size=n_sample, replace=False))

        if not os.path.exists(html_path):
            os.makedirs(html_path)

    criterion = nn.CrossEntropyLoss()

    loss_meter = meter.AverageValueMeter()
    batch_time = meter.TimeMeter(True)
    epoch_time = meter.TimeMeter(True)
    acc_meter = meter.ClassErrorMeter(topk=[1, 5], accuracy=True)
    all_features = []
    all_labels = []
    num_classes = 10
    confusion_matrix_meter = None
    print("calc_confusion_matrix", calc_confusion_matrix)
    if calc_confusion_matrix:
        confusion_matrix_meter = meter.ConfusionMeter(num_classes,
                                                      normalized=True)

    net.eval()
    for i, sample in enumerate(data_loader):
        batch_data = sample[0]
        batch_labels = sample[1]
        if config.task == "seg":
            data_num = sample[2]
            #data_labels= sample[3]
            tmp_set = set(
                np.arange(config.validation.batch_size * i,
                          (config.validation.batch_size * i) +
                          batch_data.size(0)))
            tmp_set = list(idx_samples.intersection(tmp_set))
        batch_time.reset()

        batch_data = batch_data.to(config.device)
        batch_labels = batch_labels.to(config.device)

        if rtn_features:
            raw_out, return_intermediate = net.forward(batch_data, True)
            all_features.append(
                return_intermediate.view(-1, 192).detach().cpu().numpy())
            for u in list(batch_labels.cpu().numpy().reshape(-1)):
                all_labels.append(u)
        else:
            raw_out = net.forward(batch_data)

        if config.task == "cls":
            sample_num = raw_out.shape[1]
            raw_out = raw_out.view(-1, raw_out.shape[-1])
            batch_labels = batch_labels.view(-1, 1).repeat(
                1, sample_num).view(-1).long()
            loss = criterion(raw_out, batch_labels)
            if confusion_matrix_meter is not None:
                confusion_matrix_meter.add(raw_out.cpu(), target=batch_labels)
        elif config.task == "seg":
            pred_choice = raw_out.data.max(2)[1]
            xyz_points = batch_data.cpu().numpy()
            if xyz_points.shape[-1] > 3:
                xyz_points = xyz_points[:, :, :3]
            seg_label_pred = pred_choice.cpu().numpy()
            seg_label_gt = batch_labels.cpu().numpy()
            if len(tmp_set) > 0:
                all_idx = [
                    u - config.validation.batch_size *
                    (u // config.validation.batch_size) for u in tmp_set
                ]
                for kk, idx in enumerate(all_idx):

                    x, y, z = xyz_points[idx].T
                    rgb = seg_label_gt[idx]
                    fig = go.Figure(data=[
                        go.Scatter3d(x=x,
                                     y=y,
                                     z=z,
                                     mode='markers',
                                     marker=dict(size=2,
                                                 color=rgb,
                                                 colorscale='Viridis',
                                                 opacity=0.8))
                    ])
                    fig.write_html(
                        os.path.join(html_path,
                                     "file" + str(tmp_set[kk]) + "_gt.html"))

                    x, y, z = xyz_points[idx].T
                    rgb = seg_label_pred[idx]

                    fig = go.Figure(data=[
                        go.Scatter3d(x=x,
                                     y=y,
                                     z=z,
                                     mode='markers',
                                     marker=dict(size=2,
                                                 color=rgb,
                                                 colorscale='Viridis',
                                                 opacity=0.8))
                    ])
                    fig.write_html(
                        os.path.join(html_path,
                                     "file" + str(tmp_set[kk]) + "_pred.html"))

            raw_out = raw_out.view(-1, raw_out.shape[-1])
            loss = criterion(raw_out, batch_labels.view(-1).long())

        loss_meter.add(loss.item())
        acc_meter.add(raw_out.detach(), batch_labels.view(-1).long().detach())

        if config.task == "seg":
            seg_np = batch_labels.cpu().numpy()
            pred_np = pred_choice.detach().cpu().numpy()

            avg_acc_meter.add(
                metrics.balanced_accuracy_score(seg_np.reshape(-1),
                                                pred_np.reshape(-1)))
            Iou_meter.add(np.mean(calculate_shape_IoU(pred_np, seg_np, None)))

        if i % config.print_freq == 0:
            print('[{}/{}]\t'.format(i, len(data_loader)) +
                  'Batch Time %.1f\t' % batch_time.value() +
                  'Epoch Time %.1f\t' % epoch_time.value() +
                  'Loss %.4f\t' % loss_meter.value()[0] +
                  'acc(c) %.3f' % acc_meter.value(1))

    #rst = acc_meter.value(1)
    if config.task == "cls":
        print('[ Validation summary ] category acc: {}'.format(
            acc_meter.value(1)))

        if calc_confusion_matrix and not rtn_features:
            rst = loss_meter.value()[0], acc_meter.value(
                1), confusion_matrix_meter.value()

        elif rtn_features:
            if calc_confusion_matrix:
                rst = loss_meter.value()[0], acc_meter.value(
                    1), confusion_matrix_meter.value(), np.concatenate(
                        all_features, axis=0), np.array(all_labels).reshape(-1)
            else:
                rst = loss_meter.value()[0], acc_meter.value(
                    1), np.concatenate(
                        all_features, axis=0), np.array(all_labels).reshape(-1)
        else:
            rst = loss_meter.value()[0], acc_meter.value(1)

    else:
        train_acc = acc_meter.value(1)
        avg_per_class_acc = avg_acc_meter.value()[0]

        train_ious = Iou_meter.value()[0]

        outstr = '[ Validation summary ] loss: %.6f, train acc: %.6f, train avg acc: %.6f, train iou: %.6f' % (
            loss_meter.value()[0], train_acc, avg_per_class_acc,
            np.mean(train_ious))
        print(outstr)
        rst = loss_meter.value()[0], train_acc, avg_per_class_acc, np.mean(
            train_ious)

    return rst
コード例 #2
0
ファイル: run.py プロジェクト: NMADALI97/Pytorch-PointCNN
def train_epoch(data_loader, net: nn.Module, criterion, optimizer, epoch):
    global global_step
    if config.task == "seg":
        #train_true_cls = []
        #train_pred_cls = []
        #train_true_seg = []
        #train_pred_seg = []
        #train_label_seg = []
        Iou_meter = meter.AverageValueMeter()
        avg_acc_meter = meter.AverageValueMeter()
    batch_time = meter.TimeMeter(True)
    epoch_time = meter.TimeMeter(True)
    loss_meter = meter.AverageValueMeter()
    acc_meter = meter.ClassErrorMeter(topk=[1], accuracy=True)

    net.train(True)

    #################################for loop#################################
    for i, sample in enumerate(data_loader):

        #Adjust the lr dynamically:

        batch_time.reset()

        batch_data = sample[0]
        batch_labels = sample[1]
        if config.task == "seg":
            data_num = sample[2]
            #data_labels= sample[3]

        #print("xyz max:",sample[0].numpy().max(),"  xyz min:  ",sample[0].numpy().min(),"  Nan Value ",np.isnan(sample[0].numpy()).any())
        #print("label max:",sample[1].numpy().max(),"  label min:  ",sample[1].numpy().min(),"  Nan Value ",np.isnan(sample[0].numpy()).any())
        batch_time.reset()

        batch_data = batch_data.to(config.device)
        batch_labels = batch_labels.to(config.device)

        raw_out = net.forward(batch_data)

        if config.task == "cls":
            sample_num = raw_out.shape[1]
            raw_out = raw_out.view(-1, raw_out.shape[-1])
            batch_labels = batch_labels.view(-1, 1).repeat(
                1, sample_num).view(-1).long()
            loss = criterion(raw_out, batch_labels)
        elif config.task == "seg":
            pred_choice = raw_out.data.max(2)[1]

            raw_out = raw_out.view(-1, raw_out.shape[-1])
            loss = criterion(raw_out, batch_labels.view(-1).long())

        loss_meter.add(loss.item())
        acc_meter.add(raw_out.detach(), batch_labels.view(-1).long().detach())

        optimizer.zero_grad()

        #print("before backward: ",loss.item())
        loss.backward()
        #print("before backward: ",loss.item())

        optimizer.step()

        if config.task == "seg":
            seg_np = batch_labels.cpu().numpy()
            pred_np = pred_choice.detach().cpu().numpy()

            avg_acc_meter.add(
                metrics.balanced_accuracy_score(seg_np.reshape(-1),
                                                pred_np.reshape(-1)))
            Iou_meter.add(np.mean(calculate_shape_IoU(pred_np, seg_np, None)))

            #train_label_seg.append(temp_label)

        if i % config.print_freq == 0:
            print('Epoch: [{}][{}/{}]\t'.format(epoch, i, len(data_loader)) +
                  'Batch Time %.1f\t' % batch_time.value() +
                  'Epoch Time %.1f\t' % epoch_time.value() +
                  'Loss %.4f\t' % loss_meter.value()[0] +
                  'Acc(c) %.3f' % acc_meter.value(1))

        global_step = global_step + 1
    #################################for loop#################################
    if config.task == "cls":
        print('[ TRAIN summary ] epoch {}:\n'.format(epoch) +
              'category acc: {}'.format(acc_meter.value(1)))
        return loss_meter.value()[0], acc_meter.value(1)
    else:

        train_acc = acc_meter.value(1)
        avg_per_class_acc = avg_acc_meter.value()[0]

        train_ious = Iou_meter.value()[0]

        outstr = 'Train %d, loss: %.6f, train acc: %.6f, train avg acc: %.6f, train iou: %.6f' % (
            epoch, loss_meter.value()[0], train_acc, avg_per_class_acc,
            np.mean(train_ious))
        print(outstr)
        return loss_meter.value()[0], train_acc, avg_per_class_acc, np.mean(
            train_ious)
コード例 #3
0
def train_epoch(data_loader, net: nn.Module, criterion, optimizer, epoch):
    global global_step
    batch_time = meter.TimeMeter(True)
    epoch_time = meter.TimeMeter(True)
    loss_meter = meter.AverageValueMeter()
    acc_meter = meter.ClassErrorMeter(topk=[1], accuracy=True)

    net.train(True)

    #################################for loop#################################
    for i, sample in enumerate(data_loader):

        #Adjust the lr dynamically:
        lr = config.train.learning_rate_base * (math.pow(
            config.train.decay_rate, global_step // config.train.decay_steps))
        if lr < config.train.learning_rate_min:
            lr = config.train.learning_rate_min
        for g in optimizer.param_groups:
            g['lr'] = lr

        batch_time.reset()
        loss_meter.reset()

        batch_data = sample[0]
        batch_labels = sample[1]
        if config.task == "seg":
            data_num = sample[2]
        batch_time.reset()
        if config.task == "cls":
            shape = batch_data.shape
            indices = pf.get_indices(shape[0],
                                     config.dataset_setting["sample_num"],
                                     shape[1])
            indices = indices.reshape(indices.size // 2, 2)
            indices = indices[:, 0] * batch_data.shape[1] + indices[:, 1]
            indices = indices.astype(int)
            pts_fts_sampled = batch_data.view(
                -1, batch_data.shape[-1])[indices].view(
                    shape[0], config.dataset_setting["sample_num"], -1)
        else:
            shape = batch_data.shape
            indices = pf.get_indices(shape[0],
                                     config.dataset_setting["sample_num"],
                                     data_num)
            indices = indices.reshape(indices.size // 2, 2)
            indices = indices[:, 0] * batch_data.shape[1] + indices[:, 1]
            indices = indices.astype(int)
            pts_fts_sampled = batch_data.view(
                -1, batch_data.shape[-1])[indices].view(
                    shape[0], config.dataset_setting["sample_num"], -1)
            batch_labels = batch_labels.view(-1, 1)[indices].view(
                shape[0], config.dataset_setting["sample_num"])
        features_augmented = None

        xforms, rotations = pf.get_xforms(
            config.train.batch_size, config.dataset_setting["rotation_range"],
            config.dataset_setting["scaling_range"],
            config.dataset_setting["rotation_order"])
        if config.dataset_setting["data_dim"] > 3:
            points_sampled = pts_fts_sampled[:, :, :3]
            features_sampled = pts_fts_sampled[:, :, 3:]
            if config.dataset_setting["use_extra_features"]:
                if config.dataset_setting["with_normal_feature"]:
                    if config.dataset_setting["data_dim"] < 6:
                        print('Only 3D normals are supported!')
                        #exit()()
                    elif config.dataset_setting["data_dim"] == 6:
                        features_augmented = pf.augment(
                            features_sampled, rotations)
                    else:
                        normals = features_sampled[:, :, :3]
                        rest = features_sampled[:, :, 3:]
                        normals_augmented = pf.augment(normals, rotations)
                        features_augmented = torch.cat(
                            (normals_augmented, rest), dim=-1)
                else:
                    features_augmented = features_sampled
        else:
            points_sampled = pts_fts_sampled

        jitter_range = config.dataset_setting["jitter"]
        points_augmented = pf.augment(points_sampled, xforms, jitter_range)

        if (features_augmented is None):
            batch_data = points_augmented
        else:
            batch_data = torch.cat((points_augmented, features_augmented),
                                   dim=-1)

        batch_data = batch_data.to(config.device)
        batch_labels = batch_labels.to(config.device)

        raw_out = net.forward(batch_data)
        if config.task == "cls":
            sample_num = raw_out.shape[1]
            raw_out = raw_out.view(-1, raw_out.shape[-1])
            batch_labels = batch_labels.view(-1, 1).repeat(
                1, sample_num).view(-1).long()
            loss = criterion(raw_out, batch_labels)
        elif config.task == "seg":
            raw_out = raw_out.view(-1, raw_out.shape[-1])
            batch_labels = batch_labels.view(-1).long()
            loss = criterion(raw_out, batch_labels)
        loss_meter.add(loss.item())
        acc_meter.add(raw_out.detach(), batch_labels.detach())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % config.print_freq == 0:
            print('Epoch: [{}][{}/{}]\t'.format(epoch, i, len(data_loader)) +
                  'Batch Time %.1f\t' % batch_time.value() +
                  'Epoch Time %.1f\t' % epoch_time.value() +
                  'Loss %.4f\t' % loss_meter.value()[0] +
                  'Acc(c) %.3f' % acc_meter.value(1))

        global_step = global_step + 1
    #################################for loop#################################

    print('[ TRAIN summary ] epoch {}:\n'.format(epoch) +
          'category acc: {}'.format(acc_meter.value(1)))
コード例 #4
0
def evaluate(data_loader,
             net: nn.Module,
             calc_confusion_matrix=False,
             rtn_features=False):
    batch_time = meter.TimeMeter(True)
    epoch_time = meter.TimeMeter(True)
    acc_meter = meter.ClassErrorMeter(topk=[1, 5], accuracy=True)
    all_features = []
    all_labels = []
    num_classes = 40
    confusion_matrix_meter = None
    if calc_confusion_matrix:
        confusion_matrix_meter = meter.ConfusionMeter(num_classes,
                                                      normalized=True)

    net.eval()
    for i, sample in enumerate(data_loader):
        batch_data = sample[0]
        batch_labels = sample[1]
        if config.task == "seg":
            data_num = sample[2]
        batch_time.reset()
        if config.task == "cls":
            shape = batch_data.shape
            indices = pf.get_indices(shape[0],
                                     config.dataset_setting["sample_num"],
                                     shape[1])
            indices = indices.reshape(indices.size // 2, 2)
            indices = indices[:, 0] * batch_data.shape[1] + indices[:, 1]
            indices = indices.astype(int)
            pts_fts_sampled = batch_data.view(
                -1, batch_data.shape[-1])[indices].view(
                    shape[0], config.dataset_setting["sample_num"], -1)
        else:
            shape = batch_data.shape
            indices = pf.get_indices(shape[0],
                                     config.dataset_setting["sample_num"],
                                     data_num)
            indices = indices.reshape(indices.size // 2, 2)
            indices = indices[:, 0] * batch_data.shape[1] + indices[:, 1]
            indices = indices.astype(int)
            pts_fts_sampled = batch_data.view(
                -1, batch_data.shape[-1])[indices].view(
                    shape[0], config.dataset_setting["sample_num"], -1)
            batch_labels = batch_labels.view(-1, 1)[indices].view(
                shape[0], config.dataset_setting["sample_num"])
        features_augmented = None
        xforms, rotations = pf.get_xforms(
            shape[0], config.dataset_setting["rotation_range_val"],
            config.dataset_setting["scaling_range_val"],
            config.dataset_setting["rotation_order"])
        if config.dataset_setting["data_dim"] > 3:
            points_sampled = pts_fts_sampled[:, :, :3]
            features_sampled = pts_fts_sampled[:, :, 3:]
            if config.dataset_setting["use_extra_features"]:
                if config.dataset_setting["with_normal_feature"]:
                    if config.dataset_setting["data_dim"] < 6:
                        print('Only 3D normals are supported!')
                        exit()
                    elif config.dataset_setting["data_dim"] == 6:
                        features_augmented = pf.augment(
                            features_sampled, rotations)
                    else:
                        normals = features_sampled[:, :, :3]
                        rest = features_sampled[:, :, 3:]
                        normals_augmented = pf.augment(normals, rotations)
                        features_augmented = torch.cat(
                            (normals_augmented, rest), dim=-1)
                else:
                    features_augmented = features_sampled
        else:
            points_sampled = pts_fts_sampled

        jitter_range_val = config.dataset_setting["jitter_val"]
        points_augmented = pf.augment(points_sampled, xforms, jitter_range_val)

        if (features_augmented is None):
            batch_data = points_augmented
        else:
            batch_data = torch.cat((points_augmented, features_augmented),
                                   dim=-1)

        batch_data = batch_data.to(config.device)
        batch_labels = batch_labels.to(config.device)

        raw_out = net(batch_data)
        final_sample_num = raw_out.shape[1]
        raw_out = raw_out.view(-1, raw_out.shape[-1])
        if config.task == "cls":
            batch_labels = batch_labels.view(-1, 1).repeat(
                1, final_sample_num).view(-1).long()
        elif config.task == "seg":
            batch_labels = batch_labels.view(-1).long()
        acc_meter.add(raw_out.detach(), batch_labels.detach())

        if confusion_matrix_meter is not None:
            confusion_matrix_meter.add(raw_out.cpu(), target=batch_labels)

        if i % config.print_freq == 0:
            print('[{}/{}]\t'.format(i, len(data_loader)) +
                  'Batch Time %.1f\t' % batch_time.value() +
                  'Epoch Time %.1f\t' % epoch_time.value() +
                  'acc(c) %.3f' % acc_meter.value(1))
    print('[ summary ]:\n' + 'classification: {}\t'.format(acc_meter))
    rst = acc_meter.value(1)
    if calc_confusion_matrix:
        rst = rst, confusion_matrix_meter.value()
    if rtn_features:
        rst = rst, np.concatenate(all_features, axis=0), all_labels
    return rst
コード例 #5
0
ファイル: run.py プロジェクト: NMADALI97/PointCNN_Pytorch
def train_epoch(data_loader, net: nn.Module, criterion, optimizer, epoch):
    global global_step
    if config.task == "seg":
        #train_true_cls = []
        #train_pred_cls = []
        #train_true_seg = []
        #train_pred_seg = []
        #train_label_seg = []
        Iou_meter = meter.AverageValueMeter()
        avg_acc_meter = meter.AverageValueMeter()
    batch_time = meter.TimeMeter(True)
    epoch_time = meter.TimeMeter(True)
    loss_meter = meter.AverageValueMeter()
    acc_meter = meter.ClassErrorMeter(topk=[1], accuracy=True)

    net.train(True)

    #################################for loop#################################
    for i, sample in enumerate(data_loader):

        #Adjust the lr dynamically:
        lr = config.train.learning_rate_base * (math.pow(
            config.train.decay_rate, global_step // config.train.decay_steps))
        if lr < config.train.learning_rate_min:
            lr = config.train.learning_rate_min
        for g in optimizer.param_groups:
            g['lr'] = lr

        batch_time.reset()

        batch_data = sample[0]
        batch_labels = sample[1]
        if config.task == "seg":
            data_num = sample[2]
            #data_labels= sample[3]

        #print("xyz max:",sample[0].numpy().max(),"  xyz min:  ",sample[0].numpy().min(),"  Nan Value ",np.isnan(sample[0].numpy()).any())
        #print("label max:",sample[1].numpy().max(),"  label min:  ",sample[1].numpy().min(),"  Nan Value ",np.isnan(sample[0].numpy()).any())
        batch_time.reset()
        if config.task == "cls":
            shape = batch_data.shape
            indices = pf.get_indices(shape[0],
                                     config.dataset_setting["sample_num"],
                                     shape[1])
            indices = indices.reshape(indices.size // 2, 2)
            indices = indices[:, 0] * batch_data.shape[1] + indices[:, 1]
            indices = indices.astype(int)
            pts_fts_sampled = batch_data.view(
                -1, batch_data.shape[-1])[indices].view(
                    shape[0], config.dataset_setting["sample_num"], -1)
        else:
            shape = batch_data.shape
            indices = pf.get_indices(shape[0],
                                     config.dataset_setting["sample_num"],
                                     data_num)
            indices = indices.reshape(indices.size // 2, 2)
            indices = indices[:, 0] * batch_data.shape[1] + indices[:, 1]
            indices = indices.astype(int)
            pts_fts_sampled = batch_data.view(
                -1, batch_data.shape[-1])[indices].view(
                    shape[0], config.dataset_setting["sample_num"], -1)
            batch_labels = batch_labels.view(-1, 1)[indices].view(
                shape[0], config.dataset_setting["sample_num"])
        features_augmented = None

        xforms, rotations = pf.get_xforms(
            config.train.batch_size, config.dataset_setting["rotation_range"],
            config.dataset_setting["scaling_range"],
            config.dataset_setting["rotation_order"])
        if config.dataset_setting["data_dim"] > 3:
            points_sampled = pts_fts_sampled[:, :, :3]
            features_sampled = pts_fts_sampled[:, :, 3:]
            if config.dataset_setting["use_extra_features"]:
                if config.dataset_setting["with_normal_feature"]:
                    if config.dataset_setting["data_dim"] < 6:
                        print('Only 3D normals are supported!')
                        #exit()()
                    elif config.dataset_setting["data_dim"] == 6:
                        features_augmented = pf.augment(
                            features_sampled, rotations)
                    else:
                        normals = features_sampled[:, :, :3]
                        rest = features_sampled[:, :, 3:]
                        normals_augmented = pf.augment(normals, rotations)
                        features_augmented = torch.cat(
                            (normals_augmented, rest), dim=-1)
                else:
                    features_augmented = features_sampled
        else:
            points_sampled = pts_fts_sampled

        jitter_range = config.dataset_setting["jitter"]
        points_augmented = pf.augment(points_sampled, xforms, jitter_range)

        #print("points_augmented max:",torch.max(points_augmented),"  points_augmented min:  ",torch.min(points_augmented))

        if (features_augmented is None):
            batch_data = points_augmented
        else:
            batch_data = torch.cat((points_augmented, features_augmented),
                                   dim=-1)

        batch_data = batch_data.to(config.device)
        batch_labels = batch_labels.to(config.device)

        raw_out = net.forward(batch_data)

        if config.task == "cls":
            sample_num = raw_out.shape[1]
            raw_out = raw_out.view(-1, raw_out.shape[-1])
            batch_labels = batch_labels.view(-1, 1).repeat(
                1, sample_num).view(-1).long()
            loss = criterion(raw_out, batch_labels)
        elif config.task == "seg":
            pred_choice = raw_out.data.max(2)[1]

            raw_out = raw_out.view(-1, raw_out.shape[-1])
            loss = criterion(raw_out, batch_labels.view(-1).long())

        loss_meter.add(loss.item())
        acc_meter.add(raw_out.detach(), batch_labels.view(-1).long().detach())

        optimizer.zero_grad()

        #print("before backward: ",loss.item())
        loss.backward()
        #print("before backward: ",loss.item())

        optimizer.step()

        if config.task == "seg":
            seg_np = batch_labels.cpu().numpy()
            pred_np = pred_choice.detach().cpu().numpy()

            avg_acc_meter.add(
                metrics.balanced_accuracy_score(seg_np.reshape(-1),
                                                pred_np.reshape(-1)))
            Iou_meter.add(np.mean(calculate_shape_IoU(pred_np, seg_np, None)))

            #train_label_seg.append(temp_label)

        if i % config.print_freq == 0:
            print('Epoch: [{}][{}/{}]\t'.format(epoch, i, len(data_loader)) +
                  'Batch Time %.1f\t' % batch_time.value() +
                  'Epoch Time %.1f\t' % epoch_time.value() +
                  'Loss %.4f\t' % loss_meter.value()[0] +
                  'Acc(c) %.3f' % acc_meter.value(1))

        global_step = global_step + 1
    #################################for loop#################################
    if config.task == "cls":
        print('[ TRAIN summary ] epoch {}:\n'.format(epoch) +
              'category acc: {}'.format(acc_meter.value(1)))
        return loss_meter.value()[0], acc_meter.value(1)
    else:

        train_acc = acc_meter.value(1)
        avg_per_class_acc = avg_acc_meter.value()[0]

        train_ious = Iou_meter.value()[0]

        outstr = 'Train %d, loss: %.6f, train acc: %.6f, train avg acc: %.6f, train iou: %.6f' % (
            epoch, loss_meter.value()[0], train_acc, avg_per_class_acc,
            np.mean(train_ious))
        print(outstr)
        return loss_meter.value()[0], train_acc, avg_per_class_acc, np.mean(
            train_ious)