Esempio n. 1
0
    def __init__(
        self,
        cfg: Wav2BartChrConfig,
        dictionary=None,
        embed_tokens=None,
        no_encoder_attn=False,
    ):
        super().__init__(dictionary)
        self.cfg = cfg
        # bart = torch.hub.load('pytorch/fairseq', 'bart.base')
        from fairseq.models.bart import BARTModel
        if os.path.isfile(os.path.join(cfg.bart_path, 'model.pt')):
            print('loading bart from cfg path')
            bart = BARTModel.from_pretrained(cfg.bart_path, checkpoint_file='model.pt')
        else:
            print('loading bart from relative path')
            bart = BARTModel.from_pretrained('models/bart.base', checkpoint_file='model.pt')
        
        bart_decoder = bart.model.decoder
        bart_dictionary_size = len(bart_decoder.dictionary)
        self.decoder = TransformerDecoder(bart_decoder.args, bart_decoder.dictionary, bart_decoder.embed_tokens)
        self.decoder.load_state_dict(bart_decoder.state_dict())

        # self.output_embed_dim = cfg.decoder_embed_dim

        ################## Dirty hack to alter output embedding layer of the decoder
        self.decoder.share_input_output_embed = False

        self.output_projection = nn.Linear(
            bart_dictionary_size, len(dictionary), bias=False
        )
        nn.init.normal_(
            self.output_projection.weight, mean=0, std=bart_dictionary_size ** -0.5
        )
Esempio n. 2
0
    def __init__(
        self,
        cfg: Wav2BartPoolConfig,
        dictionary=None,
        embed_tokens=None,
        no_encoder_attn=False,
    ):
        super().__init__(dictionary)
        self.cfg = cfg
        # bart = torch.hub.load('pytorch/fairseq', 'bart.base')
        from fairseq.models.bart import BARTModel
        if os.path.isfile(os.path.join(cfg.bart_path, 'model.pt')):
            print('loading bart from cfg path')
            bart = BARTModel.from_pretrained(cfg.bart_path,
                                             checkpoint_file='model.pt')
        else:
            print('loading bart from relative path')
            bart = BARTModel.from_pretrained('models/bart.base',
                                             checkpoint_file='model.pt')

        bart_decoder = bart.model.decoder
        self.decoder = TransformerDecoder(bart_decoder.args,
                                          bart_decoder.dictionary,
                                          bart_decoder.embed_tokens)
        self.decoder.load_state_dict(bart_decoder.state_dict())
Esempio n. 3
0
    def __init__(self, args):
        self.gen_model_type = args.gen_model_type
        self.gen_model_path = args.gen_model_path
        self.conv_line_path = args.conv_line_path
        self.gen_length = args.length
        self.temperature = args.temperature
        self.top_k = args.top_k
        self.top_p = args.top_p
        self.stop_token = args.stop_token
        self.repetition_penalty = args.repetition_penalty
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.gen_model_type = self.gen_model_type.lower()
        self.lookup = {
            '1': 'Fashion',
            '2': 'Politics',
            '3': 'Books',
            '4': 'Sports',
            '5': 'General Entertainment',
            '6': 'Music',
            '7': 'Science & Technology',
            '8': 'Movie',
            '9': 'General'
        }
        self.topic_cls = BertClassificationPredictor(
            model_path=args.topic_cls_path,
            label_path=args.
            label_dir,  #sys.argv[2], # directory for labels.csv file
            multi_label=False,
            model_type='bert',
            do_lower_case=True)

        self.entity_ext_model = AutoModelForTokenClassification.from_pretrained(
            "dbmdz/bert-large-cased-finetuned-conll03-english")
        self.entity_ext_model.to(self.device)
        self.entity_ext_tokenizer = AutoTokenizer.from_pretrained(
            "bert-base-cased")

        if self.gen_model_type == 'dialogpt':
            self.gen_tokenizer = AutoTokenizer.from_pretrained(
                self.gen_model_path)
            self.gen_model = AutoModelWithLMHead.from_pretrained(
                self.gen_model_path)
            self.gen_model.cuda()
            self.gen_model.eval()
        elif self.gen_model_type == 'bart':
            self.gen_model = BARTModel.from_pretrained(
                self.gen_model_path,
                checkpoint_file='checkpoint_best.pt',
                data_name_or_path=self.gen_model_path)
            self.gen_model.cuda()
            self.gen_model.eval()

        self.conv_line = BARTModel.from_pretrained(
            self.conv_line_path,
            checkpoint_file='checkpoint_best.pt',
            data_name_or_path=self.conv_line_path)
        self.conv_line.cuda()
        self.conv_line.eval()
Esempio n. 4
0
def sanity_check(model_name_or_path, checkpoint_file):
    from fairseq.models.bart import BARTModel

    bart = BARTModel.from_pretrained(
        model_name_or_path,
        checkpoint_file=checkpoint_file,
        data_name_or_path='../data/processed/binary',
        user_dir='../../source',
        task="translation_multi_simple_epoch_extended",
        decoder_langtok=True,
        lang_pairs='java-en_XX',
        lang_dict='lang_dict.txt')

    assert len(bart.task.source_dictionary) == 50008
    assert bart.task.source_dictionary[0] == '<s>'
    assert bart.task.source_dictionary[1] == '<pad>'
    assert bart.task.source_dictionary[2] == '</s>'
    assert bart.task.source_dictionary[3] == '<unk>'
    assert bart.task.source_dictionary[50001] == '__java__'
    assert bart.task.source_dictionary[50002] == '__python__'
    assert bart.task.source_dictionary[50003] == '__en_XX__'
    assert bart.task.source_dictionary[50004] == '__javascript__'
    assert bart.task.source_dictionary[50005] == '__php__'
    assert bart.task.source_dictionary[50006] == '__ruby__'
    assert bart.task.source_dictionary[50007] == '__go__'
Esempio n. 5
0
    def __init__(self,
                 squad_dir='./qa_models/squad1.0',
                 bart_qa_dir='./bart_qg/checkpoints/',
                 use_gpu=False):
        self.qg_model = BARTModel.from_pretrained(
            bart_qa_dir, checkpoint_file='checkpoint_best.pt')

        if use_gpu:
            self.qg_model.cuda()
            self.qg_model.half()
        self.qg_model.eval()

        self.batch_size = 64
        self.beam_size = 10
        self.max_length = 100

        self.nlp = spacy.load('en_core_web_sm')
        self.parser = benepar.Parser("benepar_en2")
        self.stop_words = set(stopwords.words('english'))

        self.squad_cmd = [
            'python {}/run_squad.py'.format(squad_dir), '--model_type bert',
            '--model_name_or_path {}'.format(squad_dir), '--do_eval',
            '--overwrite_cache', '--do_lower_case', '--predict_file {}',
            '--per_gpu_train_batch_size 12', '--max_seq_length 384',
            '--doc_stride 128', '--output_dir {}'
        ]

        self.squad_cmd = ' '.join(self.squad_cmd)
Esempio n. 6
0
    def __init__(self,
                 device='cpu',
                 qa_model_name="deepset/minilm-uncased-squad2",
                 qg_model_dir='../feqa/bart_qg/checkpoints/'):

        self.qg_model = BARTModel.from_pretrained(
            qg_model_dir, checkpoint_file='checkpoint_best.pt')

        if device == 'cuda':
            self.qg_model.to(device)  #.cuda()
            self.qg_model.half()
        self.qg_model.eval()

        self.batch_size = 1  #64
        self.beam_size = 10
        self.max_length = 100

        self.nlp = spacy.load('en_core_web_sm')
        #self.parser = benepar.Parser("benepar_en2")
        self.stop_words = set(stopwords.words('english'))

        self.qa_threshold = 0.1  # below threshold, the question quality is too vague
        self.qa_pipeline = pipeline('question-answering',
                                    model=qa_model_name,
                                    tokenizer=qa_model_name)
def get_model():
    BASEDIR = './model'
    bart = BARTModel.from_pretrained(BASEDIR,
                                     checkpoint_file='checkpoint_best.pt',
                                     bpe='sentencepiece',
                                     sentencepiece_model=BASEDIR +
                                     '/sentence.bpe.model')
    bart.eval()
    return bart
Esempio n. 8
0
    def build_model(cls, cfg: WavBart2BartConfig, task: FairseqTask):
        """Build a new model instance."""
        from fairseq.models.bart import BARTModel
        if os.path.isfile(os.path.join(cfg.bart_path, 'model.pt')):
            print('loading bart from cfg path')
            bart = BARTModel.from_pretrained(cfg.bart_path,
                                             checkpoint_file='model.pt')
        else:
            print('loading bart from relative path')
            bart = BARTModel.from_pretrained('models/bart.base',
                                             checkpoint_file='model.pt')

        assert cfg.autoregressive, "Please set task.autoregressive=true for seq2seq asr models"

        src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
        encoder = cls.build_encoder(cfg, bart)
        decoder = cls.build_decoder(cfg, bart)
        model = WavBart2Bart(encoder, decoder)
        return model
Esempio n. 9
0
def main():
    """
    Usage::

         python examples/bart/summarize.py \
            --model-dir $HOME/bart.large.cnn \
            --model-file model.pt \
            --src $HOME/data-bin/cnn_dm/test.source
    """
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model-dir",
        required=True,
        type=str,
        default="bart.large.cnn/",
        help="path containing model file and src_dict.txt",
    )
    parser.add_argument(
        "--model-file",
        default="checkpoint_best.pt",
        help="where in model_dir are weights saved",
    )
    parser.add_argument(
        "--src", default="test.source", help="text to summarize", type=str
    )
    parser.add_argument(
        "--out", default="test.hypo", help="where to save summaries", type=str
    )
    parser.add_argument("--bsz", default=32, help="where to save summaries", type=int)
    parser.add_argument(
        "--n", default=None, help="how many examples to summarize", type=int
    )
    parser.add_argument(
        "--xsum-kwargs",
        action="store_true",
        default=False,
        help="if true use XSUM_KWARGS else CNN_KWARGS",
    )
    args = parser.parse_args()
    eval_kwargs = XSUM_KWARGS if args.xsum_kwargs else CNN_KWARGS
    if args.model_dir == "pytorch/fairseq":
        bart = torch.hub.load("pytorch/fairseq", args.model_file)
    else:
        bart = BARTModel.from_pretrained(
            args.model_dir,
            checkpoint_file=args.model_file,
            data_name_or_path=args.model_dir,
        )
    bart = bart.eval()
    if torch.cuda.is_available():
        bart = bart.cuda().half()
    generate(
        bart, args.src, bsz=args.bsz, n_obs=args.n, outfile=args.out, **eval_kwargs
    )
Esempio n. 10
0
 def __init__(self, model_path):
     # self.model = torch.hub.load('pytorch/fairseq', 'bart.large.cnn')
     self.model = BARTModel.from_pretrained(model_path,
                                            checkpoint_file='model.pt')
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     self.model.to(device)
     self.model.eval()
     self.model.half()
     self.count = 1
     self.bsz = 2
     self.summary_list = []
     self.slines = []
Esempio n. 11
0
    def build_model(cls, cfg: Wav2Vec2BartConfig, task: FairseqTask):
        """Build a new bart model instance."""
        from fairseq.models.bart import BARTModel
        if os.path.isfile(os.path.join(cfg.bart_path, 'model.pt')):
            print('loading bart from cfg path')
            bart = BARTModel.from_pretrained(cfg.bart_path, checkpoint_file='model.pt')
        else:
            print('loading bart from relative path')
            bart = BARTModel.from_pretrained('models/bart.base', checkpoint_file='model.pt')

        assert cfg.autoregressive, "Please set task.autoregressive=true for seq2seq asr models"

        src_dict, tgt_dict = task.source_dictionary, task.target_dictionary

        bart_dict = bart.model.decoder.dictionary
        transform_embed = Linear(len(bart_dict), len(tgt_dict)) # shared embedding transformer

        encoder = cls.build_encoder(cfg, tgt_dict, transform_embed, bart)
        decoder = cls.build_decoder(cfg, transform_embed, bart)
        model = Wav2Vec2Bart(encoder, decoder)
        model.transform_embed = transform_embed
        return model
Esempio n. 12
0
    def __init__(self, init, text_logger=None):
        super(BART, self).__init__()

        assert init in ['bart.large', 'bart.large.cnn']

        cache_dir = f'{os.getenv("HOME")}/.cache/'
        if not os.path.exists(f'{cache_dir}/{init}'):
            os.system(f'wget https://dl.fbaipublicfiles.com/fairseq/models/'
                      f'{init}.tar.gz -P {cache_dir}')
            os.system(f'tar -xzvf {cache_dir}/{init}.tar.gz -C {cache_dir}')

        self._model = BARTModel.from_pretrained(f'{cache_dir}/{init}').model

        self._hparams = None

        self._text_logger = text_logger
Esempio n. 13
0
 def __init__(self, task, classification_head_name, actor_path, actor_file,
              max_tokens):
     super().__init__(task)
     # self.eps = label_smoothing
     # self.task = task
     # self.debugCount = 0
     # args.bpe = 'gpt2'
     # self.bpe = encoders.build_bpe(args)
     # print(args.actor_path)
     self.regression_target = False
     self.classification_head_name = classification_head_name
     self.actor = BARTModel.from_pretrained(
         actor_path, checkpoint_file=actor_file).model
     # self.actor.half()
     self.actor.eval()
     self.max_tokens = max_tokens
    def __init__(self, config, x_embed):
        super().__init__()

        pretrained_weights = "xlnet-base-cased"
        self.model = XLNetModel.from_pretrained(pretrained_weights)
        self.pretrained_config = XLNetConfig.from_pretrained(
            pretrained_weights)

        self.model = BARTModel.from_pretrained(config.pretrained_weights,
                                               checkpoint_file='model.pt')
        self.model.eval(
        )  # disable dropout (or leave in train mode to finetune)

        self.encoder_out_size = 768

        return
Esempio n. 15
0
    def __init__(self,
                 model_name,
                 num_labels,
                 num_hidden_layers=12,
                 hidden_size=1024):
        super(CustomBART, self).__init__()
        self.num_labels = num_labels
        self.bart = BARTModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(p=0.2)
        self.high_dropout = nn.Dropout(p=0.5)

        n_weights = num_hidden_layers + 1
        weights_init = torch.zeros(n_weights).float()
        weights_init.data[:-1] = -3
        self.layer_weights = torch.nn.Parameter(weights_init)
        self.classifier = nn.Linear(hidden_size, num_labels)
Esempio n. 16
0
def generate_TLDRs(bsz, count, datadir, outdir, checkpoint_dir,
                   checkpoint_file, test_fname, beam, lenpen, max_len_b,
                   min_len, no_repeat_ngram_size):
    bart = BARTModel.from_pretrained(checkpoint_dir,
                                     checkpoint_file=checkpoint_file,
                                     data_name_or_path=datadir + '-bin',
                                     task='translation')
    if torch.cuda.is_available():
        bart.cuda()
        bart.half()
    bart.eval()
    source_fname = join(datadir, 'test.source')
    pred_fname = join(outdir, test_fname)
    with open(source_fname,
              encoding="utf-8") as source, open(pred_fname,
                                                'w',
                                                encoding="utf-8") as fout:
        sline = source.readline().strip()
        slines = [sline]
        for sline in tqdm(source):
            if count % bsz == 0:
                with torch.no_grad():
                    hypotheses_batch = bart.sample(
                        slines,
                        beam=beam,
                        lenpen=lenpen,
                        max_len_b=max_len_b,
                        min_len=min_len,
                        no_repeat_ngram_size=no_repeat_ngram_size)
                for hypothesis in hypotheses_batch:
                    fout.write(hypothesis + '\n')
                    fout.flush()
                slines = []

            slines.append(sline.strip())
            count += 1
        if slines != []:
            hypotheses_batch = bart.sample(
                slines,
                beam=beam,
                lenpen=lenpen,
                max_len_b=max_len_b,
                min_len=min_len,
                no_repeat_ngram_size=no_repeat_ngram_size)
            for hypothesis in hypotheses_batch:
                fout.write(hypothesis.replace('\n', ' ') + '\n')
                fout.flush()
Esempio n. 17
0
    def __init__(
        self,
        cfg: AudioPretrainingConfig,
    ):
        super().__init__(cfg)
        if cfg.eval_wer:
            assert cfg.labels is not None, "eval_wer can only be set during fine-tuning"
        self.blank_symbol = "<s>"

        # self.bart = torch.hub.load('pytorch/fairseq', 'bart.base')
        print('cfg', cfg)

        print('cfg', cfg.bart_path)
        self.bart = BARTModel.from_pretrained(cfg.bart_path,
                                              checkpoint_file='model.pt')
        self.state.merge_state_dict(
            {'target_dictionary': self.bart.task.target_dictionary})
Esempio n. 18
0
def decode(args):
    bart = BARTModel.from_pretrained(args.checkpoint_dir,
                                     checkpoint_file=args.checkpoint_file,
                                     data_name_or_path=args.data_name_or_path)

    bart.cuda()
    bart.eval()
    bart.half()
    count = 1

    line_count = count_file_lines('{}/test.source'.format(args.data_dir))
    with open('{}/test.source'.format(args.data_dir)) as source, \
            open(args.output_file, 'w') as fout:
        sline = source.readline().strip()
        slines = [sline]
        for sline in tqdm(source, total=line_count):
            if count % args.batch_size == 0:
                with torch.no_grad():
                    hypotheses_batch = bart.sample(
                        slines,
                        beam=args.beam_size,
                        lenpen=args.lenpen,
                        max_len_b=args.max_len_b,
                        min_len=args.min_len,
                        no_repeat_ngram_size=args.no_repeat_ngram_size)

                for hypothesis in hypotheses_batch:
                    fout.write(hypothesis.strip() + '\n')
                    fout.flush()
                slines = []

            slines.append(sline.strip())
            count += 1

        if slines != []:
            hypotheses_batch = bart.sample(
                slines,
                beam=args.beam_size,
                lenpen=args.lenpen,
                max_len_b=args.max_len_b,
                min_len=args.min_len,
                no_repeat_ngram_size=args.no_repeat_ngram_size)
            for hypothesis in hypotheses_batch:
                fout.write(hypothesis.strip() + '\n')
                fout.flush()
Esempio n. 19
0
def main():
    bart = BARTModel.from_pretrained("ckpt",
                                     checkpoint_file="checkpoint_best.pt")
    bart.cuda()
    bart.half()
    bart.eval()

    with open("output/input.txt") as source:
        lines = source.readlines()
        lines = [line.replace("\n", "") for line in lines]

        with torch.no_grad():
            preds = bart.sample(lines)

            for i, (line, pred) in enumerate(zip(lines, preds)):
                print(f"[ori] ({i+1}): {line}")
                print(f"[cor] ({i+1}): {reorder(pred)}")
                print()
Esempio n. 20
0
def main():
    bart = BARTModel.from_pretrained('ckpt_bart',
                                     checkpoint_file='checkpoint_best.pt')
    bart.cuda()
    bart.half()
    bart.eval()

    with open('output/input.txt') as source:
        lines = [line.replace("\n", "") for line in source.readlines()]

        with torch.no_grad():
            preds = bart.sample(lines,
                                beam=4,
                                lenpen=2.0,
                                no_repeat_ngram_size=2,
                                temperature=0.9)
            for i, (line, pred) in enumerate(zip(lines, preds)):
                print(f"[ori] ({i+1}): {line}")
                print(f"[com] ({i+1}): {pred}")
                print()
    def __init__(self, config, x_embed):
        super().__init__()

        pretrained_weights = "xlnet-base-cased"
        self.model = XLNetModel.from_pretrained(pretrained_weights)
        self.pretrained_config = XLNetConfig.from_pretrained(
            pretrained_weights)

        # pretrained_weights = "/hits/basement/nlp/jeonso/pretrained/bart.large"
        self.model = BARTModel.from_pretrained(config.pretrained_weights,
                                               checkpoint_file='model.pt')
        self.model.eval(
        )  # disable dropout (or leave in train mode to finetune)

        # if config.use_gpu:
        #   self.model = self.model.to(device=torch.device("cuda"))
        # if config.use_parallel:
        #   self.model = torch.nn.DataParallel(self.model)

        self.encoder_out_size = 768

        return
Esempio n. 22
0
def evaluate():
    logger.info('***** Begin Evalution *****')

    bart = BARTModel.from_pretrained('checkpoints/',
                                     checkpoint_file='checkpoint_best.pt',
                                     data_name_or_path='bin')

    bart.cuda()
    bart.eval()

    with open('./data/val.source') as source, open('./data/val.hypo',
                                                   'w') as fout:
        for sline in source:
            sline = sline.strip().lower()
            with torch.no_grad():
                hypo = bart.sample(sline,
                                   beam=5,
                                   lenpen=2.0,
                                   max_len_b=100,
                                   min_len=20,
                                   no_repeat_ngram_size=3)
            fout.write(hypo.lower())
Esempio n. 23
0
def get_bart(folder_path, checkpoint_file):
    """
    Returns a pretrained BART model.

    Args:
        folder_path: str, path to BART's model, containing the checkpoint.
        checkpoint_file: str, name of BART's checkpoint file (starting from BART's folder).
    """

    from fairseq.models.bart import BARTModel

    bart = BARTModel.from_pretrained(model_name_or_path=folder_path + '/',
                                     checkpoint_file=checkpoint_file)

    if torch.cuda.is_available():
        bart.cuda()
        print("Using BART on GPU...")

    bart.eval()

    print("BART loaded (in evaluation mode).\n")

    return bart
Esempio n. 24
0
 def __init__(self, task, encoder_json, vocab_bpe, critic_path, critic_file,
              sentence_avg, max_tokens, print_update, critic_weight,
              label_smoothing, use_reward, rewarder_file, rewarder_weight):
     super().__init__(task)
     self.eps = label_smoothing
     # self.task = task
     self.debugCount = 0
     self.bpe = GPT2BPE_modified(encoder_json, vocab_bpe)
     self.sentence_avg = sentence_avg
     self.max_tokens = max_tokens
     self.print_update = print_update
     self.critic_weight = critic_weight
     # print(critic_path)
     self.critic = BARTModel.from_pretrained(
         critic_path, checkpoint_file=critic_file).model
     self.critic.half()
     # self.critic = self.critic.cpu()
     # self.critic.short()
     self.critic.eval()
     self.use_rewarder = use_reward
     self.rewarder_weight = rewarder_weight
     if use_reward:
         self.padder = Padder(1024)
Esempio n. 25
0
def convert_fairseq_model(args):
    if not args.save_dir:
        args.save_dir = os.path.basename(args.fairseq_model_path) + '_gluon'
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    fairseq_bart = fairseq_BARTModel.from_pretrained(
        args.fairseq_model_path, checkpoint_file='model.pt')
    vocab_size = convert_vocab(args, fairseq_bart)
    gluon_cfg = convert_config(fairseq_bart.args, vocab_size,
                               BartModel.get_cfg().clone())
    with open(os.path.join(args.save_dir, 'model.yml'), 'w') as of:
        of.write(gluon_cfg.dump())

    ctx = mx.gpu(args.gpu) if args.gpu is not None else mx.cpu()
    gluon_bart = convert_params(fairseq_bart, gluon_cfg, ctx)
    if args.test:
        test_model(fairseq_bart, gluon_bart, args.gpu)

    gluon_bart.save_parameters(os.path.join(args.save_dir, 'model.params'),
                               deduplicate=True)
    logging.info('Convert the BART MLM model in {} to {}'.format(
        os.path.join(args.fairseq_model_path, 'model.pt'),
        os.path.join(args.save_dir, 'model.params')))

    logging.info('Conversion finished!')
    logging.info('Statistics:')
    old_names = os.listdir(args.save_dir)
    for old_name in old_names:
        new_name, long_hash = naming_convention(args.save_dir, old_name)
        old_path = os.path.join(args.save_dir, old_name)
        new_path = os.path.join(args.save_dir, new_name)
        shutil.move(old_path, new_path)
        file_size = os.path.getsize(new_path)
        logging.info('\t{}/{} {} {}'.format(args.save_dir, new_name, long_hash,
                                            file_size))
Esempio n. 26
0

parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument('--checkpoint-path', default='checkpoints/actor', metavar='DIR',
                            help='path to load checkpoint')
parser.add_argument('--checkpoint-file', default='checkpoint_best.pt', metavar='DIR',
                            help='file to load actor')
parser.add_argument('--dataset', default='cnn_dm', metavar='DIR',
                            help='dataset to evaluate')
parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA')

args, _ = parser.parse_known_args()

bart = BARTModel.from_pretrained(
    args.checkpoint_path,
    checkpoint_file=args.checkpoint_file,
    data_name_or_path=args.dataset+'-bin'
)

if args.cpu:
    bart.cpu()
else:
    bart.cuda()
    bart.half()

bart.eval()
# if torch.cuda.device_count() > 1:
#     bart.model = torch.nn.DataParallel(bart.model)
count = 1
bsz = 32
num_lines = sum(1 for _ in open(args.dataset+'/test.source'))
Esempio n. 27
0
from fairseq.models.bart import BARTModel

bart = BARTModel.from_pretrained(
    'checkpoints/',
    checkpoint_file='checkpoint_best.pt',
    data_name_or_path='task-1-bin'
)

label_fn = lambda label: bart.task.label_dictionary.string(
    [label + bart.task.label_dictionary.nspecial]
)

ncorrect, nsamples = 0, 0
bart.cuda()
bart.eval()

with open('./task_1_data/dev.tsv') as fin:
    fin.readline()
    for index, line in enumerate(fin):
        tokens = line.strip().split('\t')
        idx, sent1, sent2, target = tokens
        
        tokens = bart.encode(sent1, sent2)
        
        prediction = bart.predict('sentence_classification_head', tokens).argmax().item()
        prediction_label = label_fn(prediction)
        
        if prediction_label == target:
            ncorrect += 1
        else:
            print(sent1 + '\n' + sent2)
Esempio n. 28
0
#!/usr/bin/env python
"""
created at: Mon 24 Aug 2020 04:35:36 AM EDT
created by: Priyam Tejaswin (ptejaswi)

Inference on test data.
"""

import pdb
import torch
from fairseq.models.bart import BARTModel
from tqdm import tqdm

bart = BARTModel.from_pretrained(
    './checkpoints/',
    checkpoint_file='checkpoint_best.pt',
    data_name_or_path=
    '/projects/metis1/users/ptejaswi/multistep-retrieve-summarize/models/bart.base/'
)

bart.cuda()
bart.eval()
bart.half()
count = 1
bsz = 32
with open('cnn_dm/test.source') as source, open('cnn_dm/test.hypo',
                                                'w') as fout:
    sline = source.readline().strip()
    slines = [sline]
    for sline in tqdm(source):
        if count % bsz == 0:
            with torch.no_grad():
Esempio n. 29
0
import torch
from fairseq.models.bart import BARTModel
import csv

bart = BARTModel.from_pretrained(
    'checkpoints/new',
    checkpoint_file='checkpoint4.pt',
    data_name_or_path='preprocess_taskC_data/subtaskC_data-bin'
)

bart.cuda()
bart.eval()
bart.half()
count = 1
bsz = 32
with open('data/Test Data/subtaskC_test_data_new_plusplus.csv') as source, open('subtaskC_generated/subtaskC_answers.csv', 'w') as fout:
#with open('data/Dev Data/subtaskC_dev_data_new_plusplus.csv') as source, open('subtaskC_generated/trial_hypo.csv', 'w') as fout:
#with open('data/Trial Data/taskC_trial_data.csv') as source, open('subtaskC_generated/trial_hypo.csv', 'w') as fout:
    source=csv.reader(source)
    fout=csv.writer(fout)
    next(source)
    idx,false_sent,true_sent,evidence_from_wik = next(source)
    #idx,false_sent = next(source)
    if false_sent.isupper():
        false_sent=false_sent.capitalize()
    if true_sent.isupper():
        true_sent=true_sent.capitalize()
    sline=false_sent
    #sline='The statement "'+false_sent+'" is absurd, because:'
    #sline='Context: '+evidence_from_wik+' | '+'The statement: '+false_sent
    #sline='Context: '+evidence_from_wik+' | '+'The statement "'+false_sent+'" is absurd, because:' #should be modified
Esempio n. 30
0
parser.add_argument("--high-random-prob", type=float, default=0.4)

parser.add_argument("--random-word-span", type=int, default=0)
parser.add_argument("--gen-with-mbart", type=int, default=0)
parser.add_argument("--batch-size", type=int, default=96)

parser.add_argument("--model-path", type=str, default=None)

args = parser.parse_args()

# Fill your mbart or bart checkpoint path
if args.gen_with_mbart:
    bart = BARTModel.from_pretrained(
        args.model_path,
        checkpoint_file='model.pt',
        data_name_or_path=args.model_path,
        bpe="sentencepiece",
        layernorm_embedding=True,
    )
else:
    bart = BARTModel.from_pretrained(
        os.path.join('/checkpoint/chuntinz/container/bart.large'),
        checkpoint_file='model.pt',
        data_name_or_path=os.path.join('/checkpoint/chuntinz/container/gpt2_bpe')
    )

noise_params = dict()
noise_params['mask_length'] = args.mask_length
noise_params['mask_ratio'] = args.mask_ratio
noise_params['random_ratio'] = args.random_ratio
noise_params['poisson_lambda'] = args.poisson_lambda