Ejemplo n.º 1
0
    def keywords(self):
        '''
        Keywords SelfAttention
        '''
        self.build_vocabulary()
        self.build_models()
        print(self.base_models)
        print(self.train_models)
        if len(self.base_models) > 0:
            self.init_base_model_params()
        if len(self.train_models) > 0:
            self.init_train_model_params()

        self.test_data = create_batch_memory(path_=self.args.data_dir,
                                             file_=self.args.file_test,
                                             is_shuffle=False,
                                             batch_size=self.args.batch_size,
                                             is_lower=self.args.is_lower)

        output_dir = '../nats_results/' + self.args.keywords_output_dir
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)

        for model_name in self.base_models:
            self.base_models[model_name].eval()
        for model_name in self.train_models:
            self.train_models[model_name].eval()

        with torch.no_grad():
            print('Begin Testing: {}'.format(self.args.file_test))
            test_batch = len(self.test_data)
            print('The number of batches (testing): {}'.format(test_batch))
            pred_data = []
            true_data = []
            keywords_data = []
            if self.args.debug:
                test_batch = 3
            for test_id in range(test_batch):
                self.build_batch(self.test_data[test_id])
                ratePred, rateTrue = self.test_worker()
                output = self.keywords_worker()
                keywords_data += output

                pred_data += ratePred
                true_data += rateTrue

                show_progress(test_id + 1, test_batch)
            print()

            for k in range(len(keywords_data)):
                keywords_data[k]['pred_label'] = pred_data[k]
                keywords_data[k]['gold_label'] = true_data[k]

            fout = open(
                os.path.join(
                    output_dir, '{}_{}.pickled'.format(self.args.file_test,
                                                       self.args.best_epoch)),
                'wb')
            pickle.dump(keywords_data, fout)
            fout.close()
Ejemplo n.º 2
0
    def app(self):
        '''
        Visualization
        '''
        self.build_vocabulary()
        self.build_models()
        print(self.base_models)
        print(self.train_models)
        if len(self.base_models) > 0:
            self.init_base_model_params()
        
        self.app_data = create_batch_memory(
            path_=self.args.data_dir,
            file_=self.args.file_app,
            is_shuffle=False,
            batch_size=self.args.batch_size,
            is_lower=self.args.is_lower
        )
        
        for model_name in self.base_models: 
            self.base_models[model_name].eval()
        for model_name in self.train_models: 
            self.train_models[model_name].eval()
            
        with torch.no_grad():
            for epoch in range(self.args.n_epoch):
                
                self.init_app_model_params(epoch+1)
                
                print('Begin Testing')
                n_batch = len(self.app_data)
                print('The number of batches (App): {}'.format(n_batch))
                self.pred_data = []
                self.true_data = []
                for batch_id in range(n_batch):

                    self.build_batch(self.app_data[batch_id])
                    ratePred, rateTrue = self.test_worker()
                    
                    self.pred_data += ratePred
                    self.true_data += rateTrue

                    show_progress(batch_id+1, n_batch)
                print()
                self.pred_data = np.array(self.pred_data).astype(int)
                np.savetxt(
                    os.path.join('..', 'nats_results', 'app_pred_{}.txt'.format(epoch+1)), 
                    self.pred_data, fmt='%d')
    
                self.true_data = np.array(self.true_data).astype(int)
                np.savetxt(
                    os.path.join('..', 'nats_results', 'app_true_{}.txt'.format(epoch+1)), 
                    self.true_data, fmt='%d')
                            
Ejemplo n.º 3
0
    def keyword_extraction(self):
        '''
        Visualization
        '''
        self.build_vocabulary()
        self.build_models()
        print(self.base_models)
        print(self.train_models)
        if len(self.base_models) > 0:
            self.init_base_model_params()
        if len(self.train_models) > 0:
            self.init_train_model_params()

        self.vis_data = create_batch_memory(path_=self.args.data_dir,
                                            file_=self.args.file_vis,
                                            is_shuffle=False,
                                            batch_size=self.args.batch_size,
                                            is_lower=self.args.is_lower)

        key_dir = '../nats_results/attn_keywords'
        if not os.path.exists(key_dir):
            os.mkdir(key_dir)
        else:
            shutil.rmtree(key_dir)
            os.mkdir(key_dir)

        with torch.no_grad():

            print('Begin Generate Keywords')
            n_batch = len(self.vis_data)
            print('The number of batches (keywords): {}'.format(n_batch))
            for batch_id in range(n_batch):

                self.build_batch(self.vis_data[batch_id])
                self.keyword_worker(batch_id, key_dir)

                show_progress(batch_id + 1, n_batch)
            print()

            for k in range(self.args.n_tasks):
                key_arr = [[
                    wd, 100 * self.keywords1[k][wd] / (self.wd_freq[wd] + 100)
                ] for wd in self.keywords1[k]]
                key_arr = sorted(key_arr, key=lambda k: k[1])[::-1]
                key_arr = [[itm[0]] * int(round(itm[1])) for itm in key_arr
                           if (itm[0] not in stop_words) and (
                               len(itm[0]) > 3) and (itm[0] != '<unk>')]
                key_arr = key_arr[:100]
                key_arr = list(itertools.chain(*key_arr))
                fout = open(os.path.join(key_dir, str(k) + '.txt'), 'w')
                fout.write(' '.join(key_arr) + '\n')
                fout.close()
Ejemplo n.º 4
0
    def test_penultimate(self):
        '''
        Testing
        '''
        self.build_vocabulary()
        self.build_models()
        print(self.base_models)
        print(self.train_models)
        if len(self.base_models) > 0:
            self.init_base_model_params()
        if len(self.train_models) > 0:
            self.init_train_model_params()

        self.test_data = create_batch_memory(path_=self.args.data_dir,
                                             file_=self.args.file_test,
                                             is_shuffle=False,
                                             batch_size=self.args.batch_size,
                                             is_lower=self.args.is_lower)

        output_dir = '../nats_results/' + \
            self.args.test_output_dir + '_penultimate'
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)

        for model_name in self.base_models:
            self.base_models[model_name].eval()
        for model_name in self.train_models:
            self.train_models[model_name].eval()

        with torch.no_grad():
            print('Begin Testing: {}'.format(self.args.file_test))
            test_batch = len(self.test_data)
            print('The number of batches (testing): {}'.format(test_batch))
            self.pred_data = []
            if self.args.debug:
                test_batch = 3
            for test_id in range(test_batch):
                self.build_batch(self.test_data[test_id])
                logits = self.test_penultimate_worker()
                self.pred_data += logits.data.cpu().numpy().tolist()

                show_progress(test_id + 1, test_batch)
            print()
            # save testing data.
            outfile = os.path.join(
                output_dir,
                '{}_pred_{}.pickled'.format(self.args.file_test,
                                            self.args.best_epoch))
            fout = open(outfile, 'wb')
            pickle.dump(self.pred_data, fout)
            fout.close()
    def visualization(self):
        '''
        Visualization
        '''
        self.build_vocabulary()
        self.build_models()
        print(self.base_models)
        print(self.train_models)
        if len(self.base_models) > 0:
            self.init_base_model_params()
        if len(self.train_models) > 0:
            self.init_train_model_params()

        self.vis_data = create_batch_memory(path_=self.args.data_dir,
                                            file_=self.args.file_vis,
                                            is_shuffle=False,
                                            batch_size=self.args.batch_size,
                                            is_lower=self.args.is_lower)

        vis_dir = '../nats_results/attn_vis'
        if not os.path.exists(vis_dir):
            os.mkdir(vis_dir)
        else:
            shutil.rmtree(vis_dir)
            os.mkdir(vis_dir)

        for model_name in self.base_models:
            self.base_models[model_name].eval()
        for model_name in self.train_models:
            self.train_models[model_name].eval()
        with torch.no_grad():

            print('Begin Visualization')
            n_batch = len(self.vis_data)
            print('The number of batches (visualization): {}'.format(n_batch))
            for batch_id in range(n_batch):

                self.build_batch(self.vis_data[batch_id])
                self.visualization_worker(batch_id, vis_dir)

                show_progress(batch_id + 1, n_batch)
            print()
Ejemplo n.º 6
0
    def visualization(self):
        '''
        Visualization
        '''
        self.build_vocabulary()
        self.build_models()
        print(self.base_models)
        print(self.train_models)
        if len(self.base_models) > 0:
            self.init_base_model_params()
        if len(self.train_models) > 0:
            self.init_train_model_params()

        self.test_data = create_batch_memory(path_=self.args.data_dir,
                                             file_=self.args.file_test,
                                             is_shuffle=False,
                                             batch_size=self.args.batch_size,
                                             is_lower=self.args.is_lower)

        for model_name in self.base_models:
            self.base_models[model_name].eval()
        for model_name in self.train_models:
            self.train_models[model_name].eval()

        output_dir = '../nats_results/visualization_{}'.format(
            self.args.file_test)
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)

        aspect_label = self.args.aspect_label.split(',')
        vis_label = [
            int(wd) - 1 for wd in self.args.visualization_aspect.split(',')
        ]
        data_aspect = [aspect_label[idx] for idx in vis_label]
        print('You will visualize Aspects: {}'.format(', '.join(data_aspect)))

        with torch.no_grad():
            print('Begin Testing: {}'.format(self.args.file_test))
            test_batch = len(self.test_data)
            print('The number of batches (testing): {}'.format(test_batch))
            pred_data = []
            true_data = []
            keywords_data = []
            if self.args.debug:
                test_batch = 3
            for test_id in range(test_batch):
                self.build_batch(self.test_data[test_id])
                ratePred, rateTrue = self.test_worker()
                output = self.visualization_worker()
                keywords_data += output

                pred_data += ratePred
                true_data += rateTrue

                show_progress(test_id + 1, test_batch)
            print()
            for k in range(len(keywords_data)):
                keywords_data[k]['pred_label'] = [
                    pred_data[k][idx] for idx in vis_label
                ]
                keywords_data[k]['gold_label'] = [
                    true_data[k][idx] for idx in vis_label
                ]
                len_txt = len(keywords_data[k]['text'][0].split())
                diff = []
                for j in range(len(pred_data[k])):
                    if pred_data[k][j] == true_data[k][
                            j] and true_data[k][j] > 0:
                        diff.append(0)
                    else:
                        diff.append(1)
                diff = np.sum(diff)
                ftxt = '_'.join(
                    map(str, [k, len_txt, diff] + true_data[k] + pred_data[k]))
                file_output = os.path.join(output_dir, '{}.html'.format(ftxt))
                keywords_data[k]['text'] = [
                    keywords_data[k]['text'][idx] for idx in vis_label
                ]
                keywords_data[k]['weights'] = [
                    keywords_data[k]['weights'][idx] for idx in vis_label
                ]
                createHTML(data_aspect, keywords_data[k], file_output)
Ejemplo n.º 7
0
    def test_uncertainty(self):
        '''
        Testing Uncertainty
        '''
        self.build_vocabulary()
        self.build_models()
        print(self.base_models)
        print(self.train_models)
        if len(self.base_models) > 0:
            self.init_base_model_params()
        if len(self.train_models) > 0:
            self.init_train_model_params()

        self.test_data = create_batch_memory(path_=self.args.data_dir,
                                             file_=self.args.file_test,
                                             is_shuffle=False,
                                             batch_size=self.args.batch_size,
                                             is_lower=self.args.is_lower)

        output_dir = '../nats_results/{}_uncertainty_{}_{}'.format(
            self.args.test_output_dir, self.args.drop_option,
            self.args.drop_rate)
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)

        with torch.no_grad():
            for k_unc in range(self.args.uncertainty_total_samples):

                print('Begin Testing: {}, {}'.format(self.args.file_test,
                                                     k_unc))
                test_batch = len(self.test_data)
                print('The number of batches (testing): {}'.format(test_batch))
                self.pred_data = []
                self.true_data = []
                self.vector_data = []
                if self.args.debug:
                    test_batch = 3
                for test_id in range(test_batch):
                    self.build_batch(self.test_data[test_id])
                    logits = self.test_penultimate_worker()
                    ratePred, rateTrue = self.test_worker()

                    self.vector_data += logits.data.cpu().numpy().tolist()
                    self.pred_data += ratePred
                    self.true_data += rateTrue

                    show_progress(test_id + 1, test_batch)
                print()
                # save testing data.
                outfile = os.path.join(
                    output_dir,
                    '{}_vector_{}_{}.pickled'.format(self.args.file_test,
                                                     self.args.best_epoch,
                                                     k_unc))
                fout = open(outfile, 'wb')
                pickle.dump(self.vector_data, fout)
                fout.close()
                try:
                    self.pred_data = np.array(self.pred_data).astype(int)
                    np.savetxt(os.path.join(
                        output_dir,
                        '{}_pred_{}_unc_{}.txt'.format(self.args.file_test,
                                                       self.args.best_epoch,
                                                       k_unc)),
                               self.pred_data,
                               fmt='%d')
                    self.true_data = np.array(self.true_data).astype(int)
                    np.savetxt(os.path.join(
                        output_dir,
                        '{}_true_{}_unc_{}.txt'.format(self.args.file_test,
                                                       self.args.best_epoch,
                                                       k_unc)),
                               self.true_data,
                               fmt='%d')
                except:
                    fout = open(
                        os.path.join(
                            output_dir, '{}_pred_{}.pickled'.format(
                                self.args.file_best, self.args.best_epoch)),
                        'wb')
                    pickle.dump(self.pred_data, fout)
                    fout.close()
                    fout = open(
                        os.path.join(
                            output_dir, '{}_true_{}.pickled'.format(
                                self.args.file_test, self.args.best_epoch)),
                        'wb')
                    pickle.dump(self.true_data, fout)
                    fout.close()
Ejemplo n.º 8
0
    def test(self):
        '''
        Testing
        '''
        self.build_vocabulary()
        self.build_models()
        print(self.base_models)
        print(self.train_models)
        if len(self.base_models) > 0:
            self.init_base_model_params()
        if len(self.train_models) > 0:
            self.init_train_model_params()

        self.test_data = create_batch_memory(path_=self.args.data_dir,
                                             file_=self.args.file_test,
                                             is_shuffle=False,
                                             batch_size=self.args.batch_size,
                                             is_lower=self.args.is_lower)

        output_dir = '../nats_results/' + self.args.test_output_dir
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)

        for model_name in self.base_models:
            self.base_models[model_name].eval()
        for model_name in self.train_models:
            self.train_models[model_name].eval()

        with torch.no_grad():
            print('Begin Testing: {}'.format(self.args.file_test))
            test_batch = len(self.test_data)
            print('The number of batches (testing): {}'.format(test_batch))
            self.pred_data = []
            self.true_data = []
            if self.args.debug:
                test_batch = 3
            for test_id in range(test_batch):
                self.build_batch(self.test_data[test_id])
                ratePred, rateTrue = self.test_worker()

                self.pred_data += ratePred
                self.true_data += rateTrue

                show_progress(test_id + 1, test_batch)
            print()
            # save testing data.
            try:
                self.pred_data = np.array(self.pred_data).astype(int)
                np.savetxt(os.path.join(
                    output_dir,
                    '{}_pred_{}.txt'.format(self.args.file_test,
                                            self.args.best_epoch)),
                           self.pred_data,
                           fmt='%d')
                self.true_data = np.array(self.true_data).astype(int)
                np.savetxt(os.path.join(
                    output_dir,
                    '{}_true_{}.txt'.format(self.args.file_test,
                                            self.args.best_epoch)),
                           self.true_data,
                           fmt='%d')
            except:
                fout = open(
                    os.path.join(
                        output_dir,
                        '{}_pred_{}.pickled'.format(self.args.file_best,
                                                    self.args.best_epoch)),
                    'wb')
                pickle.dump(self.pred_data, fout)
                fout.close()
                fout = open(
                    os.path.join(
                        output_dir,
                        '{}_true_{}.pickled'.format(self.args.file_test,
                                                    self.args.best_epoch)),
                    'wb')
                pickle.dump(self.true_data, fout)
                fout.close()
    def train(self):
        '''
        training here.
        Don't overwrite.
        '''
        self.build_vocabulary()
        self.build_models()
        print(self.base_models)
        print(self.train_models)
        if len(self.base_models) > 0:
            self.init_base_model_params()
        # here it is necessary to put list. Instead of directly append.
        for model_name in self.train_models:
            try:
                params += list(self.train_models[model_name].parameters())
            except:
                params = list(self.train_models[model_name].parameters())
        if self.args.train_base_model:
            for model_name in self.base_models:
                try:
                    params += list(self.base_models[model_name].parameters())
                except:
                    params = list(self.base_models[model_name].parameters())
        # define optimizer
        optimizer = self.build_optimizer(params)
        try:
            scheduler = self.build_scheduler(optimizer)
        except:
            pass
        # load checkpoint
        cc_model = 0
        out_dir = os.path.join('..', 'nats_results')
        if not os.path.exists(out_dir):
            os.mkdir(out_dir)
        if self.args.continue_training:
            model_para_files = glob.glob(os.path.join(out_dir, '*.model'))
            if len(model_para_files) > 0:
                uf_model = []
                for fl_ in model_para_files:
                    arr = re.split('\/', fl_)[-1]
                    arr = re.split('\_|\.', arr)
                    if arr not in uf_model:
                        uf_model.append(int(arr[-2]))
                cc_model = sorted(uf_model)[-1]
                try:
                    print("Try *_{}.model".format(cc_model))
                    for model_name in self.train_models:
                        fl_ = os.path.join(
                            out_dir,
                            model_name + '_' + str(cc_model) + '.model')
                        self.train_models[model_name].load_state_dict(
                            torch.load(
                                fl_,
                                map_location=lambda storage, loc: storage))
                except:
                    cc_model = sorted(uf_model)[-2]
                    print("Try *_{}.model".format(cc_model))
                    for model_name in self.train_models:
                        fl_ = os.path.join(
                            out_dir,
                            model_name + '_' + str(cc_model) + '.model')
                        self.train_models[model_name].load_state_dict(
                            torch.load(
                                fl_,
                                map_location=lambda storage, loc: storage))
                print('Continue training with *_{}.model'.format(cc_model))
        else:
            shutil.rmtree(out_dir)
            os.mkdir(out_dir)

        self.val_data = create_batch_memory(path_=self.args.data_dir,
                                            file_=self.args.file_val,
                                            is_shuffle=False,
                                            batch_size=self.args.batch_size,
                                            is_lower=self.args.is_lower)
        self.test_data = create_batch_memory(path_=self.args.data_dir,
                                             file_=self.args.file_test,
                                             is_shuffle=False,
                                             batch_size=self.args.batch_size,
                                             is_lower=self.args.is_lower)
        # train models
        if cc_model > 0:
            cc_model -= 1
        for epoch in range(cc_model, self.args.n_epoch):
            '''
            Train
            '''
            for model_name in self.base_models:
                self.base_models[model_name].train()
            for model_name in self.train_models:
                self.train_models[model_name].train()
            print('====================================')
            print('Training Epoch: {}'.format(epoch + 1))
            self.train_data = create_batch_memory(
                path_=self.args.data_dir,
                file_=self.args.file_train,
                is_shuffle=True,
                batch_size=self.args.batch_size,
                is_lower=self.args.is_lower)
            n_batch = len(self.train_data)
            print('The number of batches (training): {}'.format(n_batch))
            self.global_steps = max(0, epoch) * n_batch
            try:
                scheduler.step()
            except:
                pass
            if self.args.debug:
                n_batch = 1
            loss_arr = []
            for batch_id in range(n_batch):
                self.global_steps += 1

                self.build_batch(self.train_data[batch_id])
                loss = self.build_pipelines()

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(params, self.args.grad_clip)
                optimizer.step()

                if batch_id % self.args.checkpoint == 0:
                    for model_name in self.train_models:
                        fmodel = open(
                            os.path.join(
                                out_dir,
                                model_name + '_' + str(epoch + 1) + '.model'),
                            'wb')
                        torch.save(self.train_models[model_name].state_dict(),
                                   fmodel)
                        fmodel.close()
                show_progress(batch_id + 1, n_batch)
                loss_arr.append(loss.data.cpu().numpy())
            print()
            # write models
            print('Training Loss = {}.'.format(np.average(loss_arr)))
            for model_name in self.train_models:
                fmodel = open(
                    os.path.join(out_dir,
                                 model_name + '_' + str(epoch + 1) + '.model'),
                    'wb')
                torch.save(self.train_models[model_name].state_dict(), fmodel)
                fmodel.close()

            for model_name in self.base_models:
                self.base_models[model_name].eval()
            for model_name in self.train_models:
                self.train_models[model_name].eval()
            with torch.no_grad():
                '''
                Validate
                '''
                print('Begin Validation')
                n_batch = len(self.val_data)
                print('The number of batches (validation): {}'.format(n_batch))
                self.pred_data = []
                self.true_data = []
                if self.args.debug:
                    n_batch = 1
                for batch_id in range(n_batch):

                    self.build_batch(self.val_data[batch_id])
                    ratePred, rateTrue = self.test_worker()

                    self.pred_data += ratePred
                    self.true_data += rateTrue

                    show_progress(batch_id + 1, n_batch)
                print()
                self.pred_data = np.array(self.pred_data).astype(int)
                np.savetxt(os.path.join(
                    '..', 'nats_results',
                    'validate_pred_{}.txt'.format(epoch + 1)),
                           self.pred_data,
                           fmt='%d')

                self.true_data = np.array(self.true_data).astype(int)
                np.savetxt(os.path.join(
                    '..', 'nats_results',
                    'validate_true_{}.txt'.format(epoch + 1)),
                           self.true_data,
                           fmt='%d')

                self.run_evaluation()
                '''
                Testing
                '''
                print('Begin Testing')
                n_batch = len(self.test_data)
                print('The number of batches (testing): {}'.format(n_batch))
                self.pred_data = []
                self.true_data = []
                if self.args.debug:
                    n_batch = 1
                for batch_id in range(n_batch):

                    self.build_batch(self.test_data[batch_id])
                    ratePred, rateTrue = self.test_worker()

                    self.pred_data += ratePred
                    self.true_data += rateTrue

                    show_progress(batch_id + 1, n_batch)
                print()
                self.pred_data = np.array(self.pred_data).astype(int)
                np.savetxt(os.path.join('..', 'nats_results',
                                        'test_pred_{}.txt'.format(epoch + 1)),
                           self.pred_data,
                           fmt='%d')

                self.true_data = np.array(self.true_data).astype(int)
                np.savetxt(os.path.join('..', 'nats_results',
                                        'test_true_{}.txt'.format(epoch + 1)),
                           self.true_data,
                           fmt='%d')

                self.run_evaluation()
Ejemplo n.º 10
0
    def train(self):
        '''
        training here.
        Don't overwrite.
        '''
        self.build_vocabulary()
        self.build_models()
        pprint(self.base_models)
        pprint(self.train_models)
        if len(self.base_models) > 0:
            self.init_base_model_params()
        # here it is necessary to put list. Instead of directly append.
        params = []
        for model_name in self.train_models:
            params += list(self.train_models[model_name].parameters())
        if self.args.train_base_model:
            for model_name in self.base_models:
                params += list(self.base_models[model_name].parameters())
        print('Total number of parameters: {}.'.format(
            sum([para.numel() for para in params])))
        # define optimizer
        optimizer = self.build_optimizer(params)
        if self.args.lr_schedule == 'build-in':
            scheduler = self.build_scheduler(optimizer)
        # load checkpoint
        cc_model = 0
        out_dir = os.path.join('..', 'nats_results')
        if not os.path.exists(out_dir):
            os.mkdir(out_dir)
        if self.args.continue_training:
            model_para_files = glob.glob(os.path.join(out_dir, '*.model'))
            if len(model_para_files) > 0:
                uf_model = []
                for fl_ in model_para_files:
                    arr = re.split(r'\/', fl_)[-1]
                    arr = re.split(r'\_|\.', arr)
                    if arr not in uf_model:
                        uf_model.append(int(arr[-2]))
                cc_model = sorted(uf_model)[-1]
                try:
                    print("Try *_{}.model".format(cc_model))
                    for model_name in self.train_models:
                        fl_ = os.path.join(
                            out_dir,
                            model_name + '_' + str(cc_model) + '.model')
                        self.train_models[model_name].load_state_dict(
                            torch.load(
                                fl_,
                                map_location=lambda storage, loc: storage))
                except:
                    cc_model = sorted(uf_model)[-2]
                    print("Try *_{}.model".format(cc_model))
                    for model_name in self.train_models:
                        fl_ = os.path.join(
                            out_dir,
                            model_name + '_' + str(cc_model) + '.model')
                        self.train_models[model_name].load_state_dict(
                            torch.load(
                                fl_,
                                map_location=lambda storage, loc: storage))
                print('Continue training with *_{}.model'.format(cc_model))
        else:
            shutil.rmtree(out_dir)
            os.mkdir(out_dir)

        self.val_data = create_batch_memory(path_=self.args.data_dir,
                                            file_=self.args.file_val,
                                            is_shuffle=False,
                                            batch_size=self.args.batch_size,
                                            is_lower=self.args.is_lower)
        self.test_data = create_batch_memory(path_=self.args.data_dir,
                                             file_=self.args.file_test,
                                             is_shuffle=False,
                                             batch_size=self.args.batch_size,
                                             is_lower=self.args.is_lower)
        # train models
        fout = open('../nats_results/args.pickled', 'wb')
        pickle.dump(self.args, fout)
        fout.close()
        if cc_model > 0:
            cc_model -= 1
        for epoch in range(cc_model, self.args.n_epoch):
            # Training
            print('====================================')
            print('Training Epoch: {}'.format(epoch + 1))
            self.train_data = create_batch_memory(
                path_=self.args.data_dir,
                file_=self.args.file_train,
                is_shuffle=True,
                batch_size=self.args.batch_size,
                is_lower=self.args.is_lower)
            n_batch = len(self.train_data)
            print('The number of batches (training): {}'.format(n_batch))
            self.global_steps = max(0, epoch) * n_batch
            if self.args.debug:
                n_batch = 3
            loss_arr = []
            accu_best = 0
            for batch_id in range(n_batch):
                self.global_steps += 1
                learning_rate = self.args.learning_rate
                if self.args.lr_schedule == 'warm-up':
                    learning_rate = 2.0 * \
                        (self.args.model_size ** (-0.5) *
                        min(self.global_steps ** (-0.5),
                        self.global_steps * self.args.warmup_step**(-1.5)))
                    for p in optimizer.param_groups:
                        p['lr'] = learning_rate
                elif self.args.lr_schedule == 'build-in':
                    for p in optimizer.param_groups:
                        learning_rate = p['lr']
                        break
                # print(learning_rate)
                ccnt = batch_id % self.args.checkpoint
                if batch_id > 0 and batch_id % self.args.checkpoint == 0:
                    ccnt = self.args.checkpoint

                self.build_batch(self.train_data[batch_id])
                loss = self.build_pipelines()

                if loss != loss:
                    raise ValueError

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(params, self.args.grad_clip)
                optimizer.step()

                loss_arr.append(loss.data.cpu().numpy().tolist())
                message = 'lr={}, Loss={}, AvgLoss={}{}'.format(
                    np.around(learning_rate, 6),
                    np.around(loss.data.cpu().numpy().tolist(), 4),
                    np.around(np.average(loss_arr), 4), ' ' * 10)
                show_progress(ccnt, min(n_batch, self.args.checkpoint),
                              message)

                if (batch_id % self.args.checkpoint == 0
                        and batch_id != 0) or batch_id == n_batch - 1:
                    print()
                    print('Training Loss = {}.'.format(np.average(loss_arr)))
                    for model_name in self.base_models:
                        self.base_models[model_name].eval()
                    for model_name in self.train_models:
                        self.train_models[model_name].eval()
                    with torch.no_grad():
                        # validate
                        print('Begin Validation')
                        val_batch = len(self.val_data)
                        print('The number of batches (validation): {}'.format(
                            val_batch))
                        self.pred_data = []
                        self.true_data = []
                        if self.args.debug:
                            val_batch = 3
                        for val_id in range(val_batch):
                            self.build_batch(self.val_data[val_id])
                            ratePred, rateTrue = self.test_worker()

                            self.pred_data += ratePred
                            self.true_data += rateTrue

                            show_progress(val_id + 1, val_batch)
                        print()
                        # evaluate development
                        accu = self.run_evaluation()
                        print('Best Results: {}'.format(np.round(accu_best,
                                                                 4)))
                        if accu >= accu_best:
                            accu_best = accu
                            # save results.
                            try:
                                self.pred_data = np.array(
                                    self.pred_data).astype(int)
                                np.savetxt(os.path.join(
                                    '..', 'nats_results',
                                    'validate_pred_{}.txt'.format(epoch + 1)),
                                           self.pred_data,
                                           fmt='%d')
                                self.true_data = np.array(
                                    self.true_data).astype(int)
                                np.savetxt(os.path.join(
                                    '..', 'nats_results',
                                    'validate_true_{}.txt'.format(epoch + 1)),
                                           self.true_data,
                                           fmt='%d')
                            except:
                                fout = open(
                                    os.path.join(
                                        '..', 'nats_results',
                                        'validate_pred_{}.pickled'.format(
                                            epoch + 1)), 'wb')
                                pickle.dump(self.pred_data, fout)
                                fout.close()
                                fout = open(
                                    os.path.join(
                                        '..', 'nats_results',
                                        'validate_true_{}.pickled'.format(
                                            epoch + 1)), 'wb')
                                pickle.dump(self.true_data, fout)
                                fout.close()
                            # save models
                            for model_name in self.train_models:
                                fmodel = open(
                                    os.path.join(
                                        out_dir, model_name + '_' +
                                        str(epoch + 1) + '.model'), 'wb')
                                torch.save(
                                    self.train_models[model_name].state_dict(),
                                    fmodel)
                                fmodel.close()
                            # testing
                            print('Begin Testing')
                            test_batch = len(self.test_data)
                            print('The number of batches (testing): {}'.format(
                                test_batch))
                            self.pred_data = []
                            self.true_data = []
                            if self.args.debug:
                                test_batch = 3
                            for test_id in range(test_batch):
                                self.build_batch(self.test_data[test_id])
                                ratePred, rateTrue = self.test_worker()

                                self.pred_data += ratePred
                                self.true_data += rateTrue

                                show_progress(test_id + 1, test_batch)
                            print()
                            # save testing data.
                            try:
                                self.pred_data = np.array(
                                    self.pred_data).astype(int)
                                np.savetxt(os.path.join(
                                    '..', 'nats_results',
                                    'test_pred_{}.txt'.format(epoch + 1)),
                                           self.pred_data,
                                           fmt='%d')
                                self.true_data = np.array(
                                    self.true_data).astype(int)
                                np.savetxt(os.path.join(
                                    '..', 'nats_results',
                                    'test_true_{}.txt'.format(epoch + 1)),
                                           self.true_data,
                                           fmt='%d')
                            except:
                                fout = open(
                                    os.path.join(
                                        '..', 'nats_results',
                                        'test_pred_{}.pickled'.format(epoch +
                                                                      1)),
                                    'wb')
                                pickle.dump(self.pred_data, fout)
                                fout.close()
                                fout = open(
                                    os.path.join(
                                        '..', 'nats_results',
                                        'test_true_{}.pickled'.format(epoch +
                                                                      1)),
                                    'wb')
                                pickle.dump(self.true_data, fout)
                                fout.close()
                            # evaluate testing
                            self.run_evaluation()
                        print('====================================')

                    for model_name in self.base_models:
                        self.base_models[model_name].train()
                    for model_name in self.train_models:
                        self.train_models[model_name].train()
            if self.args.lr_schedule == 'build-in':
                scheduler.step()