class Seq2SeqPredictor: def __init__(self, model: Model, data_reader: SummDataReader, batch_size: int, cuda_device: int): self.cuda_device = cuda_device self.iterator = BucketIterator(batch_size=batch_size, sorting_keys=[("source_tokens", "num_tokens")]) self.model = model self.data_reader = data_reader def _extract_data(self, batch) -> numpy.ndarray: out_dict = self.model(**batch) return out_dict def predict(self, file_path: str, vocab_path: str): ds = self.data_reader.read(file_path) vocab = Vocabulary.from_files(vocab_path) self.iterator.index_with(vocab) self.model.eval() pred_generator = self.iterator(ds, num_epochs=1, shuffle=False) pred_generator_tqdm = tqdm(pred_generator, total=self.iterator.get_num_batches(ds)) preds = [] with torch.no_grad(): for batch in pred_generator_tqdm: batch = util.move_to_device(batch, self.cuda_device) preds.append(self._extract_data(batch)) return preds
def run_model(args): st_ds_conf = get_updated_settings(args) reader = data_adapter.GeoQueryDatasetReader() training_set = reader.read(config.DATASETS[args.dataset].train_path) try: validation_set = reader.read(config.DATASETS[args.dataset].dev_path) except: validation_set = None vocab = allennlp.data.Vocabulary.from_instances(training_set) model = get_model(vocab, st_ds_conf) device_tag = "cpu" if config.DEVICE < 0 else f"cuda:{config.DEVICE}" if args.models: model.load_state_dict( torch.load(args.models[0], map_location=device_tag)) if not args.test or not args.models: iterator = BucketIterator(sorting_keys=[("source_tokens", "num_tokens") ], batch_size=st_ds_conf['batch_sz']) iterator.index_with(vocab) optim = torch.optim.Adam(model.parameters(), lr=config.ADAM_LR, betas=config.ADAM_BETAS, eps=config.ADAM_EPS) if args.fine_tune: optim = torch.optim.SGD(model.parameters(), lr=config.SGD_LR) savepath = os.path.join( config.SNAPSHOT_PATH, args.dataset, 'unc_s2s', datetime.datetime.now().strftime('%Y%m%d-%H%M%S') + "--" + args.memo) if not os.path.exists(savepath): os.makedirs(savepath, mode=0o755) trainer = allennlp.training.Trainer( model=model, optimizer=optim, iterator=iterator, train_dataset=training_set, validation_dataset=validation_set, serialization_dir=savepath, cuda_device=config.DEVICE, num_epochs=config.TRAINING_LIMIT, grad_clipping=config.GRAD_CLIPPING, num_serialized_models_to_keep=-1, ) trainer.train() else: if args.test_on_val: testing_set = reader.read(config.DATASETS[args.dataset].dev_path) else: testing_set = reader.read(config.DATASETS[args.dataset].test_path) model.eval() model.skip_loss = True # skip loss computation on testing set for faster evaluation if config.DEVICE > -1: model = model.cuda(config.DEVICE) # batch testing iterator = BucketIterator(sorting_keys=[("source_tokens", "num_tokens") ], batch_size=st_ds_conf['batch_sz']) iterator.index_with(vocab) eval_generator = iterator(testing_set, num_epochs=1, shuffle=False) for batch in tqdm.tqdm(eval_generator, total=iterator.get_num_batches(testing_set)): batch = move_to_device(batch, config.DEVICE) output = model(**batch) metrics = model.get_metrics() print(metrics) if args.dump_test: predictor = allennlp.predictors.SimpleSeq2SeqPredictor( model, reader) for instance in tqdm.tqdm(testing_set, total=len(testing_set)): print('SRC: ', instance.fields['source_tokens'].tokens) print( 'GOLD:', ' '.join( str(x) for x in instance.fields['target_tokens'].tokens[1:-1])) del instance.fields['target_tokens'] output = predictor.predict_instance(instance) print('PRED:', ' '.join(output['predicted_tokens']))
books_train_dataset = reader.read('./data/mtl-dataset/books.task.train') books_validation_dataset = reader.read('./data/mtl-dataset/books.task.test') imdb_train_dataset = reader.read('./data/mtl-dataset/imdb.task.train') imdb_test_dataset = reader.read('./data/mtl-dataset/imdb.task.test') vocab = Vocabulary.from_instances(books_train_dataset + books_validation_dataset) iterator = BucketIterator(batch_size=128, sorting_keys=[("tokens", "num_tokens")]) iterator.index_with(vocab) print(vocab._index_to_token) # print(vocab.__getstate__()['_token_to_index']['labels']) # for batch in itera tor(books_train_dataset, num_epochs=1, shuffle=True): # print(batch['tokens']['tokens'], batch['label']) print(iterator.get_num_batches(books_train_dataset)) books_iter = iter(iterator._create_batches(books_train_dataset, shuffle=True)) print(len(books_train_dataset)) print(next(books_iter).as_tensor_dict()) ''' EMBEDDING_DIM = 300 token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'), embedding_dim=EMBEDDING_DIM, pretrained_file='/media/sihui/000970CB000A4CA8/Sentiment-Analysis/embeddings/glove.42B.300d.txt', trainable=False) # character_embedding = TokenCharactersEncoder(embedding=Embedding(num_embeddings=vocab.get_vocab_size('tokens_characters'), embedding_dim=8), # encoder=CnnEncoder(embedding_dim=8, num_filters=100, ngram_filter_sizes=[5]), dropout=0.2) word_embeddings = BasicTextFieldEmbedder({'tokens': token_embedding})
def train_epoch(model, train_dataset, validation_dataset, batch_size, optimizer, log_period, validation_period, save_dir, log_dir, cuda): """ Train the model for one epoch. """ # Set model to train mode (turns on dropout and such). model.train() # Create objects for calculating metrics. span_start_accuracy_metric = CategoricalAccuracy() span_end_accuracy_metric = CategoricalAccuracy() span_accuracy_metric = BooleanAccuracy() squad_metrics = SquadEmAndF1() # Create Tensorboard logger. writer = SummaryWriter(log_dir) # Build iterater, and have it bucket batches by passage / question length. iterator = BucketIterator(batch_size=batch_size, sorting_keys=[("passage", "num_tokens"), ("question", "num_tokens")]) num_training_batches = iterator.get_num_batches(train_dataset) # Get a generator of train batches. train_generator = tqdm(iterator(train_dataset, num_epochs=1, cuda_device=0 if cuda else -1), total=num_training_batches, leave=False) log_period_losses = 0 for batch in train_generator: # Extract the relevant data from the batch. passage = batch["passage"]["tokens"] question = batch["question"]["tokens"] span_start = batch["span_start"] span_end = batch["span_end"] metadata = batch.get("metadata", {}) # Run data through model to get start and end logits. output_dict = model(passage, question) start_logits = output_dict["start_logits"] end_logits = output_dict["end_logits"] softmax_start_logits = output_dict["softmax_start_logits"] softmax_end_logits = output_dict["softmax_end_logits"] # Calculate loss for start and end indices. loss = nll_loss(softmax_start_logits, span_start.view(-1)) loss += nll_loss(softmax_end_logits, span_end.view(-1)) log_period_losses += loss.data[0] # Backprop and take a gradient step. optimizer.zero_grad() loss.backward() optimizer.step() model.global_step += 1 # Calculate categorical span start and end accuracy. span_start_accuracy_metric(start_logits, span_start.view(-1)) span_end_accuracy_metric(end_logits, span_end.view(-1)) # Compute the best span, and calculate overall span accuracy. best_span = get_best_span(start_logits, end_logits) span_accuracy_metric(best_span, torch.stack([span_start, span_end], -1)) # Calculate EM and F1 scores calculate_em_f1(best_span, metadata, passage.size(0), squad_metrics) if model.global_step % log_period == 0: # Calculate metrics on train set. loss = log_period_losses / log_period span_start_accuracy = span_start_accuracy_metric.get_metric( reset=True) span_end_accuracy = span_end_accuracy_metric.get_metric(reset=True) span_accuracy = span_accuracy_metric.get_metric(reset=True) em, f1 = squad_metrics.get_metric(reset=True) tqdm_description = _make_tqdm_description(loss, em, f1) # Log training statistics to progress bar train_generator.set_description(tqdm_description) # Log training statistics to Tensorboard log_to_tensorboard(writer, model.global_step, "train", loss, span_start_accuracy, span_end_accuracy, span_accuracy, em, f1) log_period_losses = 0 if model.global_step % validation_period == 0: # Calculate metrics on validation set. (loss, span_start_accuracy, span_end_accuracy, span_accuracy, em, f1) = evaluate(model, validation_dataset, batch_size, cuda) # Save a checkpoint. save_name = ("{}_step_{}_loss_{:.3f}_" "em_{:.3f}_f1_{:.3f}.pth".format( model.__class__.__name__, model.global_step, loss, em, f1)) save_model(model, save_dir, save_name) # Log validation statistics to Tensorboard. log_to_tensorboard(writer, model.global_step, "validation", loss, span_start_accuracy, span_end_accuracy, span_accuracy, em, f1)
def main(param2val): # params params = Params.from_param2val(param2val) print(params, flush=True) # paths project_path = Path(param2val['project_path']) save_path = Path(param2val['save_path']) srl_eval_path = project_path / 'perl' / 'srl-eval.pl' data_path_mlm = project_path / 'data' / 'training' / f'{params.corpus_name}_mlm.txt' data_path_train_srl = project_path / 'data' / 'training' / f'{params.corpus_name}_no-dev_srl.txt' data_path_devel_srl = project_path / 'data' / 'training' / f'human-based-2018_srl.txt' data_path_test_srl = project_path / 'data' / 'training' / f'human-based-2008_srl.txt' childes_vocab_path = project_path / 'data' / f'{params.corpus_name}_vocab.txt' google_vocab_path = project_path / 'data' / 'bert-base-cased.txt' # to get word pieces # word-piece tokenizer - defines input vocabulary vocab = load_vocab(childes_vocab_path, google_vocab_path, params.vocab_size) # TODO testing google vocab with wordpieces assert vocab['[PAD]'] == 0 # AllenNLP expects this assert vocab['[UNK]'] == 1 # AllenNLP expects this assert vocab['[CLS]'] == 2 assert vocab['[SEP]'] == 3 assert vocab['[MASK]'] == 4 wordpiece_tokenizer = WordpieceTokenizer(vocab) print(f'Number of types in vocab={len(vocab):,}') # load utterances for MLM task utterances = load_utterances_from_file(data_path_mlm) train_utterances, devel_utterances, test_utterances = split(utterances) # load propositions for SLR task propositions = load_propositions_from_file(data_path_train_srl) train_propositions, devel_propositions, test_propositions = split( propositions) if data_path_devel_srl.is_file( ): # use human-annotated data as devel split print(f'Using {data_path_devel_srl.name} as SRL devel split') devel_propositions = load_propositions_from_file(data_path_devel_srl) if data_path_test_srl.is_file(): # use human-annotated data as test split print(f'Using {data_path_test_srl.name} as SRL test split') test_propositions = load_propositions_from_file(data_path_test_srl) # converters handle conversion from text to instances converter_mlm = ConverterMLM(params, wordpiece_tokenizer) converter_srl = ConverterSRL(params, wordpiece_tokenizer) # get output_vocab # note: Allen NLP vocab holds labels, wordpiece_tokenizer.vocab holds input tokens # what from_instances() does: # 1. it iterates over all instances, and all fields, and all token indexers # 2. the token indexer is used to update vocabulary count, skipping words whose text_id is already set # 4. a PADDING and MASK symbol are added to 'tokens' namespace resulting in vocab size of 2 # input tokens are not indexed, as they are already indexed by bert tokenizer vocab. # this ensures that the model is built with inputs for all vocab words, # such that words that occur only in LM or SRL task can still be input # make instances once - this allows iterating multiple times (required when num_epochs > 1) train_instances_mlm = converter_mlm.make_instances(train_utterances) devel_instances_mlm = converter_mlm.make_instances(devel_utterances) test_instances_mlm = converter_mlm.make_instances(test_utterances) train_instances_srl = converter_srl.make_instances(train_propositions) devel_instances_srl = converter_srl.make_instances(devel_propositions) test_instances_srl = converter_srl.make_instances(test_propositions) all_instances_mlm = chain(train_instances_mlm, devel_instances_mlm, test_instances_mlm) all_instances_srl = chain(train_instances_srl, devel_instances_srl, test_instances_srl) # make vocab from all instances output_vocab_mlm = Vocabulary.from_instances(all_instances_mlm) output_vocab_srl = Vocabulary.from_instances(all_instances_srl) # print(f'mlm vocab size={output_vocab_mlm.get_vocab_size()}') # contain just 2 tokens # print(f'srl vocab size={output_vocab_srl.get_vocab_size()}') # contain just 2 tokens assert output_vocab_mlm.get_vocab_size( 'tokens') == output_vocab_srl.get_vocab_size('tokens') # BERT print('Preparing Multi-task BERT...') input_vocab_size = len(converter_mlm.wordpiece_tokenizer.vocab) bert_config = BertConfig( vocab_size_or_config_json_file=input_vocab_size, # was 32K hidden_size=params.hidden_size, # was 768 num_hidden_layers=params.num_layers, # was 12 num_attention_heads=params.num_attention_heads, # was 12 intermediate_size=params.intermediate_size) # was 3072 bert_model = BertModel(config=bert_config) # Multi-tasking BERT mt_bert = MTBert(vocab_mlm=output_vocab_mlm, vocab_srl=output_vocab_srl, bert_model=bert_model, embedding_dropout=params.embedding_dropout) mt_bert.cuda() num_params = sum(p.numel() for p in mt_bert.parameters() if p.requires_grad) print('Number of model parameters: {:,}'.format(num_params), flush=True) # optimizers optimizer_mlm = BertAdam(params=mt_bert.parameters(), lr=params.lr) optimizer_srl = BertAdam(params=mt_bert.parameters(), lr=params.lr) move_optimizer_to_cuda(optimizer_mlm) move_optimizer_to_cuda(optimizer_srl) # batching bucket_batcher_mlm = BucketIterator(batch_size=params.batch_size, sorting_keys=[('tokens', "num_tokens") ]) bucket_batcher_mlm.index_with(output_vocab_mlm) bucket_batcher_srl = BucketIterator(batch_size=params.batch_size, sorting_keys=[('tokens', "num_tokens") ]) bucket_batcher_srl.index_with(output_vocab_srl) # big batcher to speed evaluation - 1024 is too big bucket_batcher_mlm_large = BucketIterator(batch_size=512, sorting_keys=[('tokens', "num_tokens")]) bucket_batcher_srl_large = BucketIterator(batch_size=512, sorting_keys=[('tokens', "num_tokens")]) bucket_batcher_mlm_large.index_with(output_vocab_mlm) bucket_batcher_srl_large.index_with(output_vocab_srl) # init performance collection name2col = { 'devel_pps': [], 'devel_f1s': [], } # init eval_steps = [] train_start = time.time() loss_mlm = None no_mlm_batches = False step = 0 # generators train_generator_mlm = bucket_batcher_mlm(train_instances_mlm, num_epochs=params.num_mlm_epochs) train_generator_srl = bucket_batcher_srl( train_instances_srl, num_epochs=None) # infinite generator num_train_mlm_batches = bucket_batcher_mlm.get_num_batches( train_instances_mlm) if params.srl_interleaved: max_step = num_train_mlm_batches else: max_step = num_train_mlm_batches * 2 print(f'Will stop training at step={max_step:,}') while step < max_step: # TRAINING if step != 0: # otherwise evaluation at step 0 is influenced by training on one batch mt_bert.train() # masked language modeling task try: batch_mlm = next(train_generator_mlm) except StopIteration: if params.srl_interleaved: break else: no_mlm_batches = True else: loss_mlm = mt_bert.train_on_batch('mlm', batch_mlm, optimizer_mlm) # semantic role labeling task if params.srl_interleaved: if random.random() < params.srl_probability: batch_srl = next(train_generator_srl) mt_bert.train_on_batch('srl', batch_srl, optimizer_srl) elif no_mlm_batches: batch_srl = next(train_generator_srl) mt_bert.train_on_batch('srl', batch_srl, optimizer_srl) # EVALUATION if step % config.Eval.interval == 0: mt_bert.eval() eval_steps.append(step) # evaluate perplexity devel_generator_mlm = bucket_batcher_mlm_large(devel_instances_mlm, num_epochs=1) devel_pp = evaluate_model_on_pp(mt_bert, devel_generator_mlm) name2col['devel_pps'].append(devel_pp) print(f'devel-pp={devel_pp}', flush=True) # test sentences if config.Eval.test_sentences: test_generator_mlm = bucket_batcher_mlm_large( test_instances_mlm, num_epochs=1) out_path = save_path / f'test_split_mlm_results_{step}.txt' predict_masked_sentences(mt_bert, test_generator_mlm, out_path) # probing - test sentences for specific syntactic tasks for name in config.Eval.probing_names: # prepare data probing_data_path_mlm = project_path / 'data' / 'probing' / f'{name}.txt' if not probing_data_path_mlm.exists(): print(f'WARNING: {probing_data_path_mlm} does not exist') continue probing_utterances_mlm = load_utterances_from_file( probing_data_path_mlm) # check that probing words are in vocab for u in probing_utterances_mlm: # print(u) for w in u: if w == '[MASK]': continue # not in output vocab # print(w) assert output_vocab_mlm.get_token_index( w, namespace='labels'), w # probing + save results to text probing_instances_mlm = converter_mlm.make_probing_instances( probing_utterances_mlm) probing_generator_mlm = bucket_batcher_mlm( probing_instances_mlm, num_epochs=1) out_path = save_path / f'probing_{name}_results_{step}.txt' predict_masked_sentences(mt_bert, probing_generator_mlm, out_path, print_gold=False, verbose=True) # evaluate devel f1 devel_generator_srl = bucket_batcher_srl_large(devel_instances_srl, num_epochs=1) devel_f1 = evaluate_model_on_f1(mt_bert, srl_eval_path, devel_generator_srl) name2col['devel_f1s'].append(devel_f1) print(f'devel-f1={devel_f1}', flush=True) # console min_elapsed = (time.time() - train_start) // 60 pp = torch.exp(loss_mlm) if loss_mlm is not None else np.nan print( f'step {step:<6,}: pp={pp :2.4f} total minutes elapsed={min_elapsed:<3}', flush=True) # only increment step once in each iteration of the loop, otherwise evaluation may never happen step += 1 # evaluate train perplexity if config.Eval.train_split: generator_mlm = bucket_batcher_mlm_large(train_instances_mlm, num_epochs=1) train_pp = evaluate_model_on_pp(mt_bert, generator_mlm) else: train_pp = np.nan print(f'train-pp={train_pp}', flush=True) # evaluate train f1 if config.Eval.train_split: generator_srl = bucket_batcher_srl_large(train_instances_srl, num_epochs=1) train_f1 = evaluate_model_on_f1(mt_bert, srl_eval_path, generator_srl, print_tag_metrics=True) else: train_f1 = np.nan print(f'train-f1={train_f1}', flush=True) # test sentences if config.Eval.test_sentences: test_generator_mlm = bucket_batcher_mlm(test_instances_mlm, num_epochs=1) out_path = save_path / f'test_split_mlm_results_{step}.txt' predict_masked_sentences(mt_bert, test_generator_mlm, out_path) # probing - test sentences for specific syntactic tasks for name in config.Eval.probing_names: # prepare data probing_data_path_mlm = project_path / 'data' / 'probing' / f'{name}.txt' if not probing_data_path_mlm.exists(): print(f'WARNING: {probing_data_path_mlm} does not exist') continue probing_utterances_mlm = load_utterances_from_file( probing_data_path_mlm) probing_instances_mlm = converter_mlm.make_probing_instances( probing_utterances_mlm) # batch and do inference probing_generator_mlm = bucket_batcher_mlm(probing_instances_mlm, num_epochs=1) out_path = save_path / f'probing_{name}_results_{step}.txt' predict_masked_sentences(mt_bert, probing_generator_mlm, out_path, print_gold=False, verbose=True) # put train-pp and train-f1 into pandas Series s1 = pd.Series([train_pp], index=[eval_steps[-1]]) s1.name = 'train_pp' s2 = pd.Series([train_f1], index=[eval_steps[-1]]) s2.name = 'train_f1' # return performance as pandas Series series_list = [s1, s2] for name, col in name2col.items(): print(f'Making pandas series with name={name} and length={len(col)}') s = pd.Series(col, index=eval_steps) s.name = name series_list.append(s) return series_list
SOURCE_FIELD_NAME = 'source_tokens' TARGET_FIELD_NAME = 'target_tokens' if __name__ == '__main__': print('Reading...') train = lfds.SmallParallelEnJa('train') \ .to_allennlp(source_field_name=SOURCE_FIELD_NAME, target_field_name=TARGET_FIELD_NAME).all() validation = lfds.SmallParallelEnJa('dev') \ .to_allennlp(source_field_name=SOURCE_FIELD_NAME, target_field_name=TARGET_FIELD_NAME).all() if not osp.exists('./enja_vocab'): print('Building vocabulary...') vocab = Vocabulary.from_instances(train + validation, max_vocab_size=50000) print(f'Vocab Size: {vocab.get_vocab_size()}') print('Saving...') vocab.save_to_files('./enja_vocab') else: print('Loading vocabulary...') vocab = Vocabulary.from_files('./enja_vocab') iterator = BucketIterator(sorting_keys=[(SOURCE_FIELD_NAME, 'num_tokens')], batch_size=32) iterator.index_with(vocab) num_batches = iterator.get_num_batches(train) for batch in Tqdm.tqdm(iterator(train, num_epochs=1), total=num_batches): ...
def train(args): source_reader = ACSADatasetReader(max_sequence_len=args.max_seq_len) target_reader = ABSADatasetReader(max_sequence_len=args.max_seq_len) source_dataset_train = source_reader.read('./data/MGAN/data/restaurant/train.txt') source_dataset_dev = source_reader.read('./data/MGAN/data/restaurant/test.txt') target_dataset_train = target_reader.read('/media/sihui/000970CB000A4CA8/Sentiment-Analysis/data/semeval14/Restaurants_Train.xml.seg') target_dataset_dev = target_reader.read('/media/sihui/000970CB000A4CA8/Sentiment-Analysis/data/semeval14/Restaurants_Test_Gold.xml.seg') vocab = Vocabulary.from_instances(source_dataset_train + source_dataset_dev + target_dataset_train + target_dataset_dev) word2idx = vocab.get_token_to_index_vocabulary() print(word2idx) embedding_matrix = build_embedding_matrix(word2idx, 300, './embedding/embedding_res_res.dat', '/media/sihui/000970CB000A4CA8/Sentiment-Analysis/embeddings/glove.42B.300d.txt') iterator = BucketIterator(batch_size=args.batch_size, sorting_keys=[('text', 'num_tokens'), ('aspect', 'num_tokens')]) iterator.index_with(vocab) my_net = ACSA2ABSA(args, word_embeddings=embedding_matrix) optimizer = optim.Adam(my_net.parameters(), lr=args.learning_rate) loss_class = torch.nn.CrossEntropyLoss() loss_domain = torch.nn.CrossEntropyLoss() my_net = my_net.to(args.device) loss_class = loss_class.to(args.device) loss_domain = loss_domain.to(args.device) n_epoch = args.epoch max_test_acc = 0 best_epoch = 0 data_target_iter = iter(iterator(target_dataset_train, shuffle=True)) # iterator over it forever for epoch in range(n_epoch): len_target_dataloader = iterator.get_num_batches(target_dataset_train) len_source_dataloader = iterator.get_num_batches(source_dataset_train) data_source_iter = iter(iterator._create_batches(source_dataset_train, shuffle=True)) # data_target_iter = iter(iterator._create_batches(target_dataset_train, shuffle=True)) s_correct, s_total = 0, 0 i = 0 while i < len_source_dataloader: my_net.train() p = float(i + epoch * len_target_dataloader) / n_epoch / len_target_dataloader alpha = 2. / (1. + np.exp(-10 * p)) - 1 # train model using source data data_source = next(data_source_iter).as_tensor_dict() s_text, s_aspect, s_label = data_source['text']['tokens'], data_source['aspect']['tokens'], data_source['label'] batch_size = len(s_label) s_domain_label = torch.zeros(batch_size).long().to(args.device) my_net.zero_grad() s_text, s_aspect, s_label = s_text.to(args.device), s_aspect.to(args.device), s_label.to(args.device) s_class_output, s_domain_output = my_net(s_text, s_aspect, alpha) err_s_label = loss_class(s_class_output, s_label) # err_s_domain = loss_domain(s_domain_output, s_domain_label) # training model using target data # data_target = next(data_target_iter).as_tensor_dict() ''' data_target = next(data_target_iter) t_text, t_aspect, t_label = data_target['text']['tokens'], data_target['aspect']['tokens'], data_target['label'] batch_size = len(t_label) t_domain_label = torch.ones(batch_size).long().to(args.device) t_text, t_aspect, t_label = t_text.to(args.device), t_aspect.to(args.device), t_label.to(args.device) t_class_output, t_domain_output = my_net(t_text, t_aspect, alpha) # err_t_domain = loss_domain(t_domain_output, t_domain_label) ''' # loss = err_t_domain + err_s_domain + err_s_label loss = err_s_label loss.backward() if args.use_grad_clip: clip_grad_norm_(my_net.parameters(), args.grad_clip) optimizer.step() i += 1 s_correct += (torch.argmax(s_class_output, -1) == s_label).sum().item() s_total += len(s_class_output) train_acc = s_correct / s_total # evaluate every 50 batch if i % 100 == 0: my_net.eval() # evaluate model on source test data s_test_correct, s_test_total = 0, 0 s_targets_all, s_output_all = None, None with torch.no_grad(): for i_batch, s_test_batch in enumerate(iterator(source_dataset_dev, num_epochs=1, shuffle=False)): s_test_text = s_test_batch['text']['tokens'].to(args.device) s_test_aspect = s_test_batch['aspect']['tokens'].to(args.device) s_test_label = s_test_batch['label'].to(args.device) s_test_output, _ = my_net(s_test_text, s_test_aspect, alpha) s_test_correct += (torch.argmax(s_test_output, -1) == s_test_label).sum().item() s_test_total += len(s_test_label) if s_targets_all is None: s_targets_all = s_test_label s_output_all = s_test_output else: s_targets_all = torch.cat((s_targets_all, s_test_label), dim=0) s_output_all = torch.cat((s_output_all, s_test_output), dim=0) s_test_acc = s_test_correct / s_test_total if s_test_acc > max_test_acc: max_test_acc = s_test_acc best_epoch = epoch if not os.path.exists('state_dict'): os.mkdir('state_dict') if s_test_acc > 0.868: path = 'state_dict/source_test_epoch{0}_acc_{1}'.format(epoch, round(s_test_acc, 4)) torch.save(my_net.state_dict(), path) print('epoch: %d, [iter: %d / all %d], loss_s_label: %f, ' 's_train_acc: %f, s_test_acc: %f'% (epoch, i, len_source_dataloader, err_s_label.cpu().item(), #err_s_domain.cpu().item(), #err_t_domain.cpu().item(), train_acc, s_test_acc)) print('max_test_acc: {0} in epoch: {1}'.format(max_test_acc, best_epoch))