def setup_args(parser=None): if parser is None: parser = ParlaiParser(True, True, 'compute statistics from model predictions') DictionaryAgent.add_cmdline_args(parser) # Get command line arguments parser.add_argument('-ne', '--num-examples', type=int, default=-1) parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2) parser.add_argument('-ed', '--external-dict', type=str, default=None, help='External dictionary for stat computation') parser.add_argument('-fb', '--freq-bins', type=str, default='0,100,1000,10000', help='Bins boundaries for rare words stat') parser.add_argument('-dup', '--dump-predictions-path', type=str, default=None, help='Dump predictions into file') parser.add_argument('-cun', '--compute-unique', type=bool, default=True, help='Compute %% of unique responses from the model') parser.set_defaults(datatype='valid', model='repeat_label') TensorboardLogger.add_cmdline_args(parser) return parser
def setup_args(parser=None): if parser is None: parser = ParlaiParser(True, True, 'Evaluate a model') parser.add_pytorch_datateacher_args() # Get command line arguments parser.add_argument('-ne', '--num-examples', type=int, default=-1) parser.add_argument('-d', '--display-examples', type='bool', default=False) parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2) parser.add_argument( '-micro', '--aggregate-micro', type='bool', default=False, help='If multitasking, average metrics over the ' 'number of examples. If false, averages over the ' 'number of tasks.', ) parser.add_argument( '--metrics', type=str, default='all', help='list of metrics to show/compute, e.g. ' 'ppl, f1, accuracy, hits@1.' 'If `all` is specified [default] all are shown.', ) TensorboardLogger.add_cmdline_args(parser) parser.set_defaults(datatype='valid') return parser
def setup_args(parser=None) -> ParlaiParser: """ Build the ParlAI parser, adding command line args if necessary. :param ParlaiParser parser: Preexisting parser to append options to. Will be created if needed. :returns: the ParlaiParser with CLI options added. """ if parser is None: parser = ParlaiParser(True, True, 'Train a model') train = parser.add_argument_group('FTML Training Loop Arguments') # train.add_argument('--n_grad', type='bool', default=False, hidden=True) # train.add_argument('-ngrad', '--num-grad', type=int, default=2) # train.add_argument('-nadd', '--num-added-data', type=int, default=50) train.add_argument('-mbchsztr', '--meta-batchsize_tr', type=int, default=10) train.add_argument('-mbchszval', '--meta-batchsize_val', type=int, default=10) train.add_argument('-nmmetastep', '--num-meta-steps', type=int, default=50) train.add_argument('-nepb', '--num-episode-batch', type=int, default=3) train.add_argument('-ebs', '--eval-batch-size', type=int, default=32) train.add_argument('-mnts', '--max-num-turns', type=int, default=15) TensorboardLogger.add_cmdline_args(parser) parser = setup_train_args(parser) return parser
def setup_args(parser=None): if parser is None: parser = ParlaiParser(True, True, 'Evaluate a model') parser.add_pytorch_datateacher_args() # Get command line arguments # Other command line arguments parser.add_argument('-ne', '--num-examples', type=int, default=-1) parser.add_argument('-d', '--display-examples', type='bool', default=False) parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2) parser.add_argument( '-micro', '--aggregate-micro', type='bool', default=False, help='If multitasking, average metrics over the ' 'number of examples. If false, averages over the ' 'number of tasks.', ) parser.add_argument( '-mcs', '--metrics', type=str, default='default', help='list of metrics to show/compute, e.g. all, default,' 'or give a list split by , like ' 'ppl,f1,accuracy,hits@1,rouge,bleu' 'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l', ) TensorboardLogger.add_cmdline_args(parser) parser.set_defaults(datatype='valid') return parser
def setup_args(parser=None): if parser is None: parser = ParlaiParser(True, True) train = parser.add_argument_group('Training Loop Arguments') train.add_argument('-dbf', '--dict-build-first', type='bool', default=True, help='build dictionary first before training agent') train.add_argument('-eps', '--num-epochs', type=float, default=-1) train.add_argument('-ttim', '--max-train-time', type=float, default=-1) train.add_argument('-ltim', '--log-every-n-secs', type=float, default=2) train.add_argument('-tden', '--train-display-every-n-secs', type=float, default=-1, help='Display training information every n seconds') train.add_argument('-vtim', '--validation-every-n-secs', type=float, default=-1, help='Validate every n seconds. Whenever the the best ' 'validation metric is found, saves the model to ' 'the model_file path if set.') train.add_argument('-stim', '--save-every-n-secs', type=float, default=-1, help='Saves the model to model_file.checkpoint after ' 'every n seconds (default -1, never).') TensorboardLogger.add_cmdline_args(parser) parser = setup_dict_args(parser) return parser
def __init__(self, opt): if isinstance(opt, ParlaiParser): print( '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]' ) opt = opt.parse_args() # Possibly load from checkpoint if opt['load_from_checkpoint'] and opt.get( 'model_file') and os.path.isfile(opt['model_file'] + '.checkpoint'): opt['init_model'] = opt['model_file'] + '.checkpoint' # Possibly build a dictionary (not all models do this). if opt['dict_build_first'] and 'dict_file' in opt: # If data built via pytorch data teacher, we need to load prebuilt dict if opt.get('pytorch_teacher_task'): opt['dict_file'] = get_pyt_dict_file(opt) elif opt['dict_file'] is None and opt.get('model_file'): opt['dict_file'] = opt['model_file'] + '.dict' print("[ building dictionary first... ]") build_dict(opt, skip_if_built=True) # Create model and assign it to the specified task self.agent = create_agent(opt) self.world = create_task(opt, self.agent) self.train_time = Timer() self.validate_time = Timer() self.log_time = Timer() self.save_time = Timer() print('[ training... ]') self.parleys = 0 self.max_num_epochs = opt[ 'num_epochs'] if opt['num_epochs'] > 0 else float('inf') self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 \ else float('inf') self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \ else float('inf') self.val_every_n_secs = \ opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 \ else float('inf') self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] \ > 0 else float('inf') self.val_every_n_epochs = \ opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 \ else float('inf') self.last_valid_epoch = 0 self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.best_valid = None if opt.get('model_file') and os.path.isfile(opt['model_file'] + '.best_valid'): with open(opt['model_file'] + ".best_valid", 'r') as f: x = f.readline() self.best_valid = float(x) f.close() self.impatience = 0 self.saved = False self.valid_world = None self.opt = opt if opt['tensorboard_log'] is True: self.writer = TensorboardLogger(opt)
def __init__(self, opt): if isinstance(opt, ParlaiParser): opt = opt.parse_args() # Possibly build a dictionary (not all models do this). if opt['dict_build_first'] and 'dict_file' in opt: if opt['dict_file'] is None and opt.get('model_file_transmitter') and opt.get('model_file_receiver'): opt['dict_file'] = opt['model_file_transmitter'] + '_' + opt['model_file_receiver'] + '.dict' print("[ building dictionary first... ]") build_dict(opt, skip_if_built=False) # Create model and assign it to the specified task print("[ create meta-agent ... ]") self.agent = create_agent(opt) print("[ create agent A ... ]") shared = self.agent.share() self.agent_a = create_agent_from_shared(shared) self.agent_a.set_id(suffix=' A') print("[ create agent B ... ]") self.agent_b = create_agent_from_shared(shared) # self.agent_b = create_agent(opt) self.agent_b.set_id(suffix=' B') # self.agent_a.copy(self.agent, 'transmitter') # self.agent_b.copy(self.agent, 'transmitter') self.world = create_selfplay_world(opt, [self.agent_a, self.agent_b]) # TODO: if batch, it is also not parallel # self.world = BatchSelfPlayWorld(opt, self_play_world) self.train_time = Timer() self.train_dis_time = Timer() self.validate_time = Timer() self.log_time = Timer() self.save_time = Timer() print('[ training... ]') self.parleys_episode = 0 self.max_num_epochs = opt['num_epochs'] if opt['num_epochs'] > 0 else float('inf') self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 else float('inf') self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 else float('inf') self.train_dis_every_n_secs = opt['train_display_every_n_secs'] if opt['train_display_every_n_secs'] > 0 else float('inf') self.val_every_n_secs = opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 else float('inf') self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] > 0 else float('inf') self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.best_valid = None if opt.get('model_file_transmitter') and os.path.isfile(opt['model_file_transmitter'] + '.best_valid'): with open(opt['model_file_transmitter'] + ".best_valid", 'r') as f: x = f.readline() self.best_valid = float(x) f.close() self.impatience = 0 self.saved = False self.valid_world = None self.opt = opt if opt['tensorboard_log'] is True: self.writer = TensorboardLogger(opt)
def setup_args(parser=None): if parser is None: parser = ParlaiParser(True, True) train = parser.add_argument_group('Training Loop Arguments') train.add_argument('-et', '--evaltask', help=('task to use for valid/test (defaults to the ' 'one used for training if not set)')) train.add_argument('--display-examples', type='bool', default=False) train.add_argument('-eps', '--num-epochs', type=float, default=-1) train.add_argument('-ttim', '--max-train-time', type=float, default=-1) train.add_argument('-ltim', '--log-every-n-secs', type=float, default=2) train.add_argument('-vtim', '--validation-every-n-secs', type=float, default=-1, help='Validate every n seconds. Whenever the the best ' 'validation metric is found, saves the model to ' 'the model_file path if set.') train.add_argument('-stim', '--save-every-n-secs', type=float, default=-1, help='Saves the model to model_file.checkpoint after ' 'every n seconds (default -1, never).') train.add_argument('-sval', '--save-after-valid', type='bool', default=False, help='Saves the model to model_file.checkpoint after ' 'every validation (default True).') train.add_argument('-vme', '--validation-max-exs', type=int, default=-1, help='max examples to use during validation (default ' '-1 uses all)') train.add_argument('-vp', '--validation-patience', type=int, default=10, help=('number of iterations of validation where result' ' does not improve before we stop training')) train.add_argument('-vmt', '--validation-metric', default='accuracy', help='key into report table for selecting best ' 'validation') train.add_argument('-vmm', '--validation-metric-mode', default='max', type=str, choices=['max', 'min'], help='how to optimize validation metric (max or min)') train.add_argument('-vcut', '--validation-cutoff', type=float, default=1.0, help='value at which training will stop if exceeded by ' 'training metric') train.add_argument('-dbf', '--dict-build-first', type='bool', default=True, help='build dictionary first before training agent') train.add_argument('-lfc', '--load-from-checkpoint', type='bool', default=False, help='load model from checkpoint if available') TensorboardLogger.add_cmdline_args(parser) parser = setup_dict_args(parser) return parser
def setup_args(parser=None): if parser is None: parser = ParlaiParser(True, True, 'Evaluate a model') parser.add_pytorch_datateacher_args() # Get command line arguments # Probing command line arguments parser.add_argument( '--probe', type=str, default=None, choices=['word_embeddings', 'encoder_state', 'combined'], help="Specify the type of representations to generate for probing. " "See 'Probing Neural Dialog for Conversational Understanding' for more details." ) parser.add_argument( '-t', '--tasks', type=str, nargs='+', required=True, help='Usage: -t trecquestion or -t trecquestion wnli multiwoz' '\nOnly compatible with names in probing/tasks') # Other command line arguments parser.add_argument('-ne', '--num-examples', type=int, default=-1) parser.add_argument('-d', '--display-examples', type='bool', default=False) parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2) parser.add_argument( '-micro', '--aggregate-micro', type='bool', default=False, help='If multitasking, average metrics over the ' 'number of examples. If false, averages over the ' 'number of tasks.', ) parser.add_argument( '-mcs', '--metrics', type=str, default='default', help='list of metrics to show/compute, e.g. all, default,' 'or give a list split by , like ' 'ppl,f1,accuracy,hits@1,rouge,bleu' 'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l', ) TensorboardLogger.add_cmdline_args(parser) parser.set_defaults(datatype='valid') parser.set_defaults(batchsize=256) return parser
def setup_args(parser=None): if parser is None: parser = ParlaiParser(True, True, 'Evaluate a model') parser.add_pytorch_datateacher_args() # Get command line arguments parser.add_argument('-rp', '--report', type=str, default="/tmp/eval_model.json") parser.add_argument( '-rf', '--report-filename', type=str, default='', help='Saves a json file of the evaluation report either as an ' 'extension to the model-file (if begins with a ".") or a whole ' 'file path. Set to the empty string to not save at all.', ) parser.add_argument( '--save-world-logs', type='bool', default=False, help='Saves a jsonl file containing all of the task examples and ' 'model replies. Must also specify --report-filename.', ) parser.add_argument('-ne', '--num-examples', type=int, default=-1) parser.add_argument('-d', '--display-examples', type='bool', default=False) parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2) parser.add_argument( '-micro', '--aggregate-micro', type='bool', default=False, help='If multitasking, average metrics over the ' 'number of examples. If false, averages over the ' 'number of tasks.', ) parser.add_argument( '-mcs', '--metrics', type=str, default='default', help='list of metrics to show/compute, e.g. all, default,' 'or give a list split by , like ' 'ppl,f1,accuracy,hits@1,rouge,bleu' 'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l', ) WorldLogger.add_cmdline_args(parser) TensorboardLogger.add_cmdline_args(parser) parser.set_defaults(datatype='valid') return parser
def setup_args(parser=None): if parser is None: parser = ParlaiParser(True, True, 'Evaluate a model') # Get command line arguments parser.add_argument( '-rf', '--report-filename', type=str, default='', help='Saves a json file of the evaluation report either as an ' 'extension to the model-file (if begins with a ".") or a whole ' 'file path. Set to the empty string to not save at all.', ) parser.add_argument( '--world-logs', type=str, default='', help='Saves a jsonl file of the world logs.' 'Set to the empty string to not save at all.', ) parser.add_argument( '--save-format', type=str, default='conversations', choices=['conversations', 'parlai'], ) parser.add_argument('-ne', '--num-examples', type=int, default=-1) parser.add_argument('-d', '--display-examples', type='bool', default=False) parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=10) parser.add_argument( '-mcs', '--metrics', type=str, default='default', help='list of metrics to show/compute, e.g. all, default,' 'or give a list split by , like ' 'ppl,f1,accuracy,hits@1,rouge,bleu' 'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l', ) parser.add_argument( '-micro', '--aggregate-micro', type='bool', default=False, help='Report micro-averaged metrics instead of macro averaged metrics.', recommended=False, ) WorldLogger.add_cmdline_args(parser, partial_opt=None) TensorboardLogger.add_cmdline_args(parser, partial_opt=None) parser.set_params(datatype='valid') return parser
def setup_args(parser=None): if parser is None: parser = ParlaiParser(True, True, 'Evaluate a model') parser.add_pytorch_datateacher_args() # Get command line arguments parser.add_argument('-ne', '--num-examples', type=int, default=-1) parser.add_argument('-d', '--display-examples', type='bool', default=False) parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2) parser.add_argument('--metrics', type=str, default="all", help="list of metrics to show/compute, e.g. " "ppl,f1,accuracy,hits@1." "If 'all' is specified [default] all are shown.") TensorboardLogger.add_cmdline_args(parser) parser.set_defaults(datatype='valid') return parser
def __init__(self, optAgent, shared=None): init_model, is_finetune = self._get_init_model(optAgent, shared) super().__init__(optAgent, shared) if optAgent.get('numthreads', 1) > 1: torch.set_num_threads(1) optAgent['gradient_clip'] = opt.maxgrad self.criterion = opt.criterion self.loss = opt.loss self.drawVars = opt.drawVars opt.edim = optAgent['embeddingsize'] opt.vocabsize = len(self.dict) opt.__dict__.update(optAgent) opt.agent = self opt.fp16 = self.fp16 torch.manual_seed(args.rank) np.random.seed(args.rank) self.writeVars = 0 self.vars = {} if optAgent['tensorboard_log']: self.writeVars, *_ = getWriter(writer=TensorboardLogger(optAgent)) if self.fp16: try: from apex import amp except ImportError: raise ImportError( 'No fp16 support without apex. Please install it from ' 'https://github.com/NVIDIA/apex') self.getParameters = lambda: amp.master_params(self.optimizer) self.amp = amp else: self.getParameters = lambda: self.model.parameters() if not shared: model = Model(opt) self.model = model if init_model: print('Loading existing model parameters from ' + init_model) states = self.load(init_model) else: states = {} initParameters(opt, self.model) if self.use_cuda: self.model.cuda() self.model.train() if optAgent.get('numthreads', 1) > 1: self.model.share_memory() paramOptions = getParamOptions(opt, self.model) self.init_optim(paramOptions, states.get('optimizer'), states.get('saved_optim_type', None)) self.build_lr_scheduler(states, hard_reset=is_finetune) if is_distributed(): self.model = nn.parallel.DistributedDataParallel( self.model, device_ids=[self.opt['gpu']], broadcast_buffers=False) self.reset() else: self.model = shared['model'] self.dict = shared['dict'] if 'optimizer' in shared: self.optimizer = shared['optimizer']
def setup_args(parser=None): if parser is None: parser = ParlaiParser(True, True, 'compute statistics from model predictions') DictionaryAgent.add_cmdline_args(parser) # These defaults can be overriden by both .opt file and user's command line flags parser.add_argument('-ne', '--num-examples', type=int, default=-1) parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2) parser.add_argument( '-ed', '--external-dict', type=str, default=None, help='External dictionary for stat computation', ) parser.add_argument( '-fb', '--freq-bins', type=str, default='0,100,1000,10000', help='Bins boundaries for rare words stat', ) parser.add_argument( '-gr', '--gold-response', type=bool, default=False, help='Compute stats for gold response', ) # These settings override .opt file but not user's command line flags parser.set_params( datatype='valid', task='projects.controllable_dialogue.tasks.agents', model= 'projects.controllable_dialogue.controllable_seq2seq.controllable_seq2seq:ControllableSeq2seqAgent', # noqa: E501 batchsize=64, beam_size=20, beam_min_n_best=10, use_reply='model', ) TensorboardLogger.add_cmdline_args(parser) return parser
def __init__(self, opt, shared=None): super().__init__(opt, shared) if opt['tensorboard_log'] is True: self.writer = TensorboardLogger(opt) # default one does not average self.rank_loss = torch.nn.CrossEntropyLoss(reduce=True, size_average=True) torch.autograd.set_detect_anomaly(True) torch.manual_seed(123)
def setup_args(parser=None) -> ParlaiParser: """ Build the ParlAI parser, adding command line args if necessary. :param ParlaiParser parser: Preexisting parser to append options to. Will be created if needed. :returns: the ParlaiParser with CLI options added. """ if parser is None: parser = ParlaiParser(True, True, 'Train a model') train = parser.add_argument_group('FTML Training Loop Arguments') # train.add_argument('--n_grad', type='bool', default=False, hidden=True) train.add_argument('-nomt', '--no-multi-task', type='bool', default=False) train.add_argument('-nepb', '--num-episode-batch', type=int, default=3) train.add_argument('-ebs', '--eval-batch-size', type=int, default=32) TensorboardLogger.add_cmdline_args(parser) parser = setup_train_args(parser) return parser
def __init__(self, opt, shared=None): super().__init__(opt, shared) if opt['tensorboard_log'] is True: self.writer = TensorboardLogger(opt) self.dictionnary_size = 177 self.embedding_dim = 100 self.batch_size = opt["batchsize"] self.criterion = nn.CrossEntropyLoss() def weight_init(m): if isinstance(m, nn.Linear): nn.init.xavier_normal_(m.weight.data) self.recurrent_entity_network = RecurrentEntityNetwork( self.dictionnary_size, self.embedding_dim, sequence_length=7) self.recurrent_entity_network.apply(weight_init) self.optimizer = optim.Adam(self.recurrent_entity_network.parameters()) #self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 25, 0.5) self.batch_iter = 0
class TrainLoop: """ TrainLoop contains the core training loop logic. """ def __init__(self, opt): # if python is called from a non-interactive shell, like a bash script, # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are # not produced. This line brings them back signal.signal(signal.SIGINT, signal.default_int_handler) # Possibly load from checkpoint trainstats_suffix = '.trainstats' # we might load training statistics from here if ( opt['load_from_checkpoint'] and opt.get('model_file') and PathManager.exists(opt['model_file'] + '.checkpoint') ): opt['init_model'] = opt['model_file'] + '.checkpoint' trainstats_suffix = '.checkpoint.trainstats' # Possibly build a dictionary (not all models do this). if not (opt.get('dict_file') or opt.get('model_file')): raise RuntimeError( 'WARNING: For train_model, please specify either a ' 'model_file or dict_file.' ) if 'dict_file' in opt: if opt['dict_file'] is None and opt.get('model_file'): opt['dict_file'] = opt['model_file'] + '.dict' logging.info("building dictionary first...") build_dict(opt, skip_if_built=True) # Create model and assign it to the specified task self.agent = create_agent(opt) self.agent.opt.log() self.world = create_task(opt, self.agent) # set up timers self.train_time = Timer() self.validate_time = Timer() self.log_time = Timer() self.save_time = Timer() self.parleys = 0 self._train_steps = 0 self._last_log_steps = 0 self.update_freq = opt.get('update_freq', 1) self.max_num_epochs = _num_else_inf(opt, 'num_epochs', distributed_warn=True) self.max_train_time = _num_else_inf( opt, 'max_train_time', distributed_warn=True ) self.max_train_steps = _num_else_inf(opt, 'max_train_steps') self.log_every_n_secs = _num_else_inf( opt, 'log_every_n_secs', distributed_warn=True ) self.log_every_n_steps = _num_else_inf(opt, 'log_every_n_steps') self.val_every_n_secs = _num_else_inf( opt, 'validation_every_n_secs', distributed_warn=True ) self.val_every_n_epochs = _num_else_inf( opt, 'validation_every_n_epochs', distributed_warn=True ) self.val_every_n_steps = _num_else_inf(opt, 'validation_every_n_steps') self.save_every_n_secs = _num_else_inf( opt, 'save_every_n_secs', distributed_warn=True ) # smart defaults for --validation-metric-mode if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}: opt['validation_metric_mode'] = 'min' elif opt['validation_metric'] in {'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'}: opt['validation_metric_mode'] = 'max' if opt.get('validation_metric_mode') is None: opt['validation_metric_mode'] = 'max' self.last_valid_epoch = 0 self._last_valid_steps = 0 self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.train_reports = [] self.valid_reports = [] self.final_valid_report = {} self.final_test_report = {} self.final_extra_valid_report = {} self.best_valid = None self.impatience = 0 self.saved = False self.valid_worlds = None self.opt = opt # we may have been preempted, make sure we note that amount self._preempted_epochs = 0.0 if opt.get('model_file') and PathManager.exists( opt['model_file'] + trainstats_suffix ): # looks like we were preempted. make sure we load up our total # training stats, etc with PathManager.open(opt['model_file'] + trainstats_suffix) as ts: obj = json.load(ts) self.parleys = obj.get('parleys', 0) self._preempted_epochs = obj.get('total_epochs', 0) self.train_time.total = obj.get('train_time', 0) self._train_steps = obj.get('train_steps', 0) self.impatience = obj.get('impatience', 0) self.valid_reports = obj.get('valid_reports', []) if self.valid_reports: self.last_valid_epoch = self.valid_reports[-1].get( 'total_epochs', 0.0 ) self.train_reports = obj.get('train_reports', []) if 'best_valid' in obj: self.best_valid = obj['best_valid'] else: # old method if opt.get('model_file') and PathManager.exists( opt['model_file'] + '.best_valid' ): with PathManager.open( opt['model_file'] + ".best_valid", 'r' ) as f: x = f.readline() self.best_valid = float(x) f.close() if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger = TensorboardLogger(opt) if opt['wandb_log'] and is_primary_worker(): model = self.agent.model if hasattr(self.agent, 'model') else None self.wb_logger = WandbLogger(opt, model) def save_model(self, suffix=None): """ Save the model to disk, possibly with a suffix. """ if not self.opt.get('model_file'): # nothing to save to, just exit return fn = self.opt['model_file'] if suffix: fn += suffix if not is_primary_worker(): # never do IO as a non-primary worker if hasattr(self.agent, 'save_nonprimary'): self.agent.save_nonprimary(fn) return while True: # don't ever let a ctrl-c interrupt saving try: self.agent.save(fn) self._save_train_stats(suffix) break except KeyboardInterrupt: pass def _save_train_stats(self, suffix=None): if not is_primary_worker(): # never do IO as a non-primary worker return fn = self.opt.get('model_file', None) if not fn: return if suffix: fn += suffix fn += '.trainstats' with PathManager.open(fn, 'w') as f: json.dump( { 'parleys': self.parleys, 'train_time': self.train_time.time(), 'train_steps': self._train_steps, 'total_epochs': self._total_epochs, 'train_reports': self.train_reports, 'valid_reports': self.valid_reports, 'best_valid': self.best_valid, 'impatience': self.impatience, 'final_valid_report': dict_report(self.final_valid_report), 'final_test_report': dict_report(self.final_test_report), 'final_extra_valid_report': dict_report( self.final_extra_valid_report ), }, f, indent=4, ) def validate(self): """ Perform a validation run, checking whether we should stop training. :return: boolean indicating whether training should stop :rtype: bool """ opt = self.opt if self.valid_worlds is None: # we need to load the world now self.valid_worlds = load_eval_worlds(self.agent, opt, 'valid') # run evaluation on valid set valid_report = self._run_eval( self.valid_worlds, opt, 'valid', opt['validation_max_exs'] ) v = dict_report(valid_report) v['train_time'] = self.train_time.time() v['parleys'] = self.parleys v['train_steps'] = self._train_steps v['total_exs'] = self._total_exs v['total_epochs'] = self._total_epochs self.valid_reports.append(v) # logging if opt['tensorboard_log'] and is_primary_worker(): valid_report['total_exs'] = self._total_exs self.tb_logger.log_metrics('valid', self.parleys, valid_report) # flush on a validation self.tb_logger.flush() if opt['wandb_log'] and is_primary_worker(): valid_report['total_exs'] = self._total_exs self.wb_logger.log_metrics('valid', self.parleys, valid_report) # send valid metrics to agent if the agent wants them if hasattr(self.agent, 'receive_metrics'): self.agent.receive_metrics(valid_report) # check which metric to look at new_valid = valid_report[opt['validation_metric']] if isinstance(new_valid, Metric): new_valid = new_valid.value() # check if this is the best validation so far if ( self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid ): logging.success( 'new best {}: {:.4g}{}'.format( opt['validation_metric'], new_valid, ' (previous best was {:.4g})'.format(self.best_valid) if self.best_valid is not None else '', ) ) self.best_valid = new_valid self.impatience = 0 if opt.get('model_file'): logging.info(f"saving best valid model: {opt['model_file']}") self.save_model() self.saved = True if ( opt['validation_metric_mode'] == 'max' and self.best_valid >= opt['validation_cutoff'] ) or ( opt['validation_metric_mode'] == 'min' and self.best_valid <= opt['validation_cutoff'] ): logging.info('task solved! stopping.') return True else: self.impatience += 1 logging.report( 'did not beat best {}: {} impatience: {}'.format( opt['validation_metric'], round(self.best_valid, 4), self.impatience ) ) self.validate_time.reset() # saving if opt.get('model_file') and opt.get('save_after_valid'): logging.info(f"saving model checkpoint: {opt['model_file']}.checkpoint") self.save_model('.checkpoint') # check if we are out of patience if ( opt['validation_patience'] > 0 and self.impatience >= opt['validation_patience'] ): logging.info('ran out of patience! stopping training.') return True return False def _run_single_eval(self, opt, valid_world, max_exs, datatype, is_multitask, task): # run evaluation on a single world valid_world.reset() world_logger = None task_opt = opt.copy() # set up world logger for the "test" fold if opt['world_logs'] and datatype == 'test': task_opt['world_logs'] = get_task_world_logs( task, opt['world_logs'], is_multitask ) world_logger = WorldLogger(task_opt) cnt = 0 max_cnt = max_exs if max_exs > 0 else float('inf') while not valid_world.epoch_done() and cnt < max_cnt: valid_world.parley() if world_logger is not None: world_logger.log(valid_world) if cnt == 0 and opt['display_examples']: print(valid_world.display() + '\n~~') print(valid_world.report()) cnt = valid_world.report().get('exs') or 0 if world_logger is not None: # dump world acts to file world_logger.reset() # add final acts to logs if is_distributed(): rank = get_rank() base_outfile, extension = os.path.splitext(task_opt['world_logs']) outfile = base_outfile + f'_{rank}' + extension else: outfile = task_opt['world_logs'] world_logger.write(outfile, valid_world, file_format=opt['save_format']) valid_report = valid_world.report() if opt.get('validation_share_agent', False): valid_world.reset() # make sure world doesn't remember valid data return valid_report def _run_eval( self, valid_worlds, opt, datatype, max_exs=-1, write_log=False, extra_log_suffix="", ): """ Eval on validation/test data. :param valid_world: list of the pre-created validation worlds. :param opt: the options that specific the task, eval_task, etc :param datatype: the datatype to use, such as "valid" or "test" :param bool write_log: specifies to write metrics to file if the model_file is set :param int max_exs: limits the number of examples if max_exs > 0 """ logging.info(f'running eval: {datatype}') timer = Timer() reports = [] max_exs_per_worker = max_exs / (len(valid_worlds) * num_workers()) is_multitask = len(valid_worlds) > 1 for index, v_world in enumerate(valid_worlds): if opt.get('evaltask'): task = opt['evaltask'].split(',')[index] else: task = opt['task'].split(',')[index] task_report = self._run_single_eval( opt, v_world, max_exs_per_worker, datatype, is_multitask, task ) reports.append(task_report) tasks = [world.getID() for world in valid_worlds] named_reports = dict(zip(tasks, reports)) report = aggregate_named_reports( named_reports, micro_average=self.opt.get('aggregate_micro', False) ) # get the results from all workers report = self._sync_metrics(report) metrics = f'{datatype}:\n{nice_report(report)}\n' logging.info(f'eval completed in {timer.time():.2f}s') logging.report(metrics) # write to file if write_log and opt.get('model_file') and is_primary_worker(): # Write out metrics with PathManager.open( opt['model_file'] + extra_log_suffix + '.' + datatype, 'a' ) as f: f.write(f'{metrics}\n') return report def _run_final_extra_eval(self, opt): final_valid_opt = copy.deepcopy(opt) final_valid_opt_raw = Opt.load_init(opt['final_extra_opt']) final_datatype = final_valid_opt_raw["datatype"] for k, v in final_valid_opt_raw.items(): final_valid_opt[k] = v final_max_exs = ( final_valid_opt['validation_max_exs'] if final_valid_opt.get('short_final_eval') else -1 ) final_valid_world = load_eval_worlds( self.agent, final_valid_opt, final_datatype ) final_valid_report = self._run_eval( final_valid_world, final_valid_opt, final_datatype, final_max_exs, write_log=True, extra_log_suffix="_extra", ) if opt['wandb_log'] and is_primary_worker(): self.wb_logger.log_final(final_datatype, final_valid_report) return final_valid_report def _sync_metrics(self, metrics): """ Sync training metrics across workers. A handful of special cases are handled as exceptions, and the remaining metrics are simply averaged across workers. """ if not is_distributed(): # nothing special needed return metrics all_versions = all_gather_list(metrics) return aggregate_unnamed_reports(all_versions) def _compute_eta( self, epochs_completed: float, time_elapsed: float, steps_taken: int ): """ Compute the estimated seconds remaining in training. :param float epochs_completed: number of epochs already completed. :param float time_elapsed: total time spent already, in seconds. :return: ETA in seconds, or None if not computable """ # start off with no estimate eta = None # Determine time_left and num_epochs max_epochs = self.opt.get('num_epochs', 0) if max_epochs > 0 and epochs_completed > 0: epoch_progress = epochs_completed / max_epochs eta = (1 - epoch_progress) * time_elapsed / epoch_progress max_training_time = self.opt.get('max_training_time', -1) if max_training_time > 0: time_left = max_training_time - time_elapsed if eta is None or time_left < eta: eta = time_left max_train_steps = self.opt.get('max_train_steps', -1) if max_train_steps > 0 and steps_taken > 0: steps_progress = steps_taken / max_train_steps eta = (1 - steps_progress) * time_elapsed / steps_progress return eta def _get_time(self, world: World) -> Tuple[float, float, float]: """ Return train, log, and validate timing. If relying on the time for validation/logging/max train time purposes, we sync and return primary worker's time. Otherwise, it's not super relevant what we do here. **SIDE EFFECT**: Update _total_epochs trained. :param world: current running world :return (train, log, valid): return time for each of train, log, and validation """ if ( self.max_train_time < float('inf') or self.log_every_n_secs < float('inf') or self.val_every_n_secs < float('inf') or self.val_every_n_epochs < float('inf') or self.max_num_epochs < float('inf') ): self._total_epochs = self._preempted_epochs + sum( all_gather_list(world.get_total_epochs()) ) train_time, log_time, validate_time, save_time = sync_object( ( self.train_time.time(), self.log_time.time(), self.validate_time.time(), self.save_time.time(), ) ) else: train_time, log_time, validate_time, save_time = ( self.train_time.time(), self.log_time.time(), self.validate_time.time(), self.save_time.time(), ) self._total_epochs = self._preempted_epochs + ( num_workers() * world.get_total_epochs() ) return train_time, log_time, validate_time, save_time def log(self): """ Output a training log entry. """ opt = self.opt if opt['display_examples']: print(self.world.display() + '\n~~') logs = [] # get report train_report = self.world.report() train_report = self._sync_metrics(train_report) self.world.reset_metrics() train_report_trainstats = dict_report(train_report) train_report_trainstats['total_epochs'] = self._total_epochs train_report_trainstats['total_exs'] = self._total_exs train_report_trainstats['parleys'] = self.parleys train_report_trainstats['train_steps'] = self._train_steps train_report_trainstats['train_time'] = self.train_time.time() self.train_reports.append(train_report_trainstats) # time elapsed logs.append(f'time:{self.train_time.time():.0f}s') logs.append(f'total_exs:{self._total_exs}') logs.append(f'total_steps:{self._train_steps}') if self._total_epochs >= 0: # only if it's unbounded logs.append(f'epochs:{self._total_epochs:.2f}') time_left = self._compute_eta( self._total_epochs, self.train_time.time(), self._train_steps ) if time_left is not None: logs.append(f'time_left:{max(0,time_left):.0f}s') log = '{}\n{}\n'.format(' '.join(logs), nice_report(train_report)) logging.info(log) self.log_time.reset() self._last_log_steps = 0 if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.log_metrics('train', self.parleys, train_report) if opt['wandb_log'] and is_primary_worker(): self.wb_logger.log_metrics('train', self.parleys, train_report) return train_report def train_steps(self): """ Core training loop. Yields a metrics dict with each log. """ logging.info('training...') opt = self.opt world = self.world with world: while True: # do one example / batch of examples try: world.parley() except StopTrainException as e: logging.info(f"Stopping from {e}") break self.parleys += 1 self._train_steps = self.parleys // self.update_freq self._last_log_steps += 1 / self.update_freq # the following additionally updates self._total_epochs train_time, log_time, validate_time, save_time = self._get_time(world) # get the total training examples done, compute epochs exs_per_epoch = world.num_examples() self._total_exs = int(np.round(self._total_epochs * exs_per_epoch)) # check counters and timers if self._total_epochs >= self.max_num_epochs: yield self.log() logging.info( f'num_epochs completed:{self.max_num_epochs} time elapsed:{train_time}s' ) break if train_time > self.max_train_time: logging.info(f'max_train_time elapsed:{train_time}s') break if self._train_steps >= self.max_train_steps: logging.info( f'max_train_steps elapsed:{self._train_steps} ' f'time elapsed:{train_time}s' ) break if ( log_time > self.log_every_n_secs or self._last_log_steps >= self.log_every_n_steps ): yield self.log() if ( validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs or self._train_steps - self._last_valid_steps >= self.val_every_n_steps ): try: # log before we validate if self._last_log_steps: yield self.log() world.reset_metrics() stop_training = self.validate() except StopTrainException: break # reset the log time because we logged right before validating self.log_time.reset() self.last_valid_epoch = self._total_epochs self._last_valid_steps = self._train_steps if stop_training: break # make sure metrics are clean before we log world.reset_metrics() if save_time > self.save_every_n_secs and opt.get('model_file'): logging.info( f"saving model checkpoint: {opt['model_file']}.checkpoint" ) if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.flush() self.save_model('.checkpoint') self.save_time.reset() if not sync_object(self.saved): # save agent self.save_model() # there's a rare edge case where the we never saved the model, and we try # # to reload it. This sync_object ensures all workers wait for the primary # worker to finish flushing before loading from disk. sync_object(None) if opt.get('model_file'): # clean up all our memory, just to make sure we don't OOM on GPU when # reloading the world del world del self.world del self.agent del self.valid_worlds # reload best validation model self.agent = create_agent(opt) def train(self): """ Perform a training run. :return: tuple of reports (validation_report, test_report) """ opt = self.opt for _train_log in self.train_steps(): # we've already done what we need in these pass # perform final validation/testing valid_worlds = load_eval_worlds(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get('short_final_eval') else -1 self.final_valid_report = self._run_eval( valid_worlds, opt, 'valid', max_exs, write_log=True ) test_worlds = load_eval_worlds(self.agent, opt, 'test') self.final_test_report = self._run_eval( test_worlds, opt, 'test', max_exs, write_log=True ) if opt['wandb_log'] and is_primary_worker(): self.wb_logger.log_final('valid', self.final_valid_report) self.wb_logger.log_final('test', self.final_test_report) self.wb_logger.finish() if valid_worlds: for valid_world in valid_worlds: valid_world.shutdown() if test_worlds: for test_world in test_worlds: test_world.shutdown() print_announcements(opt) if opt['final_extra_opt'] != '': self.final_extra_valid_report = self._run_final_extra_eval(opt) if opt['wandb_log'] and is_primary_worker(): self.wb_logger.finish() self._save_train_stats() return self.final_valid_report, self.final_test_report
def __init__(self, opt): # if python is called from a non-interactive shell, like a bash script, # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are # not produced. This line brings them back signal.signal(signal.SIGINT, signal.default_int_handler) if isinstance(opt, ParlaiParser): print( '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]' ) opt = opt.parse_args() # Possibly load from checkpoint trainstats_suffix = '.trainstats' # we might load training statistics from here if opt['load_from_checkpoint'] and opt.get( 'model_file') and os.path.isfile(opt['model_file'] + '.checkpoint'): opt['init_model'] = opt['model_file'] + '.checkpoint' trainstats_suffix = '.checkpoint.trainstats' # Possibly build a dictionary (not all models do this). if opt['dict_build_first'] and 'dict_file' in opt: # If data built via pytorch data teacher, we need to load prebuilt dict if opt.get('pytorch_teacher_task'): opt['dict_file'] = get_pyt_dict_file(opt) elif opt['dict_file'] is None and opt.get('model_file'): opt['dict_file'] = opt['model_file'] + '.dict' print("[ building dictionary first... ]") build_dict(opt, skip_if_built=True) # Create model and assign it to the specified task self.agent = create_agent(opt) #specify model such as seq2seq self.world = create_task(opt, self.agent) # bacthworld or other world # set up timers self.train_time = Timer() self.validate_time = Timer() self.log_time = Timer() self.save_time = Timer() print('[ training... ]') self.parleys = 0 self.max_num_epochs = opt[ 'num_epochs'] if opt['num_epochs'] > 0 else float('inf') self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 \ else float('inf') self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \ else float('inf') self.val_every_n_secs = \ opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 \ else float('inf') self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] \ > 0 else float('inf') self.val_every_n_epochs = \ opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 \ else float('inf') # smart defaults for --validation-metric-mode if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}: opt['validation_metric_mode'] = 'min' elif opt['validation_metric'] in { 'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu' }: opt['validation_metric_mode'] = 'max' if opt.get('validation_metric_mode') is None: opt['validation_metric_mode'] = 'max' self.last_valid_epoch = 0 self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.best_valid = None if opt.get('model_file') and os.path.isfile(opt['model_file'] + '.best_valid'): with open(opt['model_file'] + ".best_valid", 'r') as f: x = f.readline() self.best_valid = float(x) f.close() self.impatience = 0 self.saved = False self.valid_world = None self.opt = opt # we may have been preempted, make sure we note that amount self._preempted_epochs = 0.0 if (opt.get('model_file') and os.path.isfile(opt['model_file'] + trainstats_suffix)): # looks like we were preempted. make sure we load up our total # training stats, etc with open(opt['model_file'] + trainstats_suffix) as ts: obj = json.load(ts) self._preempted_epochs = obj.get('total_epochs', 0) self.train_time.total = obj.get('train_time', 0) self.impatience = obj.get('impatience', 0) if opt['tensorboard_log'] is True: self.writer = TensorboardLogger(opt)
class TrainLoop(): def __init__(self, opt): # if python is called from a non-interactive shell, like a bash script, # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are # not produced. This line brings them back signal.signal(signal.SIGINT, signal.default_int_handler) if isinstance(opt, ParlaiParser): print( '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]' ) opt = opt.parse_args() # Possibly load from checkpoint trainstats_suffix = '.trainstats' # we might load training statistics from here if opt['load_from_checkpoint'] and opt.get( 'model_file') and os.path.isfile(opt['model_file'] + '.checkpoint'): opt['init_model'] = opt['model_file'] + '.checkpoint' trainstats_suffix = '.checkpoint.trainstats' # Possibly build a dictionary (not all models do this). if opt['dict_build_first'] and 'dict_file' in opt: # If data built via pytorch data teacher, we need to load prebuilt dict if opt.get('pytorch_teacher_task'): opt['dict_file'] = get_pyt_dict_file(opt) elif opt['dict_file'] is None and opt.get('model_file'): opt['dict_file'] = opt['model_file'] + '.dict' print("[ building dictionary first... ]") build_dict(opt, skip_if_built=True) # Create model and assign it to the specified task self.agent = create_agent(opt) #specify model such as seq2seq self.world = create_task(opt, self.agent) # bacthworld or other world # set up timers self.train_time = Timer() self.validate_time = Timer() self.log_time = Timer() self.save_time = Timer() print('[ training... ]') self.parleys = 0 self.max_num_epochs = opt[ 'num_epochs'] if opt['num_epochs'] > 0 else float('inf') self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 \ else float('inf') self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \ else float('inf') self.val_every_n_secs = \ opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 \ else float('inf') self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] \ > 0 else float('inf') self.val_every_n_epochs = \ opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 \ else float('inf') # smart defaults for --validation-metric-mode if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}: opt['validation_metric_mode'] = 'min' elif opt['validation_metric'] in { 'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu' }: opt['validation_metric_mode'] = 'max' if opt.get('validation_metric_mode') is None: opt['validation_metric_mode'] = 'max' self.last_valid_epoch = 0 self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.best_valid = None if opt.get('model_file') and os.path.isfile(opt['model_file'] + '.best_valid'): with open(opt['model_file'] + ".best_valid", 'r') as f: x = f.readline() self.best_valid = float(x) f.close() self.impatience = 0 self.saved = False self.valid_world = None self.opt = opt # we may have been preempted, make sure we note that amount self._preempted_epochs = 0.0 if (opt.get('model_file') and os.path.isfile(opt['model_file'] + trainstats_suffix)): # looks like we were preempted. make sure we load up our total # training stats, etc with open(opt['model_file'] + trainstats_suffix) as ts: obj = json.load(ts) self._preempted_epochs = obj.get('total_epochs', 0) self.train_time.total = obj.get('train_time', 0) self.impatience = obj.get('impatience', 0) if opt['tensorboard_log'] is True: self.writer = TensorboardLogger(opt) def save_model(self, suffix=None): if not is_primary_worker(): # never do IO as a non-primary worker return if not self.opt.get('model_file'): # nothing to save to, just exit return fn = self.opt['model_file'] if suffix: fn += suffix while True: # don't ever let a ctrl-c interrupt saving try: self.agent.save(fn) self._save_train_stats(suffix) break except KeyboardInterrupt: pass def _save_train_stats(self, suffix=None): fn = self.opt['model_file'] if suffix: fn += suffix fn += '.trainstats' with open(fn, 'w') as f: json.dump( { 'train_time': self.train_time.time(), 'total_epochs': (self._preempted_epochs + num_workers() * self.world.get_total_epochs()), 'impatience': self.impatience, }, f) def validate(self): opt = self.opt if self.valid_world is None: # we need to load the world now self.valid_world = _maybe_load_eval_world(self.agent, opt, 'valid') # run evaluation on valid set valid_report = sync_object( run_eval(self.valid_world, opt, 'valid', opt['validation_max_exs'], True)) # logging if opt['tensorboard_log'] is True and is_primary_worker(): self.writer.add_metrics('valid', int(self.train_time.time()), valid_report) # saving if (opt.get('model_file') and opt.get('save_after_valid') and is_primary_worker()): print("[ saving model checkpoint: " + opt['model_file'] + ".checkpoint ]") self.save_model('.checkpoint') # send valid metrics to agent if the agent wants them if hasattr(self.agent, 'receive_metrics'): self.agent.receive_metrics(valid_report) # check which metric to look at if '/' in opt['validation_metric']: # if you are multitasking and want your validation metric to be # a metric specific to a subtask, specify your validation metric # as -vmt subtask/metric subtask = opt['validation_metric'].split('/')[0] validation_metric = opt['validation_metric'].split('/')[1] new_valid = valid_report['tasks'][subtask][validation_metric] else: new_valid = valid_report[opt['validation_metric']] # check if this is the best validation so far if (self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid): print('[ new best {}: {}{} ]'.format( opt['validation_metric'], new_valid, ' (previous best was {})'.format(self.best_valid) if self.best_valid is not None else '')) self.best_valid = new_valid self.impatience = 0 if opt.get('model_file') and is_primary_worker(): print("[ saving best valid model: " + opt['model_file'] + " ]") self.save_model() print("[ saving best valid metric: " + opt['model_file'] + ".best_valid ]") save_best_valid(opt['model_file'], self.best_valid) self.saved = True if (opt['validation_metric'] == 'accuracy' and self.best_valid >= opt['validation_cutoff']): print('[ task solved! stopping. ]') return True else: self.impatience += 1 print('[ did not beat best {}: {} impatience: {} ]'.format( opt['validation_metric'], round(self.best_valid, 4), self.impatience)) self.validate_time.reset() # check if we are out of patience if (opt['validation_patience'] > 0 and self.impatience >= opt['validation_patience']): print('[ ran out of patience! stopping training. ]') return True return False def _average_dicts(self, all_versions): # instead of a list-of-dicts with like keys, make a dict-of-lists with # keys to reduce to_reduce = {} for d in all_versions: for k, v in d.items(): to_reduce.setdefault(k, []).append(v) # now perform the reduction finalized = {} for k, values in to_reduce.items(): if k == 'exs' or k == 'total_skipped_batches': # sum across workers finalized[k] = np.sum(values) elif isinstance(values[0], dict): # do the same procedure recursively finalized[k] = self._average_dicts(values) else: # all other cases, take the mean across the workers finalized[k] = np.mean(values) return finalized def _sync_training_metrics(self, metrics): """ Sync training metrics across workers. A handful of special cases are handled as exceptions, and the remaining metrics are simply averaged across workers. """ if not is_distributed(): # nothing special needed return metrics all_versions = all_gather_list(metrics) return self._average_dicts(all_versions) def _nice_format(self, dictionary): rounded = {} for k, v in dictionary.items(): if isinstance(v, dict): rounded[k] = self._nice_format(v) elif isinstance(v, float): rounded[k] = round_sigfigs(v, 4) else: rounded[k] = v return rounded def _compute_eta(self, epochs_completed, time_elapsed): """ Computes the estimated seconds remaining in training. :param float epochs_completed: number of epochs already completed. :param float time_elapsed: total time spent already, in seconds. :return: ETA in seconds, or None if not computable """ # start off with no estimate eta = None # Determine time_left and num_epochs max_epochs = self.opt.get('num_epochs', 0) if max_epochs > 0 and epochs_completed > 0: epoch_progress = epochs_completed / max_epochs eta = (1 - epoch_progress) * time_elapsed / epoch_progress max_training_time = self.opt.get('max_training_time', -1) if max_training_time > 0: time_left = max_training_time - time_elapsed if eta is None or time_left < eta: eta = time_left return eta def log(self): opt = self.opt if opt['display_examples']: print(self.world.display() + '\n~~') logs = [] # get report train_report = self._sync_training_metrics(self.world.report()) self.world.reset_metrics() # time elapsed logs.append('time:{}s'.format(np.floor(self.train_time.time()))) logs.append('total_exs:{}'.format(self._total_exs)) if self._total_epochs >= 0: # only if it's unbounded logs.append('epochs:{}'.format(round(self._total_epochs, 2))) time_left = self._compute_eta(self._total_epochs, self.train_time.time()) if time_left is not None: logs.append('time_left:{}s'.format(max(0, np.ceil(time_left)))) log = '[ {} ] {}'.format(' '.join(logs), self._nice_format(train_report)) print(log) self.log_time.reset() if opt['tensorboard_log'] is True and is_primary_worker(): self.writer.add_metrics('train', self._total_exs, train_report) def train(self): if is_distributed(): warn_once( "Distributed training outputs average-per-worker metrics during " "training, and may be slightly distorted. Validation/test are " "unadulterated.") opt = self.opt world = self.world with world: while True: # do one example / batch of examples world.parley() self.parleys += 1 # print(world.display()) # get the total training examples done, compute epochs self._total_epochs = ( self._preempted_epochs + num_workers() * self.world.get_total_epochs()) exs_per_epoch = self.world.num_examples() self._total_exs = int( np.round(self._total_epochs * exs_per_epoch)) # and use the primary worker's timings for everything train_time, log_time, validate_time = sync_object( (self.train_time.time(), self.log_time.time(), self.validate_time.time())) # check counters and timers if self._total_epochs >= self.max_num_epochs: self.log() print( '[ num_epochs completed:{} time elapsed:{}s ]'.format( self.max_num_epochs, train_time)) break if train_time > self.max_train_time: print('[ max_train_time elapsed:{}s ]'.format(train_time)) break if log_time > self.log_every_n_secs: self.log() if (validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs): stop_training = self.validate() self.last_valid_epoch = self._total_epochs if stop_training: break if (self.save_time.time() > self.save_every_n_secs and opt.get('model_file') and is_primary_worker()): print("[ saving model checkpoint: {}.checkpoint".format( opt['model_file'])) self.save_model('.checkpoint') self.save_time.reset() if not self.saved and is_primary_worker(): # save agent self.save_model() elif opt.get('model_file'): # reload best validation model self.agent = create_agent(opt) valid_world = _maybe_load_eval_world(self.agent, opt, 'valid') v_report = run_eval(valid_world, opt, 'valid', write_log=True) test_world = _maybe_load_eval_world(self.agent, opt, 'test') t_report = run_eval(test_world, opt, 'test', write_log=True) if valid_world: valid_world.shutdown() if test_world: test_world.shutdown() return v_report, t_report
def setup_args(parser=None) -> ParlaiParser: """ Build the ParlAI parser, adding command line args if necessary. :param ParlaiParser parser: Preexisting parser to append options to. Will be created if needed. :returns: the ParlaiParser with CLI options added. """ if parser is None: parser = ParlaiParser(True, True, 'Train a model') train = parser.add_argument_group('Training Loop Arguments') train.add_argument( '-et', '--evaltask', help= 'task to use for valid/test (defaults to the one used for training)', ) train.add_argument( '--eval-batchsize', type=int, hidden=True, help='Eval time batch size (defaults to same as -bs)', ) train.add_argument('--display-examples', type='bool', default=False, hidden=True) train.add_argument('-eps', '--num-epochs', type=float, default=-1) train.add_argument('-ttim', '--max-train-time', type=float, default=-1) train.add_argument('-ltim', '--log-every-n-secs', type=float, default=10) train.add_argument( '-vtim', '--validation-every-n-secs', type=float, default=-1, help='Validate every n seconds. Saves model to model_file ' '(if set) whenever best val metric is found', ) train.add_argument( '-stim', '--save-every-n-secs', type=float, default=-1, help='Saves the model to model_file.checkpoint after ' 'every n seconds (default -1, never).', ) train.add_argument( '-sval', '--save-after-valid', type='bool', default=False, help='Saves the model to model_file.checkpoint after ' 'every validation (default %(default)s).', ) train.add_argument( '-veps', '--validation-every-n-epochs', type=float, default=-1, help='Validate every n epochs. Saves model to model_file ' '(if set) whenever best val metric is found', ) train.add_argument( '-vme', '--validation-max-exs', type=int, default=-1, hidden=True, help='max examples to use during validation (default -1 uses all)', ) train.add_argument( '--short-final-eval', default=False, hidden=True, type='bool', help='If true, obeys --validation-max-exs in the final ' 'validation and test evaluations.', ) train.add_argument( '-vp', '--validation-patience', type=int, default=10, help=('number of iterations of validation where result' ' does not improve before we stop training'), ) train.add_argument( '-vmt', '--validation-metric', default='accuracy', help='key into report table for selecting best validation', ) train.add_argument( '-vmm', '--validation-metric-mode', type=str, choices=['max', 'min'], help='how to optimize validation metric (max or min)', ) train.add_argument( '-vcut', '--validation-cutoff', type=float, default=1.0, hidden=True, help='value at which training will stop if exceeded by metric', ) train.add_argument( '-lfc', '--load-from-checkpoint', type='bool', default=True, hidden=True, help='load model from checkpoint if available', ) train.add_argument( '-vshare', '--validation-share-agent', default=False, hidden=True, help='use a shared copy of the agent for validation. ' 'this will eventually default to True, but ' 'currently defaults to False.', ) train.add_argument( '-mcs', '--metrics', type=str, default='default', help='list of metrics to show/compute, e.g. all, default,' 'or give a list split by , like ' 'ppl,f1,accuracy,hits@1,rouge,bleu' 'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l', ) train.add_argument( '-micro', '--aggregate-micro', type='bool', default=False, help='Report micro-averaged metrics instead of macro averaged metrics.', recommended=False, ) TensorboardLogger.add_cmdline_args(parser) parser = setup_dict_args(parser) return parser
def __init__(self, opt, shared=None): opt = copy.deepcopy(opt) super().__init__(opt, shared) self.use_cuda = not opt['no_cuda'] and torch.cuda.is_available() self.is_combine_attr = (hasattr(self, 'other_task_datafiles') and self.other_task_datafiles) self.random_policy = opt.get('random_policy', False) self.count_sample = opt.get('count_sample', False) self.anti = opt.get('anti', False) if self.random_policy: random.seed(17) if not shared: if not self.stream and opt.get('pace_by', 'sample') == 'bucket': score_list = [episode[0][2] for episode in self.data.data] assert score_list == sorted(score_list) num_buckets = opt.get('num_buckets', int(self.num_episodes() / 10)) lb_indices = [ int(len(score_list) * i / num_buckets) for i in range(num_buckets) ] lbs = [score_list[idx] for idx in lb_indices] bucket_ids = [ self.sort_into_bucket(ctrl_val, lbs) for ctrl_val in score_list ] bucket_cnt = [0 for _ in range(num_buckets)] for i in range(num_buckets): bucket_cnt[i] = bucket_ids.count(i) self.bucket_cnt = bucket_cnt self.lastYs = [None] * self.bsz # build multiple task data self.tasks = [self.data] if self.is_combine_attr: print('[ build multiple task data ... ]') for datafile in self.other_task_datafiles: task_opt = copy.deepcopy(opt) task_opt['datafile'] = datafile self.tasks.append( DialogData(task_opt, data_loader=self.setup_data, cands=self.label_candidates())) print('[ build multiple task data done! ]') # record the selections of each subtasks self.subtasks = opt['subtasks'].split(':') self.subtask_counter = OrderedDict() self.p_selections = OrderedDict() self.c_selections = OrderedDict() for t in self.subtasks: self.subtask_counter[t] = 0 self.p_selections[t] = [] self.c_selections[t] = [] if self.count_sample and not self.stream: self.sample_counter = OrderedDict() for idx, t in enumerate(self.subtasks): self.sample_counter[t] = [ 0 for _ in self.tasks[idx].data ] # setup the tensorboard log if opt['tensorboard_log_teacher'] is True: opt['tensorboard_tag'] = 'task' teacher_metrics = 'reward,policy_loss,critic_loss,mean_advantage_reward,action_ent'.split( ',') opt['tensorboard_metrics'] = ','.join( opt['tensorboard_metrics'].split(',') + teacher_metrics) self.writer = TensorboardLogger(opt) else: self.lastYs = shared['lastYs'] self.tasks = shared['tasks'] if not self.stream and opt.get('pace_by', 'sample') == 'bucket': self.bucket_cnt = shared['bucket_cnt'] if 'writer' in shared: self.writer = shared['writer'] if 'subtask_counter' in shared: self.subtask_counter = shared['subtask_counter'] if 'p_selections' in shared: self.p_selections = shared['p_selections'] if 'c_selections' in shared: self.c_selections = shared['c_selections'] # build the policy net, criterion and optimizer here self.state_dim = 32 + len(self.tasks) # hand-craft features self.action_dim = len(self.tasks) if not shared: self.policy = PolicyNet(self.state_dim, self.action_dim) self.critic = CriticNet(self.state_dim, self.action_dim) init_teacher = get_init_teacher(opt, shared) if init_teacher is not None: # load teacher parameters if available print('[ Loading existing teacher params from {} ]' ''.format(init_teacher)) states = self.load(init_teacher) else: states = {} else: self.policy = shared['policy'] self.critic = shared['critic'] states = shared['states'] if ( # only build an optimizer if we're training 'train' in opt.get('datatype', '') and # and this is the main model shared is None): # for policy net self.optimizer = self.init_optim( [p for p in self.policy.parameters() if p.requires_grad], lr=opt['learningrate_teacher'], optim_states=states.get('optimizer'), saved_optim_type=states.get('optimizer_type')) self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, 'min', factor=0.8, # 0.5 --> 0.8 patience=5, # 3 -- > 5 verbose=True) if 'lr_scheduler' in states: self.scheduler.load_state_dict(states['lr_scheduler']) # for critic net self.optimizer_critic = self.init_optim( [p for p in self.critic.parameters() if p.requires_grad], lr=opt['learningrate_teacher_critic'], optim_states=states.get('optimizer_critic'), saved_optim_type=states.get('optimizer_type')) self.scheduler_critic = optim.lr_scheduler.ReduceLROnPlateau( self.optimizer_critic, 'min', factor=0.8, # 0.5 --> 0.8 patience=5, # 3 -- > 5 verbose=True) if 'lr_scheduler_critic' in states: self.scheduler_critic.load_state_dict( states['lr_scheduler_critic']) self.critic_criterion = torch.nn.SmoothL1Loss() self.reward_metric = opt.get('reward_metric', 'total_metric') self.reward_metric_mode = opt.get('reward_metric_mode', 'max') self.prev_prev_valid_report = states[ 'prev_prev_valid_report'] if 'prev_prev_valid_report' in states else None self.prev_valid_report = states[ 'prev_valid_report'] if 'prev_valid_report' in states else None self.current_valid_report = states[ 'current_valid_report'] if 'current_valid_report' in states else None self.saved_actions = states[ 'saved_actions'] if 'saved_actions' in states else OrderedDict() self.saved_state_actions = states[ 'saved_state_actions'] if 'saved_state_actions' in states else OrderedDict( ) if self.use_cuda: for k, v in self.saved_actions.items(): self.saved_actions[k] = v.cuda() for k, v in self.saved_state_actions.items(): self.saved_state_actions[k] = v.cuda() self._number_teacher_updates = states[ '_number_teacher_updates'] if '_number_teacher_updates' in states else 0 # enable the batch_act self.use_batch_act = self.bsz > 1 self.T = self.opt.get('T', 1000) self.c0 = self.opt.get('c0', 0.01) self.p = self.opt.get('p', 2) # setup the timer self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \ else float('inf') self.action_log_time = Timer() self.move_to_cuda()
class DefaultTeacher(FbDialogTeacher): def __init__(self, opt, shared=None): opt = copy.deepcopy(opt) super().__init__(opt, shared) self.use_cuda = not opt['no_cuda'] and torch.cuda.is_available() self.is_combine_attr = (hasattr(self, 'other_task_datafiles') and self.other_task_datafiles) self.random_policy = opt.get('random_policy', False) self.count_sample = opt.get('count_sample', False) self.anti = opt.get('anti', False) if self.random_policy: random.seed(17) if not shared: if not self.stream and opt.get('pace_by', 'sample') == 'bucket': score_list = [episode[0][2] for episode in self.data.data] assert score_list == sorted(score_list) num_buckets = opt.get('num_buckets', int(self.num_episodes() / 10)) lb_indices = [ int(len(score_list) * i / num_buckets) for i in range(num_buckets) ] lbs = [score_list[idx] for idx in lb_indices] bucket_ids = [ self.sort_into_bucket(ctrl_val, lbs) for ctrl_val in score_list ] bucket_cnt = [0 for _ in range(num_buckets)] for i in range(num_buckets): bucket_cnt[i] = bucket_ids.count(i) self.bucket_cnt = bucket_cnt self.lastYs = [None] * self.bsz # build multiple task data self.tasks = [self.data] if self.is_combine_attr: print('[ build multiple task data ... ]') for datafile in self.other_task_datafiles: task_opt = copy.deepcopy(opt) task_opt['datafile'] = datafile self.tasks.append( DialogData(task_opt, data_loader=self.setup_data, cands=self.label_candidates())) print('[ build multiple task data done! ]') # record the selections of each subtasks self.subtasks = opt['subtasks'].split(':') self.subtask_counter = OrderedDict() self.p_selections = OrderedDict() self.c_selections = OrderedDict() for t in self.subtasks: self.subtask_counter[t] = 0 self.p_selections[t] = [] self.c_selections[t] = [] if self.count_sample and not self.stream: self.sample_counter = OrderedDict() for idx, t in enumerate(self.subtasks): self.sample_counter[t] = [ 0 for _ in self.tasks[idx].data ] # setup the tensorboard log if opt['tensorboard_log_teacher'] is True: opt['tensorboard_tag'] = 'task' teacher_metrics = 'reward,policy_loss,critic_loss,mean_advantage_reward,action_ent'.split( ',') opt['tensorboard_metrics'] = ','.join( opt['tensorboard_metrics'].split(',') + teacher_metrics) self.writer = TensorboardLogger(opt) else: self.lastYs = shared['lastYs'] self.tasks = shared['tasks'] if not self.stream and opt.get('pace_by', 'sample') == 'bucket': self.bucket_cnt = shared['bucket_cnt'] if 'writer' in shared: self.writer = shared['writer'] if 'subtask_counter' in shared: self.subtask_counter = shared['subtask_counter'] if 'p_selections' in shared: self.p_selections = shared['p_selections'] if 'c_selections' in shared: self.c_selections = shared['c_selections'] # build the policy net, criterion and optimizer here self.state_dim = 32 + len(self.tasks) # hand-craft features self.action_dim = len(self.tasks) if not shared: self.policy = PolicyNet(self.state_dim, self.action_dim) self.critic = CriticNet(self.state_dim, self.action_dim) init_teacher = get_init_teacher(opt, shared) if init_teacher is not None: # load teacher parameters if available print('[ Loading existing teacher params from {} ]' ''.format(init_teacher)) states = self.load(init_teacher) else: states = {} else: self.policy = shared['policy'] self.critic = shared['critic'] states = shared['states'] if ( # only build an optimizer if we're training 'train' in opt.get('datatype', '') and # and this is the main model shared is None): # for policy net self.optimizer = self.init_optim( [p for p in self.policy.parameters() if p.requires_grad], lr=opt['learningrate_teacher'], optim_states=states.get('optimizer'), saved_optim_type=states.get('optimizer_type')) self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, 'min', factor=0.8, # 0.5 --> 0.8 patience=5, # 3 -- > 5 verbose=True) if 'lr_scheduler' in states: self.scheduler.load_state_dict(states['lr_scheduler']) # for critic net self.optimizer_critic = self.init_optim( [p for p in self.critic.parameters() if p.requires_grad], lr=opt['learningrate_teacher_critic'], optim_states=states.get('optimizer_critic'), saved_optim_type=states.get('optimizer_type')) self.scheduler_critic = optim.lr_scheduler.ReduceLROnPlateau( self.optimizer_critic, 'min', factor=0.8, # 0.5 --> 0.8 patience=5, # 3 -- > 5 verbose=True) if 'lr_scheduler_critic' in states: self.scheduler_critic.load_state_dict( states['lr_scheduler_critic']) self.critic_criterion = torch.nn.SmoothL1Loss() self.reward_metric = opt.get('reward_metric', 'total_metric') self.reward_metric_mode = opt.get('reward_metric_mode', 'max') self.prev_prev_valid_report = states[ 'prev_prev_valid_report'] if 'prev_prev_valid_report' in states else None self.prev_valid_report = states[ 'prev_valid_report'] if 'prev_valid_report' in states else None self.current_valid_report = states[ 'current_valid_report'] if 'current_valid_report' in states else None self.saved_actions = states[ 'saved_actions'] if 'saved_actions' in states else OrderedDict() self.saved_state_actions = states[ 'saved_state_actions'] if 'saved_state_actions' in states else OrderedDict( ) if self.use_cuda: for k, v in self.saved_actions.items(): self.saved_actions[k] = v.cuda() for k, v in self.saved_state_actions.items(): self.saved_state_actions[k] = v.cuda() self._number_teacher_updates = states[ '_number_teacher_updates'] if '_number_teacher_updates' in states else 0 # enable the batch_act self.use_batch_act = self.bsz > 1 self.T = self.opt.get('T', 1000) self.c0 = self.opt.get('c0', 0.01) self.p = self.opt.get('p', 2) # setup the timer self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \ else float('inf') self.action_log_time = Timer() self.move_to_cuda() def move_to_cuda(self): if self.use_cuda: self.policy.cuda() self.critic.cuda() @classmethod def optim_opts(self): """ Fetch optimizer selection. By default, collects everything in torch.optim, as well as importing: - qhm / qhmadam if installed from github.com/facebookresearch/qhoptim Override this (and probably call super()) to add your own optimizers. """ # first pull torch.optim in optims = { k.lower(): v for k, v in optim.__dict__.items() if not k.startswith('__') and k[0].isupper() } try: import apex.optimizers.fused_adam as fused_adam optims['fused_adam'] = fused_adam.FusedAdam except ImportError: pass try: # https://openreview.net/pdf?id=S1fUpoR5FQ from qhoptim.pyt import QHM, QHAdam optims['qhm'] = QHM optims['qhadam'] = QHAdam except ImportError: # no QHM installed pass return optims def init_optim(self, params, lr, optim_states=None, saved_optim_type=None): """ Initialize optimizer with teacher parameters. :param params: parameters from the teacher :param optim_states: optional argument providing states of optimizer to load :param saved_optim_type: type of optimizer being loaded, if changed will skip loading optimizer states """ opt = self.opt # set up optimizer args kwargs = {'lr': lr} if opt.get('momentum_teacher') > 0 and opt['optimizer_teacher'] in [ 'sgd', 'rmsprop', 'qhm' ]: # turn on momentum for optimizers that use it kwargs['momentum'] = opt['momentum_teacher'] if opt['optimizer_teacher'] == 'sgd' and opt.get( 'nesterov_teacher', True): # for sgd, maybe nesterov kwargs['nesterov'] = opt.get('nesterov_teacher', True) elif opt['optimizer_teacher'] == 'qhm': # qhm needs a nu kwargs['nu'] = opt.get('nus_teacher', (0.7, ))[0] elif opt['optimizer_teacher'] == 'adam': # turn on amsgrad for adam # amsgrad paper: https://openreview.net/forum?id=ryQu7f-RZ kwargs['amsgrad'] = True elif opt['optimizer_teacher'] == 'qhadam': # set nus for qhadam kwargs['nus'] = opt.get('nus_teacher', (0.7, 1.0)) if opt['optimizer_teacher'] in [ 'adam', 'sparseadam', 'adamax', 'qhadam' ]: # set betas for optims that use it kwargs['betas'] = opt.get('betas_teacher', (0.9, 0.999)) optim_class = self.optim_opts()[opt['optimizer_teacher']] optimizer = optim_class(params, **kwargs) if optim_states: if saved_optim_type != opt['optimizer_teacher']: print('WARNING: not loading optim state since optim class ' 'changed.') else: try: optimizer.load_state_dict(optim_states) except ValueError: print('WARNING: not loading optim state since model ' 'params changed.') if self.use_cuda: for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() return optimizer def load(self, path): """ Return opt and teacher states. TODO: load behaviors should be consistent with function state_dict(). """ states = torch.load(path, map_location=lambda cpu, _: cpu) if 'policy' in states: self.policy.load_state_dict(states['policy']) if 'critic' in states: self.critic.load_state_dict(states['critic']) if 'optimizer' in states and hasattr(self, 'optimizer'): self.optimizer.load_state_dict(states['optimizer']) if 'optimizer_critic' in states and hasattr(self, 'optimizer_critic'): self.optimizer_critic.load_state_dict(states['optimizer_critic']) return states def share(self): shared = super().share() if hasattr(self, 'bucket_cnt'): shared['bucket_cnt'] = self.bucket_cnt shared['tasks'] = self.tasks shared['policy'] = self.policy shared['critic'] = self.critic shared['states'] = { 'optimizer_type': self.opt['optimizer_teacher'], 'prev_prev_valid_report': self.prev_prev_valid_report, 'prev_valid_report': self.prev_valid_report, 'current_valid_report': self.current_valid_report, 'saved_actions': self.saved_actions, 'saved_state_actions': self.saved_state_actions, } if hasattr(self, 'writer'): shared['writer'] = self.writer if hasattr(self, 'subtask_counter'): shared['subtask_counter'] = self.subtask_counter if hasattr(self, 'p_selections'): shared['p_selections'] = self.p_selections if hasattr(self, 'c_selections'): shared['c_selections'] = self.c_selections return shared @staticmethod def sort_into_bucket(val, bucket_lbs): """ Returns the highest bucket such that val >= lower bound for that bucket. Inputs: val: float. The value to be sorted into a bucket. bucket_lbs: list of floats, sorted ascending. Returns: bucket_id: int in range(num_buckets); the bucket that val belongs to. """ num_buckets = len(bucket_lbs) for bucket_id in range(num_buckets - 1, -1, -1): # iterate descending lb = bucket_lbs[bucket_id] if val >= lb: return bucket_id raise ValueError('val %f is not >= any of the lower bounds: %s' % (val, bucket_lbs)) def pace_function(self, states, sum_num, T=1000, c0=0.01, p=2): train_step = states['train_step'] progress = self.root_p_pace(train_step, T, c0, p) return int(sum_num * progress) @staticmethod def root_p_pace(timestep, T=1000, c0=0.01, p=2): root_p = math.pow( timestep * (1 - math.pow(c0, p)) / T + math.pow(c0, p), 1.0 / p) return min(1.0, root_p) def act(self, observation=None, task_idx=0): """Send new dialog message.""" if not hasattr(self, 'epochDone'): # reset if haven't yet self.reset() # get next example, action is episode_done dict if already out of exs action, self.epochDone = self.next_example(observation=observation, task_idx=task_idx) action['id'] = self.getID() # remember correct answer if available self.lastY = action.get('labels', action.get('eval_labels', None)) if ((not self.datatype.startswith('train') or 'evalmode' in self.datatype) and 'labels' in action): # move labels to eval field so not used for training # but this way the model can use the labels for perplexity or loss action = action.copy() labels = action.pop('labels') if not self.opt.get('hide_labels', False): action['eval_labels'] = labels return action def _cry_for_missing_in_obs(self, something): raise RuntimeError( "{} is needed to include in observations to build states!".format( something)) def _build_states(self, observations): for key in ['train_step', 'train_report', 'loss_desc', 'prob_desc']: if key not in observations[0]: self._cry_for_missing_in_obs(key) train_step = observations[0]['train_step'] # scala train_step = min(train_step / self.T, 1) train_report = observations[0]['train_report'] nll_loss = train_report.get('nll_loss', 0) / 10 # scala loss_desc = observations[0]['loss_desc'] loss_desc = F.normalize(loss_desc, p=2, dim=-1) prob_desc = observations[0]['prob_desc'] prob_desc = F.normalize(prob_desc, p=2, dim=-1) if hasattr(self, 'subtask_counter'): subtask_progress = self.subtask_counter.values() max_min = max(subtask_progress) - min(subtask_progress) subtask_progress = [ (item - min(subtask_progress)) / max_min if max_min > 0 else 0 for item in subtask_progress ] else: subtask_progress = [0] subtask_progress = torch.FloatTensor(subtask_progress) if self.use_cuda: subtask_progress = subtask_progress.cuda() prev_valid_report = self.prev_valid_report if prev_valid_report is None: prev_valid_report = {} bleu = prev_valid_report.get('bleu', 0) valid_nll_loss = prev_valid_report.get('nll_loss', 0) / 10 dist_1_ratio = prev_valid_report.get('dist_1_ratio', 0) dist_2_ratio = prev_valid_report.get('dist_2_ratio', 0) dist_3_ratio = prev_valid_report.get('dist_3_ratio', 0) embed_avg = prev_valid_report.get('embed_avg', 0) embed_greedy = prev_valid_report.get('embed_greedy', 0) embed_extrema = prev_valid_report.get('embed_extrema', 0) embed_coh = prev_valid_report.get('embed_coh', 0) intra_dist_1 = prev_valid_report.get('intra_dist_1', 0) / 10 intra_dist_2 = prev_valid_report.get('intra_dist_2', 0) / 10 intra_dist_3 = prev_valid_report.get('intra_dist_3', 0) / 10 response_length = prev_valid_report.get( 'response_length', 0) / self.opt.get('label_truncate', 100) # sent_entropy_uni = prev_valid_report.get('sent_entropy_uni', 0) / 100 # sent_entropy_bi = prev_valid_report.get('sent_entropy_bi', 0) / 100 # sent_entropy_tri = prev_valid_report.get('sent_entropy_tri', 0) / 100 word_entropy_uni = prev_valid_report.get('word_entropy_uni', 0) / 100 word_entropy_bi = prev_valid_report.get('word_entropy_bi', 0) / 100 word_entropy_tri = prev_valid_report.get('word_entropy_tri', 0) / 100 states = torch.FloatTensor([ train_step, nll_loss, bleu, valid_nll_loss, dist_1_ratio, dist_2_ratio, dist_3_ratio, embed_avg, embed_greedy, embed_extrema, embed_coh, intra_dist_1, intra_dist_2, intra_dist_3, response_length, # sent_entropy_uni, sent_entropy_bi, sent_entropy_tri, word_entropy_uni, word_entropy_bi, word_entropy_tri ]) if self.use_cuda: states = states.cuda() states = torch.cat([states, loss_desc, prob_desc, subtask_progress], dim=-1).unsqueeze(dim=0) return states def __uniform_weights(self): w = 1 / len(self.tasks) weights = torch.FloatTensor([w] * len(self.tasks)) if self.use_cuda: weights = weights.cuda() return weights.unsqueeze(dim=0) def __load_training_batch(self, observations): if observations and len( observations) > 0 and observations[0] and self.is_combine_attr: if not self.random_policy: with torch.no_grad(): current_states = self._build_states(observations) action_probs = self.policy(current_states) if self.action_log_time.time() > self.log_every_n_secs and len( self.tasks) > 1: with torch.no_grad(): # log the action distributions action_p = ','.join([ str(round_sigfigs(x, 4)) for x in action_probs[0].data.tolist() ]) log = '[ {} {} ]'.format('Action probs:', action_p) print(log) self.action_log_time.reset() sample_from = Categorical(action_probs[0]) action = sample_from.sample() train_step = observations[0]['train_step'] self.saved_actions[train_step] = sample_from.log_prob(action) self.saved_state_actions[train_step] = torch.cat( [current_states, action_probs], dim=1) selected_task = action.item() self.subtask_counter[self.subtasks[selected_task]] += 1 probs = action_probs[0].tolist() selection_report = {} for idx, t in enumerate(self.subtasks): selection_report['p_{}'.format(t)] = probs[idx] self.p_selections[t].append(probs[idx]) selection_report['c_{}'.format( t)] = self.subtask_counter[t] self.c_selections[t].append(self.subtask_counter[t]) self.writer.add_metrics(setting='Teacher/task_selection', step=train_step, report=selection_report) else: selected_task = random.choice(range(len(self.tasks))) self.subtask_counter[self.subtasks[selected_task]] += 1 else: selected_task = 0 return self.__load_batch(observations, task_idx=selected_task) def __load_batch(self, observations, task_idx=0): if observations is None: observations = [None] * self.bsz bsz = len(observations) batch = [] # Sample from multiple tasks using the policy net for idx in range(bsz): batch.append(self.act(observations[idx], task_idx=task_idx)) return batch def batch_act(self, observations): """ Returns an entire batch of examples instead of just one. """ if not hasattr(self, 'epochDone'): # reset if haven't yet self.reset() if self.opt['datatype'] == 'train': batch = self.__load_training_batch(observations) else: batch = self.__load_batch(observations) # pad batch if len(batch) < self.bsz: batch += [{ 'episode_done': True, 'id': self.getID() }] * (self.bsz - len(batch)) # remember correct answer if available (for padding, None) for i, ex in enumerate(batch): if 'labels' in ex: labels = ex['labels'] self.lastYs[i] = labels if not self.datatype.startswith( 'train') or 'evalmode' in self.datatype: del ex['labels'] if not self.opt.get('hide_labels', False): ex['eval_labels'] = labels else: self.lastYs[i] = ex.get('eval_labels', None) return batch def next_example(self, observation=None, task_idx=0): """ Returns the next example. If there are multiple examples in the same episode, returns the next one in that episode. If that episode is over, gets a new episode index and returns the first example of that episode. """ if self.stream: action, epoch_done = self.tasks[task_idx].get() else: if self.episode_done: self.episode_idx = self.next_episode_idx() self.entry_idx = 0 else: self.entry_idx += 1 if self.episode_idx >= self.num_episodes(): return {'episode_done': True}, True if observation is None or self.opt['datatype'] != 'train': # The first step of the training or validation mode sampled_episode_idx = self.episode_idx sampled_entry_idx = self.entry_idx else: # --------------- pick the sample according to the pace function ----------- pace_by = self.opt.get('pace_by', 'sample') if pace_by == 'sample': sum_num = self.num_episodes() elif pace_by == 'bucket': sum_num = len(self.bucket_cnt) else: raise ValueError('pace_by must be {} or {}!'.format( 'sample', 'bucket')) states4pace_func = observation if hasattr(self, 'subtask_counter'): states4pace_func = { 'train_step': self.subtask_counter[self.subtasks[task_idx]] } threshold = self.pace_function(states4pace_func, sum_num, self.T, self.c0, self.p) if pace_by == 'sample': stop_step = threshold elif pace_by == 'bucket': stop_step = sum(self.bucket_cnt[:threshold]) else: raise ValueError('pace_by must be {} or {}!'.format( 'sample', 'bucket')) stop_step = self.num_episodes( ) if stop_step > self.num_episodes() else stop_step # sampled_episode_idx = random.choice(list(range(self.num_episodes()))[:stop_step]) sampled_episode_idx = np.random.choice(stop_step) sampled_entry_idx = 0 # make sure the episode only contains one entry if self.anti: sampled_episode_idx = self.num_episodes( ) - 1 - sampled_episode_idx if self.count_sample: self.sample_counter[ self.subtasks[task_idx]][sampled_episode_idx] += 1 ex = self.get(sampled_episode_idx, sampled_entry_idx, task_idx=task_idx) if observation is None or self.opt['datatype'] != 'train': self.episode_done = ex.get('episode_done', False) if (not self.random and self.episode_done and self.episode_idx + self.opt.get("batchsize", 1) >= self.num_episodes()): epoch_done = True else: epoch_done = False else: # in the setting of curriculum leaning, samples are not uniformly # picked from the training set, so, the epoch records here make no sense. epoch_done = False action = ex return action, epoch_done def get(self, episode_idx, entry_idx=0, task_idx=0): return self.tasks[task_idx].get(episode_idx, entry_idx)[0] def update_params(self): self._number_teacher_updates += 1 if self.opt.get('gradient_clip_teacher', -1) > 0: torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.opt['gradient_clip_teacher']) self.optimizer.step() def update_critic_params(self): if self.opt.get('gradient_clip_teacher', -1) > 0: torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.opt['gradient_clip_teacher']) self.optimizer_critic.step() def receive_metrics(self, metrics_dict): if self.is_combine_attr and not self.random_policy: assert self.reward_metric in metrics_dict, '{} is not in the metrics_dict!'.format( self.reward_metric) self.prev_prev_valid_report = self.prev_valid_report self.prev_valid_report = self.current_valid_report self.current_valid_report = metrics_dict delt_reward = None if self.prev_prev_valid_report and self.prev_valid_report and self.current_valid_report: delt_reward1 = self.current_valid_report[ self.reward_metric] - self.prev_valid_report[ self.reward_metric] delt_reward0 = self.prev_valid_report[ self.reward_metric] - self.prev_prev_valid_report[ self.reward_metric] if self.reward_metric_mode == 'min': delt_reward1 = -delt_reward1 delt_reward0 = -delt_reward0 delt_reward = delt_reward1 / (delt_reward0 + 1e-6) - 1 if delt_reward and len(self.saved_actions) > 0 and len( self.saved_state_actions) > 0: reward = torch.clamp(torch.FloatTensor([delt_reward]), -10, 10) if self.use_cuda: reward = reward.cuda() with torch.no_grad(): batch_state_actions = torch.cat(list( self.saved_state_actions.values()), dim=0) if self.use_cuda: batch_state_actions = batch_state_actions.cuda() estimate_rewards = self.critic( batch_state_actions).squeeze() advantages = reward - estimate_rewards # rescale the rewards by ranking episode_len = len(advantages) ranks = torch.FloatTensor( list( reversed( ss.rankdata(advantages.cpu(), method='dense')))).unsqueeze(dim=1) rescaled_rewards = torch.sigmoid( 12 * (0.5 - ranks / episode_len)) rescaled_rewards = [r.item() for r in rescaled_rewards] policy_loss = [] idx = 0 for model_train_step, log_prob in self.saved_actions.items(): policy_loss.append(-log_prob.unsqueeze(dim=0) * rescaled_rewards[idx]) idx += 1 policy_loss = torch.cat(policy_loss).sum() # regularization term regarding action distribution bsz = batch_state_actions.size(0) action_probs = torch.cat(list( self.saved_state_actions.values()), dim=0).narrow(1, self.state_dim, self.action_dim) action_ent = torch.sum( -action_probs * torch.log(action_probs)) / bsz self.policy.train() self.optimizer.zero_grad() policy_loss = policy_loss + self.opt.get('reg_action', 0.001) * (-action_ent) policy_loss.backward() self.update_params() # lr_scheduler step on teacher loss policy_loss_item = policy_loss.item() if self.opt.get('optimizer_teacher', '') == 'sgd': self.scheduler.step(policy_loss_item) # training on the critic self.critic.train() self.optimizer_critic.zero_grad() batch_values = self.critic(batch_state_actions) critic_target = torch.FloatTensor(bsz, 1) critic_target = critic_target.fill_(reward.item()) if self.use_cuda: critic_target = critic_target.cuda() critic_loss = self.critic_criterion(batch_values, critic_target) critic_loss.backward() self.update_critic_params() critic_loss_item = critic_loss.item() if self.opt.get('optimizer_teacher', '') == 'sgd': self.scheduler_critic.step(critic_loss_item) # log something print( '[ reward: {}; mean_advantage_reward: {}; policy loss: {};' ' critic loss: {}; action ent: {}; episode length: {} ]'. format(reward.item(), np.mean(advantages.tolist()), policy_loss_item, critic_loss_item, action_ent.item(), len(self.saved_actions))) report = { 'reward': reward.item(), 'mean_advantage_reward': np.mean(advantages.tolist()), 'policy_loss': policy_loss_item, 'critic_loss': critic_loss_item, 'action_ent': action_ent.item(), } self.writer.add_metrics(setting='Teacher/receive_metrics', step=self._number_teacher_updates, report=report) # clear history actions self.saved_actions.clear() self.saved_state_actions.clear() def state_dict(self): """ Get the state dict for saving TODO: save more teacher-related states for reloading """ states = {} if hasattr(self, 'policy'): # save model params if hasattr(self.policy, 'module'): # did we wrap in a DistributedDataParallel states['policy'] = self.policy.module.state_dict() else: states['policy'] = self.policy.state_dict() if hasattr(self, 'critic'): # save model params if hasattr(self.critic, 'module'): # did we wrap in a DistributedDataParallel states['critic'] = self.critic.module.state_dict() else: states['critic'] = self.critic.state_dict() if hasattr(self, 'optimizer'): # save optimizer params states['optimizer'] = self.optimizer.state_dict() states['optimizer_type'] = self.opt['optimizer_teacher'] if hasattr(self, 'optimizer_critic'): states['optimizer_critic'] = self.optimizer_critic.state_dict() if getattr(self, 'scheduler', None): states['lr_scheduler'] = self.scheduler.state_dict() if getattr(self, 'scheduler_critic', None): states['lr_scheduler_critic'] = self.scheduler_critic.state_dict() states['prev_prev_valid_report'] = self.prev_prev_valid_report states['prev_valid_report'] = self.prev_valid_report states['current_valid_report'] = self.current_valid_report states['saved_actions'] = self.saved_actions states['saved_state_actions'] = self.saved_state_actions states['_number_teacher_updates'] = self._number_teacher_updates return states def save(self, path=None): if path: teacher_path = path else: model_file = self.opt.get('model_file', None) if model_file: teacher_path = model_file + '.teacher' else: teacher_path = None if teacher_path: states = self.state_dict() if states: with open(teacher_path, 'wb') as write: torch.save(states, write) # save opt file with open(teacher_path + '.opt', 'w', encoding='utf-8') as handle: json.dump(self.opt, handle) # for convenience of working with jq, make sure there's a newline handle.write('\n') if self.count_sample: # save sample count info for task_name, task_val in self.sample_counter.items(): with open(teacher_path + '.sample_count.{}'.format(task_name), 'w', encoding='utf-8') as f: f.write('\n'.join([str(item) for item in task_val])) self.write_selections('p_selections', teacher_path) self.write_selections('c_selections', teacher_path) def write_selections(self, selections, teacher_path): if hasattr(self, selections): with open(teacher_path + '.{}'.format(selections), 'w', encoding='utf-8') as f: f.write('\t'.join(self.subtasks)) f.write('\n') for idx in range( len(getattr(self, selections)[self.subtasks[0]])): p_line = [] for t in self.subtasks: p_line.append(str(getattr(self, selections)[t][idx])) f.write('\t'.join(p_line)) f.write('\n')
class TrainLoop(): def __init__(self, opt): if isinstance(opt, ParlaiParser): print( '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]' ) opt = opt.parse_args() # Possibly load from checkpoint if opt['load_from_checkpoint'] and opt.get( 'model_file') and os.path.isfile(opt['model_file'] + '.checkpoint'): opt['init_model'] = opt['model_file'] + '.checkpoint' # Possibly build a dictionary (not all models do this). if opt['dict_build_first'] and 'dict_file' in opt: if opt['dict_file'] is None and opt.get('model_file'): opt['dict_file'] = opt['model_file'] + '.dict' print("[ building dictionary first... ]") build_dict(opt, skip_if_built=True) # Create model and assign it to the specified task self.agent = create_agent(opt) self.world = create_task(opt, self.agent) self.train_time = Timer() self.validate_time = Timer() self.log_time = Timer() self.save_time = Timer() print('[ training... ]') self.parleys = 0 self.max_num_epochs = opt[ 'num_epochs'] if opt['num_epochs'] > 0 else float('inf') self.max_train_time = opt[ 'max_train_time'] if opt['max_train_time'] > 0 else float('inf') self.log_every_n_secs = opt['log_every_n_secs'] if opt[ 'log_every_n_secs'] > 0 else float('inf') self.val_every_n_secs = opt['validation_every_n_secs'] if opt[ 'validation_every_n_secs'] > 0 else float('inf') self.save_every_n_secs = opt['save_every_n_secs'] if opt[ 'save_every_n_secs'] > 0 else float('inf') self.val_every_n_epochs = opt['validation_every_n_epochs'] if opt[ 'validation_every_n_epochs'] > 0 else float('inf') self.last_valid_epoch = 0 self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.best_valid = None if opt.get('model_file') and os.path.isfile(opt['model_file'] + '.best_valid'): with open(opt['model_file'] + ".best_valid", 'r') as f: x = f.readline() self.best_valid = float(x) f.close() self.impatience = 0 self.saved = False self.valid_world = None self.opt = opt if opt['tensorboard_log'] is True: self.writer = TensorboardLogger(opt) def validate(self): opt = self.opt # run evaluation on valid set valid_report, self.valid_world = run_eval(self.agent, opt, 'valid', opt['validation_max_exs'], valid_world=self.valid_world) # logging if opt['tensorboard_log'] is True: self.writer.add_metrics('valid', int(math.floor(self.train_time.time())), valid_report) # saving if opt.get('model_file') and opt.get('save_after_valid'): print("[ saving model checkpoint: " + opt['model_file'] + ".checkpoint ]") self.agent.save(opt['model_file'] + '.checkpoint') # send valid metrics to agent if the agent wants them if hasattr(self.agent, 'receive_metrics'): self.agent.receive_metrics(valid_report) # check which metric to look at if '/' in opt['validation_metric']: # if you are multitasking and want your validation metric to be # a metric specific to a subtask, specify your validation metric # as -vmt subtask/metric subtask = opt['validation_metric'].split('/')[0] validation_metric = opt['validation_metric'].split('/')[1] new_valid = valid_report['tasks'][subtask][validation_metric] else: new_valid = valid_report[opt['validation_metric']] # check if this is the best validation so far if self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid: print('[ new best {}: {}{} ]'.format( opt['validation_metric'], new_valid, ' (previous best was {})'.format(self.best_valid) if self.best_valid is not None else '')) self.best_valid = new_valid self.impatience = 0 if opt.get('model_file'): print("[ saving best valid model: " + opt['model_file'] + " ]") self.agent.save(opt['model_file']) print("[ saving best valid metric: " + opt['model_file'] + ".best_valid ]") save_best_valid(opt['model_file'], self.best_valid) self.saved = True if opt['validation_metric'] == 'accuracy' and self.best_valid >= opt[ 'validation_cutoff']: print('[ task solved! stopping. ]') return True else: self.impatience += 1 print('[ did not beat best {}: {} impatience: {} ]'.format( opt['validation_metric'], round(self.best_valid, 4), self.impatience)) self.validate_time.reset() # check if we are out of patience if opt['validation_patience'] > 0 and self.impatience >= opt[ 'validation_patience']: print('[ ran out of patience! stopping training. ]') return True return False def log(self): opt = self.opt if opt['display_examples']: print(self.world.display() + '\n~~') logs = [] # get report train_report = self.world.report(compute_time=True) self.world.reset_metrics() # time elapsed logs.append('time:{}s'.format(math.floor(self.train_time.time()))) total_exs = self.world.get_total_exs() logs.append('total_exs:{}'.format(total_exs)) exs_per_ep = self.world.num_examples() if exs_per_ep: logs.append('epochs:{}'.format(round(total_exs / exs_per_ep, 2))) if 'time_left' in train_report: logs.append('time_left:{}s'.format( math.floor(train_report.pop('time_left', "")))) log = '[ {} ] {}'.format(' '.join(logs), train_report) print(log) self.log_time.reset() if opt['tensorboard_log'] is True: self.writer.add_metrics('train', int(logs[1].split(":")[1]), train_report) def train(self): opt = self.opt world = self.world with world: while True: # do one example / batch of examples world.parley() self.parleys += 1 # check counters and timers if world.get_total_epochs() >= self.max_num_epochs: self.log() print( '[ num_epochs completed:{} time elapsed:{}s ]'.format( self.max_num_epochs, self.train_time.time())) break if self.train_time.time() > self.max_train_time: print('[ max_train_time elapsed:{}s ]'.format( self.train_time.time())) break if self.log_time.time() > self.log_every_n_secs: self.log() if self.validate_time.time() > self.val_every_n_secs: stop_training = self.validate() if stop_training: break if world.get_total_epochs( ) - self.last_valid_epoch >= self.val_every_n_epochs: stop_training = self.validate() self.last_valid_epoch = world.get_total_epochs() if stop_training: break if self.save_time.time() > self.save_every_n_secs and opt.get( 'model_file'): print("[ saving model checkpoint: " + opt['model_file'] + ".checkpoint ]") self.agent.save(opt['model_file'] + '.checkpoint') self.save_time.reset() if not self.saved: # save agent self.agent.save(opt['model_file']) elif opt.get('model_file'): # reload best validation model self.agent = create_agent(opt) v_report, v_world = run_eval(self.agent, opt, 'valid', write_log=True) t_report, t_world = run_eval(self.agent, opt, 'test', write_log=True) v_world.shutdown() t_world.shutdown() return v_report, t_report
class TrainLoop: """ TrainLoop contains the core training loop logic. """ def __init__(self, opt): # if python is called from a non-interactive shell, like a bash script, # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are # not produced. This line brings them back signal.signal(signal.SIGINT, signal.default_int_handler) if isinstance(opt, ParlaiParser): print( '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]' ) opt = opt.parse_args() # Possibly load from checkpoint trainstats_suffix = '.trainstats' # we might load training statistics from here if (opt['load_from_checkpoint'] and opt.get('model_file') and os.path.isfile(opt['model_file'] + '.checkpoint')): opt['init_model'] = opt['model_file'] + '.checkpoint' trainstats_suffix = '.checkpoint.trainstats' # Possibly build a dictionary (not all models do this). if not (opt.get('dict_file') or opt.get('model_file')): raise RuntimeError( 'WARNING: For train_model, please specify either a ' 'model_file or dict_file.') if 'dict_file' in opt: if opt['dict_file'] is None and opt.get('model_file'): opt['dict_file'] = opt['model_file'] + '.dict' print("[ building dictionary first... ]") build_dict(opt, skip_if_built=True) # Create model and assign it to the specified task self.agent = create_agent(opt) self.world = create_task(opt, self.agent) # set up timers self.train_time = Timer() self.validate_time = Timer() self.log_time = Timer() self.save_time = Timer() print('[ training... ]') self.parleys = 0 self.max_num_epochs = (opt['num_epochs'] if opt['num_epochs'] > 0 else float('inf')) self.max_train_time = (opt['max_train_time'] if opt['max_train_time'] > 0 else float('inf')) self.log_every_n_secs = (opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 else float('inf')) self.val_every_n_secs = (opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 else float('inf')) self.save_every_n_secs = (opt['save_every_n_secs'] if opt['save_every_n_secs'] > 0 else float('inf')) self.val_every_n_epochs = (opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 else float('inf')) # smart defaults for --validation-metric-mode if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}: opt['validation_metric_mode'] = 'min' elif opt['validation_metric'] in { 'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu' }: opt['validation_metric_mode'] = 'max' if opt.get('validation_metric_mode') is None: opt['validation_metric_mode'] = 'max' self.last_valid_epoch = 0 self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.valid_reports = [] self.best_valid = None self.impatience = 0 self.saved = False self.valid_worlds = None self.opt = opt # we may have been preempted, make sure we note that amount self._preempted_epochs = 0.0 if opt.get('model_file') and os.path.isfile(opt['model_file'] + trainstats_suffix): # looks like we were preempted. make sure we load up our total # training stats, etc with open(opt['model_file'] + trainstats_suffix) as ts: obj = json.load(ts) self.parleys = obj.get('parleys', 0) self._preempted_epochs = obj.get('total_epochs', 0) self.train_time.total = obj.get('train_time', 0) self.impatience = obj.get('impatience', 0) self.valid_reports = obj.get('valid_reports', []) if 'best_valid' in obj: self.best_valid = obj['best_valid'] else: # old method if opt.get('model_file') and os.path.isfile( opt['model_file'] + '.best_valid'): with open(opt['model_file'] + ".best_valid", 'r') as f: x = f.readline() self.best_valid = float(x) f.close() if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger = TensorboardLogger(opt) def save_model(self, suffix=None): """ Save the model to disk, possibly with a suffix. """ if not is_primary_worker(): # never do IO as a non-primary worker return if not self.opt.get('model_file'): # nothing to save to, just exit return fn = self.opt['model_file'] if suffix: fn += suffix while True: # don't ever let a ctrl-c interrupt saving try: self.agent.save(fn) self._save_train_stats(suffix) break except KeyboardInterrupt: pass def _safe_report(self, report): return { k: v.value() if isinstance(v, Metric) else v for k, v in report.items() } def _save_train_stats(self, suffix=None): fn = self.opt['model_file'] if suffix: fn += suffix fn += '.trainstats' with open(fn, 'w') as f: json.dump( { 'parleys': self.parleys, 'train_time': self.train_time.time(), 'total_epochs': (self._preempted_epochs + num_workers() * self.world.get_total_epochs()), 'impatience': self.impatience, 'valid_reports': [self._safe_report(v) for v in self.valid_reports], 'best_valid': self.best_valid, }, f, ) def validate(self): """ Perform a validation run, checking whether we should stop training. :return: boolean indicating whether training should stop :rtype: bool """ opt = self.opt if self.valid_worlds is None: # we need to load the world now self.valid_worlds = load_eval_worlds(self.agent, opt, 'valid') # run evaluation on valid set # TODO(MW): replace sync_object with self._sync_metrics. You'll need some # logic to handle 'validation_max_exs' properly valid_report = run_eval(self.valid_worlds, opt, 'valid', opt['validation_max_exs']) v = valid_report.copy() v['train_time'] = self.train_time.time() self.valid_reports.append(v) # logging if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.log_metrics('valid', self.parleys, valid_report) # flush on a validation self.tb_logger.flush() # saving if (opt.get('model_file') and opt.get('save_after_valid') and is_primary_worker()): print("[ saving model checkpoint: " + opt['model_file'] + ".checkpoint ]") self.save_model('.checkpoint') # send valid metrics to agent if the agent wants them if hasattr(self.agent, 'receive_metrics'): self.agent.receive_metrics(valid_report) # check which metric to look at new_valid = valid_report[opt['validation_metric']] if isinstance(new_valid, Metric): new_valid = new_valid.value() # check if this is the best validation so far if (self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid): print('[ new best {}: {}{} ]'.format( opt['validation_metric'], new_valid, ' (previous best was {})'.format(self.best_valid) if self.best_valid is not None else '', )) self.best_valid = new_valid self.impatience = 0 if opt.get('model_file') and is_primary_worker(): print("[ saving best valid model: " + opt['model_file'] + " ]") self.save_model() self.saved = True if (opt['validation_metric'] == 'accuracy' and self.best_valid >= opt['validation_cutoff']): print('[ task solved! stopping. ]') return True else: self.impatience += 1 print('[ did not beat best {}: {} impatience: {} ]'.format( opt['validation_metric'], round(self.best_valid, 4), self.impatience)) self.validate_time.reset() # check if we are out of patience if (opt['validation_patience'] > 0 and self.impatience >= opt['validation_patience']): print('[ ran out of patience! stopping training. ]') return True return False def _sync_metrics(self, metrics): """ Sync training metrics across workers. A handful of special cases are handled as exceptions, and the remaining metrics are simply averaged across workers. """ if not is_distributed(): # nothing special needed return metrics all_versions = all_gather_list(metrics) return aggregate_unnamed_reports(all_versions) def _compute_eta(self, epochs_completed, time_elapsed): """ Compute the estimated seconds remaining in training. :param float epochs_completed: number of epochs already completed. :param float time_elapsed: total time spent already, in seconds. :return: ETA in seconds, or None if not computable """ # start off with no estimate eta = None # Determine time_left and num_epochs max_epochs = self.opt.get('num_epochs', 0) if max_epochs > 0 and epochs_completed > 0: epoch_progress = epochs_completed / max_epochs eta = (1 - epoch_progress) * time_elapsed / epoch_progress max_training_time = self.opt.get('max_training_time', -1) if max_training_time > 0: time_left = max_training_time - time_elapsed if eta is None or time_left < eta: eta = time_left return eta def log(self): """ Output a training log entry. """ opt = self.opt if opt['display_examples']: print(self.world.display() + '\n~~') logs = [] # get report train_report = self.world.report() train_report = self._sync_metrics(train_report) self.world.reset_metrics() # time elapsed logs.append('time:{}s'.format(np.floor(self.train_time.time()))) logs.append('total_exs:{}'.format(self._total_exs)) if self._total_epochs >= 0: # only if it's unbounded logs.append('epochs:{}'.format(round(self._total_epochs, 2))) time_left = self._compute_eta(self._total_epochs, self.train_time.time()) if time_left is not None: logs.append('time_left:{}s'.format(max(0, np.ceil(time_left)))) log = '[ {} ] {}'.format(' '.join(logs), nice_report(train_report)) print(log) self.log_time.reset() if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger.log_metrics('train', self.parleys, train_report) def train(self): """ Perform a training run. :return: tuple of reports (validation_report, test_report) """ opt = self.opt world = self.world count = 0 with world: while True: # do one example / batch of examples try: world.parley() except StopTrainException: if is_distributed(): raise RuntimeError( "StopTrainException not supported for " "distributed mode") break self.parleys += 1 # get the total training examples done, compute epochs self._total_epochs = ( self._preempted_epochs + num_workers() * self.world.get_total_epochs()) exs_per_epoch = self.world.num_examples() self._total_exs = int( np.round(self._total_epochs * exs_per_epoch)) # and use the primary worker's timings for everything train_time, log_time, validate_time = sync_object(( self.train_time.time(), self.log_time.time(), self.validate_time.time(), )) # check counters and timers if self._total_epochs >= self.max_num_epochs: self.log() print( '[ num_epochs completed:{} time elapsed:{}s ]'.format( self.max_num_epochs, train_time)) break if train_time > self.max_train_time: print('[ max_train_time elapsed:{}s ]'.format(train_time)) break if log_time > self.log_every_n_secs: self.log() if (validate_time > self.val_every_n_secs or self._total_epochs - self.last_valid_epoch >= self.val_every_n_epochs): try: stop_training = self.validate() except StopTrainException: if is_distributed(): raise RuntimeError( "StopTrainException not " "supported for distributed mode") break self.last_valid_epoch = self._total_epochs if stop_training: break if (self.save_time.time() > self.save_every_n_secs and opt.get('model_file') and is_primary_worker()): print("[ saving model checkpoint: {}.checkpoint".format( opt['model_file'])) self.save_model('.checkpoint') self.save_time.reset() if not self.saved and is_primary_worker(): # save agent self.save_model() elif opt.get('model_file'): # reload best validation model self.agent = create_agent(opt) valid_worlds = load_eval_worlds(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get( 'short_final_eval') else -1 v_report = run_eval(valid_worlds, opt, 'valid', max_exs, write_log=True) test_worlds = load_eval_worlds(self.agent, opt, 'test') t_report = run_eval(test_worlds, opt, 'test', max_exs, write_log=True) if valid_worlds: for valid_world in valid_worlds: valid_world.shutdown() if test_worlds: for test_world in test_worlds: test_world.shutdown() print_announcements(opt) return v_report, t_report
def setup_args(parser=None): if parser is None: parser = ParlaiParser(True, True, 'Evaluate a model') # Get command line arguments parser.add_argument( '-rf', '--report-filename', type=str, default='', help='Saves a json file of the evaluation report either as an ' 'extension to the model-file (if begins with a ".") or a whole ' 'file path. Set to the empty string to not save at all.', ) parser.add_argument( '--world-logs', type=str, default='', help='Saves a jsonl file of the world logs.' 'Set to the empty string to not save at all.', ) parser.add_argument( '--save-format', type=str, default='conversations', choices=['conversations', 'parlai'], ) parser.add_argument( '--area-under-curve-digits', '-auc', type=int, default=-1, help= 'a positive number indicates to calculate the area under the roc curve and it also determines how many decimal digits of the predictions to keep (higher numbers->more precise); also used to determine whether or not to calculate the AUC metric', ) parser.add_argument( '--area-under-curve-class', '-auclass', type=str, default=None, nargs='*', help='the name(s) of the class to calculate the auc for', ) parser.add_argument('-ne', '--num-examples', type=int, default=-1) parser.add_argument('-d', '--display-examples', type='bool', default=False) parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=10) parser.add_argument( '-mcs', '--metrics', type=str, default='default', help='list of metrics to show/compute, e.g. all, default,' 'or give a list split by , like ' 'ppl,f1,accuracy,hits@1,rouge,bleu' 'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l', ) parser.add_argument( '-micro', '--aggregate-micro', type='bool', default=False, help='Report micro-averaged metrics instead of macro averaged metrics.', recommended=False, ) WorldLogger.add_cmdline_args(parser, partial_opt=None) TensorboardLogger.add_cmdline_args(parser, partial_opt=None) parser.set_params(datatype='valid') return parser
def __init__(self, opt): # if python is called from a non-interactive shell, like a bash script, # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are # not produced. This line brings them back signal.signal(signal.SIGINT, signal.default_int_handler) # Possibly load from checkpoint trainstats_suffix = '.trainstats' # we might load training statistics from here if ( opt['load_from_checkpoint'] and opt.get('model_file') and PathManager.exists(opt['model_file'] + '.checkpoint') ): opt['init_model'] = opt['model_file'] + '.checkpoint' trainstats_suffix = '.checkpoint.trainstats' # Possibly build a dictionary (not all models do this). if not (opt.get('dict_file') or opt.get('model_file')): raise RuntimeError( 'WARNING: For train_model, please specify either a ' 'model_file or dict_file.' ) if 'dict_file' in opt: if opt['dict_file'] is None and opt.get('model_file'): opt['dict_file'] = opt['model_file'] + '.dict' logging.info("building dictionary first...") build_dict(opt, skip_if_built=True) # Create model and assign it to the specified task self.agent = create_agent(opt) self.agent.opt.log() self.world = create_task(opt, self.agent) # set up timers self.train_time = Timer() self.validate_time = Timer() self.log_time = Timer() self.save_time = Timer() self.parleys = 0 self._train_steps = 0 self._last_log_steps = 0 self.update_freq = opt.get('update_freq', 1) self.max_num_epochs = _num_else_inf(opt, 'num_epochs', distributed_warn=True) self.max_train_time = _num_else_inf( opt, 'max_train_time', distributed_warn=True ) self.max_train_steps = _num_else_inf(opt, 'max_train_steps') self.log_every_n_secs = _num_else_inf( opt, 'log_every_n_secs', distributed_warn=True ) self.log_every_n_steps = _num_else_inf(opt, 'log_every_n_steps') self.val_every_n_secs = _num_else_inf( opt, 'validation_every_n_secs', distributed_warn=True ) self.val_every_n_epochs = _num_else_inf( opt, 'validation_every_n_epochs', distributed_warn=True ) self.val_every_n_steps = _num_else_inf(opt, 'validation_every_n_steps') self.save_every_n_secs = _num_else_inf( opt, 'save_every_n_secs', distributed_warn=True ) # smart defaults for --validation-metric-mode if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}: opt['validation_metric_mode'] = 'min' elif opt['validation_metric'] in {'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'}: opt['validation_metric_mode'] = 'max' if opt.get('validation_metric_mode') is None: opt['validation_metric_mode'] = 'max' self.last_valid_epoch = 0 self._last_valid_steps = 0 self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.train_reports = [] self.valid_reports = [] self.final_valid_report = {} self.final_test_report = {} self.final_extra_valid_report = {} self.best_valid = None self.impatience = 0 self.saved = False self.valid_worlds = None self.opt = opt # we may have been preempted, make sure we note that amount self._preempted_epochs = 0.0 if opt.get('model_file') and PathManager.exists( opt['model_file'] + trainstats_suffix ): # looks like we were preempted. make sure we load up our total # training stats, etc with PathManager.open(opt['model_file'] + trainstats_suffix) as ts: obj = json.load(ts) self.parleys = obj.get('parleys', 0) self._preempted_epochs = obj.get('total_epochs', 0) self.train_time.total = obj.get('train_time', 0) self._train_steps = obj.get('train_steps', 0) self.impatience = obj.get('impatience', 0) self.valid_reports = obj.get('valid_reports', []) if self.valid_reports: self.last_valid_epoch = self.valid_reports[-1].get( 'total_epochs', 0.0 ) self.train_reports = obj.get('train_reports', []) if 'best_valid' in obj: self.best_valid = obj['best_valid'] else: # old method if opt.get('model_file') and PathManager.exists( opt['model_file'] + '.best_valid' ): with PathManager.open( opt['model_file'] + ".best_valid", 'r' ) as f: x = f.readline() self.best_valid = float(x) f.close() if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger = TensorboardLogger(opt) if opt['wandb_log'] and is_primary_worker(): model = self.agent.model if hasattr(self.agent, 'model') else None self.wb_logger = WandbLogger(opt, model)
class TrainLoop(): def __init__(self, opt): if isinstance(opt, ParlaiParser): opt = opt.parse_args() # Possibly build a dictionary (not all models do this). if opt['dict_build_first'] and 'dict_file' in opt: if opt['dict_file'] is None and opt.get( 'model_file_transmitter') and opt.get( 'model_file_receiver'): opt['dict_file'] = opt['model_file_transmitter'] + '_' + opt[ 'model_file_receiver'] + '.dict' print("[ building dictionary first... ]") build_dict(opt, skip_if_built=False) # Create model and assign it to the specified task print("[ create meta-agent ... ]") self.agent = create_agent(opt) print("[ create agent A ... ]") shared = self.agent.share() self.agent_a = create_agent_from_shared(shared) self.agent_a.set_id(suffix=' A') print("[ create agent B ... ]") self.agent_b = create_agent_from_shared(shared) # self.agent_b = create_agent(opt) self.agent_b.set_id(suffix=' B') # self.agent_a.copy(self.agent, 'transmitter') # self.agent_b.copy(self.agent, 'transmitter') self.world = create_selfplay_world(opt, [self.agent_a, self.agent_b]) # TODO: if batch, it is also not parallel # self.world = BatchSelfPlayWorld(opt, self_play_world) self.train_time = Timer() self.train_dis_time = Timer() self.validate_time = Timer() self.log_time = Timer() self.save_time = Timer() print('[ training... ]') self.parleys_episode = 0 self.max_num_epochs = opt[ 'num_epochs'] if opt['num_epochs'] > 0 else float('inf') self.max_train_time = opt[ 'max_train_time'] if opt['max_train_time'] > 0 else float('inf') self.log_every_n_secs = opt['log_every_n_secs'] if opt[ 'log_every_n_secs'] > 0 else float('inf') self.train_dis_every_n_secs = opt['train_display_every_n_secs'] if opt[ 'train_display_every_n_secs'] > 0 else float('inf') self.val_every_n_secs = opt['validation_every_n_secs'] if opt[ 'validation_every_n_secs'] > 0 else float('inf') self.save_every_n_secs = opt['save_every_n_secs'] if opt[ 'save_every_n_secs'] > 0 else float('inf') self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.best_valid = None if opt.get('model_file_transmitter') and os.path.isfile( opt['model_file_transmitter'] + '.best_valid'): with open(opt['model_file_transmitter'] + ".best_valid", 'r') as f: x = f.readline() self.best_valid = float(x) f.close() self.impatience = 0 self.saved = False self.valid_world = None self.opt = opt if opt['tensorboard_log'] is True: self.writer = TensorboardLogger(opt) def validate(self): opt = self.opt valid_report, self.valid_world = run_eval(self.agent, opt, 'valid', opt['validation_max_exs'], valid_world=self.valid_world) if opt['tensorboard_log'] is True: self.writer.add_metrics('valid', self.parleys_episode, valid_report) if opt.get('model_file_transmitter') and opt.get('save_after_valid'): print("[ saving transmitter checkpoint: " + opt['model_file_transmitter'] + ".checkpoint ]") self.agent.save(component='transmitter') # if opt.get('model_file_receiver') and opt.get('save_after_valid'): # print("[ saving receiver checkpoint: " + opt['model_file_receiver'] + ".checkpoint ]") # self.agent.save(component='receiver') if hasattr(self.agent, 'receive_metrics'): self.agent.receive_metrics(valid_report) if '/' in opt['validation_metric']: # if you are multitasking and want your validation metric to be # a metric specific to a subtask, specify your validation metric # as -vmt subtask/metric subtask = opt['validation_metric'].split('/')[0] validation_metric = opt['validation_metric'].split('/')[1] new_valid = valid_report['tasks'][subtask][validation_metric] else: new_valid = valid_report[opt['validation_metric']] if self.best_valid is None or self.valid_optim * new_valid > self.valid_optim * self.best_valid: print('[ new best {}: {}{} ]'.format( opt['validation_metric'], new_valid, ' (previous best was {})'.format(self.best_valid) if self.best_valid is not None else '')) self.best_valid = new_valid self.impatience = 0 if opt.get('model_file'): print("[ saving best valid model: " + opt['model_file'] + " ]") # the fine-tuned transmitter part is actually what we want for PSquare bot self.agent.save() print("[ saving best valid metric: " + opt['model_file'] + ".best_valid ]") save_best_valid(opt['model_file'], self.best_valid) self.saved = True if opt['validation_metric'] == 'accuracy' and self.best_valid >= opt[ 'validation_cutoff']: print('[ task solved! stopping. ]') return True else: self.impatience += 1 print('[ did not beat best {}: {} impatience: {} ]'.format( opt['validation_metric'], round(self.best_valid, 4), self.impatience)) self.validate_time.reset() if 0 < opt['validation_patience'] <= self.impatience: print('[ ran out of patience! stopping training. ]') return True return False def log(self): opt = self.opt if opt['display_examples']: print(self.world.display() + '\n~~') logs = [] # get report train_report = self.world.report() self.world.reset_metrics() # time elapsed logs.append('time:{}s'.format(math.floor(self.train_time.time()))) logs.append('parleys:{}'.format(self.parleys_episode)) if 'time_left' in train_report: logs.append('time_left:{}s'.format( math.floor(train_report.pop('time_left', "")))) if 'num_epochs' in train_report: logs.append('num_epochs:{}'.format( train_report.pop('num_epochs', ''))) log = '[ {} ] {}'.format(' '.join(logs), train_report) print(log) self.log_time.reset() if opt['tensorboard_log'] is True: self.writer.add_metrics('train', self.parleys_episode, train_report) def train(self): # print('#### Validating at {} training episode '.format(self.parleys_episode)) # self.validate() opt = self.opt world = self.world with world: while True: self.parleys_episode += 1 if self.parleys_episode % 100 == 0: print('#### Training {} episode '.format( self.parleys_episode)) if self.train_dis_time.time() > self.train_dis_every_n_secs: is_display = True # clear to zero self.train_dis_time.reset() else: is_display = False world.parley_episode(is_training=True, is_display=is_display) if world.get_total_epochs() >= self.max_num_epochs: self.log() print( '[ num_epochs completed:{} time elapsed:{}s ]'.format( self.max_num_epochs, self.train_time.time())) break if self.train_time.time() > self.max_train_time: print('[ max_train_time elapsed:{}s ]'.format( self.train_time.time())) break if self.log_time.time() > self.log_every_n_secs: self.log() if self.validate_time.time() > self.val_every_n_secs: print('#### Validating at {} training episode '.format( self.parleys_episode)) stop_training = self.validate() if stop_training: break if self.save_time.time() > self.save_every_n_secs: if opt.get('model_file_transmitter'): print("[ saving transmitter checkpoint: " + opt['model_file_transmitter'] + ".checkpoint ]") self.agent.save(opt['model_file_transmitter'] + '.checkpoint', component='transmitter') if opt.get('model_file_receiver'): print("[ saving receiver checkpoint: " + opt['model_file_receiver'] + ".checkpoint ]") self.agent.save(opt['model_file_receiver'] + '.checkpoint', component='receiver') self.save_time.reset() if not self.saved: # save agent # self.agent.save(component='transmitter') self.agent.save() # self.agent.save(component='receiver') # TODO: API for save all components elif opt.get('model_file_transmitter') and opt.get( 'model_file_receiver' ): # TODO: check if both components are necessary # reload best validation model self.agent = create_agent(opt) v_report, v_world = run_eval(self.agent, opt, 'valid', write_log=True) t_report, t_world = run_eval(self.agent, opt, 'test', write_log=True) v_world.shutdown() t_world.shutdown() return v_report, t_report
def setup_args(parser=None) -> ParlaiParser: """ Build the ParlAI parser, adding command line args if necessary. :param ParlaiParser parser: Preexisting parser to append options to. Will be created if needed. :returns: the ParlaiParser with CLI options added. """ if parser is None: parser = ParlaiParser(True, True, 'Train a model') train = parser.add_argument_group('Training Loop Arguments') train.add_argument( '-et', '--evaltask', help='task to use for valid/test (defaults to the one used for training)', ) train.add_argument( '--final-extra-opt', type=str, default='', help="A '.opt' file that is used for final eval. Useful for setting skip-generation to false. 'datatype' must be included as part of the opt.", ) train.add_argument( '--eval-batchsize', type=int, hidden=True, help='Eval time batch size (defaults to same as -bs)', ) train.add_argument( '--eval-dynamic-batching', # FIXME: see https://github.com/facebookresearch/ParlAI/issues/3367 default=None, type='nonestr', choices={None, 'off', 'full', 'batchsort'}, help=( 'Set dynamic batching at evaluation time. Set to off for ' 'train-only dynamic batching. Set to none (default) to use same ' 'setting as --dynamic-batching.' ), ) train.add_argument( '--num-workers', default=0, type=int, help='Number of background workers (training only)', ) train.add_argument('--display-examples', type='bool', default=False, hidden=True) train.add_argument('-eps', '--num-epochs', type=float, default=-1) train.add_argument('-ttim', '--max-train-time', type=float, default=-1) train.add_argument( '-tstep', '--max-train-steps', '--max-lr-steps', type=int, default=-1, help='End training after n model updates', ) train.add_argument('-ltim', '--log-every-n-secs', type=float, default=-1) train.add_argument( '-lstep', '--log-every-n-steps', type=int, default=50, help='Log every n training steps', ) train.add_argument( '-vtim', '--validation-every-n-secs', type=float, default=-1, help='Validate every n seconds. Saves model to model_file ' '(if set) whenever best val metric is found', ) train.add_argument( '-vstep', '--validation-every-n-steps', type=int, default=-1, help='Validate every n training steps. Saves model to model_file ' '(if set) whenever best val metric is found', ) train.add_argument( '-stim', '--save-every-n-secs', type=float, default=-1, help='Saves the model to model_file.checkpoint after ' 'every n seconds (default -1, never).', ) train.add_argument( '-sval', '--save-after-valid', type='bool', default=False, help='Saves the model to model_file.checkpoint after ' 'every validation (default %(default)s).', ) train.add_argument( '-veps', '--validation-every-n-epochs', type=float, default=-1, help='Validate every n epochs. Saves model to model_file ' '(if set) whenever best val metric is found', ) train.add_argument( '-vme', '--validation-max-exs', type=int, default=-1, hidden=True, help='max examples to use during validation (default -1 uses all)', ) train.add_argument( '--short-final-eval', default=False, hidden=True, type='bool', help='If true, obeys --validation-max-exs in the final ' 'validation and test evaluations.', ) train.add_argument( '-vp', '--validation-patience', type=int, default=10, help=( 'number of iterations of validation where result' ' does not improve before we stop training' ), ) train.add_argument( '-vmt', '--validation-metric', default='accuracy', help='key into report table for selecting best validation', ) train.add_argument( '-vmm', '--validation-metric-mode', type=str, choices=['max', 'min'], help='the direction in which to optimize the validation metric, i.e. maximize or minimize', ) train.add_argument( '-vcut', '--validation-cutoff', type=float, default=1.0, hidden=True, help='value at which training will stop if exceeded by metric', ) train.add_argument( '-lfc', '--load-from-checkpoint', type='bool', default=True, hidden=True, help='load model from checkpoint if available', ) train.add_argument( '-vshare', '--validation-share-agent', default=False, hidden=True, help='use a shared copy of the agent for validation. ' 'this will eventually default to True, but ' 'currently defaults to False.', ) train.add_argument( '-mcs', '--metrics', type=str, default='default', help='list of metrics to show/compute, e.g. all, default,' 'or give a list split by , like ' 'ppl,f1,accuracy,hits@1,rouge,bleu' 'the rouge metrics will be computed as rouge-1, rouge-2 and rouge-l', ) train.add_argument( '-micro', '--aggregate-micro', type='bool', default=False, help='Report micro-averaged metrics instead of macro averaged metrics.', recommended=False, ) train.add_argument( '--world-logs', type=str, default='', help='Saves a jsonl file of the world logs.' 'Set to the empty string to not save at all.', ) train.add_argument( '--save-format', type=str, default='conversations', choices=['conversations', 'parlai'], ) WorldLogger.add_cmdline_args(parser, partial_opt=None) TensorboardLogger.add_cmdline_args(parser, partial_opt=None) WandbLogger.add_cmdline_args(parser, partial_opt=None) parser = setup_dict_args(parser) return parser
class EntNetAgent(TorchAgent): @staticmethod def add_cmdline_args(argparser): TorchAgent.add_cmdline_args(argparser) agent = argparser.add_argument_group("EntNet Arguments") agent.add_argument("-wt", "--weight-tying", type=str, default="layer-wise", help="Type of weight tying") agent.add_argument("-nmh", "--num-memory-hops", type=int, default=3, help="Number of memory hops") EntNetAgent.dictionary_class().add_cmdline_args(argparser) return agent def __init__(self, opt, shared=None): super().__init__(opt, shared) if opt['tensorboard_log'] is True: self.writer = TensorboardLogger(opt) self.dictionnary_size = 177 self.embedding_dim = 100 self.batch_size = opt["batchsize"] self.criterion = nn.CrossEntropyLoss() def weight_init(m): if isinstance(m, nn.Linear): nn.init.xavier_normal_(m.weight.data) self.recurrent_entity_network = RecurrentEntityNetwork( self.dictionnary_size, self.embedding_dim, sequence_length=7) self.recurrent_entity_network.apply(weight_init) self.optimizer = optim.Adam(self.recurrent_entity_network.parameters()) #self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 25, 0.5) self.batch_iter = 0 def vectorize(self, *args, **kwargs): """Override options in vectorize from parent.""" kwargs['add_start'] = False kwargs['add_end'] = False kwargs['split_lines'] = True return super().vectorize(*args, **kwargs) def train_step(self, batch): #self.scheduler.step() self.recurrent_entity_network.train() questions, answers = batch.text_vec, batch.label_vec contexts = padded_3d(batch.memory_vecs) loss = 0 self.optimizer.zero_grad() output = self.recurrent_entity_network(questions, contexts) pred = output.argmax(dim=1) loss = self.criterion(output, answers.squeeze(1)) losses.append(loss.item()) self.writer.add_scalar("data/loss", loss, self.batch_iter) for name, param in self.recurrent_entity_network.named_parameters(): self.writer.add_histogram(name, param.clone().cpu().data.numpy(), self.batch_iter) #self.writer.add_histogram(name + "_grad", param.grad.clone().cpu().data.numpy(), self.batch_iter) # for memory_hop_layer in self.stacked_memory_hop.memory_hop_layers: # for name_in, param_in in memory_hop_layer.named_parameters(): # self.writer.add_histogram(name_in, param_in.clone().cpu().data.numpy(), self.batch_iter) # #self.writer.add_histogram(name_in + "_grad", param_in.grad.clone().cpu().data.numpy(), self.batch_iter) #print("Loss : ", loss.item()) #self.writer.add_histogram("predictions", output.clone().cpu().data.numpy(), self.batch_iter) loss.backward(retain_graph=True) self.optimizer.step() self.batch_iter += 1 return Output(self.dict.vec2txt(pred).split(" ")) def eval_step(self, batch): questions = batch.text_vec contexts = padded_3d(batch.memory_vecs) if contexts.shape[0] != self.batch_size: return Output( self.dict.vec2txt( np.random.choice(self.dictionnary_size, size=contexts.shape[0])).split(" ")) output = self.recurrent_entity_network(questions, contexts) pred = output.argmax(dim=1) return Output(self.dict.vec2txt(pred).split(" "))