def train_bert(train_iter, net, loss, vocab_size, device, num_steps):
    net = net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

    try:
        checkpoint_prefix = os.path.join("model_data/model_BERT_pretraining_single.pt")
        checkpoint = torch.load(checkpoint_prefix)
        net.load_state_dict(checkpoint['model_state_dict'])
        net.to(device)
        optimizer = torch.optim.Adam(net.parameters(), lr=lr)
        optimizer.load_state_dict(checkpoint['optimizer'])
        for state in optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(device)
    except Exception as e:
        print("Can not load the model with error:", e)

    checkpoint_prefix = os.path.join("model_data/model_BERT_pretraining_single.pt")


    step, timer = 0, Utility.Timer()
    animator = am.Animator(xlabel='step', ylabel='loss',
                            xlim=[1, num_steps], legend=['mlm'])
    # Sum of masked language modeling losses, sum of next sentence prediction
    # losses, no. of sentence pairs, count
    metric = am.Accumulator(2)
    num_steps_reached = False
    while step < num_steps and not num_steps_reached:
        for tokens_X, segments_X, valid_lens_x, pred_positions_X,\
            mlm_weights_X, mlm_Y in train_iter:
            tokens_X = tokens_X.to(device)
            segments_X = segments_X.to(device)
            valid_lens_x = valid_lens_x.to(device)
            pred_positions_X = pred_positions_X.to(device)
            mlm_weights_X = mlm_weights_X.to(device)
            mlm_Y= mlm_Y.to(device)
            optimizer.zero_grad()
            timer.start()
            l = _get_batch_loss_bert(
                net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,
                pred_positions_X, mlm_weights_X, mlm_Y)
            l.backward()
            optimizer.step()
            timer.stop()
            with torch.no_grad():
                metric.add(l, tokens_X.shape[0])
                animator.add(step + 1,
                            (metric[0] / metric[1]))
            if (step + 1) % 50 == 0:
                torch.save({'model_state_dict': net.state_dict(), "optimizer": optimizer.state_dict()},checkpoint_prefix)

            step += 1
            if step == num_steps:
                num_steps_reached = True
                break

    print(f'MLM loss {metric[0] / metric[1]:.3f}')
    print(f'{metric[1] / timer.sum():.1f} sentence pairs/sec on '
          f'{str(device)}')
def train_bert(net, data_iter, lr, num_epochs, batch_size, tgt_vocab, device):
    """Train a model for sequence to sequence."""
    def xavier_init_weights(m):
        if type(m) == nn.Linear:
            torch.nn.init.xavier_uniform_(m.weight)
        if type(m) == nn.GRU:
            for param in m._flat_weights_names:
                if "weight" in param:
                    torch.nn.init.xavier_uniform_(m._parameters[param])
    # net.apply(xavier_init_weights)
    try:
        checkpoint_prefix = os.path.join("model_data/model_bert.pt")
        checkpoint = torch.load(checkpoint_prefix)
        net.load_state_dict(checkpoint['model_state_dict'])
        net.to(device)
        optimizer = torch.optim.Adam(net.parameters(), lr=lr)
        optimizer.load_state_dict(checkpoint['optimizer'])
        for state in optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(device)
    except Exception as e:
        net.apply(xavier_init_weights)
        net.to(device)
        optimizer = torch.optim.Adam(net.parameters(), lr=lr)
        print("Can not load the model with error:", e)
    
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    
    loss = am.MaskedSoftmaxCELoss()
    net.train()
    animator = am.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[1, num_epochs*batch_size])

    
    checkpoint_prefix = os.path.join("model_data/model_bert.pt")
    # ratio = 100 / len(data_iter)
    # print("ratio=", ratio)
    num_trained = 0
    for epoch in range(num_epochs):
        timer = Utility.Timer()
        metric = am.Accumulator(2)  # Sum of training loss, no. of tokens
        # print("epoch ...", epoch)
        for i, batch in enumerate(data_iter):
            # if random.random() < (1 - ratio * 1.5):
            #     continue
            num_trained += 1
            optimizer.zero_grad()
            X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
            bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0],
                               device=device).reshape(-1, 1)
            dec_input = torch.cat([bos, Y[:, :-1]], 1)  # Teacher forcing
            Y_hat, _ = net(X, dec_input, X_valid_len)
            l = loss(Y_hat, Y, Y_valid_len)
            l.sum().backward()  # Make the loss scalar for `backward`
            # Utility.grad_clipping(net, 1)
            num_tokens = Y_valid_len.sum()
            optimizer.step()
            with torch.no_grad():
                metric.add(l.sum(), num_tokens)
            # if (i + 1) % 100 == 0:
            # print("    batch>>>", i)
            if (num_trained + 1) % 100 == 0:
                animator.add(num_trained + 1, (metric[0] / metric[1],))
                # print(f'epoch = {epoch}, loss = {metric[0] / metric[1]:.3f}')
                torch.save({'model_state_dict': net.state_dict(), "optimizer": optimizer.state_dict()},checkpoint_prefix)
        # if (epoch + 1) % 10 == 0:
        # animator.add(epoch + 1, (metric[0] / metric[1],))
        # # print(f'epoch = {epoch}, loss = {metric[0] / metric[1]:.3f}')
        # torch.save({'model_state_dict': net.state_dict(), "optimizer": optimizer.state_dict()},checkpoint_prefix)
        # sys.stdout.flush()
    print(f'loss {metric[0] / metric[1]:.3f}, {metric[1] / timer.stop():.1f} '
          f'tokens/sec on {str(device)}')
    
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    
    loss = am.MaskedSoftmaxCELoss()
    net.train()
    animator = am.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[1, num_epochs*batch_size])

    
    checkpoint_prefix = os.path.join("model_data/model_bert.pt")
    # ratio = 100 / len(data_iter)
    # print("ratio=", ratio)
    num_trained = 0
    for epoch in range(num_epochs):
        timer = Utility.Timer()
        metric = am.Accumulator(2)  # Sum of training loss, no. of tokens
        # print("epoch ...", epoch)
        for i, batch in enumerate(data_iter):
            # if random.random() < (1 - ratio * 1.5):
            #     continue
            num_trained += 1
            optimizer.zero_grad()
            X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
            bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0],
                               device=device).reshape(-1, 1)
            dec_input = torch.cat([bos, Y[:, :-1]], 1)  # Teacher forcing
            Y_hat, _ = net(X, dec_input, X_valid_len)
            l = loss(Y_hat, Y, Y_valid_len)
            l.sum().backward()  # Make the loss scalar for `backward`
            # Utility.grad_clipping(net, 1)
            num_tokens = Y_valid_len.sum()