def __init__(self,
                 DIM,
                 batchsize_para,
                 optimizee,
                 train_steps,
                 retain_graph_flag=False,
                 reset_theta=False,
                 reset_function_from_IID_distirbution=True):
        self.optimizee = optimizee
        self.beta = 1
        self.train_steps = train_steps
        self.retain_graph_flag = retain_graph_flag
        self.reset_theta = reset_theta
        self.reset_function_from_IID_distirbution = reset_function_from_IID_distirbution
        self.state = None

        self.DIM = DIM
        self.batchsize_para = batchsize_para
        self.retraction = Retraction(1)

        for parameters in optimizee.parameters():
            print(torch.sum(parameters))

        self.M = torch.randn(self.batchsize_para, self.DIM, self.DIM)
        for i in range(self.batchsize_para):

            self.M[i] = torch.eye(self.DIM)

        self.M = self.M.cuda()
        self.M.requires_grad = True

        self.P_tangent = torch.zeros(self.batchsize_para, self.DIM,
                                     self.DIM).cuda()
示例#2
0
    def __init__(self,
                 opt,
                 DIM,
                 outputDIM,
                 batchsize_para,
                 optimizee,
                 train_steps,
                 retain_graph_flag=False,
                 reset_theta=False,
                 reset_function_from_IID_distirbution=True):
        self.criterion = ContrastiveLoss().cuda()
        self.retraction = Retraction(1)
        self.optimizee = optimizee

        self.opt = opt

        self.beta = 1
        self.eiglayer1 = EigLayer()
        self.mexp = M_Exp()

        self.train_steps = train_steps
        #self.num_roll=num_roll
        self.retain_graph_flag = retain_graph_flag
        self.reset_theta = reset_theta
        self.reset_function_from_IID_distirbution = reset_function_from_IID_distirbution
        self.state = None

        self.DIM = DIM
        self.outputDIM = outputDIM
        self.batchsize_para = batchsize_para

        self.global_loss_graph = 0  # global loss for optimizing LSTM
        self.losses = []  # KEEP each loss of all epoches

        for parameters in optimizee.parameters():
            print(torch.sum(parameters))

        self.M = torch.randn(self.batchsize_para, self.DIM,
                             self.outputDIM).cuda()
        for i in range(self.batchsize_para):
            '''
            U = torch.empty(self.DIM, self.DIM)
            nn.init.orthogonal_(U)
            D=torch.abs(torch.diag(torch.rand(self.DIM)+0.6))
            self.M[i]=(U.mm(D)).mm(U.t())
            #self.M[i]=self.M[i]+0.001*torch.trace(self.M[i])*torch.eye(self.DIM)
            '''
            nn.init.orthogonal_(self.M[i])

        self.M = self.M.cuda()
        self.M.requires_grad = True
import torch
import torch.nn as nn
from torch.autograd import Variable
from timeit import default_timer as timer
import time
import math, random
import numpy as np

from losses.LOSS import ContrastiveLoss
from ReplyBuffer import ReplayBuffer
from retraction import Retraction

retraction = Retraction(1)
grad_list = []


def print_grad(grad):
    grad_list.append(grad)


def nll_loss(data, label):
    n = data.shape[0]
    L = 0
    for i in range(n):
        L = L + torch.abs(data[i][label[i]])
    return L


def f(inputs, target, M):

    loss_function = torch.nn.CrossEntropyLoss(reduction='sum')