Exemplo n.º 1
0
def main():
    print('Training Process\nInitializing...\n')
    config.init_env()

    val_dataset = data_pth.pc_data(config.pc_net.data_root, status=STATUS_TEST)

    val_loader = DataLoader(val_dataset, batch_size=config.pc_net.validation.batch_sz,
                            num_workers=config.num_workers,shuffle=True,drop_last=True)

    # create model
    net = DGCNN(n_neighbor=config.pc_net.n_neighbor,num_classes=config.pc_net.num_classes)
    net = torch.nn.DataParallel(net)
    net = net.to(device=config.device)
    optimizer = optim.Adam(net.parameters(), config.pc_net.train.lr,
                          weight_decay=config.pc_net.train.weight_decay)

    print(f'loading pretrained model from {config.pc_net.ckpt_file}')
    checkpoint = torch.load(config.pc_net.ckpt_file)
    net.module.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    best_prec1 = checkpoint['best_prec1']
    resume_epoch = checkpoint['epoch']

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 5, 0.5)
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device=config.device)

    # for p in net.module.feature.parameters():
    #     p.requires_grad = False

    with torch.no_grad():
        prec1 = validate(val_loader, net, resume_epoch)

    print('curr accuracy: ', prec1)
    print('best accuracy: ', best_prec1)
Exemplo n.º 2
0
def main():
    print('Training Process\nInitializing...\n')
    config.init_env()

    train_dataset = data_pth.pc_data(config.pc_net.data_root,
                                     status=STATUS_TRAIN)
    val_dataset = data_pth.pc_data(config.pc_net.data_root, status=STATUS_TEST)

    train_loader = DataLoader(train_dataset,
                              batch_size=config.pc_net.train.batch_sz,
                              num_workers=config.num_workers,
                              shuffle=True,
                              drop_last=True)
    val_loader = DataLoader(val_dataset,
                            batch_size=config.pc_net.validation.batch_sz,
                            num_workers=config.num_workers,
                            shuffle=True,
                            drop_last=True)

    best_prec1 = 0
    resume_epoch = 0
    best_map = 0
    # create model
    net = DGCNN(n_neighbor=config.pc_net.n_neighbor,
                num_classes=config.pc_net.num_classes)
    net = torch.nn.DataParallel(net)
    net = net.to(device=config.device)
    optimizer = optim.Adam(net.parameters(),
                           config.pc_net.train.lr,
                           weight_decay=config.pc_net.train.weight_decay)

    if config.pc_net.train.resume:
        print(f'loading pretrained model from {config.pc_net.ckpt_file}')
        checkpoint = torch.load(config.pc_net.ckpt_file)
        net.module.load_state_dict(
            {k[7:]: v
             for k, v in checkpoint['model'].items()})
        optimizer.load_state_dict(checkpoint['optimizer'])
        best_prec1 = checkpoint['best_prec1']
        if config.pc_net.train.resume_epoch is not None:
            resume_epoch = config.pc_net.train.resume_epoch
        else:
            resume_epoch = checkpoint['epoch'] + 1

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 20, 0.7)
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device=config.device)

    for epoch in range(resume_epoch, config.pc_net.train.max_epoch):

        lr_scheduler.step(epoch=epoch)
        # train
        train(train_loader, net, criterion, optimizer, epoch)
        # validation
        with torch.no_grad():
            prec1, retrieval_map = validate(val_loader, net, epoch)

        # save checkpoints
        if prec1 > best_prec1:
            best_prec1 = prec1
            save_ckpt(epoch, best_prec1, net, optimizer)

        if retrieval_map > best_map:
            best_map = retrieval_map

        # save_record(epoch, prec1, net.module)
        print('curr accuracy: ', prec1)
        print('best accuracy: ', best_prec1)
        print('best map: ', best_map)

    print('Train Finished!')