def edit(name=None):
    if request.method == 'GET':
        logging.info('loading transformer %s', name)
        if not name:
            return redirect(url_for('.new'))

        transformer = load(name)

        code = transformer.code
        editors = transformer.editors
    else:
        def get_editors():
            editors = request.form.get('editors', '').split(',')
            return User.get_by_username(editors)

        code = request.form['code']

        name = request.form.get('name')

        if name == 'new':
            name = None

        editors = get_editors()

        if name:
            Transformer.create_or_update(name, code, editors)
            return redirect(url_for('.list'))

    return render_template('transformer/create_or_edit.html',
                           name=name,
                           code=code,
                           editors=editors)
示例#2
0
def train(args):
    vocab = data_utils.get_vocab(vocab_file=args.vocab_file, min_freq=args.min_vocab_freq)
    # vocab = {}
    # with open(args.vocab_file, mode='r') as infile:
    #     for line in infile:
    #         w, w_id = line.split('\t')
    #         vocab[w] = int(w_id)

    print('Vocab loaded...')
    print('VOCAB SIZE = ', len(vocab))

    if args.model_type == 'transformer':
        transformer = Transformer(args=args, vocab=vocab)
        transformer.train_generator()
    elif args.model_type == 'rnn':
        rnn_params = {'rec_cell': 'lstm',
                     'encoder_dim': 800,
                     'decoder_dim': 800,
                     'num_encoder_layers': 2,
                     'num_decoder_layers': 2
                     }
        rnn = RNNSeq2Seq(args=args, rnn_params=rnn_params, vocab=vocab)
        # rnn.train()
        rnn.train_keras()
    elif args.model_type == 'han_rnn':
        han_rnn = HanRnnSeq2Seq(args=args, vocab=vocab)
        han_rnn.train()
    elif args.model_type == 'cnn':
        cnn = ConvSeq2Seq(args=args, vocab=vocab)
        cnn.train_keras()

    return
示例#3
0
    def valid_step(inp, targ):
        tar_inp = targ[:, :-1]
        tar_real = targ[:, 1:]
        end = tf.cast(
            tf.math.logical_not(
                tf.math.equal(tar_inp, tar_tokenizer.word_index['<end>'])),
            tf.int32)
        tar_inp *= end
        # create mask
        enc_padding_mask = Transformer.create_padding_mask(inp)

        # mask for first attention block in decoder
        look_ahead_mask = Transformer.create_seq_mask(tf.shape(tar_inp)[1])
        dec_target_padding_mask = Transformer.create_padding_mask(tar_inp)
        combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

        # mask for "enc_dec" multihead attention
        dec_padding_mask = Transformer.create_padding_mask(inp)

        # feed input into encoder
        predictions = model(inp, tar_inp, False, enc_padding_mask,
                            combined_mask, dec_padding_mask)
        val_loss = loss_fn(tar_real, predictions)
        train_accuracy(tar_real, predictions)
        return val_loss
示例#4
0
    def train_step(inp, targ):
        tar_inp = targ[:, :-1]
        tar_real = targ[:, 1:]
        end = tf.cast(
            tf.math.logical_not(
                tf.math.equal(tar_inp, tar_tokenizer.word_index['<end>'])),
            tf.int32)
        tar_inp *= end
        # tf.print("tar inp", tar_inp)
        # tf.print("tar real", tar_real)
        # create mask
        enc_padding_mask = Transformer.create_padding_mask(inp)

        # mask for first attention block in decoder
        look_ahead_mask = Transformer.create_seq_mask(tf.shape(tar_inp)[1])
        dec_target_padding_mask = Transformer.create_padding_mask(tar_inp)
        combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

        # mask for "enc_dec" multihead attention
        dec_padding_mask = Transformer.create_padding_mask(inp)

        with tf.GradientTape() as tape:
            # feed input into encoder
            predictions = model(inp, tar_inp, True, enc_padding_mask,
                                combined_mask, dec_padding_mask)
            train_loss = loss_fn(tar_real, predictions)

            # optimize step
            gradients = tape.gradient(train_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients,
                                          model.trainable_variables))
        return train_loss
示例#5
0
def translate_batch(model, inp, batch_size, tar_tokenizer):
    batch_max_length = tf.shape(inp)[1]
    decoder_input = tf.expand_dims([tar_tokenizer.word_index['<start>']] *
                                   batch_size,
                                   axis=1)
    output = decoder_input

    for i in range(batch_max_length + 50):
        # create mask
        enc_padding_mask = Transformer.create_padding_mask(inp)
        # mask for first attention block in decoder
        look_ahead_mask = Transformer.create_seq_mask(tf.shape(output)[1])
        dec_target_padding_mask = Transformer.create_padding_mask(output)
        combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

        # mask for "enc_dec" multihead attention
        dec_padding_mask = Transformer.create_padding_mask(inp)

        # predictions.shape == (batch_size, seq_len, vocab_size)
        predictions = model(inp, output, False, enc_padding_mask,
                            combined_mask, dec_padding_mask)

        predictions = predictions[:, -1:, :]  # (batch_size, 1, vocab_size)
        predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
        if (predicted_id == tar_tokenizer.word_index['<end>']).numpy().all():
            break
        output = tf.concat([output, predicted_id], axis=-1)
    pred_sentences = tar_tokenizer.sequences_to_texts(output.numpy())
    pred_sentences = [
        x.split("<end>")[0].replace("<start>", "").strip()
        for x in pred_sentences
    ]
    return pred_sentences
def delete(name):
    query = Transformer.find(name)
    if query:
        query.delete()

    json_data = Transformer.dumps({'status': 'ok'})
    return Response(json_data,  mimetype='application/json')
    def valid_step(inp):
        # set target
        tar_inp = inp[:, :-2]
        tar_real = inp[:, 2:]
        # remember the padding
        pad = tf.cast(tf.math.logical_not(tf.math.equal(inp, 0)), tf.int32)
        en = tf.math.equal(inp, sp.piece_to_id('<En>'))
        fr = tf.math.equal(inp, sp.piece_to_id('<Fr>'))
        # token maskin
        mask = tf.random.uniform(tf.shape(inp))
        mask = tf.math.less(mask, 0.2)
        mask = tf.math.logical_or(tf.math.logical_not(mask),
                                  tf.math.logical_or(en, fr))
        mask = tf.cast(mask, tf.int32)
        # [MASK] token index is 1
        inp = tf.math.maximum(inp * mask, 1)
        inp *= pad

        # create mask
        enc_padding_mask = Transformer.create_padding_mask(inp)

        # mask for first attention block in decoder
        look_ahead_mask = Transformer.create_seq_mask(tf.shape(tar_inp)[1])
        dec_target_padding_mask = Transformer.create_padding_mask(tar_inp)
        combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

        # mask for "enc_dec" multihead attention
        dec_padding_mask = Transformer.create_padding_mask(inp)

        # feed input into encoder
        predictions = model(inp, tar_inp, False, enc_padding_mask,
                            combined_mask, dec_padding_mask)
        val_loss = loss_fn(tar_real, predictions)
        train_accuracy(tar_real, predictions)
        return val_loss
    def __init__(self, config):
        self.config = config
        self.prepare_dataloaders(config['data'])

        # self.model = MLP(config['MLP'])
        # self.model = MLP_3D(config['MLP'])
        # self.model = LSTM(config['LSTM'])
        self.model = Transformer(config['Trans'])
        print(self.model)

        self.model_name = config['train']['model_name']

        self.checkpoint_dir = './checkpoint_dir/{}/'.format(self.model_name)
        if not os.path.exists(self.checkpoint_dir):
            os.mkdir(self.checkpoint_dir)
        self.tb_log_dir = './tb_log/{}/'.format(self.model_name)
        if not os.path.exists(self.tb_log_dir):
            os.mkdir(self.tb_log_dir)

        self.optimal_metric = 100000
        self.cur_metric = 100000

        self.loss = nn.MSELoss()
        self.optim = optim.Adam(self.model.parameters(),
                                lr=self.config['train']['lr'],
                                betas=(0.5, 0.999))
示例#9
0
def change_max_pos_embd(args, new_mpe_size, n_classes):
    config = TransformerConfig(size=args.transformer_size,
                               max_position_embeddings=new_mpe_size)
    if args.use_cnn:
        config.input_size = CnnConfig.output_dim
    model = Transformer(config=config, n_classes=n_classes)
    model = model.to(device)
    return model
示例#10
0
    def __init__(self):
        super(TMLU, self).__init__()

        self.global_step = 0

        # Data Pipeline
        data_pipeline = Preprocess(cfg)
        self.train_dataset, self.val_dataset = data_pipeline.get_data()

        self.tokenizer_pt = data_pipeline.tokenizer_pt
        self.tokenizer_en = data_pipeline.tokenizer_en

        cfg.input_vocab_size = self.tokenizer_pt.vocab_size + 2
        cfg.target_vocab_size = self.tokenizer_en.vocab_size + 2

        # Model
        self.transformer = Transformer(cfg)

        # Optimizer
        learning_rate = CustomSchedule(cfg.d_model)
        self.optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

        # Loss and Metrics
        self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
        self.train_loss = tf.keras.metrics.Mean(name='train_loss')
        self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

        # Build writers for logging
        self.build_writers()

        checkpoint_path = "./checkpoints/train"
        self.ckpt = tf.train.Checkpoint(transformer=self.transformer, optimizer=self.optimizer)
        self.ckpt_manager = tf.train.CheckpointManager(self.ckpt, checkpoint_path, max_to_keep=5)
示例#11
0
def evaluate(args):
    label_map = load_label_map(args.dataset)
    n_classes = 50
    if args.dataset == "include":
        n_classes = 263

    if args.use_cnn:
        dataset = FeaturesDatset(
            features_dir=os.path.join(args.data_dir,
                                      f"{args.dataset}_test_features"),
            label_map=label_map,
            mode="test",
        )

    else:
        dataset = KeypointsDataset(
            keypoints_dir=os.path.join(args.data_dir,
                                       f"{args.dataset}_test_keypoints"),
            use_augs=False,
            label_map=label_map,
            mode="test",
            max_frame_len=169,
        )

    dataloader = data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
    )

    if args.model == "lstm":
        config = LstmConfig()
        if args.use_cnn:
            config.input_size = CnnConfig.output_dim
        model = LSTM(config=config, n_classes=n_classes)
    else:
        config = TransformerConfig(size=args.transformer_size)
        if args.use_cnn:
            config.input_size = CnnConfig.output_dim
        model = Transformer(config=config, n_classes=n_classes)

    model = model.to(device)

    if args.use_pretrained == "evaluate":
        model, _, _ = load_pretrained(args, n_classes, model)
        print("### Model loaded ###")

    else:
        exp_name = get_experiment_name(args)
        model_path = os.path.join(args.save_path, exp_name) + ".pth"
        ckpt = torch.load(model_path)
        model.load_state_dict(ckpt["model"])
        print("### Model loaded ###")

    test_loss, test_acc = validate(dataloader, model, device)
    print("Evaluation Results:")
    print(f"Loss: {test_loss}, Accuracy: {test_acc}")
def load(name):
    logging.info('load transformer %s', name)

    transformer = Transformer.find(name)

    if not transformer:
        return None

    return transformer
    def train_step(inp):
        # set target
        tar_inp = inp[:, :-2]
        tar_real = inp[:, 2:]
        # remember the padding
        pad = tf.cast(tf.math.logical_not(tf.math.equal(inp, 0)), tf.int32)
        en = tf.math.equal(inp, sp.piece_to_id('<En>'))
        fr = tf.math.equal(inp, sp.piece_to_id('<Fr>'))
        # token maskin
        mask = tf.random.uniform(tf.shape(inp))
        mask = tf.math.less(mask, 0.2)
        mask = tf.math.logical_or(tf.math.logical_not(mask),
                                  tf.math.logical_or(en, fr))
        mask = tf.cast(mask, tf.int32)
        # [MASK] token index is 1
        inp = tf.math.maximum(inp * mask, 1)
        inp *= pad

        # tf.print("tar inp", tar_inp)
        # tf.print("tar real", tar_real)
        # create mask
        enc_padding_mask = Transformer.create_padding_mask(inp)

        # mask for first attention block in decoder
        look_ahead_mask = Transformer.create_seq_mask(tf.shape(tar_inp)[1])
        dec_target_padding_mask = Transformer.create_padding_mask(tar_inp)
        combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

        # mask for "enc_dec" multihead attention
        dec_padding_mask = Transformer.create_padding_mask(inp)

        with tf.GradientTape() as tape:
            # feed input into encoder
            predictions = model(inp, tar_inp, True, enc_padding_mask,
                                combined_mask, dec_padding_mask)
            train_loss = loss_fn(tar_real, predictions)

            # optimize step
            gradients = tape.gradient(train_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients,
                                          model.trainable_variables))
        return train_loss
    def valid_step(inp, targ):
        tar_inp = targ[:, :-1]
        tar_real = targ[:, 1:]

        # create mask
        enc_padding_mask = Transformer.create_padding_mask(inp)

        # mask for first attention block in decoder
        look_ahead_mask = Transformer.create_seq_mask(tf.shape(tar_inp)[1])
        dec_target_padding_mask = Transformer.create_padding_mask(tar_inp)
        combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

        # mask for "enc_dec" multihead attention
        dec_padding_mask = Transformer.create_padding_mask(inp)

        # feed input into encoder
        predictions, att = model(inp, tar_inp, False, enc_padding_mask, combined_mask, dec_padding_mask)
        val_loss = loss_fn(tar_real, predictions)
        train_accuracy(tar_real, predictions)
        return val_loss
示例#15
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input",
                        default=None,
                        type=str,
                        help="run demo chatbot")

    args = parser.parse_args()

    input_sentence = args.input

    tokenizer = tfds.features.text.SubwordTextEncoder.load_from_file(
        vocab_filename)
    # Vocabulary size plus start and end token
    VOCAB_SIZE = tokenizer.vocab_size + 2

    model = Transformer(num_layers=NUM_LAYERS,
                        units=UNITS,
                        d_model=D_MODEL,
                        num_heads=NUM_HEADS,
                        vocab_size=VOCAB_SIZE,
                        dropout=DROPOUT,
                        name='transformer')
    demo_sentense = 'How are you'
    predict(demo_sentense, tokenizer, model, True)

    model.load_weights(save_weight_path)

    model.summary()
    tf.keras.utils.plot_model(model,
                              to_file='transformer.png',
                              show_shapes=True)

    predict(input_sentence, tokenizer, model)
示例#16
0
def main():
    Place = paddle.fluid.CUDAPlace(0)
    with fluid.dygraph.guard(Place):
        model = Transformer(image_size=512,
                            num_classes=15,
                            hidden_unit_num=1024,
                            layer_num=2,
                            head_num=16,
                            dropout=0.8,
                            decoder_name='PUP',
                            hyber=True,
                            visualable=False)
        preprocess = Transform(512)
        dataloader_1 = Dataloader('/home/aistudio/dataset',
                                  '/home/aistudio/dataset/val_list.txt',
                                  transform=preprocess,
                                  shuffle=True)
        val_load = fluid.io.DataLoader.from_generator(capacity=1,
                                                      use_multiprocess=False)
        val_load.set_sample_generator(dataloader_1, batch_size=1, places=Place)
        model_dic, optic_dic = load_dygraph(
            "./output/SETR-NotZero-Epoch-2-Loss-0.161517-MIOU-0.325002")
        model.load_dict(model_dic)
        model.eval()
        '''result = get_infer_data("/home/aistudio/dataset/infer")
        infer_load  = Load_infer('/home/aistudio/dataset', result, transform=preprocess, shuffle=False)
        loader_infer= fluid.io.DataLoader.from_generator(capacity=1, use_multiprocess=False)
        loader_infer.set_sample_generator(infer_load, batch_size=1, places=Place)
        process_image(model, loader_infer, result)'''
        validation(val_load, model, 15)
示例#17
0
def create_model_transformer_b2a(arg, devices_list, eval=False):
    from models import Transformer
    resume_dataset = arg.eval_dataset_transformer if eval else arg.dataset
    resume_b = arg.eval_split_source_trasformer if eval else arg.split_source
    resume_a = arg.eval_split_trasformer if eval else arg.split
    resume_epoch = arg.eval_epoch_transformer if eval else arg.resume_epoch

    transformer = Transformer(in_channels=boundary_num,
                              out_channels=boundary_num)

    if resume_epoch > 0:
        load_path = arg.resume_folder + 'transformer_' + resume_dataset + '_' + resume_b + '2' + resume_a + '_' + str(
            resume_epoch) + '.pth'
        print('Loading Transformer from ' + load_path)
        transformer = load_weights(transformer, load_path, devices_list[0])
    else:
        init_weights(transformer, init_type='transformer')
        # init_weights(transformer)

    if arg.cuda:
        transformer = transformer.cuda(device=devices_list[0])

    return transformer
示例#18
0
def train(args):
    src_root = args.src_root
    sr = args.sample_rate
    dt = args.delta_time
    batch_size = args.batch_size
    model_type = args.model_type
    params = {'N_CLASSES':len(os.listdir(args.src_root)),
              'SR':sr,
              'DT':dt}
    models = {'conv1d':Conv1D(**params),
              'conv2d':Conv2D(**params),
              'lstm':  LSTM(**params),
              'transformer': Transformer(**params),
              'ViT': ViT(**params)}

    assert model_type in models.keys(), '{} not an available model'.format(model_type)
    csv_path = os.path.join('logs', '{}_history.csv'.format(model_type))

    wav_paths = glob('{}/**'.format(src_root), recursive=True)
    wav_paths = [x.replace(os.sep, '/') for x in wav_paths if '.wav' in x]
    classes = sorted(os.listdir(args.src_root))
    le = LabelEncoder()
    le.fit(classes)
    labels = [os.path.split(x)[0].split('/')[-1] for x in wav_paths]
    labels = le.transform(labels)
    wav_train, wav_val, label_train, label_val = train_test_split(wav_paths,
                                                                  labels,
                                                                  test_size=0.1,
                                                                  random_state=0)

    assert len(label_train) >= args.batch_size, 'Number of train samples must be >= batch_size'
    if len(set(label_train)) != params['N_CLASSES']:
        warnings.warn('Found {}/{} classes in training data. Increase data size or change random_state.'.format(len(set(label_train)), params['N_CLASSES']))
    if len(set(label_val)) != params['N_CLASSES']:
        warnings.warn('Found {}/{} classes in validation data. Increase data size or change random_state.'.format(len(set(label_val)), params['N_CLASSES']))

    tg = DataGenerator(wav_train, label_train, sr, dt,
                       params['N_CLASSES'], batch_size=batch_size)
    vg = DataGenerator(wav_val, label_val, sr, dt,
                       params['N_CLASSES'], batch_size=batch_size)
    model = models[model_type]
    model.summary()
    cp = ModelCheckpoint('models/{}.h5'.format(model_type), monitor='val_loss',
                         save_best_only=True, save_weights_only=False,
                         mode='auto', save_freq='epoch', verbose=1)
    csv_logger = CSVLogger(csv_path, append=False)
    model.fit(tg, validation_data=vg,
              epochs=40, verbose=1,
              callbacks=[csv_logger])
示例#19
0
 def __init__(self,
              voc_size_src,
              voc_size_tar,
              max_pe,
              num_encoders,
              num_decoders,
              emb_size,
              num_head,
              ff_inner=2048,
              p_dropout=0.1):
     super(mBART, self).__init__()
     self.transformer = Transformer.Transformer(voc_size_src, voc_size_tar,
                                                max_pe, num_encoders,
                                                num_decoders, emb_size,
                                                num_head, ff_inner,
                                                p_dropout)
def list():
    return render_template('transformer/list.html', transformers=Transformer.all())
def generate_predictions(ckpt, path_src, path_tar, input_file_path: str,
                         pred_file_path: str, num_sync):
    """Generates predictions for the machine translation task (EN->FR).
    You are allowed to modify this function as needed, but one again, you cannot
    modify any other part of this file. We will be importing only this function
    in our final evaluation script. Since you will most definitely need to import
    modules for your code, you must import these inside the function itself.
    Args:
        input_file_path: the file path that contains the input data.
        pred_file_path: the file path where to store the predictions.
    Returns: None
    """
    # load input file => create test dataloader => (spm encode)
    from data.dataloaders import prepare_test
    BATCH_SIZE = 128
    TOTAL_ITER = int((num_sync / 128))

    # load  tokenizer of train to tokenize test data
    import pickle
    f_src = open(path_src, 'rb')
    f_tar = open(path_tar, 'rb')
    src_tokenizer = pickle.load(f_src)
    tar_tokenizer = pickle.load(f_tar)

    test_dataset, test_max_length = prepare_test(input_file_path,
                                                 src_tokenizer,
                                                 batch_size=BATCH_SIZE)
    # create model
    from models import Transformer
    import tensorflow as tf
    src_vocsize = len(src_tokenizer.word_index) + 1
    tar_vocsize = len(tar_tokenizer.word_index) + 1
    # create model instance
    optimizer = tf.keras.optimizers.Adam()
    model = Transformer.Transformer(voc_size_src=src_vocsize,
                                    voc_size_tar=tar_vocsize,
                                    max_pe=10000,
                                    num_encoders=4,
                                    num_decoders=4,
                                    emb_size=512,
                                    num_head=8,
                                    ff_inner=1024)

    # Load CheckPoint
    ckpt_dir = ckpt
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
    status = checkpoint.restore(tf.train.latest_checkpoint(ckpt_dir))
    status.assert_existing_objects_matched()

    # Greedy Search / Beam Search and write to pred_file_path
    import time
    from translate import translate_batch
    start = time.time()
    count = 0
    with open(pred_file_path, 'w', encoding='utf-8') as pred_file:
        for (batch, (inp)) in enumerate(test_dataset):
            print("Evaluating Batch: %s" % batch)
            batch_size = tf.shape(inp)[0].numpy()
            translation = translate_batch(model, inp, batch_size,
                                          tar_tokenizer)
            for sentence in translation:
                pred_file.write(sentence.strip() + '\n')
                pred_file.flush()
            count += 1
            if count > TOTAL_ITER:
                break
    end = time.time()
    print("Translation finish in %s s" % (end - start))
示例#22
0
def fit(args):
    exp_name = get_experiment_name(args)
    logging_path = os.path.join(args.save_path, exp_name) + ".log"
    logging.basicConfig(filename=logging_path,
                        level=logging.INFO,
                        format="%(message)s")
    seed_everything(args.seed)
    label_map = load_label_map(args.dataset)

    if args.use_cnn:
        train_dataset = FeaturesDatset(
            features_dir=os.path.join(args.data_dir,
                                      f"{args.dataset}_train_features"),
            label_map=label_map,
            mode="train",
        )
        val_dataset = FeaturesDatset(
            features_dir=os.path.join(args.data_dir,
                                      f"{args.dataset}_val_features"),
            label_map=label_map,
            mode="val",
        )

    else:
        train_dataset = KeypointsDataset(
            keypoints_dir=os.path.join(args.data_dir,
                                       f"{args.dataset}_train_keypoints"),
            use_augs=args.use_augs,
            label_map=label_map,
            mode="train",
            max_frame_len=169,
        )
        val_dataset = KeypointsDataset(
            keypoints_dir=os.path.join(args.data_dir,
                                       f"{args.dataset}_val_keypoints"),
            use_augs=False,
            label_map=label_map,
            mode="val",
            max_frame_len=169,
        )

    train_dataloader = data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
    )
    val_dataloader = data.DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
    )

    n_classes = 50
    if args.dataset == "include":
        n_classes = 263

    if args.model == "lstm":
        config = LstmConfig()
        if args.use_cnn:
            config.input_size = CnnConfig.output_dim
        model = LSTM(config=config, n_classes=n_classes)
    else:
        config = TransformerConfig(size=args.transformer_size)
        if args.use_cnn:
            config.input_size = CnnConfig.output_dim
        model = Transformer(config=config, n_classes=n_classes)

    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=args.learning_rate,
                                  weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode="max",
                                                           factor=0.2)

    if args.use_pretrained == "resume_training":
        model, optimizer, scheduler = load_pretrained(args, n_classes, model,
                                                      optimizer, scheduler)

    model_path = os.path.join(args.save_path, exp_name) + ".pth"
    es = EarlyStopping(patience=15, mode="max")
    for epoch in range(args.epochs):
        print(f"Epoch: {epoch+1}/{args.epochs}")
        train_loss, train_acc = train(train_dataloader, model, optimizer,
                                      device)
        val_loss, val_acc = validate(val_dataloader, model, device)
        logging.info(
            "Epoch: {}, train loss: {}, train acc: {}, val loss: {}, val acc: {}"
            .format(epoch + 1, train_loss, train_acc, val_loss, val_acc))
        scheduler.step(val_acc)
        es(
            model_path=model_path,
            epoch_score=val_acc,
            model=model,
            optimizer=optimizer,
            scheduler=scheduler,
        )
        if es.early_stop:
            print("Early stopping")
            break

    print("### Training Complete ###")
示例#23
0
                        connection.url.startswith('http://') or \
                        connection.url.startswith('https://'):
                    description, data, columns_order = query_google_data_source(
                        connection,
                        sql,
                        fields_meta_vars,
                        query_vars)
                else:
                    description, data, columns_order = query_execute_sql(
                        connection,
                        sql,
                        fields_meta_vars,
                        query_vars)

                if transform:
                    data = Transformer.execute(transform, data, True)

                if is_format_request('json'):
                    json_data = json.dumps(data, default=handle_datetime)
                elif len(data) > 0:
                    data_table = data_to_datatable(description, data)
                    json_data = data_table.ToJSon(columns_order=columns_order)

            if not json_data:
                json_data = json.dumps([])
                data_table = None

        except Exception, ex:
            logging.exception("Failed to execute query %s", ex)
            error = str(ex)
示例#24
0
def main():
    questions, answers = load_conversations()
    # Build tokenizer using tfds for both questions and answers
    tokenizer = tfds.features.text.SubwordTextEncoder.build_from_corpus(
        questions + answers, target_vocab_size=2**13)

    tokenizer.save_to_file(vocab_filename)

    # Vocabulary size plus start and end token
    VOCAB_SIZE = tokenizer.vocab_size + 2

    questions, answers = tokenize_and_filter(questions, answers, tokenizer)
    print('Vocab size: {}'.format(VOCAB_SIZE))
    print('Number of samples: {}'.format(len(questions)))
    # decoder inputs use the previous target as input
    # remove START_TOKEN from targets
    dataset = tf.data.Dataset.from_tensor_slices((
        {
            'inputs': questions
        },
        {
            'outputs': answers
        },
    ))

    dataset = dataset.cache()
    dataset = dataset.shuffle(BUFFER_SIZE)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    print(dataset)

    model = Transformer(num_layers=NUM_LAYERS,
                        units=UNITS,
                        d_model=D_MODEL,
                        num_heads=NUM_HEADS,
                        vocab_size=VOCAB_SIZE,
                        dropout=DROPOUT,
                        name='transformer')

    learning_rate = CustomSchedule(D_MODEL)

    optimizer = tf.keras.optimizers.Adam(learning_rate,
                                         beta_1=0.9,
                                         beta_2=0.98,
                                         epsilon=1e-9)

    ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)

    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              checkpoint_path,
                                              max_to_keep=5)

    # if a checkpoint exists, restore the latest checkpoint.
    if ckpt_manager.latest_checkpoint:
        ckpt.restore(ckpt_manager.latest_checkpoint)
        print('Latest checkpoint restored!!')

    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='train_accuracy')

    for epoch in range(EPOCHS):
        start = time.time()

        train_loss.reset_states()
        train_accuracy.reset_states()

        for (batch, (inp, tar)) in enumerate(dataset):

            train_step(inp, tar, model, optimizer, train_loss, train_accuracy)

            if batch % 500 == 0:
                print('Epoch {} Batch {} Loss {:.4f} Accuracy {:.4f}'.format(
                    epoch + 1, batch, train_loss.result(),
                    train_accuracy.result()))

        if (epoch + 1) % 5 == 0:
            ckpt_save_path = ckpt_manager.save()
            print('Saving checkpoint for epoch {} at {}'.format(
                epoch + 1, ckpt_save_path))

        print('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(
            epoch + 1, train_loss.result(), train_accuracy.result()))

        print('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))

    model.save_weights(save_weight_path)
    #model.summary()
    input_sentence = 'Where have you been?'
    predict(input_sentence, tokenizer, model)

    sentence = 'I am not crazy, my mother had me tested.'
    for _ in range(5):
        sentence = predict(sentence, tokenizer, model)
        print('')
示例#25
0
# embeddings are attributes
if args.dataset_name == 'aids':
    EMBED_DIM = 4
    num_classes = 2
    num_heads = 8
    depth = 6
    p, q = 1, 1
elif args.dataset_name == 'coildel':
    EMBED_DIM = 2
    num_classes = 100
    num_heads = 8
    depth = 6
    p, q = 1, 1

# k, num_heads, depth, seq_length, num_tokens, num_
model = Transformer(EMBED_DIM, num_heads, test_dataset.walklength, depth,
                    num_classes).to(device)

lr_warmup = 10000

lr = 1e-3

opt = torch.optim.Adam(lr=lr, params=model.parameters())
sch = torch.optim.lr_scheduler.LambdaLR(
    opt, lambda i: min(i / (lr_warmup / args.batch_size), 1.0))
loss_func = nn.NLLLoss()


def train_validate(model, loader, opt, loss_func, train, device):

    if train:
        model.train()
示例#26
0
def generate_predictions(input_file_path: str, pred_file_path: str):
    """Generates predictions for the machine translation task (EN->FR).
    You are allowed to modify this function as needed, but one again, you cannot
    modify any other part of this file. We will be importing only this function
    in our final evaluation script. Since you will most definitely need to import
    modules for your code, you must import these inside the function itself.
    Args:
        input_file_path: the file path that contains the input data.
        pred_file_path: the file path where to store the predictions.
    Returns: None
    """
    # ---------------------------------------------------------------------
    # Include essential module for evaluation
    import os
    import json
    import pickle
    import time
    import tensorflow as tf
    from data.dataloaders import prepare_training_pairs, prepare_test
    from models import Transformer
    from translate import translate_batch
    from definition import ROOT_DIR
    CONFIG = "eval_cfg.json"
    # ---------------------------------------------------------------------
    # Load setting in json file

    with open(os.path.join(ROOT_DIR, CONFIG)) as f:
        para = json.load(f)
    batch_size = para["batch_size"]
    source = para["src"]
    target = para["tar"]
    ckpt_dir = para["ckpt"]

    # ---------------------------------------------------------------------
    # Create test dataloader from input file (tokenized and map to sequence)

    # Todo: The final training and target tokenizer is needed, so as to use the same tokenizer on test data,
    #  because we didn't build a dictionary file.
    f_src = open(source, 'rb')
    f_tar = open(target, 'rb')
    src_tokenizer = pickle.load(f_src)
    tar_tokenizer = pickle.load(f_tar)

    test_dataset, test_max_length = prepare_test(input_file_path,
                                                 src_tokenizer,
                                                 batch_size=batch_size)
    # calculate vocabulary size
    src_vocsize = len(src_tokenizer.word_index) + 1
    tar_vocsize = len(tar_tokenizer.word_index) + 1
    # ---------------------------------------------------------------------
    # Create the instance of model to load checkpoints
    # Todo: Define the model that fit the checkpoints you want to load
    optimizer = tf.keras.optimizers.Adam()
    model = Transformer.Transformer(voc_size_src=src_vocsize,
                                    voc_size_tar=tar_vocsize,
                                    max_pe=10000,
                                    num_encoders=4,
                                    num_decoders=4,
                                    emb_size=512,
                                    num_head=8,
                                    ff_inner=1024)

    # ---------------------------------------------------------------------
    # Load CheckPoint
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
    status = checkpoint.restore(tf.train.latest_checkpoint(ckpt_dir))
    # check if loading is successful
    status.assert_existing_objects_matched()

    # ---------------------------------------------------------------------
    # Use Greedy Search to generate prediction and write to pred_file_path
    start = time.time()
    with open(pred_file_path, 'w', encoding='utf-8') as pred_file:
        for (batch, (inp)) in enumerate(test_dataset):
            if batch % 5 == 0:
                print("Evaluating Batch: %s" % batch)
            batch_size = tf.shape(inp)[0].numpy()
            translation = translate_batch(model, inp, batch_size,
                                          tar_tokenizer)
            for sentence in translation:
                pred_file.write(sentence.strip() + '\n')
    end = time.time()
    print("Translation finish in %s s" % (end - start))
示例#27
0
plt.ylabel("Learning Rate")
plt.xlabel("Train Step")


# Loss and Metrics
# -------------------------------------------------------------------------------------------------------
print("Loss and Metrics\n------------------------------------------")
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')


# Training and Checkpoints
# -------------------------------------------------------------------------------------------------------
print("Training and Checkpoints\n------------------------------------------")
transformer = Transformer(num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, dropout_rate)

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(transformer=transformer, optimizer=optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

for epoch in range(EPOCHS):
    start = time.time()
    
示例#28
0
def edit(name=None):

    user_access_token = request.args.get('access_token', request.form.get('access_token'))

    if not user_access_token and g.user is None:
        return redirect(url_for('user.login', next=request.path))

    def is_format_request(formats):
        if isinstance(formats, basestring):
            return request.args.get(formats, None) is not None

        for format in formats:
            if is_format_request(format):
                return True

        return False

    def is_data_request():
        return is_format_request(['gwiz', 'json', 'csv', 'html', 'gwiz_json'])

    execute_sql = request.method == 'POST' or is_data_request()

    query = None

    if request.method == 'POST':
        def extract_meta_var_fields():
            index = 0

            meta = []
            while True:
                name = request.form.get('var-name%d' % index)
                if not name:
                    return meta

                type = request.form.get('var-type%d' % index, 'string')
                default = request.form.get('var-default%d' % index, 'string')

                meta.append((name, {'default': default, 'type': type}))

                index += 1

        def vars_from_request(meta_vars, empty_vars):
            vars = []
            for name, options in meta_vars:
                value = request.form.get(name)

                if value:
                    value = convert_value(value, options.get('type'))
                elif not empty_vars:
                    continue

                vars.append((name, value))

            return vars

        def get_editors():
            editors = request.form.getlist('editors')
            if editors:
                users = User.get_by_username(editors)
                return users

            return []


        meta_vars = extract_meta_var_fields()

        sql = request.form['sql']

        name = request.form.get('query-name')

        if name == 'new':
            name = None

        if not name and g.user is None:
            return redirect(url_for('user.login', next=request.path))

        connection_string = request.form.get('connection-string')
        connection = ConnectionString.find(connection_string, True)

        editors = get_editors()

        if name and request.form.get('user-action') == 'Save':
            if g.user is None:
                return redirect(url_for('user.login', next=request.path))

            query, created = save(name, sql, meta_vars, connection, editors)

            if created:
                full_vars = vars_from_request(meta_vars, False)
                return redirect(url_for('.edit', name = name, **dict(full_vars)))

        vars = vars_from_request(meta_vars, True)

    else:
        if not name:
            return redirect(url_for('.new'))

        vars = []
        for key, value in request.args.iteritems():
            if not value:
                continue

            vars.append((key, value))

        query = Query.find(name, access_token=user_access_token)

        if not query:
            return redirect(url_for('.new'))

        sql = query.sql
        meta_vars = query.meta_vars
        connection = query.connection
        editors = query.editors

    data_table = None
    error = None
    json_data = None
    access_token = None
    if execute_sql:
        transform = request.args.get('transform', None)

        if not transform:
            transform = request.args.get('transformer', None)

        try:
            json_data = None

            if connection:
                if connection.url.startswith('google://') or connection.url.startswith('http://') or connection.url.startswith('https://'):
                    description, data, columns_order = query_google_data_source(connection, sql, meta_vars, vars)
                else:
                    description, data, columns_order = query_execute_sql(connection, sql, meta_vars, vars)

                if transform:
                    data = Transformer.execute(transform, data, True)

                if is_format_request('json'):
                    json_data = json.dumps(data, default=handle_datetime)
                elif len(data) > 0:
                    data_table = data_to_datatable(description, data)
                    json_data = data_table.ToJSon(columns_order=columns_order)

            if not json_data:
                json_data = json.dumps([])
                data_table = None

        except Exception, ex:
            logging.exception("Failed to execute query %s", ex)
            error = str(ex)

        if is_format_request('gwiz'):
            if error:
                return Response(json.dumps({"error": error}), mimetype='application/json')

            return Response(data_table.ToJSonResponse(columns_order=columns_order),  mimetype='application/json')
        if is_format_request('gwiz_json'):
            if error:
                return Response(json.dumps({"error": error}), mimetype='application/json')

            if data_table:
                return Response(data_table.ToJSon(columns_order=columns_order), mimetype='application/json')
            else:
                return Response(json.dumps({"info": 'No results returned'}), mimetype='application/json')

        if is_format_request('json'):
            if error:
                return Response(json.dumps({"error": error}), mimetype='application/json')

            return Response(json_data,  mimetype='application/json')
        elif is_format_request('html'):
            return Response(data_table.ToHtml(columns_order=columns_order))
        elif is_format_request('csv'):
            return Response(data_table.ToCsv(columns_order=columns_order), mimetype='text/csv')
print(src_tokenizer.index_word)

src_vocsize = len(src_tokenizer.word_index) + 1
tar_vocsize = len(tar_tokenizer.word_index) + 1
print("Source Language voc size: %s" % src_vocsize)
print("Target Language voc size: %s" % tar_vocsize)

print("Source Language max length: %s" % source_max_length)
print("Target Language max length: %s" % target_max_length)

model = Transformer.Transformer(voc_size_src=src_vocsize,
                                voc_size_tar=tar_vocsize,
                                src_max_length=source_max_length,
                                tar_max_length=target_max_length,
                                num_encoders=1,
                                num_decoders=1,
                                emb_size=8,
                                num_head=2,
                                ff_inner=1024)

tf.random.set_seed(12)
for src, tar in train_dataset:
    # create mask
    enc_padding_mask = Transformer.create_padding_mask(src)
    print("enc_padding_mask", enc_padding_mask.numpy())

    # mask for first attention block in decoder
    look_ahead_mask = Transformer.create_seq_mask(target_max_length)
    print("look ahead mask", look_ahead_mask.numpy())
    dec_target_padding_mask = Transformer.create_padding_mask(tar)
示例#30
0
    #print(df[df["T (degC)"] < 0]["T (degC)"])
    df = df.drop(["max. wv (m/s)", "wv (m/s)"], axis=1)

    day = 24 * 60 * 60
    year = (365.2425) * day

    timestamp_s = date_time.map(datetime.datetime.timestamp)
    df['Day sin'] = np.sin(timestamp_s * (2 * np.pi / day))
    df['Day cos'] = np.cos(timestamp_s * (2 * np.pi / day))
    df['Year sin'] = np.sin(timestamp_s * (2 * np.pi / year))
    df['Year cos'] = np.cos(timestamp_s * (2 * np.pi / year))

    print(df.head(20))

    # モデルのインスタンスを作成
    model = Transformer()
    #model = LSTMM()
    loss_object = tf.keras.losses.BinaryCrossentropy()
    optimizer = tf.keras.optimizers.Adam(0.01)
    #
    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.BinaryAccuracy()

    test_loss = tf.keras.metrics.Mean(name='test_loss')
    test_accuracy = tf.keras.metrics.BinaryAccuracy()

    @tf.function
    def train_step(seq, labels):
        with tf.GradientTape() as tape:
            predictions = model(seq)
            tf.print(predictions[0])
示例#31
0
    word2id_fr = pickle.load(f)
with open('./data/europarl/word2id_en.pickle', 'rb') as f:
    word2id_en = pickle.load(f)
with open('./data/europarl/input_sentences.pickle', 'rb') as f:
    input_sentences = pickle.load(f)
with open('./data/europarl/output_sentences.pickle', 'rb') as f:
    output_sentences = pickle.load(f)

n = input_sentences.shape[0]
n_train = int(0.9 * n)
perm = np.random.permutation(n)
train_in = input_sentences[perm[0:n_train]]
train_out = output_sentences[perm[0:n_train]]
val_in = input_sentences[perm[n_train:n]].values
val_out = output_sentences[perm[n_train:n]].values

model = Transformer(in_voc=(id2word_fr, word2id_fr),
                    out_voc=(id2word_en, word2id_en),
                    hidden_size=50,
                    lr=1e-3,
                    batch_size=128,
                    beam_size=10,
                    nb_epochs=10,
                    nb_heads=4,
                    pos_enc=True,
                    nb_layers=1)

model.fit(train_in, train_out)
model.save("./model/transformer.ckpt")
# model.evaluate(valid_data=(val_in, val_out)) # requires nltk
示例#32
0
    # model parameters
    dim_input = 6
    output_sequence_length = Y_DAYS
    dec_seq_len = Y_DAYS
    dim_val = 64
    dim_attn = 12#12
    n_heads = 8 
    n_encoder_layers = 4
    n_decoder_layers = 2

    # paths
    PATHS = crypto_data_paths()
    MODEL_PATH = 'weights/trans/stock/3days/{e}_{d}_{v}_{y}_seed{seed}'.format(e=n_encoder_layers, d=n_decoder_layers, v=dim_val, y=Y_DAYS, seed=SEED)

    #init network
    net = Transformer(dim_val, dim_attn, dim_input, dec_seq_len, output_sequence_length, n_decoder_layers, n_encoder_layers, n_heads)
    #net.load_state_dict(torch.load(MODEL_PATH))
    # load the dataset
    X_train, y_train, X_test, y_test = create_input_data(PATHS, N_LAGS, Y_DAYS)
    train_dataset = StockDataset(X_train, y_train)
    train_loader = DataLoader(dataset=train_dataset,     
                            batch_size=BATCH_SIZE)
    test_dataset = StockDataset(X_test, y_test)
    test_loader = DataLoader(dataset=test_dataset,     
                            batch_size=BATCH_SIZE)

    train(net, N_EPOCHS, train_loader, LR, MODEL_PATH)
    eval(net, MODEL_PATH, test_loader)
    #tensorboard --logdir=runs

    #The MSE is  0.0010176260894923762 0
示例#33
0
    no_up = 0
    # 定义计算图
    print('---start graph---')
    with tf.Graph().as_default():

        session_conf = tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=False)
        session_conf.gpu_options.allow_growth = True
        session_conf.gpu_options.per_process_gpu_memory_fraction = 0.9  # 配置gpu占用率

        sess = tf.Session(config=session_conf)

        # 定义会话
        with sess.as_default():

            transformer = Transformer(config, wordEmbedding)

            globalStep = tf.Variable(0, name="globalStep", trainable=False)
            # 定义优化函数,传入学习速率参数
            optimizer = tf.train.MomentumOptimizer(
                learning_rate=config.training.learningRate, momentum=0.1)  #
            #optimizer = tf.train.AdamOptimizer(config.training.learningRate,beta1=0.9,beta2=0.999,epsilon=1e-08,)
            #optimizer = tf.keras.optimizers.SGD(learning_rate=config.training.learningRate, momentum=0.1,nesterov=False)
            #optimizer = tf.train.GradientDescentOptimizer(learning_rate=config.training.learningRate)
            #adadelta
            #optimizer = tf.train.RMSPropOptimizer(config.training.learningRate, decay=0.9, momentum=0.1, epsilon=1e-10,)

            gradsAndVars = optimizer.compute_gradients(transformer.loss)
            '''
            mean_grad = tf.zeros(())
            for grad, var in gradsAndVars:
示例#34
0
    keypoints_dir=save_dir,
    max_frame_len=169,
)

dataloader = data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
)
label_map = dict(zip(label_map.values(), label_map.keys()))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = TransformerConfig(size="large", max_position_embeddings=256)
model = Transformer(config=config, n_classes=263)
model = model.to(device)

pretrained_model_name = "include_no_cnn_transformer_large.pth"
pretrained_model_links = load_json("pretrained_links.json")
if not os.path.isfile(pretrained_model_name):
    link = pretrained_model_links[pretrained_model_name]
    torch.hub.download_url_to_file(link, pretrained_model_name, progress=True)

ckpt = torch.load(pretrained_model_name)
model.load_state_dict(ckpt["model"])
print("### Model loaded ###")

preds = inference(dataloader, model, device, label_map)
print(json.dumps(preds, indent=2))
示例#35
0
def main(argv):
    # Creating dataloaders for training and validation
    logging.info("Creating the source dataloader from: %s" % FLAGS.source)
    logging.info("Creating the target dataloader from: %s" % FLAGS.target)
    train_dataset, valid_dataset, src_tokenizer, \
    tar_tokenizer, size_train, size_val = prepare_training_pairs(FLAGS.source,
                                                                 FLAGS.target,
                                                                 FLAGS.syn_src,
                                                                 FLAGS.syn_tar,
                                                                 batch_size=FLAGS.batch_size,
                                                                 valid_ratio=0.1,
                                                                 name="ENFR")

    # calculate vocabulary size
    src_vocsize = len(src_tokenizer.word_index) + 1
    tar_vocsize = len(tar_tokenizer.word_index) + 1
    # ----------------------------------------------------------------------------------
    # Creating the instance of the model specified.
    logging.info("Create Transformer Model")
    optimizer = tf.keras.optimizers.Adam()
    model = Transformer.Transformer(voc_size_src=src_vocsize,
                                    voc_size_tar=tar_vocsize,
                                    max_pe=10000,
                                    num_encoders=FLAGS.num_enc,
                                    num_decoders=FLAGS.num_dec,
                                    emb_size=FLAGS.emb_size,
                                    num_head=FLAGS.num_head,
                                    ff_inner=FLAGS.ffnn_dim)

    # load pretrained mBart
    if FLAGS.load_mBart:
        print("Load Pretraining mBART...")
        mbart_ckpt_dir = FLAGS.mbartckpt
        latest = tf.train.latest_checkpoint(mbart_ckpt_dir)

        checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
        status = checkpoint.restore(tf.train.latest_checkpoint(mbart_ckpt_dir))
        status.assert_existing_objects_matched()

    # ----------------------------------------------------------------------------------
    # Choose the Optimizor, Loss Function, and Metrics
    # create custom learning rate schedule
    class transformer_lr_schedule(
            tf.keras.optimizers.schedules.LearningRateSchedule):
        def __init__(self, emb_size, warmup_steps=4000):
            super(transformer_lr_schedule, self).__init__()
            self.emb_size = tf.cast(emb_size, tf.float32)
            self.warmup_steps = warmup_steps

        def __call__(self, step):
            lr_option1 = tf.math.rsqrt(step)
            lr_option2 = step * (self.warmup_steps**-1.5)
            return tf.math.rsqrt(self.emb_size) * tf.math.minimum(
                lr_option1, lr_option2)

    learning_rate = transformer_lr_schedule(FLAGS.emb_size)
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate,
                                         beta_1=0.9,
                                         beta_2=0.98,
                                         epsilon=1e-9)

    # Todo: figure out why SparceCategorticalCrossentropy
    criterion = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True,
                                                              reduction='none')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='train_accuracy')

    def loss_fn(label, pred):
        """
        The criterion above calculate the loss for all words (voc_size), need to mask the loss that
        not appears in label
        """
        mask = tf.math.logical_not(tf.math.equal(label, 0))
        loss = criterion(label, pred)

        # convert the mask from Bool to float
        mask = tf.cast(mask, dtype=loss.dtype)
        loss *= mask

        return tf.reduce_sum(loss) / tf.reduce_sum(mask)

    # ----------------------------------------------------------------------------------
    # train/valid function
    # Todo: need to understand this
    train_step_signature = [
        tf.TensorSpec(shape=[None, None], dtype=tf.int32),
        tf.TensorSpec(shape=[None, None], dtype=tf.int32)
    ]

    @tf.function(input_signature=train_step_signature)
    def train_step(inp, targ):
        tar_inp = targ[:, :-1]
        tar_real = targ[:, 1:]
        end = tf.cast(
            tf.math.logical_not(
                tf.math.equal(tar_inp, tar_tokenizer.word_index['<end>'])),
            tf.int32)
        tar_inp *= end
        # tf.print("tar inp", tar_inp)
        # tf.print("tar real", tar_real)
        # create mask
        enc_padding_mask = Transformer.create_padding_mask(inp)

        # mask for first attention block in decoder
        look_ahead_mask = Transformer.create_seq_mask(tf.shape(tar_inp)[1])
        dec_target_padding_mask = Transformer.create_padding_mask(tar_inp)
        combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

        # mask for "enc_dec" multihead attention
        dec_padding_mask = Transformer.create_padding_mask(inp)

        with tf.GradientTape() as tape:
            # feed input into encoder
            predictions = model(inp, tar_inp, True, enc_padding_mask,
                                combined_mask, dec_padding_mask)
            train_loss = loss_fn(tar_real, predictions)

            # optimize step
            gradients = tape.gradient(train_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients,
                                          model.trainable_variables))
        return train_loss

    @tf.function(input_signature=train_step_signature)
    def valid_step(inp, targ):
        tar_inp = targ[:, :-1]
        tar_real = targ[:, 1:]
        end = tf.cast(
            tf.math.logical_not(
                tf.math.equal(tar_inp, tar_tokenizer.word_index['<end>'])),
            tf.int32)
        tar_inp *= end
        # create mask
        enc_padding_mask = Transformer.create_padding_mask(inp)

        # mask for first attention block in decoder
        look_ahead_mask = Transformer.create_seq_mask(tf.shape(tar_inp)[1])
        dec_target_padding_mask = Transformer.create_padding_mask(tar_inp)
        combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

        # mask for "enc_dec" multihead attention
        dec_padding_mask = Transformer.create_padding_mask(inp)

        # feed input into encoder
        predictions = model(inp, tar_inp, False, enc_padding_mask,
                            combined_mask, dec_padding_mask)
        val_loss = loss_fn(tar_real, predictions)
        train_accuracy(tar_real, predictions)
        return val_loss

    # ----------------------------------------------------------------------------------
    # Set up Checkpoints, so as to resume training if something interrupt, and save results
    ckpt_prefix = os.path.join(FLAGS.ckpt, "ckpt_BT_ENFR_transformer")
    ckpt = tf.train.Checkpoint(optimizer=optimizer, model=model)
    manager = tf.train.CheckpointManager(ckpt,
                                         directory=FLAGS.ckpt,
                                         max_to_keep=2)
    # restore from latest checkpoint and iteration
    if not FLAGS.load_mBart:
        print("Load previous checkpoints...")
        status = ckpt.restore(manager.latest_checkpoint)
        if manager.latest_checkpoint:
            logging.info("Restored from {}".format(manager.latest_checkpoint))
            status.assert_existing_objects_matched()
        else:
            logging.info("Initializing from scratch.")

    # ----------------------------------------------------------------------------------
    # Setup the TensorBoard for better visualization
    logging.info("Setup the TensorBoard...")
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    train_log_dir = './logs/gradient_tape/' + current_time + '/BT_ENFR_transformer_train'
    test_log_dir = './logs/gradient_tape/' + current_time + '/BT_ENFR_transformer_test'
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)
    test_summary_writer = tf.summary.create_file_writer(test_log_dir)

    # ----------------------------------------------------------------------------------
    # Start Training Process
    EPOCHS = FLAGS.epochs

    for epoch in range(EPOCHS):
        start = time.time()
        total_train_loss = 0.
        total_val_loss = 0.

        # train
        for (inp, targ) in train_dataset:
            train_loss = train_step(inp, targ)
            total_train_loss += train_loss

        # save checkpoint
        if (epoch + 1) % 5 == 0:
            ckpt.save(file_prefix=ckpt_prefix)

        # validation
        for (inp, tar) in valid_dataset:
            val_loss = valid_step(inp, tar)
            total_val_loss += val_loss

        # average loss
        total_train_loss /= (size_train / FLAGS.batch_size)
        total_val_loss /= (size_val / FLAGS.batch_size)

        # Write loss to Tensorborad
        with train_summary_writer.as_default():
            tf.summary.scalar('Train loss', total_train_loss, step=epoch)

        with test_summary_writer.as_default():
            tf.summary.scalar('Valid loss', total_val_loss, step=epoch)

        logging.info(
            'Epoch {} Train Loss {:.4f} Valid loss {:.4f} Valid Accuracy {:.4f}'
            .format(epoch + 1, total_train_loss, total_val_loss,
                    train_accuracy.result()))

        logging.info(
            'Time taken for 1 train_step {} sec\n'.format(time.time() - start))
    maxlen = int(sys.argv[1])
else:
    maxlen = 16

args = subword_batches(zip(in_texts, tar_texts), maxlen)
args = subword_batches(zip(in_valid, tar_valid), maxlen, args)
dg = args[params.train_generator]
vdg = args[params.valid_generator]
input_vocab_size = args[params.input_vocab_size]
target_vocab_size = args[params.target_vocab_size]

transformer = Transformer(num_layers,
                          d_model,
                          num_heads,
                          dff,
                          input_vocab_size,
                          target_vocab_size,
                          pe_input=input_vocab_size,
                          pe_target=target_vocab_size,
                          target_len=args[params.valid_seq_len],
                          rate=dropout_rate)

transformer.compile(optimizer=optimizer,
                    loss=loss_function,
                    metrics=[accuracy_function, bleu_score])
history = transformer.fit(dg, epochs=params.epochs, validation_data=vdg)

transformer.save_weights(weights_dir + '/w' + str(maxlen) + '_ex' +
                         str(params.train_size))
json.dump(
    history.history,
    open(history_dir + '/h' + str(maxlen) + '_ex' + str(params.train_size),