Ejemplo n.º 1
0
def train_net(cfg):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    # Set up data loader
    # choose ShapeNet
    train_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
        cfg.DATASET.TRAIN_DATASET](cfg)
    test_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
        cfg.DATASET.TEST_DATASET](cfg)
    # get_dataset's para: subdataset(train0, test1, val2)
    train_data_loader = torch.utils.data.DataLoader(
        dataset=train_dataset_loader.get_dataset(
            utils.data_loaders.DatasetSubset.TRAIN),  # train/test/val
        batch_size=cfg.TRAIN.BATCH_SIZE,
        num_workers=cfg.CONST.NUM_WORKERS,
        collate_fn=utils.data_loaders.collate_fn,
        pin_memory=True,
        shuffle=True,
        drop_last=True)
    val_data_loader = torch.utils.data.DataLoader(
        dataset=test_dataset_loader.get_dataset(
            utils.data_loaders.DatasetSubset.VAL),
        batch_size=1,
        num_workers=cfg.CONST.NUM_WORKERS,
        collate_fn=utils.data_loaders.collate_fn,
        pin_memory=True,
        shuffle=False)

    # Set up folders for logs and checkpoints
    output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s',
                              datetime.now().isoformat())  # output_dir
    cfg.DIR.CHECKPOINTS = output_dir % 'checkpoints'
    cfg.DIR.LOGS = output_dir % 'logs'
    txt_dir = output_dir % 'txt'
    if not os.path.exists(txt_dir):
        os.makedirs(txt_dir)
    f_record = open(txt_dir + '/record.txt', 'w')
    if not os.path.exists(cfg.DIR.CHECKPOINTS):
        os.makedirs(cfg.DIR.CHECKPOINTS)

    # Create tensorboard writers
    train_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'train'))
    val_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'test'))

    # Create the networks
    grnet = GRNet(cfg)
    grnet.apply(utils.helpers.init_weights)
    logging.debug('Parameters in GRNet: %d.' %
                  utils.helpers.count_parameters(grnet))

    # Move the network to GPU if possible
    if torch.cuda.is_available():
        grnet = torch.nn.DataParallel(grnet).cuda()

    # Create the optimizers
    grnet_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                              grnet.parameters()),
                                       lr=cfg.TRAIN.LEARNING_RATE,
                                       weight_decay=cfg.TRAIN.WEIGHT_DECAY,
                                       betas=cfg.TRAIN.BETAS)
    grnet_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        grnet_optimizer,
        milestones=cfg.TRAIN.LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA)

    # Set up loss functions
    chamfer_dist = ChamferDistance()
    gridding_loss = GriddingLoss(  # lgtm [py/unused-local-variable]
        scales=cfg.NETWORK.GRIDDING_LOSS_SCALES,
        alphas=cfg.NETWORK.GRIDDING_LOSS_ALPHAS)

    # Load pretrained model if exists
    init_epoch = 0  # 断点续跑
    best_metrics = None
    if 'WEIGHTS' in cfg.CONST:
        logging.info('Recovering from %s ...' % (cfg.CONST.WEIGHTS))
        checkpoint = torch.load(cfg.CONST.WEIGHTS)
        best_metrics = Metrics(cfg.TEST.METRIC_NAME,
                               checkpoint['best_metrics'])
        grnet.load_state_dict(checkpoint['grnet'])
        logging.info(
            'Recover complete. Current epoch = #%d; best metrics = %s.' %
            (init_epoch, best_metrics))

    # Training/Testing the network
    first_epoch = True
    for epoch_idx in range(init_epoch + 1, cfg.TRAIN.N_EPOCHS + 1):
        epoch_start_time = time()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter(['SparseLoss', 'DenseLoss'])
        # losses = AverageMeter(['GridLoss', 'DenseLoss'])

        grnet.train()

        batch_end_time = time()
        n_batches = len(train_data_loader)
        for batch_idx, (taxonomy_ids, model_ids,
                        data) in enumerate(train_data_loader):
            # print('batch_size: ', data['partial_cloud'].shape)
            data_time.update(time() - batch_end_time)
            for k, v in data.items():
                data[k] = utils.helpers.var_or_cuda(v)
            sparse_ptcloud, dense_ptcloud = grnet(data)
            sparse_loss = chamfer_dist(sparse_ptcloud, data['gtcloud'])
            # grid_loss = gridding_loss(dense_ptcloud, data['gtcloud'])
            dense_loss = chamfer_dist(dense_ptcloud, data['gtcloud'])
            _loss = sparse_loss + dense_loss
            losses.update(
                [sparse_loss.item() * 1000,
                 dense_loss.item() * 1000])
            # _loss = grid_loss + dense_loss
            # losses.update([grid_loss.item() * 1000, dense_loss.item() * 1000])

            grnet.zero_grad()
            _loss.backward()
            grnet_optimizer.step()

            n_itr = (epoch_idx - 1) * n_batches + batch_idx
            train_writer.add_scalar('Loss/Batch/Sparse',
                                    sparse_loss.item() * 1000, n_itr)
            # train_writer.add_scalar('Loss/Batch/Grid', grid_loss.item() * 1000, n_itr)
            train_writer.add_scalar('Loss/Batch/Dense',
                                    dense_loss.item() * 1000, n_itr)

            batch_time.update(time() - batch_end_time)
            batch_end_time = time()
            ###

            f_record.write(
                '\n[Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) Losses = %s'
                % (epoch_idx, cfg.TRAIN.N_EPOCHS, batch_idx + 1, n_batches,
                   batch_time.val(), data_time.val(),
                   ['%.4f' % l for l in losses.val()]))
            logging.info(
                '[Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) Losses = %s'
                % (epoch_idx, cfg.TRAIN.N_EPOCHS, batch_idx + 1, n_batches,
                   batch_time.val(), data_time.val(),
                   ['%.4f' % l for l in losses.val()]))

        grnet_lr_scheduler.step()
        epoch_end_time = time()
        train_writer.add_scalar('Loss/Epoch/Sparse', losses.avg(0), epoch_idx)
        # train_writer.add_scalar('Loss/Epoch/Grid', losses.avg(0), epoch_idx)
        train_writer.add_scalar('Loss/Epoch/Dense', losses.avg(1), epoch_idx)
        f_record.write('\n[Epoch %d/%d] EpochTime = %.3f (s) Losses = %s' %
                       (epoch_idx, cfg.TRAIN.N_EPOCHS, epoch_end_time -
                        epoch_start_time, ['%.4f' % l for l in losses.avg()]))
        logging.info('[Epoch %d/%d] EpochTime = %.3f (s) Losses = %s' %
                     (epoch_idx, cfg.TRAIN.N_EPOCHS, epoch_end_time -
                      epoch_start_time, ['%.4f' % l for l in losses.avg()]))

        # Validate the current model
        # if epoch_idx % cfg.TRAIN.SAVE_FREQ == 0:
        # metrics = test_net(cfg, epoch_idx, val_data_loader, val_writer, grnet)

        # Save ckeckpoints
        # if epoch_idx % cfg.TRAIN.SAVE_FREQ == 0 or metrics.better_than(best_metrics):

        if first_epoch:
            metrics = test_net(cfg, epoch_idx, val_data_loader, val_writer,
                               grnet)
            best_metrics = metrics
            first_epoch = False

        if epoch_idx % cfg.TRAIN.SAVE_FREQ == 0:
            metrics = test_net(cfg, epoch_idx, val_data_loader, val_writer,
                               grnet)
            file_name = 'best-ckpt.pth' if metrics.better_than(
                best_metrics) else 'epoch-%03d.pth' % (epoch_idx + 1)
            output_path = os.path.join(cfg.DIR.CHECKPOINTS, file_name)
            torch.save({
                'epoch_index': epoch_idx,
                'best_metrics': metrics.state_dict(),
                'grnet': grnet.state_dict()
            }, output_path)  # yapf: disable

            logging.info('Saved checkpoint to %s ...' % output_path)
            if metrics.better_than(best_metrics):
                best_metrics = metrics

    train_writer.close()
    val_writer.close()
Ejemplo n.º 2
0
def train_net_new(cfg):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Set up data loader
    pnum = 2048
    crop_point_num = 512
    workers = 1
    batchSize = 16

    class_name = "Pistol"

    train_dataset_loader = shapenet_part_loader.PartDataset(
        root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/',
        classification=False,
        class_choice=class_name,
        npoints=pnum,
        split='train')
    train_data_loader = torch.utils.data.DataLoader(train_dataset_loader,
                                                    batch_size=batchSize,
                                                    shuffle=True,
                                                    num_workers=int(workers))

    test_dataset_loader = shapenet_part_loader.PartDataset(
        root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/',
        classification=False,
        class_choice=class_name,
        npoints=pnum,
        split='test')
    val_data_loader = torch.utils.data.DataLoader(test_dataset_loader,
                                                  batch_size=batchSize,
                                                  shuffle=True,
                                                  num_workers=int(workers))

    # Set up folders for logs and checkpoints
    output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s',
                              datetime.now().isoformat())
    cfg.DIR.CHECKPOINTS = output_dir % 'checkpoints'
    cfg.DIR.LOGS = output_dir % 'logs'
    if not os.path.exists(cfg.DIR.CHECKPOINTS):
        os.makedirs(cfg.DIR.CHECKPOINTS)

    # Create tensorboard writers
    train_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'train'))
    val_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'test'))

    # Create the networks
    grnet = GRNet(cfg, seg_class_no)
    grnet.apply(utils.helpers.init_weights)
    logging.debug('Parameters in GRNet: %d.' %
                  utils.helpers.count_parameters(grnet))

    # Move the network to GPU if possible
    grnet = grnet.to(device)

    # Create the optimizers
    grnet_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                              grnet.parameters()),
                                       lr=cfg.TRAIN.LEARNING_RATE,
                                       weight_decay=cfg.TRAIN.WEIGHT_DECAY,
                                       betas=cfg.TRAIN.BETAS)
    grnet_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        grnet_optimizer,
        milestones=cfg.TRAIN.LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA)

    # Set up loss functions
    chamfer_dist = ChamferDistance()
    gridding_loss = GriddingLoss(  # lgtm [py/unused-local-variable]
        scales=cfg.NETWORK.GRIDDING_LOSS_SCALES,
        alphas=cfg.NETWORK.GRIDDING_LOSS_ALPHAS)
    seg_criterion = torch.nn.CrossEntropyLoss().cuda()

    # Load pretrained model if exists
    init_epoch = 0
    best_metrics = None
    if 'WEIGHTS' in cfg.CONST:
        logging.info('Recovering from %s ...' % (cfg.CONST.WEIGHTS))
        checkpoint = torch.load(cfg.CONST.WEIGHTS)
        grnet.load_state_dict(checkpoint['grnet'])
        logging.info(
            'Recover complete. Current epoch = #%d; best metrics = %s.' %
            (init_epoch, best_metrics))

    train_seg_on_sparse = False
    train_seg_on_dense = False

    miou = 0

    # Training/Testing the network
    for epoch_idx in range(init_epoch + 1, cfg.TRAIN.N_EPOCHS + 1):
        epoch_start_time = time()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter(['SparseLoss', 'DenseLoss'])

        grnet.train()

        if epoch_idx == 5:
            train_seg_on_sparse = True

        if epoch_idx == 7:
            train_seg_on_dense = True

        batch_end_time = time()
        n_batches = len(train_data_loader)
        for batch_idx, (
                data,
                seg,
                model_ids,
        ) in enumerate(train_data_loader):
            data_time.update(time() - batch_end_time)

            input_cropped1 = torch.FloatTensor(data.size()[0], pnum, 3)
            input_cropped1 = input_cropped1.data.copy_(data)

            if batch_idx == 10:
                pass  #break

            data = data.to(device)
            seg = seg.to(device)

            input_cropped1 = input_cropped1.to(device)

            # remove points to make input incomplete
            choice = [
                torch.Tensor([1, 0, 0]),
                torch.Tensor([0, 0, 1]),
                torch.Tensor([1, 0, 1]),
                torch.Tensor([-1, 0, 0]),
                torch.Tensor([-1, 1, 0])
            ]
            for m in range(data.size()[0]):
                index = random.sample(choice, 1)
                p_center = index[0].to(device)
                distances = torch.sum((data[m] - p_center)**2, dim=1)
                order = torch.argsort(distances)

                zero_point = torch.FloatTensor([0, 0, 0]).to(device)
                input_cropped1.data[m, order[:crop_point_num]] = zero_point

            if save_crop_mode:
                np.save(class_name + "_orig", data[0].detach().cpu().numpy())
                np.save(class_name + "_cropped",
                        input_cropped1[0].detach().cpu().numpy())
                sys.exit()

            sparse_ptcloud, dense_ptcloud, sparse_seg, full_seg, dense_seg = grnet(
                input_cropped1)

            data_seg = get_data_seg(data, full_seg)
            seg_loss = seg_criterion(torch.transpose(data_seg, 1, 2), seg)
            if train_seg_on_sparse and train_seg:
                gt_seg = get_seg_gts(seg, data, sparse_ptcloud)
                seg_loss += seg_criterion(torch.transpose(sparse_seg, 1, 2),
                                          gt_seg)
                seg_loss /= 2

            if train_seg_on_dense and train_seg:
                gt_seg = get_seg_gts(seg, data, dense_ptcloud)
                dense_seg_loss = seg_criterion(
                    torch.transpose(dense_seg, 1, 2), gt_seg)
                print(dense_seg_loss.item())

            if draw_mode:
                plot_ptcloud(data[0], seg[0], "orig")
                plot_ptcloud(input_cropped1[0], seg[0], "cropped")
                plot_ptcloud(sparse_ptcloud[0],
                             torch.argmax(sparse_seg[0], dim=1), "sparse_pred")
                if not train_seg_on_sparse:
                    gt_seg = get_seg_gts(seg, data, sparse_ptcloud)
                #plot_ptcloud(sparse_ptcloud[0], gt_seg[0], "sparse_gt")
                #if not train_seg_on_dense:
                #gt_seg = get_seg_gts(seg, data, sparse_ptcloud)
                print(dense_seg.size())
                plot_ptcloud(dense_ptcloud[0], torch.argmax(dense_seg[0],
                                                            dim=1),
                             "dense_pred")
                sys.exit()

            print(seg_loss.item())

            lamb = 0.8
            sparse_loss = chamfer_dist(sparse_ptcloud, data).to(device)
            dense_loss = chamfer_dist(dense_ptcloud, data).to(device)
            grid_loss = gridding_loss(sparse_ptcloud, data).to(device)
            if train_seg:
                _loss = lamb * (sparse_loss + dense_loss +
                                grid_loss) + (1 - lamb) * seg_loss
            else:
                _loss = (sparse_loss + dense_loss + grid_loss)
            if train_seg_on_dense and train_seg:
                _loss += (1 - lamb) * dense_seg_loss
            _loss.to(device)
            losses.update(
                [sparse_loss.item() * 1000,
                 dense_loss.item() * 1000])

            grnet.zero_grad()
            _loss.backward()
            grnet_optimizer.step()

            n_itr = (epoch_idx - 1) * n_batches + batch_idx
            train_writer.add_scalar('Loss/Batch/Sparse',
                                    sparse_loss.item() * 1000, n_itr)
            train_writer.add_scalar('Loss/Batch/Dense',
                                    dense_loss.item() * 1000, n_itr)

            batch_time.update(time() - batch_end_time)
            batch_end_time = time()
            logging.info(
                '[Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) Losses = %s'
                % (epoch_idx, cfg.TRAIN.N_EPOCHS, batch_idx + 1, n_batches,
                   batch_time.val(), data_time.val(),
                   ['%.4f' % l for l in losses.val()]))

        # Validate the current model
        if train_seg:
            miou_new = test_net_new(cfg, epoch_idx, val_data_loader,
                                    val_writer, grnet)
        else:
            miou_new = 0

        grnet_lr_scheduler.step()
        epoch_end_time = time()
        train_writer.add_scalar('Loss/Epoch/Sparse', losses.avg(0), epoch_idx)
        train_writer.add_scalar('Loss/Epoch/Dense', losses.avg(1), epoch_idx)
        logging.info('[Epoch %d/%d] EpochTime = %.3f (s) Losses = %s' %
                     (epoch_idx, cfg.TRAIN.N_EPOCHS, epoch_end_time -
                      epoch_start_time, ['%.4f' % l for l in losses.avg()]))

        if not train_seg or miou_new > miou:
            file_name = class_name + 'noseg-ckpt-epoch.pth'
            output_path = os.path.join(cfg.DIR.CHECKPOINTS, file_name)
            torch.save({
                'epoch_index': epoch_idx,
                'grnet': grnet.state_dict()
            }, output_path)  # yapf: disable

            logging.info('Saved checkpoint to %s ...' % output_path)
            miou = miou_new

    train_writer.close()
    val_writer.close()