예제 #1
0
def update_auxgradient(model, amodel, projection='ARGMAX', beta=1.):
    """ update gradient of u by approximate gradient of argmax-projection
        g_u = (g_u * (d P(tu)/d (tu)), approximate using straight-through est.
    """
    if model.qlevels != 2 and projection == 'ARGMAX':
        print 'Only binary labels are supported, exiting ...'
        exit()

    for i, (name, p) in enumerate(model.named_parameters()):
        w = su.view_w_as_u(
            amodel.auxparams[i].data,
            model.qlevels)  # convert to desired matrix format (m x d)
        g = su.view_w_as_u(
            p.grad.data,
            model.qlevels)  # convert to desired matrix format (m x d)

        if projection == 'ARGMAX':  # grad through argmax using grad of sign function
            dw = w[:, 0] - w[:, 1]
            ind = torch.ones(dw.size(), device=dw.device, dtype=dw.dtype)
            ind[dw.lt(-1.0)] = 0.
            ind[dw.gt(1.0)] = 0.
            ind.unsqueeze_(dim=1)
            ind = torch.cat((ind, -ind), dim=1)  # m x 2
            ind = torch.cat((ind, -ind), dim=1)  # m x 4
            ind = ind.view(-1, 2, 2)  # m x d x d (d==2)
            g = torch.bmm(g.unsqueeze_(dim=1),
                          ind.div_(2.))  # (m x 1 x d) x (m x d x d)
            g.squeeze_()

        elif projection == 'SOFTMAX':  # grad through softmax
            w = su.view_w_as_u(
                p.data,
                model.qlevels)  # convert to desired matrix format (m x d)
            ww = torch.bmm(w.unsqueeze(dim=2),
                           w.unsqueeze(dim=1))  # (m x d x 1) x (m x 1 x d)
            ww = torch.eye(w.size(1), dtype=w.dtype, device=w.device) - ww
            g = torch.bmm(g.unsqueeze_(dim=1), ww)  # (m x 1 x d) x (m x d x d)
            g.mul_(beta)
            g.squeeze_()

        elif projection == 'EUCLIDEAN':  # grad through sparsemax
            pass

        else:
            print 'Projection type "{0}" not recognized in update gradient, exiting ...'.format(
                projection)
            exit()

        p.grad.data = su.view_u_as_w(
            g, p.grad.data)  # convert N x d to original format
예제 #2
0
파일: pgdsimplex.py 프로젝트: tajanthan/pmf
def doround(model,
            device,
            scheme='ARGMAX',
            data=None,
            target=None,
            optimizer=None,
            criterion=None):
    """ do rounding given the feasible point in the polytope
    """
    if scheme == 'ARGMAX':
        for i, p in enumerate(model.parameters()):
            w = su.view_w_as_u(
                p.data,
                model.qlevels)  # convert to desired matrix format (N x d)

            wi = w.argmax(dim=1, keepdim=True)
            ids = torch.arange(model.qlevels, dtype=torch.long, device=device)
            ids = ids.repeat(w.size(0), 1)
            w = ids.eq(wi)
            w = w.float()
            p.data = su.view_u_as_w(w,
                                    p.data)  # convert N x d to original format
    else:
        print 'Rounding type "{0}" not recognized, returning ...'.format(
            scheme)
        return
예제 #3
0
def doround(model,
            device,
            scheme='ARGMAX',
            data=None,
            target=None,
            optimizer=None,
            criterion=None):
    """ do rounding given the feasible point in the polytope
    """
    if scheme == 'ARGMAX':
        for i, (name, p) in enumerate(model.named_parameters()):
            w = su.view_w_as_u(
                p.data,
                model.qlevels)  # convert to desired matrix format (N x d)

            wi = w.argmax(dim=1, keepdim=True)
            ids = torch.arange(model.qlevels, dtype=torch.long, device=device)
            ids = ids.repeat(w.size(0), 1)
            w = ids.eq(wi)
            w = w.float()
            p.data = su.view_u_as_w(w,
                                    p.data)  # convert N x d to original format
    elif scheme == 'ICM':
        if data is None or target is None or optimizer is None or criterion is None:
            pass  # use last minibatch gradient
        else:
            # compute full gradient
            pass
        for i, (name, p) in enumerate(model.named_parameters()):
            p.data = su.icm(model.qlevels, p.grad.data)

    else:
        print 'Rounding type "{0}" not recognized, returning ...'.format(
            scheme)
        return
예제 #4
0
파일: pgdsimplex.py 프로젝트: tajanthan/pmf
def simplex(model, device, projection='SOFTMAX', beta=1):
    """ project the parameters to the feasible polytope
    """
    if projection == 'ARGMAX':  # argmax based projection: rounding!
        doround(model, device, scheme='ARGMAX')
        return

    for i, p in enumerate(model.parameters()):
        w = su.view_w_as_u(
            p.data, model.qlevels)  # convert to desired matrix format (N x d)

        if projection == 'SOFTMAX':  # softmax based simplex projection
            w = F.softmax(w * beta, dim=1)
        elif projection == 'EUCLIDEAN':  # condat based (Euclidean) simplex projection
            w = su.sparsemax(w * beta, model.qlevels)
        else:
            print 'Projection type "{0}" not implemented, returning ...'.format(
                projection)

        p.data = su.view_u_as_w(w, p.data)  # convert N x d to original format
        assert (su.isfeasible(model.qlevels, p.data))
예제 #5
0
파일: pgdsimplex.py 프로젝트: tajanthan/pmf
def update_auxgradient(model, amodel, projection='ARGMAX', beta=1.):
    """ update gradient of u by approximate gradient of argmax-projection
        g_u = (g_u * (d P(tu)/d (tu)), approximate using straight-through est.
    """
    if model.qlevels != 2 and projection == 'ARGMAX':
        print 'Only binary labels are supported, exiting ...'
        exit()

    for i, p in enumerate(model.parameters()):
        w = su.view_w_as_u(
            amodel.auxparams[i].data,
            model.qlevels)  # convert to desired matrix format (m x d)
        g = su.view_w_as_u(
            p.grad.data,
            model.qlevels)  # convert to desired matrix format (m x d)

        if projection == 'ARGMAX':  # grad through argmax using grad of sign function
            dw = w[:, 0] - w[:, 1]
            ind = torch.ones(dw.size(), device=dw.device, dtype=dw.dtype)
            ind[dw.lt(-1.0)] = 0.
            ind[dw.gt(1.0)] = 0.
            ind.unsqueeze_(dim=1)
            ind = torch.cat((ind, -ind), dim=1)  # m x 2
            ind = torch.cat((ind, -ind), dim=1)  # m x 4
            ind = ind.view(-1, 2, 2)  # m x d x d (d==2)
            g = torch.bmm(g.unsqueeze_(dim=1),
                          ind.div_(2.))  # (m x 1 x d) x (m x d x d)
            g.squeeze_()

        elif projection == 'SOFTMAX':  # grad through softmax
            w = su.view_w_as_u(
                p.data,
                model.qlevels)  # convert to desired matrix format (m x d)
            w = F.softmax(w * beta,
                          dim=1)  # no need to call softmax again, use p.data
            ww = torch.bmm(w.unsqueeze(dim=2),
                           w.unsqueeze(dim=1))  # (m x d x 1) x (m x 1 x d)
            # This works slightly better than the true derivative! But approaches the true derivative when beta --> BETAMAX
            ww = torch.eye(w.size(1), dtype=w.dtype, device=w.device) - ww
            #ww = torch.diag_embed(w) - ww       # (du_i/dtu_j = u_i(1(i=j) - u_j)
            g = torch.bmm(g.unsqueeze_(dim=1), ww)  # (m x 1 x d) x (m x d x d)
            g.mul_(beta)
            g.squeeze_()

        elif projection == 'EUCLIDEAN':  # grad through sparsemax
            w = su.view_w_as_u(
                p.data,
                model.qlevels)  # convert to desired matrix format (m x d)
            s = w > 0  # indicator to the support set (m x d)
            s = s.to(dtype=w.dtype, device=w.device)
            sd = torch.sum(s, dim=1)  # no of positives
            sd = torch.div(1., sd)  # (m x d)
            sd = sd.unsqueeze(dim=1).expand(s.size(0),
                                            s.size(1)).unsqueeze(dim=2).expand(
                                                s.size(0), s.size(1),
                                                s.size(1))  # (m x d x d)

            ss = torch.bmm(s.unsqueeze(dim=2),
                           s.unsqueeze(dim=1))  # (m x d x 1) x (m x 1 x d)
            ss = torch.diag_embed(s) - torch.mul(
                ss, sd)  # du_i/dtu_j = 1(i=j) - s_is_j/|s(u)|
            g = torch.bmm(g.unsqueeze_(dim=1), ss)  # (m x 1 x d) x (m x d x d)
            g.mul_(beta)
            g.squeeze_()

        else:
            print 'Projection type "{0}" not recognized in update gradient, exiting ...'.format(
                projection)
            exit()

        p.grad.data = su.view_u_as_w(
            g, p.grad.data)  # convert N x d to original format