Ejemplo n.º 1
0
def _main(_):
    '''
    # Create output_path
    if os.path.exists(output_path):
        logger.error('output path {} already exists'.format(output_path))
        raise ValueError('output path {} already exists'.format(output_path))
    os.mkdir(output_path)
    os.mkdir('{}/src'.format(output_path))
    os.system('cp *.py {}/src'.format(output_path))
    os.system('cp models/*.py {}/src'.format(output_path))
    os.system('cp utils_data/*.py {}/src'.format(output_path))
    '''
    if not os.path.exists(output_path):
        os.mkdir(output_path)
        os.mkdir('{}/src'.format(output_path))
        os.system('cp *.py {}/src'.format(output_path))
        os.system('cp models/*.py {}/src'.format(output_path))
        os.system('cp utils_data/*.py {}/src'.format(output_path))

    # clean sample_path and checkpoint_path before training
    if tf.gfile.Exists(config.sample_path):
        tf.gfile.DeleteRecursively(config.sample_path)
    if tf.gfile.Exists(config.checkpoint_path):
        tf.gfile.DeleteRecursively(config.checkpoint_path)
    tf.gfile.MakeDirs(config.sample_path)
    tf.gfile.MakeDirs(config.checkpoint_path)

    # Data
    train_data = MultiAlignedNumpyData(config.train_data)
    val_data = MultiAlignedNumpyData(config.val_data)
    test_data = MultiAlignedNumpyData(config.test_data)
    vocab = train_data.vocab(0)

    # Each training batch is used twice: once for updating the generator and
    # once for updating the discriminator. Feedable data iterator is used for
    # such case.
    iterator = tx.data.FeedableDataIterator({
        'train': train_data,
        'val': val_data,
        'test': test_data
    })
    batch = iterator.get_next()

    # Model
    model = RELA_CLASS(batch, vocab, config.model)

    def _train_epoch(sess,
                     epoch,
                     adjs_true_list,
                     adjs_preds_list,
                     verbose=True):
        avg_meters_d = tx.utils.AverageRecorder(size=10)

        step = 0
        while True:
            try:
                step += 1
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, 'train'),
                }
                vals_d = sess.run(model.fetches_train_d, feed_dict=feed_dict)
                adjs_truth = np.reshape(vals_d.pop("adjs_truth"),
                                        [-1])  # [128,17,17]
                adjs_preds = np.reshape(vals_d.pop("adjs_preds"), [-1])
                adjs_true_list.extend(adjs_truth)
                adjs_preds_list.extend(adjs_preds)
                avg_meters_d.add(vals_d)

                if verbose and (step == 1 or step % config.display == 0):
                    logger.info('step: {}, {}'.format(step,
                                                      avg_meters_d.to_str(4)))

                if verbose and step % config.display_eval == 0:
                    iterator.restart_dataset(sess, 'val')
                    tmp_a = []
                    tmp_b = []
                    _eval_epoch(sess, epoch, tmp_a, tmp_b)

            except tf.errors.OutOfRangeError:
                logger.info('epoch: {}, {}'.format(epoch,
                                                   avg_meters_d.to_str(4)))
                break

        return adjs_true_list, adjs_preds_list

    def _eval_epoch(sess,
                    epoch,
                    adjs_true_list,
                    adjs_preds_list,
                    val_or_test='val'):
        avg_meters = tx.utils.AverageRecorder()

        while True:
            try:
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, val_or_test),
                    tx.context.global_mode(): tf.estimator.ModeKeys.EVAL
                }

                vals = sess.run(model.fetches_eval, feed_dict=feed_dict)
                adjs_truth = np.reshape(vals.pop("adjs_truth"),
                                        [-1])  # [128,17,17]
                adjs_preds = np.reshape(vals.pop("adjs_preds"), [-1])
                adjs_true_list.extend(adjs_truth)
                adjs_preds_list.extend(adjs_preds)

                batch_size = vals.pop('batch_size')
                avg_meters.add(vals, weight=batch_size)
                '''
                # Writes samples
                tx.utils.write_paired_text(
                    refs.squeeze(), hyps,
                    os.path.join(config.sample_path, 'val.%d'%epoch),
                    append=True, mode='v')
                '''

            except tf.errors.OutOfRangeError:
                logger.info('{}: {}'.format(val_or_test,
                                            avg_meters.to_str(precision=4)))
                break

        return avg_meters.avg(), adjs_true_list, adjs_preds_list

    tf.gfile.MakeDirs(config.sample_path)
    tf.gfile.MakeDirs(config.checkpoint_path)

    # Runs the logics
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())

        saver = tf.train.Saver(max_to_keep=None)
        if config.restore:
            logger.info('Restore from: {}'.format(config.restore))
            saver.restore(sess, config.restore)

        iterator.initialize_dataset(sess)

        test_true = []
        test_preds = []
        for epoch in range(1, config.max_nepochs + 1):
            val_adjs_true = []
            val_adjs_preds = []
            test_adjs_true = []
            test_adjs_preds = []
            train_adjs_true = []
            train_adjs_preds = []

            #logger.info('gamma: {}'.format(gamma_))

            # Train
            iterator.restart_dataset(sess, ['train'])
            train_adjs_true, train_adjs_preds = _train_epoch(
                sess, epoch, train_adjs_true, train_adjs_preds)

            # Val
            iterator.restart_dataset(sess, 'val')
            _, val_adjs_true, val_adjs_preds = _eval_epoch(
                sess, epoch, val_adjs_true, val_adjs_preds, 'val')

            saver.save(sess, os.path.join(config.checkpoint_path, 'ckpt'),
                       epoch)

            # Test
            iterator.restart_dataset(sess, 'test')
            _, test_adjs_true, test_adjs_preds = _eval_epoch(
                sess, epoch, test_adjs_true, test_adjs_preds, 'test')

            if epoch == config.max_nepochs:
                test_true = test_adjs_true
                test_preds = test_adjs_preds

            #plot_confusion_matrix(train_adjs_true, train_adjs_preds, classes=['1', '0'],
            #          title='Train Confusion matrix, without normalization')
            #plot_confusion_matrix(val_adjs_true, val_adjs_preds, classes=['1', '0'],
            #          title='Val Confusion matrix, without normalization')
        plot_confusion_matrix(
            test_true,
            test_preds,
            classes=np.array(["non-relevant", "relevant"]),
            normalize=True,
            title='Test Confusion matrix, without normalization')
        plot_confusion_matrix(
            test_true,
            test_preds,
            classes=np.array(["non-relevant", "relevant"]),
            normalize=False,
            title='Test Confusion matrix, without normalization')
        plt.show()
Ejemplo n.º 2
0
def _main(_):    
    '''
    if config.distributed:
        import horovod.tensorflow as hvd
        hvd.init()
        for i in range(14):
            config.train_data["datasets"][i]["num_shards"] = hvd.size()
            config.train_data["datasets"][i]["shard_id"] = hvd.rank()
        config.train_data["batch_size"] //= hvd.size()
    '''

    # Data
    train_data = MultiAlignedNumpyData(config.train_data)
    val_data = MultiAlignedNumpyData(config.val_data)
    test_data = MultiAlignedNumpyData(config.test_data)
    vocab = train_data.vocab(0)
    ctx_maxSeqLen = config.max_sequence_length




    # Each training batch is used twice: once for updating the generator and
    # once for updating the discriminator. Feedable data iterator is used for
    # such case.
    iterator = tx.data.FeedableDataIterator(
        {'train_pre': train_data, 'val': val_data, 'test': test_data})
    batch = iterator.get_next()

    # Model
    #global_step = tf.placeholder(tf.int32) #hvd
    lr = tf.placeholder(dtype=tf.float32, shape=[], name='lr') #hvd
    
    
    if config.model_name == 'EvolveGTAE':
        #model = EvolveGTAE(batch, vocab, ctx_maxSeqLen, hvd, global_step, config.model)
        model = EvolveGTAE(batch, vocab, ctx_maxSeqLen, lr, config.model)
    else:
        logger.error('config.model_name: {} is incorrect'.format(config.model_name))
        raise ValueError('config.model_name: {} is incorrect'.format(config.model_name))

    def _train_epoch(sess, lr_,  writer, epoch, flag, verbose=True):
        avg_meters_pre = tx.utils.AverageRecorder(size=10)

        step = 0
        while True:
            try:
                step += 1
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, 'train_pre'),
                    lr:lr_
                }
                vals_pre = sess.run(model.fetches_train_pre, feed_dict=feed_dict)

                merged_summary = vals_pre.pop('merged')

                iteration = (epoch-1)*(config.train_num/config.batch_size)+step
                if iteration % 20 == 0:
                    writer.add_summary(merged_summary, iteration)
                
    

                avg_meters_pre.add(vals_pre)
                
                if verbose and (step == 1 or step % config.display == 0):
                    logger.info('step: {}, {}'.format(step, avg_meters_pre.to_str(4)))
                    sys.stdout.flush()

            except tf.errors.OutOfRangeError:
                logger.info('epoch: {}, {}'.format(epoch, avg_meters_pre.to_str(4)))
                sys.stdout.flush()
                break

    def _eval_epoch(sess, lr_, writer, epoch, val_or_test='val'):##1
        avg_meters = tx.utils.AverageRecorder()
        
        step = 0
        while True:
            try:
                step += 1
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, val_or_test),
                    lr:lr_,
                    tx.context.global_mode(): tf.estimator.ModeKeys.EVAL
                }
                vals = sess.run(model.fetches_eval, feed_dict=feed_dict)

                iteration = (epoch-1)*(config.dev_num/config.batch_size)+step
                if val_or_test is 'val' and iteration % 10 == 0:
                    merged_summary = vals['merged']
                    writer.add_summary(merged_summary, iteration)
                vals.pop('merged')
                

                batch_size = vals.pop('batch_size')

                # Computes BLEU
                samples = tx.utils.dict_pop(vals, list(model.samples.keys()))
                
                x1x2 = tx.utils.map_ids_to_strs(samples['x1x2'], vocab)
                x1xx2 = tx.utils.map_ids_to_strs(samples['x1xx2'], vocab)

                hyps_y1 = tx.utils.map_ids_to_strs(samples['transferred_yy1_pred'], vocab)
                refs_y1 = tx.utils.map_ids_to_strs(samples['transferred_yy1_gt'], vocab) ## text == ['a sentence', 'parsed from ids']
                origin_y1 = tx.utils.map_ids_to_strs(samples['origin_y1'], vocab)
                refs_y1 = np.expand_dims(refs_y1, axis=1) #[32,1]
                bleu = tx.evals.corpus_bleu_moses(refs_y1, hyps_y1) #[32]
                vals['bleu_y1'] = bleu

                hyps_y2 = tx.utils.map_ids_to_strs(samples['transferred_yy2_pred'], vocab)
                refs_y2 = tx.utils.map_ids_to_strs(samples['transferred_yy2_gt'], vocab)
                origin_y2 = tx.utils.map_ids_to_strs(samples['origin_y2'], vocab)
                refs_y2 = np.expand_dims(refs_y2, axis=1)
                bleu = tx.evals.corpus_bleu_moses(refs_y2, hyps_y2)
                vals['bleu_y2'] = bleu

                hyps_y3 = tx.utils.map_ids_to_strs(samples['transferred_yy3_pred'], vocab)
                refs_y3 = tx.utils.map_ids_to_strs(samples['transferred_yy3_gt'], vocab)
                origin_y3 = tx.utils.map_ids_to_strs(samples['origin_y3'], vocab)
                refs_y3 = np.expand_dims(refs_y3, axis=1)
                bleu = tx.evals.corpus_bleu_moses(refs_y3, hyps_y3)
                vals['bleu_y3'] = bleu

                avg_meters.add(vals, weight=batch_size)

                
                # Writes samples
                if val_or_test is 'test':
                    tx.utils.write_paired_text(
                        x1x2, x1xx2,
                        os.path.join(sample_path, 'val_x.%d'%epoch),
                        append=True, mode='v')
                    tx.utils.write_paired_text(
                        refs_y1.squeeze(), hyps_y1,
                        os.path.join(sample_path, 'val_y1.%d'%epoch),
                        append=True, mode='v')
                    tx.utils.write_paired_text(
                        refs_y2.squeeze(), hyps_y2,
                        os.path.join(sample_path, 'val_y2.%d'%epoch),
                        append=True, mode='v')
                    tx.utils.write_paired_text(
                        refs_y3.squeeze(), hyps_y3,
                        os.path.join(sample_path, 'val_y3.%d'%epoch),
                        append=True, mode='v')

                    tx.utils.write_paired_text(
                        refs_y1.squeeze(), origin_y1,
                        os.path.join(sample_path, 'val_yy1gt_y1.%d'%epoch),
                        append=True, mode='v')
                    tx.utils.write_paired_text(
                        refs_y2.squeeze(), origin_y2,
                        os.path.join(sample_path, 'val_yy2gt_y2.%d'%epoch),
                        append=True, mode='v')
                    tx.utils.write_paired_text(
                        refs_y3.squeeze(), origin_y3,
                        os.path.join(sample_path, 'val_yy3gt_y3.%d'%epoch),
                        append=True, mode='v')

            except tf.errors.OutOfRangeError:
                logger.info('{}: {}'.format(
                    val_or_test, avg_meters.to_str(precision=4)))
                break

        return avg_meters.avg()

    '''
    if config.distributed:
        bcast = hvd.broadcast_global_variables(0)
        tf.ConfigProto().gpu_options.visible_device_list = str(hvd.local_rank())
    '''
    # Runs the logics
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())

        # visulization
        graph_writer=tf.summary.FileWriter(vis_graph_path, sess.graph)

        # visulization
        train_writer = tf.summary.FileWriter(vis_train_path, tf.Graph())
        # visulization
        val_writer = tf.summary.FileWriter(vis_val_path, tf.Graph())

        
        saver = tf.train.Saver(max_to_keep=None)

        if config.restore:
            logger.info('Restore from: {}'.format(config.restore))
            saver.restore(sess, config.restore)

        iterator.initialize_dataset(sess)
        lr_ = config.initial_lr
    
        for epoch in range(0, config.max_nepochs + 1):                               ###modify
            flag = True
            '''
            if epoch<=3:
                lr_=config.initial_lr*(epoch+1)
            if epoch>=10 and epoch %2==0:
                lr_*=0.25
            '''

            logger.info('learning rate: {}'.format(lr_))
            # Train
            iterator.restart_dataset(sess, ['train_pre'])
            
            _train_epoch(sess, lr_, train_writer, epoch, flag)

            # Val
            iterator.restart_dataset(sess, 'val')
            _eval_epoch(sess, lr_,val_writer, epoch, 'val') ##1

            if epoch%3==0:
                saver.save(
                    sess, os.path.join(checkpoint_path, 'ckpt'), epoch)                 ###modify

            # Test
            iterator.restart_dataset(sess, 'test')
            _eval_epoch(sess, lr_,val_writer, epoch, 'test')

            

        graph_writer.close()
        train_writer.close()
        val_writer.close()
Ejemplo n.º 3
0
def _main(_):
    # Create output_path
    if os.path.exists(output_path):
        logger.error('output path {} already exists'.format(output_path))
        raise ValueError('output path {} already exists'.format(output_path))
    os.mkdir(output_path)
    os.mkdir('{}/src'.format(output_path))
    os.system('cp *.py {}/src'.format(output_path))
    os.system('cp models/*.py {}/src'.format(output_path))
    os.system('cp utils_data/*.py {}/src'.format(output_path))

    # clean sample_path and checkpoint_path before training 
    if tf.gfile.Exists(sample_path):
        tf.gfile.DeleteRecursively(sample_path)
    if tf.gfile.Exists(checkpoint_path):
        tf.gfile.DeleteRecursively(checkpoint_path)
    tf.gfile.MakeDirs(sample_path)
    tf.gfile.MakeDirs(checkpoint_path)
    
    # Data
    train_data = MultiAlignedNumpyData(config.train_data)
    val_data = MultiAlignedNumpyData(config.val_data)
    test_data = MultiAlignedNumpyData(config.test_data)
    vocab = train_data.vocab(0)

    # Each training batch is used twice: once for updating the generator and
    # once for updating the discriminator. Feedable data iterator is used for
    # such case.
    iterator = tx.data.FeedableDataIterator(
        {'train_g': train_data, 'train_d': train_data,
         'val': val_data, 'test': test_data})
    batch = iterator.get_next()

    # Model
    gamma = tf.placeholder(dtype=tf.float32, shape=[], name='gamma')
    lambda_t_graph = tf.placeholder(dtype=tf.float32, shape=[], name='lambda_t_graph')
    lambda_t_sentence = tf.placeholder(dtype=tf.float32, shape=[], name='lambda_t_sentence')
    
    if config.model_name == 'GTAE':
        model = GTAE(batch, vocab, gamma, lambda_t_graph, lambda_t_sentence, ablation, config.model)
    else:
        logger.error('config.model_name: {} is incorrect'.format(config.model_name))
        raise ValueError('config.model_name: {} is incorrect'.format(config.model_name))

    def _train_epoch(sess, gamma_, lambda_t_graph_, lambda_t_sentence_, epoch, verbose=True):
        avg_meters_d = tx.utils.AverageRecorder(size=10)
        avg_meters_g = tx.utils.AverageRecorder(size=10)

        step = 0
        while True:
            try:
                step += 1
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, 'train_d'),
                    gamma: gamma_,
                    lambda_t_graph: lambda_t_graph_,
                    lambda_t_sentence: lambda_t_sentence_
                }

                vals_d = sess.run(model.fetches_train_d, feed_dict=feed_dict)
                avg_meters_d.add(vals_d)

                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, 'train_g'),
                    gamma: gamma_,
                    lambda_t_graph: lambda_t_graph_,
                    lambda_t_sentence: lambda_t_sentence_
                }
                vals_g = sess.run(model.fetches_train_g, feed_dict=feed_dict)
                avg_meters_g.add(vals_g)

                if verbose and (step == 1 or step % config.display == 0):
                    logger.info('step: {}, {}'.format(step, avg_meters_d.to_str(4)))
                    logger.info('step: {}, {}'.format(step, avg_meters_g.to_str(4)))
                    sys.stdout.flush()

                if verbose and step % config.display_eval == 0:
                    iterator.restart_dataset(sess, 'val')
                    _eval_epoch(sess, gamma_, lambda_t_graph_, lambda_t_sentence_, epoch)

            except tf.errors.OutOfRangeError:
                logger.info('epoch: {}, {}'.format(epoch, avg_meters_d.to_str(4)))
                logger.info('epoch: {}, {}'.format(epoch, avg_meters_g.to_str(4)))
                sys.stdout.flush()
                break

    def _eval_epoch(sess, gamma_, lambda_t_graph_, lambda_t_sentence_, epoch, val_or_test='val'):
        avg_meters = tx.utils.AverageRecorder()

        while True:
            try:
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, val_or_test),
                    gamma: gamma_,
                    lambda_t_graph: lambda_t_graph_,
                    lambda_t_sentence: lambda_t_sentence_,
                    tx.context.global_mode(): tf.estimator.ModeKeys.EVAL
                }

                vals = sess.run(model.fetches_eval, feed_dict=feed_dict)

                batch_size = vals.pop('batch_size')

                # Computes BLEU
                samples = tx.utils.dict_pop(vals, list(model.samples.keys()))
                hyps = tx.utils.map_ids_to_strs(samples['transferred'], vocab)

                refs = tx.utils.map_ids_to_strs(samples['original'], vocab)
                refs = np.expand_dims(refs, axis=1)

                bleu = tx.evals.corpus_bleu_moses(refs, hyps)
                vals['bleu'] = bleu

                avg_meters.add(vals, weight=batch_size)

                # Writes samples
                tx.utils.write_paired_text(
                    refs.squeeze(), hyps,
                    os.path.join(sample_path, 'val.%d'%epoch),
                    append=True, mode='v')

            except tf.errors.OutOfRangeError:
                logger.info('{}: {}'.format(
                    val_or_test, avg_meters.to_str(precision=4)))
                break

        return avg_meters.avg()

    # Runs the logics
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())

        saver = tf.train.Saver(max_to_keep=None)
        if config.restore:
            logger.info('Restore from: {}'.format(config.restore))
            saver.restore(sess, config.restore)

        iterator.initialize_dataset(sess)

        gamma_ = 1.
        lambda_t_graph_ = 0.
        lambda_t_sentence_ = 0.
        for epoch in range(1, max_nepochs + 1):
            if epoch > FLAGS.pretrain_nepochs:
                # Anneals the gumbel-softmax temperature
                gamma_ = max(0.001, gamma_ * config.gamma_decay)
                lambda_t_graph_ = FLAGS.lambda_t_graph
                lambda_t_sentence_ = FLAGS.lambda_t_sentence
            logger.info('gamma: {}, lambda_t_graph: {}, lambda_t_sentence: {}'.format(gamma_, lambda_t_graph_, lambda_t_sentence_))

            # Train
            iterator.restart_dataset(sess, ['train_g', 'train_d'])
            _train_epoch(sess, gamma_, lambda_t_graph_, lambda_t_sentence_, epoch)

            # Val
            iterator.restart_dataset(sess, 'val')
            _eval_epoch(sess, gamma_, lambda_t_graph_, lambda_t_sentence_, epoch, 'val')

            if epoch > FLAGS.pretrain_nepochs:
                saver.save(sess, os.path.join(checkpoint_path, 'ckpt'), epoch)

            # Test
            iterator.restart_dataset(sess, 'test')
            _eval_epoch(sess, gamma_, lambda_t_graph_, lambda_t_sentence_, epoch, 'test')
    
    logger.info('tensorflow training process finished successlly!')
    if not os.path.exists('{}.log'.format(output_path)):
        logger.error('cannot find {}.log'.format(output_path))
    else:
        os.system('mv {}.log {}/'.format(output_path, output_path))
Ejemplo n.º 4
0
def _main(_):    
    '''
    if config.distributed:
        import horovod.tensorflow as hvd
        hvd.init()
        for i in range(14):
            config.train_data["datasets"][i]["num_shards"] = hvd.size()
            config.train_data["datasets"][i]["shard_id"] = hvd.rank()
        config.train_data["batch_size"] //= hvd.size()
    '''

    # Data
    train_data = MultiAlignedNumpyData(config.train_data)
    val_data = MultiAlignedNumpyData(config.val_data)
    test_data = MultiAlignedNumpyData(config.test_data)
    vocab = train_data.vocab(0)
    ctx_maxSeqLen = config.max_sequence_length

    
    # Each training batch is used twice: once for updating the generator and
    # once for updating the discriminator. Feedable data iterator is used for
    # such case.
    iterator = tx.data.FeedableDataIterator(
        {'train_pre': train_data, 'val': val_data, 'test': test_data})
    batch = iterator.get_next()

    node_add_1_sum = 0
    node_add_2_sum = 0
    node_add_3_sum = 0
    node_dec_1_sum = 0
    node_dec_2_sum = 0
    node_dec_3_sum = 0

    edge_add_1_sum = 0
    edge_add_2_sum = 0
    edge_add_3_sum = 0
    edge_dec_1_sum = 0
    edge_dec_2_sum = 0
    edge_dec_3_sum = 0

    edge_C_1_sum = 0
    edge_C_2_sum = 0
    edge_C_3_sum = 0

    samples_num = 0
    
    # Runs the logics
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())

        iterator.initialize_dataset(sess)

        for mode in ['test']:#, 'val', 'test']:
            #train
            iterator.restart_dataset(sess, [mode])
            while True:
                try:
                    
                    feed_dict = {
                        iterator.handle: iterator.get_handle(sess, mode),
                    }
                    fetches_data = {
                        'batch': batch,
                    }
                    inputs = sess.run(fetches_data, feed_dict)['batch']
                    
                    
                    text_ids_y1 = inputs['y1_yy1_text_ids'][:,0,:]   # [batch, maxlen1]
                    text_ids_yy1 = inputs['y1_yy1_text_ids'][:,1,:]  # [batch, maxlen1]
                    text_ids_y2 = inputs['y2_yy2_text_ids'][:,0,:]   # [batch, maxlen2]
                    text_ids_yy2 = inputs['y2_yy2_text_ids'][:,1,:]  # [batch, maxlen2]
                    text_ids_y3 = inputs['y3_yy3_text_ids'][:,0,:]   # [batch, maxlen3]
                    text_ids_yy3 = inputs['y3_yy3_text_ids'][:,1,:]  # [batch, maxlen3]

                    sequence_length_y1 = inputs['y1_yy1_length'][:,0]
                    sequence_length_y2 = inputs['y2_yy2_length'][:,0]
                    sequence_length_y3 = inputs['y3_yy3_length'][:,0]
                    sequence_length_yy1 = inputs['y1_yy1_length'][:,1]
                    sequence_length_yy2 = inputs['y2_yy2_length'][:,1]
                    sequence_length_yy3 = inputs['y3_yy3_length'][:,1]

                    enc_shape_y1 = tf.shape(text_ids_y1)[1]
                    enc_shape_y2 = tf.shape(text_ids_y2)[1]
                    enc_shape_y3 = tf.shape(text_ids_y3)[1]

                    adjs_y1_undirt = np.int32(np.reshape(inputs['y1_und_adjs'], [-1,ctx_maxSeqLen+2,ctx_maxSeqLen+2]))
                    adjs_y2_undirt = np.int32(np.reshape(inputs['y2_und_adjs'], [-1,ctx_maxSeqLen+2,ctx_maxSeqLen+2]))
                    adjs_y3_undirt = np.int32(np.reshape(inputs['y3_und_adjs'], [-1,ctx_maxSeqLen+2,ctx_maxSeqLen+2]))
                    adjs_yy1_undirt = np.int32(np.reshape(inputs['yy1_und_adjs'], [-1,ctx_maxSeqLen+2,ctx_maxSeqLen+2]))
                    adjs_yy2_undirt = np.int32(np.reshape(inputs['yy2_und_adjs'], [-1,ctx_maxSeqLen+2,ctx_maxSeqLen+2]))
                    adjs_yy3_undirt = np.int32(np.reshape(inputs['yy3_und_adjs'], [-1,ctx_maxSeqLen+2,ctx_maxSeqLen+2]))
                    #print(np.shape(adjs_yy3_undirt))
                    

                    
                    
                    node_add_1 = np.sum(np.where((sequence_length_yy1 - sequence_length_y1)>0, (sequence_length_yy1 - sequence_length_y1), 0))
                    node_dec_1 = np.sum(np.where((sequence_length_y1 - sequence_length_yy1)>0, (sequence_length_y1 - sequence_length_yy1), 0))
                    node_add_2 = np.sum(np.where((sequence_length_yy2 - sequence_length_y2)>0, (sequence_length_yy2 - sequence_length_y2), 0))
                    node_dec_2 = np.sum(np.where((sequence_length_y2 - sequence_length_yy2)>0, (sequence_length_y2 - sequence_length_yy2), 0))
                    node_add_3 = np.sum(np.where((sequence_length_yy3 - sequence_length_y3)>0, (sequence_length_yy3 - sequence_length_y3), 0))
                    node_dec_3 = np.sum(np.where((sequence_length_y3 - sequence_length_yy3)>0, (sequence_length_y3 - sequence_length_yy3), 0))

                    
                    node_add_1_sum+=node_add_1
                    node_add_2_sum+=node_add_2
                    node_add_3_sum+=node_add_3
                    node_dec_1_sum+=node_dec_1
                    node_dec_2_sum+=node_dec_2
                    node_dec_3_sum+=node_dec_3

                    edge_add_1 = np.sum(cal_edge_2(adjs_y1_undirt*(-1) + adjs_yy1_undirt)[0])
                    edge_add_2 = np.sum(cal_edge_2(adjs_y2_undirt*(-1) + adjs_yy2_undirt)[0])
                    edge_add_3 = np.sum(cal_edge_2(adjs_y3_undirt*(-1) + adjs_yy3_undirt)[0])
                    edge_dec_1 = np.sum(cal_edge_2(adjs_y1_undirt*(-1) + adjs_yy1_undirt)[1])
                    edge_dec_2 = np.sum(cal_edge_2(adjs_y2_undirt*(-1) + adjs_yy2_undirt)[1])
                    edge_dec_3 = np.sum(cal_edge_2(adjs_y3_undirt*(-1) + adjs_yy3_undirt)[1])
                    edge_C_1 = np.sum(cal_edge(np.abs(adjs_y1_undirt*(-1) + adjs_yy1_undirt)))
                    edge_C_2 = np.sum(cal_edge(np.abs(adjs_y2_undirt*(-1) + adjs_yy2_undirt)))
                    edge_C_3 = np.sum(cal_edge(np.abs(adjs_y3_undirt*(-1) + adjs_yy3_undirt)))

                    edge_C_1_sum += edge_C_1
                    edge_C_2_sum += edge_C_2
                    edge_C_3_sum += edge_C_3
                    
                    edge_add_1_sum+=edge_add_1
                    edge_add_2_sum+=edge_add_2
                    edge_add_3_sum+=edge_add_3

                    edge_dec_1_sum+=edge_dec_1
                    edge_dec_2_sum+=edge_dec_2
                    edge_dec_3_sum+=edge_dec_3

                    samples_num+=np.shape(adjs_y1_undirt)[0] #batch size

                    print("samples_num: ",samples_num)#, tx.utils.map_ids_to_strs(text_ids_y1, vocab), '---------', tx.utils.map_ids_to_strs(text_ids_yy1, vocab))
                    #print((-1)*adjs_y1_undirt,'\n')
                    #print(adjs_yy1_undirt, '\n')
                    #print((-1)*adjs_y1_undirt+adjs_yy1_undirt, '\n')
                    print("edge_num_y1: ", cal_edge(adjs_y1_undirt), "edge_num_yy1: ", cal_edge(adjs_yy1_undirt))
                    
                    print("node_add_1: ", node_add_1_sum/samples_num, "node_dec_1: ", node_dec_1_sum/samples_num)
                    print("edge_add_1: ", edge_add_1_sum/samples_num, "edge_dec_1: ", edge_dec_1_sum/samples_num)
                    print("edge_C_1_sum: ", edge_C_1_sum/samples_num)
                    print('\n')
                except tf.errors.OutOfRangeError:
                    break
    
    avg_node_add_1 = node_add_1_sum / float(samples_num)
    avg_node_add_2 = node_add_2_sum / float(samples_num)
    avg_node_add_3 = node_add_3_sum / float(samples_num)
    
    avg_node_dec_1 = node_dec_1_sum / float(samples_num)
    avg_node_dec_2 = node_dec_2_sum / float(samples_num)
    avg_node_dec_3 = node_dec_3_sum / float(samples_num)


    avg_edge_add_1 = edge_add_1_sum / float(samples_num)
    avg_edge_add_2 = edge_add_2_sum / float(samples_num)
    avg_edge_add_3 = edge_add_3_sum / float(samples_num)
    
    avg_edge_dec_1 = edge_dec_1_sum / float(samples_num)
    avg_edge_dec_2 = edge_dec_2_sum / float(samples_num)
    avg_edge_dec_3 = edge_dec_3_sum / float(samples_num)

    avg_edge_C_1 = edge_C_1_sum / float(samples_num)
    avg_edge_C_2 = edge_C_2_sum / float(samples_num)
    avg_edge_C_3 = edge_C_3_sum / float(samples_num)
    
    print("avg_node_add_1: ", avg_node_add_1, "avg_node_dec_1: ", avg_node_dec_1, "avg_edge_add_1: ", avg_edge_add_1, "avg_edge_dec_1: ", avg_edge_dec_1)
    print("avg_node_add_2: ", avg_node_add_2, "avg_node_dec_2: ", avg_node_dec_2, "avg_edge_add_2: ", avg_edge_add_2, "avg_edge_dec_2: ", avg_edge_dec_2)
    print("avg_node_add_3: ", avg_node_add_3, "avg_node_dec_3: ", avg_node_dec_3, "avg_edge_add_3: ", avg_edge_add_3, "avg_edge_dec_3: ", avg_edge_dec_3)
    print("avg_edge_C_1: ", avg_edge_C_1, "avg_edge_C_2: ", avg_edge_C_2, "avg_edge_C_3: ", avg_edge_C_3)