Exemple #1
0
 def build_vocab(self, embed_file: str = None) -> Vocab:
     word_counts = Counter()
     count_words(word_counts, [src + tgr for src, tgr in self.pairs])
     vocab = Vocab()
     for word, count in word_counts.most_common(config.max_vocab_size):
         vocab.add_words([word])
     if embed_file is not None:
         count = vocab.load_embeddings(embed_file)
         print("%d pre-trained embeddings loaded." % count)
     return vocab
Exemple #2
0
def main():
    args_parser = argparse.ArgumentParser()
    args_parser.add_argument('--config_path', required=True)
    args = args_parser.parse_args()
    config = Config(None)
    config.load_config(args.config_path)

    logger = get_logger("RSTParser (Top-Down) RUN", config.use_dynamic_oracle,
                        config.model_path)
    word_alpha, tag_alpha, gold_action_alpha, action_label_alpha, relation_alpha, nuclear_alpha, nuclear_relation_alpha, etype_alpha = create_alphabet(
        None, config.alphabet_path, logger)
    vocab = Vocab(word_alpha, tag_alpha, etype_alpha, gold_action_alpha,
                  action_label_alpha, relation_alpha, nuclear_alpha,
                  nuclear_relation_alpha)

    network = MainArchitecture(vocab, config)

    if config.use_gpu:
        network.load_state_dict(torch.load(config.model_name))
        network = network.cuda()
    else:
        network.load_state_dict(
            torch.load(config.model_name, map_location=torch.device('cpu')))

    network.eval()

    logger.info('Reading dev instance, and predict...')
    reader = Reader(config.dev_path, config.dev_syn_feat_path)
    dev_instances = reader.read_data()
    predict(network, dev_instances, vocab, config, logger)

    logger.info('Reading test instance, and predict...')
    reader = Reader(config.test_path, config.test_syn_feat_path)
    test_instances = reader.read_data()
    predict(network, test_instances, vocab, config, logger)
def main():
    args_parser = argparse.ArgumentParser()
    args_parser.add_argument('--config_path', required=True)
    args = args_parser.parse_args()
    config = Config(None)
    config.load_config(args.config_path)

    logger = get_logger("RSTParser RUN", config.use_dynamic_oracle,
                        config.model_path)
    word_alpha, tag_alpha, gold_action_alpha, action_label_alpha, etype_alpha = create_alphabet(
        None, config.alphabet_path, logger)
    vocab = Vocab(word_alpha, tag_alpha, etype_alpha, gold_action_alpha,
                  action_label_alpha)

    network = MainArchitecture(vocab, config)
    network.load_state_dict(torch.load(config.model_name))

    if config.use_gpu:
        network = network.cuda()
    network.eval()

    logger.info('Reading test instance')
    reader = Reader(config.test_path, config.test_syn_feat_path)
    test_instances = reader.read_data()
    time_start = datetime.now()
    batch_size = config.batch_size
    span = Metric()
    nuclear = Metric()
    relation = Metric()
    full = Metric()
    predictions = []
    total_data_test = len(test_instances)
    for i in range(0, total_data_test, batch_size):
        end_index = i + batch_size
        if end_index > total_data_test:
            end_index = total_data_test
        indices = np.array(range(i, end_index))
        subset_data_test = batch_data_variable(test_instances, indices, vocab,
                                               config)
        prediction_of_subtrees = network.loss(subset_data_test, None)
        predictions += prediction_of_subtrees
    for i in range(total_data_test):
        span, nuclear, relation, full = test_instances[i].evaluate(
            predictions[i], span, nuclear, relation, full)
    time_elapsed = datetime.now() - time_start
    m, s = divmod(time_elapsed.seconds, 60)
    logger.info('TEST is finished in {} mins {} secs'.format(m, s))
    logger.info("S: " + span.print_metric())
    logger.info("N: " + nuclear.print_metric())
    logger.info("R: " + relation.print_metric())
    logger.info("F: " + full.print_metric())

    import ipdb
    ipdb.set_trace()
Exemple #4
0
def prepare(args):
    """
    checks data, creates the directories, prepare the vocabulary and embeddings
    """
    logger = logging.getLogger("brc")
    logger.info('Checking the data files...')
    for data_path in args.train_files + args.dev_files + args.test_files:
        assert os.path.exists(data_path), '{} file does not exist.'.format(data_path)
    logger.info('Preparing the directories...')
    for dir_path in [args.vocab_dir, args.model_dir, args.result_dir, args.summary_dir]:
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)

    logger.info('Building vocabulary...')
    word_vocab = Vocab(lower=True)
    char_vocab = Vocab(lower=True)
    '''
        for word in brc_data.word_iter('train'):
        word_vocab.add(word)
        charlst = seg_char(word)
        for char in charlst:
            char_vocab.add(char)
    unfiltered_word_vocab_size = word_vocab.size()
    unfiltered_char_vocab_size = char_vocab.size()
    word_vocab.filter_tokens_by_cnt(min_cnt=4)
    char_vocab.filter_tokens_by_cnt(min_cnt=4)
    filtered_num = unfiltered_word_vocab_size - word_vocab.size()
    logger.info('After filter {} tokens, the final word_vocab size is {}'.format(filtered_num,
                                                                            word_vocab.size()))
    filtered_num = unfiltered_char_vocab_size - char_vocab.size()
    logger.info('After filter {} tokens, the final char_vocab size is {}'.format(filtered_num,
                                                                            char_vocab.size()))
    logger.info('Assigning embeddings...')
    word_vocab.randomly_init_embeddings(args.embed_size)
    char_vocab.randomly_init_embeddings(args.embed_size)

    '''
    word_vocab.load_pretrained_embeddings(args.vocab_dir+'word.emb')
    char_vocab.load_pretrained_embeddings(args.vocab_dir+'char.emb')
    logger.info('Saving word_vocab...')
    with open(os.path.join(args.vocab_dir, 'word_vocab.data'), 'wb') as fout:
        pickle.dump(word_vocab, fout)

    logger.info('Saving char_vocab...')
    with open(os.path.join(args.vocab_dir, 'char_vocab.data'), 'wb') as fout:
        pickle.dump(char_vocab, fout)

    logger.info('Done with preparing!')
Exemple #5
0
            highlights.append(word2id(line))
        else:
            article_lines.append(word2id(line))

    # Make abstract into a signle string, putting <s> and </s> tags around the sentences
    abstract = ' '.join(["%s %s %s" % (vocab.word2id('<s>'), sent, vocab.word2id('</s>')) for sent in highlights])
    abstract = '%s %s %s' % (vocab.word2id(vocab.DECODING_START), abstract, vocab.word2id(vocab.DECODING_STOP))
    
    article = ' '.join(["%s %s %s" % (vocab.word2id('<s>'), sent, vocab.word2id('</s>')) for sent in article_lines])
    article = '%s %s %s' % (vocab.word2id(vocab.DECODING_START), article, vocab.word2id(vocab.DECODING_STOP))
    
    return article, abstract


if __name__ == '__main__':
    vocab = Vocab('../data/vocab')
    tokenized_path = '/home/xuyang/data/tokenized/cnn_stories_tokenized'
    article_path = '/home/xuyang/data/articles'
    abstract_path = '/home/xuyang/data/abstracts'
    file_list = os.listdir(tokenized_path)

    for file in tqdm(file_list, ncols=10):
        articles_line, abstracts_line = get_art_abs(os.path.join(tokenized_path, file))
        
        abstracts = list(map(int, abstracts_line.split()))
        try:
            articles = list(map(int, articles_line.split()))
        except:
            print('#'+articles_line+'#')
            break
        pickle.dump(articles, open(os.path.join(article_path, file),'wb'))
import pickle
from models.vocab import Vocab
from models.rouge_calc import rouge_1, rouge_2, rouge_3, rouge_n
from models.objective import *
import numpy as np

# Test with toy dataset
with open('../toy_dataset/toy_dataset_word_lvl.p', 'rb') as pickle_file:
    data = pickle.load(pickle_file)

summaries = data[0]
documents = data[1]

print(documents)

vocab = Vocab.load('../toy_dataset/toy_vocab.json')
device = torch.device('cpu')
doc_indices = vocab.to_input_tensor(documents, device)

print(doc_indices)

T = 3
batch_size, doc_len, _ = doc_indices.size()
print('batch_size: %d' % batch_size)
torch.manual_seed(0)
sents_scores = torch.rand(
    batch_size, T, doc_len,
    requires_grad=True)  # has shape:(doc_len, batch_size, T)

pred_sents_indices = torch.tensor([[0, 2, 1], [2, 1, 0]],
                                  device=device)  #shape (batch_size, T)
Exemple #7
0
def main():
    start_a = time.time()

    args_parser = argparse.ArgumentParser()
    args_parser.add_argument('--word_embedding',
                             default='glove',
                             help='Embedding for words')
    args_parser.add_argument('--word_embedding_file',
                             default=main_path +
                             'Data/NeuralRST/glove.6B.200d.txt.gz')
    args_parser.add_argument('--train',
                             default=main_path + 'Data/NeuralRST/rst.train312')
    args_parser.add_argument('--test',
                             default=main_path + 'Data/NeuralRST/rst.test38')
    args_parser.add_argument('--dev',
                             default=main_path + 'Data/NeuralRST/rst.dev35')
    args_parser.add_argument(
        '--train_syn_feat',
        default=main_path +
        'Data/NeuralRST/SyntaxBiaffine/train.conll.dump.results')
    args_parser.add_argument(
        '--test_syn_feat',
        default=main_path +
        'Data/NeuralRST/SyntaxBiaffine/test.conll.dump.results')
    args_parser.add_argument(
        '--dev_syn_feat',
        default=main_path +
        'Data/NeuralRST/SyntaxBiaffine/dev.conll.dump.results')
    args_parser.add_argument('--model_path',
                             default=main_path +
                             'Workspace/NeuralRST/experiment')
    args_parser.add_argument('--experiment',
                             help='Name of your experiment',
                             required=True)
    args_parser.add_argument('--model_name', default='network.pt')
    args_parser.add_argument('--max_iter',
                             type=int,
                             default=1000,
                             help='maximum epoch')

    args_parser.add_argument('--word_dim',
                             type=int,
                             default=200,
                             help='Dimension of word embeddings')
    args_parser.add_argument('--tag_dim',
                             type=int,
                             default=200,
                             help='Dimension of POS tag embeddings')
    args_parser.add_argument('--etype_dim',
                             type=int,
                             default=100,
                             help='Dimension of Etype embeddings')
    args_parser.add_argument('--syntax_dim',
                             type=int,
                             default=1200,
                             help='Dimension of Etype embeddings')
    args_parser.add_argument(
        '--freeze',
        default=True,
        help='frozen the word embedding (disable fine-tuning).')

    args_parser.add_argument('--max_sent_size',
                             type=int,
                             default=20,
                             help='maximum word size in 1 edu')
    args_parser.add_argument('--max_edu_size',
                             type=int,
                             default=120,
                             help='maximum edu size')
    args_parser.add_argument('--max_state_size',
                             type=int,
                             default=1024,
                             help='maximum decoding steps')
    args_parser.add_argument('--hidden_size', type=int, default=200, help='')

    args_parser.add_argument('--drop_prob',
                             type=float,
                             default=0.2,
                             help='default drop_prob')
    args_parser.add_argument('--num_layers',
                             type=int,
                             default=1,
                             help='number of RNN layers')

    args_parser.add_argument('--batch_size',
                             type=int,
                             default=8,
                             help='Number of sentences in each batch')
    args_parser.add_argument('--lr',
                             type=float,
                             default=0.001,
                             help='Learning rate')
    args_parser.add_argument('--ada_eps',
                             type=float,
                             default=1e-8,
                             help='epsilon for adam or adamax')
    args_parser.add_argument(
        '--opt',
        default='adam',
        help='Optimization, choose between adam, sgd, and adamax')
    args_parser.add_argument('--start_decay', type=int, default=0, help='')

    args_parser.add_argument('--beta1',
                             type=float,
                             default=0.9,
                             help='beta1 for adam')
    args_parser.add_argument('--beta2',
                             type=float,
                             default=0.999,
                             help='beta2 for adam')
    args_parser.add_argument('--gamma',
                             type=float,
                             default=2e-6,
                             help='weight for regularization')
    args_parser.add_argument('--clip',
                             type=float,
                             default=10.0,
                             help='gradient clipping')

    args_parser.add_argument('--decay', type=int, default=0, help='')
    args_parser.add_argument('--oracle_prob',
                             type=float,
                             default=0.66666,
                             help='')
    args_parser.add_argument('--start_dynamic_oracle',
                             type=int,
                             default=20,
                             help='')
    args_parser.add_argument('--use_dynamic_oracle',
                             type=int,
                             default=0,
                             help='')
    args_parser.add_argument('--early_stopping', type=int, default=50, help='')

    args = args_parser.parse_args()
    config = Config(args)

    torch.manual_seed(123)
    if config.use_gpu:
        torch.cuda.manual_seed_all(999)

    if not os.path.exists(config.model_path):
        os.makedirs(config.model_path)

    logger = get_logger("RSTParser", config.use_dynamic_oracle,
                        config.model_path)
    if config.use_dynamic_oracle:
        logger.info(
            "This is using DYNAMIC oracle, and will be activated at Epoch %d" %
            (config.start_dynamic_oracle))
        model_name = 'dynamic_' + config.model_name
    else:
        logger.info("This is using STATIC oracle")
        model_name = 'static_' + config.model_name

    logger.info("Load word embedding")
    pretrained_embed, word_dim = load_embedding_dict(
        config.word_embedding, config.word_embedding_file)
    assert (word_dim == config.word_dim)

    logger.info("Reading Train start")
    reader = Reader(config.train_path, config.train_syn_feat_path)
    train_instances = reader.read_data()
    logger.info('Finish reading training instances: ' +
                str(len(train_instances)))
    # config.max_edu_size, config.max_sent_size, config.max_state_size = get_max_parameter (train_instances)
    logger.info('Max edu size: ' + str(config.max_edu_size))
    logger.info('Max sentence size: ' + str(config.max_sent_size))
    logger.info('Max gold action / state size: ' + str(config.max_state_size))

    logger.info('Creating Alphabet....')
    config.model_name = os.path.join(config.model_path, config.model_name)
    word_alpha, tag_alpha, gold_action_alpha, action_label_alpha, etype_alpha = create_alphabet(
        train_instances, config.alphabet_path, logger)
    vocab = Vocab(word_alpha, tag_alpha, etype_alpha, gold_action_alpha,
                  action_label_alpha)
    set_label_action(action_label_alpha.alpha2id, train_instances)

    logger.info('Checking Gold Actions....')
    validate_gold_actions(train_instances, config.max_state_size)
    word_table = construct_embedding_table(word_alpha, config.word_dim,
                                           config.freeze, pretrained_embed)
    tag_table = construct_embedding_table(tag_alpha, config.tag_dim,
                                          config.freeze)
    etype_table = construct_embedding_table(etype_alpha, config.etype_dim,
                                            config.freeze)

    logger.info("Finish reading train data by: " + str(time.time() - start_a))

    # DEV data processing
    reader = Reader(config.dev_path, config.dev_syn_feat_path)
    dev_instances = reader.read_data()
    logger.info('Finish reading dev instances')

    # TEST data processing
    reader = Reader(config.test_path, config.test_syn_feat_path)
    test_instances = reader.read_data()
    logger.info('Finish reading test instances')

    torch.set_num_threads(4)
    network = MainArchitecture(vocab, config, word_table, tag_table,
                               etype_table)

    if config.freeze:
        network.word_embedd.freeze()
    if config.use_gpu:
        network.cuda()

    # Set-up Optimizer
    def generate_optimizer(config, params):
        params = filter(lambda param: param.requires_grad, params)
        if config.opt == 'adam':
            return Adam(params,
                        lr=config.lr,
                        betas=config.betas,
                        weight_decay=config.gamma,
                        eps=config.ada_eps)
        elif config.opt == 'sgd':
            return SGD(params,
                       lr=config.lr,
                       momentum=config.momentum,
                       weight_decay=config.start_decay,
                       nesterov=True)
        elif opt == 'adamax':
            return Adamax(params,
                          lr=config.lr,
                          betas=config.betas,
                          weight_decay=config.start_decay,
                          eps=config.ada_eps)
        else:
            raise ValueError('Unknown optimization algorithm: %s' % config.opt)

    optim = generate_optimizer(config, network.parameters())
    opt_info = 'opt: %s, ' % config.opt
    if config.opt == 'adam':
        opt_info += 'betas=%s, eps=%.1e, lr=%.2f, weight_decay=%.1e' % (
            config.betas, config.ada_eps, config.lr, config.gamma)
    elif config.opt == 'sgd':
        opt_info += 'momentum=%.2f' % config.momentum
    elif config.opt == 'adamax':
        opt_info += 'betas=%s, eps=%.1e, lr=%f' % (config.betas,
                                                   config.ada_eps, config.lr)

    logger.info(opt_info)

    def get_subtrees(data, indices):
        subtrees = []
        for i in indices:
            subtrees.append(data[i].result)
        return subtrees

    # START TRAINING
    config.save()
    batch_size = config.batch_size
    logger.info('Start doing training....')
    total_data = len(train_instances)
    logger.info('Batch size: %d' % batch_size)
    num_batch = total_data / batch_size + 1
    es_counter = 0
    best_S = 0
    best_N = 0
    best_R = 0
    best_F = 0
    iteration = -1

    for epoch in range(0, config.max_iter):
        logger.info('Epoch %d ' % (epoch))
        logger.info("Current learning rate: %.4f" % (config.lr))

        if epoch == config.start_dynamic_oracle and config.use_dynamic_oracle:
            logger.info("In this epoch, dynamic oracle is activated!")
            config.flag_oracle = True

        permutation = torch.randperm(total_data).long()
        network.metric.reset()
        time_start = datetime.now()
        for i in range(0, total_data, batch_size):
            network.train()
            network.training = True

            indices = permutation[i:i + batch_size]
            # subset_data =  words_var, tags_var, etypes_var, edu_mask_var, word_mask_var, gold_actions_var, len_edus, word_denominator, syntax
            subset_data = batch_data_variable(train_instances, indices, vocab,
                                              config)
            gold_subtrees = get_subtrees(train_instances, indices)

            cost, cost_val = network.loss(subset_data, gold_subtrees)
            cost.backward()
            clip_grad_norm(network.parameters(), config.clip)
            optim.step()
            network.zero_grad()
            time_elapsed = datetime.now() - time_start
            m, s = divmod(time_elapsed.seconds, 60)
            logger.info(
                'Epoch %d, Batch %d, Cost: %.2f, Correct: %.2f, {} mins {} secs'
                .format(m, s) % (epoch, (i + batch_size) / batch_size,
                                 cost_val, network.metric.get_accuracy()))
        logger.info('Batch ends, performing test for DEV set')

        # START EVALUATING DEV:
        network.eval()
        network.training = False
        time_start = datetime.now()
        span = Metric()
        nuclear = Metric()
        relation = Metric()
        full = Metric()
        predictions = []
        total_data_dev = len(dev_instances)
        for i in range(0, total_data_dev, batch_size):
            end_index = i + batch_size
            if end_index > total_data_dev:
                end_index = total_data_dev
            indices = np.array((range(i, end_index)))
            subset_data_dev = batch_data_variable(dev_instances, indices,
                                                  vocab, config)
            prediction_of_subtrees = network.loss(subset_data_dev, None)
            predictions += prediction_of_subtrees
        for i in range(total_data_dev):
            span, nuclear, relation, full = dev_instances[i].evaluate(
                predictions[i], span, nuclear, relation, full)
        time_elapsed = datetime.now() - time_start
        m, s = divmod(time_elapsed.seconds, 60)
        logger.info('DEV is finished in {} mins {} secs'.format(m, s))
        logger.info("S: " + span.print_metric())
        logger.info("N: " + nuclear.print_metric())
        logger.info("R: " + relation.print_metric())
        logger.info("F: " + full.print_metric())

        if best_F < full.get_f_measure():
            best_S = span.get_f_measure()
            best_N = nuclear.get_f_measure()
            best_R = relation.get_f_measure()
            best_F = full.get_f_measure()
            iteration = epoch
            #save the model
            config.save()
            torch.save(network.state_dict(), config.model_name)
            logger.info('Model is successfully saved')
            es_counter = 0
        else:
            logger.info(
                "NOT exceed best Full F-score: history = %.2f, current = %.2f"
                % (best_F, full.get_f_measure()))
            logger.info(
                "Best dev performance in Iteration %d with result S: %.4f, N: %.4f, R: %.4f, F: %.4f"
                % (iteration, best_S, best_N, best_R, best_F))
            if es_counter > config.early_stopping:
                logger.info(
                    'Early stopping after getting lower DEV performance in %d consecutive epoch. BYE, Assalamualaikum!'
                    % (es_counter))
                sys.exit()
            es_counter += 1
        # # START EVALUATING TEST:
        time_start = datetime.now()
        span = Metric()
        nuclear = Metric()
        relation = Metric()
        full = Metric()
        predictions = []
        total_data_test = len(test_instances)
        for i in range(0, total_data_test, batch_size):
            end_index = i + batch_size
            if end_index > total_data_test:
                end_index = total_data_test
            indices = np.array(range(i, end_index))
            subset_data_test = batch_data_variable(test_instances, indices,
                                                   vocab, config)
            prediction_of_subtrees = network.loss(subset_data_test, None)
            predictions += prediction_of_subtrees
        for i in range(total_data_test):
            span, nuclear, relation, full = test_instances[i].evaluate(
                predictions[i], span, nuclear, relation, full)
        time_elapsed = datetime.now() - time_start
        m, s = divmod(time_elapsed.seconds, 60)
        logger.info('TEST is finished in {} mins {} secs'.format(m, s))
        logger.info("S: " + span.print_metric())
        logger.info("N: " + nuclear.print_metric())
        logger.info("R: " + relation.print_metric())
        logger.info("F: " + full.print_metric())
Exemple #8
0
from tqdm import trange
from models.vocab import Vocab

model_param = {'emb_size': 200, 'hidden_size': 200}
vocab_path = 'data/vocab'
source_path = 'data/cnn_articles'
target_path = 'data/cnn_abstracts'
lda_path = 'data/cnn_head_lda'

model = Model(model_param,
              vocab_path,
              mode='decode',
              head_attention=True,
              decoder_cell='dlstm',
              ctx=cpu())

res = model.decode(source_path, lda_path, 'best.model')

res = [int(i.asscalar()) for i in res.tokens]

vocab = Vocab(vocab_path)

res = [vocab.id2word(i) for i in res]
print(' '.join(res))

abstract = pickle.load(
    open(os.path.join(target_path,
                      os.listdir(target_path)[0]), 'rb'))
res = [vocab.id2word(i) for i in abstract]
print(' '.join(res))
def main(emb_path='glove.6B.100d.txt', data_path='data/msdialogue/'):
    device = 'cpu'
    if torch.cuda.is_available():
        device = 'cuda'
    print(f'DEVICE : {device}')
    params = {'batch_size': 128, 'shuffle': True}

    # 1) Data loading
    # # Пока так для дебага
    # X, y = load_from_json(data_path)
    # # 1. One-Hot Encode
    # labels = {'O': 0, 'FQ': 1, 'IR': 2,
    #           'OQ': 3, 'GG': 4, 'FD': 5,
    #           'JK': 6, 'NF': 7, 'PF': 8,
    #           'RQ': 9, 'CQ': 10, 'PA': 11}
    # y_train = []
    # for l in y:
    #     l = l.split('_')
    #     cur_y = [0] * len(labels)
    #     for un_l in l:
    #         cur_y[labels[un_l]] = 1
    #     y_train.append(cur_y)
    # y_train = torch.tensor(y_train)
    # # 2. Нужный вид
    # X_train = []
    # for i in range(len(X)):
    #     for j in range(len(X[i])):
    #         X_train.append(X[i][j])
    print('Building Embedding')
    if emb_path == 'glove.6B.100d.txt':
        tmp_file = get_tmpfile("test_word2vec.txt")
        _ = glove2word2vec(emb_path, tmp_file)
        word2vec = KeyedVectors.load_word2vec_format(tmp_file)
    else:
        word2vec = gensim.models.KeyedVectors.load_word2vec_format(emb_path,
                                                                   binary=True)
    EMB_DIM = word2vec.vectors.shape[1]
    word2vec.add('<UNK>', np.mean(word2vec.vectors.astype('float32'), axis=0))
    word2vec.add('<PAD>', np.array(np.zeros(EMB_DIM)))
    tokenizer = Vocab()
    tokenizer.build(word2vec)

    print('Loading Data')
    X_train = pd.read_csv(data_path + "train.tsv",
                          sep="\t",
                          header=None,
                          index_col=None)
    y_train = encode_label(X_train[0].to_numpy())
    X_train = tokenizer.tokenize(X_train[1].to_numpy(), max_len=MAX_SEQ_LEN)

    X_val = pd.read_csv(data_path + "valid.tsv",
                        sep="\t",
                        header=None,
                        index_col=None)
    y_val = encode_label(X_val[0].to_numpy())
    X_val = tokenizer.tokenize(X_val[1].to_numpy(), max_len=MAX_SEQ_LEN)

    X_test = pd.read_csv(data_path + "test.tsv",
                         sep="\t",
                         header=None,
                         index_col=None)
    y_test = encode_label(X_test[0].to_numpy())
    X_test = tokenizer.tokenize(X_test[1].to_numpy(), max_len=MAX_SEQ_LEN)

    # 2. padding
    pad_val = tokenizer.get_pad()
    X_train = pad_sequence(
        X_train, batch_first=True, padding_value=pad_val).to(
            torch.long)[1:, :MAX_SEQ_LEN]  # size: tensor(batch, max_seq_len)
    X_val = pad_sequence(X_val, batch_first=True, padding_value=pad_val).to(
        torch.long)[1:, :MAX_SEQ_LEN]
    X_test = pad_sequence(X_test, batch_first=True, padding_value=pad_val).to(
        torch.long)[1:, :MAX_SEQ_LEN]

    # 3) Batch iterator
    training = data.DataLoader(MSDialog(X_train, y_train), **params)
    validation = data.DataLoader(MSDialog(X_val, y_val), **params)
    testing = data.DataLoader(MSDialog(X_test, y_test), **params)

    # 4) Model, criterion and optimizer
    model = BaseCNN(word2vec, tokenizer.get_pad(), emb_dim=EMB_DIM).to(device)
    optimizer = Adam(model.parameters(),
                     lr=0.001,
                     betas=(0.9, 0.999),
                     eps=1e-08)
    criterion = nn.BCELoss()
    # 5) training process
    treshold = 0.5
    print('Train')
    #     for X, y in training:
    #         X, y = X.to(device), y.to(device)
    #         break
    for ep in range(N_EPOCHS):
        if ep == 10:
            optimizer = Adam(model.parameters(),
                             lr=0.0001,
                             betas=(0.9, 0.999),
                             eps=1e-08)
        print(f'epoch: {ep}')
        #         j = 0
        #         # model.train()
        #         losses = []
        #         for i in range(50):
        #             optimizer.zero_grad()

        #             output = model(X)
        #             loss = torch.tensor(0.0).to(output)
        #             for i in range(output.shape[1]):
        #                 criterion = nn.BCELoss()
        #                 loss += criterion(output[:, i].unsqueeze(1), y[:, i].unsqueeze(1).to(torch.float32))
        #             losses.append(float(loss.cpu())/output.shape[1])
        #             loss.backward()
        #             optimizer.step()

        #             # print(f'iter: {j}, loss: {loss}')
        #             j += 1
        #         print(f'train loss={np.mean(losses)}')

        j = 0
        model.train()
        losses = []
        for X, y in training:
            optimizer.zero_grad()
            X, y = X.to(device), y.to(device)
            output = model(X)
            loss = torch.tensor(0.0).to(output)
            for i in range(output.shape[1]):
                criterion = nn.BCELoss()
                loss += criterion(output[:, i].unsqueeze(1),
                                  y[:, i].unsqueeze(1).to(torch.float32))
            loss.backward()
            losses.append(float(loss.cpu()))
            optimizer.step()

            # print(f'iter: {j}, loss: {loss}')
            j += 1
        print(f'train loss={np.mean(losses)}')
        with torch.no_grad():
            model.eval()
            # print('EVALUATION________')
            losses = []
            f1_scores = []
            precisions = []
            recalls = []
            accuracies = []
            for X, y in validation:
                criterion = nn.MultiLabelSoftMarginLoss()

                X, y = X.to(device), y.to(device)

                output = model(X)
                loss = torch.tensor(0.0).to(output)
                for i in range(output.shape[1]):
                    criterion = nn.BCELoss()
                    loss += criterion(output[:, i].unsqueeze(1),
                                      y[:, i].unsqueeze(1).to(torch.float32))
                losses.append(float(loss.cpu()))
                output = output.cpu().numpy()
                for i in range(len(output)):
                    pred = output[i] > treshold
                    if sum(pred) == 0:
                        pred = output[i].max(axis=0, keepdims=1) == output[i]
                    output[i] = pred
                precisions.append(get_f1(y, output)[0])
                recalls.append(get_f1(y, output)[1])
                f1_scores.append(get_f1(y, output)[2])
                accuracies.append(get_accuracy(y, output))

            print('VAL:')
            print(f'val_loss={np.mean(losses)}')
            print(f'accuracy={np.mean(accuracies)}')
            print(f'precision={np.mean(precisions)}')
            print(f'recall={np.mean(recalls)}')
            print(f'f1-score={np.mean(f1_scores)}')

            print('__________________')
    torch.save(model.state_dict(), SAVE_PATH)
def main():
    start_a = time.time()
    args_parser = argparse.ArgumentParser()
    args_parser.add_argument('--word_embedding',
                             default='glove',
                             help='Embedding for words')
    args_parser.add_argument('--word_embedding_file',
                             default=main_path +
                             'Data/NeuralRST/glove.6B.200d.txt.gz')
    args_parser.add_argument('--train',
                             default=main_path + 'Data/NeuralRST/rst.train312')
    args_parser.add_argument('--test',
                             default=main_path + 'Data/NeuralRST/rst.test38')
    args_parser.add_argument('--dev',
                             default=main_path + 'Data/NeuralRST/rst.dev35')
    args_parser.add_argument(
        '--train_syn_feat',
        default=main_path +
        'Data/NeuralRST/SyntaxBiaffine/train.conll.dump.results')
    args_parser.add_argument(
        '--test_syn_feat',
        default=main_path +
        'Data/NeuralRST/SyntaxBiaffine/test.conll.dump.results')
    args_parser.add_argument(
        '--dev_syn_feat',
        default=main_path +
        'Data/NeuralRST/SyntaxBiaffine/dev.conll.dump.results')
    args_parser.add_argument('--model_path',
                             default=main_path + 'Data/NeuralRST/experiment')
    args_parser.add_argument('--experiment',
                             help='Name of your experiment',
                             required=True)
    args_parser.add_argument('--model_name', default='network.pt')
    args_parser.add_argument('--max_iter',
                             type=int,
                             default=1000,
                             help='maximum epoch')

    args_parser.add_argument('--word_dim',
                             type=int,
                             default=200,
                             help='Dimension of word embeddings')
    args_parser.add_argument('--tag_dim',
                             type=int,
                             default=200,
                             help='Dimension of POS tag embeddings')
    args_parser.add_argument('--etype_dim',
                             type=int,
                             default=100,
                             help='Dimension of Etype embeddings')
    args_parser.add_argument('--syntax_dim',
                             type=int,
                             default=1200,
                             help='Dimension of Sytax embeddings')
    args_parser.add_argument(
        '--freeze',
        default=True,
        help='frozen the word embedding (disable fine-tuning).')

    args_parser.add_argument('--max_sent_size',
                             type=int,
                             default=20,
                             help='maximum word size in 1 edu')
    # two args below are not used in top-down discourse parsing
    args_parser.add_argument('--max_edu_size',
                             type=int,
                             default=400,
                             help='maximum edu size')
    args_parser.add_argument('--max_state_size',
                             type=int,
                             default=1024,
                             help='maximum decoding steps')

    args_parser.add_argument('--hidden_size', type=int, default=200, help='')
    args_parser.add_argument('--hidden_size_tagger',
                             type=int,
                             default=100,
                             help='')

    args_parser.add_argument('--drop_prob',
                             type=float,
                             default=0.5,
                             help='default drop_prob')
    args_parser.add_argument('--num_layers',
                             type=int,
                             default=1,
                             help='number of RNN layers')

    args_parser.add_argument('--batch_size',
                             type=int,
                             default=4,
                             help='Number of sentences in each batch')
    args_parser.add_argument('--lr',
                             type=float,
                             default=0.001,
                             help='Learning rate')
    args_parser.add_argument('--ada_eps',
                             type=float,
                             default=1e-6,
                             help='epsilon for adam or adamax')
    args_parser.add_argument(
        '--opt',
        default='adam',
        help='Optimization, choose between adam, sgd, and adamax')
    args_parser.add_argument('--start_decay', type=int, default=0, help='')
    args_parser.add_argument('--grad_accum',
                             type=int,
                             default=2,
                             help='gradient accumulation setting')

    args_parser.add_argument('--beta1',
                             type=float,
                             default=0.9,
                             help='beta1 for adam')
    args_parser.add_argument('--beta2',
                             type=float,
                             default=0.999,
                             help='beta2 for adam')
    args_parser.add_argument('--gamma',
                             type=float,
                             default=1e-6,
                             help='weight for regularization')
    args_parser.add_argument('--clip',
                             type=float,
                             default=10.0,
                             help='gradient clipping')

    args_parser.add_argument('--loss_nuc_rel',
                             type=float,
                             default=1.0,
                             help='weight for nucleus and relation loss')
    args_parser.add_argument('--loss_seg',
                             type=float,
                             default=1.0,
                             help='weight for segmentation loss')
    args_parser.add_argument('--activate_nuc_rel_loss',
                             type=int,
                             default=0,
                             help='set the starting epoch for nuclear loss')
    args_parser.add_argument(
        '--activate_seg_loss',
        type=int,
        default=0,
        help='set the starting epoch for segmentation loss')

    args_parser.add_argument('--decay', type=int, default=0, help='')
    args_parser.add_argument('--oracle_prob',
                             type=float,
                             default=0.66666,
                             help='')
    args_parser.add_argument('--start_dynamic_oracle',
                             type=int,
                             default=20,
                             help='')
    args_parser.add_argument('--use_dynamic_oracle',
                             type=int,
                             default=0,
                             help='')
    args_parser.add_argument('--early_stopping', type=int, default=50, help='')

    args_parser.add_argument('--beam_search',
                             type=int,
                             default=1,
                             help='assign parameter k for beam search')
    args_parser.add_argument(
        '--depth_alpha',
        type=float,
        default=0.0,
        help='multiplier of loss based on depth of the subtree')
    args_parser.add_argument(
        '--elem_alpha',
        type=float,
        default=0.0,
        help='multiplier of loss based on number of element in a subtree')
    args_parser.add_argument('--seed',
                             type=int,
                             default=999,
                             help='random seed')

    args = args_parser.parse_args()
    config = Config(args)

    torch.manual_seed(config.seed)
    if config.use_gpu:
        torch.cuda.manual_seed_all(config.seed)

    if not os.path.exists(config.model_path):
        os.makedirs(config.model_path)

    logger = get_logger("RSTParser", config.use_dynamic_oracle,
                        config.model_path)
    if config.use_dynamic_oracle:
        logger.info(
            "This is using DYNAMIC oracle, and will be activated at Epoch %d" %
            (config.start_dynamic_oracle))
        model_name = 'dynamic_' + config.model_name
    else:
        logger.info("This is using STATIC oracle")
        model_name = 'static_' + config.model_name

    logger.info("Load word embedding, will take 2 minutes")
    pretrained_embed, word_dim = load_embedding_dict(
        config.word_embedding, config.word_embedding_file)
    assert (word_dim == config.word_dim)

    logger.info("Reading Train start")
    reader = Reader(config.train_path, config.train_syn_feat_path)
    train_instances = reader.read_data()
    logger.info('Finish reading training instances: ' +
                str(len(train_instances)))
    logger.info('Max sentence size: ' + str(config.max_sent_size))

    logger.info('Creating Alphabet....')
    config.model_name = os.path.join(config.model_path, config.model_name)
    word_alpha, tag_alpha, gold_action_alpha, action_label_alpha, relation_alpha, nuclear_alpha, nuclear_relation_alpha, etype_alpha = create_alphabet(
        train_instances, config.alphabet_path, logger)
    vocab = Vocab(word_alpha, tag_alpha, etype_alpha, gold_action_alpha,
                  action_label_alpha, relation_alpha, nuclear_alpha,
                  nuclear_relation_alpha)
    set_label_action(action_label_alpha.alpha2id, train_instances)

    # logger.info('Checking Gold Actions for transition-based parser....')
    # validate_gold_actions(train_instances, config.max_state_size)
    logger.info('Checking Gold Labels for top-down parser....')
    validate_gold_top_down(train_instances)

    word_table = construct_embedding_table(word_alpha, config.word_dim,
                                           config.freeze, pretrained_embed)
    tag_table = construct_embedding_table(tag_alpha, config.tag_dim,
                                          config.freeze)
    etype_table = construct_embedding_table(etype_alpha, config.etype_dim,
                                            config.freeze)

    logger.info("Finish reading train data by:" +
                str(round(time.time() - start_a, 2)) + 'sec')

    # DEV data processing
    reader = Reader(config.dev_path, config.dev_syn_feat_path)
    dev_instances = reader.read_data()
    logger.info('Finish reading dev instances')

    # TEST data processing
    reader = Reader(config.test_path, config.test_syn_feat_path)
    test_instances = reader.read_data()
    logger.info('Finish reading test instances')

    torch.set_num_threads(4)
    network = MainArchitecture(vocab, config, word_table, tag_table,
                               etype_table)

    if config.freeze:
        network.word_embedd.freeze()
    if config.use_gpu:
        network.cuda()

    # Set-up Optimizer
    def generate_optimizer(config, params):
        params = filter(lambda param: param.requires_grad, params)
        if config.opt == 'adam':
            return Adam(params,
                        lr=config.lr,
                        betas=config.betas,
                        weight_decay=config.gamma,
                        eps=config.ada_eps)
        elif config.opt == 'sgd':
            return SGD(params,
                       lr=config.lr,
                       momentum=config.momentum,
                       weight_decay=config.start_decay,
                       nesterov=True)
        elif opt == 'adamax':
            return Adamax(params,
                          lr=config.lr,
                          betas=config.betas,
                          weight_decay=config.start_decay,
                          eps=config.ada_eps)
        else:
            raise ValueError('Unknown optimization algorithm: %s' % config.opt)

    optim = generate_optimizer(config, network.parameters())
    opt_info = 'opt: %s, ' % config.opt
    if config.opt == 'adam':
        opt_info += 'betas=%s, eps=%.1e, lr=%.5f, weight_decay=%.1e' % (
            config.betas, config.ada_eps, config.lr, config.gamma)
    elif config.opt == 'sgd':
        opt_info += 'momentum=%.2f' % config.momentum
    elif config.opt == 'adamax':
        opt_info += 'betas=%s, eps=%.1e, lr=%f' % (config.betas,
                                                   config.ada_eps, config.lr)

    logger.info(opt_info)

    def get_subtrees(data, indices):
        subtrees = []
        for i in indices:
            subtrees.append(data[i].result)
        return subtrees

    # START TRAINING
    config.save()
    batch_size = config.batch_size
    logger.info('Start doing training....')
    total_data = len(train_instances)
    logger.info('Batch size: %d' % batch_size)
    num_batch = total_data / batch_size + 1
    es_counter = 0
    best_S = 0
    best_S_ori = 0
    best_N = 0
    best_N_ori = 0
    best_R = 0
    best_R_ori = 0
    best_F = 0
    best_F_ori = 0
    iteration = -1

    for epoch in range(0, config.max_iter):
        logger.info('Epoch %d ' % (epoch))
        logger.info("Current learning rate: %.5f" % (config.lr))

        if epoch == config.start_dynamic_oracle and config.use_dynamic_oracle:
            logger.info("In this epoch, dynamic oracle is activated!")
            config.flag_oracle = True

        permutation = torch.randperm(total_data).long()
        network.metric_span.reset()
        network.metric_nuclear_relation.reset()
        time_start = datetime.now()
        costs = []
        counter_acc = 0
        for i in range(0, total_data, batch_size):
            network.train()
            network.training = True

            indices = permutation[i:i + batch_size]
            subset_data = batch_data_variable(train_instances, indices, vocab,
                                              config)
            gold_subtrees = get_subtrees(train_instances, indices)

            cost, cost_val = network.loss(subset_data,
                                          gold_subtrees,
                                          epoch=epoch)
            costs.append(cost_val)
            cost.backward()
            counter_acc += 1

            if config.grad_accum > 1 and counter_acc == config.grad_accum:
                clip_grad_norm_(network.parameters(), config.clip)
                optim.step()
                network.zero_grad()
                counter_acc = 0
            elif config.grad_accum == 1:
                optim.step()
                network.zero_grad()
                counter_acc = 0

            time_elapsed = datetime.now() - time_start
            m, s = divmod(time_elapsed.seconds, 60)
            logger.info(
                'Epoch %d, Batch %d, AvgCost: %.2f, CorrectSpan: %.2f, CorrectNuclearRelation: %.2f - {} mins {} secs'
                .format(m, s) %
                (epoch, (i + batch_size) / batch_size, np.mean(costs),
                 network.metric_span.get_accuracy(),
                 network.metric_nuclear_relation.get_accuracy()))

        # Perform evaluation if span accuracy is at leas 0.8 OR when dynamic oracle is activated
        if network.metric_span.get_accuracy() < 0.8 and not config.flag_oracle:
            logger.info(
                'We only perform test for DEV and TEST set if the span accuracy >= 0.80'
            )
            continue

        logger.info('Batch ends, performing test for DEV and TEST set')
        # START EVALUATING DEV:
        network.eval()
        network.training = False

        logger.info('Evaluate DEV:')
        span, nuclear, relation, full, span_ori, nuclear_ori, relation_ori, full_ori =\
                predict(network, dev_instances, vocab, config, logger)

        if best_F < full.get_f_measure():
            best_S = span.get_f_measure()
            best_S_ori = span_ori.get_f_measure()
            best_N = nuclear.get_f_measure()
            best_N_ori = nuclear_ori.get_f_measure()
            best_R = relation.get_f_measure()
            best_R_ori = relation_ori.get_f_measure()
            best_F = full.get_f_measure()
            best_F_ori = full_ori.get_f_measure()
            iteration = epoch
            #save the model
            config.save()
            torch.save(network.state_dict(), config.model_name)
            logger.info('Model is successfully saved')
            es_counter = 0
        else:
            logger.info(
                "NOT exceed best Full F-score: history = %.2f, current = %.2f"
                % (best_F, full.get_f_measure()))
            logger.info(
                "Best dev performance in Iteration %d with result S (rst): %.4f, N (rst): %.4f, R (rst): %.4f, F (rst): %.4f"
                % (iteration, best_S, best_N, best_R, best_F))
            #logger.info("Best dev performance in Iteration %d with result S (ori): %.4f, N (ori): %.4f, R (ori): %.4f, F (ori): %.4f" %(iteration, best_S_ori, best_N_ori, best_R_ori, best_F_ori))
            if es_counter > config.early_stopping:
                logger.info(
                    'Early stopping after getting lower DEV performance in %d consecutive epoch. BYE, Assalamualaikum!'
                    % (es_counter))
                sys.exit()
            es_counter += 1

        # START EVALUATING TEST:
        logger.info('Evaluate TEST:')
        span, nuclear, relation, full, span_ori, nuclear_ori, relation_ori, full_ori =\
                predict(network, test_instances, vocab, config, logger)