def train(num_layers, length, rate, batch, epochs, load_path, save_path, preproc_dir): if rate is None: rate = callback.CustomSchedule(par.embedding_dim) preproc_dir = Path(preproc_dir) model = MusicTransformerDecoder( embedding_dim=256, vocab_size=par.vocab_size, num_layer=num_layers, max_seq=length, dropout=0.2, debug=False, loader_path=load_path, ) model.compile( optimizer=Adam(rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9), loss=callback.transformer_dist_train_loss, ) time = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S") train_summary_writer = tf.summary.create_file_writer( f"logs/mt_decoder/{time}/train") eval_summary_writer = tf.summary.create_file_writer( f"logs/mt_decoder/{time}/eval") dataset = Data(preproc_dir) idx = 0 with click.progressbar(length=epochs) as prog: for e in prog: model.reset_metrics() with click.progressbar(length=len(dataset.files) // batch) as prog2: for b in prog2: batch_x, batch_y = dataset.slide_seq2seq_batch( batch, length) loss, acc = model.train_on_batch(batch_x, batch_y) if b % 100 == 0: eval_x, eval_y = dataset.slide_seq2seq_batch( batch, length, "eval") (eloss, eacc), weights = model.evaluate(eval_x, eval_y) if save_path is not None: model.save(save_path) with train_summary_writer.as_default(): if b == 0: tf.summary.histogram("target_analysis", batch_y, step=e) tf.summary.histogram("source_analysis", batch_x, step=e) tf.summary.scalar("loss", loss, step=idx) tf.summary.scalar("accuracy", acc, step=idx) with eval_summary_writer.as_default(): if b == 0: model.sanity_check(eval_x, eval_y, step=e) tf.summary.scalar("loss", eloss, step=idx) tf.summary.scalar("accuracy", eacc, step=idx) for i, weight in enumerate(weights): with tf.name_scope("layer_%d" % i): with tf.name_scope("w"): utils.attention_image_summary(weight, step=idx) print( f"Loss: {loss:6.6} (e: {eloss:6.6}), Accuracy: {acc} (e: {eacc})" ) idx += 1
result_metrics = mt.train_on_batch(batch_x, batch_y) if b % 100 == 0: eval_x, eval_y = dataset.slide_seq2seq_batch(batch_size, max_seq, 'eval') eval_result_metrics, weights = mt.evaluate(eval_x, eval_y) mt.save(save_path) with train_summary_writer.as_default(): if b == 0: tf.summary.histogram("target_analysis", batch_y, step=e) tf.summary.histogram("source_analysis", batch_x, step=e) tf.summary.scalar('loss', result_metrics[0], step=idx) tf.summary.scalar('accuracy', result_metrics[1], step=idx) with eval_summary_writer.as_default(): if b == 0: mt.sanity_check(eval_x, eval_y, step=e) tf.summary.scalar('loss', eval_result_metrics[0], step=idx) tf.summary.scalar('accuracy', eval_result_metrics[1], step=idx) for i, weight in enumerate(weights): with tf.name_scope("layer_%d" % i): with tf.name_scope("w"): utils.attention_image_summary(weight, step=idx) # for i, weight in enumerate(weights): # with tf.name_scope("layer_%d" % i): # with tf.name_scope("_w0"): # utils.attention_image_summary(weight[0]) # with tf.name_scope("_w1"): # utils.attention_image_summary(weight[1]) idx += 1 print('\n====================================================')
def train(input_path, save_path, l_r=None, batch_size=2, max_seq=1024, epochs=100, load_path=None, num_layer=6, log_dir='/pfs/out/logs'): # load data dataset = Data(input_path) print('dataset', dataset) # load model learning_rate = callback.CustomSchedule(par.embedding_dim) if l_r is None else l_r opt = Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9) # define model mt = MusicTransformerDecoder( embedding_dim=256, vocab_size=par.vocab_size, num_layer=num_layer, max_seq=max_seq, dropout=0.2, debug=False, loader_path=load_path) mt.compile(optimizer=opt, loss=callback.transformer_dist_train_loss) # define tensorboard writer current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') train_log_dir = '{log_dir}/{time}/train'.format(log_dir=log_dir, time=current_time) eval_log_dir = '{log_dir}/{time}/eval'.format(log_dir=log_dir, time=current_time) train_summary_writer = tf.summary.create_file_writer(train_log_dir) eval_summary_writer = tf.summary.create_file_writer(eval_log_dir) # Train Start idx = 0 batchings = len(dataset.files) // batch_size how_often_to_print = 50 for e in tqdm(range(epochs), desc='epochs'): mt.reset_metrics() for b in tqdm(range(batchings), desc='batches'): try: batch_x, batch_y = dataset.slide_seq2seq_batch(batch_size, max_seq) except: continue result_metrics = mt.train_on_batch(batch_x, batch_y) if b % how_often_to_print == 0: eval_x, eval_y = dataset.slide_seq2seq_batch(batch_size, max_seq, 'eval') eval_result_metrics, weights = mt.evaluate(eval_x, eval_y) mt.save(save_path) with train_summary_writer.as_default(): if b == 0: tf.summary.histogram("target_analysis", batch_y, step=e) tf.summary.histogram("source_analysis", batch_x, step=e) tf.summary.scalar('loss', result_metrics[0], step=idx) tf.summary.scalar('accuracy', result_metrics[1], step=idx) with eval_summary_writer.as_default(): if b == 0: mt.sanity_check(eval_x, eval_y, step=e) tf.summary.scalar('loss', eval_result_metrics[0], step=idx) tf.summary.scalar('accuracy', eval_result_metrics[1], step=idx) for i, weight in enumerate(weights): with tf.name_scope("layer_%d" % i): with tf.name_scope("w"): utils.attention_image_summary(weight, step=idx) # for i, weight in enumerate(weights): # with tf.name_scope("layer_%d" % i): # with tf.name_scope("_w0"): # utils.attention_image_summary(weight[0]) # with tf.name_scope("_w1"): # utils.attention_image_summary(weight[1]) idx += 1 print('\n====================================================') print('Epoch/Batch: {}/{}'.format(e, b)) print('Train >>>> Loss: {:6.6}, Accuracy: {}'.format(result_metrics[0], result_metrics[1])) print('Eval >>>> Loss: {:6.6}, Accuracy: {}'.format(eval_result_metrics[0], eval_result_metrics[1]))