Пример #1
0
def main(args):
    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    experiment_dir = Path('./log/')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath('part_seg')
    experiment_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        experiment_dir = experiment_dir.joinpath(timestr)
    else:
        experiment_dir = experiment_dir.joinpath(args.log_dir)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)

    root = '/media/feihu/Storage/kitti_point_cloud/semantic_kitti/'
    file_list = '/media/feihu/Storage/kitti_point_cloud/semantic_kitti/train2.list'
    val_list = '/media/feihu/Storage/kitti_point_cloud/semantic_kitti/val2.list'
    TRAIN_DATASET = KittiDataset(root=root,
                                 file_list=file_list,
                                 npoints=args.npoint,
                                 training=True,
                                 augment=True)
    trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  drop_last=True,
                                                  num_workers=2)
    TEST_DATASET = KittiDataset(root=root,
                                file_list=val_list,
                                npoints=args.npoint,
                                training=False,
                                augment=False)
    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 drop_last=True,
                                                 num_workers=2)
    log_string("The number of training data is: %d" % len(TRAIN_DATASET))
    log_string("The number of test data is: %d" % len(TEST_DATASET))
    #    num_classes = 16
    '''MODEL LOADING'''

    shutil.copy('models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('models/pointnet_util.py', str(experiment_dir))

    num_devices = args.num_gpus  #torch.cuda.device_count()
    #    assert num_devices > 1, "Cannot detect more than 1 GPU."
    #    print(num_devices)
    devices = list(range(num_devices))
    target_device = devices[0]

    #    MODEL = importlib.import_module(args.model)

    net = FusionNet(args.npoint, 4, 20, nPlanes)

    #    net = MODEL.get_model(num_classes, normal_channel=args.normal)
    net = net.to(target_device)

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if m.weight is not None:
                torch.nn.init.xavier_normal_(m.weight.data)
            if m.bias is not None:
                torch.nn.init.constant_(m.bias.data, 0.0)
        elif classname.find('Linear') != -1:
            if m.weight is not None:
                torch.nn.init.xavier_normal_(m.weight.data)
            if m.bias is not None:
                torch.nn.init.constant_(m.bias.data, 0.0)

    try:
        checkpoint = torch.load(
            str(experiment_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        net.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0
        net = net.apply(weights_init)

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(net.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    else:
        optimizer = torch.optim.SGD(net.parameters(),
                                    lr=1e-1,
                                    momentum=0.9,
                                    weight_decay=1e-4,
                                    nesterov=True)
#        optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9)

    def bn_momentum_adjust(m, momentum):
        if isinstance(m, torch.nn.BatchNorm2d) or isinstance(
                m, torch.nn.BatchNorm1d):
            m.momentum = momentum

    LEARNING_RATE_CLIP = 1e-5
    MOMENTUM_ORIGINAL = 0.1
    MOMENTUM_DECCAY = 0.5
    MOMENTUM_DECCAY_STEP = 20 / 2  # args.step_size

    best_acc = 0
    global_epoch = 0
    best_class_avg_iou = 0
    best_inctance_avg_iou = 0

    #    criterion = MODEL.get_loss()
    criterion = nn.CrossEntropyLoss()
    criterions = parallel.replicate(criterion, devices)

    # The raw version of the parallel_apply
    #    replicas = parallel.replicate(net, devices)
    #    input_coding = scn.InputLayer(dimension, torch.LongTensor(spatialSize), mode=4)

    for epoch in range(start_epoch, args.epoch):
        log_string('Epoch %d (%d/%s):' %
                   (global_epoch + 1, epoch + 1, args.epoch))
        '''Adjust learning rate and BN momentum'''

        #        lr = max(args.learning_rate * (args.lr_decay ** (epoch // args.step_size)), LEARNING_RATE_CLIP)
        #        lr = args.learning_rate * \
        #            math.exp((1 - epoch) * args.lr_decay)

        #        log_string('Learning rate:%f' % lr)

        #        for param_group in optimizer.param_groups:
        #            param_group['lr'] = lr
        #        for param_group in optimizer.param_groups:
        #            param_group['lr'] = lr

        mean_correct = []
        if 1:
            momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY
                                            **(epoch // MOMENTUM_DECCAY_STEP))
            if momentum < 0.01:
                momentum = 0.01
            print('BN momentum updated to: %f' % momentum)
            net = net.apply(lambda x: bn_momentum_adjust(x, momentum))
        '''learning one epoch'''
        net.train()

        #        for iteration, data in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader), smoothing=0.9):
        for iteration, data in enumerate(trainDataLoader):
            #adjust learing rate.
            if (iteration) % 320 == 0:
                lr_count = epoch * 6 + (iteration) / 320
                lr = args.learning_rate * math.exp(
                    (1 - lr_count) * args.lr_decay)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr

                log_string('Learning rate:%f' % lr)

            optimizer.zero_grad()
            if iteration > 1920:
                break
            points, target, ins, mask = data
            #            print(torch.max(points[:, :, :3], 1)[0])
            #            print(torch.min(points[:, :, :3], 1)[0])

            valid = mask > 0
            total_points = valid.sum()
            orgs = points
            points = points.data.numpy()
            #            print(total_points)
            inputs, targets, masks = [], [], []
            coords = []
            for i in range(num_devices):
                start = int(i * (args.batch_size / num_devices))
                end = int((i + 1) * (args.batch_size / num_devices))
                batch = provider.transform_for_sparse(
                    points[start:end, :, :3], points[start:end, :, 3:],
                    target[start:end, :].data.numpy(),
                    mask[start:end, :].data.numpy(), scale, spatialSize)
                batch['x'][1] = batch['x'][1].type(torch.FloatTensor)
                batch['x'][0] = batch['x'][0].type(torch.IntTensor)
                batch['y'] = batch['y'].type(torch.LongTensor)

                org_xyz = orgs[start:end, :, :3].transpose(1, 2).contiguous()
                org_feas = orgs[start:end, :, 3:].transpose(1, 2).contiguous()

                label = Variable(batch['y'], requires_grad=False)
                maski = batch['mask'].type(torch.IntTensor)
                #                print(torch.max(batch['x'][0], 0)[0])
                #                print(torch.min(batch['x'][0], 0)[0])
                #                locs, feas = input_layer(batch['x'][0].to(devices[i]), batch['x'][1].to(devices[i]))
                locs, feas = input_layer(batch['x'][0].cuda(),
                                         batch['x'][1].cuda())
                #                print(locs.size(), feas.size(), batch['x'][0].size())

                #               print(inputi.size(), batch['x'][1].size())

                with torch.cuda.device(devices[i]):
                    org_coords = batch['x'][0].to(devices[i])
                    inputi = ME.SparseTensor(feas.cpu(), locs).to(
                        devices[i])  #input_coding(batch['x'])
                    org_xyz = org_xyz.to(devices[i])
                    org_feas = org_feas.to(devices[i])
                    maski = maski.to(devices[i])
                    inputs.append(
                        [inputi, org_coords, org_xyz, org_feas, maski])
                    targets.append(label.to(devices[i]))
#                    masks.append(maski.contiguous().to(devices[i]))

            replicas = parallel.replicate(net, devices)
            predictions = parallel.parallel_apply(replicas,
                                                  inputs,
                                                  devices=devices)

            count = 0
            #            print("end ...")
            results = []
            labels = []
            match = 0

            for i in range(num_devices):
                #               temp = predictions[i]['output1'].F#.view(-1, num_classes)
                temp = predictions[i]
                #                temp = output_layer(locs, predictions[i]['output1'].F, coords[i])
                temp = temp[targets[i] > 0, :]
                results.append(temp)

                temp = targets[i]
                temp = temp[targets[i] > 0]
                labels.append(temp)
                #               print(prediction2[i].size(), prediction1[i].size(), targets[i].size())
                outputi = results[
                    i]  #prediction2[i].contiguous().view(-1, num_classes)
                num_points = labels[i].size(0)
                count += num_points

                _, pred_choice = outputi.data.max(1)  #[1]
                #                print(pred_choice)
                correct = pred_choice.eq(labels[i].data).cpu().sum()
                match += correct.item()
                mean_correct.append(correct.item() / num_points)
#            print(prediction2, labels)
            losses = parallel.parallel_apply(criterions,
                                             tuple(zip(results, labels)),
                                             devices=devices)
            loss = parallel.gather(losses, target_device, dim=0).mean()
            loss.backward()
            optimizer.step()
            #            assert(count1 == count2 and total_points == count1)
            log_string(
                "===> Epoch[{}]({}/{}) Valid points:{}/{} Loss: {:.4f} Accuracy: {:.4f}"
                .format(epoch, iteration, len(trainDataLoader), count,
                        total_points, loss.item(), match / count))
#            sys.stdout.flush()
        train_instance_acc = np.mean(mean_correct)
        log_string('Train accuracy is: %.5f' % train_instance_acc)

        #        continue

        with torch.no_grad():
            net.eval()
            evaluator = iouEval(num_classes, ignore)

            evaluator.reset()
            for iteration, (points, target, ins,
                            mask) in tqdm(enumerate(testDataLoader),
                                          total=len(testDataLoader),
                                          smoothing=0.9):
                cur_batch_size, NUM_POINT, _ = points.size()
                #                points, label, target, mask = points.float().cuda(), label.long().cuda(), target.long().cuda(), mask.float().cuda()
                if iteration > 192:
                    break
                if 0:
                    points = points.data.numpy()
                    points[:, :, 0:3], norm = provider.pc_normalize(
                        points[:, :, :3], mask.data.numpy())
                    points = torch.Tensor(points)
                orgs = points
                points = points.data.numpy()
                inputs, targets, masks = [], [], []
                coords = []
                for i in range(num_devices):
                    start = int(i * (cur_batch_size / num_devices))
                    end = int((i + 1) * (cur_batch_size / num_devices))
                    batch = provider.transform_for_test(
                        points[start:end, :, :3], points[start:end, :, 3:],
                        target[start:end, :].data.numpy(),
                        mask[start:end, :].data.numpy(), scale, spatialSize)
                    batch['x'][1] = batch['x'][1].type(torch.FloatTensor)
                    batch['x'][0] = batch['x'][0].type(torch.IntTensor)
                    batch['y'] = batch['y'].type(torch.LongTensor)

                    org_xyz = orgs[start:end, :, :3].transpose(1,
                                                               2).contiguous()
                    org_feas = orgs[start:end, :,
                                    3:].transpose(1, 2).contiguous()

                    label = Variable(batch['y'], requires_grad=False)
                    maski = batch['mask'].type(torch.IntTensor)
                    locs, feas = input_layer(batch['x'][0].cuda(),
                                             batch['x'][1].cuda())
                    #                print(locs.size(), feas.size(), batch['x'][0].size())

                    #               print(inputi.size(), batch['x'][1].size())
                    with torch.cuda.device(devices[i]):
                        org_coords = batch['x'][0].to(devices[i])
                        inputi = ME.SparseTensor(feas.cpu(), locs).to(
                            devices[i])  #input_coding(batch['x'])
                        org_xyz = org_xyz.to(devices[i])
                        org_feas = org_feas.to(devices[i])
                        maski = maski.to(devices[i])
                        inputs.append(
                            [inputi, org_coords, org_xyz, org_feas, maski])
                        targets.append(label.to(devices[i]))
#                        masks.append(maski.contiguous().to(devices[i]))

                replicas = parallel.replicate(net, devices)
                outputs = parallel.parallel_apply(replicas,
                                                  inputs,
                                                  devices=devices)

                #                net = net.eval()
                #                seg_pred = classifier(points, to_categorical(label, num_classes))
                seg_pred = outputs[0].cpu()
                #                mask = masks[0].cpu()
                target = targets[0].cpu()
                loc = locs[0].cpu()
                for i in range(1, num_devices):
                    seg_pred = torch.cat((seg_pred, outputs[i].cpu()), 0)
                    #                    mask = torch.cat((mask, masks[i].cpu()), 0)
                    target = torch.cat((target, targets[i].cpu()), 0)

                seg_pred = seg_pred[target > 0, :]
                target = target[target > 0]
                _, seg_pred = seg_pred.data.max(1)  #[1]

                target = target.data.numpy()

                evaluator.addBatch(seg_pred, target)

# when I am done, print the evaluation
            m_accuracy = evaluator.getacc()
            m_jaccard, class_jaccard = evaluator.getIoU()

            log_string('Validation set:\n'
                       'Acc avg {m_accuracy:.3f}\n'
                       'IoU avg {m_jaccard:.3f}'.format(m_accuracy=m_accuracy,
                                                        m_jaccard=m_jaccard))
            # print also classwise
            for i, jacc in enumerate(class_jaccard):
                if i not in ignore:
                    log_string(
                        'IoU class {i:} [{class_str:}] = {jacc:.3f}'.format(
                            i=i,
                            class_str=class_strings[class_inv_remap[i]],
                            jacc=jacc))

        log_string('Epoch %d test Accuracy: %f  mean avg mIOU: %f' %
                   (epoch + 1, m_accuracy, m_jaccard))
        if (m_jaccard >= best_class_avg_iou):
            #            logger.info('Save model...')
            log_string('Saveing model...')
            savepath = str(checkpoints_dir) + '/best_model.pth'
            log_string('Saving at %s' % savepath)
            state = {
                'epoch': epoch,
                'train_acc': train_instance_acc,
                'test_acc': m_accuracy,
                'class_avg_iou': m_jaccard,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }
            torch.save(state, savepath)


#            log_string('Saving model....')

        if m_accuracy > best_acc:
            best_acc = m_accuracy
        if m_jaccard > best_class_avg_iou:
            best_class_avg_iou = m_jaccard

        log_string('Best accuracy is: %.5f' % best_acc)
        log_string('Best class avg mIOU is: %.5f' % best_class_avg_iou)

        global_epoch += 1
Пример #2
0
def main(args):
    def log_string(str):
        #        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    experiment_dir = Path('./log/')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath('part_seg')
    experiment_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        experiment_dir = experiment_dir.joinpath(timestr)
    else:
        experiment_dir = experiment_dir.joinpath(args.log_dir)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)

    root = '/media/feihu/Storage/kitti_point_cloud/semantic_kitti/'
    #    file_list = '/media/feihu/Storage/kitti_point_cloud/semantic_kitti/train2.list'
    val_list = '/media/feihu/Storage/kitti_point_cloud/semantic_kitti/val2.list'
    #    TRAIN_DATASET = KittiDataset(root = root, file_list=file_list, npoints=args.npoint, training=True, augment=True)
    #    trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=2)
    TEST_DATASET = KittiDataset(root=root,
                                file_list=val_list,
                                npoints=args.npoint,
                                training=False,
                                augment=False)
    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 drop_last=True,
                                                 num_workers=2)
    #    log_string("The number of training data is: %d" % len(TRAIN_DATASET))
    log_string("The number of test data is: %d" % len(TEST_DATASET))
    #    num_classes = 16

    num_devices = args.num_gpus  #torch.cuda.device_count()
    #    assert num_devices > 1, "Cannot detect more than 1 GPU."
    #    print(num_devices)
    devices = list(range(num_devices))
    target_device = devices[0]

    #    MODEL = importlib.import_module(args.model)

    net = UNet(4, 20, nPlanes)

    #    net = MODEL.get_model(num_classes, normal_channel=args.normal)
    net = net.to(target_device)

    try:
        checkpoint = torch.load(
            str(experiment_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        net.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        quit()

    if 1:

        with torch.no_grad():
            net.eval()
            evaluator = iouEval(num_classes, ignore)

            evaluator.reset()
            #            for iteration, (points, target, ins, mask) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9):
            for iteration, (points, target, ins,
                            mask) in enumerate(testDataLoader):
                evaone = iouEval(num_classes, ignore)
                evaone.reset()
                cur_batch_size, NUM_POINT, _ = points.size()

                if iteration > 128:
                    break

                inputs, targets, masks = [], [], []
                coords = []
                for i in range(num_devices):
                    start = int(i * (cur_batch_size / num_devices))
                    end = int((i + 1) * (cur_batch_size / num_devices))
                    with torch.cuda.device(devices[i]):
                        pc = points[start:end, :, :].to(devices[i])
                        #feas = points[start:end,:,3:].to(devices[i])
                        targeti = target[start:end, :].to(devices[i])
                        maski = mask[start:end, :].to(devices[i])

                        locs, feas, label, maski, offsets = input_layer(
                            pc, targeti, maski, scale.to(devices[i]),
                            spatialSize.to(devices[i]), True)
                        #                        print(locs.size(), feas.size(), label.size(), maski.size(), offsets.size())
                        org_coords = locs[1]
                        label = Variable(label, requires_grad=False)

                        inputi = ME.SparseTensor(feas.cpu(), locs[0].cpu())
                        inputs.append([inputi.to(devices[i]), org_coords])
                        targets.append(label)
                        masks.append(maski)

                replicas = parallel.replicate(net, devices)
                outputs = parallel.parallel_apply(replicas,
                                                  inputs,
                                                  devices=devices)

                seg_pred = outputs[0].cpu()
                mask = masks[0].cpu()
                target = targets[0].cpu()
                loc = locs[0].cpu()
                for i in range(1, num_devices):
                    seg_pred = torch.cat((seg_pred, outputs[i].cpu()), 0)
                    mask = torch.cat((mask, masks[i].cpu()), 0)
                    target = torch.cat((target, targets[i].cpu()), 0)

                seg_pred = seg_pred[target > 0, :]
                target = target[target > 0]
                _, seg_pred = seg_pred.data.max(1)  #[1]

                target = target.data.numpy()

                evaluator.addBatch(seg_pred, target)

                evaone.addBatch(seg_pred, target)
                cur_accuracy = evaone.getacc()
                cur_jaccard, class_jaccard = evaone.getIoU()
                print('%.4f %.4f' % (cur_accuracy, cur_jaccard))

            m_accuracy = evaluator.getacc()
            m_jaccard, class_jaccard = evaluator.getIoU()

            log_string('Validation set:\n'
                       'Acc avg {m_accuracy:.3f}\n'
                       'IoU avg {m_jaccard:.3f}'.format(m_accuracy=m_accuracy,
                                                        m_jaccard=m_jaccard))
            # print also classwise
            for i, jacc in enumerate(class_jaccard):
                if i not in ignore:
                    log_string(
                        'IoU class {i:} [{class_str:}] = {jacc:.3f}'.format(
                            i=i,
                            class_str=class_strings[class_inv_remap[i]],
                            jacc=jacc))