コード例 #1
0
ファイル: train_model.py プロジェクト: magicye/ParlAI
    def train(self):
        """
        Perform a training run.

        :return: tuple of reports (validation_report, test_report)
        """
        opt = self.opt
        for _train_log in self.train_steps():
            # we've already done what we need in these
            pass

        # perform final validation/testing
        valid_worlds = load_eval_worlds(self.agent, opt, 'valid')
        max_exs = opt['validation_max_exs'] if opt.get('short_final_eval') else -1
        v_report = self._run_eval(valid_worlds, opt, 'valid', max_exs, write_log=True)
        test_worlds = load_eval_worlds(self.agent, opt, 'test')
        t_report = self._run_eval(test_worlds, opt, 'test', max_exs, write_log=True)

        if opt['wandb_log'] and is_primary_worker():
            self.wb_logger.log_final('valid', v_report)
            self.wb_logger.log_final('test', t_report)
            self.wb_logger.finish()

        if valid_worlds:
            for valid_world in valid_worlds:
                valid_world.shutdown()
        if test_worlds:
            for test_world in test_worlds:
                test_world.shutdown()

        print_announcements(opt)

        return v_report, t_report
コード例 #2
0
def eval_model(opt, print_parser=None):
    """Evaluates a model.

    :param opt: tells the evaluation function how to run
    :param bool print_parser: if provided, prints the options that are set within the
        model after loading the model
    :return: the final result of calling report()
    """
    random.seed(42)

    # load model and possibly print opt
    agent = create_agent(opt, requireModelExists=True)
    if print_parser:
        # show args after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()

    tasks = opt['task'].split(',')
    reports = []
    for task in tasks:
        task_report = _eval_single_world(opt, agent, task)
        reports.append(task_report)

    report = aggregate_task_reports(reports,
                                    tasks,
                                    micro=opt.get('aggregate_micro', True))

    # print announcments and report
    print_announcements(opt)
    print('[ Finished evaluating tasks {} using datatype {} ]'.format(
        tasks, opt.get('datatype', 'N/A')))
    print(report)

    return report
コード例 #3
0
def eval_model(opt, print_parser=None):
    """
    Evaluates a model.

    :param opt: tells the evaluation function how to run
    :param bool print_parser: if provided, prints the options that are set within the
        model after loading the model
    :return: the final result of calling report()
    """
    random.seed(42)
    if 'train' in opt['datatype'] and 'evalmode' not in opt['datatype']:
        raise ValueError(
            'You should use --datatype train:evalmode if you want to evaluate on '
            'the training set.'
        )

    if opt['save_world_logs'] and not opt['report_filename']:
        raise RuntimeError(
            'In order to save model replies, please specify the save path '
            'with --report-filename'
        )

    # load model and possibly print opt
    agent = create_agent(opt, requireModelExists=True)
    if print_parser:
        # show args after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()

    tasks = opt['task'].split(',')
    reports = []
    for task in tasks:
        task_report = _eval_single_world(opt, agent, task)
        reports.append(task_report)

    report = aggregate_named_reports(
        dict(zip(tasks, reports)), micro_average=opt.get('aggregate_micro', False)
    )

    # print announcments and report
    print_announcements(opt)
    print(
        '[ Finished evaluating tasks {} using datatype {} ]'.format(
            tasks, opt.get('datatype', 'N/A')
        )
    )
    print(nice_report(report))
    _save_eval_stats(opt, report)
コード例 #4
0
def eval_model(opt):
    """
    Evaluates a model.

    :param opt: tells the evaluation function how to run
    :return: the final result of calling report()
    """
    random.seed(42)
    if 'train' in opt['datatype'] and 'evalmode' not in opt['datatype']:
        raise ValueError(
            'You should use --datatype train:evalmode if you want to evaluate on '
            'the training set.')

    if opt['save_world_logs'] and not opt['report_filename']:
        raise RuntimeError(
            'In order to save model replies, please specify the save path '
            'with --report-filename')

    # load model and possibly print opt
    agent = create_agent(opt, requireModelExists=True)
    agent.opt.log()

    tasks = opt['task'].split(',')
    reports = []
    for task in tasks:
        task_report = _eval_single_world(opt, agent, task)
        reports.append(task_report)

    report = aggregate_named_reports(dict(zip(tasks, reports)),
                                     micro_average=opt.get(
                                         'aggregate_micro', False))

    # print announcments and report
    print_announcements(opt)
    logging.info(
        f'Finished evaluating tasks {tasks} using datatype {opt.get("datatype")}'
    )

    print(nice_report(report))
    _save_eval_stats(opt, report)
    return report
コード例 #5
0
    def train(self):
        """
        Perform a training run.

        :return: tuple of reports (validation_report, test_report)
        """
        logging.info('training...')
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                try:
                    world.parley()
                except StopTrainException as e:
                    logging.info(f"Stopping from {e}")
                    break

                self.parleys += 1

                # get the total training examples done, compute epochs
                self._total_epochs = self._preempted_epochs + sum(
                    all_gather_list(world.get_total_epochs()))
                exs_per_epoch = world.num_examples()
                self._total_exs = int(
                    np.round(self._total_epochs * exs_per_epoch))
                # and use the primary worker's timings for everything
                train_time, log_time, validate_time = sync_object((
                    self.train_time.time(),
                    self.log_time.time(),
                    self.validate_time.time(),
                ))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    self.log()
                    logging.info(
                        f'num_epochs completed:{self.max_num_epochs} time elapsed:{train_time}s'
                    )
                    break
                if train_time > self.max_train_time:
                    logging.info(f'max_train_time elapsed:{train_time}s')
                    break
                if log_time > self.log_every_n_secs:
                    self.log()
                if (validate_time > self.val_every_n_secs
                        or self._total_epochs - self.last_valid_epoch >=
                        self.val_every_n_epochs):
                    try:
                        # log before we validate
                        self.log()
                        world.reset_metrics()
                        stop_training = self.validate()
                    except StopTrainException:
                        break
                    # reset the log time because we logged right before validating
                    self.log_time.reset()
                    self.last_valid_epoch = self._total_epochs
                    if stop_training:
                        break
                    # make sure metrics are clean before we log
                    world.reset_metrics()
                if (self.save_time.time() > self.save_every_n_secs
                        and opt.get('model_file') and is_primary_worker()):
                    logging.info(
                        f"saving model checkpoint: {opt['model_file']}.checkpoint"
                    )
                    if opt['tensorboard_log'] and is_primary_worker():
                        self.tb_logger.flush()
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not self.saved and is_primary_worker():
            # save agent
            self.save_model()

        # there's a rare edge case where the we never saved the model, and we try
        # # to reload it. This sync_object ensures all workers wait for the primary
        # worker to finish flushing before loading from disk.
        sync_object(None)
        if opt.get('model_file'):
            # clean up all our memory, just to make sure we don't OOM on GPU when
            # reloading the world
            del world
            del self.world
            del self.agent
            del self.valid_worlds
            # reload best validation model
            self.agent = create_agent(opt)

        # perform final validation/testing
        valid_worlds = load_eval_worlds(self.agent, opt, 'valid')
        max_exs = opt['validation_max_exs'] if opt.get(
            'short_final_eval') else -1
        v_report = self._run_eval(valid_worlds,
                                  opt,
                                  'valid',
                                  max_exs,
                                  write_log=True)
        test_worlds = load_eval_worlds(self.agent, opt, 'test')
        t_report = self._run_eval(test_worlds,
                                  opt,
                                  'test',
                                  max_exs,
                                  write_log=True)
        if valid_worlds:
            for valid_world in valid_worlds:
                valid_world.shutdown()
        if test_worlds:
            for test_world in test_worlds:
                test_world.shutdown()

        print_announcements(opt)

        return v_report, t_report
コード例 #6
0
ファイル: train_model.py プロジェクト: kotenev/ParlAI
    def train(self):
        """
        Perform a training run.

        :return: tuple of reports (validation_report, test_report)
        """
        if is_distributed():
            warn_once(
                "Distributed training outputs average-per-worker metrics during "
                "training, and may be slightly distorted. Validation/test are "
                "unadulterated.")
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                world.parley()
                self.parleys += 1

                # get the total training examples done, compute epochs
                self._total_epochs = (
                    self._preempted_epochs +
                    num_workers() * self.world.get_total_epochs())
                exs_per_epoch = self.world.num_examples()
                self._total_exs = int(
                    np.round(self._total_epochs * exs_per_epoch))

                # and use the primary worker's timings for everything
                train_time, log_time, validate_time = sync_object((
                    self.train_time.time(),
                    self.log_time.time(),
                    self.validate_time.time(),
                ))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    self.log()
                    print(
                        '[ num_epochs completed:{} time elapsed:{}s ]'.format(
                            self.max_num_epochs, train_time))
                    break
                if train_time > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(train_time))
                    break
                if log_time > self.log_every_n_secs:
                    self.log()
                if (validate_time > self.val_every_n_secs
                        or self._total_epochs - self.last_valid_epoch >=
                        self.val_every_n_epochs):
                    stop_training = self.validate()
                    self.last_valid_epoch = self._total_epochs
                    if stop_training:
                        break
                if (self.save_time.time() > self.save_every_n_secs
                        and opt.get('model_file') and is_primary_worker()):
                    print("[ saving model checkpoint: {}.checkpoint".format(
                        opt['model_file']))
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not self.saved and is_primary_worker():
            # save agent
            self.save_model()
        elif opt.get('model_file'):
            # reload best validation model
            self.agent = create_agent(opt)

        valid_worlds = _maybe_load_eval_worlds(self.agent, opt, 'valid')
        max_exs = opt['validation_max_exs'] if opt.get(
            'short_final_eval') else -1
        v_report = run_eval(valid_worlds,
                            opt,
                            'valid',
                            max_exs,
                            write_log=True)
        test_worlds = _maybe_load_eval_worlds(self.agent, opt, 'test')
        t_report = run_eval(test_worlds, opt, 'test', max_exs, write_log=True)
        if valid_worlds:
            for valid_world in valid_worlds:
                valid_world.shutdown()
        if test_worlds:
            for test_world in test_worlds:
                test_world.shutdown()

        print_announcements(opt)

        return v_report, t_report
コード例 #7
0
    def train(self):
        """
        Perform a training run.

        :return: tuple of reports (validation_report, test_report)
        """
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                try:
                    world.parley()
                except StopTrainException:
                    if is_distributed():
                        raise RuntimeError(
                            "StopTrainException not supported for "
                            "distributed mode")
                    break

                self.parleys += 1

                # get the total training examples done, compute epochs
                self._total_epochs = self._preempted_epochs + sum(
                    all_gather_list(world.get_total_epochs()))
                exs_per_epoch = world.num_examples()
                self._total_exs = int(
                    np.round(self._total_epochs * exs_per_epoch))
                # and use the primary worker's timings for everything
                train_time, log_time, validate_time = sync_object((
                    self.train_time.time(),
                    self.log_time.time(),
                    self.validate_time.time(),
                ))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    self.log()
                    print(
                        '[ num_epochs completed:{} time elapsed:{}s ]'.format(
                            self.max_num_epochs, train_time))
                    break
                if train_time > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(train_time))
                    break
                if log_time > self.log_every_n_secs:
                    self.log()
                if (validate_time > self.val_every_n_secs
                        or self._total_epochs - self.last_valid_epoch >=
                        self.val_every_n_epochs):
                    try:
                        # log before we validate
                        self.log()
                        world.reset_metrics()
                        stop_training = self.validate()
                    except StopTrainException:
                        if is_distributed():
                            raise RuntimeError(
                                "StopTrainException not supported for distributed mode"
                            )
                        break
                    # reset the log time because we logged right before validating
                    self.log_time.reset()
                    self.last_valid_epoch = self._total_epochs
                    if stop_training:
                        break
                    # make sure metrics are clean before we log
                    world.reset_metrics()
                if (self.save_time.time() > self.save_every_n_secs
                        and opt.get('model_file') and is_primary_worker()):
                    print("[ saving model checkpoint: {}.checkpoint".format(
                        opt['model_file']))
                    if opt['tensorboard_log'] and is_primary_worker():
                        self.tb_logger.flush()
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not self.saved and is_primary_worker():
            # save agent
            self.save_model()
        elif opt.get('model_file'):
            # reload best validation model
            self.agent = create_agent(opt)

        valid_worlds = load_eval_worlds(self.agent, opt, 'valid')
        max_exs = opt['validation_max_exs'] if opt.get(
            'short_final_eval') else -1
        v_report = self._run_eval(valid_worlds,
                                  opt,
                                  'valid',
                                  max_exs,
                                  write_log=True)
        test_worlds = load_eval_worlds(self.agent, opt, 'test')
        t_report = self._run_eval(test_worlds,
                                  opt,
                                  'test',
                                  max_exs,
                                  write_log=True)
        if valid_worlds:
            for valid_world in valid_worlds:
                valid_world.shutdown()
        if test_worlds:
            for test_world in test_worlds:
                test_world.shutdown()

        print_announcements(opt)

        return v_report, t_report
コード例 #8
0
def eval_model(opt, printargs=None, print_parser=None):
    """Evaluates a model.

    :param opt: tells the evaluation function how to run
    :param bool print_parser: if provided, prints the options that are set within the
        model after loading the model
    :return: the final result of calling report()
    """
    if printargs is not None:
        print('[ Deprecated Warning: eval_model no longer uses `printargs` ]')
        print_parser = printargs
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    if isinstance(opt, ParlaiParser):
        print(
            '[ Deprecated Warning: eval_model should be passed opt not Parser ]'
        )
        opt = opt.parse_args()

    random.seed(42)

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    world = create_task(opt, agent)

    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    # Show some example dialogs:
    cnt = 0
    while not world.epoch_done():
        cnt += opt.get('batchsize', 1)
        world.parley()
        if opt['display_examples']:
            print(world.display() + "\n~~")
        if log_time.time() > log_every_n_secs:
            report = world.report()
            text, report = log_time.log(report['exs'], world.num_examples(),
                                        report)
            print(text)
        if opt['num_examples'] > 0 and cnt >= opt['num_examples']:
            break
    if world.epoch_done():
        print("EPOCH DONE")
    print('finished evaluating task {} using datatype {}'.format(
        opt['task'], opt.get('datatype', 'N/A')))
    report = world.report()
    print(report)

    print_announcements(opt)

    return report
    def train(self):
        if is_distributed():
            warn_once(
                "Distributed training outputs average-per-worker metrics during "
                "training, and may be slightly distorted. Validation/test are "
                "unadulterated."
            )
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                world.parley()
                self.parleys += 1

                # get the total training examples done, compute epochs
                self._total_epochs = (
                    self._preempted_epochs +
                    num_workers() * self.world.get_total_epochs()
                )
                exs_per_epoch = self.world.num_examples()
                self._total_exs = int(np.round(self._total_epochs * exs_per_epoch))

                # and use the primary worker's timings for everything
                train_time, log_time, validate_time = sync_object((
                    self.train_time.time(),
                    self.log_time.time(),
                    self.validate_time.time()
                ))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    self.log()
                    print('[ num_epochs completed:{} time elapsed:{}s ]'.format(
                        self.max_num_epochs, train_time))
                    break
                if train_time > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(train_time))
                    break
                if log_time > self.log_every_n_secs:
                    self.log()
                if (
                    validate_time > self.val_every_n_secs or
                    self._total_epochs - self.last_valid_epoch
                        >= self.val_every_n_epochs
                ):
                    stop_training = self.validate()
                    self.last_valid_epoch = self._total_epochs

                    # --------------- change by hengyicai -------------------------
                    # run evaluation on the test data as well
                    test_opt = copy.deepcopy(self.opt)
                    test_opt['display_examples'] = False
                    test_opt['report_freq'] = 0
                    if self.test_world is None:
                        # we need to load the world now
                        self.test_world = _maybe_load_eval_world(self.agent, test_opt, 'test')
                    run_eval(self.test_world, test_opt, 'test', -1, write_log=True)
                    # --------------- change by hengyicai -------------------------
                    if stop_training:
                        break
                if (
                    self.save_time.time() > self.save_every_n_secs and
                    opt.get('model_file') and
                    is_primary_worker()
                ):
                    print("[ saving model checkpoint: {}.checkpoint".format(
                        opt['model_file']
                    ))
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not self.saved and is_primary_worker():
            # save agent
            self.save_model()
        elif opt.get('model_file'):
            # reload best validation model
            self.agent = create_agent(opt)

        valid_world = _maybe_load_eval_world(self.agent, opt, 'valid')
        max_exs = opt['validation_max_exs'] if opt.get('short_final_eval') else -1
        v_report = run_eval(valid_world, opt, 'valid', max_exs, write_log=True)
        test_world = _maybe_load_eval_world(self.agent, opt, 'test')
        t_report = run_eval(test_world, opt, 'test', max_exs, write_log=True)
        if valid_world:
            valid_world.shutdown()
        if test_world:
            test_world.shutdown()

        # --------------- change by hengyicai -------------------------
        last_model = opt.get('model_file') + '.checkpoint'
        if os.path.isfile(last_model):
            print('[ Conducting evaluations on valid and test data using the last model. ]')
            last_model_opt = copy.deepcopy(opt)
            last_model_opt['model_file'] = last_model
            last_agent = create_agent(last_model_opt)
            valid_world = _maybe_load_eval_world(last_agent, last_model_opt, 'valid')
            max_exs = last_model_opt['validation_max_exs'] if last_model_opt.get('short_final_eval') else -1
            run_eval(valid_world, last_model_opt, 'valid', max_exs, write_log=True)
            test_world = _maybe_load_eval_world(last_agent, last_model_opt, 'test')
            run_eval(test_world, last_model_opt, 'test', max_exs, write_log=True)
            if valid_world:
                valid_world.shutdown()
            if test_world:
                test_world.shutdown()
        # --------------- change by hengyicai -------------------------

        print_announcements(opt)

        return v_report, t_report