Ejemplo n.º 1
0
    def test(self):
        '''
        testing
        Don't overwrite.
        '''
        self.build_vocabulary()
        self.build_models()
        pprint(self.base_models)
        pprint(self.train_models)
        if len(self.base_models) > 0:
            self.init_base_model_params()

        _nbatch = create_batch_file(path_data=self.args.data_dir,
                                    path_work=os.path.join(
                                        '..', 'nats_results'),
                                    is_shuffle=False,
                                    fkey_=self.args.task,
                                    file_=self.args.file_test,
                                    batch_size=self.args.test_batch_size)
        print('The number of batches (test): {}'.format(_nbatch))

        for model_name in self.base_models:
            self.base_models[model_name].eval()
        for model_name in self.train_models:
            self.train_models[model_name].eval()
        with torch.no_grad():
            if self.args.use_optimal_model:
                model_valid_file = os.path.join('..', 'nats_results',
                                                'model_validate.txt')
                fp = open(model_valid_file, 'r')
                for line in fp:
                    arr = re.split(r'\s', line[:-1])
                    model_optimal_key = ''.join(
                        ['_', arr[1], '_', arr[2], '.model'])
                    break
                fp.close()
            else:
                arr = re.split(r'\D', self.args.model_optimal_key)
                model_optimal_key = ''.join(
                    ['_', arr[0], '_', arr[1], '.model'])
            print("You choose to use *{} for decoding.".format(
                model_optimal_key))

            for model_name in self.train_models:
                model_optimal_file = os.path.join(
                    '..', 'nats_results', model_name + model_optimal_key)
                self.train_models[model_name].load_state_dict(
                    torch.load(model_optimal_file,
                               map_location=lambda storage, loc: storage))

            self.test_worker(_nbatch)
            print()
Ejemplo n.º 2
0
    def train(self):
        '''
        training here.
        Don't overwrite.
        '''
        self.build_vocabulary()
        self.build_models()
        pprint(self.train_models)
        # here it is necessary to put list.
        for model_name in self.train_models:
            try:
                params += list(self.train_models[model_name].parameters())
            except:
                params = list(self.train_models[model_name].parameters())
        print('Total number of parameters: {}.'.format(
            sum([para.numel() for para in params])))
        # define optimizer
        optimizer = self.build_optimizer(params)
        # load checkpoint
        uf_epoch = 0
        out_dir = os.path.join('..', 'nats_results')
        if not os.path.exists(out_dir):
            os.mkdir(out_dir)
        if self.args.continue_training:
            model_para_files = glob.glob(os.path.join(out_dir, '*.model'))
            if len(model_para_files) > 0:
                uf_model = []
                for fl_ in model_para_files:
                    arr = re.split(r'\/', fl_)[-1]
                    arr = re.split(r'\_|\.', arr)
                    arr = [int(arr[-3]), int(arr[-2])]
                    if arr not in uf_model:
                        uf_model.append(arr)
                cc_model = sorted(uf_model)[-1]
                try:
                    print("Try *_{}_{}.model".format(cc_model[0], cc_model[1]))
                    for model_name in self.train_models:
                        fl_ = os.path.join(
                            out_dir, model_name+'_'+str(cc_model[0])+'_'+str(cc_model[1])+'.model')
                        self.train_models[model_name].load_state_dict(
                            torch.load(fl_, map_location=lambda storage, loc: storage))
                except:
                    cc_model = sorted(uf_model)[-2]
                    print("Try *_{}_{}.model".format(cc_model[0], cc_model[1]))
                    for model_name in self.train_models:
                        fl_ = os.path.join(
                            out_dir, model_name+'_'+str(cc_model[0])+'_'+str(cc_model[1])+'.model')
                        self.train_models[model_name].load_state_dict(
                            torch.load(fl_, map_location=lambda storage, loc: storage))
                print(
                    'Continue training with *_{}_{}.model'.format(cc_model[0], cc_model[1]))
                uf_model = cc_model
        else:
            shutil.rmtree(out_dir)
            os.mkdir(out_dir)
        # train models
        fout = open('../nats_results/args.pickled', 'wb')
        pickle.dump(self.args, fout)
        fout.close()
        start_time = time.time()
        cclb = 0
        if uf_epoch < 0:
            uf_epoch = 0
        for epoch in range(uf_epoch, self.args.n_epoch):
            n_batch = create_batch_file(
                path_data=self.args.data_dir,
                path_work=os.path.join('..', 'nats_results'),
                is_shuffle=True,
                fkey_=self.args.task,
                file_=self.args.file_train,
                batch_size=self.args.batch_size,
                is_lower=self.args.is_lower)
            print('The number of batches: {}'.format(n_batch))
            self.global_steps = n_batch * max(0, epoch)
            for batch_id in range(n_batch):
                self.global_steps += 1
                learning_rate = self.args.learning_rate
                if self.args.lr_schedule == 'warm-up':
                    learning_rate = 2.0 * \
                        (self.args.model_size ** (-0.5) *
                         min(self.global_steps ** (-0.5),
                             self.global_steps * self.args.warmup_step**(-1.5)))
                    for p in optimizer.param_groups:
                        p['lr'] = learning_rate
                elif self.args.lr_schedule == 'build-in':
                    for p in optimizer.param_groups:
                        learning_rate = p['lr']
                        break
                if cclb == 0 and batch_id < n_batch-1 and batch_id <= uf_model[1]:
                    continue
                else:
                    cclb += 1
                self.build_batch(batch_id)
                loss = self.build_pipelines()

                if loss != loss:
                    raise ValueError

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(params, self.args.grad_clip)
                optimizer.step()

                end_time = time.time()
                if batch_id % self.args.checkpoint == 0:
                    for model_name in self.train_models:
                        fmodel = open(os.path.join(
                            out_dir, model_name+'_'+str(epoch)+'_'+str(batch_id)+'.model'), 'wb')
                        torch.save(
                            self.train_models[model_name].state_dict(), fmodel)
                        fmodel.close()
                    if not os.path.exists(os.path.join(out_dir, 'model')):
                        os.mkdir(os.path.join(out_dir, 'model'))
                    for model_name in self.train_models:
                        fmodel = open(os.path.join(
                            out_dir, 'model', model_name+'.model'), 'wb')
                        torch.save(
                            self.train_models[model_name].state_dict(), fmodel)
                        fmodel.close()
                if batch_id % 1 == 0:
                    end_time = time.time()
                    print('epoch={}, batch={}, lr={}, loss={}, time={}h'.format(
                        epoch, batch_id, np.around(learning_rate, 6),
                        np.round(float(loss.data.cpu().numpy()), 6),
                        np.round((end_time-start_time)/3600.0, 4)))
                del loss
Ejemplo n.º 3
0
    def train(self):
        '''
        training here.
        Don't overwrite.
        '''
        self.build_vocabulary()
        self.build_models()
        print(self.base_models)
        print(self.train_models)
        if len(self.base_models) > 0:
            self.init_base_model_params()
        # here it is necessary to put list. Instead of directly append.
        for model_name in self.train_models:
            try:
                params += list(self.train_models[model_name].parameters())
            except:
                params = list(self.train_models[model_name].parameters())
        if self.args.train_base_model:
            for model_name in self.base_models:
                try:
                    params += list(self.base_models[model_name].parameters())
                except:
                    params = list(self.base_models[model_name].parameters())
        # define optimizer
        optimizer = self.build_optimizer(params)
        # load checkpoint
        uf_model = [0, -1]
        out_dir = os.path.join('..', 'nats_results')
        if not os.path.exists(out_dir):
            os.mkdir(out_dir)
        if self.args.continue_training:
            model_para_files = glob.glob(os.path.join(out_dir, '*.model'))
            if len(model_para_files) > 0:
                uf_model = []
                for fl_ in model_para_files:
                    arr = re.split(r'\/', fl_)[-1]
                    arr = re.split(r'\_|\.', arr)
                    arr = [int(arr[-3]), int(arr[-2])]
                    if arr not in uf_model:
                        uf_model.append(arr)
                cc_model = sorted(uf_model)[-1]
                try:
                    print("Try *_{}_{}.model".format(cc_model[0], cc_model[1]))
                    for model_name in self.train_models:
                        fl_ = os.path.join(
                            out_dir, model_name + '_' + str(cc_model[0]) +
                            '_' + str(cc_model[1]) + '.model')
                        self.train_models[model_name].load_state_dict(
                            torch.load(
                                fl_,
                                map_location=lambda storage, loc: storage))
                except:
                    cc_model = sorted(uf_model)[-2]
                    print("Try *_{}_{}.model".format(cc_model[0], cc_model[1]))
                    for model_name in self.train_models:
                        fl_ = os.path.join(
                            out_dir, model_name + '_' + str(cc_model[0]) +
                            '_' + str(cc_model[1]) + '.model')
                        self.train_models[model_name].load_state_dict(
                            torch.load(
                                fl_,
                                map_location=lambda storage, loc: storage))
                print('Continue training with *_{}_{}.model'.format(
                    cc_model[0], cc_model[1]))
                uf_model = cc_model

        else:
            shutil.rmtree(out_dir)
            os.mkdir(out_dir)
        # train models
        fout = open('../nats_results/args.pickled', 'wb')
        pickle.dump(self.args, fout)
        fout.close()
        start_time = time.time()
        cclb = 0
        for epoch in range(uf_model[0], self.args.n_epoch):
            n_batch = create_batch_file(path_data=self.args.data_dir,
                                        path_work=os.path.join(
                                            '..', 'nats_results'),
                                        is_shuffle=True,
                                        fkey_=self.args.task,
                                        file_=self.args.file_train,
                                        batch_size=self.args.batch_size,
                                        is_lower=self.args.is_lower)
            print('The number of batches: {}'.format(n_batch))
            self.global_steps = n_batch * max(0, epoch)
            for batch_id in range(n_batch):
                self.global_steps += 1
                if cclb == 0 and batch_id <= uf_model[1]:
                    continue
                else:
                    cclb += 1

                self.build_batch(batch_id)
                loss = self.build_pipelines()

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(params, self.args.grad_clip)
                optimizer.step()

                end_time = time.time()
                if batch_id % self.args.checkpoint == 0:
                    for model_name in self.train_models:
                        fmodel = open(
                            os.path.join(
                                out_dir, model_name + '_' + str(epoch) + '_' +
                                str(batch_id) + '.model'), 'wb')
                        torch.save(self.train_models[model_name].state_dict(),
                                   fmodel)
                        fmodel.close()
                if batch_id % 1 == 0:
                    end_time = time.time()
                    print('epoch={}, batch={}, loss={}, time_escape={}s={}h'.
                          format(epoch, batch_id,
                                 np.round(float(loss.data.cpu().numpy()), 4),
                                 np.round(end_time - start_time, 2),
                                 np.round((end_time - start_time) / 3600.0,
                                          4)))
                    self.print_info_train()
                del loss

            for model_name in self.train_models:
                fmodel = open(
                    os.path.join(
                        out_dir, model_name + '_' + str(epoch) + '_' +
                        str(batch_id) + '.model'), 'wb')
                torch.save(self.train_models[model_name].state_dict(), fmodel)
                fmodel.close()
Ejemplo n.º 4
0
    def test(self):
        '''
        testing
        Don't overwrite.
        '''
        self.args.batch_size = 1

        self.build_vocabulary()
        self.build_models()
        pprint(self.base_models)
        pprint(self.train_models)
        if len(self.base_models) > 0:
            self.init_base_model_params()

        _nbatch = create_batch_file(path_data=self.args.data_dir,
                                    path_work=os.path.join(
                                        '..', 'nats_results'),
                                    is_shuffle=False,
                                    fkey_=self.args.task,
                                    file_=self.args.file_test,
                                    batch_size=self.args.batch_size)
        print('The number of samples (test): {}'.format(_nbatch))

        for model_name in self.base_models:
            self.base_models[model_name].eval()
        for model_name in self.train_models:
            self.train_models[model_name].eval()
        with torch.no_grad():
            if self.args.use_optimal_model:
                model_valid_file = os.path.join('..', 'nats_results',
                                                'model_validate.txt')
                fp = open(model_valid_file, 'r')
                for line in fp:
                    arr = re.split(r'\s', line[:-1])
                    model_optimal_key = ''.join(
                        ['_', arr[1], '_', arr[2], '.model'])
                    break
                fp.close()
            else:
                arr = re.split(r'\D', self.args.model_optimal_key)
                model_optimal_key = ''.join(
                    ['_', arr[0], '_', arr[1], '.model'])
            print("You choose to use *{} for decoding.".format(
                model_optimal_key))

            for model_name in self.train_models:
                model_optimal_file = os.path.join(
                    '..', 'nats_results', model_name + model_optimal_key)
                self.train_models[model_name].load_state_dict(
                    torch.load(model_optimal_file,
                               map_location=lambda storage, loc: storage))

            start_time = time.time()
            output_file = os.path.join('..', 'nats_results',
                                       self.args.file_output)
            data_check = []
            try:
                self.args.continue_decoding
            except:
                self.args.continue_decoding = True
            if os.path.exists(output_file) and self.args.continue_decoding:
                fchk = open(output_file, 'r')
                for line in fchk:
                    data_check.append(line)
                fchk.close()
                data_check = data_check[:-1]
                fchk = open(output_file, 'w')
                for line in data_check:
                    fchk.write(line)
                fchk.close()
            else:
                fout = open(output_file, 'w')
                fout.close()
            try:
                fout = open(output_file, 'a')
            except:
                fout = open(output_file, 'w')
            for batch_id in range(_nbatch):
                if batch_id < len(data_check):
                    continue
                self.build_batch(batch_id)
                self.test_worker()
                json.dump(self.test_data, fout)
                fout.write('\n')

                end_time = time.time()
                show_progress(batch_id, _nbatch,
                              str((end_time - start_time) / 3600)[:8] + "h")
            fout.close()
            print()
Ejemplo n.º 5
0
    def validate(self):
        '''
        Validation here.
        Don't overwrite.
        '''
        self.build_vocabulary()
        self.build_models()
        pprint(self.base_models)
        pprint(self.train_models)
        if len(self.base_models) > 0:
            self.init_base_model_params()

        best_arr = []
        val_file = os.path.join('..', 'nats_results', 'model_validate.txt')
        if os.path.exists(val_file):
            fp = open(val_file, 'r')
            for line in fp:
                arr = re.split(r'\s', line[:-1])
                best_arr.append(
                    [arr[0], arr[1], arr[2],
                     float(arr[3]),
                     float(arr[4])])
            fp.close()

        for model_name in self.base_models:
            self.base_models[model_name].eval()
        for model_name in self.train_models:
            self.train_models[model_name].eval()
        with torch.no_grad():
            while 1:
                model_para_files = []
                model_para_files = glob.glob(
                    os.path.join(
                        '..', 'nats_results',
                        sorted(list(self.train_models))[0] + '*.model'))
                for j in range(len(model_para_files)):
                    arr = re.split(r'\_|\.', model_para_files[j])
                    arr = [int(arr[-3]), int(arr[-2]), model_para_files[j]]
                    model_para_files[j] = arr
                model_para_files = sorted(model_para_files)

                for fl_ in model_para_files:
                    best_model = {itm[0]: itm[3] for itm in best_arr}
                    if fl_[-1] in best_model:
                        continue
                    print('Validate *_{}_{}.model'.format(fl_[0], fl_[1]))

                    losses = []
                    start_time = time.time()
                    if os.path.exists(fl_[-1]):
                        time.sleep(3)
                        try:
                            for model_name in self.train_models:
                                fl_tmp = os.path.join(
                                    '..', 'nats_results', model_name + '_' +
                                    str(fl_[0]) + '_' + str(fl_[1]) + '.model')
                                self.train_models[model_name].load_state_dict(
                                    torch.load(fl_tmp,
                                               map_location=lambda storage,
                                               loc: storage))
                        except:
                            continue
                    else:
                        continue
                    val_batch = create_batch_file(
                        path_data=self.args.data_dir,
                        path_work=os.path.join('..', 'nats_results'),
                        is_shuffle=True,
                        fkey_=self.args.task,
                        file_=self.args.file_val,
                        batch_size=self.args.batch_size)
                    print('The number of batches (test): {}'.format(val_batch))
                    if self.args.val_num_batch > val_batch:
                        self.args.val_num_batch = val_batch
                    for batch_id in range(self.args.val_num_batch):

                        self.build_batch(batch_id)
                        loss = self.build_pipelines()

                        losses.append(loss.data.cpu().numpy())
                        show_progress(batch_id + 1, self.args.val_num_batch)
                    print()
                    losses = np.array(losses)
                    end_time = time.time()
                    if self.args.use_move_avg:
                        try:
                            losses_out = 0.9*losses_out + \
                                0.1*np.average(losses)
                        except:
                            losses_out = np.average(losses)
                    else:
                        losses_out = np.average(losses)
                    best_arr.append([
                        fl_[2], fl_[0], fl_[1], losses_out,
                        end_time - start_time
                    ])
                    best_arr = sorted(best_arr, key=lambda bb: bb[3])
                    if best_arr[0][0] == fl_[2]:
                        out_dir = os.path.join('..', 'nats_results', 'model')
                        try:
                            shutil.rmtree(out_dir)
                        except:
                            pass
                        os.mkdir(out_dir)
                        for model_name in self.base_models:
                            fmodel = open(
                                os.path.join(out_dir, model_name + '.model'),
                                'wb')
                            torch.save(
                                self.base_models[model_name].state_dict(),
                                fmodel)
                            fmodel.close()
                        for model_name in self.train_models:
                            fmodel = open(
                                os.path.join(out_dir, model_name + '.model'),
                                'wb')
                            torch.save(
                                self.train_models[model_name].state_dict(),
                                fmodel)
                            fmodel.close()
                        try:
                            shutil.copy2(
                                os.path.join(self.args.data_dir,
                                             self.args.file_vocab), out_dir)
                        except:
                            pass
                    for itm in best_arr[:self.args.nbestmodel]:
                        print('model={}_{}, loss={}, time={}'.format(
                            itm[1], itm[2], np.round(float(itm[3]), 4),
                            np.round(float(itm[4]), 4)))

                    for itm in best_arr[self.args.nbestmodel:]:
                        tarr = re.split(r'_|\.', itm[0])
                        if tarr[-2] == '0':
                            continue
                        if os.path.exists(itm[0]):
                            for model_name in self.train_models:
                                fl_tmp = os.path.join(
                                    '..', 'nats_results', model_name + '_' +
                                    str(itm[1]) + '_' + str(itm[2]) + '.model')
                                os.unlink(fl_tmp)
                    fout = open(val_file, 'w')
                    for itm in best_arr:
                        if len(itm) == 0:
                            continue
                        fout.write(' '.join([
                            itm[0],
                            str(itm[1]),
                            str(itm[2]),
                            str(itm[3]),
                            str(itm[4])
                        ]) + '\n')
                    fout.close()
Ejemplo n.º 6
0
    def train(self):
        '''
        training here.
        Don't overwrite.
        '''
        self.build_vocabulary()
        self.build_models()
        pprint(self.train_models)
        # here it is necessary to put list.
        for model_name in self.train_models:
            try:
                params += list(self.train_models[model_name].parameters())
            except:
                params = list(self.train_models[model_name].parameters())
        print('Total number of parameters: {}.'.format(
            sum([para.numel() for para in params])))
        # define optimizer
        optimizer = self.build_optimizer(params)
        # load checkpoint
        uf_epoch = 0
        out_dir = os.path.join('..', 'nats_results')
        if not os.path.exists(out_dir):
            os.mkdir(out_dir)
        if self.args.continue_training:
            model_para_files = glob.glob(os.path.join(out_dir, '*.model'))
            if len(model_para_files) > 0:
                for fl_ in model_para_files:
                    arr = re.split(r'\/', fl_)[-1]
                    arr = re.split(r'\_|\.', arr)
                    epoch_idx = int(arr[-2])
                    if epoch_idx > uf_epoch:
                        uf_epoch = epoch_idx
                try:
                    print("Try *_{}.model".format(uf_epoch))
                    for model_name in self.train_models:
                        fl_ = os.path.join(
                            out_dir,
                            '{}_{}.model'.format(model_name, uf_epoch))
                        self.train_models[model_name].load_state_dict(
                            torch.load(
                                fl_,
                                map_location=lambda storage, loc: storage))
                except:
                    uf_epoch -= 1
                    if uf_epoch == -1:
                        uf_epoch = 0
                    else:
                        print("Try *_{}.model".format(uf_epoch))
                        for model_name in self.train_models:
                            fl_ = os.path.join(
                                out_dir,
                                '{}_{}.model'.format(model_name, uf_epoch))
                            self.train_models[model_name].load_state_dict(
                                torch.load(
                                    fl_,
                                    map_location=lambda storage, loc: storage))

                print('Continue training with *_{}.model'.format(uf_epoch))
        else:
            shutil.rmtree(out_dir)
            os.mkdir(out_dir)
        # train models
        fout = open('../nats_results/args.pickled', 'wb')
        pickle.dump(self.args, fout)
        fout.close()
        start_time = time.time()
        cclb = 0
        if uf_epoch < 0:
            uf_epoch = 0
        for epoch in range(uf_epoch, self.args.n_epoch):
            n_batch = create_batch_file(path_data=self.args.data_dir,
                                        path_work=os.path.join(
                                            '..', 'nats_results'),
                                        is_shuffle=True,
                                        fkey_=self.args.task,
                                        file_=self.args.file_train,
                                        batch_size=self.args.batch_size,
                                        is_lower=self.args.is_lower)
            print('The number of batches: {}'.format(n_batch))
            for batch_id in range(n_batch):

                self.build_batch(batch_id)
                loss = self.build_pipelines()

                if loss != loss:
                    raise ValueError

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(params, self.args.grad_clip)
                optimizer.step()

                end_time = time.time()
                if batch_id % self.args.checkpoint == 0:
                    for model_name in self.train_models:
                        fmodel = open(
                            os.path.join(
                                out_dir, model_name + '_' + str(epoch) + '_' +
                                str(batch_id) + '.model'), 'wb')
                        torch.save(self.train_models[model_name].state_dict(),
                                   fmodel)
                        fmodel.close()
                    if not os.path.exists(os.path.join(out_dir, 'model')):
                        os.mkdir(os.path.join(out_dir, 'model'))
                    for model_name in self.train_models:
                        fmodel = open(
                            os.path.join(out_dir, 'model',
                                         model_name + '.model'), 'wb')
                        torch.save(self.train_models[model_name].state_dict(),
                                   fmodel)
                        fmodel.close()
                if batch_id % 1 == 0:
                    end_time = time.time()
                    print('epoch={}, batch={}, loss={}, time_escape={}s={}h'.
                          format(epoch, batch_id,
                                 np.round(float(loss.data.cpu().numpy()), 4),
                                 np.round(end_time - start_time, 2),
                                 np.round((end_time - start_time) / 3600.0,
                                          4)))
                del loss
Ejemplo n.º 7
0
    def validate(self):
        '''
        Validation here.
        Don't overwrite.
        '''
        self.build_vocabulary()
        self.build_models()
        pprint(self.base_models)
        pprint(self.train_models)
        self.init_base_model_params()

        for model_name in self.base_models:
            self.base_models[model_name].eval()
        for model_name in self.train_models:
            self.train_models[model_name].eval()
        with torch.no_grad():
            model_para_files = []
            model_para_files = glob.glob(
                os.path.join('../nats_results',
                             sorted(list(self.train_models))[0] + '*.model'))
            for j in range(len(model_para_files)):
                arr = re.split(r'\_|\.', model_para_files[j])
                arr = [int(arr[-3]), int(arr[-2]), model_para_files[j]]
                model_para_files[j] = arr
            model_para_files = sorted(model_para_files)

            if not os.path.exists(self.args.optimal_model_dir):
                os.mkdir(self.args.optimal_model_dir)
            best_f1 = 0
            for fl_ in model_para_files:
                print('Validate *_{}_{}.model'.format(fl_[0], fl_[1]))
                try:
                    for model_name in self.train_models:
                        fl_tmp = os.path.join(
                            '../nats_results', model_name + '_' + str(fl_[0]) +
                            '_' + str(fl_[1]) + '.model')
                        self.train_models[model_name].load_state_dict(
                            torch.load(
                                fl_tmp,
                                map_location=lambda storage, loc: storage))
                except:
                    print('Models cannot be load!!!')
                    continue
                val_batch = create_batch_file(path_data=self.args.data_dir,
                                              path_work='../nats_results',
                                              is_shuffle=False,
                                              fkey_=self.args.task,
                                              file_=self.args.file_dev,
                                              batch_size=self.args.batch_size)
                print('The number of batches (Dev): {}'.format(val_batch))

                val_results = []
                for batch_id in range(val_batch):
                    start_time = time.time()
                    self.build_batch(batch_id)
                    self.test_worker()
                    val_results += self.test_data
                    self.test_data = []
                    end_time = time.time()
                    show_progress(batch_id + 1, val_batch,
                                  str((end_time - start_time))[:8] + "s")
                print()

                f1 = self.evaluate_worker(val_results)
                print('Best f1: {}; Current f1: {}.'.format(best_f1, f1))

                if f1 > best_f1:
                    for model_name in self.train_models:
                        fmodel = open(
                            os.path.join(self.args.optimal_model_dir,
                                         '{}.model'.format(model_name)), 'wb')
                        torch.save(self.train_models[model_name].state_dict(),
                                   fmodel)
                        fmodel.close()
                    best_f1 = f1
Ejemplo n.º 8
0
    def test(self):
        '''
        testing
        Don't overwrite.
        '''
        self.build_vocabulary()
        self.build_models()
        pprint(self.base_models)
        pprint(self.train_models)
        if len(self.base_models) > 0:
            self.init_base_model_params()

        _nbatch = create_batch_file(path_data=self.args.data_dir,
                                    path_work=os.path.join(
                                        '..', 'nats_results'),
                                    is_shuffle=False,
                                    fkey_=self.args.task,
                                    file_=self.args.file_test,
                                    batch_size=self.args.batch_size)
        print('The number of samples (test): {}'.format(_nbatch))

        for model_name in self.base_models:
            self.base_models[model_name].eval()
        for model_name in self.train_models:
            self.train_models[model_name].eval()
        with torch.no_grad():
            if self.args.use_optimal_model:
                for model_name in self.train_models:
                    fl_ = os.path.join(self.args.optimal_model_dir,
                                       '{}.model'.format(model_name))
                    self.train_models[model_name].load_state_dict(
                        torch.load(fl_,
                                   map_location=lambda storage, loc: storage))
            else:
                arr = re.split(r'\D', self.args.model_optimal_key)
                model_optimal_key = ''.join(
                    ['_', arr[0], '_', arr[1], '.model'])
                print("You choose to use *{} for decoding.".format(
                    model_optimal_key))

                for model_name in self.train_models:
                    model_optimal_file = os.path.join(
                        '../nats_results', model_name + model_optimal_key)
                    self.train_models[model_name].load_state_dict(
                        torch.load(model_optimal_file,
                                   map_location=lambda storage, loc: storage))

            start_time = time.time()
            output_file = os.path.join('../nats_results',
                                       self.args.file_output)

            fout = open(output_file, 'w')
            self.aspect_worker()
            for batch_id in range(_nbatch):
                self.build_batch(batch_id)
                self.test_worker()
                for itm in self.test_data:
                    json.dump(itm, fout)
                    fout.write('\n')
                self.test_data = []
                end_time = time.time()
                show_progress(batch_id + 1, _nbatch,
                              str((end_time - start_time) / 3600)[:8] + "h")
            fout.close()
            print()