def test_reshaping():
    pman = ProductManifold((Sphere(), 10), (Sphere(), (3, 2)),
                           (Euclidean(), ()))
    point = [
        Sphere().random_uniform(5, 10),
        Sphere().random_uniform(5, 3, 2),
        Euclidean().random_normal(5),
    ]
    tensor = pman.pack_point(*point)
    assert tensor.shape == (5, 10 + 3 * 2 + 1)
    point_new = pman.unpack_tensor(tensor)
    for old, new in zip(point, point_new):
        np.testing.assert_allclose(old, new)
def test_component_inner_product():
    pman = ProductManifold((Sphere(), 10), (Sphere(), (3, 2)),
                           (Euclidean(), ()))
    point = [
        Sphere().random_uniform(5, 10),
        Sphere().random_uniform(5, 3, 2),
        Euclidean().random_normal(5),
    ]
    tensor = pman.pack_point(*point)
    tangent = torch.randn_like(tensor)
    tangent = pman.proju(tensor, tangent)

    inner = pman.component_inner(tensor, tangent)
    assert inner.shape == (5, pman.n_elements)
def test_unpack_tensor_method():
    pman = ProductManifold((Sphere(), 10), (Sphere(), (3, 2)),
                           (Euclidean(), ()))
    point = pman.random(4, pman.n_elements)
    parts = point.unpack_tensor()
    assert parts[0].shape == (4, 10)
    assert parts[1].shape == (4, 3, 2)
    assert parts[2].shape == (4, )
def test_from_point():
    point = [
        Sphere().random_uniform(5, 10),
        Sphere().random_uniform(5, 3, 2),
        Euclidean().random_normal(5),
    ]
    pman = ProductManifold.from_point(*point, batch_dims=1)
    assert pman.n_elements == (10 + 3 * 2 + 1)
def test_from_point_checks_shapes():
    point = [
        Sphere().random_uniform(5, 10),
        Sphere().random_uniform(3, 3, 2),
        Euclidean().random_normal(5),
    ]
    pman = ProductManifold.from_point(*point)
    assert pman.n_elements == (5 * 10 + 3 * 3 * 2 + 5 * 1)
    with pytest.raises(ValueError) as e:
        _ = ProductManifold.from_point(*point, batch_dims=1)
    assert e.match("Not all parts have same batch shape")
示例#6
0
    def _build_optimizer(self):
        # params = list(
        #     filter(
        #         lambda p: p.requires_grad,
        #         chain(self.model.parameters(), self.criterion.parameters()),
        #     )
        # )
        params_dict = {}
        _default_manifold = Euclidean()
        for name, p in chain(self.model.named_parameters(),
                             self.criterion.named_parameters()):
            if not p.requires_grad:
                continue
            if isinstance(p, (ManifoldParameter, ManifoldTensor)):
                _manifold = p.manifold
            else:
                _manifold = _default_manifold
            _manifold_name = _manifold.__class__.__name__
            if not _manifold_name in params_dict:
                ref_grad = _manifold.egrad2rgrad(p.new_zeros(1), p.new_ones(1))
                coef = 1 if ref_grad == 1 else 1
                #print(f"lr={self.args.lr}, ref={ref_grad.item()}")
                params_dict[_manifold_name] = dict(
                    params=[],
                    lr_rectifier=ref_grad.reciprocal().item() * coef)
            params_dict[_manifold_name]['params'].append(p)
        params = params_dict.values()

        if self.args.fp16:
            if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
                logger.info(
                    "NOTE: your device does NOT support faster training with --fp16, "
                    "please switch to FP32 which is likely to be faster")
            if self.args.memory_efficient_fp16:
                self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(
                    self.args, params)
            else:
                self._optimizer = optim.FP16Optimizer.build_optimizer(
                    self.args, params)
        else:
            if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
                logger.info(
                    "NOTE: your device may support faster training with --fp16"
                )
            self._optimizer = optim.build_optimizer(self.args, params)

        if self.args.use_bmuf:
            self._optimizer = optim.FairseqBMUF(self.args, self._optimizer)

        # We should initialize the learning rate scheduler immediately after
        # building the optimizer, so that the initial learning rate is set.
        self._lr_scheduler = lr_scheduler.build_lr_scheduler(
            self.args, self.optimizer)
        self._lr_scheduler.step_update(0)