Пример #1
0
def build(opt):
    version = 'v1.0'
    dpath = os.path.join(opt['datapath'], 'QACNN')

    if not build_data.built(dpath, version):
        print('[building data: ' + dpath + ']')
        if build_data.built(dpath):
            # An older version exists, so remove these outdated files.
            build_data.remove_dir(dpath)
        build_data.make_dir(dpath)

        # Download the data.
        fname = 'cnn.tgz'
        gd_id = '0BwmD_VLjROrfTTljRDVZMFJnVWM'
        build_data.download_from_google_drive(gd_id, os.path.join(dpath, fname))
        build_data.untar(dpath, fname)

        create_fb_format(dpath, 'train',
                         os.path.join(dpath, 'cnn', 'questions', 'training'))
        create_fb_format(dpath, 'valid',
                         os.path.join(dpath, 'cnn', 'questions', 'validation'))
        create_fb_format(dpath, 'test',
                         os.path.join(dpath, 'cnn', 'questions', 'test'))

        # Mark the data as built.
        build_data.mark_done(dpath, version)
Пример #2
0
def build(opt):
    version = 'v1.0'
    dpath = os.path.join(opt['datapath'], 'QADailyMail')

    if not build_data.built(dpath, version):
        print('[building data: ' + dpath + ']')
        if build_data.built(dpath):
            # An older version exists, so remove these outdated files.
            build_data.remove_dir(dpath)
        build_data.make_dir(dpath)

        # Download the data.
        fname = 'qadailymail.tar.gz'
        gd_id = '0BwmD_VLjROrfN0xhTDVteGQ3eG8'
        build_data.download_from_google_drive(gd_id,
                                              os.path.join(dpath, fname))
        build_data.untar(dpath, fname)

        ext = os.path.join('dailymail', 'questions')
        create_fb_format(dpath, 'train', os.path.join(dpath, ext, 'training'))
        create_fb_format(dpath, 'valid', os.path.join(dpath, ext,
                                                      'validation'))
        create_fb_format(dpath, 'test', os.path.join(dpath, ext, 'test'))

        # Mark the data as built.
        build_data.mark_done(dpath, version)
Пример #3
0
def build(opt):
    version = 'v1.0'
    dpath = os.path.join(opt['datapath'], 'QACNN')

    if not build_data.built(dpath, version):
        print('[building data: ' + dpath + ']')
        if build_data.built(dpath):
            # An older version exists, so remove these outdated files.
            build_data.remove_dir(dpath)
        build_data.make_dir(dpath)

        # Download the data.
        fname = 'cnn.tgz'
        gd_id = '0BwmD_VLjROrfTTljRDVZMFJnVWM'
        build_data.download_from_google_drive(gd_id,
                                              os.path.join(dpath, fname))
        build_data.untar(dpath, fname)

        create_fb_format(dpath, 'train',
                         os.path.join(dpath, 'cnn', 'questions', 'training'))
        create_fb_format(dpath, 'valid',
                         os.path.join(dpath, 'cnn', 'questions', 'validation'))
        create_fb_format(dpath, 'test',
                         os.path.join(dpath, 'cnn', 'questions', 'test'))

        # Mark the data as built.
        build_data.mark_done(dpath, version)
Пример #4
0
def build(opt):
    # get path to data directory
    dpath = os.path.join(opt['datapath'], 'dialogue_nli')
    # define version if any
    version = '1.0'

    # check if data had been previously built
    if not build_data.built(dpath, version_string=version):
        print('[building data: ' + dpath + ']')

        # make a clean directory if needed
        if build_data.built(dpath):
            # an older version exists, so remove these outdated files.
            build_data.remove_dir(dpath)
        build_data.make_dir(dpath)

        # download the data.
        fname = 'dialogue_nli.zip'
        gd_id = '1WtbXCv3vPB5ql6w0FVDmAEMmWadbrCuG'
        build_data.download_from_google_drive(gd_id,
                                              os.path.join(dpath, fname))

        # uncompress it
        build_data.unzip(dpath, fname)

        # mark the data as built
        build_data.mark_done(dpath, version_string=version)
def build(opt):
    dpath = os.path.join(opt['datapath'], 'CNN_DM')
    version = None

    if not build_data.built(dpath, version_string=version):
        print('[building data: ' + dpath + ']')
        if build_data.built(dpath):
            # An older version exists, so remove these outdated files.
            build_data.remove_dir(dpath)
        build_data.make_dir(dpath)

        # Download the data.

        cnn_fname = 'cnn_stories.tgz'
        cnn_gd_id = '0BwmD_VLjROrfTHk4NFg2SndKcjQ'
        build_data.download_from_google_drive(cnn_gd_id,
                                              os.path.join(dpath, cnn_fname))
        build_data.untar(dpath, cnn_fname)

        dm_fname = 'dm_stories.tgz'
        dm_gd_id = '0BwmD_VLjROrfM1BxdkxVaTY2bWs'
        build_data.download_from_google_drive(dm_gd_id,
                                              os.path.join(dpath, dm_fname))
        build_data.untar(dpath, dm_fname)

        for dt in CNN_FNAMES:
            fname = CNN_FNAMES[dt]
            url = CNN_ROOT + fname
            build_data.download(url, dpath, fname)
            urls_fname = os.path.join(dpath, fname)
            split_fname = os.path.join(dpath, dt + '.txt')
            with open(urls_fname) as urls_file, open(split_fname,
                                                     'a') as split_file:
                for url in urls_file:
                    file_name = hashlib.sha1(
                        url.strip().encode('utf-8')).hexdigest()
                    split_file.write(
                        "cnn/stories/{}.story\n".format(file_name))

        for dt in DM_FNAMES:
            fname = DM_FNAMES[dt]
            url = DM_ROOT + fname
            build_data.download(url, dpath, fname)
            urls_fname = os.path.join(dpath, fname)
            split_fname = os.path.join(dpath, dt + '.txt')
            with open(urls_fname) as urls_file, open(split_fname,
                                                     'a') as split_file:
                for url in urls_file:
                    file_name = hashlib.sha1(
                        url.strip().encode('utf-8')).hexdigest()
                    split_file.write(
                        "dailymail/stories/{}.story\n".format(file_name))

        # Mark the data as built.
        build_data.mark_done(dpath, version_string=version)
Пример #6
0
    def download_if_not_existing(data_path):
        dpath = os.path.join(data_path, 'qangaroo')
        if not os.path.isdir(dpath):
            os.mkdir(dpath)
            fname = 'qangaroo.zip'
            g_ID = "1ytVZ4AhubFDOEL7o7XrIRIyhU8g9wvKA"

            print("downloading ...")
            build_data.download_from_google_drive(g_ID,
                                                  os.path.join(dpath, fname))
            build_data.untar(dpath, fname)
Пример #7
0
    def _load_dnli_model(self):
        # Download pretrained weight
        dnli_model_fname = os.path.join(self.opt['datapath'], 'dnli_model.bin')
        if not os.path.exists(dnli_model_fname):
            print(f"[ Download pretrained dnli model params to {dnli_model_fname}]")
            download_from_google_drive(
                '1Qawz1pMcV0aGLVYzOgpHPgG5vLSKPOJ1',
                dnli_model_fname
            )

        # Load pretrained weight
        print(f"[ Load pretrained dnli model from {dnli_model_fname}]")
        model_state_dict = torch.load(dnli_model_fname)
        dnli_model = BertForSequenceClassification.from_pretrained('bert-base-uncased', state_dict=model_state_dict, num_labels=3)
        if self.use_cuda:
            dnli_model.cuda()
        dnli_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

        return dnli_model, dnli_tokenizer
Пример #8
0
def build(opt):
    dpath = os.path.join(opt['datapath'], 'qangaroo')
    version = 'v1.1'

    if not build_data.built(dpath, version_string=version):
        print('[building data: ' + dpath + ']')
        if build_data.built(dpath):
            # An older version exists, so remove these outdated files.
            build_data.remove_dir(dpath)
        build_data.make_dir(dpath)

        # Download the data.
        fname = 'qangaroo.zip'
        g_ID = "1ytVZ4AhubFDOEL7o7XrIRIyhU8g9wvKA"

        print("downloading ...")
        build_data.download_from_google_drive(g_ID, os.path.join(dpath, fname))
        build_data.untar(dpath, fname)

        # Mark the data as built.
        build_data.mark_done(dpath, version_string=version)
Пример #9
0
def build(opt):
    dpath = os.path.join(opt['datapath'], 'CNN_DM')
    version = None

    if not build_data.built(dpath, version_string=version):
        print('[building data: ' + dpath + ']')
        if build_data.built(dpath):
            # An older version exists, so remove these outdated files.
            build_data.remove_dir(dpath)
        build_data.make_dir(dpath)

        # Download the data.
        cnn_fname = 'cnn_stories.tgz'
        cnn_gd_id = '0BwmD_VLjROrfTHk4NFg2SndKcjQ'
        build_data.download_from_google_drive(cnn_gd_id, os.path.join(dpath, cnn_fname))
        build_data.untar(dpath, cnn_fname)

        dm_fname = 'dm_stories.tgz'
        dm_gd_id = '0BwmD_VLjROrfM1BxdkxVaTY2bWs'
        build_data.download_from_google_drive(dm_gd_id, os.path.join(dpath, dm_fname))
        build_data.untar(dpath, dm_fname)

        # Mark the data as built.
        build_data.mark_done(dpath, version_string=version)
Пример #10
0
def build(opt):
    dpath = os.path.join(opt['datapath'], FOLDER_NAME)
    # version 1.0: initial release
    version = '1.0'

    # check whether data had been previously built
    if not build_data.built(dpath, version_string=version):
        print('[building data: ' + dpath + ']')

        # make a clean directory if needed
        if build_data.built(dpath):
            # if an older version exists, remove those outdated files.
            build_data.remove_dir(dpath)
        build_data.make_dir(dpath)

        #########################
        # ConvAI2 (PersonaChat)
        #########################
        fname = 'data_v1.tar.gz'
        url = 'https://parl.ai/downloads/controllable_dialogue/' + fname
        build_data.download(url, dpath, fname)
        build_data.untar(dpath, fname)

        fname = 'convai2_fix_723.tgz'
        url = 'http://parl.ai/downloads/convai2/' + fname
        build_data.download(url, dpath, fname)
        build_data.untar(dpath, fname)

        #########################
        # Dialogue NLI
        #########################
        fname = 'dialogue_nli.zip'
        gd_id = '1WtbXCv3vPB5ql6w0FVDmAEMmWadbrCuG'
        build_data.download_from_google_drive(gd_id,
                                              os.path.join(dpath, fname))
        build_data.untar(dpath, fname)

        fname = 'dialogue_nli_evaluation.zip'
        gd_id = '1sllq30KMJzEVQ4C0-a9ShSLSPIZc3iMi'
        build_data.download_from_google_drive(gd_id,
                                              os.path.join(dpath, fname))
        build_data.untar(dpath, fname)

        #########################
        # Distractor personas
        #########################
        fname = 'train_sorted_50_personas.json'
        gd_id = '1SGFdJqyNYeepKFqwMLv4Ym717QQTtpi8'
        build_data.download_from_google_drive(gd_id,
                                              os.path.join(dpath, fname))
        fname = 'valid_sorted_50_personas.json'
        gd_id = '1A7oVKmjJ1EZTh6-3Gio4XQo81QgnTGGi'
        build_data.download_from_google_drive(gd_id,
                                              os.path.join(dpath, fname))
        fname = 'dnli_sorted_50_personas.json'
        gd_id = '1wlIkVcBZoGQd3rbI7XWNhuq4rvw9FyoP'
        build_data.download_from_google_drive(gd_id,
                                              os.path.join(dpath, fname))

        print("Data has been placed in " + dpath)

        build_data.mark_done(dpath, version)