def main(): args = parse() args_pt = copy.deepcopy(args) args_teacher = copy.deepcopy(args) # Load a conf file if args.resume: conf = load_config( os.path.join(os.path.dirname(args.resume), 'conf.yml')) for k, v in conf.items(): if k != 'resume': setattr(args, k, v) recog_params = vars(args) # Automatically reduce batch size in multi-GPU setting if args.n_gpus > 1: args.batch_size -= 10 args.print_step //= args.n_gpus # Compute subsampling factor subsample_factor = 1 subsample_factor_sub1 = 1 subsample_factor_sub2 = 1 subsample = [int(s) for s in args.subsample.split('_')] if args.conv_poolings and 'conv' in args.enc_type: for p in args.conv_poolings.split('_'): subsample_factor *= int(p.split(',')[0].replace('(', '')) else: subsample_factor = np.prod(subsample) if args.train_set_sub1: if args.conv_poolings and 'conv' in args.enc_type: subsample_factor_sub1 = subsample_factor * np.prod( subsample[:args.enc_n_layers_sub1 - 1]) else: subsample_factor_sub1 = subsample_factor if args.train_set_sub2: if args.conv_poolings and 'conv' in args.enc_type: subsample_factor_sub2 = subsample_factor * np.prod( subsample[:args.enc_n_layers_sub2 - 1]) else: subsample_factor_sub2 = subsample_factor skip_thought = 'skip' in args.enc_type # Load dataset train_set = Dataset(corpus=args.corpus, tsv_path=args.train_set, tsv_path_sub1=args.train_set_sub1, tsv_path_sub2=args.train_set_sub2, dict_path=args.dict, dict_path_sub1=args.dict_sub1, dict_path_sub2=args.dict_sub2, nlsyms=args.nlsyms, unit=args.unit, unit_sub1=args.unit_sub1, unit_sub2=args.unit_sub2, wp_model=args.wp_model, wp_model_sub1=args.wp_model_sub1, wp_model_sub2=args.wp_model_sub2, batch_size=args.batch_size * args.n_gpus, n_epochs=args.n_epochs, min_n_frames=args.min_n_frames, max_n_frames=args.max_n_frames, sort_by='input', short2long=True, sort_stop_epoch=args.sort_stop_epoch, dynamic_batching=args.dynamic_batching, ctc=args.ctc_weight > 0, ctc_sub1=args.ctc_weight_sub1 > 0, ctc_sub2=args.ctc_weight_sub2 > 0, subsample_factor=subsample_factor, subsample_factor_sub1=subsample_factor_sub1, subsample_factor_sub2=subsample_factor_sub2, discourse_aware=args.discourse_aware, skip_thought=skip_thought) dev_set = Dataset(corpus=args.corpus, tsv_path=args.dev_set, tsv_path_sub1=args.dev_set_sub1, tsv_path_sub2=args.dev_set_sub2, dict_path=args.dict, dict_path_sub1=args.dict_sub1, dict_path_sub2=args.dict_sub2, nlsyms=args.nlsyms, unit=args.unit, unit_sub1=args.unit_sub1, unit_sub2=args.unit_sub2, wp_model=args.wp_model, wp_model_sub1=args.wp_model_sub1, wp_model_sub2=args.wp_model_sub2, batch_size=args.batch_size * args.n_gpus, min_n_frames=args.min_n_frames, max_n_frames=args.max_n_frames, ctc=args.ctc_weight > 0, ctc_sub1=args.ctc_weight_sub1 > 0, ctc_sub2=args.ctc_weight_sub2 > 0, subsample_factor=subsample_factor, subsample_factor_sub1=subsample_factor_sub1, subsample_factor_sub2=subsample_factor_sub2, discourse_aware=args.discourse_aware, skip_thought=skip_thought) eval_sets = [] for s in args.eval_sets: eval_sets += [ Dataset(corpus=args.corpus, tsv_path=s, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=1, discourse_aware=args.discourse_aware, skip_thought=skip_thought, is_test=True) ] args.vocab = train_set.vocab args.vocab_sub1 = train_set.vocab_sub1 args.vocab_sub2 = train_set.vocab_sub2 args.input_dim = train_set.input_dim # Load a LM conf file for LM fusion & LM initialization if not args.resume and (args.lm_fusion or args.lm_init): if args.lm_fusion: lm_conf = load_config( os.path.join(os.path.dirname(args.lm_fusion), 'conf.yml')) elif args.lm_init: lm_conf = load_config( os.path.join(os.path.dirname(args.lm_init), 'conf.yml')) args.lm_conf = argparse.Namespace() for k, v in lm_conf.items(): setattr(args.lm_conf, k, v) assert args.unit == args.lm_conf.unit assert args.vocab == args.lm_conf.vocab # Set save path if args.resume: save_path = os.path.dirname(args.resume) dir_name = os.path.basename(save_path) else: dir_name = set_asr_model_name(args, subsample_factor) save_path = mkdir_join( args.model_save_dir, '_'.join(os.path.basename(args.train_set).split('.')[:-1]), dir_name) save_path = set_save_path(save_path) # avoid overwriting # Set logger logger = set_logger(os.path.join(save_path, 'train.log'), key='training', stdout=args.stdout) # Model setting model = Speech2Text(args, save_path) if not skip_thought else SkipThought( args, save_path) if args.resume: # Set optimizer epoch = int(args.resume.split('-')[-1]) optimizer = set_optimizer( model, 'sgd' if epoch > conf['convert_to_sgd_epoch'] else conf['optimizer'], conf['lr'], conf['weight_decay']) # Wrap optimizer by learning rate scheduler noam = 'transformer' in conf['enc_type'] or conf[ 'dec_type'] == 'transformer' optimizer = LRScheduler( optimizer, conf['lr'], decay_type=conf['lr_decay_type'], decay_start_epoch=conf['lr_decay_start_epoch'], decay_rate=conf['lr_decay_rate'], decay_patient_n_epochs=conf['lr_decay_patient_n_epochs'], early_stop_patient_n_epochs=conf['early_stop_patient_n_epochs'], warmup_start_lr=conf['warmup_start_lr'], warmup_n_steps=conf['warmup_n_steps'], model_size=conf['d_model'], factor=conf['lr_factor'], noam=noam) # Restore the last saved model model, optimizer = load_checkpoint(model, args.resume, optimizer, resume=True) # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch if epoch == conf['convert_to_sgd_epoch']: n_epochs = optimizer.n_epochs n_steps = optimizer.n_steps optimizer = set_optimizer(model, 'sgd', args.lr, conf['weight_decay']) optimizer = LRScheduler(optimizer, args.lr, decay_type='always', decay_start_epoch=0, decay_rate=0.5) optimizer._epoch = n_epochs optimizer._step = n_steps logger.info('========== Convert to SGD ==========') else: # Save the conf file as a yaml file save_config(vars(args), os.path.join(save_path, 'conf.yml')) if args.lm_fusion: save_config(args.lm_conf, os.path.join(save_path, 'conf_lm.yml')) # Save the nlsyms, dictionar, and wp_model if args.nlsyms: shutil.copy(args.nlsyms, os.path.join(save_path, 'nlsyms.txt')) for sub in ['', '_sub1', '_sub2']: if getattr(args, 'dict' + sub): shutil.copy(getattr(args, 'dict' + sub), os.path.join(save_path, 'dict' + sub + '.txt')) if getattr(args, 'unit' + sub) == 'wp': shutil.copy(getattr(args, 'wp_model' + sub), os.path.join(save_path, 'wp' + sub + '.model')) for k, v in sorted(vars(args).items(), key=lambda x: x[0]): logger.info('%s: %s' % (k, str(v))) # Count total parameters for n in sorted(list(model.num_params_dict.keys())): n_params = model.num_params_dict[n] logger.info("%s %d" % (n, n_params)) logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000)) logger.info(model) # Initialize with pre-trained model's parameters if args.pretrained_model and os.path.isfile(args.pretrained_model): # Load the ASR model conf_pt = load_config( os.path.join(os.path.dirname(args.pretrained_model), 'conf.yml')) for k, v in conf_pt.items(): setattr(args_pt, k, v) model_pt = Speech2Text(args_pt) model_pt = load_checkpoint(model_pt, args.pretrained_model)[0] # Overwrite parameters only_enc = (args.enc_n_layers != args_pt.enc_n_layers) or ( args.unit != args_pt.unit) or args_pt.ctc_weight == 1 param_dict = dict(model_pt.named_parameters()) for n, p in model.named_parameters(): if n in param_dict.keys() and p.size() == param_dict[n].size(): if only_enc and 'enc' not in n: continue if args.lm_fusion_type == 'cache' and 'output' in n: continue p.data = param_dict[n].data logger.info('Overwrite %s' % n) # Set optimizer optimizer = set_optimizer(model, args.optimizer, args.lr, args.weight_decay) # Wrap optimizer by learning rate scheduler noam = 'transformer' in args.enc_type or args.dec_type == 'transformer' optimizer = LRScheduler( optimizer, args.lr, decay_type=args.lr_decay_type, decay_start_epoch=args.lr_decay_start_epoch, decay_rate=args.lr_decay_rate, decay_patient_n_epochs=args.lr_decay_patient_n_epochs, early_stop_patient_n_epochs=args.early_stop_patient_n_epochs, warmup_start_lr=args.warmup_start_lr, warmup_n_steps=args.warmup_n_steps, model_size=args.d_model, factor=args.lr_factor, noam=noam) # Load the teacher ASR model teacher = None if args.teacher and os.path.isfile(args.teacher): conf_teacher = load_config( os.path.join(os.path.dirname(args.teacher), 'conf.yml')) for k, v in conf_teacher.items(): setattr(args_teacher, k, v) # Setting for knowledge distillation args_teacher.ss_prob = 0 args.lsm_prob = 0 teacher = Speech2Text(args_teacher) teacher = load_checkpoint(teacher, args.teacher)[0] # Load the teacher LM teacher_lm = None if args.teacher_lm and os.path.isfile(args.teacher_lm): conf_lm = load_config( os.path.join(os.path.dirname(args.teacher_lm), 'conf.yml')) args_lm = argparse.Namespace() for k, v in conf_lm.items(): setattr(args_lm, k, v) teacher_lm = build_lm(args_lm) teacher_lm = load_checkpoint(teacher_lm, args.teacher_lm)[0] # GPU setting if args.n_gpus >= 1: model = CustomDataParallel(model, device_ids=list(range(0, args.n_gpus, 1)), deterministic=False, benchmark=True) model.cuda() if teacher is not None: teacher.cuda() if teacher_lm is not None: teacher_lm.cuda() # Set process name logger.info('PID: %s' % os.getpid()) logger.info('USERNAME: %s' % os.uname()[1]) setproctitle(args.job_name if args.job_name else dir_name) # Set reporter reporter = Reporter(save_path, tensorboard=True) if args.mtl_per_batch: # NOTE: from easier to harder tasks tasks = [] if 1 - args.bwd_weight - args.ctc_weight - args.sub1_weight - args.sub2_weight > 0: tasks += ['ys'] if args.bwd_weight > 0: tasks = ['ys.bwd'] + tasks if args.ctc_weight > 0: tasks = ['ys.ctc'] + tasks if args.lmobj_weight > 0: tasks = ['ys.lmobj'] + tasks for sub in ['sub1', 'sub2']: if getattr(args, 'train_set_' + sub): if getattr(args, sub + '_weight') - getattr( args, 'ctc_weight_' + sub) > 0: tasks = ['ys_' + sub] + tasks if getattr(args, 'ctc_weight_' + sub) > 0: tasks = ['ys_' + sub + '.ctc'] + tasks else: tasks = ['all'] start_time_train = time.time() start_time_epoch = time.time() start_time_step = time.time() pbar_epoch = tqdm(total=len(train_set)) accum_n_tokens = 0 while True: # Compute loss in the training set batch_train, is_new_epoch = train_set.next() accum_n_tokens += sum([len(y) for y in batch_train['ys']]) # Change mini-batch depending on task for task in tasks: if skip_thought: loss, reporter = model(batch_train['ys'], ys_prev=batch_train['ys_prev'], ys_next=batch_train['ys_next'], reporter=reporter) else: loss, reporter = model(batch_train, reporter=reporter, task=task, teacher=teacher, teacher_lm=teacher_lm) # loss /= args.accum_grad_n_steps if len(model.device_ids) > 1: loss.backward(torch.ones(len(model.device_ids))) else: loss.backward() loss.detach() # Trancate the graph if args.accum_grad_n_tokens == 0 or accum_n_tokens >= args.accum_grad_n_tokens: if args.clip_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.module.parameters(), args.clip_grad_norm) optimizer.step() optimizer.zero_grad() accum_n_tokens = 0 loss_train = loss.item() del loss reporter.step() if optimizer.n_steps % args.print_step == 0: # Compute loss in the dev set batch_dev = dev_set.next()[0] # Change mini-batch depending on task for task in tasks: if skip_thought: loss, reporter = model(batch_dev['ys'], ys_prev=batch_dev['ys_prev'], ys_next=batch_dev['ys_next'], reporter=reporter, is_eval=True) else: loss, reporter = model(batch_dev, reporter=reporter, task=task, is_eval=True) loss_dev = loss.item() del loss reporter.step(is_eval=True) duration_step = time.time() - start_time_step if args.input_type == 'speech': xlen = max(len(x) for x in batch_train['xs']) ylen = max(len(y) for y in batch_train['ys']) elif args.input_type == 'text': xlen = max(len(x) for x in batch_train['ys']) ylen = max(len(y) for y in batch_train['ys_sub1']) logger.info( "step:%d(ep:%.2f) loss:%.3f(%.3f)/lr:%.5f/bs:%d/xlen:%d/ylen:%d (%.2f min)" % (optimizer.n_steps, optimizer.n_epochs + train_set.epoch_detail, loss_train, loss_dev, optimizer.lr, len(batch_train['utt_ids']), xlen, ylen, duration_step / 60)) start_time_step = time.time() pbar_epoch.update(len(batch_train['utt_ids'])) # Save fugures of loss and accuracy if optimizer.n_steps % (args.print_step * 10) == 0: reporter.snapshot() model.module.plot_attention() # Save checkpoint and evaluate model per epoch if is_new_epoch: duration_epoch = time.time() - start_time_epoch logger.info('========== EPOCH:%d (%.2f min) ==========' % (optimizer.n_epochs + 1, duration_epoch / 60)) if optimizer.n_epochs + 1 < args.eval_start_epoch: optimizer.epoch() reporter.epoch() # Save the model save_checkpoint(model, save_path, optimizer, optimizer.n_epochs, remove_old_checkpoints=not noam) else: start_time_eval = time.time() # dev metric_dev = eval_epoch([model.module], dev_set, recog_params, args, optimizer.n_epochs + 1, logger) optimizer.epoch(metric_dev) reporter.epoch(metric_dev) if optimizer.is_best: # Save the model save_checkpoint(model, save_path, optimizer, optimizer.n_epochs, remove_old_checkpoints=not noam) # test for eval_set in eval_sets: eval_epoch([model.module], eval_set, recog_params, args, optimizer.n_epochs, logger) # start scheduled sampling if args.ss_prob > 0: model.module.scheduled_sampling_trigger() duration_eval = time.time() - start_time_eval logger.info('Evaluation time: %.2f min' % (duration_eval / 60)) # Early stopping if optimizer.is_early_stop: break # Convert to fine-tuning stage if optimizer.n_epochs == args.convert_to_sgd_epoch: n_epochs = optimizer.n_epochs n_steps = optimizer.n_steps optimizer = set_optimizer(model, 'sgd', args.lr, args.weight_decay) optimizer = LRScheduler(optimizer, args.lr, decay_type='always', decay_start_epoch=0, decay_rate=0.5) optimizer._epoch = n_epochs optimizer._step = n_steps logger.info('========== Convert to SGD ==========') pbar_epoch = tqdm(total=len(train_set)) if optimizer.n_epochs == args.n_epochs: break start_time_step = time.time() start_time_epoch = time.time() duration_train = time.time() - start_time_train logger.info('Total time: %.2f hour' % (duration_train / 3600)) if reporter.tensorboard: reporter.tf_writer.close() pbar_epoch.close() return save_path
def main(): args = parse() hvd.init() torch.cuda.set_device(hvd.local_rank()) hvd_rank = hvd.rank() # Load a conf file if args.resume: conf = load_config( os.path.join(os.path.dirname(args.resume), 'conf.yml')) for k, v in conf.items(): if k != 'resume': setattr(args, k, v) recog_params = vars(args) # Compute subsampling factor subsample_factor = 1 subsample = [int(s) for s in args.subsample.split('_')] if args.conv_poolings and 'conv' in args.enc_type: for p in args.conv_poolings.split('_'): subsample_factor *= int(p.split(',')[0].replace('(', '')) else: subsample_factor = np.prod(subsample) skip_thought = 'skip' in args.enc_type batch_per_allreduce = args.batch_size # Load dataset train_set = Dataset(corpus=args.corpus, tsv_path=args.train_set, tsv_path_sub1=args.train_set_sub1, tsv_path_sub2=args.train_set_sub2, dict_path=args.dict, dict_path_sub1=args.dict_sub1, dict_path_sub2=args.dict_sub2, nlsyms=args.nlsyms, unit=args.unit, unit_sub1=args.unit_sub1, unit_sub2=args.unit_sub2, wp_model=args.wp_model, wp_model_sub1=args.wp_model_sub1, wp_model_sub2=args.wp_model_sub2, batch_size=args.batch_size, n_epochs=args.n_epochs, min_n_frames=args.min_n_frames, max_n_frames=args.max_n_frames, sort_by='no_sort', short2long=True, sort_stop_epoch=args.sort_stop_epoch, dynamic_batching=args.dynamic_batching, ctc=args.ctc_weight > 0, ctc_sub1=args.ctc_weight_sub1 > 0, ctc_sub2=args.ctc_weight_sub2 > 0, subsample_factor=subsample_factor, discourse_aware=args.discourse_aware, skip_thought=skip_thought) dev_set = Dataset(corpus=args.corpus, tsv_path=args.dev_set, tsv_path_sub1=args.dev_set_sub1, tsv_path_sub2=args.dev_set_sub2, dict_path=args.dict, dict_path_sub1=args.dict_sub1, dict_path_sub2=args.dict_sub2, nlsyms=args.nlsyms, unit=args.unit, unit_sub1=args.unit_sub1, unit_sub2=args.unit_sub2, wp_model=args.wp_model, wp_model_sub1=args.wp_model_sub1, wp_model_sub2=args.wp_model_sub2, batch_size=args.batch_size, min_n_frames=args.min_n_frames, max_n_frames=args.max_n_frames, ctc=args.ctc_weight > 0, ctc_sub1=args.ctc_weight_sub1 > 0, ctc_sub2=args.ctc_weight_sub2 > 0, subsample_factor=subsample_factor, discourse_aware=args.discourse_aware, skip_thought=skip_thought) eval_sets = [] for s in args.eval_sets: eval_sets += [ Dataset(corpus=args.corpus, tsv_path=s, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=1, discourse_aware=args.discourse_aware, skip_thought=skip_thought, is_test=True) ] args.vocab = train_set.vocab args.vocab_sub1 = train_set.vocab_sub1 args.vocab_sub2 = train_set.vocab_sub2 args.input_dim = train_set.input_dim # Horovod: use DistributedSampler to partition data among workers. Manually specify # `num_replicas=hvd.size()` and `rank=hvd.rank()`. train_loader = SeqDataloader(train_set, batch_size=args.batch_size, num_workers=1, distributed=True, num_stacks=args.n_stacks, num_splices=args.n_splices, num_skips=args.n_skips, pin_memory=False, shuffle=False) val_loader = SeqDataloader(dev_set, batch_size=args.batch_size, num_workers=1, distributed=True, num_stacks=args.n_stacks, num_splices=args.n_splices, num_skips=args.n_skips, pin_memory=False, shuffle=False) # Load a LM conf file for LM fusion & LM initialization if not args.resume and (args.lm_fusion or args.lm_init): if args.lm_fusion: lm_conf = load_config( os.path.join(os.path.dirname(args.lm_fusion), 'conf.yml')) elif args.lm_init: lm_conf = load_config( os.path.join(os.path.dirname(args.lm_init), 'conf.yml')) args.lm_conf = argparse.Namespace() for k, v in lm_conf.items(): setattr(args.lm_conf, k, v) assert args.unit == args.lm_conf.unit assert args.vocab == args.lm_conf.vocab # Set save path if args.resume: save_path = os.path.dirname(args.resume) dir_name = os.path.basename(save_path) else: dir_name = set_asr_model_name(args, subsample_factor) save_path = mkdir_join( args.model_save_dir, '_'.join(os.path.basename(args.train_set).split('.')[:-1]), dir_name) if hvd_rank == 0: save_path = set_save_path(save_path) # avoid overwriting # Set logger if hvd_rank == 0: logger = set_logger(os.path.join(save_path, 'train.log'), key='training', stdout=args.stdout) # Set process name logger.info('PID: %s' % os.getpid()) logger.info('USERNAME: %s' % os.uname()[1]) logger.info('NUMBER_DEVICES: %s' % hvd.size()) setproctitle(args.job_name if args.job_name else dir_name) # Model setting model = Speech2Text(args, save_path) # GPU setting if args.n_gpus >= 1: torch.backends.cudnn.benchmark = True model.cuda() if args.resume: # Set optimizer epochs = int(args.resume.split('-')[-1]) #optimizer = set_optimizer(model, 'sgd' if epochs >= conf['convert_to_sgd_epoch'] else conf['optimizer'], model, _ = load_checkpoint(model, args.resume, resume=True) optimizer = set_optimizer(model, 'sgd', conf['lr'], conf['weight_decay']) #broadcast optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=model.named_parameters()) hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) # Wrap optimizer by learning rate scheduler noam = 'transformer' in args.enc_type or args.dec_type == 'transformer' optimizer = LRScheduler( optimizer, args.lr, decay_type=args.lr_decay_type, decay_start_epoch=args.lr_decay_start_epoch, decay_rate=args.lr_decay_rate, decay_patient_n_epochs=args.lr_decay_patient_n_epochs, early_stop_patient_n_epochs=args.early_stop_patient_n_epochs, warmup_start_lr=args.warmup_start_lr, warmup_n_steps=args.warmup_n_steps, model_size=args.d_model, factor=args.lr_factor, noam=noam) else: # Save the conf file as a yaml file if hvd_rank == 0: save_config(vars(args), os.path.join(save_path, 'conf.yml')) if args.lm_fusion: save_config(args.lm_conf, os.path.join(save_path, 'conf_lm.yml')) if hvd_rank == 0: for k, v in sorted(vars(args).items(), key=lambda x: x[0]): logger.info('%s: %s' % (k, str(v))) # Count total parameters for n in sorted(list(model.num_params_dict.keys())): n_params = model.num_params_dict[n] logger.info("%s %d" % (n, n_params)) logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000)) logger.info(model) # Set optimizer optimizer = set_optimizer(model, args.optimizer, args.lr, args.weight_decay) optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=model.named_parameters(), compression=hvd.Compression.none, backward_passes_per_step=batch_per_allreduce) hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) # Wrap optimizer by learning rate scheduler noam = 'transformer' in args.enc_type or args.dec_type == 'transformer' optimizer = LRScheduler( optimizer, args.lr, decay_type=args.lr_decay_type, decay_start_epoch=args.lr_decay_start_epoch, decay_rate=args.lr_decay_rate, decay_patient_n_epochs=args.lr_decay_patient_n_epochs, early_stop_patient_n_epochs=args.early_stop_patient_n_epochs, warmup_start_lr=args.warmup_start_lr, warmup_n_steps=args.warmup_n_steps, model_size=args.d_model, factor=args.lr_factor, noam=noam) # Set reporter reporter = Reporter(save_path) if args.mtl_per_batch: # NOTE: from easier to harder tasks tasks = [] if 1 - args.bwd_weight - args.ctc_weight - args.sub1_weight - args.sub2_weight > 0: tasks += ['ys'] if args.bwd_weight > 0: tasks = ['ys.bwd'] + tasks if args.ctc_weight > 0: tasks = ['ys.ctc'] + tasks else: tasks = ['all'] start_time_train = time.time() start_time_epoch = time.time() start_time_step = time.time() accum_n_tokens = 0 verbose = 1 if hvd_rank == 0 else 0 data_size = len(train_set) while True: model.train() with tqdm(total=data_size // hvd.size(), desc='Train Epoch #{}'.format(optimizer.n_epochs + 1), disable=not verbose) as pbar_epoch: # Compute loss in the training set for _, batch_train in enumerate(train_loader): accum_n_tokens += sum([len(y) for y in batch_train['ys']]) # Change mini-batch depending on task for task in tasks: if skip_thought: loss, reporter = model(batch_train['ys'], ys_prev=batch_train['ys_prev'], ys_next=batch_train['ys_next'], reporter=reporter) else: loss, reporter = model(batch_train, reporter, task) loss.backward() loss.detach() # Trancate the graph if args.accum_grad_n_tokens == 0 or accum_n_tokens >= args.accum_grad_n_tokens: if args.clip_grad_norm > 0: total_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), args.clip_grad_norm) optimizer.step() optimizer.zero_grad() accum_n_tokens = 0 loss_train = loss.item() del loss if optimizer.n_steps % args.print_step == 0: # Compute loss in the dev set model.eval() batch_dev = dev_set.next()[0] # Change mini-batch depending on task for task in tasks: if skip_thought: loss, reporter = model( batch_dev['ys'], ys_prev=batch_dev['ys_prev'], ys_next=batch_dev['ys_next'], reporter=reporter, is_eval=True) else: loss, reporter = model(batch_dev, reporter, task, is_eval=True) loss_dev = loss.item() del loss duration_step = time.time() - start_time_step if args.input_type == 'speech': xlen = max(len(x) for x in batch_train['xs']) ylen = max(len(y) for y in batch_train['ys']) elif args.input_type == 'text': xlen = max(len(x) for x in batch_train['ys']) ylen = max(len(y) for y in batch_train['ys_sub1']) if hvd_rank == 0: logger.info( "step:%d(ep:%.2f) loss:%.3f(%.3f)/lr:%.5f/bs:%d/xlen:%d/ylen:%d (%.2f min)" % (optimizer.n_steps, optimizer.n_steps * args.batch_size / (data_size / hvd.size()), loss_train, loss_dev, optimizer.lr, len(batch_train['utt_ids']), xlen, ylen, duration_step / 60)) start_time_step = time.time() pbar_epoch.update(len(batch_train['utt_ids'])) # Save fugures of loss and accuracy if optimizer.n_steps % (args.print_step * 10) == 0 and hvd.rank() == 0: model.plot_attention() start_time_step = time.time() # reset dev set dev_set.reset() # Save checkpoint and evaluate model per epoch duration_epoch = time.time() - start_time_epoch if hvd_rank == 0: logger.info('========== EPOCH:%d (%.2f min) ==========' % (optimizer.n_epochs + 1, duration_epoch / 60)) if optimizer.n_epochs + 1 < args.eval_start_epoch: optimizer.epoch() if hvd_rank == 0: save_checkpoint(model, save_path, optimizer, optimizer.n_epochs, remove_old_checkpoints=not noam) else: start_time_eval = time.time() # dev metric_dev = eval_epoch([model], val_loader, recog_params, args, optimizer.n_epochs + 1) metric_dev = hvd.allreduce( np2tensor(np.array([metric_dev], dtype=float), hvd.local_rank())) loss_dev = metric_dev.item() if hvd_rank == 0: logger.info('Loss : %.2f %%' % (loss_dev)) optimizer.epoch(loss_dev) if hvd.rank() == 0: save_checkpoint(model, save_path, optimizer, optimizer.n_epochs, remove_old_checkpoints=False) if not optimizer.is_best: model, _ = load_checkpoint( model, save_path + '/model.epoch-' + str(optimizer.best_epochs)) duration_eval = time.time() - start_time_eval if hvd_rank == 0: logger.info('Evaluation time: %.2f min' % (duration_eval / 60)) # Early stopping if optimizer.is_early_stop: break # Convert to fine-tuning stage if optimizer.n_epochs == args.convert_to_sgd_epoch: n_epochs = optimizer.n_epochs n_steps = optimizer.n_steps optimizer = set_optimizer(model, 'sgd', args.lr, args.weight_decay) optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=model.named_parameters(), compression=hvd.Compression.none, backward_passes_per_step=batch_per_allreduce) hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) optimizer = LRScheduler( optimizer, args.lr, decay_type=args.lr_decay_type, decay_start_epoch=args.lr_decay_start_epoch, decay_rate=args.lr_decay_rate, decay_patient_n_epochs=args.lr_decay_patient_n_epochs, early_stop_patient_n_epochs=args. early_stop_patient_n_epochs, warmup_start_lr=args.warmup_start_lr, warmup_n_steps=args.warmup_n_steps, model_size=args.d_model, factor=args.lr_factor, noam=noam) optimizer._epoch = n_epochs optimizer._step = n_steps if hvd_rank == 0: logger.info('========== Convert to SGD ==========') if optimizer.n_epochs == args.n_epochs: break start_time_step = time.time() start_time_epoch = time.time() duration_train = time.time() - start_time_train if hvd_rank == 0: logger.info('Total time: %.2f hour' % (duration_train / 3600)) return save_path
def main(): args = parse() # Load a conf file if args.resume: conf = load_config( os.path.join(os.path.dirname(args.resume), 'conf.yml')) for k, v in conf.items(): if k != 'resume': setattr(args, k, v) # Set save path if args.resume: save_path = os.path.dirname(args.resume) dir_name = os.path.basename(save_path) else: dir_name = set_lm_name(args) save_path = mkdir_join( args.model_save_dir, '_'.join(os.path.basename(args.train_set).split('.')[:-1]), dir_name) save_path = set_save_path(save_path) # avoid overwriting # Set logger logger = set_logger(os.path.join(save_path, 'train.log'), key='training', stdout=args.stdout) # Load dataset train_set = Dataset(corpus=args.corpus, tsv_path=args.train_set, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=args.batch_size * args.n_gpus, n_epochs=args.n_epochs, min_n_tokens=args.min_n_tokens, bptt=args.bptt, backward=args.backward, serialize=args.serialize) dev_set = Dataset(corpus=args.corpus, tsv_path=args.dev_set, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=args.batch_size * args.n_gpus, bptt=args.bptt, backward=args.backward, serialize=args.serialize) eval_sets = [] for s in args.eval_sets: eval_sets += [ Dataset(corpus=args.corpus, tsv_path=s, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=1, bptt=args.bptt, backward=args.backward, serialize=args.serialize) ] args.vocab = train_set.vocab # Model setting model = build_lm(args, save_path) if args.resume: # Set optimizer epoch = int(args.resume.split('-')[-1]) optimizer = set_optimizer( model, 'sgd' if epoch > conf['convert_to_sgd_epoch'] else conf['optimizer'], conf['lr'], conf['weight_decay']) # Wrap optimizer by learning rate scheduler optimizer = LRScheduler( optimizer, conf['lr'], decay_type=conf['lr_decay_type'], decay_start_epoch=conf['lr_decay_start_epoch'], decay_rate=conf['lr_decay_rate'], decay_patient_n_epochs=conf['lr_decay_patient_n_epochs'], early_stop_patient_n_epochs=conf['early_stop_patient_n_epochs'], warmup_start_lr=conf['warmup_start_lr'], warmup_n_steps=conf['warmup_n_steps'], model_size=conf['d_model'], factor=conf['lr_factor'], noam=conf['lm_type'] == 'transformer') # Restore the last saved model model, optimizer = load_checkpoint(model, args.resume, optimizer, resume=True) # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch if epoch == conf['convert_to_sgd_epoch']: n_epochs = optimizer.n_epochs n_steps = optimizer.n_steps optimizer = set_optimizer(model, 'sgd', args.lr, conf['weight_decay']) optimizer = LRScheduler(optimizer, args.lr, decay_type='always', decay_start_epoch=0, decay_rate=0.5) optimizer._epoch = n_epochs optimizer._step = n_steps logger.info('========== Convert to SGD ==========') else: # Save the conf file as a yaml file save_config(vars(args), os.path.join(save_path, 'conf.yml')) # Save the nlsyms, dictionar, and wp_model if args.nlsyms: shutil.copy(args.nlsyms, os.path.join(save_path, 'nlsyms.txt')) shutil.copy(args.dict, os.path.join(save_path, 'dict.txt')) if args.unit == 'wp': shutil.copy(args.wp_model, os.path.join(save_path, 'wp.model')) for k, v in sorted(vars(args).items(), key=lambda x: x[0]): logger.info('%s: %s' % (k, str(v))) # Count total parameters for n in sorted(list(model.num_params_dict.keys())): n_params = model.num_params_dict[n] logger.info("%s %d" % (n, n_params)) logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000)) logger.info(model) # Set optimizer optimizer = set_optimizer(model, args.optimizer, args.lr, args.weight_decay) # Wrap optimizer by learning rate scheduler optimizer = LRScheduler( optimizer, args.lr, decay_type=args.lr_decay_type, decay_start_epoch=args.lr_decay_start_epoch, decay_rate=args.lr_decay_rate, decay_patient_n_epochs=args.lr_decay_patient_n_epochs, early_stop_patient_n_epochs=args.early_stop_patient_n_epochs, warmup_start_lr=args.warmup_start_lr, warmup_n_steps=args.warmup_n_steps, model_size=args.d_model, factor=args.lr_factor, noam=args.lm_type == 'transformer') # GPU setting if args.n_gpus >= 1: torch.backends.cudnn.benchmark = True model = CustomDataParallel(model, device_ids=list(range(0, args.n_gpus))) model.cuda() # Set process name logger.info('PID: %s' % os.getpid()) logger.info('USERNAME: %s' % os.uname()[1]) setproctitle(args.job_name if args.job_name else dir_name) # Set reporter reporter = Reporter(save_path) hidden = None start_time_train = time.time() start_time_epoch = time.time() start_time_step = time.time() pbar_epoch = tqdm(total=len(train_set)) accum_n_tokens = 0 while True: # Compute loss in the training set ys_train, is_new_epoch = train_set.next() accum_n_tokens += sum([len(y) for y in ys_train]) optimizer.zero_grad() loss, hidden, reporter = model(ys_train, hidden, reporter) loss.backward() loss.detach() # Trancate the graph if args.accum_grad_n_tokens == 0 or accum_n_tokens >= args.accum_grad_n_tokens: if args.clip_grad_norm > 0: total_norm = torch.nn.utils.clip_grad_norm_( model.module.parameters(), args.clip_grad_norm) reporter.add_tensorboard_scalar('total_norm', total_norm) optimizer.step() optimizer.zero_grad() accum_n_tokens = 0 loss_train = loss.item() del loss hidden = model.module.repackage_state(hidden) reporter.add_tensorboard_scalar('learning_rate', optimizer.lr) # NOTE: loss/acc/ppl are already added in the model reporter.step() if optimizer.n_steps % args.print_step == 0: # Compute loss in the dev set ys_dev = dev_set.next()[0] loss, _, reporter = model(ys_dev, None, reporter, is_eval=True) loss_dev = loss.item() del loss reporter.step(is_eval=True) duration_step = time.time() - start_time_step logger.info( "step:%d(ep:%.2f) loss:%.3f(%.3f)/ppl:%.3f(%.3f)/lr:%.5f/bs:%d (%.2f min)" % (optimizer.n_steps, optimizer.n_epochs + train_set.epoch_detail, loss_train, loss_dev, np.exp(loss_train), np.exp(loss_dev), optimizer.lr, ys_train.shape[0], duration_step / 60)) start_time_step = time.time() pbar_epoch.update(ys_train.shape[0] * (ys_train.shape[1] - 1)) # Save fugures of loss and accuracy if optimizer.n_steps % (args.print_step * 10) == 0: reporter.snapshot() if args.lm_type == 'transformer': model.module.plot_attention() # Save checkpoint and evaluate model per epoch if is_new_epoch: duration_epoch = time.time() - start_time_epoch logger.info('========== EPOCH:%d (%.2f min) ==========' % (optimizer.n_epochs + 1, duration_epoch / 60)) if optimizer.n_epochs + 1 < args.eval_start_epoch: optimizer.epoch() # lr decay reporter.epoch() # plot # Save the model save_checkpoint( model, save_path, optimizer, optimizer.n_epochs, remove_old_checkpoints=args.lm_type != 'transformer') else: start_time_eval = time.time() # dev ppl_dev, _ = eval_ppl([model.module], dev_set, batch_size=1, bptt=args.bptt) logger.info('PPL (%s, epoch:%d): %.2f' % (dev_set.set, optimizer.n_epochs, ppl_dev)) optimizer.epoch(ppl_dev) # lr decay reporter.epoch(ppl_dev, name='perplexity') # plot if optimizer.is_best: # Save the model save_checkpoint( model, save_path, optimizer, optimizer.n_epochs, remove_old_checkpoints=args.lm_type != 'transformer') # test ppl_test_avg = 0. for eval_set in eval_sets: ppl_test, _ = eval_ppl([model.module], eval_set, batch_size=1, bptt=args.bptt) logger.info( 'PPL (%s, epoch:%d): %.2f' % (eval_set.set, optimizer.n_epochs, ppl_test)) ppl_test_avg += ppl_test if len(eval_sets) > 0: logger.info('PPL (avg., epoch:%d): %.2f' % (optimizer.n_epochs, ppl_test_avg / len(eval_sets))) duration_eval = time.time() - start_time_eval logger.info('Evaluation time: %.2f min' % (duration_eval / 60)) # Early stopping if optimizer.is_early_stop: break # Convert to fine-tuning stage if optimizer.n_epochs == args.convert_to_sgd_epoch: n_epochs = optimizer.n_epochs n_steps = optimizer.n_steps optimizer = set_optimizer(model, 'sgd', args.lr, args.weight_decay) optimizer = LRScheduler(optimizer, args.lr, decay_type='always', decay_start_epoch=0, decay_rate=0.5) optimizer._epoch = n_epochs optimizer._step = n_steps logger.info('========== Convert to SGD ==========') pbar_epoch = tqdm(total=len(train_set)) if optimizer.n_epochs == args.n_epochs: break start_time_step = time.time() start_time_epoch = time.time() duration_train = time.time() - start_time_train logger.info('Total time: %.2f hour' % (duration_train / 3600)) reporter.tf_writer.close() pbar_epoch.close() return save_path
def main(): args = parse() hvd.init() torch.cuda.set_device(hvd.local_rank()) hvd_rank = hvd.rank() # Load a conf file if args.resume: conf = load_config(os.path.join(os.path.dirname(args.resume), 'conf.yml')) for k, v in conf.items(): if k != 'resume': setattr(args, k, v) # Load dataset train_set = Dataset(corpus=args.corpus, tsv_path=args.train_set, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=args.batch_size, n_epochs=args.n_epochs, min_n_tokens=args.min_n_tokens, bptt=args.bptt, n_customers=hvd.size(), backward=args.backward, serialize=args.serialize) dev_set = Dataset(corpus=args.corpus, tsv_path=args.dev_set, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=args.batch_size, bptt=args.bptt, n_customers=hvd.size(), backward=args.backward, serialize=args.serialize) eval_set = Dataset(corpus=args.corpus, tsv_path=args.eval_set, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=args.batch_size, bptt=args.bptt, n_customers=hvd.size(), backward=args.backward, serialize=args.serialize) args.vocab = train_set.vocab train_loader = ChunkDataloader(train_set, batch_size=1, num_workers = 1, distributed=True, shuffle=False) eval_loader = ChunkDataloader(eval_set, batch_size=1, num_workers=1, distributed=True) # Set save path if args.resume: save_path = os.path.dirname(args.resume) dir_name = os.path.basename(save_path) else: dir_name = set_lm_name(args) save_path = mkdir_join(args.model_save_dir, '_'.join( os.path.basename(args.train_set).split('.')[:-1]), dir_name) if hvd.rank() == 0: save_path = set_save_path(save_path) # avoid overwriting # Set logger if hvd_rank == 0: logger = set_logger(os.path.join(save_path, 'train.log'), key='training', stdout=args.stdout) # Set process name logger.info('PID: %s' % os.getpid()) logger.info('USERNAME: %s' % os.uname()[1]) logger.info('NUMBER_DEVICES: %s' % hvd.size()) setproctitle(args.job_name if args.job_name else dir_name) # Model setting model = build_lm(args, save_path) # GPU setting if args.n_gpus >= 1: torch.backends.cudnn.benchmark = True model.cuda() if args.resume: # Set optimizer epoch = int(args.resume.split('-')[-1]) optimizer = set_optimizer(model, 'sgd' if epoch > conf['convert_to_sgd_epoch'] else conf['optimizer'], conf['lr'], conf['weight_decay']) # Restore the last saved model if hvd_rank == 0: model, optimizer = load_checkpoint(model, args.resume, optimizer, resume=True) #broadcast optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters()) hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) # Wrap optimizer by learning rate scheduler optimizer = LRScheduler(optimizer, conf['lr'], decay_type=conf['lr_decay_type'], decay_start_epoch=conf['lr_decay_start_epoch'], decay_rate=conf['lr_decay_rate'], decay_patient_n_epochs=conf['lr_decay_patient_n_epochs'], early_stop_patient_n_epochs=conf['early_stop_patient_n_epochs'], warmup_start_lr=conf['warmup_start_lr'], warmup_n_steps=conf['warmup_n_steps'], model_size=conf['d_model'], factor=conf['lr_factor'], noam=conf['lm_type'] == 'transformer') # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch if epoch == conf['convert_to_sgd_epoch']: n_epochs = optimizer.n_epochs n_steps = optimizer.n_steps optimizer = set_optimizer(model, 'sgd', args.lr, conf['weight_decay']) optimizer = LRScheduler(optimizer, args.lr, decay_type='always', decay_start_epoch=0, decay_rate=0.5) optimizer._epoch = n_epochs optimizer._step = n_steps if hvd_rank == 0: logger.info('========== Convert to SGD ==========') #broadcast optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters()) hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) else: # Save the conf file as a yaml file if hvd_rank == 0: save_config(vars(args), os.path.join(save_path, 'conf.yml')) # Save the nlsyms, dictionar, and wp_model if args.nlsyms: shutil.copy(args.nlsyms, os.path.join(save_path, 'nlsyms.txt')) shutil.copy(args.dict, os.path.join(save_path, 'dict.txt')) if args.unit == 'wp': shutil.copy(args.wp_model, os.path.join(save_path, 'wp.model')) for k, v in sorted(vars(args).items(), key=lambda x: x[0]): logger.info('%s: %s' % (k, str(v))) # Count total parameters for n in sorted(list(model.num_params_dict.keys())): n_params = model.num_params_dict[n] if hvd.rank() == 0: logger.info("%s %d" % (n, n_params)) if hvd_rank == 0: logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000)) logger.info(model) # Set optimizer hvd.broadcast_parameters(model.state_dict(), root_rank=0) optimizer = set_optimizer(model, args.optimizer, args.lr, args.weight_decay) optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters()) hvd.broadcast_optimizer_state(optimizer, root_rank=0) # Wrap optimizer by learning rate scheduler optimizer = LRScheduler(optimizer, args.lr, decay_type=args.lr_decay_type, decay_start_epoch=args.lr_decay_start_epoch, decay_rate=args.lr_decay_rate, decay_patient_n_epochs=args.lr_decay_patient_n_epochs, early_stop_patient_n_epochs=args.early_stop_patient_n_epochs, warmup_start_lr=args.warmup_start_lr, warmup_n_steps=args.warmup_n_steps, model_size=args.d_model, factor=args.lr_factor, noam=args.lm_type == 'transformer') # Set reporter reporter = Reporter(save_path) hidden = None start_time_train = time.time() start_time_epoch = time.time() start_time_step = time.time() data_size = len(train_set) accum_n_tokens = 0 verbose = 1 if hvd_rank == 0 else 0 while True: model.train() with tqdm(total=data_size/hvd.size(), desc='Train Epoch #{}'.format(optimizer.n_epochs + 1), disable=not verbose) as pbar_epoch: # Compute loss in the training set for _, ys_train in enumerate(train_loader): accum_n_tokens += sum([len(y) for y in ys_train]) optimizer.zero_grad() loss, hidden, reporter = model(ys_train, hidden, reporter) loss.backward() loss.detach() # Trancate the graph if args.accum_grad_n_tokens == 0 or accum_n_tokens >= args.accum_grad_n_tokens: if args.clip_grad_norm > 0: total_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), args.clip_grad_norm) #reporter.add_tensorboard_scalar('total_norm', total_norm) optimizer.step() optimizer.zero_grad() accum_n_tokens = 0 loss_train = loss.item() del loss hidden = model.repackage_state(hidden) if optimizer.n_steps % args.print_step == 0: model.eval() # Compute loss in the dev set ys_dev = dev_set.next()[0] loss, _, reporter = model(ys_dev, None, reporter, is_eval=True) loss_dev = loss.item() del loss duration_step = time.time() - start_time_step if hvd_rank == 0: logger.info("step:%d(ep:%.2f) loss:%.3f(%.3f)/ppl:%.3f(%.3f)/lr:%.5f/bs:%d (%.2f min)" % (optimizer.n_steps, optimizer.n_steps/data_size*hvd.size(), loss_train, loss_dev, np.exp(loss_train), np.exp(loss_dev), optimizer.lr, ys_train.shape[0], duration_step / 60)) start_time_step = time.time() pbar_epoch.update(1) # Save checkpoint and evaluate model per epoch duration_epoch = time.time() - start_time_epoch if hvd_rank == 0: logger.info('========== EPOCH:%d (%.2f min) ==========' %(optimizer.n_epochs + 1, duration_epoch / 60)) if optimizer.n_epochs + 1 < args.eval_start_epoch: # Save the model if hvd_rank == 0: optimizer.epoch() save_checkpoint(model, save_path, optimizer, optimizer.n_epochs, remove_old_checkpoints=args.lm_type != 'transformer') else: start_time_eval = time.time() # dev model.eval() ppl_dev, _ = eval_ppl_parallel([model], eval_loader, optimizer.n_epochs, batch_size=args.batch_size) ppl_dev = hvd.allreduce(np2tensor(np.array([ppl_dev], dtype=float), hvd.local_rank())) if hvd_rank == 0: logger.info('PPL : %.2f' % ppl_dev) optimizer.epoch(ppl_dev) if optimizer.is_best and hvd.rank() == 0: # Save the model save_checkpoint(model, save_path, optimizer, optimizer.n_epochs, remove_old_checkpoints=args.lm_type != 'transformer') duration_eval = time.time() - start_time_eval if hvd_rank == 0: logger.info('Evaluation time: %.2f min' % (duration_eval / 60)) # Early stopping if optimizer.is_early_stop: break # Convert to fine-tuning stage if optimizer.n_epochs == args.convert_to_sgd_epoch: n_epochs = optimizer.n_epochs n_steps = optimizer.n_steps optimizer = set_optimizer(model, 'sgd', args.lr, args.weight_decay) optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=model.named_parameters()) hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) optimizer = LRScheduler(optimizer, args.lr, decay_type='always', decay_start_epoch=0, decay_rate=0.5) optimizer._epoch = n_epochs optimizer._step = n_steps if hvd_rank == 0: logger.info('========== Convert to SGD ==========') if optimizer.n_epochs == args.n_epochs: break start_time_step = time.time() start_time_epoch = time.time() duration_train = time.time() - start_time_train if hvd_rank == 0: logger.info('Total time: %.2f hour' % (duration_train / 3600)) reporter.tf_writer.close() return save_path