Пример #1
0
def convert_to_dialogpt(args):
    config = GPT2Config.from_json_file(args.config_path)
    model = load_model(GPT2LMHeadModel(config), None, args, verbose=True)

    model_state_dict = torch.load(args.megatron_checkpoint_path)

    model_state_dict = fix_state_dict_namespace(model_state_dict['model'])
    model_state_dict = fix_model_shapes(model_state_dict)

    start_model = model
    if (hasattr(model, "transformer")
        and all(not s.startswith('transformer.')
                for s in model_state_dict.keys())):
        logger.info('loading transfomer only')
        start_model = model.transformer
    start_model.load_state_dict(model_state_dict)

    torch.save(start_model.state_dict(), args.dialogpt_output_path)
Пример #2
0
log_dir = args.log_dir if args.log_dir is not None and len(
    args.log_dir) > 0 else output_dir
if args.local_rank == -1:
    os.makedirs(output_dir, exist_ok=True)

logger.info('Input Argument Information')
args_dict = vars(args)
for a in args_dict:
    logger.info('%-28s  %s' % (a, args_dict[a]))

#########################################################################
# Prepare Data Set
##########################################################################
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)

config = GPT2Config.from_json_file(join(args.model_name_or_path,
                                        'config.json'))

if args.local_rank == -1:
    train_dataloader = BucketingDataLoader(args.train_input_file,
                                           args.train_batch_size,
                                           args.max_seq_length)
else:
    pass
    # train_dataloader = DistributedBucketingDataLoader(
    #     get_rank(), get_world_size(),
    #     args.train_input_file, args.train_batch_size,
    #     args.max_seq_length)

eval_dataloader_loss = DynamicBatchingLoader(args.eval_input_file, enc,
                                             args.normalize_data,
                                             args.eval_batch_size,
Пример #3
0
log_dir = args.log_dir if args.log_dir is not None and len(
    args.log_dir) > 0 else output_dir
if args.local_rank == -1 or get_rank() == 0:
    os.makedirs(output_dir, exist_ok=True)

logger.info('Input Argument Information')
args_dict = vars(args)
for a in args_dict:
    logger.info('%-28s  %s' % (a, args_dict[a]))

#########################################################################
# Prepare Data Set
##########################################################################
enc = RubertaTokenizer(vocab_file=args.tokenizer_path)

config = GPT2Config.from_json_file(args.config_path)

if args.local_rank == -1:
    train_dataloader = BucketingDataLoader(args.train_input_file,
                                           args.train_batch_size,
                                           args.max_seq_length)
else:
    train_dataloader = DistributedBucketingDataLoader(get_rank(),
                                                      get_world_size(),
                                                      args.train_input_file,
                                                      args.train_batch_size,
                                                      args.max_seq_length)

eval_dataloader_loss = DynamicBatchingLoader(args.eval_input_file, enc,
                                             args.normalize_data,
                                             args.eval_batch_size,