Beispiel #1
0
def train_net(cfg):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use

    # train_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TRAIN_DATASET](cfg)
    # test_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg)

    train_dataset_loader = dataloader_jt.DATASET_LOADER_MAPPING[
        cfg.DATASET.TRAIN_DATASET](cfg)
    test_dataset_loader = dataloader_jt.DATASET_LOADER_MAPPING[
        cfg.DATASET.TEST_DATASET](cfg)

    train_data_loader = train_dataset_loader.get_dataset(
        dataloader_jt.DatasetSubset.TRAIN,
        batch_size=cfg.TRAIN.BATCH_SIZE,
        num_workers=cfg.CONST.NUM_WORKERS,
        shuffle=True)
    val_data_loader = test_dataset_loader.get_dataset(
        dataloader_jt.DatasetSubset.VAL,
        batch_size=cfg.TRAIN.BATCH_SIZE,
        num_workers=cfg.CONST.NUM_WORKERS,
        shuffle=False)

    # Set up folders for logs and checkpoints
    output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s',
                              datetime.now().isoformat())
    cfg.DIR.CHECKPOINTS = output_dir % 'checkpoints'
    cfg.DIR.LOGS = output_dir % 'logs'
    if not os.path.exists(cfg.DIR.CHECKPOINTS):
        os.makedirs(cfg.DIR.CHECKPOINTS)

    # Create tensorboard writers
    train_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'train'))
    val_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'test'))
    model = Model(dataset=cfg.DATASET.TRAIN_DATASET)
    init_epoch = 0
    best_metrics = float('inf')

    optimizer = nn.Adam(model.parameters(),
                        lr=cfg.TRAIN.LEARNING_RATE,
                        weight_decay=cfg.TRAIN.WEIGHT_DECAY,
                        betas=cfg.TRAIN.BETAS)
    lr_scheduler = jittor.lr_scheduler.MultiStepLR(
        optimizer,
        milestones=cfg.TRAIN.LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA,
        last_epoch=init_epoch)

    # Training/Testing the network
    for epoch_idx in range(init_epoch + 1, cfg.TRAIN.N_EPOCHS + 1):
        epoch_start_time = time()

        model.train()

        loss_metric = AverageMeter()
        n_batches = len(train_data_loader)
        print('epoch: ', epoch_idx, 'optimizer: ', lr_scheduler.get_lr())
        with tqdm(train_data_loader) as t:
            for batch_idx, (taxonomy_ids, model_ids, data) in enumerate(t):
                partial = jittor.array(data['partial_cloud'])
                gt = jittor.array(data['gtcloud'])
                pcds, deltas = model(partial)

                cd1 = chamfer(pcds[0], gt)
                cd2 = chamfer(pcds[1], gt)
                cd3 = chamfer(pcds[2], gt)
                loss_cd = cd1 + cd2 + cd3

                delta_losses = []
                for delta in deltas:
                    delta_losses.append(jittor.sum(delta**2))

                loss_pmd = jittor.sum(jittor.stack(delta_losses)) / 3

                loss = loss_cd * cfg.TRAIN.LAMBDA_CD + loss_pmd * cfg.TRAIN.LAMBDA_PMD
                optimizer.step(loss)

                loss_item = loss.item()
                loss_metric.update(loss_item)

                jittor.sync_all()

                t.set_description(
                    '[Epoch %d/%d][Batch %d/%d]' %
                    (epoch_idx, cfg.TRAIN.N_EPOCHS, batch_idx + 1, n_batches))
                t.set_postfix(loss='%s' % ['%.4f' % l for l in [loss_item]])

        lr_scheduler.step()
        epoch_end_time = time()
        train_writer.add_scalar('Loss/Epoch/loss', loss_metric.avg(),
                                epoch_idx)
        logging.info(
            '[Epoch %d/%d] EpochTime = %.3f (s) Losses = %s' %
            (epoch_idx, cfg.TRAIN.N_EPOCHS, epoch_end_time - epoch_start_time,
             ['%.4f' % l for l in [loss_metric.avg()]]))

        # Validate the current model
        cd_eval = test_net(cfg, epoch_idx, val_data_loader, val_writer, model)

        # Save checkpoints
        if epoch_idx % cfg.TRAIN.SAVE_FREQ == 0 or cd_eval < best_metrics:
            file_name = 'ckpt-best.pkl' if cd_eval < best_metrics else 'ckpt-epoch-%03d.pkl' % epoch_idx
            output_path = os.path.join(cfg.DIR.CHECKPOINTS, file_name)

            model.save(output_path)

            logging.info('Saved checkpoint to %s ...' % output_path)
            if cd_eval < best_metrics:
                best_metrics = cd_eval

    train_writer.close()
    val_writer.close()
Beispiel #2
0
def train_net(cfg):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    train_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
        cfg.DATASET.TRAIN_DATASET](cfg)
    test_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
        cfg.DATASET.TEST_DATASET](cfg)

    train_data_loader = torch.utils.data.DataLoader(
        dataset=train_dataset_loader.get_dataset(
            utils.data_loaders.DatasetSubset.TRAIN),
        batch_size=cfg.TRAIN.BATCH_SIZE,
        num_workers=cfg.CONST.NUM_WORKERS,
        collate_fn=utils.data_loaders.collate_fn,
        pin_memory=True,
        shuffle=True,
        drop_last=True)
    val_data_loader = torch.utils.data.DataLoader(
        dataset=test_dataset_loader.get_dataset(
            utils.data_loaders.DatasetSubset.TEST),
        batch_size=cfg.TRAIN.BATCH_SIZE,
        num_workers=cfg.CONST.NUM_WORKERS // 2,
        collate_fn=utils.data_loaders.collate_fn,
        pin_memory=True,
        shuffle=False)

    # Set up folders for logs and checkpoints
    output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s',
                              datetime.now().isoformat())
    cfg.DIR.CHECKPOINTS = output_dir % 'checkpoints'
    cfg.DIR.LOGS = output_dir % 'logs'
    if not os.path.exists(cfg.DIR.CHECKPOINTS):
        os.makedirs(cfg.DIR.CHECKPOINTS)

    # Create tensorboard writers
    train_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'train'))
    val_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'test'))

    model = Model(dataset=cfg.DATASET.TRAIN_DATASET)
    if torch.cuda.is_available():
        model = torch.nn.DataParallel(model).cuda()

    # Create the optimizers
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                 lr=cfg.TRAIN.LEARNING_RATE,
                                 weight_decay=cfg.TRAIN.WEIGHT_DECAY,
                                 betas=cfg.TRAIN.BETAS)
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                     lr_lambda=lr_lambda)

    init_epoch = 0
    best_metrics = float('inf')

    if 'WEIGHTS' in cfg.CONST and cfg.CONST.WEIGHTS:
        logging.info('Recovering from %s ...' % (cfg.CONST.WEIGHTS))
        checkpoint = torch.load(cfg.CONST.WEIGHTS)
        best_metrics = checkpoint['best_metrics']
        model.load_state_dict(checkpoint['model'])
        logging.info(
            'Recover complete. Current epoch = #%d; best metrics = %s.' %
            (init_epoch, best_metrics))

    # Training/Testing the network
    for epoch_idx in range(init_epoch + 1, cfg.TRAIN.N_EPOCHS + 1):
        epoch_start_time = time()

        batch_time = AverageMeter()
        data_time = AverageMeter()

        model.train()

        total_cd1 = 0
        total_cd2 = 0
        total_cd3 = 0
        total_pmd = 0

        batch_end_time = time()
        n_batches = len(train_data_loader)
        with tqdm(train_data_loader) as t:
            for batch_idx, (taxonomy_ids, model_ids, data) in enumerate(t):
                data_time.update(time() - batch_end_time)
                for k, v in data.items():
                    data[k] = utils.helpers.var_or_cuda(v)
                partial = random_subsample(data['partial_cloud'])
                gt = random_subsample(data['gtcloud'])

                pcds, deltas = model(partial)

                cd1 = chamfer(pcds[0], gt)
                cd2 = chamfer(pcds[1], gt)
                cd3 = chamfer(pcds[2], gt)
                loss_cd = cd1 + cd2 + cd3

                delta_losses = []
                for delta in deltas:
                    delta_losses.append(torch.sum(delta**2))

                loss_pmd = torch.sum(torch.stack(delta_losses)) / 3

                loss = loss_cd * cfg.TRAIN.LAMBDA_CD + loss_pmd * cfg.TRAIN.LAMBDA_PMD

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                cd1_item = cd1.item() * 1e3
                total_cd1 += cd1_item
                cd2_item = cd2.item() * 1e3
                total_cd2 += cd2_item
                cd3_item = cd3.item() * 1e3
                total_cd3 += cd3_item
                pmd_item = loss_pmd.item()
                total_pmd += pmd_item
                n_itr = (epoch_idx - 1) * n_batches + batch_idx
                train_writer.add_scalar('Loss/Batch/cd1', cd1_item, n_itr)
                train_writer.add_scalar('Loss/Batch/cd2', cd2_item, n_itr)
                train_writer.add_scalar('Loss/Batch/cd3', cd3_item, n_itr)
                train_writer.add_scalar('Loss/Batch/pmd', pmd_item, n_itr)
                batch_time.update(time() - batch_end_time)
                batch_end_time = time()
                t.set_description(
                    '[Epoch %d/%d][Batch %d/%d]' %
                    (epoch_idx, cfg.TRAIN.N_EPOCHS, batch_idx + 1, n_batches))
                t.set_postfix(loss='%s' % [
                    '%.4f' % l
                    for l in [cd1_item, cd2_item, cd3_item, pmd_item]
                ])

        avg_cd1 = total_cd1 / n_batches
        avg_cd2 = total_cd2 / n_batches
        avg_cd3 = total_cd3 / n_batches
        avg_pmd = total_pmd / n_batches

        lr_scheduler.step()
        epoch_end_time = time()
        train_writer.add_scalar('Loss/Epoch/cd1', avg_cd1, epoch_idx)
        train_writer.add_scalar('Loss/Epoch/cd2', avg_cd2, epoch_idx)
        train_writer.add_scalar('Loss/Epoch/cd3', avg_cd3, epoch_idx)
        train_writer.add_scalar('Loss/Epoch/pmd', avg_pmd, epoch_idx)
        logging.info(
            '[Epoch %d/%d] EpochTime = %.3f (s) Losses = %s' %
            (epoch_idx, cfg.TRAIN.N_EPOCHS, epoch_end_time - epoch_start_time,
             ['%.4f' % l for l in [avg_cd1, avg_cd2, avg_cd3, avg_pmd]]))

        # Validate the current model
        cd_eval = test_net(cfg, epoch_idx, val_data_loader, val_writer, model)

        # Save checkpoints
        if epoch_idx % cfg.TRAIN.SAVE_FREQ == 0 or cd_eval < best_metrics:
            file_name = 'ckpt-best.pth' if cd_eval < best_metrics else 'ckpt-epoch-%03d.pth' % epoch_idx
            output_path = os.path.join(cfg.DIR.CHECKPOINTS, file_name)
            torch.save(
                {
                    'epoch_index': epoch_idx,
                    'best_metrics': best_metrics,
                    'model': model.state_dict()
                }, output_path)

            logging.info('Saved checkpoint to %s ...' % output_path)
            if cd_eval < best_metrics:
                best_metrics = cd_eval

    train_writer.close()
    val_writer.close()
Beispiel #3
0
def test_net(cfg,
             epoch_idx=-1,
             test_data_loader=None,
             test_writer=None,
             model=None):

    if test_data_loader is None:
        # Set up data loader
        dataset_loader = dataloader_jt.DATASET_LOADER_MAPPING[
            cfg.DATASET.TEST_DATASET](cfg)
        test_data_loader = dataset_loader.get_dataset(
            dataloader_jt.DatasetSubset.VAL, batch_size=1, shuffle=False)

    # Setup networks and initialize networks
    if model is None:
        model = Model(dataset=cfg.DATASET.TEST_DATASET)

        assert 'WEIGHTS' in cfg.CONST and cfg.CONST.WEIGHTS
        print('loading: ', cfg.CONST.WEIGHTS)
        model.load(cfg.CONST.WEIGHTS)

    # Switch models to evaluation mode
    model.eval()

    n_samples = len(test_data_loader)
    test_losses = AverageMeter(['cd1', 'cd2', 'cd3', 'pmd'])
    test_metrics = AverageMeter(Metrics.names())
    category_metrics = dict()

    # Testing loop
    with tqdm(test_data_loader) as t:
        # print('repeating')
        for model_idx, (taxonomy_id, model_id, data) in enumerate(t):
            taxonomy_id = taxonomy_id[0] if isinstance(
                taxonomy_id[0], str) else taxonomy_id[0].item()
            model_id = model_id[0]

            # for k, v in data.items():
            # data[k] = utils.helpers.var_or_cuda(v)

            partial = jittor.array(data['partial_cloud'])
            gt = jittor.array(data['gtcloud'])

            b, n, _ = partial.shape

            pcds, deltas = model(partial)

            cd1 = chamfer(pcds[0], gt).item() * 1e3
            cd2 = chamfer(pcds[1], gt).item() * 1e3
            cd3 = chamfer(pcds[2], gt).item() * 1e3

            # pmd loss
            pmd_losses = []
            for delta in deltas:
                pmd_losses.append(jittor.sum(delta**2))

            pmd = jittor.sum(jittor.stack(pmd_losses)) / 3

            pmd_item = pmd.item()

            _metrics = [pmd_item, cd3]
            test_losses.update([cd1, cd2, cd3, pmd_item])

            test_metrics.update(_metrics)
            if taxonomy_id not in category_metrics:
                category_metrics[taxonomy_id] = AverageMeter(Metrics.names())
            category_metrics[taxonomy_id].update(_metrics)

            t.set_description(
                'Test[%d/%d] Taxonomy = %s Sample = %s Losses = %s Metrics = %s'
                %
                (model_idx + 1, n_samples, taxonomy_id, model_id,
                 ['%.4f' % l
                  for l in test_losses.val()], ['%.4f' % m for m in _metrics]))

    # Print testing results
    print(
        '============================ TEST RESULTS ============================'
    )
    print('Taxonomy', end='\t')
    print('#Sample', end='\t')
    for metric in test_metrics.items:
        print(metric, end='\t')
    print()

    for taxonomy_id in category_metrics:
        print(taxonomy_id, end='\t')
        print(category_metrics[taxonomy_id].count(0), end='\t')
        for value in category_metrics[taxonomy_id].avg():
            print('%.4f' % value, end='\t')
        print()

    print('Overall', end='\t\t\t')
    for value in test_metrics.avg():
        print('%.4f' % value, end='\t')
    print('\n')

    # Add testing results to TensorBoard
    if test_writer is not None:
        test_writer.add_scalar('Loss/Epoch/cd1', test_losses.avg(0), epoch_idx)
        test_writer.add_scalar('Loss/Epoch/cd2', test_losses.avg(1), epoch_idx)
        test_writer.add_scalar('Loss/Epoch/cd3', test_losses.avg(2), epoch_idx)
        test_writer.add_scalar('Loss/Epoch/delta', test_losses.avg(3),
                               epoch_idx)
        for i, metric in enumerate(test_metrics.items):
            test_writer.add_scalar('Metric/%s' % metric, test_metrics.avg(i),
                                   epoch_idx)
    model.train()
    return test_losses.avg(2)