Ejemplo n.º 1
0
    def make_M(self, coords):

        #print(self.W)

        g = self.W[coords]
        m, n = self.tile_rows, self.tile_cols

        def makea(i):
            return torch.diag(g[i,:]) \
                 + torch.diag(torch.cat((self.g_wl, self.g_wl * 2 * torch.ones(n-2), self.g_wl))) \
                 + torch.diag(self.g_wl * -1 * torch.ones(n-1), diagonal = 1) \
                 + torch.diag(self.g_wl * -1 * torch.ones(n-1), diagonal = -1) \
                 + torch.diag(torch.cat((self.g_s_wl_in[i].view(1), torch.zeros(n - 2), self.g_s_wl_out[i].view(1))))

        def makec(j):
            return torch.zeros(m, m * n).index_put(
                (torch.arange(m), torch.arange(m) * n + j), g[:, j])

        def maked(j):
            d = torch.zeros(m, m * n)

            i = 0
            d[i, j] = -self.g_s_bl_in[j] - self.g_bl - g[i, j]
            d[i, n * (i + 1) + j] = self.g_bl

            for i in range(1, m):
                d[i, n * (i - 1) + j] = self.g_bl
                d[i, n * i + j] = -self.g_bl - g[i, j] - self.g_bl
                d[i, j] = self.g_bl

            i = m - 1
            d[i, n * (i - 1) + j] = self.g_bl
            d[i, n * i + j] = -self.g_s_bl_out[j] - g[i, j] - self.g_bl

            return d

        A = torch.block_diag(*tuple(makea(i) for i in range(m)))
        B = torch.block_diag(*tuple(-torch.diag(g[i, :]) for i in range(m)))
        C = torch.cat([makec(j) for j in range(n)], dim=0)
        D = torch.cat([maked(j) for j in range(0, n)], dim=0)

        M = torch.cat((torch.cat((A, B), dim=1), torch.cat((C, D), dim=1)),
                      dim=0)

        #print(A, B, C, D, sep='\n')
        #print("M", M)
        M = torch.inverse(M)

        self.saved_tiles[str(coords)] = M

        return M
Ejemplo n.º 2
0
def build_diag_block(blk_list):
    """Build a block diagonal Tensor from a list of Tensors."""
    if blk_list[0].ndim == 2:
        return torch.block_diag(*blk_list)
    elif blk_list[0].ndim == 3:
        blks = []
        for idx_ant in range(blk_list[0].shape[-1]):
            blks_per_ant = []
            for idx_link in range(len(blk_list)):
                blks_per_ant.append(blk_list[idx_link][:, :, idx_ant])
            blks.append(torch.block_diag(*blks_per_ant))
        return torch.dstack(blks)
    else:
        raise Exception("Invalid input dimension")
Ejemplo n.º 3
0
def model_enu_to_ned(enu_model, scale, zlimit=None):
    device = torch.device('cuda')
    id_ = torch.tensor([[1]], dtype=torch.float32).to(device)
    neg_yxz = torch.tensor([[0, -1, 0], [-1, 0, 0], [0, 0, -1]],
                           dtype=torch.float32,
                           device=device) * scale
    neg_yxz_tri = torch.block_diag(
        id_, torch.block_diag(neg_yxz, torch.block_diag(neg_yxz, neg_yxz)))
    ned_model = torch.matmul(
        torch.from_numpy(enu_model).to(device), neg_yxz_tri)
    idx = torch.logical_and(
        torch.logical_and(ned_model[:, 3] < zlimit, ned_model[:, 6] < zlimit),
        ned_model[:, 9] < zlimit)
    ned_model = ned_model[idx, :]
    return np.array(ned_model.cpu())
Ejemplo n.º 4
0
    def forward(self, g, h):
        feat = self.feat_gc(
            g, h)  # size = (sum_N, F_out), sum_N is num of nodes in this batch
        device = feat.device
        assign_tensor = self.pool_gc(
            g, h)  # size = (sum_N, N_a), N_a is num of nodes in pooled graph.
        assign_tensor = F.softmax(assign_tensor, dim=1)
        assign_tensor = torch.split(assign_tensor,
                                    g.batch_num_nodes().tolist())
        assign_tensor = torch.block_diag(
            *assign_tensor)  # size = (sum_N, batch_size * N_a)

        h = torch.matmul(torch.t(assign_tensor), feat)
        adj = g.adjacency_matrix(transpose=False, ctx=device)
        adj_new = torch.sparse.mm(adj, assign_tensor)
        adj_new = torch.mm(torch.t(assign_tensor), adj_new)

        if self.link_pred:
            current_lp_loss = torch.norm(adj.to_dense() - torch.mm(
                assign_tensor, torch.t(assign_tensor))) / np.power(
                    g.number_of_nodes(), 2)
            self.loss_log['LinkPredLoss'] = current_lp_loss

        for loss_layer in self.reg_loss:
            loss_name = str(type(loss_layer).__name__)
            self.loss_log[loss_name] = loss_layer(adj, adj_new, assign_tensor)

        return adj_new, h
Ejemplo n.º 5
0
    def loss(self, *views):
        # https: // www.uta.edu / math / _docs / preprint / 2014 / rep2014_04.pdf
        # H is n_views * n_samples * k

        # Subtract the mean from each output
        views = [view - view.mean(dim=0) for view in views]

        # Concatenate all views and from this get the cross-covariance matrix
        all_views = torch.cat(views, dim=1)
        C = torch.matmul(all_views.T, all_views)

        # Get the block covariance matrix placing Xi^TX_i on the diagonal
        D = torch.block_diag(*[torch.matmul(view.T, view) for view in views])

        # In MCCA our eigenvalue problem Cv = lambda Dv

        # Use the cholesky method to whiten the matrix C R^{-1}CRv = lambda v
        R = torch.cholesky(D, upper=True)

        C_whitened = torch.inverse(R.T) @ C @ torch.inverse(R)

        [eigvals, eigvecs] = torch.symeig(C_whitened, eigenvectors=True)

        # Sort eigenvalues so lviewest first
        idx = torch.argsort(eigvals, descending=True)

        # Sum the first #latent_dims values (after subtracting 1).
        corr = (eigvals[idx][:self.latent_dims] - 1).sum()

        return -corr
Ejemplo n.º 6
0
    def gen_training_tensors(self):
        """Make PyTorch x and y tensors for training DisjointDomainNet"""

        item_mat, context_mat, attr_mat = dd.make_io_mats(
            ctx_per_domain=self.ctx_per_domain,
            attrs_per_context=self.attrs_per_context,
            attrs_set_per_item=self.attrs_set_per_item,
            n_domains=self.n_domains,
            cluster_info=self.cluster_info,
            last_domain_cluster_info=self.last_domain_cluster_info,
            repeat_attrs_over_domains=self.repeat_attrs_over_domains,
            share_ctx=self.share_ctx,
            share_attr_units_in_domain=self.share_attr_units_in_domain,
            padding_attrs=self.padding_attrs)

        x_item = torch.tensor(item_mat, dtype=self.torchfp, device=self.device)
        x_context = torch.tensor(context_mat,
                                 dtype=self.torchfp,
                                 device=self.device)
        y = torch.tensor(attr_mat, dtype=self.torchfp, device=self.device)

        y_domain_mask = torch.block_diag(*[
            torch.ones([s // self.n_domains for s in y.shape],
                       dtype=self.torchfp,
                       device=self.device) for _ in range(self.n_domains)
        ])

        return x_item, x_context, y, y_domain_mask
def measure_block_diag_perf(num_mats, mat_dim_size, iters, dtype, device):
    if dtype in [torch.float32, torch.float64]:
        mats = [
            torch.rand(mat_dim_size, mat_dim_size, dtype=dtype, device=device)
            for i in range(num_mats)
        ]
    else:
        mats = [
            torch.randint(0xdeadbeef, (mat_dim_size, mat_dim_size),
                          dtype=dtype,
                          device=device) for i in range(num_mats)
        ]

    # do one warmup iteration
    for _ in range(2):
        torch_time_start = time.time()
        for i in range(iters):
            torch_result = torch.block_diag(*mats)
        torch_time = time.time() - torch_time_start

        workaround_time_start = time.time()
        for i in range(iters):
            workaround_result = block_diag_workaround(*mats)
        workaround_time = time.time() - workaround_time_start

    if not torch_result.equal(workaround_result):
        print("Results do not match!!")
        exit(1)
    return torch_time, workaround_time
Ejemplo n.º 8
0
    def __init__(
        self,
        in_hidden_channels,
        mid_hidden_channels,
        sphere_size_lat,
        sphere_size_long,
        sphere_message,
        act,
        lmax,
    ):
        super(SpinConvBlock, self).__init__()
        self.in_hidden_channels = in_hidden_channels
        self.mid_hidden_channels = mid_hidden_channels
        self.sphere_size_lat = sphere_size_lat
        self.sphere_size_long = sphere_size_long
        self.sphere_message = sphere_message
        self.act = act
        self.lmax = lmax
        self.num_groups = self.in_hidden_channels // 8

        self.ProjectLatLongSphere = ProjectLatLongSphere(
            sphere_size_lat, sphere_size_long)
        assert self.sphere_message in [
            "fullconv",
            "rotspharmwd",
        ]
        if self.sphere_message in ["rotspharmwd"]:
            self.sph_froms2grid = FromS2Grid(
                (self.sphere_size_lat, self.sphere_size_long), self.lmax)
            self.mlp = nn.Linear(
                self.in_hidden_channels * (self.lmax + 1)**2,
                self.mid_hidden_channels,
            )
            self.sphlength = (self.lmax + 1)**2
            rotx = torch.zeros(
                self.sphere_size_long) + (2 * math.pi / self.sphere_size_long)
            roty = torch.zeros(self.sphere_size_long)
            rotz = torch.zeros(self.sphere_size_long)

            self.wigner = []
            for xrot, yrot, zrot in zip(rotx, roty, rotz):
                _blocks = []
                for l_degree in range(self.lmax + 1):
                    _blocks.append(o3.wigner_D(l_degree, xrot, yrot, zrot))
                self.wigner.append(torch.block_diag(*_blocks))

        if self.sphere_message == "fullconv":
            padding = self.sphere_size_long // 2
            self.conv1 = nn.Conv1d(
                self.in_hidden_channels * self.sphere_size_lat,
                self.mid_hidden_channels,
                self.sphere_size_long,
                groups=self.in_hidden_channels // 8,
                padding=padding,
                padding_mode="circular",
            )
            self.pool = nn.AvgPool1d(sphere_size_long)

        self.GroupNorm = nn.GroupNorm(self.num_groups,
                                      self.mid_hidden_channels)
Ejemplo n.º 9
0
    def __init__(self, inputs):
        Xs = [b.X for b in inputs]
        As = [b.A for b in inputs]
        Ys = [b.y for b in inputs]

        self.A = torch.block_diag(*As).cuda()
        self.X = torch.cat(Xs).cuda()
        self.Y = Ys
        self.graph_sizes = [len(x) for x in Xs]
Ejemplo n.º 10
0
 def forward(self):
     """
     Forward computation
     Returns
     -------
     weight : torch.Tensor
         Weight tensor (block-diagonal) with same shape as self.shape.
     """
     # Create matrix from blocks
     return torch.block_diag(*torch.unbind(self.weight, dim=0),
                             self.last_weight)
Ejemplo n.º 11
0
def get_tree_masks(n_tokens, n_transformers):
    masks = torch.empty((n_transformers, n_tokens, n_tokens),
                        dtype=torch.float32)
    partitions = min(n_tokens // 2, 2**(n_transformers - 1))
    for i in range(n_transformers):
        assert n_tokens % partitions == 0, "not a power of 2"
        blocks = [torch.ones(n_tokens // partitions, n_tokens // partitions)
                  ] * partitions
        masks[i] = torch.block_diag(*blocks)
        if partitions > 1:
            partitions //= 2
    return masks.unsqueeze(1)
Ejemplo n.º 12
0
 def forward(self, x_cond, x_to_film):
     gb = self.gb_weights(x_cond).unsqueeze(1)
     gamma, beta = torch.chunk(gb, 2, dim=-1)
     out = (1 + gamma) * x_to_film + beta
     if self.layer_norm is not None:
         out = self.layer_norm(out)
     out = [
         torch.block_diag(*list(out_b.chunk(self.blocks, 0)))
         for out_b in out
     ]
     out = torch.stack(out)
     return out[:, :, :out.size(1)]
Ejemplo n.º 13
0
 def _sparsify(self, tensor, rank_j, k_split, scheme):
     """Sparsify to the desired number of features (rank_j)."""
     # in some rare occassion Messi fails
     # --> then let's just use grouped projective sparsifier and stitch it
     try:
         return self._sparsify_with_messi(tensor, rank_j, k_split, scheme)
     except ValueError:
         weights_hat = super()._sparsify(tensor, rank_j, k_split, scheme)
         u_chunks, v_chunks = zip(*weights_hat)
         print(
             "Messi failed."
             " Falling back to grouped decomposition sparsification"
         )
         return [(torch.cat(u_chunks, dim=1), torch.block_diag(*v_chunks))]
Ejemplo n.º 14
0
    def loss(self, *views):
        # Subtract the mean from each output
        views = _demean(*views)

        # Concatenate all views and from this get the cross-covariance matrix
        all_views = torch.cat(views, dim=1)
        C = all_views.T @ all_views

        # Get the block covariance matrix placing Xi^TX_i on the diagonal
        D = torch.block_diag(
            *[
                (1 - self.r) * m.T @ m + self.r * torch.eye(m.shape[1], device=m.device)
                for i, m in enumerate(views)
            ]
        )

        C = C - torch.block_diag(*[view.T @ view for view in views]) + D

        R = mat_pow(D, -0.5, self.eps)

        # In MCCA our eigenvalue problem Cv = lambda Dv
        C_whitened = R @ C @ R.T

        eigvals = torch.linalg.eigvalsh(C_whitened)

        # Sort eigenvalues so lviewest first
        idx = torch.argsort(eigvals, descending=True)

        eigvals = eigvals[idx[: self.latent_dims]]

        # leaky relu encourages the gradient to be driven by positively correlated dimensions while also encouraging
        # dimensions associated with spurious negative correlations to become more positive
        eigvals = torch.nn.LeakyReLU()(eigvals[torch.gt(eigvals, 0)] - 1)

        corr = eigvals.sum()

        return -corr
Ejemplo n.º 15
0
def optimal_JJT_blk():
    jac_list = 0
    bc = BATCH_SIZE * num_classes
    # L = []

    with backpack(TRIAL(MODE)):
        loss = loss_function(output, y)
        loss.backward(retain_graph=True)
    for name, param in model.named_parameters():
        trial_vals = param.trial
        # L.append([trial_vals / BATCH_SIZE, name])
        jac_list += torch.block_diag(*trial_vals)
        param.trial = None
    JJT = jac_list / BATCH_SIZE
    return JJT
Ejemplo n.º 16
0
    def __init__(self, in_features, out_features, num_of_modules, bias=True):
        """
        extended torch.nn module which mask connection.
        Argumens
        ------------------
        mask [torch.tensor]:
            the shape is (n_input_feature, n_output_feature).
            the elements are 0 or 1 which declare un-connected or
            connected.
        bias [bool]:
            flg of bias.
        """
        super(MaskedLinear, self).__init__()
        mask = torch.block_diag(
            *chain([torch.ones(in_features, out_features)] * num_of_modules))
        self.input_features = mask.shape[0]
        self.output_features = mask.shape[1]
        if isinstance(mask, torch.Tensor):
            self.mask = mask.type(torch.float).t()
        else:
            self.mask = torch.tensor(mask, dtype=torch.float).t()

        self.mask = nn.Parameter(self.mask, requires_grad=False)

        # nn.Parameter is a special kind of Tensor, that will get
        # automatically registered as Module's parameter once it's assigned
        # as an attribute. Parameters and buffers need to be registered, or
        # they won't appear in .parameters() (doesn't apply to buffers), and
        # won't be converted when e.g. .cuda() is called. You can use
        # .register_buffer() to register buffers.
        # nn.Parameters require gradients by default.
        self.weight = nn.Parameter(
            torch.Tensor(self.output_features, self.input_features))

        if bias:
            self.bias = nn.Parameter(torch.Tensor(self.output_features))
        else:
            # You should always register all possible parameters, but the
            # optional ones can be None if you want.
            self.register_parameter('bias', None)
        self.reset_parameters()

        # mask weight
        self.weight.data = self.weight.data * self.mask
Ejemplo n.º 17
0
def test_attn_mask():

    torch.set_default_dtype(torch.float64)

    T, N, D = 8, 1, 20

    attn_mask = torch.triu(torch.ones(T, T), diagonal=1) * -1e12

    x = torch.randn(T * N * D).requires_grad_(True)
    mha = L2MultiheadAttention(D, 1)

    y = mha(x.reshape(T, N, D), attn_mask=attn_mask)
    yhat = mha(x.reshape(T, N, D), attn_mask=attn_mask, rm_nonself_grads=True)
    print(torch.norm(y - yhat))

    # Construct full Jacobian.
    def func(x):
        return mha(x.reshape(T, N, D), attn_mask=attn_mask).reshape(-1)

    jac = torch.autograd.functional.jacobian(func, x)

    # Exact diagonal block of Jacobian.
    jac = jac.reshape(T, D, T, D)
    blocks = []
    for i in range(T):
        blocks.append(jac[i, :, i, :])
    jac_block_diag = torch.block_diag(*blocks)

    # Simulated diagonal block of Jacobian.
    def selfonly_func(x):
        return mha(x.reshape(T, N, D), attn_mask=attn_mask, rm_nonself_grads=True).reshape(-1)
    simulated_jac_block_diag = torch.autograd.functional.jacobian(selfonly_func, x)

    print(torch.norm(simulated_jac_block_diag - jac_block_diag))

    import matplotlib.pyplot as plt

    fig, axs = plt.subplots(1, 3)
    axs[0].imshow(jac_block_diag)
    axs[1].imshow(simulated_jac_block_diag)
    axs[2].imshow(torch.abs(simulated_jac_block_diag - jac_block_diag))
    plt.savefig("jacobian.png")
Ejemplo n.º 18
0
    def forward(self, z):
        rho, n_elements = self.rho, self.n_elements
        # flatten latent variables & get dims
        z = z.view(z.size(0), -1)
        batch_size, z_dim = z.size()
        assert z_dim % n_elements == 0
        groups = z_dim // n_elements

        # apply mask to extract groups
        mask = torch.block_diag(
            *[torch.ones(n_elements) for i in range(groups)]
        ).unsqueeze(0).repeat(batch_size, 1, 1).to(z.device)
        z_groups = z.unsqueeze(1).repeat(1, groups, 1)
        masked_z = z_groups * mask
        rho_hat = masked_z.norm(p=2, dim=-1)

        sparsity_penalty = (
            (rho * (torch.log(rho) - torch.log(rho_hat))) + (
                    (1 - rho) * (torch.log(1 - rho) - torch.log(1 - rho_hat)))).sum(-1)
        return sparsity_penalty.mean()
Ejemplo n.º 19
0
    def get_original_module(self):
        """Return an "unprojected" version of the module."""
        encoding = self.encoding
        decoding = self.decoding

        # get groups
        num_groups = self.groups_enc.item()

        # get scheme
        scheme = FoldScheme(self.scheme_value.item())

        # get resulting encoding and decoding weights in right shape
        weight_enc = torch.block_diag(*[
            scheme.fold(w_enc)
            for w_enc in torch.chunk(encoding.weight, num_groups, dim=0)
        ])
        weight_dec = scheme.fold(decoding.weight)

        # retrieve module kwargs from scheme and kernel_size
        kwargs_original = scheme.compose_kwargs(decoding, encoding)
        try:
            k_size = kwargs_original["kernel_size"]
        except KeyError:
            k_size = ()

        # build original weights
        w_original = scheme.unfold(weight_dec @ weight_enc, k_size)

        # build original module
        kwargs_weight = self._get_init_kwargs(w_original, 1)
        kwargs_weight.pop("groups")
        module_original = self._ungrouped_module_type(**kwargs_weight,
                                                      **kwargs_original)
        module_original.weight = nn.Parameter(w_original)
        bias = decoding.bias
        module_original.bias = None if bias is None else nn.Parameter(bias)

        return module_original
Ejemplo n.º 20
0
    def __init__(self,
                 input_size,
                 hidden_size,
                 bias=True,
                 blocksize=128,
                 sparsity=0.5,
                 mode='sparse'):
        super(GatedLSTMCell, self).__init__(input_size,
                                            hidden_size,
                                            bias,
                                            num_chunks=4)
        self.block_size = blocksize

        if mode == 'dynamic_torch':
            self.blockmul = self.blockmul1
            self.ih_nblock = int(
                (input_size * hidden_size * 4) / (blocksize * blocksize))
            self.hh_nblock = int(
                (hidden_size * hidden_size * 4) / (blocksize * blocksize))
            self.g_ih = nn.Sequential(nn.Linear(input_size, self.ih_nblock),
                                      Sparsify1D(sparsity))
            self.g_hh = nn.Sequential(nn.Linear(hidden_size, self.hh_nblock),
                                      Sparsify1D(sparsity))

        elif mode == 'static':
            # todo:@amir: clean it
            self.blockmul = self.blockmul3
            if sparsity == 0.5:
                block = torch.ones(
                    (int(((hidden_size / blocksize) / 2) * blocksize),
                     int((((hidden_size * 4) / blocksize) / 2) * blocksize)),
                    dtype=torch.float)
                self.g_ih = torch.block_diag(*([block] * 2))
                self.g_hh = torch.block_diag(*([block] * 2))
            elif sparsity == 0.75:
                block = torch.ones(
                    (int(((hidden_size / blocksize) / 4) * blocksize),
                     int((((hidden_size * 4) / blocksize) / 4) * blocksize)),
                    dtype=torch.float)
                self.g_ih = torch.block_diag(*([block] * 4))
                self.g_hh = torch.block_diag(*([block] * 4))
            elif sparsity == 0.9:
                block = torch.ones(
                    (int(((hidden_size / blocksize) / 12) * blocksize),
                     int((((hidden_size * 4) / blocksize) / 12) * blocksize)),
                    dtype=torch.float)
                self.g_ih = torch.block_diag(*([block] * 12))
                self.g_hh = torch.block_diag(*([block] * 12))
        elif mode == 'dynamic':
            self.blockmul = self.blockmul2
            self.ih_nblock = int(self.input_size / self.block_size)
            self.hh_nblock = int(self.hidden_size / self.block_size)
            self.g_ih = nn.Sequential(
                nn.Linear(input_size, self.ih_nblock * self.hh_nblock * 4),
                Sparsify1D(sparsity))
            self.g_hh = nn.Sequential(
                nn.Linear(hidden_size, self.hh_nblock * self.hh_nblock * 4),
                Sparsify1D(sparsity))
        else:
            raise NotImplementedError('something goes wrong...')
        self.epoch = 1
Ejemplo n.º 21
0
    def __init__(self, **net_params):
        super(FamilyTreeNet, self).__init__()

        # Merge default params with overrides and make them properties
        net_params = {**net_defaults, **net_params}
        for key, val in net_params.items():
            setattr(self, key, val)

        self.device, self.torchfp, self.zeros_fn = util.init_torch(self.device, self.torchfp)
        if self.device.type == 'cuda':
            print('Using CUDA')
        else:
            print('Using CPU')
        
        if self.seed is None:
            self.seed = torch.seed()
        else:
            torch.manual_seed(self.seed)
        
        # Get training tensors
        self.n_trees = len(self.trees)
        assert self.n_trees > 0, 'Must learn at least one tree'
        self.person1_mat = self.zeros_fn((0, 0))
        self.person2_mat = self.zeros_fn((0, 0))
        self.p2_tree_mask = self.zeros_fn((0, 0))
        self.rel_mat = self.zeros_fn((0, 12 if self.share_rel_units else 0))
        self.full_tree = familytree.FamilyTree([], [])
        self.each_tree = []
        
        for i, tree_name in enumerate(self.trees):
            this_tree = familytree.get_tree(name=tree_name)
            self.full_tree += this_tree
            self.each_tree.append(this_tree)
            this_p1, this_rels, this_p2 = this_tree.get_io_mats(zeros_fn=self.zeros_fn, cat_fn=torch.cat) 
            
            self.person1_mat = torch.block_diag(self.person1_mat, this_p1)
            self.person2_mat = torch.block_diag(self.person2_mat, this_p2)
            self.p2_tree_mask = torch.block_diag(self.p2_tree_mask,
                                                   torch.ones_like(this_p2))
            if self.share_rel_units:
                self.rel_mat = torch.cat((self.rel_mat, this_rels), 0)
            else:
                self.rel_mat = torch.block_diag(self.rel_mat, this_rels)
            
            if i == 0:
                self.n_inputs_first, self.n_people_first = this_p1.shape
        
        self.n_inputs, self.person1_units = self.person1_mat.shape
        self.rel_units = self.rel_mat.shape[1]
        self.person2_units = self.person2_mat.shape[1] 
                    
        # Make layers
        def make_layer(in_size, out_size):
            return nn.Linear(in_size, out_size, bias=self.use_biases).to(self.device)

        self.person1_to_repr = make_layer(self.person1_units, self.person1_repr_units)
        self.rel_to_repr = make_layer(self.rel_units, self.rel_repr_units)
        total_repr_units = self.person1_repr_units + self.rel_repr_units
        self.repr_to_hidden = make_layer(total_repr_units, self.hidden_units)

        if self.use_preoutput:
            self.hidden_to_preoutput = make_layer(self.hidden_units, self.preoutput_units)
            self.preoutput_to_person2 = make_layer(self.preoutput_units, self.person2_units)
        else:
            self.hidden_to_person2 = make_layer(self.hidden_units, self.person2_units)
        
        # Initialize with small random weights
        def init_uniform(param, offset, prange):
            a = offset - prange/2
            b = offset + prange/2
            nn.init.uniform_(param.data, a=a, b=b)

        def init_normal(param, offset, prange):
            nn.init.normal_(param.data, mean=offset, std=prange/2)

        def init_default(*_):
            pass

        init_fns = {'default': init_default, 'uniform': init_uniform, 'normal': init_normal}

        with torch.no_grad():
            for layer in self.children():
                try:
                    init_fns[self.weight_init_type](layer.weight, self.weight_init_offset, self.weight_init_range)
                except KeyError:
                    raise ValueError('Weight initialization type not recognized')

                if layer.bias is not None:
                    try:
                        init_fns[self.bias_init_type](layer.bias, self.bias_init_offset, self.bias_init_range)
                    except KeyError:
                        raise ValueError('Bias initialization type notn recognized')

        # For simplicity, instead of using the "liberal" loss function described in the paper, make the targets 
        # 0.1 (for false) and 0.9 (for true) and use regular mean squared error.
        if self.act_fn == torch.sigmoid:
            self.person2_train_target = (1-self.target_offset) * self.person2_mat + self.target_offset/2
        else:
            self.person2_train_target = self.person2_mat

        self.criterion = self.loss_fn(reduction=self.loss_reduction)
Ejemplo n.º 22
0
 def f(x):
     # Test an op with TensorList input
     y = torch.block_diag(x, x)
     return y
Ejemplo n.º 23
0
        ### Blocked NGD version
        # start_time = time.time()
        # JJT_opt_blk = optimal_JJT_blk()
        # print(torch.norm(JJT_opt))
        # print(JJT_opt)
        # time_opt = time.time() - start_time

        # plotting NGD kernel for some iterations
        if PLOT and batch_idx in [2, 10, 50, 600]:
            # JJT_opt_blk = optimal_JJT_blk()

            JJT_opt, L, _, _ = optimal_JJT(True)
            x = torch.ones(1, BATCH_SIZE, BATCH_SIZE)
            x = x.repeat(num_classes, 1, 1)
            eye_blk = torch.block_diag(*x)
            diff = JJT_opt - JJT_opt * eye_blk
            # u, s, vh = torch.linalg.svd(diff)
            # s_normal = torch.cumsum(s, dim = 0)/torch.sum(s)
            # print(s_normal.numpy())
            # fig, ax = plt.subplots()
            # im = ax.plot(s_normal)
            # print(s)
            # fig.colorbar(im,  orientation='horizontal')
            # plt.show()

            fig, ax = plt.subplots()
            im = ax.imshow(JJT_opt - JJT_opt * eye_blk, cmap='viridis')
            fig.colorbar(im, orientation='horizontal')

            plt.show()
Ejemplo n.º 24
0
 def other_ops(self):
     a = torch.randn(4)
     b = torch.randn(4)
     c = torch.randint(0, 8, (5, ), dtype=torch.int64)
     e = torch.randn(4, 3)
     f = torch.randn(4, 4, 4)
     size = [0, 1]
     dims = [0, 1]
     return (
         torch.atleast_1d(a),
         torch.atleast_2d(a),
         torch.atleast_3d(a),
         torch.bincount(c),
         torch.block_diag(a),
         torch.broadcast_tensors(a),
         torch.broadcast_to(a, (4)),
         # torch.broadcast_shapes(a),
         torch.bucketize(a, b),
         torch.cartesian_prod(a),
         torch.cdist(e, e),
         torch.clone(a),
         torch.combinations(a),
         torch.corrcoef(a),
         # torch.cov(a),
         torch.cross(e, e),
         torch.cummax(a, 0),
         torch.cummin(a, 0),
         torch.cumprod(a, 0),
         torch.cumsum(a, 0),
         torch.diag(a),
         torch.diag_embed(a),
         torch.diagflat(a),
         torch.diagonal(e),
         torch.diff(a),
         torch.einsum("iii", f),
         torch.flatten(a),
         torch.flip(e, dims),
         torch.fliplr(e),
         torch.flipud(e),
         torch.kron(a, b),
         torch.rot90(e),
         torch.gcd(c, c),
         torch.histc(a),
         torch.histogram(a),
         torch.meshgrid(a),
         torch.lcm(c, c),
         torch.logcumsumexp(a, 0),
         torch.ravel(a),
         torch.renorm(e, 1, 0, 5),
         torch.repeat_interleave(c),
         torch.roll(a, 1, 0),
         torch.searchsorted(a, b),
         torch.tensordot(e, e),
         torch.trace(e),
         torch.tril(e),
         torch.tril_indices(3, 3),
         torch.triu(e),
         torch.triu_indices(3, 3),
         torch.vander(a),
         torch.view_as_real(torch.randn(4, dtype=torch.cfloat)),
         torch.view_as_complex(torch.randn(4, 2)),
         torch.resolve_conj(a),
         torch.resolve_neg(a),
     )
Ejemplo n.º 25
0
    def circuit_solve(self, conductances,  v_wl_in, v_bl_in, v_bl_out, v_wl_out):

        start_time = time.time()
        g_wl, g_bl = self.g_wl, self.g_bl
        g_s_wl_in, g_s_wl_out = torch.ones(self.tile_rows) * 1, torch.ones(self.tile_rows) * 1e-9
        g_s_bl_in, g_s_bl_out = torch.ones(self.tile_rows) * 1e-9, torch.ones(self.tile_rows) * 1
        m, n = self.tile_rows, self.tile_cols
        
        A = torch.block_diag(*tuple(torch.diag(conductances[i,:])
                          + torch.diag(torch.cat((g_wl, g_wl * 2 * torch.ones(n-2), g_wl)))
                          + torch.diag(g_wl * -1 *torch.ones(n-1), diagonal = 1)
                          + torch.diag(g_wl * -1 *torch.ones(n-1), diagonal = -1)
                          + torch.diag(torch.cat((g_s_wl_in[i].view(1), torch.zeros(n - 2), g_s_wl_out[i].view(1))))
                                   for i in range(m)))

        B = torch.block_diag(*tuple(-torch.diag(conductances[i,:]) for i in range(m)))
        
        def makec(j):
            c = torch.zeros(m, m*n)
            for i in range(m):
                c[i,n*(i) + j] = conductances[i,j]
            return c
  
        C = torch.cat([makec(j) for j in range(n)],dim=0)
        
        def maked(j):
            d = torch.zeros(m, m*n)

            def c(k): 
                return(k - 1)
            
            i = 1
            d[c(i),c(j)] = -g_s_bl_in[c(j)] - g_bl - conductances[c(i),c(j)]
            d[c(i), n*i + c(j)] = g_bl

            i = m
            d[c(i), n*(i-2) + c(j)] = g_bl
            d[c(i), n*(i-1) + c(j)] = -g_s_bl_out[c(j)] - conductances[c(i),c(j)] - g_bl

            for i in range(2, m):
                d[c(i), n*(i-2) + c(j)] = g_bl
                d[c(i), n*(i-1) + c(j)] = -g_bl - conductances[c(i),c(j)] - g_bl
                d[c(i), n*(i) + c(j)] = g_bl

            return d

        D = torch.cat([maked(j) for j in range(1,n+1)], dim=0)

        E = torch.cat([torch.cat(((v_wl_in[i]*g_s_wl_in[i]).view(1), #EW
                                  torch.zeros(n-2),
                                  (v_wl_out[i]*g_s_wl_out[i]).view(1)))
                                 for i in range(m)] +
                      [torch.cat(((-v_bl_in[i]*g_s_bl_in[i]).view(1), #EB
                                  torch.zeros(m-2),
                                  (-v_bl_in[i]*g_s_bl_out[i]).view(1)))
                                 for i in range(n)]
        ).view(-1, 1)

        M = torch.cat((torch.cat((A,B),dim=1), torch.cat((C,D),dim=1)), dim=0)
        
        V, _ = torch.solve(E, M)
        
        V = torch.chunk(torch.solve(E, M)[0], 2)

        return torch.sum((V[1] - V[0]).view(m,n)*conductances,dim=0)