Exemple #1
0
 def __init__(self, optAgent, shared=None):
     init_model, is_finetune = self._get_init_model(optAgent, shared)
     super().__init__(optAgent, shared)
     if optAgent.get('numthreads', 1) > 1:
         torch.set_num_threads(1)
     optAgent['gradient_clip'] = opt.maxgrad
     self.criterion = opt.criterion
     self.loss = opt.loss
     self.drawVars = opt.drawVars
     opt.edim = optAgent['embeddingsize']
     opt.vocabsize = len(self.dict)
     opt.__dict__.update(optAgent)
     opt.agent = self
     opt.fp16 = self.fp16
     torch.manual_seed(args.rank)
     np.random.seed(args.rank)
     self.writeVars = 0
     self.vars = {}
     if optAgent['tensorboard_log']:
         self.writeVars, *_ = getWriter(writer=TensorboardLogger(optAgent))
     if self.fp16:
         try:
             from apex import amp
         except ImportError:
             raise ImportError(
                 'No fp16 support without apex. Please install it from '
                 'https://github.com/NVIDIA/apex')
         self.getParameters = lambda: amp.master_params(self.optimizer)
         self.amp = amp
     else:
         self.getParameters = lambda: self.model.parameters()
     if not shared:
         model = Model(opt)
         self.model = model
         if init_model:
             print('Loading existing model parameters from ' + init_model)
             states = self.load(init_model)
         else:
             states = {}
             initParameters(opt, self.model)
         if self.use_cuda:
             self.model.cuda()
         self.model.train()
         if optAgent.get('numthreads', 1) > 1:
             self.model.share_memory()
         paramOptions = getParamOptions(opt, self.model)
         self.init_optim(paramOptions, states.get('optimizer'),
                         states.get('saved_optim_type', None))
         self.build_lr_scheduler(states, hard_reset=is_finetune)
         if is_distributed():
             self.model = nn.parallel.DistributedDataParallel(
                 self.model,
                 device_ids=[self.opt['gpu']],
                 broadcast_buffers=False)
         self.reset()
     else:
         self.model = shared['model']
         self.dict = shared['dict']
         if 'optimizer' in shared:
             self.optimizer = shared['optimizer']
Exemple #2
0
 def __init__(self, opt):
     if isinstance(opt, ParlaiParser):
         print(
             '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]'
         )
         opt = opt.parse_args()
     # Possibly load from checkpoint
     if opt['load_from_checkpoint'] and opt.get(
             'model_file') and os.path.isfile(opt['model_file'] +
                                              '.checkpoint'):
         opt['init_model'] = opt['model_file'] + '.checkpoint'
     # Possibly build a dictionary (not all models do this).
     if opt['dict_build_first'] and 'dict_file' in opt:
         # If data built via pytorch data teacher, we need to load prebuilt dict
         if opt.get('pytorch_teacher_task'):
             opt['dict_file'] = get_pyt_dict_file(opt)
         elif opt['dict_file'] is None and opt.get('model_file'):
             opt['dict_file'] = opt['model_file'] + '.dict'
         print("[ building dictionary first... ]")
         build_dict(opt, skip_if_built=True)
     # Create model and assign it to the specified task
     self.agent = create_agent(opt)
     self.world = create_task(opt, self.agent)
     self.train_time = Timer()
     self.validate_time = Timer()
     self.log_time = Timer()
     self.save_time = Timer()
     print('[ training... ]')
     self.parleys = 0
     self.max_num_epochs = opt[
         'num_epochs'] if opt['num_epochs'] > 0 else float('inf')
     self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 \
         else float('inf')
     self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \
         else float('inf')
     self.val_every_n_secs = \
         opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 \
         else float('inf')
     self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] \
         > 0 else float('inf')
     self.val_every_n_epochs = \
         opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 \
         else float('inf')
     self.last_valid_epoch = 0
     self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
     self.best_valid = None
     if opt.get('model_file') and os.path.isfile(opt['model_file'] +
                                                 '.best_valid'):
         with open(opt['model_file'] + ".best_valid", 'r') as f:
             x = f.readline()
             self.best_valid = float(x)
             f.close()
     self.impatience = 0
     self.saved = False
     self.valid_world = None
     self.opt = opt
     if opt['tensorboard_log'] is True:
         self.writer = TensorboardLogger(opt)
    def __init__(self, opt):
        if isinstance(opt, ParlaiParser):
            opt = opt.parse_args()
        # Possibly build a dictionary (not all models do this).
        if opt['dict_build_first'] and 'dict_file' in opt:
            if opt['dict_file'] is None and opt.get('model_file_transmitter') and opt.get('model_file_receiver'):
                opt['dict_file'] = opt['model_file_transmitter'] + '_' + opt['model_file_receiver']  + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt, skip_if_built=False)

        # Create model and assign it to the specified task
        print("[ create meta-agent ... ]")
        self.agent = create_agent(opt)
        print("[ create agent A ... ]")
        shared = self.agent.share()
        self.agent_a = create_agent_from_shared(shared)
        self.agent_a.set_id(suffix=' A')
        print("[ create agent B ... ]")
        self.agent_b = create_agent_from_shared(shared)
        # self.agent_b = create_agent(opt)
        self.agent_b.set_id(suffix=' B')
        # self.agent_a.copy(self.agent, 'transmitter')
        # self.agent_b.copy(self.agent, 'transmitter')
        self.world = create_selfplay_world(opt, [self.agent_a, self.agent_b])

        # TODO: if batch, it is also not parallel
        # self.world = BatchSelfPlayWorld(opt, self_play_world)

        self.train_time = Timer()
        self.train_dis_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')
        self.parleys_episode = 0
        self.max_num_epochs = opt['num_epochs'] if opt['num_epochs'] > 0 else float('inf')
        self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 else float('inf')
        self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 else float('inf')
        self.train_dis_every_n_secs = opt['train_display_every_n_secs'] if opt['train_display_every_n_secs'] > 0 else float('inf')
        self.val_every_n_secs = opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 else float('inf')
        self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] > 0 else float('inf')
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.best_valid = None
        if opt.get('model_file_transmitter') and os.path.isfile(opt['model_file_transmitter'] + '.best_valid'):
            with open(opt['model_file_transmitter'] + ".best_valid", 'r') as f:
                x = f.readline()
                self.best_valid = float(x)
                f.close()
        self.impatience = 0
        self.saved = False
        self.valid_world = None
        self.opt = opt
        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)
Exemple #4
0
    def __init__(self, opt, shared=None):
        super().__init__(opt, shared)

        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)

        # default one does not average
        self.rank_loss = torch.nn.CrossEntropyLoss(reduce=True,
                                                   size_average=True)
        torch.autograd.set_detect_anomaly(True)
        torch.manual_seed(123)
    def __init__(self, opt, shared=None):

        super().__init__(opt, shared)

        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)

        self.dictionnary_size = 177
        self.embedding_dim = 100
        self.batch_size = opt["batchsize"]

        self.criterion = nn.CrossEntropyLoss()

        def weight_init(m):
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)

        self.recurrent_entity_network = RecurrentEntityNetwork(
            self.dictionnary_size, self.embedding_dim, sequence_length=7)
        self.recurrent_entity_network.apply(weight_init)
        self.optimizer = optim.Adam(self.recurrent_entity_network.parameters())
        #self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 25, 0.5)
        self.batch_iter = 0
Exemple #6
0
    def __init__(self, opt):
        # if python is called from a non-interactive shell, like a bash script,
        # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are
        # not produced. This line brings them back
        signal.signal(signal.SIGINT, signal.default_int_handler)
        # Possibly load from checkpoint
        trainstats_suffix = '.trainstats'  # we might load training statistics from here
        if (
            opt['load_from_checkpoint']
            and opt.get('model_file')
            and PathManager.exists(opt['model_file'] + '.checkpoint')
        ):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
            trainstats_suffix = '.checkpoint.trainstats'
        # Possibly build a dictionary (not all models do this).
        if not (opt.get('dict_file') or opt.get('model_file')):
            raise RuntimeError(
                'WARNING: For train_model, please specify either a '
                'model_file or dict_file.'
            )
        if 'dict_file' in opt:
            if opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            logging.info("building dictionary first...")
            build_dict(opt, skip_if_built=True)

        # Create model and assign it to the specified task
        self.agent = create_agent(opt)
        self.agent.opt.log()
        self.world = create_task(opt, self.agent)
        # set up timers
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()

        self.parleys = 0
        self._train_steps = 0
        self._last_log_steps = 0
        self.update_freq = opt.get('update_freq', 1)

        self.max_num_epochs = _num_else_inf(opt, 'num_epochs', distributed_warn=True)
        self.max_train_time = _num_else_inf(
            opt, 'max_train_time', distributed_warn=True
        )
        self.max_train_steps = _num_else_inf(opt, 'max_train_steps')
        self.log_every_n_secs = _num_else_inf(
            opt, 'log_every_n_secs', distributed_warn=True
        )
        self.log_every_n_steps = _num_else_inf(opt, 'log_every_n_steps')
        self.val_every_n_secs = _num_else_inf(
            opt, 'validation_every_n_secs', distributed_warn=True
        )
        self.val_every_n_epochs = _num_else_inf(
            opt, 'validation_every_n_epochs', distributed_warn=True
        )
        self.val_every_n_steps = _num_else_inf(opt, 'validation_every_n_steps')
        self.save_every_n_secs = _num_else_inf(
            opt, 'save_every_n_secs', distributed_warn=True
        )

        # smart defaults for --validation-metric-mode
        if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
            opt['validation_metric_mode'] = 'min'
        elif opt['validation_metric'] in {'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'}:
            opt['validation_metric_mode'] = 'max'
        if opt.get('validation_metric_mode') is None:
            opt['validation_metric_mode'] = 'max'

        self.last_valid_epoch = 0
        self._last_valid_steps = 0
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.train_reports = []
        self.valid_reports = []
        self.final_valid_report = {}
        self.final_test_report = {}
        self.final_extra_valid_report = {}
        self.best_valid = None

        self.impatience = 0
        self.saved = False
        self.valid_worlds = None
        self.opt = opt

        # we may have been preempted, make sure we note that amount
        self._preempted_epochs = 0.0
        if opt.get('model_file') and PathManager.exists(
            opt['model_file'] + trainstats_suffix
        ):
            # looks like we were preempted. make sure we load up our total
            # training stats, etc
            with PathManager.open(opt['model_file'] + trainstats_suffix) as ts:
                obj = json.load(ts)
                self.parleys = obj.get('parleys', 0)
                self._preempted_epochs = obj.get('total_epochs', 0)
                self.train_time.total = obj.get('train_time', 0)
                self._train_steps = obj.get('train_steps', 0)
                self.impatience = obj.get('impatience', 0)
                self.valid_reports = obj.get('valid_reports', [])
                if self.valid_reports:
                    self.last_valid_epoch = self.valid_reports[-1].get(
                        'total_epochs', 0.0
                    )
                self.train_reports = obj.get('train_reports', [])
                if 'best_valid' in obj:
                    self.best_valid = obj['best_valid']
                else:
                    # old method
                    if opt.get('model_file') and PathManager.exists(
                        opt['model_file'] + '.best_valid'
                    ):
                        with PathManager.open(
                            opt['model_file'] + ".best_valid", 'r'
                        ) as f:
                            x = f.readline()
                            self.best_valid = float(x)
                            f.close()

        if opt['tensorboard_log'] and is_primary_worker():
            self.tb_logger = TensorboardLogger(opt)
        if opt['wandb_log'] and is_primary_worker():
            model = self.agent.model if hasattr(self.agent, 'model') else None
            self.wb_logger = WandbLogger(opt, model)
Exemple #7
0
    def __init__(self, opt):
        # if python is called from a non-interactive shell, like a bash script,
        # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are
        # not produced. This line brings them back
        signal.signal(signal.SIGINT, signal.default_int_handler)

        if isinstance(opt, ParlaiParser):
            print(
                '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]'
            )
            opt = opt.parse_args()
        # Possibly load from checkpoint
        trainstats_suffix = '.trainstats'  # we might load training statistics from here
        if opt['load_from_checkpoint'] and opt.get(
                'model_file') and os.path.isfile(opt['model_file'] +
                                                 '.checkpoint'):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
            trainstats_suffix = '.checkpoint.trainstats'
        # Possibly build a dictionary (not all models do this).
        if opt['dict_build_first'] and 'dict_file' in opt:
            # If data built via pytorch data teacher, we need to load prebuilt dict
            if opt.get('pytorch_teacher_task'):
                opt['dict_file'] = get_pyt_dict_file(opt)
            elif opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt, skip_if_built=True)
        # Create model and assign it to the specified task
        self.agent = create_agent(opt)  #specify model such as seq2seq
        self.world = create_task(opt, self.agent)  # bacthworld or other world
        # set up timers
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')
        self.parleys = 0
        self.max_num_epochs = opt[
            'num_epochs'] if opt['num_epochs'] > 0 else float('inf')
        self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 \
            else float('inf')
        self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \
            else float('inf')
        self.val_every_n_secs = \
            opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 \
            else float('inf')
        self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] \
            > 0 else float('inf')
        self.val_every_n_epochs = \
            opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 \
            else float('inf')

        # smart defaults for --validation-metric-mode
        if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
            opt['validation_metric_mode'] = 'min'
        elif opt['validation_metric'] in {
                'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'
        }:
            opt['validation_metric_mode'] = 'max'
        if opt.get('validation_metric_mode') is None:
            opt['validation_metric_mode'] = 'max'

        self.last_valid_epoch = 0
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.best_valid = None
        if opt.get('model_file') and os.path.isfile(opt['model_file'] +
                                                    '.best_valid'):
            with open(opt['model_file'] + ".best_valid", 'r') as f:
                x = f.readline()
                self.best_valid = float(x)
                f.close()
        self.impatience = 0
        self.saved = False
        self.valid_world = None
        self.opt = opt

        # we may have been preempted, make sure we note that amount
        self._preempted_epochs = 0.0
        if (opt.get('model_file')
                and os.path.isfile(opt['model_file'] + trainstats_suffix)):
            # looks like we were preempted. make sure we load up our total
            # training stats, etc
            with open(opt['model_file'] + trainstats_suffix) as ts:
                obj = json.load(ts)
                self._preempted_epochs = obj.get('total_epochs', 0)
                self.train_time.total = obj.get('train_time', 0)
                self.impatience = obj.get('impatience', 0)

        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)
Exemple #8
0
    def __init__(self, opt, shared=None):
        opt = copy.deepcopy(opt)
        super().__init__(opt, shared)
        self.use_cuda = not opt['no_cuda'] and torch.cuda.is_available()
        self.is_combine_attr = (hasattr(self, 'other_task_datafiles')
                                and self.other_task_datafiles)
        self.random_policy = opt.get('random_policy', False)
        self.count_sample = opt.get('count_sample', False)
        self.anti = opt.get('anti', False)

        if self.random_policy:
            random.seed(17)

        if not shared:
            if not self.stream and opt.get('pace_by', 'sample') == 'bucket':
                score_list = [episode[0][2] for episode in self.data.data]
                assert score_list == sorted(score_list)
                num_buckets = opt.get('num_buckets',
                                      int(self.num_episodes() / 10))
                lb_indices = [
                    int(len(score_list) * i / num_buckets)
                    for i in range(num_buckets)
                ]
                lbs = [score_list[idx] for idx in lb_indices]
                bucket_ids = [
                    self.sort_into_bucket(ctrl_val, lbs)
                    for ctrl_val in score_list
                ]
                bucket_cnt = [0 for _ in range(num_buckets)]
                for i in range(num_buckets):
                    bucket_cnt[i] = bucket_ids.count(i)
                self.bucket_cnt = bucket_cnt
            self.lastYs = [None] * self.bsz
            # build multiple task data
            self.tasks = [self.data]

            if self.is_combine_attr:
                print('[ build multiple task data ... ]')
                for datafile in self.other_task_datafiles:
                    task_opt = copy.deepcopy(opt)
                    task_opt['datafile'] = datafile
                    self.tasks.append(
                        DialogData(task_opt,
                                   data_loader=self.setup_data,
                                   cands=self.label_candidates()))
                print('[ build multiple task data done! ]')

                # record the selections of each subtasks
                self.subtasks = opt['subtasks'].split(':')
                self.subtask_counter = OrderedDict()
                self.p_selections = OrderedDict()
                self.c_selections = OrderedDict()
                for t in self.subtasks:
                    self.subtask_counter[t] = 0
                    self.p_selections[t] = []
                    self.c_selections[t] = []

                if self.count_sample and not self.stream:
                    self.sample_counter = OrderedDict()
                    for idx, t in enumerate(self.subtasks):
                        self.sample_counter[t] = [
                            0 for _ in self.tasks[idx].data
                        ]

            # setup the tensorboard log
            if opt['tensorboard_log_teacher'] is True:
                opt['tensorboard_tag'] = 'task'
                teacher_metrics = 'reward,policy_loss,critic_loss,mean_advantage_reward,action_ent'.split(
                    ',')
                opt['tensorboard_metrics'] = ','.join(
                    opt['tensorboard_metrics'].split(',') + teacher_metrics)
                self.writer = TensorboardLogger(opt)

        else:
            self.lastYs = shared['lastYs']
            self.tasks = shared['tasks']
            if not self.stream and opt.get('pace_by', 'sample') == 'bucket':
                self.bucket_cnt = shared['bucket_cnt']
            if 'writer' in shared:
                self.writer = shared['writer']
            if 'subtask_counter' in shared:
                self.subtask_counter = shared['subtask_counter']
            if 'p_selections' in shared:
                self.p_selections = shared['p_selections']
            if 'c_selections' in shared:
                self.c_selections = shared['c_selections']

        # build the policy net, criterion and optimizer here
        self.state_dim = 32 + len(self.tasks)  # hand-craft features
        self.action_dim = len(self.tasks)

        if not shared:
            self.policy = PolicyNet(self.state_dim, self.action_dim)
            self.critic = CriticNet(self.state_dim, self.action_dim)

            init_teacher = get_init_teacher(opt, shared)
            if init_teacher is not None:
                # load teacher parameters if available
                print('[ Loading existing teacher params from {} ]'
                      ''.format(init_teacher))
                states = self.load(init_teacher)
            else:
                states = {}
        else:
            self.policy = shared['policy']
            self.critic = shared['critic']
            states = shared['states']

        if (
                # only build an optimizer if we're training
                'train' in opt.get('datatype', '') and
                # and this is the main model
                shared is None):
            # for policy net
            self.optimizer = self.init_optim(
                [p for p in self.policy.parameters() if p.requires_grad],
                lr=opt['learningrate_teacher'],
                optim_states=states.get('optimizer'),
                saved_optim_type=states.get('optimizer_type'))
            self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer,
                'min',
                factor=0.8,  # 0.5 --> 0.8
                patience=5,  # 3 -- > 5
                verbose=True)
            if 'lr_scheduler' in states:
                self.scheduler.load_state_dict(states['lr_scheduler'])

            # for critic net
            self.optimizer_critic = self.init_optim(
                [p for p in self.critic.parameters() if p.requires_grad],
                lr=opt['learningrate_teacher_critic'],
                optim_states=states.get('optimizer_critic'),
                saved_optim_type=states.get('optimizer_type'))
            self.scheduler_critic = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer_critic,
                'min',
                factor=0.8,  # 0.5 --> 0.8
                patience=5,  # 3 -- > 5
                verbose=True)
            if 'lr_scheduler_critic' in states:
                self.scheduler_critic.load_state_dict(
                    states['lr_scheduler_critic'])

            self.critic_criterion = torch.nn.SmoothL1Loss()

        self.reward_metric = opt.get('reward_metric', 'total_metric')
        self.reward_metric_mode = opt.get('reward_metric_mode', 'max')

        self.prev_prev_valid_report = states[
            'prev_prev_valid_report'] if 'prev_prev_valid_report' in states else None
        self.prev_valid_report = states[
            'prev_valid_report'] if 'prev_valid_report' in states else None
        self.current_valid_report = states[
            'current_valid_report'] if 'current_valid_report' in states else None
        self.saved_actions = states[
            'saved_actions'] if 'saved_actions' in states else OrderedDict()
        self.saved_state_actions = states[
            'saved_state_actions'] if 'saved_state_actions' in states else OrderedDict(
            )
        if self.use_cuda:
            for k, v in self.saved_actions.items():
                self.saved_actions[k] = v.cuda()
            for k, v in self.saved_state_actions.items():
                self.saved_state_actions[k] = v.cuda()
        self._number_teacher_updates = states[
            '_number_teacher_updates'] if '_number_teacher_updates' in states else 0

        # enable the batch_act
        self.use_batch_act = self.bsz > 1

        self.T = self.opt.get('T', 1000)
        self.c0 = self.opt.get('c0', 0.01)
        self.p = self.opt.get('p', 2)

        # setup the timer
        self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \
            else float('inf')
        self.action_log_time = Timer()

        self.move_to_cuda()
    def __init__(self, opt):
        signal.signal(signal.SIGINT, signal.default_int_handler)

        if isinstance(opt, ParlaiParser):
            opt = opt.parse_args()
        # Possibly load from checkpoint
        trainstats_suffix = '.trainstats'
        if (opt.get('model_file')
                and isfile(opt['model_file'] + '.checkpoint')):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
            trainstats_suffix = '.checkpoint.trainstats'
        else:
            pass
            # TODO for testing only
            # raise RuntimeError('WARNING: Reinforcement learning'
            #                    ' must be initialized by a model.checkpoint '
            #                    'file and {} does not exist.'.format(
            #                        opt['model_file'] + '.checkpoint'))
        # Possibly build a dictionary (not all models do this).
        if (opt['dict_build_first']
                and not (opt.get('dict_file') or opt.get('model_file'))):
            raise RuntimeError('WARNING: For train_model, '
                               'please specify either a '
                               'model_file or dict_file.')

        if opt['dict_build_first'] and 'dict_file' in opt:
            if opt.get('pytorch_teacher_task'):
                opt['dict_file'] = get_pyt_dict_file(opt)
            elif opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt, skip_if_built=True)

        # Create model and assign it to the specified task
        self.agent = create_agent(opt)

        # Freeze the model for the static dialogue partner
        static_agent = copy.deepcopy(self.agent)
        self.agent.id = ACTIVE

        static_agent.id = STATIC
        freeze_agent(static_agent)

        self.world = create_task(opt, self.agent, static_agent)

        # set up timers
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')

        self.parleys = 0
        self.max_num_epochs = (opt['num_epochs']
                               if opt['num_epochs'] > 0 else float('inf'))

        self.max_train_time = (opt['max_train_time']
                               if opt['max_train_time'] > 0 else float('inf'))

        self.log_every_n_secs = (opt['log_every_n_secs'] if
                                 opt['log_every_n_secs'] > 0 else float('inf'))

        self.val_every_n_secs = (opt['validation_every_n_secs']
                                 if opt['validation_every_n_secs'] > 0 else
                                 float('inf'))

        self.save_every_n_secs = (opt['save_every_n_secs']
                                  if opt['save_every_n_secs'] > 0 else
                                  float('inf'))

        self.val_every_n_epochs = (opt['validation_every_n_epochs']
                                   if opt['validation_every_n_epochs'] > 0 else
                                   float('inf'))

        # smart defaults for --validation-metric-mode
        if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
            opt['validation_metric_mode'] = 'min'
        elif opt['validation_metric'] in {
                'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'
        }:
            opt['validation_metric_mode'] = 'max'
        if opt.get('validation_metric_mode') is None:
            opt['validation_metric_mode'] = 'max'

        self.last_valid_epoch = 0
        self.valid_optim = (1
                            if opt['validation_metric_mode'] == 'max' else -1)
        self.valid_reports = []
        self.best_valid = None
        if (opt.get('model_file')
                and isfile(opt['model_file'] + '.best_valid')):
            with open(opt['model_file'] + ".best_valid", 'r') as f:
                x = f.readline()
                self.best_valid = float(x)
                f.close()
        self.impatience = 0
        self.saved = False
        self.valid_world = None
        self.opt = opt

        # we may have been preempted, make sure we note that amount
        self._preempted_epochs = 0.0
        if (opt.get('model_file')
                and isfile(opt['model_file'] + trainstats_suffix)):
            # looks like we were preempted. make sure we load up our total
            # training stats, etc
            with open(opt['model_file'] + trainstats_suffix) as ts:
                obj = json.load(ts)
                self._preempted_epochs = obj.get('total_epochs', 0)
                self.train_time.total = obj.get('train_time', 0)
                self.impatience = obj.get('impatience', 0)
                self.valid_reports = obj.get('valid_reports', [])

        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)