コード例 #1
0
ファイル: train.py プロジェクト: potsawee/seq2seq
def train(config):
    # --------- configurations --------- #
    batch_size = config['batch_size']
    save_path = config['save']  # path to store model
    saved_model = config['load']  # None or path
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    # ---------------------------------- #
    write_config(save_path + '/config.txt', config)

    # random seed
    random.seed(config['random_seed'])
    # np.random.seed(config['random_seed'])

    batches, vocab_size, src_word2id, tgt_word2id = construct_training_data_batches(
        config)

    tgt_id2word = list(tgt_word2id.keys())

    params = {
        'vocab_src_size': vocab_size['src'],
        'vocab_tgt_size': vocab_size['tgt'],
        'go_id': tgt_word2id['<go>'],
        'eos_id': tgt_word2id['</s>']
    }

    model = EncoderDecoder(config, params)
    model.build_network()

    learning_rate = config['learning_rate']
    decay_rate = config['decay_rate']

    tf_variables = tf.trainable_variables()
    for i in range(len(tf_variables)):
        print(tf_variables[i])

    # save & restore model
    saver = tf.train.Saver(max_to_keep=1)

    if config['use_gpu']:
        if 'X_SGE_CUDA_DEVICE' in os.environ:
            print('running on the stack...')
            cuda_device = os.environ['X_SGE_CUDA_DEVICE']
            print('X_SGE_CUDA_DEVICE is set to {}'.format(cuda_device))
            os.environ['CUDA_VISIBLE_DEVICES'] = cuda_device

        else:  # development only e.g. air202
            print('running locally...')
            os.environ[
                'CUDA_VISIBLE_DEVICES'] = '1'  # choose the device (GPU) here

        sess_config = tf.ConfigProto(allow_soft_placement=True)
        sess_config.gpu_options.allow_growth = True  # Whether the GPU memory usage can grow dynamically.
        sess_config.gpu_options.per_process_gpu_memory_fraction = 0.95  # The fraction of GPU memory that the process can use.

    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = ''
        sess_config = tf.ConfigProto()

    with tf.Session(config=sess_config) as sess:

        if saved_model == None:
            sess.run(tf.global_variables_initializer())
            # ------------ load pre-trained embeddings ------------ #
            if config['load_embedding_src'] != None:
                src_embedding = sess.run(model.src_word_embeddings)
                src_embedding_matrix = load_pretrained_embedding(
                    src_word2id, src_embedding, config['load_embedding_src'])
                sess.run(
                    model.src_word_embeddings.assign(src_embedding_matrix))
            if config['load_embedding_tgt'] != None:
                if config['load_embedding_tgt'] == config[
                        'load_embedding_src']:
                    sess.run(
                        model.tgt_word_embeddings.assign(src_embedding_matrix))
                else:
                    tgt_embedding = sess.run(model.tgt_word_embeddings)
                    tgt_embedding_matrix = load_pretrained_embedding(
                        tgt_word2id, tgt_embedding,
                        config['load_embedding_tgt'])
                    sess.run(
                        model.tgt_word_embeddings.assign(tgt_embedding_matrix))
            # ----------------------------------------------------- #
        else:
            new_saver = tf.train.import_meta_graph(saved_model + '.meta')
            new_saver.restore(sess, saved_model)
            print('loaded model...', saved_model)

        # ------------ TensorBoard ------------ #
        # summary_writer = tf.summary.FileWriter(save_path + '/tfboard/', graph_def=sess.graph_def)
        # ------------------------------------- #

        # ------------ To print out some output -------------------- #
        my_sentences = [
            'this is test . </s>',
            'this is confirm my reservation at hotel . </s>',
            'playing tennis good for you . </s>',
            'when talking about successful longterm business relationships customer services are important element </s>'
        ]
        my_sent_ids = []
        for my_sentence in my_sentences:
            ids = []
            for word in my_sentence.split():
                if word in src_word2id:
                    ids.append(src_word2id[word])
                else:
                    ids.append(src_word2id['<unk>'])
            my_sent_ids.append(ids)
        my_sent_len = [len(my_sent) for my_sent in my_sent_ids]
        my_sent_ids = [
            ids + [src_word2id['</s>']] *
            (config['max_sentence_length'] - len(ids)) for ids in my_sent_ids
        ]
        infer_dict = {
            model.src_word_ids: my_sent_ids,
            model.src_sentence_lengths: my_sent_len,
            model.dropout: 0.0,
            model.learning_rate: learning_rate
        }
        # ---------------------------------------------------------- #

        num_epochs = config['num_epochs']
        for epoch in range(num_epochs):
            print("num_batches = ", len(batches))

            random.shuffle(batches)

            epoch_loss = 0

            for i, batch in enumerate(batches):

                feed_dict = {
                    model.src_word_ids: batch['src_word_ids'],
                    model.tgt_word_ids: batch['tgt_word_ids'],
                    model.src_sentence_lengths: batch['src_sentence_lengths'],
                    model.tgt_sentence_lengths: batch['tgt_sentence_lengths'],
                    model.dropout: config['dropout'],
                    model.learning_rate: learning_rate
                }

                [_, loss] = sess.run([model.train_op, model.train_loss],
                                     feed_dict=feed_dict)
                epoch_loss += loss

                if i % 100 == 0:
                    # to print out training status

                    # if config['decoding_method'] != 'beamsearch':
                    # [train_loss, infer_loss] = sess.run([model.train_loss, model.infer_loss], feed_dict=feed_dict)
                    # print("batch: {} --- train_loss: {:.5f} | inf_loss: {:.5f}".format(i, train_loss, infer_loss))

                    # else:
                    # --- beam search --- #
                    # [train_loss] = sess.run([model.train_loss], feed_dict=feed_dict)
                    # print("BEAMSEARCH - batch: {} --- train_loss: {:.5f}".format(i, train_loss))

                    print("batch: {} --- avg train loss: {:.5f}".format(
                        i, epoch_loss / (i + 1)))

                    sys.stdout.flush()

                if i % 500 == 0:

                    [my_translations] = sess.run([model.translations],
                                                 feed_dict=infer_dict)
                    # pdb.set_trace()
                    for my_sent in my_translations:
                        my_words = [tgt_id2word[id] for id in my_sent]
                        print(' '.join(my_words))

            model.increment_counter()
            learning_rate *= decay_rate

            print("---------------------------------------------------")
            print("epoch {} done".format(epoch + 1))
            print("total training loss = {}".format(epoch_loss))
            print("---------------------------------------------------")

            if math.isnan(epoch_loss):
                print("stop training - loss/gradient exploded")
                break

            saver.save(sess, save_path + '/model', global_step=epoch)
コード例 #2
0
ファイル: adapt-nmt.py プロジェクト: potsawee/py-tools-ged
def adapt(config):
    if 'X_SGE_CUDA_DEVICE' in os.environ:
        print('running on the stack...')
        cuda_device = os.environ['X_SGE_CUDA_DEVICE']
        print('X_SGE_CUDA_DEVICE is set to {}'.format(cuda_device))
        os.environ['CUDA_VISIBLE_DEVICES'] = cuda_device

    else:  # development only e.g. air202
        print('running locally...')
        os.environ[
            'CUDA_VISIBLE_DEVICES'] = '3'  # choose the device (GPU) here

    sess_config = tf.ConfigProto()

    batches, vocab_size, src_word2id, tgt_word2id = construct_training_data_batches(
        config)

    tgt_id2word = list(tgt_word2id.keys())

    params = {
        'vocab_src_size': len(src_word2id),
        'vocab_tgt_size': len(tgt_word2id),
        'go_id': tgt_word2id['<go>'],
        'eos_id': tgt_word2id['</s>']
    }

    # build the model
    model = EncoderDecoder(config, params)
    model.build_network()

    # -------- Adaption work -------- #
    bias_name = 'decoder/decode_with_shared_attention/decoder/dense/bias:0'
    weight_name = 'decoder/decode_with_shared_attention/decoder/dense/kernel:0'
    param_names = [bias_name, weight_name]
    # param_names = [var.name for var in tf.trainable_variables()]
    model.adapt_weights(param_names)
    # ------------------------------- #

    new_save_path = config['save']
    if not os.path.exists(new_save_path):
        os.makedirs(new_save_path)
    write_config(new_save_path + '/config.txt', config)

    # save & restore model
    saver = tf.train.Saver(max_to_keep=1)
    save_path = config['load']
    model_number = config['model_number'] if config[
        'model_number'] != None else config['num_epochs'] - 1
    full_save_path_to_model = save_path + '/model-' + str(model_number)

    with tf.Session(config=sess_config) as sess:
        # Restore variables from disk.
        saver.restore(sess, full_save_path_to_model)

        for epoch in range(10):
            print("num_batches = ", len(batches))

            random.shuffle(batches)

            for i, batch in enumerate(batches):
                feed_dict = {
                    model.src_word_ids: batch['src_word_ids'],
                    model.tgt_word_ids: batch['tgt_word_ids'],
                    model.src_sentence_lengths: batch['src_sentence_lengths'],
                    model.tgt_sentence_lengths: batch['tgt_sentence_lengths'],
                    model.dropout: config['dropout']
                }

                _ = sess.run([model.adapt_op], feed_dict=feed_dict)

                if i % 100 == 0:
                    # to print out training status

                    if config['decoding_method'] != 'beamsearch':
                        [train_loss, infer_loss
                         ] = sess.run([model.train_loss, model.infer_loss],
                                      feed_dict=feed_dict)
                        print(
                            "batch: {} --- train_loss: {:.5f} | inf_loss: {:.5f}"
                            .format(i, train_loss, infer_loss))

                    else:
                        # --- beam search --- #
                        [train_loss] = sess.run([model.train_loss],
                                                feed_dict=feed_dict)
                        print("BEAMSEARCH - batch: {} --- train_loss: {:.5f}".
                              format(i, train_loss))

                    sys.stdout.flush()

            model.increment_counter()
            print("################## EPOCH {} done ##################".format(
                epoch))
            saver.save(sess, new_save_path + '/model', global_step=epoch)