コード例 #1
0
def BiLinear(repin, repout):
    """ Cheap bilinear layer (adds parameters for each part of the input which can be
        interpreted as a linear map from a part of the input to the output representation)."""
    Wdim, weight_proj = bilinear_weights(repout, repin)
    #self.w = TrainVar(objax.random.normal((Wdim,)))#xavier_normal((Wdim,))) #TODO: revert to xavier
    logging.info(f"BiW components: dim:{Wdim}")
    return _BiLinear(Wdim, weight_proj)
コード例 #2
0
ファイル: haiku.py プロジェクト: mfinzi/equivariant-MLP
def BiLinear(repin, repout):
    """ Cheap bilinear layer (adds parameters for each part of the input which can be
        interpreted as a linear map from a part of the input to the output representation)."""
    Wdim, weight_proj = bilinear_weights(repout, repin)
    return lambda x: hkBiLinear(weight_proj, Wdim)(x)
コード例 #3
0
 def __init__(self, repin, repout):
     super().__init__()
     Wdim, weight_proj = bilinear_weights(repout, repin)
     self.weight_proj = torchify_fn(jit(weight_proj))
     self.bi_params = nn.Parameter(torch.randn(Wdim))
     logging.info(f"BiW components: dim:{Wdim}")
コード例 #4
0
ファイル: objax.py プロジェクト: sanmayphy/equivariant-MLP
 def __init__(self, repin, repout):
     super().__init__()
     Wdim, weight_proj = bilinear_weights(repout,repin)
     self.weight_proj = jit(weight_proj)
     self.w = TrainVar(objax.random.normal((Wdim,)))#xavier_normal((Wdim,))) #TODO: revert to xavier
     logging.info(f"BiW components: dim:{Wdim}")