def train(FLAGS): """ FLAGS: saveto: str reload: store_true config_path: str pretrain_path: str, default="" model_name: str log_path: str """ # write log of training to file. write_log_to_file( os.path.join(FLAGS.log_path, "%s.log" % time.strftime("%Y%m%d-%H%M%S"))) GlobalNames.USE_GPU = FLAGS.use_gpu if GlobalNames.USE_GPU: CURRENT_DEVICE = "cpu" else: CURRENT_DEVICE = "cuda:0" config_path = os.path.abspath(FLAGS.config_path) with open(config_path.strip()) as f: configs = yaml.load(f) INFO(pretty_configs(configs)) # Add default configs configs = default_configs(configs) data_configs = configs['data_configs'] model_configs = configs['model_configs'] optimizer_configs = configs['optimizer_configs'] training_configs = configs['training_configs'] GlobalNames.SEED = training_configs['seed'] set_seed(GlobalNames.SEED) best_model_prefix = os.path.join( FLAGS.saveto, FLAGS.model_name + GlobalNames.MY_BEST_MODEL_SUFFIX) timer = Timer() # ================================================================================== # # Load Data INFO('Loading data...') timer.tic() # Generate target dictionary vocab_tgt = Vocabulary(**data_configs["vocabularies"][0]) train_batch_size = training_configs["batch_size"] * max( 1, training_configs["update_cycle"]) train_buffer_size = training_configs["buffer_size"] * max( 1, training_configs["update_cycle"]) train_bitext_dataset = ZipDataset(TextLineDataset( data_path=data_configs['train_data'][0], vocabulary=vocab_tgt, max_len=data_configs['max_len'][0], ), shuffle=training_configs['shuffle']) valid_bitext_dataset = ZipDataset( TextLineDataset( data_path=data_configs['valid_data'][0], vocabulary=vocab_tgt, )) training_iterator = DataIterator( dataset=train_bitext_dataset, batch_size=train_batch_size, use_bucket=training_configs['use_bucket'], buffer_size=train_buffer_size, batching_func=training_configs['batching_key']) valid_iterator = DataIterator( dataset=valid_bitext_dataset, batch_size=training_configs['valid_batch_size'], use_bucket=True, buffer_size=100000, numbering=True) INFO('Done. Elapsed time {0}'.format(timer.toc())) lrate = optimizer_configs['learning_rate'] is_early_stop = False # ================================ Begin ======================================== # # Build Model & Optimizer # We would do steps below on after another # 1. build models & criterion # 2. move models & criterion to gpu if needed # 3. load pre-trained model if needed # 4. build optimizer # 5. build learning rate scheduler if needed # 6. load checkpoints if needed # 0. Initial model_collections = Collections() checkpoint_saver = Saver( save_prefix="{0}.ckpt".format( os.path.join(FLAGS.saveto, FLAGS.model_name)), num_max_keeping=training_configs['num_kept_checkpoints']) best_model_saver = Saver( save_prefix=best_model_prefix, num_max_keeping=training_configs['num_kept_best_model']) # 1. Build Model & Criterion INFO('Building model...') timer.tic() lm_model = build_model(n_tgt_vocab=vocab_tgt.max_n_words, **model_configs) INFO(lm_model) params_total = sum([p.numel() for n, p in lm_model.named_parameters()]) params_with_embedding = sum([ p.numel() for n, p in lm_model.named_parameters() if n.find('embedding') == -1 ]) INFO('Total parameters: {}'.format(params_total)) INFO('Total parameters (excluding word embeddings): {}'.format( params_with_embedding)) critic = NMTCriterion(label_smoothing=model_configs['label_smoothing']) INFO(critic) INFO('Done. Elapsed time {0}'.format(timer.toc())) # 2. Move to GPU if GlobalNames.USE_GPU: lm_model = lm_model.cuda() critic = critic.cuda() # 3. Load pretrained model if needed lm_model.init_parameters(FLAGS.pretrain_path, device=CURRENT_DEVICE) # 4. Build optimizer INFO('Building Optimizer...') optim = Optimizer(name=optimizer_configs['optimizer'], model=lm_model, lr=lrate, grad_clip=optimizer_configs['grad_clip'], optim_args=optimizer_configs['optimizer_params']) # 5. Build scheduler for optimizer if needed if optimizer_configs['schedule_method'] is not None: if optimizer_configs['schedule_method'] == "loss": scheduler = ReduceOnPlateauScheduler( optimizer=optim, **optimizer_configs["scheduler_configs"]) elif optimizer_configs['schedule_method'] == "noam": scheduler = NoamScheduler(optimizer=optim, **optimizer_configs['scheduler_configs']) else: WARN( "Unknown scheduler name {0}. Do not use lr_scheduling.".format( optimizer_configs['schedule_method'])) scheduler = None else: scheduler = None # 6. build moving average if training_configs['moving_average_method'] is not None: ma = MovingAverage( moving_average_method=training_configs['moving_average_method'], named_params=lm_model.named_parameters(), alpha=training_configs['moving_average_alpha']) else: ma = None INFO('Done. Elapsed time {0}'.format(timer.toc())) # Reload from latest checkpoint if FLAGS.reload: checkpoint_saver.load_latest(model=lm_model, optim=optim, lr_scheduler=scheduler, collections=model_collections, ma=ma) # ================================================================================== # # Prepare training eidx = model_collections.get_collection("eidx", [0])[-1] uidx = model_collections.get_collection("uidx", [0])[-1] bad_count = model_collections.get_collection("bad_count", [0])[-1] oom_count = model_collections.get_collection("oom_count", [0])[-1] summary_writer = SummaryWriter(log_dir=FLAGS.log_path) cum_samples = 0 cum_words = 0 valid_loss = best_valid_loss = float('inf') # Max Float saving_files = [] # Timer for computing speed timer_for_speed = Timer() timer_for_speed.tic() INFO('Begin training...') while True: summary_writer.add_scalar("Epoch", (eidx + 1), uidx) # Build iterator and progress bar training_iter = training_iterator.build_generator() training_progress_bar = tqdm(desc=' - (Epc {}, Upd {}) '.format( eidx, uidx), total=len(training_iterator), unit="sents") for batch in training_iter: uidx += 1 if optimizer_configs[ "schedule_method"] is not None and optimizer_configs[ "schedule_method"] != "loss": scheduler.step(global_step=uidx) seqs_y = batch n_samples_t = len(seqs_y) n_words_t = sum(len(s) for s in seqs_y) cum_samples += n_samples_t cum_words += n_words_t train_loss = 0. optim.zero_grad() try: # Prepare data for (seqs_y_t, ) in split_shard( seqs_y, split_size=training_configs['update_cycle']): y = prepare_data(seqs_y_t, cuda=GlobalNames.USE_GPU) loss = compute_forward( model=lm_model, critic=critic, # seqs_x=x, seqs_y=y, eval=False, normalization=n_samples_t, norm_by_words=training_configs["norm_by_words"]) train_loss += loss / y.size( 1) if not training_configs["norm_by_words"] else loss optim.step() except RuntimeError as e: if 'out of memory' in str(e): print('| WARNING: ran out of memory, skipping batch') oom_count += 1 optim.zero_grad() else: raise e if ma is not None and eidx >= training_configs[ 'moving_average_start_epoch']: ma.step() training_progress_bar.update(n_samples_t) training_progress_bar.set_description( ' - (Epc {}, Upd {}) '.format(eidx, uidx)) training_progress_bar.set_postfix_str( 'TrainLoss: {:.2f}, ValidLoss(best): {:.2f} ({:.2f})'.format( train_loss, valid_loss, best_valid_loss)) summary_writer.add_scalar("train_loss", scalar_value=train_loss, global_step=uidx) # ================================================================================== # # Display some information if should_trigger_by_steps( uidx, eidx, every_n_step=training_configs['disp_freq']): # words per second and sents per second words_per_sec = cum_words / (timer.toc(return_seconds=True)) sents_per_sec = cum_samples / (timer.toc(return_seconds=True)) lrate = list(optim.get_lrate())[0] summary_writer.add_scalar("Speed(words/sec)", scalar_value=words_per_sec, global_step=uidx) summary_writer.add_scalar("Speed(sents/sen)", scalar_value=sents_per_sec, global_step=uidx) summary_writer.add_scalar("lrate", scalar_value=lrate, global_step=uidx) summary_writer.add_scalar("oom_count", scalar_value=oom_count, global_step=uidx) # Reset timer timer.tic() cum_words = 0 cum_samples = 0 # ================================================================================== # # Saving checkpoints if should_trigger_by_steps( uidx, eidx, every_n_step=training_configs['save_freq'], debug=FLAGS.debug): model_collections.add_to_collection("uidx", uidx) model_collections.add_to_collection("eidx", eidx) model_collections.add_to_collection("bad_count", bad_count) if not is_early_stop: checkpoint_saver.save(global_step=uidx, model=lm_model, optim=optim, lr_scheduler=scheduler, collections=model_collections, ma=ma) # ================================================================================== # # Loss Validation & Learning rate annealing if should_trigger_by_steps( global_step=uidx, n_epoch=eidx, every_n_step=training_configs['loss_valid_freq'], debug=FLAGS.debug): if ma is not None: origin_state_dict = deepcopy(lm_model.state_dict()) lm_model.load_state_dict(ma.export_ma_params(), strict=False) valid_loss = loss_validation( model=lm_model, critic=critic, valid_iterator=valid_iterator, norm_by_words=training_configs["norm_by_words"]) model_collections.add_to_collection("history_losses", valid_loss) min_history_loss = np.array( model_collections.get_collection("history_losses")).min() summary_writer.add_scalar("loss", valid_loss, global_step=uidx) summary_writer.add_scalar("best_loss", min_history_loss, global_step=uidx) if ma is not None: lm_model.load_state_dict(origin_state_dict) del origin_state_dict if optimizer_configs["schedule_method"] == "loss": scheduler.step(metric=best_valid_loss) # If model get new best valid loss if valid_loss < best_valid_loss: bad_count = 0 if is_early_stop is False: # 1. save the best model torch.save(lm_model.state_dict(), best_model_prefix + ".final") # 2. record all several best models best_model_saver.save(global_step=uidx, model=lm_model) else: bad_count += 1 # At least one epoch should be traversed if bad_count >= training_configs[ 'early_stop_patience'] and eidx > 0: is_early_stop = True WARN("Early Stop!") best_valid_loss = min_history_loss summary_writer.add_scalar("bad_count", bad_count, uidx) INFO("{0} Loss: {1:.2f} lrate: {2:6f} patience: {3}".format( uidx, valid_loss, lrate, bad_count)) training_progress_bar.close() eidx += 1 if eidx > training_configs["max_epochs"]: break
def _init_local_optims(self, rephraser_optimizer_configs): """ actor, critic, alpha optimizers and lr scheduler if necessary rephraser_optimizer_configs: optimizer: "adafactor" learning_rate: 0.01 grad_clip: -1.0 optimizer_params: ~ schedule_method: rsqrt scheduler_configs: d_model: *dim warmup_steps: 100 """ # initiate local optimizer if rephraser_optimizer_configs is None: self.actor_optimizer = None self.critic_optimizer = None self.log_alpha_optimizer = None # self.actor_icm_optimizer = None self.actor_scheduler = None self.critic_scheduler = None else: self.actor_optimizer = Optimizer( name=rephraser_optimizer_configs["optimizer"], model=self.actor, lr=rephraser_optimizer_configs["learning_rate"], grad_clip=rephraser_optimizer_configs["grad_clip"], optim_args=rephraser_optimizer_configs["optimizer_params"]) self.critic_optimizer = Optimizer( name=rephraser_optimizer_configs["optimizer"], model=self.critic, lr=rephraser_optimizer_configs["learning_rate"], grad_clip=rephraser_optimizer_configs["grad_clip"], optim_args=rephraser_optimizer_configs["optimizer_params"]) # hardcoded entropy weight updates and icm updates self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=1e-4, betas=(0.9, 0.999)) # self.actor_icm_optimizer = torch.optim.Adam(self.actor.icm.parameters(), lr=1e-3, ) # Build scheduler for optimizer if needed if rephraser_optimizer_configs['schedule_method'] is not None: if rephraser_optimizer_configs['schedule_method'] == "loss": self.actor_scheduler = ReduceOnPlateauScheduler( optimizer=self.actor_optimizer, **rephraser_optimizer_configs["scheduler_configs"]) self.critic_scheduler = ReduceOnPlateauScheduler( optimizer=self.critic_optimizer, **rephraser_optimizer_configs["scheduler_configs"]) elif rephraser_optimizer_configs['schedule_method'] == "noam": self.actor_scheduler = NoamScheduler( optimizer=self.actor_optimizer, **rephraser_optimizer_configs["scheduler_configs"]) self.critic_scheduler = NoamScheduler( optimizer=self.critic_optimizer, **rephraser_optimizer_configs["scheduler_configs"]) elif rephraser_optimizer_configs["schedule_method"] == "rsqrt": self.actor_scheduler = RsqrtScheduler( optimizer=self.actor_optimizer, **rephraser_optimizer_configs["scheduler_configs"]) self.critic_scheduler = RsqrtScheduler( optimizer=self.critic_optimizer, **rephraser_optimizer_configs["scheduler_configs"]) else: WARN( "Unknown scheduler name {0}. Do not use lr_scheduling." .format( rephraser_optimizer_configs['schedule_method'])) self.actor_scheduler = None self.critic_scheduler = None else: self.actor_scheduler = None self.critic_scheduler = None
def run(): # default actor threads as 1 os.environ["OMP_NUM_THREADS"] = "1" mp = _mp.get_context('spawn') args = parser.parse_args() if not os.path.exists(args.save_to): os.mkdir(args.save_to) with open(args.config_path, "r") as f, \ open(os.path.join(args.save_to, "current_attack_configs.yaml"), "w") as current_configs: configs = yaml.load(f) yaml.dump(configs, current_configs) attack_configs = configs["attack_configs"] attacker_configs = configs["attacker_configs"] attacker_model_configs = attacker_configs["attacker_model_configs"] attacker_optimizer_configs = attacker_configs["attacker_optimizer_configs"] discriminator_configs = configs["discriminator_configs"] # training_configs = configs["training_configs"] # initial best saver for global model global_saver = Saver( save_prefix="{0}.final".format(os.path.join(args.save_to, "ACmodel")), num_max_keeping=attack_configs["num_kept_checkpoints"]) # the Global variable of USE_GPU is mainly used for environments GlobalNames.SEED = attack_configs["seed"] GlobalNames.USE_GPU = args.use_gpu torch.manual_seed(GlobalNames.SEED) # build vocabulary and data iterator for env with open(attack_configs["victim_configs"], "r") as victim_f: victim_configs = yaml.load(victim_f) data_configs = victim_configs["data_configs"] src_vocab = Vocabulary(**data_configs["vocabularies"][0]) trg_vocab = Vocabulary(**data_configs["vocabularies"][1]) data_set = ZipDataset( TextLineDataset( data_path=data_configs["train_data"][0], vocabulary=src_vocab, ), TextLineDataset( data_path=data_configs["train_data"][1], vocabulary=trg_vocab, ), shuffle=attack_configs["shuffle"] ) # we build the parallel data sets and iterate inside a thread # global model variables (trg network to save the results) global_attacker = attacker.Attacker(src_vocab.max_n_words, **attacker_model_configs) global_attacker = global_attacker.cpu() global_attacker.share_memory() if args.share_optim: # initiate optimizer and set to share mode optimizer = Optimizer( name=attacker_optimizer_configs["optimizer"], model=global_attacker, lr=attacker_optimizer_configs["learning_rate"], grad_clip=attacker_optimizer_configs["grad_clip"], optim_args=attacker_optimizer_configs["optimizer_params"]) optimizer.optim.share_memory() # Build scheduler for optimizer if needed if attacker_optimizer_configs['schedule_method'] is not None: if attacker_optimizer_configs['schedule_method'] == "loss": scheduler = ReduceOnPlateauScheduler( optimizer=optimizer, **attacker_optimizer_configs["scheduler_configs"]) elif attacker_optimizer_configs['schedule_method'] == "noam": scheduler = NoamScheduler( optimizer=optimizer, **attacker_optimizer_configs['scheduler_configs']) elif attacker_optimizer_configs["schedule_method"] == "rsqrt": scheduler = RsqrtScheduler( optimizer=optimizer, **attacker_optimizer_configs["scheduler_configs"]) else: WARN("Unknown scheduler name {0}. Do not use lr_scheduling.". format(attacker_optimizer_configs['schedule_method'])) scheduler = None else: scheduler = None else: optimizer = None scheduler = None # load from checkpoint for global model global_saver.load_latest(model=global_attacker, optim=optimizer, lr_scheduler=scheduler) if args.use_gpu: # collect available devices and distribute env on the available gpu device = "cuda" devices = [] for i in range(torch.cuda.device_count()): devices += ["cuda:%d" % i] print("available gpus:", devices) else: device = "cpu" devices = [device] process = [] counter = mp.Value("i", 0) lock = mp.Lock() # for multiple attackers update INFO("extract near candidates") _, _ = load_or_extract_near_vocab( config_path=attack_configs["victim_configs"], model_path=attack_configs["victim_model"], init_perturb_rate=attack_configs["init_perturb_rate"], save_to=os.path.join(args.save_to, "near_vocab"), save_to_full=os.path.join(args.save_to, "full_near_vocab"), top_reserve=12, emit_as_id=True) # train(0, device, args, counter, lock, # attack_configs, discriminator_configs, # src_vocab, trg_vocab, data_set, # global_attacker, attacker_configs, # optimizer, scheduler, # global_saver) # valid(args.n, device, args, # attack_configs, discriminator_configs, # src_vocab, trg_vocab, data_set, # global_attacker, attacker_configs, counter) # run multiple training process of local attacker to update global one for rank in range(args.n): print("initialize training thread on cuda:%d" % (rank + 1)) p = mp.Process(target=train, args=(rank, "cuda:%d" % (rank + 1), args, counter, lock, attack_configs, discriminator_configs, src_vocab, trg_vocab, data_set, global_attacker, attacker_configs, optimizer, scheduler, global_saver)) p.start() process.append(p) # run the dev thread for initiation print("initialize dev thread on cuda:0") p = mp.Process(target=valid, args=(0, "cuda:0", args, attack_configs, discriminator_configs, src_vocab, trg_vocab, data_set, global_attacker, attacker_configs, counter)) p.start() process.append(p) for p in process: p.join()
def train(rank, device, args, counter, lock, attack_configs, discriminator_configs, src_vocab, trg_vocab, data_set, global_attacker, attacker_configs, optimizer=None, scheduler=None, saver=None): """ running train process #1# train the env_discriminator #2# run attacker AC based on rewards from trained env_discriminator #3# run training updates attacker AC #4# :param rank: (int) the rank of the process (from multiprocess) :param device: the device of the process :param counter: python multiprocess variable :param lock: python multiprocess variable :param args: global args :param attack_configs: attack settings :param discriminator_configs: discriminator settings :param src_vocab: :param trg_vocab: :param data_set: (data_iterator object) provide batched data labels :param global_attacker: the model to sync from :param attacker_configs: local attacker settings :param optimizer: uses shared optimizer for the attacker use local one if none :param scheduler: uses shared scheduler for the attacker, use local one if none :param saver: model saver :return: """ trust_acc = acc_bound = discriminator_configs["acc_bound"] converged_bound = discriminator_configs["converged_bound"] patience = discriminator_configs["patience"] attacker_model_configs = attacker_configs["attacker_model_configs"] attacker_optimizer_configs = attacker_configs["attacker_optimizer_configs"] # this is for multi-processing, GlobalNames can not be direct inherited GlobalNames.USE_GPU = args.use_gpu GlobalNames.SEED = attack_configs["seed"] torch.manual_seed(GlobalNames.SEED + rank) # initiate local saver and load checkpoint if possible local_saver = Saver(save_prefix="{0}.local".format( os.path.join(args.save_to, "train_env%d" % rank, "ACmodel")), num_max_keeping=attack_configs["num_kept_checkpoints"]) attack_iterator = DataIterator(dataset=data_set, batch_size=attack_configs["batch_size"], use_bucket=True, buffer_size=attack_configs["buffer_size"], numbering=True) summary_writer = SummaryWriter( log_dir=os.path.join(args.save_to, "train_env%d" % rank)) local_attacker = attacker.Attacker(src_vocab.max_n_words, **attacker_model_configs) # build optimizer for attacker if optimizer is None: optimizer = Optimizer( name=attacker_optimizer_configs["optimizer"], model=global_attacker, lr=attacker_optimizer_configs["learning_rate"], grad_clip=attacker_optimizer_configs["grad_clip"], optim_args=attacker_optimizer_configs["optimizer_params"]) # Build scheduler for optimizer if needed if attacker_optimizer_configs['schedule_method'] is not None: if attacker_optimizer_configs['schedule_method'] == "loss": scheduler = ReduceOnPlateauScheduler( optimizer=optimizer, **attacker_optimizer_configs["scheduler_configs"]) elif attacker_optimizer_configs['schedule_method'] == "noam": scheduler = NoamScheduler( optimizer=optimizer, **attacker_optimizer_configs['scheduler_configs']) elif attacker_optimizer_configs["schedule_method"] == "rsqrt": scheduler = RsqrtScheduler( optimizer=optimizer, **attacker_optimizer_configs["scheduler_configs"]) else: WARN("Unknown scheduler name {0}. Do not use lr_scheduling.". format(attacker_optimizer_configs['schedule_method'])) scheduler = None else: scheduler = None local_saver.load_latest(model=local_attacker, optim=optimizer, lr_scheduler=scheduler) attacker_iterator = attack_iterator.build_generator() env = Translate_Env(attack_configs=attack_configs, discriminator_configs=discriminator_configs, src_vocab=src_vocab, trg_vocab=trg_vocab, data_iterator=attacker_iterator, save_to=args.save_to, device=device) episode_count = 0 episode_length = 0 local_steps = 0 # optimization steps: for learning rate schedules patience_t = patience while True: # infinite loop of data set # we will continue with a new iterator with refreshed environments # whenever the last iterator breaks with "StopIteration" attacker_iterator = attack_iterator.build_generator() env.reset_data_iter(attacker_iterator) padded_src = env.reset() padded_src = torch.from_numpy(padded_src) if device != "cpu": padded_src = padded_src.to(device) done = True discriminator_base_steps = local_steps while True: # check for update of discriminator # if env.acc_validation(local_attacker, use_gpu=True if env.device != "cpu" else False) < 0.55: if episode_count % attacker_configs["attacker_update_steps"] == 0: """ stop criterion: when updates a discriminator, we check for acc. If acc fails acc_bound, we reset the discriminator and try, until acc reaches the bound with patience. otherwise the training thread stops """ try: discriminator_base_steps, trust_acc = env.update_discriminator( local_attacker, discriminator_base_steps, min_update_steps=discriminator_configs[ "acc_valid_freq"], max_update_steps=discriminator_configs[ "discriminator_update_steps"], accuracy_bound=acc_bound, summary_writer=summary_writer) except StopIteration: INFO("finish one training epoch, reset data_iterator") break discriminator_base_steps += 1 # a flag to label the discriminator updates if trust_acc < converged_bound: # GAN target reached patience_t -= 1 INFO( "discriminator reached GAN convergence bound: %d times" % patience_t) else: # reset patience if discriminator is refreshed patience_t = patience if saver and local_steps % attack_configs["save_freq"] == 0: local_saver.save(global_step=local_steps, model=local_attacker, optim=optimizer, lr_scheduler=scheduler) if trust_acc < converged_bound: # and patience_t == patience-1: # we only save the global params reaching acc_bound torch.save(global_attacker.state_dict(), os.path.join(args.save_to, "ACmodel.final")) # saver.raw_save(model=global_attacker) if patience_t == 0: WARN("maximum patience reached. Training Thread should stop") break local_attacker.train() # switch back to training mode # for a initial (reset) attacker from global parameters if done: INFO("sync from global model") local_attacker.load_state_dict(global_attacker.state_dict()) # move the local attacker params back to device after updates local_attacker = local_attacker.to(device) values = [] # training critic: network outputs log_probs = [] rewards = [] # actual rewards entropies = [] local_steps += 1 # run sequences step of attack try: for i in range(args.action_roll_steps): episode_length += 1 attack_out, critic_out = local_attacker( padded_src, padded_src[:, env.index - 1:env.index + 2]) logit_attack_out = torch.log(attack_out) entropy = -(attack_out * logit_attack_out).sum(dim=-1).mean() summary_writer.add_scalar("action_entropy", scalar_value=entropy, global_step=local_steps) entropies.append(entropy) # for entropy loss actions = attack_out.multinomial(num_samples=1).detach() # only extract the log prob for chosen action (avg over batch) log_attack_out = logit_attack_out.gather(-1, actions).mean() padded_src, reward, terminal_signal = env.step( actions.squeeze()) done = terminal_signal or episode_length > args.max_episode_lengths with lock: counter.value += 1 if done: episode_length = 0 padded_src = env.reset() padded_src = torch.from_numpy(padded_src) if device != "cpu": padded_src = padded_src.to(device) values.append( critic_out.mean()) # list of torch variables (scalar) log_probs.append( log_attack_out) # list of torch variables (scalar) rewards.append(reward) # list of reward variables if done: episode_count += 1 break except StopIteration: INFO("finish one training epoch, reset data_iterator") break R = torch.zeros(1, 1) gae = torch.zeros(1, 1) if device != "cpu": R = R.to(device) gae = gae.to(device) if not done: # calculate value loss value = local_attacker.get_critic( padded_src, padded_src[:, env.index - 1:env.index + 2]) R = value.mean().detach() values.append(R) policy_loss = 0 value_loss = 0 # collect values for training for i in reversed((range(len(rewards)))): # value loss and policy loss must be clipped to stabilize training R = attack_configs["gamma"] * R + rewards[i] advantage = R - values[i] value_loss = value_loss + 0.5 * advantage.pow(2) delta_t = rewards[i] + attack_configs["gamma"] * \ values[i + 1] - values[i] gae = gae * attack_configs["gamma"] * attack_configs["tau"] + \ delta_t policy_loss = policy_loss - log_probs[i] * gae.detach() - \ attack_configs["entropy_coef"] * entropies[i] print("policy_loss", policy_loss) print("gae", gae) # update with optimizer optimizer.zero_grad() # we decay the loss according to discriminator's accuracy as a trust region constrain summary_writer.add_scalar("policy_loss", scalar_value=policy_loss * trust_acc, global_step=local_steps) summary_writer.add_scalar("value_loss", scalar_value=value_loss * trust_acc, global_step=local_steps) total_loss = trust_acc * policy_loss + \ trust_acc * attack_configs["value_coef"] * value_loss total_loss.backward() if attacker_optimizer_configs[ "schedule_method"] is not None and attacker_optimizer_configs[ "schedule_method"] != "loss": scheduler.step(global_step=local_steps) # move the model params to CPU and # assign local gradients to the global model to update local_attacker.to("cpu").ensure_shared_grads(global_attacker) optimizer.step() print("bingo!") if patience_t == 0: INFO("Reach maximum Discriminator patience, Finish") break
def train(FLAGS): """ FLAGS: saveto: str reload: store_true config_path: str pretrain_path: str, default="" model_name: str log_path: str """ # ================================================================================== # # Initialization for training on different devices # - CPU/GPU # - Single/Distributed GlobalNames.USE_GPU = FLAGS.use_gpu if FLAGS.multi_gpu: if hvd is None or distributed is None: ERROR("Distributed training is disable. Please check the installation of Horovod.") hvd.init() world_size = hvd.size() rank = hvd.rank() local_rank = hvd.local_rank() else: world_size = 1 rank = 0 local_rank = 0 if GlobalNames.USE_GPU: torch.cuda.set_device(local_rank) CURRENT_DEVICE = "cuda:{0}".format(local_rank) else: CURRENT_DEVICE = "cpu" # If not root_rank, close logging if rank != 0: close_logging() # write log of training to file. if rank == 0: write_log_to_file(os.path.join(FLAGS.log_path, "%s.log" % time.strftime("%Y%m%d-%H%M%S"))) # ================================================================================== # # Parsing configuration files config_path = os.path.abspath(FLAGS.config_path) with open(config_path.strip()) as f: configs = yaml.load(f) INFO(pretty_configs(configs)) # Add default configs configs = default_baseline_configs(configs) data_configs = configs['data_configs'] model_configs = configs['model_configs'] optimizer_configs = configs['optimizer_configs'] training_configs = configs['training_configs'] GlobalNames.SEED = training_configs['seed'] set_seed(GlobalNames.SEED) timer = Timer() # ================================================================================== # # Load Data INFO('Loading data...') timer.tic() # Generate target dictionary vocab_src = Vocabulary(**data_configs["vocabularies"][0]) vocab_tgt = Vocabulary(**data_configs["vocabularies"][1]) actual_buffer_size = training_configs["buffer_size"] * max(1, training_configs["update_cycle"]) train_bitext_dataset = ZipDataset( TextLineDataset(data_path=data_configs['train_data'][0], vocabulary=vocab_src, max_len=data_configs['max_len'][0], ), TextLineDataset(data_path=data_configs['train_data'][1], vocabulary=vocab_tgt, max_len=data_configs['max_len'][1], ) ) valid_bitext_dataset = ZipDataset( TextLineDataset(data_path=data_configs['valid_data'][0], vocabulary=vocab_src, ), TextLineDataset(data_path=data_configs['valid_data'][1], vocabulary=vocab_tgt, ) ) training_iterator = DataIterator(dataset=train_bitext_dataset, batch_size=training_configs["batch_size"], use_bucket=training_configs['use_bucket'], buffer_size=actual_buffer_size, batching_func=training_configs['batching_key'], world_size=world_size, rank=rank) valid_iterator = DataIterator(dataset=valid_bitext_dataset, batch_size=training_configs['valid_batch_size'], use_bucket=True, buffer_size=100000, numbering=True, world_size=world_size, rank=rank) bleu_scorer = SacreBLEUScorer(reference_path=data_configs["bleu_valid_reference"], num_refs=data_configs["num_refs"], lang_pair=data_configs["lang_pair"], sacrebleu_args=training_configs["bleu_valid_configs"]['sacrebleu_args'], postprocess=training_configs["bleu_valid_configs"]['postprocess'] ) INFO('Done. Elapsed time {0}'.format(timer.toc())) lrate = optimizer_configs['learning_rate'] is_early_stop = False # ================================ Begin ======================================== # # Build Model & Optimizer # We would do steps below on after another # 1. build models & criterion # 2. move models & criterion to gpu if needed # 3. load pre-trained model if needed # 4. build optimizer # 5. build learning rate scheduler if needed # 6. load checkpoints if needed # 0. Initial model_collections = Collections() best_model_prefix = os.path.join(FLAGS.saveto, FLAGS.model_name + GlobalNames.MY_BEST_MODEL_SUFFIX) checkpoint_saver = Saver(save_prefix="{0}.ckpt".format(os.path.join(FLAGS.saveto, FLAGS.model_name)), num_max_keeping=training_configs['num_kept_checkpoints'] ) best_model_saver = Saver(save_prefix=best_model_prefix, num_max_keeping=training_configs['num_kept_best_model']) INFO('Building model...') timer.tic() nmt_model = build_model(n_src_vocab=vocab_src.max_n_words, n_tgt_vocab=vocab_tgt.max_n_words, **model_configs) INFO(nmt_model) critic = NMTCriterion(label_smoothing=model_configs['label_smoothing']) INFO(critic) INFO('Done. Elapsed time {0}'.format(timer.toc())) # 2. Move to GPU if GlobalNames.USE_GPU: nmt_model = nmt_model.cuda() critic = critic.cuda() # 3. Load pretrained model if needed load_pretrained_model(nmt_model, FLAGS.pretrain_path, exclude_prefix=None, device=CURRENT_DEVICE) # 4. Build optimizer INFO('Building Optimizer...') optim = Optimizer(name=optimizer_configs['optimizer'], model=nmt_model, lr=lrate, grad_clip=optimizer_configs['grad_clip'], optim_args=optimizer_configs['optimizer_params'], distributed=True if world_size > 1 else False, update_cycle=training_configs['update_cycle'] ) # 5. Build scheduler for optimizer if needed if optimizer_configs['schedule_method'] is not None: if optimizer_configs['schedule_method'] == "loss": scheduler = ReduceOnPlateauScheduler(optimizer=optim, **optimizer_configs["scheduler_configs"] ) elif optimizer_configs['schedule_method'] == "noam": scheduler = NoamScheduler(optimizer=optim, **optimizer_configs['scheduler_configs']) else: WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".format(optimizer_configs['schedule_method'])) scheduler = None else: scheduler = None # 6. build moving average if training_configs['moving_average_method'] is not None: ma = MovingAverage(moving_average_method=training_configs['moving_average_method'], named_params=nmt_model.named_parameters(), alpha=training_configs['moving_average_alpha']) else: ma = None INFO('Done. Elapsed time {0}'.format(timer.toc())) # Reload from latest checkpoint if FLAGS.reload: checkpoint_saver.load_latest(model=nmt_model, optim=optim, lr_scheduler=scheduler, collections=model_collections, ma=ma) # broadcast parameters and optimizer states if world_size > 1: hvd.broadcast_parameters(params=nmt_model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer=optim.optim, root_rank=0) # ================================================================================== # # Prepare training eidx = model_collections.get_collection("eidx", [0])[-1] uidx = model_collections.get_collection("uidx", [1])[-1] bad_count = model_collections.get_collection("bad_count", [0])[-1] oom_count = model_collections.get_collection("oom_count", [0])[-1] cum_n_samples = 0 cum_n_words = 0 best_valid_loss = 1.0 * 1e10 # Max Float update_cycle = training_configs['update_cycle'] grad_denom = 0 if rank == 0: summary_writer = SummaryWriter(log_dir=FLAGS.log_path) else: summary_writer = None # Timer for computing speed timer_for_speed = Timer() timer_for_speed.tic() INFO('Begin training...') while True: if summary_writer is not None: summary_writer.add_scalar("Epoch", (eidx + 1), uidx) # Build iterator and progress bar training_iter = training_iterator.build_generator() if rank == 0: training_progress_bar = tqdm(desc=' - (Epoch %d) ' % eidx, total=len(training_iterator), unit="sents" ) else: training_progress_bar = None for batch in training_iter: seqs_x, seqs_y = batch batch_size = len(seqs_x) cum_n_samples += batch_size cum_n_words += sum(len(s) for s in seqs_y) try: # Prepare data x, y = prepare_data(seqs_x, seqs_y, cuda=GlobalNames.USE_GPU) loss = compute_forward(model=nmt_model, critic=critic, seqs_x=x, seqs_y=y, eval=False, normalization=1.0, norm_by_words=training_configs["norm_by_words"]) update_cycle -= 1 grad_denom += batch_size except RuntimeError as e: if 'out of memory' in str(e): print('| WARNING: ran out of memory, skipping batch') oom_count += 1 else: raise e # When update_cycle becomes 0, it means end of one batch. Several things will be done: # - update parameters # - reset update_cycle and grad_denom # - update uidx # - update moving average if update_cycle == 0: if world_size > 1: grad_denom = distributed.all_reduce(grad_denom) optim.step(denom=grad_denom) optim.zero_grad() if training_progress_bar is not None: training_progress_bar.update(grad_denom) update_cycle = training_configs['update_cycle'] grad_denom = 0 uidx += 1 if scheduler is None: pass elif optimizer_configs["schedule_method"] == "loss": scheduler.step(metric=best_valid_loss) else: scheduler.step(global_step=uidx) if ma is not None and eidx >= training_configs['moving_average_start_epoch']: ma.step() else: continue # ================================================================================== # # Display some information if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['disp_freq']): if world_size > 1: cum_n_words = sum(distributed.all_gather(cum_n_words)) cum_n_samples = sum(distributed.all_gather(cum_n_samples)) # words per second and sents per second words_per_sec = cum_n_words / (timer.toc(return_seconds=True)) sents_per_sec = cum_n_samples / (timer.toc(return_seconds=True)) lrate = list(optim.get_lrate())[0] if summary_writer is not None: summary_writer.add_scalar("Speed(words/sec)", scalar_value=words_per_sec, global_step=uidx) summary_writer.add_scalar("Speed(sents/sen)", scalar_value=sents_per_sec, global_step=uidx) summary_writer.add_scalar("lrate", scalar_value=lrate, global_step=uidx) summary_writer.add_scalar("oom_count", scalar_value=oom_count, global_step=uidx) # Reset timer timer.tic() cum_n_words = 0 cum_n_samples = 0 # ================================================================================== # # Loss Validation & Learning rate annealing if should_trigger_by_steps(global_step=uidx, n_epoch=eidx, every_n_step=training_configs['loss_valid_freq'], debug=FLAGS.debug): valid_loss = loss_validation(model=nmt_model, critic=critic, valid_iterator=valid_iterator, rank=rank, world_size=world_size ) model_collections.add_to_collection("history_losses", valid_loss) min_history_loss = np.array(model_collections.get_collection("history_losses")).min() best_valid_loss = min_history_loss if summary_writer is not None: summary_writer.add_scalar("loss", valid_loss, global_step=uidx) summary_writer.add_scalar("best_loss", min_history_loss, global_step=uidx) # ================================================================================== # # BLEU Validation & Early Stop if should_trigger_by_steps(global_step=uidx, n_epoch=eidx, every_n_step=training_configs['bleu_valid_freq'], min_step=training_configs['bleu_valid_warmup'], debug=FLAGS.debug): valid_bleu = bleu_validation(uidx=uidx, valid_iterator=valid_iterator, batch_size=training_configs["bleu_valid_batch_size"], model=nmt_model, bleu_scorer=bleu_scorer, vocab_tgt=vocab_tgt, valid_dir=FLAGS.valid_path, max_steps=training_configs["bleu_valid_configs"]["max_steps"], beam_size=training_configs["bleu_valid_configs"]["beam_size"], alpha=training_configs["bleu_valid_configs"]["alpha"], world_size=world_size, rank=rank, ) model_collections.add_to_collection(key="history_bleus", value=valid_bleu) best_valid_bleu = float(np.array(model_collections.get_collection("history_bleus")).max()) if summary_writer is not None: summary_writer.add_scalar("bleu", valid_bleu, uidx) summary_writer.add_scalar("best_bleu", best_valid_bleu, uidx) # If model get new best valid bleu score if valid_bleu >= best_valid_bleu: bad_count = 0 if is_early_stop is False: if rank == 0: # 1. save the best model torch.save(nmt_model.state_dict(), best_model_prefix + ".final") # 2. record all several best models best_model_saver.save(global_step=uidx, model=nmt_model, ma=ma) else: bad_count += 1 # At least one epoch should be traversed if bad_count >= training_configs['early_stop_patience'] and eidx > 0: is_early_stop = True WARN("Early Stop!") if summary_writer is not None: summary_writer.add_scalar("bad_count", bad_count, uidx) INFO("{0} Loss: {1:.2f} BLEU: {2:.2f} lrate: {3:6f} patience: {4}".format( uidx, valid_loss, valid_bleu, lrate, bad_count )) # ================================================================================== # # Saving checkpoints if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['save_freq'], debug=FLAGS.debug): model_collections.add_to_collection("uidx", uidx) model_collections.add_to_collection("eidx", eidx) model_collections.add_to_collection("bad_count", bad_count) if not is_early_stop: if rank == 0: checkpoint_saver.save(global_step=uidx, model=nmt_model, optim=optim, lr_scheduler=scheduler, collections=model_collections, ma=ma) if training_progress_bar is not None: training_progress_bar.close() eidx += 1 if eidx > training_configs["max_epochs"]: break
def train(FLAGS): """ FLAGS: saveto: str reload: store_true config_path: str pretrain_path: str, default="" model_name: str log_path: str """ # write log of training to file. write_log_to_file(os.path.join(FLAGS.log_path, "%s.log" % time.strftime("%Y%m%d-%H%M%S"))) GlobalNames.USE_GPU = FLAGS.use_gpu if GlobalNames.USE_GPU: CURRENT_DEVICE = "cpu" else: CURRENT_DEVICE = "cuda:0" config_path = os.path.abspath(FLAGS.config_path) with open(config_path.strip()) as f: configs = yaml.load(f) INFO(pretty_configs(configs)) # Add default configs configs = default_configs(configs) data_configs = configs['data_configs'] model_configs = configs['model_configs'] optimizer_configs = configs['optimizer_configs'] training_configs = configs['training_configs'] GlobalNames.SEED = training_configs['seed'] set_seed(GlobalNames.SEED) best_model_prefix = os.path.join(FLAGS.saveto, FLAGS.model_name + GlobalNames.MY_BEST_MODEL_SUFFIX) timer = Timer() # ================================================================================== # # Load Data INFO('Loading data...') timer.tic() # Generate target dictionary vocab_src = Vocabulary(**data_configs["vocabularies"][0]) vocab_tgt = Vocabulary(**data_configs["vocabularies"][1]) train_batch_size = training_configs["batch_size"] * max(1, training_configs["update_cycle"]) train_buffer_size = training_configs["buffer_size"] * max(1, training_configs["update_cycle"]) train_bitext_dataset = ZipDataset( TextLineDataset(data_path=data_configs['train_data'][0], vocabulary=vocab_src, max_len=data_configs['max_len'][0], ), TextLineDataset(data_path=data_configs['train_data'][1], vocabulary=vocab_tgt, max_len=data_configs['max_len'][1], ), shuffle=training_configs['shuffle'] ) valid_bitext_dataset = ZipDataset( TextLineDataset(data_path=data_configs['valid_data'][0], vocabulary=vocab_src, ), TextLineDataset(data_path=data_configs['valid_data'][1], vocabulary=vocab_tgt, ) ) training_iterator = DataIterator(dataset=train_bitext_dataset, batch_size=train_batch_size, use_bucket=training_configs['use_bucket'], buffer_size=train_buffer_size, batching_func=training_configs['batching_key']) valid_iterator = DataIterator(dataset=valid_bitext_dataset, batch_size=training_configs['valid_batch_size'], use_bucket=True, buffer_size=100000, numbering=True) bleu_scorer = SacreBLEUScorer(reference_path=data_configs["bleu_valid_reference"], num_refs=data_configs["num_refs"], lang_pair=data_configs["lang_pair"], sacrebleu_args=training_configs["bleu_valid_configs"]['sacrebleu_args'], postprocess=training_configs["bleu_valid_configs"]['postprocess'] ) INFO('Done. Elapsed time {0}'.format(timer.toc())) lrate = optimizer_configs['learning_rate'] is_early_stop = False # ================================ Begin ======================================== # # Build Model & Optimizer # We would do steps below on after another # 1. build models & criterion # 2. move models & criterion to gpu if needed # 3. load pre-trained model if needed # 4. build optimizer # 5. build learning rate scheduler if needed # 6. load checkpoints if needed # 0. Initial model_collections = Collections() checkpoint_saver = Saver(save_prefix="{0}.ckpt".format(os.path.join(FLAGS.saveto, FLAGS.model_name)), num_max_keeping=training_configs['num_kept_checkpoints'] ) best_model_saver = Saver(save_prefix=best_model_prefix, num_max_keeping=training_configs['num_kept_best_model']) # 1. Build Model & Criterion INFO('Building model...') timer.tic() nmt_model = build_model(n_src_vocab=vocab_src.max_n_words, n_tgt_vocab=vocab_tgt.max_n_words, **model_configs) INFO(nmt_model) critic = NMTCriterion(label_smoothing=model_configs['label_smoothing']) INFO(critic) INFO('Done. Elapsed time {0}'.format(timer.toc())) # 2. Move to GPU if GlobalNames.USE_GPU: nmt_model = nmt_model.cuda() critic = critic.cuda() # 3. Load pretrained model if needed load_pretrained_model(nmt_model, FLAGS.pretrain_path, exclude_prefix=None, device=CURRENT_DEVICE) # 4. Build optimizer INFO('Building Optimizer...') optim = Optimizer(name=optimizer_configs['optimizer'], model=nmt_model, lr=lrate, grad_clip=optimizer_configs['grad_clip'], optim_args=optimizer_configs['optimizer_params'] ) # 5. Build scheduler for optimizer if needed if optimizer_configs['schedule_method'] is not None: if optimizer_configs['schedule_method'] == "loss": scheduler = ReduceOnPlateauScheduler(optimizer=optim, **optimizer_configs["scheduler_configs"] ) elif optimizer_configs['schedule_method'] == "noam": scheduler = NoamScheduler(optimizer=optim, **optimizer_configs['scheduler_configs']) else: WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".format(optimizer_configs['schedule_method'])) scheduler = None else: scheduler = None # 6. build EMA if training_configs['ema_decay'] > 0.0: ema = ExponentialMovingAverage(named_params=nmt_model.named_parameters(), decay=training_configs['ema_decay']) else: ema = None INFO('Done. Elapsed time {0}'.format(timer.toc())) # Reload from latest checkpoint if FLAGS.reload: checkpoint_saver.load_latest(model=nmt_model, optim=optim, lr_scheduler=scheduler, collections=model_collections) # ================================================================================== # # Prepare training eidx = model_collections.get_collection("eidx", [0])[-1] uidx = model_collections.get_collection("uidx", [0])[-1] bad_count = model_collections.get_collection("bad_count", [0])[-1] summary_writer = SummaryWriter(log_dir=FLAGS.log_path) cum_samples = 0 cum_words = 0 best_valid_loss = 1.0 * 1e10 # Max Float saving_files = [] # Timer for computing speed timer_for_speed = Timer() timer_for_speed.tic() INFO('Begin training...') while True: summary_writer.add_scalar("Epoch", (eidx + 1), uidx) # Build iterator and progress bar training_iter = training_iterator.build_generator() training_progress_bar = tqdm(desc=' - (Epoch %d) ' % eidx, total=len(training_iterator), unit="sents" ) for batch in training_iter: uidx += 1 if scheduler is None: pass elif optimizer_configs["schedule_method"] == "loss": scheduler.step(metric=best_valid_loss) else: scheduler.step(global_step=uidx) seqs_x, seqs_y = batch n_samples_t = len(seqs_x) n_words_t = sum(len(s) for s in seqs_y) cum_samples += n_samples_t cum_words += n_words_t training_progress_bar.update(n_samples_t) optim.zero_grad() # Prepare data for seqs_x_t, seqs_y_t in split_shard(seqs_x, seqs_y, split_size=training_configs['update_cycle']): x, y = prepare_data(seqs_x_t, seqs_y_t, cuda=GlobalNames.USE_GPU) loss = compute_forward(model=nmt_model, critic=critic, seqs_x=x, seqs_y=y, eval=False, normalization=n_samples_t, norm_by_words=training_configs["norm_by_words"]) optim.step() if ema is not None: ema.step() # ================================================================================== # # Display some information if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['disp_freq']): # words per second and sents per second words_per_sec = cum_words / (timer.toc(return_seconds=True)) sents_per_sec = cum_samples / (timer.toc(return_seconds=True)) lrate = list(optim.get_lrate())[0] summary_writer.add_scalar("Speed(words/sec)", scalar_value=words_per_sec, global_step=uidx) summary_writer.add_scalar("Speed(sents/sen)", scalar_value=sents_per_sec, global_step=uidx) summary_writer.add_scalar("lrate", scalar_value=lrate, global_step=uidx) # Reset timer timer.tic() cum_words = 0 cum_samples = 0 # ================================================================================== # # Saving checkpoints if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['save_freq'], debug=FLAGS.debug): model_collections.add_to_collection("uidx", uidx) model_collections.add_to_collection("eidx", eidx) model_collections.add_to_collection("bad_count", bad_count) if not is_early_stop: checkpoint_saver.save(global_step=uidx, model=nmt_model, optim=optim, lr_scheduler=scheduler, collections=model_collections, ema=ema) # ================================================================================== # # Loss Validation & Learning rate annealing if should_trigger_by_steps(global_step=uidx, n_epoch=eidx, every_n_step=training_configs['loss_valid_freq'], debug=FLAGS.debug): if ema is not None: origin_state_dict = deepcopy(nmt_model.state_dict()) nmt_model.load_state_dict(ema.state_dict(), strict=False) valid_loss = loss_validation(model=nmt_model, critic=critic, valid_iterator=valid_iterator, ) model_collections.add_to_collection("history_losses", valid_loss) min_history_loss = np.array(model_collections.get_collection("history_losses")).min() summary_writer.add_scalar("loss", valid_loss, global_step=uidx) summary_writer.add_scalar("best_loss", min_history_loss, global_step=uidx) best_valid_loss = min_history_loss if ema is not None: nmt_model.load_state_dict(origin_state_dict) del origin_state_dict # ================================================================================== # # BLEU Validation & Early Stop if should_trigger_by_steps(global_step=uidx, n_epoch=eidx, every_n_step=training_configs['bleu_valid_freq'], min_step=training_configs['bleu_valid_warmup'], debug=FLAGS.debug): if ema is not None: origin_state_dict = deepcopy(nmt_model.state_dict()) nmt_model.load_state_dict(ema.state_dict(), strict=False) valid_bleu = bleu_validation(uidx=uidx, valid_iterator=valid_iterator, batch_size=training_configs["bleu_valid_batch_size"], model=nmt_model, bleu_scorer=bleu_scorer, vocab_tgt=vocab_tgt, valid_dir=FLAGS.valid_path, max_steps=training_configs["bleu_valid_configs"]["max_steps"], beam_size=training_configs["bleu_valid_configs"]["beam_size"], alpha=training_configs["bleu_valid_configs"]["alpha"] ) model_collections.add_to_collection(key="history_bleus", value=valid_bleu) best_valid_bleu = float(np.array(model_collections.get_collection("history_bleus")).max()) summary_writer.add_scalar("bleu", valid_bleu, uidx) summary_writer.add_scalar("best_bleu", best_valid_bleu, uidx) # If model get new best valid bleu score if valid_bleu >= best_valid_bleu: bad_count = 0 if is_early_stop is False: # 1. save the best model torch.save(nmt_model.state_dict(), best_model_prefix + ".final") # 2. record all several best models best_model_saver.save(global_step=uidx, model=nmt_model) else: bad_count += 1 # At least one epoch should be traversed if bad_count >= training_configs['early_stop_patience'] and eidx > 0: is_early_stop = True WARN("Early Stop!") summary_writer.add_scalar("bad_count", bad_count, uidx) if ema is not None: nmt_model.load_state_dict(origin_state_dict) del origin_state_dict INFO("{0} Loss: {1:.2f} BLEU: {2:.2f} lrate: {3:6f} patience: {4}".format( uidx, valid_loss, valid_bleu, lrate, bad_count )) training_progress_bar.close() eidx += 1 if eidx > training_configs["max_epochs"]: break