def update_global(self, b_i=0): if is_use_bias: self.global_b.append(my_copy(b_i)) if self.global_num_task >= 0: with torch.no_grad(): self.global_gamma.append(my_copy(self.style_gama.squeeze())) self.global_beta.append(my_copy(self.style_beta.squeeze())) self.global_num_task += 1
def Para2List(L_P): n = L_P.__len__() L_t = [] s_sum = 0.0 for i in range(n): L_t.append(my_copy(L_P[i])) s_sum += L_P[i].sum().detach().data * 1.0 return L_t, s_sum
def forward(self, input, W, b, b_i=0, task_id=-1, UPDATE_GLOBAL=False): if not UPDATE_GLOBAL: if task_id >= 0: # test history tasks W_i = W * self.global_gamma[task_id] + self.global_beta[task_id] if is_use_bias: b_i = self.global_b[task_id] out = F_conv(input, W_i, bias=b + b_i, stride=1, padding=1) elif task_id == -1: # train the current task W_i = W * self.style_gama + self.style_beta if is_use_bias: b_i = self.b out = F_conv(input, W_i, bias=b + b_i, stride=1, padding=1) return out else: self.global_gamma.append(my_copy(self.style_gama)) self.global_beta.append(my_copy(self.style_beta)) if is_use_bias: self.global_b.append(my_copy(self.b))
def forward(self, input=0, W=0, b=0, b_i=0, task_id=-1, UPDATE_GLOBAL=False): if not UPDATE_GLOBAL: if task_id >= 0: # test history tasks W0 = W.t() * self.global_gamma[task_id] + self.global_beta[task_id] out = input.mm(W0) + b if is_use_bias: out += self.global_b[task_id] elif task_id == -1: # train the current task W0 = W.t() * self.gamma + self.beta out = input.mm(W0) + b if is_use_bias: out = out + self.b return out else: self.global_gamma.append(my_copy(self.gamma)) self.global_beta.append(my_copy(self.beta)) if is_use_bias: self.global_b.append(my_copy(self.b))
def update_global(self, in_channel, out_channel, task_id=0, FRAC0=0.5, FRAC1=0.9, n_ch=256,device=None): # FRAC0, FRAC1 = 0.5, 0.99 p = [1.0, 1.0, 0.95, 0.9, 0.8] p2 = [1.0, 1.0, 0.95, 0.9, 0.8] p2 = [1.0, 1.0, 0.95, 0.9, 0.8] p4 = [1.0, 1.0, 0.95, 0.9, 0.8] if is_use_bias: self.global_b.append(my_copy(self.b)) if self.global_num_task >= 0: with torch.no_grad(): # gamma # if self.global_num_task == 0: s_sum = 0.0 FRAC = chanel_percent(self.style_gama.shape[0], p=p) elif self.global_num_task >= 4: s_sum = self.gamma_s1.abs().sum() FRAC = chanel_percent(self.style_gama.shape[0], p=p4) else: s_sum = self.gamma_s1.abs().sum() FRAC = chanel_percent(self.style_gama.shape[0], p=p2) #[0.95, 0.95, 0.9, 0.9, 0.9] # FRAC = chanel_percent(self.style_gama.shape[0]) Ua, Sa, Va = self.style_gama.squeeze().svd() ii, jj = Sa.abs().sort(descending=True) ii[0] += s_sum ii_acsum = ii.cumsum(dim=0) if s_sum / ii_acsum[-1] < FRAC: num = (~(ii_acsum / ii_acsum[-1] >= FRAC)).sum() + 1 if self.global_num_task == 0: s_all = my_copy(Sa[jj[:num]]) else: s_all = torch.cat((my_copy(self.gamma_s1), my_copy(Sa[jj[:num]])), 0) self.global_gamma_s.append(s_all) self.global_gamma_u.append(my_copy(Ua[:, jj[:num]])) self.global_gamma_v.append(my_copy(Va[:, jj[:num]])) self.global_num_gamma.append(num) else: num = jj[0]-jj[0] self.global_num_gamma.append(num) self.global_gamma_s.append(my_copy(self.gamma_s1)) self.global_gamma_u.append('none') self.global_gamma_v.append('none') # beta if self.global_num_task == 0: s_sum = 0.0 FRAC = chanel_percent(self.style_beta.shape[0], p=p) elif self.global_num_task >= 4: s_sum = self.beta_s1.abs().sum() FRAC = chanel_percent(self.style_gama.shape[0], p=p4) else: s_sum = self.beta_s1.abs().sum() FRAC = chanel_percent(self.style_gama.shape[0], p=p2) #[0.95, 0.95, 0.9, 0.9, 0.9] # FRAC = chanel_percent(self.style_beta.shape[0]) Ua, Sa, Va = self.style_beta.squeeze().svd() ii, jj = Sa.abs().sort(descending=True) ii[0] += s_sum ii_acsum = ii.cumsum(dim=0) if s_sum / ii_acsum[-1] < FRAC: num = (~(ii_acsum / ii_acsum[-1] >= FRAC)).sum() + 1 if self.global_num_task == 0: s_all = my_copy(Sa[jj[:num]]) else: s_all = torch.cat((my_copy(self.beta_s1), my_copy(Sa[jj[:num]])), 0) self.global_num_task = self.global_num_task + 1 self.global_beta_s.append(s_all) self.global_beta_u.append(my_copy(Ua[:, jj[:num]])) self.global_beta_v.append(my_copy(Va[:, jj[:num]])) self.global_num_beta.append(num) else: num = jj[0]-jj[0] self.global_num_beta.append(num) self.global_beta_s.append(my_copy(self.beta_s1)) self.global_beta_u.append('none') self.global_beta_v.append('none') # update parameters # self.gamma_s1 = nn.Parameter(self.global_gamma_s[0][0]) # for ii in range(1, self.global_num_task): # self.gamma_s1.append(nn.Parameter(stdd3 * torch.rand(self.global_num_gamma[ii], device=device))) self.gamma_s1 = nn.Parameter(stdd4*torch.randn(sum(self.global_num_gamma), device=device)) # self.gamma_s1.data[:self.global_num_gamma[0]] = self.global_gamma_s[0].data self.gamma_s1.data[:sum(self.global_num_gamma)] = self.global_gamma_s[-1].data # self.beta_s1 = nn.Parameter(self.global_beta_s[0][0]) # for ii in range(1, self.global_num_task): # self.beta_s1.append(nn.Parameter(stdd3 * torch.rand(self.global_num_beta[ii], device=device))) self.beta_s1 = nn.Parameter(stdd4*torch.randn(sum(self.global_num_beta), device=device)) # self.beta_s1.data[:self.global_num_beta[0]] = self.global_beta_s[0].data self.beta_s1.data[:sum(self.global_num_beta)] = self.global_beta_s[-1].data # self.gamma = nn.Parameter(1/(in_channel**0.5) * torch.randn(in_channel, out_channel)) # self.beta = nn.Parameter(stdd2 + stdd2 * torch.randn(in_channel, out_channel)) self.style_gama.data = torch.zeros(in_channel, out_channel).to(device) self.style_beta.data = torch.zeros(in_channel, out_channel).to(device) if is_use_bias: # self.b = nn.Parameter(torch.zeros(in_channel)) self.b.data = my_copy(torch.tensor(self.global_b[0]).clone()).to(device)