예제 #1
0
def train(dataset, vectors_path, lr_file, ckpt_dir, checkpoint, idx2vocab,
          vocab_unigrams, embedding_size, neg_sampled, distortion_power,
          batch_size, initial_learning_rate, decay_epochs, decay_rate,
          iter_epochs, allow_soft_placement, log_device_placement,
          gpu_memory_fraction, using_gpu, allow_growth, loss_interval,
          summary_steps, ckpt_interval, ckpt_epochs, summary_interval,
          decay_interval, train_workers):

    num_steps_per_epoch = int(dataset.num_examples / batch_size)
    iter_steps = iter_epochs * num_steps_per_epoch
    decay_steps = int(decay_epochs * num_steps_per_epoch)
    ckpt_steps = int(ckpt_epochs * num_steps_per_epoch)

    LR = utils.LearningRateGenerator(
        initial_learning_rate=initial_learning_rate,
        initial_steps=0,
        decay_rate=decay_rate,
        decay_steps=decay_steps)

    with tf.Graph().as_default(), tf.device(
            '/gpu:0' if using_gpu else '/cpu:0'):

        global_step = tf.Variable(0, trainable=False, name="global_step")

        inputs = tf.placeholder(tf.int32, shape=[batch_size], name='inputs')
        labels = tf.placeholder(tf.int32, shape=[batch_size], name='labels')
        learning_rate = tf.placeholder(tf.float32, name='learning_rate')

        model = Word2Vec(vocab_size=len(idx2vocab),
                         embedding_size=embedding_size,
                         vocab_unigrams=vocab_unigrams,
                         neg_sampled=neg_sampled,
                         distortion_power=distortion_power,
                         batch_size=batch_size)

        train_op, loss = model.train(inputs, labels, global_step,
                                     learning_rate)

        # Create a saver.
        saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=5)

        summary_op = tf.summary.merge_all()

        # Build an initialization operation to run below.
        init_op = tf.global_variables_initializer()

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU implementations.
        config = tf.ConfigProto(allow_soft_placement=allow_soft_placement,
                                log_device_placement=log_device_placement)
        config.gpu_options.per_process_gpu_memory_fraction = gpu_memory_fraction
        config.gpu_options.allow_growth = allow_growth
        # config.gpu_options.visible_device_list = visible_device_list

        with tf.Session(config=config) as sess:
            # first_step = 0
            if checkpoint == '0':  # new train
                sess.run(init_op)
            elif checkpoint == '-1':  # choose the latest one
                ckpt = tf.train.get_checkpoint_state(ckpt_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    # new_saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')
                    # Restores from checkpoint
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    # global_step_for_restore = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                    # first_step = int(global_step_for_restore) + 1
                else:
                    logger.warning('No checkpoint file found')
                    return
            else:
                if os.path.exists(
                        os.path.join(ckpt_dir,
                                     'model.ckpt-' + checkpoint + '.index')):
                    # new_saver = tf.train.import_meta_graph(
                    #     os.path.join(ckpt_dir, 'model.ckpt-' + checkpoint + '.meta'))
                    saver.restore(
                        sess, os.path.join(ckpt_dir,
                                           'model.ckpt-' + checkpoint))
                    # first_step = int(checkpoint) + 1
                else:
                    logger.warning(
                        'checkpoint {} not found'.format(checkpoint))
                    return

            summary_writer = tf.summary.FileWriter(ckpt_dir, sess.graph)

            ## train
            executor_workers = train_workers - 1
            if executor_workers > 0:
                executor = ThreadPoolExecutor(max_workers=executor_workers)
                for _ in range(executor_workers):
                    executor.submit(_train_thread_body, dataset, batch_size,
                                    inputs, labels, sess, train_op, iter_steps,
                                    global_step, learning_rate, LR)

            last_loss_time = time.time() - loss_interval
            last_summary_time = time.time() - summary_interval
            last_decay_time = last_checkpoint_time = time.time()
            last_decay_step = last_summary_step = last_checkpoint_step = 0
            while True:
                start_time = time.time()
                batch_data, batch_labels = dataset.next_batch(
                    batch_size, keep_strict_batching=True)
                feed_dict = {
                    inputs: batch_data,
                    labels: batch_labels,
                    learning_rate: LR.learning_rate
                }
                _, loss_value, cur_step = sess.run(
                    [train_op, loss, global_step], feed_dict=feed_dict)
                now = time.time()

                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'

                epoch, epoch_step = divmod(cur_step, num_steps_per_epoch)

                if now - last_loss_time >= loss_interval:
                    format_str = '%s: step=%d(%d/%d), lr=%.6f, loss=%.6f, duration/step=%.4fs'
                    logger.info(format_str %
                                (time.strftime('%Y-%m-%d %H:%M:%S',
                                               time.localtime(time.time())),
                                 cur_step, epoch_step, epoch, LR.learning_rate,
                                 loss_value, now - start_time))
                    last_loss_time = time.time()
                if now - last_summary_time >= summary_interval or cur_step - last_summary_step >= summary_steps or cur_step >= iter_steps:
                    summary_str = sess.run(summary_op, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, cur_step)
                    last_summary_time = time.time()
                    last_summary_step = cur_step
                ckpted = False
                # Save the model checkpoint periodically. (named 'model.ckpt-global_step.meta')
                if now - last_checkpoint_time >= ckpt_interval or cur_step - last_checkpoint_step >= ckpt_steps or cur_step >= iter_steps:
                    checkpoint_path = os.path.join(ckpt_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=cur_step)
                    # embedding_vectors = sess.run(model.vectors, feed_dict=feed_dict)
                    vecs, weights, biases = sess.run([
                        model.vectors, model.context_weights,
                        model.context_biases
                    ],
                                                     feed_dict=feed_dict)
                    save_word2vec_format(vectors_path, vecs, idx2vocab)
                    np.savetxt(vectors_path + ".contexts", weights)
                    np.savetxt(vectors_path + ".context_biases", biases)
                    last_checkpoint_time = time.time()
                    last_checkpoint_step = cur_step
                    ckpted = True
                # update learning rate
                if ckpted or now - last_decay_time >= decay_interval or cur_step - last_decay_step >= decay_steps:
                    lr_info = np.loadtxt(lr_file, dtype=float)
                    if np.abs(lr_info[1] - decay_epochs) >= 1e-7:
                        decay_epochs = lr_info[1]
                        decay_steps = int(decay_epochs * num_steps_per_epoch)
                    if np.abs(lr_info[2] - decay_rate) >= 1e-7:
                        decay_rate = lr_info[2]
                    if np.abs(lr_info[0] - initial_learning_rate) < 1e-7:
                        LR.exponential_decay(cur_step,
                                             decay_rate=decay_rate,
                                             decay_steps=decay_steps)
                    else:
                        initial_learning_rate = lr_info[0]
                        LR.reset(initial_learning_rate=initial_learning_rate,
                                 initial_steps=cur_step,
                                 decay_rate=decay_rate,
                                 decay_steps=decay_steps)
                    last_decay_time = time.time()
                    last_decay_step = cur_step

                if cur_step >= iter_steps:
                    break
예제 #2
0
def train(net, vectors_path, lr_file, ckpt_dir, checkpoint, embedding_size,
          neg_sampled, order, distortion_power, iter_epochs, batch_size,
          initial_learning_rate, decay_epochs, decay_interval, decay_rate,
          allow_soft_placement, log_device_placement, gpu_memory_fraction,
          using_gpu, allow_growth, loss_interval, summary_steps,
          summary_interval, ckpt_epochs, ckpt_interval, train_workers):
    edge_sampler = Edge_sampler(net, batch_size)
    edges_size = edge_sampler.edges_size
    nodes_size = net.get_nodes_size()
    num_steps_per_epoch = int(edges_size / batch_size)
    iter_steps = round(
        iter_epochs *
        num_steps_per_epoch)  # iter_epochs should be big enough to converge.
    decay_steps = round(decay_epochs * num_steps_per_epoch)
    ckpt_steps = round(ckpt_epochs * num_steps_per_epoch)

    nodes_degrees = [net.get_degrees(v) for v in range(nodes_size)]

    LR = utils.LearningRateGenerator(
        initial_learning_rate=initial_learning_rate,
        initial_steps=0,
        decay_rate=decay_rate,
        decay_steps=decay_steps,
        iter_steps=iter_steps)

    with tf.Graph().as_default(), tf.device(
            '/gpu:0' if using_gpu else '/cpu:0'):

        inputs = tf.placeholder(tf.int32, shape=[batch_size], name='inputs')
        labels = tf.placeholder(tf.int32, shape=[batch_size], name='labels')
        learning_rate = tf.placeholder(tf.float32, name='learning_rate')

        model_list = []
        trains_list = []
        if order == "1":
            with tf.name_scope("1st_order"):
                model = SGNS(vocab_size=nodes_size,
                             embedding_size=embedding_size,
                             vocab_unigrams=nodes_degrees,
                             distortion_power=distortion_power,
                             neg_sampled=neg_sampled,
                             batch_size=batch_size,
                             order=1)
                global_step = tf.Variable(0,
                                          trainable=False,
                                          name="global_step")
            train_op, loss = model.train(inputs, labels, global_step,
                                         learning_rate)
            model_list.append(model)
            trains_list.append((train_op, loss, global_step))
        elif order == "2":
            with tf.name_scope("2st_order"):
                model = SGNS(vocab_size=nodes_size,
                             embedding_size=embedding_size,
                             vocab_unigrams=nodes_degrees,
                             distortion_power=distortion_power,
                             neg_sampled=neg_sampled,
                             batch_size=batch_size,
                             order=2)
                global_step = tf.Variable(0,
                                          trainable=False,
                                          name="global_step")
            train_op, loss = model.train(inputs, labels, global_step,
                                         learning_rate)
            model_list.append(model)
            trains_list.append((train_op, loss, global_step))
        elif order == "3":
            with tf.name_scope("1st_order"):
                model = SGNS(vocab_size=nodes_size,
                             embedding_size=embedding_size // 2,
                             vocab_unigrams=nodes_degrees,
                             distortion_power=distortion_power,
                             neg_sampled=neg_sampled,
                             batch_size=batch_size,
                             order=1)
                global_step = tf.Variable(0,
                                          trainable=False,
                                          name="global_step")
            train_op, loss = model.train(inputs, labels, global_step,
                                         learning_rate)
            model_list.append(model)
            trains_list.append((train_op, loss, global_step))
            with tf.name_scope("2st_order"):
                model = SGNS(vocab_size=nodes_size,
                             embedding_size=embedding_size // 2,
                             vocab_unigrams=nodes_degrees,
                             distortion_power=distortion_power,
                             neg_sampled=neg_sampled,
                             batch_size=batch_size,
                             order=2)
                global_step = tf.Variable(0,
                                          trainable=False,
                                          name="global_step")
            train_op, loss = model.train(inputs, labels, global_step,
                                         learning_rate)
            model_list.append(model)
            trains_list.append((train_op, loss, global_step))
        else:
            logger.error("unvalid order in LINE: '%s'. " % order)
            sys.exit()

        # Create a saver.
        saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=5)

        summary_op = tf.summary.merge_all()

        # Build an initialization operation to run below.
        init_op = tf.global_variables_initializer()

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU implementations.
        config = tf.ConfigProto(allow_soft_placement=allow_soft_placement,
                                log_device_placement=log_device_placement)
        config.gpu_options.per_process_gpu_memory_fraction = gpu_memory_fraction
        config.gpu_options.allow_growth = allow_growth
        # config.gpu_options.visible_device_list = visible_device_list

        with tf.Session(config=config) as sess:
            # first_step = 0
            if checkpoint == '0':  # new train
                sess.run(init_op)
            elif checkpoint == '-1':  # choose the latest one
                ckpt = tf.train.get_checkpoint_state(ckpt_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    # new_saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')
                    # Restores from checkpoint
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    # global_step_for_restore = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                    # first_step = int(global_step_for_restore) + 1
                else:
                    logger.warning('No checkpoint file found')
                    return
            else:
                if os.path.exists(
                        os.path.join(ckpt_dir,
                                     'model.ckpt-' + checkpoint + '.index')):
                    # new_saver = tf.train.import_meta_graph(
                    #     os.path.join(ckpt_dir, 'model.ckpt-' + checkpoint + '.meta'))
                    saver.restore(
                        sess, os.path.join(ckpt_dir,
                                           'model.ckpt-' + checkpoint))
                    # first_step = int(checkpoint) + 1
                else:
                    logger.warning(
                        'checkpoint {} not found'.format(checkpoint))
                    return

            summary_writer = tf.summary.FileWriter(ckpt_dir, sess.graph)

            ## train
            executor_workers = train_workers - 1
            if executor_workers > 0:
                futures = set()
                executor = ThreadPoolExecutor(max_workers=executor_workers)
                for _ in range(executor_workers):
                    future = executor.submit(_train_thread_body, edge_sampler,
                                             inputs, labels, sess, trains_list,
                                             learning_rate, LR)
                    logger.info("open a new training thread: %s" % future)
                    futures.add(future)
            last_loss_time = time.time() - loss_interval
            last_summary_time = time.time() - summary_interval
            last_decay_time = last_checkpoint_time = time.time()
            last_decay_step = last_summary_step = last_checkpoint_step = 0
            while True:
                start_time = time.time()
                batch_data, batch_labels = edge_sampler.next_batch()
                feed_dict = {
                    inputs: batch_data,
                    labels: batch_labels,
                    learning_rate: LR.learning_rate
                }
                loss_value_list = []
                for train_op, loss, global_step in trains_list:
                    _, loss_value, cur_step = sess.run(
                        [train_op, loss, global_step], feed_dict=feed_dict)
                    assert not np.isnan(
                        loss_value), 'Model diverged with loss = NaN'
                    loss_value_list.append(loss_value)
                now = time.time()

                epoch, epoch_step = divmod(cur_step, num_steps_per_epoch)

                if now - last_loss_time >= loss_interval:
                    if len(loss_value_list) == 1:
                        loss_str = "%.6f" % loss_value_list[0]
                    else:
                        loss_str = "[%.6f, %.6f]" % (loss_value_list[0],
                                                     loss_value_list[1])
                    format_str = '%s: step=%d(%d/%d), lr=%.6f, loss=%s, duration/step=%.4fs'
                    logger.info(format_str %
                                (time.strftime('%Y-%m-%d %H:%M:%S',
                                               time.localtime(time.time())),
                                 cur_step, epoch_step, epoch, LR.learning_rate,
                                 loss_str, now - start_time))
                    last_loss_time = time.time()
                if now - last_summary_time >= summary_interval or cur_step - last_summary_step >= summary_steps or cur_step >= iter_steps:
                    summary_str = sess.run(summary_op, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, cur_step)
                    last_summary_time = time.time()
                    last_summary_step = cur_step
                ckpted = False
                # Save the model checkpoint periodically. (named 'model.ckpt-global_step.meta')
                if now - last_checkpoint_time >= ckpt_interval or cur_step - last_checkpoint_step >= ckpt_steps or cur_step >= iter_steps:
                    # embedding_vectors = sess.run(model.vectors, feed_dict=feed_dict)
                    vecs_list = []
                    for model in model_list:
                        vecs = sess.run(model.vectors, feed_dict=feed_dict)
                        vecs_list.append(vecs)
                    vecs = np.concatenate(vecs_list, axis=1)
                    checkpoint_path = os.path.join(ckpt_dir, 'model.ckpt')
                    utils.save_word2vec_format_and_ckpt(
                        vectors_path, vecs, checkpoint_path, sess, saver,
                        cur_step)
                    last_checkpoint_time = time.time()
                    last_checkpoint_step = cur_step
                    ckpted = True
                # update learning rate
                if ckpted or now - last_decay_time >= decay_interval or (
                        decay_steps > 0
                        and cur_step - last_decay_step >= decay_steps):
                    lr_info = np.loadtxt(lr_file, dtype=float)
                    if np.abs(lr_info[1] - decay_epochs) > 1e-6:
                        decay_epochs = lr_info[1]
                        decay_steps = round(decay_epochs * num_steps_per_epoch)
                    if np.abs(lr_info[2] - decay_rate) > 1e-6:
                        decay_rate = lr_info[2]
                    if np.abs(lr_info[3] - iter_epochs) > 1e-6:
                        iter_epochs = lr_info[3]
                        iter_steps = round(iter_epochs * num_steps_per_epoch)
                    if np.abs(lr_info[0] - initial_learning_rate) > 1e-6:
                        initial_learning_rate = lr_info[0]
                        LR.reset(initial_learning_rate=initial_learning_rate,
                                 initial_steps=cur_step,
                                 decay_rate=decay_rate,
                                 decay_steps=decay_steps,
                                 iter_steps=iter_steps)
                    else:
                        LR.exponential_decay(cur_step,
                                             decay_rate=decay_rate,
                                             decay_steps=decay_steps,
                                             iter_steps=iter_steps)
                    last_decay_time = time.time()
                    last_decay_step = cur_step
                if cur_step >= LR.iter_steps:
                    break

            summary_writer.close()
            if executor_workers > 0:
                logger.info("waiting the training threads finished:")
                try:
                    for future in as_completed(futures):
                        logger.info(future)
                except KeyboardInterrupt:
                    print("stopped by hand.")
예제 #3
0
파일: TF_gcn.py 프로젝트: RingBDStack/RWNE
def train(dataset, lr_file, ckpt_dir, checkpoint, options):
    nodes_size = dataset._nodes_size
    num_steps_per_epoch = int(nodes_size / options.batch_size)
    iter_epochs = options.iter_epoches
    iter_steps = round(
        iter_epochs *
        num_steps_per_epoch)  # iter_epoches should be big enough to converge.
    decay_epochs = options.decay_epochs
    decay_steps = round(decay_epochs * num_steps_per_epoch)
    ckpt_steps = round(options.ckpt_epochs * num_steps_per_epoch)
    initial_learning_rate = options.learning_rate
    decay_rate = options.decay_rate

    LR = utils.LearningRateGenerator(
        initial_learning_rate=initial_learning_rate,
        initial_steps=0,
        decay_rate=decay_rate,
        decay_steps=decay_steps,
        iter_steps=iter_steps)

    with tf.Graph().as_default(), tf.device(
            '/gpu:0' if options.using_gpu else '/cpu:0'):

        global_step = tf.Variable(0, trainable=False, name="global_step")
        learning_rate = tf.placeholder(tf.float32, name='learning_rate')
        inputs = tf.placeholder(tf.float32,
                                shape=[None, options.feature_size],
                                name='inputs')
        laplacian = tf.placeholder(tf.float32, [None, None],
                                   name="laplacian_matrix")
        if options.using_label:
            labels = tf.placeholder(tf.int32,
                                    shape=[None, options.label_size],
                                    name='labels')
        else:
            labels = tf.placeholder(tf.int32,
                                    shape=[None, None],
                                    name='adjacency')

        model = GCN(dropout=options.dropout,
                    feature_size=options.feature_size,
                    using_label=options.using_label,
                    embedding_size=options.embedding_size,
                    hidden_size_list=options.hidden_size_list,
                    label_size=options.label_size,
                    weight_decay=options.weight_decay)
        train_op, loss = model.train(inputs, laplacian, labels, global_step,
                                     learning_rate)

        # Create a saver.
        saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=6)

        summary_op = tf.summary.merge_all()

        # Build an initialization operation to run below.
        init_op = tf.global_variables_initializer()

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU implementations.
        config = tf.ConfigProto(
            allow_soft_placement=options.allow_soft_placement,
            log_device_placement=options.log_device_placement)
        config.gpu_options.per_process_gpu_memory_fraction = options.gpu_memory_fraction
        config.gpu_options.allow_growth = options.allow_growth
        # config.gpu_options.visible_device_list = visible_device_list

        with tf.Session(config=config) as sess:
            # first_step = 0
            if checkpoint == '0':  # new train
                sess.run(init_op)

            elif checkpoint == '-1':  # choose the latest one
                ckpt = tf.train.get_checkpoint_state(ckpt_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    # new_saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')
                    # Restores from checkpoint
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    # global_step_for_restore = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                    # first_step = int(global_step_for_restore) + 1
                else:
                    logger.warning('No checkpoint file found')
                    return
            else:
                if os.path.exists(
                        os.path.join(ckpt_dir,
                                     'model.ckpt-' + checkpoint + '.index')):
                    # new_saver = tf.train.import_meta_graph(
                    #     os.path.join(ckpt_dir, 'model.ckpt-' + checkpoint + '.meta'))
                    saver.restore(
                        sess, os.path.join(ckpt_dir,
                                           'model.ckpt-' + checkpoint))
                    # first_step = int(checkpoint) + 1
                else:
                    logger.warning(
                        'checkpoint {} not found'.format(checkpoint))
                    return

            summary_writer = tf.summary.FileWriter(ckpt_dir, sess.graph)

            last_loss_time = time.time() - options.loss_interval
            last_summary_time = time.time() - options.summary_interval
            last_decay_time = last_checkpoint_time = time.time()
            last_decay_step = last_summary_step = last_checkpoint_step = 0
            while True:
                start_time = time.time()
                batch_features, batch_adj, batch_labels = dataset.next_batch(
                    options.batch_size)
                feed_dict = {
                    inputs: batch_features,
                    laplacian: batch_adj,
                    labels: batch_labels,
                    learning_rate: LR.learning_rate
                }
                _, loss_value, cur_step = sess.run(
                    [train_op, loss, global_step], feed_dict=feed_dict)
                now = time.time()

                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'

                epoch, epoch_step = divmod(cur_step, num_steps_per_epoch)

                if now - last_loss_time >= options.loss_interval:
                    format_str = '%s: step=%d(%d/%d), lr=%.6f, loss=%.6f, duration/step=%.4fs'
                    logger.info(format_str %
                                (time.strftime('%Y-%m-%d %H:%M:%S',
                                               time.localtime(time.time())),
                                 cur_step, epoch_step, epoch, LR.learning_rate,
                                 loss_value, now - start_time))
                    last_loss_time = time.time()
                if now - last_summary_time >= options.summary_interval or cur_step - last_summary_step >= options.summary_steps or cur_step >= iter_steps:
                    summary_str = sess.run(summary_op, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, cur_step)
                    last_summary_time = time.time()
                    last_summary_step = cur_step
                ckpted = False
                # Save the model checkpoint periodically. (named 'model.ckpt-global_step.meta')
                if now - last_checkpoint_time >= options.ckpt_interval or cur_step - last_checkpoint_step >= ckpt_steps or cur_step >= iter_steps:
                    if options.batch_size == nodes_size:
                        batch_features, batch_adj, batch_labels = dataset.get_full(
                        )
                        feed_dict = {
                            inputs: batch_features,
                            laplacian: batch_adj,
                            labels: batch_labels,
                            learning_rate: LR.learning_rate
                        }
                        vecs = sess.run(model.vectors, feed_dict=feed_dict)
                    else:
                        vecs = []
                        start = 0
                        while start < nodes_size:
                            end = min(nodes_size, start + options.batch_size)
                            index = np.arange(start, end)
                            start = end
                            batch_features, batch_adj, batch_labels = dataset.get_batch(
                                index)
                            feed_dict = {
                                inputs: batch_features,
                                laplacian: batch_adj,
                                labels: batch_labels,
                                learning_rate: LR.learning_rate
                            }
                            batch_embeddings = sess.run(model.vectors,
                                                        feed_dict=feed_dict)
                            vecs.append(batch_embeddings)
                        vecs = np.concatenate(vecs, axis=0)
                    checkpoint_path = os.path.join(ckpt_dir, 'model.ckpt')
                    utils.save_word2vec_format_and_ckpt(
                        options.vectors_path, vecs, checkpoint_path, sess,
                        saver, cur_step)
                    last_checkpoint_time = time.time()
                    last_checkpoint_step = cur_step
                    ckpted = True
                # update learning rate
                if ckpted or now - last_decay_time >= options.decay_interval or (
                        decay_steps > 0
                        and cur_step - last_decay_step >= decay_steps):
                    lr_info = np.loadtxt(lr_file, dtype=float)
                    if np.abs(lr_info[1] - decay_epochs) > 1e-6:
                        decay_epochs = lr_info[1]
                        decay_steps = round(decay_epochs * num_steps_per_epoch)
                    if np.abs(lr_info[2] - decay_rate) > 1e-6:
                        decay_rate = lr_info[2]
                    if np.abs(lr_info[3] - iter_epochs) > 1e-6:
                        iter_epochs = lr_info[3]
                        iter_steps = round(iter_epochs * num_steps_per_epoch)
                    if np.abs(lr_info[0] - initial_learning_rate) > 1e-6:
                        initial_learning_rate = lr_info[0]
                        LR.reset(initial_learning_rate=initial_learning_rate,
                                 initial_steps=cur_step,
                                 decay_rate=decay_rate,
                                 decay_steps=decay_steps,
                                 iter_steps=iter_steps)
                    else:
                        LR.exponential_decay(cur_step,
                                             decay_rate=decay_rate,
                                             decay_steps=decay_steps,
                                             iter_steps=iter_steps)
                    last_decay_time = time.time()
                    last_decay_step = cur_step
                if cur_step >= LR.iter_steps:
                    break
            summary_writer.close()
예제 #4
0
def train(walker, lr_file, ckpt_dir, checkpoint, options):
    vocab_size = walker.nodes_size
    types_size = walker.node_types_size
    num_steps_per_epoch = int(
        vocab_size * options.train_workers /
        options.batch_size)  # a rough formula of epoch in RWR.???????????
    iter_epochs = options.iter_epoches
    iter_steps = round(
        iter_epochs *
        num_steps_per_epoch)  # iter_epoches should be big enough to converge.
    decay_epochs = options.decay_epochs
    decay_steps = round(decay_epochs * num_steps_per_epoch)
    ckpt_steps = round(options.ckpt_epochs * num_steps_per_epoch)
    initial_learning_rate = options.learning_rate
    decay_rate = options.decay_rate

    LR = utils.LearningRateGenerator(
        initial_learning_rate=initial_learning_rate,
        initial_steps=0,
        decay_rate=decay_rate,
        decay_steps=decay_steps,
        iter_steps=iter_steps)

    with tf.Graph().as_default(), tf.device(
            '/gpu:0' if options.using_gpu else '/cpu:0'):

        global_step = tf.Variable(0, trainable=False, name="global_step")
        # inputs(center_nodes), labels(context_nodes), labels_type(context_nodes_type), neg_labels(neg_nodes)
        inputs = tf.placeholder(tf.int32, name='inputs')  # center_nodes
        labels = [
            tf.placeholder(tf.int32,
                           shape=[None],
                           name='labels_T{}'.format(type_i))
            for type_i in range(types_size)
        ]
        labels_mask = [
            tf.placeholder(tf.float32, name='labels_mask_T{}'.format(type_i))
            for type_i in range(types_size)
        ]
        neg_labels = [
            tf.placeholder(tf.int32,
                           shape=[None],
                           name='neg_labels_T{}'.format(type_i))
            for type_i in range(types_size)
        ]
        learning_rate = tf.placeholder(tf.float32, name='learning_rate')

        model = SGNS(vocab_size=vocab_size,
                     embedding_size=options.embedding_size,
                     type_size=types_size)

        train_op, loss = model.train(inputs, labels, labels_mask, neg_labels,
                                     global_step, learning_rate)

        # Create a saver.
        saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=6)

        summary_op = tf.summary.merge_all()

        # Build an initialization operation to run below.
        init_op = tf.global_variables_initializer()

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU implementations.
        config = tf.ConfigProto(
            allow_soft_placement=options.allow_soft_placement,
            log_device_placement=options.log_device_placement)
        config.gpu_options.per_process_gpu_memory_fraction = options.gpu_memory_fraction
        config.gpu_options.allow_growth = options.allow_growth
        # config.gpu_options.visible_device_list = visible_device_list

        with tf.Session(config=config) as sess:
            # first_step = 0
            if checkpoint == '0':  # new train
                sess.run(init_op)

            elif checkpoint == '-1':  # choose the latest one
                ckpt = tf.train.get_checkpoint_state(ckpt_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    # new_saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')
                    # Restores from checkpoint
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    # global_step_for_restore = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                    # first_step = int(global_step_for_restore) + 1
                else:
                    logger.warning('No checkpoint file found')
                    return
            else:
                if os.path.exists(
                        os.path.join(ckpt_dir,
                                     'model.ckpt-' + checkpoint + '.index')):
                    # new_saver = tf.train.import_meta_graph(
                    #     os.path.join(ckpt_dir, 'model.ckpt-' + checkpoint + '.meta'))
                    saver.restore(
                        sess, os.path.join(ckpt_dir,
                                           'model.ckpt-' + checkpoint))
                    # first_step = int(checkpoint) + 1
                else:
                    logger.warning(
                        'checkpoint {} not found'.format(checkpoint))
                    return

            summary_writer = tf.summary.FileWriter(ckpt_dir, sess.graph)

            last_loss_time = time.time() - options.loss_interval
            last_summary_time = time.time() - options.summary_interval
            last_decay_time = last_checkpoint_time = time.time()
            last_decay_step = last_summary_step = last_checkpoint_step = 0
            rwrgenerator = RWRGenerator(walker=walker,
                                        walk_times=options.walk_times)
            while True:
                start_time = time.time()
                batch_inputs, batch_labels, batch_labels_mask, batch_neg_labels = rwrgenerator.next_batch(
                )
                feed_dict = {
                    inputs: batch_inputs,
                    learning_rate: LR.learning_rate
                }
                for type_i in range(types_size):
                    feed_dict[labels[type_i]] = batch_labels[type_i]
                    feed_dict[labels_mask[type_i]] = batch_labels_mask[type_i]
                    feed_dict[neg_labels[type_i]] = batch_neg_labels[type_i]
                _, loss_value, cur_step = sess.run(
                    [train_op, loss, global_step], feed_dict=feed_dict)
                now = time.time()

                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'

                epoch, epoch_step = divmod(cur_step, num_steps_per_epoch)

                if now - last_loss_time >= options.loss_interval:
                    format_str = '%s: step=%d(%d/%d), lr=%.6f, loss=%.6f, duration/step=%.4fs'
                    logger.info(format_str %
                                (time.strftime('%Y-%m-%d %H:%M:%S',
                                               time.localtime(time.time())),
                                 cur_step, epoch_step, epoch, LR.learning_rate,
                                 loss_value, now - start_time))
                    last_loss_time = time.time()
                if now - last_summary_time >= options.summary_interval or cur_step - last_summary_step >= options.summary_steps or cur_step >= iter_steps:
                    summary_str = sess.run(summary_op, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, cur_step)
                    last_summary_time = time.time()
                    last_summary_step = cur_step
                ckpted = False
                # Save the model checkpoint periodically. (named 'model.ckpt-global_step.meta')
                if now - last_checkpoint_time >= options.ckpt_interval or cur_step - last_checkpoint_step >= ckpt_steps or cur_step >= iter_steps:
                    vecs, global_step_value = sess.run(
                        [model.vectors, global_step], feed_dict=feed_dict)
                    # vecs,weights,biases = sess.run([model.vectors,model.context_weights,model.context_biases],
                    #                              feed_dict=feed_dict)
                    checkpoint_path = os.path.join(ckpt_dir, 'model.ckpt')
                    utils.save_word2vec_format_and_ckpt(
                        options.vectors_path, vecs, checkpoint_path, sess,
                        saver, global_step_value, types_size)
                    # save_word2vec_format(vectors_path+".contexts", weights, walker.idx_nodes)
                    # save_word2vec_format(vectors_path+".context_biases", np.reshape(biases,[-1,1]), walker.idx_nodes)
                    last_checkpoint_time = time.time()
                    last_checkpoint_step = global_step_value
                    ckpted = True
                # update learning rate
                if ckpted or now - last_decay_time >= options.decay_interval or (
                        decay_steps > 0
                        and cur_step - last_decay_step >= decay_steps):
                    lr_info = np.loadtxt(lr_file, dtype=float)
                    if np.abs(lr_info[1] - decay_epochs) > 1e-6:
                        decay_epochs = lr_info[1]
                        decay_steps = round(decay_epochs * num_steps_per_epoch)
                    if np.abs(lr_info[2] - decay_rate) > 1e-6:
                        decay_rate = lr_info[2]
                    if np.abs(lr_info[3] - iter_epochs) > 1e-6:
                        iter_epochs = lr_info[3]
                        iter_steps = round(iter_epochs * num_steps_per_epoch)
                    if np.abs(lr_info[0] - initial_learning_rate) > 1e-6:
                        initial_learning_rate = lr_info[0]
                        LR.reset(initial_learning_rate=initial_learning_rate,
                                 initial_steps=cur_step,
                                 decay_rate=decay_rate,
                                 decay_steps=decay_steps,
                                 iter_steps=iter_steps)
                    else:
                        LR.exponential_decay(cur_step,
                                             decay_rate=decay_rate,
                                             decay_steps=decay_steps,
                                             iter_steps=iter_steps)
                    last_decay_time = time.time()
                    last_decay_step = cur_step
                if cur_step >= LR.iter_steps:
                    break

            summary_writer.close()
예제 #5
0
def train(dataset,
          vectors_path,
          lr_file,
          ckpt_dir,
          checkpoint,
          embedding_size,
          struct,
          alpha,
          beta,
          gamma,
          reg,
          sparse_dot,
          iter_epochs,
          batch_size,
          initial_learning_rate,
          decay_epochs,
          decay_interval,
          decay_rate,
          allow_soft_placement,
          log_device_placement,
          gpu_memory_fraction,
          using_gpu,
          allow_growth,
          loss_interval,
          summary_steps,
          summary_interval,
          ckpt_epochs,
          ckpt_interval,
          dbn_initial,
          dbn_epochs,
          dbn_batchsize,
          dbn_learning_rate,
          active_function="sigmoid"):
    actv_func = {
        'sigmoid': tf.sigmoid,
        'tanh': tf.tanh,
        'relu': tf.nn.relu,
        'leaky_relu': tf.nn.leaky_relu
    }[active_function]
    nodes_size = dataset.nodes_size
    num_steps_per_epoch = int(nodes_size / batch_size)  #
    iter_steps = round(
        iter_epochs *
        num_steps_per_epoch)  # iter_epochs should be big enough to converge.
    decay_steps = round(decay_epochs * num_steps_per_epoch)
    ckpt_steps = round(ckpt_epochs * num_steps_per_epoch)

    LR = utils.LearningRateGenerator(
        initial_learning_rate=initial_learning_rate,
        initial_steps=0,
        decay_rate=decay_rate,
        decay_steps=decay_steps,
        iter_steps=iter_steps)

    with tf.Graph().as_default(), tf.device(
            '/gpu:0' if using_gpu else '/cpu:0'):

        global_step = tf.Variable(0, trainable=False, name="global_step")
        adj_matrix = tf.placeholder(tf.float32, [None, None])
        if sparse_dot:
            inputs_sp_indices = tf.placeholder(tf.int64)
            inputs_sp_ids_val = tf.placeholder(tf.float32)
            inputs_sp_shape = tf.placeholder(tf.int64)
            inputs = tf.SparseTensor(inputs_sp_indices, inputs_sp_ids_val,
                                     inputs_sp_shape)
        else:
            inputs = tf.placeholder(tf.float32, [None, nodes_size])
        learning_rate = tf.placeholder(tf.float32, name='learning_rate')

        model = SDNE(nodes_size=nodes_size,
                     struct=struct,
                     embedding_size=embedding_size,
                     alpha=alpha,
                     beta=beta,
                     gamma=gamma,
                     reg=reg,
                     sparse_dot=sparse_dot,
                     active_function=actv_func)

        train_op, loss, embeddings = model.train(inputs, adj_matrix,
                                                 global_step, learning_rate)

        # Create a saver.
        saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=5)

        summary_op = tf.summary.merge_all()

        # Build an initialization operation to run below.
        init_op = tf.global_variables_initializer()

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU implementations.
        config = tf.ConfigProto(allow_soft_placement=allow_soft_placement,
                                log_device_placement=log_device_placement)
        config.gpu_options.per_process_gpu_memory_fraction = gpu_memory_fraction
        config.gpu_options.allow_growth = allow_growth
        # config.gpu_options.visible_device_list = visible_device_list

        with tf.Session(config=config) as sess:
            # first_step = 0
            if checkpoint == '0':  # new train
                sess.run(init_op)
                if dbn_initial:
                    time_start = time.time()
                    logger.info("DBN initial start ...")
                    RBMs = []
                    for i in range(len(model._struct) - 1):
                        RBM = rbm(model._struct[i],
                                  model._struct[i + 1],
                                  batchsize=dbn_batchsize,
                                  learning_rate=dbn_learning_rate,
                                  config=config)
                        logger.info("create rbm {}-{}".format(
                            model._struct[i], model._struct[i + 1]))
                        RBMs.append(RBM)
                        for epoch in range(dbn_epochs):
                            error = 0
                            for batch in range(0, nodes_size, batch_size):
                                # 这句没动
                                # 它是遍历了全局的node?
                                mini_batch, _ = dataset.next_batch(batch_size)
                                for k in range(len(RBMs) - 1):
                                    mini_batch = RBMs[k].getH(mini_batch)
                                error += RBM.fit(mini_batch)
                            logger.info("rbm_" + str(len(RBMs)) + " epochs:" +
                                        str(epoch) + " error: " + str(error))

                        W, bv, bh = RBM.getWb()
                        name = "encoder" + str(i)

                        def assign(a, b, sessss):
                            op = a.assign(b)
                            sessss.run(op)

                        assign(model._weights[name], W, sess)
                        assign(model._bias[name], bh, sess)

                        name = "decoder" + str(len(model._struct) - i - 2)
                        assign(model._weights[name], W.transpose(), sess)
                        assign(model._bias[name], bv, sess)
                    logger.info(
                        "dbn_init finished in {}s.".format(time.time() -
                                                           time_start))

                vecs = []
                start = 0
                while start < nodes_size:
                    end = min(nodes_size, start + batch_size)
                    index = np.arange(start, end)
                    start = end
                    batch_input, batch_adj = dataset.get_batch(index)
                    if sparse_dot:
                        batch_input_ind = np.vstack(
                            np.where(batch_input)).astype(np.int64).T
                        batch_input_shape = np.array(batch_input.shape).astype(
                            np.int64)
                        batch_input_val = batch_input[np.where(batch_input)]
                        feed_dict = {
                            inputs_sp_indices: batch_input_ind,
                            inputs_sp_shape: batch_input_shape,
                            inputs_sp_ids_val: batch_input_val,
                            adj_matrix: batch_adj,
                            learning_rate: LR.learning_rate
                        }
                    else:
                        feed_dict = {
                            inputs: batch_input,
                            adj_matrix: batch_adj,
                            learning_rate: LR.learning_rate
                        }
                    batch_embeddings = sess.run(embeddings,
                                                feed_dict=feed_dict)
                    vecs.append(batch_embeddings)
                vecs = np.concatenate(vecs, axis=0)
                checkpoint_path = os.path.join(ckpt_dir, 'model.ckpt')
                utils.save_word2vec_format_and_ckpt(vectors_path, vecs,
                                                    checkpoint_path, sess,
                                                    saver, 0)

            elif checkpoint == '-1':  # load the latest one
                ckpt = tf.train.get_checkpoint_state(ckpt_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    # new_saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')
                    # Restores from checkpoint
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    # global_step_for_restore = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                    # first_step = int(global_step_for_restore) + 1
                else:
                    logger.warning('No checkpoint file found')
                    return
            else:
                if os.path.exists(
                        os.path.join(ckpt_dir,
                                     'model.ckpt-' + checkpoint + '.index')):
                    # new_saver = tf.train.import_meta_graph(
                    #     os.path.join(ckpt_dir, 'model.ckpt-' + checkpoint + '.meta'))
                    saver.restore(
                        sess, os.path.join(ckpt_dir,
                                           'model.ckpt-' + checkpoint))
                    # first_step = int(checkpoint) + 1
                else:
                    logger.warning(
                        'checkpoint {} not found'.format(checkpoint))
                    return

            summary_writer = tf.summary.FileWriter(ckpt_dir, sess.graph)

            ## train
            last_loss_time = time.time() - loss_interval
            last_summary_time = time.time() - summary_interval
            last_decay_time = last_checkpoint_time = time.time()
            last_decay_step = last_summary_step = last_checkpoint_step = 0
            while True:
                start_time = time.time()
                batch_input, batch_adj = dataset.next_batch(
                    batch_size, keep_strict_batching=True)
                if sparse_dot:
                    batch_input_ind = np.vstack(np.where(batch_input)).astype(
                        np.int64).T
                    batch_input_shape = np.array(batch_input.shape).astype(
                        np.int64)
                    batch_input_val = batch_input[np.where(batch_input)]
                    feed_dict = {
                        inputs_sp_indices: batch_input_ind,
                        inputs_sp_shape: batch_input_shape,
                        inputs_sp_ids_val: batch_input_val,
                        adj_matrix: batch_adj,
                        learning_rate: LR.learning_rate
                    }
                else:
                    feed_dict = {
                        inputs: batch_input,
                        adj_matrix: batch_adj,
                        learning_rate: LR.learning_rate
                    }

                _, loss_value, cur_step = sess.run(
                    [train_op, loss, global_step], feed_dict=feed_dict)
                now = time.time()

                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'

                epoch, epoch_step = divmod(cur_step, num_steps_per_epoch)

                if now - last_loss_time >= loss_interval:
                    format_str = '%s: step=%d(%d/%d), lr=%.6f, loss=%.6f, duration/step=%.4fs'
                    logger.info(format_str %
                                (time.strftime('%Y-%m-%d %H:%M:%S',
                                               time.localtime(time.time())),
                                 cur_step, epoch_step, epoch, LR.learning_rate,
                                 loss_value, now - start_time))
                    last_loss_time = time.time()
                if now - last_summary_time >= summary_interval or cur_step - last_summary_step >= summary_steps or cur_step >= iter_steps:
                    summary_str = sess.run(summary_op, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, cur_step)
                    last_summary_time = time.time()
                    last_summary_step = cur_step
                ckpted = False
                # Save the model checkpoint periodically. (named 'model.ckpt-global_step.meta')
                if now - last_checkpoint_time >= ckpt_interval or cur_step - last_checkpoint_step >= ckpt_steps or cur_step >= iter_steps:
                    vecs = []
                    start = 0
                    while start < nodes_size:
                        end = min(nodes_size, start + batch_size)
                        index = np.arange(start, end)
                        start = end
                        batch_input, batch_adj = dataset.get_batch(index)
                        if sparse_dot:
                            batch_input_ind = np.vstack(
                                np.where(batch_input)).astype(np.int64).T
                            batch_input_shape = np.array(
                                batch_input.shape).astype(np.int64)
                            batch_input_val = batch_input[np.where(
                                batch_input)]
                            feed_dict = {
                                inputs_sp_indices: batch_input_ind,
                                inputs_sp_shape: batch_input_shape,
                                inputs_sp_ids_val: batch_input_val,
                                adj_matrix: batch_adj,
                                learning_rate: LR.learning_rate
                            }
                        else:
                            feed_dict = {
                                inputs: batch_input,
                                adj_matrix: batch_adj,
                                learning_rate: LR.learning_rate
                            }
                        batch_embeddings = sess.run(embeddings,
                                                    feed_dict=feed_dict)
                        vecs.append(batch_embeddings)
                    vecs = np.concatenate(vecs, axis=0)
                    checkpoint_path = os.path.join(ckpt_dir, 'model.ckpt')
                    utils.save_word2vec_format_and_ckpt(
                        vectors_path, vecs, checkpoint_path, sess, saver,
                        cur_step)
                    last_checkpoint_time = time.time()
                    last_checkpoint_step = cur_step
                    ckpted = True
                # update learning rate
                if ckpted or now - last_decay_time >= decay_interval or (
                        decay_steps > 0
                        and cur_step - last_decay_step >= decay_steps):
                    lr_info = np.loadtxt(lr_file, dtype=float)
                    if np.abs(lr_info[1] - decay_epochs) > 1e-6:
                        decay_epochs = lr_info[1]
                        decay_steps = round(decay_epochs * num_steps_per_epoch)
                    if np.abs(lr_info[2] - decay_rate) > 1e-6:
                        decay_rate = lr_info[2]
                    if np.abs(lr_info[3] - iter_epochs) > 1e-6:
                        iter_epochs = lr_info[3]
                        iter_steps = round(iter_epochs * num_steps_per_epoch)
                    if np.abs(lr_info[0] - initial_learning_rate) > 1e-6:
                        initial_learning_rate = lr_info[0]
                        LR.reset(initial_learning_rate=initial_learning_rate,
                                 initial_steps=cur_step,
                                 decay_rate=decay_rate,
                                 decay_steps=decay_steps,
                                 iter_steps=iter_steps)
                    else:
                        LR.exponential_decay(cur_step,
                                             decay_rate=decay_rate,
                                             decay_steps=decay_steps,
                                             iter_steps=iter_steps)
                    last_decay_time = time.time()
                    last_decay_step = cur_step
                if cur_step >= LR.iter_steps:
                    break