def __init__( self, cfg: Wav2BartChrConfig, dictionary=None, embed_tokens=None, no_encoder_attn=False, ): super().__init__(dictionary) self.cfg = cfg # bart = torch.hub.load('pytorch/fairseq', 'bart.base') from fairseq.models.bart import BARTModel if os.path.isfile(os.path.join(cfg.bart_path, 'model.pt')): print('loading bart from cfg path') bart = BARTModel.from_pretrained(cfg.bart_path, checkpoint_file='model.pt') else: print('loading bart from relative path') bart = BARTModel.from_pretrained('models/bart.base', checkpoint_file='model.pt') bart_decoder = bart.model.decoder bart_dictionary_size = len(bart_decoder.dictionary) self.decoder = TransformerDecoder(bart_decoder.args, bart_decoder.dictionary, bart_decoder.embed_tokens) self.decoder.load_state_dict(bart_decoder.state_dict()) # self.output_embed_dim = cfg.decoder_embed_dim ################## Dirty hack to alter output embedding layer of the decoder self.decoder.share_input_output_embed = False self.output_projection = nn.Linear( bart_dictionary_size, len(dictionary), bias=False ) nn.init.normal_( self.output_projection.weight, mean=0, std=bart_dictionary_size ** -0.5 )
def __init__( self, cfg: Wav2BartPoolConfig, dictionary=None, embed_tokens=None, no_encoder_attn=False, ): super().__init__(dictionary) self.cfg = cfg # bart = torch.hub.load('pytorch/fairseq', 'bart.base') from fairseq.models.bart import BARTModel if os.path.isfile(os.path.join(cfg.bart_path, 'model.pt')): print('loading bart from cfg path') bart = BARTModel.from_pretrained(cfg.bart_path, checkpoint_file='model.pt') else: print('loading bart from relative path') bart = BARTModel.from_pretrained('models/bart.base', checkpoint_file='model.pt') bart_decoder = bart.model.decoder self.decoder = TransformerDecoder(bart_decoder.args, bart_decoder.dictionary, bart_decoder.embed_tokens) self.decoder.load_state_dict(bart_decoder.state_dict())
def __init__(self, args): self.gen_model_type = args.gen_model_type self.gen_model_path = args.gen_model_path self.conv_line_path = args.conv_line_path self.gen_length = args.length self.temperature = args.temperature self.top_k = args.top_k self.top_p = args.top_p self.stop_token = args.stop_token self.repetition_penalty = args.repetition_penalty self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") self.gen_model_type = self.gen_model_type.lower() self.lookup = { '1': 'Fashion', '2': 'Politics', '3': 'Books', '4': 'Sports', '5': 'General Entertainment', '6': 'Music', '7': 'Science & Technology', '8': 'Movie', '9': 'General' } self.topic_cls = BertClassificationPredictor( model_path=args.topic_cls_path, label_path=args. label_dir, #sys.argv[2], # directory for labels.csv file multi_label=False, model_type='bert', do_lower_case=True) self.entity_ext_model = AutoModelForTokenClassification.from_pretrained( "dbmdz/bert-large-cased-finetuned-conll03-english") self.entity_ext_model.to(self.device) self.entity_ext_tokenizer = AutoTokenizer.from_pretrained( "bert-base-cased") if self.gen_model_type == 'dialogpt': self.gen_tokenizer = AutoTokenizer.from_pretrained( self.gen_model_path) self.gen_model = AutoModelWithLMHead.from_pretrained( self.gen_model_path) self.gen_model.cuda() self.gen_model.eval() elif self.gen_model_type == 'bart': self.gen_model = BARTModel.from_pretrained( self.gen_model_path, checkpoint_file='checkpoint_best.pt', data_name_or_path=self.gen_model_path) self.gen_model.cuda() self.gen_model.eval() self.conv_line = BARTModel.from_pretrained( self.conv_line_path, checkpoint_file='checkpoint_best.pt', data_name_or_path=self.conv_line_path) self.conv_line.cuda() self.conv_line.eval()
def sanity_check(model_name_or_path, checkpoint_file): from fairseq.models.bart import BARTModel bart = BARTModel.from_pretrained( model_name_or_path, checkpoint_file=checkpoint_file, data_name_or_path='../data/processed/binary', user_dir='../../source', task="translation_multi_simple_epoch_extended", decoder_langtok=True, lang_pairs='java-en_XX', lang_dict='lang_dict.txt') assert len(bart.task.source_dictionary) == 50008 assert bart.task.source_dictionary[0] == '<s>' assert bart.task.source_dictionary[1] == '<pad>' assert bart.task.source_dictionary[2] == '</s>' assert bart.task.source_dictionary[3] == '<unk>' assert bart.task.source_dictionary[50001] == '__java__' assert bart.task.source_dictionary[50002] == '__python__' assert bart.task.source_dictionary[50003] == '__en_XX__' assert bart.task.source_dictionary[50004] == '__javascript__' assert bart.task.source_dictionary[50005] == '__php__' assert bart.task.source_dictionary[50006] == '__ruby__' assert bart.task.source_dictionary[50007] == '__go__'
def __init__(self, squad_dir='./qa_models/squad1.0', bart_qa_dir='./bart_qg/checkpoints/', use_gpu=False): self.qg_model = BARTModel.from_pretrained( bart_qa_dir, checkpoint_file='checkpoint_best.pt') if use_gpu: self.qg_model.cuda() self.qg_model.half() self.qg_model.eval() self.batch_size = 64 self.beam_size = 10 self.max_length = 100 self.nlp = spacy.load('en_core_web_sm') self.parser = benepar.Parser("benepar_en2") self.stop_words = set(stopwords.words('english')) self.squad_cmd = [ 'python {}/run_squad.py'.format(squad_dir), '--model_type bert', '--model_name_or_path {}'.format(squad_dir), '--do_eval', '--overwrite_cache', '--do_lower_case', '--predict_file {}', '--per_gpu_train_batch_size 12', '--max_seq_length 384', '--doc_stride 128', '--output_dir {}' ] self.squad_cmd = ' '.join(self.squad_cmd)
def __init__(self, device='cpu', qa_model_name="deepset/minilm-uncased-squad2", qg_model_dir='../feqa/bart_qg/checkpoints/'): self.qg_model = BARTModel.from_pretrained( qg_model_dir, checkpoint_file='checkpoint_best.pt') if device == 'cuda': self.qg_model.to(device) #.cuda() self.qg_model.half() self.qg_model.eval() self.batch_size = 1 #64 self.beam_size = 10 self.max_length = 100 self.nlp = spacy.load('en_core_web_sm') #self.parser = benepar.Parser("benepar_en2") self.stop_words = set(stopwords.words('english')) self.qa_threshold = 0.1 # below threshold, the question quality is too vague self.qa_pipeline = pipeline('question-answering', model=qa_model_name, tokenizer=qa_model_name)
def get_model(): BASEDIR = './model' bart = BARTModel.from_pretrained(BASEDIR, checkpoint_file='checkpoint_best.pt', bpe='sentencepiece', sentencepiece_model=BASEDIR + '/sentence.bpe.model') bart.eval() return bart
def build_model(cls, cfg: WavBart2BartConfig, task: FairseqTask): """Build a new model instance.""" from fairseq.models.bart import BARTModel if os.path.isfile(os.path.join(cfg.bart_path, 'model.pt')): print('loading bart from cfg path') bart = BARTModel.from_pretrained(cfg.bart_path, checkpoint_file='model.pt') else: print('loading bart from relative path') bart = BARTModel.from_pretrained('models/bart.base', checkpoint_file='model.pt') assert cfg.autoregressive, "Please set task.autoregressive=true for seq2seq asr models" src_dict, tgt_dict = task.source_dictionary, task.target_dictionary encoder = cls.build_encoder(cfg, bart) decoder = cls.build_decoder(cfg, bart) model = WavBart2Bart(encoder, decoder) return model
def main(): """ Usage:: python examples/bart/summarize.py \ --model-dir $HOME/bart.large.cnn \ --model-file model.pt \ --src $HOME/data-bin/cnn_dm/test.source """ parser = argparse.ArgumentParser() parser.add_argument( "--model-dir", required=True, type=str, default="bart.large.cnn/", help="path containing model file and src_dict.txt", ) parser.add_argument( "--model-file", default="checkpoint_best.pt", help="where in model_dir are weights saved", ) parser.add_argument( "--src", default="test.source", help="text to summarize", type=str ) parser.add_argument( "--out", default="test.hypo", help="where to save summaries", type=str ) parser.add_argument("--bsz", default=32, help="where to save summaries", type=int) parser.add_argument( "--n", default=None, help="how many examples to summarize", type=int ) parser.add_argument( "--xsum-kwargs", action="store_true", default=False, help="if true use XSUM_KWARGS else CNN_KWARGS", ) args = parser.parse_args() eval_kwargs = XSUM_KWARGS if args.xsum_kwargs else CNN_KWARGS if args.model_dir == "pytorch/fairseq": bart = torch.hub.load("pytorch/fairseq", args.model_file) else: bart = BARTModel.from_pretrained( args.model_dir, checkpoint_file=args.model_file, data_name_or_path=args.model_dir, ) bart = bart.eval() if torch.cuda.is_available(): bart = bart.cuda().half() generate( bart, args.src, bsz=args.bsz, n_obs=args.n, outfile=args.out, **eval_kwargs )
def __init__(self, model_path): # self.model = torch.hub.load('pytorch/fairseq', 'bart.large.cnn') self.model = BARTModel.from_pretrained(model_path, checkpoint_file='model.pt') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(device) self.model.eval() self.model.half() self.count = 1 self.bsz = 2 self.summary_list = [] self.slines = []
def build_model(cls, cfg: Wav2Vec2BartConfig, task: FairseqTask): """Build a new bart model instance.""" from fairseq.models.bart import BARTModel if os.path.isfile(os.path.join(cfg.bart_path, 'model.pt')): print('loading bart from cfg path') bart = BARTModel.from_pretrained(cfg.bart_path, checkpoint_file='model.pt') else: print('loading bart from relative path') bart = BARTModel.from_pretrained('models/bart.base', checkpoint_file='model.pt') assert cfg.autoregressive, "Please set task.autoregressive=true for seq2seq asr models" src_dict, tgt_dict = task.source_dictionary, task.target_dictionary bart_dict = bart.model.decoder.dictionary transform_embed = Linear(len(bart_dict), len(tgt_dict)) # shared embedding transformer encoder = cls.build_encoder(cfg, tgt_dict, transform_embed, bart) decoder = cls.build_decoder(cfg, transform_embed, bart) model = Wav2Vec2Bart(encoder, decoder) model.transform_embed = transform_embed return model
def __init__(self, init, text_logger=None): super(BART, self).__init__() assert init in ['bart.large', 'bart.large.cnn'] cache_dir = f'{os.getenv("HOME")}/.cache/' if not os.path.exists(f'{cache_dir}/{init}'): os.system(f'wget https://dl.fbaipublicfiles.com/fairseq/models/' f'{init}.tar.gz -P {cache_dir}') os.system(f'tar -xzvf {cache_dir}/{init}.tar.gz -C {cache_dir}') self._model = BARTModel.from_pretrained(f'{cache_dir}/{init}').model self._hparams = None self._text_logger = text_logger
def __init__(self, task, classification_head_name, actor_path, actor_file, max_tokens): super().__init__(task) # self.eps = label_smoothing # self.task = task # self.debugCount = 0 # args.bpe = 'gpt2' # self.bpe = encoders.build_bpe(args) # print(args.actor_path) self.regression_target = False self.classification_head_name = classification_head_name self.actor = BARTModel.from_pretrained( actor_path, checkpoint_file=actor_file).model # self.actor.half() self.actor.eval() self.max_tokens = max_tokens
def __init__(self, config, x_embed): super().__init__() pretrained_weights = "xlnet-base-cased" self.model = XLNetModel.from_pretrained(pretrained_weights) self.pretrained_config = XLNetConfig.from_pretrained( pretrained_weights) self.model = BARTModel.from_pretrained(config.pretrained_weights, checkpoint_file='model.pt') self.model.eval( ) # disable dropout (or leave in train mode to finetune) self.encoder_out_size = 768 return
def __init__(self, model_name, num_labels, num_hidden_layers=12, hidden_size=1024): super(CustomBART, self).__init__() self.num_labels = num_labels self.bart = BARTModel.from_pretrained(model_name) self.dropout = nn.Dropout(p=0.2) self.high_dropout = nn.Dropout(p=0.5) n_weights = num_hidden_layers + 1 weights_init = torch.zeros(n_weights).float() weights_init.data[:-1] = -3 self.layer_weights = torch.nn.Parameter(weights_init) self.classifier = nn.Linear(hidden_size, num_labels)
def generate_TLDRs(bsz, count, datadir, outdir, checkpoint_dir, checkpoint_file, test_fname, beam, lenpen, max_len_b, min_len, no_repeat_ngram_size): bart = BARTModel.from_pretrained(checkpoint_dir, checkpoint_file=checkpoint_file, data_name_or_path=datadir + '-bin', task='translation') if torch.cuda.is_available(): bart.cuda() bart.half() bart.eval() source_fname = join(datadir, 'test.source') pred_fname = join(outdir, test_fname) with open(source_fname, encoding="utf-8") as source, open(pred_fname, 'w', encoding="utf-8") as fout: sline = source.readline().strip() slines = [sline] for sline in tqdm(source): if count % bsz == 0: with torch.no_grad(): hypotheses_batch = bart.sample( slines, beam=beam, lenpen=lenpen, max_len_b=max_len_b, min_len=min_len, no_repeat_ngram_size=no_repeat_ngram_size) for hypothesis in hypotheses_batch: fout.write(hypothesis + '\n') fout.flush() slines = [] slines.append(sline.strip()) count += 1 if slines != []: hypotheses_batch = bart.sample( slines, beam=beam, lenpen=lenpen, max_len_b=max_len_b, min_len=min_len, no_repeat_ngram_size=no_repeat_ngram_size) for hypothesis in hypotheses_batch: fout.write(hypothesis.replace('\n', ' ') + '\n') fout.flush()
def __init__( self, cfg: AudioPretrainingConfig, ): super().__init__(cfg) if cfg.eval_wer: assert cfg.labels is not None, "eval_wer can only be set during fine-tuning" self.blank_symbol = "<s>" # self.bart = torch.hub.load('pytorch/fairseq', 'bart.base') print('cfg', cfg) print('cfg', cfg.bart_path) self.bart = BARTModel.from_pretrained(cfg.bart_path, checkpoint_file='model.pt') self.state.merge_state_dict( {'target_dictionary': self.bart.task.target_dictionary})
def decode(args): bart = BARTModel.from_pretrained(args.checkpoint_dir, checkpoint_file=args.checkpoint_file, data_name_or_path=args.data_name_or_path) bart.cuda() bart.eval() bart.half() count = 1 line_count = count_file_lines('{}/test.source'.format(args.data_dir)) with open('{}/test.source'.format(args.data_dir)) as source, \ open(args.output_file, 'w') as fout: sline = source.readline().strip() slines = [sline] for sline in tqdm(source, total=line_count): if count % args.batch_size == 0: with torch.no_grad(): hypotheses_batch = bart.sample( slines, beam=args.beam_size, lenpen=args.lenpen, max_len_b=args.max_len_b, min_len=args.min_len, no_repeat_ngram_size=args.no_repeat_ngram_size) for hypothesis in hypotheses_batch: fout.write(hypothesis.strip() + '\n') fout.flush() slines = [] slines.append(sline.strip()) count += 1 if slines != []: hypotheses_batch = bart.sample( slines, beam=args.beam_size, lenpen=args.lenpen, max_len_b=args.max_len_b, min_len=args.min_len, no_repeat_ngram_size=args.no_repeat_ngram_size) for hypothesis in hypotheses_batch: fout.write(hypothesis.strip() + '\n') fout.flush()
def main(): bart = BARTModel.from_pretrained("ckpt", checkpoint_file="checkpoint_best.pt") bart.cuda() bart.half() bart.eval() with open("output/input.txt") as source: lines = source.readlines() lines = [line.replace("\n", "") for line in lines] with torch.no_grad(): preds = bart.sample(lines) for i, (line, pred) in enumerate(zip(lines, preds)): print(f"[ori] ({i+1}): {line}") print(f"[cor] ({i+1}): {reorder(pred)}") print()
def main(): bart = BARTModel.from_pretrained('ckpt_bart', checkpoint_file='checkpoint_best.pt') bart.cuda() bart.half() bart.eval() with open('output/input.txt') as source: lines = [line.replace("\n", "") for line in source.readlines()] with torch.no_grad(): preds = bart.sample(lines, beam=4, lenpen=2.0, no_repeat_ngram_size=2, temperature=0.9) for i, (line, pred) in enumerate(zip(lines, preds)): print(f"[ori] ({i+1}): {line}") print(f"[com] ({i+1}): {pred}") print()
def __init__(self, config, x_embed): super().__init__() pretrained_weights = "xlnet-base-cased" self.model = XLNetModel.from_pretrained(pretrained_weights) self.pretrained_config = XLNetConfig.from_pretrained( pretrained_weights) # pretrained_weights = "/hits/basement/nlp/jeonso/pretrained/bart.large" self.model = BARTModel.from_pretrained(config.pretrained_weights, checkpoint_file='model.pt') self.model.eval( ) # disable dropout (or leave in train mode to finetune) # if config.use_gpu: # self.model = self.model.to(device=torch.device("cuda")) # if config.use_parallel: # self.model = torch.nn.DataParallel(self.model) self.encoder_out_size = 768 return
def evaluate(): logger.info('***** Begin Evalution *****') bart = BARTModel.from_pretrained('checkpoints/', checkpoint_file='checkpoint_best.pt', data_name_or_path='bin') bart.cuda() bart.eval() with open('./data/val.source') as source, open('./data/val.hypo', 'w') as fout: for sline in source: sline = sline.strip().lower() with torch.no_grad(): hypo = bart.sample(sline, beam=5, lenpen=2.0, max_len_b=100, min_len=20, no_repeat_ngram_size=3) fout.write(hypo.lower())
def get_bart(folder_path, checkpoint_file): """ Returns a pretrained BART model. Args: folder_path: str, path to BART's model, containing the checkpoint. checkpoint_file: str, name of BART's checkpoint file (starting from BART's folder). """ from fairseq.models.bart import BARTModel bart = BARTModel.from_pretrained(model_name_or_path=folder_path + '/', checkpoint_file=checkpoint_file) if torch.cuda.is_available(): bart.cuda() print("Using BART on GPU...") bart.eval() print("BART loaded (in evaluation mode).\n") return bart
def __init__(self, task, encoder_json, vocab_bpe, critic_path, critic_file, sentence_avg, max_tokens, print_update, critic_weight, label_smoothing, use_reward, rewarder_file, rewarder_weight): super().__init__(task) self.eps = label_smoothing # self.task = task self.debugCount = 0 self.bpe = GPT2BPE_modified(encoder_json, vocab_bpe) self.sentence_avg = sentence_avg self.max_tokens = max_tokens self.print_update = print_update self.critic_weight = critic_weight # print(critic_path) self.critic = BARTModel.from_pretrained( critic_path, checkpoint_file=critic_file).model self.critic.half() # self.critic = self.critic.cpu() # self.critic.short() self.critic.eval() self.use_rewarder = use_reward self.rewarder_weight = rewarder_weight if use_reward: self.padder = Padder(1024)
def convert_fairseq_model(args): if not args.save_dir: args.save_dir = os.path.basename(args.fairseq_model_path) + '_gluon' if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) fairseq_bart = fairseq_BARTModel.from_pretrained( args.fairseq_model_path, checkpoint_file='model.pt') vocab_size = convert_vocab(args, fairseq_bart) gluon_cfg = convert_config(fairseq_bart.args, vocab_size, BartModel.get_cfg().clone()) with open(os.path.join(args.save_dir, 'model.yml'), 'w') as of: of.write(gluon_cfg.dump()) ctx = mx.gpu(args.gpu) if args.gpu is not None else mx.cpu() gluon_bart = convert_params(fairseq_bart, gluon_cfg, ctx) if args.test: test_model(fairseq_bart, gluon_bart, args.gpu) gluon_bart.save_parameters(os.path.join(args.save_dir, 'model.params'), deduplicate=True) logging.info('Convert the BART MLM model in {} to {}'.format( os.path.join(args.fairseq_model_path, 'model.pt'), os.path.join(args.save_dir, 'model.params'))) logging.info('Conversion finished!') logging.info('Statistics:') old_names = os.listdir(args.save_dir) for old_name in old_names: new_name, long_hash = naming_convention(args.save_dir, old_name) old_path = os.path.join(args.save_dir, old_name) new_path = os.path.join(args.save_dir, new_name) shutil.move(old_path, new_path) file_size = os.path.getsize(new_path) logging.info('\t{}/{} {} {}'.format(args.save_dir, new_name, long_hash, file_size))
parser = argparse.ArgumentParser(allow_abbrev=False) parser.add_argument('--checkpoint-path', default='checkpoints/actor', metavar='DIR', help='path to load checkpoint') parser.add_argument('--checkpoint-file', default='checkpoint_best.pt', metavar='DIR', help='file to load actor') parser.add_argument('--dataset', default='cnn_dm', metavar='DIR', help='dataset to evaluate') parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA') args, _ = parser.parse_known_args() bart = BARTModel.from_pretrained( args.checkpoint_path, checkpoint_file=args.checkpoint_file, data_name_or_path=args.dataset+'-bin' ) if args.cpu: bart.cpu() else: bart.cuda() bart.half() bart.eval() # if torch.cuda.device_count() > 1: # bart.model = torch.nn.DataParallel(bart.model) count = 1 bsz = 32 num_lines = sum(1 for _ in open(args.dataset+'/test.source'))
from fairseq.models.bart import BARTModel bart = BARTModel.from_pretrained( 'checkpoints/', checkpoint_file='checkpoint_best.pt', data_name_or_path='task-1-bin' ) label_fn = lambda label: bart.task.label_dictionary.string( [label + bart.task.label_dictionary.nspecial] ) ncorrect, nsamples = 0, 0 bart.cuda() bart.eval() with open('./task_1_data/dev.tsv') as fin: fin.readline() for index, line in enumerate(fin): tokens = line.strip().split('\t') idx, sent1, sent2, target = tokens tokens = bart.encode(sent1, sent2) prediction = bart.predict('sentence_classification_head', tokens).argmax().item() prediction_label = label_fn(prediction) if prediction_label == target: ncorrect += 1 else: print(sent1 + '\n' + sent2)
#!/usr/bin/env python """ created at: Mon 24 Aug 2020 04:35:36 AM EDT created by: Priyam Tejaswin (ptejaswi) Inference on test data. """ import pdb import torch from fairseq.models.bart import BARTModel from tqdm import tqdm bart = BARTModel.from_pretrained( './checkpoints/', checkpoint_file='checkpoint_best.pt', data_name_or_path= '/projects/metis1/users/ptejaswi/multistep-retrieve-summarize/models/bart.base/' ) bart.cuda() bart.eval() bart.half() count = 1 bsz = 32 with open('cnn_dm/test.source') as source, open('cnn_dm/test.hypo', 'w') as fout: sline = source.readline().strip() slines = [sline] for sline in tqdm(source): if count % bsz == 0: with torch.no_grad():
import torch from fairseq.models.bart import BARTModel import csv bart = BARTModel.from_pretrained( 'checkpoints/new', checkpoint_file='checkpoint4.pt', data_name_or_path='preprocess_taskC_data/subtaskC_data-bin' ) bart.cuda() bart.eval() bart.half() count = 1 bsz = 32 with open('data/Test Data/subtaskC_test_data_new_plusplus.csv') as source, open('subtaskC_generated/subtaskC_answers.csv', 'w') as fout: #with open('data/Dev Data/subtaskC_dev_data_new_plusplus.csv') as source, open('subtaskC_generated/trial_hypo.csv', 'w') as fout: #with open('data/Trial Data/taskC_trial_data.csv') as source, open('subtaskC_generated/trial_hypo.csv', 'w') as fout: source=csv.reader(source) fout=csv.writer(fout) next(source) idx,false_sent,true_sent,evidence_from_wik = next(source) #idx,false_sent = next(source) if false_sent.isupper(): false_sent=false_sent.capitalize() if true_sent.isupper(): true_sent=true_sent.capitalize() sline=false_sent #sline='The statement "'+false_sent+'" is absurd, because:' #sline='Context: '+evidence_from_wik+' | '+'The statement: '+false_sent #sline='Context: '+evidence_from_wik+' | '+'The statement "'+false_sent+'" is absurd, because:' #should be modified
parser.add_argument("--high-random-prob", type=float, default=0.4) parser.add_argument("--random-word-span", type=int, default=0) parser.add_argument("--gen-with-mbart", type=int, default=0) parser.add_argument("--batch-size", type=int, default=96) parser.add_argument("--model-path", type=str, default=None) args = parser.parse_args() # Fill your mbart or bart checkpoint path if args.gen_with_mbart: bart = BARTModel.from_pretrained( args.model_path, checkpoint_file='model.pt', data_name_or_path=args.model_path, bpe="sentencepiece", layernorm_embedding=True, ) else: bart = BARTModel.from_pretrained( os.path.join('/checkpoint/chuntinz/container/bart.large'), checkpoint_file='model.pt', data_name_or_path=os.path.join('/checkpoint/chuntinz/container/gpt2_bpe') ) noise_params = dict() noise_params['mask_length'] = args.mask_length noise_params['mask_ratio'] = args.mask_ratio noise_params['random_ratio'] = args.random_ratio noise_params['poisson_lambda'] = args.poisson_lambda