Пример #1
0
    def train(self):
        train_losses = []
        val_losses = []
        model_path = os.path.join(self.model_dir, self.model_file)

        print("Training model...\n")
        timer = Timer()
        timer.tic()

        x = self.data.x.to(self.device)
        train_pos_edge_index = self.data.train_pos_edge_index.to(self.device)

        for epoch in range(self.epochs):
            print("Epoch: {}".format(epoch + 1))
            self.model.train()
            self.optimizer.zero_grad()
            z = self.model.encode(x, train_pos_edge_index)
            loss = self.model.recon_loss(z, train_pos_edge_index)
            if self.model_name == "ARGVA":
                loss = loss + (1 / self.data.num_nodes) * self.model.kl_loss()
            loss += self.dis_loss_para * self.model.discriminator_loss(z) + \
                self.reg_loss_para * self.model.reg_loss(z)
            loss.backward()
            self.optimizer.step()

            # Evaluate on validation data
            self.model.eval()
            with torch.no_grad():
                train_losses.append(loss.cpu().detach().numpy())

                # Compute validation statistics
                val_pos_edge_index = self.data.val_pos_edge_index.to(
                    self.device)
                val_neg_edge_index = self.data.val_neg_edge_index.to(
                    self.device)
                z = self.model.encode(x, train_pos_edge_index)
                val_loss = self.model.recon_loss(z, train_pos_edge_index)
                if self.model_name == "ARGVA":
                    val_loss += (1 /
                                 self.data.num_nodes) * self.model.kl_loss()
                val_loss += self.dis_loss_para * self.model.discriminator_loss(
                    z) + self.reg_loss_para * self.model.reg_loss(z)
                val_losses.append(val_loss.cpu().detach().numpy())
                if val_losses[-1] == min(val_losses):
                    print("\tSaving model...")
                    torch.save(self.model.state_dict(), model_path)
                    print("\tSaved.")
                print("\ttrain_loss=", "{:.5f}".format(loss), "val_loss=",
                      "{:.5f}".format(val_loss))

        print("Finished training.\n")
        training_time = timer.toc()
        self._plot_losses(train_losses, val_losses)
        self._print_stats(train_losses, val_losses, training_time)
    def __init__(self, embedding_type, gpu=0):
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu)

        self.embedding_type = embedding_type
        self.embeddings_parser = EmbeddingsParser(gpu)
        self.timer = Timer()
        self.path_persistent = os.path.join(
            os.path.dirname(os.path.realpath(__file__)), "..", "..", "..",
            "data", "interim", "han", self.embedding_type)
        if not os.path.exists(self.path_persistent):
            os.makedirs(self.path_persistent)
Пример #3
0
    def __init__(self, embedding_type, graph_type, threshold=2, gpu=0):
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu)

        self.embedding_type = embedding_type
        self.graph_type = graph_type
        self.threshold = threshold
        self.embeddings_parser = EmbeddingsParser(gpu)
        self.timer = Timer()
        self.path_persistent = os.path.join(
            os.path.dirname(os.path.realpath(__file__)), "..", "..", "..",
            "data", "interim", "graphsage", self.embedding_type,
            self.graph_type)
        if not os.path.isdir(self.path_persistent):
            os.mkdir(self.path_persistent)
Пример #4
0
 def __init__(self):
     self.parser = FileParser()
     self.persistent = {}
     self.timer = Timer()
     self.processes = {
         "chapters_books": {
             "process_data": "_process_data_chapters_books",
             "persistent_file": os.path.join(self.path,
                                             "chapters_books.pkl")
         },
         "chapters_all_scigraph_citations": {
             "process_data":
             "_process_data_chapters_all_scigraph_citations",
             "persistent_file":
             os.path.join(self.path, "chapters_all_scigraph_citations.pkl")
         },
         "chapters_confproc_scigraph_citations": {
             "process_data":
             "_process_data_chapters_confproc_scigraph_citations",
             "persistent_file":
             os.path.join(self.path,
                          "chapters_confproc_scigraph_citations.pkl")
         },
         "books_conferences": {
             "process_data": "_process_data_books_conferences",
             "persistent_file": os.path.join(self.path,
                                             "books_conferences.pkl")
         },
         "author_id_chapters": {
             "process_data":
             "_process_data_author_id_chapters",
             "persistent_file":
             os.path.join(self.path, "author_id_chapters.pkl")
         },
         "author_name_chapters": {
             "process_data":
             "_process_data_author_name_chapters",
             "persistent_file":
             os.path.join(self.path, "author_name_chapters.pkl")
         },
         "confproc_scigraph_citations_chapters": {
             "process_data":
             "_process_data_confproc_scigraph_citations_chapters",
             "persistent_file":
             os.path.join(self.path,
                          "confproc_scigraph_citations_chapters.pkl")
         }
     }
    def __init__(self,
                 embedding_type,
                 dataset,
                 graph_type="directed",
                 threshold=2,
                 gpu=0):
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu)

        self.embedding_type = embedding_type
        self.dataset = dataset
        self.graph_type = graph_type
        self.threshold = threshold
        self.embeddings_parser = EmbeddingsParser(gpu)
        self.timer = Timer()
        self.path_persistent = os.path.join(
            os.path.dirname(os.path.realpath(__file__)), "..", "..", "..",
            "data", "interim", "gat", self.embedding_type, self.dataset)
        if not os.path.exists(self.path_persistent):
            os.makedirs(self.path_persistent)
    def train(self, data):
        if not self._load_model_classifier():
            print("Classifier not trained yet. Training now...")
            timer = Timer()
            timer.tic()

            print("Loading the training embeddings...")
            if not self._load_train_embeddings():
                print("The pretrained embeddings are missing.")
            else:
                print("Loaded.")
            training_ids = list(data.chapter)
            training_embeddings = self.pretrained_embeddings[[
                self.pretrained_embeddings_id_map[id] for id in training_ids
            ]]

            self.label_encoder = LabelEncoder()
            self.labels = self.label_encoder.fit_transform(
                data.conferenceseries)
            self.classifier.fit(training_embeddings, self.labels)
            self._save_model_classifier()

            print("Training finished.")
            timer.toc()
Пример #7
0
    def train(self):
        print("Loading data...")
        adj_list, features_list, y_train, y_val, train_mask, val_mask = load_data(
            self.embedding_type)
        print("Loaded.")

        nb_nodes = features_list[0].shape[0]
        ft_size = features_list[0].shape[1]
        nb_classes = y_train.shape[1]

        features_list = [features[np.newaxis] for features in features_list]
        y_train = y_train[np.newaxis]
        y_val = y_val[np.newaxis]
        train_mask = train_mask[np.newaxis]
        val_mask = val_mask[np.newaxis]
        biases_list = [preprocess_adj_bias(adj) for adj in adj_list]

        print("Training model...")
        timer = Timer()
        timer.tic()

        print(
            "Parameters: batch size={}, nb_nodes={}, ft_size={}, nb_classes={}\n"
            .format(self.batch_size, nb_nodes, ft_size, nb_classes))

        model = HAN(self.model,
                    self.hid_units,
                    self.n_heads,
                    nb_classes,
                    nb_nodes,
                    l2_coef=self.weight_decay,
                    ffd_drop=self.ffd_drop,
                    attn_drop=self.attn_drop,
                    activation=self.nonlinearity,
                    residual=self.residual)

        vlss_mn = np.inf
        vacc_mx = 0.0
        curr_step = 0

        train_loss_avg = 0
        train_acc_avg = 0
        val_loss_avg = 0
        val_acc_avg = 0

        train_losses = []
        val_losses = []
        train_accuracies = []
        val_accuracies = []

        for epoch in range(self.epochs):
            print("\nEpoch {}".format(epoch))

            # Training
            tr_step = 0
            tr_size = features_list[0].shape[0]
            while tr_step * self.batch_size < tr_size:
                feats_list = [
                    features[tr_step * self.batch_size:(tr_step + 1) *
                             self.batch_size] for features in features_list
                ]

                _, train_embed, att_val, acc_tr, loss_value_tr = self._train(
                    model=model,
                    inputs_list=feats_list,
                    bias_mat_list=biases_list,
                    lbl_in=y_train[tr_step * self.batch_size:(tr_step + 1) *
                                   self.batch_size],
                    msk_in=train_mask[tr_step * self.batch_size:(tr_step + 1) *
                                      self.batch_size])

                train_loss_avg += loss_value_tr
                train_acc_avg += acc_tr
                tr_step += 1

            # Validation
            vl_step = 0
            vl_size = features_list[0].shape[0]

            while vl_step * self.batch_size < vl_size:
                feats_list = [
                    features[vl_step * self.batch_size:(vl_step + 1) *
                             self.batch_size] for features in features_list
                ]

                _, val_embed, att_val, acc_vl, loss_value_vl = self.evaluate(
                    model=model,
                    inputs_list=feats_list,
                    bias_mat_list=biases_list,
                    lbl_in=y_val[vl_step * self.batch_size:(vl_step + 1) *
                                 self.batch_size],
                    msk_in=val_mask[vl_step * self.batch_size:(vl_step + 1) *
                                    self.batch_size])

                val_loss_avg += loss_value_vl
                val_acc_avg += acc_vl
                vl_step += 1

            print(
                'Training: loss = %.5f, acc = %.5f | Val: loss = %.5f, acc = %.5f'
                % (train_loss_avg / tr_step, train_acc_avg / tr_step,
                   val_loss_avg / vl_step, val_acc_avg / vl_step))
            train_losses.append(train_loss_avg / tr_step)
            val_losses.append(val_loss_avg / vl_step)
            train_accuracies.append(train_acc_avg / tr_step)
            val_accuracies.append(val_acc_avg / vl_step)

            # Early Stopping
            if val_acc_avg / vl_step >= vacc_mx or val_loss_avg / vl_step <= vlss_mn:
                if val_acc_avg / vl_step >= vacc_mx and val_loss_avg / vl_step <= vlss_mn:
                    vacc_early_model = val_acc_avg / vl_step
                    vlss_early_model = val_loss_avg / vl_step
                    working_weights = model.get_weights()
                    print(
                        "Minimum validation loss ({}), maximum accuracy ({}) so far  at epoch {}."
                        .format(val_loss_avg / vl_step, val_acc_avg / vl_step,
                                epoch))
                    self._save_model(model)
                vacc_mx = np.max((val_acc_avg / vl_step, vacc_mx))
                vlss_mn = np.min((val_loss_avg / vl_step, vlss_mn))
                curr_step = 0
            else:
                curr_step += 1
                if curr_step == self.patience:
                    print("Early stop! Min loss: {}, Max accuracy: {}".format(
                        vlss_mn, vacc_mx))
                    print("Early stop model validation loss: {}, accuracy: {}".
                          format(vlss_early_model, vacc_early_model))
                    model.set_weights(working_weights)
                    break

            train_loss_avg = 0
            train_acc_avg = 0
            val_loss_avg = 0
            val_acc_avg = 0

        print("Training finished.")

        training_time = timer.toc()
        train_losses = [x.numpy() for x in train_losses]
        val_losses = [x.numpy() for x in val_losses]
        train_accuracies = [x.numpy() for x in train_accuracies]
        val_accuracies = [x.numpy() for x in val_accuracies]
        self._plot_losses(train_losses, val_losses)
        self._plot_accuracies(train_accuracies, val_accuracies)
        self._print_stats(train_losses, val_losses, train_accuracies,
                          val_accuracies, training_time)
Пример #8
0
    def inference(self, test_data, gpu_mem_fraction=None):
        print("Inference.")
        timer = Timer()
        timer.tic()

        G = test_data[0]
        features = test_data[1]
        id_map = test_data[2]
        class_map = test_data[4]
        if isinstance(list(class_map.values())[0], list):
            num_classes = len(list(class_map.values())[0])
        else:
            num_classes = len(set(class_map.values()))

        if not features is None:
            # pad with dummy zero vector
            features = np.vstack([features, np.zeros((features.shape[1], ))])

        placeholders = self._construct_placeholders(num_classes)
        minibatch = NodeMinibatchIterator(G,
                                          id_map,
                                          placeholders,
                                          class_map,
                                          num_classes,
                                          batch_size=self.batch_size,
                                          max_degree=self.max_degree)

        adj_info_ph = tf.compat.v1.placeholder(tf.int32,
                                               shape=minibatch.adj.shape)
        adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

        model = self._create_model(num_classes, placeholders, features,
                                   adj_info, minibatch)

        config = tf.compat.v1.ConfigProto(
            log_device_placement=self.log_device_placement)
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True

        # Initialize session
        sess = tf.compat.v1.Session(config=config)
        merged = tf.compat.v1.summary.merge_all()
        #        summary_writer = tf.summary.FileWriter(self._log_dir(), sess.graph)

        # Initialize model saver
        saver = tf.compat.v1.train.Saver(max_to_keep=self.epochs)

        # Init variables
        sess.run(tf.compat.v1.global_variables_initializer(),
                 feed_dict={adj_info_ph: minibatch.adj})

        # Restore model
        print("Restoring trained model.")
        checkpoint_file = os.path.join(self._log_dir(), "model.ckpt")
        ckpt = tf.compat.v1.train.get_checkpoint_state(checkpoint_file)
        if checkpoint_file:
            saver.restore(sess, checkpoint_file)
            print("Model restored.")
        else:
            print("This model checkpoint does not exist. The model might " +
                  "not be trained yet or the checkpoint is invalid.")

        val_adj_info = tf.compat.v1.assign(adj_info, minibatch.test_adj)
        sess.run(val_adj_info.op)

        print("Computing predictions...")
        t_test = time.time()
        finished = False
        val_losses = []
        val_preds = []
        nodes = []
        iter_num = 0
        while not finished:
            feed_dict_val, _, finished, nodes_subset = minibatch.incremental_node_val_feed_dict(
                self.batch_size, iter_num, test=True)
            node_outs_val = sess.run([model.preds, model.loss],
                                     feed_dict=feed_dict_val)
            val_preds.append(node_outs_val[0])
            val_losses.append(node_outs_val[1])
            nodes.extend(nodes_subset)
            iter_num += 1
        val_preds = np.vstack(val_preds)
        print("Computed.")

        # Return only the embeddings of the test nodes
        test_preds_ids = {}
        for i, node in enumerate(nodes):
            test_preds_ids[node] = i
        test_nodes = [n for n in G.nodes() if G.node[n]['test']]
        test_preds = val_preds[[test_preds_ids[id] for id in test_nodes]]
        timer.toc()
        sess.close()
        return test_nodes, test_preds
Пример #9
0
    def train(self, train_data, test_data=None):
        print("Training model...")
        timer = Timer()
        timer.tic()

        G = train_data[0]
        features = train_data[1]
        id_map = train_data[2]
        class_map = train_data[4]
        if isinstance(list(class_map.values())[0], list):
            num_classes = len(list(class_map.values())[0])
        else:
            num_classes = len(set(class_map.values()))

        if not features is None:
            # pad with dummy zero vector
            features = np.vstack([features, np.zeros((features.shape[1], ))])

        placeholders = self._construct_placeholders(num_classes)
        minibatch = NodeMinibatchIterator(G,
                                          id_map,
                                          placeholders,
                                          class_map,
                                          num_classes,
                                          batch_size=self.batch_size,
                                          max_degree=self.max_degree)

        adj_info_ph = tf.compat.v1.placeholder(tf.int32,
                                               shape=minibatch.adj.shape)
        adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

        model = self._create_model(num_classes, placeholders, features,
                                   adj_info, minibatch)

        config = tf.compat.v1.ConfigProto(
            log_device_placement=self.log_device_placement)
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True

        # Initialize session
        sess = tf.compat.v1.Session(config=config)
        merged = tf.compat.v1.summary.merge_all()
        #        summary_writer = tf.summary.FileWriter(self._log_dir(), sess.graph)

        # Initialize model saver
        saver = tf.compat.v1.train.Saver(max_to_keep=self.epochs)

        # Init variables
        sess.run(tf.compat.v1.global_variables_initializer(),
                 feed_dict={adj_info_ph: minibatch.adj})

        # Train model
        total_steps = 0
        avg_time = 0.0
        epoch_val_costs = []

        train_losses = []
        validation_losses = []

        train_adj_info = tf.compat.v1.assign(adj_info, minibatch.adj)
        val_adj_info = tf.compat.v1.assign(adj_info, minibatch.test_adj)

        for epoch in range(self.epochs):
            minibatch.shuffle()

            iter = 0
            print('Epoch: %04d' % (epoch))
            epoch_val_costs.append(0)
            train_loss_epoch = []
            validation_loss_epoch = []
            while not minibatch.end():
                # Construct feed dictionary
                feed_dict, labels = minibatch.next_minibatch_feed_dict()
                feed_dict.update({placeholders['dropout']: self.dropout})

                t = time.time()
                # Training step
                outs = sess.run(
                    [merged, model.opt_op, model.loss, model.preds],
                    feed_dict=feed_dict)
                train_cost = outs[2]
                train_loss_epoch.append(train_cost)

                if iter % self.validate_iter == 0:
                    # Validation
                    sess.run(val_adj_info.op)
                    if self.validate_batch_size == -1:
                        val_cost, val_f1_mic, val_f1_mac, duration = self._incremental_evaluate(
                            sess, model, minibatch, self.batch_size)
                    else:
                        val_cost, val_f1_mic, val_f1_mac, duration = self._evaluate(
                            sess, model, minibatch, self.validate_batch_size)
                    sess.run(train_adj_info.op)
                    epoch_val_costs[-1] += val_cost
                    validation_loss_epoch.append(val_cost)

#                if total_steps % self.print_every == 0:
#                    summary_writer.add_summary(outs[0], total_steps)

# Print results
                avg_time = (avg_time * total_steps + time.time() -
                            t) / (total_steps + 1)

                if total_steps % self.print_every == 0:
                    train_f1_mic, train_f1_mac = self._calc_f1(
                        labels, outs[-1])
                    print("Iter:", '%04d' % iter, "train_loss=",
                          "{:.5f}".format(train_cost), "train_f1_mic=",
                          "{:.5f}".format(train_f1_mic), "train_f1_mac=",
                          "{:.5f}".format(train_f1_mac), "val_loss=",
                          "{:.5f}".format(val_cost), "val_f1_mic=",
                          "{:.5f}".format(val_f1_mic), "val_f1_mac=",
                          "{:.5f}".format(val_f1_mac), "time=",
                          "{:.5f}".format(avg_time))

                iter += 1
                total_steps += 1

                if total_steps > self.max_total_steps:
                    break

            # Keep track of train and validation losses per epoch
            train_losses.append(sum(train_loss_epoch) / len(train_loss_epoch))
            validation_losses.append(
                sum(validation_loss_epoch) / len(validation_loss_epoch))

            # If the epoch has the lowest validation loss so far
            if validation_losses[-1] == min(validation_losses):
                print(
                    "Minimum validation loss so far ({}) at epoch {}.".format(
                        validation_losses[-1], epoch))
                # Save model at each epoch
                print("Saving model at epoch {}.".format(epoch))
                saver.save(sess, os.path.join(self._log_dir(), "model.ckpt"))

            if total_steps > self.max_total_steps:
                break

        print("Optimization Finished!")

        training_time = timer.toc()
        self._plot_losses(train_losses, validation_losses)
        self._print_stats(train_losses, validation_losses, training_time)

        sess.run(val_adj_info.op)
        val_cost, val_f1_mic, val_f1_mac, duration = self._incremental_evaluate(
            sess, model, minibatch, self.batch_size)
        print("Full validation stats:", "loss=", "{:.5f}".format(val_cost),
              "f1_micro=", "{:.5f}".format(val_f1_mic), "f1_macro=",
              "{:.5f}".format(val_f1_mac), "time=", "{:.5f}".format(duration))
        with open(self._log_dir() + "val_stats.txt", "w") as fp:
            fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}".
                     format(val_cost, val_f1_mic, val_f1_mac, duration))
Пример #10
0
    def __init__(self):
        self.timer = Timer()
        self.persistent = {}
        self.processes = {
                # Old datasets
                "old_books": {
                        "filename": os.path.join(self.path_raw,
                                                 old_books_file),
                        "process_line": "_process_line_old_books",
                        "persistent_file": os.path.join(self.path_persistent,
                                                        "old_books.pkl"),
                        "persistent_variable": [],
                        "dataset_format": "ntriples"
                        },
                "old_books_new_books": {
                        "filename": os.path.join(self.path_raw,
                                                 old_books_file),
                        "process_line": "_process_line_old_books_new_books",
                        "persistent_file": os.path.join(
                                self.path_persistent,
                                "old_books_new_books.pkl"),
                        "persistent_variable": {},
                        "dataset_format": "ntriples"
                        },
                "old_books_conferences": {
                        "filename": os.path.join(self.path_raw,
                                                 old_books_file),
                        "process_line": "_process_line_old_books_conferences",
                        "persistent_file": os.path.join(
                                self.path_persistent,
                                "old_books_conferences.pkl"),
                        "persistent_variable": {},
                        "dataset_format": "ntriples"
                        },
                "conferences": {
                        "filename": os.path.join(self.path_raw,
                                                 old_conferences_file),
                        "process_line": "_process_line_conferences",
                        "persistent_file": os.path.join(self.path_persistent,
                                                        "conferences.pkl"),
                        "persistent_variable": [],
                        "dataset_format": "ntriples"
                        },
                "conferences_name": {
                        "filename": os.path.join(self.path_raw,
                                                 old_conferences_file),
                        "process_line": "_process_line_conferences_name",
                        "persistent_file": os.path.join(
                                self.path_persistent, "conferences_name.pkl"),
                        "persistent_variable": {},
                        "dataset_format": "ntriples"
                        },
                "conferences_acronym": {
                        "filename": os.path.join(self.path_raw,
                                                 old_conferences_file),
                        "process_line": "_process_line_conferences_acronym",
                        "persistent_file": os.path.join(
                                self.path_persistent,
                                "conferences_acronym.pkl"),
                        "persistent_variable": {},
                        "dataset_format": "ntriples"
                        },
                "conferences_city": {
                        "filename": os.path.join(self.path_raw,
                                                 old_conferences_file),
                        "process_line": "_process_line_conferences_city",
                        "persistent_file": os.path.join(
                                self.path_persistent,
                                "conferences_city.pkl"),
                        "persistent_variable": {},
                        "dataset_format": "ntriples"
                        },
                "conferences_country": {
                        "filename": os.path.join(self.path_raw,
                                                 old_conferences_file),
                        "process_line": "_process_line_conferences_country",
                        "persistent_file": os.path.join(
                                self.path_persistent,
                                "conferences_country.pkl"),
                        "persistent_variable": {},
                        "dataset_format": "ntriples"
                        },
                "conferences_year": {
                        "filename": os.path.join(self.path_raw,
                                                 old_conferences_file),
                        "process_line": "_process_line_conferences_year",
                        "persistent_file": os.path.join(
                                self.path_persistent,
                                "conferences_year.pkl"),
                        "persistent_variable": {},
                        "dataset_format": "ntriples"
                        },
                "conferences_datestart": {
                        "filename": os.path.join(self.path_raw,
                                                 old_conferences_file),
                        "process_line": "_process_line_conferences_datestart",
                        "persistent_file": os.path.join(
                                self.path_persistent,
                                "conferences_datestart.pkl"),
                        "persistent_variable": {},
                        "dataset_format": "ntriples"
                        },
                "conferences_dateend": {
                        "filename": os.path.join(self.path_raw,
                                                 old_conferences_file),
                        "process_line": "_process_line_conferences_dateend",
                        "persistent_file": os.path.join(
                                self.path_persistent,
                                "conferences_dateend.pkl"),
                        "persistent_variable": {},
                        "dataset_format": "ntriples"
                        },
                "conferences_conferenceseries": {
                        "filename": os.path.join(self.path_raw,
                                                 old_conferences_file),
                        "process_line": "_process_line_conferences_conferenceseries",
                        "persistent_file": os.path.join(
                                self.path_persistent,
                                "conferences_conferenceseries.pkl"),
                        "persistent_variable": {},
                        "dataset_format": "ntriples"
                        },
                "conferenceseries": {
                        "filename": os.path.join(self.path_raw,
                                                 old_conferences_file),
                        "process_line": "_process_line_conferenceseries",
                        "persistent_file": os.path.join(
                                self.path_persistent, "conferenceseries.pkl"),
                        "persistent_variable": [],
                        "dataset_format": "ntriples"
                        },
                "conferenceseries_name": {
                        "filename": os.path.join(self.path_raw,
                                                 old_conferences_file),
                        "process_line": "_process_line_conferenceseries_name",
                        "persistent_file": os.path.join(
                                self.path_persistent,
                                "conferenceseries_name.pkl"),
                        "persistent_variable": {},
                        "dataset_format": "ntriples"
                        },

                # New datasets
                "books": {
                        "filename": os.path.join(self.path_raw, books_file),
                        "process_line": "_process_line_books",
                        "persistent_file": os.path.join(self.path_persistent,
                                                        "books.pkl"),
                        "persistent_variable": [],
                        "dataset_format": "json"
                        },
                "isbn_books": {
                        "filename": os.path.join(self.path_raw, books_file),
                        "process_line": "_process_line_isbn_books",
                        "persistent_file": os.path.join(self.path_persistent,
                                                        "isbn_books.pkl"),
                        "persistent_variable": {},
                        "dataset_format": "json"
                        },
                "authors_name": {
                        "filename": os.path.join(self.path_raw, authors_file),
                        "process_line": "_process_line_authors_name",
                        "persistent_file": os.path.join(self.path_persistent,
                                                        "authors_name.pkl"),
                        "persistent_variable": {},
                        "dataset_format": "json"
                        },
                "chapters": {
                        "filename": os.path.join(self.path_raw, chapters_file),
                        "process_line": "_process_line_chapters",
                        "persistent_file": os.path.join(self.path_persistent,
                                                        "chapters.pkl"),
                        "persistent_variable": [],
                        "dataset_format": "json"
                        },
                "chapters_title": {
                        "filename": os.path.join(self.path_raw, chapters_file),
                        "process_line": "_process_line_chapters_title",
                        "persistent_file": os.path.join(self.path_persistent,
                                                        "chapters_title.pkl"),
                        "persistent_variable": {},
                        "dataset_format": "json"
                        },
                "chapters_year": {
                        "filename": os.path.join(self.path_raw, chapters_file),
                        "process_line": "_process_line_chapters_year",
                        "persistent_file": os.path.join(self.path_persistent,
                                                        "chapters_year.pkl"),
                        "persistent_variable": {},
                        "dataset_format": "json"
                        },
                "chapters_language": {
                        "filename": os.path.join(self.path_raw, chapters_file),
                        "process_line": "_process_line_chapters_language",
                        "persistent_file": os.path.join(
                                self.path_persistent, "chapters_language.pkl"),
                        "persistent_variable": {},
                        "dataset_format": "json"
                        },
                "chapters_abstract": {
                        "filename": os.path.join(self.path_raw, chapters_file),
                        "process_line": "_process_line_chapters_abstract",
                        "persistent_file": os.path.join(
                                self.path_persistent, "chapters_abstract.pkl"),
                        "persistent_variable": {},
                        "dataset_format": "json"
                        },
                "chapters_authors": {
                        "filename": os.path.join(self.path_raw, chapters_file),
                        "process_line": "_process_line_chapters_authors",
                        "persistent_file": os.path.join(
                                self.path_persistent, "chapters_authors.pkl"),
                        "persistent_variable": {},
                        "dataset_format": "json"
                        },
                "chapters_authors_name": {
                        "filename": os.path.join(self.path_raw, chapters_file),
                        "process_line": "_process_line_chapters_authors_name",
                        "persistent_file": os.path.join(
                                self.path_persistent,
                                "chapters_authors_name.pkl"),
                        "persistent_variable": {},
                        "dataset_format": "json"
                        },
                "chapters_all_citations": {
                        "filename": os.path.join(self.path_raw, chapters_file),
                        "process_line": "_process_line_chapters_all_citations",
                        "persistent_file": os.path.join(
                                self.path_persistent,
                                "chapters_all_citations.pkl"),
                        "persistent_variable": {},
                        "dataset_format": "json"
                        },
                "chapters_keywords": {
                        "filename": os.path.join(self.path_raw, chapters_file),
                        "process_line": "_process_line_chapters_keywords",
                        "persistent_file": os.path.join(
                                self.path_persistent, "chapters_keywords.pkl"),
                        "persistent_variable": {},
                        "dataset_format": "json"
                        },
                "chapters_books_isbns": {
                        "filename": os.path.join(self.path_raw, chapters_file),
                        "process_line": "_process_line_chapters_books_isbns",
                        "persistent_file": os.path.join(
                                self.path_persistent,
                                "chapters_books_isbns.pkl"),
                        "persistent_variable": {},
                        "dataset_format": "json"
                        },
                }
Пример #11
0
    def train(self):
        # Make the datasets iterable
        batch_size = 10000

        train_data_loader = torch.utils.data.DataLoader(
            dataset=self.training_data, batch_size=batch_size)
        validation_data_loader = torch.utils.data.DataLoader(
            dataset=self.validation_data, batch_size=batch_size)
        train_labels_loader = torch.utils.data.DataLoader(
            dataset=self.training_labels, batch_size=batch_size)
        validation_labels_loader = torch.utils.data.DataLoader(
            dataset=self.validation_labels, batch_size=batch_size)

        # Train the model
        timer = Timer()
        timer.tic()

        mean_train_losses = []
        mean_validation_losses = []

        for epoch in range(self.epochs):
            print("Epoch: {}".format(epoch + 1))
            train_losses = []
            validation_losses = []
            self.model.train()

            for i, (train_data, train_labels) in enumerate(
                    zip(train_data_loader, train_labels_loader)):
                self.model.train()
                self.optimizer.zero_grad()
                outputs = self.model(train_data)
                loss = self.cross_entropy_loss(outputs.squeeze(), train_labels)
                loss.backward()
                self.optimizer.step()
                train_losses.append(loss.item())

                # Compute validation loss
                self.model.eval()
                with torch.no_grad():
                    for _, (val_data, val_labels) in enumerate(
                            zip(validation_data_loader,
                                validation_labels_loader)):
                        val_pred = self.model(val_data)
                        val_loss = self.cross_entropy_loss(
                            val_pred.squeeze(), val_labels)
                        validation_losses.append(val_loss.item())

            print("\tTrain loss: {}, validation loss: {}".format(
                np.mean(train_losses), np.mean(validation_losses)))
            mean_train_losses.append(np.mean(train_losses))
            mean_validation_losses.append(np.mean(validation_losses))
            if mean_validation_losses[-1] == min(mean_validation_losses):
                print("\tSaving model...")
                torch.save(self.model.state_dict(), self.model_path)
                print("\tSaved.")

        print("Finished training.")
        training_time = timer.toc()
        self._plot_losses(mean_train_losses, mean_validation_losses)
        self._print_stats(mean_train_losses, mean_validation_losses,
                          training_time)
Пример #12
0
    def predict(self, test_data, model_checkpoint, gpu_mem_fraction=None):
        timer = Timer()
        timer.tic()

        G = test_data[0]
        features = test_data[1]
        id_map = test_data[2]

        if features is not None:
            # pad with dummy zero vector
            features = np.vstack([features, np.zeros((features.shape[1], ))])

        context_pairs = test_data[3] if self.random_context else None
        placeholders = self._construct_placeholders()
        minibatch = EdgeMinibatchIterator(G,
                                          id_map,
                                          placeholders,
                                          batch_size=self.batch_size,
                                          max_degree=self.max_degree,
                                          num_neg_samples=self.neg_sample_size,
                                          context_pairs=context_pairs)

        adj_info_ph = tf.compat.v1.placeholder(tf.int32,
                                               shape=minibatch.adj.shape)
        adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

        model = self._create_model(placeholders, features, adj_info, minibatch)

        config = tf.compat.v1.ConfigProto(
            log_device_placement=self.log_device_placement)
        if gpu_mem_fraction is not None:
            config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_fraction
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True

        # Initialize session
        sess = tf.compat.v1.Session(config=config)
        merged = tf.compat.v1.summary.merge_all()
        #        summary_writer = tf.compat.v1.summary.FileWriter(self._log_dir(),
        #                                                         sess.graph)

        # Initialize model saver
        saver = tf.compat.v1.train.Saver()

        # Init variables
        sess.run(tf.compat.v1.global_variables_initializer(),
                 feed_dict={adj_info_ph: minibatch.adj})

        val_adj_info = tf.compat.v1.assign(adj_info, minibatch.test_adj)

        # Restore model
        print("Restoring trained model.")
        checkpoint_file = os.path.join(self._log_dir(), model_checkpoint)
        ckpt = tf.compat.v1.train.get_checkpoint_state(checkpoint_file)
        if checkpoint_file:
            saver.restore(sess, checkpoint_file)
            print("Model restored.")
        else:
            print("This model checkpoint does not exist. The model might " +
                  "not be trained yet or the checkpoint is invalid.")

        # Infer embeddings
        sess.run(val_adj_info.op)
        print("Computing embeddings...")
        val_embeddings = []
        finished = False
        seen = set([])
        nodes = []
        iter_num = 0
        while not finished:
            feed_dict_val, finished, edges = minibatch.incremental_embed_feed_dict(
                self.validate_batch_size, iter_num)
            iter_num += 1
            outs_val = sess.run([model.loss, model.mrr, model.outputs1],
                                feed_dict=feed_dict_val)
            for i, edge in enumerate(edges):
                if not edge[0] in seen:
                    val_embeddings.append(outs_val[-1][i, :])
                    nodes.append(edge[0])
                    seen.add(edge[0])

        val_embeddings = np.vstack(val_embeddings)
        if self.save_embeddings:
            print("Saving embeddings...")
            if not os.path.exists(self._log_dir()):
                os.makedirs(self._log_dir())
            np.save(self._log_dir() + "inferred_embeddings.npy",
                    val_embeddings)
            with open(self._log_dir() + "inferred_embeddings_ids.txt",
                      "w") as fp:
                fp.write("\n".join(map(str, nodes)))
            print("Embeddings saved.\n")

        # Return only the embeddings of the test nodes
        test_embeddings_ids = {}
        for i, node in enumerate(nodes):
            test_embeddings_ids[node] = i
        test_nodes = [n for n in G.nodes() if G.node[n]['test']]
        test_embeddings = val_embeddings[[
            test_embeddings_ids[id] for id in test_nodes
        ]]

        sess.close()
        tf.compat.v1.reset_default_graph()
        timer.toc()
        return test_nodes, test_embeddings
Пример #13
0
    def train(self, train_data):
        print("Training model...")
        timer = Timer()
        timer.tic()

        G = train_data[0]
        features = train_data[1]
        id_map = train_data[2]

        if features is not None:
            # pad with dummy zero vector
            features = np.vstack([features, np.zeros((features.shape[1], ))])

        context_pairs = train_data[3] if self.random_context else None
        placeholders = self._construct_placeholders()
        minibatch = EdgeMinibatchIterator(G,
                                          id_map,
                                          placeholders,
                                          batch_size=self.batch_size,
                                          max_degree=self.max_degree,
                                          num_neg_samples=self.neg_sample_size,
                                          context_pairs=context_pairs)

        adj_info_ph = tf.compat.v1.placeholder(tf.int32,
                                               shape=minibatch.adj.shape)
        adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

        model = self._create_model(placeholders, features, adj_info, minibatch)

        config = tf.compat.v1.ConfigProto(
            log_device_placement=self.log_device_placement)
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True

        # Initialize session
        sess = tf.compat.v1.Session(config=config)
        merged = tf.compat.v1.summary.merge_all()
        #        summary_writer = tf.compat.v1.summary.FileWriter(self._log_dir(),
        #                                                         sess.graph)

        # Initialize model saver
        saver = tf.compat.v1.train.Saver(max_to_keep=self.epochs)

        # Init variables
        sess.run(tf.compat.v1.global_variables_initializer(),
                 feed_dict={adj_info_ph: minibatch.adj})

        # Train model
        train_shadow_mrr = None
        shadow_mrr = None

        total_steps = 0
        avg_time = 0.0
        epoch_val_costs = []

        train_losses = []
        validation_losses = []

        train_adj_info = tf.compat.v1.assign(adj_info, minibatch.adj)
        val_adj_info = tf.compat.v1.assign(adj_info, minibatch.test_adj)

        for epoch in range(self.epochs):
            minibatch.shuffle()

            iter = 0
            print('Epoch: %04d' % (epoch))
            epoch_val_costs.append(0)
            train_loss_epoch = []
            validation_loss_epoch = []
            while not minibatch.end():
                # Construct feed dictionary
                feed_dict = minibatch.next_minibatch_feed_dict()
                feed_dict.update({placeholders['dropout']: self.dropout})

                t = time.time()
                # Training step
                outs = sess.run([
                    merged, model.opt_op, model.loss, model.ranks,
                    model.aff_all, model.mrr, model.outputs1
                ],
                                feed_dict=feed_dict)

                train_cost = outs[2]
                train_mrr = outs[5]
                train_loss_epoch.append(train_cost)
                if train_shadow_mrr is None:
                    train_shadow_mrr = train_mrr
                else:
                    train_shadow_mrr -= (1 - 0.99) * (train_shadow_mrr -
                                                      train_mrr)

                if iter % self.validate_iter == 0:
                    # Validation
                    sess.run(val_adj_info.op)
                    val_cost, ranks, val_mrr, duration = self._evaluate(
                        sess, model, minibatch, size=self.validate_batch_size)
                    sess.run(train_adj_info.op)
                    epoch_val_costs[-1] += val_cost
                    validation_loss_epoch.append(val_cost)
                if shadow_mrr is None:
                    shadow_mrr = val_mrr
                else:
                    shadow_mrr -= (1 - 0.99) * (shadow_mrr - val_mrr)

#                if total_steps % self.print_every == 0:
#                    summary_writer.add_summary(outs[0], total_steps)

# Print results
                avg_time = (avg_time * total_steps + time.time() -
                            t) / (total_steps + 1)

                if total_steps % self.print_every == 0:
                    print(
                        "Iter: %04d" % iter,
                        "train_loss={:.5f}".format(train_cost),
                        "train_mrr={:.5f}".format(train_mrr),
                        # exponential moving average
                        "train_mrr_ema={:.5f}".format(train_shadow_mrr),
                        "val_loss={:.5f}".format(val_cost),
                        "val_mrr={:.5f}".format(val_mrr),
                        # exponential moving average
                        "val_mrr_ema={:.5f}".format(shadow_mrr),
                        "time={:.5f}".format(avg_time))

                iter += 1
                total_steps += 1

                if total_steps > self.max_total_steps:
                    break

            # Keep track of train and validation losses per epoch
            train_losses.append(sum(train_loss_epoch) / len(train_loss_epoch))
            validation_losses.append(
                sum(validation_loss_epoch) / len(validation_loss_epoch))

            # Save embeddings if the epoch has the lowest validation loss
            # so far
            if self.save_embeddings and validation_losses[-1] == min(
                    validation_losses):
                print(
                    "Minimum validation loss so far ({}) at epoch {}.".format(
                        validation_losses[-1], epoch))
                sess.run(val_adj_info.op)
                self._save_embeddings(sess, model,
                                      minibatch, self.validate_batch_size,
                                      self._log_dir())

            # Save model at each epoch
            print("Saving model at epoch {}.".format(epoch))
            saver.save(sess,
                       os.path.join(self._log_dir(),
                                    "model_epoch_" + str(epoch) + ".ckpt"),
                       global_step=total_steps)

            if total_steps > self.max_total_steps:
                break

        print("Optimization finished!\n")

        training_time = timer.toc()
        self._plot_losses(train_losses, validation_losses)
        self._print_stats(train_losses, validation_losses, training_time)
    def main():
        parser = argparse.ArgumentParser(
            description='Arguments for GraphSAGE concatenated ' +
            'classifier model evaluation.')
        parser.add_argument(
            "classifier_name",
            choices=["KNN", "MLP", "MultinomialLogisticRegression"],
            help="The name of the classifier.")
        parser.add_argument('embedding_type',
                            choices=[
                                "AVG_L", "AVG_2L", "AVG_SUM_L4", "AVG_SUM_ALL",
                                "MAX_2L", "CONC_AVG_MAX_2L",
                                "CONC_AVG_MAX_SUM_L4", "SUM_L", "SUM_2L"
                            ],
                            help="Type of embedding.")
        parser.add_argument('model_checkpoint_citations',
                            help='Name of the GraphSAGE model checkpoint ' +
                            'for the citations graph.')
        parser.add_argument('model_checkpoint_authors',
                            help='Name of the GraphSAGE model checkpoint ' +
                            'for the authors graph.')
        parser.add_argument('train_prefix_citations',
                            help='Name of the object file that stores the ' +
                            'citations training data.')
        parser.add_argument('train_prefix_authors',
                            help='Name of the object file that stores the ' +
                            'authors training data.')
        parser.add_argument('model_name',
                            choices=[
                                "graphsage_mean", "gcn", "graphsage_seq",
                                "graphsage_maxpool", "graphsage_meanpool"
                            ],
                            help="Model names.")
        parser.add_argument('--model_size',
                            choices=["small", "big"],
                            default="small",
                            help="Can be big or small; model specific def'ns")
        parser.add_argument('--learning_rate',
                            type=float,
                            default=0.00001,
                            help='Initial learning rate.')
        parser.add_argument('--epochs',
                            type=int,
                            default=10,
                            help='Number of epochs to train.')
        parser.add_argument('--dropout',
                            type=float,
                            default=0.0,
                            help='Dropout rate (1 - keep probability).')
        parser.add_argument('--weight_decay',
                            type=float,
                            default=0.0,
                            help='Weight for l2 loss on embedding matrix.')
        parser.add_argument('--max_degree',
                            type=int,
                            default=100,
                            help='Maximum node degree.')
        parser.add_argument('--samples_1',
                            type=int,
                            default=25,
                            help='Number of samples in layer 1.')
        parser.add_argument('--samples_2',
                            type=int,
                            default=10,
                            help='Number of users samples in layer 2.')
        parser.add_argument('--dim_1',
                            type=int,
                            default=128,
                            help='Size of output dim ' +
                            '(final is 2x this, if using concat)')
        parser.add_argument('--dim_2',
                            type=int,
                            default=128,
                            help='Size of output dim ' +
                            '(final is 2x this, if using concat)')
        parser.add_argument('--random_context',
                            action="store_false",
                            default=True,
                            help='Whether to use random context or direct ' +
                            'edges.')
        parser.add_argument('--neg_sample_size',
                            type=int,
                            default=20,
                            help='Number of negative samples.')
        parser.add_argument('--batch_size',
                            type=int,
                            default=512,
                            help='Minibatch size.')
        parser.add_argument('--identity_dim',
                            type=int,
                            default=0,
                            help='Set to positive value to use identity ' +
                            'embedding features of that dimension.')
        parser.add_argument('--save_embeddings',
                            action="store_true",
                            default=False,
                            help='Whether to save embeddings for all nodes ' +
                            'after training')
        parser.add_argument('--base_log_dir',
                            default='../../../data/processed/graphsage/',
                            help='Base directory for logging and saving ' +
                            'embeddings')
        parser.add_argument('--validate_iter',
                            type=int,
                            default=5000,
                            help='How often to run a validation minibatch.')
        parser.add_argument('--validate_batch_size',
                            type=int,
                            default=256,
                            help='How many nodes per validation sample.')
        parser.add_argument('--gpu',
                            type=int,
                            default=0,
                            help='Which gpu to use.')
        parser.add_argument('--print_every',
                            type=int,
                            default=50,
                            help='How often to print training info.')
        parser.add_argument('--max_total_steps',
                            type=int,
                            default=10**10,
                            help='Maximum total number of iterations.')
        parser.add_argument('--log_device_placement',
                            action="store_true",
                            default=False,
                            help='Whether to log device placement.')
        parser.add_argument('--recs',
                            type=int,
                            default=10,
                            help='Number of recommendations.')
        args = parser.parse_args()

        print("Starting evaluation...")
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
        print("Using GPU {}.".format(str(args.gpu)))

        from GraphSAGEClassifierConcatEvaluation import GraphSAGEClassifierConcatEvaluation
        evaluation_model = GraphSAGEClassifierConcatEvaluation(
            args.classifier_name, args.embedding_type, args.model_name,
            args.model_size, args.learning_rate, args.gpu, args.recs)

        # Initialize GraphSAGE models
        graphsage_model_citations = UnsupervisedModel(
            args.train_prefix_citations, args.model_name, args.model_size,
            args.learning_rate, args.epochs, args.dropout, args.weight_decay,
            args.max_degree, args.samples_1, args.samples_2, args.dim_1,
            args.dim_2, args.random_context, args.neg_sample_size,
            args.batch_size, args.identity_dim, args.save_embeddings,
            args.base_log_dir, args.validate_iter, args.validate_batch_size,
            args.gpu, args.print_every, args.max_total_steps,
            args.log_device_placement)
        graphsage_model_authors = UnsupervisedModel(
            args.train_prefix_authors, args.model_name, args.model_size,
            args.learning_rate, args.epochs, args.dropout, args.weight_decay,
            args.max_degree, args.samples_1, args.samples_2, args.dim_1,
            args.dim_2, args.random_context, args.neg_sample_size,
            args.batch_size, args.identity_dim, args.save_embeddings,
            args.base_log_dir, args.validate_iter, args.validate_batch_size,
            args.gpu, args.print_every, args.max_total_steps,
            args.log_device_placement)

        # Train model if needed:
        if not evaluation_model._has_persistent_model():
            print("Classifier not trained yet. Training now...")
            timer = Timer()
            timer.tic()
            evaluation_model.train(graphsage_model_citations,
                                   graphsage_model_authors)
            print("Training finished.")
            timer.toc()
        else:
            evaluation_model._load_model_classifier()

        # Load test data
        print("Loading test data...")
        query_test, query_test_authors, truth = evaluation_model.load_data()
        print("Loaded.")

        # Infer embeddings
        print("Inferring embeddings for citations graph.")
        queue_citations = mp.Queue()
        process_citations = mp.Process(
            target=evaluation_model.infer_embeddings,
            args=(query_test, None, "citations", graphsage_model_citations,
                  args.model_checkpoint_citations, queue_citations))
        process_citations.start()
        embeddings_citations = queue_citations.get()
        process_citations.join()
        process_citations.terminate()

        print("Inferring embeddings for authors graphs.")
        queue_authors = mp.Queue()
        process_authors = mp.Process(target=evaluation_model.infer_embeddings,
                                     args=(query_test, query_test_authors,
                                           "authors", graphsage_model_authors,
                                           args.model_checkpoint_authors,
                                           queue_authors))
        process_authors.start()
        embeddings_authors = queue_authors.get()
        process_authors.join()
        process_authors.terminate()

        # Concatenate embeddings
        test_embeddings = np.concatenate(
            (embeddings_citations, embeddings_authors), axis=1)

        print("Computing predictions...")
        recommendation = evaluation_model.compute_predictions(test_embeddings)
        print("Predictions computed.")

        # Evaluate
        print("Evaluating...")
        evaluation = EvaluationContainer()
        evaluation.evaluate(recommendation, truth)
        print("Finished.")
    def train(self, train_data, sampler_name='Uniform'):
        print("Training model...")
        timer = Timer()
        timer.tic()

        G = train_data[0]
        features = train_data[1]
        id_map = train_data[2]

        if features is not None:
            # pad with dummy zero vector
            features = np.vstack([features, np.zeros((features.shape[1], ))])

        context_pairs = train_data[3] if self.random_context else None
        placeholders = self._construct_placeholders()
        minibatch = EdgeMinibatchIterator(G,
                                          id_map,
                                          placeholders,
                                          batch_size=self.batch_size,
                                          max_degree=self.max_degree,
                                          num_neg_samples=self.neg_sample_size,
                                          context_pairs=context_pairs)

        adj_info_ph = tf.compat.v1.placeholder(tf.int32,
                                               shape=minibatch.adj.shape)
        adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")
        adj_shape = adj_info.get_shape().as_list()

        model = self._create_model(sampler_name, placeholders, features,
                                   adj_info, minibatch)

        config = tf.compat.v1.ConfigProto(
            log_device_placement=self.log_device_placement)
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True

        # Initialize session
        sess = tf.compat.v1.Session(config=config)
        merged = tf.compat.v1.summary.merge_all()
        #        summary_writer = tf.compat.v1.summary.FileWriter(
        #                self._log_dir(sampler_name), sess.graph)

        # Initialize model saver
        saver = tf.compat.v1.train.Saver(max_to_keep=self.epochs)

        # Init variables
        sess.run(tf.compat.v1.global_variables_initializer(),
                 feed_dict={adj_info_ph: minibatch.adj})

        # Restore params of ML sampler model
        if sampler_name == 'ML' or sampler_name == 'FastML':
            sampler_vars = tf.compat.v1.get_collection(
                tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope="MLsampler")
            saver_sampler = tf.compat.v1.train.Saver(var_list=sampler_vars)
            sampler_model_path = self._sampler_model_path()
            saver_sampler.restore(sess, sampler_model_path + 'model.ckpt')

        # Loss node path
        loss_node_path = self._loss_node_path(sampler_name)
        if not os.path.exists(loss_node_path):
            os.makedirs(loss_node_path)

        # Train model
        train_shadow_mrr = None
        shadow_mrr = None

        total_steps = 0
        avg_time = 0.0
        epoch_val_costs = []

        train_adj_info = tf.compat.v1.assign(adj_info, minibatch.adj)
        val_adj_info = tf.compat.v1.assign(adj_info, minibatch.test_adj)

        train_losses = []
        validation_losses = []

        val_cost_ = []
        val_mrr_ = []
        shadow_mrr_ = []
        duration_ = []

        ln_acc = sparse.csr_matrix((adj_shape[0], adj_shape[0]),
                                   dtype=np.float32)
        lnc_acc = sparse.csr_matrix((adj_shape[0], adj_shape[0]),
                                    dtype=np.int32)
        ln_acc = ln_acc.tolil()
        lnc_acc = lnc_acc.tolil()

        for epoch in range(self.epochs):
            minibatch.shuffle()

            iter = 0
            print('Epoch: %04d' % (epoch))
            epoch_val_costs.append(0)
            train_loss_epoch = []
            validation_loss_epoch = []

            while not minibatch.end():
                # Construct feed dictionary
                feed_dict = minibatch.next_minibatch_feed_dict()
                feed_dict.update({placeholders['dropout']: self.dropout})
                t = time.time()

                # Training step
                outs = sess.run([
                    merged, model.opt_op, model.loss, model.ranks,
                    model.aff_all, model.mrr, model.outputs1, model.loss_node,
                    model.loss_node_count
                ],
                                feed_dict=feed_dict)
                train_cost = outs[2]
                train_mrr = outs[5]
                train_loss_epoch.append(train_cost)

                if train_shadow_mrr is None:
                    train_shadow_mrr = train_mrr
                else:
                    train_shadow_mrr -= (1 - 0.99) * (train_shadow_mrr -
                                                      train_mrr)

                if iter % self.validate_iter == 0:
                    # Validation
                    sess.run(val_adj_info.op)
                    val_cost, ranks, val_mrr, duration = self._evaluate(
                        sess, model, minibatch, size=self.validate_batch_size)
                    sess.run(train_adj_info.op)
                    epoch_val_costs[-1] += val_cost
                    validation_loss_epoch.append(val_cost)

                if shadow_mrr is None:
                    shadow_mrr = val_mrr
                else:
                    shadow_mrr -= (1 - 0.99) * (shadow_mrr - val_mrr)

                val_cost_.append(val_cost)
                val_mrr_.append(val_mrr)
                shadow_mrr_.append(shadow_mrr)
                duration_.append(duration)

                #                if total_steps % self.print_every == 0:
                #                    summary_writer.add_summary(outs[0], total_steps)

                # Print results
                avg_time = (avg_time * total_steps + time.time() -
                            t) / (total_steps + 1)

                if total_steps % self.print_every == 0:
                    print(
                        "Iter: %04d" % iter,
                        "train_loss={:.5f}".format(train_cost),
                        "train_mrr={:.5f}".format(train_mrr),
                        # exponential moving average
                        "train_mrr_ema={:.5f}".format(train_shadow_mrr),
                        "val_loss={:.5f}".format(val_cost),
                        "val_mrr={:.5f}".format(val_mrr),
                        # exponential moving average
                        "val_mrr_ema={:.5f}".format(shadow_mrr),
                        "time={:.5f}".format(avg_time))

                ln = outs[7].values
                ln_idx = outs[7].indices
                ln_acc[ln_idx[:, 0], ln_idx[:, 1]] += ln

                lnc = outs[8].values
                lnc_idx = outs[8].indices
                lnc_acc[lnc_idx[:, 0], lnc_idx[:, 1]] += lnc

                iter += 1
                total_steps += 1

                if total_steps > self.max_total_steps:
                    break

            # Keep track of train and validation losses per epoch
            train_losses.append(sum(train_loss_epoch) / len(train_loss_epoch))
            validation_losses.append(
                sum(validation_loss_epoch) / len(validation_loss_epoch))

            # If the epoch has the lowest validation loss so far
            if validation_losses[-1] == min(validation_losses):
                print(
                    "Minimum validation loss so far ({}) at epoch {}.".format(
                        validation_losses[-1], epoch))
                # Save loss node and count
                loss_node = sparse.save_npz(loss_node_path + 'loss_node.npz',
                                            sparse.csr_matrix(ln_acc))
                loss_node_count = sparse.save_npz(
                    loss_node_path + 'loss_node_count.npz',
                    sparse.csr_matrix(lnc_acc))
                # Save embeddings
                if self.save_embeddings and sampler_name is not "Uniform":
                    sess.run(val_adj_info.op)
                    self._save_embeddings(sess, model, minibatch,
                                          self.validate_batch_size,
                                          self._log_dir(sampler_name))

            # Save model at each epoch
            print("Saving model at epoch {}.".format(epoch))
            saver.save(
                sess,
                os.path.join(self._log_dir(sampler_name),
                             "model_epoch_" + str(epoch) + ".ckpt"))

            if total_steps > self.max_total_steps:
                break

        print("Optimization Finished!")

        training_time = timer.toc()
        self._plot_losses(train_losses, validation_losses, sampler_name)
        self._print_stats(train_losses, validation_losses, training_time,
                          sampler_name)