Example #1
0
    def __init__(self, config_path, model_path, model_type):
        print(config_path)
        print(model_path)
        print(model_type)
        self.model_type = model_type
        configs = prepare_configs(config_path)

        data_configs = configs['data_configs']
        model_configs = configs['model_configs']

        vocab_src = Vocabulary.build_from_file(
            **data_configs['vocabularies'][0])
        vocab_tgt = Vocabulary.build_from_file(
            **data_configs['vocabularies'][1])

        nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                                n_tgt_vocab=vocab_tgt.max_n_words,
                                padding_idx=vocab_src.pad,
                                **model_configs)

        params = load_model_parameters(model_path, map_location="cpu")
        nmt_model.load_state_dict(params)
        nmt_model.cuda()
        nmt_model.eval()

        self.model = nmt_model
        self.data_configs = data_configs
        self.model_configs = model_configs
        self.vocab_src = vocab_src
        self.vocab_tgt = vocab_tgt
Example #2
0
def train(flags):
    """
    flags:
        saveto: str
        reload: store_true
        config_path: str
        pretrain_path: str, default=""
        model_name: str
        log_path: str
    """

    # ================================================================================== #
    # Initialization for training on different devices
    # - CPU/GPU
    # - Single/Distributed
    Constants.USE_GPU = flags.use_gpu

    world_size = 1
    rank = 0
    local_rank = 0

    if Constants.USE_GPU:
        torch.cuda.set_device(local_rank)
        Constants.CURRENT_DEVICE = "cuda:{0}".format(local_rank)
    else:
        Constants.CURRENT_DEVICE = "cpu"

    # If not root_rank, close logging
    # else write log of training to file.
    if rank == 0:
        write_log_to_file(
            os.path.join(flags.log_path,
                         "%s.log" % time.strftime("%Y%m%d-%H%M%S")))
    else:
        close_logging()

    # ================================================================================== #
    # Parsing configuration files
    # - Load default settings
    # - Load pre-defined settings
    # - Load user-defined settings

    configs = prepare_configs(flags.config_path, flags.predefined_config)

    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    optimizer_configs = configs['optimizer_configs']
    training_configs = configs['training_configs']

    INFO(pretty_configs(configs))

    Constants.SEED = training_configs['seed']

    set_seed(Constants.SEED)

    timer = Timer()

    # ================================================================================== #
    # Load Data

    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_src = Vocabulary.build_from_file(**data_configs['vocabularies'][0])

    Constants.EOS = vocab_src.eos
    Constants.PAD = vocab_src.pad
    Constants.BOS = vocab_src.bos

    train_bitext_dataset = TextLineDataset(
        data_path=data_configs['train_data'][0],
        vocabulary=vocab_src,
        max_len=data_configs['max_len'][0],
        is_train_dataset=True)

    valid_bitext_dataset = TextLineDataset(
        data_path=data_configs['valid_data'][0],
        vocabulary=vocab_src,
        is_train_dataset=False)

    training_iterator = DataIterator(
        dataset=train_bitext_dataset,
        batch_size=training_configs["batch_size"],
        use_bucket=training_configs['use_bucket'],
        buffer_size=training_configs['buffer_size'],
        batching_func=training_configs['batching_key'],
        world_size=world_size,
        rank=rank)
    valid_iterator = DataIterator(
        dataset=valid_bitext_dataset,
        batch_size=training_configs['valid_batch_size'],
        use_bucket=True,
        buffer_size=100000,
        numbering=True,
        shuffle=False,
        world_size=world_size,
        rank=rank)

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # ================================ Begin ======================================== #
    # Build Model & Optimizer
    # We would do steps below on after another
    #     1. build models & criterion
    #     2. move models & criterion to gpu if needed
    #     3. load pre-trained model if needed
    #     4. build optimizer
    #     5. build learning rate scheduler if needed
    #     6. load checkpoints if needed

    # 0. Initial

    model_collections = Collections()

    checkpoint_saver = Saver(
        save_prefix="{0}.ckpt".format(
            os.path.join(flags.saveto, flags.model_name)),
        num_max_keeping=training_configs['num_kept_checkpoints'])
    best_model_prefix = os.path.join(
        flags.saveto, flags.model_name + Constants.MY_BEST_MODEL_SUFFIX)
    best_model_saver = Saver(
        save_prefix=best_model_prefix,
        num_max_keeping=training_configs['num_kept_best_model'])

    # 1. Build Model & Criterion
    INFO('Building model...')
    timer.tic()
    nmt_model = build_model(vocab_size=vocab_src.max_n_words,
                            padding_idx=vocab_src.pad,
                            vocab_src=vocab_src,
                            **model_configs)
    INFO(nmt_model)
    # 损失函数
    critic = torch.nn.CrossEntropyLoss(ignore_index=Constants.PAD)
    INFO(critic)

    # 2. Move to GPU
    if Constants.USE_GPU:
        nmt_model = nmt_model.cuda()
        critic = critic.cuda()

    # 3. Load pretrained model if needed
    load_pretrained_model(nmt_model,
                          flags.pretrain_path,
                          exclude_prefix=flags.pretrain_exclude_prefix,
                          device=Constants.CURRENT_DEVICE)

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # 4. Build optimizer
    INFO('Building Optimizer...')
    optimizer = torch.optim.Adam(nmt_model.parameters(),
                                 lr=optimizer_configs['learning_rate'])

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # ================================================================================== #
    # Prepare training

    eidx = model_collections.get_collection("eidx", [0])[-1]
    uidx = model_collections.get_collection("uidx", [1])[-1]
    bad_count = model_collections.get_collection("bad_count", [0])[-1]
    oom_count = model_collections.get_collection("oom_count", [0])[-1]
    is_early_stop = model_collections.get_collection("is_early_stop", [
        False,
    ])[-1]

    train_loss_meter = AverageMeter()
    sent_per_sec_meter = TimeMeter()
    tok_per_sec_meter = TimeMeter()

    grad_denom = 0
    train_loss = 0.0
    cum_n_words = 0
    valid_loss = best_valid_loss = float('inf')

    if rank == 0:
        summary_writer = SummaryWriter(log_dir=flags.log_path)
    else:
        summary_writer = None

    sent_per_sec_meter.start()
    tok_per_sec_meter.start()

    INFO('Begin training...')

    while True:

        if summary_writer is not None:
            summary_writer.add_scalar("Epoch", (eidx + 1), uidx)

        # Build iterator and progress bar
        training_iter = training_iterator.build_generator()

        if rank == 0:
            training_progress_bar = tqdm(desc=' - (Epc {}, Upd {}) '.format(
                eidx, uidx),
                                         total=len(training_iterator),
                                         unit="sents")
        else:
            training_progress_bar = None

        for batch in training_iter:
            seqs_x = batch

            batch_size = len(seqs_x)
            cum_n_words = 0.0
            train_loss = 0.0

            try:
                # Prepare data
                grad_denom += batch_size
                x = prepare_data(seqs_x, seqs_y=None, cuda=Constants.USE_GPU)
                nmt_model.train()
                critic.train()
                critic.zero_grad()
                with torch.enable_grad():
                    logits = nmt_model(x[:-1])
                    logits = logits.view(-1, vocab_src.max_n_words)
                    trg = x[1:]
                    trg = trg.view(-1)
                    loss = critic(logits, trg)
                    loss.backward()
                    optimizer.step()
                    valid_token = (trg != Constants.PAD).long().sum().item()
                    cum_n_words += valid_token
                    train_loss += loss.item() * valid_token

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    oom_count += 1
                else:
                    raise e

            # When update_cycle becomes 0, it means end of one batch. Several things will be done:
            # - update parameters
            # - reset update_cycle and grad_denom, update uidx
            # - learning rate scheduling
            # - update moving average

            if training_progress_bar is not None:
                training_progress_bar.update(grad_denom)
                training_progress_bar.set_description(
                    ' - (Epc {}, Upd {}) '.format(eidx, uidx))

                postfix_str = 'TrainLoss: {:.2f}, ValidLoss(best): {:.2f} ({:.2f}), '.format(
                    train_loss / cum_n_words, valid_loss, best_valid_loss)
                training_progress_bar.set_postfix_str(postfix_str)

            # 4. update meters
            train_loss_meter.update(train_loss, cum_n_words)
            sent_per_sec_meter.update(grad_denom)
            tok_per_sec_meter.update(cum_n_words)

            # 5. reset accumulated variables, update uidx
            grad_denom = 0
            uidx += 1
            cum_n_words = 0.0
            train_loss = 0.0

            # ================================================================================== #
            # Display some information
            if should_trigger_by_steps(
                    uidx, eidx, every_n_step=training_configs['disp_freq']):

                if summary_writer is not None:
                    summary_writer.add_scalar(
                        "Speed(sents/sec)",
                        scalar_value=sent_per_sec_meter.ave,
                        global_step=uidx)
                    summary_writer.add_scalar(
                        "Speed(words/sec)",
                        scalar_value=tok_per_sec_meter.ave,
                        global_step=uidx)
                    summary_writer.add_scalar(
                        "train_loss",
                        scalar_value=train_loss_meter.ave,
                        global_step=uidx)
                    summary_writer.add_scalar("oom_count",
                                              scalar_value=oom_count,
                                              global_step=uidx)

                # Reset Meters
                sent_per_sec_meter.reset()
                tok_per_sec_meter.reset()
                train_loss_meter.reset()

            # ================================================================================== #
            # Loss Validation & Learning rate annealing
            if should_trigger_by_steps(
                    global_step=uidx,
                    n_epoch=eidx,
                    every_n_step=training_configs['loss_valid_freq'],
                    min_step=training_configs['bleu_valid_warmup'],
                    debug=flags.debug):

                valid_iter = valid_iterator.build_generator()
                valid_loss = 0
                total_tokens = 0
                for batch in valid_iter:
                    seq_number, seqs_x = batch
                    x = prepare_data(seqs_x,
                                     seqs_y=None,
                                     cuda=Constants.USE_GPU)
                    nmt_model.eval()
                    critic.eval()
                    with torch.no_grad():
                        logits = nmt_model(x[:-1])
                        logits = logits.view(-1, vocab_src.max_n_words)
                        trg = x[1:]
                        valid_token = (trg != Constants.PAD).sum(-1)
                        batch_size, seq_len = trg.shape
                        trg = trg.view(-1)
                        # loss = critic(logits, trg)
                        # valid_token = (trg != Constants.PAD).long().sum().item()
                        # total_tokens += valid_token
                        # valid_loss += loss.item() * valid_token
                        import torch.nn.functional as F
                        loss = F.cross_entropy(logits,
                                               trg,
                                               reduce=False,
                                               ignore_index=vocab_src.pad)
                        loss = loss.view(batch_size, seq_len)
                        loss = loss.sum(-1)
                        print(seq_number)
                        print(loss.double().div(valid_token.double()))
                exit(0)
                valid_loss = valid_loss / total_tokens
                model_collections.add_to_collection("history_losses",
                                                    valid_loss)

                min_history_loss = np.array(
                    model_collections.get_collection("history_losses")).min()
                best_valid_loss = min_history_loss
                if summary_writer is not None:
                    summary_writer.add_scalar("loss",
                                              valid_loss,
                                              global_step=uidx)
                    summary_writer.add_scalar("best_loss",
                                              min_history_loss,
                                              global_step=uidx)
                # If model get new best valid bleu score
                if valid_loss <= best_valid_loss:
                    bad_count = 0

                    if is_early_stop is False:
                        if rank == 0:
                            # 1. save the best model
                            torch.save(nmt_model.state_dict(),
                                       best_model_prefix + ".final")

                            # 2. record all several best models
                            best_model_saver.save(
                                global_step=uidx,
                                model=nmt_model,
                                optimizer=optimizer,
                                collections=model_collections)
                    else:
                        bad_count += 1

                        # At least one epoch should be traversed
                        if bad_count >= training_configs[
                                'early_stop_patience'] and eidx > 0:
                            is_early_stop = True
                            WARN("Early Stop!")
                            exit(0)

                if summary_writer is not None:
                    summary_writer.add_scalar("bad_count", bad_count, uidx)

                INFO("{0} Loss: {1:.2f}  patience: {2}".format(
                    uidx, valid_loss, bad_count))

            # ================================================================================== #
            # # Saving checkpoints
            # if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['save_freq'], debug=flags.debug):
            #     model_collections.add_to_collection("uidx", uidx)
            #     model_collections.add_to_collection("eidx", eidx)
            #     model_collections.add_to_collection("bad_count", bad_count)
            #
            #     if not is_early_stop:
            #         if rank == 0:
            #             checkpoint_saver.save(global_step=uidx,
            #                                   model=nmt_model,
            #                                   optim=optimizer,
            #                                   collections=model_collections)

        if training_progress_bar is not None:
            training_progress_bar.close()

        eidx += 1
        if eidx > training_configs["max_epochs"]:
            break
Example #3
0
def train(config_path, model_path, model_type, src_filename, trg_filename):
    """
    flags:
        saveto: str
        reload: store_true
        config_path: str
        pretrain_path: str, default=""
        model_name: str
        log_path: str
    """

    # ================================================================================== #
    # Initialization for training on different devices
    # - CPU/GPU
    # - Single/Distributed
    Constants.USE_GPU = True
    print(config_path)
    print(model_path)
    print(model_type)

    world_size = 1
    rank = 0
    local_rank = 0

    if Constants.USE_GPU:
        torch.cuda.set_device(local_rank)
        Constants.CURRENT_DEVICE = "cuda:{0}".format(local_rank)
    else:
        Constants.CURRENT_DEVICE = "cpu"

    # ================================================================================== #
    # Parsing configuration files
    # - Load default settings
    # - Load pre-defined settings
    # - Load user-defined settings

    configs = prepare_configs(config_path)

    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    training_configs = configs['training_configs']

    INFO(pretty_configs(configs))

    Constants.SEED = training_configs['seed']
    set_seed(Constants.SEED)
    timer = Timer()

    # ================================================================================== #
    # Load Data

    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_src = Vocabulary.build_from_file(**data_configs['vocabularies'][0])
    vocab_tgt = Vocabulary.build_from_file(**data_configs['vocabularies'][1])

    Constants.EOS = vocab_src.eos
    Constants.PAD = vocab_src.pad
    Constants.BOS = vocab_src.bos

    valid_bitext_dataset = ZipDataset(
        TextLineDataset(
            data_path=src_filename,
            vocabulary=vocab_src,
            max_len=100,
            is_train_dataset=False,
        ),
        TextLineDataset(
            data_path=trg_filename,
            vocabulary=vocab_tgt,
            is_train_dataset=False,
            max_len=100,
        ))

    valid_iterator = DataIterator(dataset=valid_bitext_dataset,
                                  batch_size=20,
                                  use_bucket=training_configs['use_bucket'],
                                  buffer_size=training_configs['buffer_size'],
                                  numbering=True,
                                  world_size=world_size,
                                  rank=rank)

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # ================================ Begin ======================================== #
    # Build Model & Optimizer
    # We would do steps below on after another
    #     1. build models & criterion
    #     2. move models & criterion to gpu if needed
    #     3. load pre-trained model if needed
    #     4. build optimizer
    #     5. build learning rate scheduler if needed
    #     6. load checkpoints if needed

    # 0. Initial

    # 1. Build Model & Criterion
    INFO('Building model...')
    timer.tic()
    nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                            n_tgt_vocab=vocab_tgt.max_n_words,
                            padding_idx=vocab_src.pad,
                            vocab_src=vocab_src,
                            **model_configs)
    INFO(nmt_model)

    # 2. Move to GPU
    if Constants.USE_GPU:
        nmt_model = nmt_model.cuda()

    # 3. Load pretrained model if needed
    load_pretrained_model(nmt_model,
                          model_path,
                          device=Constants.CURRENT_DEVICE)

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # ================================================================================== #
    # Prepare training

    sent_per_sec_meter = TimeMeter()
    tok_per_sec_meter = TimeMeter()

    grad_denom = 0
    train_loss = 0.0
    cum_n_words = 0
    valid_loss = best_valid_loss = float('inf')

    sent_per_sec_meter.start()
    tok_per_sec_meter.start()

    INFO('Begin training...')
    eidx = 0
    uidx = 0
    score_result = dict()

    # Build iterator and progress bar
    training_iter = valid_iterator.build_generator()

    training_progress_bar = tqdm(desc=' - (Epc {}, Upd {}) '.format(
        eidx, uidx),
                                 total=len(valid_iterator),
                                 unit="sents")

    for batch in training_iter:
        seqs_numbers, seqs_x, seqs_y = batch

        batch_size = len(seqs_x)
        cum_n_words += sum(len(s) for s in seqs_y)

        try:
            # Prepare data
            x, y = prepare_data(seqs_x, seqs_y, cuda=Constants.USE_GPU)

            y_inp = y[:, :-1].contiguous()
            y_label = y[:, 1:].contiguous()  # [batch_size, seq_len]
            log_probs = nmt_model(
                x, y_inp, log_probs=True)  # [batch_size, seq_len, vocab_size]

            _, seq_len = y_label.shape
            log_probs = log_probs.view(-1, vocab_tgt.max_n_words)
            y_label = y_label.view(-1)
            loss = F.nll_loss(log_probs,
                              y_label,
                              reduce=False,
                              ignore_index=vocab_tgt.pad)
            loss = loss.view(batch_size, seq_len)
            loss = loss.sum(-1)

            y_label = y_label.view(batch_size, seq_len)
            valid_token = (y_label != vocab_tgt.pad).sum(-1)
            loss = loss.double().div(valid_token.double())
            for seq_num, l in zip(seqs_numbers, loss):
                assert seq_num not in score_result
                score_result.update({seq_num: l.item()})

            uidx += 1
            grad_denom += batch_size

        except RuntimeError as e:
            if 'out of memory' in str(e):
                print('| WARNING: ran out of memory, skipping batch')
            else:
                raise e

        if training_progress_bar is not None:
            training_progress_bar.update(batch_size)
            training_progress_bar.set_description(
                ' - (Epc {}, Upd {}) '.format(eidx, uidx))

            postfix_str = 'TrainLoss: {:.2f}, ValidLoss(best): {:.2f} ({:.2f}), '.format(
                train_loss, valid_loss, best_valid_loss)
            training_progress_bar.set_postfix_str(postfix_str)

    training_progress_bar.close()
    return score_result
Example #4
0
def train2(flags):
    """
    flags:
        saveto: str
        reload: store_true
        config_path: str
        pretrain_path: str, default=""
        model_name: str
        log_path: str
    """

    # ================================================================================== #
    # Initialization for training on different devices
    # - CPU/GPU
    # - Single/Distributed
    Constants.USE_GPU = flags.use_gpu

    world_size = 1
    rank = 0
    local_rank = 0

    if Constants.USE_GPU:
        torch.cuda.set_device(local_rank)
        Constants.CURRENT_DEVICE = "cuda:{0}".format(local_rank)
    else:
        Constants.CURRENT_DEVICE = "cpu"

    # ================================================================================== #
    # Parsing configuration files
    # - Load default settings
    # - Load pre-defined settings
    # - Load user-defined settings

    configs = prepare_configs(flags.config_path, flags.predefined_config)

    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    training_configs = configs['training_configs']
    bt_configs = configs['bt_configs'] if 'bt_configs' in configs else None
    if bt_configs is not None:
        print("btconfigs ", bt_configs)
        if 'bt_attribute_data' not in bt_configs:
            Constants.USE_BT = False
            bt_configs = None
        else:
            Constants.USE_BT = True
            Constants.USE_BTTAG = bt_configs['use_bttag']
            Constants.USE_CONFIDENCE = bt_configs['use_confidence']
    INFO(pretty_configs(configs))

    Constants.SEED = training_configs['seed']

    set_seed(Constants.SEED)

    timer = Timer()

    # ================================================================================== #
    # Load Data
    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_src = Vocabulary.build_from_file(**data_configs['vocabularies'][0])
    vocab_tgt = Vocabulary.build_from_file(**data_configs['vocabularies'][1])

    Constants.EOS = vocab_src.eos
    Constants.PAD = vocab_src.pad
    Constants.BOS = vocab_src.bos

    # bt tag dataset
    if Constants.USE_BT:
        if Constants.USE_BTTAG:
            Constants.BTTAG = vocab_src.bttag
        train_bitext_dataset = ZipDataset(
            TextLineDataset(data_path=data_configs['train_data'][0],
                            vocabulary=vocab_src,
                            max_len=data_configs['max_len'][0],
                            is_train_dataset=True
                            ),
            TextLineDataset(data_path=data_configs['train_data'][1],
                            vocabulary=vocab_tgt,
                            max_len=data_configs['max_len'][1],
                            is_train_dataset=True
                            ),
            AttributeDataset(data_path=bt_configs['bt_attribute_data'], is_train_dataset=True)
        )
    else:
        train_bitext_dataset = ZipDataset(
            TextLineDataset(data_path=data_configs['train_data'][0],
                            vocabulary=vocab_src,
                            max_len=data_configs['max_len'][0],
                            is_train_dataset=True
                            ),
            TextLineDataset(data_path=data_configs['train_data'][1],
                            vocabulary=vocab_tgt,
                            max_len=data_configs['max_len'][1],
                            is_train_dataset=True
                            )
        )

    training_iterator = DataIterator(dataset=train_bitext_dataset,
                                     batch_size=training_configs["batch_size"],
                                     use_bucket=training_configs['use_bucket'],
                                     buffer_size=training_configs['buffer_size'],
                                     batching_func=training_configs['batching_key'],
                                     world_size=world_size, numbering=True,
                                     rank=rank)

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # ================================ Begin ======================================== #
    # Build Model & Optimizer
    # We would do steps below on after another
    #     1. build models & criterion
    #     2. move models & criterion to gpu if needed
    #     3. load pre-trained model if needed
    #     4. build optimizer
    #     5. build learning rate scheduler if needed
    #     6. load checkpoints if needed

    # 0. Initial
    # 1. Build Model & Criterion
    INFO('Building model...')
    timer.tic()
    nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                            n_tgt_vocab=vocab_tgt.max_n_words, padding_idx=vocab_src.pad, vocab_src=vocab_src,
                            vocab_tgt=vocab_tgt,
                            **model_configs)
    INFO(nmt_model)

    # 2. Move to GPU
    if Constants.USE_GPU:
        nmt_model = nmt_model.cuda()

    # 3. Load pretrained model if needed
    load_pretrained_model(nmt_model, flags.pretrain_path, exclude_prefix=flags.pretrain_exclude_prefix,
                          device=Constants.CURRENT_DEVICE)
    nmt_model = nmt_model.encoder
    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    INFO('Begin training...')

    # 计算train集合每个句子的表示:mean pool
    training_iter = training_iterator.build_generator()
    nmt_model.eval()

    all_seq_numbers = []
    encoder_filename = "/home/wangdq/encoder.mean.output"
    seq_numbers_filename = '/home/wangdq/seq_numbers.output'

    processd = 0

    with open(encoder_filename, 'w') as f_encoder, open(seq_numbers_filename, 'w') as f_seq_numbers:

        for batch in training_iter:
            bt_attrib = None
            # bt attrib data
            if Constants.USE_BT:
                seq_numbers, seqs_x, seqs_y, bt_attrib = batch  # seq_numerbs从0开始编号
            else:
                seq_numbers, seqs_x, seqs_y = batch

            x = prepare_data(seqs_x, seqs_y=None, cuda=Constants.USE_GPU, bt_attrib=bt_attrib)

            try:
                with torch.no_grad():
                    encoder_hidden, mask = nmt_model(x)
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                else:
                    raise e

            valid_hidden = (mask == False).float().cuda()
            sum_encoder_hidden = (encoder_hidden * valid_hidden.unsqueeze(-1)).sum(dim=1)
            valid_tokens = (mask == False).sum(-1)
            mean_encoder_hidden = sum_encoder_hidden.float().div(valid_tokens.unsqueeze(1))

            all_seq_numbers.extend(seq_numbers)
            # if all_mean_encoder_hidden is None:
            #     all_mean_encoder_hidden = mean_encoder_hidden.cpu()
            # else:
            #     all_mean_encoder_hidden = torch.cat((all_mean_encoder_hidden, mean_encoder_hidden.cpu()), dim=0)

            mean_encoder_list = mean_encoder_hidden.cpu().numpy().tolist()
            content = [[str(i) for i in mean] for mean in mean_encoder_list]
            content = [' '.join(mean) + '\n' for mean in content]
            f_encoder.writelines(content)

            processd += len(seq_numbers)
            print(processd)

        content = [str(i) for i in all_seq_numbers]
        content = ' '.join(content)
        f_seq_numbers.writelines(content)
Example #5
0
def test_data(flags):
    Constants.USE_GPU = flags.use_gpu

    world_size = 1
    rank = 0
    local_rank = 0

    if Constants.USE_GPU:
        torch.cuda.set_device(local_rank)
        Constants.CURRENT_DEVICE = "cuda:{0}".format(local_rank)
    else:
        Constants.CURRENT_DEVICE = "cpu"

    # ================================================================================== #
    # Parsing configuration files
    # - Load default settings
    # - Load pre-defined settings
    # - Load user-defined settings

    configs = prepare_configs(flags.config_path, flags.predefined_config)

    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    training_configs = configs['training_configs']
    bt_configs = configs['bt_configs'] if 'bt_configs' in configs else None
    if bt_configs is not None:
        print("btconfigs ", bt_configs)
        if 'bt_attribute_data' not in bt_configs:
            Constants.USE_BT = False
            bt_configs = None
        else:
            Constants.USE_BT = True
            Constants.USE_BTTAG = bt_configs['use_bttag']
            Constants.USE_CONFIDENCE = bt_configs['use_confidence']
    INFO(pretty_configs(configs))

    Constants.SEED = training_configs['seed']

    set_seed(Constants.SEED)

    timer = Timer()

    # ================================================================================== #
    # Load Data
    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_src = Vocabulary.build_from_file(**data_configs['vocabularies'][0])
    vocab_tgt = Vocabulary.build_from_file(**data_configs['vocabularies'][1])

    Constants.EOS = vocab_src.eos
    Constants.PAD = vocab_src.pad
    Constants.BOS = vocab_src.bos

    valid_bitext_dataset = ZipDataset(
        TextLineDataset(data_path=data_configs['valid_data'][0],
                        vocabulary=vocab_src,
                        is_train_dataset=False,
                        ),
        TextLineDataset(data_path=data_configs['valid_data'][1],
                        vocabulary=vocab_tgt,
                        is_train_dataset=False
                        )
    )

    valid_iterator = DataIterator(dataset=valid_bitext_dataset,
                                  batch_size=training_configs['valid_batch_size'],
                                  use_bucket=True, buffer_size=100000, numbering=True,
                                  world_size=world_size, rank=rank, shuffle=False)

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # ================================ Begin ======================================== #
    # Build Model & Optimizer
    # We would do steps below on after another
    #     1. build models & criterion
    #     2. move models & criterion to gpu if needed
    #     3. load pre-trained model if needed
    #     4. build optimizer
    #     5. build learning rate scheduler if needed
    #     6. load checkpoints if needed

    # 0. Initial
    # 1. Build Model & Criterion
    INFO('Building model...')
    timer.tic()
    nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                            n_tgt_vocab=vocab_tgt.max_n_words, padding_idx=vocab_src.pad, vocab_src=vocab_src,
                            vocab_tgt=vocab_tgt,
                            **model_configs)
    INFO(nmt_model)

    # 2. Move to GPU
    if Constants.USE_GPU:
        nmt_model = nmt_model.cuda()

    # 3. Load pretrained model if needed
    load_pretrained_model(nmt_model, flags.pretrain_path, exclude_prefix=flags.pretrain_exclude_prefix,
                          device=Constants.CURRENT_DEVICE)
    nmt_model = nmt_model.encoder
    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    INFO('Begin training...')
    # 计算train集合每个句子的表示:mean pool
    nmt_model.eval()

    # 计算test集合每个句子的表示: mean pool
    valid_iter = valid_iterator.build_generator()
    all_seq_numbers = []

    all_mean_encoder_hidden = None
    for batch in valid_iter:
        bt_attrib = None
        seq_numbers, seqs_x, seqs_y = batch
        all_seq_numbers.extend(seq_numbers)
        x = prepare_data(seqs_x, seqs_y=None, cuda=Constants.USE_GPU, bt_attrib=bt_attrib)
        try:
            with torch.no_grad():
                encoder_hidden, mask = nmt_model(x)
        except RuntimeError as e:
            if 'out of memory' in str(e):
                print('| WARNING: ran out of memory, skipping batch')
            else:
                raise e

        valid_hidden = (mask == False).float().cuda()
        sum_encoder_hidden = (encoder_hidden * valid_hidden.unsqueeze(-1)).sum(dim=1)
        valid_tokens = (mask == False).sum(-1)
        mean_encoder_hidden = sum_encoder_hidden.float().div(valid_tokens.unsqueeze(1))

        if all_mean_encoder_hidden is None:
            all_mean_encoder_hidden = mean_encoder_hidden
        else:
            all_mean_encoder_hidden = torch.cat((all_mean_encoder_hidden, mean_encoder_hidden), dim=0)
    return all_mean_encoder_hidden, all_seq_numbers
Example #6
0
    def __init__(self,
                 n_src_vocab,
                 n_tgt_vocab,
                 n_layers=6,
                 n_head=8,
                 d_word_vec=512,
                 d_model=512,
                 d_inner_hid=1024,
                 dim_per_head=None,
                 dropout=0.1,
                 tie_input_output_embedding=True,
                 tie_source_target_embedding=False,
                 padding_idx=PAD,
                 layer_norm_first=True,
                 positional_embedding="sin",
                 generator_bias=False,
                 ffn_activation="relu",
                 vocab_src=None,
                 **kwargs):

        super(Transformer_Char, self).__init__()

        self.char_vocab = Vocabulary.build_from_file(**kwargs['char_vocab'])

        self.encoder = Encoder(n_src_vocab,
                               char_src_vocab=self.char_vocab.max_n_words,
                               n_layers=n_layers,
                               n_head=n_head,
                               d_word_vec=d_word_vec,
                               d_model=d_model,
                               d_inner_hid=d_inner_hid,
                               dropout=dropout,
                               dim_per_head=dim_per_head,
                               padding_idx=padding_idx,
                               layer_norm_first=layer_norm_first,
                               positional_embedding=positional_embedding,
                               ffn_activation=ffn_activation)

        self.decoder = Decoder(
            n_tgt_vocab,
            n_layers=n_layers,
            n_head=n_head,
            d_word_vec=d_word_vec,
            d_model=d_model,
            d_inner_hid=d_inner_hid,
            dropout=dropout,
            dim_per_head=dim_per_head,
            padding_idx=padding_idx,
            layer_norm_first=layer_norm_first,
            positional_embedding=positional_embedding,
            ffn_activation=ffn_activation,
        )

        self.dropout = nn.Dropout(dropout)

        assert d_model == d_word_vec, \
            'To facilitate the residual connections, \
             the dimensions of all module output shall be the same.'

        if tie_source_target_embedding:
            assert n_src_vocab == n_tgt_vocab, \
                "source and target vocabulary should have equal size when tying source&target embedding"
            self.encoder.embeddings.embeddings.weight = self.decoder.embeddings.embeddings.weight

        if tie_input_output_embedding:
            self.generator = Generator(
                n_words=n_tgt_vocab,
                hidden_size=d_word_vec,
                shared_weight=self.decoder.embeddings.embeddings.weight,
                padding_idx=PAD,
                add_bias=generator_bias)

        else:
            self.generator = Generator(n_words=n_tgt_vocab,
                                       hidden_size=d_word_vec,
                                       padding_idx=PAD,
                                       add_bias=generator_bias)

        self.bpe_vocab = vocab_src
Example #7
0
def tune(flags):
    """
    flags:
        saveto: str
        reload: store_true
        config_path: str
        pretrain_path: str, default=""
        model_name: str
        log_path: str
    """

    # ================================================================================== #
    # Initialization for training on different devices
    # - CPU/GPU
    # - Single/Distributed
    Constants.USE_GPU = flags.use_gpu

    if flags.multi_gpu:
        dist.distributed_init(flags.shared_dir)
        world_size = dist.get_world_size()
        rank = dist.get_rank()
        local_rank = dist.get_local_rank()
    else:
        world_size = 1
        rank = 0
        local_rank = 0

    if Constants.USE_GPU:
        torch.cuda.set_device(local_rank)
        Constants.CURRENT_DEVICE = "cuda:{0}".format(local_rank)
    else:
        Constants.CURRENT_DEVICE = "cpu"

    # If not root_rank, close logging
    # else write log of training to file.
    if rank == 0:
        write_log_to_file(
            os.path.join(flags.log_path,
                         "%s.log" % time.strftime("%Y%m%d-%H%M%S")))
    else:
        close_logging()

    # ================================================================================== #
    # Parsing configuration files
    # - Load default settings
    # - Load pre-defined settings
    # - Load user-defined settings

    configs = prepare_configs(flags.config_path, flags.predefined_config)

    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    optimizer_configs = configs['optimizer_configs']
    training_configs = configs['training_configs']

    INFO(pretty_configs(configs))

    Constants.SEED = training_configs['seed']

    set_seed(Constants.SEED)

    timer = Timer()

    # ================================================================================== #
    # Load Data
    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_src = Vocabulary.build_from_file(**data_configs['vocabularies'][0])
    vocab_tgt = Vocabulary.build_from_file(**data_configs['vocabularies'][1])

    Constants.EOS = vocab_src.eos
    Constants.PAD = vocab_src.pad
    Constants.BOS = vocab_src.bos
    # bt tag dataset
    train_bitext_dataset = ZipDataset(
        TextLineDataset(data_path=data_configs['train_data'][0],
                        vocabulary=vocab_src,
                        max_len=data_configs['max_len'][0],
                        is_train_dataset=True),
        TextLineDataset(data_path=data_configs['train_data'][1],
                        vocabulary=vocab_tgt,
                        max_len=data_configs['max_len'][1],
                        is_train_dataset=True))

    training_iterator = DataIterator(
        dataset=train_bitext_dataset,
        batch_size=training_configs["batch_size"],
        use_bucket=training_configs['use_bucket'],
        buffer_size=training_configs['buffer_size'],
        batching_func=training_configs['batching_key'],
        world_size=world_size,
        rank=rank)

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # ================================ Begin ======================================== #
    # Build Model & Optimizer
    # We would do steps below on after another
    #     1. build models & criterion
    #     2. move models & criterion to gpu if needed
    #     3. load pre-trained model if needed
    #     4. build optimizer
    #     5. build learning rate scheduler if needed
    #     6. load checkpoints if needed

    # 0. Initial

    lrate = optimizer_configs['learning_rate']
    model_collections = Collections()

    checkpoint_saver = Saver(
        save_prefix="{0}.ckpt".format(
            os.path.join(flags.saveto, flags.model_name)),
        num_max_keeping=training_configs['num_kept_checkpoints'])
    best_model_prefix = os.path.join(
        flags.saveto, flags.model_name + Constants.MY_BEST_MODEL_SUFFIX)
    best_model_saver = Saver(
        save_prefix=best_model_prefix,
        num_max_keeping=training_configs['num_kept_best_model'])

    # 1. Build Model & Criterion
    INFO('Building model...')
    timer.tic()
    nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                            n_tgt_vocab=vocab_tgt.max_n_words,
                            padding_idx=vocab_src.pad,
                            vocab_src=vocab_src,
                            vocab_tgt=vocab_tgt,
                            **model_configs)
    INFO(nmt_model)

    critic = NMTCriterion(label_smoothing=model_configs['label_smoothing'],
                          padding_idx=vocab_tgt.pad)

    INFO(critic)

    # 2. Move to GPU
    if Constants.USE_GPU:
        nmt_model = nmt_model.cuda()
        critic = critic.cuda()

    # 3. Load pretrained model if needed
    load_pretrained_model(nmt_model,
                          flags.pretrain_path,
                          exclude_prefix=flags.pretrain_exclude_prefix,
                          device=Constants.CURRENT_DEVICE)
    # froze_parameters
    froze_params(nmt_model, flags.froze_config)

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # 4. Build optimizer
    INFO('Building Optimizer...')

    if not flags.multi_gpu:
        optim = Optimizer(name=optimizer_configs['optimizer'],
                          model=nmt_model,
                          lr=lrate,
                          grad_clip=optimizer_configs['grad_clip'],
                          optim_args=optimizer_configs['optimizer_params'],
                          update_cycle=training_configs['update_cycle'])
    else:
        optim = dist.DistributedOptimizer(
            name=optimizer_configs['optimizer'],
            model=nmt_model,
            lr=lrate,
            grad_clip=optimizer_configs['grad_clip'],
            optim_args=optimizer_configs['optimizer_params'],
            device_id=local_rank)

    # 5. Build scheduler for optimizer if needed
    scheduler = build_scheduler(
        schedule_method=optimizer_configs['schedule_method'],
        optimizer=optim,
        scheduler_configs=optimizer_configs['scheduler_configs'])

    # 6. build moving average
    if training_configs['moving_average_method'] is not None:
        ma = MovingAverage(
            moving_average_method=training_configs['moving_average_method'],
            named_params=nmt_model.named_parameters(),
            alpha=training_configs['moving_average_alpha'])
    else:
        ma = None

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # Reload from latest checkpoint
    if flags.reload:
        checkpoint_saver.load_latest(model=nmt_model,
                                     optim=optim,
                                     lr_scheduler=scheduler,
                                     collections=model_collections,
                                     ma=ma,
                                     device=Constants.CURRENT_DEVICE)

    # broadcast parameters and optimizer states
    if world_size > 1:
        INFO("Broadcasting model parameters...")
        dist.broadcast_parameters(params=nmt_model.state_dict())
        INFO("Broadcasting optimizer states...")
        dist.broadcast_optimizer_state(optimizer=optim.optim)
        INFO('Done.')

    # ================================================================================== #
    # Prepare training

    eidx = model_collections.get_collection("eidx", [0])[-1]
    uidx = model_collections.get_collection("uidx", [1])[-1]
    bad_count = model_collections.get_collection("bad_count", [0])[-1]
    oom_count = model_collections.get_collection("oom_count", [0])[-1]
    is_early_stop = model_collections.get_collection("is_early_stop", [
        False,
    ])[-1]

    train_loss_meter = AverageMeter()
    sent_per_sec_meter = TimeMeter()
    tok_per_sec_meter = TimeMeter()

    update_cycle = training_configs['update_cycle']
    grad_denom = 0
    train_loss = 0.0
    cum_n_words = 0
    valid_loss = best_valid_loss = float('inf')

    if rank == 0:
        summary_writer = SummaryWriter(log_dir=flags.log_path)
    else:
        summary_writer = None

    sent_per_sec_meter.start()
    tok_per_sec_meter.start()

    INFO('Begin training...')

    while True:

        if summary_writer is not None:
            summary_writer.add_scalar("Epoch", (eidx + 1), uidx)

        # Build iterator and progress bar
        training_iter = training_iterator.build_generator()

        if rank == 0:
            training_progress_bar = tqdm(desc=' - (Epc {}, Upd {}) '.format(
                eidx, uidx),
                                         total=len(training_iterator),
                                         unit="sents")
        else:
            training_progress_bar = None
        # INFO(Constants.USE_BT)
        for batch in training_iter:
            # bt attrib data
            seqs_x, seqs_y = batch

            batch_size = len(seqs_x)
            cum_n_words += sum(len(s) for s in seqs_y)

            try:
                # Prepare data
                x, y = prepare_data(seqs_x, seqs_y, cuda=Constants.USE_GPU)

                loss = compute_forward(
                    model=nmt_model,
                    critic=critic,
                    seqs_x=x,
                    seqs_y=y,
                    eval=False,
                    normalization=1.0,
                    norm_by_words=training_configs["norm_by_words"])

                update_cycle -= 1
                grad_denom += batch_size
                train_loss += loss

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    oom_count += 1
                else:
                    raise e

            # When update_cycle becomes 0, it means end of one batch. Several things will be done:
            # - update parameters
            # - reset update_cycle and grad_denom, update uidx
            # - learning rate scheduling
            # - update moving average

            if update_cycle == 0:

                # 0. reduce variables
                if world_size > 1:
                    grad_denom = dist.all_reduce_py(grad_denom)
                    train_loss = dist.all_reduce_py(train_loss)
                    cum_n_words = dist.all_reduce_py(cum_n_words)

                # 1. update parameters
                optim.step(denom=grad_denom)
                optim.zero_grad()

                if training_progress_bar is not None:
                    training_progress_bar.update(grad_denom)
                    training_progress_bar.set_description(
                        ' - (Epc {}, Upd {}) '.format(eidx, uidx))

                    postfix_str = 'TrainLoss: {:.2f}, ValidLoss(best): {:.2f} ({:.2f}), '.format(
                        train_loss, valid_loss, best_valid_loss)
                    training_progress_bar.set_postfix_str(postfix_str)

                # 2. learning rate scheduling
                if scheduler is not None and optimizer_configs[
                        "schedule_method"] != "loss":
                    scheduler.step(global_step=uidx)

                # 3. update moving average
                if ma is not None and eidx >= training_configs[
                        'moving_average_start_epoch']:
                    ma.step()

                # 4. update meters
                train_loss_meter.update(train_loss, grad_denom)
                sent_per_sec_meter.update(grad_denom)
                tok_per_sec_meter.update(cum_n_words)

                # 5. reset accumulated variables, update uidx
                update_cycle = training_configs['update_cycle']
                grad_denom = 0
                uidx += 1
                cum_n_words = 0.0
                train_loss = 0.0

            else:
                continue

            # ================================================================================== #
            # Display some information
            if should_trigger_by_steps(
                    uidx, eidx, every_n_step=training_configs['disp_freq']):

                lrate = list(optim.get_lrate())[0]

                if summary_writer is not None:
                    summary_writer.add_scalar(
                        "Speed(sents/sec)",
                        scalar_value=sent_per_sec_meter.ave,
                        global_step=uidx)
                    summary_writer.add_scalar(
                        "Speed(words/sec)",
                        scalar_value=tok_per_sec_meter.ave,
                        global_step=uidx)
                    summary_writer.add_scalar(
                        "train_loss",
                        scalar_value=train_loss_meter.ave,
                        global_step=uidx)
                    summary_writer.add_scalar("lrate",
                                              scalar_value=lrate,
                                              global_step=uidx)
                    summary_writer.add_scalar("oom_count",
                                              scalar_value=oom_count,
                                              global_step=uidx)

                # Reset Meters
                sent_per_sec_meter.reset()
                tok_per_sec_meter.reset()
                train_loss_meter.reset()

            # ================================================================================== #
            # Saving checkpoints
            # if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['save_freq'], debug=flags.debug):
            #     model_collections.add_to_collection("uidx", uidx)
            #     model_collections.add_to_collection("eidx", eidx)
            #     model_collections.add_to_collection("bad_count", bad_count)
            #
            #     if not is_early_stop:
            #         if rank == 0:
            #             checkpoint_saver.save(global_step=uidx,
            #                                   model=nmt_model,
            #                                   optim=optim,
            #                                   lr_scheduler=scheduler,
            #                                   collections=model_collections,
            #                                   ma=ma)

        torch.save(nmt_model.state_dict(), best_model_prefix + ".final")

        if training_progress_bar is not None:
            training_progress_bar.close()

        eidx += 1
        if eidx > training_configs["max_epochs"]:
            break
Example #8
0
def train(flags):
    """
    flags:
        saveto: str
        reload: store_true
        config_path: str
        pretrain_path: str, default=""
        model_name: str
        log_path: str
    """

    # ================================================================================== #
    # Initialization for training on different devices
    # - CPU/GPU
    # - Single/Distributed
    Constants.USE_GPU = flags.use_gpu

    if flags.multi_gpu:
        dist.distributed_init(flags.shared_dir)
        world_size = dist.get_world_size()
        rank = dist.get_rank()
        local_rank = dist.get_local_rank()
    else:
        world_size = 1
        rank = 0
        local_rank = 0

    if Constants.USE_GPU:
        torch.cuda.set_device(local_rank)
        Constants.CURRENT_DEVICE = "cuda:{0}".format(local_rank)
    else:
        Constants.CURRENT_DEVICE = "cpu"

    # If not root_rank, close logging
    # else write log of training to file.
    if rank == 0:
        write_log_to_file(
            os.path.join(flags.log_path,
                         "%s.log" % time.strftime("%Y%m%d-%H%M%S")))
    else:
        close_logging()

    # ================================================================================== #
    # Parsing configuration files
    # - Load default settings
    # - Load pre-defined settings
    # - Load user-defined settings

    configs = prepare_configs(flags.config_path, flags.predefined_config)

    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    optimizer_configs = configs['optimizer_configs']
    training_configs = configs['training_configs']

    INFO(pretty_configs(configs))

    # use odc
    if training_configs['use_odc'] is True:
        ave_best_k = check_odc_config(training_configs)
    else:
        ave_best_k = 0

    Constants.SEED = training_configs['seed']

    set_seed(Constants.SEED)

    timer = Timer()

    # ================================================================================== #
    # Load Data

    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_src = Vocabulary.build_from_file(**data_configs['vocabularies'][0])
    vocab_tgt = Vocabulary.build_from_file(**data_configs['vocabularies'][1])

    Constants.EOS = vocab_src.eos
    Constants.PAD = vocab_src.pad
    Constants.BOS = vocab_src.bos

    train_bitext_dataset = ZipDataset(
        TextLineDataset(data_path=data_configs['train_data'][0],
                        vocabulary=vocab_src,
                        max_len=data_configs['max_len'][0],
                        is_train_dataset=True),
        TextLineDataset(data_path=data_configs['train_data'][1],
                        vocabulary=vocab_tgt,
                        max_len=data_configs['max_len'][1],
                        is_train_dataset=True))

    valid_bitext_dataset = ZipDataset(
        TextLineDataset(
            data_path=data_configs['valid_data'][0],
            vocabulary=vocab_src,
            is_train_dataset=False,
        ),
        TextLineDataset(data_path=data_configs['valid_data'][1],
                        vocabulary=vocab_tgt,
                        is_train_dataset=False))

    training_iterator = DataIterator(
        dataset=train_bitext_dataset,
        batch_size=training_configs["batch_size"],
        use_bucket=training_configs['use_bucket'],
        buffer_size=training_configs['buffer_size'],
        batching_func=training_configs['batching_key'],
        world_size=world_size,
        rank=rank)

    valid_iterator = DataIterator(
        dataset=valid_bitext_dataset,
        batch_size=training_configs['valid_batch_size'],
        use_bucket=True,
        buffer_size=100000,
        numbering=True,
        world_size=world_size,
        rank=rank)

    bleu_scorer = SacreBLEUScorer(
        reference_path=data_configs["bleu_valid_reference"],
        num_refs=data_configs["num_refs"],
        lang_pair=data_configs["lang_pair"],
        sacrebleu_args=training_configs["bleu_valid_configs"]
        ['sacrebleu_args'],
        postprocess=training_configs["bleu_valid_configs"]['postprocess'])

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # ================================ Begin ======================================== #
    # Build Model & Optimizer
    # We would do steps below on after another
    #     1. build models & criterion
    #     2. move models & criterion to gpu if needed
    #     3. load pre-trained model if needed
    #     4. build optimizer
    #     5. build learning rate scheduler if needed
    #     6. load checkpoints if needed

    # 0. Initial

    lrate = optimizer_configs['learning_rate']
    model_collections = Collections()

    checkpoint_saver = Saver(
        save_prefix="{0}.ckpt".format(
            os.path.join(flags.saveto, flags.model_name)),
        num_max_keeping=training_configs['num_kept_checkpoints'])

    best_model_prefix = os.path.join(
        flags.saveto, flags.model_name + Constants.MY_BEST_MODEL_SUFFIX)

    best_k_saver = BestKSaver(
        save_prefix="{0}.best_k_ckpt".format(
            os.path.join(flags.saveto, flags.model_name)),
        num_max_keeping=training_configs['num_kept_best_k_checkpoints'])

    # 1. Build Model & Criterion
    INFO('Building model...')
    timer.tic()
    nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                            n_tgt_vocab=vocab_tgt.max_n_words,
                            padding_idx=vocab_src.pad,
                            vocab_src=vocab_src,
                            **model_configs)
    INFO(nmt_model)

    # build teacher model
    teacher_model, teacher_model_path = get_teacher_model(
        training_configs, model_configs, vocab_src, vocab_tgt, flags)

    # build critic
    critic = CombinationCriterion(model_configs['loss_configs'],
                                  padding_idx=vocab_tgt.pad,
                                  teacher=teacher_model)
    # INFO(critic)
    critic.INFO()

    # 2. Move to GPU
    if Constants.USE_GPU:
        nmt_model = nmt_model.cuda()
        critic = critic.cuda()

    # 3. Load pretrained model if needed
    load_pretrained_model(nmt_model,
                          flags.pretrain_path,
                          exclude_prefix=None,
                          device=Constants.CURRENT_DEVICE)

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # 4. Build optimizer
    INFO('Building Optimizer...')

    if not flags.multi_gpu:
        optim = Optimizer(name=optimizer_configs['optimizer'],
                          model=nmt_model,
                          lr=lrate,
                          grad_clip=optimizer_configs['grad_clip'],
                          optim_args=optimizer_configs['optimizer_params'],
                          update_cycle=training_configs['update_cycle'])
    else:
        optim = dist.DistributedOptimizer(
            name=optimizer_configs['optimizer'],
            model=nmt_model,
            lr=lrate,
            grad_clip=optimizer_configs['grad_clip'],
            optim_args=optimizer_configs['optimizer_params'],
            device_id=local_rank)

    # 5. Build scheduler for optimizer if needed
    scheduler = build_scheduler(
        schedule_method=optimizer_configs['schedule_method'],
        optimizer=optim,
        scheduler_configs=optimizer_configs['scheduler_configs'])

    # 6. build moving average
    ma = build_ma(training_configs, nmt_model.named_parameters())

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # Reload from latest checkpoint
    if flags.reload:
        checkpoint_saver.load_latest(model=nmt_model,
                                     optim=optim,
                                     lr_scheduler=scheduler,
                                     collections=model_collections,
                                     ma=ma,
                                     device=Constants.CURRENT_DEVICE)

    # broadcast parameters and optimizer states
    if world_size > 1:
        INFO("Broadcasting model parameters...")
        dist.broadcast_parameters(params=nmt_model.state_dict())
        INFO("Broadcasting optimizer states...")
        dist.broadcast_optimizer_state(optimizer=optim.optim)
        INFO('Done.')

    # ================================================================================== #
    # Prepare training

    eidx = model_collections.get_collection("eidx", [0])[-1]
    uidx = model_collections.get_collection("uidx", [1])[-1]
    bad_count = model_collections.get_collection("bad_count", [0])[-1]
    oom_count = model_collections.get_collection("oom_count", [0])[-1]
    is_early_stop = model_collections.get_collection("is_early_stop", [
        False,
    ])[-1]
    teacher_patience = model_collections.get_collection(
        "teacher_patience", [training_configs['teacher_patience']])[-1]

    train_loss_meter = AverageMeter()
    train_loss_dict_meter = AverageMeterDict(critic.get_critic_name())
    sent_per_sec_meter = TimeMeter()
    tok_per_sec_meter = TimeMeter()

    update_cycle = training_configs['update_cycle']
    grad_denom = 0
    train_loss = 0.0
    cum_n_words = 0
    train_loss_dict = dict()
    valid_loss = best_valid_loss = float('inf')

    if rank == 0:
        summary_writer = SummaryWriter(log_dir=flags.log_path)
    else:
        summary_writer = None

    sent_per_sec_meter.start()
    tok_per_sec_meter.start()

    INFO('Begin training...')

    while True:

        if summary_writer is not None:
            summary_writer.add_scalar("Epoch", (eidx + 1), uidx)

        # Build iterator and progress bar
        training_iter = training_iterator.build_generator()

        if rank == 0:
            training_progress_bar = tqdm(desc=' - (Epc {}, Upd {}) '.format(
                eidx, uidx),
                                         total=len(training_iterator),
                                         unit="sents")
        else:
            training_progress_bar = None

        for batch in training_iter:

            seqs_x, seqs_y = batch

            batch_size = len(seqs_x)
            cum_n_words += sum(len(s) for s in seqs_y)

            try:
                # Prepare data
                x, y = prepare_data(seqs_x, seqs_y, cuda=Constants.USE_GPU)

                loss, loss_dict = compute_forward(
                    model=nmt_model,
                    critic=critic,
                    seqs_x=x,
                    seqs_y=y,
                    eval=False,
                    normalization=1.0,
                    norm_by_words=training_configs["norm_by_words"])

                update_cycle -= 1
                grad_denom += batch_size
                train_loss += loss
                train_loss_dict = add_dict_value(train_loss_dict, loss_dict)

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    oom_count += 1
                else:
                    raise e

            # When update_cycle becomes 0, it means end of one batch. Several things will be done:
            # - update parameters
            # - reset update_cycle and grad_denom, update uidx
            # - learning rate scheduling
            # - update moving average

            if update_cycle == 0:

                # 0. reduce variables
                if world_size > 1:
                    grad_denom = dist.all_reduce_py(grad_denom)
                    train_loss = dist.all_reduce_py(train_loss)
                    train_loss_dict = dist.all_reduce_py(train_loss_dict)
                    cum_n_words = dist.all_reduce_py(cum_n_words)

                # 1. update parameters
                optim.step(denom=grad_denom)
                optim.zero_grad()

                if training_progress_bar is not None:
                    training_progress_bar.update(grad_denom)
                    training_progress_bar.set_description(
                        ' - (Epc {}, Upd {}) '.format(eidx, uidx))

                    postfix_str = 'TrainLoss: {:.2f}, ValidLoss(best): {:.2f} ({:.2f}), '.format(
                        train_loss, valid_loss, best_valid_loss)
                    for critic_name, loss_value in train_loss_dict.items():
                        postfix_str += (critic_name +
                                        ': {:.2f}, ').format(loss_value)
                    training_progress_bar.set_postfix_str(postfix_str)

                # 2. learning rate scheduling
                if scheduler is not None and optimizer_configs[
                        "schedule_method"] != "loss":
                    scheduler.step(global_step=uidx)

                # 3. update moving average
                if ma is not None and eidx >= training_configs[
                        'moving_average_start_epoch']:
                    ma.step()

                # 4. update meters
                train_loss_meter.update(train_loss, grad_denom)
                train_loss_dict_meter.update(train_loss_dict, grad_denom)
                sent_per_sec_meter.update(grad_denom)
                tok_per_sec_meter.update(cum_n_words)

                # 5. reset accumulated variables, update uidx
                update_cycle = training_configs['update_cycle']
                grad_denom = 0
                uidx += 1
                cum_n_words = 0.0
                train_loss = 0.0
                train_loss_dict = dict()

            else:
                continue

            # ================================================================================== #
            # Display some information
            if should_trigger_by_steps(
                    uidx, eidx, every_n_step=training_configs['disp_freq']):

                lrate = list(optim.get_lrate())[0]

                if summary_writer is not None:
                    summary_writer.add_scalar(
                        "Speed(sents/sec)",
                        scalar_value=sent_per_sec_meter.ave,
                        global_step=uidx)
                    summary_writer.add_scalar(
                        "Speed(words/sec)",
                        scalar_value=tok_per_sec_meter.ave,
                        global_step=uidx)
                    summary_writer.add_scalar(
                        "train_loss",
                        scalar_value=train_loss_meter.ave,
                        global_step=uidx)
                    # add loss for every critic
                    if flags.display_loss_detail:
                        combination_loss = train_loss_dict_meter.value
                        for key, value in combination_loss.items():
                            summary_writer.add_scalar(key,
                                                      scalar_value=value,
                                                      global_step=uidx)
                    summary_writer.add_scalar("lrate",
                                              scalar_value=lrate,
                                              global_step=uidx)
                    summary_writer.add_scalar("oom_count",
                                              scalar_value=oom_count,
                                              global_step=uidx)

                # Reset Meters
                sent_per_sec_meter.reset()
                tok_per_sec_meter.reset()
                train_loss_meter.reset()
                train_loss_dict_meter.reset()

            # ================================================================================== #
            # Loss Validation & Learning rate annealing
            if should_trigger_by_steps(
                    global_step=uidx,
                    n_epoch=eidx,
                    every_n_step=training_configs['loss_valid_freq'],
                    debug=flags.debug):
                with cache_parameters(nmt_model):

                    valid_loss, valid_loss_dict = loss_evaluation(
                        model=nmt_model,
                        critic=critic,
                        valid_iterator=valid_iterator,
                        rank=rank,
                        world_size=world_size)

                if scheduler is not None and optimizer_configs[
                        "schedule_method"] == "loss":
                    scheduler.step(metric=valid_loss)

                model_collections.add_to_collection("history_losses",
                                                    valid_loss)

                min_history_loss = np.array(
                    model_collections.get_collection("history_losses")).min()
                best_valid_loss = min_history_loss

                if summary_writer is not None:
                    summary_writer.add_scalar("loss",
                                              valid_loss,
                                              global_step=uidx)
                    summary_writer.add_scalar("best_loss",
                                              min_history_loss,
                                              global_step=uidx)

            # ================================================================================== #
            # BLEU Validation & Early Stop
            if should_trigger_by_steps(
                    global_step=uidx,
                    n_epoch=eidx,
                    every_n_step=training_configs['bleu_valid_freq'],
                    min_step=training_configs['bleu_valid_warmup'],
                    debug=flags.debug):

                with cache_parameters(nmt_model):

                    valid_bleu = bleu_evaluation(
                        uidx=uidx,
                        valid_iterator=valid_iterator,
                        batch_size=training_configs["bleu_valid_batch_size"],
                        model=nmt_model,
                        bleu_scorer=bleu_scorer,
                        vocab_src=vocab_src,
                        vocab_tgt=vocab_tgt,
                        valid_dir=flags.valid_path,
                        max_steps=training_configs["bleu_valid_configs"]
                        ["max_steps"],
                        beam_size=training_configs["bleu_valid_configs"]
                        ["beam_size"],
                        alpha=training_configs["bleu_valid_configs"]["alpha"],
                        world_size=world_size,
                        rank=rank,
                    )

                model_collections.add_to_collection(key="history_bleus",
                                                    value=valid_bleu)

                best_valid_bleu = float(
                    np.array(model_collections.get_collection(
                        "history_bleus")).max())

                if summary_writer is not None:
                    summary_writer.add_scalar("bleu", valid_bleu, uidx)
                    summary_writer.add_scalar("best_bleu", best_valid_bleu,
                                              uidx)

                # If model get new best valid bleu score
                if valid_bleu >= best_valid_bleu:
                    bad_count = 0

                    if is_early_stop is False:
                        if rank == 0:
                            # 1. save the best model
                            torch.save(nmt_model.state_dict(),
                                       best_model_prefix + ".final")

                else:
                    bad_count += 1

                    # At least one epoch should be traversed
                    if bad_count >= training_configs[
                            'early_stop_patience'] and eidx > 0:
                        is_early_stop = True
                        WARN("Early Stop!")
                        exit(0)

                if rank == 0:
                    best_k_saver.save(global_step=uidx,
                                      metric=valid_bleu,
                                      model=nmt_model,
                                      optim=optim,
                                      lr_scheduler=scheduler,
                                      collections=model_collections,
                                      ma=ma)

                # ODC
                if training_configs['use_odc'] is True:
                    if valid_bleu >= best_valid_bleu:
                        pass

                        # choose method to generate teachers from checkpoints
                        # - best
                        # - ave_k_best
                        # - ma

                        if training_configs['teacher_choice'] == 'ma':
                            teacher_params = ma.export_ma_params()
                        elif training_configs['teacher_choice'] == 'best':
                            teacher_params = nmt_model.state_dict()
                        elif "ave_best" in training_configs['teacher_choice']:
                            if best_k_saver.num_saved >= ave_best_k:
                                teacher_params = average_checkpoints(
                                    best_k_saver.get_all_ckpt_path()
                                    [-ave_best_k:])
                            else:
                                teacher_params = nmt_model.state_dict()
                        else:
                            raise ValueError(
                                "can not support teacher choice %s" %
                                training_configs['teacher_choice'])
                        torch.save(teacher_params, teacher_model_path)
                        del teacher_params
                        teacher_patience = 0
                        critic.set_use_KD(False)
                    else:
                        teacher_patience += 1
                        if teacher_patience >= training_configs[
                                'teacher_refresh_warmup']:
                            teacher_params = torch.load(
                                teacher_model_path,
                                map_location=Constants.CURRENT_DEVICE)
                            teacher_model.load_state_dict(teacher_params,
                                                          strict=False)
                            del teacher_params
                            critic.reset_teacher(teacher_model)
                            critic.set_use_KD(True)

                if summary_writer is not None:
                    summary_writer.add_scalar("bad_count", bad_count, uidx)

                info_str = "{0} Loss: {1:.2f} BLEU: {2:.2f} lrate: {3:6f} patience: {4} ".format(
                    uidx, valid_loss, valid_bleu, lrate, bad_count)
                for key, value in valid_loss_dict.items():
                    info_str += (key + ': {0:.2f} '.format(value))
                INFO(info_str)

            # ================================================================================== #
            # Saving checkpoints
            if should_trigger_by_steps(
                    uidx,
                    eidx,
                    every_n_step=training_configs['save_freq'],
                    debug=flags.debug):
                model_collections.add_to_collection("uidx", uidx)
                model_collections.add_to_collection("eidx", eidx)
                model_collections.add_to_collection("bad_count", bad_count)
                model_collections.add_to_collection("teacher_patience",
                                                    teacher_patience)
                if not is_early_stop:
                    if rank == 0:
                        checkpoint_saver.save(global_step=uidx,
                                              model=nmt_model,
                                              optim=optim,
                                              lr_scheduler=scheduler,
                                              collections=model_collections,
                                              ma=ma)

        if training_progress_bar is not None:
            training_progress_bar.close()

        eidx += 1
        if eidx > training_configs["max_epochs"]:
            break