def main(config: dict):
    ''' entry point '''
    # print tag.
    print(10 * '*', __file__, config)
    if not os.path.exists(config['saved_dir']):
        os.makedirs(config['saved_dir'])

    # create datasets. provide train and eval data.
    dataset = OverallZeshelDataset(
        '../datasets/mel_zeshel_overall/', {
            'TRAIN_TASKS_NUM': 3000,
            'VALID_TASKS_NUM': 300,
            'WAYS_NUM_PRE_TASK': config['way_num'][0],
            'SUPPORT_USING_CONTEXT': config['context'],
            'QUERY_USING_CONTEXT': config['context']
        })

    # tensorizer. convert an example to tensors.
    tensorizer = EasyBertTokenizer.from_pretrained(
        '../pretrain/uncased_L-12_H-768_A-12', {
            'FIXED_LEN': 72,
            'DO_LOWER_CASE': True
        })

    # adapter. call tensorizer, convert a batch of examples to big tensors.
    adapter = MelAdapter(tensorizer, tensorizer)

    # embedding model. for predication.
    bert = Bert.from_pretrained('../pretrain/uncased_L-12_H-768_A-12', {
        'POOLING_METHOD': 'avg',
        'FINETUNE_LAYER_RANGE': '9:12'
    })

    # prototypical network for training.
    model = PrototypicalNetwork(bert, bert)

    # trainer. to train siamese bert.
    trainer = Trainer({
        'dataset': dataset,
        'adapter': adapter,
        'model': model,
        'DEVICE': torch.device(config['device']),
        'TRAIN_BATCH_SIZE': 2,
        'VALID_BATCH_SIZE': 5,
        'ROUND': 3
    })

    # train start here.
    trainer.train()

    # test here.
    for way_num in config['way_num']:
        test_dataset = OverallZeshelDataset(
            '../datasets/mel_zeshel_overall/', {
                'TRAIN_TASKS_NUM': 0,
                'VALID_TASKS_NUM': 0,
                'WAYS_NUM_PRE_TASK': way_num,
                'SUPPORT_USING_CONTEXT': config['context'],
                'QUERY_USING_CONTEXT': config['context']
            })
        trainer.test(test_dataset.test_data())

    # save model
    bert.save_pretrained(config['saved_dir'])
def main(config: dict):
    ''' entry point '''
    # print tag.
    print(10 * '*', __file__, config)
    if not os.path.exists(config['saved_dir']):
        os.makedirs(config['saved_dir'])

    # create datasets. provide train and eval data.
    dataset = NormalCrosselDataset(
        '../datasets/mel_crossel_normal/',
        {
            'TRAIN_TASKS_NUM': 3000,
            'VALID_TASKS_NUM': 300,
            'WAYS_NUM_PRE_TASK': config['way_num'][0],
            'SHOTS_NUM_PRE_WAYS': config['shot_num'],
            'TRAIN_WAY_PORTION': 0.9,
            'TEST_LANGUAGE': 'bn'  # it doesn't matter. we do not test in here.
        })

    # tensorizer. convert an example to tensors.
    tensorizer = EasyBertTokenizer.from_pretrained(
        '../pretrain/multi_cased_L-12_H-768_A-12', {
            'FIXED_LEN': 16,
            'DO_LOWER_CASE': True
        })

    # adapter. call tensorizer, convert a batch of examples to big tensors.
    adapter = MelAdapter(tensorizer, tensorizer)

    # embedding model. for predication.
    bert = Bert.from_pretrained('../pretrain/multi_cased_L-12_H-768_A-12', {
        'POOLING_METHOD': 'avg',
        'FINETUNE_LAYER_RANGE': '9:12'
    })

    # prototypical network for training.
    model = PrototypicalNetwork(bert, bert)

    # trainer. to train siamese bert.
    trainer = Trainer({
        'dataset': dataset,
        'adapter': adapter,
        'model': model,
        'DEVICE': torch.device(config['device']),
        'TRAIN_BATCH_SIZE': 2,
        'VALID_BATCH_SIZE': 5,
        'ROUND': 3
    })

    # train start here.
    trainer.train()

    # test here.
    for way_num in config['way_num']:
        for lan in config['lan']:
            test_dataset = NormalCrosselDataset(
                '../datasets/mel_crossel_normal/',
                {
                    'TRAIN_TASKS_NUM': 0,
                    'VALID_TASKS_NUM': 0,
                    'WAYS_NUM_PRE_TASK': way_num,
                    'SHOTS_NUM_PRE_WAYS': config['shot_num'],
                    'TEST_LANGUAGE': lan  # we test here.
                })
            print(f'[TEST]: WAY_NUM: {way_num}, LANGUAGE: {lan}')
            trainer.test(test_dataset.test_data())

    # save model
    bert.save_pretrained(config['saved_dir'])
Exemple #3
0
def main(config: dict):
    '''
        zel pipeline example.
    '''
    # print tag.
    print(10 * '*', __file__, config)
    saved_dir = os.path.dirname(config['saved_file'])
    if not os.path.exists(saved_dir):
        os.makedirs(saved_dir)

    # create datasets. provide train and eval data.
    dataset = Crossel(
        '../datasets/crossel',
        {
            'TRAIN_WAY_PORTION': 0.9,
            'CANDIDATE_USING_TEXT': config['context'],
            'TEST_LANGUAGE': 'uk'  # it doesn't matter.
        })

    # tensorizer. convert an example to tensors.
    tensorizer = EasyBertTokenizer.from_pretrained(
        '../pretrain/multi_cased_L-12_H-768_A-12', {
            'FIXED_LEN': config['fixed_len'],
            'DO_LOWER_CASE': True
        })

    # adapter. call tensorizer, convert a batch of examples to big tensors.
    adapter = ZelAdapter(tensorizer, tensorizer)

    # embedding model. for predication.
    bert = Bert.from_pretrained('../pretrain/multi_cased_L-12_H-768_A-12', {
        'POOLING_METHOD': 'avg',
        'FINETUNE_LAYER_RANGE': '9:12'
    })

    # siamese bert for training.
    model = SimilarNet(bert, bert, bert.config.hidden_size, {
        'DROP_OUT_PROB': 0.1,
        'ACT_NAME': 'relu',
        'USE_BIAS': False
    })

    # trainer. to train siamese bert.
    trainer = Trainer({
        'dataset': dataset,
        'adapter': adapter,
        'model': model,
        'DEVICE': torch.device(config['device']),
        'TRAIN_BATCH_SIZE': 150,
        'VALID_BATCH_SIZE': 500,
        'ROUND': 10
    })

    # train start here.
    trainer.train()

    # train done, fetch bert model to prediction.
    tester = ZelPredictor(
        model, adapter, {
            'TEST_BATCH_SIZE': 200,
            'EMB_BATCH_SIZE': 1000,
            'DEVICE': torch.device(config['device'])
        })
    # add candidates.
    tester.set_candidates(dataset.all_candidates())
    tester.save(config['saved_file'])

    # we start test here.
    for lan in config['lan']:
        print(f'we are testing lan: {lan}')
        test_dataset = Crossel('../datasets/crossel', {'TEST_LANGUAGE': lan})
        test_data = test_dataset.test_data()
        for i in config['top_what']:
            print(f'now top-{i} ACC:')
            tester.test(test_data, i)