Exemple #1
0
class SACAgent(object):
    def __init__(self,
                 device="cpu",
                 d_word_vec=512,
                 d_model=256,
                 limit_dist=0.1,
                 dropout=0.0,
                 reparam_noise=1e-6,
                 **kwargs):
        self.device = device
        self.actor = Rephraser(d_word_vec=d_word_vec,
                               d_model=d_model,
                               limit_dist=limit_dist,
                               dropout=dropout,
                               reparam_noise=reparam_noise).to(device)
        self.critic = CriticNet(d_word_vec=d_word_vec,
                                d_model=d_model,
                                limit_dist=limit_dist,
                                dropout=dropout,
                                reparam_noise=reparam_noise).to(device)
        self.saver = Saver(save_prefix="{0}.ckpt".format(
            os.path.join(kwargs["save_to"], "ACmodel")),
                           num_max_keeping=kwargs["num_kept_checkpoints"])
        self.soft_update_lock = mp.Lock()
        # the entropy regularization weight for SAC learning
        self.learnable_temperature = kwargs["learnable_temperature"]
        self.target_entropy = -d_word_vec  # act_dim (d_word_vec) as the expected entropy base
        self.log_alpha = torch.tensor(np.log(kwargs["init_temperature"])).to(
            self.device)
        self.log_alpha.requires_grad = True
        # initialize the training mode for the Agent
        self.train()
        self._init_local_optims(kwargs["rephraser_optimizer_configs"])
        # self.load_model()  # always reload model if there is any in the path

    def to(self, device):
        self.actor.to(device)
        self.critic.to(device)
        self.log_alpha.to(device)
        return self

    def share_memory(self):
        # global model needs to share memory with other threads
        self.actor.share_memory()
        self.critic.share_memory()

    @property
    def alpha(self):
        return self.log_alpha.exp()

    def load_model(self, load_final_path: str = None):
        """
        load from path by self.saver
        :param load_final_path: final model path dir, final model doesn't have optim_params
        :return: training step count int
        """
        step = 0
        model_collections = Collections()
        if load_final_path:
            # self.saver.load_latest(
            #     actor_model=self.actor, critic_model=self.critic
            # )  # load from the latest ckpt model
            state_dict = torch.load(os.path.join(load_final_path))
            self.actor.load_state_dict(state_dict["actor_model"])
            self.critic.load_state_dict(state_dict["critic_model"])
        else:
            self.saver.load_latest(collections=model_collections,
                                   actor_model=self.actor,
                                   critic_model=self.critic,
                                   actor_optim=self.actor_optimizer,
                                   critic_optim=self.critic_optimizer,
                                   actor_scheduler=self.actor_scheduler,
                                   critic_scheduler=self.critic_scheduler)
            step = model_collections.get_collection("step", [0])[-1]
        return step

    def save_model(
            self,
            step=None,
            save_to_final=None):  # save model parameters, optims, lr_steps
        model_collections = Collections()
        if step is not None:
            model_collections.add_to_collection("step", step)
            self.saver.save(global_step=step,
                            collections=model_collections,
                            actor_model=self.actor,
                            critic_model=self.critic,
                            actor_optim=self.actor_optimizer,
                            critic_optim=self.critic_optimizer,
                            actor_scheduler=self.actor_scheduler,
                            critic_scheduler=self.critic_scheduler)
        else:  # only save the model parameters
            assert save_to_final is not None, "final model saving dir must be provided"
            collection = dict()
            collection["actor_model"] = self.actor.state_dict()
            collection["critic_model"] = self.critic.state_dict()
            torch.save(collection, os.path.join(save_to_final,
                                                "ACmodel.final"))
        return

    def _init_local_optims(self, rephraser_optimizer_configs):
        """ actor, critic, alpha optimizers and lr scheduler if necessary
        rephraser_optimizer_configs:
            optimizer: "adafactor"
            learning_rate: 0.01
            grad_clip: -1.0
            optimizer_params: ~
            schedule_method: rsqrt
            scheduler_configs:
                d_model: *dim
                warmup_steps: 100
        """
        # initiate local optimizer
        if rephraser_optimizer_configs is None:
            self.actor_optimizer = None
            self.critic_optimizer = None
            self.log_alpha_optimizer = None
            # self.actor_icm_optimizer = None
            self.actor_scheduler = None
            self.critic_scheduler = None
        else:
            self.actor_optimizer = Optimizer(
                name=rephraser_optimizer_configs["optimizer"],
                model=self.actor,
                lr=rephraser_optimizer_configs["learning_rate"],
                grad_clip=rephraser_optimizer_configs["grad_clip"],
                optim_args=rephraser_optimizer_configs["optimizer_params"])
            self.critic_optimizer = Optimizer(
                name=rephraser_optimizer_configs["optimizer"],
                model=self.critic,
                lr=rephraser_optimizer_configs["learning_rate"],
                grad_clip=rephraser_optimizer_configs["grad_clip"],
                optim_args=rephraser_optimizer_configs["optimizer_params"])
            # hardcoded entropy weight updates and icm updates
            self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],
                                                        lr=1e-4,
                                                        betas=(0.9, 0.999))
            # self.actor_icm_optimizer = torch.optim.Adam(self.actor.icm.parameters(), lr=1e-3, )

            # Build scheduler for optimizer if needed
            if rephraser_optimizer_configs['schedule_method'] is not None:
                if rephraser_optimizer_configs['schedule_method'] == "loss":
                    self.actor_scheduler = ReduceOnPlateauScheduler(
                        optimizer=self.actor_optimizer,
                        **rephraser_optimizer_configs["scheduler_configs"])
                    self.critic_scheduler = ReduceOnPlateauScheduler(
                        optimizer=self.critic_optimizer,
                        **rephraser_optimizer_configs["scheduler_configs"])
                elif rephraser_optimizer_configs['schedule_method'] == "noam":
                    self.actor_scheduler = NoamScheduler(
                        optimizer=self.actor_optimizer,
                        **rephraser_optimizer_configs["scheduler_configs"])
                    self.critic_scheduler = NoamScheduler(
                        optimizer=self.critic_optimizer,
                        **rephraser_optimizer_configs["scheduler_configs"])
                elif rephraser_optimizer_configs["schedule_method"] == "rsqrt":
                    self.actor_scheduler = RsqrtScheduler(
                        optimizer=self.actor_optimizer,
                        **rephraser_optimizer_configs["scheduler_configs"])
                    self.critic_scheduler = RsqrtScheduler(
                        optimizer=self.critic_optimizer,
                        **rephraser_optimizer_configs["scheduler_configs"])
                else:
                    WARN(
                        "Unknown scheduler name {0}. Do not use lr_scheduling."
                        .format(
                            rephraser_optimizer_configs['schedule_method']))
                    self.actor_scheduler = None
                    self.critic_scheduler = None
            else:
                self.actor_scheduler = None
                self.critic_scheduler = None

    def sync_from(self, sac_agent):
        with sac_agent.soft_update_lock, self.soft_update_lock:
            self.actor.sync_from(sac_agent.actor)
            self.critic.sync_from(sac_agent.critic)

    def train(self, training=True):
        # default training is true
        self.training = training
        self.actor.train(training)
        self.critic.train(training)
        return self

    def update_critic(self,
                      states,
                      masks,
                      rephrase_positions,
                      actions,
                      rewards,
                      survive_and_no_maxs,
                      target_critic,
                      update_step,
                      discount_factor,
                      summary_writer,
                      update_trust_region=0.8):
        """
        update critic by using a target_critic net (usually a global critic model) and a buffer
        SARSA for TD learning
        :param states:
        :param masks:
        :param rephrase_positions:
        :param actions: actions
        :param rewards: the rewards
        :param survive_and_no_maxs: able to rollout next step for TD learning

        :param target_critic: provides target value estimation(global model is usually on cpu)
        :param discount_factor: for discounted rewards update
        :param update_step: learning steps
        :param update_trust_region: discount for loss updates
        """

        label_emb = slice_by_indices(
            states, rephrase_positions,
            device=self.device)  # next_rephrase_positions to label emb
        next_action, log_probs = self.actor.sample_normal(states,
                                                          1. - masks,
                                                          label_emb,
                                                          reparamization=True)
        log_probs = log_probs.sum(dim=-1, keepdims=True)
        next_states = transition(states, masks, actions, rephrase_positions)
        next_rephrase_positions = rephrase_positions + 1
        next_label_emb = slice_by_indices(next_states,
                                          next_rephrase_positions,
                                          device=self.device)

        # # note that with intrinsic curiosity module, the rewards will add curiosity bonus
        # self.actor.icm.eval()
        # rephrase_feature = self.actor.preprocess(states, 1.-masks, label_emb)
        # next_rephrase_feature = self.actor.preprocess(next_states, 1.-masks, next_label_emb)
        # bonus = self.actor.icm.get_surprise_bonus(rephrase_feature, next_rephrase_feature, actions).detach()
        # bonus = 0.01 * bonus * survive_and_no_maxs
        # print("bonus:", bonus.sum())
        # rewards += bonus
        # print("rewards:", rewards.squeeze())

        # note that log_probs has the same dimension with the action. thus the log_prob of a whole action is the sum along dimensions.
        target_critic.eval()
        target_V = target_critic(
            next_states, 1. - masks, next_label_emb,
            next_action) - log_probs * self.alpha.detach()
        target_Q = rewards + (
            survive_and_no_maxs
        ) * discount_factor * target_V  #  we have next states for TD learning rollout
        target_Q = target_Q.detach()

        # get current Q estimates
        current_Q = self.critic(states, 1. - masks, label_emb, actions)
        critic_loss = F.mse_loss(current_Q, target_Q)
        critic_loss *= update_trust_region
        print("critic_loss", critic_loss.sum())

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        if self.critic_scheduler is not None:
            self.critic_scheduler.step(global_step=update_step)
        critic_loss.backward()
        self.critic_optimizer.step()

        # logging: entropy/target_entropy ratio, critic_loss,
        summary_writer.add_scalar("critic_loss",
                                  scalar_value=critic_loss,
                                  global_step=update_step)

    def update_actor_and_alpha(self,
                               states,
                               masks,
                               rephrase_positions,
                               target_critic,
                               update_step,
                               summary_writer,
                               update_trust_region=0.5):
        """
        :param states: tensor states from the buffer samples
        :param masks: indicats the valid token positions
        :param rephrase_positions: induce the next states by the given states
        :param update_trust_region: current annnunciator trust_acc (valid).
        served as a trust region for RL updates;
        also the weight of rewind or reinforce
        trust_acc * rewind_loss + (1-trust_acc) * policy_loss
        """
        self.actor.train()
        label_emb = slice_by_indices(states,
                                     rephrase_positions,
                                     device=self.device)
        actions, log_probs = self.actor.sample_normal(states,
                                                      1. - masks,
                                                      label_emb,
                                                      reparamization=True)
        log_probs = log_probs.sum(dim=-1, keepdims=True)
        actor_Q = self.critic(states, 1. - masks, label_emb, actions)
        policy_loss = (self.alpha.detach() * log_probs - actor_Q).mean()

        summary_writer.add_scalar('policy_loss', policy_loss, update_step)
        summary_writer.add_scalar('entropy_ratio',
                                  -log_probs.mean() / self.target_entropy,
                                  update_step)

        # the policy rewind loss, the rewind is determined by target value estimates (estimated survival + improvements)
        # negative means rewind needed.
        target_Q = target_critic(states, 1. - masks, label_emb,
                                 actions).detach()
        rewind_mask = target_Q.lt(0.).detach().float()  # [batch, 1]
        next_states = transition(states, masks, actions,
                                 rephrase_positions).detach()
        next_label_emb = slice_by_indices(next_states,
                                          rephrase_positions,
                                          device=self.device).detach()
        rewind_action, _ = self.actor.forward(next_states, 1. - masks,
                                              next_label_emb)
        target_actions = -actions.detach()
        rewind_loss = F.mse_loss(
            rewind_action * rewind_mask * self.actor.action_range,
            target_actions * rewind_mask)
        summary_writer.add_scalar('rewind_loss', rewind_loss, update_step)

        # the higher trust region means less indicative the perturbations are, policy should focus more on the rewind.
        actor_loss = (update_trust_region) * rewind_loss + (
            1. - update_trust_region) * policy_loss
        ## update the intrinsic reward module: action reconstruction and feature prediction mse
        # self.actor.icm.train()
        # next_states = transition(states, masks, actions, rephrase_positions)
        # next_label_emb = slice_by_indices(next_states, rephrase_positions, device=self.device)
        # rephrase_feature = self.actor.preprocess(states, 1.0-masks, label_emb)
        # next_rephrase_feature = self.actor.preprocess(next_states, 1.0-masks, next_label_emb)
        # if update_step<3000:
        #     # icm updates does not propagate to the policy on the early stage
        #     icm_loss = self.actor.icm(rephrase_feature.detach(), next_rephrase_feature.detach(), actions)
        # else:
        #     icm_loss = self.actor.icm(rephrase_feature, next_rephrase_feature, actions)
        # summary_writer.add_scalar("intrinsic_curiosity_loss", icm_loss, update_step)
        # # the 0.1 is the setting by Intrinsic curiosity learning
        # actor_loss = actor_loss + icm_loss

        # optimize the actor
        self.actor_optimizer.zero_grad()
        # self.actor_icm_optimizer.zero_grad()
        if self.actor_scheduler is not None:
            self.actor_scheduler.step(global_step=update_step)
        actor_loss.backward()
        self.actor_optimizer.step()
        # self.actor_icm_optimizer.step()

        if self.learnable_temperature:
            self.log_alpha_optimizer.zero_grad()
            alpha_loss = (self.alpha *
                          (-log_probs - self.target_entropy).detach()).mean()
            summary_writer.add_scalar('alpha_loss', alpha_loss, update_step)
            summary_writer.add_scalar('alpha', self.alpha, update_step)
            alpha_loss.backward()
            self.log_alpha_optimizer.step()

    def update_local_net(self,
                         local_agent_configs,
                         replay_buffer,
                         target_critic,
                         update_step,
                         discount_factor,
                         summary_writer,
                         update_trust_region=1.0):
        """
        :param local_agent_configs: provides agent update freq
        :param replay_buffer: provides the SARSA listed below
            [states, masks, actions, rephrase_positions, rewards, terminal_signals]
            states: the embedding as states. [batch, len, emb_dim] float
            masks: the indicator of valid token for embedding. [batch, len] float
                actions: the action embedding on the position. [batch, emb_dim] float
            rephrase_positions: the position to rephrase. [batch, 1]  long
            rewards: the rewards for the transition. [batch, 1] float
            terminal_signals: the terminal signals for the transition. [batch, 1] float
        :param target_critic: provides the global critic
        :param update_step: for lr scheduler and logging
        :param discount_factor: rollout-rewards discount
        :param summary_writer: logging
        :param update_trust_region: discount for loss updates
        """
        learn_batch_size = local_agent_configs["rephraser_learning_batch"]

        states, masks, \
        actions, rephrase_positions, \
        rewards, _, survive_and_no_maxs = replay_buffer.sample(learn_batch_size, device=self.device)
        INFO("update local agent critics on device: %s" % self.device)
        self.update_critic(states, masks, rephrase_positions, actions, rewards,
                           survive_and_no_maxs, target_critic, update_step,
                           discount_factor, summary_writer,
                           update_trust_region)
        if update_step % local_agent_configs["actor_update_freq"] == 0:
            INFO("update local agent policy on device: %s" % self.device)
            self.update_actor_and_alpha(states, masks, rephrase_positions,
                                        target_critic, update_step,
                                        summary_writer, update_trust_region)

    def soft_update_target_net(self, target_SACAgent, tau):
        # soft update the target network. first move to CPU, than move back to local
        # mind not to update global model while reading and synch local models.
        self.to(target_SACAgent.device)
        with target_SACAgent.soft_update_lock:
            for param, target_param in zip(
                    self.critic.parameters(),
                    target_SACAgent.critic.parameters()):
                target_param.data.copy_(tau * param.data +
                                        (1 - tau) * target_param.data)

            for param, target_param in zip(self.actor.parameters(),
                                           target_SACAgent.actor.parameters()):
                target_param.data.copy_(tau * param.data +
                                        (1 - tau) * target_param.data)
        self.to(self.device)
Exemple #2
0
def train(FLAGS):
    """
    FLAGS:
        saveto: str
        reload: store_true
        config_path: str
        pretrain_path: str, default=""
        model_name: str
        log_path: str
    """

    # write log of training to file.
    write_log_to_file(
        os.path.join(FLAGS.log_path,
                     "%s.log" % time.strftime("%Y%m%d-%H%M%S")))

    GlobalNames.USE_GPU = FLAGS.use_gpu

    if GlobalNames.USE_GPU:
        CURRENT_DEVICE = "cpu"
    else:
        CURRENT_DEVICE = "cuda:0"

    config_path = os.path.abspath(FLAGS.config_path)
    with open(config_path.strip()) as f:
        configs = yaml.load(f)

    INFO(pretty_configs(configs))

    # Add default configs
    configs = default_configs(configs)
    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    optimizer_configs = configs['optimizer_configs']
    training_configs = configs['training_configs']

    GlobalNames.SEED = training_configs['seed']

    set_seed(GlobalNames.SEED)

    best_model_prefix = os.path.join(
        FLAGS.saveto, FLAGS.model_name + GlobalNames.MY_BEST_MODEL_SUFFIX)

    timer = Timer()

    # ================================================================================== #
    # Load Data

    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_tgt = Vocabulary(**data_configs["vocabularies"][0])

    train_batch_size = training_configs["batch_size"] * max(
        1, training_configs["update_cycle"])
    train_buffer_size = training_configs["buffer_size"] * max(
        1, training_configs["update_cycle"])

    train_bitext_dataset = ZipDataset(TextLineDataset(
        data_path=data_configs['train_data'][0],
        vocabulary=vocab_tgt,
        max_len=data_configs['max_len'][0],
    ),
                                      shuffle=training_configs['shuffle'])

    valid_bitext_dataset = ZipDataset(
        TextLineDataset(
            data_path=data_configs['valid_data'][0],
            vocabulary=vocab_tgt,
        ))

    training_iterator = DataIterator(
        dataset=train_bitext_dataset,
        batch_size=train_batch_size,
        use_bucket=training_configs['use_bucket'],
        buffer_size=train_buffer_size,
        batching_func=training_configs['batching_key'])

    valid_iterator = DataIterator(
        dataset=valid_bitext_dataset,
        batch_size=training_configs['valid_batch_size'],
        use_bucket=True,
        buffer_size=100000,
        numbering=True)

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    lrate = optimizer_configs['learning_rate']
    is_early_stop = False

    # ================================ Begin ======================================== #
    # Build Model & Optimizer
    # We would do steps below on after another
    #     1. build models & criterion
    #     2. move models & criterion to gpu if needed
    #     3. load pre-trained model if needed
    #     4. build optimizer
    #     5. build learning rate scheduler if needed
    #     6. load checkpoints if needed

    # 0. Initial
    model_collections = Collections()
    checkpoint_saver = Saver(
        save_prefix="{0}.ckpt".format(
            os.path.join(FLAGS.saveto, FLAGS.model_name)),
        num_max_keeping=training_configs['num_kept_checkpoints'])
    best_model_saver = Saver(
        save_prefix=best_model_prefix,
        num_max_keeping=training_configs['num_kept_best_model'])

    # 1. Build Model & Criterion
    INFO('Building model...')
    timer.tic()
    lm_model = build_model(n_tgt_vocab=vocab_tgt.max_n_words, **model_configs)
    INFO(lm_model)

    params_total = sum([p.numel() for n, p in lm_model.named_parameters()])
    params_with_embedding = sum([
        p.numel() for n, p in lm_model.named_parameters()
        if n.find('embedding') == -1
    ])
    INFO('Total parameters: {}'.format(params_total))
    INFO('Total parameters (excluding word embeddings): {}'.format(
        params_with_embedding))

    critic = NMTCriterion(label_smoothing=model_configs['label_smoothing'])

    INFO(critic)
    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # 2. Move to GPU
    if GlobalNames.USE_GPU:
        lm_model = lm_model.cuda()
        critic = critic.cuda()

    # 3. Load pretrained model if needed
    lm_model.init_parameters(FLAGS.pretrain_path, device=CURRENT_DEVICE)

    # 4. Build optimizer
    INFO('Building Optimizer...')
    optim = Optimizer(name=optimizer_configs['optimizer'],
                      model=lm_model,
                      lr=lrate,
                      grad_clip=optimizer_configs['grad_clip'],
                      optim_args=optimizer_configs['optimizer_params'])

    # 5. Build scheduler for optimizer if needed
    if optimizer_configs['schedule_method'] is not None:

        if optimizer_configs['schedule_method'] == "loss":

            scheduler = ReduceOnPlateauScheduler(
                optimizer=optim, **optimizer_configs["scheduler_configs"])

        elif optimizer_configs['schedule_method'] == "noam":
            scheduler = NoamScheduler(optimizer=optim,
                                      **optimizer_configs['scheduler_configs'])
        else:
            WARN(
                "Unknown scheduler name {0}. Do not use lr_scheduling.".format(
                    optimizer_configs['schedule_method']))
            scheduler = None
    else:
        scheduler = None

    # 6. build moving average

    if training_configs['moving_average_method'] is not None:
        ma = MovingAverage(
            moving_average_method=training_configs['moving_average_method'],
            named_params=lm_model.named_parameters(),
            alpha=training_configs['moving_average_alpha'])
    else:
        ma = None

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # Reload from latest checkpoint
    if FLAGS.reload:
        checkpoint_saver.load_latest(model=lm_model,
                                     optim=optim,
                                     lr_scheduler=scheduler,
                                     collections=model_collections,
                                     ma=ma)

    # ================================================================================== #
    # Prepare training

    eidx = model_collections.get_collection("eidx", [0])[-1]
    uidx = model_collections.get_collection("uidx", [0])[-1]
    bad_count = model_collections.get_collection("bad_count", [0])[-1]
    oom_count = model_collections.get_collection("oom_count", [0])[-1]

    summary_writer = SummaryWriter(log_dir=FLAGS.log_path)

    cum_samples = 0
    cum_words = 0
    valid_loss = best_valid_loss = float('inf')  # Max Float
    saving_files = []

    # Timer for computing speed
    timer_for_speed = Timer()
    timer_for_speed.tic()

    INFO('Begin training...')

    while True:
        summary_writer.add_scalar("Epoch", (eidx + 1), uidx)

        # Build iterator and progress bar
        training_iter = training_iterator.build_generator()
        training_progress_bar = tqdm(desc=' - (Epc {}, Upd {}) '.format(
            eidx, uidx),
                                     total=len(training_iterator),
                                     unit="sents")
        for batch in training_iter:

            uidx += 1

            if optimizer_configs[
                    "schedule_method"] is not None and optimizer_configs[
                        "schedule_method"] != "loss":
                scheduler.step(global_step=uidx)

            seqs_y = batch

            n_samples_t = len(seqs_y)
            n_words_t = sum(len(s) for s in seqs_y)

            cum_samples += n_samples_t
            cum_words += n_words_t

            train_loss = 0.
            optim.zero_grad()
            try:
                # Prepare data
                for (seqs_y_t, ) in split_shard(
                        seqs_y, split_size=training_configs['update_cycle']):
                    y = prepare_data(seqs_y_t, cuda=GlobalNames.USE_GPU)

                    loss = compute_forward(
                        model=lm_model,
                        critic=critic,
                        # seqs_x=x,
                        seqs_y=y,
                        eval=False,
                        normalization=n_samples_t,
                        norm_by_words=training_configs["norm_by_words"])
                    train_loss += loss / y.size(
                        1) if not training_configs["norm_by_words"] else loss
                optim.step()

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    oom_count += 1
                    optim.zero_grad()
                else:
                    raise e

            if ma is not None and eidx >= training_configs[
                    'moving_average_start_epoch']:
                ma.step()

            training_progress_bar.update(n_samples_t)
            training_progress_bar.set_description(
                ' - (Epc {}, Upd {}) '.format(eidx, uidx))
            training_progress_bar.set_postfix_str(
                'TrainLoss: {:.2f}, ValidLoss(best): {:.2f} ({:.2f})'.format(
                    train_loss, valid_loss, best_valid_loss))
            summary_writer.add_scalar("train_loss",
                                      scalar_value=train_loss,
                                      global_step=uidx)

            # ================================================================================== #
            # Display some information
            if should_trigger_by_steps(
                    uidx, eidx, every_n_step=training_configs['disp_freq']):
                # words per second and sents per second
                words_per_sec = cum_words / (timer.toc(return_seconds=True))
                sents_per_sec = cum_samples / (timer.toc(return_seconds=True))
                lrate = list(optim.get_lrate())[0]

                summary_writer.add_scalar("Speed(words/sec)",
                                          scalar_value=words_per_sec,
                                          global_step=uidx)
                summary_writer.add_scalar("Speed(sents/sen)",
                                          scalar_value=sents_per_sec,
                                          global_step=uidx)
                summary_writer.add_scalar("lrate",
                                          scalar_value=lrate,
                                          global_step=uidx)
                summary_writer.add_scalar("oom_count",
                                          scalar_value=oom_count,
                                          global_step=uidx)

                # Reset timer
                timer.tic()
                cum_words = 0
                cum_samples = 0

            # ================================================================================== #
            # Saving checkpoints
            if should_trigger_by_steps(
                    uidx,
                    eidx,
                    every_n_step=training_configs['save_freq'],
                    debug=FLAGS.debug):
                model_collections.add_to_collection("uidx", uidx)
                model_collections.add_to_collection("eidx", eidx)
                model_collections.add_to_collection("bad_count", bad_count)

                if not is_early_stop:
                    checkpoint_saver.save(global_step=uidx,
                                          model=lm_model,
                                          optim=optim,
                                          lr_scheduler=scheduler,
                                          collections=model_collections,
                                          ma=ma)

            # ================================================================================== #
            # Loss Validation & Learning rate annealing
            if should_trigger_by_steps(
                    global_step=uidx,
                    n_epoch=eidx,
                    every_n_step=training_configs['loss_valid_freq'],
                    debug=FLAGS.debug):

                if ma is not None:
                    origin_state_dict = deepcopy(lm_model.state_dict())
                    lm_model.load_state_dict(ma.export_ma_params(),
                                             strict=False)

                valid_loss = loss_validation(
                    model=lm_model,
                    critic=critic,
                    valid_iterator=valid_iterator,
                    norm_by_words=training_configs["norm_by_words"])

                model_collections.add_to_collection("history_losses",
                                                    valid_loss)

                min_history_loss = np.array(
                    model_collections.get_collection("history_losses")).min()

                summary_writer.add_scalar("loss", valid_loss, global_step=uidx)
                summary_writer.add_scalar("best_loss",
                                          min_history_loss,
                                          global_step=uidx)

                if ma is not None:
                    lm_model.load_state_dict(origin_state_dict)
                    del origin_state_dict

                if optimizer_configs["schedule_method"] == "loss":
                    scheduler.step(metric=best_valid_loss)

                # If model get new best valid loss
                if valid_loss < best_valid_loss:
                    bad_count = 0

                    if is_early_stop is False:
                        # 1. save the best model
                        torch.save(lm_model.state_dict(),
                                   best_model_prefix + ".final")

                        # 2. record all several best models
                        best_model_saver.save(global_step=uidx, model=lm_model)
                else:
                    bad_count += 1

                    # At least one epoch should be traversed
                    if bad_count >= training_configs[
                            'early_stop_patience'] and eidx > 0:
                        is_early_stop = True
                        WARN("Early Stop!")

                best_valid_loss = min_history_loss

                summary_writer.add_scalar("bad_count", bad_count, uidx)

                INFO("{0} Loss: {1:.2f} lrate: {2:6f} patience: {3}".format(
                    uidx, valid_loss, lrate, bad_count))

        training_progress_bar.close()

        eidx += 1
        if eidx > training_configs["max_epochs"]:
            break
Exemple #3
0
def tune(flags):
    """
    flags:
        saveto: str
        reload: store_true
        config_path: str
        pretrain_path: str, default=""
        model_name: str
        log_path: str
    """

    # ================================================================================== #
    # Initialization for training on different devices
    # - CPU/GPU
    # - Single/Distributed
    Constants.USE_GPU = flags.use_gpu

    if flags.multi_gpu:
        dist.distributed_init(flags.shared_dir)
        world_size = dist.get_world_size()
        rank = dist.get_rank()
        local_rank = dist.get_local_rank()
    else:
        world_size = 1
        rank = 0
        local_rank = 0

    if Constants.USE_GPU:
        torch.cuda.set_device(local_rank)
        Constants.CURRENT_DEVICE = "cuda:{0}".format(local_rank)
    else:
        Constants.CURRENT_DEVICE = "cpu"

    # If not root_rank, close logging
    # else write log of training to file.
    if rank == 0:
        write_log_to_file(
            os.path.join(flags.log_path,
                         "%s.log" % time.strftime("%Y%m%d-%H%M%S")))
    else:
        close_logging()

    # ================================================================================== #
    # Parsing configuration files
    # - Load default settings
    # - Load pre-defined settings
    # - Load user-defined settings

    configs = prepare_configs(flags.config_path, flags.predefined_config)

    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    optimizer_configs = configs['optimizer_configs']
    training_configs = configs['training_configs']

    INFO(pretty_configs(configs))

    Constants.SEED = training_configs['seed']

    set_seed(Constants.SEED)

    timer = Timer()

    # ================================================================================== #
    # Load Data
    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_src = Vocabulary.build_from_file(**data_configs['vocabularies'][0])
    vocab_tgt = Vocabulary.build_from_file(**data_configs['vocabularies'][1])

    Constants.EOS = vocab_src.eos
    Constants.PAD = vocab_src.pad
    Constants.BOS = vocab_src.bos
    # bt tag dataset
    train_bitext_dataset = ZipDataset(
        TextLineDataset(data_path=data_configs['train_data'][0],
                        vocabulary=vocab_src,
                        max_len=data_configs['max_len'][0],
                        is_train_dataset=True),
        TextLineDataset(data_path=data_configs['train_data'][1],
                        vocabulary=vocab_tgt,
                        max_len=data_configs['max_len'][1],
                        is_train_dataset=True))

    training_iterator = DataIterator(
        dataset=train_bitext_dataset,
        batch_size=training_configs["batch_size"],
        use_bucket=training_configs['use_bucket'],
        buffer_size=training_configs['buffer_size'],
        batching_func=training_configs['batching_key'],
        world_size=world_size,
        rank=rank)

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # ================================ Begin ======================================== #
    # Build Model & Optimizer
    # We would do steps below on after another
    #     1. build models & criterion
    #     2. move models & criterion to gpu if needed
    #     3. load pre-trained model if needed
    #     4. build optimizer
    #     5. build learning rate scheduler if needed
    #     6. load checkpoints if needed

    # 0. Initial

    lrate = optimizer_configs['learning_rate']
    model_collections = Collections()

    checkpoint_saver = Saver(
        save_prefix="{0}.ckpt".format(
            os.path.join(flags.saveto, flags.model_name)),
        num_max_keeping=training_configs['num_kept_checkpoints'])
    best_model_prefix = os.path.join(
        flags.saveto, flags.model_name + Constants.MY_BEST_MODEL_SUFFIX)
    best_model_saver = Saver(
        save_prefix=best_model_prefix,
        num_max_keeping=training_configs['num_kept_best_model'])

    # 1. Build Model & Criterion
    INFO('Building model...')
    timer.tic()
    nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                            n_tgt_vocab=vocab_tgt.max_n_words,
                            padding_idx=vocab_src.pad,
                            vocab_src=vocab_src,
                            vocab_tgt=vocab_tgt,
                            **model_configs)
    INFO(nmt_model)

    critic = NMTCriterion(label_smoothing=model_configs['label_smoothing'],
                          padding_idx=vocab_tgt.pad)

    INFO(critic)

    # 2. Move to GPU
    if Constants.USE_GPU:
        nmt_model = nmt_model.cuda()
        critic = critic.cuda()

    # 3. Load pretrained model if needed
    load_pretrained_model(nmt_model,
                          flags.pretrain_path,
                          exclude_prefix=flags.pretrain_exclude_prefix,
                          device=Constants.CURRENT_DEVICE)
    # froze_parameters
    froze_params(nmt_model, flags.froze_config)

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # 4. Build optimizer
    INFO('Building Optimizer...')

    if not flags.multi_gpu:
        optim = Optimizer(name=optimizer_configs['optimizer'],
                          model=nmt_model,
                          lr=lrate,
                          grad_clip=optimizer_configs['grad_clip'],
                          optim_args=optimizer_configs['optimizer_params'],
                          update_cycle=training_configs['update_cycle'])
    else:
        optim = dist.DistributedOptimizer(
            name=optimizer_configs['optimizer'],
            model=nmt_model,
            lr=lrate,
            grad_clip=optimizer_configs['grad_clip'],
            optim_args=optimizer_configs['optimizer_params'],
            device_id=local_rank)

    # 5. Build scheduler for optimizer if needed
    scheduler = build_scheduler(
        schedule_method=optimizer_configs['schedule_method'],
        optimizer=optim,
        scheduler_configs=optimizer_configs['scheduler_configs'])

    # 6. build moving average
    if training_configs['moving_average_method'] is not None:
        ma = MovingAverage(
            moving_average_method=training_configs['moving_average_method'],
            named_params=nmt_model.named_parameters(),
            alpha=training_configs['moving_average_alpha'])
    else:
        ma = None

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # Reload from latest checkpoint
    if flags.reload:
        checkpoint_saver.load_latest(model=nmt_model,
                                     optim=optim,
                                     lr_scheduler=scheduler,
                                     collections=model_collections,
                                     ma=ma,
                                     device=Constants.CURRENT_DEVICE)

    # broadcast parameters and optimizer states
    if world_size > 1:
        INFO("Broadcasting model parameters...")
        dist.broadcast_parameters(params=nmt_model.state_dict())
        INFO("Broadcasting optimizer states...")
        dist.broadcast_optimizer_state(optimizer=optim.optim)
        INFO('Done.')

    # ================================================================================== #
    # Prepare training

    eidx = model_collections.get_collection("eidx", [0])[-1]
    uidx = model_collections.get_collection("uidx", [1])[-1]
    bad_count = model_collections.get_collection("bad_count", [0])[-1]
    oom_count = model_collections.get_collection("oom_count", [0])[-1]
    is_early_stop = model_collections.get_collection("is_early_stop", [
        False,
    ])[-1]

    train_loss_meter = AverageMeter()
    sent_per_sec_meter = TimeMeter()
    tok_per_sec_meter = TimeMeter()

    update_cycle = training_configs['update_cycle']
    grad_denom = 0
    train_loss = 0.0
    cum_n_words = 0
    valid_loss = best_valid_loss = float('inf')

    if rank == 0:
        summary_writer = SummaryWriter(log_dir=flags.log_path)
    else:
        summary_writer = None

    sent_per_sec_meter.start()
    tok_per_sec_meter.start()

    INFO('Begin training...')

    while True:

        if summary_writer is not None:
            summary_writer.add_scalar("Epoch", (eidx + 1), uidx)

        # Build iterator and progress bar
        training_iter = training_iterator.build_generator()

        if rank == 0:
            training_progress_bar = tqdm(desc=' - (Epc {}, Upd {}) '.format(
                eidx, uidx),
                                         total=len(training_iterator),
                                         unit="sents")
        else:
            training_progress_bar = None
        # INFO(Constants.USE_BT)
        for batch in training_iter:
            # bt attrib data
            seqs_x, seqs_y = batch

            batch_size = len(seqs_x)
            cum_n_words += sum(len(s) for s in seqs_y)

            try:
                # Prepare data
                x, y = prepare_data(seqs_x, seqs_y, cuda=Constants.USE_GPU)

                loss = compute_forward(
                    model=nmt_model,
                    critic=critic,
                    seqs_x=x,
                    seqs_y=y,
                    eval=False,
                    normalization=1.0,
                    norm_by_words=training_configs["norm_by_words"])

                update_cycle -= 1
                grad_denom += batch_size
                train_loss += loss

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    oom_count += 1
                else:
                    raise e

            # When update_cycle becomes 0, it means end of one batch. Several things will be done:
            # - update parameters
            # - reset update_cycle and grad_denom, update uidx
            # - learning rate scheduling
            # - update moving average

            if update_cycle == 0:

                # 0. reduce variables
                if world_size > 1:
                    grad_denom = dist.all_reduce_py(grad_denom)
                    train_loss = dist.all_reduce_py(train_loss)
                    cum_n_words = dist.all_reduce_py(cum_n_words)

                # 1. update parameters
                optim.step(denom=grad_denom)
                optim.zero_grad()

                if training_progress_bar is not None:
                    training_progress_bar.update(grad_denom)
                    training_progress_bar.set_description(
                        ' - (Epc {}, Upd {}) '.format(eidx, uidx))

                    postfix_str = 'TrainLoss: {:.2f}, ValidLoss(best): {:.2f} ({:.2f}), '.format(
                        train_loss, valid_loss, best_valid_loss)
                    training_progress_bar.set_postfix_str(postfix_str)

                # 2. learning rate scheduling
                if scheduler is not None and optimizer_configs[
                        "schedule_method"] != "loss":
                    scheduler.step(global_step=uidx)

                # 3. update moving average
                if ma is not None and eidx >= training_configs[
                        'moving_average_start_epoch']:
                    ma.step()

                # 4. update meters
                train_loss_meter.update(train_loss, grad_denom)
                sent_per_sec_meter.update(grad_denom)
                tok_per_sec_meter.update(cum_n_words)

                # 5. reset accumulated variables, update uidx
                update_cycle = training_configs['update_cycle']
                grad_denom = 0
                uidx += 1
                cum_n_words = 0.0
                train_loss = 0.0

            else:
                continue

            # ================================================================================== #
            # Display some information
            if should_trigger_by_steps(
                    uidx, eidx, every_n_step=training_configs['disp_freq']):

                lrate = list(optim.get_lrate())[0]

                if summary_writer is not None:
                    summary_writer.add_scalar(
                        "Speed(sents/sec)",
                        scalar_value=sent_per_sec_meter.ave,
                        global_step=uidx)
                    summary_writer.add_scalar(
                        "Speed(words/sec)",
                        scalar_value=tok_per_sec_meter.ave,
                        global_step=uidx)
                    summary_writer.add_scalar(
                        "train_loss",
                        scalar_value=train_loss_meter.ave,
                        global_step=uidx)
                    summary_writer.add_scalar("lrate",
                                              scalar_value=lrate,
                                              global_step=uidx)
                    summary_writer.add_scalar("oom_count",
                                              scalar_value=oom_count,
                                              global_step=uidx)

                # Reset Meters
                sent_per_sec_meter.reset()
                tok_per_sec_meter.reset()
                train_loss_meter.reset()

            # ================================================================================== #
            # Saving checkpoints
            # if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['save_freq'], debug=flags.debug):
            #     model_collections.add_to_collection("uidx", uidx)
            #     model_collections.add_to_collection("eidx", eidx)
            #     model_collections.add_to_collection("bad_count", bad_count)
            #
            #     if not is_early_stop:
            #         if rank == 0:
            #             checkpoint_saver.save(global_step=uidx,
            #                                   model=nmt_model,
            #                                   optim=optim,
            #                                   lr_scheduler=scheduler,
            #                                   collections=model_collections,
            #                                   ma=ma)

        torch.save(nmt_model.state_dict(), best_model_prefix + ".final")

        if training_progress_bar is not None:
            training_progress_bar.close()

        eidx += 1
        if eidx > training_configs["max_epochs"]:
            break
Exemple #4
0
def train(rank,
          device,
          args,
          counter,
          lock,
          attack_configs,
          discriminator_configs,
          src_vocab,
          trg_vocab,
          data_set,
          global_attacker,
          attacker_configs,
          optimizer=None,
          scheduler=None,
          saver=None):
    """
    running train process
    #1# train the env_discriminator
    #2# run attacker AC based on rewards from trained env_discriminator
    #3# run training updates attacker AC
    #4#
    :param rank: (int) the rank of the process (from multiprocess)
    :param device: the device of the process
    :param counter: python multiprocess variable
    :param lock: python multiprocess variable
    :param args: global args
    :param attack_configs: attack settings
    :param discriminator_configs: discriminator settings
    :param src_vocab:
    :param trg_vocab:
    :param data_set: (data_iterator object) provide batched data labels
    :param global_attacker: the model to sync from
    :param attacker_configs: local attacker settings
    :param optimizer: uses shared optimizer for the attacker
            use local one if none
    :param scheduler: uses shared scheduler for the attacker,
            use local one if none
    :param saver: model saver
    :return:
    """
    trust_acc = acc_bound = discriminator_configs["acc_bound"]
    converged_bound = discriminator_configs["converged_bound"]
    patience = discriminator_configs["patience"]
    attacker_model_configs = attacker_configs["attacker_model_configs"]
    attacker_optimizer_configs = attacker_configs["attacker_optimizer_configs"]

    # this is for multi-processing, GlobalNames can not be direct inherited
    GlobalNames.USE_GPU = args.use_gpu
    GlobalNames.SEED = attack_configs["seed"]
    torch.manual_seed(GlobalNames.SEED + rank)

    # initiate local saver and load checkpoint if possible
    local_saver = Saver(save_prefix="{0}.local".format(
        os.path.join(args.save_to, "train_env%d" % rank, "ACmodel")),
                        num_max_keeping=attack_configs["num_kept_checkpoints"])

    attack_iterator = DataIterator(dataset=data_set,
                                   batch_size=attack_configs["batch_size"],
                                   use_bucket=True,
                                   buffer_size=attack_configs["buffer_size"],
                                   numbering=True)

    summary_writer = SummaryWriter(
        log_dir=os.path.join(args.save_to, "train_env%d" % rank))
    local_attacker = attacker.Attacker(src_vocab.max_n_words,
                                       **attacker_model_configs)
    # build optimizer for attacker
    if optimizer is None:
        optimizer = Optimizer(
            name=attacker_optimizer_configs["optimizer"],
            model=global_attacker,
            lr=attacker_optimizer_configs["learning_rate"],
            grad_clip=attacker_optimizer_configs["grad_clip"],
            optim_args=attacker_optimizer_configs["optimizer_params"])
        # Build scheduler for optimizer if needed
        if attacker_optimizer_configs['schedule_method'] is not None:
            if attacker_optimizer_configs['schedule_method'] == "loss":
                scheduler = ReduceOnPlateauScheduler(
                    optimizer=optimizer,
                    **attacker_optimizer_configs["scheduler_configs"])
            elif attacker_optimizer_configs['schedule_method'] == "noam":
                scheduler = NoamScheduler(
                    optimizer=optimizer,
                    **attacker_optimizer_configs['scheduler_configs'])
            elif attacker_optimizer_configs["schedule_method"] == "rsqrt":
                scheduler = RsqrtScheduler(
                    optimizer=optimizer,
                    **attacker_optimizer_configs["scheduler_configs"])
            else:
                WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".
                     format(attacker_optimizer_configs['schedule_method']))
                scheduler = None
        else:
            scheduler = None

    local_saver.load_latest(model=local_attacker,
                            optim=optimizer,
                            lr_scheduler=scheduler)

    attacker_iterator = attack_iterator.build_generator()
    env = Translate_Env(attack_configs=attack_configs,
                        discriminator_configs=discriminator_configs,
                        src_vocab=src_vocab,
                        trg_vocab=trg_vocab,
                        data_iterator=attacker_iterator,
                        save_to=args.save_to,
                        device=device)
    episode_count = 0
    episode_length = 0
    local_steps = 0  # optimization steps: for learning rate schedules
    patience_t = patience
    while True:  # infinite loop of data set
        # we will continue with a new iterator with refreshed environments
        # whenever the last iterator breaks with "StopIteration"
        attacker_iterator = attack_iterator.build_generator()
        env.reset_data_iter(attacker_iterator)
        padded_src = env.reset()
        padded_src = torch.from_numpy(padded_src)
        if device != "cpu":
            padded_src = padded_src.to(device)
        done = True
        discriminator_base_steps = local_steps

        while True:
            # check for update of discriminator
            # if env.acc_validation(local_attacker, use_gpu=True if env.device != "cpu" else False) < 0.55:
            if episode_count % attacker_configs["attacker_update_steps"] == 0:
                """ stop criterion:
                when updates a discriminator, we check for acc. If acc fails acc_bound,
                we reset the discriminator and try, until acc reaches the bound with patience.
                otherwise the training thread stops
                """
                try:
                    discriminator_base_steps, trust_acc = env.update_discriminator(
                        local_attacker,
                        discriminator_base_steps,
                        min_update_steps=discriminator_configs[
                            "acc_valid_freq"],
                        max_update_steps=discriminator_configs[
                            "discriminator_update_steps"],
                        accuracy_bound=acc_bound,
                        summary_writer=summary_writer)
                except StopIteration:
                    INFO("finish one training epoch, reset data_iterator")
                    break

                discriminator_base_steps += 1  # a flag to label the discriminator updates
                if trust_acc < converged_bound:  # GAN target reached
                    patience_t -= 1
                    INFO(
                        "discriminator reached GAN convergence bound: %d times"
                        % patience_t)
                else:  # reset patience if discriminator is refreshed
                    patience_t = patience

            if saver and local_steps % attack_configs["save_freq"] == 0:
                local_saver.save(global_step=local_steps,
                                 model=local_attacker,
                                 optim=optimizer,
                                 lr_scheduler=scheduler)

                if trust_acc < converged_bound:  # and patience_t == patience-1:
                    # we only save the global params reaching acc_bound
                    torch.save(global_attacker.state_dict(),
                               os.path.join(args.save_to, "ACmodel.final"))
                    # saver.raw_save(model=global_attacker)

            if patience_t == 0:
                WARN("maximum patience reached. Training Thread should stop")
                break

            local_attacker.train()  # switch back to training mode

            # for a initial (reset) attacker from global parameters
            if done:
                INFO("sync from global model")
                local_attacker.load_state_dict(global_attacker.state_dict())
            # move the local attacker params back to device after updates
            local_attacker = local_attacker.to(device)
            values = []  # training critic: network outputs
            log_probs = []
            rewards = []  # actual rewards
            entropies = []

            local_steps += 1
            # run sequences step of attack
            try:
                for i in range(args.action_roll_steps):
                    episode_length += 1
                    attack_out, critic_out = local_attacker(
                        padded_src, padded_src[:, env.index - 1:env.index + 2])
                    logit_attack_out = torch.log(attack_out)
                    entropy = -(attack_out *
                                logit_attack_out).sum(dim=-1).mean()

                    summary_writer.add_scalar("action_entropy",
                                              scalar_value=entropy,
                                              global_step=local_steps)
                    entropies.append(entropy)  # for entropy loss
                    actions = attack_out.multinomial(num_samples=1).detach()
                    # only extract the log prob for chosen action (avg over batch)
                    log_attack_out = logit_attack_out.gather(-1,
                                                             actions).mean()
                    padded_src, reward, terminal_signal = env.step(
                        actions.squeeze())
                    done = terminal_signal or episode_length > args.max_episode_lengths

                    with lock:
                        counter.value += 1

                    if done:
                        episode_length = 0
                        padded_src = env.reset()

                    padded_src = torch.from_numpy(padded_src)
                    if device != "cpu":
                        padded_src = padded_src.to(device)

                    values.append(
                        critic_out.mean())  # list of torch variables (scalar)
                    log_probs.append(
                        log_attack_out)  # list of torch variables (scalar)
                    rewards.append(reward)  # list of reward variables

                    if done:
                        episode_count += 1
                        break
            except StopIteration:
                INFO("finish one training epoch, reset data_iterator")
                break

            R = torch.zeros(1, 1)
            gae = torch.zeros(1, 1)
            if device != "cpu":
                R = R.to(device)
                gae = gae.to(device)

            if not done:  # calculate value loss
                value = local_attacker.get_critic(
                    padded_src, padded_src[:, env.index - 1:env.index + 2])
                R = value.mean().detach()

            values.append(R)
            policy_loss = 0
            value_loss = 0

            # collect values for training
            for i in reversed((range(len(rewards)))):
                # value loss and policy loss must be clipped to stabilize training
                R = attack_configs["gamma"] * R + rewards[i]
                advantage = R - values[i]
                value_loss = value_loss + 0.5 * advantage.pow(2)

                delta_t = rewards[i] + attack_configs["gamma"] * \
                          values[i + 1] - values[i]
                gae = gae * attack_configs["gamma"] * attack_configs["tau"] + \
                      delta_t
                policy_loss = policy_loss - log_probs[i] * gae.detach() - \
                              attack_configs["entropy_coef"] * entropies[i]
                print("policy_loss", policy_loss)
                print("gae", gae)

            # update with optimizer
            optimizer.zero_grad()
            # we decay the loss according to discriminator's accuracy as a trust region constrain
            summary_writer.add_scalar("policy_loss",
                                      scalar_value=policy_loss * trust_acc,
                                      global_step=local_steps)
            summary_writer.add_scalar("value_loss",
                                      scalar_value=value_loss * trust_acc,
                                      global_step=local_steps)

            total_loss = trust_acc * policy_loss + \
                         trust_acc * attack_configs["value_coef"] * value_loss
            total_loss.backward()

            if attacker_optimizer_configs[
                    "schedule_method"] is not None and attacker_optimizer_configs[
                        "schedule_method"] != "loss":
                scheduler.step(global_step=local_steps)

            # move the model params to CPU and
            # assign local gradients to the global model to update
            local_attacker.to("cpu").ensure_shared_grads(global_attacker)
            optimizer.step()
            print("bingo!")

        if patience_t == 0:
            INFO("Reach maximum Discriminator patience, Finish")
            break
Exemple #5
0
def train(FLAGS):
    """
    FLAGS:
        saveto: str
        reload: store_true
        config_path: str
        pretrain_path: str, default=""
        model_name: str
        log_path: str
    """

    # ================================================================================== #
    # Initialization for training on different devices
    # - CPU/GPU
    # - Single/Distributed
    GlobalNames.USE_GPU = FLAGS.use_gpu

    if FLAGS.multi_gpu:

        if hvd is None or distributed is None:
            ERROR("Distributed training is disable. Please check the installation of Horovod.")

        hvd.init()
        world_size = hvd.size()
        rank = hvd.rank()
        local_rank = hvd.local_rank()
    else:
        world_size = 1
        rank = 0
        local_rank = 0

    if GlobalNames.USE_GPU:
        torch.cuda.set_device(local_rank)
        CURRENT_DEVICE = "cuda:{0}".format(local_rank)
    else:
        CURRENT_DEVICE = "cpu"

    # If not root_rank, close logging
    if rank != 0:
        close_logging()

    # write log of training to file.
    if rank == 0:
        write_log_to_file(os.path.join(FLAGS.log_path, "%s.log" % time.strftime("%Y%m%d-%H%M%S")))

    # ================================================================================== #
    # Parsing configuration files

    config_path = os.path.abspath(FLAGS.config_path)
    with open(config_path.strip()) as f:
        configs = yaml.load(f)

    INFO(pretty_configs(configs))

    # Add default configs
    configs = default_baseline_configs(configs)
    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    optimizer_configs = configs['optimizer_configs']
    training_configs = configs['training_configs']

    GlobalNames.SEED = training_configs['seed']

    set_seed(GlobalNames.SEED)

    timer = Timer()

    # ================================================================================== #
    # Load Data

    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_src = Vocabulary(**data_configs["vocabularies"][0])
    vocab_tgt = Vocabulary(**data_configs["vocabularies"][1])

    actual_buffer_size = training_configs["buffer_size"] * max(1, training_configs["update_cycle"])

    train_bitext_dataset = ZipDataset(
        TextLineDataset(data_path=data_configs['train_data'][0],
                        vocabulary=vocab_src,
                        max_len=data_configs['max_len'][0],
                        ),
        TextLineDataset(data_path=data_configs['train_data'][1],
                        vocabulary=vocab_tgt,
                        max_len=data_configs['max_len'][1],
                        )
    )

    valid_bitext_dataset = ZipDataset(
        TextLineDataset(data_path=data_configs['valid_data'][0],
                        vocabulary=vocab_src,
                        ),
        TextLineDataset(data_path=data_configs['valid_data'][1],
                        vocabulary=vocab_tgt,
                        )
    )

    training_iterator = DataIterator(dataset=train_bitext_dataset,
                                     batch_size=training_configs["batch_size"],
                                     use_bucket=training_configs['use_bucket'],
                                     buffer_size=actual_buffer_size,
                                     batching_func=training_configs['batching_key'],
                                     world_size=world_size,
                                     rank=rank)

    valid_iterator = DataIterator(dataset=valid_bitext_dataset,
                                  batch_size=training_configs['valid_batch_size'],
                                  use_bucket=True, buffer_size=100000, numbering=True,
                                  world_size=world_size, rank=rank)

    bleu_scorer = SacreBLEUScorer(reference_path=data_configs["bleu_valid_reference"],
                                  num_refs=data_configs["num_refs"],
                                  lang_pair=data_configs["lang_pair"],
                                  sacrebleu_args=training_configs["bleu_valid_configs"]['sacrebleu_args'],
                                  postprocess=training_configs["bleu_valid_configs"]['postprocess']
                                  )

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    lrate = optimizer_configs['learning_rate']
    is_early_stop = False

    # ================================ Begin ======================================== #
    # Build Model & Optimizer
    # We would do steps below on after another
    #     1. build models & criterion
    #     2. move models & criterion to gpu if needed
    #     3. load pre-trained model if needed
    #     4. build optimizer
    #     5. build learning rate scheduler if needed
    #     6. load checkpoints if needed

    # 0. Initial
    model_collections = Collections()
    best_model_prefix = os.path.join(FLAGS.saveto, FLAGS.model_name + GlobalNames.MY_BEST_MODEL_SUFFIX)

    checkpoint_saver = Saver(save_prefix="{0}.ckpt".format(os.path.join(FLAGS.saveto, FLAGS.model_name)),
                             num_max_keeping=training_configs['num_kept_checkpoints']
                             )
    best_model_saver = Saver(save_prefix=best_model_prefix, num_max_keeping=training_configs['num_kept_best_model'])

    INFO('Building model...')
    timer.tic()
    nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                            n_tgt_vocab=vocab_tgt.max_n_words, **model_configs)
    INFO(nmt_model)

    critic = NMTCriterion(label_smoothing=model_configs['label_smoothing'])

    INFO(critic)
    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # 2. Move to GPU
    if GlobalNames.USE_GPU:
        nmt_model = nmt_model.cuda()
        critic = critic.cuda()

    # 3. Load pretrained model if needed
    load_pretrained_model(nmt_model, FLAGS.pretrain_path, exclude_prefix=None, device=CURRENT_DEVICE)

    # 4. Build optimizer
    INFO('Building Optimizer...')
    optim = Optimizer(name=optimizer_configs['optimizer'],
                      model=nmt_model,
                      lr=lrate,
                      grad_clip=optimizer_configs['grad_clip'],
                      optim_args=optimizer_configs['optimizer_params'],
                      distributed=True if world_size > 1 else False,
                      update_cycle=training_configs['update_cycle']
                      )
    # 5. Build scheduler for optimizer if needed
    if optimizer_configs['schedule_method'] is not None:

        if optimizer_configs['schedule_method'] == "loss":

            scheduler = ReduceOnPlateauScheduler(optimizer=optim,
                                                 **optimizer_configs["scheduler_configs"]
                                                 )

        elif optimizer_configs['schedule_method'] == "noam":
            scheduler = NoamScheduler(optimizer=optim, **optimizer_configs['scheduler_configs'])
        else:
            WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".format(optimizer_configs['schedule_method']))
            scheduler = None
    else:
        scheduler = None

    # 6. build moving average

    if training_configs['moving_average_method'] is not None:
        ma = MovingAverage(moving_average_method=training_configs['moving_average_method'],
                           named_params=nmt_model.named_parameters(),
                           alpha=training_configs['moving_average_alpha'])
    else:
        ma = None

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # Reload from latest checkpoint
    if FLAGS.reload:
        checkpoint_saver.load_latest(model=nmt_model, optim=optim, lr_scheduler=scheduler,
                                     collections=model_collections, ma=ma)

    # broadcast parameters and optimizer states
    if world_size > 1:
        hvd.broadcast_parameters(params=nmt_model.state_dict(), root_rank=0)
        hvd.broadcast_optimizer_state(optimizer=optim.optim, root_rank=0)

    # ================================================================================== #
    # Prepare training

    eidx = model_collections.get_collection("eidx", [0])[-1]
    uidx = model_collections.get_collection("uidx", [1])[-1]
    bad_count = model_collections.get_collection("bad_count", [0])[-1]
    oom_count = model_collections.get_collection("oom_count", [0])[-1]
    cum_n_samples = 0
    cum_n_words = 0
    best_valid_loss = 1.0 * 1e10  # Max Float
    update_cycle = training_configs['update_cycle']
    grad_denom = 0

    if rank == 0:
        summary_writer = SummaryWriter(log_dir=FLAGS.log_path)
    else:
        summary_writer = None

    # Timer for computing speed
    timer_for_speed = Timer()
    timer_for_speed.tic()

    INFO('Begin training...')

    while True:

        if summary_writer is not None:
            summary_writer.add_scalar("Epoch", (eidx + 1), uidx)

        # Build iterator and progress bar
        training_iter = training_iterator.build_generator()

        if rank == 0:
            training_progress_bar = tqdm(desc='  - (Epoch %d)   ' % eidx,
                                         total=len(training_iterator),
                                         unit="sents"
                                         )
        else:
            training_progress_bar = None

        for batch in training_iter:

            seqs_x, seqs_y = batch

            batch_size = len(seqs_x)

            cum_n_samples += batch_size
            cum_n_words += sum(len(s) for s in seqs_y)

            try:
                # Prepare data
                x, y = prepare_data(seqs_x, seqs_y, cuda=GlobalNames.USE_GPU)

                loss = compute_forward(model=nmt_model,
                                       critic=critic,
                                       seqs_x=x,
                                       seqs_y=y,
                                       eval=False,
                                       normalization=1.0,
                                       norm_by_words=training_configs["norm_by_words"])

                update_cycle -= 1
                grad_denom += batch_size

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    oom_count += 1
                else:
                    raise e

            # When update_cycle becomes 0, it means end of one batch. Several things will be done:
            # - update parameters
            # - reset update_cycle and grad_denom
            # - update uidx
            # - update moving average

            if update_cycle == 0:
                if world_size > 1:
                    grad_denom = distributed.all_reduce(grad_denom)

                optim.step(denom=grad_denom)
                optim.zero_grad()

                if training_progress_bar is not None:
                    training_progress_bar.update(grad_denom)

                update_cycle = training_configs['update_cycle']
                grad_denom = 0

                uidx += 1

                if scheduler is None:
                    pass
                elif optimizer_configs["schedule_method"] == "loss":
                    scheduler.step(metric=best_valid_loss)
                else:
                    scheduler.step(global_step=uidx)

                if ma is not None and eidx >= training_configs['moving_average_start_epoch']:
                    ma.step()
            else:
                continue

            # ================================================================================== #
            # Display some information
            if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['disp_freq']):

                if world_size > 1:
                    cum_n_words = sum(distributed.all_gather(cum_n_words))
                    cum_n_samples = sum(distributed.all_gather(cum_n_samples))

                # words per second and sents per second
                words_per_sec = cum_n_words / (timer.toc(return_seconds=True))
                sents_per_sec = cum_n_samples / (timer.toc(return_seconds=True))
                lrate = list(optim.get_lrate())[0]

                if summary_writer is not None:
                    summary_writer.add_scalar("Speed(words/sec)", scalar_value=words_per_sec, global_step=uidx)
                    summary_writer.add_scalar("Speed(sents/sen)", scalar_value=sents_per_sec, global_step=uidx)
                    summary_writer.add_scalar("lrate", scalar_value=lrate, global_step=uidx)
                    summary_writer.add_scalar("oom_count", scalar_value=oom_count, global_step=uidx)

                # Reset timer
                timer.tic()
                cum_n_words = 0
                cum_n_samples = 0

            # ================================================================================== #
            # Loss Validation & Learning rate annealing
            if should_trigger_by_steps(global_step=uidx, n_epoch=eidx, every_n_step=training_configs['loss_valid_freq'],
                                       debug=FLAGS.debug):

                valid_loss = loss_validation(model=nmt_model,
                                             critic=critic,
                                             valid_iterator=valid_iterator,
                                             rank=rank,
                                             world_size=world_size
                                             )

                model_collections.add_to_collection("history_losses", valid_loss)

                min_history_loss = np.array(model_collections.get_collection("history_losses")).min()

                best_valid_loss = min_history_loss

                if summary_writer is not None:
                    summary_writer.add_scalar("loss", valid_loss, global_step=uidx)
                    summary_writer.add_scalar("best_loss", min_history_loss, global_step=uidx)

            # ================================================================================== #
            # BLEU Validation & Early Stop

            if should_trigger_by_steps(global_step=uidx, n_epoch=eidx,
                                       every_n_step=training_configs['bleu_valid_freq'],
                                       min_step=training_configs['bleu_valid_warmup'],
                                       debug=FLAGS.debug):

                valid_bleu = bleu_validation(uidx=uidx,
                                             valid_iterator=valid_iterator,
                                             batch_size=training_configs["bleu_valid_batch_size"],
                                             model=nmt_model,
                                             bleu_scorer=bleu_scorer,
                                             vocab_tgt=vocab_tgt,
                                             valid_dir=FLAGS.valid_path,
                                             max_steps=training_configs["bleu_valid_configs"]["max_steps"],
                                             beam_size=training_configs["bleu_valid_configs"]["beam_size"],
                                             alpha=training_configs["bleu_valid_configs"]["alpha"],
                                             world_size=world_size,
                                             rank=rank,
                                             )

                model_collections.add_to_collection(key="history_bleus", value=valid_bleu)

                best_valid_bleu = float(np.array(model_collections.get_collection("history_bleus")).max())

                if summary_writer is not None:
                    summary_writer.add_scalar("bleu", valid_bleu, uidx)
                    summary_writer.add_scalar("best_bleu", best_valid_bleu, uidx)

                # If model get new best valid bleu score
                if valid_bleu >= best_valid_bleu:
                    bad_count = 0

                    if is_early_stop is False:
                        if rank == 0:
                            # 1. save the best model
                            torch.save(nmt_model.state_dict(), best_model_prefix + ".final")

                            # 2. record all several best models
                            best_model_saver.save(global_step=uidx, model=nmt_model, ma=ma)
                else:
                    bad_count += 1

                    # At least one epoch should be traversed
                    if bad_count >= training_configs['early_stop_patience'] and eidx > 0:
                        is_early_stop = True
                        WARN("Early Stop!")

                if summary_writer is not None:
                    summary_writer.add_scalar("bad_count", bad_count, uidx)

                INFO("{0} Loss: {1:.2f} BLEU: {2:.2f} lrate: {3:6f} patience: {4}".format(
                    uidx, valid_loss, valid_bleu, lrate, bad_count
                ))

            # ================================================================================== #
            # Saving checkpoints
            if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['save_freq'], debug=FLAGS.debug):
                model_collections.add_to_collection("uidx", uidx)
                model_collections.add_to_collection("eidx", eidx)
                model_collections.add_to_collection("bad_count", bad_count)

                if not is_early_stop:
                    if rank == 0:
                        checkpoint_saver.save(global_step=uidx,
                                              model=nmt_model,
                                              optim=optim,
                                              lr_scheduler=scheduler,
                                              collections=model_collections,
                                              ma=ma)

        if training_progress_bar is not None:
            training_progress_bar.close()

        eidx += 1
        if eidx > training_configs["max_epochs"]:
            break
Exemple #6
0
def train(FLAGS):
    """
    FLAGS:
        saveto: str
        reload: store_true
        config_path: str
        pretrain_path: str, default=""
        model_name: str
        log_path: str
    """

    # write log of training to file.
    write_log_to_file(os.path.join(FLAGS.log_path, "%s.log" % time.strftime("%Y%m%d-%H%M%S")))

    GlobalNames.USE_GPU = FLAGS.use_gpu

    if GlobalNames.USE_GPU:
        CURRENT_DEVICE = "cpu"
    else:
        CURRENT_DEVICE = "cuda:0"

    config_path = os.path.abspath(FLAGS.config_path)
    with open(config_path.strip()) as f:
        configs = yaml.load(f)

    INFO(pretty_configs(configs))

    # Add default configs
    configs = default_configs(configs)
    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    optimizer_configs = configs['optimizer_configs']
    training_configs = configs['training_configs']

    GlobalNames.SEED = training_configs['seed']

    set_seed(GlobalNames.SEED)

    best_model_prefix = os.path.join(FLAGS.saveto, FLAGS.model_name + GlobalNames.MY_BEST_MODEL_SUFFIX)

    timer = Timer()

    # ================================================================================== #
    # Load Data

    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_src = Vocabulary(**data_configs["vocabularies"][0])
    vocab_tgt = Vocabulary(**data_configs["vocabularies"][1])

    train_batch_size = training_configs["batch_size"] * max(1, training_configs["update_cycle"])
    train_buffer_size = training_configs["buffer_size"] * max(1, training_configs["update_cycle"])

    train_bitext_dataset = ZipDataset(
        TextLineDataset(data_path=data_configs['train_data'][0],
                        vocabulary=vocab_src,
                        max_len=data_configs['max_len'][0],
                        ),
        TextLineDataset(data_path=data_configs['train_data'][1],
                        vocabulary=vocab_tgt,
                        max_len=data_configs['max_len'][1],
                        ),
        shuffle=training_configs['shuffle']
    )

    valid_bitext_dataset = ZipDataset(
        TextLineDataset(data_path=data_configs['valid_data'][0],
                        vocabulary=vocab_src,
                        ),
        TextLineDataset(data_path=data_configs['valid_data'][1],
                        vocabulary=vocab_tgt,
                        )
    )

    training_iterator = DataIterator(dataset=train_bitext_dataset,
                                     batch_size=train_batch_size,
                                     use_bucket=training_configs['use_bucket'],
                                     buffer_size=train_buffer_size,
                                     batching_func=training_configs['batching_key'])

    valid_iterator = DataIterator(dataset=valid_bitext_dataset,
                                  batch_size=training_configs['valid_batch_size'],
                                  use_bucket=True, buffer_size=100000, numbering=True)

    bleu_scorer = SacreBLEUScorer(reference_path=data_configs["bleu_valid_reference"],
                                  num_refs=data_configs["num_refs"],
                                  lang_pair=data_configs["lang_pair"],
                                  sacrebleu_args=training_configs["bleu_valid_configs"]['sacrebleu_args'],
                                  postprocess=training_configs["bleu_valid_configs"]['postprocess']
                                  )

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    lrate = optimizer_configs['learning_rate']
    is_early_stop = False

    # ================================ Begin ======================================== #
    # Build Model & Optimizer
    # We would do steps below on after another
    #     1. build models & criterion
    #     2. move models & criterion to gpu if needed
    #     3. load pre-trained model if needed
    #     4. build optimizer
    #     5. build learning rate scheduler if needed
    #     6. load checkpoints if needed

    # 0. Initial
    model_collections = Collections()
    checkpoint_saver = Saver(save_prefix="{0}.ckpt".format(os.path.join(FLAGS.saveto, FLAGS.model_name)),
                             num_max_keeping=training_configs['num_kept_checkpoints']
                             )
    best_model_saver = Saver(save_prefix=best_model_prefix, num_max_keeping=training_configs['num_kept_best_model'])

    # 1. Build Model & Criterion
    INFO('Building model...')
    timer.tic()
    nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                            n_tgt_vocab=vocab_tgt.max_n_words, **model_configs)
    INFO(nmt_model)

    critic = NMTCriterion(label_smoothing=model_configs['label_smoothing'])

    INFO(critic)
    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # 2. Move to GPU
    if GlobalNames.USE_GPU:
        nmt_model = nmt_model.cuda()
        critic = critic.cuda()

    # 3. Load pretrained model if needed
    load_pretrained_model(nmt_model, FLAGS.pretrain_path, exclude_prefix=None, device=CURRENT_DEVICE)

    # 4. Build optimizer
    INFO('Building Optimizer...')
    optim = Optimizer(name=optimizer_configs['optimizer'],
                      model=nmt_model,
                      lr=lrate,
                      grad_clip=optimizer_configs['grad_clip'],
                      optim_args=optimizer_configs['optimizer_params']
                      )
    # 5. Build scheduler for optimizer if needed
    if optimizer_configs['schedule_method'] is not None:

        if optimizer_configs['schedule_method'] == "loss":

            scheduler = ReduceOnPlateauScheduler(optimizer=optim,
                                                 **optimizer_configs["scheduler_configs"]
                                                 )

        elif optimizer_configs['schedule_method'] == "noam":
            scheduler = NoamScheduler(optimizer=optim, **optimizer_configs['scheduler_configs'])
        else:
            WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".format(optimizer_configs['schedule_method']))
            scheduler = None
    else:
        scheduler = None

    # 6. build EMA
    if training_configs['ema_decay'] > 0.0:
        ema = ExponentialMovingAverage(named_params=nmt_model.named_parameters(), decay=training_configs['ema_decay'])
    else:
        ema = None

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # Reload from latest checkpoint
    if FLAGS.reload:
        checkpoint_saver.load_latest(model=nmt_model, optim=optim, lr_scheduler=scheduler,
                                     collections=model_collections)

    # ================================================================================== #
    # Prepare training

    eidx = model_collections.get_collection("eidx", [0])[-1]
    uidx = model_collections.get_collection("uidx", [0])[-1]
    bad_count = model_collections.get_collection("bad_count", [0])[-1]

    summary_writer = SummaryWriter(log_dir=FLAGS.log_path)

    cum_samples = 0
    cum_words = 0
    best_valid_loss = 1.0 * 1e10  # Max Float
    saving_files = []

    # Timer for computing speed
    timer_for_speed = Timer()
    timer_for_speed.tic()

    INFO('Begin training...')

    while True:

        summary_writer.add_scalar("Epoch", (eidx + 1), uidx)

        # Build iterator and progress bar
        training_iter = training_iterator.build_generator()
        training_progress_bar = tqdm(desc='  - (Epoch %d)   ' % eidx,
                                     total=len(training_iterator),
                                     unit="sents"
                                     )
        for batch in training_iter:

            uidx += 1

            if scheduler is None:
                pass
            elif optimizer_configs["schedule_method"] == "loss":
                scheduler.step(metric=best_valid_loss)
            else:
                scheduler.step(global_step=uidx)

            seqs_x, seqs_y = batch

            n_samples_t = len(seqs_x)
            n_words_t = sum(len(s) for s in seqs_y)

            cum_samples += n_samples_t
            cum_words += n_words_t

            training_progress_bar.update(n_samples_t)

            optim.zero_grad()

            # Prepare data
            for seqs_x_t, seqs_y_t in split_shard(seqs_x, seqs_y, split_size=training_configs['update_cycle']):
                x, y = prepare_data(seqs_x_t, seqs_y_t, cuda=GlobalNames.USE_GPU)

                loss = compute_forward(model=nmt_model,
                                       critic=critic,
                                       seqs_x=x,
                                       seqs_y=y,
                                       eval=False,
                                       normalization=n_samples_t,
                                       norm_by_words=training_configs["norm_by_words"])
            optim.step()

            if ema is not None:
                ema.step()

            # ================================================================================== #
            # Display some information
            if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['disp_freq']):
                # words per second and sents per second
                words_per_sec = cum_words / (timer.toc(return_seconds=True))
                sents_per_sec = cum_samples / (timer.toc(return_seconds=True))
                lrate = list(optim.get_lrate())[0]

                summary_writer.add_scalar("Speed(words/sec)", scalar_value=words_per_sec, global_step=uidx)
                summary_writer.add_scalar("Speed(sents/sen)", scalar_value=sents_per_sec, global_step=uidx)
                summary_writer.add_scalar("lrate", scalar_value=lrate, global_step=uidx)

                # Reset timer
                timer.tic()
                cum_words = 0
                cum_samples = 0

            # ================================================================================== #
            # Saving checkpoints
            if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['save_freq'], debug=FLAGS.debug):
                model_collections.add_to_collection("uidx", uidx)
                model_collections.add_to_collection("eidx", eidx)
                model_collections.add_to_collection("bad_count", bad_count)

                if not is_early_stop:

                    checkpoint_saver.save(global_step=uidx, model=nmt_model, optim=optim, lr_scheduler=scheduler,
                                          collections=model_collections, ema=ema)

            # ================================================================================== #
            # Loss Validation & Learning rate annealing
            if should_trigger_by_steps(global_step=uidx, n_epoch=eidx, every_n_step=training_configs['loss_valid_freq'],
                                       debug=FLAGS.debug):

                if ema is not None:
                    origin_state_dict = deepcopy(nmt_model.state_dict())
                    nmt_model.load_state_dict(ema.state_dict(), strict=False)

                valid_loss = loss_validation(model=nmt_model,
                                             critic=critic,
                                             valid_iterator=valid_iterator,
                                             )

                model_collections.add_to_collection("history_losses", valid_loss)

                min_history_loss = np.array(model_collections.get_collection("history_losses")).min()

                summary_writer.add_scalar("loss", valid_loss, global_step=uidx)
                summary_writer.add_scalar("best_loss", min_history_loss, global_step=uidx)

                best_valid_loss = min_history_loss

                if ema is not None:
                    nmt_model.load_state_dict(origin_state_dict)
                    del origin_state_dict

            # ================================================================================== #
            # BLEU Validation & Early Stop

            if should_trigger_by_steps(global_step=uidx, n_epoch=eidx,
                                       every_n_step=training_configs['bleu_valid_freq'],
                                       min_step=training_configs['bleu_valid_warmup'],
                                       debug=FLAGS.debug):

                if ema is not None:
                    origin_state_dict = deepcopy(nmt_model.state_dict())
                    nmt_model.load_state_dict(ema.state_dict(), strict=False)

                valid_bleu = bleu_validation(uidx=uidx,
                                             valid_iterator=valid_iterator,
                                             batch_size=training_configs["bleu_valid_batch_size"],
                                             model=nmt_model,
                                             bleu_scorer=bleu_scorer,
                                             vocab_tgt=vocab_tgt,
                                             valid_dir=FLAGS.valid_path,
                                             max_steps=training_configs["bleu_valid_configs"]["max_steps"],
                                             beam_size=training_configs["bleu_valid_configs"]["beam_size"],
                                             alpha=training_configs["bleu_valid_configs"]["alpha"]
                                             )

                model_collections.add_to_collection(key="history_bleus", value=valid_bleu)

                best_valid_bleu = float(np.array(model_collections.get_collection("history_bleus")).max())

                summary_writer.add_scalar("bleu", valid_bleu, uidx)
                summary_writer.add_scalar("best_bleu", best_valid_bleu, uidx)

                # If model get new best valid bleu score
                if valid_bleu >= best_valid_bleu:
                    bad_count = 0

                    if is_early_stop is False:
                        # 1. save the best model
                        torch.save(nmt_model.state_dict(), best_model_prefix + ".final")

                        # 2. record all several best models
                        best_model_saver.save(global_step=uidx, model=nmt_model)
                else:
                    bad_count += 1

                    # At least one epoch should be traversed
                    if bad_count >= training_configs['early_stop_patience'] and eidx > 0:
                        is_early_stop = True
                        WARN("Early Stop!")

                summary_writer.add_scalar("bad_count", bad_count, uidx)

                if ema is not None:
                    nmt_model.load_state_dict(origin_state_dict)
                    del origin_state_dict

                INFO("{0} Loss: {1:.2f} BLEU: {2:.2f} lrate: {3:6f} patience: {4}".format(
                    uidx, valid_loss, valid_bleu, lrate, bad_count
                ))

        training_progress_bar.close()

        eidx += 1
        if eidx > training_configs["max_epochs"]:
            break
Exemple #7
0
class Translate_Env(object):
    """
    wrap translate environment for multiple agents
    env needs parallel data to evaluate final bleu improvement

    stores the states as [current src embeddings, index], yields rewards at each step
    environment yields rewards based on scorer and finally by sentence-level BLEU
    :return: translation multiple sentences and return changed bleu
    """
    def __init__(self, reinforce_configs,
                 annunciator_configs,
                 src_vocab, trg_vocab,
                 data_iterator,
                 save_to,
                 device="cpu",
                 ):
        """
        initiate translation environments, needs a Scorer and translator
        :param reinforce_configs: attack configures dictionary
        :param annunciator_configs: discriminator or scorer configs(provide survive signals)
        :param save_to: path to save the model
        :param data_iterator: use to provide data for environment initiate
        the directory of the src sentences
        :param device: (string) devices to allocate variables("cpu", "cuda:*")
        default as cpu
        """
        # environment devices
        self.device = device
        self.data_iterator = data_iterator
        scorer_model_configs = annunciator_configs["scorer_model_configs"]
        # discriminator_model_configs = annunciator_configs["discriminator_model_configs"]
        annunciator_optim_configs = annunciator_configs["annunciator_optimizer_configs"]

        victim_config_path = reinforce_configs["victim_configs"]
        victim_model_path = reinforce_configs["victim_model"]
        with open(victim_config_path.strip()) as v_f:
            INFO("env open victim configs at %s" % victim_config_path)
            victim_configs = yaml.load(v_f, Loader=yaml.FullLoader)

        # to extract the embedding as representation
        # *vocab and *emb will provide psudo-reinforced embedding to train annunciator
        self.src_vocab = src_vocab
        self.trg_vocab = trg_vocab
        # translation model for BLEU(take src_embs as inputs) and corresponding embedding layers
        self.src_emb, self.trg_emb, self.translate_model = build_translate_model(
            victim_configs, victim_model_path,
            vocab_src=self.src_vocab, vocab_trg=self.trg_vocab,
            device=self.device)

        self.max_roll_out_step = victim_configs["data_configs"]["max_len"][0]
        self.src_emb.eval()  # source language embeddings
        self.trg_emb.eval()  # target language embeddings
        self.translate_model.eval()

        # the epsilon range used for action space when perturbation
        _, _, self.limit_dist = load_or_extract_near_vocab(
            config_path=victim_config_path, model_path=victim_model_path,
            init_perturb_rate=reinforce_configs["init_perturb_rate"],
            save_to=os.path.join(save_to, "near_vocab"),
            save_to_full=os.path.join(save_to, "full_near_vocab"),
            top_reserve=12, emit_as_id=True)

        #########################################################
        # scorer(an Annunciator object) provides intrinsic step rewards
        self.annunciator = TransScorer(
            victim_configs, victim_model_path, self.trg_emb,
            **scorer_model_configs)
        self.annunciator.to(self.device)
        # # discriminator(an Annunciator object) provides intrisic step rewards and terminal signal
        # self.discriminator = TransDiscriminator(
        #     victim_configs, victim_model_path,
        #     **discriminator_model_configs)
        # self.discriminator.to(self.device)
        # Annunciator update configs
        self.acc_bound = annunciator_configs["acc_bound"]
        self.mse_bound = annunciator_configs["mse_bound"]
        self.min_update_steps = annunciator_configs["valid_freq"]
        self.max_update_steps = annunciator_configs["annunciator_update_steps"]
        # the optimizer and schedule used for Annunciator update.
        self.optim_A = Optimizer(
            name=annunciator_optim_configs["optimizer"],
            model=self.annunciator,
            lr=annunciator_optim_configs["learning_rate"],
            grad_clip=annunciator_optim_configs["grad_clip"],
            optim_args=annunciator_optim_configs["optimizer_params"])

        self.scheduler_A = None  # default as None
        if annunciator_optim_configs['schedule_method'] is not None:
            if annunciator_optim_configs['schedule_method'] == "loss":
                self.scheduler_A = ReduceOnPlateauScheduler(optimizer=self.optim_A,
                                                            **annunciator_optim_configs["scheduler_configs"])
            elif annunciator_optim_configs['schedule_method'] == "noam":
                self.scheduler_A = NoamScheduler(optimizer=self.optim_A,
                                                 **annunciator_optim_configs['scheduler_configs'])
            elif annunciator_optim_configs["schedule_method"] == "rsqrt":
                self.scheduler_A = RsqrtScheduler(optimizer=self.optim_A,
                                                  **annunciator_optim_configs["scheduler_configs"])
            else:
                WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".format(
                    annunciator_optim_configs['schedule_method']))
        self.criterion_A = nn.CrossEntropyLoss()
        ############################################################
        self.adversarial = reinforce_configs["adversarial"]  # adversarial or reinforce as learning objects
        self.r_s_weight = reinforce_configs["r_s_weight"]
        self.r_i_weight = reinforce_configs["r_i_weight"]

    def _init_state(self, rephraser=None):
        """
        initiate batched sentences / origin_bleu / index (start from first label, no BOS/EOS)
        the initial state of the environment. (applied on the env's device)
        :return: env states (the src, index)
        """
        self.index = 1  # step index for perturbation
        self.origin_bleu = []  # saving origin BLEU

        batch = next(self.data_iterator)
        assert len(batch) == 3, "must be provided with line index (check for data_iterator)"

        # training, parallel trg is provided for evaluation (src grouped by similar length)
        _, seqs_x, self.seqs_y = batch
        self.sent_len = [len(x) for x in seqs_x]
        self.survival_signals = np.array([1] * len(seqs_x))  # the survival signals, 1 when true.

        # for reinforce inputs(embedding level).
        padded_src, padded_trg = self.prepare_data(
            seqs_x=seqs_x, seqs_y=self.seqs_y)
        self.x_emb = self.src_emb(padded_src).detach()  # float
        self.y_emb = self.trg_emb(padded_trg).detach()
        self.x_pad_indicator = padded_src.detach().eq(PAD)  # byte indicating PAD tokens
        self.y_pad_indicator = padded_trg.detach().eq(PAD)

        # randomly choose half of the sequence and perturbed by given agent
        # for self learning (rephraser can be on the cpu())
        if rephraser is not None:
            # self.x_emb, mask_to_UNK = rephraser.random_seq_perturb(
            #     self.x_emb, self.x_pad_indicator,
            #     half_mode=True, rand_act=False, enable_UNK=False)
            # self.x_emb = self.x_emb.to(self.device)
            # mask_to_UNK = mask_to_UNK.to(self.device)
            # # print("x_emb shape:", self.x_emb.shape, "mask_to_UNK shape:", mask_to_UNK.shape)
            # self.x_emb = self.x_emb*(1.-mask_to_UNK.float().unsqueeze(dim=2)) + \
            #              self.src_emb((UNK * mask_to_UNK).long())
            # self.x_emb = self.x_emb.detach()
            _, self.x_emb, _ = rephraser.random_seq_perturb(
                self.x_emb, self.x_pad_indicator, half_mode=True,
                rand_act=False)
            self.x_emb = self.x_emb.detach()

        # print(self.x_mask.shape, self.x_emb.shape)
        self.origin_result = self.translate()
        # calculate BLEU scores for the top candidate
        for index, sent_t in enumerate(self.seqs_y):
            bleu_t = bleu.sentence_bleu(references=[sent_t],
                                        hypothesis=self.origin_result[index],
                                        emulate_multibleu=True)
            self.origin_bleu.append(bleu_t)

        INFO("initialize env on: %s"%self.x_emb.device)
        return self.x_emb.cpu().numpy()

    def get_src_vocab(self):
        return self.src_vocab

    def reset(self, rephraser=None):
        """
        when the steps are exhausted.
        :param rephraser: rephraser is default None for no self-improving learning
        :return: reset environments' embedding
        """
        return self._init_state(rephraser)

    def reset_data_iter(self, data_iter):  # reset data iterator with provided iterator
        self.data_iterator = data_iter
        return

    def reset_annunciator(self):
        # a backup, would be deprecated
        self.annunciator.reset()

    def prepare_A_data(self, agent,
                       seqs_x, seqs_y,
                       batch_first=True,
                       half_mode=True,
                       rand_act=True):
        """
        use the current rephraser to generate data for Annunciator training
        perturbation will be applied to a random sequence step.
        (perturb all the former steps as the origin_emb, and perturb one more step as
        the perturbed_emb)
        such process will rephrase the entire batch.
        :param agent: prepare the data for scorer training (actor and critic)
        :param seqs_x: list of sources
        :param seqs_y: list of targets
        :param batch_first: first dimension of seqs be batch
        :param rand_act: sample the actions based on rephraser outputs
        :return: origin_x_emb, perturbed_x_emb, y_emb, x_mask, y_mask, flags
        """
        def _np_pad_batch_2D(samples, pad, batch_first=True):
            # pack seqs into tensor with pads
            batch_size = len(samples)
            sizes = [len(s) for s in samples]
            max_size = max(sizes)
            x_np = np.full((batch_size, max_size), fill_value=pad, dtype='int64')
            for ii in range(batch_size):
                x_np[ii, :sizes[ii]] = samples[ii]
            if batch_first is False:
                x_np = np.transpose(x_np, [1, 0])
            x = torch.tensor(x_np).to(self.device)
            return x
        seqs_x = list(map(lambda s: [BOS] + s + [EOS], seqs_x))
        x = _np_pad_batch_2D(samples=seqs_x, pad=PAD,
                             batch_first=batch_first)
        x_emb = self.src_emb(x).detach().to(self.device)
        x_pad_indicator = x.detach().eq(PAD).to(self.device)

        # # mere actor rollout
        # origin_x_emb, perturbed_x_emb, flags = rephraser.random_seq_perturb(
        #     x_emb, x_pad_indicator,
        #     half_mode=True, rand_act=rand_act)

        # actor rollout w/ critic's restriction
        with torch.no_grad():
            agent.actor.eval()
            agent.critic.eval()
            batch_size, max_seq_len = x_pad_indicator.shape
            perturbed_x_emb = x_emb.detach().clone()
            x_mask = 1 - x_pad_indicator.int()
            for t in range(1, max_seq_len-1):
                former_emb = perturbed_x_emb
                input_emb = former_emb[:, t-1:t+2, :]
                if rand_act:
                    actions, _ = agent.actor.sample_normal(
                        x_emb=former_emb, x_pad_indicator=x_pad_indicator,
                        label_emb=input_emb, reparamization=False)
                else:
                    mu, _ = agent.actor.forward(
                        x_emb=former_emb, x_pad_indicator=x_pad_indicator,
                        label_emb=input_emb)
                    actions = mu * agent.actor.action_range
                # actions shape [batch, emb_dim]
                critique = agent.critic(
                    x_emb=former_emb, x_pad_indicator=x_pad_indicator,
                    label_emb=input_emb, action=actions)
                # actions_masks shape [batch, 1]
                actions_mask = critique.gt(0).int() * x_mask[:, t].unsqueeze(dim=1)
                # mask unnecessary actions
                perturbed_x_emb[:,t,:] += actions * actions_mask

            flags = x_emb.new_ones(batch_size)
            if half_mode:
                flags = torch.bernoulli(0.5 * flags).to(x_emb.device)
            perturbed_x_emb = perturbed_x_emb * flags.unsqueeze(dim=1). unsqueeze(dim=2) \
                              + x_emb * (1-flags).unsqueeze(dim=1).unsqueeze(dim=2)

        origin_x_emb = x_emb
        seqs_y = list(map(lambda s: [BOS] + s + [EOS], seqs_y))
        y = _np_pad_batch_2D(seqs_y, pad=PAD,
                             batch_first=batch_first)
        y_emb = self.trg_emb(y).detach().to(self.device)
        y_pad_indicator = y.detach().eq(PAD).to(self.device)

        perturbed_x_emb.detach().to(self.device)
        return origin_x_emb, perturbed_x_emb, y_emb, x_pad_indicator, y_pad_indicator, flags.long()

    def prepare_data(self, seqs_x, seqs_y=None, batch_first=True):
        """
        prepare the batched, padded data with BOS and EOS for translation.
        used in initialization.
        Returns: padded data matrices (batch_size, max_seq_len)
        """
        def _np_pad_batch_2D(samples, pad, batch_first=True):
            batch_size = len(samples)
            sizes = [len(s) for s in samples]
            max_size = max(sizes)
            x_np = np.full((batch_size, max_size), fill_value=pad, dtype='int64')
            for ii in range(batch_size):
                x_np[ii, :sizes[ii]] = samples[ii]
            if batch_first is False:
                x_np = np.transpose(x_np, [1, 0])
            x = torch.tensor(x_np).to(self.device)
            return x

        seqs_x = list(map(lambda s: [BOS] + s + [EOS], seqs_x))
        x = _np_pad_batch_2D(samples=seqs_x, pad=PAD,
                             batch_first=batch_first)
        if seqs_y is None:
            return x
        seqs_y = list(map(lambda s: [BOS] + s + [EOS], seqs_y))
        y = _np_pad_batch_2D(seqs_y, pad=PAD,
                             batch_first=batch_first)
        return x, y

    def ratio_validation(self, agent, overall_contrast=True):
        """
        validate the mse of the environments scorer for the given rephraser
        used for checkpoints and other checks
        :param rephraser generates the data for validation.
        :return: the mse of the current scorer in environment.
        """
        # set victim encoder and scorer to evaluation mode
        self.annunciator.eval()
        # for i in range(5):
        try:
            batch = next(self.data_iterator)
        except StopIteration:
            batch = next(self.data_iterator)
        seq_nums, seqs_x, seqs_y = batch

        origin_x_emb, perturbed_x_emb, y_emb, x_mask, y_mask, flags = self.prepare_A_data(
            agent, seqs_x, seqs_y, half_mode=False, rand_act=False)
        origin_density_score = self.annunciator.get_density_score(
            origin_x_emb, x_mask, seqs_y)
        perturbed_density_score = self.annunciator.get_density_score(
            perturbed_x_emb, x_mask, seqs_y)
        density_score = origin_density_score/(origin_density_score+perturbed_density_score)
        if overall_contrast:
            return density_score.mean().item()
        else:
            return perturbed_density_score.mean().item()

    def acc_validation(self, agent):
        """
        validate the acc of the environments discriminator by a given rephraser
        used for checkpoints
        :param agent generates data for validation
        :return the accuracy of the discriminator to evaluation mode
        """
        self.annunciator.eval()
        acc = 0
        sample_count = 0
        for i in range(5):
            try:
                batch = next(self.data_iterator)
            except StopIteration:
                batch = next(self.data_iterator)
            seq_nums, seqs_x, seqs_y = batch
            origin_x_emb, perturbed_x_emb, y_emb, x_pad_indicator, y_pad_indicator, flags = \
                self.prepare_A_data(agent, seqs_x, seqs_y, half_mode=True)
            with torch.no_grad():
                preds = self.annunciator(perturbed_x_emb, x_pad_indicator,
                                         y_emb, y_pad_indicator).argmax(dim=-1)
                acc += torch.eq(preds, flags).sum()
                sample_count += preds.shape[0]
        acc = acc.float() / sample_count
        return acc.item()

    # def compute_P_forward(self,
    #                       origin_x_emb, perturbed_x_emb, x_mask,
    #                       evaluate=False):
    #     """
    #     process the victim encoder embedding and get CE loss
    #     :param origin_x_emb: float tensor, input embeddings of input tokens
    #     :param perturbed_x_emb: float tensor, perturbed inputs embeddings
    #     :param x_mask: byte tensor, mask of the input tokens
    #     :return: loss value
    #     """
    #     if not evaluate:
    #         # set components to training mode(dropout layers)
    #         self.scorer.train()
    #         with torch.enable_grad():
    #             loss = self.scorer(origin_x_emb, perturbed_x_emb, x_mask).mean()
    #         torch.autograd.backward(loss)
    #         return loss.item()
    #     else:
    #         # set components to evaluation mode(dropout layers)
    #         self.scorer.eval()
    #         with torch.enable_grad():
    #             loss = self.scorer(origin_x_emb, perturbed_x_emb, x_mask).mean()
    #     return loss.item()

    def compute_A_forward(self, x_emb, y_emb, x_pad_indicator, y_pad_indicator, gold_flags,
                          evaluate=False):
        """get loss according to criterion
        :param gold_flags=1 if perturbed, otherwise 0
        :param evaluate: False during training mode
        :return loss value
        """
        if not evaluate:
            # set components to training mode(dropout layers)
            self.annunciator.train()
            self.criterion_A.train()
            with torch.enable_grad():
                class_probs = self.annunciator(
                    x_emb, x_pad_indicator,
                    y_emb, y_pad_indicator)
                loss = self.criterion_A(class_probs, gold_flags)
            torch.autograd.backward(loss)
            return loss.item()
        else:
            # set components to evaluation mode(dropout layers)
            self.annunciator.eval()
            self.criterion_A.eval()
            with torch.no_grad():
                class_probs = self.annunciator(
                    x_emb, x_pad_indicator,
                    y_emb, y_pad_indicator)
                loss = self.criterion_A(class_probs, gold_flags)
        return loss.item()

    def update_annunciator(self,
                           agent,
                           base_steps=0,
                           min_update_steps=1,
                           max_update_steps=300,
                           accuracy_bound=0.8,
                           overall_update_weight=0.5,
                           summary_writer=None):
        """
        update discriminator using given rephraser
        :param agent: AC agent to generate training data for discriminator
        :param base_steps: used for saving
        :param min_update_steps: (integer) minimum update steps,
                    also the discriminator evaluate steps
        :param max_update_steps: (integer) maximum update steps
        :param accuracy_bound: (float) update until accuracy reaches the bound
                    (or max_update_steps)
        :param summary_writer: used to log discriminator learning information
        :return: steps and test accuracy as trust region
        """
        INFO("update annunciator")
        self.optim_A.zero_grad()
        agent.to(self.device)
        step = 0
        while True:
            try:
                batch = next(self.data_iterator)
            except StopIteration:
                batch = next(self.data_iterator)
            # update the discriminator
            step += 1
            if self.scheduler_A is not None:
                # override learning rate in self.optim_D
                self.scheduler_A.step(global_step=step)
            _, seqs_x, seqs_y = batch  # returned tensor type of the data
            try:
                x_emb, perturbed_x_emb, y_emb, x_pad_indicator, y_pad_indicator, flags = \
                    self.prepare_A_data(agent, seqs_x, seqs_y, half_mode=False, rand_act=True)
                loss = self.annunciator(x_emb, perturbed_x_emb, x_pad_indicator, seqs_y, overall_update_weight)
                # for name, p in self.annunciator.named_parameters():
                #     if "weight" in name:
                #         loss += torch.norm(p, 2)  # with l2-norm against overfitting
                torch.autograd.backward(loss)
                self.optim_A.step()
                print("annunciator loss:", loss)
            except RuntimeError as e:
                if "out of memory" in str(e):
                    print("WARNING: out of memory, skipping batch")
                    self.optim_A.zero_grad()
                else:
                    raise e

            # valid for accuracy / check for break (if any)
            if step % min_update_steps == 0:
                perturbed_density = self.ratio_validation(agent, overall_contrast=False)
                overall_density = self.ratio_validation(agent)
                if summary_writer is not None:
                    summary_writer.add_scalar("a_contrast_ratio", scalar_value=overall_density, global_step=base_steps+step)
                    summary_writer.add_scalar("a_ratio_src", scalar_value=perturbed_density, global_step=base_steps+step)
                print("overall density: %2f" % overall_density)
                if accuracy_bound and overall_density > accuracy_bound:
                    INFO("annunciator reached training bound, updated")
                    return base_steps+step, overall_density

            if step > max_update_steps:
                overall_density = self.ratio_validation(agent)
                perturbed_density = self.ratio_validation(agent, overall_contrast=False)
                print("overall density: %2f" % overall_density)
                INFO("Reach maximum annunciator update. Finished.")
                return base_steps+step, overall_density   # stop updates


    def translate(self, x_emb=None, x_mask=None):
        """
        translate by given embedding
        :param src_emb: if None, translate embeddings stored in the environments
        :param src_mask: input mask paired with embedding
        :return: list of translation results
        """
        if x_emb is None:  # original translation with original embedding
            x_emb = self.x_emb
            x_mask = self.x_pad_indicator

        with torch.no_grad():
            perturbed_results = beam_search(
                self.translate_model,
                beam_size=5, max_steps=150,
                src_embs=x_emb, src_mask=x_mask,
                alpha=-1.0)
        perturbed_results = perturbed_results.cpu().numpy().tolist()
        # only use the top result from the result
        result = []
        for sent in perturbed_results:
            sent = [wid for wid in sent[0] if wid != PAD]
            result.append(sent)
        return result

    def get_state(self):
        """
        retrieve states for the learning
        :return: the states of the environment
        """
        states = self.x_emb  # current sen embeddings, [batch_size, len, emb_dim]
        masks = 1. - self.x_pad_indicator.float()  # indicates valid tokens [batch, max_len]
        rephrase_positions = torch.tensor(np.array([self.index] * masks.shape[0])).unsqueeze(dim=-1).long()  # current state positions [batch, 1]
        survival_signals = torch.tensor(self.survival_signals).unsqueeze(dim=-1).float()  # [batch_size, 1]
        return states, masks, rephrase_positions, survival_signals

    def step(self, action):
        """
        step update for the environment: finally update self.index
        this is defined as inference of the environments
        states are returned in np.array
        :param action: tensor.variable as action input(in shape [batch, dim])
            on current index for step updates
        :return: updated states/ rewards/ terminal signal from the environments
                 reward (list of float), terminal_signal (list of boolean)
        """
        with torch.no_grad():
            self.annunciator.eval()
            batch_size, _ = action.shape
            batched_rewards = [0.] * batch_size
            if self.device != "cpu" and not action.is_cuda:
                WARN("mismatching action for gpu_env, move actions to %s"%self.device)
                action = action.to(self.device)

            # extract the step mask for actions and rewards
            inputs_mask = 1. - self.x_pad_indicator.float()
            inputs_mask = inputs_mask[:, self.index]  # slice at current step(index), mask of [batch]
            inputs_mask *= inputs_mask.new_tensor(self.survival_signals)  # mask those terminated

            # update current src embedding with action
            origin_emb = self.x_emb.clone().detach()
            # update embedding; cancel modification on PAD
            self.x_emb[:, self.index, :] += (action * inputs_mask.unsqueeze(dim=1))  # actions on PAD is masked

            # update survival_signals, which later determines whether rewards are valid for return
            # 1. mask survival by step and sent-len
            step_reward_mask = [int(self.index <= i) for i in self.sent_len]
            # 2. get batched sentence matching for survival signals on the current src state
            # # as the reward process (probs on ``survival'' as rewards)
            d_probs = self.annunciator.get_density_score(self.x_emb, self.x_pad_indicator, self.seqs_y)
            # print("dprobs:",d_probs)
            signals = d_probs.detach().lt(0.5).long().cpu().numpy().tolist()    # 1 as terminate
            # print("signals:", signals)

            if 1 in step_reward_mask:  # rollout reaches the sents length
                # 0 as survive, 1 as terminate
                probs = d_probs.detach().cpu().numpy()
                discriminate_index = d_probs.detach().lt(0.5).long()
                survival_mask = (1 - discriminate_index).cpu().numpy()
                survival_value = probs * survival_mask
                terminate_punishment = probs * discriminate_index.cpu().numpy()

                # looping for survival signals and step rewards
                origin_survival_signals = self.survival_signals.copy()
                for i in range(batch_size):
                    # update survivals signals
                    self.survival_signals[i] = self.survival_signals[i] * (1-signals[i]) * step_reward_mask[i]
                    if self.survival_signals[i]:
                        batched_rewards[i] += survival_value[i] * self.r_s_weight
                    elif origin_survival_signals[i]:
                        # punish once the survival signal flips
                        batched_rewards[i] -= terminate_punishment[i] * self.r_i_weight
            else:  # all dead, no need to calculate other rewards, it's ok to waste some samples
                return self.x_emb.cpu().numpy(), np.array(batched_rewards), self.survival_signals

            # additional episodic reward for surviving sequences (w/ finished sentence at current step)
            bleu_mask = [int(self.index == i) for i in self.sent_len]
            bleu_mask = [bleu_mask[i]*self.survival_signals[i] for i in range(batch_size)]
            if 1 in bleu_mask:
                # check for the finished line and mask out the others
                perturbed_results = self.translate(self.x_emb, self.x_pad_indicator)
                episodic_rewards = []
                for i, sent in enumerate(self.seqs_y):
                    if bleu_mask[i] == 1:
                        degraded_value = (self.origin_bleu[i]-bleu.sentence_bleu(
                            references=[sent],
                            hypothesis=perturbed_results[i],
                            emulate_multibleu=True
                        ))
                        if self.adversarial:  # relative degradation
                            if self.origin_bleu[i] == 0:
                                relative_degraded_bleu = 0
                            else:
                                relative_degraded_bleu = degraded_value/self.origin_bleu[i]
                            episodic_rewards.append(relative_degraded_bleu)
                        else:  # absolute improvement
                            print("bleu variation:", self.origin_bleu[i],-degraded_value)
                            episodic_rewards.append(-degraded_value)
                    else:
                        episodic_rewards.append(0.0)
                # append additional episodic rewards
                batched_rewards = [batched_rewards[i]+episodic_rewards[i]*self.r_i_weight
                                   for i in range(batch_size)]

        # update sequences' pointer for rephrasing
        self.index += 1

        return self.x_emb.cpu().numpy(), np.array(batched_rewards), self.survival_signals
Exemple #8
0
def train(flags):
    """
    flags:
        saveto: str
        reload: store_true
        config_path: str
        pretrain_path: str, default=""
        model_name: str
        log_path: str
    """

    # ================================================================================== #
    # Initialization for training on different devices
    # - CPU/GPU
    # - Single/Distributed
    Constants.USE_GPU = flags.use_gpu

    if flags.multi_gpu:
        dist.distributed_init(flags.shared_dir)
        world_size = dist.get_world_size()
        rank = dist.get_rank()
        local_rank = dist.get_local_rank()
    else:
        world_size = 1
        rank = 0
        local_rank = 0

    if Constants.USE_GPU:
        torch.cuda.set_device(local_rank)
        Constants.CURRENT_DEVICE = "cuda:{0}".format(local_rank)
    else:
        Constants.CURRENT_DEVICE = "cpu"

    # If not root_rank, close logging
    # else write log of training to file.
    if rank == 0:
        write_log_to_file(
            os.path.join(flags.log_path,
                         "%s.log" % time.strftime("%Y%m%d-%H%M%S")))
    else:
        close_logging()

    # ================================================================================== #
    # Parsing configuration files
    # - Load default settings
    # - Load pre-defined settings
    # - Load user-defined settings

    configs = prepare_configs(flags.config_path, flags.predefined_config)

    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    optimizer_configs = configs['optimizer_configs']
    training_configs = configs['training_configs']

    INFO(pretty_configs(configs))

    # use odc
    if training_configs['use_odc'] is True:
        ave_best_k = check_odc_config(training_configs)
    else:
        ave_best_k = 0

    Constants.SEED = training_configs['seed']

    set_seed(Constants.SEED)

    timer = Timer()

    # ================================================================================== #
    # Load Data

    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_src = Vocabulary.build_from_file(**data_configs['vocabularies'][0])
    vocab_tgt = Vocabulary.build_from_file(**data_configs['vocabularies'][1])

    Constants.EOS = vocab_src.eos
    Constants.PAD = vocab_src.pad
    Constants.BOS = vocab_src.bos

    train_bitext_dataset = ZipDataset(
        TextLineDataset(data_path=data_configs['train_data'][0],
                        vocabulary=vocab_src,
                        max_len=data_configs['max_len'][0],
                        is_train_dataset=True),
        TextLineDataset(data_path=data_configs['train_data'][1],
                        vocabulary=vocab_tgt,
                        max_len=data_configs['max_len'][1],
                        is_train_dataset=True))

    valid_bitext_dataset = ZipDataset(
        TextLineDataset(
            data_path=data_configs['valid_data'][0],
            vocabulary=vocab_src,
            is_train_dataset=False,
        ),
        TextLineDataset(data_path=data_configs['valid_data'][1],
                        vocabulary=vocab_tgt,
                        is_train_dataset=False))

    training_iterator = DataIterator(
        dataset=train_bitext_dataset,
        batch_size=training_configs["batch_size"],
        use_bucket=training_configs['use_bucket'],
        buffer_size=training_configs['buffer_size'],
        batching_func=training_configs['batching_key'],
        world_size=world_size,
        rank=rank)

    valid_iterator = DataIterator(
        dataset=valid_bitext_dataset,
        batch_size=training_configs['valid_batch_size'],
        use_bucket=True,
        buffer_size=100000,
        numbering=True,
        world_size=world_size,
        rank=rank)

    bleu_scorer = SacreBLEUScorer(
        reference_path=data_configs["bleu_valid_reference"],
        num_refs=data_configs["num_refs"],
        lang_pair=data_configs["lang_pair"],
        sacrebleu_args=training_configs["bleu_valid_configs"]
        ['sacrebleu_args'],
        postprocess=training_configs["bleu_valid_configs"]['postprocess'])

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # ================================ Begin ======================================== #
    # Build Model & Optimizer
    # We would do steps below on after another
    #     1. build models & criterion
    #     2. move models & criterion to gpu if needed
    #     3. load pre-trained model if needed
    #     4. build optimizer
    #     5. build learning rate scheduler if needed
    #     6. load checkpoints if needed

    # 0. Initial

    lrate = optimizer_configs['learning_rate']
    model_collections = Collections()

    checkpoint_saver = Saver(
        save_prefix="{0}.ckpt".format(
            os.path.join(flags.saveto, flags.model_name)),
        num_max_keeping=training_configs['num_kept_checkpoints'])

    best_model_prefix = os.path.join(
        flags.saveto, flags.model_name + Constants.MY_BEST_MODEL_SUFFIX)

    best_k_saver = BestKSaver(
        save_prefix="{0}.best_k_ckpt".format(
            os.path.join(flags.saveto, flags.model_name)),
        num_max_keeping=training_configs['num_kept_best_k_checkpoints'])

    # 1. Build Model & Criterion
    INFO('Building model...')
    timer.tic()
    nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                            n_tgt_vocab=vocab_tgt.max_n_words,
                            padding_idx=vocab_src.pad,
                            vocab_src=vocab_src,
                            **model_configs)
    INFO(nmt_model)

    # build teacher model
    teacher_model, teacher_model_path = get_teacher_model(
        training_configs, model_configs, vocab_src, vocab_tgt, flags)

    # build critic
    critic = CombinationCriterion(model_configs['loss_configs'],
                                  padding_idx=vocab_tgt.pad,
                                  teacher=teacher_model)
    # INFO(critic)
    critic.INFO()

    # 2. Move to GPU
    if Constants.USE_GPU:
        nmt_model = nmt_model.cuda()
        critic = critic.cuda()

    # 3. Load pretrained model if needed
    load_pretrained_model(nmt_model,
                          flags.pretrain_path,
                          exclude_prefix=None,
                          device=Constants.CURRENT_DEVICE)

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # 4. Build optimizer
    INFO('Building Optimizer...')

    if not flags.multi_gpu:
        optim = Optimizer(name=optimizer_configs['optimizer'],
                          model=nmt_model,
                          lr=lrate,
                          grad_clip=optimizer_configs['grad_clip'],
                          optim_args=optimizer_configs['optimizer_params'],
                          update_cycle=training_configs['update_cycle'])
    else:
        optim = dist.DistributedOptimizer(
            name=optimizer_configs['optimizer'],
            model=nmt_model,
            lr=lrate,
            grad_clip=optimizer_configs['grad_clip'],
            optim_args=optimizer_configs['optimizer_params'],
            device_id=local_rank)

    # 5. Build scheduler for optimizer if needed
    scheduler = build_scheduler(
        schedule_method=optimizer_configs['schedule_method'],
        optimizer=optim,
        scheduler_configs=optimizer_configs['scheduler_configs'])

    # 6. build moving average
    ma = build_ma(training_configs, nmt_model.named_parameters())

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # Reload from latest checkpoint
    if flags.reload:
        checkpoint_saver.load_latest(model=nmt_model,
                                     optim=optim,
                                     lr_scheduler=scheduler,
                                     collections=model_collections,
                                     ma=ma,
                                     device=Constants.CURRENT_DEVICE)

    # broadcast parameters and optimizer states
    if world_size > 1:
        INFO("Broadcasting model parameters...")
        dist.broadcast_parameters(params=nmt_model.state_dict())
        INFO("Broadcasting optimizer states...")
        dist.broadcast_optimizer_state(optimizer=optim.optim)
        INFO('Done.')

    # ================================================================================== #
    # Prepare training

    eidx = model_collections.get_collection("eidx", [0])[-1]
    uidx = model_collections.get_collection("uidx", [1])[-1]
    bad_count = model_collections.get_collection("bad_count", [0])[-1]
    oom_count = model_collections.get_collection("oom_count", [0])[-1]
    is_early_stop = model_collections.get_collection("is_early_stop", [
        False,
    ])[-1]
    teacher_patience = model_collections.get_collection(
        "teacher_patience", [training_configs['teacher_patience']])[-1]

    train_loss_meter = AverageMeter()
    train_loss_dict_meter = AverageMeterDict(critic.get_critic_name())
    sent_per_sec_meter = TimeMeter()
    tok_per_sec_meter = TimeMeter()

    update_cycle = training_configs['update_cycle']
    grad_denom = 0
    train_loss = 0.0
    cum_n_words = 0
    train_loss_dict = dict()
    valid_loss = best_valid_loss = float('inf')

    if rank == 0:
        summary_writer = SummaryWriter(log_dir=flags.log_path)
    else:
        summary_writer = None

    sent_per_sec_meter.start()
    tok_per_sec_meter.start()

    INFO('Begin training...')

    while True:

        if summary_writer is not None:
            summary_writer.add_scalar("Epoch", (eidx + 1), uidx)

        # Build iterator and progress bar
        training_iter = training_iterator.build_generator()

        if rank == 0:
            training_progress_bar = tqdm(desc=' - (Epc {}, Upd {}) '.format(
                eidx, uidx),
                                         total=len(training_iterator),
                                         unit="sents")
        else:
            training_progress_bar = None

        for batch in training_iter:

            seqs_x, seqs_y = batch

            batch_size = len(seqs_x)
            cum_n_words += sum(len(s) for s in seqs_y)

            try:
                # Prepare data
                x, y = prepare_data(seqs_x, seqs_y, cuda=Constants.USE_GPU)

                loss, loss_dict = compute_forward(
                    model=nmt_model,
                    critic=critic,
                    seqs_x=x,
                    seqs_y=y,
                    eval=False,
                    normalization=1.0,
                    norm_by_words=training_configs["norm_by_words"])

                update_cycle -= 1
                grad_denom += batch_size
                train_loss += loss
                train_loss_dict = add_dict_value(train_loss_dict, loss_dict)

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    oom_count += 1
                else:
                    raise e

            # When update_cycle becomes 0, it means end of one batch. Several things will be done:
            # - update parameters
            # - reset update_cycle and grad_denom, update uidx
            # - learning rate scheduling
            # - update moving average

            if update_cycle == 0:

                # 0. reduce variables
                if world_size > 1:
                    grad_denom = dist.all_reduce_py(grad_denom)
                    train_loss = dist.all_reduce_py(train_loss)
                    train_loss_dict = dist.all_reduce_py(train_loss_dict)
                    cum_n_words = dist.all_reduce_py(cum_n_words)

                # 1. update parameters
                optim.step(denom=grad_denom)
                optim.zero_grad()

                if training_progress_bar is not None:
                    training_progress_bar.update(grad_denom)
                    training_progress_bar.set_description(
                        ' - (Epc {}, Upd {}) '.format(eidx, uidx))

                    postfix_str = 'TrainLoss: {:.2f}, ValidLoss(best): {:.2f} ({:.2f}), '.format(
                        train_loss, valid_loss, best_valid_loss)
                    for critic_name, loss_value in train_loss_dict.items():
                        postfix_str += (critic_name +
                                        ': {:.2f}, ').format(loss_value)
                    training_progress_bar.set_postfix_str(postfix_str)

                # 2. learning rate scheduling
                if scheduler is not None and optimizer_configs[
                        "schedule_method"] != "loss":
                    scheduler.step(global_step=uidx)

                # 3. update moving average
                if ma is not None and eidx >= training_configs[
                        'moving_average_start_epoch']:
                    ma.step()

                # 4. update meters
                train_loss_meter.update(train_loss, grad_denom)
                train_loss_dict_meter.update(train_loss_dict, grad_denom)
                sent_per_sec_meter.update(grad_denom)
                tok_per_sec_meter.update(cum_n_words)

                # 5. reset accumulated variables, update uidx
                update_cycle = training_configs['update_cycle']
                grad_denom = 0
                uidx += 1
                cum_n_words = 0.0
                train_loss = 0.0
                train_loss_dict = dict()

            else:
                continue

            # ================================================================================== #
            # Display some information
            if should_trigger_by_steps(
                    uidx, eidx, every_n_step=training_configs['disp_freq']):

                lrate = list(optim.get_lrate())[0]

                if summary_writer is not None:
                    summary_writer.add_scalar(
                        "Speed(sents/sec)",
                        scalar_value=sent_per_sec_meter.ave,
                        global_step=uidx)
                    summary_writer.add_scalar(
                        "Speed(words/sec)",
                        scalar_value=tok_per_sec_meter.ave,
                        global_step=uidx)
                    summary_writer.add_scalar(
                        "train_loss",
                        scalar_value=train_loss_meter.ave,
                        global_step=uidx)
                    # add loss for every critic
                    if flags.display_loss_detail:
                        combination_loss = train_loss_dict_meter.value
                        for key, value in combination_loss.items():
                            summary_writer.add_scalar(key,
                                                      scalar_value=value,
                                                      global_step=uidx)
                    summary_writer.add_scalar("lrate",
                                              scalar_value=lrate,
                                              global_step=uidx)
                    summary_writer.add_scalar("oom_count",
                                              scalar_value=oom_count,
                                              global_step=uidx)

                # Reset Meters
                sent_per_sec_meter.reset()
                tok_per_sec_meter.reset()
                train_loss_meter.reset()
                train_loss_dict_meter.reset()

            # ================================================================================== #
            # Loss Validation & Learning rate annealing
            if should_trigger_by_steps(
                    global_step=uidx,
                    n_epoch=eidx,
                    every_n_step=training_configs['loss_valid_freq'],
                    debug=flags.debug):
                with cache_parameters(nmt_model):

                    valid_loss, valid_loss_dict = loss_evaluation(
                        model=nmt_model,
                        critic=critic,
                        valid_iterator=valid_iterator,
                        rank=rank,
                        world_size=world_size)

                if scheduler is not None and optimizer_configs[
                        "schedule_method"] == "loss":
                    scheduler.step(metric=valid_loss)

                model_collections.add_to_collection("history_losses",
                                                    valid_loss)

                min_history_loss = np.array(
                    model_collections.get_collection("history_losses")).min()
                best_valid_loss = min_history_loss

                if summary_writer is not None:
                    summary_writer.add_scalar("loss",
                                              valid_loss,
                                              global_step=uidx)
                    summary_writer.add_scalar("best_loss",
                                              min_history_loss,
                                              global_step=uidx)

            # ================================================================================== #
            # BLEU Validation & Early Stop
            if should_trigger_by_steps(
                    global_step=uidx,
                    n_epoch=eidx,
                    every_n_step=training_configs['bleu_valid_freq'],
                    min_step=training_configs['bleu_valid_warmup'],
                    debug=flags.debug):

                with cache_parameters(nmt_model):

                    valid_bleu = bleu_evaluation(
                        uidx=uidx,
                        valid_iterator=valid_iterator,
                        batch_size=training_configs["bleu_valid_batch_size"],
                        model=nmt_model,
                        bleu_scorer=bleu_scorer,
                        vocab_src=vocab_src,
                        vocab_tgt=vocab_tgt,
                        valid_dir=flags.valid_path,
                        max_steps=training_configs["bleu_valid_configs"]
                        ["max_steps"],
                        beam_size=training_configs["bleu_valid_configs"]
                        ["beam_size"],
                        alpha=training_configs["bleu_valid_configs"]["alpha"],
                        world_size=world_size,
                        rank=rank,
                    )

                model_collections.add_to_collection(key="history_bleus",
                                                    value=valid_bleu)

                best_valid_bleu = float(
                    np.array(model_collections.get_collection(
                        "history_bleus")).max())

                if summary_writer is not None:
                    summary_writer.add_scalar("bleu", valid_bleu, uidx)
                    summary_writer.add_scalar("best_bleu", best_valid_bleu,
                                              uidx)

                # If model get new best valid bleu score
                if valid_bleu >= best_valid_bleu:
                    bad_count = 0

                    if is_early_stop is False:
                        if rank == 0:
                            # 1. save the best model
                            torch.save(nmt_model.state_dict(),
                                       best_model_prefix + ".final")

                else:
                    bad_count += 1

                    # At least one epoch should be traversed
                    if bad_count >= training_configs[
                            'early_stop_patience'] and eidx > 0:
                        is_early_stop = True
                        WARN("Early Stop!")
                        exit(0)

                if rank == 0:
                    best_k_saver.save(global_step=uidx,
                                      metric=valid_bleu,
                                      model=nmt_model,
                                      optim=optim,
                                      lr_scheduler=scheduler,
                                      collections=model_collections,
                                      ma=ma)

                # ODC
                if training_configs['use_odc'] is True:
                    if valid_bleu >= best_valid_bleu:
                        pass

                        # choose method to generate teachers from checkpoints
                        # - best
                        # - ave_k_best
                        # - ma

                        if training_configs['teacher_choice'] == 'ma':
                            teacher_params = ma.export_ma_params()
                        elif training_configs['teacher_choice'] == 'best':
                            teacher_params = nmt_model.state_dict()
                        elif "ave_best" in training_configs['teacher_choice']:
                            if best_k_saver.num_saved >= ave_best_k:
                                teacher_params = average_checkpoints(
                                    best_k_saver.get_all_ckpt_path()
                                    [-ave_best_k:])
                            else:
                                teacher_params = nmt_model.state_dict()
                        else:
                            raise ValueError(
                                "can not support teacher choice %s" %
                                training_configs['teacher_choice'])
                        torch.save(teacher_params, teacher_model_path)
                        del teacher_params
                        teacher_patience = 0
                        critic.set_use_KD(False)
                    else:
                        teacher_patience += 1
                        if teacher_patience >= training_configs[
                                'teacher_refresh_warmup']:
                            teacher_params = torch.load(
                                teacher_model_path,
                                map_location=Constants.CURRENT_DEVICE)
                            teacher_model.load_state_dict(teacher_params,
                                                          strict=False)
                            del teacher_params
                            critic.reset_teacher(teacher_model)
                            critic.set_use_KD(True)

                if summary_writer is not None:
                    summary_writer.add_scalar("bad_count", bad_count, uidx)

                info_str = "{0} Loss: {1:.2f} BLEU: {2:.2f} lrate: {3:6f} patience: {4} ".format(
                    uidx, valid_loss, valid_bleu, lrate, bad_count)
                for key, value in valid_loss_dict.items():
                    info_str += (key + ': {0:.2f} '.format(value))
                INFO(info_str)

            # ================================================================================== #
            # Saving checkpoints
            if should_trigger_by_steps(
                    uidx,
                    eidx,
                    every_n_step=training_configs['save_freq'],
                    debug=flags.debug):
                model_collections.add_to_collection("uidx", uidx)
                model_collections.add_to_collection("eidx", eidx)
                model_collections.add_to_collection("bad_count", bad_count)
                model_collections.add_to_collection("teacher_patience",
                                                    teacher_patience)
                if not is_early_stop:
                    if rank == 0:
                        checkpoint_saver.save(global_step=uidx,
                                              model=nmt_model,
                                              optim=optim,
                                              lr_scheduler=scheduler,
                                              collections=model_collections,
                                              ma=ma)

        if training_progress_bar is not None:
            training_progress_bar.close()

        eidx += 1
        if eidx > training_configs["max_epochs"]:
            break
Exemple #9
0
class Translate_Env(object):
    """
    wrap translate environment for multiple agents
    env needs parallel data to evaluate bleu_degredation
    state of the env is defined as the batched src labels and current target index
    environment yields rewards based on discriminator and finally by sentence-level BLEU
    :return: translation multiple sentences and return changed bleu
    """
    def __init__(
        self,
        attack_configs,
        discriminator_configs,
        src_vocab,
        trg_vocab,
        data_iterator,
        save_to,
        device="cpu",
    ):
        """
        initiate translation environments, needs a discriminator and translator
        :param attack_configs: attack configures dictionary
        :param save_to: discriminator models
        :param data_iterator: use to provide data for environment initiate
        the directory of the src sentences
        :param device: (string) devices to allocate variables("cpu", "cuda:*")
        default as cpu
        """
        self.data_iterator = data_iterator
        discriminator_model_configs = discriminator_configs[
            "discriminator_model_configs"]
        discriminator_optim_configs = discriminator_configs[
            "discriminator_optimizer_configs"]
        self.victim_config_path = attack_configs["victim_configs"]
        self.victim_model_path = attack_configs["victim_model"]
        # determine devices
        self.device = device
        with open(self.victim_config_path.strip()) as v_f:
            print("open victim configs...%s" % self.victim_config_path)
            victim_configs = yaml.load(v_f)

        self.src_vocab = src_vocab
        self.trg_vocab = trg_vocab
        self.translate_model = build_translate_model(victim_configs,
                                                     self.victim_model_path,
                                                     vocab_src=self.src_vocab,
                                                     vocab_trg=self.trg_vocab,
                                                     device=self.device)
        self.translate_model.eval()
        self.w2p, self.w2vocab = load_or_extract_near_vocab(
            config_path=self.victim_config_path,
            model_path=self.victim_model_path,
            init_perturb_rate=attack_configs["init_perturb_rate"],
            save_to=os.path.join(save_to, "near_vocab"),
            save_to_full=os.path.join(save_to, "full_near_vocab"),
            top_reserve=12,
            emit_as_id=True)
        #########################################################
        # to update discriminator
        # discriminator_data_configs = attack_configs["discriminator_data_configs"]
        self.discriminator = TransDiscriminator(
            n_src_words=self.src_vocab.max_n_words,
            n_trg_words=self.trg_vocab.max_n_words,
            **discriminator_model_configs)
        self.discriminator.to(self.device)

        load_embedding(self.discriminator,
                       model_path=self.victim_model_path,
                       device=self.device)

        self.optim_D = Optimizer(
            name=discriminator_optim_configs["optimizer"],
            model=self.discriminator,
            lr=discriminator_optim_configs["learning_rate"],
            grad_clip=discriminator_optim_configs["grad_clip"],
            optim_args=discriminator_optim_configs["optimizer_params"])
        self.criterion_D = nn.CrossEntropyLoss(
        )  # used in discriminator updates
        self.scheduler_D = None  # default as None
        if discriminator_optim_configs['schedule_method'] is not None:
            if discriminator_optim_configs['schedule_method'] == "loss":
                self.scheduler_D = ReduceOnPlateauScheduler(
                    optimizer=self.optim_D,
                    **discriminator_optim_configs["scheduler_configs"])
            elif discriminator_optim_configs['schedule_method'] == "noam":
                self.scheduler_D = NoamScheduler(
                    optimizer=self.optim_D,
                    **discriminator_optim_configs['scheduler_configs'])
            elif discriminator_optim_configs["schedule_method"] == "rsqrt":
                self.scheduler_D = RsqrtScheduler(
                    optimizer=self.optim_D,
                    **discriminator_optim_configs["scheduler_configs"])
            else:
                WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".
                     format(discriminator_optim_configs['schedule_method']))
        ############################################################
        self._init_state()
        self.adversarial = attack_configs[
            "adversarial"]  # adversarial sample or reinforced samples
        self.r_s_weight = attack_configs["r_s_weight"]
        self.r_d_weight = attack_configs["r_d_weight"]

    def _init_state(self):
        """
        initiate batched sentences / origin_bleu / index (start from first label, no BOS/EOS)
        the initial state of the environment
        :return: env states (the src, index)
        """
        self.index = 1
        self.origin_bleu = []
        batch = next(self.data_iterator)
        assert len(
            batch
        ) == 3, "must be provided with line index (check for data_iterator)"
        # training, parallel trg is provided
        _, seqs_x, self.seqs_y = batch
        self.sent_len = [len(x) for x in seqs_x]  # for terminal signals
        self.terminal_signal = [0] * len(seqs_x)  # for terminal signals

        self.padded_src, self.padded_trg = self.prepare_data(
            seqs_x=seqs_x, seqs_y=self.seqs_y)
        self.origin_result = self.translate()
        # calculate BLEU scores for the top candidate
        for index, sent_t in enumerate(self.seqs_y):
            bleu_t = bleu.sentence_bleu(references=[sent_t],
                                        hypothesis=self.origin_result[index],
                                        emulate_multibleu=True)
            self.origin_bleu.append(bleu_t)
        return self.padded_src.cpu().numpy()

    def get_src_vocab(self):
        return self.src_vocab

    def reset(self):
        return self._init_state()

    def reset_data_iter(
            self, data_iter):  # reset data iterator with provided iterator
        self.data_iterator = data_iter
        return

    def reset_discriminator(self):
        self.discriminator.reset()
        load_embedding(self.discriminator,
                       model_path=self.victim_model_path,
                       device=self.device)

    def prepare_D_data(self, attacker, seqs_x, seqs_y, batch_first=True):
        """
        using global_attacker to generate training data for discriminator
        :param attacker: prepare the data
        :param seqs_x: list of sources
        :param seqs_y: corresponding targets
        :param batch_first: first dimension of seqs be batch
        :param device: cpu or cuda*
        :return: perturbed seqsx, seqsy, flags
        """
        def _np_pad_batch_2D(samples, pad, batch_first=True):
            # pack seqs into tensor with pads
            batch_size = len(samples)
            sizes = [len(s) for s in samples]
            max_size = max(sizes)
            x_np = np.full((batch_size, max_size),
                           fill_value=pad,
                           dtype='int64')
            for ii in range(batch_size):
                x_np[ii, :sizes[ii]] = samples[ii]
            if batch_first is False:
                x_np = np.transpose(x_np, [1, 0])
            x = torch.tensor(x_np).to(self.device)
            return x

        seqs_x = list(map(lambda s: [BOS] + s + [EOS], seqs_x))

        x = _np_pad_batch_2D(samples=seqs_x, pad=PAD, batch_first=batch_first)
        # training mode attack: randomly choose half of the seqs to attack
        attacker.eval()
        x, flags = attacker.seq_attack(x, self.w2vocab, training_mode=True)

        seqs_y = list(map(lambda s: [BOS] + s + [EOS], seqs_y))

        y = _np_pad_batch_2D(seqs_y, pad=PAD, batch_first=batch_first)
        flags.to(self.device)

        # # print trace
        # flag_list = flags.cpu().numpy().tolist()
        # x_list = x.cpu().numpy().tolist()
        # for i in range(len(flag_list)):
        #     if flag_list[i]==1:
        #         print(self.src_vocab.ids2sent(seqs_x[i]))
        #         print(self.src_vocab.ids2sent(x_list[i]))
        #         print(self.trg_vocab.ids2sent(seqs_y[i]))
        return x, y, flags

    def prepare_data(self, seqs_x, seqs_y=None, batch_first=True):
        """
        Args:
            eval ('bool'): indicator for eval/infer.
        Returns: padded data matrices
        """
        def _np_pad_batch_2D(samples, pad, batch_first=True):
            batch_size = len(samples)
            sizes = [len(s) for s in samples]
            max_size = max(sizes)
            x_np = np.full((batch_size, max_size),
                           fill_value=pad,
                           dtype='int64')
            for ii in range(batch_size):
                x_np[ii, :sizes[ii]] = samples[ii]
            if batch_first is False:
                x_np = np.transpose(x_np, [1, 0])
            x = torch.tensor(x_np).to(self.device)
            return x

        seqs_x = list(map(lambda s: [BOS] + s + [EOS], seqs_x))
        x = _np_pad_batch_2D(samples=seqs_x, pad=PAD, batch_first=batch_first)
        if seqs_y is None:
            return x
        seqs_y = list(map(lambda s: [BOS] + s + [EOS], seqs_y))
        y = _np_pad_batch_2D(seqs_y, pad=PAD, batch_first=batch_first)

        return x, y

    def acc_validation(self, attacker):
        self.discriminator.eval()
        acc = 0
        sample_count = 0
        for i in range(5):
            try:
                batch = next(self.data_iterator)
            except StopIteration:
                batch = next(self.data_iterator)
            seq_nums, seqs_x, seqs_y = batch
            x, y, flags = self.prepare_D_data(attacker, seqs_x, seqs_y)
            # set components to evaluation mode
            self.discriminator.eval()
            with torch.no_grad():
                preds = self.discriminator(x, y).argmax(dim=-1)
                acc += torch.eq(preds, flags).sum()
                sample_count += preds.size(0)
        acc = acc.float() / sample_count
        return acc.item()

    def compute_D_forward(self, seqs_x, seqs_y, gold_flags, evaluate=False):
        """
        get loss according to criterion
        :param: gold_flags=1 if perturbed, otherwise 0
        :return: loss value
        """
        if not evaluate:
            # set components to training mode(dropout layers)
            self.discriminator.train()
            self.criterion_D.train()
            with torch.enable_grad():
                class_probs = self.discriminator(seqs_x, seqs_y)
                loss = self.criterion_D(class_probs, gold_flags)
            torch.autograd.backward(loss)
            return loss.item()
        else:
            # set components to evaluation mode(dropout layers)
            self.discriminator.eval()
            self.criterion_D.eval()
            with torch.no_grad():
                class_probs = self.discriminator(seqs_x, seqs_y)
                loss = self.criterion_D(class_probs, gold_flags)
        return loss.item()

    def update_discriminator(self,
                             attacker_model,
                             base_steps=0,
                             min_update_steps=20,
                             max_update_steps=300,
                             accuracy_bound=0.8,
                             summary_writer=None):
        """
        update discriminator
        :param attacker_model: attacker to generate training data for discriminator
        :param base_steps: used for saving
        :param min_update_steps: (integer) minimum update steps,
                    also the discriminator evaluate steps
        :param max_update_steps: (integer) maximum update steps
        :param accuracy_bound: (float) update until accuracy reaches the bound
                    (or max_update_steps)
        :param summary_writer: used to log discriminator learning information
        :return: steps and test accuracy as trust region
        """
        INFO("update discriminator")
        self.optim_D.zero_grad()
        attacker_model = attacker_model.to(self.device)
        step = 0
        while True:
            try:
                batch = next(self.data_iterator)
            except StopIteration:
                batch = next(self.data_iterator)
            # update the discriminator
            step += 1
            if self.scheduler_D is not None:
                # override learning rate in self.optim_D
                self.scheduler_D.step(global_step=step)
            _, seqs_x, seqs_y = batch  # returned tensor type of the data
            try:
                x, y, flags = self.prepare_D_data(attacker_model, seqs_x,
                                                  seqs_y)
                loss = self.compute_D_forward(seqs_x=x,
                                              seqs_y=y,
                                              gold_flags=flags)
                self.optim_D.step()
                print("discriminator loss:", loss)
            except RuntimeError as e:
                if "out of memory" in str(e):
                    print("WARNING: out of memory, skipping batch")
                    self.optim_D.zero_grad()
                else:
                    raise e

            # valid for accuracy / check for break (if any)
            if step % min_update_steps == 0:
                acc = self.acc_validation(attacker_model)
                print("discriminator acc: %2f" % acc)
                summary_writer.add_scalar("discriminator",
                                          scalar_value=acc,
                                          global_step=base_steps + step)
                if accuracy_bound and acc > accuracy_bound:
                    INFO("discriminator reached training acc bound, updated.")
                    return base_steps + step, acc

            if step > max_update_steps:
                acc = self.acc_validation(attacker_model)
                print("discriminator acc: %2f" % acc)
                INFO("Reach maximum discriminator update. Finished.")
                return base_steps + step, acc  # stop updates

    def translate(self, inputs=None):
        """
        translate the self.perturbed_src
        :param inputs: if None, translate perturbed sequences stored in the environments
        :return: list of translation results
        """
        if inputs is None:
            inputs = self.padded_src
        with torch.no_grad():
            print(inputs.device)
            perturbed_results = beam_search(
                self.translate_model,
                beam_size=5,
                max_steps=150,
                src_seqs=inputs,
                alpha=-1.0,
            )
        perturbed_results = perturbed_results.cpu().numpy().tolist()
        # only use the top result from the result
        result = []
        for sent in perturbed_results:
            sent = [wid for wid in sent[0] if wid != PAD]
            result.append(sent)

        return result

    def step(self, actions):
        """
        step update for the environment: finally update self.index
        this is defined as inference of the environments
        :param actions: whether to perturb (action distribution vector
                    in shape [batch, 1])on current index
                 *  result of torch.argmax(actor_output_distribution, dim=-1)
                    test: actions = actor_output_distribution.argmax(dim=-1)
                    or train: actions = actor.output_distribution.multinomial(dim=-1)
                    can be on cpu or cuda.
        :return: updated states/ rewards/ terminal signal from the environments
                 reward (float), terminal_signal (boolean)
        """
        with torch.no_grad():
            terminal = False  # default is not terminated
            batch_size = actions.shape[0]
            reward = 0
            inputs = self.padded_src[:, self.index]
            inputs_mask = ~inputs.eq(PAD)
            target_of_step = []
            # modification on sequences (state)
            for batch_index in range(batch_size):
                word_id = inputs[batch_index]
                target_word_id = self.w2vocab[word_id.item()][np.random.choice(
                    len(self.w2vocab[word_id.item()]), 1)[0]]
                target_of_step += [target_word_id]
            if self.device != "cpu" and not actions.is_cuda:
                actions = actions.to(self.device)
                actions *= inputs_mask  # PAD is neglect
            # override the state src with random choice from candidates
            self.padded_src[:, self.index] *= (1 - actions)
            adjustification_ = torch.tensor(target_of_step)
            adjustification_ = adjustification_.to(self.device)
            self.padded_src[:, self.index] += adjustification_ * actions

            # update sequences' pointer
            self.index += 1
            """ run discriminator check for terminal signals, update local terminal list
            True: all sentences in the batch is defined as false by self.discriminator
            False: otherwise
            """
            # get discriminator distribution on the current src state
            discriminate_out = self.discriminator(self.padded_src,
                                                  self.padded_trg)
            self.terminal_signal = self.terminal_signal or discriminate_out.detach(
            ).argmax(dim=-1).cpu().numpy().tolist()
            signal = (1 - discriminate_out.argmax(dim=-1)).sum().item()
            if signal == 0 or self.index == self.padded_src.shape[1] - 1:
                terminal = True  # no need to further explore or reached EOS for all src
            """ collect rewards on the current state
            """
            # calculate intermediate survival rewards
            if not terminal:
                # survival rewards for survived objects
                distribution, discriminate_index = discriminate_out.max(dim=-1)
                distribution = distribution.detach().cpu().numpy()
                discriminate_index = (1 - discriminate_index).cpu().numpy()
                survival_value = distribution * discriminate_index * (
                    1 - np.array(self.terminal_signal))
                reward += survival_value.sum() * self.r_s_weight
            else:  # only penalty for overall intermediate termination
                reward = -1 * batch_size

            # only check for finished relative BLEU degradation when survival on the last label
            if self.index == self.padded_src.shape[1] - 1:
                # re-tokenize ignore the original UNK for victim model
                inputs = self.padded_src.cpu().numpy().tolist()
                new_inputs = []
                for indices in inputs:
                    # remove EOS, BOS, PAD
                    new_line = [
                        word_id for word_id in indices
                        if word_id not in [EOS, BOS, PAD]
                    ]
                    new_line = self.src_vocab.ids2sent(new_line)
                    if not hasattr(self.src_vocab.tokenizer, "bpe"):
                        new_line = new_line.strip().split()
                    else:
                        new_token = []
                        for w in new_line.strip().split():
                            if w != self.src_vocab.id2token(UNK):
                                new_token.append(
                                    self.src_vocab.tokenizer.bpe.segment_word(
                                        w))
                            else:
                                new_token.append([w])
                        new_line = sum(new_token, [])
                    new_line = [self.src_vocab.token2id(t) for t in new_line]
                    new_inputs.append(new_line)
                # translate calculate padded_src
                perturbed_result = self.translate(
                    self.prepare_data(seqs_x=new_inputs, ))
                # calculate final BLEU degredation:
                episodic_rewards = []
                for i, sent in enumerate(self.seqs_y):
                    # sentence is still surviving
                    if self.index >= self.sent_len[
                            i] - 1 and self.terminal_signal[i] == 0:
                        if self.origin_bleu[i] == 0:
                            # here we want to minimize noise from original bad cases
                            relative_degraded_value = 0
                        else:
                            relative_degraded_value = (
                                self.origin_bleu[i] - bleu.sentence_bleu(
                                    references=[sent],
                                    hypothesis=perturbed_result[i],
                                    emulate_multibleu=True))

                            # print(relative_degraded_value, self.origin_bleu[i])
                            relative_degraded_value /= self.origin_bleu[i]
                        if self.adversarial:
                            episodic_rewards.append(relative_degraded_value)
                        else:
                            episodic_rewards.append(-relative_degraded_value)
                    else:
                        episodic_rewards.append(0.0)
                reward += sum(episodic_rewards) * self.r_d_weight

            reward = reward / batch_size

        return self.padded_src.cpu().numpy(), reward, terminal,