def test_create_batches_groups_correctly(self):
     iterator = BucketIterator(batch_size=2, padding_noise=0, sorting_keys=[('text', 'num_tokens')])
     batches = list(iterator._create_batches(self.instances, shuffle=False))
     grouped_instances = [batch.instances for batch in batches]
     assert grouped_instances == [[self.instances[4], self.instances[2]],
                                  [self.instances[0], self.instances[1]],
                                  [self.instances[3]]]
Exemple #2
0
 def test_create_batches_groups_correctly(self):
     iterator = BucketIterator(batch_size=2, padding_noise=0, sorting_keys=[('text', 'num_tokens')])
     batches = list(iterator._create_batches(self.instances, shuffle=False))
     grouped_instances = [batch.instances for batch in batches]
     assert grouped_instances == [[self.instances[4], self.instances[2]],
                                  [self.instances[0], self.instances[1]],
                                  [self.instances[3]]]
 def test_biggest_batch_first_works(self):
     iterator = BucketIterator(batch_size=2,
                               padding_noise=0,
                               sorting_keys=[('text', 'num_tokens')],
                               biggest_batch_first=True)
     grouped_instances = iterator._create_batches(self.dataset, shuffle=False)
     assert grouped_instances == [[self.instances[3]],
                                  [self.instances[0], self.instances[1]],
                                  [self.instances[4], self.instances[2]]]
Exemple #4
0
 def test_biggest_batch_first_works(self):
     iterator = BucketIterator(batch_size=2,
                               padding_noise=0,
                               sorting_keys=[(u'text', u'num_tokens')],
                               biggest_batch_first=True)
     batches = list(iterator._create_batches(self.instances, shuffle=False))
     grouped_instances = [batch.instances for batch in batches]
     assert grouped_instances == [[self.instances[3]],
                                  [self.instances[0], self.instances[1]],
                                  [self.instances[4], self.instances[2]]]
 def test_biggest_batch_first_works(self):
     iterator = BucketIterator(batch_size=2,
                               padding_noise=0,
                               sorting_keys=[('text', 'num_tokens')],
                               biggest_batch_first=True)
     iterator.index_with(self.vocab)
     batches = list(iterator._create_batches(self.instances, shuffle=False))
     grouped_instances = [batch.instances for batch in batches]
     assert grouped_instances == [[self.instances[3]],
                                  [self.instances[0], self.instances[1]],
                                  [self.instances[4], self.instances[2]]]
Exemple #6
0
 def test_create_batches_groups_correctly(self):
     iterator = BucketIterator(
         batch_size=2, padding_noise=0, sorting_keys=[("text", "tokens___tokens")]
     )
     iterator.index_with(self.vocab)
     batches = list(iterator._create_batches(self.instances, shuffle=False))
     grouped_instances = [batch.instances for batch in batches]
     assert grouped_instances == [
         [self.instances[4], self.instances[2]],
         [self.instances[0], self.instances[1]],
         [self.instances[3]],
     ]
    def test_skip_smaller_batches_works(self):
        iterator = BucketIterator(batch_size=2,
                                  padding_noise=0,
                                  sorting_keys=[('text', 'num_tokens')],
                                  skip_smaller_batches=True)
        iterator.index_with(self.vocab)
        batches = list(iterator._create_batches(self.instances, shuffle=False))
        stats = self.get_batches_stats(batches)

        # all batches have length batch_size
        assert all(batch_len == 2 for batch_len in stats['batch_lengths'])

        # we should have lost one instance by skipping the last batch
        assert stats['total_instances'] == len(self.instances) - 1
 def test_create_batches_groups_correctly_with_max_instances(self):
     # If we knew all the instances, the correct order is 4 -> 2 -> 0 -> 1 -> 3.
     # Here max_instances_in_memory is 3, so we load instances [0, 1, 2]
     # and then bucket them by size into batches of size 2 to get [2, 0] -> [1].
     # Then we load the remaining instances and bucket them by size to get [4, 3].
     iterator = BucketIterator(batch_size=2,
                               padding_noise=0,
                               sorting_keys=[('text', 'num_tokens')],
                               max_instances_in_memory=3)
     for test_instances in (self.instances, self.lazy_instances):
         batches = list(iterator._create_batches(test_instances, shuffle=False))
         grouped_instances = [batch.instances for batch in batches]
         assert grouped_instances == [[self.instances[2], self.instances[0]],
                                      [self.instances[1]],
                                      [self.instances[4], self.instances[3]]]
Exemple #9
0
 def test_create_batches_groups_correctly_with_max_instances(self):
     # If we knew all the instances, the correct order is 4 -> 2 -> 0 -> 1 -> 3.
     # Here max_instances_in_memory is 3, so we load instances [0, 1, 2]
     # and then bucket them by size into batches of size 2 to get [2, 0] -> [1].
     # Then we load the remaining instances and bucket them by size to get [4, 3].
     iterator = BucketIterator(batch_size=2,
                               padding_noise=0,
                               sorting_keys=[(u'text', u'num_tokens')],
                               max_instances_in_memory=3)
     for test_instances in (self.instances, self.lazy_instances):
         batches = list(iterator._create_batches(test_instances, shuffle=False))
         grouped_instances = [batch.instances for batch in batches]
         assert grouped_instances == [[self.instances[2], self.instances[0]],
                                      [self.instances[1]],
                                      [self.instances[4], self.instances[3]]]
    def test_bucket_iterator_maximum_samples_per_batch(self):
        iterator = BucketIterator(batch_size=3,
                                  padding_noise=0,
                                  sorting_keys=[('text', 'num_tokens')],
                                  maximum_samples_per_batch=['num_tokens', 9])
        iterator.index_with(self.vocab)
        batches = list(iterator._create_batches(self.instances, shuffle=False))
        stats = self.get_batches_stats(batches)

        # ensure all instances are in a batch
        assert stats['total_instances'] == len(self.instances)

        # ensure correct batch sizes
        assert stats['batch_lengths'] == [2, 2, 1]

        # ensure correct sample sizes (<= 9)
        assert stats['sample_sizes'] == [6, 8, 9]
    def test_bucket_iterator_maximum_samples_per_batch(self):
        iterator = BucketIterator(batch_size=3,
                                  padding_noise=0,
                                  sorting_keys=[('text', 'num_tokens')],
                                  maximum_samples_per_batch=['num_tokens', 9])
        batches = list(iterator._create_batches(self.instances, shuffle=False))

        # ensure all instances are in a batch
        grouped_instances = [batch.instances for batch in batches]
        num_instances = sum(len(group) for group in grouped_instances)
        assert num_instances == len(self.instances)

        # ensure all batches are sufficiently small
        for batch in batches:
            batch_sequence_length = max([
                instance.get_padding_lengths()['text']['num_tokens']
                for instance in batch.instances
            ])
            assert batch_sequence_length * len(batch.instances) <= 9
    def test_bucket_iterator_maximum_samples_per_batch(self):
        iterator = BucketIterator(
                batch_size=3,
                padding_noise=0,
                sorting_keys=[('text', 'num_tokens')],
                maximum_samples_per_batch=['num_tokens', 9]
        )
        iterator.index_with(self.vocab)
        batches = list(iterator._create_batches(self.instances, shuffle=False))
        stats = self.get_batches_stats(batches)

        # ensure all instances are in a batch
        assert stats['total_instances'] == len(self.instances)

        # ensure correct batch sizes
        assert stats['batch_lengths'] == [2, 2, 1]

        # ensure correct sample sizes (<= 9)
        assert stats['sample_sizes'] == [6, 8, 9]
    def test_bucket_iterator_maximum_samples_per_batch(self):
        iterator = BucketIterator(
                batch_size=3, padding_noise=0,
                sorting_keys=[('text', 'num_tokens')],
                maximum_samples_per_batch=['num_tokens', 9]
        )
        batches = list(iterator._create_batches(self.instances, shuffle=False))

        # ensure all instances are in a batch
        grouped_instances = [batch.instances for batch in batches]
        num_instances = sum(len(group) for group in grouped_instances)
        assert num_instances == len(self.instances)

        # ensure all batches are sufficiently small
        for batch in batches:
            batch_sequence_length = max(
                    [instance.get_padding_lengths()['text']['num_tokens']
                     for instance in batch.instances]
            )
            assert batch_sequence_length * len(batch.instances) <= 9
    def test_maximum_samples_per_batch_packs_tightly(self):
        token_counts = [10, 4, 3]
        test_instances = self.create_instances_from_token_counts(token_counts)

        iterator = BucketIterator(batch_size=3,
                                  padding_noise=0,
                                  sorting_keys=[('text', 'num_tokens')],
                                  maximum_samples_per_batch=['num_tokens', 11])
        iterator.index_with(self.vocab)
        batches = list(iterator._create_batches(test_instances, shuffle=False))
        stats = self.get_batches_stats(batches)

        # ensure all instances are in a batch
        assert stats['total_instances'] == len(test_instances)

        # ensure correct batch sizes
        assert stats['batch_lengths'] == [2, 1]

        # ensure correct sample sizes (<= 11)
        assert stats['sample_sizes'] == [8, 10]
    def test_maximum_samples_per_batch_packs_tightly(self):
        token_counts = [10, 4, 3]
        test_instances = self.create_instances_from_token_counts(token_counts)

        iterator = BucketIterator(
                batch_size=3,
                padding_noise=0,
                sorting_keys=[('text', 'num_tokens')],
                maximum_samples_per_batch=['num_tokens', 11]
        )
        iterator.index_with(self.vocab)
        batches = list(iterator._create_batches(test_instances, shuffle=False))
        stats = self.get_batches_stats(batches)

        # ensure all instances are in a batch
        assert stats['total_instances'] == len(test_instances)

        # ensure correct batch sizes
        assert stats['batch_lengths'] == [2, 1]

        # ensure correct sample sizes (<= 11)
        assert stats['sample_sizes'] == [8, 10]
Exemple #16
0
imdb_train_dataset = reader.read('./data/mtl-dataset/imdb.task.train')
imdb_test_dataset = reader.read('./data/mtl-dataset/imdb.task.test')

vocab = Vocabulary.from_instances(books_train_dataset +
                                  books_validation_dataset)
iterator = BucketIterator(batch_size=128,
                          sorting_keys=[("tokens", "num_tokens")])
iterator.index_with(vocab)
print(vocab._index_to_token)
# print(vocab.__getstate__()['_token_to_index']['labels'])
# for batch in itera  tor(books_train_dataset, num_epochs=1, shuffle=True):
#     print(batch['tokens']['tokens'], batch['label'])

print(iterator.get_num_batches(books_train_dataset))

books_iter = iter(iterator._create_batches(books_train_dataset, shuffle=True))
print(len(books_train_dataset))

print(next(books_iter).as_tensor_dict())
'''
EMBEDDING_DIM = 300

token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                            embedding_dim=EMBEDDING_DIM,
                            pretrained_file='/media/sihui/000970CB000A4CA8/Sentiment-Analysis/embeddings/glove.42B.300d.txt',
                            trainable=False)
# character_embedding = TokenCharactersEncoder(embedding=Embedding(num_embeddings=vocab.get_vocab_size('tokens_characters'), embedding_dim=8),
#                                              encoder=CnnEncoder(embedding_dim=8, num_filters=100, ngram_filter_sizes=[5]), dropout=0.2)
word_embeddings = BasicTextFieldEmbedder({'tokens': token_embedding})

# lstm = PytorchSeq2SeqWrapper(nn.LSTM(input_size=308, hidden_size=100, num_layers=1, bidirectional=True, batch_first=True))
Exemple #17
0
def train(args):
    source_reader = ACSADatasetReader(max_sequence_len=args.max_seq_len)
    target_reader = ABSADatasetReader(max_sequence_len=args.max_seq_len)

    source_dataset_train = source_reader.read('./data/MGAN/data/restaurant/train.txt')
    source_dataset_dev = source_reader.read('./data/MGAN/data/restaurant/test.txt')

    target_dataset_train = target_reader.read('/media/sihui/000970CB000A4CA8/Sentiment-Analysis/data/semeval14/Restaurants_Train.xml.seg')
    target_dataset_dev = target_reader.read('/media/sihui/000970CB000A4CA8/Sentiment-Analysis/data/semeval14/Restaurants_Test_Gold.xml.seg')

    vocab = Vocabulary.from_instances(source_dataset_train + source_dataset_dev + target_dataset_train + target_dataset_dev)
    word2idx = vocab.get_token_to_index_vocabulary()
    print(word2idx)
    embedding_matrix = build_embedding_matrix(word2idx, 300, './embedding/embedding_res_res.dat', '/media/sihui/000970CB000A4CA8/Sentiment-Analysis/embeddings/glove.42B.300d.txt')

    iterator = BucketIterator(batch_size=args.batch_size, sorting_keys=[('text', 'num_tokens'), ('aspect', 'num_tokens')])
    iterator.index_with(vocab)

    my_net = ACSA2ABSA(args, word_embeddings=embedding_matrix)

    optimizer = optim.Adam(my_net.parameters(), lr=args.learning_rate)
    loss_class = torch.nn.CrossEntropyLoss()
    loss_domain = torch.nn.CrossEntropyLoss()

    my_net = my_net.to(args.device)
    loss_class = loss_class.to(args.device)
    loss_domain = loss_domain.to(args.device)

    n_epoch = args.epoch

    max_test_acc = 0
    best_epoch = 0

    data_target_iter = iter(iterator(target_dataset_train, shuffle=True))
    # iterator over it forever

    for epoch in range(n_epoch):
        len_target_dataloader = iterator.get_num_batches(target_dataset_train)
        len_source_dataloader = iterator.get_num_batches(source_dataset_train)
        data_source_iter = iter(iterator._create_batches(source_dataset_train, shuffle=True))
        # data_target_iter = iter(iterator._create_batches(target_dataset_train, shuffle=True))
        s_correct, s_total = 0, 0
        i = 0
        while i < len_source_dataloader:
            my_net.train()
            p = float(i + epoch * len_target_dataloader) / n_epoch / len_target_dataloader
            alpha = 2. / (1. + np.exp(-10 * p)) - 1

            # train model using source data
            data_source = next(data_source_iter).as_tensor_dict()
            s_text, s_aspect, s_label = data_source['text']['tokens'], data_source['aspect']['tokens'], data_source['label']
            batch_size = len(s_label)

            s_domain_label = torch.zeros(batch_size).long().to(args.device)

            my_net.zero_grad()

            s_text, s_aspect, s_label = s_text.to(args.device), s_aspect.to(args.device), s_label.to(args.device)
            s_class_output, s_domain_output = my_net(s_text, s_aspect, alpha)

            err_s_label = loss_class(s_class_output, s_label)
            # err_s_domain = loss_domain(s_domain_output, s_domain_label)

            # training model using target data
            # data_target = next(data_target_iter).as_tensor_dict()
            '''
            data_target = next(data_target_iter)
            t_text, t_aspect, t_label = data_target['text']['tokens'], data_target['aspect']['tokens'], data_target['label']

            batch_size = len(t_label)
            t_domain_label = torch.ones(batch_size).long().to(args.device)

            t_text, t_aspect, t_label = t_text.to(args.device), t_aspect.to(args.device), t_label.to(args.device)

            t_class_output, t_domain_output = my_net(t_text, t_aspect, alpha)
            # err_t_domain = loss_domain(t_domain_output, t_domain_label)
            '''
            # loss = err_t_domain + err_s_domain + err_s_label
            loss = err_s_label
            loss.backward()

            if args.use_grad_clip:
                clip_grad_norm_(my_net.parameters(), args.grad_clip)

            optimizer.step()

            i += 1

            s_correct += (torch.argmax(s_class_output, -1) == s_label).sum().item()
            s_total += len(s_class_output)
            train_acc = s_correct / s_total

            # evaluate every 50 batch
            if i % 100 == 0:
                my_net.eval()
                # evaluate model on source test data
                s_test_correct, s_test_total = 0, 0
                s_targets_all, s_output_all = None, None
                with torch.no_grad():
                    for i_batch, s_test_batch in enumerate(iterator(source_dataset_dev, num_epochs=1, shuffle=False)):
                        s_test_text = s_test_batch['text']['tokens'].to(args.device)
                        s_test_aspect = s_test_batch['aspect']['tokens'].to(args.device)
                        s_test_label = s_test_batch['label'].to(args.device)

                        s_test_output, _ = my_net(s_test_text, s_test_aspect, alpha)

                        s_test_correct += (torch.argmax(s_test_output, -1) == s_test_label).sum().item()
                        s_test_total += len(s_test_label)

                        if s_targets_all is None:
                            s_targets_all = s_test_label
                            s_output_all = s_test_output
                        else:
                            s_targets_all = torch.cat((s_targets_all, s_test_label), dim=0)
                            s_output_all = torch.cat((s_output_all, s_test_output), dim=0)

                s_test_acc = s_test_correct / s_test_total
                if s_test_acc > max_test_acc:
                    max_test_acc = s_test_acc
                    best_epoch = epoch
                    if not os.path.exists('state_dict'):
                        os.mkdir('state_dict')
                    if s_test_acc > 0.868:
                        path = 'state_dict/source_test_epoch{0}_acc_{1}'.format(epoch, round(s_test_acc, 4))
                        torch.save(my_net.state_dict(), path)

                print('epoch: %d, [iter: %d / all %d], loss_s_label: %f, '
                      's_train_acc: %f, s_test_acc: %f'% (epoch, i, len_source_dataloader,
                                                                             err_s_label.cpu().item(),
                                                                             #err_s_domain.cpu().item(),
                                                                             #err_t_domain.cpu().item(),
                                                                             train_acc,
                                                                             s_test_acc))
    print('max_test_acc: {0} in epoch: {1}'.format(max_test_acc, best_epoch))