Exemplo n.º 1
0
 def __init__(self, opt, agents=None, shared=None):
     self.id = opt['task']
     self.opt = copy.deepcopy(opt)
     if shared:
         # Create agents based on shared data.
         self.agents = create_agents_from_shared(shared['agents'])
     else:
         # Add passed in agents to world directly.
         self.agents = agents
     self.max_exs = None
     self.total_exs = 0
     self.total_epochs = 0
     self.total_parleys = 0
     self.time = Timer()
Exemplo n.º 2
0
def eval_model(parser, printargs=True):
    random.seed(42)
    opt = parser.parse_args(print_args=False)

    nomodel = False
    # check to make sure the model file exists
    if opt.get('model_file') is None:
        nomodel = True
    elif not os.path.isfile(opt['model_file']):
        raise RuntimeError('WARNING: Model file does not exist, check to make '
                           'sure it is correct: {}'.format(opt['model_file']))

    # Create model and assign it to the specified task
    agent = create_agent(opt)
    if nomodel and hasattr(agent, 'load'):
        # double check that we didn't forget to set model_file on loadable model
        raise RuntimeError('Stopping evaluation because model_file unset but '
                           'model has a `load` function.')
    world = create_task(opt, agent)
    # Show arguments after loading model
    parser.opt = agent.opt
    if (printargs):
        parser.print_args()
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = Timer()
    tot_time = 0

    # Show some example dialogs:
    cnt = 0
    while not world.epoch_done():
        cnt += 1
        world.parley()
        if opt['display_examples']:
            print("---")
            print(world.display() + "\n~~")
        if log_time.time() > log_every_n_secs:
            tot_time += log_time.time()
            print(str(int(tot_time)) + "s elapsed: " + str(world.report()))
            log_time.reset()
        if opt['num_examples'] > 0 and cnt >= opt['num_examples']:
            break
    if world.epoch_done():
        print("EPOCH DONE")
    report = world.report()
    print(report)
    return report
Exemplo n.º 3
0
def run_eval(valid_worlds, opt, datatype, max_exs=-1, write_log=False):
    """
    Eval on validation/test data.

    :param valid_world:
        list of the pre-created validation worlds.
    :param opt:
        the options that specific the task, eval_task, etc
    :param datatype:
        the datatype to use, such as "valid" or "test"
    :param bool write_log:
        specifies to write metrics to file if the model_file is set
    :param int max_exs:
        limits the number of examples if max_exs > 0
    """
    if valid_worlds is None:
        # This isn't the primary worker, so we can just skip evaluation
        return None

    print('[ running eval: ' + datatype + ' ]')
    timer = Timer()
    reports = []
    for v_world in valid_worlds:
        task_report = _run_single_eval(opt, v_world,
                                       max_exs / len(valid_worlds))
        reports.append(task_report)

    tasks = [world.opt['task'] for world in valid_worlds]
    report = aggregate_task_reports(reports,
                                    tasks,
                                    micro=opt.get('aggregate_micro', True))

    metrics = f'{datatype}:{report}'
    print(f'[ eval completed in {timer.time():.2f}s ]')
    print(metrics)

    # write to file
    if write_log and opt.get('model_file'):
        # Write out metrics
        f = open(opt['model_file'] + '.' + datatype, 'a+')
        f.write(metrics + '\n')
        f.close()

    return report
Exemplo n.º 4
0
def eval_ppl(opt, build_dict):
    """Evaluates the the perplexity of a model.

    See the documentation for this file for more info.

    :param opt: option dict
    :param build_dict: function for building official dictionary.
        note that this function does not use the opt passed into eval_ppl,
        but rather should have hardcoded settings for its dictionary.

    """
    dict_agent = build_dict()

    # create agents
    agent = create_agent(opt)
    world = create_task(opt, [agent, dict_agent], default_world=PerplexityWorld)

    # set up logging
    log_time = Timer()
    tot_time = 0

    while not world.epoch_done():
        world.parley()  # process an example

        if log_time.time() > 1:  # log every 1 sec
            tot_time += log_time.time()
            report = world.report()
            print('{}s elapsed, {}%% complete, {}'.format(
                int(tot_time),
                round_sigfigs(report['total'] / world.num_examples() * 100, 3),
                report))
            log_time.reset()
    print('EPOCH DONE')
    tot_time += log_time.time()
    final_report = world.report()
    print('{}s elapsed: {}'.format(int(tot_time), final_report))
    print("============================")
    print("FINAL PPL: " +str(final_report['ppl']))
    if final_report.get('ppl', 0) == float('inf'):
        print('Note: you got inf perplexity. Consider adding (or raising) the '
              'minimum probability you assign to each possible word. If you '
              'assign zero probability to the correct token in the evaluation '
              'vocabulary, you get inf probability immediately.')
Exemplo n.º 5
0
def eval_ppl(opt):
    """Evaluates the the perplexity and f1 of a model (and hits@1 if model has
    ranking enabled.
    """
    dict_agent = build_dict()

    # create agents
    agent = create_agent(opt)
    world = create_task(opt, [agent, dict_agent],
                        default_world=PerplexityWorld)
    world.dict = dict_agent

    # set up logging
    log_time = Timer()
    tot_time = 0

    while not world.epoch_done():
        world.parley()  # process an example

        if log_time.time() > 1:  # log every 1 sec
            tot_time += log_time.time()
            report = world.report()
            print('{}s elapsed, {}%% complete, {}'.format(
                int(tot_time),
                round_sigfigs(report['total'] / world.num_examples() * 100, 3),
                report))
            log_time.reset()
    if world.epoch_done():
        print('EPOCH DONE')
    tot_time += log_time.time()
    final_report = world.report()
    print('{}s elapsed: {}'.format(int(tot_time), final_report))
    print("============================")
    print("FINAL PPL: " + str(final_report['ppl']))
    if final_report.get('ppl', 0) == float('inf'):
        print('Note: you got inf perplexity. Consider adding (or raising) the '
              'minimum probability you assign to each possible word. If you '
              'assign zero probability to the correct token in the evaluation '
              'vocabulary, you get inf probability immediately.')
Exemplo n.º 6
0
    def __init__(self, opt, shared=None):
        super().__init__(opt, shared)
        if not shared:
            # don't enter this loop for shared instantiations
            self.path = opt['model_file']

            # loss
            self.loss_time = Timer()
            self.losses = []
            self.save_loss_every_n_secs = opt['save_loss_every_n_secs']

            # attention
            if opt['task'] == "babi:All1k":
                self.tasks = {'babi:Task1k:' + str(i): 0 for i in range(1, 21)}
            elif opt['task'] == "babi:All10k":
                self.tasks = {
                    'babi:Task10k:' + str(i): 0
                    for i in range(1, 21)
                }
            else:
                self.tasks = {task: 0 for task in opt['task'].split(',')}
            self.attention_weights = {task: [] for task in self.tasks.keys()}
            self.save_attention_exs = opt['save_attention_exs']
Exemplo n.º 7
0
def main():
    # Get command line arguments
    parser = ParlaiParser(True, True)
    train = parser.add_argument_group('Training Loop Arguments')
    train.add_argument('-et',
                       '--evaltask',
                       help=('task to use for valid/test (defaults to the ' +
                             'one used for training if not set)'))
    train.add_argument('-d', '--display-examples', type='bool', default=False)
    train.add_argument('-e', '--num-epochs', type=float, default=-1)
    train.add_argument('-ttim', '--max-train-time', type=float, default=-1)
    train.add_argument('-ltim', '--log-every-n-secs', type=float, default=2)
    train.add_argument('-lparl',
                       '--log-every-n-parleys',
                       type=int,
                       default=100)
    train.add_argument('-vtim',
                       '--validation-every-n-secs',
                       type=float,
                       default=-1)
    train.add_argument('-vparl',
                       '--validation-every-n-parleys',
                       type=int,
                       default=-1)
    train.add_argument('-vme',
                       '--validation-max-exs',
                       type=int,
                       default=-1,
                       help='max examples to use during validation (default ' +
                       '-1 uses all)')
    train.add_argument(
        '-vp',
        '--validation-patience',
        type=int,
        default=5,
        help=('number of iterations of validation where result ' +
              'does not improve before we stop training'))
    train.add_argument(
        '-vmt',
        '--validation-metric',
        default='accuracy',
        help='key into report table for selecting best validation')
    train.add_argument('-dbf',
                       '--dict-build-first',
                       type='bool',
                       default=True,
                       help='build dictionary first before training agent')
    opt = parser.parse_args()

    # Set logging
    logger = logging.getLogger('DrQA')
    logger.setLevel(logging.INFO)
    fmt = logging.Formatter('%(asctime)s: %(message)s', '%m/%d/%Y %I:%M:%S %p')
    console = logging.StreamHandler()
    console.setFormatter(fmt)
    logger.addHandler(console)
    if 'log_file' in opt:
        logfile = logging.FileHandler(opt['log_file'], 'w')
        logfile.setFormatter(fmt)
        logger.addHandler(logfile)
    logger.info('[ COMMAND: %s ]' % ' '.join(sys.argv))

    # Possibly build a dictionary (not all models do this).
    if opt['dict_build_first'] and 'dict_file' in opt:
        if opt['dict_file'] is None and opt.get('model_file'):
            opt['dict_file'] = opt['model_file'] + '.dict'
        logger.info("[ building dictionary first... ]")
        build_dict.build_dict(opt)

    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)

    train_time = Timer()
    validate_time = Timer()
    log_time = Timer()
    logger.info('[ training... ]')
    parleys = 0
    total_exs = 0
    max_exs = opt['num_epochs'] * len(world)
    max_parleys = math.ceil(max_exs / opt['batchsize'])
    best_valid = 0
    impatience = 0
    saved = False
    valid_world = None
    best_accuracy = 0

    while True:
        world.parley()
        parleys += 1

        if opt['num_epochs'] > 0 and parleys >= max_parleys:
            logger.info('[ num_epochs completed: {} ]'.format(
                opt['num_epochs']))
            break
        if opt['max_train_time'] > 0 and train_time.time(
        ) > opt['max_train_time']:
            logger.info('[ max_train_time elapsed: {} ]'.format(
                train_time.time()))
            break

#        instead of every_n_secs, use n_parleys
#        if opt['log_every_n_secs'] > 0 and log_time.time() > opt['log_every_n_secs']:
        if opt['log_every_n_parleys'] > 0 and parleys % opt[
                'log_every_n_parleys'] == 0:
            if opt['display_examples']:
                logger.info(world.display() + '\n~~')

            logs = []
            # time elapsed
            logs.append('time:{}s'.format(math.floor(train_time.time())))
            logs.append('parleys:{}'.format(parleys))

            # get report and update total examples seen so far
            if hasattr(agent, 'report'):
                train_report = agent.report()
                agent.reset_metrics()
            else:
                train_report = world.report()
                world.reset_metrics()

            if hasattr(train_report, 'get') and train_report.get('total'):
                total_exs += train_report['total']
                logs.append('total_exs:{}'.format(total_exs))

            # check if we should log amount of time remaining
            time_left = None
            if opt['num_epochs'] > 0:
                exs_per_sec = train_time.time() / total_exs
                time_left = (max_exs - total_exs) * exs_per_sec
            if opt['max_train_time'] > 0:
                other_time_left = opt['max_train_time'] - train_time.time()
                if time_left is not None:
                    time_left = min(time_left, other_time_left)
                else:
                    time_left = other_time_left
            if time_left is not None:
                logs.append('time_left:{}s'.format(math.floor(time_left)))

            # join log string and add full metrics report to end of log
            log = '[ {} ] {}'.format(' '.join(logs), train_report)

            logger.info(log)
            log_time.reset()


#        instead of every_n_secs, use n_parleys
#       if (opt['validation_every_n_secs'] > 0 and
#                    validate_time.time() > opt['validation_every_n_secs']):
        if (opt['validation_every_n_parleys'] > 0
                and parleys % opt['validation_every_n_parleys'] == 0):
            #if True :
            valid_report, valid_world = run_eval(agent,
                                                 opt,
                                                 'valid',
                                                 opt['validation_max_exs'],
                                                 logger=logger)
            #if False :
            if valid_report[opt['validation_metric']] > best_accuracy:
                best_accuracy = valid_report[opt['validation_metric']]
                impatience = 0
                logger.info('[ new best accuracy: ' + str(best_accuracy) +
                            ' ]')
                world.save_agents()
                saved = True
                if best_accuracy == 1:
                    logger.info('[ task solved! stopping. ]')
                    break
            #if True:
            else:
                opt['learning_rate'] *= 0.5
                agent.model.set_lrate(opt['learning_rate'])
                logger.info('[ Decrease learning_rate %.2e]' %
                            opt['learning_rate'])
                impatience += 1
                logger.info(
                    '[ did not beat best accuracy: {} impatience: {} ]'.format(
                        round(best_accuracy, 4), impatience))

            validate_time.reset()
            if opt['validation_patience'] > 0 and impatience >= opt[
                    'validation_patience']:
                logger.info('[ ran out of patience! stopping training. ]')
                break
            if opt['learning_rate'] < pow(10, -6):
                logger.info(
                    '[ learning_rate < pow(10,-6) ! stopping training. ]')
                break

    world.shutdown()
    if not saved:
        world.save_agents()
    else:
        # reload best validation model
        agent = create_agent(opt)

    run_eval(agent, opt, 'valid', write_log=True, logger=logger)
    run_eval(agent, opt, 'test', write_log=True, logger=logger)
Exemplo n.º 8
0
    def __init__(self, opt, shared=None):
        opt = copy.deepcopy(opt)
        super().__init__(opt, shared)
        self.use_cuda = not opt['no_cuda'] and torch.cuda.is_available()
        self.is_combine_attr = (hasattr(self, 'other_task_datafiles')
                                and self.other_task_datafiles)
        self.random_policy = opt.get('random_policy', False)
        self.count_sample = opt.get('count_sample', False)
        self.anti = opt.get('anti', False)

        if self.random_policy:
            random.seed(17)

        if not shared:
            if not self.stream and opt.get('pace_by', 'sample') == 'bucket':
                score_list = [episode[0][2] for episode in self.data.data]
                assert score_list == sorted(score_list)
                num_buckets = opt.get('num_buckets',
                                      int(self.num_episodes() / 10))
                lb_indices = [
                    int(len(score_list) * i / num_buckets)
                    for i in range(num_buckets)
                ]
                lbs = [score_list[idx] for idx in lb_indices]
                bucket_ids = [
                    self.sort_into_bucket(ctrl_val, lbs)
                    for ctrl_val in score_list
                ]
                bucket_cnt = [0 for _ in range(num_buckets)]
                for i in range(num_buckets):
                    bucket_cnt[i] = bucket_ids.count(i)
                self.bucket_cnt = bucket_cnt
            self.lastYs = [None] * self.bsz
            # build multiple task data
            self.tasks = [self.data]

            if self.is_combine_attr:
                print('[ build multiple task data ... ]')
                for datafile in self.other_task_datafiles:
                    task_opt = copy.deepcopy(opt)
                    task_opt['datafile'] = datafile
                    self.tasks.append(
                        DialogData(task_opt,
                                   data_loader=self.setup_data,
                                   cands=self.label_candidates()))
                print('[ build multiple task data done! ]')

                # record the selections of each subtasks
                self.subtasks = opt['subtasks'].split(':')
                self.subtask_counter = OrderedDict()
                self.p_selections = OrderedDict()
                self.c_selections = OrderedDict()
                for t in self.subtasks:
                    self.subtask_counter[t] = 0
                    self.p_selections[t] = []
                    self.c_selections[t] = []

                if self.count_sample and not self.stream:
                    self.sample_counter = OrderedDict()
                    for idx, t in enumerate(self.subtasks):
                        self.sample_counter[t] = [
                            0 for _ in self.tasks[idx].data
                        ]

            # setup the tensorboard log
            if opt['tensorboard_log_teacher'] is True:
                opt['tensorboard_tag'] = 'task'
                teacher_metrics = 'reward,policy_loss,critic_loss,mean_advantage_reward,action_ent'.split(
                    ',')
                opt['tensorboard_metrics'] = ','.join(
                    opt['tensorboard_metrics'].split(',') + teacher_metrics)
                self.writer = TensorboardLogger(opt)

        else:
            self.lastYs = shared['lastYs']
            self.tasks = shared['tasks']
            if not self.stream and opt.get('pace_by', 'sample') == 'bucket':
                self.bucket_cnt = shared['bucket_cnt']
            if 'writer' in shared:
                self.writer = shared['writer']
            if 'subtask_counter' in shared:
                self.subtask_counter = shared['subtask_counter']
            if 'p_selections' in shared:
                self.p_selections = shared['p_selections']
            if 'c_selections' in shared:
                self.c_selections = shared['c_selections']

        # build the policy net, criterion and optimizer here
        self.state_dim = 32 + len(self.tasks)  # hand-craft features
        self.action_dim = len(self.tasks)

        if not shared:
            self.policy = PolicyNet(self.state_dim, self.action_dim)
            self.critic = CriticNet(self.state_dim, self.action_dim)

            init_teacher = get_init_teacher(opt, shared)
            if init_teacher is not None:
                # load teacher parameters if available
                print('[ Loading existing teacher params from {} ]'
                      ''.format(init_teacher))
                states = self.load(init_teacher)
            else:
                states = {}
        else:
            self.policy = shared['policy']
            self.critic = shared['critic']
            states = shared['states']

        if (
                # only build an optimizer if we're training
                'train' in opt.get('datatype', '') and
                # and this is the main model
                shared is None):
            # for policy net
            self.optimizer = self.init_optim(
                [p for p in self.policy.parameters() if p.requires_grad],
                lr=opt['learningrate_teacher'],
                optim_states=states.get('optimizer'),
                saved_optim_type=states.get('optimizer_type'))
            self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer,
                'min',
                factor=0.8,  # 0.5 --> 0.8
                patience=5,  # 3 -- > 5
                verbose=True)
            if 'lr_scheduler' in states:
                self.scheduler.load_state_dict(states['lr_scheduler'])

            # for critic net
            self.optimizer_critic = self.init_optim(
                [p for p in self.critic.parameters() if p.requires_grad],
                lr=opt['learningrate_teacher_critic'],
                optim_states=states.get('optimizer_critic'),
                saved_optim_type=states.get('optimizer_type'))
            self.scheduler_critic = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer_critic,
                'min',
                factor=0.8,  # 0.5 --> 0.8
                patience=5,  # 3 -- > 5
                verbose=True)
            if 'lr_scheduler_critic' in states:
                self.scheduler_critic.load_state_dict(
                    states['lr_scheduler_critic'])

            self.critic_criterion = torch.nn.SmoothL1Loss()

        self.reward_metric = opt.get('reward_metric', 'total_metric')
        self.reward_metric_mode = opt.get('reward_metric_mode', 'max')

        self.prev_prev_valid_report = states[
            'prev_prev_valid_report'] if 'prev_prev_valid_report' in states else None
        self.prev_valid_report = states[
            'prev_valid_report'] if 'prev_valid_report' in states else None
        self.current_valid_report = states[
            'current_valid_report'] if 'current_valid_report' in states else None
        self.saved_actions = states[
            'saved_actions'] if 'saved_actions' in states else OrderedDict()
        self.saved_state_actions = states[
            'saved_state_actions'] if 'saved_state_actions' in states else OrderedDict(
            )
        if self.use_cuda:
            for k, v in self.saved_actions.items():
                self.saved_actions[k] = v.cuda()
            for k, v in self.saved_state_actions.items():
                self.saved_state_actions[k] = v.cuda()
        self._number_teacher_updates = states[
            '_number_teacher_updates'] if '_number_teacher_updates' in states else 0

        # enable the batch_act
        self.use_batch_act = self.bsz > 1

        self.T = self.opt.get('T', 1000)
        self.c0 = self.opt.get('c0', 0.01)
        self.p = self.opt.get('p', 2)

        # setup the timer
        self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \
            else float('inf')
        self.action_log_time = Timer()

        self.move_to_cuda()
Exemplo n.º 9
0
def main():
    # Get command line arguments
    parser = ParlaiParser(True, True)
    train = parser.add_argument_group('Training Loop Arguments')
    train.add_argument('-et', '--evaltask',
                        help=('task to use for valid/test (defaults to the ' +
                              'one used for training if not set)'))
    train.add_argument('-d', '--display-examples',
                        type='bool', default=False)
    train.add_argument('-e', '--num-epochs', type=int, default=-1)
    train.add_argument('-ttim', '--max-train-time',
                        type=float, default=-1)
    train.add_argument('-ltim', '--log-every-n-secs',
                        type=float, default=2)
    train.add_argument('-vtim', '--validation-every-n-secs',
                        type=float, default=-1)
    train.add_argument('-vme', '--validation-max-exs',
                        type=int, default=-1,
                        help='max examples to use during validation (default ' +
                             '-1 uses all)')
    train.add_argument('-vp', '--validation-patience',
                        type=int, default=5,
                        help=('number of iterations of validation where result '
                              + 'does not improve before we stop training'))
    train.add_argument('-dbf', '--dict-build-first',
                        type='bool', default=True,
                        help='build dictionary first before training agent')
    opt = parser.parse_args()
    # Possibly build a dictionary (not all models do this).
    if opt['dict_build_first'] and 'dict_file' in opt:
        if opt['dict_file'] is None and opt.get('model_file'):
            opt['dict_file'] = opt['model_file'] + '.dict'
        build_dict.build_dict(opt)
    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)

    train_time = Timer()
    validate_time = Timer()
    log_time = Timer()
    print('[ training... ]')
    total_exs = 0
    max_exs = opt['num_epochs'] * len(world)
    best_accuracy = 0
    impatience = 0
    saved = False
    while True:
        world.parley()
        if opt['num_epochs'] > 0 and total_exs >= max_exs:
            print('[ num_epochs completed: {} ]'.format(opt['num_epochs']))
            break
        if opt['max_train_time'] > 0 and train_time.time() > opt['max_train_time']:
            print('[ max_train_time elapsed: {} ]'.format(train_time.time()))
            break
        if opt['log_every_n_secs'] > 0 and log_time.time() > opt['log_every_n_secs']:
            if opt['display_examples']:
                print(world.display() + '\n~~')

            logs = []
            # time elapsed
            logs.append('time:{}s'.format(math.floor(train_time.time())))

            # get report and update total examples seen so far
            if hasattr(agent, 'report'):
                train_report = agent.report()
                agent.reset_metrics()
            else:
                train_report = world.report()
                world.reset_metrics()
            total_exs += train_report['total']
            logs.append('total_exs:{}'.format(total_exs))

            # check if we should log amount of time remaining
            time_left = None
            if opt['num_epochs'] > 0:
                exs_per_sec =  train_time.time() / total_exs
                time_left = (max_exs - total_exs) * exs_per_sec
            if opt['max_train_time'] > 0:
                other_time_left = opt['max_train_time'] - train_time.time()
                if time_left is not None:
                    time_left = min(time_left, other_time_left)
                else:
                    time_left = other_time_left
            if time_left is not None:
                logs.append('time_left:{}s'.format(math.floor(time_left)))

            # join log string and add full metrics report to end of log
            log = '[ {} ] {}'.format(' '.join(logs), train_report)

            print(log)
            log_time.reset()

        if (opt['validation_every_n_secs'] > 0 and
                validate_time.time() > opt['validation_every_n_secs']):
            valid_report = run_eval(agent, opt, 'valid', True, opt['validation_max_exs'])
            if valid_report['accuracy'] > best_accuracy:
                best_accuracy = valid_report['accuracy']
                impatience = 0
                print('[ new best accuracy: ' + str(best_accuracy) +  ' ]')
                if opt['model_file']:
                    agent.save(opt['model_file'])
                    saved = True
                if best_accuracy == 1:
                    print('[ task solved! stopping. ]')
                    break
            else:
                impatience += 1
                print('[ did not beat best accuracy: {} impatience: {} ]'.format(
                        round(best_accuracy, 4), impatience))
            validate_time.reset()
            if opt['validation_patience'] > 0 and impatience >= opt['validation_patience']:
                print('[ ran out of patience! stopping training. ]')
                break
    world.shutdown()
    if not saved:
        if opt['model_file']:
            agent.save(opt['model_file'])
    else:
        # reload best validation model
        agent = create_agent(opt)

    run_eval(agent, opt, 'valid')
    run_eval(agent, opt, 'test')
Exemplo n.º 10
0
    def __init__(self, opt):
        if isinstance(opt, ParlaiParser):
            opt = opt.parse_args()
        # Possibly build a dictionary (not all models do this).
        if opt['dict_build_first'] and 'dict_file' in opt:
            if opt['dict_file'] is None and opt.get(
                    'model_file_transmitter') and opt.get(
                        'model_file_receiver'):
                opt['dict_file'] = opt['model_file_transmitter'] + '_' + opt[
                    'model_file_receiver'] + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt, skip_if_built=False)

        # Create model and assign it to the specified task
        print("[ create meta-agent ... ]")
        self.agent = create_agent(opt)
        print("[ create agent A ... ]")
        shared = self.agent.share()
        self.agent_a = create_agent_from_shared(shared)
        self.agent_a.set_id(suffix=' A')
        print("[ create agent B ... ]")
        self.agent_b = create_agent_from_shared(shared)
        # self.agent_b = create_agent(opt)
        self.agent_b.set_id(suffix=' B')
        # self.agent_a.copy(self.agent, 'transmitter')
        # self.agent_b.copy(self.agent, 'transmitter')
        self.world = create_selfplay_world(opt, [self.agent_a, self.agent_b])

        # TODO: if batch, it is also not parallel
        # self.world = BatchSelfPlayWorld(opt, self_play_world)

        self.train_time = Timer()
        self.train_dis_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')
        self.parleys_episode = 0
        self.max_num_epochs = opt[
            'num_epochs'] if opt['num_epochs'] > 0 else float('inf')
        self.max_train_time = opt[
            'max_train_time'] if opt['max_train_time'] > 0 else float('inf')
        self.log_every_n_secs = opt['log_every_n_secs'] if opt[
            'log_every_n_secs'] > 0 else float('inf')
        self.train_dis_every_n_secs = opt['train_display_every_n_secs'] if opt[
            'train_display_every_n_secs'] > 0 else float('inf')
        self.val_every_n_secs = opt['validation_every_n_secs'] if opt[
            'validation_every_n_secs'] > 0 else float('inf')
        self.save_every_n_secs = opt['save_every_n_secs'] if opt[
            'save_every_n_secs'] > 0 else float('inf')
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.best_valid = None
        if opt.get('model_file_transmitter') and os.path.isfile(
                opt['model_file_transmitter'] + '.best_valid'):
            with open(opt['model_file_transmitter'] + ".best_valid", 'r') as f:
                x = f.readline()
                self.best_valid = float(x)
                f.close()
        self.impatience = 0
        self.saved = False
        self.valid_world = None
        self.opt = opt
        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)
Exemplo n.º 11
0
def main():
    # Get command line arguments
    parser = ParlaiParser(True, True)
    parser.add_argument('-d', '--display-examples',
                        type='bool', default=False)
    parser.add_argument('-e', '--num-epochs', type=int, default=1)
    parser.add_argument('-ttim', '--max-train-time',
                        type=float, default=float('inf'))
    parser.add_argument('-ltim', '--log-every-n-secs',
                        type=float, default=1)
    parser.add_argument('-vtim', '--validation-every-n-secs',
                        type=float, default=False)
    parser.add_argument('-vp', '--validation-patience',
                        type=int, default=5,
                        help=('number of iterations of validation where result '
                              + 'does not improve before we stop training'))
    parser.add_argument('-dbf', '--dict_build_first',
                        type='bool', default=True,
                        help='build dictionary first before training agent')
    opt = parser.parse_args()
    # Possibly build a dictionary (not all models do this).
    if opt['dict_build_first']:
        if 'dict_file' not in opt and 'model_file' in opt:
            opt['dict_file'] = opt['model_file'] + '.dict'
        build_dict.build_dict(opt)
    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)

    train_time = Timer()
    validate_time = Timer()
    log_time = Timer()
    print("[ training... ]")
    parleys = 0
    num_parleys = opt['num_epochs'] * int(len(world) / opt['batchsize'])
    best_accuracy = 0
    impatience = 0
    saved = False
    for i in range(num_parleys):
        world.parley()
        parleys = parleys + 1
        if train_time.time() > opt['max_train_time']:
            print("[ max_train_time elapsed: " + str(train_time.time()) + " ]")
            break
        if log_time.time() > opt['log_every_n_secs']:
            if opt['display_examples']:
                print(world.display() + "\n~~")
            parleys_per_sec =  train_time.time() / parleys
            time_left = (num_parleys - parleys) * parleys_per_sec
            log = ("[ time:" + str(math.floor(train_time.time()))
                  + "s parleys:" + str(parleys)
                  + " time_left:"
                  + str(math.floor(time_left))  + "s ]")
            if hasattr(agent, 'report'):
                log = log + str(agent.report())
            else:
                log = log + str(world.report())
                # TODO: world.reset_metrics()
            print(log)
            log_time.reset()
        if (opt['validation_every_n_secs'] and
            validate_time.time() > opt['validation_every_n_secs']):
            valid_report = run_eval(agent, opt, 'valid', True)
            if valid_report['accuracy'] > best_accuracy:
                best_accuracy = valid_report['accuracy']
                impatience = 0
                print("[ new best accuracy: " + str(best_accuracy) +  " ]")
                if opt['model_file']:
                    agent.save(opt['model_file'])
                    saved = True
                if best_accuracy == 1:
                    print('[ task solved! stopping. ]')
                    break
            else:
                impatience += 1
                print("[ did not beat best accuracy: " + str(best_accuracy) +
                      " impatience: " + str(impatience)  + " ]")
            validate_time.reset()
            if impatience >= opt['validation_patience']:
                print('[ ran out of patience! stopping. ]')
                break
    world.shutdown()
    if not saved:
        if opt['model_file']:
            agent.save(opt['model_file'])
    else:
        # reload best validation model
        agent = create_agent(opt)

    run_eval(agent, opt, 'valid')
    run_eval(agent, opt, 'test')
Exemplo n.º 12
0
def main(parser):
    opt = parser.parse_args()
    # Possibly build a dictionary (not all models do this).
    if opt['dict_build_first'] and 'dict_file' in opt:
        if opt['dict_file'] is None and opt.get('model_file'):
            opt['dict_file'] = opt['model_file'] + '.dict'
        print("[ building dictionary first... ]")
        build_dict.build_dict(opt)
    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)

    train_time = Timer()
    validate_time = Timer()
    log_time = Timer()
    print('[ training... ]')
    parleys = 0
    total_exs = 0
    max_exs = opt['num_epochs'] * len(world)
    max_parleys = math.ceil(max_exs / opt['batchsize'])
    best_valid = 0
    impatience = 0
    saved = False
    valid_world = None
    with world:
        while True:
            world.parley()
            parleys += 1

            if opt['num_epochs'] > 0 and parleys >= max_parleys:
                print('[ num_epochs completed: {} ]'.format(opt['num_epochs']))
                break
            if opt['max_train_time'] > 0 and train_time.time(
            ) > opt['max_train_time']:
                print('[ max_train_time elapsed: {} ]'.format(
                    train_time.time()))
                break
            if opt['log_every_n_secs'] > 0 and log_time.time(
            ) > opt['log_every_n_secs']:
                if opt['display_examples']:
                    print(world.display() + '\n~~')

                logs = []
                # time elapsed
                logs.append('time:{}s'.format(math.floor(train_time.time())))
                logs.append('parleys:{}'.format(parleys))

                # get report and update total examples seen so far
                if hasattr(agent, 'report'):
                    train_report = agent.report()
                    agent.reset_metrics()
                else:
                    train_report = world.report()
                    world.reset_metrics()

                if hasattr(train_report, 'get') and train_report.get('total'):
                    total_exs += train_report['total']
                    logs.append('total_exs:{}'.format(total_exs))

                # check if we should log amount of time remaining
                time_left = None
                if opt['num_epochs'] > 0:
                    exs_per_sec = train_time.time() / total_exs
                    time_left = (max_exs - total_exs) * exs_per_sec
                if opt['max_train_time'] > 0:
                    other_time_left = opt['max_train_time'] - train_time.time()
                    if time_left is not None:
                        time_left = min(time_left, other_time_left)
                    else:
                        time_left = other_time_left
                if time_left is not None:
                    logs.append('time_left:{}s'.format(math.floor(time_left)))

                # join log string and add full metrics report to end of log
                log = '[ {} ] {}'.format(' '.join(logs), train_report)

                print(log)
                log_time.reset()

            if (opt['validation_every_n_secs'] > 0
                    and validate_time.time() > opt['validation_every_n_secs']):
                valid_report, valid_world = run_eval(agent,
                                                     opt,
                                                     'valid',
                                                     opt['validation_max_exs'],
                                                     valid_world=valid_world)
                if valid_report[opt['validation_metric']] > best_valid:
                    best_valid = valid_report[opt['validation_metric']]
                    impatience = 0
                    print('[ new best {}: {} ]'.format(
                        opt['validation_metric'], best_valid))
                    world.save_agents()
                    saved = True
                    if opt['validation_metric'] == 'accuracy' and best_valid > 99.5:
                        print('[ task solved! stopping. ]')
                        break
                else:
                    impatience += 1
                    print('[ did not beat best {}: {} impatience: {} ]'.format(
                        opt['validation_metric'], round(best_valid, 4),
                        impatience))
                validate_time.reset()
                if opt['validation_patience'] > 0 and impatience >= opt[
                        'validation_patience']:
                    print('[ ran out of patience! stopping training. ]')
                    break
    if not saved:
        # save agent
        world.save_agents()
    elif opt.get('model_file'):
        # reload best validation model
        agent = create_agent(opt)

    run_eval(agent, opt, 'valid', write_log=True)
    run_eval(agent, opt, 'test', write_log=True)
Exemplo n.º 13
0
def eval_model(opt, printargs=None, print_parser=None):
    """Evaluates a model.

    Arguments:
    opt -- tells the evaluation function how to run
    print_parser -- if provided, prints the options that are set within the
        model after loading the model
    """
    if printargs is not None:
        print('[ Deprecated Warning: eval_model no longer uses `printargs` ]')
        print_parser = printargs
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    if isinstance(opt, ParlaiParser):
        print(
            '[ Deprecated Warning: eval_model should be passed opt not Parser ]'
        )
        opt = opt.parse_args()

    random.seed(42)

    nomodel = False
    # check to make sure the model file exists
    if opt.get('model_file') is None:
        nomodel = True
    elif not os.path.isfile(opt['model_file']):
        raise RuntimeError('WARNING: Model file does not exist, check to make '
                           'sure it is correct: {}'.format(opt['model_file']))

    # Create model and assign it to the specified task
    agent = create_agent(opt)
    if nomodel and hasattr(agent, 'load'):
        # double check that we didn't forget to set model_file on loadable model
        print('WARNING: model_file unset but model has a `load` function.')
    world = create_task(opt, agent)

    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = Timer()
    tot_time = 0

    # Show some example dialogs:
    cnt = 0
    while not world.epoch_done():
        cnt += 1
        world.parley()
        if opt['display_examples']:
            print("---")
            print(world.display() + "\n~~")
        if log_time.time() > log_every_n_secs:
            tot_time += log_time.time()
            print(str(int(tot_time)) + "s elapsed: " + str(world.report()))
            log_time.reset()
        if opt['num_examples'] > 0 and cnt >= opt['num_examples']:
            break
    if world.epoch_done():
        print("EPOCH DONE")
    report = world.report()
    print(report)
    return report
Exemplo n.º 14
0
def __train_single_model(opt):
    """Train single model.
    opt is a dictionary returned by arg_parse
    """
    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)
    print('[ training... ]')

    train_dict = {
        'train_time': Timer(),
        'validate_time': Timer(),
        'log_time': Timer(),
        'new_epoch': None,
        'epochs_done': 0,
        'max_exs': opt['num_epochs'] * len(world),
        'total_exs': 0,
        'parleys': 0,
        'max_parleys':
        math.ceil(opt['num_epochs'] * len(world) / opt['batchsize']),
        'best_metrics': opt['chosen_metrics'],
        'best_metrics_value': 0,
        'impatience': 0,
        'lr_drop_impatience': 0,
        'saved': False,
        'train_report': None,
        'train_report_agent': None,
        'train_report_world': None,
        'break': None
    }

    try:
        while True:
            world.parley()
            train_dict['parleys'] += 1
            train_dict['new_epoch'] = world.epoch_done()
            if train_dict['new_epoch']:
                world.reset()
                train_dict['epochs_done'] += 1
            if opt['num_epochs'] > 0 and train_dict['parleys'] >= train_dict[
                    'max_parleys']:
                print('[ num_epochs completed: {} ]'.format(opt['num_epochs']))
                break
            if 0 < opt['max_train_time'] < train_dict['train_time'].time():
                print('[ max_train_time elapsed: {} ]'.format(
                    train_dict['train_time'].time()))
                break
            world, agent, train_dict = __train_log(opt, world, agent,
                                                   train_dict)
            _, agent, train_dict = __intermediate_validation(
                opt, world, agent, train_dict)

            if train_dict['break']:
                break
    except KeyboardInterrupt:
        print('Stopped training, starting testing')

    if not train_dict['saved']:
        world.save_agents()

    world.shutdown()
    agent.shutdown()

    # reload best validation model
    vopt = copy.deepcopy(opt)
    if vopt.get('evaltask'):
        vopt['task'] = vopt['evaltask']
    vopt['datatype'] = 'valid'
    vopt['pretrained_model'] = vopt['model_file']
    agent = create_agent(vopt)
    valid_world = create_task(vopt, agent)
    metrics, _ = __evaluate_model(valid_world, vopt['batchsize'], 'valid',
                                  vopt['display_examples'],
                                  vopt['validation_max_exs'])
    valid_world.shutdown()
    agent.shutdown()
    return metrics
Exemplo n.º 15
0
def eval_ppl(opt, build_dict=None, dict_file=None):
    """Evaluates the the perplexity of a model.

    This uses a dictionary which implements the following functions:
    - tokenize(text): splits string up into list of tokens
    - __in__(text): checks whether dictionary contains a token
    - keys(): returns an iterator over all tokens in the dictionary

    :param opt: option dict
    :param build_dict: function which returns a dictionary class implementing
        the functions above.
    :param dict_file: file used when loading the dictionary class set via the
        "dictionary_class" argument (defaults to
        parlai.core.dict:DictionaryAgent).

    Either build_dict or dict_file must be set (both default to None) to
    determine the dictionary used for the evaluation.
    """
    if not build_dict and not dict_file:
        raise RuntimeError('eval_ppl script either needs a dictionary build '
                           'function or a dictionary file.')

    if build_dict:
        dict_agent = build_dict()
    else:
        dict_opt = copy.deepcopy(opt)
        dict_opt['model'] = dict_opt.get(
            'dictionary_class', 'parlai.core.dict:DictionaryAgent'
        )
        dict_opt['model_file'] = dict_file
        if 'override' in dict_opt:
            del dict_opt['override']
        dict_agent = create_agent(dict_opt, requireModelExists=True)

    # create agents
    agent = create_agent(opt)
    world = create_task(opt, [agent, dict_agent], default_world=PerplexityWorld)

    # set up logging
    log_time = Timer()
    tot_time = 0

    while not world.epoch_done():
        world.parley()  # process an example

        if log_time.time() > 1:  # log every 1 sec
            tot_time += log_time.time()
            report = world.report()
            print('{}s elapsed, {}%% complete, {}'.format(
                int(tot_time),
                round_sigfigs(report['exs'] / world.num_examples() * 100, 3),
                report))
            log_time.reset()
    print('EPOCH DONE')
    tot_time += log_time.time()
    final_report = world.report()
    print('{}s elapsed: {}'.format(int(tot_time), final_report))
    print("============================")
    print("FINAL PPL: " + str(final_report['ppl']))
    if final_report.get('ppl', 0) == float('inf'):
        print('Note: you got inf perplexity. Consider adding (or raising) the '
              'minimum probability you assign to each possible word. If you '
              'assign zero probability to the correct token in the evaluation '
              'vocabulary, you get inf probability immediately.')
Exemplo n.º 16
0
    def __init__(self, opt):
        signal.signal(signal.SIGINT, signal.default_int_handler)

        if isinstance(opt, ParlaiParser):
            opt = opt.parse_args()
        # Possibly load from checkpoint
        trainstats_suffix = '.trainstats'
        if (opt.get('model_file')
                and isfile(opt['model_file'] + '.checkpoint')):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
            trainstats_suffix = '.checkpoint.trainstats'
        else:
            pass
            # TODO for testing only
            # raise RuntimeError('WARNING: Reinforcement learning'
            #                    ' must be initialized by a model.checkpoint '
            #                    'file and {} does not exist.'.format(
            #                        opt['model_file'] + '.checkpoint'))
        # Possibly build a dictionary (not all models do this).
        if (opt['dict_build_first']
                and not (opt.get('dict_file') or opt.get('model_file'))):
            raise RuntimeError('WARNING: For train_model, '
                               'please specify either a '
                               'model_file or dict_file.')

        if opt['dict_build_first'] and 'dict_file' in opt:
            if opt.get('pytorch_teacher_task'):
                opt['dict_file'] = get_pyt_dict_file(opt)
            elif opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt, skip_if_built=True)

        # Create model and assign it to the specified task
        self.agent = create_agent(opt)

        # Freeze the model for the static dialogue partner
        static_agent = copy.deepcopy(self.agent)
        self.agent.id = ACTIVE

        static_agent.id = STATIC
        freeze_agent(static_agent)

        self.world = create_task(opt, self.agent, static_agent)

        # set up timers
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')

        self.parleys = 0
        self.max_num_epochs = (opt['num_epochs']
                               if opt['num_epochs'] > 0 else float('inf'))

        self.max_train_time = (opt['max_train_time']
                               if opt['max_train_time'] > 0 else float('inf'))

        self.log_every_n_secs = (opt['log_every_n_secs'] if
                                 opt['log_every_n_secs'] > 0 else float('inf'))

        self.val_every_n_secs = (opt['validation_every_n_secs']
                                 if opt['validation_every_n_secs'] > 0 else
                                 float('inf'))

        self.save_every_n_secs = (opt['save_every_n_secs']
                                  if opt['save_every_n_secs'] > 0 else
                                  float('inf'))

        self.val_every_n_epochs = (opt['validation_every_n_epochs']
                                   if opt['validation_every_n_epochs'] > 0 else
                                   float('inf'))

        # smart defaults for --validation-metric-mode
        if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
            opt['validation_metric_mode'] = 'min'
        elif opt['validation_metric'] in {
                'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'
        }:
            opt['validation_metric_mode'] = 'max'
        if opt.get('validation_metric_mode') is None:
            opt['validation_metric_mode'] = 'max'

        self.last_valid_epoch = 0
        self.valid_optim = (1
                            if opt['validation_metric_mode'] == 'max' else -1)
        self.valid_reports = []
        self.best_valid = None
        if (opt.get('model_file')
                and isfile(opt['model_file'] + '.best_valid')):
            with open(opt['model_file'] + ".best_valid", 'r') as f:
                x = f.readline()
                self.best_valid = float(x)
                f.close()
        self.impatience = 0
        self.saved = False
        self.valid_world = None
        self.opt = opt

        # we may have been preempted, make sure we note that amount
        self._preempted_epochs = 0.0
        if (opt.get('model_file')
                and isfile(opt['model_file'] + trainstats_suffix)):
            # looks like we were preempted. make sure we load up our total
            # training stats, etc
            with open(opt['model_file'] + trainstats_suffix) as ts:
                obj = json.load(ts)
                self._preempted_epochs = obj.get('total_epochs', 0)
                self.train_time.total = obj.get('train_time', 0)
                self.impatience = obj.get('impatience', 0)
                self.valid_reports = obj.get('valid_reports', [])

        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)
Exemplo n.º 17
0
    def __init__(self, opt):
        # if python is called from a non-interactive shell, like a bash script,
        # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are
        # not produced. This line brings them back
        signal.signal(signal.SIGINT, signal.default_int_handler)

        if isinstance(opt, ParlaiParser):
            print(
                '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]'
            )
            opt = opt.parse_args()
        # Possibly load from checkpoint
        trainstats_suffix = '.trainstats'  # we might load training statistics from here
        if opt['load_from_checkpoint'] and opt.get(
                'model_file') and os.path.isfile(opt['model_file'] +
                                                 '.checkpoint'):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
            trainstats_suffix = '.checkpoint.trainstats'
        # Possibly build a dictionary (not all models do this).
        if opt['dict_build_first'] and 'dict_file' in opt:
            # If data built via pytorch data teacher, we need to load prebuilt dict
            if opt.get('pytorch_teacher_task'):
                opt['dict_file'] = get_pyt_dict_file(opt)
            elif opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt, skip_if_built=True)
        # Create model and assign it to the specified task
        self.agent = create_agent(opt)  #specify model such as seq2seq
        self.world = create_task(opt, self.agent)  # bacthworld or other world
        # set up timers
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')
        self.parleys = 0
        self.max_num_epochs = opt[
            'num_epochs'] if opt['num_epochs'] > 0 else float('inf')
        self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 \
            else float('inf')
        self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \
            else float('inf')
        self.val_every_n_secs = \
            opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 \
            else float('inf')
        self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] \
            > 0 else float('inf')
        self.val_every_n_epochs = \
            opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 \
            else float('inf')

        # smart defaults for --validation-metric-mode
        if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
            opt['validation_metric_mode'] = 'min'
        elif opt['validation_metric'] in {
                'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'
        }:
            opt['validation_metric_mode'] = 'max'
        if opt.get('validation_metric_mode') is None:
            opt['validation_metric_mode'] = 'max'

        self.last_valid_epoch = 0
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.best_valid = None
        if opt.get('model_file') and os.path.isfile(opt['model_file'] +
                                                    '.best_valid'):
            with open(opt['model_file'] + ".best_valid", 'r') as f:
                x = f.readline()
                self.best_valid = float(x)
                f.close()
        self.impatience = 0
        self.saved = False
        self.valid_world = None
        self.opt = opt

        # we may have been preempted, make sure we note that amount
        self._preempted_epochs = 0.0
        if (opt.get('model_file')
                and os.path.isfile(opt['model_file'] + trainstats_suffix)):
            # looks like we were preempted. make sure we load up our total
            # training stats, etc
            with open(opt['model_file'] + trainstats_suffix) as ts:
                obj = json.load(ts)
                self._preempted_epochs = obj.get('total_epochs', 0)
                self.train_time.total = obj.get('train_time', 0)
                self.impatience = obj.get('impatience', 0)

        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)
Exemplo n.º 18
0
def train_model(opt):
    # Possibly build a dictionary (not all models do this).
    if opt['dict_build_first'] and 'dict_file' in opt:
        if opt['dict_file'] is None and opt.get('pretrained_model'):
            opt['dict_file'] = opt['pretrained_model'] + '.dict'
        if opt['dict_file'] is None and opt.get('model_file'):
            opt['dict_file'] = opt['model_file'] + '.dict'
        print("[ building dictionary first... ]")
        build_dict.build_dict(opt)
    # Create model and assign it to the specified task
    agent = create_agent(opt)
    if opt['datatype'].split(':')[0] == 'train':
        world = create_task(opt, agent)

        train_time = Timer()
        validate_time = Timer()
        log_time = Timer()
        print('[ training... ]')
        parleys = 0
        total_exs = 0
        max_exs = opt['num_epochs'] * len(world)
        epochs_done = 0
        max_parleys = math.ceil(max_exs / opt['batchsize'])
        best_metric_name = opt['chosen_metric']
        best_metric = 0
        impatience = 0
        saved = False
        valid_world = None
        try:
            while True:
                world.parley()
                parleys += 1
                new_epoch = world.epoch_done()
                if new_epoch:
                    world.reset()
                    epochs_done += 1

                if opt['num_epochs'] > 0 and parleys >= max_parleys:
                    print('[ num_epochs completed: {} ]'.format(opt['num_epochs']))
                    break
                if 0 < opt['max_train_time'] < train_time.time():
                    print('[ max_train_time elapsed: {} ]'.format(train_time.time()))
                    break
                if (0 < opt['log_every_n_secs'] < log_time.time()) or \
                        (opt['log_every_n_epochs'] > 0 and new_epoch and
                                 (epochs_done % opt['log_every_n_epochs']) == 0):
                    if opt['display_examples']:
                        print(world.display() + '\n~~')

                    logs = list()
                    # time elapsed
                    logs.append('time:{}s'.format(math.floor(train_time.time())))
                    logs.append('parleys:{}'.format(parleys))
                    if epochs_done > 0:
                        logs.append('epochs done:{}'.format(epochs_done))

                    # get report and update total examples seen so far
                    if hasattr(agent, 'report'):
                        train_report = agent.report()
                        agent.reset_metrics()
                    else:
                        train_report = world.report()
                        world.reset_metrics()

                    if hasattr(train_report, 'get') and train_report.get('total'):
                        total_exs += train_report['total']
                        logs.append('total_exs:{}'.format(total_exs))

                    # check if we should log amount of time remaining
                    time_left = None
                    if opt['num_epochs'] > 0 and total_exs > 0:
                        exs_per_sec = train_time.time() / total_exs
                        time_left = (max_exs - total_exs) * exs_per_sec
                    if opt['max_train_time'] > 0:
                        other_time_left = opt['max_train_time'] - train_time.time()
                        if time_left is not None:
                            time_left = min(time_left, other_time_left)
                        else:
                            time_left = other_time_left
                    if time_left is not None:
                        logs.append('time_left:{}s'.format(math.floor(time_left)))

                    # join log string and add full metrics report to end of log
                    log = '[ {} ] {}'.format(' '.join(logs), train_report)

                    print(log)
                    log_time.reset()

                if 0 < opt['validation_every_n_secs'] < validate_time.time() or \
                        (opt['validation_every_n_epochs'] > 0 and new_epoch and (
                                    epochs_done % opt['validation_every_n_epochs']) == 0):

                    valid_report, valid_world = run_eval(agent, opt, 'valid',
                                                         opt['validation_max_exs'],
                                                         valid_world=valid_world)
                    if best_metric_name not in valid_report and 'accuracy' in valid_report:
                        best_metric_name = 'accuracy'
                    if valid_report[best_metric_name] > best_metric:
                        best_metric = valid_report[best_metric_name]
                        impatience = 0
                        print('[ new best ' + best_metric_name + ': ' + str(best_metric) + ' ]')
                        world.save_agents()
                        saved = True
                        if best_metric == 1:
                            print('[ task solved! stopping. ]')
                            break
                    else:
                        impatience += 1
                        print('[ did not beat best ' + best_metric_name + ': {} impatience: {} ]'.format(
                                round(best_metric, 4), impatience))
                    validate_time.reset()
                    if 0 < opt['validation_patience'] <= impatience:
                        print('[ ran out of patience! stopping training. ]')
                        break
        except KeyboardInterrupt:
            print('Stopped training, starting testing')

        if not saved:
            world.save_agents()
        # else:
        world.shutdown()

        # reload best validation model
        opt['pretrained_model'] = opt['model_file']
        agent = create_agent(opt)

        run_eval(agent, opt, 'valid', write_log=True)
        run_eval(agent, opt, 'test', write_log=True)
    else:
        run_eval(agent, opt, opt['datatype'], write_log=True)
    agent.shutdown()