예제 #1
0
    def test_gpt2standin(self):
        with testing_utils.tempdir() as tmpdir:
            # we need to build the dict file
            hf_bpe_opt = self._get_dict_opt('bytelevelbpe')
            slow_bytelevel_bpe_opt = self._get_dict_opt('slow_bytelevel_bpe')

            dict_file = os.path.join(tmpdir, "dict")
            pp = build_dict.setup_args()
            pp.set_defaults(**hf_bpe_opt)
            pp.set_defaults(task='babi')
            popt = pp.parse_args([])
            popt['dict_file'] = dict_file
            build_dict.build_dict(popt)

            hf_bpe_opt['dict_file'] = dict_file
            hf_bpe = DictionaryAgent(hf_bpe_opt)

            slow_bytelevel_bpe_opt['dict_file'] = dict_file
            slow_bytelevel_bpe = DictionaryAgent(slow_bytelevel_bpe_opt)

            self._run_test(slow_bytelevel_bpe, hf_bpe)

            slow_bytelevel_bpe_opt['bpe_add_prefix_space'] = True
            slow_bytelevel_bpe = DictionaryAgent(slow_bytelevel_bpe_opt)
            self._run_prefix_space_test(slow_bytelevel_bpe)
예제 #2
0
 def verify_batch_lengths(defaults):
     with testing_utils.capture_output() as _, testing_utils.tempdir() as tmpdir:
         # Get processed act from agent
         parser = train_setup_args()
         defaults['model_file'] = os.path.join(tmpdir, 'model')
         defaults['dict_file'] = os.path.join(tmpdir, 'model.dict')
         parser.set_defaults(**defaults)
         opt = parser.parse_args()
         build_dict(opt)
         agent = create_agent(opt)
         world_data = create_task(opt, agent)
         batch_sort_acts = []
         # first epoch
         while len(batch_sort_acts) < 900/50:
             world_data.parley()
             batch_sort_acts.append(world_data.acts[0])
         teacher = world_data.world.get_agents()[0]
         teacher.reset_data()
         # second epoch
         while len(batch_sort_acts) < 1800/50:
             world_data.parley()
             batch_sort_acts.append(world_data.acts[0])
         world_data.shutdown()
     field = defaults['batch_sort_field']
     lengths = [[ep_length(b[field]) for b in bb if field in b]
                for bb in batch_sort_acts[:-2]]  # exclude last batch
     # verify batch lengths
     for batch_lens in lengths:
         self.assertLessEqual(max(batch_lens) - min(batch_lens), max_range,
                              'PytorchDataTeacher batching does not give '
                              'batches with similar sized examples, when '
                              'sorting by `{}` field.'.format(
                                 defaults['batch_sort_field']))
예제 #3
0
        def verify_batch_lengths(defaults):
            f = io.StringIO()

            with redirect_stdout(f):
                # Get processed act from agent
                parser = train_setup_args()
                set_model_file(defaults)
                parser.set_defaults(**defaults)
                opt = parser.parse_args()
                build_dict(opt)
                agent = create_agent(opt)
                world_data = create_task(opt, agent)
                batch_sort_acts = []
                # first epoch
                while len(batch_sort_acts) < 900 / 50:
                    world_data.parley()
                    batch_sort_acts.append(world_data.acts[0])
                teacher = world_data.world.get_agents()[0]
                teacher.reset_data()
                # second epoch
                while len(batch_sort_acts) < 1800 / 50:
                    world_data.parley()
                    batch_sort_acts.append(world_data.acts[0])
            field = defaults['batch_sort_field']
            lengths = [[ep_length(b[field]) for b in bb if field in b]
                       for bb in batch_sort_acts[:-2]]  # exclude last batch
            # verify batch lengths
            for batch_lens in lengths:
                self.assertLessEqual(
                    max(batch_lens) - min(batch_lens), max_range,
                    'PytorchDataTeacher batching does not give '
                    'batches with similar sized examples, when '
                    'sorting by `{}` field.'.format(
                        defaults['batch_sort_field']))
예제 #4
0
 def __init__(self, opt):
     if isinstance(opt, ParlaiParser):
         print(
             '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]'
         )
         opt = opt.parse_args()
     # Possibly load from checkpoint
     if opt['load_from_checkpoint'] and opt.get(
             'model_file') and os.path.isfile(opt['model_file'] +
                                              '.checkpoint'):
         opt['init_model'] = opt['model_file'] + '.checkpoint'
     # Possibly build a dictionary (not all models do this).
     if opt['dict_build_first'] and 'dict_file' in opt:
         # If data built via pytorch data teacher, we need to load prebuilt dict
         if opt.get('pytorch_teacher_task'):
             opt['dict_file'] = get_pyt_dict_file(opt)
         elif opt['dict_file'] is None and opt.get('model_file'):
             opt['dict_file'] = opt['model_file'] + '.dict'
         print("[ building dictionary first... ]")
         build_dict(opt, skip_if_built=True)
     # Create model and assign it to the specified task
     self.agent = create_agent(opt)
     self.world = create_task(opt, self.agent)
     self.train_time = Timer()
     self.validate_time = Timer()
     self.log_time = Timer()
     self.save_time = Timer()
     print('[ training... ]')
     self.parleys = 0
     self.max_num_epochs = opt[
         'num_epochs'] if opt['num_epochs'] > 0 else float('inf')
     self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 \
         else float('inf')
     self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \
         else float('inf')
     self.val_every_n_secs = \
         opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 \
         else float('inf')
     self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] \
         > 0 else float('inf')
     self.val_every_n_epochs = \
         opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 \
         else float('inf')
     self.last_valid_epoch = 0
     self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
     self.best_valid = None
     if opt.get('model_file') and os.path.isfile(opt['model_file'] +
                                                 '.best_valid'):
         with open(opt['model_file'] + ".best_valid", 'r') as f:
             x = f.readline()
             self.best_valid = float(x)
             f.close()
     self.impatience = 0
     self.saved = False
     self.valid_world = None
     self.opt = opt
     if opt['tensorboard_log'] is True:
         self.writer = TensorboardLogger(opt)
예제 #5
0
 def get_teacher_act(defaults, teacher_processed=False, agent_to=None):
     parser = train_setup_args()
     parser.set_defaults(**defaults)
     opt = parser.parse_args([])
     build_dict(opt)
     teacher = create_task_agent_from_taskname(opt)[0]
     agent = create_agent(opt)
     act = teacher.act()
     if teacher_processed:
         return act, agent
     return agent.observe(act), agent
예제 #6
0
    def _distributed_train_model(self, opt):
        with testing_utils.tempdir() as tmpdir:
            if 'model_file' not in opt:
                opt['model_file'] = os.path.join(tmpdir, 'model')
            if 'dict_file' not in opt:
                opt['dict_file'] = os.path.join(tmpdir, 'model.dict')

            parser = mp_train.setup_args()
            popt = _forced_parse(parser, opt)

            # we need a prebuilt dictionary
            parser = build_dict.setup_args()
            build_dict.build_dict(popt)

            valid, test = mp_train.launch_and_train(popt, 31337)

        return (valid, test)
예제 #7
0
    def _distributed_train_model(self, **overrides):
        opt = {**self.base_config, **overrides}
        with testing_utils.tempdir() as tmpdir:
            if 'model_file' not in opt:
                opt['model_file'] = os.path.join(tmpdir, 'model')
            if 'dict_file' not in opt:
                opt['dict_file'] = os.path.join(tmpdir, 'model.dict')

            parser = mp_train.setup_args()
            popt = parser.parse_kwargs(**opt)

            # we need a prebuilt dictionary
            parser = build_dict.setup_args()
            build_dict.build_dict(popt)

            valid, test = mp_train.launch_and_train(popt)

        return (valid, test)
예제 #8
0
def build_smoother_dict():
    def setup_args():
        parser = ParlaiParser(add_model_args=True, add_parlai_args=True)
        parser.set_defaults(
            task='tasks.convai2smoother.agents:BothOriginalTeacher',
            # dict_initpath='../../tmp/dict/init_receiver.dict',
            datatype='train',
            dict_lower=True,
            dict_file='../../tmp/dict/smoother_origin.dict',
            dict_tokenizer='split',
            dict_language='english',
            dict_include_valid=True,
            dict_maxexs=-1)
        return parser

    parser = setup_args()
    opt = parser.parse_args(args=[])
    # %%
    build_dict(opt)
예제 #9
0
파일: test_gpt2.py 프로젝트: donshen/ParlAI
    def _distributed_train_model(self, opt):
        with testing_utils.tempdir() as tmpdir:
            if 'model_file' not in opt:
                opt['model_file'] = os.path.join(tmpdir, 'model')
            if 'dict_file' not in opt:
                opt['dict_file'] = os.path.join(tmpdir, 'model.dict')

            parser = mp_train.setup_args()
            # TODO: Kill this after dictionaries build correctly
            popt = self._forced_parse(parser, opt)

            # we need a prebuilt dictionary
            parser = build_dict.setup_args()
            build_dict.build_dict(popt)

            valid, test = mp_train.launch_and_train(popt, 31338)
            dist.destroy_process_group()

        return (valid, test)
예제 #10
0
 def get_acts_epochs_1_and_2(defaults):
     parser.set_defaults(**defaults)
     opt = parser.parse_args()
     build_dict(opt)
     agent = create_agent(opt)
     world_data = create_task(opt, agent)
     acts_epoch_1 = []
     acts_epoch_2 = []
     while not world_data.epoch_done():
         world_data.parley()
         acts_epoch_1.append(world_data.acts[0])
     world_data.reset()
     while not world_data.epoch_done():
         world_data.parley()
         acts_epoch_2.append(world_data.acts[0])
     acts_epoch_1 = [bb for b in acts_epoch_1 for bb in b]
     acts_epoch_1 = sorted([b for b in acts_epoch_1 if 'text' in b],
                           key=lambda x: x.get('text'))
     acts_epoch_2 = [bb for b in acts_epoch_2 for bb in b]
     acts_epoch_2 = sorted([b for b in acts_epoch_2 if 'text' in b],
                           key=lambda x: x.get('text'))
     return acts_epoch_1, acts_epoch_2
예제 #11
0
    def _distributed_train_model(self, opt):
        # we have to delay our import to here, because the set_spawn_method call
        # inside multiprocessing_train will break the multithreading tests, even
        # when we skip the test.
        import parlai.scripts.multiprocessing_train as mp_train

        with testing_utils.capture_output() as output:
            with testing_utils.tempdir() as tmpdir:
                if 'model_file' not in opt:
                    opt['model_file'] = os.path.join(tmpdir, 'model')
                if 'dict_file' not in opt:
                    opt['dict_file'] = os.path.join(tmpdir, 'model.dict')

                parser = mp_train.setup_args()
                popt = _forced_parse(parser, opt)

                # we need a prebuilt dictionary
                parser = build_dict.setup_args()
                build_dict.build_dict(popt)

                valid, test = mp_train.launch_and_train(popt, 31337)

        return (output.getvalue(), valid, test)
예제 #12
0
def build_transmitter_dict():
    def setup_args():
        parser = ParlaiParser(add_model_args=True, add_parlai_args=True)
        parser.set_defaults(
            task='tasks.convai2transmitter.agents:BothTeacher',
            dict_initpath='../../tmp/dict/init_transmitter.dict',
            datatype='train',
            dict_lower=True,
            dict_file='../../tmp/dict/convai2_self_seq2seq_model.dict',
            dict_nulltoken=SpecialToken.pad,
            dict_starttoken=SpecialToken.start,
            dict_endtoken=SpecialToken.end,
            dict_unktoken=SpecialToken.unk,
            dict_tokenizer='split',
            dict_language='english',
            dict_include_valid=True,
            dict_minfreq=2,
            dict_maxexs=-1)
        return parser

    parser = setup_args()
    opt = parser.parse_args(args=[])
    # %%
    build_dict(opt)
예제 #13
0
    def __init__(self, opt):
        # if python is called from a non-interactive shell, like a bash script,
        # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are
        # not produced. This line brings them back
        signal.signal(signal.SIGINT, signal.default_int_handler)
        # Possibly load from checkpoint
        trainstats_suffix = '.trainstats'  # we might load training statistics from here
        if (
            opt['load_from_checkpoint']
            and opt.get('model_file')
            and PathManager.exists(opt['model_file'] + '.checkpoint')
        ):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
            trainstats_suffix = '.checkpoint.trainstats'
        # Possibly build a dictionary (not all models do this).
        if not (opt.get('dict_file') or opt.get('model_file')):
            raise RuntimeError(
                'WARNING: For train_model, please specify either a '
                'model_file or dict_file.'
            )
        if 'dict_file' in opt:
            if opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            logging.info("building dictionary first...")
            build_dict(opt, skip_if_built=True)

        # Create model and assign it to the specified task
        self.agent = create_agent(opt)
        self.agent.opt.log()
        self.world = create_task(opt, self.agent)
        # set up timers
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()

        self.parleys = 0
        self._train_steps = 0
        self._last_log_steps = 0
        self.update_freq = opt.get('update_freq', 1)

        self.max_num_epochs = _num_else_inf(opt, 'num_epochs', distributed_warn=True)
        self.max_train_time = _num_else_inf(
            opt, 'max_train_time', distributed_warn=True
        )
        self.max_train_steps = _num_else_inf(opt, 'max_train_steps')
        self.log_every_n_secs = _num_else_inf(
            opt, 'log_every_n_secs', distributed_warn=True
        )
        self.log_every_n_steps = _num_else_inf(opt, 'log_every_n_steps')
        self.val_every_n_secs = _num_else_inf(
            opt, 'validation_every_n_secs', distributed_warn=True
        )
        self.val_every_n_epochs = _num_else_inf(
            opt, 'validation_every_n_epochs', distributed_warn=True
        )
        self.val_every_n_steps = _num_else_inf(opt, 'validation_every_n_steps')
        self.save_every_n_secs = _num_else_inf(
            opt, 'save_every_n_secs', distributed_warn=True
        )

        # smart defaults for --validation-metric-mode
        if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
            opt['validation_metric_mode'] = 'min'
        elif opt['validation_metric'] in {'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'}:
            opt['validation_metric_mode'] = 'max'
        if opt.get('validation_metric_mode') is None:
            opt['validation_metric_mode'] = 'max'

        self.last_valid_epoch = 0
        self._last_valid_steps = 0
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.train_reports = []
        self.valid_reports = []
        self.final_valid_report = {}
        self.final_test_report = {}
        self.final_extra_valid_report = {}
        self.best_valid = None

        self.impatience = 0
        self.saved = False
        self.valid_worlds = None
        self.opt = opt

        # we may have been preempted, make sure we note that amount
        self._preempted_epochs = 0.0
        if opt.get('model_file') and PathManager.exists(
            opt['model_file'] + trainstats_suffix
        ):
            # looks like we were preempted. make sure we load up our total
            # training stats, etc
            with PathManager.open(opt['model_file'] + trainstats_suffix) as ts:
                obj = json.load(ts)
                self.parleys = obj.get('parleys', 0)
                self._preempted_epochs = obj.get('total_epochs', 0)
                self.train_time.total = obj.get('train_time', 0)
                self._train_steps = obj.get('train_steps', 0)
                self.impatience = obj.get('impatience', 0)
                self.valid_reports = obj.get('valid_reports', [])
                if self.valid_reports:
                    self.last_valid_epoch = self.valid_reports[-1].get(
                        'total_epochs', 0.0
                    )
                self.train_reports = obj.get('train_reports', [])
                if 'best_valid' in obj:
                    self.best_valid = obj['best_valid']
                else:
                    # old method
                    if opt.get('model_file') and PathManager.exists(
                        opt['model_file'] + '.best_valid'
                    ):
                        with PathManager.open(
                            opt['model_file'] + ".best_valid", 'r'
                        ) as f:
                            x = f.readline()
                            self.best_valid = float(x)
                            f.close()

        if opt['tensorboard_log'] and is_primary_worker():
            self.tb_logger = TensorboardLogger(opt)
        if opt['wandb_log'] and is_primary_worker():
            model = self.agent.model if hasattr(self.agent, 'model') else None
            self.wb_logger = WandbLogger(opt, model)
예제 #14
0
파일: train_model.py 프로젝트: ying-A/RED
    def __init__(self, opt):
        # if python is called from a non-interactive shell, like a bash script,
        # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are
        # not produced. This line brings them back
        signal.signal(signal.SIGINT, signal.default_int_handler)

        if isinstance(opt, ParlaiParser):
            print(
                '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]'
            )
            opt = opt.parse_args()
        # Possibly load from checkpoint
        trainstats_suffix = '.trainstats'  # we might load training statistics from here
        if opt['load_from_checkpoint'] and opt.get(
                'model_file') and os.path.isfile(opt['model_file'] +
                                                 '.checkpoint'):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
            trainstats_suffix = '.checkpoint.trainstats'
        # Possibly build a dictionary (not all models do this).
        if opt['dict_build_first'] and 'dict_file' in opt:
            # If data built via pytorch data teacher, we need to load prebuilt dict
            if opt.get('pytorch_teacher_task'):
                opt['dict_file'] = get_pyt_dict_file(opt)
            elif opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt, skip_if_built=True)
        # Create model and assign it to the specified task
        self.agent = create_agent(opt)  #specify model such as seq2seq
        self.world = create_task(opt, self.agent)  # bacthworld or other world
        # set up timers
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')
        self.parleys = 0
        self.max_num_epochs = opt[
            'num_epochs'] if opt['num_epochs'] > 0 else float('inf')
        self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 \
            else float('inf')
        self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \
            else float('inf')
        self.val_every_n_secs = \
            opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 \
            else float('inf')
        self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] \
            > 0 else float('inf')
        self.val_every_n_epochs = \
            opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 \
            else float('inf')

        # smart defaults for --validation-metric-mode
        if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
            opt['validation_metric_mode'] = 'min'
        elif opt['validation_metric'] in {
                'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'
        }:
            opt['validation_metric_mode'] = 'max'
        if opt.get('validation_metric_mode') is None:
            opt['validation_metric_mode'] = 'max'

        self.last_valid_epoch = 0
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.best_valid = None
        if opt.get('model_file') and os.path.isfile(opt['model_file'] +
                                                    '.best_valid'):
            with open(opt['model_file'] + ".best_valid", 'r') as f:
                x = f.readline()
                self.best_valid = float(x)
                f.close()
        self.impatience = 0
        self.saved = False
        self.valid_world = None
        self.opt = opt

        # we may have been preempted, make sure we note that amount
        self._preempted_epochs = 0.0
        if (opt.get('model_file')
                and os.path.isfile(opt['model_file'] + trainstats_suffix)):
            # looks like we were preempted. make sure we load up our total
            # training stats, etc
            with open(opt['model_file'] + trainstats_suffix) as ts:
                obj = json.load(ts)
                self._preempted_epochs = obj.get('total_epochs', 0)
                self.train_time.total = obj.get('train_time', 0)
                self.impatience = obj.get('impatience', 0)

        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)
예제 #15
0
def build_data(opt):
    if not opt.get('model', False):
        opt['model'] = 'repeat_label'
    preprocess = opt.get('pytorch_preprocess', True)
    opt['dict_file'] = get_pyt_dict_file(opt)
    dictionary = None
    if 'dict_maxexs' in opt:
        # Note: only build dictionary if dict loop args specified
        dictionary = build_dict(opt, skip_if_built=True)
    agent = create_agent(opt)
    # If build teacher not specified, we are simply looking for the file
    if not opt.get('pytorch_teacher_task', None):
        df = opt.get('pytorch_datapath')
        # check if the user set a datafile
        if not df:
            raise Exception(
                'Tried to find data but `--pytorch-datapath` is not set')
        # check if the user provided the already built file
        if 'pytorch' not in df:
            df += '.pytorch' + (agent.getID() if opt.get(
                'pytorch_preprocess', True) else '')
        if not os.path.isfile(df):
            raise Exception('Tried to find data but it is not built, please'
                            'specify `--pytorch-teacher-task`')
        else:
            return df

    ordered_opt = copy.deepcopy(opt)
    # we use streaming to build the data
    dt = opt['datatype'].split(':')[0]
    ordered_opt['datatype'] = dt + ':ordered:stream'
    ordered_opt['numthreads'] = 1
    ordered_opt['batchsize'] = 1
    ordered_opt['task'] = ordered_opt['pytorch_teacher_task']
    ordered_opt.pop('pytorch_teacher_dataset')
    ordered_opt['no_cuda'] = True
    world_data = create_task(ordered_opt, agent)
    teacher = world_data.agents[0]
    agent = world_data.agents[1]
    datapath = os.path.join(
        opt.get('datapath', '.'),
        '{}_pyt_data'.format(ordered_opt['task'].replace(':', '_')), dt)
    if preprocess:
        datapath += '_{}_preprocess'.format(agent.getID().replace(':', '_'))
    if os.path.isdir(datapath) and 'data_length' in os.listdir(datapath):
        # Data already built
        print("[ pytorch data already built, at {}. ]".format(datapath))
        return datapath
    print('----------\n[ setting up pytorch data, saving to {}/ ]\n----------'.
          format(datapath))
    os.makedirs(datapath, exist_ok=True)
    num_eps = 0
    num_exs = 0
    current = []
    episode_done = False
    include_labels = opt.get('pytorch_include_labels', True)
    context_length = opt.get('pytorch_context_length', -1)
    context = deque(maxlen=context_length if context_length > 0 else None)
    total_exs = world_data.num_examples()
    pbar = tqdm.tqdm(total=total_exs,
                     unit='ex',
                     unit_scale=True,
                     desc='Building pytorch data')
    idx_to_char = []
    cumulative_char_len = 0
    # pass examples to dictionary
    with open(os.path.join(datapath, 'data'), 'w') as pytorch_data:
        while num_exs < total_exs:
            while not episode_done:
                action = teacher.act()
                current.append(action)
                episode_done = action.get('episode_done', False)

            # build separate episodes
            for ex in current:
                context.append(ex.get('text', ''))
                if len(context) > 1:
                    ex['text'] = '\n'.join(context)
                ex['episode_done'] = True
                labels = ex.get('labels', ex.get('eval_labels', None))
                if labels is not None and include_labels:
                    context.append(random.choice(labels))
                # generate observation from new example
                if preprocess:
                    ex = agent.observe(ex)
                    ex.pop('label_candidates', '')
                    ex['preprocessed'] = True
                num_eps += 1
                num_exs += 1
                pbar.update(1)
                ex_len = pytorch_data.write(
                    json.dumps(make_serializable(ex)) + "\n")
                idx_to_char.append(cumulative_char_len)
                cumulative_char_len += ex_len
            # reset
            episode_done = False
            current.clear()
            context.clear()
    pbar.close()
    with open(os.path.join(datapath, 'char_index'), 'w') as char_index:
        json.dump(idx_to_char, char_index)
    with open(os.path.join(datapath, 'data_length'), 'w') as pytorch_data_len:
        pytorch_data_len.write(
            json.dumps({
                'num_eps': num_eps,
                'num_exs': num_exs
            }))
    if dictionary:
        dictionary.save(get_pyt_dict_file(opt), sort=True)

    print('[ pytorch data built. ]')
    return datapath
예제 #16
0
def _build_dict(opt):
    build_dict(opt)
예제 #17
0
    def __init__(self, opt):
        signal.signal(signal.SIGINT, signal.default_int_handler)

        if isinstance(opt, ParlaiParser):
            opt = opt.parse_args()
        # Possibly load from checkpoint
        trainstats_suffix = '.trainstats'
        if (opt.get('model_file')
                and isfile(opt['model_file'] + '.checkpoint')):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
            trainstats_suffix = '.checkpoint.trainstats'
        else:
            pass
            # TODO for testing only
            # raise RuntimeError('WARNING: Reinforcement learning'
            #                    ' must be initialized by a model.checkpoint '
            #                    'file and {} does not exist.'.format(
            #                        opt['model_file'] + '.checkpoint'))
        # Possibly build a dictionary (not all models do this).
        if (opt['dict_build_first']
                and not (opt.get('dict_file') or opt.get('model_file'))):
            raise RuntimeError('WARNING: For train_model, '
                               'please specify either a '
                               'model_file or dict_file.')

        if opt['dict_build_first'] and 'dict_file' in opt:
            if opt.get('pytorch_teacher_task'):
                opt['dict_file'] = get_pyt_dict_file(opt)
            elif opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt, skip_if_built=True)

        # Create model and assign it to the specified task
        self.agent = create_agent(opt)

        # Freeze the model for the static dialogue partner
        static_agent = copy.deepcopy(self.agent)
        self.agent.id = ACTIVE

        static_agent.id = STATIC
        freeze_agent(static_agent)

        self.world = create_task(opt, self.agent, static_agent)

        # set up timers
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')

        self.parleys = 0
        self.max_num_epochs = (opt['num_epochs']
                               if opt['num_epochs'] > 0 else float('inf'))

        self.max_train_time = (opt['max_train_time']
                               if opt['max_train_time'] > 0 else float('inf'))

        self.log_every_n_secs = (opt['log_every_n_secs'] if
                                 opt['log_every_n_secs'] > 0 else float('inf'))

        self.val_every_n_secs = (opt['validation_every_n_secs']
                                 if opt['validation_every_n_secs'] > 0 else
                                 float('inf'))

        self.save_every_n_secs = (opt['save_every_n_secs']
                                  if opt['save_every_n_secs'] > 0 else
                                  float('inf'))

        self.val_every_n_epochs = (opt['validation_every_n_epochs']
                                   if opt['validation_every_n_epochs'] > 0 else
                                   float('inf'))

        # smart defaults for --validation-metric-mode
        if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
            opt['validation_metric_mode'] = 'min'
        elif opt['validation_metric'] in {
                'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'
        }:
            opt['validation_metric_mode'] = 'max'
        if opt.get('validation_metric_mode') is None:
            opt['validation_metric_mode'] = 'max'

        self.last_valid_epoch = 0
        self.valid_optim = (1
                            if opt['validation_metric_mode'] == 'max' else -1)
        self.valid_reports = []
        self.best_valid = None
        if (opt.get('model_file')
                and isfile(opt['model_file'] + '.best_valid')):
            with open(opt['model_file'] + ".best_valid", 'r') as f:
                x = f.readline()
                self.best_valid = float(x)
                f.close()
        self.impatience = 0
        self.saved = False
        self.valid_world = None
        self.opt = opt

        # we may have been preempted, make sure we note that amount
        self._preempted_epochs = 0.0
        if (opt.get('model_file')
                and isfile(opt['model_file'] + trainstats_suffix)):
            # looks like we were preempted. make sure we load up our total
            # training stats, etc
            with open(opt['model_file'] + trainstats_suffix) as ts:
                obj = json.load(ts)
                self._preempted_epochs = obj.get('total_epochs', 0)
                self.train_time.total = obj.get('train_time', 0)
                self.impatience = obj.get('impatience', 0)
                self.valid_reports = obj.get('valid_reports', [])

        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)
예제 #18
0
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Generates a dictionary file from the training data.

For more documentation, see `parlai.scripts.build_dict`.
"""

from parlai.scripts.build_dict import setup_args, build_dict

if __name__ == '__main__':
    parser = setup_args()
    opt = parser.parse_args()
    build_dict(opt)