Exemple #1
0
def train(model, train_loader, optimizer, scheduler, criterion, opt):
    model.train()
    for i, data in enumerate(train_loader):
        opt.iter += 1
        if not opt.multi_gpus:
            data = data.to(opt.device)
            gt = data.y
        else:
            gt = torch.cat([data_batch.y for data_batch in data],
                           0).to(opt.device)

        # ------------------ zero, output, loss
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, gt)

        # ------------------ optimization
        loss.backward()
        optimizer.step()

        opt.losses.update(loss.item())
        # ------------------ show information
        if opt.iter % opt.print_freq == 0:
            logging.info(
                'Epoch:{}\t Iter: {}\t [{}/{}]\t Loss: {Losses.avg: .4f}'.
                format(opt.epoch,
                       opt.iter,
                       i + 1,
                       len(train_loader),
                       Losses=opt.losses))
            opt.losses.reset()

        # ------------------ tensor board log
        info = {
            'loss': loss,
            'test_value': opt.test_value,
            'lr': scheduler.get_lr()[0]
        }
        for tag, value in info.items():
            opt.writer.scalar_summary(tag, value, opt.iter)

    # ------------------ save checkpoints
    # min or max. based on the metrics
    is_best = (opt.test_value < opt.best_value)
    opt.best_value = min(opt.test_value, opt.best_value)

    model_cpu = {k: v.cpu() for k, v in model.state_dict().items()}
    # optim_cpu = {k: v.cpu() for k, v in optimizer.state_dict().items()}
    save_checkpoint(
        {
            'epoch': opt.epoch,
            'state_dict': model_cpu,
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_value': opt.best_value,
        }, is_best, opt.save_path, opt.post)
Exemple #2
0
def train(model, train_loader, optimizer, scheduler, criterion, opt):
    opt.losses.reset()
    model.train()
    with tqdm(train_loader) as tqdm_loader:
        for i, data in enumerate(tqdm_loader):
            opt.iter += 1

            # tqdm progress bar
            desc = 'Epoch:{}  Iter:{}  [{}/{}]  Loss:{Losses.avg: .4f}'\
                .format(opt.epoch, opt.iter, i + 1, len(train_loader), Losses=opt.losses)
            tqdm_loader.set_description(desc)

            if not opt.multi_gpus:
                data = data.to(opt.device)
            inputs = torch.cat((data.pos.transpose(
                2, 1).unsqueeze(3), data.x.transpose(2, 1).unsqueeze(3)), 1)
            gt = data.y.to(opt.device)
            # ------------------ zero, output, loss
            optimizer.zero_grad()
            out = model(inputs)
            loss = criterion(out, gt)

            # ------------------ optimization
            loss.backward()
            optimizer.step()

            opt.losses.update(loss.item())

            # ------------------ tensor board log
            info = {
                'loss': loss,
                'test_value': opt.test_value,
                'lr': scheduler.get_lr()[0]
            }
            for tag, value in info.items():
                opt.logger.scalar_summary(tag, value, opt.iter)

    # ------------------ save checkpoints
    # min or max. based on the metrics
    is_best = (opt.test_value < opt.best_value)
    opt.best_value = max(opt.test_value, opt.best_value)

    model_cpu = {k: v.cpu() for k, v in model.state_dict().items()}
    save_checkpoint(
        {
            'epoch': opt.epoch,
            'state_dict': model_cpu,
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_value': opt.best_value,
        }, is_best, opt.ckpt_dir, opt.exp_name)
Exemple #3
0
def save_ckpt(model, optimizer, scheduler, opt):
    # ------------------ save ckpt
    is_best = (opt.test_value > opt.best_value)
    if opt.save_best_only:
        save_flag = is_best
    else:
        save_flag = True
    if save_flag:
        opt.best_value = max(opt.test_value, opt.best_value)
        model_cpu = {k: v.cpu() for k, v in model.state_dict().items()}
        save_checkpoint(
            {
                'epoch': opt.epoch,
                'state_dict': model_cpu,
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_value': opt.best_value,
            }, is_best, opt.save_path, opt.post)
        opt.losses.reset()
        opt.test_value = 0.
Exemple #4
0
def main():
    opt = OptInit().get_args()
    logging.info('===> Creating dataloader ...')
    train_dataset = GeoData.S3DIS(opt.data_dir,
                                  opt.area,
                                  True,
                                  pre_transform=T.NormalizeScale())
    train_loader = DenseDataLoader(train_dataset,
                                   batch_size=opt.batch_size,
                                   shuffle=True,
                                   num_workers=4)
    test_dataset = GeoData.S3DIS(opt.data_dir,
                                 opt.area,
                                 train=False,
                                 pre_transform=T.NormalizeScale())
    test_loader = DenseDataLoader(test_dataset,
                                  batch_size=opt.batch_size,
                                  shuffle=False,
                                  num_workers=0)
    opt.n_classes = train_loader.dataset.num_classes

    logging.info('===> Loading the network ...')
    model = DenseDeepGCN(opt).to(opt.device)
    if opt.multi_gpus:
        model = DataParallel(DenseDeepGCN(opt)).to(opt.device)

    logging.info('===> loading pre-trained ...')
    model, opt.best_value, opt.epoch = load_pretrained_models(
        model, opt.pretrained_model, opt.phase)
    logging.info(model)

    logging.info('===> Init the optimizer ...')
    criterion = torch.nn.CrossEntropyLoss().to(opt.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, opt.lr_adjust_freq,
                                                opt.lr_decay_rate)
    optimizer, scheduler, opt.lr = load_pretrained_optimizer(
        opt.pretrained_model, optimizer, scheduler, opt.lr)

    logging.info('===> Init Metric ...')
    opt.losses = AverageMeter()
    opt.test_value = 0.

    logging.info('===> start training ...')
    for _ in range(opt.epoch, opt.total_epochs):
        opt.epoch += 1
        logging.info('Epoch:{}'.format(opt.epoch))
        train(model, train_loader, optimizer, criterion, opt)
        if opt.epoch % opt.eval_freq == 0 and opt.eval_freq != -1:
            test(model, test_loader, opt)
        scheduler.step()

        # ------------------ save checkpoints
        # min or max. based on the metrics
        is_best = (opt.test_value < opt.best_value)
        opt.best_value = max(opt.test_value, opt.best_value)
        model_cpu = {k: v.cpu() for k, v in model.state_dict().items()}
        save_checkpoint(
            {
                'epoch': opt.epoch,
                'state_dict': model_cpu,
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_value': opt.best_value,
            }, is_best, opt.ckpt_dir, opt.exp_name)

        # ------------------ tensorboard log
        info = {
            'loss': opt.losses.avg,
            'test_value': opt.test_value,
            'lr': scheduler.get_lr()[0]
        }
        opt.writer.add_scalars('epoch', info, opt.iter)

    logging.info('Saving the final model.Finish!')