Exemple #1
0
    def joint_training_2op(self, src_tasks, tgt_tasks, model_path):
        self.model.train(), self.d_classifier.train()
        self.model.apply(WEIGHTS_INIT), self.d_classifier.apply(WEIGHTS_INIT)

        optimizer1 = torch.optim.SGD(self.model.parameters(),
                                     lr=0.1,
                                     momentum=0.9)  # lr初始值设为0.1-0.2
        optimizer2 = torch.optim.Adam(self.model.parameters(),
                                      lr=1e-3)  # SGD is better
        c_optimizer = optimizer1  # for encoder
        d_optimizer = torch.optim.RMSprop(self.d_classifier.parameters(),
                                          lr=1e-3,
                                          alpha=0.99)  # 跨域更好
        # d_optimizer = torch.optim.Adam(self.d_classifier.parameters(), lr=1e-3, weight_decay=1e-5)
        c_scheduler = torch.optim.lr_scheduler.ExponentialLR(
            c_optimizer, gamma=0.99)  # lr=lr∗gamma^epoch
        d_scheduler = torch.optim.lr_scheduler.ExponentialLR(
            d_optimizer, gamma=0.99)  # lr=lr∗gamma^epoch
        # =======
        # optional_lr = 0.01  # 经验参数:0.001~0.05: 0.02 [to SA/SQ]
        optional_lr = 0.01  # 经验参数: 0.1~0.2 [CW]
        # =======

        tar_tr = tgt_tasks[:, :src_tasks.shape[1]]
        print('source set for training:', src_tasks.shape)
        print('target set for training', tar_tr.shape)
        print('target set for validation', tgt_tasks.shape)
        print('(n_s, n_q)==> ', (self.ns, self.nq))

        epochs = running_params['train_epochs']
        episodes = running_params['train_episodes']
        counter = 0
        draw = False  # t-SNE
        opt_flag = False
        avg_ls = torch.zeros([episodes])
        times = np.zeros([epochs])

        print(
            f'Start to train! {epochs} epochs, {episodes} episodes, {episodes * epochs} steps.\n'
        )
        for ep in range(epochs):
            # if (ep + 1) <= 3 and CHECK_D:
            #     draw = True
            # elif 25 <= (ep + 1) <= 40 and CHECK_D:
            #     draw = True
            draw = True if 30 <= (ep + 1) <= 40 and CHECK_D else False

            delta = 10 if (ep + 1) <= 30 else 5
            t0 = time.time()
            for epi in range(episodes):
                support, query = sample_task_tr(src_tasks,
                                                self.way,
                                                self.ns,
                                                length=DIM)
                tgt_s, _ = sample_task_tr(tar_tr,
                                          self.way,
                                          self.ns,
                                          length=DIM)
                tgt_v_s, tgt_v_q = sample_task_tr(tgt_tasks,
                                                  self.way,
                                                  self.ns,
                                                  length=DIM)

                src_loss, src_acc, _ = self.model.forward(xs=support,
                                                          xq=query,
                                                          sne_state=False)
                constant = self.get_constant(episodes, ep, epi)
                domain_loss, domain_acc = self.domain_loss(support,
                                                           tgt_s,
                                                           constant,
                                                           draw=draw)
                # draw = False  # t-SNE

                d_loss = domain_loss[0] + domain_loss[1]
                loss = src_loss + adda_params['alpha'] * d_loss

                c_optimizer.zero_grad()
                d_optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    parameters=self.model.parameters(), max_norm=0.5)
                # nn.utils.clip_grad_norm_(parameters=self.d_classifier.parameters(), max_norm=0.5)
                # To clip the grads of d_classifier is Not recommended.
                c_optimizer.step()
                d_optimizer.step()

                self.model.eval()
                with torch.no_grad():
                    tgt_loss, tgt_acc, _ = self.model.forward(xs=tgt_v_s,
                                                              xq=tgt_v_q,
                                                              sne_state=False)
                self.model.train()

                src_ls, src_ac = src_loss.cpu().item(), src_acc.cpu().item()
                tgt_ls, tgt_ac = tgt_loss.cpu().item(), tgt_acc.cpu().item()
                avg_ls[epi] = src_ls

                if (epi + 1) % 5 == 0:
                    self.visualization.plot([src_ls, tgt_ls],
                                            ['Source_cls', 'Target_cls'],
                                            counter=counter,
                                            scenario="DASMN_Cls Loss")
                    self.visualization.plot([src_ac, tgt_ac],
                                            ['C_src', 'C_tgt'],
                                            counter=counter,
                                            scenario="DASMN_Cls_Acc")
                    self.visualization.plot([
                        domain_acc[0].cpu().item(), domain_acc[1].cpu().item()
                    ],
                                            label=['D_src', 'D_tgt'],
                                            counter=counter,
                                            scenario="DASMN_D_Acc")
                    self.visualization.plot([
                        domain_loss[0].cpu().item(),
                        domain_loss[1].cpu().item()
                    ],
                                            label=['Source_d', 'Target_d'],
                                            counter=counter,
                                            scenario="DASMN_D_Loss")
                    self.visualization.plot([loss.cpu().item()],
                                            label=['All_Loss'],
                                            counter=counter,
                                            scenario="DASMN_All_Loss")
                    counter += 1

                # if (epi + 1) % 10 == 0:
                #     print('[epoch {}/{}, episode {}/{}] => loss: {:.8f}, acc: {:.8f}'.format(
                #         ep + 1, epochs, epi + 1, episodes, src_ls, src_ac))
            # epoch
            t1 = time.time()
            times[ep] = t1 - t0
            print('[epoch {}/{}] time: {:.5f} Total: {:.5f}'.format(
                ep + 1, epochs, times[ep], np.sum(times)))
            ls_ = torch.mean(avg_ls).cpu()  # .item()
            print('[epoch {}/{}] avg_loss: {:.8f}\n'.format(
                ep + 1, epochs, ls_))
            if isinstance(c_optimizer, torch.optim.SGD):
                c_scheduler.step()  # ep // 5
            d_scheduler.step()  # ep // 5
            if ls_ < optional_lr and opt_flag is False:
                # if (ep + 1) >= 20 and opt_flag is False:
                c_optimizer = optimizer2
                #     # c_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
                print('====== Optimizer Switch ======\n')
                opt_flag = True

            if ep + 1 >= CHECK_EPOCH and (ep + 1) % delta == 0:
                flag = input("Shall we stop the training? Y/N\n")
                if flag == 'y' or flag == 'Y':
                    print('Training stops!(manually)')
                    new_path = os.path.join(model_path, f"final_epoch{ep + 1}")
                    self.save(new_path, running_params['train_epochs'])
                    break
                else:
                    flag = input(f"Save model at epoch {ep + 1}? Y/N\n")
                    if flag == 'y' or flag == 'Y':
                        child_path = os.path.join(model_path, f"epoch{ep + 1}")
                        self.save(child_path, ep + 1)

            # self.visualization.plot(data=[1000 * optimizer.param_groups[0]['lr']],
            #                         label=['LR(*0.001)'], counter=ep,
            #                         scenario="SSMN_Dynamic params")
        print("The total time: {:.5f} s\n".format(np.sum(times)))
Exemple #2
0
    def test(self,
             tar_tasks,
             src_tasks,
             save_fig_path,
             mask=False,
             model_eval=True):
        """
        :param mask:
        :param src_tasks: for t-sne
        :param tar_tasks: target tasks [way, n, dim]
        :return:
        """
        if model_eval:
            self.model.eval()
        else:
            self.model.train()
        print('target set', tar_tasks.shape)
        print('(n_s, n_q)==> ', (self.ns, self.nq))

        epochs = running_params['test_epochs']
        episodes = running_params['test_episodes']
        # episodes = tar_tasks.shape[1] // self.ns
        print(
            f'Start to train! {epochs} epochs, {episodes} episodes, {episodes * epochs} steps.\n'
        )
        counter = 0
        avg_acc_all = 0.
        avg_loss_all = 0.

        print('Model.eval() is:', not self.model.training)
        for ep in range(epochs):
            avg_acc_ep = 0.
            avg_loss_ep = 0.
            sne_state = False
            for epi in range(episodes):
                # tar_s, tar_q = sample_task_te(tar_tasks, self.way, self.ns, length=DIM)
                # src_s, src_q = sample_task_te(src_tasks, self.way, self.ns, length=DIM)
                tar_s, tar_q = sample_task_tr(tar_tasks,
                                              self.way,
                                              self.ns,
                                              length=DIM)
                src_s, src_q = sample_task_tr(src_tasks,
                                              self.way,
                                              self.ns,
                                              length=DIM)

                # sne_state = True if epi + 1 == episodes else False
                with torch.no_grad():
                    tar_loss, tar_acc, zq_t = self.model.forward(
                        xs=tar_s, xq=tar_q, sne_state=sne_state)
                    _, _, zq_s = self.model.forward(xs=src_s,
                                                    xq=src_s,
                                                    sne_state=False)
                    if mask:
                        _, _, zq_t = self.model.forward(xs=src_q,
                                                        xq=src_q,
                                                        sne_state=False)

                tar_ls, tar_ac = tar_loss.cpu().item(), tar_acc.cpu().item()
                avg_acc_ep += tar_ac
                avg_loss_ep += tar_ls

                self.visualization.plot([tar_ac, tar_ls], ['Acc', 'Loss'],
                                        counter=counter,
                                        scenario="DASMN-Test")
                counter += 1

                # if (epi + 1) == episodes:
                print(
                    f'[{ep + 1}/{epochs}, {epi + 1}/{episodes}]\ttar_loss: {tar_ls:.8f}\ttar_acc: {tar_ac:.8f}'
                )
                self.plot_adaptation(zq_s, zq_t)
                plt.show()
                order = input('Save fig? Y/N\n')
                if order == 'y' or order == 'Y':
                    self.plot_adaptation(zq_s, zq_t)
                    new_path = check_creat_new(save_fig_path)
                    plt.savefig(new_path,
                                dpi=600,
                                format='svg',
                                bbox_inches='tight',
                                pad_inches=0.01)
                    print('Save t-SNE.eps to \n', new_path)

            # epoch
            avg_acc_ep /= episodes
            avg_loss_ep /= episodes
            avg_acc_all += avg_acc_ep
            avg_loss_all += avg_loss_ep
            print(
                f'[epoch {ep + 1}/{epochs}] avg_loss: {avg_loss_ep:.8f}\tavg_acc: {avg_acc_ep:.8f}'
            )
        avg_acc_all /= epochs
        avg_loss_all /= epochs
        print(
            '\n------------------------Average Result----------------------------'
        )
        print('Average Test Loss: {:.6f}'.format(avg_loss_all))
        print('Average Test Accuracy: {:.6f}\n'.format(avg_acc_all))
        vis.text(text='Eval:{} Average Accuracy: {:.6f}'.format(
            not self.model.training, avg_acc_all),
                 win='Eval:{} Test result'.format(not self.model.training))
Exemple #3
0
    def model_training(self, src_tasks, tgt_tasks, model_path):
        self.model.train()
        self.model.apply(WEIGHTS_INIT)  # not recommended weights initialization

        optimizer1 = torch.optim.SGD(self.model.parameters(), lr=0.1, momentum=0.9)  # lr初始值设为0.1-0.2
        c_optimizer = optimizer1  # for encoder
        c_scheduler = torch.optim.lr_scheduler.ExponentialLR(c_optimizer, gamma=0.95)  # lr=lr∗gamma^epoch

        print('source set for training:', src_tasks.shape)
        print('target set for validation', tgt_tasks.shape)
        print('(n_s, n_q)==> ', (self.ns, self.nq))

        epochs = running_params['train_epochs']
        episodes = running_params['train_episodes']
        counter = 0
        draw = False
        avg_ls = torch.zeros([episodes])
        times = np.zeros([epochs])

        print(f'Start to train! {epochs} epochs, {episodes} episodes, {episodes * epochs} steps.\n')
        for ep in range(epochs):
            # if (ep + 1) <= 3 and CHECK_D:
            #     draw = True
            # elif 25 <= (ep + 1) <= 40 and CHECK_D:
            #     draw = True

            delta = 10 if (ep + 1) <= 30 else 5
            t0 = time.time()
            for epi in range(episodes):
                support, query = sample_task_tr(src_tasks, self.way, self.ns, length=DIM)
                tgt_v_s, tgt_v_q = sample_task_tr(tgt_tasks, self.way, self.ns, length=DIM)

                src_loss, src_acc, _ = self.model.forward(xs=support, xq=query, sne_state=False)
                draw = False
                loss = src_loss

                c_optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(parameters=self.model.parameters(), max_norm=0.5)
                # nn.utils.clip_grad_norm_(parameters=self.d_classifier.parameters(), max_norm=0.5)
                # To clip the grads of d_classifier is Not recommended.
                c_optimizer.step()

                self.model.eval()
                with torch.no_grad():
                    tgt_loss, tgt_acc, _ = self.model.forward(xs=tgt_v_s, xq=tgt_v_q, sne_state=False)
                self.model.train()

                src_ls, src_ac = src_loss.cpu().item(), src_acc.cpu().item()
                tgt_ls, tgt_ac = tgt_loss.cpu().item(), tgt_acc.cpu().item()
                avg_ls[epi] = src_ls

                if (epi + 1) % 5 == 0:
                    self.visualization.plot([src_ls, tgt_ls], ['Source_cls', 'Target_cls'],
                                            counter=counter, scenario="proto_Cls Loss")
                    self.visualization.plot([src_ac, tgt_ac], ['C_src', 'C_tgt'],
                                            counter=counter, scenario="proto_Cls_Acc")
                    counter += 1

                # if (epi + 1) % 10 == 0:
                #     print('[epoch {}/{}, episode {}/{}] => loss: {:.6f}, acc: {:.6f}'.format(
                #         ep + 1, epochs, epi + 1, episodes, src_ls, src_ac))
            # epoch
            t1 = time.time()
            times[ep] = t1 - t0
            print('[epoch {}/{}] time: {:.6f} Total: {:.6f}'.format(ep + 1, epochs, times[ep], np.sum(times)))
            ls_ = torch.mean(avg_ls).cpu()  # .item()
            print('[epoch {}/{}] avg_loss: {:.6f}\n'.format(ep + 1, epochs, ls_))

            # if isinstance(c_optimizer, torch.optim.SGD):
            c_scheduler.step()

            if ep + 1 >= CHECK_EPOCH and (ep + 1) % delta == 0:
                flag = input("Shall we stop the training? Y/N\n")
                if flag == 'y' or flag == 'Y':
                    print('Training stops!(manually)')
                    new_path = os.path.join(model_path, f"final_epoch{ep+1}")
                    self.save(new_path, running_params['train_epochs'])
                    break
                else:
                    flag = input(f"Save model at epoch {ep+1}? Y/N\n")
                    if flag == 'y' or flag == 'Y':
                        child_path = os.path.join(model_path, f"epoch{ep+1}")
                        self.save(child_path, ep+1)

        print("The total time: {:.5f} s\n".format(np.sum(times)))