Exemplo n.º 1
0
def main(config):
    val_loader, test_loader = get_loader(config)
    n_data = len(val_loader.dataset)
    logger.info(f"length of validation dataset: {n_data}")
    n_data = len(test_loader.dataset)
    logger.info(f"length of testing dataset: {n_data}")

    model, criterion = build_multi_part_segmentation(config)
    model.cuda()
    criterion.cuda()

    model = DistributedDataParallel(model,
                                    device_ids=[config.local_rank],
                                    broadcast_buffers=False)

    # optionally resume from a checkpoint
    if config.load_path:
        assert os.path.isfile(config.load_path)
        load_checkpoint(config, model)
        logger.info("==> checking loaded ckpt")
        validate('resume', 'val', val_loader, model, criterion, config)
        validate('resume', 'test', test_loader, model, criterion, config)
def main(config):
    test_loader = get_loader(config)
    n_data = len(test_loader.dataset)
    logger.info(f"length of testing dataset: {n_data}")

    model, criterion = build_multi_part_segmentation(config)
    model.cuda()
    criterion.cuda()

    model = DistributedDataParallel(model,
                                    device_ids=[config.local_rank],
                                    broadcast_buffers=False)

    # optionally resume from a checkpoint
    assert os.path.isfile(config.load_path)
    load_checkpoint(config, model)
    logger.info("==> checking loaded ckpt")
    for i in range(4):
        acc, msIoU, mIoU = validate(f'V{i + 1}',
                                    test_loader,
                                    model,
                                    criterion,
                                    config,
                                    num_votes=i + 1)
Exemplo n.º 3
0
def main(config):
    global best_acc
    global best_epoch
    train_loader, val_loader, test_loader = get_loader(config)
    n_data = len(train_loader.dataset)
    logger.info(f"length of training dataset: {n_data}")
    n_data = len(val_loader.dataset)
    logger.info(f"length of validation dataset: {n_data}")
    n_data = len(test_loader.dataset)
    logger.info(f"length of testing dataset: {n_data}")

    model, criterion = build_multi_part_segmentation(config)
    model.cuda()
    criterion.cuda()

    if config.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=config.batch_size *
                                    dist.get_world_size() / 16 *
                                    config.base_learning_rate,
                                    momentum=config.momentum,
                                    weight_decay=config.weight_decay)
    elif config.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=config.batch_size *
                                     dist.get_world_size() / 16 *
                                     config.base_learning_rate,
                                     weight_decay=config.weight_decay)
    elif config.optimizer == 'adamW':
        optimizer = torch.optim.AdamW(
            model.parameters(),
            # lr=config.batch_size * dist.get_world_size() / 16 * config.base_learning_rate,
            lr=config.base_learning_rate,
            weight_decay=config.weight_decay)
    else:
        raise NotImplementedError(
            f"Optimizer {config.optimizer} not supported")

    scheduler = get_scheduler(optimizer, len(train_loader), config)

    model = DistributedDataParallel(model,
                                    device_ids=[config.local_rank],
                                    broadcast_buffers=False)

    # optionally resume from a checkpoint
    if config.load_path:
        assert os.path.isfile(config.load_path)
        load_checkpoint(config, model, optimizer, scheduler)
        logger.info("==> checking loaded ckpt")
        validate('resume', 'val', val_loader, model, criterion, config)
        validate('resume', 'test', test_loader, model, criterion, config)

    # tensorboard
    if dist.get_rank() == 0:
        summary_writer = SummaryWriter(log_dir=config.log_dir)
    else:
        summary_writer = None

    # routine
    for epoch in range(config.start_epoch, config.epochs + 1):
        train_loader.sampler.set_epoch(epoch)

        tic = time.time()
        loss = train(epoch, train_loader, model, criterion, optimizer,
                     scheduler, config)

        logger.info('epoch {}, total time {:.2f}, lr {:.5f}'.format(
            epoch, (time.time() - tic), optimizer.param_groups[0]['lr']))
        if epoch % config.val_freq == 0:
            validate(epoch, 'val', val_loader, model, criterion, config)
            validate(epoch, 'test', test_loader, model, criterion, config)
        else:
            validate(epoch,
                     'val',
                     val_loader,
                     model,
                     criterion,
                     config,
                     num_votes=1)
            validate(epoch,
                     'test',
                     test_loader,
                     model,
                     criterion,
                     config,
                     num_votes=1)

        if dist.get_rank() == 0:
            # save model
            save_checkpoint(config, epoch, model, optimizer, scheduler)

        if summary_writer is not None:
            # tensorboard logger
            summary_writer.add_scalar('ins_loss', loss, epoch)
            summary_writer.add_scalar('learning_rate',
                                      optimizer.param_groups[0]['lr'], epoch)