def joint_training_2op(self, src_tasks, tgt_tasks, model_path): self.model.train(), self.d_classifier.train() self.model.apply(WEIGHTS_INIT), self.d_classifier.apply(WEIGHTS_INIT) optimizer1 = torch.optim.SGD(self.model.parameters(), lr=0.1, momentum=0.9) # lr初始值设为0.1-0.2 optimizer2 = torch.optim.Adam(self.model.parameters(), lr=1e-3) # SGD is better c_optimizer = optimizer1 # for encoder d_optimizer = torch.optim.RMSprop(self.d_classifier.parameters(), lr=1e-3, alpha=0.99) # 跨域更好 # d_optimizer = torch.optim.Adam(self.d_classifier.parameters(), lr=1e-3, weight_decay=1e-5) c_scheduler = torch.optim.lr_scheduler.ExponentialLR( c_optimizer, gamma=0.99) # lr=lr∗gamma^epoch d_scheduler = torch.optim.lr_scheduler.ExponentialLR( d_optimizer, gamma=0.99) # lr=lr∗gamma^epoch # ======= # optional_lr = 0.01 # 经验参数:0.001~0.05: 0.02 [to SA/SQ] optional_lr = 0.01 # 经验参数: 0.1~0.2 [CW] # ======= tar_tr = tgt_tasks[:, :src_tasks.shape[1]] print('source set for training:', src_tasks.shape) print('target set for training', tar_tr.shape) print('target set for validation', tgt_tasks.shape) print('(n_s, n_q)==> ', (self.ns, self.nq)) epochs = running_params['train_epochs'] episodes = running_params['train_episodes'] counter = 0 draw = False # t-SNE opt_flag = False avg_ls = torch.zeros([episodes]) times = np.zeros([epochs]) print( f'Start to train! {epochs} epochs, {episodes} episodes, {episodes * epochs} steps.\n' ) for ep in range(epochs): # if (ep + 1) <= 3 and CHECK_D: # draw = True # elif 25 <= (ep + 1) <= 40 and CHECK_D: # draw = True draw = True if 30 <= (ep + 1) <= 40 and CHECK_D else False delta = 10 if (ep + 1) <= 30 else 5 t0 = time.time() for epi in range(episodes): support, query = sample_task_tr(src_tasks, self.way, self.ns, length=DIM) tgt_s, _ = sample_task_tr(tar_tr, self.way, self.ns, length=DIM) tgt_v_s, tgt_v_q = sample_task_tr(tgt_tasks, self.way, self.ns, length=DIM) src_loss, src_acc, _ = self.model.forward(xs=support, xq=query, sne_state=False) constant = self.get_constant(episodes, ep, epi) domain_loss, domain_acc = self.domain_loss(support, tgt_s, constant, draw=draw) # draw = False # t-SNE d_loss = domain_loss[0] + domain_loss[1] loss = src_loss + adda_params['alpha'] * d_loss c_optimizer.zero_grad() d_optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_( parameters=self.model.parameters(), max_norm=0.5) # nn.utils.clip_grad_norm_(parameters=self.d_classifier.parameters(), max_norm=0.5) # To clip the grads of d_classifier is Not recommended. c_optimizer.step() d_optimizer.step() self.model.eval() with torch.no_grad(): tgt_loss, tgt_acc, _ = self.model.forward(xs=tgt_v_s, xq=tgt_v_q, sne_state=False) self.model.train() src_ls, src_ac = src_loss.cpu().item(), src_acc.cpu().item() tgt_ls, tgt_ac = tgt_loss.cpu().item(), tgt_acc.cpu().item() avg_ls[epi] = src_ls if (epi + 1) % 5 == 0: self.visualization.plot([src_ls, tgt_ls], ['Source_cls', 'Target_cls'], counter=counter, scenario="DASMN_Cls Loss") self.visualization.plot([src_ac, tgt_ac], ['C_src', 'C_tgt'], counter=counter, scenario="DASMN_Cls_Acc") self.visualization.plot([ domain_acc[0].cpu().item(), domain_acc[1].cpu().item() ], label=['D_src', 'D_tgt'], counter=counter, scenario="DASMN_D_Acc") self.visualization.plot([ domain_loss[0].cpu().item(), domain_loss[1].cpu().item() ], label=['Source_d', 'Target_d'], counter=counter, scenario="DASMN_D_Loss") self.visualization.plot([loss.cpu().item()], label=['All_Loss'], counter=counter, scenario="DASMN_All_Loss") counter += 1 # if (epi + 1) % 10 == 0: # print('[epoch {}/{}, episode {}/{}] => loss: {:.8f}, acc: {:.8f}'.format( # ep + 1, epochs, epi + 1, episodes, src_ls, src_ac)) # epoch t1 = time.time() times[ep] = t1 - t0 print('[epoch {}/{}] time: {:.5f} Total: {:.5f}'.format( ep + 1, epochs, times[ep], np.sum(times))) ls_ = torch.mean(avg_ls).cpu() # .item() print('[epoch {}/{}] avg_loss: {:.8f}\n'.format( ep + 1, epochs, ls_)) if isinstance(c_optimizer, torch.optim.SGD): c_scheduler.step() # ep // 5 d_scheduler.step() # ep // 5 if ls_ < optional_lr and opt_flag is False: # if (ep + 1) >= 20 and opt_flag is False: c_optimizer = optimizer2 # # c_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) print('====== Optimizer Switch ======\n') opt_flag = True if ep + 1 >= CHECK_EPOCH and (ep + 1) % delta == 0: flag = input("Shall we stop the training? Y/N\n") if flag == 'y' or flag == 'Y': print('Training stops!(manually)') new_path = os.path.join(model_path, f"final_epoch{ep + 1}") self.save(new_path, running_params['train_epochs']) break else: flag = input(f"Save model at epoch {ep + 1}? Y/N\n") if flag == 'y' or flag == 'Y': child_path = os.path.join(model_path, f"epoch{ep + 1}") self.save(child_path, ep + 1) # self.visualization.plot(data=[1000 * optimizer.param_groups[0]['lr']], # label=['LR(*0.001)'], counter=ep, # scenario="SSMN_Dynamic params") print("The total time: {:.5f} s\n".format(np.sum(times)))
def test(self, tar_tasks, src_tasks, save_fig_path, mask=False, model_eval=True): """ :param mask: :param src_tasks: for t-sne :param tar_tasks: target tasks [way, n, dim] :return: """ if model_eval: self.model.eval() else: self.model.train() print('target set', tar_tasks.shape) print('(n_s, n_q)==> ', (self.ns, self.nq)) epochs = running_params['test_epochs'] episodes = running_params['test_episodes'] # episodes = tar_tasks.shape[1] // self.ns print( f'Start to train! {epochs} epochs, {episodes} episodes, {episodes * epochs} steps.\n' ) counter = 0 avg_acc_all = 0. avg_loss_all = 0. print('Model.eval() is:', not self.model.training) for ep in range(epochs): avg_acc_ep = 0. avg_loss_ep = 0. sne_state = False for epi in range(episodes): # tar_s, tar_q = sample_task_te(tar_tasks, self.way, self.ns, length=DIM) # src_s, src_q = sample_task_te(src_tasks, self.way, self.ns, length=DIM) tar_s, tar_q = sample_task_tr(tar_tasks, self.way, self.ns, length=DIM) src_s, src_q = sample_task_tr(src_tasks, self.way, self.ns, length=DIM) # sne_state = True if epi + 1 == episodes else False with torch.no_grad(): tar_loss, tar_acc, zq_t = self.model.forward( xs=tar_s, xq=tar_q, sne_state=sne_state) _, _, zq_s = self.model.forward(xs=src_s, xq=src_s, sne_state=False) if mask: _, _, zq_t = self.model.forward(xs=src_q, xq=src_q, sne_state=False) tar_ls, tar_ac = tar_loss.cpu().item(), tar_acc.cpu().item() avg_acc_ep += tar_ac avg_loss_ep += tar_ls self.visualization.plot([tar_ac, tar_ls], ['Acc', 'Loss'], counter=counter, scenario="DASMN-Test") counter += 1 # if (epi + 1) == episodes: print( f'[{ep + 1}/{epochs}, {epi + 1}/{episodes}]\ttar_loss: {tar_ls:.8f}\ttar_acc: {tar_ac:.8f}' ) self.plot_adaptation(zq_s, zq_t) plt.show() order = input('Save fig? Y/N\n') if order == 'y' or order == 'Y': self.plot_adaptation(zq_s, zq_t) new_path = check_creat_new(save_fig_path) plt.savefig(new_path, dpi=600, format='svg', bbox_inches='tight', pad_inches=0.01) print('Save t-SNE.eps to \n', new_path) # epoch avg_acc_ep /= episodes avg_loss_ep /= episodes avg_acc_all += avg_acc_ep avg_loss_all += avg_loss_ep print( f'[epoch {ep + 1}/{epochs}] avg_loss: {avg_loss_ep:.8f}\tavg_acc: {avg_acc_ep:.8f}' ) avg_acc_all /= epochs avg_loss_all /= epochs print( '\n------------------------Average Result----------------------------' ) print('Average Test Loss: {:.6f}'.format(avg_loss_all)) print('Average Test Accuracy: {:.6f}\n'.format(avg_acc_all)) vis.text(text='Eval:{} Average Accuracy: {:.6f}'.format( not self.model.training, avg_acc_all), win='Eval:{} Test result'.format(not self.model.training))
def model_training(self, src_tasks, tgt_tasks, model_path): self.model.train() self.model.apply(WEIGHTS_INIT) # not recommended weights initialization optimizer1 = torch.optim.SGD(self.model.parameters(), lr=0.1, momentum=0.9) # lr初始值设为0.1-0.2 c_optimizer = optimizer1 # for encoder c_scheduler = torch.optim.lr_scheduler.ExponentialLR(c_optimizer, gamma=0.95) # lr=lr∗gamma^epoch print('source set for training:', src_tasks.shape) print('target set for validation', tgt_tasks.shape) print('(n_s, n_q)==> ', (self.ns, self.nq)) epochs = running_params['train_epochs'] episodes = running_params['train_episodes'] counter = 0 draw = False avg_ls = torch.zeros([episodes]) times = np.zeros([epochs]) print(f'Start to train! {epochs} epochs, {episodes} episodes, {episodes * epochs} steps.\n') for ep in range(epochs): # if (ep + 1) <= 3 and CHECK_D: # draw = True # elif 25 <= (ep + 1) <= 40 and CHECK_D: # draw = True delta = 10 if (ep + 1) <= 30 else 5 t0 = time.time() for epi in range(episodes): support, query = sample_task_tr(src_tasks, self.way, self.ns, length=DIM) tgt_v_s, tgt_v_q = sample_task_tr(tgt_tasks, self.way, self.ns, length=DIM) src_loss, src_acc, _ = self.model.forward(xs=support, xq=query, sne_state=False) draw = False loss = src_loss c_optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(parameters=self.model.parameters(), max_norm=0.5) # nn.utils.clip_grad_norm_(parameters=self.d_classifier.parameters(), max_norm=0.5) # To clip the grads of d_classifier is Not recommended. c_optimizer.step() self.model.eval() with torch.no_grad(): tgt_loss, tgt_acc, _ = self.model.forward(xs=tgt_v_s, xq=tgt_v_q, sne_state=False) self.model.train() src_ls, src_ac = src_loss.cpu().item(), src_acc.cpu().item() tgt_ls, tgt_ac = tgt_loss.cpu().item(), tgt_acc.cpu().item() avg_ls[epi] = src_ls if (epi + 1) % 5 == 0: self.visualization.plot([src_ls, tgt_ls], ['Source_cls', 'Target_cls'], counter=counter, scenario="proto_Cls Loss") self.visualization.plot([src_ac, tgt_ac], ['C_src', 'C_tgt'], counter=counter, scenario="proto_Cls_Acc") counter += 1 # if (epi + 1) % 10 == 0: # print('[epoch {}/{}, episode {}/{}] => loss: {:.6f}, acc: {:.6f}'.format( # ep + 1, epochs, epi + 1, episodes, src_ls, src_ac)) # epoch t1 = time.time() times[ep] = t1 - t0 print('[epoch {}/{}] time: {:.6f} Total: {:.6f}'.format(ep + 1, epochs, times[ep], np.sum(times))) ls_ = torch.mean(avg_ls).cpu() # .item() print('[epoch {}/{}] avg_loss: {:.6f}\n'.format(ep + 1, epochs, ls_)) # if isinstance(c_optimizer, torch.optim.SGD): c_scheduler.step() if ep + 1 >= CHECK_EPOCH and (ep + 1) % delta == 0: flag = input("Shall we stop the training? Y/N\n") if flag == 'y' or flag == 'Y': print('Training stops!(manually)') new_path = os.path.join(model_path, f"final_epoch{ep+1}") self.save(new_path, running_params['train_epochs']) break else: flag = input(f"Save model at epoch {ep+1}? Y/N\n") if flag == 'y' or flag == 'Y': child_path = os.path.join(model_path, f"epoch{ep+1}") self.save(child_path, ep+1) print("The total time: {:.5f} s\n".format(np.sum(times)))