예제 #1
0
    def testCompatibleNames(self):
        with self.test_session(use_gpu=self._use_gpu, graph=ops.Graph()):
            cell = core_rnn_cell_impl.LSTMCell(10)
            pcell = core_rnn_cell_impl.LSTMCell(10, use_peepholes=True)
            inputs = [array_ops.zeros([4, 5])] * 6
            core_rnn.static_rnn(cell,
                                inputs,
                                dtype=dtypes.float32,
                                scope="basic")
            core_rnn.static_rnn(pcell,
                                inputs,
                                dtype=dtypes.float32,
                                scope="peephole")
            basic_names = {
                v.name: v.get_shape()
                for v in variables.trainable_variables()
            }

        with self.test_session(use_gpu=self._use_gpu, graph=ops.Graph()):
            cell = lstm_ops.LSTMBlockCell(10)
            pcell = lstm_ops.LSTMBlockCell(10, use_peephole=True)
            inputs = [array_ops.zeros([4, 5])] * 6
            core_rnn.static_rnn(cell,
                                inputs,
                                dtype=dtypes.float32,
                                scope="basic")
            core_rnn.static_rnn(pcell,
                                inputs,
                                dtype=dtypes.float32,
                                scope="peephole")
            block_names = {
                v.name: v.get_shape()
                for v in variables.trainable_variables()
            }

        with self.test_session(use_gpu=self._use_gpu, graph=ops.Graph()):
            cell = lstm_ops.LSTMBlockFusedCell(10)
            pcell = lstm_ops.LSTMBlockFusedCell(10, use_peephole=True)
            inputs = [array_ops.zeros([4, 5])] * 6
            cell(inputs, dtype=dtypes.float32, scope="basic/lstm_cell")
            pcell(inputs, dtype=dtypes.float32, scope="peephole/lstm_cell")
            fused_names = {
                v.name: v.get_shape()
                for v in variables.trainable_variables()
            }

        self.assertEqual(basic_names, block_names)
        self.assertEqual(basic_names, fused_names)
예제 #2
0
 def testLSTMBlockCell(self):
     with self.session(use_gpu=True, graph=ops.Graph()) as sess:
         with variable_scope.variable_scope(
                 "root", initializer=init_ops.constant_initializer(0.5)):
             x = array_ops.zeros([1, 2])
             m0 = array_ops.zeros([1, 2])
             m1 = array_ops.zeros([1, 2])
             m2 = array_ops.zeros([1, 2])
             m3 = array_ops.zeros([1, 2])
             g, ((out_m0, out_m1),
                 (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
                     [lstm_ops.LSTMBlockCell(2) for _ in range(2)],
                     state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
             sess.run([variables.global_variables_initializer()])
             res = sess.run(
                 [g, out_m0, out_m1, out_m2, out_m3], {
                     x.name: np.array([[1., 1.]]),
                     m0.name: 0.1 * np.ones([1, 2]),
                     m1.name: 0.1 * np.ones([1, 2]),
                     m2.name: 0.1 * np.ones([1, 2]),
                     m3.name: 0.1 * np.ones([1, 2])
                 })
             self.assertEqual(len(res), 5)
             self.assertAllClose(res[0], [[0.24024698, 0.24024698]])
             # These numbers are from testBasicLSTMCell and only test c/h.
             self.assertAllClose(res[1], [[0.68967271, 0.68967271]])
             self.assertAllClose(res[2], [[0.44848421, 0.44848421]])
             self.assertAllClose(res[3], [[0.39897051, 0.39897051]])
             self.assertAllClose(res[4], [[0.24024698, 0.24024698]])
  def benchmarkTfRNNLSTMBlockCellTraining(self):
    test_configs = self._GetTestConfig()
    for config_name, config in test_configs.items():
      num_layers = config["num_layers"]
      num_units = config["num_units"]
      batch_size = config["batch_size"]
      seq_length = config["seq_length"]

      with ops.Graph().as_default(), ops.device("/gpu:0"):
        inputs = seq_length * [
            array_ops.zeros([batch_size, num_units], dtypes.float32)
        ]
        cell = lambda: lstm_ops.LSTMBlockCell(num_units=num_units)  # pylint: disable=cell-var-from-loop

        multi_cell = core_rnn_cell_impl.MultiRNNCell(
            [cell() for _ in range(num_layers)])
        outputs, final_state = core_rnn.static_rnn(
            multi_cell, inputs, dtype=dtypes.float32)
        trainable_variables = ops.get_collection(
            ops.GraphKeys.TRAINABLE_VARIABLES)
        gradients = gradients_impl.gradients([outputs, final_state],
                                             trainable_variables)
        training_op = control_flow_ops.group(*gradients)
        self._BenchmarkOp(training_op, "tf_rnn_lstm_block_cell %s %s" %
                          (config_name, self._GetConfigDesc(config)))
예제 #4
0
    def _create_equivalent_canonical_rnn(self,
                                         cudnn_model,
                                         inputs,
                                         use_block_cell,
                                         scope="rnn"):
        if cudnn_model.rnn_mode is not "lstm":
            raise ValueError("%s is not supported!" % cudnn_model.rnn_mode)

        num_units = cudnn_model.num_units
        num_layers = cudnn_model.num_layers

        # To reuse cuDNN-trained models, must set
        # forget_bias, clip_cell = 0, False
        # In LSTMCell and LSTMBlockCell, forget_bias is added in addition to learned
        # bias, whereas cuDNN does not apply the additional bias.
        if use_block_cell:
            # pylint: disable=g-long-lambda
            single_cell = lambda: lstm_ops.LSTMBlockCell(
                num_units, forget_bias=0, clip_cell=False)
            # pylint: enable=g-long-lambda
        else:
            single_cell = lambda: rnn_cell_impl.LSTMCell(num_units,
                                                         forget_bias=0)
        cell = rnn_cell_impl.MultiRNNCell(
            [single_cell() for _ in range(num_layers)])
        return rnn.dynamic_rnn(cell,
                               inputs,
                               dtype=dtypes.float32,
                               time_major=True,
                               scope=scope)
예제 #5
0
    def testLSTMBasicToBlockCell(self):
        with self.session(use_gpu=True) as sess:
            x = array_ops.zeros([1, 2])
            x_values = np.random.randn(1, 2)

            m0_val = 0.1 * np.ones([1, 2])
            m1_val = -0.1 * np.ones([1, 2])
            m2_val = -0.2 * np.ones([1, 2])
            m3_val = 0.2 * np.ones([1, 2])

            initializer = init_ops.random_uniform_initializer(-0.01,
                                                              0.01,
                                                              seed=19890212)
            with variable_scope.variable_scope("basic",
                                               initializer=initializer):
                m0 = array_ops.zeros([1, 2])
                m1 = array_ops.zeros([1, 2])
                m2 = array_ops.zeros([1, 2])
                m3 = array_ops.zeros([1, 2])
                g, ((out_m0, out_m1),
                    (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
                        [
                            rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
                            for _ in range(2)
                        ],
                        state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
                sess.run([variables.global_variables_initializer()])
                basic_res = sess.run(
                    [g, out_m0, out_m1, out_m2, out_m3], {
                        x.name: x_values,
                        m0.name: m0_val,
                        m1.name: m1_val,
                        m2.name: m2_val,
                        m3.name: m3_val
                    })

            with variable_scope.variable_scope("block",
                                               initializer=initializer):
                m0 = array_ops.zeros([1, 2])
                m1 = array_ops.zeros([1, 2])
                m2 = array_ops.zeros([1, 2])
                m3 = array_ops.zeros([1, 2])
                g, ((out_m0, out_m1),
                    (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
                        [lstm_ops.LSTMBlockCell(2) for _ in range(2)],
                        state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
                sess.run([variables.global_variables_initializer()])
                block_res = sess.run(
                    [g, out_m0, out_m1, out_m2, out_m3], {
                        x.name: x_values,
                        m0.name: m0_val,
                        m1.name: m1_val,
                        m2.name: m2_val,
                        m3.name: m3_val
                    })

            self.assertEqual(len(basic_res), len(block_res))
            for basic, block in zip(basic_res, block_res):
                self.assertAllClose(basic, block)
예제 #6
0
  def benchmarkLSTMBlockCellBpropWithDynamicRNN(self):
    print("BlockLSTMCell backward propagation via dynamic_rnn().")
    print("--------------------------------------------------------------")
    print("LSTMBlockCell Seconds per inference.")
    print("batch_size,cell_size,input_size,time_steps,use_gpu,wall_time")
    iters = 10
    for config in benchmarking.dict_product({
        "batch_size": [1, 8, 13, 32, 67, 128],
        "cell_size": [128, 250, 512, 650, 1024, 1350],
        "time_steps": [40],
        "use_gpu": [True, False]
    }):
      with ops.Graph().as_default():
        with benchmarking.device(use_gpu=config["use_gpu"]):
          time_steps = config["time_steps"]
          batch_size = config["batch_size"]
          cell_size = input_size = config["cell_size"]
          inputs = variable_scope.get_variable(
              "x", [time_steps, batch_size, cell_size],
              trainable=False,
              dtype=dtypes.float32)
          with variable_scope.variable_scope(
              "rnn", reuse=variable_scope.AUTO_REUSE):
            w = variable_scope.get_variable(
                "rnn/lstm_cell/kernel",
                shape=[input_size + cell_size, cell_size * 4],
                dtype=dtypes.float32)
            b = variable_scope.get_variable(
                "rnn/lstm_cell/bias",
                shape=[cell_size * 4],
                dtype=dtypes.float32,
                initializer=init_ops.zeros_initializer())
            cell = lstm_ops.LSTMBlockCell(cell_size)
            outputs = rnn.dynamic_rnn(
                cell, inputs, time_major=True, dtype=dtypes.float32)
          grads = gradients_impl.gradients(outputs, [inputs, w, b])
          init_op = variables.global_variables_initializer()

        with session.Session() as sess:
          sess.run(init_op)
          wall_time = benchmarking.seconds_per_run(grads, sess, iters)

        # Print to stdout. If the TEST_REPORT_FILE_PREFIX environment variable
        # is set, this will produce a copy-paste-able CSV file.
        print(",".join(
            map(str, [
                batch_size, cell_size, cell_size, time_steps, config["use_gpu"],
                wall_time
            ])))
        benchmark_name_template = "_".join([
            "LSTMBlockCell_bprop", "BS%(batch_size)i", "CS%(cell_size)i",
            "IS%(cell_size)i", "TS%(time_steps)i", "gpu_%(use_gpu)s"
        ])

        self.report_benchmark(
            name=benchmark_name_template % config,
            iters=iters,
            wall_time=wall_time,
            extras=config)
예제 #7
0
    def benchmarkLSTMBlockCellFpropWithDynamicRNN(self):
        print("BlockLSTMCell forward propagation via dynamic_rnn().")
        print("--------------------------------------------------------------")
        print("LSTMBlockCell Seconds per inference.")
        print("batch_size,cell_size,input_size,time_steps,use_gpu,wall_time")
        iters = 10
        for config in benchmarking.dict_product({
                "batch_size": [1, 8, 13, 32, 67, 128],
                "cell_size": [128, 250, 512, 650, 1024, 1350],
                "time_steps": [40],
                "use_gpu": [True, False],
                "dtype": ["float32", "float16"],
        }):
            dtype = dtypes.float32 if config[
                "dtype"] == "float32" else dtypes.float16
            with ops.Graph().as_default():
                with benchmarking.device(use_gpu=config["use_gpu"]):
                    inputs = variable_scope.get_variable(
                        "x",
                        dtype=dtype,
                        shape=[
                            config["time_steps"], config["batch_size"],
                            config["cell_size"]
                        ])
                    cell = lstm_ops.LSTMBlockCell(config["cell_size"],
                                                  dtype=dtype)
                    outputs = rnn.dynamic_rnn(cell,
                                              inputs,
                                              time_major=True,
                                              dtype=dtype)
                    init_op = variables.global_variables_initializer()

                with session.Session() as sess:
                    sess.run(init_op)
                    wall_time = benchmarking.seconds_per_run(
                        outputs, sess, iters)

                # Print to stdout. If the TEST_REPORT_FILE_PREFIX environment variable
                # is set, this will produce a copy-paste-able CSV file.
                print(",".join(
                    map(str, [
                        config["dtype"], config["batch_size"],
                        config["cell_size"], config["cell_size"],
                        config["time_steps"], config["use_gpu"], wall_time
                    ])))
                benchmark_name_template = "_".join([
                    "LSTMBlockCell_fprop", "DT_%(dtype)s", "BS%(batch_size)i",
                    "CS%(cell_size)i", "IS%(cell_size)i", "TS%(time_steps)i",
                    "gpu_%(use_gpu)s"
                ])

                self.report_benchmark(name=benchmark_name_template % config,
                                      iters=iters,
                                      wall_time=wall_time,
                                      extras=config)
예제 #8
0
  def testNoneDimsWithDynamicRNN(self):
    with self.test_session(use_gpu=self._use_gpu, graph=ops.Graph()) as sess:
      batch_size = 4
      num_steps = 5
      input_dim = 6
      cell_size = 7

      cell = lstm_ops.LSTMBlockCell(cell_size)
      x = array_ops.placeholder(dtypes.float32, shape=(None, None, input_dim))

      output, _ = rnn.dynamic_rnn(
          cell, x, time_major=True, dtype=dtypes.float32)
      sess.run(variables.global_variables_initializer())
      feed = {}
      feed[x] = np.random.randn(num_steps, batch_size, input_dim)
      sess.run(output, feed)
예제 #9
0
def run(**args):
    tf.reset_default_graph()
    tf.set_random_seed(20160408)
    random.seed(20160408)
    exp_name = 'train_data'+str(args['train_data'])+'batch_size'+str(args['batch_size'])+\
               '_dropout'+str(args['dropout'])+'_epochs'+str(args['num_epochs'])+\
               '_px'+str(args['lambda_px'])+'_cpr'+str(args['lambda_cpr'])+'_lr'+str(args['learning_rate'])+\
               '_mode'+args['mode']+'_cube'+args['cube']+'_layer1_init'+args['layer1_init']+\
               '_layer2_init'+args['layer2_init']+'lambda_v'+str(args['lambda_value'])+'delta_acti'+str(args['delta_acti'])+\
               'delta_initv'+str(args['layer2_init_value'])+'loss'+str(args['loss'])+'seed'+str(args['lstm_seed'])
    print('parameters', exp_name)
    start_time = time.time()

    # Read Training/Dev/Test data
    data_dir = './data/' + args['data_dir'] + '/'
    np_matrix, index = Sparse.read_glove_vectors('./data/' + args['data_dir'] +
                                                 '/' + args['vector_file'])

    if args['method'] == 'train':
        train_1, train_2, train_xlabels, train_ylabels, train_xylabels, train_cpr_labels, train_lens1, train_lens2, maxlength, train_phrase1, train_phrase2, train_labels = Sparse.gzread_cpr(
            data_dir + args['train_data'], index)
        dev_1, dev_2, dev_xlabels, dev_ylabels, dev_xylabels, dev_cpr_labels, dev_lens1, dev_lens2, _, dev_phrase1, dev_phrase2, dev_labels = Sparse.gzread_cpr(
            data_dir + args['dev_data'], index)
        test_1, test_2, test_xlabels, test_ylabels, test_xylabels, test_cpr_labels, test_lens1, test_lens2, _, test_phrase1, test_phrase2, test_labels = Sparse.gzread_cpr(
            data_dir + args['test_data'], index)
    elif args[
            'method'] == 'test':  # predict probabilities on test file (probability format)
        test_1, test_2, test_xlabels, test_ylabels, test_xylabels, test_cpr_labels, test_lens1, test_lens2, maxlength, test_phrase1, test_phrase2, test_labels = Sparse.gzread_cpr(
            data_dir + args['test_data'], index)

    graph = tf.get_default_graph()

    # Input -> LSTM -> Outstate
    dropout = tf.placeholder(tf.float32)
    inputs1 = tf.placeholder(tf.int32,
                             [args['batch_size'], None])  # batch size * length
    inputs2 = tf.placeholder(tf.int32,
                             [args['batch_size'], None])  # batch size * length
    x_labels = tf.placeholder(tf.float32, [args['batch_size']])
    y_labels = tf.placeholder(tf.float32, [args['batch_size']])
    xy_labels = tf.placeholder(tf.float32, [args['batch_size']],
                               name='xy_labels')
    cpr_labels = tf.placeholder(tf.float32, [args['batch_size']])
    lengths1 = tf.placeholder(tf.int32, [args['batch_size']])
    lengths2 = tf.placeholder(tf.int32, [args['batch_size']])

    # RNN
    with tf.variable_scope('prob',
                           initializer=tf.variance_scaling_initializer(
                               seed=args['lstm_seed'])):
        # with tf.variable_scope('prob', initializer = tf.variance_scaling_initializer(seed = 1012)):
        # with tf.variable_scope('prob'):

        # LSTM
        embeddings = tf.Variable(np_matrix, dtype=tf.float32, trainable=False)
        bilinear_matrix = tf.Variable(tf.orthogonal_initializer(seed=20160408)(
            shape=(args['hidden_dim'], args['hidden_dim'])),
                                      trainable=True)

        # lstm = tf.contrib.rnn.LSTMCell(args['hidden_dim'], state_is_tuple=True)

        # poe model, including kl loss and correlation loss. Both contains kl between marginals
        if args['mode'] == 'poe':
            lstm = lstm_ops.LSTMBlockCell(args['hidden_dim'])
            lstm = tf.contrib.rnn.DropoutWrapper(lstm,
                                                 output_keep_prob=dropout)
            fstate1, fstate2 = get_lstm_input(args['hidden_dim'], embeddings,
                                              inputs1, inputs2, lengths1,
                                              lengths2, dropout, lstm)
            joint_predicted, x_predicted, y_predicted, cpr_predicted, cpr_predicted_reverse = Probability.poe_model(
                args, fstate1, fstate2)

            if args['loss'] == 'kl':
                print('poe_kl')
                cpr_loss = Probability.kl_loss(args['batch_size'],
                                               cpr_predicted, cpr_labels)

            elif args['loss'] == 'corr':
                print('poe_correlation')
                cpr_loss = corr_prob.corr_loss(x_predicted, y_predicted,
                                               joint_predicted, x_labels,
                                               y_labels, xy_labels)
            else:
                print 'invalid loss'

        elif 'cube' in args['mode']:
            lstm = lstm_ops.LSTMBlockCell(args['hidden_dim'])
            lstm = tf.contrib.rnn.DropoutWrapper(lstm,
                                                 output_keep_prob=dropout)
            # init first feed forward network parameters
            W1 = Layer.W(args['hidden_dim'], args['output_dim'], 'Output')
            b1 = Layer.layer1_bias(args['output_dim'], args['layer1_init'],
                                   -5.0, 'layer1_b')
            # get lstm output
            fstate1, fstate2 = get_lstm_input(args['hidden_dim'], embeddings,
                                              inputs1, inputs2, lengths1,
                                              lengths2, dropout, lstm)
            # init second feed forward network parameters
            W2 = Layer.W(args['hidden_dim'], args['output_dim'], 'Output1')
            b2 = Layer.layer2_bias(args['output_dim'], args['layer2_init'],
                                   args['layer2_init_value'], 'layer2_b')
            # get box embedding via lstm output
            t1_min_embed, t1_max_embed, t2_min_embed, t2_max_embed = cube_exp_prob.box_model_embed(
                args, fstate1, fstate2, W1, b1, W2, b2)
            join_min, join_max, meet_min, meet_max, not_have_meet = cube_exp_prob.box_model_params(
                t1_min_embed, t1_max_embed, t2_min_embed, t2_max_embed)
            joint_predicted, x_predicted, y_predicted, cpr_predicted, cpr_predicted_reverse = cube_exp_prob.box_prob(
                join_min, join_max, meet_min, meet_max, not_have_meet,
                t1_min_embed, t1_max_embed, t2_min_embed, t2_max_embed)

            if args['loss'] == 'corr':
                print('cube_correlation')
                # calculate correlation loss, it's different when we need to use lower bound and xy_label greater than 0.0
                cpr_loss = cube_exp_prob.slicing_where(
                    condition=not_have_meet & (xy_labels > 0),
                    full_input=([
                        join_min, join_max, meet_min, meet_max, t1_min_embed,
                        t1_max_embed, t2_min_embed, t2_max_embed, x_labels,
                        y_labels, xy_labels, not_have_meet
                    ]),
                    true_branch=lambda x: corr_prob.lambda_upper_bound(*x),
                    false_branch=lambda x: corr_prob.lambda_corr_loss(*x))

            elif args['loss'] == 'kl':
                # for training
                # calculate log conditional probability for positive examplse, and negative upper bound if two things are disjoing
                train_cpr_predicted = cube_exp_prob.slicing_where(
                    condition=not_have_meet,
                    full_input=([
                        join_min, join_max, meet_min, meet_max, t1_min_embed,
                        t1_max_embed, t2_min_embed, t2_max_embed
                    ]),
                    true_branch=lambda x: cube_exp_prob.
                    lambda_batch_log_upper_bound(*x),
                    # true_branch= lambda  x: cube_exp_prob.lambda_batch_log_upper_bound_version2(*x),
                    false_branch=lambda x: cube_exp_prob.
                    lambda_batch_log_cube_measure(*x))
                # calculate log(1-p) if overlap, 0 if no overlap
                onem_cpr_predicted = cube_exp_prob.slicing_where(
                    condition=not_have_meet,
                    full_input=tf.tuple([
                        join_min, join_max, meet_min, meet_max, t1_min_embed,
                        t1_max_embed, t2_min_embed, t2_max_embed
                    ]),
                    true_branch=lambda x: cube_exp_prob.
                    lambda_zero_log_upper_bound(*x),
                    false_branch=lambda x: cube_exp_prob.
                    lambda_batch_log_cond_cube_measure(*x))

                whole_cpr_predicted = tf.concat([
                    tf.expand_dims(train_cpr_predicted, 1),
                    tf.expand_dims(onem_cpr_predicted, 1)
                ], 1)
                cpr_loss = tf.nn.softmax_cross_entropy_with_logits(
                    logits=whole_cpr_predicted,
                    labels=cube_exp_prob.create_distribution(
                        cpr_labels, args['batch_size']))

        elif args['mode'] == 'bilinear':
            print('bilinear')
            lstm = lstm_ops.LSTMBlockCell(args['hidden_dim'])
            lstm = tf.contrib.rnn.DropoutWrapper(lstm,
                                                 output_keep_prob=dropout)
            fstate1, fstate2 = get_lstm_input(args['hidden_dim'], embeddings,
                                              inputs1, inputs2, lengths1,
                                              lengths2, dropout, lstm)
            joint_predicted, x_predicted, y_predicted, cpr_predicted, cpr_predicted_reverse = Bilinear.bilinear_model(
                args, fstate1, fstate2, bilinear_matrix)
            cpr_loss = tf.nn.softmax_cross_entropy_with_logits(
                logits=Bilinear.create_log_distribution(
                    cpr_predicted, args['batch_size']),
                labels=Bilinear.create_distribution(cpr_labels,
                                                    args['batch_size']))
        else:
            print('mode is wrong')

        x_loss = tf.nn.softmax_cross_entropy_with_logits(
            logits=cube_exp_prob.create_log_distribution(
                x_predicted, args['batch_size']),
            labels=cube_exp_prob.create_distribution(x_labels,
                                                     args['batch_size']))
        y_loss = tf.nn.softmax_cross_entropy_with_logits(
            logits=cube_exp_prob.create_log_distribution(
                y_predicted, args['batch_size']),
            labels=cube_exp_prob.create_distribution(y_labels,
                                                     args['batch_size']))
        mean_loss = tf.reduce_mean(args['lambda_px'] * (x_loss + y_loss) +
                                   args['lambda_cpr'] * cpr_loss)

    ## Learning ##
    optimizer = tf.train.AdamOptimizer(args['learning_rate'])
    varlist = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   scope='prob')
    gradient = optimizer.compute_gradients(mean_loss, var_list=varlist)
    # train_op = optimizer.apply_gradients(gradient)

    train_op = optimizer.minimize(mean_loss, var_list=varlist)

    tf.set_random_seed(20160408)
    saver = tf.train.Saver(max_to_keep=10)
    with tf.Session() as sess:
        tf.set_random_seed(20160408)
        if args['method'] == 'train':
            sess.run(tf.global_variables_initializer())
            Trainer = Training(sess, train_op, mean_loss, x_loss, cpr_loss,
                               x_predicted, y_predicted, joint_predicted,
                               cpr_predicted, cpr_predicted_reverse, inputs1,
                               inputs2, x_labels, y_labels, xy_labels,
                               cpr_labels, lengths1, lengths2,
                               args['batch_size'], maxlength, dropout,
                               args['dropout'], gradient)
            best_dev = float('inf')
            for e in range(args['num_epochs']):
                print("Outer epoch %d" % e)
                kl_div = Trainer.train(train_1, train_2, train_xlabels,
                                       train_ylabels, train_xylabels,
                                       train_cpr_labels, dev_1, dev_2,
                                       dev_xlabels, dev_ylabels, dev_xylabels,
                                       dev_cpr_labels, train_lens1,
                                       train_lens2, dev_lens1, dev_lens2)
                if kl_div < best_dev:
                    save_path = saver.save(sess,
                                           "./tmp/" + exp_name + "_best.ckpt")
                    print("Best model saved in file: %s" % save_path)
                    best_dev = kl_div
                print("--------------- %s seconds ---------------" %
                      (time.time() - start_time))
                save_path = saver.save(
                    sess, "./tmp/" + exp_name + "_" + str(e) + ".ckpt")
                print("Model saved in file: %s" % save_path)
        elif args['method'] == 'test':
            saver.restore(sess, "./tmp/" + exp_name + "_best.ckpt")
            Trainer = Training(sess, train_op, mean_loss, x_loss, cpr_loss,
                               x_predicted, y_predicted, joint_predicted,
                               cpr_predicted, cpr_predicted_reverse, inputs1,
                               inputs2, x_labels, y_labels, xy_labels,
                               cpr_labels, lengths1, lengths2,
                               args['batch_size'], maxlength, dropout,
                               args['dropout'], gradient)
            test_loss, x_pred, x_corr, x_kl, y_pred, y_corr, y_kl, xy_pred, xy_corr, cpr_pred, cpr_xy_corr, cpr_kl, cpr_pred_reverse, corr_loss, tp, tn, fp, fn = Trainer.eval(
                test_1, test_2, test_xlabels, test_ylabels, test_xylabels,
                test_cpr_labels, test_lens1, test_lens2)
            #test_loss, x_pred, x_corr, y_pred, y_corr, xy_pred, xy_corr, cpr_pred, cpr_xy_corr, cpr_kl, cpr_pred_reverse, corr_loss, tp, tn, fp, fn = Trainer.eval(test_1, test_2, test_xlabels, test_ylabels, test_xylabels, test_cpr_labels, test_lens1, test_lens2)
            out_file = open(
                data_dir + 'after_respon_' + exp_name + "_" +
                args['test_data'].split(".")[0] + "_pred_prob.txt", "w")
            print('tp', tp)
            print('tn', tn)
            print('fp', fp)
            print('fn', fn)
            neg_count = 0
            ind_count = 0
            for idx, cpr_prob in enumerate(cpr_pred):
                cpr_prob = np.exp(cpr_prob)
                cpr_prob_rev = np.exp(cpr_pred_reverse[idx])
                x_prob = np.exp(x_pred[idx])
                y_prob = np.exp(y_pred[idx])
                xy_prob = np.exp(xy_pred[idx])
                pmi = np.log(xy_prob / (x_prob * y_prob)) / -np.log(xy_prob)
                s1 = [str(a) for a in test_1[idx]]
                s2 = [str(a) for a in test_2[idx]]
                p1 = test_phrase1[idx]
                p2 = test_phrase2[idx]
                out_file.write(
                    "%f\t%f\t%f\t%f\t%f\t%f\t%s\t%s\t%s\t%s" %
                    (x_prob, y_prob, xy_prob, pmi, cpr_prob, cpr_prob_rev,
                     " ".join(s1), p1, " ".join(s2), p2))
                out_file.write(
                    str(x_prob) + "\t" + str(y_prob) + "\t" + str(xy_prob) +
                    "\t" + str(pmi) + "\t" + str(cpr_prob) + "\t" +
                    str(cpr_prob_rev) + " ".join(s1) + "\t" + p1 + "\t" +
                    " ".join(s2) + "\t" + p2)

                if len(test_labels) == len(cpr_pred):
                    out_file.write("\t%s" % test_labels[idx])
                out_file.write("\n")
            out_file.close()
            print("X Prediction correlation: %f" % x_corr)
            print("X KL divergence: %f" % x_kl)
            print("Y Prediction correlation: %f" % y_corr)
            print("Y KL divergence: %f" % y_kl)

            print("Prediction correlation: %f" % cpr_xy_corr)
            print("KL divergence: %f" % cpr_kl)
            print("Number of negative correlation", neg_count)
            print("Number of independence", ind_count)
            print("--- %s seconds ---" % (time.time() - start_time))