def main(): config = Config() full_dataset = np.load('swbd_test2.npy').item() train_data = Dataset(full_dataset, partition="train", config=config) dev_data = Dataset(full_dataset, partition="dev", config=config) train_model = Model(is_train=True, config=config, reuse=None) dev_model = Model(is_train=False, config=config, reuse=True) batch_size = config.batch_size saver = tf.train.Saver() proto = tf.ConfigProto(intra_op_parallelism_threads=2) with tf.Session(config=proto) as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) ckpt = tf.train.get_checkpoint_state(config.ckptdir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) print("restored from %s" % ckpt.model_checkpoint_path) batch = 0 for epoch in range(config.current_epoch, config.num_epochs): print("epoch: ", epoch) losses = [] for x, ts, same, diff in train_data.batch(batch_size, config.max_same, config.max_diff): _, loss = train_model.get_loss(sess, x, ts, same, diff) losses.append(loss) if batch % config.log_interval == 0: print("avg batch loss: %.4f" % np.mean(losses[-config.log_interval:])) batch += 1 embeddings, labels = [], [] for x, ts, ids in dev_data.batch(batch_size): embeddings.append(dev_model.get_embeddings(sess, x, ts)) labels.append(ids) embeddings, labels = np.concatenate(embeddings), np.concatenate( labels) print("ap: %.4f" % average_precision(embeddings, labels)) saver.save(sess, path.join(config.ckptdir, "model"), global_step=epoch)