Exemplo n.º 1
0
def main():
    torch.manual_seed(3333)
    np.random.seed(3333)

    config_file_path = sys.argv[1]

    print('read config')
    with open(config_file_path, 'r') as config_file:
        config = yaml.load(config_file)

    # ================================================================================== #
    # set the device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # set the data fields
    mt_data_loader = MTDataLoader(config)
    mt_data_loader.build_vocab()
    vocab = mt_data_loader.vocab

    model_builder = ModelBuilder()

    model = model_builder.build_model(
        model_name=config['copy_adapter']['model_name'],
        model_config=config['Model'],
        vocab=vocab,
        device=device,
        load_pretrained=False,
        pretrain_path=None)
    model_dict = model.state_dict()

    # load from trained model
    load_model_dict = torch.load(config['copy_adapter']['load_model_path'])
    model_dict.update(load_model_dict)

    # copy adapter parameters according to dict
    for copy_item in config['copy_adapter']['copy_dict']:
        src_adapter_domain = copy_item['src']
        trg_adapter_domain = copy_item['trg']

        for parameter_name in model_dict.keys():
            if trg_adapter_domain in parameter_name:
                src_adapter_parameter_name = parameter_name.replace(
                    trg_adapter_domain, src_adapter_domain)
                # copy value
                model_dict[parameter_name] = model_dict[
                    src_adapter_parameter_name]
                print(parameter_name, src_adapter_parameter_name)

    model.load_state_dict(model_dict)
    create_path(get_path_prefix(config['copy_adapter']['save_path']))
    torch.save(model.state_dict(), config['copy_adapter']['save_path'])
Exemplo n.º 2
0
    def __init__(self, config, device, model_name):
        self.config = config
        self.device = device

        mt_data_loader = MTDataLoader(config)
        mt_data_loader.load_datasets(load_train=False,
                                     load_dev=False,
                                     load_test=True)
        mt_data_loader.build_vocab()
        mt_data_loader.build_iterators(device=device,
                                       build_train=False,
                                       build_dev=False,
                                       build_test=True)

        self.vocab = mt_data_loader.vocab
        self.test_iterators = mt_data_loader.test_iterators

        model_builder = ModelBuilder()
        model = model_builder.build_model(model_name=model_name,
                                          model_config=config['Model'],
                                          vocab=self.vocab,
                                          device=device,
                                          load_pretrained=False,
                                          pretrain_path=None)

        save_model_state_dict = torch.load(config['Test']['model_path'])
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        save_in_multi_gpu = config['Test']['save_in_multi_gpu']
        if save_in_multi_gpu:
            for k, v in save_model_state_dict.items():
                name = k[7:]  # remove `module.`
                new_state_dict[name] = v

            model.load_state_dict(new_state_dict)
        else:
            model.load_state_dict(save_model_state_dict)

        self.target_language = config['Test']['target_language']
        self.test_ref_file_paths = config['Test']['refs']
        self.detruecase_script = config['Test']['detruecase_script']
        self.detokenize_script = config['Test']['detokenize_script']

        model = model.to(device)
        model.domain_mask = model.domain_mask.to(device)
        self.model = model
Exemplo n.º 3
0
def main():
    torch.manual_seed(3333)
    np.random.seed(3333)

    config_file_path = sys.argv[1]

    print('read config')
    with open(config_file_path, 'r') as config_file:
        config = yaml.load(config_file)

    # ================================================================================== #
    # set the device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # set the data fields
    mt_data_loader = MTDataLoader(config)
    mt_data_loader.build_vocab()
    vocab = mt_data_loader.vocab

    model_builder = ModelBuilder()

    model = model_builder.build_model(model_name=config['load_multiple_model']['model_name'],
                                      model_config=config['Model'],
                                      vocab=vocab,
                                      device=device,
                                      load_pretrained=False,
                                      pretrain_path=None)
    model_dict = model.state_dict()

    load_model_dicts = [torch.load(model_path) for model_path in config['load_multiple_model']['load_path']]
    check_inconsistent(load_model_dicts)

    for load_model_dict in load_model_dicts:
        model_dict.update(load_model_dict)

    model.load_state_dict(model_dict)
    create_path(get_path_prefix(config['load_multiple_model']['save_path']))
    torch.save(model.state_dict(), config['load_multiple_model']['save_path'])
Exemplo n.º 4
0
def main():
    config_file_path = sys.argv[1]

    print('read config')
    with open(config_file_path, 'r') as config_file:
        config = yaml.load(config_file)
        create_path(config['Record']['training_record_path'])

    # ================================================================================== #
    # set the device
    set_random_seed(config['Train']['random_seed'])
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # load dataset
    print('load dataset ...')
    mt_data_loader = MTDataLoader(config)
    mt_data_loader.load_datasets(load_train=True,
                                 load_dev=True,
                                 load_test=False)
    mt_data_loader.build_vocab()
    mt_data_loader.build_iterators(device=device,
                                   build_train=True,
                                   build_dev=True,
                                   build_test=False)

    vocab = mt_data_loader.vocab
    train_iterators = mt_data_loader.train_iterators
    dev_iterators = mt_data_loader.dev_iterators
    dev_test_iterators = mt_data_loader.dev_test_iterators

    model_builder = ModelBuilder()
    model = model_builder.build_model(
        model_name='transformer_with_parallel_adapter',
        model_config=config['Model'],
        vocab=vocab,
        device=device,
        load_pretrained=config['Train']['load_exist_model'],
        pretrain_path=config['Train']['model_load_path'])
    criterion = model_builder.build_criterion(
        criterion_config=config['Criterion'], vocab=vocab)

    training_domain = config['Train']['training_domain']
    for name, param in model.named_parameters():
        if 'adapter' not in name or training_domain not in name:
            param.requires_grad = False
        else:
            param.requires_grad = True

    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name, param.shape)

    parameters = filter(lambda p: p.requires_grad, model.parameters())

    optimizer = model_builder.build_optimizer(
        parameters=parameters,
        optimizer_config=config['Optimizer'],
        load_pretrained=config['Train']['load_optimizer'],
        pretrain_path=config['Train']['optimizer_path'])
    # make optimizer
    lr_scheduler = model_builder.build_lr_scheduler(
        optimizer=optimizer,
        lr_scheduler_config=config['Optimizer']['lr_scheduler'],
        load_pretrained=config['Train']['load_lr_scheduler'],
        pretrain_path=config['Train']['lr_scheduler_path'])

    os.system('cp ' + config_file_path + ' ' +
              config['Record']['training_record_path'] + '/model_config.txt')

    # parameters=filter(lambda p: p.requires_grad, model.parameters()))

    trainer = MultiAdapterTrainer(
        model=model,
        criterion=criterion,
        vocab=vocab,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        train_iterators=train_iterators,
        validation_iterators=dev_iterators,
        validation_test_iterators=dev_test_iterators,
        optimizer_config=config['Optimizer'],
        train_config=config['Train'],
        validation_config=config['Validation'],
        record_config=config['Record'],
        device=device,
        used_domain_list=config['Train']['used_domain_list'],
    )

    trainer.train()
Exemplo n.º 5
0
def main():
    torch.manual_seed(3333)
    np.random.seed(3333)

    config_file_path = sys.argv[1]

    print('read config')
    with open(config_file_path, 'r') as config_file:
        config = yaml.load(config_file)

    # ================================================================================== #
    # set the device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # set the data fields
    mt_data_loader = MTDataLoader(config)
    mt_data_loader.build_vocab()
    vocab = mt_data_loader.vocab

    model_builder = ModelBuilder()

    model = model_builder.build_model(
        model_name=config['generate_adapter']['model_name'],
        model_config=config['Model'],
        vocab=vocab,
        device=device,
        load_pretrained=False,
        pretrain_path=None)
    model_dict = model.state_dict()

    # load from trained model
    load_model_dict = torch.load(config['generate_adapter']['load_model_path'])
    model_dict.update(load_model_dict)
    model.load_state_dict(model_dict)

    # from xx-generate to xx
    for generate_item in config['generate_adapter']['generate_dict']:
        src_adapter_name = generate_item['src']
        trg_adapter_name = generate_item['trg']

        for i in range(0, len(model.encoder.layers)):
            print(i)
            w1_weight, w1_bias, w2_weight, w2_bias = model.encoder.layers[
                i].adapters.adapter_layers[src_adapter_name].generate_param(
                    model.encoder.layers[i].adapters.adapter_layers)
            model.encoder.layers[i].adapters.adapter_layers[
                trg_adapter_name].w_1.weight.data = w1_weight.transpose(0, 1)
            model.encoder.layers[i].adapters.adapter_layers[
                trg_adapter_name].w_1.bias.data = w1_bias
            model.encoder.layers[i].adapters.adapter_layers[
                trg_adapter_name].w_2.weight.data = w2_weight.transpose(0, 1)
            model.encoder.layers[i].adapters.adapter_layers[
                trg_adapter_name].w_2.bias.data = w2_bias

        for i in range(0, len(model.encoder.layers)):
            print(i)
            w1_weight, w1_bias, w2_weight, w2_bias = model.decoder.layers[
                i].adapters.adapter_layers[src_adapter_name].generate_param(
                    model.decoder.layers[i].adapters.adapter_layers)
            model.decoder.layers[i].adapters.adapter_layers[
                trg_adapter_name].w_1.weight.data = w1_weight.transpose(0, 1)
            model.decoder.layers[i].adapters.adapter_layers[
                trg_adapter_name].w_1.bias.data = w1_bias
            model.decoder.layers[i].adapters.adapter_layers[
                trg_adapter_name].w_2.weight.data = w2_weight.transpose(0, 1)
            model.decoder.layers[i].adapters.adapter_layers[
                trg_adapter_name].w_2.bias.data = w2_bias

    create_path(get_path_prefix(config['generate_adapter']['save_path']))

    model_dict = model.state_dict()

    for generate_item in config['generate_adapter']['generate_dict']:
        src_adapter_name = generate_item['src']
        trg_adapter_name = generate_item['trg']
        for parameter_name in model_dict.keys():
            if trg_adapter_name in parameter_name and 'sublayer_connection' in parameter_name and 'generate' not in parameter_name:

                src_adapter_parameter_name = parameter_name.replace(
                    trg_adapter_name, src_adapter_name)
                # copy value
                print(parameter_name, src_adapter_parameter_name)
                model_dict[parameter_name] = model_dict[
                    src_adapter_parameter_name]

    model_dict = {k: v for k, v in model_dict.items() if 'generate' not in k}
    torch.save(model_dict, config['generate_adapter']['save_path'])
Exemplo n.º 6
0
def main():
    config_file_path = sys.argv[1]

    print('read config')
    with open(config_file_path, 'r') as config_file:
        config = yaml.load(config_file)
        create_path(config['Record']['training_record_path'])

    # set random seed
    set_random_seed(config['Train']['random_seed'])

    # ================================================================================== #
    # set the device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # set the data fields dict['src': (name, field), 'trg': (name, field)]

    # load dataset
    print('load dataset ...')
    mt_data_loader = MTDataLoader(config)
    mt_data_loader.load_datasets(load_train=True,
                                 load_dev=True,
                                 load_test=False)
    mt_data_loader.build_vocab()
    mt_data_loader.build_iterators(device=device,
                                   build_train=True,
                                   build_dev=True,
                                   build_test=False)

    vocab = mt_data_loader.vocab
    train_iterators = mt_data_loader.train_iterators
    dev_iterators = mt_data_loader.dev_iterators
    dev_test_iterators = mt_data_loader.dev_test_iterators

    model_builder = ModelBuilder()
    model = model_builder.build_model(
        model_name='transformer_with_split_position',
        model_config=config['Model'],
        vocab=vocab,
        device=device,
        load_pretrained=config['Train']['load_exist_model'],
        pretrain_path=config['Train']['model_load_path'])
    criterion = model_builder.build_criterion(
        criterion_config=config['Criterion'], vocab=vocab)
    # make model
    optimizer = model_builder.build_optimizer(
        parameters=model.parameters(),
        optimizer_config=config['Optimizer'],
        load_pretrained=config['Train']['load_optimizer'],
        pretrain_path=config['Train']['optimizer_path'])
    # make optimizer
    lr_scheduler = model_builder.build_lr_scheduler(
        optimizer=optimizer,
        lr_scheduler_config=config['Optimizer']['lr_scheduler'],
        load_pretrained=config['Train']['load_lr_scheduler'],
        pretrain_path=config['Train']['lr_scheduler_path'])

    os.system('cp ' + config_file_path + ' ' +
              config['Record']['training_record_path'] + '/model_config.txt')

    # parameters=filter(lambda p: p.requires_grad, model.parameters()))

    trainer = Split_Position_Trainer(
        model=model,
        criterion=criterion,
        vocab=vocab,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        train_iterators=train_iterators,
        validation_iterators=dev_iterators,
        validation_test_iterators=dev_test_iterators,
        optimizer_config=config['Optimizer'],
        train_config=config['Train'],
        validation_config=config['Validation'],
        record_config=config['Record'],
        device=device,
    )

    trainer.train()
def main():
    config_file_path = sys.argv[1]

    print('read config')
    with open(config_file_path, 'r') as config_file:
        config = yaml.load(config_file)
        create_path(config['Record']['training_record_path'])

    # ================================================================================== #
    # set the device
    set_random_seed(config['Train']['random_seed'])
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # load dataset
    print('load dataset ...')
    mt_data_loader = MTDataLoader(config)
    mt_data_loader.load_datasets(load_train=True,
                                 load_dev=True,
                                 load_test=False)
    mt_data_loader.build_vocab()
    mt_data_loader.build_iterators(device=device,
                                   build_train=True,
                                   build_dev=True,
                                   build_test=False)

    vocab = mt_data_loader.vocab
    train_iterators = mt_data_loader.train_iterators
    train_iterator_domain = config['Dataset']['train_dataset_domain']
    dev_iterators = mt_data_loader.dev_iterators
    dev_test_iterators = mt_data_loader.dev_test_iterators
    dev_iterator_domain = config['Dataset']['dev_dataset_domain']

    model_builder = ModelBuilder()
    model = model_builder.build_model(
        model_name='transformer_with_mix_adapter_update',
        model_config=config['Model'],
        vocab=vocab,
        device=device,
        load_pretrained=config['Train']['load_exist_model'],
        pretrain_path=config['Train']['model_load_path'])
    # model.classify_domain_mask = model.classify_domain_mask.to(device)

    criterion = model_builder.build_criterion(
        criterion_config=config['Criterion'], vocab=vocab)
    validation_criterion = model_builder.build_criterion(criterion_config={
        'name': 'kl_divergence',
        'label_smoothing': 0,
    },
                                                         vocab=vocab)

    # training_domain = config['Train']['training_domain']
    train_params = config['Train']['params']

    for name, param in model.named_parameters():

        tag = True
        for param_filter in train_params:
            if isinstance(param_filter, str):
                if param_filter not in name:
                    tag = False
            if isinstance(param_filter, list):
                if not any(domain in name for domain in param_filter):
                    tag = False
            param.requires_grad = tag

    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name, param.shape)

    parameters = filter(lambda p: p.requires_grad, model.parameters())

    optimizer = model_builder.build_optimizer(
        parameters=parameters,
        optimizer_config=config['Optimizer'],
        load_pretrained=config['Train']['load_optimizer'],
        pretrain_path=config['Train']['optimizer_path'])
    # make optimizer
    lr_scheduler = model_builder.build_lr_scheduler(
        optimizer=optimizer,
        lr_scheduler_config=config['Optimizer']['lr_scheduler'],
        load_pretrained=config['Train']['load_lr_scheduler'],
        pretrain_path=config['Train']['lr_scheduler_path'])

    os.system('cp ' + config_file_path + ' ' +
              config['Record']['training_record_path'] + '/model_config.txt')

    # parameters=filter(lambda p: p.requires_grad, model.parameters()))

    trainer = Mix_Adapter_Trainer(
        model=model,
        criterion=criterion,
        validation_criterion=validation_criterion,
        vocab=vocab,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        train_iterators=train_iterators,
        validation_iterators=dev_iterators,
        validation_test_iterators=dev_test_iterators,
        optimizer_config=config['Optimizer'],
        train_config=config['Train'],
        validation_config=config['Validation'],
        record_config=config['Record'],
        device=device,
    )

    trainer.train()
Exemplo n.º 8
0
def main():
    config_file_path = sys.argv[1]

    print('read config')
    with open(config_file_path, 'r') as config_file:
        config = yaml.load(config_file)
        create_path(config['Record']['training_record_path'])

    # set random seed
    set_random_seed(config['Train']['random_seed'])

    # ================================================================================== #
    # set the device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # set the data fields dict['src': (name, field), 'trg': (name, field)]

    # load dataset
    print('load dataset ...')
    mt_data_loader = MTDataLoader(config)
    mt_data_loader.load_datasets(load_train=True, load_dev=True, load_test=False)
    mt_data_loader.build_vocab()
    mt_data_loader.build_iterators(device=device, build_train=True, build_dev=True, build_test=False)

    vocab = mt_data_loader.vocab
    train_iterators = mt_data_loader.train_iterators
    dev_iterators = mt_data_loader.dev_iterators
    dev_test_iterators = mt_data_loader.dev_test_iterators

    # make model
    model_builder = ModelBuilder()
    model = model_builder.build_model(model_name='transformer',
                                      model_config=config['Model'],
                                      vocab=vocab,
                                      device=device,
                                      load_pretrained=config['Train']['load_exist_model'],
                                      pretrain_path=config['Train']['model_load_path'])
    print('trained parameters: ')

    ref_model = model_builder.build_model(model_name='transformer',
                                          model_config=config['Model'],
                                          vocab=vocab,
                                          device=device,
                                          load_pretrained=True,
                                          pretrain_path=config['Train']['ref_model_load_path'])

    if 'params' in config['Train']:
        train_params = config['Train']['params']

        for name, param in model.named_parameters():

            tag = True
            for param_filter in train_params:
                if isinstance(param_filter, str):
                    if param_filter not in name:
                        tag = False
                if isinstance(param_filter, list):
                    if not any(domain in name for domain in param_filter):
                        tag = False
                param.requires_grad = tag

    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name, param.shape)

    # make criterion
    # the label_smoothing of validation criterion is always set to 0
    criterion = model_builder.build_criterion(criterion_config=config['Criterion'], vocab=vocab)
    validation_criterion = model_builder.build_criterion(criterion_config={
        'name': 'kl_divergence',
        'label_smoothing': 0,
    }, vocab=vocab)

    # make optimizer
    optimizer = model_builder.build_optimizer(parameters=model.parameters(),
                                              optimizer_config=config['Optimizer'],
                                              load_pretrained=config['Train']['load_optimizer'],
                                              pretrain_path=config['Train']['optimizer_path'])
    # make optimizer
    lr_scheduler = model_builder.build_lr_scheduler(optimizer=optimizer,
                                                    lr_scheduler_config=config['Optimizer']['lr_scheduler'],
                                                    load_pretrained=config['Train']['load_lr_scheduler'],
                                                    pretrain_path=config['Train']['lr_scheduler_path']
                                                    )

    os.system('cp ' + config_file_path + ' ' + config['Record']['training_record_path'] + '/model_config.txt')

    # parameters=filter(lambda p: p.requires_grad, model.parameters()))

    trainer = Kd_Trainer(
        model=model,
        criterion=criterion,
        validation_criterion=validation_criterion,
        vocab=vocab,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        train_iterators=train_iterators,
        validation_iterators=dev_iterators,
        validation_test_iterators=dev_test_iterators,
        optimizer_config=config['Optimizer'],
        train_config=config['Train'],
        validation_config=config['Validation'],
        record_config=config['Record'],
        device=device,
        ref_model=ref_model,
        ref_temperature=config['Train']['ref_temperature'],
        ref_factor=config['Train']['ref_factor'],
    )

    trainer.train()
def main():
    config_file_path = sys.argv[1]

    print('read config')
    with open(config_file_path, 'r') as config_file:
        config = yaml.load(config_file)
        create_path(config['Record']['training_record_path'])

    # ================================================================================== #
    # set the device
    set_random_seed(config['Train']['random_seed'])
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # load dataset
    print('load dataset ...')
    mt_data_loader = MTDataLoader(config)
    mt_data_loader.load_datasets(load_train=True,
                                 load_dev=True,
                                 load_test=False)
    mt_data_loader.build_vocab()
    mt_data_loader.build_iterators(device=device,
                                   build_train=True,
                                   build_dev=True,
                                   build_test=False)

    vocab = mt_data_loader.vocab
    train_iterators = mt_data_loader.train_iterators
    train_iterator_domain = config['Dataset']['train_dataset_domain']
    dev_iterators = mt_data_loader.dev_iterators
    dev_test_iterators = mt_data_loader.dev_test_iterators
    dev_iterator_domain = config['Dataset']['dev_dataset_domain']

    model_builder = ModelBuilder()
    model = model_builder.build_model(
        model_name='transformer_with_mix_adapter',
        model_config=config['Model'],
        vocab=vocab,
        device=device,
        load_pretrained=config['Train']['load_exist_model'],
        pretrain_path=config['Train']['model_load_path'])
    model.classify_domain_mask = model.classify_domain_mask.to(device)

    criterion = model_builder.build_criterion(
        criterion_config=config['Criterion'], vocab=vocab)
    validation_criterion = model_builder.build_criterion(criterion_config={
        'name': 'kl_divergence',
        'label_smoothing': 0,
    },
                                                         vocab=vocab)

    # training_domain = config['Train']['training_domain']
    training_stage = config['Train']['stage']

    if training_stage == 'classify':
        for name, param in model.named_parameters():
            if 'classifier' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False

    elif training_stage == 'mixture_of_experts':
        for name, param in model.named_parameters():
            if 'adapter' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False

    elif training_stage == 'train_gate':
        for name, param in model.named_parameters():
            if 'inner_gate' in name and any(
                    domain in name for domain in train_iterator_domain):
                param.requires_grad = True
            else:
                param.requires_grad = False

    else:
        # update specific domain adapter
        for name, param in model.named_parameters():
            if 'adapter' in name and any(domain in name
                                         for domain in train_iterator_domain):
                param.requires_grad = True
            else:
                param.requires_grad = False

    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name, param.shape)

    parameters = filter(lambda p: p.requires_grad, model.parameters())

    optimizer = model_builder.build_optimizer(
        parameters=parameters,
        optimizer_config=config['Optimizer'],
        load_pretrained=config['Train']['load_optimizer'],
        pretrain_path=config['Train']['optimizer_path'])
    # make optimizer
    lr_scheduler = model_builder.build_lr_scheduler(
        optimizer=optimizer,
        lr_scheduler_config=config['Optimizer']['lr_scheduler'],
        load_pretrained=config['Train']['load_lr_scheduler'],
        pretrain_path=config['Train']['lr_scheduler_path'])

    os.system('cp ' + config_file_path + ' ' +
              config['Record']['training_record_path'] + '/model_config.txt')

    # parameters=filter(lambda p: p.requires_grad, model.parameters()))
    if training_stage == 'classify':
        trainer = ClassifierTrainer(
            model=model,
            criterion=criterion,
            vocab=vocab,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            train_iterators=train_iterators,
            train_iterators_domain_list=train_iterator_domain,
            validation_iterators=dev_iterators,
            validation_iterators_domain_list=dev_iterator_domain,
            domain_dict=config['Model']['domain_dict'],
            optimizer_config=config['Optimizer'],
            train_config=config['Train'],
            validation_config=config['Validation'],
            record_config=config['Record'],
            device=device,
        )
    elif training_stage == 'mix_adapter_translation' or training_stage == 'train_gate' or training_stage == 'mixture_of_experts':
        trainer = Mix_Adapter_Trainer(
            model=model,
            criterion=criterion,
            validation_criterion=validation_criterion,
            vocab=vocab,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            train_iterators=train_iterators,
            validation_iterators=dev_iterators,
            validation_test_iterators=dev_test_iterators,
            optimizer_config=config['Optimizer'],
            train_config=config['Train'],
            validation_config=config['Validation'],
            record_config=config['Record'],
            device=device,
            target_domain=config['Train']['target_domain'],
            used_domain_list=config['Train']['used_domain_list'],
            used_inner_gate=config['Train']['used_inner_gate'],
        )
    elif training_stage == 'kd':
        trainer = Kd_Adapter_Trainer(
            model=model,
            criterion=criterion,
            validation_criterion=validation_criterion,
            vocab=vocab,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            train_iterators=train_iterators,
            validation_iterators=dev_iterators,
            validation_test_iterators=dev_test_iterators,
            optimizer_config=config['Optimizer'],
            train_config=config['Train'],
            validation_config=config['Validation'],
            record_config=config['Record'],
            device=device,
            target_domain=config['Train']['target_domain'],
            ref_domain_dict=config['Train']['kd_ref_domain'])
    else:
        trainer = Adapter_Trainer(
            model=model,
            criterion=criterion,
            validation_criterion=validation_criterion,
            vocab=vocab,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            train_iterators=train_iterators,
            validation_iterators=dev_iterators,
            validation_test_iterators=dev_test_iterators,
            optimizer_config=config['Optimizer'],
            train_config=config['Train'],
            validation_config=config['Validation'],
            record_config=config['Record'],
            device=device,
            target_domain=config['Train']['target_domain'],
        )

    trainer.train()
Exemplo n.º 10
0
def main():
    config_file_path = sys.argv[1]

    print('read config')
    with open(config_file_path, 'r') as config_file:
        config = yaml.load(config_file)
        create_path(config['Record']['training_record_path'])

    # set random seed
    set_random_seed(config['Train']['random_seed'])

    # ================================================================================== #
    # set the device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # set the data fields dict['src': (name, field), 'trg': (name, field)]

    # load dataset
    print('load dataset ...')
    lm_data_loader = LMDataLoader(config)
    lm_data_loader.load_datasets(load_train=True,
                                 load_dev=True,
                                 load_test=False)
    lm_data_loader.build_vocab()
    lm_data_loader.build_iterators(device=device,
                                   build_train=True,
                                   build_dev=True,
                                   build_test=False)

    for i in range(5):
        print(lm_data_loader.train_datasets[0].examples[i].text)

    vocab = lm_data_loader.vocab
    train_iterators = lm_data_loader.train_iterators
    dev_iterators = lm_data_loader.dev_iterators

    # make model
    model_builder = ModelBuilder()
    model = model_builder.build_model(
        model_name='transformer_language_model',
        model_config=config['Model'],
        vocab=vocab,
        device=device,
        load_pretrained=config['Train']['load_exist_model'],
        pretrain_path=config['Train']['model_load_path'])
    print('trained parameters: ')
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name, param.shape)

    # make criterion
    # the label_smoothing of validation criterion is always set to 0
    criterion = model_builder.build_criterion(
        criterion_config=config['Criterion'], vocab=vocab)
    validation_criterion = model_builder.build_criterion(criterion_config={
        'name': 'kl_divergence',
        'label_smoothing': 0,
    },
                                                         vocab=vocab)

    # make optimizer
    optimizer = model_builder.build_optimizer(
        parameters=model.parameters(),
        optimizer_config=config['Optimizer'],
        load_pretrained=config['Train']['load_optimizer'],
        pretrain_path=config['Train']['optimizer_path'])
    # make optimizer
    lr_scheduler = model_builder.build_lr_scheduler(
        optimizer=optimizer,
        lr_scheduler_config=config['Optimizer']['lr_scheduler'],
        load_pretrained=config['Train']['load_lr_scheduler'],
        pretrain_path=config['Train']['lr_scheduler_path'])

    os.system('cp ' + config_file_path + ' ' +
              config['Record']['training_record_path'] + '/model_config.txt')

    # parameters=filter(lambda p: p.requires_grad, model.parameters()))

    trainer = Trainer(
        model=model,
        criterion=criterion,
        validation_criterion=validation_criterion,
        vocab=vocab,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        train_iterators=train_iterators,
        validation_iterators=dev_iterators,
        optimizer_config=config['Optimizer'],
        train_config=config['Train'],
        validation_config=config['Validation'],
        record_config=config['Record'],
        device=device,
    )

    trainer.train()