예제 #1
0
def multiprocess_train(
    rank, opt, port=61337, rank_offset=0, gpu=None, hostname='localhost'
):
    """
    Subprocess which initializes distributed training, and begins training.

    This should be launched n times for n GPUs; this is handled either in main
    or via srun.

    :param int rank: This process's rank - 1. (Starts at -1 ... n - 2). See comments.
    :param opt: command line options
    :param int port: A TCP port to use. This will need to be changed to run
        multiple distributed training setups on the same machine.
    :param int gpu: Which GPU to use. Defaults to using rank and local devices,
        but must be manually specified when using many-hosts.
    :param str hostname: Hostname of the main server.
    """
    # Set per-host options
    opt = copy.deepcopy(opt)
    # we need to manually adjust the rank differently in multiprocessing
    # and distributed train
    rank = rank + rank_offset
    opt['rank'] = rank
    if gpu is None:
        # default assumption is local GPUs
        gpu = rank % torch.cuda.device_count()
    opt['gpu'] = gpu
    # make sure we don't just use whatever GPU was saved in the model file
    if 'override' not in opt:
        opt['override'] = {}
    opt['override']['gpu'] = gpu

    # Suppress output of workers except the main host.
    if opt.get('verbose') or rank != 0:
        print_prefix = '[rank:{:3d}]'.format(rank)
    else:
        print_prefix = None
    suppress_output = not opt.get('verbose') and rank != 0

    with distributed_utils.override_print(suppress_output, print_prefix):
        # perform distributed setup, ensuring all hosts are ready
        if opt['gpu'] != -1:
            torch.cuda.set_device(opt['gpu'])
        dist.init_process_group(
            backend="nccl",
            init_method="tcp://{}:{}".format(hostname, port),
            world_size=opt['distributed_world_size'],
            rank=rank,
        )
        logging.info("Distributed group initialized")

        # manual_seed can be a noop without this
        torch.cuda.init()
        # make sure all parameters will be in sync
        torch.manual_seed(42)
        # force a sync so that no one gets ahead, and all are seeded together
        distributed_utils.sync_object(None)

        # Run the actual training
        return single_train.TrainLoop(opt).train()
예제 #2
0
def run_eval(valid_worlds, opt, datatype, max_exs=-1, write_log=False):
    """
    Eval on validation/test data.

    :param valid_world:
        list of the pre-created validation worlds.
    :param opt:
        the options that specific the task, eval_task, etc
    :param datatype:
        the datatype to use, such as "valid" or "test"
    :param bool write_log:
        specifies to write metrics to file if the model_file is set
    :param int max_exs:
        limits the number of examples if max_exs > 0
    """
    if valid_worlds is None:
        # This isn't the primary worker, so we can just skip evaluation
        return sync_object(None)

    print('[ running eval: ' + datatype + ' ]')
    timer = Timer()
    reports = []
    for v_world in valid_worlds:
        task_report = _run_single_eval(opt, v_world,
                                       max_exs / len(valid_worlds))
        reports.append(task_report)

    tasks = [world.getID() for world in valid_worlds]
    named_reports = dict(zip(tasks, reports))
    report = aggregate_named_reports(named_reports)

    metrics = f'{datatype}:{nice_report(report)}'
    print(f'[ eval completed in {timer.time():.2f}s ]')
    print(metrics)

    # write to file
    if write_log and opt.get('model_file'):
        # Write out metrics
        f = open(opt['model_file'] + '.' + datatype, 'a+')
        f.write(f'{metrics}\n')
        f.close()

    return sync_object(report)
예제 #3
0
    def _get_time(self, world: World) -> Tuple[float, float, float]:
        """
        Return train, log, and validate timing.

        If relying on the time for validation/logging/max train time purposes,
        we sync and return primary worker's time.

        Otherwise, it's not super relevant what we do here.

        **SIDE EFFECT**: Update _total_epochs trained.

        :param world:
            current running world

        :return (train, log, valid):
            return time for each of train, log, and validation
        """
        if (
            self.max_train_time < float('inf')
            or self.log_every_n_secs < float('inf')
            or self.val_every_n_secs < float('inf')
            or self.val_every_n_epochs < float('inf')
            or self.max_num_epochs < float('inf')
        ):
            self._total_epochs = self._preempted_epochs + sum(
                all_gather_list(world.get_total_epochs())
            )
            train_time, log_time, validate_time, save_time = sync_object(
                (
                    self.train_time.time(),
                    self.log_time.time(),
                    self.validate_time.time(),
                    self.save_time.time(),
                )
            )
        else:
            train_time, log_time, validate_time, save_time = (
                self.train_time.time(),
                self.log_time.time(),
                self.validate_time.time(),
                self.save_time.time(),
            )
            self._total_epochs = self._preempted_epochs + (
                num_workers() * world.get_total_epochs()
            )

        return train_time, log_time, validate_time, save_time
예제 #4
0
    def train_steps(self):
        """
        Core training loop.

        Yields a metrics dict with each log.
        """
        logging.info('training...')
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                try:
                    world.parley()
                except StopTrainException as e:
                    logging.info(f"Stopping from {e}")
                    break

                self.parleys += 1
                self._train_steps = self.parleys // self.update_freq
                self._last_log_steps += 1 / self.update_freq

                # the following additionally updates self._total_epochs
                train_time, log_time, validate_time, save_time = self._get_time(world)
                # get the total training examples done, compute epochs
                exs_per_epoch = world.num_examples()
                self._total_exs = int(np.round(self._total_epochs * exs_per_epoch))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    yield self.log()
                    logging.info(
                        f'num_epochs completed:{self.max_num_epochs} time elapsed:{train_time}s'
                    )
                    break
                if train_time > self.max_train_time:
                    logging.info(f'max_train_time elapsed:{train_time}s')
                    break
                if self._train_steps >= self.max_train_steps:
                    logging.info(
                        f'max_train_steps elapsed:{self._train_steps} '
                        f'time elapsed:{train_time}s'
                    )
                    break
                if (
                    log_time > self.log_every_n_secs
                    or self._last_log_steps >= self.log_every_n_steps
                ):
                    yield self.log()
                if (
                    validate_time > self.val_every_n_secs
                    or self._total_epochs - self.last_valid_epoch
                    >= self.val_every_n_epochs
                    or self._train_steps - self._last_valid_steps
                    >= self.val_every_n_steps
                ):
                    try:
                        # log before we validate
                        if self._last_log_steps:
                            yield self.log()
                        world.reset_metrics()
                        stop_training = self.validate()
                    except StopTrainException:
                        break
                    # reset the log time because we logged right before validating
                    self.log_time.reset()
                    self.last_valid_epoch = self._total_epochs
                    self._last_valid_steps = self._train_steps
                    if stop_training:
                        break
                    # make sure metrics are clean before we log
                    world.reset_metrics()
                if save_time > self.save_every_n_secs and opt.get('model_file'):
                    logging.info(
                        f"saving model checkpoint: {opt['model_file']}.checkpoint"
                    )
                    if opt['tensorboard_log'] and is_primary_worker():
                        self.tb_logger.flush()
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not sync_object(self.saved):
            # save agent
            self.save_model()

        # there's a rare edge case where the we never saved the model, and we try
        # # to reload it. This sync_object ensures all workers wait for the primary
        # worker to finish flushing before loading from disk.
        sync_object(None)
        if opt.get('model_file'):
            # clean up all our memory, just to make sure we don't OOM on GPU when
            # reloading the world
            del world
            del self.world
            del self.agent
            del self.valid_worlds
            # reload best validation model
            self.agent = create_agent(opt)
예제 #5
0
    def train(self):
        """
        Perform a training run.

        :return: tuple of reports (validation_report, test_report)
        """
        logging.info('training...')
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                try:
                    world.parley()
                except StopTrainException as e:
                    logging.info(f"Stopping from {e}")
                    break

                self.parleys += 1

                # get the total training examples done, compute epochs
                self._total_epochs = self._preempted_epochs + sum(
                    all_gather_list(world.get_total_epochs()))
                exs_per_epoch = world.num_examples()
                self._total_exs = int(
                    np.round(self._total_epochs * exs_per_epoch))
                # and use the primary worker's timings for everything
                train_time, log_time, validate_time = sync_object((
                    self.train_time.time(),
                    self.log_time.time(),
                    self.validate_time.time(),
                ))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    self.log()
                    logging.info(
                        f'num_epochs completed:{self.max_num_epochs} time elapsed:{train_time}s'
                    )
                    break
                if train_time > self.max_train_time:
                    logging.info(f'max_train_time elapsed:{train_time}s')
                    break
                if log_time > self.log_every_n_secs:
                    self.log()
                if (validate_time > self.val_every_n_secs
                        or self._total_epochs - self.last_valid_epoch >=
                        self.val_every_n_epochs):
                    try:
                        # log before we validate
                        self.log()
                        world.reset_metrics()
                        stop_training = self.validate()
                    except StopTrainException:
                        break
                    # reset the log time because we logged right before validating
                    self.log_time.reset()
                    self.last_valid_epoch = self._total_epochs
                    if stop_training:
                        break
                    # make sure metrics are clean before we log
                    world.reset_metrics()
                if (self.save_time.time() > self.save_every_n_secs
                        and opt.get('model_file') and is_primary_worker()):
                    logging.info(
                        f"saving model checkpoint: {opt['model_file']}.checkpoint"
                    )
                    if opt['tensorboard_log'] and is_primary_worker():
                        self.tb_logger.flush()
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not self.saved and is_primary_worker():
            # save agent
            self.save_model()

        # there's a rare edge case where the we never saved the model, and we try
        # # to reload it. This sync_object ensures all workers wait for the primary
        # worker to finish flushing before loading from disk.
        sync_object(None)
        if opt.get('model_file'):
            # clean up all our memory, just to make sure we don't OOM on GPU when
            # reloading the world
            del world
            del self.world
            del self.agent
            del self.valid_worlds
            # reload best validation model
            self.agent = create_agent(opt)

        # perform final validation/testing
        valid_worlds = load_eval_worlds(self.agent, opt, 'valid')
        max_exs = opt['validation_max_exs'] if opt.get(
            'short_final_eval') else -1
        v_report = self._run_eval(valid_worlds,
                                  opt,
                                  'valid',
                                  max_exs,
                                  write_log=True)
        test_worlds = load_eval_worlds(self.agent, opt, 'test')
        t_report = self._run_eval(test_worlds,
                                  opt,
                                  'test',
                                  max_exs,
                                  write_log=True)
        if valid_worlds:
            for valid_world in valid_worlds:
                valid_world.shutdown()
        if test_worlds:
            for test_world in test_worlds:
                test_world.shutdown()

        print_announcements(opt)

        return v_report, t_report
예제 #6
0
    def train(self):
        """
        Perform a training run.

        :return: tuple of reports (validation_report, test_report)
        """
        opt = self.opt
        world = self.world
        count = 0
        with world:
            while True:
                # do one example / batch of examples
                try:
                    world.parley()
                except StopTrainException:
                    if is_distributed():
                        raise RuntimeError(
                            "StopTrainException not supported for "
                            "distributed mode")
                    break

                self.parleys += 1

                # get the total training examples done, compute epochs
                self._total_epochs = (
                    self._preempted_epochs +
                    num_workers() * self.world.get_total_epochs())
                exs_per_epoch = self.world.num_examples()
                self._total_exs = int(
                    np.round(self._total_epochs * exs_per_epoch))

                # and use the primary worker's timings for everything
                train_time, log_time, validate_time = sync_object((
                    self.train_time.time(),
                    self.log_time.time(),
                    self.validate_time.time(),
                ))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    self.log()
                    print(
                        '[ num_epochs completed:{} time elapsed:{}s ]'.format(
                            self.max_num_epochs, train_time))
                    break
                if train_time > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(train_time))
                    break
                if log_time > self.log_every_n_secs:
                    self.log()
                if (validate_time > self.val_every_n_secs
                        or self._total_epochs - self.last_valid_epoch >=
                        self.val_every_n_epochs):
                    try:
                        stop_training = self.validate()
                    except StopTrainException:
                        if is_distributed():
                            raise RuntimeError(
                                "StopTrainException not "
                                "supported for distributed mode")
                        break
                    self.last_valid_epoch = self._total_epochs
                    if stop_training:
                        break
                if (self.save_time.time() > self.save_every_n_secs
                        and opt.get('model_file') and is_primary_worker()):
                    print("[ saving model checkpoint: {}.checkpoint".format(
                        opt['model_file']))
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not self.saved and is_primary_worker():
            # save agent
            self.save_model()
        elif opt.get('model_file'):
            # reload best validation model
            self.agent = create_agent(opt)

        valid_worlds = load_eval_worlds(self.agent, opt, 'valid')
        max_exs = opt['validation_max_exs'] if opt.get(
            'short_final_eval') else -1
        v_report = run_eval(valid_worlds,
                            opt,
                            'valid',
                            max_exs,
                            write_log=True)
        test_worlds = load_eval_worlds(self.agent, opt, 'test')
        t_report = run_eval(test_worlds, opt, 'test', max_exs, write_log=True)
        if valid_worlds:
            for valid_world in valid_worlds:
                valid_world.shutdown()
        if test_worlds:
            for test_world in test_worlds:
                test_world.shutdown()

        print_announcements(opt)

        return v_report, t_report
예제 #7
0
파일: train_model.py 프로젝트: omry/ParlAI
    def validate(self):
        """
        Perform a validation run, checking whether we should stop training.

        :return: boolean indicating whether training should stop
        :rtype: bool
        """
        opt = self.opt

        if self.valid_worlds is None:
            # we need to load the world now
            self.valid_worlds = _maybe_load_eval_worlds(
                self.agent, opt, 'valid')

        # run evaluation on valid set
        valid_report = sync_object(
            run_eval(self.valid_worlds, opt, 'valid',
                     opt['validation_max_exs']))
        v = valid_report.copy()
        v['train_time'] = self.train_time.time()
        self.valid_reports.append(v)
        # logging
        if opt['tensorboard_log'] and is_primary_worker():
            self.tb_logger.log_metrics('valid', self.parleys, valid_report)
        # saving
        if (opt.get('model_file') and opt.get('save_after_valid')
                and is_primary_worker()):
            print("[ saving model checkpoint: " + opt['model_file'] +
                  ".checkpoint ]")
            self.save_model('.checkpoint')

        # send valid metrics to agent if the agent wants them
        if hasattr(self.agent, 'receive_metrics'):
            self.agent.receive_metrics(valid_report)

        # check which metric to look at
        if 'tasks' in valid_report and '/' in opt['validation_metric']:
            # if you are multitasking and want your validation metric to be
            # a metric specific to a subtask, specify your validation metric
            # as -vmt subtask/metric
            subtask = opt['validation_metric'].split('/')[0]
            validation_metric = opt['validation_metric'].split('/')[1]
            new_valid = valid_report['tasks'][subtask][validation_metric]
        else:
            new_valid = valid_report[opt['validation_metric']]

        # check if this is the best validation so far
        if (self.best_valid is None or self.valid_optim * new_valid >
                self.valid_optim * self.best_valid):
            print('[ new best {}: {}{} ]'.format(
                opt['validation_metric'],
                new_valid,
                ' (previous best was {})'.format(self.best_valid)
                if self.best_valid is not None else '',
            ))
            self.best_valid = new_valid
            self.impatience = 0
            if opt.get('model_file') and is_primary_worker():
                print("[ saving best valid model: " + opt['model_file'] + " ]")
                self.save_model()
                print("[ saving best valid metric: " + opt['model_file'] +
                      ".best_valid ]")
                _save_best_valid(opt['model_file'], self.best_valid)
                self.saved = True
            if (opt['validation_metric'] == 'accuracy'
                    and self.best_valid >= opt['validation_cutoff']):
                print('[ task solved! stopping. ]')
                return True
        else:
            self.impatience += 1
            print('[ did not beat best {}: {} impatience: {} ]'.format(
                opt['validation_metric'], round(self.best_valid, 4),
                self.impatience))
        self.validate_time.reset()

        # check if we are out of patience
        if (opt['validation_patience'] > 0
                and self.impatience >= opt['validation_patience']):
            print('[ ran out of patience! stopping training. ]')
            return True
        return False
예제 #8
0
    def _save_outputs(self, opt, world, logger, episode_metrics):
        if is_distributed():  # flatten everything intelligently if need be
            world_report = aggregate_unnamed_reports(
                all_gather_list(world.report()))
            episode_metrics_unflattened = all_gather_list(episode_metrics)
            flattened = []
            for rank_elem in episode_metrics_unflattened:
                for elem in rank_elem:
                    flattened.append(elem)
            episode_metrics = flattened
        else:
            world_report = world.report()
        logging.report("Final report:\n" + nice_report(world_report))

        report = dict_report(world_report)

        def get_episode_report(goal, episode_metric):
            metrics_dict = dict_report(episode_metric.report())
            metrics_dict["goal"] = goal
            return metrics_dict

        report["tod_metrics"] = [
            get_episode_report(g, e) for g, e in episode_metrics
        ]

        if "report_filename" in opt and opt["report_filename"] is not None:
            if len(world_report) == 0:
                logging.warning("Report is empty; not saving report")

            report_fname = f"{opt['report_filename']}.json"
            # Save report
            if not is_distributed() or is_primary_worker():
                with PathManager.open(report_fname, "w") as f:
                    logging.info(f"Saving model report to {report_fname}")
                    json.dump({"opt": opt, "report": report}, f, indent=4)
                    f.write("\n")  # for jq

        if "world_logs" in opt and opt["world_logs"] is not None:
            if is_distributed():  # Save separately, then aggregate together
                rank = get_rank()
                log_outfile_part = (
                    f"{opt['world_logs']}_{opt['save_format']}_{rank}.jsonl")
                logger.write(log_outfile_part,
                             world,
                             file_format=opt["save_format"])
                sync_object(None)
                if is_primary_worker():
                    log_outfile = f"{opt['world_logs']}_{opt['save_format']}.jsonl"
                    log_outfile_metadata = (
                        f"{opt['world_logs']}_{opt['save_format']}.metadata")
                    with open(log_outfile, "w+") as outfile:
                        for rank in range(num_workers()):
                            log_outfile_part = (
                                f"{opt['world_logs']}_{opt['save_format']}_{rank}.jsonl"
                            )
                            with open(log_outfile_part) as infile:
                                for line in infile:
                                    json_blob = json.loads(line.strip())
                                    if (
                                            len(json_blob["dialog"]) < 2
                                    ):  # skip when we don't have generation
                                        continue
                                    json_blob[
                                        "metadata_path"] = log_outfile_metadata
                                    outfile.write(json.dumps(json_blob))
                                    outfile.write("\n")
                            log_output_part_metadata = f"{opt['world_logs']}_{opt['save_format']}_{rank}.metadata"
                            if rank == 0:
                                copyfile(log_output_part_metadata,
                                         log_outfile_metadata),
                            os.remove(log_outfile_part)
                            os.remove(log_output_part_metadata)
            else:
                log_outfile = f"{opt['world_logs']}_{opt['save_format']}.jsonl"
                logger.write(log_outfile,
                             world,
                             file_format=opt["save_format"])

        return report
예제 #9
0
    def train(self):
        """
        Perform a training run.

        :return: tuple of reports (validation_report, test_report)
        """
        if is_distributed():
            warn_once(
                "Distributed training outputs average-per-worker metrics during "
                "training, and may be slightly distorted. Validation/test are "
                "unadulterated.")
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                world.parley()
                self.parleys += 1

                # get the total training examples done, compute epochs
                self._total_epochs = (
                    self._preempted_epochs +
                    num_workers() * self.world.get_total_epochs())
                exs_per_epoch = self.world.num_examples()
                self._total_exs = int(
                    np.round(self._total_epochs * exs_per_epoch))

                # and use the primary worker's timings for everything
                train_time, log_time, validate_time = sync_object((
                    self.train_time.time(),
                    self.log_time.time(),
                    self.validate_time.time(),
                ))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    self.log()
                    print(
                        '[ num_epochs completed:{} time elapsed:{}s ]'.format(
                            self.max_num_epochs, train_time))
                    break
                if train_time > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(train_time))
                    break
                if log_time > self.log_every_n_secs:
                    self.log()
                if (validate_time > self.val_every_n_secs
                        or self._total_epochs - self.last_valid_epoch >=
                        self.val_every_n_epochs):
                    stop_training = self.validate()
                    self.last_valid_epoch = self._total_epochs

                    # --------------- change by hengyicai -------------------------
                    if opt.get('run_test_after_validation', False):
                        # run evaluation on the test data as well
                        test_opt = copy.deepcopy(self.opt)
                        test_opt['display_examples'] = False
                        test_opt['report_freq'] = 0
                        if self.test_worlds is None:
                            # we need to load the world now
                            self.test_worlds = _maybe_load_eval_worlds(
                                self.agent, test_opt, 'test')
                        run_eval(self.test_worlds,
                                 test_opt,
                                 'test',
                                 -1,
                                 write_log=True)
                    # --------------- change by hengyicai -------------------------
                    if stop_training:
                        break
                if (self.save_time.time() > self.save_every_n_secs
                        and opt.get('model_file') and is_primary_worker()):
                    print("[ saving model checkpoint: {}.checkpoint".format(
                        opt['model_file']))
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not self.saved and is_primary_worker():
            # save agent
            self.save_model()
        elif opt.get('model_file'):
            # reload best validation model
            self.agent = create_agent(opt)

        valid_worlds = _maybe_load_eval_worlds(self.agent, opt, 'valid')
        max_exs = opt['validation_max_exs'] if opt.get(
            'short_final_eval') else -1
        v_report = run_eval(valid_worlds,
                            opt,
                            'valid',
                            max_exs,
                            write_log=True)
        test_worlds = _maybe_load_eval_worlds(self.agent, opt, 'test')
        t_report = run_eval(test_worlds, opt, 'test', max_exs, write_log=True)
        if valid_worlds:
            for valid_world in valid_worlds:
                valid_world.shutdown()
        if test_worlds:
            for test_world in test_worlds:
                test_world.shutdown()

        # --------------- change by hengyicai -------------------------
        last_model = opt.get('model_file') + '.checkpoint'
        if os.path.isfile(last_model):
            print(
                '[ Conducting evaluations on valid and test data using the last model. ]'
            )
            last_model_opt = copy.deepcopy(opt)
            last_model_opt['model_file'] = last_model
            last_agent = create_agent(last_model_opt)
            valid_worlds = _maybe_load_eval_worlds(last_agent, last_model_opt,
                                                   'valid')
            max_exs = last_model_opt[
                'validation_max_exs'] if last_model_opt.get(
                    'short_final_eval') else -1
            run_eval(valid_worlds,
                     last_model_opt,
                     'valid',
                     max_exs,
                     write_log=True)
            test_worlds = _maybe_load_eval_worlds(last_agent, last_model_opt,
                                                  'test')
            run_eval(test_worlds,
                     last_model_opt,
                     'test',
                     max_exs,
                     write_log=True)
            if valid_worlds:
                for valid_world in valid_worlds:
                    valid_world.shutdown()
            if test_worlds:
                for test_world in test_worlds:
                    test_world.shutdown()
        # --------------- change by hengyicai -------------------------
        print_announcements(opt)

        return v_report, t_report
예제 #10
0
    def train(self):
        """
        Perform ftml training run.
        :return: tuple of reports (validation_report, test_report)
        """
        logging.info('training...')
        opt = self.opt
        world = self.world
        teacher = world.agents[0]
        student = world.agents[1]
        more_data_in_domain = True

        eval_data = {x: [] for x in teacher.domains}

        with world:
            shuffled_domains = [
                x for x in teacher.domains if x not in ['police', 'hospital']
            ]
            random.shuffle(shuffled_domains)

            for d, domain in enumerate(shuffled_domains):

                N = len(teacher.domain_convo_inds[domain])
                teacher.add_domain(domain)
                teacher.add_all_domain_data(domain)

            self.best_valid = None
            stop_training = False

            if opt['no_multi_task']:
                # only fine-tune to each domain, so don't enter the multi-task while loop
                stop_training = True
                self._total_epochs = 0
                self._total_exs = 0

            while not stop_training:  # This is for multi-tasking a global model
                for _ in range(
                        int(teacher.num_episodes() /
                            opt['num_episode_batch'])):
                    world.batch_parley()
                    self.parleys += 1

                    # TODO: I think this should be set correctly. Helps tracking, needed for termination?
                    self._total_epochs = 0
                    self._total_exs = 0

                    train_time, log_time, validate_time = sync_object((
                        self.train_time.time(),
                        self.log_time.time(),
                        self.validate_time.time(),
                    ))

                    if log_time > self.log_every_n_secs:
                        self.log()
                self.write_log('finished %s parleys' % self.parleys)
                self.write_log('Learning rate before valid: %s ' %
                               world.agents[1].optimizer.state_dict()
                               ['param_groups'][0]['lr'])
                # validation_decreasing = # todo
                # todo: add validation here to tell when to stop updating the meta model.
                # This is harder to do, as the validation is of the meta model....
                for w in self.valid_worlds:
                    for dd in teacher.added_domains():
                        w.reset(
                        )  # Should also reset the teacher.index.value --> -1, but keep the domain fixed.
                        w.agents[0].add_domain(dd)
                        w.agents[0].add_all_domain_data(dd)
                    # Fix validation teacher to domains training teacher has seen.
                    w.agents[0].fix_teacher_domain(teacher.added_domains())
                    w.agents[
                        0].index.value = -1  # reset index because we'll stream through the training data.
                    w.agents[0].entry_idx = 0

                if not opt['validation_metric'].startswith('bleu'):
                    student.skip_generation = True
                stop_training = self.validate()
                logging.info('Multi-task model validation value: %s ' %
                             self.best_valid)

            # After the multi-task model is trained, fine tune model for each domain.
            M = copy.deepcopy(world.agents[1].model.state_dict())
            optim_state = copy.deepcopy(world.agents[1].optimizer.state_dict())

            for dd in teacher.domains:

                # Restrict valid_world teachers to chosen domain for fine-tuning
                for w in self.valid_worlds:
                    w.reset(
                    )  # Should also reset the teacher.index.value --> -1, but keep the domain fixed.
                    w.agents[0].add_domain(dd)
                    w.agents[0].add_all_domain_data(dd)
                    w.agents[0].fix_teacher_domain([dd])
                    w.agents[
                        0].index.value = -1  # reset index because we'll stream through the data.
                    w.agents[0].entry_idx = 0

                # Restrict test worlds to chosen domain for fine-tuning
                for w in self.test_worlds:
                    w.reset(
                    )  # Should also reset the teacher.index.value --> -1, but keep the domain fixed.
                    w.agents[0].add_domain(dd)
                    w.agents[0].add_all_domain_data(dd)
                    w.agents[0].fix_teacher_domain([dd])
                    w.agents[
                        0].index.value = -1  # reset index because we'll stream through the data.
                    w.agents[0].entry_idx = 0

                    # note the appropriate state_dict should be loaded, as the agent should
                    # be shared by reference in the training and the testing worlds.

                self.write_log("STARTING Fine-tuning OF STUDENT HERE on %s " %
                               dd)
                if self.test_worlds[0].agents[
                        0].num_episodes_in_restricted_domain() > 0:

                    # make sure the meta parameters are loaded before evaluating another training domain
                    world.agents[1].model.load_state_dict(M)
                    world.agents[1].optimizer.load_state_dict(optim_state)

                    # Restrict training world to fine-tuning domain
                    teacher.fix_teacher_domain([dd])
                    teacher.index.value = -1  # reset index because we'll stream through the training data.
                    teacher.entry_idx = 0

                    logging.info('Fine-tuning to: %s' % dd)
                    self.write_log('fine-tuning epoch size: %s' %
                                   teacher.num_episodes_in_restricted_domain())

                    self.best_valid = None
                    stop_training = False
                    self.tune_parley_epochs = 0

                    # Fine-tune model to single domain
                    while not stop_training:
                        # fine-tune for one epoch over training
                        self.write_log('Learning rate : %s ' %
                                       world.agents[1].optimizer.state_dict()
                                       ['param_groups'][0]['lr'])

                        #                     while not world.epoch_done(): # HERE: loop for an epoch over domain training data.
                        for n in range(
                                int(teacher.num_episodes_in_restricted_domain(
                                ) / opt['num_episode_batch'])
                        ):  # epoch episodes, as each full episode processed.
                            #                             print('\n\n TRAINING PARLEY')
                            world.batch_parley(
                            )  # Note the updating is fixed to the domain training data only.


#                             import pdb; pdb.set_trace()

# fine-tune until validation on domain stops decreasing.
#                         print('\n\n VALIDATION PARLEYS')
                        if not opt['validation_metric'].startswith('bleu'):
                            student.skip_generation = True
                        stop_training = self.validate()
                        #                         import pdb; pdb.set_trace()
                        #                         print('\t\t\t\t\tValid: %s Learning rate : %s ' % (self.best_valid, world.agents[1].optimizer.state_dict()['param_groups'][0]['lr']))
                        #                         import sys; sys.exit()
                        #                         print('num training exs: ', world.agents[0].num_episodes_in_restricted_domain())
                        #                         print('num valid exs: ', self.valid_worlds[0].agents[0].num_episodes_in_restricted_domain())
                        #                         print('valid is rand: ', self.valid_worlds[0].agents[0].random)
                        #                         print('episode index: ', self.valid_worlds[0].agents[0].index.value)
                        #                         print('Training: ', world.agents[0].messages[0])
                        #                         print('Validation: ', self.valid_worlds[0].agents[0].messages[0])
                        self.tune_parley_epochs += 1

                        logging.info('Best valid: %s' % self.best_valid)
                        self.write_log('Best fine-tune valid: %s' %
                                       self.best_valid)
                        self.write_log('Finished %s tune_parley epochs' %
                                       self.tune_parley_epochs)

                    # Evaluate on domain test set.
                    student.skip_generation = False
                    max_exs = -1
                    t_report = self._run_eval(self.test_worlds,
                                              opt,
                                              'test',
                                              max_exs,
                                              write_log=True)
                    logging.info('on domain %s: test report: ' % dd)
                    logging.info(t_report)
                    eval_data[dd] = {
                        'domain':
                        dd,
                        'test_report':
                        t_report,
                        'num_parleys':
                        self.parleys,
                        'tune_epochs':
                        self.tune_parley_epochs,
                        'mt_epoch_size':
                        teacher.num_episodes(),
                        'domain_epoch_size':
                        teacher.num_episodes_in_restricted_domain()
                    }

        import datetime
        stamp = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M')
        if opt['no_multi_task']:
            locationname = '/home/oademasi/transfer-learning-conv-ai/ParlAI/parlai_internal/eval_data_ft_%s.pkl' % stamp
        else:
            locationname = '/home/oademasi/transfer-learning-conv-ai/ParlAI/parlai_internal/eval_data_mtft_%s.pkl' % stamp
        pickle.dump(eval_data, open(locationname, 'wb'))
        print('wrote to: ', locationname)
        v_report = None
        t_report = None
        return v_report, t_report
    def train(self):
        """
        Perform a training run.

        :return: tuple of reports (validation_report, test_report)
        """
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                try:
                    world.parley()
                except StopTrainException:
                    if is_distributed():
                        raise RuntimeError(
                            "StopTrainException not supported for "
                            "distributed mode")
                    break

                self.parleys += 1

                # get the total training examples done, compute epochs
                self._total_epochs = self._preempted_epochs + sum(
                    all_gather_list(world.get_total_epochs()))
                exs_per_epoch = world.num_examples()
                self._total_exs = int(
                    np.round(self._total_epochs * exs_per_epoch))
                # and use the primary worker's timings for everything
                train_time, log_time, validate_time = sync_object((
                    self.train_time.time(),
                    self.log_time.time(),
                    self.validate_time.time(),
                ))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    self.log()
                    print(
                        '[ num_epochs completed:{} time elapsed:{}s ]'.format(
                            self.max_num_epochs, train_time))
                    break
                if train_time > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(train_time))
                    break
                if log_time > self.log_every_n_secs:
                    self.log()
                if (validate_time > self.val_every_n_secs
                        or self._total_epochs - self.last_valid_epoch >=
                        self.val_every_n_epochs):
                    try:
                        # log before we validate
                        self.log()
                        world.reset_metrics()
                        stop_training = self.validate()
                    except StopTrainException:
                        if is_distributed():
                            raise RuntimeError(
                                "StopTrainException not supported for distributed mode"
                            )
                        break
                    # reset the log time because we logged right before validating
                    self.log_time.reset()
                    self.last_valid_epoch = self._total_epochs

                    # --------------- change by hengyicai -------------------------
                    if opt.get('run_test_after_validation', False):
                        # run evaluation on the test data as well
                        test_opt = copy.deepcopy(self.opt)
                        test_opt['display_examples'] = False
                        test_opt['report_freq'] = 0
                        if self.test_worlds is None:
                            # we need to load the world now
                            self.test_worlds = load_eval_worlds(
                                self.agent, test_opt, 'test')
                        run_eval(self.test_worlds,
                                 test_opt,
                                 'test',
                                 -1,
                                 write_log=True)
                    # --------------- change by hengyicai -------------------------
                    if stop_training:
                        break
                    # make sure metrics are clean before we log
                    world.reset_metrics()
                if (self.save_time.time() > self.save_every_n_secs
                        and opt.get('model_file') and is_primary_worker()):
                    print("[ saving model checkpoint: {}.checkpoint".format(
                        opt['model_file']))
                    if opt['tensorboard_log'] and is_primary_worker():
                        self.tb_logger.flush()
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not self.saved and is_primary_worker():
            # save agent
            self.save_model()
        elif opt.get('model_file'):
            # reload best validation model
            self.agent = create_agent(opt)

        valid_worlds = load_eval_worlds(self.agent, opt, 'valid')
        max_exs = opt['validation_max_exs'] if opt.get(
            'short_final_eval') else -1
        v_report = run_eval(valid_worlds,
                            opt,
                            'valid',
                            max_exs,
                            write_log=True)
        test_worlds = load_eval_worlds(self.agent, opt, 'test')
        t_report = run_eval(test_worlds, opt, 'test', max_exs, write_log=True)
        if valid_worlds:
            for valid_world in valid_worlds:
                valid_world.shutdown()
        if test_worlds:
            for test_world in test_worlds:
                test_world.shutdown()

        # --------------- change by hengyicai -------------------------
        last_model = opt.get('model_file') + '.checkpoint'
        if os.path.isfile(last_model):
            print(
                '[ Conducting evaluations on valid and test data using the last model. ]'
            )
            last_model_opt = copy.deepcopy(opt)
            last_model_opt['model_file'] = last_model
            last_agent = create_agent(last_model_opt)
            valid_worlds = load_eval_worlds(last_agent, last_model_opt,
                                            'valid')
            max_exs = last_model_opt[
                'validation_max_exs'] if last_model_opt.get(
                    'short_final_eval') else -1
            run_eval(valid_worlds,
                     last_model_opt,
                     'valid',
                     max_exs,
                     write_log=True)
            test_worlds = load_eval_worlds(last_agent, last_model_opt, 'test')
            run_eval(test_worlds,
                     last_model_opt,
                     'test',
                     max_exs,
                     write_log=True)
            if valid_worlds:
                for valid_world in valid_worlds:
                    valid_world.shutdown()
            if test_worlds:
                for test_world in test_worlds:
                    test_world.shutdown()
        # --------------- change by hengyicai -------------------------
        print_announcements(opt)

        return v_report, t_report