예제 #1
0
def train(FLAGS):
    """
    FLAGS:
        saveto: str
        reload: store_true
        config_path: str
        pretrain_path: str, default=""
        model_name: str
        log_path: str
    """

    # write log of training to file.
    write_log_to_file(
        os.path.join(FLAGS.log_path,
                     "%s.log" % time.strftime("%Y%m%d-%H%M%S")))

    GlobalNames.USE_GPU = FLAGS.use_gpu

    if GlobalNames.USE_GPU:
        CURRENT_DEVICE = "cpu"
    else:
        CURRENT_DEVICE = "cuda:0"

    config_path = os.path.abspath(FLAGS.config_path)
    with open(config_path.strip()) as f:
        configs = yaml.load(f)

    INFO(pretty_configs(configs))

    # Add default configs
    configs = default_configs(configs)
    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    optimizer_configs = configs['optimizer_configs']
    training_configs = configs['training_configs']

    GlobalNames.SEED = training_configs['seed']

    set_seed(GlobalNames.SEED)

    best_model_prefix = os.path.join(
        FLAGS.saveto, FLAGS.model_name + GlobalNames.MY_BEST_MODEL_SUFFIX)

    timer = Timer()

    # ================================================================================== #
    # Load Data

    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_tgt = Vocabulary(**data_configs["vocabularies"][0])

    train_batch_size = training_configs["batch_size"] * max(
        1, training_configs["update_cycle"])
    train_buffer_size = training_configs["buffer_size"] * max(
        1, training_configs["update_cycle"])

    train_bitext_dataset = ZipDataset(TextLineDataset(
        data_path=data_configs['train_data'][0],
        vocabulary=vocab_tgt,
        max_len=data_configs['max_len'][0],
    ),
                                      shuffle=training_configs['shuffle'])

    valid_bitext_dataset = ZipDataset(
        TextLineDataset(
            data_path=data_configs['valid_data'][0],
            vocabulary=vocab_tgt,
        ))

    training_iterator = DataIterator(
        dataset=train_bitext_dataset,
        batch_size=train_batch_size,
        use_bucket=training_configs['use_bucket'],
        buffer_size=train_buffer_size,
        batching_func=training_configs['batching_key'])

    valid_iterator = DataIterator(
        dataset=valid_bitext_dataset,
        batch_size=training_configs['valid_batch_size'],
        use_bucket=True,
        buffer_size=100000,
        numbering=True)

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    lrate = optimizer_configs['learning_rate']
    is_early_stop = False

    # ================================ Begin ======================================== #
    # Build Model & Optimizer
    # We would do steps below on after another
    #     1. build models & criterion
    #     2. move models & criterion to gpu if needed
    #     3. load pre-trained model if needed
    #     4. build optimizer
    #     5. build learning rate scheduler if needed
    #     6. load checkpoints if needed

    # 0. Initial
    model_collections = Collections()
    checkpoint_saver = Saver(
        save_prefix="{0}.ckpt".format(
            os.path.join(FLAGS.saveto, FLAGS.model_name)),
        num_max_keeping=training_configs['num_kept_checkpoints'])
    best_model_saver = Saver(
        save_prefix=best_model_prefix,
        num_max_keeping=training_configs['num_kept_best_model'])

    # 1. Build Model & Criterion
    INFO('Building model...')
    timer.tic()
    lm_model = build_model(n_tgt_vocab=vocab_tgt.max_n_words, **model_configs)
    INFO(lm_model)

    params_total = sum([p.numel() for n, p in lm_model.named_parameters()])
    params_with_embedding = sum([
        p.numel() for n, p in lm_model.named_parameters()
        if n.find('embedding') == -1
    ])
    INFO('Total parameters: {}'.format(params_total))
    INFO('Total parameters (excluding word embeddings): {}'.format(
        params_with_embedding))

    critic = NMTCriterion(label_smoothing=model_configs['label_smoothing'])

    INFO(critic)
    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # 2. Move to GPU
    if GlobalNames.USE_GPU:
        lm_model = lm_model.cuda()
        critic = critic.cuda()

    # 3. Load pretrained model if needed
    lm_model.init_parameters(FLAGS.pretrain_path, device=CURRENT_DEVICE)

    # 4. Build optimizer
    INFO('Building Optimizer...')
    optim = Optimizer(name=optimizer_configs['optimizer'],
                      model=lm_model,
                      lr=lrate,
                      grad_clip=optimizer_configs['grad_clip'],
                      optim_args=optimizer_configs['optimizer_params'])

    # 5. Build scheduler for optimizer if needed
    if optimizer_configs['schedule_method'] is not None:

        if optimizer_configs['schedule_method'] == "loss":

            scheduler = ReduceOnPlateauScheduler(
                optimizer=optim, **optimizer_configs["scheduler_configs"])

        elif optimizer_configs['schedule_method'] == "noam":
            scheduler = NoamScheduler(optimizer=optim,
                                      **optimizer_configs['scheduler_configs'])
        else:
            WARN(
                "Unknown scheduler name {0}. Do not use lr_scheduling.".format(
                    optimizer_configs['schedule_method']))
            scheduler = None
    else:
        scheduler = None

    # 6. build moving average

    if training_configs['moving_average_method'] is not None:
        ma = MovingAverage(
            moving_average_method=training_configs['moving_average_method'],
            named_params=lm_model.named_parameters(),
            alpha=training_configs['moving_average_alpha'])
    else:
        ma = None

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # Reload from latest checkpoint
    if FLAGS.reload:
        checkpoint_saver.load_latest(model=lm_model,
                                     optim=optim,
                                     lr_scheduler=scheduler,
                                     collections=model_collections,
                                     ma=ma)

    # ================================================================================== #
    # Prepare training

    eidx = model_collections.get_collection("eidx", [0])[-1]
    uidx = model_collections.get_collection("uidx", [0])[-1]
    bad_count = model_collections.get_collection("bad_count", [0])[-1]
    oom_count = model_collections.get_collection("oom_count", [0])[-1]

    summary_writer = SummaryWriter(log_dir=FLAGS.log_path)

    cum_samples = 0
    cum_words = 0
    valid_loss = best_valid_loss = float('inf')  # Max Float
    saving_files = []

    # Timer for computing speed
    timer_for_speed = Timer()
    timer_for_speed.tic()

    INFO('Begin training...')

    while True:
        summary_writer.add_scalar("Epoch", (eidx + 1), uidx)

        # Build iterator and progress bar
        training_iter = training_iterator.build_generator()
        training_progress_bar = tqdm(desc=' - (Epc {}, Upd {}) '.format(
            eidx, uidx),
                                     total=len(training_iterator),
                                     unit="sents")
        for batch in training_iter:

            uidx += 1

            if optimizer_configs[
                    "schedule_method"] is not None and optimizer_configs[
                        "schedule_method"] != "loss":
                scheduler.step(global_step=uidx)

            seqs_y = batch

            n_samples_t = len(seqs_y)
            n_words_t = sum(len(s) for s in seqs_y)

            cum_samples += n_samples_t
            cum_words += n_words_t

            train_loss = 0.
            optim.zero_grad()
            try:
                # Prepare data
                for (seqs_y_t, ) in split_shard(
                        seqs_y, split_size=training_configs['update_cycle']):
                    y = prepare_data(seqs_y_t, cuda=GlobalNames.USE_GPU)

                    loss = compute_forward(
                        model=lm_model,
                        critic=critic,
                        # seqs_x=x,
                        seqs_y=y,
                        eval=False,
                        normalization=n_samples_t,
                        norm_by_words=training_configs["norm_by_words"])
                    train_loss += loss / y.size(
                        1) if not training_configs["norm_by_words"] else loss
                optim.step()

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    oom_count += 1
                    optim.zero_grad()
                else:
                    raise e

            if ma is not None and eidx >= training_configs[
                    'moving_average_start_epoch']:
                ma.step()

            training_progress_bar.update(n_samples_t)
            training_progress_bar.set_description(
                ' - (Epc {}, Upd {}) '.format(eidx, uidx))
            training_progress_bar.set_postfix_str(
                'TrainLoss: {:.2f}, ValidLoss(best): {:.2f} ({:.2f})'.format(
                    train_loss, valid_loss, best_valid_loss))
            summary_writer.add_scalar("train_loss",
                                      scalar_value=train_loss,
                                      global_step=uidx)

            # ================================================================================== #
            # Display some information
            if should_trigger_by_steps(
                    uidx, eidx, every_n_step=training_configs['disp_freq']):
                # words per second and sents per second
                words_per_sec = cum_words / (timer.toc(return_seconds=True))
                sents_per_sec = cum_samples / (timer.toc(return_seconds=True))
                lrate = list(optim.get_lrate())[0]

                summary_writer.add_scalar("Speed(words/sec)",
                                          scalar_value=words_per_sec,
                                          global_step=uidx)
                summary_writer.add_scalar("Speed(sents/sen)",
                                          scalar_value=sents_per_sec,
                                          global_step=uidx)
                summary_writer.add_scalar("lrate",
                                          scalar_value=lrate,
                                          global_step=uidx)
                summary_writer.add_scalar("oom_count",
                                          scalar_value=oom_count,
                                          global_step=uidx)

                # Reset timer
                timer.tic()
                cum_words = 0
                cum_samples = 0

            # ================================================================================== #
            # Saving checkpoints
            if should_trigger_by_steps(
                    uidx,
                    eidx,
                    every_n_step=training_configs['save_freq'],
                    debug=FLAGS.debug):
                model_collections.add_to_collection("uidx", uidx)
                model_collections.add_to_collection("eidx", eidx)
                model_collections.add_to_collection("bad_count", bad_count)

                if not is_early_stop:
                    checkpoint_saver.save(global_step=uidx,
                                          model=lm_model,
                                          optim=optim,
                                          lr_scheduler=scheduler,
                                          collections=model_collections,
                                          ma=ma)

            # ================================================================================== #
            # Loss Validation & Learning rate annealing
            if should_trigger_by_steps(
                    global_step=uidx,
                    n_epoch=eidx,
                    every_n_step=training_configs['loss_valid_freq'],
                    debug=FLAGS.debug):

                if ma is not None:
                    origin_state_dict = deepcopy(lm_model.state_dict())
                    lm_model.load_state_dict(ma.export_ma_params(),
                                             strict=False)

                valid_loss = loss_validation(
                    model=lm_model,
                    critic=critic,
                    valid_iterator=valid_iterator,
                    norm_by_words=training_configs["norm_by_words"])

                model_collections.add_to_collection("history_losses",
                                                    valid_loss)

                min_history_loss = np.array(
                    model_collections.get_collection("history_losses")).min()

                summary_writer.add_scalar("loss", valid_loss, global_step=uidx)
                summary_writer.add_scalar("best_loss",
                                          min_history_loss,
                                          global_step=uidx)

                if ma is not None:
                    lm_model.load_state_dict(origin_state_dict)
                    del origin_state_dict

                if optimizer_configs["schedule_method"] == "loss":
                    scheduler.step(metric=best_valid_loss)

                # If model get new best valid loss
                if valid_loss < best_valid_loss:
                    bad_count = 0

                    if is_early_stop is False:
                        # 1. save the best model
                        torch.save(lm_model.state_dict(),
                                   best_model_prefix + ".final")

                        # 2. record all several best models
                        best_model_saver.save(global_step=uidx, model=lm_model)
                else:
                    bad_count += 1

                    # At least one epoch should be traversed
                    if bad_count >= training_configs[
                            'early_stop_patience'] and eidx > 0:
                        is_early_stop = True
                        WARN("Early Stop!")

                best_valid_loss = min_history_loss

                summary_writer.add_scalar("bad_count", bad_count, uidx)

                INFO("{0} Loss: {1:.2f} lrate: {2:6f} patience: {3}".format(
                    uidx, valid_loss, lrate, bad_count))

        training_progress_bar.close()

        eidx += 1
        if eidx > training_configs["max_epochs"]:
            break
예제 #2
0
    def _init_local_optims(self, rephraser_optimizer_configs):
        """ actor, critic, alpha optimizers and lr scheduler if necessary
        rephraser_optimizer_configs:
            optimizer: "adafactor"
            learning_rate: 0.01
            grad_clip: -1.0
            optimizer_params: ~
            schedule_method: rsqrt
            scheduler_configs:
                d_model: *dim
                warmup_steps: 100
        """
        # initiate local optimizer
        if rephraser_optimizer_configs is None:
            self.actor_optimizer = None
            self.critic_optimizer = None
            self.log_alpha_optimizer = None
            # self.actor_icm_optimizer = None
            self.actor_scheduler = None
            self.critic_scheduler = None
        else:
            self.actor_optimizer = Optimizer(
                name=rephraser_optimizer_configs["optimizer"],
                model=self.actor,
                lr=rephraser_optimizer_configs["learning_rate"],
                grad_clip=rephraser_optimizer_configs["grad_clip"],
                optim_args=rephraser_optimizer_configs["optimizer_params"])
            self.critic_optimizer = Optimizer(
                name=rephraser_optimizer_configs["optimizer"],
                model=self.critic,
                lr=rephraser_optimizer_configs["learning_rate"],
                grad_clip=rephraser_optimizer_configs["grad_clip"],
                optim_args=rephraser_optimizer_configs["optimizer_params"])
            # hardcoded entropy weight updates and icm updates
            self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],
                                                        lr=1e-4,
                                                        betas=(0.9, 0.999))
            # self.actor_icm_optimizer = torch.optim.Adam(self.actor.icm.parameters(), lr=1e-3, )

            # Build scheduler for optimizer if needed
            if rephraser_optimizer_configs['schedule_method'] is not None:
                if rephraser_optimizer_configs['schedule_method'] == "loss":
                    self.actor_scheduler = ReduceOnPlateauScheduler(
                        optimizer=self.actor_optimizer,
                        **rephraser_optimizer_configs["scheduler_configs"])
                    self.critic_scheduler = ReduceOnPlateauScheduler(
                        optimizer=self.critic_optimizer,
                        **rephraser_optimizer_configs["scheduler_configs"])
                elif rephraser_optimizer_configs['schedule_method'] == "noam":
                    self.actor_scheduler = NoamScheduler(
                        optimizer=self.actor_optimizer,
                        **rephraser_optimizer_configs["scheduler_configs"])
                    self.critic_scheduler = NoamScheduler(
                        optimizer=self.critic_optimizer,
                        **rephraser_optimizer_configs["scheduler_configs"])
                elif rephraser_optimizer_configs["schedule_method"] == "rsqrt":
                    self.actor_scheduler = RsqrtScheduler(
                        optimizer=self.actor_optimizer,
                        **rephraser_optimizer_configs["scheduler_configs"])
                    self.critic_scheduler = RsqrtScheduler(
                        optimizer=self.critic_optimizer,
                        **rephraser_optimizer_configs["scheduler_configs"])
                else:
                    WARN(
                        "Unknown scheduler name {0}. Do not use lr_scheduling."
                        .format(
                            rephraser_optimizer_configs['schedule_method']))
                    self.actor_scheduler = None
                    self.critic_scheduler = None
            else:
                self.actor_scheduler = None
                self.critic_scheduler = None
예제 #3
0
def run():
    # default actor threads as 1
    os.environ["OMP_NUM_THREADS"] = "1"
    mp = _mp.get_context('spawn')
    args = parser.parse_args()
    if not os.path.exists(args.save_to):
        os.mkdir(args.save_to)
    with open(args.config_path, "r") as f, \
            open(os.path.join(args.save_to, "current_attack_configs.yaml"), "w") as current_configs:
        configs = yaml.load(f)
        yaml.dump(configs, current_configs)
    attack_configs = configs["attack_configs"]
    attacker_configs = configs["attacker_configs"]
    attacker_model_configs = attacker_configs["attacker_model_configs"]
    attacker_optimizer_configs = attacker_configs["attacker_optimizer_configs"]
    discriminator_configs = configs["discriminator_configs"]
    # training_configs = configs["training_configs"]

    # initial best saver for global model
    global_saver = Saver(
        save_prefix="{0}.final".format(os.path.join(args.save_to, "ACmodel")),
        num_max_keeping=attack_configs["num_kept_checkpoints"])
    # the Global variable of  USE_GPU is mainly used for environments
    GlobalNames.SEED = attack_configs["seed"]
    GlobalNames.USE_GPU = args.use_gpu
    torch.manual_seed(GlobalNames.SEED)

    # build vocabulary and data iterator for env
    with open(attack_configs["victim_configs"], "r") as victim_f:
        victim_configs = yaml.load(victim_f)
    data_configs = victim_configs["data_configs"]
    src_vocab = Vocabulary(**data_configs["vocabularies"][0])
    trg_vocab = Vocabulary(**data_configs["vocabularies"][1])
    data_set = ZipDataset(
        TextLineDataset(
            data_path=data_configs["train_data"][0],
            vocabulary=src_vocab,
        ),
        TextLineDataset(
            data_path=data_configs["train_data"][1],
            vocabulary=trg_vocab,
        ),
        shuffle=attack_configs["shuffle"]
    )  # we build the parallel data sets and iterate inside a thread

    # global model variables (trg network to save the results)
    global_attacker = attacker.Attacker(src_vocab.max_n_words,
                                        **attacker_model_configs)
    global_attacker = global_attacker.cpu()
    global_attacker.share_memory()
    if args.share_optim:
        # initiate optimizer and set to share mode
        optimizer = Optimizer(
            name=attacker_optimizer_configs["optimizer"],
            model=global_attacker,
            lr=attacker_optimizer_configs["learning_rate"],
            grad_clip=attacker_optimizer_configs["grad_clip"],
            optim_args=attacker_optimizer_configs["optimizer_params"])
        optimizer.optim.share_memory()
        # Build scheduler for optimizer if needed
        if attacker_optimizer_configs['schedule_method'] is not None:
            if attacker_optimizer_configs['schedule_method'] == "loss":
                scheduler = ReduceOnPlateauScheduler(
                    optimizer=optimizer,
                    **attacker_optimizer_configs["scheduler_configs"])
            elif attacker_optimizer_configs['schedule_method'] == "noam":
                scheduler = NoamScheduler(
                    optimizer=optimizer,
                    **attacker_optimizer_configs['scheduler_configs'])
            elif attacker_optimizer_configs["schedule_method"] == "rsqrt":
                scheduler = RsqrtScheduler(
                    optimizer=optimizer,
                    **attacker_optimizer_configs["scheduler_configs"])
            else:
                WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".
                     format(attacker_optimizer_configs['schedule_method']))
                scheduler = None
        else:
            scheduler = None
    else:
        optimizer = None
        scheduler = None

    # load from checkpoint for global model
    global_saver.load_latest(model=global_attacker,
                             optim=optimizer,
                             lr_scheduler=scheduler)

    if args.use_gpu:
        # collect available devices and distribute env on the available gpu
        device = "cuda"
        devices = []
        for i in range(torch.cuda.device_count()):
            devices += ["cuda:%d" % i]
        print("available gpus:", devices)
    else:
        device = "cpu"
        devices = [device]

    process = []
    counter = mp.Value("i", 0)
    lock = mp.Lock()  # for multiple attackers update

    INFO("extract near candidates")
    _, _ = load_or_extract_near_vocab(
        config_path=attack_configs["victim_configs"],
        model_path=attack_configs["victim_model"],
        init_perturb_rate=attack_configs["init_perturb_rate"],
        save_to=os.path.join(args.save_to, "near_vocab"),
        save_to_full=os.path.join(args.save_to, "full_near_vocab"),
        top_reserve=12,
        emit_as_id=True)

    # train(0, device, args, counter, lock,
    #       attack_configs, discriminator_configs,
    #       src_vocab, trg_vocab, data_set,
    #       global_attacker, attacker_configs,
    #       optimizer, scheduler,
    #       global_saver)

    # valid(args.n, device, args,
    #      attack_configs, discriminator_configs,
    #      src_vocab, trg_vocab, data_set,
    #      global_attacker, attacker_configs, counter)
    # run multiple training process of local attacker to update global one

    for rank in range(args.n):
        print("initialize training thread on cuda:%d" % (rank + 1))
        p = mp.Process(target=train,
                       args=(rank, "cuda:%d" % (rank + 1), args, counter, lock,
                             attack_configs, discriminator_configs, src_vocab,
                             trg_vocab, data_set, global_attacker,
                             attacker_configs, optimizer, scheduler,
                             global_saver))
        p.start()
        process.append(p)
    # run the dev thread for initiation
    print("initialize dev thread on cuda:0")
    p = mp.Process(target=valid,
                   args=(0, "cuda:0", args, attack_configs,
                         discriminator_configs, src_vocab, trg_vocab, data_set,
                         global_attacker, attacker_configs, counter))
    p.start()
    process.append(p)

    for p in process:
        p.join()
예제 #4
0
def train(rank,
          device,
          args,
          counter,
          lock,
          attack_configs,
          discriminator_configs,
          src_vocab,
          trg_vocab,
          data_set,
          global_attacker,
          attacker_configs,
          optimizer=None,
          scheduler=None,
          saver=None):
    """
    running train process
    #1# train the env_discriminator
    #2# run attacker AC based on rewards from trained env_discriminator
    #3# run training updates attacker AC
    #4#
    :param rank: (int) the rank of the process (from multiprocess)
    :param device: the device of the process
    :param counter: python multiprocess variable
    :param lock: python multiprocess variable
    :param args: global args
    :param attack_configs: attack settings
    :param discriminator_configs: discriminator settings
    :param src_vocab:
    :param trg_vocab:
    :param data_set: (data_iterator object) provide batched data labels
    :param global_attacker: the model to sync from
    :param attacker_configs: local attacker settings
    :param optimizer: uses shared optimizer for the attacker
            use local one if none
    :param scheduler: uses shared scheduler for the attacker,
            use local one if none
    :param saver: model saver
    :return:
    """
    trust_acc = acc_bound = discriminator_configs["acc_bound"]
    converged_bound = discriminator_configs["converged_bound"]
    patience = discriminator_configs["patience"]
    attacker_model_configs = attacker_configs["attacker_model_configs"]
    attacker_optimizer_configs = attacker_configs["attacker_optimizer_configs"]

    # this is for multi-processing, GlobalNames can not be direct inherited
    GlobalNames.USE_GPU = args.use_gpu
    GlobalNames.SEED = attack_configs["seed"]
    torch.manual_seed(GlobalNames.SEED + rank)

    # initiate local saver and load checkpoint if possible
    local_saver = Saver(save_prefix="{0}.local".format(
        os.path.join(args.save_to, "train_env%d" % rank, "ACmodel")),
                        num_max_keeping=attack_configs["num_kept_checkpoints"])

    attack_iterator = DataIterator(dataset=data_set,
                                   batch_size=attack_configs["batch_size"],
                                   use_bucket=True,
                                   buffer_size=attack_configs["buffer_size"],
                                   numbering=True)

    summary_writer = SummaryWriter(
        log_dir=os.path.join(args.save_to, "train_env%d" % rank))
    local_attacker = attacker.Attacker(src_vocab.max_n_words,
                                       **attacker_model_configs)
    # build optimizer for attacker
    if optimizer is None:
        optimizer = Optimizer(
            name=attacker_optimizer_configs["optimizer"],
            model=global_attacker,
            lr=attacker_optimizer_configs["learning_rate"],
            grad_clip=attacker_optimizer_configs["grad_clip"],
            optim_args=attacker_optimizer_configs["optimizer_params"])
        # Build scheduler for optimizer if needed
        if attacker_optimizer_configs['schedule_method'] is not None:
            if attacker_optimizer_configs['schedule_method'] == "loss":
                scheduler = ReduceOnPlateauScheduler(
                    optimizer=optimizer,
                    **attacker_optimizer_configs["scheduler_configs"])
            elif attacker_optimizer_configs['schedule_method'] == "noam":
                scheduler = NoamScheduler(
                    optimizer=optimizer,
                    **attacker_optimizer_configs['scheduler_configs'])
            elif attacker_optimizer_configs["schedule_method"] == "rsqrt":
                scheduler = RsqrtScheduler(
                    optimizer=optimizer,
                    **attacker_optimizer_configs["scheduler_configs"])
            else:
                WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".
                     format(attacker_optimizer_configs['schedule_method']))
                scheduler = None
        else:
            scheduler = None

    local_saver.load_latest(model=local_attacker,
                            optim=optimizer,
                            lr_scheduler=scheduler)

    attacker_iterator = attack_iterator.build_generator()
    env = Translate_Env(attack_configs=attack_configs,
                        discriminator_configs=discriminator_configs,
                        src_vocab=src_vocab,
                        trg_vocab=trg_vocab,
                        data_iterator=attacker_iterator,
                        save_to=args.save_to,
                        device=device)
    episode_count = 0
    episode_length = 0
    local_steps = 0  # optimization steps: for learning rate schedules
    patience_t = patience
    while True:  # infinite loop of data set
        # we will continue with a new iterator with refreshed environments
        # whenever the last iterator breaks with "StopIteration"
        attacker_iterator = attack_iterator.build_generator()
        env.reset_data_iter(attacker_iterator)
        padded_src = env.reset()
        padded_src = torch.from_numpy(padded_src)
        if device != "cpu":
            padded_src = padded_src.to(device)
        done = True
        discriminator_base_steps = local_steps

        while True:
            # check for update of discriminator
            # if env.acc_validation(local_attacker, use_gpu=True if env.device != "cpu" else False) < 0.55:
            if episode_count % attacker_configs["attacker_update_steps"] == 0:
                """ stop criterion:
                when updates a discriminator, we check for acc. If acc fails acc_bound,
                we reset the discriminator and try, until acc reaches the bound with patience.
                otherwise the training thread stops
                """
                try:
                    discriminator_base_steps, trust_acc = env.update_discriminator(
                        local_attacker,
                        discriminator_base_steps,
                        min_update_steps=discriminator_configs[
                            "acc_valid_freq"],
                        max_update_steps=discriminator_configs[
                            "discriminator_update_steps"],
                        accuracy_bound=acc_bound,
                        summary_writer=summary_writer)
                except StopIteration:
                    INFO("finish one training epoch, reset data_iterator")
                    break

                discriminator_base_steps += 1  # a flag to label the discriminator updates
                if trust_acc < converged_bound:  # GAN target reached
                    patience_t -= 1
                    INFO(
                        "discriminator reached GAN convergence bound: %d times"
                        % patience_t)
                else:  # reset patience if discriminator is refreshed
                    patience_t = patience

            if saver and local_steps % attack_configs["save_freq"] == 0:
                local_saver.save(global_step=local_steps,
                                 model=local_attacker,
                                 optim=optimizer,
                                 lr_scheduler=scheduler)

                if trust_acc < converged_bound:  # and patience_t == patience-1:
                    # we only save the global params reaching acc_bound
                    torch.save(global_attacker.state_dict(),
                               os.path.join(args.save_to, "ACmodel.final"))
                    # saver.raw_save(model=global_attacker)

            if patience_t == 0:
                WARN("maximum patience reached. Training Thread should stop")
                break

            local_attacker.train()  # switch back to training mode

            # for a initial (reset) attacker from global parameters
            if done:
                INFO("sync from global model")
                local_attacker.load_state_dict(global_attacker.state_dict())
            # move the local attacker params back to device after updates
            local_attacker = local_attacker.to(device)
            values = []  # training critic: network outputs
            log_probs = []
            rewards = []  # actual rewards
            entropies = []

            local_steps += 1
            # run sequences step of attack
            try:
                for i in range(args.action_roll_steps):
                    episode_length += 1
                    attack_out, critic_out = local_attacker(
                        padded_src, padded_src[:, env.index - 1:env.index + 2])
                    logit_attack_out = torch.log(attack_out)
                    entropy = -(attack_out *
                                logit_attack_out).sum(dim=-1).mean()

                    summary_writer.add_scalar("action_entropy",
                                              scalar_value=entropy,
                                              global_step=local_steps)
                    entropies.append(entropy)  # for entropy loss
                    actions = attack_out.multinomial(num_samples=1).detach()
                    # only extract the log prob for chosen action (avg over batch)
                    log_attack_out = logit_attack_out.gather(-1,
                                                             actions).mean()
                    padded_src, reward, terminal_signal = env.step(
                        actions.squeeze())
                    done = terminal_signal or episode_length > args.max_episode_lengths

                    with lock:
                        counter.value += 1

                    if done:
                        episode_length = 0
                        padded_src = env.reset()

                    padded_src = torch.from_numpy(padded_src)
                    if device != "cpu":
                        padded_src = padded_src.to(device)

                    values.append(
                        critic_out.mean())  # list of torch variables (scalar)
                    log_probs.append(
                        log_attack_out)  # list of torch variables (scalar)
                    rewards.append(reward)  # list of reward variables

                    if done:
                        episode_count += 1
                        break
            except StopIteration:
                INFO("finish one training epoch, reset data_iterator")
                break

            R = torch.zeros(1, 1)
            gae = torch.zeros(1, 1)
            if device != "cpu":
                R = R.to(device)
                gae = gae.to(device)

            if not done:  # calculate value loss
                value = local_attacker.get_critic(
                    padded_src, padded_src[:, env.index - 1:env.index + 2])
                R = value.mean().detach()

            values.append(R)
            policy_loss = 0
            value_loss = 0

            # collect values for training
            for i in reversed((range(len(rewards)))):
                # value loss and policy loss must be clipped to stabilize training
                R = attack_configs["gamma"] * R + rewards[i]
                advantage = R - values[i]
                value_loss = value_loss + 0.5 * advantage.pow(2)

                delta_t = rewards[i] + attack_configs["gamma"] * \
                          values[i + 1] - values[i]
                gae = gae * attack_configs["gamma"] * attack_configs["tau"] + \
                      delta_t
                policy_loss = policy_loss - log_probs[i] * gae.detach() - \
                              attack_configs["entropy_coef"] * entropies[i]
                print("policy_loss", policy_loss)
                print("gae", gae)

            # update with optimizer
            optimizer.zero_grad()
            # we decay the loss according to discriminator's accuracy as a trust region constrain
            summary_writer.add_scalar("policy_loss",
                                      scalar_value=policy_loss * trust_acc,
                                      global_step=local_steps)
            summary_writer.add_scalar("value_loss",
                                      scalar_value=value_loss * trust_acc,
                                      global_step=local_steps)

            total_loss = trust_acc * policy_loss + \
                         trust_acc * attack_configs["value_coef"] * value_loss
            total_loss.backward()

            if attacker_optimizer_configs[
                    "schedule_method"] is not None and attacker_optimizer_configs[
                        "schedule_method"] != "loss":
                scheduler.step(global_step=local_steps)

            # move the model params to CPU and
            # assign local gradients to the global model to update
            local_attacker.to("cpu").ensure_shared_grads(global_attacker)
            optimizer.step()
            print("bingo!")

        if patience_t == 0:
            INFO("Reach maximum Discriminator patience, Finish")
            break
예제 #5
0
파일: nmt.py 프로젝트: whr94621/ODC-NMT
def train(FLAGS):
    """
    FLAGS:
        saveto: str
        reload: store_true
        config_path: str
        pretrain_path: str, default=""
        model_name: str
        log_path: str
    """

    # ================================================================================== #
    # Initialization for training on different devices
    # - CPU/GPU
    # - Single/Distributed
    GlobalNames.USE_GPU = FLAGS.use_gpu

    if FLAGS.multi_gpu:

        if hvd is None or distributed is None:
            ERROR("Distributed training is disable. Please check the installation of Horovod.")

        hvd.init()
        world_size = hvd.size()
        rank = hvd.rank()
        local_rank = hvd.local_rank()
    else:
        world_size = 1
        rank = 0
        local_rank = 0

    if GlobalNames.USE_GPU:
        torch.cuda.set_device(local_rank)
        CURRENT_DEVICE = "cuda:{0}".format(local_rank)
    else:
        CURRENT_DEVICE = "cpu"

    # If not root_rank, close logging
    if rank != 0:
        close_logging()

    # write log of training to file.
    if rank == 0:
        write_log_to_file(os.path.join(FLAGS.log_path, "%s.log" % time.strftime("%Y%m%d-%H%M%S")))

    # ================================================================================== #
    # Parsing configuration files

    config_path = os.path.abspath(FLAGS.config_path)
    with open(config_path.strip()) as f:
        configs = yaml.load(f)

    INFO(pretty_configs(configs))

    # Add default configs
    configs = default_baseline_configs(configs)
    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    optimizer_configs = configs['optimizer_configs']
    training_configs = configs['training_configs']

    GlobalNames.SEED = training_configs['seed']

    set_seed(GlobalNames.SEED)

    timer = Timer()

    # ================================================================================== #
    # Load Data

    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_src = Vocabulary(**data_configs["vocabularies"][0])
    vocab_tgt = Vocabulary(**data_configs["vocabularies"][1])

    actual_buffer_size = training_configs["buffer_size"] * max(1, training_configs["update_cycle"])

    train_bitext_dataset = ZipDataset(
        TextLineDataset(data_path=data_configs['train_data'][0],
                        vocabulary=vocab_src,
                        max_len=data_configs['max_len'][0],
                        ),
        TextLineDataset(data_path=data_configs['train_data'][1],
                        vocabulary=vocab_tgt,
                        max_len=data_configs['max_len'][1],
                        )
    )

    valid_bitext_dataset = ZipDataset(
        TextLineDataset(data_path=data_configs['valid_data'][0],
                        vocabulary=vocab_src,
                        ),
        TextLineDataset(data_path=data_configs['valid_data'][1],
                        vocabulary=vocab_tgt,
                        )
    )

    training_iterator = DataIterator(dataset=train_bitext_dataset,
                                     batch_size=training_configs["batch_size"],
                                     use_bucket=training_configs['use_bucket'],
                                     buffer_size=actual_buffer_size,
                                     batching_func=training_configs['batching_key'],
                                     world_size=world_size,
                                     rank=rank)

    valid_iterator = DataIterator(dataset=valid_bitext_dataset,
                                  batch_size=training_configs['valid_batch_size'],
                                  use_bucket=True, buffer_size=100000, numbering=True,
                                  world_size=world_size, rank=rank)

    bleu_scorer = SacreBLEUScorer(reference_path=data_configs["bleu_valid_reference"],
                                  num_refs=data_configs["num_refs"],
                                  lang_pair=data_configs["lang_pair"],
                                  sacrebleu_args=training_configs["bleu_valid_configs"]['sacrebleu_args'],
                                  postprocess=training_configs["bleu_valid_configs"]['postprocess']
                                  )

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    lrate = optimizer_configs['learning_rate']
    is_early_stop = False

    # ================================ Begin ======================================== #
    # Build Model & Optimizer
    # We would do steps below on after another
    #     1. build models & criterion
    #     2. move models & criterion to gpu if needed
    #     3. load pre-trained model if needed
    #     4. build optimizer
    #     5. build learning rate scheduler if needed
    #     6. load checkpoints if needed

    # 0. Initial
    model_collections = Collections()
    best_model_prefix = os.path.join(FLAGS.saveto, FLAGS.model_name + GlobalNames.MY_BEST_MODEL_SUFFIX)

    checkpoint_saver = Saver(save_prefix="{0}.ckpt".format(os.path.join(FLAGS.saveto, FLAGS.model_name)),
                             num_max_keeping=training_configs['num_kept_checkpoints']
                             )
    best_model_saver = Saver(save_prefix=best_model_prefix, num_max_keeping=training_configs['num_kept_best_model'])

    INFO('Building model...')
    timer.tic()
    nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                            n_tgt_vocab=vocab_tgt.max_n_words, **model_configs)
    INFO(nmt_model)

    critic = NMTCriterion(label_smoothing=model_configs['label_smoothing'])

    INFO(critic)
    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # 2. Move to GPU
    if GlobalNames.USE_GPU:
        nmt_model = nmt_model.cuda()
        critic = critic.cuda()

    # 3. Load pretrained model if needed
    load_pretrained_model(nmt_model, FLAGS.pretrain_path, exclude_prefix=None, device=CURRENT_DEVICE)

    # 4. Build optimizer
    INFO('Building Optimizer...')
    optim = Optimizer(name=optimizer_configs['optimizer'],
                      model=nmt_model,
                      lr=lrate,
                      grad_clip=optimizer_configs['grad_clip'],
                      optim_args=optimizer_configs['optimizer_params'],
                      distributed=True if world_size > 1 else False,
                      update_cycle=training_configs['update_cycle']
                      )
    # 5. Build scheduler for optimizer if needed
    if optimizer_configs['schedule_method'] is not None:

        if optimizer_configs['schedule_method'] == "loss":

            scheduler = ReduceOnPlateauScheduler(optimizer=optim,
                                                 **optimizer_configs["scheduler_configs"]
                                                 )

        elif optimizer_configs['schedule_method'] == "noam":
            scheduler = NoamScheduler(optimizer=optim, **optimizer_configs['scheduler_configs'])
        else:
            WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".format(optimizer_configs['schedule_method']))
            scheduler = None
    else:
        scheduler = None

    # 6. build moving average

    if training_configs['moving_average_method'] is not None:
        ma = MovingAverage(moving_average_method=training_configs['moving_average_method'],
                           named_params=nmt_model.named_parameters(),
                           alpha=training_configs['moving_average_alpha'])
    else:
        ma = None

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # Reload from latest checkpoint
    if FLAGS.reload:
        checkpoint_saver.load_latest(model=nmt_model, optim=optim, lr_scheduler=scheduler,
                                     collections=model_collections, ma=ma)

    # broadcast parameters and optimizer states
    if world_size > 1:
        hvd.broadcast_parameters(params=nmt_model.state_dict(), root_rank=0)
        hvd.broadcast_optimizer_state(optimizer=optim.optim, root_rank=0)

    # ================================================================================== #
    # Prepare training

    eidx = model_collections.get_collection("eidx", [0])[-1]
    uidx = model_collections.get_collection("uidx", [1])[-1]
    bad_count = model_collections.get_collection("bad_count", [0])[-1]
    oom_count = model_collections.get_collection("oom_count", [0])[-1]
    cum_n_samples = 0
    cum_n_words = 0
    best_valid_loss = 1.0 * 1e10  # Max Float
    update_cycle = training_configs['update_cycle']
    grad_denom = 0

    if rank == 0:
        summary_writer = SummaryWriter(log_dir=FLAGS.log_path)
    else:
        summary_writer = None

    # Timer for computing speed
    timer_for_speed = Timer()
    timer_for_speed.tic()

    INFO('Begin training...')

    while True:

        if summary_writer is not None:
            summary_writer.add_scalar("Epoch", (eidx + 1), uidx)

        # Build iterator and progress bar
        training_iter = training_iterator.build_generator()

        if rank == 0:
            training_progress_bar = tqdm(desc='  - (Epoch %d)   ' % eidx,
                                         total=len(training_iterator),
                                         unit="sents"
                                         )
        else:
            training_progress_bar = None

        for batch in training_iter:

            seqs_x, seqs_y = batch

            batch_size = len(seqs_x)

            cum_n_samples += batch_size
            cum_n_words += sum(len(s) for s in seqs_y)

            try:
                # Prepare data
                x, y = prepare_data(seqs_x, seqs_y, cuda=GlobalNames.USE_GPU)

                loss = compute_forward(model=nmt_model,
                                       critic=critic,
                                       seqs_x=x,
                                       seqs_y=y,
                                       eval=False,
                                       normalization=1.0,
                                       norm_by_words=training_configs["norm_by_words"])

                update_cycle -= 1
                grad_denom += batch_size

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    oom_count += 1
                else:
                    raise e

            # When update_cycle becomes 0, it means end of one batch. Several things will be done:
            # - update parameters
            # - reset update_cycle and grad_denom
            # - update uidx
            # - update moving average

            if update_cycle == 0:
                if world_size > 1:
                    grad_denom = distributed.all_reduce(grad_denom)

                optim.step(denom=grad_denom)
                optim.zero_grad()

                if training_progress_bar is not None:
                    training_progress_bar.update(grad_denom)

                update_cycle = training_configs['update_cycle']
                grad_denom = 0

                uidx += 1

                if scheduler is None:
                    pass
                elif optimizer_configs["schedule_method"] == "loss":
                    scheduler.step(metric=best_valid_loss)
                else:
                    scheduler.step(global_step=uidx)

                if ma is not None and eidx >= training_configs['moving_average_start_epoch']:
                    ma.step()
            else:
                continue

            # ================================================================================== #
            # Display some information
            if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['disp_freq']):

                if world_size > 1:
                    cum_n_words = sum(distributed.all_gather(cum_n_words))
                    cum_n_samples = sum(distributed.all_gather(cum_n_samples))

                # words per second and sents per second
                words_per_sec = cum_n_words / (timer.toc(return_seconds=True))
                sents_per_sec = cum_n_samples / (timer.toc(return_seconds=True))
                lrate = list(optim.get_lrate())[0]

                if summary_writer is not None:
                    summary_writer.add_scalar("Speed(words/sec)", scalar_value=words_per_sec, global_step=uidx)
                    summary_writer.add_scalar("Speed(sents/sen)", scalar_value=sents_per_sec, global_step=uidx)
                    summary_writer.add_scalar("lrate", scalar_value=lrate, global_step=uidx)
                    summary_writer.add_scalar("oom_count", scalar_value=oom_count, global_step=uidx)

                # Reset timer
                timer.tic()
                cum_n_words = 0
                cum_n_samples = 0

            # ================================================================================== #
            # Loss Validation & Learning rate annealing
            if should_trigger_by_steps(global_step=uidx, n_epoch=eidx, every_n_step=training_configs['loss_valid_freq'],
                                       debug=FLAGS.debug):

                valid_loss = loss_validation(model=nmt_model,
                                             critic=critic,
                                             valid_iterator=valid_iterator,
                                             rank=rank,
                                             world_size=world_size
                                             )

                model_collections.add_to_collection("history_losses", valid_loss)

                min_history_loss = np.array(model_collections.get_collection("history_losses")).min()

                best_valid_loss = min_history_loss

                if summary_writer is not None:
                    summary_writer.add_scalar("loss", valid_loss, global_step=uidx)
                    summary_writer.add_scalar("best_loss", min_history_loss, global_step=uidx)

            # ================================================================================== #
            # BLEU Validation & Early Stop

            if should_trigger_by_steps(global_step=uidx, n_epoch=eidx,
                                       every_n_step=training_configs['bleu_valid_freq'],
                                       min_step=training_configs['bleu_valid_warmup'],
                                       debug=FLAGS.debug):

                valid_bleu = bleu_validation(uidx=uidx,
                                             valid_iterator=valid_iterator,
                                             batch_size=training_configs["bleu_valid_batch_size"],
                                             model=nmt_model,
                                             bleu_scorer=bleu_scorer,
                                             vocab_tgt=vocab_tgt,
                                             valid_dir=FLAGS.valid_path,
                                             max_steps=training_configs["bleu_valid_configs"]["max_steps"],
                                             beam_size=training_configs["bleu_valid_configs"]["beam_size"],
                                             alpha=training_configs["bleu_valid_configs"]["alpha"],
                                             world_size=world_size,
                                             rank=rank,
                                             )

                model_collections.add_to_collection(key="history_bleus", value=valid_bleu)

                best_valid_bleu = float(np.array(model_collections.get_collection("history_bleus")).max())

                if summary_writer is not None:
                    summary_writer.add_scalar("bleu", valid_bleu, uidx)
                    summary_writer.add_scalar("best_bleu", best_valid_bleu, uidx)

                # If model get new best valid bleu score
                if valid_bleu >= best_valid_bleu:
                    bad_count = 0

                    if is_early_stop is False:
                        if rank == 0:
                            # 1. save the best model
                            torch.save(nmt_model.state_dict(), best_model_prefix + ".final")

                            # 2. record all several best models
                            best_model_saver.save(global_step=uidx, model=nmt_model, ma=ma)
                else:
                    bad_count += 1

                    # At least one epoch should be traversed
                    if bad_count >= training_configs['early_stop_patience'] and eidx > 0:
                        is_early_stop = True
                        WARN("Early Stop!")

                if summary_writer is not None:
                    summary_writer.add_scalar("bad_count", bad_count, uidx)

                INFO("{0} Loss: {1:.2f} BLEU: {2:.2f} lrate: {3:6f} patience: {4}".format(
                    uidx, valid_loss, valid_bleu, lrate, bad_count
                ))

            # ================================================================================== #
            # Saving checkpoints
            if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['save_freq'], debug=FLAGS.debug):
                model_collections.add_to_collection("uidx", uidx)
                model_collections.add_to_collection("eidx", eidx)
                model_collections.add_to_collection("bad_count", bad_count)

                if not is_early_stop:
                    if rank == 0:
                        checkpoint_saver.save(global_step=uidx,
                                              model=nmt_model,
                                              optim=optim,
                                              lr_scheduler=scheduler,
                                              collections=model_collections,
                                              ma=ma)

        if training_progress_bar is not None:
            training_progress_bar.close()

        eidx += 1
        if eidx > training_configs["max_epochs"]:
            break
예제 #6
0
def train(FLAGS):
    """
    FLAGS:
        saveto: str
        reload: store_true
        config_path: str
        pretrain_path: str, default=""
        model_name: str
        log_path: str
    """

    # write log of training to file.
    write_log_to_file(os.path.join(FLAGS.log_path, "%s.log" % time.strftime("%Y%m%d-%H%M%S")))

    GlobalNames.USE_GPU = FLAGS.use_gpu

    if GlobalNames.USE_GPU:
        CURRENT_DEVICE = "cpu"
    else:
        CURRENT_DEVICE = "cuda:0"

    config_path = os.path.abspath(FLAGS.config_path)
    with open(config_path.strip()) as f:
        configs = yaml.load(f)

    INFO(pretty_configs(configs))

    # Add default configs
    configs = default_configs(configs)
    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    optimizer_configs = configs['optimizer_configs']
    training_configs = configs['training_configs']

    GlobalNames.SEED = training_configs['seed']

    set_seed(GlobalNames.SEED)

    best_model_prefix = os.path.join(FLAGS.saveto, FLAGS.model_name + GlobalNames.MY_BEST_MODEL_SUFFIX)

    timer = Timer()

    # ================================================================================== #
    # Load Data

    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_src = Vocabulary(**data_configs["vocabularies"][0])
    vocab_tgt = Vocabulary(**data_configs["vocabularies"][1])

    train_batch_size = training_configs["batch_size"] * max(1, training_configs["update_cycle"])
    train_buffer_size = training_configs["buffer_size"] * max(1, training_configs["update_cycle"])

    train_bitext_dataset = ZipDataset(
        TextLineDataset(data_path=data_configs['train_data'][0],
                        vocabulary=vocab_src,
                        max_len=data_configs['max_len'][0],
                        ),
        TextLineDataset(data_path=data_configs['train_data'][1],
                        vocabulary=vocab_tgt,
                        max_len=data_configs['max_len'][1],
                        ),
        shuffle=training_configs['shuffle']
    )

    valid_bitext_dataset = ZipDataset(
        TextLineDataset(data_path=data_configs['valid_data'][0],
                        vocabulary=vocab_src,
                        ),
        TextLineDataset(data_path=data_configs['valid_data'][1],
                        vocabulary=vocab_tgt,
                        )
    )

    training_iterator = DataIterator(dataset=train_bitext_dataset,
                                     batch_size=train_batch_size,
                                     use_bucket=training_configs['use_bucket'],
                                     buffer_size=train_buffer_size,
                                     batching_func=training_configs['batching_key'])

    valid_iterator = DataIterator(dataset=valid_bitext_dataset,
                                  batch_size=training_configs['valid_batch_size'],
                                  use_bucket=True, buffer_size=100000, numbering=True)

    bleu_scorer = SacreBLEUScorer(reference_path=data_configs["bleu_valid_reference"],
                                  num_refs=data_configs["num_refs"],
                                  lang_pair=data_configs["lang_pair"],
                                  sacrebleu_args=training_configs["bleu_valid_configs"]['sacrebleu_args'],
                                  postprocess=training_configs["bleu_valid_configs"]['postprocess']
                                  )

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    lrate = optimizer_configs['learning_rate']
    is_early_stop = False

    # ================================ Begin ======================================== #
    # Build Model & Optimizer
    # We would do steps below on after another
    #     1. build models & criterion
    #     2. move models & criterion to gpu if needed
    #     3. load pre-trained model if needed
    #     4. build optimizer
    #     5. build learning rate scheduler if needed
    #     6. load checkpoints if needed

    # 0. Initial
    model_collections = Collections()
    checkpoint_saver = Saver(save_prefix="{0}.ckpt".format(os.path.join(FLAGS.saveto, FLAGS.model_name)),
                             num_max_keeping=training_configs['num_kept_checkpoints']
                             )
    best_model_saver = Saver(save_prefix=best_model_prefix, num_max_keeping=training_configs['num_kept_best_model'])

    # 1. Build Model & Criterion
    INFO('Building model...')
    timer.tic()
    nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                            n_tgt_vocab=vocab_tgt.max_n_words, **model_configs)
    INFO(nmt_model)

    critic = NMTCriterion(label_smoothing=model_configs['label_smoothing'])

    INFO(critic)
    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # 2. Move to GPU
    if GlobalNames.USE_GPU:
        nmt_model = nmt_model.cuda()
        critic = critic.cuda()

    # 3. Load pretrained model if needed
    load_pretrained_model(nmt_model, FLAGS.pretrain_path, exclude_prefix=None, device=CURRENT_DEVICE)

    # 4. Build optimizer
    INFO('Building Optimizer...')
    optim = Optimizer(name=optimizer_configs['optimizer'],
                      model=nmt_model,
                      lr=lrate,
                      grad_clip=optimizer_configs['grad_clip'],
                      optim_args=optimizer_configs['optimizer_params']
                      )
    # 5. Build scheduler for optimizer if needed
    if optimizer_configs['schedule_method'] is not None:

        if optimizer_configs['schedule_method'] == "loss":

            scheduler = ReduceOnPlateauScheduler(optimizer=optim,
                                                 **optimizer_configs["scheduler_configs"]
                                                 )

        elif optimizer_configs['schedule_method'] == "noam":
            scheduler = NoamScheduler(optimizer=optim, **optimizer_configs['scheduler_configs'])
        else:
            WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".format(optimizer_configs['schedule_method']))
            scheduler = None
    else:
        scheduler = None

    # 6. build EMA
    if training_configs['ema_decay'] > 0.0:
        ema = ExponentialMovingAverage(named_params=nmt_model.named_parameters(), decay=training_configs['ema_decay'])
    else:
        ema = None

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # Reload from latest checkpoint
    if FLAGS.reload:
        checkpoint_saver.load_latest(model=nmt_model, optim=optim, lr_scheduler=scheduler,
                                     collections=model_collections)

    # ================================================================================== #
    # Prepare training

    eidx = model_collections.get_collection("eidx", [0])[-1]
    uidx = model_collections.get_collection("uidx", [0])[-1]
    bad_count = model_collections.get_collection("bad_count", [0])[-1]

    summary_writer = SummaryWriter(log_dir=FLAGS.log_path)

    cum_samples = 0
    cum_words = 0
    best_valid_loss = 1.0 * 1e10  # Max Float
    saving_files = []

    # Timer for computing speed
    timer_for_speed = Timer()
    timer_for_speed.tic()

    INFO('Begin training...')

    while True:

        summary_writer.add_scalar("Epoch", (eidx + 1), uidx)

        # Build iterator and progress bar
        training_iter = training_iterator.build_generator()
        training_progress_bar = tqdm(desc='  - (Epoch %d)   ' % eidx,
                                     total=len(training_iterator),
                                     unit="sents"
                                     )
        for batch in training_iter:

            uidx += 1

            if scheduler is None:
                pass
            elif optimizer_configs["schedule_method"] == "loss":
                scheduler.step(metric=best_valid_loss)
            else:
                scheduler.step(global_step=uidx)

            seqs_x, seqs_y = batch

            n_samples_t = len(seqs_x)
            n_words_t = sum(len(s) for s in seqs_y)

            cum_samples += n_samples_t
            cum_words += n_words_t

            training_progress_bar.update(n_samples_t)

            optim.zero_grad()

            # Prepare data
            for seqs_x_t, seqs_y_t in split_shard(seqs_x, seqs_y, split_size=training_configs['update_cycle']):
                x, y = prepare_data(seqs_x_t, seqs_y_t, cuda=GlobalNames.USE_GPU)

                loss = compute_forward(model=nmt_model,
                                       critic=critic,
                                       seqs_x=x,
                                       seqs_y=y,
                                       eval=False,
                                       normalization=n_samples_t,
                                       norm_by_words=training_configs["norm_by_words"])
            optim.step()

            if ema is not None:
                ema.step()

            # ================================================================================== #
            # Display some information
            if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['disp_freq']):
                # words per second and sents per second
                words_per_sec = cum_words / (timer.toc(return_seconds=True))
                sents_per_sec = cum_samples / (timer.toc(return_seconds=True))
                lrate = list(optim.get_lrate())[0]

                summary_writer.add_scalar("Speed(words/sec)", scalar_value=words_per_sec, global_step=uidx)
                summary_writer.add_scalar("Speed(sents/sen)", scalar_value=sents_per_sec, global_step=uidx)
                summary_writer.add_scalar("lrate", scalar_value=lrate, global_step=uidx)

                # Reset timer
                timer.tic()
                cum_words = 0
                cum_samples = 0

            # ================================================================================== #
            # Saving checkpoints
            if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['save_freq'], debug=FLAGS.debug):
                model_collections.add_to_collection("uidx", uidx)
                model_collections.add_to_collection("eidx", eidx)
                model_collections.add_to_collection("bad_count", bad_count)

                if not is_early_stop:

                    checkpoint_saver.save(global_step=uidx, model=nmt_model, optim=optim, lr_scheduler=scheduler,
                                          collections=model_collections, ema=ema)

            # ================================================================================== #
            # Loss Validation & Learning rate annealing
            if should_trigger_by_steps(global_step=uidx, n_epoch=eidx, every_n_step=training_configs['loss_valid_freq'],
                                       debug=FLAGS.debug):

                if ema is not None:
                    origin_state_dict = deepcopy(nmt_model.state_dict())
                    nmt_model.load_state_dict(ema.state_dict(), strict=False)

                valid_loss = loss_validation(model=nmt_model,
                                             critic=critic,
                                             valid_iterator=valid_iterator,
                                             )

                model_collections.add_to_collection("history_losses", valid_loss)

                min_history_loss = np.array(model_collections.get_collection("history_losses")).min()

                summary_writer.add_scalar("loss", valid_loss, global_step=uidx)
                summary_writer.add_scalar("best_loss", min_history_loss, global_step=uidx)

                best_valid_loss = min_history_loss

                if ema is not None:
                    nmt_model.load_state_dict(origin_state_dict)
                    del origin_state_dict

            # ================================================================================== #
            # BLEU Validation & Early Stop

            if should_trigger_by_steps(global_step=uidx, n_epoch=eidx,
                                       every_n_step=training_configs['bleu_valid_freq'],
                                       min_step=training_configs['bleu_valid_warmup'],
                                       debug=FLAGS.debug):

                if ema is not None:
                    origin_state_dict = deepcopy(nmt_model.state_dict())
                    nmt_model.load_state_dict(ema.state_dict(), strict=False)

                valid_bleu = bleu_validation(uidx=uidx,
                                             valid_iterator=valid_iterator,
                                             batch_size=training_configs["bleu_valid_batch_size"],
                                             model=nmt_model,
                                             bleu_scorer=bleu_scorer,
                                             vocab_tgt=vocab_tgt,
                                             valid_dir=FLAGS.valid_path,
                                             max_steps=training_configs["bleu_valid_configs"]["max_steps"],
                                             beam_size=training_configs["bleu_valid_configs"]["beam_size"],
                                             alpha=training_configs["bleu_valid_configs"]["alpha"]
                                             )

                model_collections.add_to_collection(key="history_bleus", value=valid_bleu)

                best_valid_bleu = float(np.array(model_collections.get_collection("history_bleus")).max())

                summary_writer.add_scalar("bleu", valid_bleu, uidx)
                summary_writer.add_scalar("best_bleu", best_valid_bleu, uidx)

                # If model get new best valid bleu score
                if valid_bleu >= best_valid_bleu:
                    bad_count = 0

                    if is_early_stop is False:
                        # 1. save the best model
                        torch.save(nmt_model.state_dict(), best_model_prefix + ".final")

                        # 2. record all several best models
                        best_model_saver.save(global_step=uidx, model=nmt_model)
                else:
                    bad_count += 1

                    # At least one epoch should be traversed
                    if bad_count >= training_configs['early_stop_patience'] and eidx > 0:
                        is_early_stop = True
                        WARN("Early Stop!")

                summary_writer.add_scalar("bad_count", bad_count, uidx)

                if ema is not None:
                    nmt_model.load_state_dict(origin_state_dict)
                    del origin_state_dict

                INFO("{0} Loss: {1:.2f} BLEU: {2:.2f} lrate: {3:6f} patience: {4}".format(
                    uidx, valid_loss, valid_bleu, lrate, bad_count
                ))

        training_progress_bar.close()

        eidx += 1
        if eidx > training_configs["max_epochs"]:
            break