Esempio n. 1
0
def main(opt):
    # build dataloader
    train_loader = build_loader(opt)
    n_data = len(train_loader.dataset)
    logger.info(f"length of training dataset: {n_data}")
    # build model
    model = build_model(opt)
    CE = torch.nn.BCEWithLogitsLoss().cuda()
    # build optimizer
    if opt.optim == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     opt.lr,
                                     weight_decay=opt.weight_decay)
    elif opt.optim == 'adamW':
        optimizer = torch.optim.AdamW(model.parameters(),
                                      opt.lr,
                                      weight_decay=opt.weight_decay)
    elif opt.optim == 'sdg':
        optimizer = torch.optim.SGD(model.parameters(),
                                    opt.lr / 10.0 * opt.batchsize,
                                    momentum=opt.momentum,
                                    weight_decay=opt.weight_decay)
    else:
        raise NotImplementedError
    scheduler = get_scheduler(optimizer, len(train_loader), opt)

    # routine
    for epoch in range(1, opt.epochs + 1):
        tic = time.time()
        train(train_loader, model, optimizer, CE, scheduler, epoch, opt)
        logger.info('epoch {}, total time {:.2f}, learning_rate {}'.format(
            epoch, (time.time() - tic), optimizer.param_groups[0]['lr']))
        if (epoch) % 1 == 0:
            torch.save(
                model.state_dict(),
                os.path.join(opt.output_dir,
                             f"I3D_edge_epoch_{epoch}_ckpt.pth"))
            logger.info("model saved {}!".format(
                os.path.join(opt.output_dir,
                             f"I3D_edge_epoch_{epoch}_ckpt.pth")))
    torch.save(model.state_dict(),
               os.path.join(opt.output_dir, f"I3D_edge_last_ckpt.pth"))
    logger.info("model saved {}!".format(
        os.path.join(opt.output_dir, f"I3D_edge_last_ckpt.pth")))
    return os.path.join(opt.output_dir, f"I3D_edge_last_ckpt.pth")
Esempio n. 2
0
def main(config):
    global best_acc
    global best_epoch
    train_loader, val_loader, test_loader = get_loader(config)
    n_data = len(train_loader.dataset)
    logger.info(f"length of training dataset: {n_data}")
    n_data = len(val_loader.dataset)
    logger.info(f"length of validation dataset: {n_data}")
    n_data = len(test_loader.dataset)
    logger.info(f"length of testing dataset: {n_data}")

    model, criterion = build_multi_part_segmentation(config)
    model.cuda()
    criterion.cuda()

    if config.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=config.batch_size *
                                    dist.get_world_size() / 16 *
                                    config.base_learning_rate,
                                    momentum=config.momentum,
                                    weight_decay=config.weight_decay)
    elif config.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=config.batch_size *
                                     dist.get_world_size() / 16 *
                                     config.base_learning_rate,
                                     weight_decay=config.weight_decay)
    elif config.optimizer == 'adamW':
        optimizer = torch.optim.AdamW(
            model.parameters(),
            # lr=config.batch_size * dist.get_world_size() / 16 * config.base_learning_rate,
            lr=config.base_learning_rate,
            weight_decay=config.weight_decay)
    else:
        raise NotImplementedError(
            f"Optimizer {config.optimizer} not supported")

    scheduler = get_scheduler(optimizer, len(train_loader), config)

    model = DistributedDataParallel(model,
                                    device_ids=[config.local_rank],
                                    broadcast_buffers=False)

    # optionally resume from a checkpoint
    if config.load_path:
        assert os.path.isfile(config.load_path)
        load_checkpoint(config, model, optimizer, scheduler)
        logger.info("==> checking loaded ckpt")
        validate('resume', 'val', val_loader, model, criterion, config)
        validate('resume', 'test', test_loader, model, criterion, config)

    # tensorboard
    if dist.get_rank() == 0:
        summary_writer = SummaryWriter(log_dir=config.log_dir)
    else:
        summary_writer = None

    # routine
    for epoch in range(config.start_epoch, config.epochs + 1):
        train_loader.sampler.set_epoch(epoch)

        tic = time.time()
        loss = train(epoch, train_loader, model, criterion, optimizer,
                     scheduler, config)

        logger.info('epoch {}, total time {:.2f}, lr {:.5f}'.format(
            epoch, (time.time() - tic), optimizer.param_groups[0]['lr']))
        if epoch % config.val_freq == 0:
            validate(epoch, 'val', val_loader, model, criterion, config)
            validate(epoch, 'test', test_loader, model, criterion, config)
        else:
            validate(epoch,
                     'val',
                     val_loader,
                     model,
                     criterion,
                     config,
                     num_votes=1)
            validate(epoch,
                     'test',
                     test_loader,
                     model,
                     criterion,
                     config,
                     num_votes=1)

        if dist.get_rank() == 0:
            # save model
            save_checkpoint(config, epoch, model, optimizer, scheduler)

        if summary_writer is not None:
            # tensorboard logger
            summary_writer.add_scalar('ins_loss', loss, epoch)
            summary_writer.add_scalar('learning_rate',
                                      optimizer.param_groups[0]['lr'], epoch)
def main(config):
    train_loader, val_loader = get_loader(config)
    n_data = len(train_loader.dataset)
    logger.info(f"length of training dataset: {n_data}")
    n_data = len(val_loader.dataset)
    logger.info(f"length of validation dataset: {n_data}")

    if config.model_name == 'pointnet':
        model = PointNetSemSeg(config, config.input_features_dim)
    elif config.model_name == 'pointnet2_ssg':
        model = PointNet2SSGSemSeg(config, config.input_features_dim)
    elif config.model_name == 'pointnet2_msg':
        model = PointNet2MSGSemSeg(config, config.input_features_dim)
    else:
        raise NotImplementedError("error")

    # print(model)
    criterion = get_masked_CE_loss()

    model.cuda()
    criterion.cuda()

    if config.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=config.batch_size *
                                    dist.get_world_size() / 8 *
                                    config.base_learning_rate,
                                    momentum=config.momentum,
                                    weight_decay=config.weight_decay)
    elif config.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=config.base_learning_rate,
                                     weight_decay=config.weight_decay)
    elif config.optimizer == 'adamW':
        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=config.base_learning_rate,
                                      weight_decay=config.weight_decay)
    else:
        raise NotImplementedError(
            f"Optimizer {config.optimizer} not supported")

    scheduler = get_scheduler(optimizer, len(train_loader), config)

    # add find_unused_parameters=True to overcome the error "RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one"
    model = DistributedDataParallel(model,
                                    device_ids=[config.local_rank],
                                    broadcast_buffers=False,
                                    find_unused_parameters=True)

    runing_vote_logits = [
        np.zeros((config.num_classes, l.shape[0]), dtype=np.float32)
        for l in val_loader.dataset.sub_clouds_points_labels
    ]

    # optionally resume from a checkpoint
    if config.load_path:
        assert os.path.isfile(config.load_path)
        load_checkpoint(config, model, optimizer, scheduler)
        logger.info("==> checking loaded ckpt")
        validate('resume',
                 val_loader,
                 model,
                 criterion,
                 runing_vote_logits,
                 config,
                 num_votes=2)

    # tensorboard
    if dist.get_rank() == 0:
        summary_writer = SummaryWriter(log_dir=config.log_dir)
    else:
        summary_writer = None

    # routine
    for epoch in range(config.start_epoch, config.epochs + 1):
        train_loader.sampler.set_epoch(epoch)
        val_loader.sampler.set_epoch(epoch)
        train_loader.dataset.epoch = epoch - 1
        tic = time.time()
        loss = train(epoch, train_loader, model, criterion, optimizer,
                     scheduler, config)

        logger.info('epoch {}, total time {:.2f}, lr {:.5f}'.format(
            epoch, (time.time() - tic), optimizer.param_groups[0]['lr']))
        if epoch % config.val_freq == 0:
            validate(epoch,
                     val_loader,
                     model,
                     criterion,
                     runing_vote_logits,
                     config,
                     num_votes=2)

        if dist.get_rank() == 0:
            # save model
            save_checkpoint(config, epoch, model, optimizer, scheduler)

        if summary_writer is not None:
            # tensorboard logger
            summary_writer.add_scalar('ins_loss', loss, epoch)
            summary_writer.add_scalar('learning_rate',
                                      optimizer.param_groups[0]['lr'], epoch)

    validate('Last',
             val_loader,
             model,
             criterion,
             runing_vote_logits,
             config,
             num_votes=20)
Esempio n. 4
0
    #define the config
    config = config_args()

    #define the train dataset
    dataset = Dataset_PSE(config=config)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=config.TRAIN.BATCH,
                            shuffle=config.TRAIN.SHUFFLE,
                            num_workers=config.TRAIN.WORKERS)

    #define the psenet model , psenet loss , optimizer and schedular
    psenet = PSENET(config=config).to(torch.device('cuda:' + config.CUDA.GPU))
    pseloss = PSELOSS(config=config)
    optimizer = torch.optim.Adam(psenet.parameters(), lr=config.TRAIN.LR)

    schedular = get_scheduler(config, optimizer)

    #define the logger and the tensorboard writer

    loggerinfo = LoggerInfo(config, num_dataset=len(dataset))
    if os.path.exists(config.MODEL.MODEL_SAVE_DIR) == False:
        os.mkdir(config.MODEL.MODEL_SAVE_DIR)
    else:
        check_outputs(config.MODEL.MODEL_SAVE_DIR)
    nowtime = datetime.datetime.now().strftime('%Y-%m-%d@%H:%M:%S')
    os.mkdir(os.path.join(config.MODEL.MODEL_SAVE_DIR, nowtime))
    os.mkdir(os.path.join(config.MODEL.MODEL_SAVE_DIR + '/' + nowtime, 'runs'))
    logger = setup_logger(
        log_file_path=os.path.join(config.MODEL.MODEL_SAVE_DIR + '/' +
                                   nowtime, 'log.txt'))
    logger.info(print_config(config))