Пример #1
0
    def test(self):
        for test_iterator, test_output_path in \
                zip(self.test_iterators, self.config['Test']['output_path']):

            create_path(get_path_prefix(test_output_path))
            self.model.eval()
            with torch.no_grad():

                hypotheses = []

                with tqdm(test_iterator) as bar:
                    bar.set_description("inference")

                    for batch in bar:
                        # [batch size, max len]
                        new_batch = SrcTestBatch(
                            batch.src, self.vocab['src'].stoi['<pad>'])

                        result = self.model.classify_forward(
                            new_batch.src,
                            new_batch.src_mask,
                            None,
                            train=False)
                        logits = result['emb_classify_logits']
                        prediction = torch.max(logits, dim=-1)[1]
                        for i in range(0, prediction.size(0)):
                            predict = prediction[i].item()
                            for domain in self.domain_dict:
                                if self.domain_dict[domain] == predict:
                                    hypotheses.append(domain)

                with open(test_output_path, 'w', encoding='utf-8') as f:
                    f.write('\n'.join(hypotheses))
Пример #2
0
def main():
    torch.manual_seed(3333)
    np.random.seed(3333)

    config_file_path = sys.argv[1]

    print('read config')
    with open(config_file_path, 'r') as config_file:
        config = yaml.load(config_file)

    # ================================================================================== #
    # set the device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # set the data fields
    mt_data_loader = MTDataLoader(config)
    mt_data_loader.build_vocab()
    vocab = mt_data_loader.vocab

    model_builder = ModelBuilder()

    model = model_builder.build_model(
        model_name=config['copy_adapter']['model_name'],
        model_config=config['Model'],
        vocab=vocab,
        device=device,
        load_pretrained=False,
        pretrain_path=None)
    model_dict = model.state_dict()

    # load from trained model
    load_model_dict = torch.load(config['copy_adapter']['load_model_path'])
    model_dict.update(load_model_dict)

    # copy adapter parameters according to dict
    for copy_item in config['copy_adapter']['copy_dict']:
        src_adapter_domain = copy_item['src']
        trg_adapter_domain = copy_item['trg']

        for parameter_name in model_dict.keys():
            if trg_adapter_domain in parameter_name:
                src_adapter_parameter_name = parameter_name.replace(
                    trg_adapter_domain, src_adapter_domain)
                # copy value
                model_dict[parameter_name] = model_dict[
                    src_adapter_parameter_name]
                print(parameter_name, src_adapter_parameter_name)

    model.load_state_dict(model_dict)
    create_path(get_path_prefix(config['copy_adapter']['save_path']))
    torch.save(model.state_dict(), config['copy_adapter']['save_path'])
Пример #3
0
def main():
    torch.manual_seed(3333)
    np.random.seed(3333)

    config_file_path = sys.argv[1]

    print('read config')
    with open(config_file_path, 'r') as config_file:
        config = yaml.load(config_file)

    # ================================================================================== #
    # set the device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # set the data fields
    mt_data_loader = MTDataLoader(config)
    mt_data_loader.build_vocab()
    vocab = mt_data_loader.vocab

    model_builder = ModelBuilder()

    model = model_builder.build_model(model_name=config['load_multiple_model']['model_name'],
                                      model_config=config['Model'],
                                      vocab=vocab,
                                      device=device,
                                      load_pretrained=False,
                                      pretrain_path=None)
    model_dict = model.state_dict()

    load_model_dicts = [torch.load(model_path) for model_path in config['load_multiple_model']['load_path']]
    check_inconsistent(load_model_dicts)

    for load_model_dict in load_model_dicts:
        model_dict.update(load_model_dict)

    model.load_state_dict(model_dict)
    create_path(get_path_prefix(config['load_multiple_model']['save_path']))
    torch.save(model.state_dict(), config['load_multiple_model']['save_path'])
Пример #4
0
def main():
    print('cuda is available: ', torch.cuda.is_available())
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    config_file_path = sys.argv[1]

    print('read config')
    with open(config_file_path, 'r') as config_file:
        config = yaml.load(config_file)

    score_ranges = config['score_ranges']

    bert_model_path = config['bert_model_path']
    tokenizer = BertTokenizer.from_pretrained(bert_model_path)

    print('read dataset')

    test_dataset_file = config['test_dataset_file']
    test_dataset = pd.read_csv(test_dataset_file, delimiter='\t', usecols=['essay_set', 'essay_id', 'essay'])

    essay_set = set(test_dataset['essay_set'])

    for set_id, essay_set_id in enumerate(essay_set):
        if config['need_test'][set_id] is False:
            continue

        print('begin set ', essay_set_id, 'processing')

        with open(config['essay_prompt'][set_id]) as f:
            prompt = [f.read()]
            prompt_process = process_data(prompt, tokenizer, True, 300)
            prompt_inputs = prompt_process['inputs']
            prompt_sent_count = prompt_process['sent_count']
            prompt_sent_length = prompt_process['sent_length']
            prompt_mask = prompt_process['attention_mask']

        test_dataset_in_set = test_dataset[test_dataset.essay_set == essay_set_id]

        test_essays = test_dataset_in_set.essay.values

        test_dataset_process = process_data(test_essays, tokenizer, config['split_segment'], config['segment_max_len'])

        ids = test_dataset_in_set.essay_id.values
        test_features = get_feature_from_test_ids(ids, config['test_feature'])

        test_inputs = test_dataset_process['inputs']
        test_sent_count = test_dataset_process['sent_count']
        test_sent_length = test_dataset_process['sent_length']
        test_masks = test_dataset_process['attention_mask']

        test_inputs = torch.tensor(test_inputs).to(device)
        # test_labels = torch.tensor(test_labels).to(device)
        test_masks = torch.tensor(test_masks).to(device)
        test_sent_count = torch.tensor(test_sent_count).to(device)
        test_sent_length = torch.tensor(test_sent_length).to(device)
        test_features = torch.tensor(test_features).to(device)

        prompt_inputs = torch.tensor(prompt_inputs).to(device)
        prompt_mask = torch.tensor(prompt_mask).to(device)
        prompt_sent_count = torch.tensor(prompt_sent_count).to(device)
        prompt_sent_length = torch.tensor(prompt_sent_length).to(device)

        test_data = TensorDataset(test_inputs, test_masks,
                                  test_sent_count, test_sent_length, test_features)
        test_sampler = SequentialSampler(test_data)
        test_dataloader = DataLoader(test_data, sampler=test_sampler,
                                     batch_size=config['batch_size'][set_id])

        print('begin set ', essay_set_id, 'setup model')
        model = make_model(config, device, set_id)

        # set optimizer only to update the new parameters
        model.load_state_dict(torch.load(config['model_save_path'][set_id]))

        # begin training
        print('begin set ', essay_set_id, 'begin test')

        # evaluation
        model.eval()
        with torch.no_grad():
            dev_predict = []
            for batch in test_dataloader:
                batch_inputs, batch_masks, batch_sent_count, batch_sent_length, batch_feature = batch

                if 'classifier' in config['model']:
                    result = model(batch_inputs, batch_masks, batch_sent_count, batch_sent_length,
                                   prompt_inputs, prompt_mask, prompt_sent_count, prompt_sent_length, batch_feature,
                                   None)
                else:
                    result = model(batch_inputs, batch_masks, batch_sent_count, batch_sent_length,
                                   prompt_inputs, prompt_mask, prompt_sent_count, prompt_sent_length,
                                   score_ranges[set_id][0], score_ranges[set_id][1], batch_feature,
                                   None)
                prediction = result['prediction']
                dev_predict.append(prediction)

        dev_predict = torch.cat(dev_predict, dim=0)

        samples = []
        for i in range(0, len(ids)):
            samples.append({})
            samples[i]['domain1_score'] = np.around(dev_predict[i].item())
            samples[i]['essay_id'] = ids[i]
            samples[i]['essay_set'] = essay_set_id
        create_path(get_path_prefix(config['test_output_path'][set_id]))
        save_to_tsv(samples, config['test_output_path'][set_id])

        del model
        del test_inputs
        del test_masks
        torch.cuda.empty_cache()
Пример #5
0
def main():
    torch.manual_seed(3333)
    np.random.seed(3333)

    config_file_path = sys.argv[1]

    print('read config')
    with open(config_file_path, 'r') as config_file:
        config = yaml.load(config_file)

    # ================================================================================== #
    # set the device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # set the data fields
    mt_data_loader = MTDataLoader(config)
    mt_data_loader.build_vocab()
    vocab = mt_data_loader.vocab

    model_builder = ModelBuilder()

    model = model_builder.build_model(
        model_name=config['generate_adapter']['model_name'],
        model_config=config['Model'],
        vocab=vocab,
        device=device,
        load_pretrained=False,
        pretrain_path=None)
    model_dict = model.state_dict()

    # load from trained model
    load_model_dict = torch.load(config['generate_adapter']['load_model_path'])
    model_dict.update(load_model_dict)
    model.load_state_dict(model_dict)

    # from xx-generate to xx
    for generate_item in config['generate_adapter']['generate_dict']:
        src_adapter_name = generate_item['src']
        trg_adapter_name = generate_item['trg']

        for i in range(0, len(model.encoder.layers)):
            print(i)
            w1_weight, w1_bias, w2_weight, w2_bias = model.encoder.layers[
                i].adapters.adapter_layers[src_adapter_name].generate_param(
                    model.encoder.layers[i].adapters.adapter_layers)
            model.encoder.layers[i].adapters.adapter_layers[
                trg_adapter_name].w_1.weight.data = w1_weight.transpose(0, 1)
            model.encoder.layers[i].adapters.adapter_layers[
                trg_adapter_name].w_1.bias.data = w1_bias
            model.encoder.layers[i].adapters.adapter_layers[
                trg_adapter_name].w_2.weight.data = w2_weight.transpose(0, 1)
            model.encoder.layers[i].adapters.adapter_layers[
                trg_adapter_name].w_2.bias.data = w2_bias

        for i in range(0, len(model.encoder.layers)):
            print(i)
            w1_weight, w1_bias, w2_weight, w2_bias = model.decoder.layers[
                i].adapters.adapter_layers[src_adapter_name].generate_param(
                    model.decoder.layers[i].adapters.adapter_layers)
            model.decoder.layers[i].adapters.adapter_layers[
                trg_adapter_name].w_1.weight.data = w1_weight.transpose(0, 1)
            model.decoder.layers[i].adapters.adapter_layers[
                trg_adapter_name].w_1.bias.data = w1_bias
            model.decoder.layers[i].adapters.adapter_layers[
                trg_adapter_name].w_2.weight.data = w2_weight.transpose(0, 1)
            model.decoder.layers[i].adapters.adapter_layers[
                trg_adapter_name].w_2.bias.data = w2_bias

    create_path(get_path_prefix(config['generate_adapter']['save_path']))

    model_dict = model.state_dict()

    for generate_item in config['generate_adapter']['generate_dict']:
        src_adapter_name = generate_item['src']
        trg_adapter_name = generate_item['trg']
        for parameter_name in model_dict.keys():
            if trg_adapter_name in parameter_name and 'sublayer_connection' in parameter_name and 'generate' not in parameter_name:

                src_adapter_parameter_name = parameter_name.replace(
                    trg_adapter_name, src_adapter_name)
                # copy value
                print(parameter_name, src_adapter_parameter_name)
                model_dict[parameter_name] = model_dict[
                    src_adapter_parameter_name]

    model_dict = {k: v for k, v in model_dict.items() if 'generate' not in k}
    torch.save(model_dict, config['generate_adapter']['save_path'])
Пример #6
0
    def decoding(self):
        for test_iterator, test_output_path, test_ref_file_path in \
                zip(self.test_iterators, self.config['Test']['output_path'], self.test_ref_file_paths):

            create_path(get_path_prefix(test_output_path))
            self.model.eval()
            with torch.no_grad():

                hypotheses = []
                with open(test_ref_file_path, 'r', encoding='utf-8') as f:
                    references = f.read().splitlines()

                with tqdm(test_iterator) as bar:
                    bar.set_description("inference")

                    for batch in bar:
                        # [batch size, max len]

                        if self.config['Test']['target_domain'] is None:
                            new_batch = SrcTestBatch(
                                batch.src, self.vocab['src'].stoi['<pad>'])
                            result = self.model.classify_forward(
                                new_batch.src, new_batch.src_mask)
                            logits = result['emb_classify_logits']
                            logits = torch.softmax(logits, dim=-1)
                            target_domain_prob, target_domain = torch.max(
                                logits, -1)
                            for i in range(0, target_domain_prob.size(0)):
                                if target_domain_prob[
                                        i] < 0.90 and target_domain[i].item(
                                        ) != 1:
                                    print('change')
                                    target_domain[i] = 1
                        else:
                            target_domain = self.config['Test'][
                                'target_domain']

                        search_results = self.decoding_step(
                            batch, target_domain)
                        prediction = search_results['prediction']

                        for i in range(prediction.size(0)):
                            hypotheses.append(
                                tensor2str(prediction[i], self.vocab['trg']))

                if self.config['Vocab']['use_bpe']:
                    hypotheses = [de_bpe(sent) for sent in hypotheses]

                test_initial_output_path = test_output_path + '.initial'
                with open(test_initial_output_path, 'w',
                          encoding='utf-8') as f:
                    f.write("\n".join(hypotheses))
                    os.system(self.detokenize_script + ' -l ' +
                              self.target_language + ' < ' +
                              test_initial_output_path + ' > ' +
                              test_output_path)
                with open(test_output_path, 'r', encoding='utf-8') as f:
                    hypotheses = f.read().splitlines()

                bleu_score = sacrebleu.corpus_bleu(
                    hypotheses, [references],
                    tokenize=self.config['Test']['tokenize'])

                print('some examples')
                for i in range(3):
                    print("hyp: ", hypotheses[i])
                    print("ref: ", references[i])

                print()
                print('bleu scores: ', bleu_score)
                print()
Пример #7
0
    def __init__(self,
                 model,
                 criterion,
                 vocab,
                 optimizer,
                 lr_scheduler,
                 train_iterators,
                 train_iterators_domain_list,
                 validation_iterators,
                 validation_iterators_domain_list,
                 domain_dict,
                 optimizer_config,
                 train_config,
                 validation_config,
                 record_config,
                 device,
                 ):

        self.device = device
        self.model = model
        self.vocab = vocab
        self.criterion = criterion

        # iterators
        self.train_iterators = train_iterators
        self.validation_iterators = validation_iterators
        self.train_iterators_domain_list = train_iterators_domain_list
        self.validation_iterators_domain_list = validation_iterators_domain_list
        self.domain_dict = domain_dict

        # optimizer
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.grad_clip = optimizer_config['grad_clip']

        # train process
        self.epoch_num = train_config['epoch_num']
        self.update_batch_count = train_config['update_batch_count']
        self.use_multiple_gpu = train_config['use_multiple_gpu']
        self.current_epoch = 0
        self.current_step = 0

        # loss_validation:
        self.start_loss_validation_on_steps = validation_config['loss_validation']['start_on_steps']
        self.loss_validation_frequency = validation_config['loss_validation']['frequency']

        # record
        # training record
        self.training_record_path = record_config['training_record_path']
        self.writer = SummaryWriter(self.training_record_path + '/visualization')
        self.average_train_loss = 0
        self.best_validation_loss = 10000.

        # best model save
        self.model_record_path = record_config['model_record']['path']
        create_path(self.model_record_path + '/loss_best')
        self.best_loss_model_path = self.model_record_path + '/loss_best/model'
        self.best_loss_optimizer_path = self.model_record_path + '/loss_best/optimizer'
        self.best_loss_lr_scheduler_path = self.model_record_path + '/loss_best/lr_scheduler'

        self.save_loss_best = record_config['model_record']['best_model_save']['loss_best']
        self.best_model_save_optimizer = record_config['model_record']['best_model_save']['save_optimizer']
        self.best_model_save_lr_scheduler = record_config['model_record']['best_model_save']['save_lr_scheduler']

        # checkpoint save
        self.save_checkpoint_start_on_steps = record_config['model_record']['last_checkpoint_save']['start_on_steps']
        self.save_checkpoint_frequency = record_config['model_record']['last_checkpoint_save']['frequency']
        self.checkpoint_num = record_config['model_record']['last_checkpoint_save']['save_checkpoint_count']
        # create_path(self.model_record_path + '/checkpoint')
        self.checkpoint_path = [self.model_record_path + '/checkpoint' + str(i) for i in range(0, self.checkpoint_num)]
        for checkpoint_path in self.checkpoint_path:
            create_path(checkpoint_path)
        self.checkpoint_model_path = [path + '/model' for path in self.checkpoint_path]
        self.checkpoint_optimizer_path = [path + '/optimizer' for path in self.checkpoint_path]
        self.checkpoint_lr_scheduler_path = [path + '/lr_scheduler' for path in self.checkpoint_path]

        self.checkpoint_save_optimizer = record_config['model_record']['last_checkpoint_save']['save_optimizer']
        self.checkpoint_save_lr_scheduler = record_config['model_record']['last_checkpoint_save']['save_lr_scheduler']

        # Option
        self.step_total_loss = 0
        self.step_total_samples = 0
        self.batch_count = 0
        self.current_checkpoint_index = 0
Пример #8
0
def main():
    config_file_path = sys.argv[1]

    print('read config')
    with open(config_file_path, 'r') as config_file:
        config = yaml.load(config_file)
        create_path(config['Record']['training_record_path'])

    # set random seed
    set_random_seed(config['Train']['random_seed'])

    # ================================================================================== #
    # set the device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # set the data fields dict['src': (name, field), 'trg': (name, field)]

    # load dataset
    print('load dataset ...')
    mt_data_loader = MTDataLoader(config)
    mt_data_loader.load_datasets(load_train=True,
                                 load_dev=True,
                                 load_test=False)
    mt_data_loader.build_vocab()
    mt_data_loader.build_iterators(device=device,
                                   build_train=True,
                                   build_dev=True,
                                   build_test=False)

    vocab = mt_data_loader.vocab
    train_iterators = mt_data_loader.train_iterators
    dev_iterators = mt_data_loader.dev_iterators
    dev_test_iterators = mt_data_loader.dev_test_iterators

    model_builder = ModelBuilder()
    model = model_builder.build_model(
        model_name='transformer_with_split_position',
        model_config=config['Model'],
        vocab=vocab,
        device=device,
        load_pretrained=config['Train']['load_exist_model'],
        pretrain_path=config['Train']['model_load_path'])
    criterion = model_builder.build_criterion(
        criterion_config=config['Criterion'], vocab=vocab)
    # make model
    optimizer = model_builder.build_optimizer(
        parameters=model.parameters(),
        optimizer_config=config['Optimizer'],
        load_pretrained=config['Train']['load_optimizer'],
        pretrain_path=config['Train']['optimizer_path'])
    # make optimizer
    lr_scheduler = model_builder.build_lr_scheduler(
        optimizer=optimizer,
        lr_scheduler_config=config['Optimizer']['lr_scheduler'],
        load_pretrained=config['Train']['load_lr_scheduler'],
        pretrain_path=config['Train']['lr_scheduler_path'])

    os.system('cp ' + config_file_path + ' ' +
              config['Record']['training_record_path'] + '/model_config.txt')

    # parameters=filter(lambda p: p.requires_grad, model.parameters()))

    trainer = Split_Position_Trainer(
        model=model,
        criterion=criterion,
        vocab=vocab,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        train_iterators=train_iterators,
        validation_iterators=dev_iterators,
        validation_test_iterators=dev_test_iterators,
        optimizer_config=config['Optimizer'],
        train_config=config['Train'],
        validation_config=config['Validation'],
        record_config=config['Record'],
        device=device,
    )

    trainer.train()
Пример #9
0
def main():
    config_file_path = sys.argv[1]

    print('read config')
    with open(config_file_path, 'r') as config_file:
        config = yaml.load(config_file)
        create_path(config['Record']['training_record_path'])

    # set random seed
    set_random_seed(config['Train']['random_seed'])

    # ================================================================================== #
    # set the device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # set the data fields dict['src': (name, field), 'trg': (name, field)]

    # load dataset
    print('load dataset ...')
    lm_data_loader = LMDataLoader(config)
    lm_data_loader.load_datasets(load_train=True,
                                 load_dev=True,
                                 load_test=False)
    lm_data_loader.build_vocab()
    lm_data_loader.build_iterators(device=device,
                                   build_train=True,
                                   build_dev=True,
                                   build_test=False)

    for i in range(5):
        print(lm_data_loader.train_datasets[0].examples[i].text)

    vocab = lm_data_loader.vocab
    train_iterators = lm_data_loader.train_iterators
    dev_iterators = lm_data_loader.dev_iterators

    # make model
    model_builder = ModelBuilder()
    model = model_builder.build_model(
        model_name='transformer_language_model',
        model_config=config['Model'],
        vocab=vocab,
        device=device,
        load_pretrained=config['Train']['load_exist_model'],
        pretrain_path=config['Train']['model_load_path'])
    print('trained parameters: ')
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name, param.shape)

    # make criterion
    # the label_smoothing of validation criterion is always set to 0
    criterion = model_builder.build_criterion(
        criterion_config=config['Criterion'], vocab=vocab)
    validation_criterion = model_builder.build_criterion(criterion_config={
        'name': 'kl_divergence',
        'label_smoothing': 0,
    },
                                                         vocab=vocab)

    # make optimizer
    optimizer = model_builder.build_optimizer(
        parameters=model.parameters(),
        optimizer_config=config['Optimizer'],
        load_pretrained=config['Train']['load_optimizer'],
        pretrain_path=config['Train']['optimizer_path'])
    # make optimizer
    lr_scheduler = model_builder.build_lr_scheduler(
        optimizer=optimizer,
        lr_scheduler_config=config['Optimizer']['lr_scheduler'],
        load_pretrained=config['Train']['load_lr_scheduler'],
        pretrain_path=config['Train']['lr_scheduler_path'])

    os.system('cp ' + config_file_path + ' ' +
              config['Record']['training_record_path'] + '/model_config.txt')

    # parameters=filter(lambda p: p.requires_grad, model.parameters()))

    trainer = Trainer(
        model=model,
        criterion=criterion,
        validation_criterion=validation_criterion,
        vocab=vocab,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        train_iterators=train_iterators,
        validation_iterators=dev_iterators,
        optimizer_config=config['Optimizer'],
        train_config=config['Train'],
        validation_config=config['Validation'],
        record_config=config['Record'],
        device=device,
    )

    trainer.train()
Пример #10
0
def main():
    torch.manual_seed(0)

    print('cuda is available: ', torch.cuda.is_available())
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    config_file_path = sys.argv[1]

    print('read config')
    with open(config_file_path, 'r') as config_file:
        config = yaml.load(config_file)

    score_ranges = config['score_ranges']

    bert_model_path = config['bert_model_path']
    tokenizer = BertTokenizer.from_pretrained(bert_model_path)

    print('read dataset')
    # dataset = []
    train_dataset_file = config['train_dataset_file']
    train_dataset = pd.read_csv(
        train_dataset_file,
        delimiter='\t',
        usecols=['essay_set', 'essay_id', 'essay', 'domain1_score'])
    dev_dataset_file = config['dev_dataset_file']
    dev_dataset = pd.read_csv(
        dev_dataset_file,
        delimiter='\t',
        usecols=['essay_set', 'essay_id', 'essay', 'domain1_score'])

    essay_set = set(train_dataset['essay_set'])

    # use to save tensor board
    create_path(config['record_path'])
    writer = SummaryWriter(config['record_path'])

    for set_id, essay_set_id in enumerate(essay_set):
        if config['need_training'][set_id] is False:
            continue

        print('begin set ', essay_set_id, 'processing')

        # prepare prompt
        train_essays = []
        validation_essays = []
        train_ids = []
        validation_ids = []
        train_features = []
        validation_features = []
        train_labels = []
        validation_labels = []
        train_prompt_essays = []
        validation_prompt_essays = []
        train_max_scores = []
        train_min_scores = []
        validation_max_scores = []
        validation_min_scores = []
        train_domain_label = []
        validation_domain_label = []

        # train_avg_len = []
        # validation_avg_len = []
        # train_score_bias = []
        # validation_score_bias = []

        current_domain = 0
        for current_set_id, current_essay_set_id in enumerate(essay_set):
            if current_essay_set_id not in config['used_set'][set_id]:
                # if current_set_id == set_id:
                continue
            else:
                print('used set', current_essay_set_id)

            current_train_dataset_in_set = train_dataset[
                train_dataset.essay_set == current_essay_set_id]
            current_validation_dataset_int_set = dev_dataset[
                dev_dataset.essay_set == current_essay_set_id]

            current_train_essays = current_train_dataset_in_set.essay.values
            train_essays.extend(current_train_essays)
            current_validation_essays = current_validation_dataset_int_set.essay.values
            validation_essays.extend(current_validation_essays)

            current_train_ids = current_train_dataset_in_set.essay_id.values
            train_ids.extend(current_train_ids)
            current_validation_ids = current_validation_dataset_int_set.essay_id.values
            validation_ids.extend(current_validation_ids)

            current_train_features = get_feature_from_ids(
                current_train_ids, config['train_feature'])
            train_features.extend(current_train_features)
            current_validation_features = get_feature_from_ids(
                current_validation_ids, config['validation_feature'])
            validation_features.extend(current_validation_features)

            current_train_labels = current_train_dataset_in_set.domain1_score.values
            train_labels.extend(current_train_labels)
            current_validation_labels = current_validation_dataset_int_set.domain1_score.values
            validation_labels.extend(current_validation_labels)

            current_train_max_scores = [
                config['score_ranges'][current_set_id][1]
            ] * len(current_train_ids)
            current_train_min_scores = [
                config['score_ranges'][current_set_id][0]
            ] * len(current_train_ids)
            current_validation_max_scores = [
                config['score_ranges'][current_set_id][1]
            ] * len(current_validation_ids)
            current_validation_min_scores = [
                config['score_ranges'][current_set_id][0]
            ] * len(current_validation_ids)
            train_max_scores.extend(current_train_max_scores)
            train_min_scores.extend(current_train_min_scores)
            validation_max_scores.extend(current_validation_max_scores)
            validation_min_scores.extend(current_validation_min_scores)

            # current_train_avg_len = [config['avg_length'][current_set_id]] * len(current_train_ids)
            # current_validation_avg_len = [config['avg_length'][current_set_id]] * len(current_validation_ids)
            # train_avg_len.extend(current_train_avg_len)
            # validation_avg_len.extend(current_validation_avg_len)

            # current_train_score_bias = [config['score_bias'][current_set_id]] * len(current_train_ids)
            # current_validation_score_bias = [config['score_bias'][current_set_id]] * len(current_validation_ids)
            # train_score_bias.extend(current_train_score_bias)
            # validation_score_bias.extend(current_validation_score_bias)

            current_train_domain_label = [current_domain
                                          ] * len(current_train_ids)
            current_validation_domain_label = [current_domain
                                               ] * len(current_validation_ids)
            train_domain_label.extend(current_train_domain_label)
            validation_domain_label.extend(current_validation_domain_label)
            current_domain += 1

            with open(config['essay_prompt'][set_id]) as f:
                current_prompt_essays = [f.read()]
            current_train_prompt_essays = current_prompt_essays * len(
                current_train_ids)
            train_prompt_essays.extend(current_train_prompt_essays)
            current_validation_prompt_essays = current_prompt_essays * len(
                current_validation_ids)
            validation_prompt_essays.extend(current_validation_prompt_essays)

        train_prompt_process = process_data(train_prompt_essays, tokenizer,
                                            config['split_segment'],
                                            config['segment_max_len'])
        validation_prompt_process = process_data(validation_prompt_essays,
                                                 tokenizer,
                                                 config['split_segment'],
                                                 config['segment_max_len'])
        train_prompt_inputs = train_prompt_process['inputs']
        train_prompt_sent_count = train_prompt_process['sent_count']
        train_prompt_sent_length = train_prompt_process['sent_length']
        train_prompt_mask = train_prompt_process['attention_mask']

        validation_prompt_inputs = validation_prompt_process['inputs']
        validation_prompt_sent_count = validation_prompt_process['sent_count']
        validation_prompt_sent_length = validation_prompt_process[
            'sent_length']
        validation_prompt_mask = validation_prompt_process['attention_mask']

        train_dataset_process = process_data(train_essays, tokenizer,
                                             config['split_segment'],
                                             config['segment_max_len'])
        dev_dataset_process = process_data(validation_essays, tokenizer,
                                           config['split_segment'],
                                           config['segment_max_len'])

        train_inputs = train_dataset_process['inputs']
        train_sent_count = train_dataset_process['sent_count']
        train_sent_length = train_dataset_process['sent_length']
        train_masks = train_dataset_process['attention_mask']

        validation_inputs = dev_dataset_process['inputs']
        validation_sent_count = dev_dataset_process['sent_count']
        validation_sent_length = dev_dataset_process['sent_length']
        validation_masks = dev_dataset_process['attention_mask']

        train_inputs = torch.tensor(train_inputs).to(device)
        validation_inputs = torch.tensor(validation_inputs).to(device)

        train_labels = torch.tensor(train_labels).to(device)
        validation_labels = torch.tensor(validation_labels).to(device)

        train_masks = torch.tensor(train_masks).to(device)
        validation_masks = torch.tensor(validation_masks).to(device)

        train_sent_counts = torch.tensor(train_sent_count).to(device)
        validation_sent_counts = torch.tensor(validation_sent_count).to(device)

        train_sent_length = torch.tensor(train_sent_length).to(device)
        validation_sent_length = torch.tensor(validation_sent_length).to(
            device)

        train_features = torch.tensor(train_features).to(device)
        validation_features = torch.tensor(validation_features).to(device)

        train_max_scores = torch.tensor(train_max_scores).to(device)
        train_min_scores = torch.tensor(train_min_scores).to(device)
        validation_max_scores = torch.tensor(validation_max_scores).to(device)
        validation_min_scores = torch.tensor(validation_min_scores).to(device)

        train_domain_label = torch.tensor(train_domain_label).to(device)
        validation_domain_label = torch.tensor(validation_domain_label).to(
            device)

        train_prompt_inputs = torch.tensor(train_prompt_inputs).to(device)
        train_prompt_mask = torch.tensor(train_prompt_mask).to(device)
        train_prompt_sent_count = torch.tensor(train_prompt_sent_count).to(
            device)
        train_prompt_sent_length = torch.tensor(train_prompt_sent_length).to(
            device)

        validation_prompt_inputs = torch.tensor(validation_prompt_inputs).to(
            device)
        validation_prompt_mask = torch.tensor(validation_prompt_mask).to(
            device)
        validation_prompt_sent_count = torch.tensor(
            validation_prompt_sent_count).to(device)
        validation_prompt_sent_length = torch.tensor(
            validation_prompt_sent_length).to(device)

        # train_avg_len = torch.tensor(train_avg_len).to(device)
        # validation_avg_len = torch.tensor(validation_avg_len).to(device)
        # train_score_bias = torch.tensor(train_score_bias).to(device)
        # validation_score_bias = torch.tensor(validation_score_bias).to(device)

        train_data = TensorDataset(
            train_inputs,
            train_masks,
            train_labels,
            train_sent_counts,
            train_sent_length,
            train_features,
            train_prompt_inputs,
            train_prompt_mask,
            train_prompt_sent_count,
            train_prompt_sent_length,
            train_max_scores,
            train_min_scores,
            train_domain_label,
        )
        train_sampler = RandomSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=config['batch_size'][set_id])

        validation_data = TensorDataset(
            validation_inputs,
            validation_masks,
            validation_labels,
            validation_sent_counts,
            validation_sent_length,
            validation_features,
            validation_prompt_inputs,
            validation_prompt_mask,
            validation_prompt_sent_count,
            validation_prompt_sent_length,
            validation_max_scores,
            validation_min_scores,
            validation_domain_label,
        )

        validation_sampler = SequentialSampler(validation_data)
        validation_dataloader = DataLoader(
            validation_data,
            sampler=validation_sampler,
            batch_size=config['batch_size'][set_id])

        print('begin set ', essay_set_id, 'setup model')
        print('model: ', config['model'])
        model = make_model(config, device, set_id)

        # print_model(model)

        # set optimizer only to update the new parameters
        for name, param in model.named_parameters():
            if 'bert' in name \
                    and 'pooler' not in name \
                    and '11' not in name:
                param.requires_grad = False

        parameters = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = optim.Adam(parameters, lr=0.00005)

        # lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True,
        #                                       min_lr=0.00001)

        # begin training
        print('begin set ', essay_set_id, 'begin training')
        create_path(get_path_prefix(config['model_save_path'][set_id]))
        epoch = config['epoch_num'][set_id]
        global_step = 0
        # best_validation_kappa = 0
        best_validation_loss = 100000
        for current_epoch in trange(epoch, desc='Epoch'):

            train_loss = []
            for step, batch in enumerate(tqdm(train_dataloader)):
                model.train()
                batch_inputs, batch_masks, batch_labels, batch_sent_count, batch_sent_length, batch_features, \
                prompt_inputs, prompt_mask, prompt_sent_count, prompt_sent_length, max_scores, min_scores, \
                train_domain_label = batch

                optimizer.zero_grad()

                total_batch_count = int(train_inputs.shape[0] /
                                        batch_inputs.shape[0])
                p = (step + current_epoch *
                     total_batch_count) / epoch / total_batch_count
                alpha = 2. / (1. + np.exp(-10 * p)) - 1

                result = model(batch_inputs,
                               batch_masks,
                               batch_sent_count,
                               batch_sent_length,
                               prompt_inputs,
                               prompt_mask,
                               prompt_sent_count,
                               prompt_sent_length,
                               min_scores,
                               max_scores,
                               batch_features,
                               batch_labels,
                               domain_label=train_domain_label,
                               alpha=alpha)

                result['loss'].backward()
                nn.utils.clip_grad_norm_(parameters, 1.0)

                train_loss.append(result['loss'].item())
                # print('loss: ', result['loss'].item())
                optimizer.step()

                global_step += 1

                # evaluation
                # if global_step % 25 == 0:
            dev_true = []
            dev_predict = []
            model.eval()
            dev_loss = []
            with torch.no_grad():
                for batch in validation_dataloader:
                    batch_inputs, batch_masks, batch_labels, batch_sent_count, batch_sent_length, batch_features, \
                    prompt_inputs, prompt_mask, prompt_sent_count, prompt_sent_length, max_scores, min_scores, \
                    domain_label = batch

                    result = model(batch_inputs,
                                   batch_masks,
                                   batch_sent_count,
                                   batch_sent_length,
                                   prompt_inputs,
                                   prompt_mask,
                                   prompt_sent_count,
                                   prompt_sent_length,
                                   min_scores,
                                   max_scores,
                                   batch_features,
                                   batch_labels,
                                   domain_label=None,
                                   alpha=None)

                    # prediction = result['prediction']

                    dev_loss.append(result['loss'].item())
                    # dev_true.append(batch_labels)
                    # dev_predict.append(prediction)

                # dev_true = torch.cat(dev_true, dim=0)
                # dev_predict = torch.cat(dev_predict, dim=0)

                # dev_kappa = kappa(y_true=dev_true, y_pred=dev_predict, weights='quadratic')
                # writer.add_scalar(tag='set' + str(essay_set_id) + '_epoch_dev_kappa', scalar_value=dev_kappa,
                #                   global_step=current_epoch)
            dev_loss = np.sum(dev_loss) / len(validation_ids)
            writer.add_scalar(tag='set' + str(essay_set_id) +
                              '_epoch_dev_loss',
                              scalar_value=dev_loss,
                              global_step=current_epoch)

            if dev_loss < best_validation_loss:
                print('get better result save')
                best_validation_loss = dev_loss
                # best_validation_kappa = dev_kappa
                torch.save(model.state_dict(),
                           config['model_save_path'][set_id])

            # lr_scheduler.step(np.average(dev_loss))
            # print('dev_kappa is', dev_kappa)
            print('dev loss ', dev_loss)

            writer.add_scalar('set' + str(essay_set_id) +
                              '_epoch_avg_train_loss',
                              scalar_value=np.sum(train_loss) / len(train_ids),
                              global_step=current_epoch)
            print('average train loss: ', np.sum(train_loss) / len(train_ids))
            print()

        del model
        del train_inputs
        del validation_inputs
        del train_masks
        del validation_masks
        del train_prompt_inputs
        del train_prompt_mask
        del validation_prompt_inputs
        del validation_prompt_mask
        torch.cuda.empty_cache()
Пример #11
0
def main():
    config_file_path = sys.argv[1]

    print('read config')
    with open(config_file_path, 'r') as config_file:
        config = yaml.load(config_file)
        create_path(config['Record']['training_record_path'])

    # set random seed
    set_random_seed(config['Train']['random_seed'])

    # ================================================================================== #
    # set the device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # set the data fields dict['src': (name, field), 'trg': (name, field)]

    # load dataset
    print('load dataset ...')
    mt_data_loader = MTDataLoader(config)
    mt_data_loader.load_datasets(load_train=True, load_dev=True, load_test=False)
    mt_data_loader.build_vocab()
    mt_data_loader.build_iterators(device=device, build_train=True, build_dev=True, build_test=False)

    vocab = mt_data_loader.vocab
    train_iterators = mt_data_loader.train_iterators
    dev_iterators = mt_data_loader.dev_iterators
    dev_test_iterators = mt_data_loader.dev_test_iterators

    # make model
    model_builder = ModelBuilder()
    model = model_builder.build_model(model_name='transformer',
                                      model_config=config['Model'],
                                      vocab=vocab,
                                      device=device,
                                      load_pretrained=config['Train']['load_exist_model'],
                                      pretrain_path=config['Train']['model_load_path'])
    print('trained parameters: ')

    ref_model = model_builder.build_model(model_name='transformer',
                                          model_config=config['Model'],
                                          vocab=vocab,
                                          device=device,
                                          load_pretrained=True,
                                          pretrain_path=config['Train']['ref_model_load_path'])

    if 'params' in config['Train']:
        train_params = config['Train']['params']

        for name, param in model.named_parameters():

            tag = True
            for param_filter in train_params:
                if isinstance(param_filter, str):
                    if param_filter not in name:
                        tag = False
                if isinstance(param_filter, list):
                    if not any(domain in name for domain in param_filter):
                        tag = False
                param.requires_grad = tag

    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name, param.shape)

    # make criterion
    # the label_smoothing of validation criterion is always set to 0
    criterion = model_builder.build_criterion(criterion_config=config['Criterion'], vocab=vocab)
    validation_criterion = model_builder.build_criterion(criterion_config={
        'name': 'kl_divergence',
        'label_smoothing': 0,
    }, vocab=vocab)

    # make optimizer
    optimizer = model_builder.build_optimizer(parameters=model.parameters(),
                                              optimizer_config=config['Optimizer'],
                                              load_pretrained=config['Train']['load_optimizer'],
                                              pretrain_path=config['Train']['optimizer_path'])
    # make optimizer
    lr_scheduler = model_builder.build_lr_scheduler(optimizer=optimizer,
                                                    lr_scheduler_config=config['Optimizer']['lr_scheduler'],
                                                    load_pretrained=config['Train']['load_lr_scheduler'],
                                                    pretrain_path=config['Train']['lr_scheduler_path']
                                                    )

    os.system('cp ' + config_file_path + ' ' + config['Record']['training_record_path'] + '/model_config.txt')

    # parameters=filter(lambda p: p.requires_grad, model.parameters()))

    trainer = Kd_Trainer(
        model=model,
        criterion=criterion,
        validation_criterion=validation_criterion,
        vocab=vocab,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        train_iterators=train_iterators,
        validation_iterators=dev_iterators,
        validation_test_iterators=dev_test_iterators,
        optimizer_config=config['Optimizer'],
        train_config=config['Train'],
        validation_config=config['Validation'],
        record_config=config['Record'],
        device=device,
        ref_model=ref_model,
        ref_temperature=config['Train']['ref_temperature'],
        ref_factor=config['Train']['ref_factor'],
    )

    trainer.train()
Пример #12
0
def main():
    torch.manual_seed(0)

    print('cuda is available: ', torch.cuda.is_available())
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    config_file_path = sys.argv[1]

    print('read config')
    with open(config_file_path, 'r') as config_file:
        config = yaml.load(config_file)

    score_ranges = config['score_ranges']

    bert_model_path = config['bert_model_path']
    tokenizer = BertTokenizer.from_pretrained(bert_model_path)

    print('read dataset')
    # dataset = []
    train_dataset_file = config['train_dataset_file']
    train_dataset = pd.read_csv(train_dataset_file, delimiter='\t', usecols=['essay_set', 'essay_id', 'essay', 'domain1_score'])
    dev_dataset_file = config['dev_dataset_file']
    dev_dataset = pd.read_csv(dev_dataset_file, delimiter='\t', usecols=['essay_set', 'essay_id', 'essay', 'domain1_score'])

    # dataset.append(_train_dataset_)
    # dataset.append(_dev_dataset_)
    # dataset = pd.concat(dataset, axis=0, ignore_index=True)

    essay_set = set(train_dataset['essay_set'])

    # use to save tensor board
    create_path(config['record_path'])
    writer = SummaryWriter(config['record_path'])

    for set_id, essay_set_id in enumerate(essay_set):
        if config['need_training'][set_id] is False:
            continue

        print('begin set ', essay_set_id, 'processing')

        with open(config['essay_prompt'][set_id]) as f:
            prompt = [f.read()]
            prompt_process = process_data(prompt, tokenizer, config['split_segment'], config['segment_max_len'])
            prompt_inputs = prompt_process['inputs']
            prompt_sent_count = prompt_process['sent_count']
            prompt_sent_length = prompt_process['sent_length']
            prompt_mask = prompt_process['attention_mask']

        train_dataset_in_set = train_dataset[train_dataset.essay_set == essay_set_id]
        dev_dataset_in_set = dev_dataset[dev_dataset.essay_set == essay_set_id]

        # essays
        train_essays = train_dataset_in_set.essay.values
        dev_essays = dev_dataset_in_set.essay.values

        # ids
        train_ids = train_dataset_in_set.essay_id.values
        validation_ids = dev_dataset_in_set.essay_id.values

        train_features = get_feature_from_ids(train_ids, config['train_feature'])
        validation_features = get_feature_from_ids(validation_ids, config['validation_feature'])

        train_labels = train_dataset_in_set.domain1_score.values
        validation_labels = dev_dataset_in_set.domain1_score.values

        train_dataset_process = process_data(train_essays, tokenizer, config['split_segment'], config['segment_max_len'])
        dev_dataset_process = process_data(dev_essays, tokenizer, config['split_segment'], config['segment_max_len'])

        train_inputs = train_dataset_process['inputs']
        train_sent_count = train_dataset_process['sent_count']
        train_sent_length = train_dataset_process['sent_length']
        train_masks = train_dataset_process['attention_mask']

        validation_inputs = dev_dataset_process['inputs']
        validation_sent_count = dev_dataset_process['sent_count']
        validation_sent_length = dev_dataset_process['sent_length']
        validation_masks = dev_dataset_process['attention_mask']

        # train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(essay_tokens_pad, labels,
        #                                                                                     random_state=3,
        #                                                                                     test_size=config['dev_dataset_ratio'])
        # train_masks, validation_masks, _, _ = train_test_split(attention_mask, essay_tokens_pad,
        #                                                        random_state=3, test_size=config['dev_dataset_ratio'])
        #
        # train_sent_counts, validation_sent_counts, train_sent_length, validation_sent_length = train_test_split(
        #     essay_sent_count, essay_sent_length, random_state=3, test_size=config['dev_dataset_ratio']
        # )
        # print(train_sent_count)

        train_inputs = torch.tensor(train_inputs).to(device)
        validation_inputs = torch.tensor(validation_inputs).to(device)

        train_labels = torch.tensor(train_labels).to(device)
        validation_labels = torch.tensor(validation_labels).to(device)

        train_masks = torch.tensor(train_masks).to(device)
        validation_masks = torch.tensor(validation_masks).to(device)

        train_sent_counts = torch.tensor(train_sent_count).to(device)
        validation_sent_counts = torch.tensor(validation_sent_count).to(device)

        train_sent_length = torch.tensor(train_sent_length).to(device)
        validation_sent_length = torch.tensor(validation_sent_length).to(device)

        train_features = torch.tensor(train_features).to(device)
        validation_features = torch.tensor(validation_features).to(device)

        prompt_inputs = torch.tensor(prompt_inputs).to(device)
        prompt_mask = torch.tensor(prompt_mask).to(device)
        prompt_sent_count = torch.tensor(prompt_sent_count).to(device)
        prompt_sent_length = torch.tensor(prompt_sent_length).to(device)

        train_data = TensorDataset(train_inputs, train_masks, train_labels,
                                   train_sent_counts, train_sent_length, train_features)
        train_sampler = RandomSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=config['batch_size'][set_id])

        validation_data = TensorDataset(validation_inputs, validation_masks, validation_labels,
                                        validation_sent_counts, validation_sent_length, validation_features)
        validation_sampler = SequentialSampler(validation_data)
        validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=config['batch_size'][set_id])

        print('begin set ', essay_set_id, 'setup model')
        print('model: ', config['model'])
        model = make_model(config, device, set_id)

        # print_model(model)

        # set optimizer only to update the new parameters
        for name, param in model.named_parameters():
            if 'bert' in name \
                    and 'pooler' not in name \
                    and '11' not in name:  # \
                    # and '10' not in name \
                    # and '9' not in name:
                param.requires_grad = False

        parameters = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = optim.Adam(parameters, lr=0.00005)

        # lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True,
        #                                       min_lr=0.00001)

        # begin training
        print('begin set ', essay_set_id, 'begin training')
        create_path(get_path_prefix(config['model_save_path'][set_id]))
        epoch = config['epoch_num'][set_id]
        global_step = 0
        best_validation_kappa = 0
        best_validation_loss = 100000
        for current_epoch in trange(epoch, desc='Epoch'):

            train_loss = []
            for step, batch in enumerate(tqdm(train_dataloader)):
                model.train()
                batch_inputs, batch_masks, batch_labels, batch_sent_count, batch_sent_length, batch_features = batch
                optimizer.zero_grad()

                if 'classifier' in config['model']:
                    result = model(batch_inputs, batch_masks, batch_sent_count, batch_sent_length,
                                   prompt_inputs, prompt_mask, prompt_sent_count, prompt_sent_length, batch_features,
                                   batch_labels)
                else:
                    result = model(batch_inputs, batch_masks, batch_sent_count, batch_sent_length,
                                   prompt_inputs, prompt_mask, prompt_sent_count, prompt_sent_length,
                                   score_ranges[set_id][0], score_ranges[set_id][1], batch_features,
                                   batch_labels)

                result['loss'].backward()
                nn.utils.clip_grad_norm_(parameters, 1.0)

                train_loss.append(result['loss'].item() / batch_inputs.shape[0])
                # print('loss: ', result['loss'].item())
                optimizer.step()

                global_step += 1

                # evaluation
                # if global_step % 25 == 0:
            dev_true = []
            dev_predict = []
            model.eval()
            dev_loss = []
            with torch.no_grad():
                for batch in validation_dataloader:
                    batch_inputs, batch_masks, batch_labels, batch_sent_count, batch_sent_length, batch_features = batch

                    if 'classifier' in config['model']:
                        result = model(batch_inputs, batch_masks, batch_sent_count, batch_sent_length,
                                       prompt_inputs, prompt_mask, prompt_sent_count, prompt_sent_length, batch_features,
                                       batch_labels)
                    else:
                        result = model(batch_inputs, batch_masks, batch_sent_count, batch_sent_length,
                                       prompt_inputs, prompt_mask, prompt_sent_count, prompt_sent_length,
                                       score_ranges[set_id][0], score_ranges[set_id][1], batch_features,
                                       batch_labels)

                    prediction = result['prediction']

                    dev_loss.append(result['loss'].item())
                    dev_true.append(batch_labels)
                    dev_predict.append(prediction)

                dev_true = torch.cat(dev_true, dim=0)
                dev_predict = torch.cat(dev_predict, dim=0)

                dev_kappa = kappa(y_true=dev_true, y_pred=dev_predict, weights='quadratic')
                writer.add_scalar(tag='set' + str(essay_set_id) + '_epoch_dev_kappa', scalar_value=dev_kappa,
                                  global_step=current_epoch)

            dev_loss = np.sum(dev_loss) / validation_ids.shape[0]
            writer.add_scalar(tag='set' + str(essay_set_id) + '_epoch_dev_loss', scalar_value=dev_loss,
                              global_step=current_epoch)

            # if dev_loss < best_validation_loss:
            #     print('get better result save')
            #     best_validation_loss = dev_loss
            #     # best_validation_kappa = dev_kappa
            #     torch.save(model.state_dict(), config['model_save_path'][set_id])

            if dev_kappa > best_validation_kappa:
                print('get better kappa result, save')
                best_validation_kappa = dev_kappa
                torch.save(model.state_dict(), config['model_save_path'][set_id])

            # lr_scheduler.step(np.average(dev_loss))
            print('dev_kappa is', dev_kappa)
            print('dev loss ', dev_loss)

            writer.add_scalar('set'+str(essay_set_id)+'_epoch_avg_train_loss', scalar_value=np.average(train_loss), global_step=current_epoch)
            print('average train loss: ', np.average(train_loss))
            print()

        del model
        del train_inputs
        del validation_inputs
        del train_masks
        del validation_masks
        torch.cuda.empty_cache()
Пример #13
0
def main():
    print('cuda is available: ', torch.cuda.is_available())
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    config_file_path = sys.argv[1]

    print('read config')
    with open(config_file_path, 'r') as config_file:
        config = yaml.load(config_file)

    score_ranges = config['score_ranges']

    bert_model_path = config['bert_model_path']
    tokenizer = BertTokenizer.from_pretrained(bert_model_path)

    print('read dataset')

    test_dataset_file = config['test_dataset_file']
    test_dataset = pd.read_csv(test_dataset_file, delimiter='\t', usecols=['essay_set', 'essay_id', 'essay'])

    essay_set = set(test_dataset['essay_set'])

    for set_id, essay_set_id in enumerate(essay_set):
        if config['need_test'][set_id] is False:
            continue

        print('begin set ', essay_set_id, 'processing')

        test_dataset_in_set = test_dataset[test_dataset.essay_set == essay_set_id]

        test_essays = test_dataset_in_set.essay.values

        test_dataset_process = process_data(test_essays, tokenizer, config['split_segment'], config['segment_max_len'])

        ids = test_dataset_in_set.essay_id.values
        test_features = get_feature_from_test_ids(ids, config['test_feature'])

        test_inputs = test_dataset_process['inputs']
        test_sent_count = test_dataset_process['sent_count']
        test_sent_length = test_dataset_process['sent_length']
        test_masks = test_dataset_process['attention_mask']

        test_inputs = torch.tensor(test_inputs).to(device)
        # test_labels = torch.tensor(test_labels).to(device)
        test_masks = torch.tensor(test_masks).to(device)
        test_sent_count = torch.tensor(test_sent_count).to(device)
        test_sent_length = torch.tensor(test_sent_length).to(device)
        test_features = torch.tensor(test_features).to(device)

        with open(config['essay_prompt'][set_id]) as f:
            test_prompt = [f.read()] * len(ids)
            test_prompt_process = process_data(test_prompt, tokenizer, config['split_segment'], config['segment_max_len'])
            test_prompt_inputs = test_prompt_process['inputs']
            test_prompt_sent_count = test_prompt_process['sent_count']
            test_prompt_sent_length = test_prompt_process['sent_length']
            test_prompt_mask = test_prompt_process['attention_mask']

        test_prompt_inputs = torch.tensor(test_prompt_inputs).to(device)
        test_prompt_mask = torch.tensor(test_prompt_mask).to(device)
        test_prompt_sent_count = torch.tensor(test_prompt_sent_count).to(device)
        test_prompt_sent_length = torch.tensor(test_prompt_sent_length).to(device)

        test_max_scores = [config['score_ranges'][set_id][1]] * len(ids)
        test_min_scores = [config['score_ranges'][set_id][0]] * len(ids)
        test_max_scores = torch.tensor(test_max_scores).to(device)
        test_min_scores = torch.tensor(test_min_scores).to(device)

        test_data = TensorDataset(test_inputs, test_masks, test_sent_count, test_sent_length, test_features,
                                  test_prompt_inputs, test_prompt_mask, test_prompt_sent_count, test_prompt_sent_length,
                                  test_max_scores, test_min_scores)
        test_sampler = SequentialSampler(test_data)
        test_dataloader = DataLoader(test_data, sampler=test_sampler,
                                     batch_size=config['batch_size'][set_id])

        print('begin set ', essay_set_id, 'setup model')
        model = make_model(config, device, set_id)

        # set optimizer only to update the new parameters
        model.load_state_dict(torch.load(config['model_save_path'][set_id]))

        # begin training
        print('begin set ', essay_set_id, 'begin test')

        # evaluation
        model.eval()
        with torch.no_grad():
            dev_predict = []
            for batch in test_dataloader:
                batch_inputs, batch_masks, batch_sent_count, batch_sent_length, batch_feature, \
                    prompt_inputs, prompt_mask, prompt_sent_count, prompt_sent_length, \
                    batch_max_scores, batch_min_scores = batch

                result = model(batch_inputs, batch_masks, batch_sent_count, batch_sent_length,
                               prompt_inputs, prompt_mask, prompt_sent_count, prompt_sent_length,
                               batch_min_scores, batch_max_scores, batch_feature,
                               None, None, None)
                prediction = result['prediction']
                prediction = prediction[:, 0]
                dev_predict.append(prediction)

        dev_predict = torch.cat(dev_predict, dim=0)
        dev_predict = dev_predict.tolist()

        predict_average = np.average(dev_predict)
        gap = config['mean_score'][set_id] - predict_average

        # predict_average = np.average(dev_predict)
        # gap = config['mean_score'][set_id] - predict_average
        #
        # if essay_set_id in [1, 2, 7, 8]:
        #     if gap < 0:
        #         gap = -math.pow(-gap, 0.666)
        #     else:
        #         gap = math.pow(gap, 0.666)
        #
        # #
        # dev_predict = [temp + gap for temp in dev_predict]
        # if essay_set_id in [2, 3, 4, 5, 6]:
        #     dev_predict = more_uniform(dev_predict)

        # dev_predict = [temp if temp > score_ranges[set_id][0] else score_ranges[set_id][0] for temp in dev_predict]
        # dev_predict = [temp if temp < score_ranges[set_id][1] else score_ranges[set_id][1] for temp in dev_predict]

        samples = []
        for i in range(0, len(ids)):
            samples.append({})
            samples[i]['domain1_score'] = dev_predict[i]  # np.around(dev_predict[i])  # np.around(dev_predict[i].item())
            samples[i]['essay_id'] = ids[i]
            samples[i]['essay_set'] = essay_set_id
        create_path(get_path_prefix(config['test_output_path'][set_id]))
        save_to_tsv(samples, config['test_output_path'][set_id])

        del model
        del test_inputs
        del test_masks
        torch.cuda.empty_cache()
def main():
    config_file_path = sys.argv[1]

    print('read config')
    with open(config_file_path, 'r') as config_file:
        config = yaml.load(config_file)
        create_path(config['Record']['training_record_path'])

    # ================================================================================== #
    # set the device
    set_random_seed(config['Train']['random_seed'])
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # load dataset
    print('load dataset ...')
    mt_data_loader = MTDataLoader(config)
    mt_data_loader.load_datasets(load_train=True,
                                 load_dev=True,
                                 load_test=False)
    mt_data_loader.build_vocab()
    mt_data_loader.build_iterators(device=device,
                                   build_train=True,
                                   build_dev=True,
                                   build_test=False)

    vocab = mt_data_loader.vocab
    train_iterators = mt_data_loader.train_iterators
    train_iterator_domain = config['Dataset']['train_dataset_domain']
    dev_iterators = mt_data_loader.dev_iterators
    dev_test_iterators = mt_data_loader.dev_test_iterators
    dev_iterator_domain = config['Dataset']['dev_dataset_domain']

    model_builder = ModelBuilder()
    model = model_builder.build_model(
        model_name='transformer_with_mix_adapter',
        model_config=config['Model'],
        vocab=vocab,
        device=device,
        load_pretrained=config['Train']['load_exist_model'],
        pretrain_path=config['Train']['model_load_path'])
    model.classify_domain_mask = model.classify_domain_mask.to(device)

    criterion = model_builder.build_criterion(
        criterion_config=config['Criterion'], vocab=vocab)
    validation_criterion = model_builder.build_criterion(criterion_config={
        'name': 'kl_divergence',
        'label_smoothing': 0,
    },
                                                         vocab=vocab)

    # training_domain = config['Train']['training_domain']
    training_stage = config['Train']['stage']

    if training_stage == 'classify':
        for name, param in model.named_parameters():
            if 'classifier' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False

    elif training_stage == 'mixture_of_experts':
        for name, param in model.named_parameters():
            if 'adapter' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False

    elif training_stage == 'train_gate':
        for name, param in model.named_parameters():
            if 'inner_gate' in name and any(
                    domain in name for domain in train_iterator_domain):
                param.requires_grad = True
            else:
                param.requires_grad = False

    else:
        # update specific domain adapter
        for name, param in model.named_parameters():
            if 'adapter' in name and any(domain in name
                                         for domain in train_iterator_domain):
                param.requires_grad = True
            else:
                param.requires_grad = False

    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name, param.shape)

    parameters = filter(lambda p: p.requires_grad, model.parameters())

    optimizer = model_builder.build_optimizer(
        parameters=parameters,
        optimizer_config=config['Optimizer'],
        load_pretrained=config['Train']['load_optimizer'],
        pretrain_path=config['Train']['optimizer_path'])
    # make optimizer
    lr_scheduler = model_builder.build_lr_scheduler(
        optimizer=optimizer,
        lr_scheduler_config=config['Optimizer']['lr_scheduler'],
        load_pretrained=config['Train']['load_lr_scheduler'],
        pretrain_path=config['Train']['lr_scheduler_path'])

    os.system('cp ' + config_file_path + ' ' +
              config['Record']['training_record_path'] + '/model_config.txt')

    # parameters=filter(lambda p: p.requires_grad, model.parameters()))
    if training_stage == 'classify':
        trainer = ClassifierTrainer(
            model=model,
            criterion=criterion,
            vocab=vocab,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            train_iterators=train_iterators,
            train_iterators_domain_list=train_iterator_domain,
            validation_iterators=dev_iterators,
            validation_iterators_domain_list=dev_iterator_domain,
            domain_dict=config['Model']['domain_dict'],
            optimizer_config=config['Optimizer'],
            train_config=config['Train'],
            validation_config=config['Validation'],
            record_config=config['Record'],
            device=device,
        )
    elif training_stage == 'mix_adapter_translation' or training_stage == 'train_gate' or training_stage == 'mixture_of_experts':
        trainer = Mix_Adapter_Trainer(
            model=model,
            criterion=criterion,
            validation_criterion=validation_criterion,
            vocab=vocab,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            train_iterators=train_iterators,
            validation_iterators=dev_iterators,
            validation_test_iterators=dev_test_iterators,
            optimizer_config=config['Optimizer'],
            train_config=config['Train'],
            validation_config=config['Validation'],
            record_config=config['Record'],
            device=device,
            target_domain=config['Train']['target_domain'],
            used_domain_list=config['Train']['used_domain_list'],
            used_inner_gate=config['Train']['used_inner_gate'],
        )
    elif training_stage == 'kd':
        trainer = Kd_Adapter_Trainer(
            model=model,
            criterion=criterion,
            validation_criterion=validation_criterion,
            vocab=vocab,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            train_iterators=train_iterators,
            validation_iterators=dev_iterators,
            validation_test_iterators=dev_test_iterators,
            optimizer_config=config['Optimizer'],
            train_config=config['Train'],
            validation_config=config['Validation'],
            record_config=config['Record'],
            device=device,
            target_domain=config['Train']['target_domain'],
            ref_domain_dict=config['Train']['kd_ref_domain'])
    else:
        trainer = Adapter_Trainer(
            model=model,
            criterion=criterion,
            validation_criterion=validation_criterion,
            vocab=vocab,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            train_iterators=train_iterators,
            validation_iterators=dev_iterators,
            validation_test_iterators=dev_test_iterators,
            optimizer_config=config['Optimizer'],
            train_config=config['Train'],
            validation_config=config['Validation'],
            record_config=config['Record'],
            device=device,
            target_domain=config['Train']['target_domain'],
        )

    trainer.train()
Пример #15
0
def main():

    torch.manual_seed(3)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(3)
    random.seed(3)
    np.random.seed(3)
    torch.backends.cudnn.deterministic = True

    config_file_path = sys.argv[1]

    print('read config')
    with open(config_file_path, 'r') as config_file:
        config = yaml.load(config_file)
        create_path(config['Record']['path'])

    # ================================================================================== #
    # set the device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # set the data fields dict['src': (name, field), 'trg': (name, field)]
    data_fields = mt_data_fields()

    # load dataset
    print('load dataset ...')
    meta_train_datasets = load_datasets(
        paths=config['Dataset']['meta_train_dataset_path'],
        data_fields=data_fields,
        filter_len=config['Dataset']['filter_len'])

    meta_dev_datasets = load_datasets(
        paths=config['Dataset']['meta_dev_dataset_path'],
        data_fields=data_fields,
        filter_len=config['Dataset']['filter_len'])

    validation_dataset = load_datasets(
        paths=config['Dataset']['validation_dataset_path'],
        data_fields=data_fields)[0]

    # load vocab
    print('begin load vocab ...')
    vocab = build_vocabs(
        data_fields={
            'src': data_fields[0][1],
            'trg': data_fields[1][1]
        },  # list( tuple(name, field) )
        path={
            'src': config['Vocab']['src']['file'],
            'trg': config['Vocab']['trg']['file']
        },
        max_size={
            'src': config['Vocab']['src']['max_size'],
            'trg': config['Vocab']['trg']['max_size']
        },
        special_tokens={
            'src': ['<unk>', '<pad>', '<sos>', '<eos>'],
            'trg': ['<unk>', '<pad>', '<sos>', '<eos>']
        })

    meta_train_iterators = [
        MyIterator(dataset=dataset,
                   batch_size=config['meta_train']['batch_size'],
                   device=device,
                   repeat=False,
                   sort_key=lambda x: (len(x.src), len(x.trg)),
                   batch_size_fn=batch_size_fn,
                   train=True,
                   shuffle=True) for dataset in meta_train_datasets
    ]

    meta_dev_iterators = [
        MyIterator(dataset=dataset,
                   batch_size=config['meta_dev']['batch_size'],
                   device=device,
                   repeat=False,
                   sort_key=lambda x: (len(x.src), len(x.trg)),
                   batch_size_fn=batch_size_fn,
                   train=True,
                   shuffle=True) for dataset in meta_dev_datasets
    ]

    validation_iterator = get_dev_iterator(
        dataset=validation_dataset,
        batch_size=config['Validation']['batch_size'],
        device=device)
    validation_test_iterator = get_test_iterator(
        dataset=validation_dataset,
        batch_size=config['Validation']['batch_size'],
        device=device)

    # init or load the model
    model = make_model(vocab=vocab, model_config=config['Model'])

    if config['meta_train']['load_exist_model']:
        model_dict = model.state_dict()
        pretrained_dict = torch.load(config['meta_train']['model_load_path'])
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

    model.current_domain = config['meta_train']['current_domain']

    # fix except adapter

    for name, param in model.named_parameters():
        if 'adapter' not in name:
            param.requires_grad = False
        else:
            param.requires_grad = True

    parameters = filter(lambda p: p.requires_grad, model.parameters())
    # print(parameters)
    # print(model.state_dict().keys())

    create_path(config['Record']['path'] + '/visualization')
    writer = SummaryWriter(config['Record']['path'] + '/visualization')

    meta_optimizer = torch.optim.Adam(
        model.parameters(), lr=config['Optimizer']['meta_learning_rate'])
    model_optimizer = torch.optim.Adam(
        model.parameters(), lr=config['Optimizer']['model_learning_rate'])

    os.system('cp ' + config_file_path + ' ' + config['Record']['path'] +
              '/model_config.txt')

    turn_num = config['meta_train']['turn_num']
    task_num = config['meta_train']['task_num']
    detruecase_script = config['Dataset']['detruecase_script']
    detokenize_script = config['Dataset']['detokenize_script']
    validation_per_step = config['Validation']['per_steps']
    validation_ref = config['Validation']['ref']
    best_bleu_score = 0.0
    create_path(config['Record']['path'] + '/output')
    create_path(config['Record']['model_record_path'] + '/model')
    best_model_path = config['Record'][
        'model_record_path'] + '/model/best_model'

    model.train()
    for i in range(0, turn_num):

        loss_tasks = []
        init_state = copy.deepcopy(model.state_dict())

        for j in range(0, task_num):
            for k in range(0, len(meta_train_iterators)):

                model.load_state_dict(init_state)
                model_optimizer.zero_grad()

                current_meta_train_iterator = meta_train_iterators[k]
                current_meta_test_iterator = meta_dev_iterators[k]

                support_set_batch = next(iter(current_meta_train_iterator))
                query_set_batch = next(iter(current_meta_test_iterator))

                support_loss, n_tokens = step(
                    model,
                    support_set_batch,
                    pad_token=vocab['trg'].stoi['<pad>'])
                support_loss = support_loss / n_tokens
                support_loss.backward()
                model_optimizer.step()

                model.eval()
                query_loss, n_tokens = step(
                    model,
                    query_set_batch,
                    pad_token=vocab['trg'].stoi['<pad>'])
                query_loss = query_loss / n_tokens
                loss_tasks.append(query_loss)
                model.train()

        model.load_state_dict(init_state)
        meta_optimizer.zero_grad()

        meta_loss = torch.stack(loss_tasks).sum(0)
        meta_loss.backward()
        print('meta loss', meta_loss.item())

        # for name, param in model.named_parameters():
        #     if 'adapter' in name:
        #         print(name)
        #         print(param.grad)

        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        meta_optimizer.step()

        # writer.add_scalar(tag='meta_loss', scalar_value=meta_loss.item(), global_step=i)
        if i % validation_per_step == 0:
            avg_valid_loss = \
                loss_validation(model, validation_iterator, pad_token=vocab['trg'].stoi['<pad>'])
            bleu_score = bleu_validation(model, validation_test_iterator,
                                         vocab, validation_ref,
                                         config['Record']['path'] + '/output',
                                         detruecase_script, detokenize_script,
                                         config['Dataset']['target_language'])
            if bleu_score > best_bleu_score:
                best_bleu_score = bleu_score
                torch.save(model.state_dict(), best_model_path)

            writer.add_scalar(tag='meta_loss',
                              scalar_value=meta_loss.item(),
                              global_step=i)
            writer.add_scalar(tag='bleu_score',
                              scalar_value=bleu_score,
                              global_step=i)
            writer.add_scalar(tag='validation_loss',
                              scalar_value=avg_valid_loss,
                              global_step=i)

            model.train()
Пример #16
0
def main():
    config_file_path = sys.argv[1]

    print('read config')
    with open(config_file_path, 'r') as config_file:
        config = yaml.load(config_file)
        create_path(config['Record']['training_record_path'])

    # ================================================================================== #
    # set the device
    set_random_seed(config['Train']['random_seed'])
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # load dataset
    print('load dataset ...')
    mt_data_loader = MTDataLoader(config)
    mt_data_loader.load_datasets(load_train=True,
                                 load_dev=True,
                                 load_test=False)
    mt_data_loader.build_vocab()
    mt_data_loader.build_iterators(device=device,
                                   build_train=True,
                                   build_dev=True,
                                   build_test=False)

    vocab = mt_data_loader.vocab
    train_iterators = mt_data_loader.train_iterators
    dev_iterators = mt_data_loader.dev_iterators
    dev_test_iterators = mt_data_loader.dev_test_iterators

    model_builder = ModelBuilder()
    model = model_builder.build_model(
        model_name='transformer_with_parallel_adapter',
        model_config=config['Model'],
        vocab=vocab,
        device=device,
        load_pretrained=config['Train']['load_exist_model'],
        pretrain_path=config['Train']['model_load_path'])
    criterion = model_builder.build_criterion(
        criterion_config=config['Criterion'], vocab=vocab)

    training_domain = config['Train']['training_domain']
    for name, param in model.named_parameters():
        if 'adapter' not in name or training_domain not in name:
            param.requires_grad = False
        else:
            param.requires_grad = True

    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name, param.shape)

    parameters = filter(lambda p: p.requires_grad, model.parameters())

    optimizer = model_builder.build_optimizer(
        parameters=parameters,
        optimizer_config=config['Optimizer'],
        load_pretrained=config['Train']['load_optimizer'],
        pretrain_path=config['Train']['optimizer_path'])
    # make optimizer
    lr_scheduler = model_builder.build_lr_scheduler(
        optimizer=optimizer,
        lr_scheduler_config=config['Optimizer']['lr_scheduler'],
        load_pretrained=config['Train']['load_lr_scheduler'],
        pretrain_path=config['Train']['lr_scheduler_path'])

    os.system('cp ' + config_file_path + ' ' +
              config['Record']['training_record_path'] + '/model_config.txt')

    # parameters=filter(lambda p: p.requires_grad, model.parameters()))

    trainer = MultiAdapterTrainer(
        model=model,
        criterion=criterion,
        vocab=vocab,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        train_iterators=train_iterators,
        validation_iterators=dev_iterators,
        validation_test_iterators=dev_test_iterators,
        optimizer_config=config['Optimizer'],
        train_config=config['Train'],
        validation_config=config['Validation'],
        record_config=config['Record'],
        device=device,
        used_domain_list=config['Train']['used_domain_list'],
    )

    trainer.train()
def main():
    config_file_path = sys.argv[1]

    print('read config')
    with open(config_file_path, 'r') as config_file:
        config = yaml.load(config_file)
        create_path(config['Record']['training_record_path'])

    # ================================================================================== #
    # set the device
    set_random_seed(config['Train']['random_seed'])
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # load dataset
    print('load dataset ...')
    mt_data_loader = MTDataLoader(config)
    mt_data_loader.load_datasets(load_train=True,
                                 load_dev=True,
                                 load_test=False)
    mt_data_loader.build_vocab()
    mt_data_loader.build_iterators(device=device,
                                   build_train=True,
                                   build_dev=True,
                                   build_test=False)

    vocab = mt_data_loader.vocab
    train_iterators = mt_data_loader.train_iterators
    train_iterator_domain = config['Dataset']['train_dataset_domain']
    dev_iterators = mt_data_loader.dev_iterators
    dev_test_iterators = mt_data_loader.dev_test_iterators
    dev_iterator_domain = config['Dataset']['dev_dataset_domain']

    model_builder = ModelBuilder()
    model = model_builder.build_model(
        model_name='transformer_with_mix_adapter_update',
        model_config=config['Model'],
        vocab=vocab,
        device=device,
        load_pretrained=config['Train']['load_exist_model'],
        pretrain_path=config['Train']['model_load_path'])
    # model.classify_domain_mask = model.classify_domain_mask.to(device)

    criterion = model_builder.build_criterion(
        criterion_config=config['Criterion'], vocab=vocab)
    validation_criterion = model_builder.build_criterion(criterion_config={
        'name': 'kl_divergence',
        'label_smoothing': 0,
    },
                                                         vocab=vocab)

    # training_domain = config['Train']['training_domain']
    train_params = config['Train']['params']

    for name, param in model.named_parameters():

        tag = True
        for param_filter in train_params:
            if isinstance(param_filter, str):
                if param_filter not in name:
                    tag = False
            if isinstance(param_filter, list):
                if not any(domain in name for domain in param_filter):
                    tag = False
            param.requires_grad = tag

    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name, param.shape)

    parameters = filter(lambda p: p.requires_grad, model.parameters())

    optimizer = model_builder.build_optimizer(
        parameters=parameters,
        optimizer_config=config['Optimizer'],
        load_pretrained=config['Train']['load_optimizer'],
        pretrain_path=config['Train']['optimizer_path'])
    # make optimizer
    lr_scheduler = model_builder.build_lr_scheduler(
        optimizer=optimizer,
        lr_scheduler_config=config['Optimizer']['lr_scheduler'],
        load_pretrained=config['Train']['load_lr_scheduler'],
        pretrain_path=config['Train']['lr_scheduler_path'])

    os.system('cp ' + config_file_path + ' ' +
              config['Record']['training_record_path'] + '/model_config.txt')

    # parameters=filter(lambda p: p.requires_grad, model.parameters()))

    trainer = Mix_Adapter_Trainer(
        model=model,
        criterion=criterion,
        validation_criterion=validation_criterion,
        vocab=vocab,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        train_iterators=train_iterators,
        validation_iterators=dev_iterators,
        validation_test_iterators=dev_test_iterators,
        optimizer_config=config['Optimizer'],
        train_config=config['Train'],
        validation_config=config['Validation'],
        record_config=config['Record'],
        device=device,
    )

    trainer.train()