def main(): args = parse() # Load a conf file if args.resume: conf = load_config( os.path.join(os.path.dirname(args.resume), 'conf.yml')) for k, v in conf.items(): if k != 'resume': setattr(args, k, v) # Load dataset train_set = Dataset(corpus=args.corpus, tsv_path=args.train_set, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=args.batch_size * args.n_gpus, n_epochs=args.n_epochs, min_n_tokens=args.min_n_tokens, bptt=args.bptt, backward=args.backward, serialize=args.serialize) dev_set = Dataset(corpus=args.corpus, tsv_path=args.dev_set, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=args.batch_size * args.n_gpus, bptt=args.bptt, backward=args.backward, serialize=args.serialize) eval_sets = [] for s in args.eval_sets: eval_sets += [ Dataset(corpus=args.corpus, tsv_path=s, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=1, bptt=args.bptt, backward=args.backward, serialize=args.serialize) ] args.vocab = train_set.vocab # Set save path if args.resume: save_path = os.path.dirname(args.resume) dir_name = os.path.basename(save_path) else: dir_name = make_model_name(args) save_path = mkdir_join( args.model, '_'.join(os.path.basename(args.train_set).split('.')[:-1]), dir_name) save_path = set_save_path(save_path) # avoid overwriting # Set logger logger = set_logger(os.path.join(save_path, 'train.log'), key='training') # Model setting if 'gated_conv' in args.lm_type: model = GatedConvLM(args) else: model = RNNLM(args) model.save_path = save_path if args.resume: # Set optimizer epoch = int(args.resume.split('-')[-1]) model.set_optimizer( optimizer='sgd' if epoch > conf['convert_to_sgd_epoch'] + 1 else conf['optimizer'], learning_rate=float(conf['learning_rate']), # on-the-fly weight_decay=float(conf['weight_decay'])) # Restore the last saved model model, checkpoint = load_checkpoint(model, args.resume, resume=True) lr_controller = checkpoint['lr_controller'] epoch = checkpoint['epoch'] step = checkpoint['step'] ppl_dev_best = checkpoint['metric_dev_best'] # Resume between convert_to_sgd_epoch and convert_to_sgd_epoch + 1 if epoch == conf['convert_to_sgd_epoch'] + 1: model.set_optimizer(optimizer='sgd', learning_rate=args.learning_rate, weight_decay=float(conf['weight_decay'])) logger.info('========== Convert to SGD ==========') else: # Save the conf file as a yaml file save_config(vars(args), os.path.join(model.save_path, 'conf.yml')) # Save the nlsyms, dictionar, and wp_model if args.nlsyms: shutil.copy(args.nlsyms, os.path.join(model.save_path, 'nlsyms.txt')) shutil.copy(args.dict, os.path.join(model.save_path, 'dict.txt')) if args.unit == 'wp': shutil.copy(args.wp_model, os.path.join(model.save_path, 'wp.model')) for k, v in sorted(vars(args).items(), key=lambda x: x[0]): logger.info('%s: %s' % (k, str(v))) # Count total parameters for n in sorted(list(model.num_params_dict.keys())): nparams = model.num_params_dict[n] logger.info("%s %d" % (n, nparams)) logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000)) logger.info(model) # Set optimizer model.set_optimizer(optimizer=args.optimizer, learning_rate=float(args.learning_rate), weight_decay=float(args.weight_decay)) epoch, step = 1, 1 ppl_dev_best = 10000 # Set learning rate controller lr_controller = Controller( learning_rate=float(args.learning_rate), decay_type=args.decay_type, decay_start_epoch=args.decay_start_epoch, decay_rate=args.decay_rate, decay_patient_n_epochs=args.decay_patient_n_epochs, lower_better=True, best_value=ppl_dev_best) train_set.epoch = epoch - 1 # start from index:0 # GPU setting if args.n_gpus >= 1: model = CustomDataParallel(model, device_ids=list(range(0, args.n_gpus, 1)), deterministic=False, benchmark=True) model.cuda() logger.info('PID: %s' % os.getpid()) logger.info('USERNAME: %s' % os.uname()[1]) # Set process name if args.job_name: setproctitle(args.job_name) else: setproctitle(dir_name) # Set reporter reporter = Reporter(model.module.save_path, tensorboard=True) hidden = None start_time_train = time.time() start_time_epoch = time.time() start_time_step = time.time() not_improved_epoch = 0 pbar_epoch = tqdm(total=len(train_set)) while True: # Compute loss in the training set ys_train, is_new_epoch = train_set.next() model.module.optimizer.zero_grad() loss, hidden, reporter = model(ys_train, hidden, reporter) if len(model.device_ids) > 1: loss.backward(torch.ones(len(model.device_ids))) else: loss.backward() loss.detach() # Trancate the graph if args.clip_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.module.parameters(), args.clip_grad_norm) model.module.optimizer.step() loss_train = loss.item() del loss if 'gated_conv' not in args.lm_type: hidden = model.module.repackage_hidden(hidden) reporter.step(is_eval=False) if step % args.print_step == 0: # Compute loss in the dev set ys_dev = dev_set.next()[0] loss, _, reporter = model(ys_dev, None, reporter, is_eval=True) loss_dev = loss.item() del loss reporter.step(is_eval=True) duration_step = time.time() - start_time_step logger.info( "step:%d(ep:%.2f) loss:%.3f(%.3f)/ppl:%.3f(%.3f)/lr:%.5f/bs:%d (%.2f min)" % (step, train_set.epoch_detail, loss_train, loss_dev, np.exp(loss_train), np.exp(loss_dev), lr_controller.lr, ys_train.shape[0], duration_step / 60)) start_time_step = time.time() step += args.n_gpus pbar_epoch.update(ys_train.shape[0] * (ys_train.shape[1] - 1)) # Save fugures of loss and accuracy if step % (args.print_step * 10) == 0: reporter.snapshot() # Save checkpoint and evaluate model per epoch if is_new_epoch: duration_epoch = time.time() - start_time_epoch logger.info('========== EPOCH:%d (%.2f min) ==========' % (epoch, duration_epoch / 60)) if epoch < args.eval_start_epoch: # Save the model save_checkpoint(model.module, model.module.save_path, lr_controller, epoch, step - 1, ppl_dev_best, remove_old_checkpoints=True) else: start_time_eval = time.time() # dev ppl_dev, _ = eval_ppl([model.module], dev_set, batch_size=1, bptt=args.bptt) logger.info('PPL (%s): %.2f' % (dev_set.set, ppl_dev)) # Update learning rate model.module.optimizer = lr_controller.decay( model.module.optimizer, epoch=epoch, value=ppl_dev) if ppl_dev < ppl_dev_best: ppl_dev_best = ppl_dev not_improved_epoch = 0 logger.info('||||| Best Score |||||') # Save the model save_checkpoint(model.module, model.module.save_path, lr_controller, epoch, step - 1, ppl_dev_best, remove_old_checkpoints=True) # test ppl_test_avg = 0. for eval_set in eval_sets: ppl_test, _ = eval_ppl([model.module], eval_set, batch_size=1, bptt=args.bptt) logger.info('PPL (%s): %.2f' % (eval_set.set, ppl_test)) ppl_test_avg += ppl_test if len(eval_sets) > 0: logger.info('PPL (avg.): %.2f' % (ppl_test_avg / len(eval_sets))) else: not_improved_epoch += 1 duration_eval = time.time() - start_time_eval logger.info('Evaluation time: %.2f min' % (duration_eval / 60)) # Early stopping if not_improved_epoch == args.not_improved_patient_n_epochs: break # Convert to fine-tuning stage if epoch == args.convert_to_sgd_epoch: model.module.set_optimizer( 'sgd', learning_rate=args.learning_rate, weight_decay=float(args.weight_decay)) lr_controller = Controller( learning_rate=args.learning_rate, decay_type='epoch', decay_start_epoch=epoch, decay_rate=0.5, lower_better=True) logger.info('========== Convert to SGD ==========') pbar_epoch = tqdm(total=len(train_set)) if epoch == args.n_epochs: break start_time_step = time.time() start_time_epoch = time.time() epoch += 1 duration_train = time.time() - start_time_train logger.info('Total time: %.2f hour' % (duration_train / 3600)) if reporter.tensorboard: reporter.tf_writer.close() pbar_epoch.close() return model.module.save_path
def main(): args = parse() # Load a conf file if args.resume: conf = load_config( os.path.join(os.path.dirname(args.resume), 'conf.yml')) for k, v in conf.items(): if k != 'resume': setattr(args, k, v) # Load dataset train_set = Dataset(corpus=args.corpus, tsv_path=args.train_set, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=args.batch_size * args.n_gpus, n_epochs=args.n_epochs, min_n_tokens=args.min_n_tokens, bptt=args.bptt, backward=args.backward, serialize=args.serialize) dev_set = Dataset(corpus=args.corpus, tsv_path=args.dev_set, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=args.batch_size * args.n_gpus, bptt=args.bptt, backward=args.backward, serialize=args.serialize) eval_sets = [] for s in args.eval_sets: eval_sets += [ Dataset(corpus=args.corpus, tsv_path=s, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=1, bptt=args.bptt, backward=args.backward, serialize=args.serialize) ] args.vocab = train_set.vocab # Set save path if args.resume: save_path = os.path.dirname(args.resume) dir_name = os.path.basename(save_path) else: dir_name = set_lm_name(args) save_path = mkdir_join( args.model_save_dir, '_'.join(os.path.basename(args.train_set).split('.')[:-1]), dir_name) save_path = set_save_path(save_path) # avoid overwriting # Set logger logger = set_logger(os.path.join(save_path, 'train.log'), key='training', stdout=args.stdout) # Model setting model = select_lm(args, save_path) if args.resume: # Set optimizer epoch = int(args.resume.split('-')[-1]) optimizer = set_optimizer( model, 'sgd' if epoch > conf['convert_to_sgd_epoch'] else conf['optimizer'], conf['lr'], conf['weight_decay']) # Restore the last saved model model, optimizer = load_checkpoint(model, args.resume, optimizer, resume=True) # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch if epoch == conf['convert_to_sgd_epoch']: optimizer = set_optimizer(model, 'sgd', args.lr, conf['weight_decay']) optimizer = LRScheduler(optimizer, args.lr, decay_type='always', decay_start_epoch=0, decay_rate=0.5) logger.info('========== Convert to SGD ==========') else: # Save the conf file as a yaml file save_config(vars(args), os.path.join(save_path, 'conf.yml')) # Save the nlsyms, dictionar, and wp_model if args.nlsyms: shutil.copy(args.nlsyms, os.path.join(save_path, 'nlsyms.txt')) shutil.copy(args.dict, os.path.join(save_path, 'dict.txt')) if args.unit == 'wp': shutil.copy(args.wp_model, os.path.join(save_path, 'wp.model')) for k, v in sorted(vars(args).items(), key=lambda x: x[0]): logger.info('%s: %s' % (k, str(v))) # Count total parameters for n in sorted(list(model.num_params_dict.keys())): n_params = model.num_params_dict[n] logger.info("%s %d" % (n, n_params)) logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000)) logger.info(model) # Set optimizer optimizer = set_optimizer(model, args.optimizer, args.lr, args.weight_decay) # Wrap optimizer by learning rate scheduler optimizer = LRScheduler( optimizer, args.lr, decay_type=args.lr_decay_type, decay_start_epoch=args.lr_decay_start_epoch, decay_rate=args.lr_decay_rate, decay_patient_n_epochs=args.lr_decay_patient_n_epochs, early_stop_patient_n_epochs=args.early_stop_patient_n_epochs, warmup_start_lr=args.warmup_start_lr, warmup_n_steps=args.warmup_n_steps, model_size=args.d_model, factor=args.lr_factor, noam=args.lm_type == 'transformer') # GPU setting if args.n_gpus >= 1: model = CustomDataParallel(model, device_ids=list(range(0, args.n_gpus, 1)), deterministic=False, benchmark=True) model.cuda() # Set process name logger.info('PID: %s' % os.getpid()) logger.info('USERNAME: %s' % os.uname()[1]) setproctitle(args.job_name if args.job_name else dir_name) # Set reporter reporter = Reporter(save_path, tensorboard=True) hidden = None start_time_train = time.time() start_time_epoch = time.time() start_time_step = time.time() pbar_epoch = tqdm(total=len(train_set)) accum_n_tokens = 0 while True: # Compute loss in the training set ys_train, is_new_epoch = train_set.next() accum_n_tokens += sum([len(y) for y in ys_train]) optimizer.zero_grad() loss, hidden, reporter = model(ys_train, hidden, reporter) # loss /= args.accum_grad_n_steps if len(model.device_ids) > 1: loss.backward(torch.ones(len(model.device_ids))) else: loss.backward() loss.detach() # Trancate the graph if args.accum_grad_n_tokens == 0 or accum_n_tokens >= args.accum_grad_n_tokens: if args.clip_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.module.parameters(), args.clip_grad_norm) optimizer.step() optimizer.zero_grad() accum_n_tokens = 0 loss_train = loss.item() del loss hidden = model.module.repackage_state(hidden) reporter.step() if optimizer.n_steps % args.print_step == 0: # Compute loss in the dev set ys_dev = dev_set.next()[0] loss, _, reporter = model(ys_dev, None, reporter, is_eval=True) loss_dev = loss.item() del loss reporter.step(is_eval=True) duration_step = time.time() - start_time_step logger.info( "step:%d(ep:%.2f) loss:%.3f(%.3f)/ppl:%.3f(%.3f)/lr:%.5f/bs:%d (%.2f min)" % (optimizer.n_steps, optimizer.n_epochs + train_set.epoch_detail, loss_train, loss_dev, np.exp(loss_train), np.exp(loss_dev), optimizer.lr, ys_train.shape[0], duration_step / 60)) start_time_step = time.time() pbar_epoch.update(ys_train.shape[0] * (ys_train.shape[1] - 1)) # Save fugures of loss and accuracy if optimizer.n_steps % (args.print_step * 10) == 0: reporter.snapshot() if args.lm_type == 'transformer': model.module.plot_attention() # Save checkpoint and evaluate model per epoch if is_new_epoch: duration_epoch = time.time() - start_time_epoch logger.info('========== EPOCH:%d (%.2f min) ==========' % (optimizer.n_epochs + 1, duration_epoch / 60)) if optimizer.n_epochs + 1 < args.eval_start_epoch: optimizer.epoch() reporter.epoch() # Save the model save_checkpoint( model, save_path, optimizer, optimizer.n_epochs, remove_old_checkpoints=args.lm_type != 'transformer') else: start_time_eval = time.time() # dev ppl_dev, _ = eval_ppl([model.module], dev_set, batch_size=1, bptt=args.bptt) logger.info('PPL (%s): %.2f' % (dev_set.set, ppl_dev)) optimizer.epoch(ppl_dev) reporter.epoch(ppl_dev, name='perplexity') if optimizer.is_best: # Save the model save_checkpoint( model, save_path, optimizer, optimizer.n_epochs, remove_old_checkpoints=args.lm_type != 'transformer') # test ppl_test_avg = 0. for eval_set in eval_sets: ppl_test, _ = eval_ppl([model.module], eval_set, batch_size=1, bptt=args.bptt) logger.info('PPL (%s): %.2f' % (eval_set.set, ppl_test)) ppl_test_avg += ppl_test if len(eval_sets) > 0: logger.info('PPL (avg.): %.2f' % (ppl_test_avg / len(eval_sets))) duration_eval = time.time() - start_time_eval logger.info('Evaluation time: %.2f min' % (duration_eval / 60)) # Early stopping if optimizer.is_early_stop: break # Convert to fine-tuning stage if optimizer.n_epochs == args.convert_to_sgd_epoch: optimizer = set_optimizer(model, 'sgd', args.lr, args.weight_decay) optimizer = LRScheduler(optimizer, args.lr, decay_type='always', decay_start_epoch=0, decay_rate=0.5) logger.info('========== Convert to SGD ==========') pbar_epoch = tqdm(total=len(train_set)) if optimizer.n_epochs == args.n_epochs: break start_time_step = time.time() start_time_epoch = time.time() duration_train = time.time() - start_time_train logger.info('Total time: %.2f hour' % (duration_train / 3600)) if reporter.tensorboard: reporter.tf_writer.close() pbar_epoch.close() return save_path
def main(): # Load a config file if args.resume_model is None: config = load_config(args.config) else: # Restart from the last checkpoint config = load_config(os.path.join(args.resume_model, 'config.yml')) # Check differences between args and yaml comfiguraiton for k, v in vars(args).items(): if k not in config.keys(): warnings.warn("key %s is automatically set to %s" % (k, str(v))) # Merge config with args for k, v in config.items(): setattr(args, k, v) # Load dataset train_set = Dataset(csv_path=args.train_set, dict_path=args.dict, label_type=args.label_type, batch_size=args.batch_size * args.ngpus, bptt=args.bptt, eos=args.eos, max_epoch=args.num_epochs, shuffle=True) dev_set = Dataset(csv_path=args.dev_set, dict_path=args.dict, label_type=args.label_type, batch_size=args.batch_size * args.ngpus, bptt=args.bptt, eos=args.eos, shuffle=True) eval_sets = [] for set in args.eval_sets: eval_sets += [Dataset(csv_path=set, dict_path=args.dict, label_type=args.label_type, batch_size=1, bptt=args.bptt, eos=args.eos, is_test=True)] args.num_classes = train_set.num_classes # Model setting model = RNNLM(args) model.name = args.rnn_type model.name += str(args.num_units) + 'H' model.name += str(args.num_projs) + 'P' model.name += str(args.num_layers) + 'L' model.name += '_emb' + str(args.emb_dim) model.name += '_' + args.optimizer model.name += '_lr' + str(args.learning_rate) model.name += '_bs' + str(args.batch_size) if args.tie_weights: model.name += '_tie' if args.residual: model.name += '_residual' if args.backward: model.name += '_bwd' if args.resume_model is None: # Set save path save_path = mkdir_join(args.model, '_'.join(os.path.basename(args.train_set).split('.')[:-1]), model.name) model.set_save_path(save_path) # avoid overwriting # Save the config file as a yaml file save_config(vars(args), model.save_path) # Save the dictionary & wp_model shutil.copy(args.dict, os.path.join(save_path, 'dict.txt')) if args.label_type == 'wordpiece': shutil.copy(args.wp_model, os.path.join(save_path, 'wp.model')) # Setting for logging logger = set_logger(os.path.join(model.save_path, 'train.log'), key='training') for k, v in sorted(vars(args).items(), key=lambda x: x[0]): logger.info('%s: %s' % (k, str(v))) # Count total parameters for name in sorted(list(model.num_params_dict.keys())): num_params = model.num_params_dict[name] logger.info("%s %d" % (name, num_params)) logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000)) # Set optimizer model.set_optimizer(optimizer=args.optimizer, learning_rate_init=float(args.learning_rate), weight_decay=float(args.weight_decay), clip_grad_norm=args.clip_grad_norm, lr_schedule=False, factor=args.decay_rate, patience_epoch=args.decay_patient_epoch) epoch, step = 1, 0 learning_rate = float(args.learning_rate) metric_dev_best = 10000 else: raise NotImplementedError() train_set.epoch = epoch - 1 # GPU setting if args.ngpus >= 1: model = CustomDataParallel(model, device_ids=list(range(0, args.ngpus, 1)), deterministic=True, benchmark=False) model.cuda() logger.info('PID: %s' % os.getpid()) logger.info('USERNAME: %s' % os.uname()[1]) # Set process name # setproctitle(args.job_name) # Set learning rate controller lr_controller = Controller(learning_rate_init=learning_rate, decay_type=args.decay_type, decay_start_epoch=args.decay_start_epoch, decay_rate=args.decay_rate, decay_patient_epoch=args.decay_patient_epoch, lower_better=True, best_value=metric_dev_best) # Set reporter reporter = Reporter(model.module.save_path, max_loss=10) # Set the updater updater = Updater(args.clip_grad_norm) # Setting for tensorboard tf_writer = SummaryWriter(model.module.save_path) start_time_train = time.time() start_time_epoch = time.time() start_time_step = time.time() not_improved_epoch = 0 loss_train_mean, acc_train_mean = 0., 0. pbar_epoch = tqdm(total=len(train_set)) pbar_all = tqdm(total=len(train_set) * args.num_epochs) while True: # Compute loss in the training set (including parameter update) ys_train, is_new_epoch = train_set.next() model, loss_train, acc_train = updater(model, ys_train, args.bptt) loss_train_mean += loss_train acc_train_mean += acc_train pbar_epoch.update(np.sum([len(y) for y in ys_train])) if (step + 1) % args.print_step == 0: # Compute loss in the dev set ys_dev = dev_set.next()[0] model, loss_dev, acc_dev = updater(model, ys_dev, args.bptt, is_eval=True) loss_train_mean /= args.print_step acc_train_mean /= args.print_step reporter.step(step, loss_train_mean, loss_dev, acc_train_mean, acc_dev) # Logging by tensorboard tf_writer.add_scalar('train/loss', loss_train_mean, step + 1) tf_writer.add_scalar('dev/loss', loss_dev, step + 1) for n, p in model.module.named_parameters(): n = n.replace('.', '/') if p.grad is not None: tf_writer.add_histogram(n, p.data.cpu().numpy(), step + 1) tf_writer.add_histogram(n + '/grad', p.grad.data.cpu().numpy(), step + 1) duration_step = time.time() - start_time_step logger.info("...Step:%d(ep:%.2f) loss:%.2f(%.2f)/acc:%.2f(%.2f)/ppl:%.2f(%.2f)/lr:%.5f/bs:%d (%.2f min)" % (step + 1, train_set.epoch_detail, loss_train_mean, loss_dev, acc_train_mean, acc_dev, math.exp(loss_train_mean), math.exp(loss_dev), learning_rate, len(ys_train), duration_step / 60)) start_time_step = time.time() loss_train_mean, acc_train_mean = 0., 0. step += args.ngpus # Save checkpoint and evaluate model per epoch if is_new_epoch: duration_epoch = time.time() - start_time_epoch logger.info('===== EPOCH:%d (%.2f min) =====' % (epoch, duration_epoch / 60)) # Save fugures of loss and accuracy reporter.epoch() if epoch < args.eval_start_epoch: # Save the model model.module.save_checkpoint(model.module.save_path, epoch, step, learning_rate, metric_dev_best) else: start_time_eval = time.time() # dev ppl_dev = eval_ppl([model.module], dev_set, args.bptt) logger.info(' PPL (%s): %.3f' % (dev_set.set, ppl_dev)) if ppl_dev < metric_dev_best: metric_dev_best = ppl_dev not_improved_epoch = 0 logger.info('||||| Best Score |||||') # Update learning rate model.module.optimizer, learning_rate = lr_controller.decay_lr( optimizer=model.module.optimizer, learning_rate=learning_rate, epoch=epoch, value=ppl_dev) # Save the model model.module.save_checkpoint(model.module.save_path, epoch, step, learning_rate, metric_dev_best) # test ppl_test_mean = 0. for eval_set in eval_sets: ppl_test = eval_ppl([model.module], eval_set, args.bptt) logger.info(' PPL (%s): %.3f' % (eval_set.set, ppl_test)) ppl_test_mean += ppl_test if len(eval_sets) > 0: logger.info(' PPL (mean): %.3f' % (ppl_test_mean / len(eval_sets))) else: # Update learning rate model.module.optimizer, learning_rate = lr_controller.decay_lr( optimizer=model.module.optimizer, learning_rate=learning_rate, epoch=epoch, value=ppl_dev) not_improved_epoch += 1 duration_eval = time.time() - start_time_eval logger.info('Evaluation time: %.2f min' % (duration_eval / 60)) # Early stopping if not_improved_epoch == args.not_improved_patient_epoch: break if epoch == args.convert_to_sgd_epoch: # Convert to fine-tuning stage model.module.set_optimizer( 'sgd', learning_rate_init=float(args.learning_rate), # TODO: ? weight_decay=float(args.weight_decay), clip_grad_norm=args.clip_grad_norm, lr_schedule=False, factor=args.decay_rate, patience_epoch=args.decay_patient_epoch) logger.info('========== Convert to SGD ==========') pbar_epoch = tqdm(total=len(train_set)) pbar_all.update(len(train_set)) if epoch == args.num_epoch: break start_time_step = time.time() start_time_epoch = time.time() epoch += 1 duration_train = time.time() - start_time_train logger.info('Total time: %.2f hour' % (duration_train / 3600)) tf_writer.close() pbar_epoch.close() pbar_all.close() return model.module.save_path
def main(): args = parse() # Load a conf file dir_name = os.path.dirname(args.recog_model[0]) conf = load_config(os.path.join(dir_name, 'conf.yml')) # Overwrite conf for k, v in conf.items(): if 'recog' not in k: setattr(args, k, v) # Setting for logging if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')): os.remove(os.path.join(args.recog_dir, 'plot.log')) logger = set_logger(os.path.join(args.recog_dir, 'plot.log'), key='decoding') for i, s in enumerate(args.recog_sets): # Load dataset dataset = Dataset(corpus=args.corpus, tsv_path=s, dict_path=os.path.join(dir_name, 'dict.txt'), wp_model=os.path.join(dir_name, 'wp.model'), unit=args.unit, batch_size=args.recog_batch_size, bptt=args.bptt, serialize=args.serialize, is_test=True) if i == 0: # Load the LM if args.lm_type == 'gated_cnn': model = GatedConvLM(args) else: model = RNNLM(args) epoch = model.load_checkpoint(args.recog_model[0])['epoch'] model.save_path = dir_name logger.info('epoch: %d' % (epoch - 1)) logger.info('batch size: %d' % args.recog_batch_size) # logger.info('recog unit: %s' % args.recog_unit) # logger.info('ensemble: %d' % (len(ensemble_models))) logger.info('BPTT: %d' % (args.bptt)) logger.info('cache size: %d' % (args.recog_n_caches)) logger.info('cache theta: %.3f' % (args.recog_cache_theta)) logger.info('cache lambda: %.3f' % (args.recog_cache_lambda)) model.cache_theta = args.recog_cache_theta model.cache_lambda = args.recog_cache_lambda # GPU setting model.cuda() assert args.recog_n_caches > 0 save_path = mkdir_join(args.recog_dir, 'cache') # Clean directory if save_path is not None and os.path.isdir(save_path): shutil.rmtree(save_path) os.mkdir(save_path) if args.unit == 'word': idx2token = dataset.idx2word elif args.unit == 'wp': idx2token = dataset.idx2wp elif args.unit == 'char': idx2token = dataset.idx2char elif args.unit == 'phone': idx2token = dataset.idx2phone else: raise NotImplementedError(args.unit) hidden = None fig_count = 0 toknen_count = 0 n_tokens = args.recog_n_caches while True: ys, is_new_epoch = dataset.next() for t in range(ys.shape[1] - 1): loss, hidden = model(ys[:, t:t + 2], hidden, is_eval=True, n_caches=args.recog_n_caches)[:2] if len(model.cache_attn) > 0: if toknen_count == n_tokens: tokens_keys = idx2token( model.cache_ids[:args.recog_n_caches], return_list=True) tokens_query = idx2token(model.cache_ids[-n_tokens:], return_list=True) # Slide attention matrix n_keys = len(tokens_keys) n_queries = len(tokens_query) cache_probs = np.zeros( (n_keys, n_queries)) # `[n_keys, n_queries]` mask = np.zeros((n_keys, n_queries)) for i, aw in enumerate(model.cache_attn[-n_tokens:]): cache_probs[:(n_keys - n_queries + i + 1), i] = aw[0, -(n_keys - n_queries + i + 1):] mask[(n_keys - n_queries + i + 1):, i] = 1 plot_cache_weights(cache_probs, keys=tokens_keys, queries=tokens_query, save_path=mkdir_join( save_path, str(fig_count) + '.png'), figsize=(40, 16), mask=mask) toknen_count = 0 fig_count += 1 else: toknen_count += 1 if is_new_epoch: break