示例#1
0
def test_gin_config():
    try:
        from neuralnet_pytorch import gin_nnt as gin
    except ImportError:
        print('Please install Gin-config first and run this test again')
        return

    @gin.configurable('net')
    def assert_same(dtype, activation, loss, optimizer, scheduler):
        assert dtype is T.float32
        assert isinstance(activation(), T.nn.Tanh)
        assert isinstance(loss(), T.nn.L1Loss)

    import os
    config_file = os.path.join(os.path.dirname(__file__),
                               'test_files/test.gin')
    gin.parse_config_file(config_file)
    assert_same()
示例#2
0
                                 shuffle=False,
                                 num_workers=10,
                                 collate_fn=collate)

        mon.set_iter(0)
        mon.clear_scalar_stats(file_cat + '/test chamfer')
        print('Testing...')
        with T.set_grad_enabled(False):
            for itt, batch in enumerate(test_loader):
                init_pc, image, gt_pc = batch
                if nnt.cuda_available:
                    init_pc = init_pc.cuda()
                    image = image.cuda()
                    gt_pc = [pc.cuda() for pc in gt_pc] if isinstance(
                        gt_pc, (list, tuple)) else gt_pc.cuda()

                pred_pc = net(image, init_pc)
                loss = sum([
                    normalized_chamfer_loss(pred[None], gt[None])
                    for pred, gt in zip(pred_pc, gt_pc)
                ]) / (3. * len(gt_pc))
                loss = nnt.utils.to_numpy(loss)
                with mon:
                    mon.plot(file_cat + '/test chamfer', loss)
    print('Testing finished!')


if __name__ == '__main__':
    gin.parse_config_file(config_file)
    test_each_category()
示例#3
0
def eval_queue(q, ckpt, dataset, bs, use_jit, use_amp, opt_level):
    assert dataset in ('cifar10', 'cifar100')

    # TODO: passing `model` this way is ugly
    @gin.configurable('Classifier')
    def _make_network(model, **kwargs):
        net = model(num_classes=10 if dataset == 'cifar10' else 100, default_init=False)
        net = net.to(eval_device)
        net.eval()
        return net

    lock.acquire_read()
    mon = nnt.Monitor(current_folder=ckpt)
    lock.release_read()

    cfg = glob(os.path.join(mon.file_folder, '*.gin'))
    cfg = cfg[0] if len(cfg) == 1 else config_file  # fall back
    gin.parse_config_file(cfg)
    net = _make_network()

    dataset = torchvision.datasets.CIFAR10 if dataset == 'cifar10' else torchvision.datasets.CIFAR100
    eval_data = dataset(root='./data', train=False, download=True, transform=transform_test)
    eval_loader = T.utils.data.DataLoader(eval_data, batch_size=bs, shuffle=False)

    if nnt.cuda_available:
        eval_loader = nnt.DataPrefetcher(eval_loader, device=eval_device)

    while True:
        item = q.get()
        if item == 'DONE':
            break
        elif item is not None:
            lock.acquire_read()
            mon.load_state()
            lock.release_read()
            mon.epoch, mon.iter = item
            # TODO: find a better way to load the current state of `monitor`
            lock.acquire_read()
            states = mon.load('tmp.pt', method='torch')
            lock.release_read()
            net.load_state_dict(states['model_state_dict'])
            if use_jit:
                img = T.rand(1, 3, 32, 32).to(eval_device)
                net = T.jit.trace(net, img)

            if use_amp:
                import apex
                net = apex.amp.initialize(net, opt_level=opt_level)
                apex.amp.load_state_dict(states['amp'])

            with T.set_grad_enabled(False):
                losses, accuracies = [], []
                for itt, batch in enumerate(eval_loader):
                    batch = nnt.utils.batch_to_device(batch, eval_device)

                    loss, acc = get_loss(net, *batch)
                    losses.append(nnt.utils.to_numpy(loss))
                    accuracies.append(nnt.utils.to_numpy(acc))

            mon.plot('test-loss', np.mean(losses))
            mon.plot('test-accuracy', np.mean(accuracies))
            mon.flush()