def __init__(self, DIM, batchsize_para, optimizee, train_steps, retain_graph_flag=False, reset_theta=False, reset_function_from_IID_distirbution=True): self.optimizee = optimizee self.beta = 1 self.train_steps = train_steps self.retain_graph_flag = retain_graph_flag self.reset_theta = reset_theta self.reset_function_from_IID_distirbution = reset_function_from_IID_distirbution self.state = None self.DIM = DIM self.batchsize_para = batchsize_para self.retraction = Retraction(1) for parameters in optimizee.parameters(): print(torch.sum(parameters)) self.M = torch.randn(self.batchsize_para, self.DIM, self.DIM) for i in range(self.batchsize_para): self.M[i] = torch.eye(self.DIM) self.M = self.M.cuda() self.M.requires_grad = True self.P_tangent = torch.zeros(self.batchsize_para, self.DIM, self.DIM).cuda()
def __init__(self, opt, DIM, outputDIM, batchsize_para, optimizee, train_steps, retain_graph_flag=False, reset_theta=False, reset_function_from_IID_distirbution=True): self.criterion = ContrastiveLoss().cuda() self.retraction = Retraction(1) self.optimizee = optimizee self.opt = opt self.beta = 1 self.eiglayer1 = EigLayer() self.mexp = M_Exp() self.train_steps = train_steps #self.num_roll=num_roll self.retain_graph_flag = retain_graph_flag self.reset_theta = reset_theta self.reset_function_from_IID_distirbution = reset_function_from_IID_distirbution self.state = None self.DIM = DIM self.outputDIM = outputDIM self.batchsize_para = batchsize_para self.global_loss_graph = 0 # global loss for optimizing LSTM self.losses = [] # KEEP each loss of all epoches for parameters in optimizee.parameters(): print(torch.sum(parameters)) self.M = torch.randn(self.batchsize_para, self.DIM, self.outputDIM).cuda() for i in range(self.batchsize_para): ''' U = torch.empty(self.DIM, self.DIM) nn.init.orthogonal_(U) D=torch.abs(torch.diag(torch.rand(self.DIM)+0.6)) self.M[i]=(U.mm(D)).mm(U.t()) #self.M[i]=self.M[i]+0.001*torch.trace(self.M[i])*torch.eye(self.DIM) ''' nn.init.orthogonal_(self.M[i]) self.M = self.M.cuda() self.M.requires_grad = True
import torch import torch.nn as nn from torch.autograd import Variable from timeit import default_timer as timer import time import math, random import numpy as np from losses.LOSS import ContrastiveLoss from ReplyBuffer import ReplayBuffer from retraction import Retraction retraction = Retraction(1) grad_list = [] def print_grad(grad): grad_list.append(grad) def nll_loss(data, label): n = data.shape[0] L = 0 for i in range(n): L = L + torch.abs(data[i][label[i]]) return L def f(inputs, target, M): loss_function = torch.nn.CrossEntropyLoss(reduction='sum')