def test_create_batches_groups_correctly(self): iterator = BucketIterator(batch_size=2, padding_noise=0, sorting_keys=[('text', 'num_tokens')]) batches = list(iterator._create_batches(self.instances, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[self.instances[4], self.instances[2]], [self.instances[0], self.instances[1]], [self.instances[3]]]
def test_biggest_batch_first_works(self): iterator = BucketIterator(batch_size=2, padding_noise=0, sorting_keys=[('text', 'num_tokens')], biggest_batch_first=True) grouped_instances = iterator._create_batches(self.dataset, shuffle=False) assert grouped_instances == [[self.instances[3]], [self.instances[0], self.instances[1]], [self.instances[4], self.instances[2]]]
def test_biggest_batch_first_works(self): iterator = BucketIterator(batch_size=2, padding_noise=0, sorting_keys=[(u'text', u'num_tokens')], biggest_batch_first=True) batches = list(iterator._create_batches(self.instances, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[self.instances[3]], [self.instances[0], self.instances[1]], [self.instances[4], self.instances[2]]]
def test_biggest_batch_first_works(self): iterator = BucketIterator(batch_size=2, padding_noise=0, sorting_keys=[('text', 'num_tokens')], biggest_batch_first=True) iterator.index_with(self.vocab) batches = list(iterator._create_batches(self.instances, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[self.instances[3]], [self.instances[0], self.instances[1]], [self.instances[4], self.instances[2]]]
def test_create_batches_groups_correctly(self): iterator = BucketIterator( batch_size=2, padding_noise=0, sorting_keys=[("text", "tokens___tokens")] ) iterator.index_with(self.vocab) batches = list(iterator._create_batches(self.instances, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [ [self.instances[4], self.instances[2]], [self.instances[0], self.instances[1]], [self.instances[3]], ]
def test_skip_smaller_batches_works(self): iterator = BucketIterator(batch_size=2, padding_noise=0, sorting_keys=[('text', 'num_tokens')], skip_smaller_batches=True) iterator.index_with(self.vocab) batches = list(iterator._create_batches(self.instances, shuffle=False)) stats = self.get_batches_stats(batches) # all batches have length batch_size assert all(batch_len == 2 for batch_len in stats['batch_lengths']) # we should have lost one instance by skipping the last batch assert stats['total_instances'] == len(self.instances) - 1
def test_create_batches_groups_correctly_with_max_instances(self): # If we knew all the instances, the correct order is 4 -> 2 -> 0 -> 1 -> 3. # Here max_instances_in_memory is 3, so we load instances [0, 1, 2] # and then bucket them by size into batches of size 2 to get [2, 0] -> [1]. # Then we load the remaining instances and bucket them by size to get [4, 3]. iterator = BucketIterator(batch_size=2, padding_noise=0, sorting_keys=[('text', 'num_tokens')], max_instances_in_memory=3) for test_instances in (self.instances, self.lazy_instances): batches = list(iterator._create_batches(test_instances, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[self.instances[2], self.instances[0]], [self.instances[1]], [self.instances[4], self.instances[3]]]
def test_create_batches_groups_correctly_with_max_instances(self): # If we knew all the instances, the correct order is 4 -> 2 -> 0 -> 1 -> 3. # Here max_instances_in_memory is 3, so we load instances [0, 1, 2] # and then bucket them by size into batches of size 2 to get [2, 0] -> [1]. # Then we load the remaining instances and bucket them by size to get [4, 3]. iterator = BucketIterator(batch_size=2, padding_noise=0, sorting_keys=[(u'text', u'num_tokens')], max_instances_in_memory=3) for test_instances in (self.instances, self.lazy_instances): batches = list(iterator._create_batches(test_instances, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[self.instances[2], self.instances[0]], [self.instances[1]], [self.instances[4], self.instances[3]]]
def test_bucket_iterator_maximum_samples_per_batch(self): iterator = BucketIterator(batch_size=3, padding_noise=0, sorting_keys=[('text', 'num_tokens')], maximum_samples_per_batch=['num_tokens', 9]) iterator.index_with(self.vocab) batches = list(iterator._create_batches(self.instances, shuffle=False)) stats = self.get_batches_stats(batches) # ensure all instances are in a batch assert stats['total_instances'] == len(self.instances) # ensure correct batch sizes assert stats['batch_lengths'] == [2, 2, 1] # ensure correct sample sizes (<= 9) assert stats['sample_sizes'] == [6, 8, 9]
def test_bucket_iterator_maximum_samples_per_batch(self): iterator = BucketIterator(batch_size=3, padding_noise=0, sorting_keys=[('text', 'num_tokens')], maximum_samples_per_batch=['num_tokens', 9]) batches = list(iterator._create_batches(self.instances, shuffle=False)) # ensure all instances are in a batch grouped_instances = [batch.instances for batch in batches] num_instances = sum(len(group) for group in grouped_instances) assert num_instances == len(self.instances) # ensure all batches are sufficiently small for batch in batches: batch_sequence_length = max([ instance.get_padding_lengths()['text']['num_tokens'] for instance in batch.instances ]) assert batch_sequence_length * len(batch.instances) <= 9
def test_bucket_iterator_maximum_samples_per_batch(self): iterator = BucketIterator( batch_size=3, padding_noise=0, sorting_keys=[('text', 'num_tokens')], maximum_samples_per_batch=['num_tokens', 9] ) iterator.index_with(self.vocab) batches = list(iterator._create_batches(self.instances, shuffle=False)) stats = self.get_batches_stats(batches) # ensure all instances are in a batch assert stats['total_instances'] == len(self.instances) # ensure correct batch sizes assert stats['batch_lengths'] == [2, 2, 1] # ensure correct sample sizes (<= 9) assert stats['sample_sizes'] == [6, 8, 9]
def test_bucket_iterator_maximum_samples_per_batch(self): iterator = BucketIterator( batch_size=3, padding_noise=0, sorting_keys=[('text', 'num_tokens')], maximum_samples_per_batch=['num_tokens', 9] ) batches = list(iterator._create_batches(self.instances, shuffle=False)) # ensure all instances are in a batch grouped_instances = [batch.instances for batch in batches] num_instances = sum(len(group) for group in grouped_instances) assert num_instances == len(self.instances) # ensure all batches are sufficiently small for batch in batches: batch_sequence_length = max( [instance.get_padding_lengths()['text']['num_tokens'] for instance in batch.instances] ) assert batch_sequence_length * len(batch.instances) <= 9
def test_maximum_samples_per_batch_packs_tightly(self): token_counts = [10, 4, 3] test_instances = self.create_instances_from_token_counts(token_counts) iterator = BucketIterator(batch_size=3, padding_noise=0, sorting_keys=[('text', 'num_tokens')], maximum_samples_per_batch=['num_tokens', 11]) iterator.index_with(self.vocab) batches = list(iterator._create_batches(test_instances, shuffle=False)) stats = self.get_batches_stats(batches) # ensure all instances are in a batch assert stats['total_instances'] == len(test_instances) # ensure correct batch sizes assert stats['batch_lengths'] == [2, 1] # ensure correct sample sizes (<= 11) assert stats['sample_sizes'] == [8, 10]
def test_maximum_samples_per_batch_packs_tightly(self): token_counts = [10, 4, 3] test_instances = self.create_instances_from_token_counts(token_counts) iterator = BucketIterator( batch_size=3, padding_noise=0, sorting_keys=[('text', 'num_tokens')], maximum_samples_per_batch=['num_tokens', 11] ) iterator.index_with(self.vocab) batches = list(iterator._create_batches(test_instances, shuffle=False)) stats = self.get_batches_stats(batches) # ensure all instances are in a batch assert stats['total_instances'] == len(test_instances) # ensure correct batch sizes assert stats['batch_lengths'] == [2, 1] # ensure correct sample sizes (<= 11) assert stats['sample_sizes'] == [8, 10]
imdb_train_dataset = reader.read('./data/mtl-dataset/imdb.task.train') imdb_test_dataset = reader.read('./data/mtl-dataset/imdb.task.test') vocab = Vocabulary.from_instances(books_train_dataset + books_validation_dataset) iterator = BucketIterator(batch_size=128, sorting_keys=[("tokens", "num_tokens")]) iterator.index_with(vocab) print(vocab._index_to_token) # print(vocab.__getstate__()['_token_to_index']['labels']) # for batch in itera tor(books_train_dataset, num_epochs=1, shuffle=True): # print(batch['tokens']['tokens'], batch['label']) print(iterator.get_num_batches(books_train_dataset)) books_iter = iter(iterator._create_batches(books_train_dataset, shuffle=True)) print(len(books_train_dataset)) print(next(books_iter).as_tensor_dict()) ''' EMBEDDING_DIM = 300 token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'), embedding_dim=EMBEDDING_DIM, pretrained_file='/media/sihui/000970CB000A4CA8/Sentiment-Analysis/embeddings/glove.42B.300d.txt', trainable=False) # character_embedding = TokenCharactersEncoder(embedding=Embedding(num_embeddings=vocab.get_vocab_size('tokens_characters'), embedding_dim=8), # encoder=CnnEncoder(embedding_dim=8, num_filters=100, ngram_filter_sizes=[5]), dropout=0.2) word_embeddings = BasicTextFieldEmbedder({'tokens': token_embedding}) # lstm = PytorchSeq2SeqWrapper(nn.LSTM(input_size=308, hidden_size=100, num_layers=1, bidirectional=True, batch_first=True))
def train(args): source_reader = ACSADatasetReader(max_sequence_len=args.max_seq_len) target_reader = ABSADatasetReader(max_sequence_len=args.max_seq_len) source_dataset_train = source_reader.read('./data/MGAN/data/restaurant/train.txt') source_dataset_dev = source_reader.read('./data/MGAN/data/restaurant/test.txt') target_dataset_train = target_reader.read('/media/sihui/000970CB000A4CA8/Sentiment-Analysis/data/semeval14/Restaurants_Train.xml.seg') target_dataset_dev = target_reader.read('/media/sihui/000970CB000A4CA8/Sentiment-Analysis/data/semeval14/Restaurants_Test_Gold.xml.seg') vocab = Vocabulary.from_instances(source_dataset_train + source_dataset_dev + target_dataset_train + target_dataset_dev) word2idx = vocab.get_token_to_index_vocabulary() print(word2idx) embedding_matrix = build_embedding_matrix(word2idx, 300, './embedding/embedding_res_res.dat', '/media/sihui/000970CB000A4CA8/Sentiment-Analysis/embeddings/glove.42B.300d.txt') iterator = BucketIterator(batch_size=args.batch_size, sorting_keys=[('text', 'num_tokens'), ('aspect', 'num_tokens')]) iterator.index_with(vocab) my_net = ACSA2ABSA(args, word_embeddings=embedding_matrix) optimizer = optim.Adam(my_net.parameters(), lr=args.learning_rate) loss_class = torch.nn.CrossEntropyLoss() loss_domain = torch.nn.CrossEntropyLoss() my_net = my_net.to(args.device) loss_class = loss_class.to(args.device) loss_domain = loss_domain.to(args.device) n_epoch = args.epoch max_test_acc = 0 best_epoch = 0 data_target_iter = iter(iterator(target_dataset_train, shuffle=True)) # iterator over it forever for epoch in range(n_epoch): len_target_dataloader = iterator.get_num_batches(target_dataset_train) len_source_dataloader = iterator.get_num_batches(source_dataset_train) data_source_iter = iter(iterator._create_batches(source_dataset_train, shuffle=True)) # data_target_iter = iter(iterator._create_batches(target_dataset_train, shuffle=True)) s_correct, s_total = 0, 0 i = 0 while i < len_source_dataloader: my_net.train() p = float(i + epoch * len_target_dataloader) / n_epoch / len_target_dataloader alpha = 2. / (1. + np.exp(-10 * p)) - 1 # train model using source data data_source = next(data_source_iter).as_tensor_dict() s_text, s_aspect, s_label = data_source['text']['tokens'], data_source['aspect']['tokens'], data_source['label'] batch_size = len(s_label) s_domain_label = torch.zeros(batch_size).long().to(args.device) my_net.zero_grad() s_text, s_aspect, s_label = s_text.to(args.device), s_aspect.to(args.device), s_label.to(args.device) s_class_output, s_domain_output = my_net(s_text, s_aspect, alpha) err_s_label = loss_class(s_class_output, s_label) # err_s_domain = loss_domain(s_domain_output, s_domain_label) # training model using target data # data_target = next(data_target_iter).as_tensor_dict() ''' data_target = next(data_target_iter) t_text, t_aspect, t_label = data_target['text']['tokens'], data_target['aspect']['tokens'], data_target['label'] batch_size = len(t_label) t_domain_label = torch.ones(batch_size).long().to(args.device) t_text, t_aspect, t_label = t_text.to(args.device), t_aspect.to(args.device), t_label.to(args.device) t_class_output, t_domain_output = my_net(t_text, t_aspect, alpha) # err_t_domain = loss_domain(t_domain_output, t_domain_label) ''' # loss = err_t_domain + err_s_domain + err_s_label loss = err_s_label loss.backward() if args.use_grad_clip: clip_grad_norm_(my_net.parameters(), args.grad_clip) optimizer.step() i += 1 s_correct += (torch.argmax(s_class_output, -1) == s_label).sum().item() s_total += len(s_class_output) train_acc = s_correct / s_total # evaluate every 50 batch if i % 100 == 0: my_net.eval() # evaluate model on source test data s_test_correct, s_test_total = 0, 0 s_targets_all, s_output_all = None, None with torch.no_grad(): for i_batch, s_test_batch in enumerate(iterator(source_dataset_dev, num_epochs=1, shuffle=False)): s_test_text = s_test_batch['text']['tokens'].to(args.device) s_test_aspect = s_test_batch['aspect']['tokens'].to(args.device) s_test_label = s_test_batch['label'].to(args.device) s_test_output, _ = my_net(s_test_text, s_test_aspect, alpha) s_test_correct += (torch.argmax(s_test_output, -1) == s_test_label).sum().item() s_test_total += len(s_test_label) if s_targets_all is None: s_targets_all = s_test_label s_output_all = s_test_output else: s_targets_all = torch.cat((s_targets_all, s_test_label), dim=0) s_output_all = torch.cat((s_output_all, s_test_output), dim=0) s_test_acc = s_test_correct / s_test_total if s_test_acc > max_test_acc: max_test_acc = s_test_acc best_epoch = epoch if not os.path.exists('state_dict'): os.mkdir('state_dict') if s_test_acc > 0.868: path = 'state_dict/source_test_epoch{0}_acc_{1}'.format(epoch, round(s_test_acc, 4)) torch.save(my_net.state_dict(), path) print('epoch: %d, [iter: %d / all %d], loss_s_label: %f, ' 's_train_acc: %f, s_test_acc: %f'% (epoch, i, len_source_dataloader, err_s_label.cpu().item(), #err_s_domain.cpu().item(), #err_t_domain.cpu().item(), train_acc, s_test_acc)) print('max_test_acc: {0} in epoch: {1}'.format(max_test_acc, best_epoch))