Exemplo n.º 1
0
            eval_preiction, weights = single_mt.forward(eval_x)

            eval_metrics = metric_set(eval_preiction, eval_y)
            torch.save(single_mt.state_dict(),
                       args.model_dir + '/train-{}.pth'.format(e))
            if b == 0:
                train_summary_writer.add_histogram("target_analysis",
                                                   batch_y,
                                                   global_step=e)
                train_summary_writer.add_histogram("source_analysis",
                                                   batch_x,
                                                   global_step=e)
                for i, weight in enumerate(weights):
                    attn_log_name = "attn/layer-{}".format(i)
                    utils.attention_image_summary(attn_log_name,
                                                  weight,
                                                  step=idx,
                                                  writer=eval_summary_writer)

            eval_summary_writer.add_scalar('loss',
                                           eval_metrics['loss'],
                                           global_step=idx)
            eval_summary_writer.add_scalar('accuracy',
                                           eval_metrics['accuracy'],
                                           global_step=idx)
            eval_summary_writer.add_histogram("logits_bucket",
                                              eval_metrics['bucket'],
                                              global_step=idx)

            print('\n====================================================')
            print('Epoch/Batch: {}/{}'.format(e, b))
            print('Train >>>> Loss: {:6.6}, Accuracy: {}'.format(
Exemplo n.º 2
0
                if b == 0:
                    tf.summary.histogram("target_analysis", batch_y, step=e)
                    tf.summary.histogram("source_analysis", batch_x, step=e)

                tf.summary.scalar('loss', result_metrics[0], step=idx)
                tf.summary.scalar('accuracy', result_metrics[1], step=idx)

            with eval_summary_writer.as_default():
                if b == 0:
                    mt.sanity_check(eval_x, eval_y, step=e)

                tf.summary.scalar('loss', eval_result_metrics[0], step=idx)
                tf.summary.scalar('accuracy', eval_result_metrics[1], step=idx)
                for i, weight in enumerate(weights):
                    with tf.name_scope("layer_%d" % i):
                        with tf.name_scope("w"):
                            utils.attention_image_summary(weight, step=idx)
                # for i, weight in enumerate(weights):
                #     with tf.name_scope("layer_%d" % i):
                #         with tf.name_scope("_w0"):
                #             utils.attention_image_summary(weight[0])
                #         with tf.name_scope("_w1"):
                #             utils.attention_image_summary(weight[1])
            idx += 1
            print('\n====================================================')
            print('Epoch/Batch: {}/{}'.format(e, b))
            print('Train >>>> Loss: {:6.6}, Accuracy: {}'.format(result_metrics[0], result_metrics[1]))
            print('Eval >>>> Loss: {:6.6}, Accuracy: {}'.format(eval_result_metrics[0], eval_result_metrics[1]))


def train(num_layers, length, rate, batch, epochs, load_path, save_path,
          preproc_dir):
    if rate is None:
        rate = callback.CustomSchedule(par.embedding_dim)
    preproc_dir = Path(preproc_dir)

    model = MusicTransformerDecoder(
        embedding_dim=256,
        vocab_size=par.vocab_size,
        num_layer=num_layers,
        max_seq=length,
        dropout=0.2,
        debug=False,
        loader_path=load_path,
    )
    model.compile(
        optimizer=Adam(rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9),
        loss=callback.transformer_dist_train_loss,
    )

    time = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
    train_summary_writer = tf.summary.create_file_writer(
        f"logs/mt_decoder/{time}/train")
    eval_summary_writer = tf.summary.create_file_writer(
        f"logs/mt_decoder/{time}/eval")

    dataset = Data(preproc_dir)

    idx = 0
    with click.progressbar(length=epochs) as prog:
        for e in prog:
            model.reset_metrics()
            with click.progressbar(length=len(dataset.files) //
                                   batch) as prog2:
                for b in prog2:
                    batch_x, batch_y = dataset.slide_seq2seq_batch(
                        batch, length)
                    loss, acc = model.train_on_batch(batch_x, batch_y)

                    if b % 100 == 0:
                        eval_x, eval_y = dataset.slide_seq2seq_batch(
                            batch, length, "eval")
                        (eloss, eacc), weights = model.evaluate(eval_x, eval_y)
                        if save_path is not None:
                            model.save(save_path)

                        with train_summary_writer.as_default():
                            if b == 0:
                                tf.summary.histogram("target_analysis",
                                                     batch_y,
                                                     step=e)
                                tf.summary.histogram("source_analysis",
                                                     batch_x,
                                                     step=e)

                            tf.summary.scalar("loss", loss, step=idx)
                            tf.summary.scalar("accuracy", acc, step=idx)

                        with eval_summary_writer.as_default():
                            if b == 0:
                                model.sanity_check(eval_x, eval_y, step=e)

                            tf.summary.scalar("loss", eloss, step=idx)
                            tf.summary.scalar("accuracy", eacc, step=idx)

                            for i, weight in enumerate(weights):
                                with tf.name_scope("layer_%d" % i):
                                    with tf.name_scope("w"):
                                        utils.attention_image_summary(weight,
                                                                      step=idx)

                        print(
                            f"Loss: {loss:6.6} (e: {eloss:6.6}), Accuracy: {acc} (e: {eacc})"
                        )
                        idx += 1
Exemplo n.º 4
0
        try:
            batch_x, batch_y = dataset.seq2seq_batch(batch_size, max_seq)
        except:
            continue
        result_metrics = mt.train_on_batch(batch_x, batch_y)

        if b % 100 == 0:
            eval_x, eval_y = dataset.seq2seq_batch(batch_size, max_seq, 'eval')
            eval_result_metrics, weights = mt.evaluate(eval_x, eval_y)
            mt.save(save_path)
            with train_summary_writer.as_default():
                tf.summary.scalar('loss', result_metrics[0], step=idx)
                tf.summary.scalar('accuracy', result_metrics[1], step=idx)
                for i, weight in enumerate(weights):
                    with tf.name_scope("layer_%d" % i):
                        with tf.name_scope("_w0"):
                            utils.attention_image_summary(weight[0])
                        with tf.name_scope("_w1"):
                            utils.attention_image_summary(weight[1])

            with eval_summary_writer.as_default():
                tf.summary.scalar('loss', eval_result_metrics[0], step=idx)
                tf.summary.scalar('accuracy', eval_result_metrics[1], step=idx)
            idx += 1
            print('\n====================================================')
            print('Epoch/Batch: {}/{}'.format(e, b))
            print('Train >>>> Loss: {:6.6}, Accuracy: {}'.format(
                result_metrics[0], result_metrics[1]))
            print('Eval >>>> Loss: {:6.6}, Accuracy: {}'.format(
                eval_result_metrics[0], eval_result_metrics[1]))
Exemplo n.º 5
0
def train(input_path, save_path, l_r=None, batch_size=2,
          max_seq=1024, epochs=100,
          load_path=None, num_layer=6, log_dir='/pfs/out/logs'):
    # load data
    dataset = Data(input_path)
    print('dataset', dataset)


    # load model
    learning_rate = callback.CustomSchedule(par.embedding_dim) if l_r is None else l_r
    opt = Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)


    # define model
    mt = MusicTransformerDecoder(
                embedding_dim=256,
                vocab_size=par.vocab_size,
                num_layer=num_layer,
                max_seq=max_seq,
                dropout=0.2,
                debug=False, loader_path=load_path)
    mt.compile(optimizer=opt, loss=callback.transformer_dist_train_loss)


    # define tensorboard writer
    current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
    train_log_dir = '{log_dir}/{time}/train'.format(log_dir=log_dir, time=current_time)
    eval_log_dir = '{log_dir}/{time}/eval'.format(log_dir=log_dir, time=current_time)
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)
    eval_summary_writer = tf.summary.create_file_writer(eval_log_dir)


    # Train Start
    idx = 0
    batchings = len(dataset.files) // batch_size
    how_often_to_print = 50
    for e in tqdm(range(epochs), desc='epochs'):
        mt.reset_metrics()
        for b in tqdm(range(batchings), desc='batches'):
            try:
                batch_x, batch_y = dataset.slide_seq2seq_batch(batch_size, max_seq)
            except:
                continue
            result_metrics = mt.train_on_batch(batch_x, batch_y)
            if b % how_often_to_print == 0:
                eval_x, eval_y = dataset.slide_seq2seq_batch(batch_size, max_seq, 'eval')
                eval_result_metrics, weights = mt.evaluate(eval_x, eval_y)
                mt.save(save_path)
                with train_summary_writer.as_default():
                    if b == 0:
                        tf.summary.histogram("target_analysis", batch_y, step=e)
                        tf.summary.histogram("source_analysis", batch_x, step=e)

                    tf.summary.scalar('loss', result_metrics[0], step=idx)
                    tf.summary.scalar('accuracy', result_metrics[1], step=idx)

                with eval_summary_writer.as_default():
                    if b == 0:
                        mt.sanity_check(eval_x, eval_y, step=e)

                    tf.summary.scalar('loss', eval_result_metrics[0], step=idx)
                    tf.summary.scalar('accuracy', eval_result_metrics[1], step=idx)
                    for i, weight in enumerate(weights):
                        with tf.name_scope("layer_%d" % i):
                            with tf.name_scope("w"):
                                utils.attention_image_summary(weight, step=idx)
                    # for i, weight in enumerate(weights):
                    #     with tf.name_scope("layer_%d" % i):
                    #         with tf.name_scope("_w0"):
                    #             utils.attention_image_summary(weight[0])
                    #         with tf.name_scope("_w1"):
                    #             utils.attention_image_summary(weight[1])
                idx += 1
                print('\n====================================================')
                print('Epoch/Batch: {}/{}'.format(e, b))
                print('Train >>>> Loss: {:6.6}, Accuracy: {}'.format(result_metrics[0], result_metrics[1]))
                print('Eval >>>> Loss: {:6.6}, Accuracy: {}'.format(eval_result_metrics[0], eval_result_metrics[1]))