def fit(self):
     path = path_to_embedding(root=self.path_to_dumps,
                              method='hist_loss_' +
                              str(self.hist_loss_configuration),
                              name=self.graph_name,
                              dim=self.dim)
     if self.use_cached:
         if Path(path).exists():
             E = read_embedding(path)
             print("Loaded cached embedding from " + path)
             return E
     if self.hist_loss_configuration.linearity == 'linear':
         E = self.run(should_stop=self.should_stop)
     elif self.hist_loss_configuration.linearity == 'direct':
         E = self.run_direct(should_stop=self.should_stop)
     elif self.hist_loss_configuration.linearity == 'nonlinear2':
         E = self.run_nonlinear2(should_stop=self.should_stop)
     elif self.hist_loss_configuration.linearity == 'nonlinear2-reduce':
         E = self.run_nonlinear2_reduce(should_stop=self.should_stop)
     elif self.hist_loss_configuration.linearity == 'nonlinear3':
         E = self.run_nonlinear3(should_stop=self.should_stop)
     else:
         raise Exception('Unknown linearity: ' +
                         self.hist_loss_configuration.linearity)
     save_embedding(path, E=np.array(E), normalize=True)
Ejemplo n.º 2
0
    def fit(self):
        path = path_to_embedding(root=self.path_to_dumps,
                                 method='hope',
                                 name=self.graph_name,
                                 dim=self.dim)
        if self.use_cached:
            if Path(path).exists():
                E = read_embedding(path)
                print("Loaded cached embedding from " + path)
                return E

        E = self.learn_embedding()
        save_embedding(path, E=np.array(E))
    def run_nonlinear3(self,
                       patience=100,
                       patience_delta=0.001,
                       learning_rate=0.1,
                       LOG_DIR='./tensorflow_events/',
                       should_stop=None,
                       save_intermediate=False):

        tf.reset_default_graph()
        bin_num = 64
        N = self.graph.number_of_nodes()

        batch_size = min(800, N)
        print("Batch size:", str(batch_size))

        _A_batched = tf.placeholder(tf.float32, [batch_size, N])
        _batch_indxs = tf.placeholder(tf.int32, [batch_size])
        _neg_sampling_indxs = tf.placeholder(tf.int32, [None])
        _W_1 = tf.Variable(tf.random_uniform([N, N], -1.0, 1.0), name="W1")
        _b_1 = tf.Variable(tf.zeros([N, N]), name="b1")
        _W_2 = tf.Variable(tf.random_uniform([N, N], -1.0, 1.0), name="W2")
        _b_2 = tf.Variable(tf.zeros([N, N]), name="b2")
        _W_3 = tf.Variable(tf.random_uniform([N, self.dim], -1.0, 1.0),
                           name="W3")
        _b_3 = tf.Variable(tf.zeros([N, self.dim]), name="b3")
        _b_1_batched = tf.gather(_b_1, indices=_batch_indxs)
        _b_2_batched = tf.gather(_b_2, indices=_batch_indxs)
        _b_3_batched = tf.gather(_b_3, indices=_batch_indxs)

        _E = tf.matmul(
            tf.nn.sigmoid(
                tf.matmul(
                    tf.nn.sigmoid(tf.matmul(_A_batched, _W_1) +
                                  _b_1_batched), _W_2) + _b_2_batched),
            _W_3) + _b_3_batched
        _E_corr = self.preprocess_embedding(_E)
        _A_batched_square = tf.gather(_A_batched, _batch_indxs, axis=1)
        _neg_samples = self.calc_neg_samples(
            _A_batched_square,
            _E_corr,
            method=self.hist_loss_configuration.calc_neg_method)
        _pos_samples = self.calc_pos_samples(
            _A_batched_square,
            _E_corr,
            method=self.hist_loss_configuration.calc_pos_method)

        if N <= 1000:
            _neg_samples = tf.gather(_neg_samples, _neg_sampling_indxs)

        _pos_hist = self.calc_hist(
            _pos_samples,
            method=self.hist_loss_configuration.calc_hist_method,
            bin_num=bin_num)
        _neg_hist = self.calc_hist(
            _neg_samples,
            method=self.hist_loss_configuration.calc_hist_method,
            bin_num=bin_num)
        _loss = -self.calc_loss(
            _neg_hist,
            _pos_hist,
            method=self.hist_loss_configuration.loss_method)

        tf.summary.scalar("loss", _loss)
        _optimizer = tf.train.GradientDescentOptimizer(
            learning_rate=learning_rate).minimize(_loss)
        _summary = tf.summary.merge_all()

        G = self.graph
        A = nx.adj_matrix(G).todense()
        A = self.np_calc_simmatrix(
            A, method=self.hist_loss_configuration.simmatrix_method)

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            writer = tf.summary.FileWriter(os.path.join(
                LOG_DIR,
                '{}_{}_{}_{}_{}').format(str(datetime.now().timestamp()),
                                         self.dim,
                                         self.hist_loss_configuration,
                                         self.graph_name, batch_size),
                                           graph=sess.graph)

            prev_loss = 0
            patience_counter = 0

            for epoch in range(1000):
                if epoch % 50 == 1:
                    print('epoch: ' + str(epoch) + ', loss: ' + str(loss))
                    if save_intermediate:
                        E = np.dot(
                            np.maximum(
                                np.dot(np.maximum(np.dot(A, W_1) + b_1, 0),
                                       W_2) + b_2, 0), W_3) + b_3
                        save_embedding('E_' + str(epoch) + '.txt', np.array(E))
                batch_indxs = np.random.choice(a=N,
                                               size=batch_size).astype('int32')
                A_batched = A[batch_indxs]
                pos_count = np.count_nonzero(A_batched[:, batch_indxs])
                neg_count = batch_size * N - pos_count
                neg_sampling_indxs = np.random.choice(a=neg_count,
                                                      size=pos_count *
                                                      2).astype('int32')

                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()
                pos_hist, neg_hist, loss, _, summary, W_1, b_1, W_2, b_2, W_3, b_3 = sess.run(
                    [
                        _pos_hist,
                        _neg_hist,
                        _loss,
                        _optimizer,
                        _summary,
                        _W_1,
                        _b_1,
                        _W_2,
                        _b_2,
                        _W_3,
                        _b_3,
                    ],
                    feed_dict={
                        _A_batched: A_batched,
                        _batch_indxs: batch_indxs,
                        _neg_sampling_indxs: neg_sampling_indxs,
                    },
                    options=run_options,
                    run_metadata=run_metadata)
                writer.add_run_metadata(run_metadata, "step_{}".format(epoch))

                writer.add_summary(summary, epoch)

                if epoch > 0 and prev_loss - loss > patience_delta:
                    patience_counter = 0
                else:
                    patience_counter += 1

                if patience_counter > patience:
                    print("\tearly stopping at", epoch)
                    break

                if should_stop:
                    if should_stop():
                        break

                prev_loss = loss
            E = np.dot(expit(np.dot(expit(np.dot(A, W_1) + b_1), W_2) + b_2),
                       W_3) + b_3
            return E