コード例 #1
0
def multiprocess_train(
    rank, opt, port=61337, rank_offset=0, gpu=None, hostname='localhost'
):
    """
    Subprocess which initializes distributed training, and begins training.

    This should be launched n times for n GPUs; this is handled either in main
    or via srun.

    :param int rank: This process's rank - 1. (Starts at -1 ... n - 2). See comments.
    :param opt: command line options
    :param int port: A TCP port to use. This will need to be changed to run
        multiple distributed training setups on the same machine.
    :param int gpu: Which GPU to use. Defaults to using rank and local devices,
        but must be manually specified when using many-hosts.
    :param str hostname: Hostname of the main server.
    """
    # Set per-host options
    opt = copy.deepcopy(opt)
    # we need to manually adjust the rank differently in multiprocessing
    # and distributed train
    rank = rank + rank_offset
    opt['rank'] = rank
    if gpu is None:
        # default assumption is local GPUs
        gpu = rank % torch.cuda.device_count()
    opt['gpu'] = gpu
    # make sure we don't just use whatever GPU was saved in the model file
    if 'override' not in opt:
        opt['override'] = {}
    opt['override']['gpu'] = gpu

    # Suppress output of workers except the main host.
    if opt.get('verbose') or rank != 0:
        print_prefix = '[rank:{:3d}]'.format(rank)
    else:
        print_prefix = None
    suppress_output = not opt.get('verbose') and rank != 0

    with distributed_utils.override_print(suppress_output, print_prefix):
        # perform distributed setup, ensuring all hosts are ready
        torch.cuda.set_device(opt['gpu'])
        dist.init_process_group(
            backend="nccl",
            init_method="tcp://{}:{}".format(hostname, port),
            world_size=opt['distributed_world_size'],
            rank=rank,
        )
        print("Distributed group initialized")

        # manual_seed can be a noop without this
        torch.cuda.init()
        # make sure all parameters will be in sync
        torch.manual_seed(42)
        # force a sync so that no one gets ahead, and all are seeded together
        distributed_utils.sync_object(None)

        # Run the actual training
        return single_train.TrainLoop(opt).train()
コード例 #2
0
ファイル: train_model.py プロジェクト: ying-A/RED
    def train(self):
        if is_distributed():
            warn_once(
                "Distributed training outputs average-per-worker metrics during "
                "training, and may be slightly distorted. Validation/test are "
                "unadulterated.")
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                world.parley()
                self.parleys += 1
                # print(world.display())

                # get the total training examples done, compute epochs
                self._total_epochs = (
                    self._preempted_epochs +
                    num_workers() * self.world.get_total_epochs())
                exs_per_epoch = self.world.num_examples()
                self._total_exs = int(
                    np.round(self._total_epochs * exs_per_epoch))

                # and use the primary worker's timings for everything
                train_time, log_time, validate_time = sync_object(
                    (self.train_time.time(), self.log_time.time(),
                     self.validate_time.time()))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    self.log()
                    print(
                        '[ num_epochs completed:{} time elapsed:{}s ]'.format(
                            self.max_num_epochs, train_time))
                    break
                if train_time > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(train_time))
                    break
                if log_time > self.log_every_n_secs:
                    self.log()
                if (validate_time > self.val_every_n_secs
                        or self._total_epochs - self.last_valid_epoch >=
                        self.val_every_n_epochs):
                    stop_training = self.validate()
                    self.last_valid_epoch = self._total_epochs
                    if stop_training:
                        break
                if (self.save_time.time() > self.save_every_n_secs
                        and opt.get('model_file') and is_primary_worker()):
                    print("[ saving model checkpoint: {}.checkpoint".format(
                        opt['model_file']))
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not self.saved and is_primary_worker():
            # save agent
            self.save_model()
        elif opt.get('model_file'):
            # reload best validation model
            self.agent = create_agent(opt)

        valid_world = _maybe_load_eval_world(self.agent, opt, 'valid')
        v_report = run_eval(valid_world, opt, 'valid', write_log=True)
        test_world = _maybe_load_eval_world(self.agent, opt, 'test')
        t_report = run_eval(test_world, opt, 'test', write_log=True)
        if valid_world:
            valid_world.shutdown()
        if test_world:
            test_world.shutdown()

        return v_report, t_report
コード例 #3
0
ファイル: train_model.py プロジェクト: ying-A/RED
    def validate(self):
        opt = self.opt

        if self.valid_world is None:
            # we need to load the world now
            self.valid_world = _maybe_load_eval_world(self.agent, opt, 'valid')

        # run evaluation on valid set
        valid_report = sync_object(
            run_eval(self.valid_world, opt, 'valid', opt['validation_max_exs'],
                     True))

        # logging
        if opt['tensorboard_log'] is True and is_primary_worker():
            self.writer.add_metrics('valid', int(self.train_time.time()),
                                    valid_report)
        # saving
        if (opt.get('model_file') and opt.get('save_after_valid')
                and is_primary_worker()):
            print("[ saving model checkpoint: " + opt['model_file'] +
                  ".checkpoint ]")
            self.save_model('.checkpoint')

        # send valid metrics to agent if the agent wants them
        if hasattr(self.agent, 'receive_metrics'):
            self.agent.receive_metrics(valid_report)

        # check which metric to look at
        if '/' in opt['validation_metric']:
            # if you are multitasking and want your validation metric to be
            # a metric specific to a subtask, specify your validation metric
            # as -vmt subtask/metric
            subtask = opt['validation_metric'].split('/')[0]
            validation_metric = opt['validation_metric'].split('/')[1]
            new_valid = valid_report['tasks'][subtask][validation_metric]
        else:
            new_valid = valid_report[opt['validation_metric']]

        # check if this is the best validation so far
        if (self.best_valid is None or self.valid_optim * new_valid >
                self.valid_optim * self.best_valid):
            print('[ new best {}: {}{} ]'.format(
                opt['validation_metric'], new_valid,
                ' (previous best was {})'.format(self.best_valid)
                if self.best_valid is not None else ''))
            self.best_valid = new_valid
            self.impatience = 0
            if opt.get('model_file') and is_primary_worker():
                print("[ saving best valid model: " + opt['model_file'] + " ]")
                self.save_model()
                print("[ saving best valid metric: " + opt['model_file'] +
                      ".best_valid ]")
                save_best_valid(opt['model_file'], self.best_valid)
                self.saved = True
            if (opt['validation_metric'] == 'accuracy'
                    and self.best_valid >= opt['validation_cutoff']):
                print('[ task solved! stopping. ]')
                return True
        else:
            self.impatience += 1
            print('[ did not beat best {}: {} impatience: {} ]'.format(
                opt['validation_metric'], round(self.best_valid, 4),
                self.impatience))
        self.validate_time.reset()

        # check if we are out of patience
        if (opt['validation_patience'] > 0
                and self.impatience >= opt['validation_patience']):
            print('[ ran out of patience! stopping training. ]')
            return True
        return False
    def train(self):
        if is_distributed():
            warn_once(
                "Distributed training outputs average-per-worker metrics during "
                "training, and may be slightly distorted. Validation/test are "
                "unadulterated."
            )
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                world.parley()
                self.parleys += 1

                # get the total training examples done, compute epochs
                self._total_epochs = (
                    self._preempted_epochs +
                    num_workers() * self.world.get_total_epochs()
                )
                exs_per_epoch = self.world.num_examples()
                self._total_exs = int(np.round(self._total_epochs * exs_per_epoch))

                # and use the primary worker's timings for everything
                train_time, log_time, validate_time = sync_object((
                    self.train_time.time(),
                    self.log_time.time(),
                    self.validate_time.time()
                ))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    self.log()
                    print('[ num_epochs completed:{} time elapsed:{}s ]'.format(
                        self.max_num_epochs, train_time))
                    break
                if train_time > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(train_time))
                    break
                if log_time > self.log_every_n_secs:
                    self.log()
                if (
                    validate_time > self.val_every_n_secs or
                    self._total_epochs - self.last_valid_epoch
                        >= self.val_every_n_epochs
                ):
                    stop_training = self.validate()
                    self.last_valid_epoch = self._total_epochs

                    # --------------- change by hengyicai -------------------------
                    # run evaluation on the test data as well
                    test_opt = copy.deepcopy(self.opt)
                    test_opt['display_examples'] = False
                    test_opt['report_freq'] = 0
                    if self.test_world is None:
                        # we need to load the world now
                        self.test_world = _maybe_load_eval_world(self.agent, test_opt, 'test')
                    run_eval(self.test_world, test_opt, 'test', -1, write_log=True)
                    # --------------- change by hengyicai -------------------------
                    if stop_training:
                        break
                if (
                    self.save_time.time() > self.save_every_n_secs and
                    opt.get('model_file') and
                    is_primary_worker()
                ):
                    print("[ saving model checkpoint: {}.checkpoint".format(
                        opt['model_file']
                    ))
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not self.saved and is_primary_worker():
            # save agent
            self.save_model()
        elif opt.get('model_file'):
            # reload best validation model
            self.agent = create_agent(opt)

        valid_world = _maybe_load_eval_world(self.agent, opt, 'valid')
        max_exs = opt['validation_max_exs'] if opt.get('short_final_eval') else -1
        v_report = run_eval(valid_world, opt, 'valid', max_exs, write_log=True)
        test_world = _maybe_load_eval_world(self.agent, opt, 'test')
        t_report = run_eval(test_world, opt, 'test', max_exs, write_log=True)
        if valid_world:
            valid_world.shutdown()
        if test_world:
            test_world.shutdown()

        # --------------- change by hengyicai -------------------------
        last_model = opt.get('model_file') + '.checkpoint'
        if os.path.isfile(last_model):
            print('[ Conducting evaluations on valid and test data using the last model. ]')
            last_model_opt = copy.deepcopy(opt)
            last_model_opt['model_file'] = last_model
            last_agent = create_agent(last_model_opt)
            valid_world = _maybe_load_eval_world(last_agent, last_model_opt, 'valid')
            max_exs = last_model_opt['validation_max_exs'] if last_model_opt.get('short_final_eval') else -1
            run_eval(valid_world, last_model_opt, 'valid', max_exs, write_log=True)
            test_world = _maybe_load_eval_world(last_agent, last_model_opt, 'test')
            run_eval(test_world, last_model_opt, 'test', max_exs, write_log=True)
            if valid_world:
                valid_world.shutdown()
            if test_world:
                test_world.shutdown()
        # --------------- change by hengyicai -------------------------

        print_announcements(opt)

        return v_report, t_report
    def validate(self):
        opt = self.opt

        if self.valid_world is None:
            # we need to load the world now
            self.valid_world = _maybe_load_eval_world(self.agent, opt, 'valid')

        # run evaluation on valid set
        valid_report = sync_object(run_eval(
            self.valid_world, opt, 'valid', opt['validation_max_exs'],
        ))
        v = valid_report.copy()
        v['train_time'] = self.train_time.time()
        self.valid_reports.append(v)
        # logging
        if opt['tensorboard_log'] is True and is_primary_worker():
            self.writer.add_metrics('valid', int(self.train_time.time()), valid_report)
        # saving
        if (
            opt.get('model_file') and
            opt.get('save_after_valid') and
            is_primary_worker()
        ):
            print("[ saving model checkpoint: " +
                  opt['model_file'] + ".checkpoint ]")
            self.save_model('.checkpoint')

        # send valid metrics to agent if the agent wants them
        if hasattr(self.agent, 'receive_metrics'):
            self.agent.receive_metrics(valid_report)

        # --------------- change by hengyicai -------------------------
        teacher_agent = self.return_teacher_agent()
        if teacher_agent:
            teacher_agent.receive_metrics(valid_report)
        # --------------- change by hengyicai -------------------------

        # check which metric to look at
        if '/' in opt['validation_metric']:
            # if you are multitasking and want your validation metric to be
            # a metric specific to a subtask, specify your validation metric
            # as -vmt subtask/metric
            subtask = opt['validation_metric'].split('/')[0]
            validation_metric = opt['validation_metric'].split('/')[1]
            new_valid = valid_report['tasks'][subtask][validation_metric]
        else:
            new_valid = valid_report[opt['validation_metric']]

        # check if this is the best validation so far
        if (self.best_valid is None or
                self.valid_optim * new_valid > self.valid_optim * self.best_valid):
            print('[ new best {}: {}{} ]'.format(
                opt['validation_metric'], new_valid,
                ' (previous best was {})'.format(self.best_valid)
                if self.best_valid is not None else ''))
            self.best_valid = new_valid
            self.impatience = 0
            if opt.get('model_file') and is_primary_worker():
                print("[ saving best valid model: " + opt['model_file'] + " ]")
                self.save_model()
                print("[ saving best valid metric: " +
                      opt['model_file'] + ".best_valid ]")
                save_best_valid(opt['model_file'], self.best_valid)
                self.saved = True
            if (opt['validation_metric'] == 'accuracy' and
                    self.best_valid >= opt['validation_cutoff']):
                print('[ task solved! stopping. ]')
                return True
        else:
            self.impatience += 1
            print('[ did not beat best {}: {} impatience: {} ]'.format(
                opt['validation_metric'], round(self.best_valid, 4),
                self.impatience))
        # --------------- change by hengyicai -------------------------
        if self.opt.get('cutoff_metric_name', 'none') != 'none':
            cutoff_metric_name = self.opt['cutoff_metric_name']
            cutoff_metric_val = self.opt['cutoff_metric_val']
            if cutoff_metric_name in valid_report and cutoff_metric_val > 0:
                if valid_report[cutoff_metric_name] >= cutoff_metric_val:
                    print('[ {} >= {}, stopping. ]'.format(
                        cutoff_metric_name, cutoff_metric_val
                    ))
                    return True
            elif cutoff_metric_name not in valid_report:
                warn_once('[ {} is not in the validation report!'
                          'can not do metric cutoff stopping! ]'.format(cutoff_metric_name))
            else:
                warn_once('[ you asked to do metric cutoff stopping,'
                          'but the cutoff_metric_val <= 0! ]')

        # --------------- change by hengyicai -------------------------
        self.validate_time.reset()

        # check if we are out of patience
        if (opt['validation_patience'] > 0 and
                self.impatience >= opt['validation_patience']):
            print('[ ran out of patience! stopping training. ]')
            return True
        return False