예제 #1
0
class Seq2SeqPredictor:

    def __init__(self, model: Model,
                 data_reader: SummDataReader,
                 batch_size: int,
                 cuda_device: int):
        self.cuda_device = cuda_device
        self.iterator = BucketIterator(batch_size=batch_size,
                                       sorting_keys=[("source_tokens", "num_tokens")])
        self.model = model
        self.data_reader = data_reader

    def _extract_data(self, batch) -> numpy.ndarray:
        out_dict = self.model(**batch)
        return out_dict

    def predict(self, file_path: str, vocab_path: str):
        ds = self.data_reader.read(file_path)
        vocab = Vocabulary.from_files(vocab_path)
        self.iterator.index_with(vocab)
        self.model.eval()
        pred_generator = self.iterator(ds, num_epochs=1, shuffle=False)
        pred_generator_tqdm = tqdm(pred_generator,
                                   total=self.iterator.get_num_batches(ds))
        preds = []
        with torch.no_grad():
            for batch in pred_generator_tqdm:
                batch = util.move_to_device(batch, self.cuda_device)
                preds.append(self._extract_data(batch))
        return preds
예제 #2
0
파일: unc_s2s.py 프로젝트: zxteloiv/AdaNSP
def run_model(args):
    st_ds_conf = get_updated_settings(args)
    reader = data_adapter.GeoQueryDatasetReader()
    training_set = reader.read(config.DATASETS[args.dataset].train_path)
    try:
        validation_set = reader.read(config.DATASETS[args.dataset].dev_path)
    except:
        validation_set = None

    vocab = allennlp.data.Vocabulary.from_instances(training_set)
    model = get_model(vocab, st_ds_conf)
    device_tag = "cpu" if config.DEVICE < 0 else f"cuda:{config.DEVICE}"
    if args.models:
        model.load_state_dict(
            torch.load(args.models[0], map_location=device_tag))

    if not args.test or not args.models:
        iterator = BucketIterator(sorting_keys=[("source_tokens", "num_tokens")
                                                ],
                                  batch_size=st_ds_conf['batch_sz'])
        iterator.index_with(vocab)

        optim = torch.optim.Adam(model.parameters(),
                                 lr=config.ADAM_LR,
                                 betas=config.ADAM_BETAS,
                                 eps=config.ADAM_EPS)
        if args.fine_tune:
            optim = torch.optim.SGD(model.parameters(), lr=config.SGD_LR)

        savepath = os.path.join(
            config.SNAPSHOT_PATH, args.dataset, 'unc_s2s',
            datetime.datetime.now().strftime('%Y%m%d-%H%M%S') + "--" +
            args.memo)
        if not os.path.exists(savepath):
            os.makedirs(savepath, mode=0o755)

        trainer = allennlp.training.Trainer(
            model=model,
            optimizer=optim,
            iterator=iterator,
            train_dataset=training_set,
            validation_dataset=validation_set,
            serialization_dir=savepath,
            cuda_device=config.DEVICE,
            num_epochs=config.TRAINING_LIMIT,
            grad_clipping=config.GRAD_CLIPPING,
            num_serialized_models_to_keep=-1,
        )

        trainer.train()

    else:
        if args.test_on_val:
            testing_set = reader.read(config.DATASETS[args.dataset].dev_path)
        else:
            testing_set = reader.read(config.DATASETS[args.dataset].test_path)

        model.eval()
        model.skip_loss = True  # skip loss computation on testing set for faster evaluation

        if config.DEVICE > -1:
            model = model.cuda(config.DEVICE)

        # batch testing
        iterator = BucketIterator(sorting_keys=[("source_tokens", "num_tokens")
                                                ],
                                  batch_size=st_ds_conf['batch_sz'])
        iterator.index_with(vocab)
        eval_generator = iterator(testing_set, num_epochs=1, shuffle=False)
        for batch in tqdm.tqdm(eval_generator,
                               total=iterator.get_num_batches(testing_set)):
            batch = move_to_device(batch, config.DEVICE)
            output = model(**batch)
        metrics = model.get_metrics()
        print(metrics)

        if args.dump_test:

            predictor = allennlp.predictors.SimpleSeq2SeqPredictor(
                model, reader)

            for instance in tqdm.tqdm(testing_set, total=len(testing_set)):
                print('SRC: ', instance.fields['source_tokens'].tokens)
                print(
                    'GOLD:', ' '.join(
                        str(x) for x in
                        instance.fields['target_tokens'].tokens[1:-1]))
                del instance.fields['target_tokens']
                output = predictor.predict_instance(instance)
                print('PRED:', ' '.join(output['predicted_tokens']))
예제 #3
0
books_train_dataset = reader.read('./data/mtl-dataset/books.task.train')
books_validation_dataset = reader.read('./data/mtl-dataset/books.task.test')
imdb_train_dataset = reader.read('./data/mtl-dataset/imdb.task.train')
imdb_test_dataset = reader.read('./data/mtl-dataset/imdb.task.test')

vocab = Vocabulary.from_instances(books_train_dataset +
                                  books_validation_dataset)
iterator = BucketIterator(batch_size=128,
                          sorting_keys=[("tokens", "num_tokens")])
iterator.index_with(vocab)
print(vocab._index_to_token)
# print(vocab.__getstate__()['_token_to_index']['labels'])
# for batch in itera  tor(books_train_dataset, num_epochs=1, shuffle=True):
#     print(batch['tokens']['tokens'], batch['label'])

print(iterator.get_num_batches(books_train_dataset))

books_iter = iter(iterator._create_batches(books_train_dataset, shuffle=True))
print(len(books_train_dataset))

print(next(books_iter).as_tensor_dict())
'''
EMBEDDING_DIM = 300

token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                            embedding_dim=EMBEDDING_DIM,
                            pretrained_file='/media/sihui/000970CB000A4CA8/Sentiment-Analysis/embeddings/glove.42B.300d.txt',
                            trainable=False)
# character_embedding = TokenCharactersEncoder(embedding=Embedding(num_embeddings=vocab.get_vocab_size('tokens_characters'), embedding_dim=8),
#                                              encoder=CnnEncoder(embedding_dim=8, num_filters=100, ngram_filter_sizes=[5]), dropout=0.2)
word_embeddings = BasicTextFieldEmbedder({'tokens': token_embedding})
예제 #4
0
def train_epoch(model, train_dataset, validation_dataset, batch_size,
                optimizer, log_period, validation_period, save_dir, log_dir,
                cuda):
    """
    Train the model for one epoch.
    """
    # Set model to train mode (turns on dropout and such).
    model.train()
    # Create objects for calculating metrics.
    span_start_accuracy_metric = CategoricalAccuracy()
    span_end_accuracy_metric = CategoricalAccuracy()
    span_accuracy_metric = BooleanAccuracy()
    squad_metrics = SquadEmAndF1()
    # Create Tensorboard logger.
    writer = SummaryWriter(log_dir)

    # Build iterater, and have it bucket batches by passage / question length.
    iterator = BucketIterator(batch_size=batch_size,
                              sorting_keys=[("passage", "num_tokens"),
                                            ("question", "num_tokens")])
    num_training_batches = iterator.get_num_batches(train_dataset)
    # Get a generator of train batches.
    train_generator = tqdm(iterator(train_dataset,
                                    num_epochs=1,
                                    cuda_device=0 if cuda else -1),
                           total=num_training_batches,
                           leave=False)
    log_period_losses = 0

    for batch in train_generator:
        # Extract the relevant data from the batch.
        passage = batch["passage"]["tokens"]
        question = batch["question"]["tokens"]
        span_start = batch["span_start"]
        span_end = batch["span_end"]
        metadata = batch.get("metadata", {})

        # Run data through model to get start and end logits.
        output_dict = model(passage, question)
        start_logits = output_dict["start_logits"]
        end_logits = output_dict["end_logits"]
        softmax_start_logits = output_dict["softmax_start_logits"]
        softmax_end_logits = output_dict["softmax_end_logits"]

        # Calculate loss for start and end indices.
        loss = nll_loss(softmax_start_logits, span_start.view(-1))
        loss += nll_loss(softmax_end_logits, span_end.view(-1))
        log_period_losses += loss.data[0]

        # Backprop and take a gradient step.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        model.global_step += 1

        # Calculate categorical span start and end accuracy.
        span_start_accuracy_metric(start_logits, span_start.view(-1))
        span_end_accuracy_metric(end_logits, span_end.view(-1))
        # Compute the best span, and calculate overall span accuracy.
        best_span = get_best_span(start_logits, end_logits)
        span_accuracy_metric(best_span, torch.stack([span_start, span_end],
                                                    -1))
        # Calculate EM and F1 scores
        calculate_em_f1(best_span, metadata, passage.size(0), squad_metrics)

        if model.global_step % log_period == 0:
            # Calculate metrics on train set.
            loss = log_period_losses / log_period
            span_start_accuracy = span_start_accuracy_metric.get_metric(
                reset=True)
            span_end_accuracy = span_end_accuracy_metric.get_metric(reset=True)
            span_accuracy = span_accuracy_metric.get_metric(reset=True)
            em, f1 = squad_metrics.get_metric(reset=True)
            tqdm_description = _make_tqdm_description(loss, em, f1)
            # Log training statistics to progress bar
            train_generator.set_description(tqdm_description)
            # Log training statistics to Tensorboard
            log_to_tensorboard(writer, model.global_step, "train", loss,
                               span_start_accuracy, span_end_accuracy,
                               span_accuracy, em, f1)
            log_period_losses = 0

        if model.global_step % validation_period == 0:
            # Calculate metrics on validation set.
            (loss, span_start_accuracy, span_end_accuracy, span_accuracy, em,
             f1) = evaluate(model, validation_dataset, batch_size, cuda)
            # Save a checkpoint.
            save_name = ("{}_step_{}_loss_{:.3f}_"
                         "em_{:.3f}_f1_{:.3f}.pth".format(
                             model.__class__.__name__, model.global_step, loss,
                             em, f1))
            save_model(model, save_dir, save_name)
            # Log validation statistics to Tensorboard.
            log_to_tensorboard(writer, model.global_step, "validation", loss,
                               span_start_accuracy, span_end_accuracy,
                               span_accuracy, em, f1)
예제 #5
0
파일: job.py 프로젝트: h0m3brew/BabyBertSRL
def main(param2val):

    # params
    params = Params.from_param2val(param2val)
    print(params, flush=True)

    #  paths
    project_path = Path(param2val['project_path'])
    save_path = Path(param2val['save_path'])
    srl_eval_path = project_path / 'perl' / 'srl-eval.pl'
    data_path_mlm = project_path / 'data' / 'training' / f'{params.corpus_name}_mlm.txt'
    data_path_train_srl = project_path / 'data' / 'training' / f'{params.corpus_name}_no-dev_srl.txt'
    data_path_devel_srl = project_path / 'data' / 'training' / f'human-based-2018_srl.txt'
    data_path_test_srl = project_path / 'data' / 'training' / f'human-based-2008_srl.txt'
    childes_vocab_path = project_path / 'data' / f'{params.corpus_name}_vocab.txt'
    google_vocab_path = project_path / 'data' / 'bert-base-cased.txt'  # to get word pieces

    # word-piece tokenizer - defines input vocabulary
    vocab = load_vocab(childes_vocab_path, google_vocab_path,
                       params.vocab_size)
    # TODO testing google vocab with wordpieces

    assert vocab['[PAD]'] == 0  # AllenNLP expects this
    assert vocab['[UNK]'] == 1  # AllenNLP expects this
    assert vocab['[CLS]'] == 2
    assert vocab['[SEP]'] == 3
    assert vocab['[MASK]'] == 4
    wordpiece_tokenizer = WordpieceTokenizer(vocab)
    print(f'Number of types in vocab={len(vocab):,}')

    # load utterances for MLM task
    utterances = load_utterances_from_file(data_path_mlm)
    train_utterances, devel_utterances, test_utterances = split(utterances)

    # load propositions for SLR task
    propositions = load_propositions_from_file(data_path_train_srl)
    train_propositions, devel_propositions, test_propositions = split(
        propositions)
    if data_path_devel_srl.is_file(
    ):  # use human-annotated data as devel split
        print(f'Using {data_path_devel_srl.name} as SRL devel split')
        devel_propositions = load_propositions_from_file(data_path_devel_srl)
    if data_path_test_srl.is_file():  # use human-annotated data as test split
        print(f'Using {data_path_test_srl.name} as SRL test split')
        test_propositions = load_propositions_from_file(data_path_test_srl)

    # converters handle conversion from text to instances
    converter_mlm = ConverterMLM(params, wordpiece_tokenizer)
    converter_srl = ConverterSRL(params, wordpiece_tokenizer)

    # get output_vocab
    # note: Allen NLP vocab holds labels, wordpiece_tokenizer.vocab holds input tokens
    # what from_instances() does:
    # 1. it iterates over all instances, and all fields, and all token indexers
    # 2. the token indexer is used to update vocabulary count, skipping words whose text_id is already set
    # 4. a PADDING and MASK symbol are added to 'tokens' namespace resulting in vocab size of 2
    # input tokens are not indexed, as they are already indexed by bert tokenizer vocab.
    # this ensures that the model is built with inputs for all vocab words,
    # such that words that occur only in LM or SRL task can still be input

    # make instances once - this allows iterating multiple times (required when num_epochs > 1)
    train_instances_mlm = converter_mlm.make_instances(train_utterances)
    devel_instances_mlm = converter_mlm.make_instances(devel_utterances)
    test_instances_mlm = converter_mlm.make_instances(test_utterances)
    train_instances_srl = converter_srl.make_instances(train_propositions)
    devel_instances_srl = converter_srl.make_instances(devel_propositions)
    test_instances_srl = converter_srl.make_instances(test_propositions)
    all_instances_mlm = chain(train_instances_mlm, devel_instances_mlm,
                              test_instances_mlm)
    all_instances_srl = chain(train_instances_srl, devel_instances_srl,
                              test_instances_srl)

    # make vocab from all instances
    output_vocab_mlm = Vocabulary.from_instances(all_instances_mlm)
    output_vocab_srl = Vocabulary.from_instances(all_instances_srl)
    # print(f'mlm vocab size={output_vocab_mlm.get_vocab_size()}')  # contain just 2 tokens
    # print(f'srl vocab size={output_vocab_srl.get_vocab_size()}')  # contain just 2 tokens
    assert output_vocab_mlm.get_vocab_size(
        'tokens') == output_vocab_srl.get_vocab_size('tokens')

    # BERT
    print('Preparing Multi-task BERT...')
    input_vocab_size = len(converter_mlm.wordpiece_tokenizer.vocab)
    bert_config = BertConfig(
        vocab_size_or_config_json_file=input_vocab_size,  # was 32K
        hidden_size=params.hidden_size,  # was 768
        num_hidden_layers=params.num_layers,  # was 12
        num_attention_heads=params.num_attention_heads,  # was 12
        intermediate_size=params.intermediate_size)  # was 3072
    bert_model = BertModel(config=bert_config)
    # Multi-tasking BERT
    mt_bert = MTBert(vocab_mlm=output_vocab_mlm,
                     vocab_srl=output_vocab_srl,
                     bert_model=bert_model,
                     embedding_dropout=params.embedding_dropout)
    mt_bert.cuda()
    num_params = sum(p.numel() for p in mt_bert.parameters()
                     if p.requires_grad)
    print('Number of model parameters: {:,}'.format(num_params), flush=True)

    # optimizers
    optimizer_mlm = BertAdam(params=mt_bert.parameters(), lr=params.lr)
    optimizer_srl = BertAdam(params=mt_bert.parameters(), lr=params.lr)
    move_optimizer_to_cuda(optimizer_mlm)
    move_optimizer_to_cuda(optimizer_srl)

    # batching
    bucket_batcher_mlm = BucketIterator(batch_size=params.batch_size,
                                        sorting_keys=[('tokens', "num_tokens")
                                                      ])
    bucket_batcher_mlm.index_with(output_vocab_mlm)
    bucket_batcher_srl = BucketIterator(batch_size=params.batch_size,
                                        sorting_keys=[('tokens', "num_tokens")
                                                      ])
    bucket_batcher_srl.index_with(output_vocab_srl)

    # big batcher to speed evaluation - 1024 is too big
    bucket_batcher_mlm_large = BucketIterator(batch_size=512,
                                              sorting_keys=[('tokens',
                                                             "num_tokens")])
    bucket_batcher_srl_large = BucketIterator(batch_size=512,
                                              sorting_keys=[('tokens',
                                                             "num_tokens")])
    bucket_batcher_mlm_large.index_with(output_vocab_mlm)
    bucket_batcher_srl_large.index_with(output_vocab_srl)

    # init performance collection
    name2col = {
        'devel_pps': [],
        'devel_f1s': [],
    }

    # init
    eval_steps = []
    train_start = time.time()
    loss_mlm = None
    no_mlm_batches = False
    step = 0

    # generators
    train_generator_mlm = bucket_batcher_mlm(train_instances_mlm,
                                             num_epochs=params.num_mlm_epochs)
    train_generator_srl = bucket_batcher_srl(
        train_instances_srl, num_epochs=None)  # infinite generator
    num_train_mlm_batches = bucket_batcher_mlm.get_num_batches(
        train_instances_mlm)
    if params.srl_interleaved:
        max_step = num_train_mlm_batches
    else:
        max_step = num_train_mlm_batches * 2
    print(f'Will stop training at step={max_step:,}')

    while step < max_step:

        # TRAINING
        if step != 0:  # otherwise evaluation at step 0 is influenced by training on one batch
            mt_bert.train()

            # masked language modeling task
            try:
                batch_mlm = next(train_generator_mlm)
            except StopIteration:
                if params.srl_interleaved:
                    break
                else:
                    no_mlm_batches = True
            else:
                loss_mlm = mt_bert.train_on_batch('mlm', batch_mlm,
                                                  optimizer_mlm)

            # semantic role labeling task
            if params.srl_interleaved:
                if random.random() < params.srl_probability:
                    batch_srl = next(train_generator_srl)
                    mt_bert.train_on_batch('srl', batch_srl, optimizer_srl)
            elif no_mlm_batches:
                batch_srl = next(train_generator_srl)
                mt_bert.train_on_batch('srl', batch_srl, optimizer_srl)

        # EVALUATION
        if step % config.Eval.interval == 0:
            mt_bert.eval()
            eval_steps.append(step)

            # evaluate perplexity
            devel_generator_mlm = bucket_batcher_mlm_large(devel_instances_mlm,
                                                           num_epochs=1)
            devel_pp = evaluate_model_on_pp(mt_bert, devel_generator_mlm)
            name2col['devel_pps'].append(devel_pp)
            print(f'devel-pp={devel_pp}', flush=True)

            # test sentences
            if config.Eval.test_sentences:
                test_generator_mlm = bucket_batcher_mlm_large(
                    test_instances_mlm, num_epochs=1)
                out_path = save_path / f'test_split_mlm_results_{step}.txt'
                predict_masked_sentences(mt_bert, test_generator_mlm, out_path)

            # probing - test sentences for specific syntactic tasks
            for name in config.Eval.probing_names:
                # prepare data
                probing_data_path_mlm = project_path / 'data' / 'probing' / f'{name}.txt'
                if not probing_data_path_mlm.exists():
                    print(f'WARNING: {probing_data_path_mlm} does not exist')
                    continue
                probing_utterances_mlm = load_utterances_from_file(
                    probing_data_path_mlm)
                # check that probing words are in vocab
                for u in probing_utterances_mlm:
                    # print(u)
                    for w in u:
                        if w == '[MASK]':
                            continue  # not in output vocab
                        # print(w)
                        assert output_vocab_mlm.get_token_index(
                            w, namespace='labels'), w
                # probing + save results to text
                probing_instances_mlm = converter_mlm.make_probing_instances(
                    probing_utterances_mlm)
                probing_generator_mlm = bucket_batcher_mlm(
                    probing_instances_mlm, num_epochs=1)
                out_path = save_path / f'probing_{name}_results_{step}.txt'
                predict_masked_sentences(mt_bert,
                                         probing_generator_mlm,
                                         out_path,
                                         print_gold=False,
                                         verbose=True)

            # evaluate devel f1
            devel_generator_srl = bucket_batcher_srl_large(devel_instances_srl,
                                                           num_epochs=1)
            devel_f1 = evaluate_model_on_f1(mt_bert, srl_eval_path,
                                            devel_generator_srl)

            name2col['devel_f1s'].append(devel_f1)
            print(f'devel-f1={devel_f1}', flush=True)

            # console
            min_elapsed = (time.time() - train_start) // 60
            pp = torch.exp(loss_mlm) if loss_mlm is not None else np.nan
            print(
                f'step {step:<6,}: pp={pp :2.4f} total minutes elapsed={min_elapsed:<3}',
                flush=True)

        # only increment step once in each iteration of the loop, otherwise evaluation may never happen
        step += 1

    # evaluate train perplexity
    if config.Eval.train_split:
        generator_mlm = bucket_batcher_mlm_large(train_instances_mlm,
                                                 num_epochs=1)
        train_pp = evaluate_model_on_pp(mt_bert, generator_mlm)
    else:
        train_pp = np.nan
    print(f'train-pp={train_pp}', flush=True)

    # evaluate train f1
    if config.Eval.train_split:
        generator_srl = bucket_batcher_srl_large(train_instances_srl,
                                                 num_epochs=1)
        train_f1 = evaluate_model_on_f1(mt_bert,
                                        srl_eval_path,
                                        generator_srl,
                                        print_tag_metrics=True)
    else:
        train_f1 = np.nan
    print(f'train-f1={train_f1}', flush=True)

    # test sentences
    if config.Eval.test_sentences:
        test_generator_mlm = bucket_batcher_mlm(test_instances_mlm,
                                                num_epochs=1)
        out_path = save_path / f'test_split_mlm_results_{step}.txt'
        predict_masked_sentences(mt_bert, test_generator_mlm, out_path)

    # probing - test sentences for specific syntactic tasks
    for name in config.Eval.probing_names:
        # prepare data
        probing_data_path_mlm = project_path / 'data' / 'probing' / f'{name}.txt'
        if not probing_data_path_mlm.exists():
            print(f'WARNING: {probing_data_path_mlm} does not exist')
            continue
        probing_utterances_mlm = load_utterances_from_file(
            probing_data_path_mlm)
        probing_instances_mlm = converter_mlm.make_probing_instances(
            probing_utterances_mlm)
        # batch and do inference
        probing_generator_mlm = bucket_batcher_mlm(probing_instances_mlm,
                                                   num_epochs=1)
        out_path = save_path / f'probing_{name}_results_{step}.txt'
        predict_masked_sentences(mt_bert,
                                 probing_generator_mlm,
                                 out_path,
                                 print_gold=False,
                                 verbose=True)

    # put train-pp and train-f1 into pandas Series
    s1 = pd.Series([train_pp], index=[eval_steps[-1]])
    s1.name = 'train_pp'
    s2 = pd.Series([train_f1], index=[eval_steps[-1]])
    s2.name = 'train_f1'

    # return performance as pandas Series
    series_list = [s1, s2]
    for name, col in name2col.items():
        print(f'Making pandas series with name={name} and length={len(col)}')
        s = pd.Series(col, index=eval_steps)
        s.name = name
        series_list.append(s)

    return series_list
SOURCE_FIELD_NAME = 'source_tokens'
TARGET_FIELD_NAME = 'target_tokens'

if __name__ == '__main__':
    print('Reading...')
    train = lfds.SmallParallelEnJa('train') \
        .to_allennlp(source_field_name=SOURCE_FIELD_NAME, target_field_name=TARGET_FIELD_NAME).all()
    validation = lfds.SmallParallelEnJa('dev') \
        .to_allennlp(source_field_name=SOURCE_FIELD_NAME, target_field_name=TARGET_FIELD_NAME).all()

    if not osp.exists('./enja_vocab'):
        print('Building vocabulary...')
        vocab = Vocabulary.from_instances(train + validation,
                                          max_vocab_size=50000)
        print(f'Vocab Size: {vocab.get_vocab_size()}')

        print('Saving...')
        vocab.save_to_files('./enja_vocab')
    else:
        print('Loading vocabulary...')
        vocab = Vocabulary.from_files('./enja_vocab')

    iterator = BucketIterator(sorting_keys=[(SOURCE_FIELD_NAME, 'num_tokens')],
                              batch_size=32)
    iterator.index_with(vocab)

    num_batches = iterator.get_num_batches(train)

    for batch in Tqdm.tqdm(iterator(train, num_epochs=1), total=num_batches):
        ...
예제 #7
0
def train(args):
    source_reader = ACSADatasetReader(max_sequence_len=args.max_seq_len)
    target_reader = ABSADatasetReader(max_sequence_len=args.max_seq_len)

    source_dataset_train = source_reader.read('./data/MGAN/data/restaurant/train.txt')
    source_dataset_dev = source_reader.read('./data/MGAN/data/restaurant/test.txt')

    target_dataset_train = target_reader.read('/media/sihui/000970CB000A4CA8/Sentiment-Analysis/data/semeval14/Restaurants_Train.xml.seg')
    target_dataset_dev = target_reader.read('/media/sihui/000970CB000A4CA8/Sentiment-Analysis/data/semeval14/Restaurants_Test_Gold.xml.seg')

    vocab = Vocabulary.from_instances(source_dataset_train + source_dataset_dev + target_dataset_train + target_dataset_dev)
    word2idx = vocab.get_token_to_index_vocabulary()
    print(word2idx)
    embedding_matrix = build_embedding_matrix(word2idx, 300, './embedding/embedding_res_res.dat', '/media/sihui/000970CB000A4CA8/Sentiment-Analysis/embeddings/glove.42B.300d.txt')

    iterator = BucketIterator(batch_size=args.batch_size, sorting_keys=[('text', 'num_tokens'), ('aspect', 'num_tokens')])
    iterator.index_with(vocab)

    my_net = ACSA2ABSA(args, word_embeddings=embedding_matrix)

    optimizer = optim.Adam(my_net.parameters(), lr=args.learning_rate)
    loss_class = torch.nn.CrossEntropyLoss()
    loss_domain = torch.nn.CrossEntropyLoss()

    my_net = my_net.to(args.device)
    loss_class = loss_class.to(args.device)
    loss_domain = loss_domain.to(args.device)

    n_epoch = args.epoch

    max_test_acc = 0
    best_epoch = 0

    data_target_iter = iter(iterator(target_dataset_train, shuffle=True))
    # iterator over it forever

    for epoch in range(n_epoch):
        len_target_dataloader = iterator.get_num_batches(target_dataset_train)
        len_source_dataloader = iterator.get_num_batches(source_dataset_train)
        data_source_iter = iter(iterator._create_batches(source_dataset_train, shuffle=True))
        # data_target_iter = iter(iterator._create_batches(target_dataset_train, shuffle=True))
        s_correct, s_total = 0, 0
        i = 0
        while i < len_source_dataloader:
            my_net.train()
            p = float(i + epoch * len_target_dataloader) / n_epoch / len_target_dataloader
            alpha = 2. / (1. + np.exp(-10 * p)) - 1

            # train model using source data
            data_source = next(data_source_iter).as_tensor_dict()
            s_text, s_aspect, s_label = data_source['text']['tokens'], data_source['aspect']['tokens'], data_source['label']
            batch_size = len(s_label)

            s_domain_label = torch.zeros(batch_size).long().to(args.device)

            my_net.zero_grad()

            s_text, s_aspect, s_label = s_text.to(args.device), s_aspect.to(args.device), s_label.to(args.device)
            s_class_output, s_domain_output = my_net(s_text, s_aspect, alpha)

            err_s_label = loss_class(s_class_output, s_label)
            # err_s_domain = loss_domain(s_domain_output, s_domain_label)

            # training model using target data
            # data_target = next(data_target_iter).as_tensor_dict()
            '''
            data_target = next(data_target_iter)
            t_text, t_aspect, t_label = data_target['text']['tokens'], data_target['aspect']['tokens'], data_target['label']

            batch_size = len(t_label)
            t_domain_label = torch.ones(batch_size).long().to(args.device)

            t_text, t_aspect, t_label = t_text.to(args.device), t_aspect.to(args.device), t_label.to(args.device)

            t_class_output, t_domain_output = my_net(t_text, t_aspect, alpha)
            # err_t_domain = loss_domain(t_domain_output, t_domain_label)
            '''
            # loss = err_t_domain + err_s_domain + err_s_label
            loss = err_s_label
            loss.backward()

            if args.use_grad_clip:
                clip_grad_norm_(my_net.parameters(), args.grad_clip)

            optimizer.step()

            i += 1

            s_correct += (torch.argmax(s_class_output, -1) == s_label).sum().item()
            s_total += len(s_class_output)
            train_acc = s_correct / s_total

            # evaluate every 50 batch
            if i % 100 == 0:
                my_net.eval()
                # evaluate model on source test data
                s_test_correct, s_test_total = 0, 0
                s_targets_all, s_output_all = None, None
                with torch.no_grad():
                    for i_batch, s_test_batch in enumerate(iterator(source_dataset_dev, num_epochs=1, shuffle=False)):
                        s_test_text = s_test_batch['text']['tokens'].to(args.device)
                        s_test_aspect = s_test_batch['aspect']['tokens'].to(args.device)
                        s_test_label = s_test_batch['label'].to(args.device)

                        s_test_output, _ = my_net(s_test_text, s_test_aspect, alpha)

                        s_test_correct += (torch.argmax(s_test_output, -1) == s_test_label).sum().item()
                        s_test_total += len(s_test_label)

                        if s_targets_all is None:
                            s_targets_all = s_test_label
                            s_output_all = s_test_output
                        else:
                            s_targets_all = torch.cat((s_targets_all, s_test_label), dim=0)
                            s_output_all = torch.cat((s_output_all, s_test_output), dim=0)

                s_test_acc = s_test_correct / s_test_total
                if s_test_acc > max_test_acc:
                    max_test_acc = s_test_acc
                    best_epoch = epoch
                    if not os.path.exists('state_dict'):
                        os.mkdir('state_dict')
                    if s_test_acc > 0.868:
                        path = 'state_dict/source_test_epoch{0}_acc_{1}'.format(epoch, round(s_test_acc, 4))
                        torch.save(my_net.state_dict(), path)

                print('epoch: %d, [iter: %d / all %d], loss_s_label: %f, '
                      's_train_acc: %f, s_test_acc: %f'% (epoch, i, len_source_dataloader,
                                                                             err_s_label.cpu().item(),
                                                                             #err_s_domain.cpu().item(),
                                                                             #err_t_domain.cpu().item(),
                                                                             train_acc,
                                                                             s_test_acc))
    print('max_test_acc: {0} in epoch: {1}'.format(max_test_acc, best_epoch))