def _main(_): # Data batch_size = config.batch_size memory_size = config.memory_size terminating_learning_rate = config.terminating_learning_rate data = prepare_data(FLAGS.data_path) vocab_size = data["vocab_size"] print('vocab_size = {}'.format(vocab_size)) inputs = tf.placeholder(tf.int32, [None, memory_size], name="inputs") targets = tf.placeholder(tf.int32, [None], name="targets") # Model architecture initializer = tf.random_normal_initializer(stddev=config.initialize_stddev) with tf.variable_scope("model", initializer=initializer): memnet = tx.modules.MemNetRNNLike(raw_memory_dim=vocab_size, hparams=config.memnet) queries = tf.fill([tf.shape(inputs)[0], config.dim], config.query_constant) logits = memnet(inputs, queries) # Losses & train ops mle_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=targets, logits=logits) mle_loss = tf.reduce_sum(mle_loss) # Use global_step to pass epoch, for lr decay lr = config.opt["optimizer"]["kwargs"]["learning_rate"] learning_rate = tf.placeholder(tf.float32, [], name="learning_rate") global_step = tf.Variable(0, dtype=tf.int32, name="global_step") increment_global_step = tf.assign_add(global_step, 1) train_op = tx.core.get_train_op(mle_loss, learning_rate=learning_rate, global_step=global_step, increment_global_step=False, hparams=config.opt) def _run_epoch(sess, data_iter, epoch, is_train=False): loss = 0. iters = 0 fetches = {"mle_loss": mle_loss} if is_train: fetches["train_op"] = train_op mode = (tf.estimator.ModeKeys.TRAIN if is_train else tf.estimator.ModeKeys.EVAL) for _, (x, y) in enumerate(data_iter): batch_size = x.shape[0] feed_dict = { inputs: x, targets: y, learning_rate: lr, tx.global_mode(): mode, } rets = sess.run(fetches, feed_dict) loss += rets["mle_loss"] iters += batch_size ppl = np.exp(loss / iters) return ppl saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) try: saver.restore(sess, "ckpt/model.ckpt") print('restored checkpoint.') except: print('restore checkpoint failed.') last_valid_ppl = None heuristic_lr_decay = (hasattr(config, 'heuristic_lr_decay') and config.heuristic_lr_decay) while True: if lr < terminating_learning_rate: break epoch = sess.run(global_step) if epoch >= config.num_epochs: print('Too many epochs!') break print('epoch: {} learning_rate: {:.6f}'.format(epoch, lr)) # Train train_data_iter = ptb_iterator(data["train_text_id"], batch_size, memory_size) train_ppl = _run_epoch(sess, train_data_iter, epoch, is_train=True) print("Train Perplexity: {:.3f}".format(train_ppl)) sess.run(increment_global_step) # checkpoint if epoch % 5 == 0: try: saver.save(sess, "ckpt/model.ckpt") print("saved checkpoint.") except: print("save checkpoint failed.") # Valid valid_data_iter = ptb_iterator(data["valid_text_id"], batch_size, memory_size) valid_ppl = _run_epoch(sess, valid_data_iter, epoch) print("Valid Perplexity: {:.3f}".format(valid_ppl)) # Learning rate decay if last_valid_ppl: if heuristic_lr_decay: if valid_ppl > last_valid_ppl * config.heuristic_threshold: lr /= 1. + (valid_ppl / last_valid_ppl \ - config.heuristic_threshold) \ * config.heuristic_rate last_valid_ppl = last_valid_ppl \ * (1 - config.heuristic_smooth_rate) \ + valid_ppl * config.heuristic_smooth_rate else: if valid_ppl > last_valid_ppl: lr /= config.learning_rate_anneal_factor last_valid_ppl = valid_ppl else: last_valid_ppl = valid_ppl print("last_valid_ppl: {:.6f}".format(last_valid_ppl)) epoch = sess.run(global_step) print('Terminate after epoch ', epoch) # Test test_data_iter = ptb_iterator(data["test_text_id"], 1, memory_size) test_ppl = _run_epoch(sess, test_data_iter, 0) print("Test Perplexity: {:.3f}".format(test_ppl))
def _main(_): # Data batch_size = config.batch_size num_steps = config.num_steps data = prepare_data(FLAGS.data_path) vocab_size = data["vocab_size"] inputs = tf.placeholder(tf.int32, [batch_size, num_steps]) targets = tf.placeholder(tf.int32, [batch_size, num_steps]) # Model architecture initializer = tf.random_uniform_initializer(-config.init_scale, config.init_scale) with tf.variable_scope("model", initializer=initializer): embedder = tx.modules.WordEmbedder(vocab_size=vocab_size, hparams=config.emb) emb_inputs = embedder(inputs) if config.keep_prob < 1: emb_inputs = tf.nn.dropout( emb_inputs, tx.utils.switch_dropout(config.keep_prob)) decoder = tx.modules.BasicRNNDecoder(vocab_size=vocab_size, hparams={"rnn_cell": config.cell}) initial_state = decoder.zero_state(batch_size, tf.float32) outputs, final_state, seq_lengths = decoder( decoding_strategy="train_greedy", impute_finished=True, inputs=emb_inputs, sequence_length=[num_steps] * batch_size, initial_state=initial_state) # Losses & train ops mle_loss = tx.losses.sequence_sparse_softmax_cross_entropy( labels=targets, logits=outputs.logits, sequence_length=seq_lengths) # Use global_step to pass epoch, for lr decay global_step = tf.placeholder(tf.int32) train_op = tx.core.get_train_op(mle_loss, global_step=global_step, increment_global_step=False, hparams=config.opt) def _run_epoch(sess, data_iter, epoch, is_train=False, verbose=False): start_time = time.time() loss = 0. iters = 0 state = sess.run(initial_state) fetches = { "mle_loss": mle_loss, "final_state": final_state, } if is_train: fetches["train_op"] = train_op epoch_size = (len(data["train_text_id"]) // batch_size - 1)\ // num_steps mode = (tf.estimator.ModeKeys.TRAIN if is_train else tf.estimator.ModeKeys.EVAL) for step, (x, y) in enumerate(data_iter): feed_dict = { inputs: x, targets: y, global_step: epoch, tx.global_mode(): mode, } for i, (c, h) in enumerate(initial_state): feed_dict[c] = state[i].c feed_dict[h] = state[i].h rets = sess.run(fetches, feed_dict) loss += rets["mle_loss"] state = rets["final_state"] iters += num_steps ppl = np.exp(loss / iters) if verbose and is_train and step % (epoch_size // 10) == 10: print("%.3f perplexity: %.3f speed: %.0f wps" % ((step + 1) * 1.0 / epoch_size, ppl, iters * batch_size / (time.time() - start_time))) ppl = np.exp(loss / iters) return ppl with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) for epoch in range(config.num_epochs): # Train train_data_iter = ptb_iterator(data["train_text_id"], config.batch_size, num_steps) train_ppl = _run_epoch(sess, train_data_iter, epoch, is_train=True, verbose=True) print("Epoch: %d Train Perplexity: %.3f" % (epoch, train_ppl)) # Valid valid_data_iter = ptb_iterator(data["valid_text_id"], config.batch_size, num_steps) valid_ppl = _run_epoch(sess, valid_data_iter, epoch) print("Epoch: %d Valid Perplexity: %.3f" % (epoch, valid_ppl)) # Test test_data_iter = ptb_iterator(data["test_text_id"], batch_size, num_steps) test_ppl = _run_epoch(sess, test_data_iter, 0) print("Test Perplexity: %.3f" % (test_ppl))
def _main(_): # Data tf.logging.set_verbosity(tf.logging.INFO) # 1. initialize the horovod hvd.init() batch_size = config.batch_size num_steps = config.num_steps data = prepare_data(FLAGS.data_path) vocab_size = data["vocab_size"] inputs = tf.placeholder(tf.int32, [None, num_steps], name='inputs') targets = tf.placeholder(tf.int32, [None, num_steps], name='targets') # Model architecture initializer = tf.random_uniform_initializer(-config.init_scale, config.init_scale) with tf.variable_scope("model", initializer=initializer): embedder = tx.modules.WordEmbedder(vocab_size=vocab_size, hparams=config.emb) emb_inputs = embedder(inputs) if config.keep_prob < 1: emb_inputs = tf.nn.dropout( emb_inputs, tx.utils.switch_dropout(config.keep_prob)) decoder = tx.modules.BasicRNNDecoder(vocab_size=vocab_size, hparams={"rnn_cell": config.cell}) # This _batch_size equals to batch_size // hvd.size() in # distributed training. # because the mini-batch is distributed to multiple GPUs _batch_size = tf.shape(inputs)[0] initial_state = decoder.zero_state(_batch_size, tf.float32) seq_length = tf.broadcast_to([num_steps], (_batch_size, )) outputs, final_state, seq_lengths = decoder( decoding_strategy="train_greedy", impute_finished=True, inputs=emb_inputs, sequence_length=seq_length, initial_state=initial_state) # Losses & train ops mle_loss = tx.losses.sequence_sparse_softmax_cross_entropy( labels=targets, logits=outputs.logits, sequence_length=seq_lengths) # Use global_step to pass epoch, for lr decay global_step = tf.placeholder(tf.int32) opt = tx.core.get_optimizer(global_step=global_step, hparams=config.opt) # 2. wrap the optimizer opt = hvd.DistributedOptimizer(opt) train_op = tx.core.get_train_op(loss=mle_loss, optimizer=opt, global_step=global_step, learning_rate=None, increment_global_step=False, hparams=config.opt) def _run_epoch(sess, data_iter, epoch, is_train=False, verbose=False): start_time = time.time() loss = 0. iters = 0 fetches = { "mle_loss": mle_loss, "final_state": final_state, } if is_train: fetches["train_op"] = train_op epoch_size = (len(data["train_text_id"]) // batch_size - 1)\ // num_steps mode = (tf.estimator.ModeKeys.TRAIN if is_train else tf.estimator.ModeKeys.EVAL) for step, (x, y) in enumerate(data_iter): if step == 0: state = sess.run(initial_state, feed_dict={inputs: x}) feed_dict = { inputs: x, targets: y, global_step: epoch, tx.global_mode(): mode, } for i, (c, h) in enumerate(initial_state): feed_dict[c] = state[i].c feed_dict[h] = state[i].h rets = sess.run(fetches, feed_dict) loss += rets["mle_loss"] state = rets["final_state"] iters += num_steps ppl = np.exp(loss / iters) if verbose and is_train and hvd.rank() == 0 \ and (step + 1) % (epoch_size // 10) == 0: tf.logging.info( "%.3f perplexity: %.3f speed: %.0f wps" % ((step + 1) * 1.0 / epoch_size, ppl, iters * batch_size / (time.time() - start_time))) _elapsed_time = time.time() - start_time tf.logging.info("epoch time elapsed: %f" % (_elapsed_time)) ppl = np.exp(loss / iters) return ppl, _elapsed_time # 3. set broadcase global variables from rank-0 process bcast = hvd.broadcast_global_variables(0) # 4. set visible GPU session_config = tf.ConfigProto() session_config.gpu_options.visible_device_list = str(hvd.local_rank()) with tf.Session(config=session_config) as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) # 5. run the broadcast_global_variables node before training bcast.run() _times = [] for epoch in range(config.num_epochs): # Train train_data_iter = ptb_iterator(data["train_text_id"], config.batch_size, num_steps, is_train=True) train_ppl, train_time = _run_epoch(sess, train_data_iter, epoch, is_train=True, verbose=True) _times.append(train_time) tf.logging.info("Epoch: %d Train Perplexity: %.3f" % (epoch, train_ppl)) # Valid in the main process if hvd.rank() == 0: valid_data_iter = ptb_iterator(data["valid_text_id"], config.batch_size, num_steps) valid_ppl, _ = _run_epoch(sess, valid_data_iter, epoch) tf.logging.info("Epoch: %d Valid Perplexity: %.3f" % (epoch, valid_ppl)) tf.logging.info('train times: %s' % (_times)) tf.logging.info('average train time/epoch %f' % np.mean(np.array(_times))) # Test in the main process if hvd.rank() == 0: test_data_iter = ptb_iterator(data["test_text_id"], batch_size, num_steps) test_ppl, _ = _run_epoch(sess, test_data_iter, 0) tf.logging.info("Test Perplexity: %.3f" % (test_ppl))