コード例 #1
0
ファイル: train.py プロジェクト: telin0411/dialogue-reinforce
def setup_rl_args():
    parser = setup_args()
    reinforce = parser.add_argument_group('Reinforce Arguments')
    reinforce.add_argument(
        '-dl',
        '--dialog_rounds',
        type=int,
        default=2,
        help='Number of rollouts rounds for estimating the reward.')
    reinforce.add_argument(
        '-dl',
        '--dialog_branches',
        type=int,
        default=5,
        help='Branches of the active agent responses during rollout.')
    reinforce.add_argument('-lmp',
                           '--language_model_path',
                           type=str,
                           default=None,
                           help='Path of the language model for the reward.')
    reinforce.add_argument('-rd',
                           '--reward_decay',
                           type=float,
                           default=0.9,
                           help='Value of the reward decay.')

    return parser
コード例 #2
0
ファイル: test_torch_agent.py プロジェクト: rhamnett/ParlAI
 def get_popt_and_tl(opt):
     parser = tms.setup_args()
     parser.set_params(**opt)
     popt = parser.parse_args([], print_args=False)
     for k, v in opt.items():
         popt[k] = v
     return popt, tms.TrainLoop(popt)
コード例 #3
0
    def test_hogwild_eval(self):
        """Test eval with numthreads > 1 and batchsize in [1,2,3]."""
        parser = setup_args()
        NUM_EXS = 500
        parser.set_defaults(
            task='tasks.repeat:RepeatTeacher:{}'.format(NUM_EXS),
            model='repeat_label',
            datatype='valid',
            num_examples=-1,
            display_examples=False,
        )

        old_out = sys.stdout
        output = display_output()
        try:
            sys.stdout = output
            for nt in [2, 5, 10]:
                parser.set_defaults(numthreads=nt)
                for bs in [1, 2, 3]:
                    parser.set_defaults(batchsize=bs)
                    parser.set_defaults(batch_sort=(bs % 2 == 0))
                    report = eval_model(parser, printargs=False)
                    self.assertEqual(report['total'], NUM_EXS)
        finally:
            # restore sys.stdout
            sys.stdout = old_out
コード例 #4
0
def train_model(opt):
    """
    Runs through a TrainLoop.

    If model_file is not in opt, then this helper will create a temporary
    directory to store the model, dict, etc.

    :return: (stdout, stderr, valid_results, test_results)
    :rtype: (str, str, dict, dict)
    """
    import parlai.scripts.train_model as tms

    with capture_output() as output:
        with tempdir() as tmpdir:
            if 'model_file' not in opt:
                opt['model_file'] = os.path.join(tmpdir, 'model')
            if 'dict_file' not in opt:
                opt['dict_file'] = os.path.join(tmpdir, 'model.dict')
            parser = tms.setup_args()
            parser.set_params(**opt)
            popt = parser.parse_args(print_args=False)
            tl = tms.TrainLoop(popt)
            valid, test = tl.train()

    return (
        output.getvalue(),
        valid,
        test,
    )
コード例 #5
0
        def get_tl(tmpdir):
            final_opt = Opt({
                'task': 'integration_tests',
                'datatype': 'valid',
                'validation_max_exs': 30,
                'short_final_eval': True,
            })
            final_opt.save(os.path.join(tmpdir, "final_opt.opt"))

            opt = Opt({
                'task':
                'integration_tests',
                'validation_max_exs':
                10,
                'model':
                'repeat_label',
                'model_file':
                os.path.join(tmpdir, 'model'),
                'short_final_eval':
                True,
                'num_epochs':
                1.0,
                'final_extra_opt':
                str(os.path.join(tmpdir, "final_opt.opt")),
            })
            parser = tms.setup_args()
            parser.set_params(**opt)
            popt = parser.parse_args([])
            for k, v in opt.items():
                popt[k] = v
            return tms.TrainLoop(popt)
コード例 #6
0
def train_model(opt):
    """
    Run through a TrainLoop.

    If model_file is not in opt, then this helper will create a temporary
    directory to store the model, dict, etc.

    :return: (stdout, valid_results, test_results)
    :rtype: (str, dict, dict)
    """
    import parlai.scripts.train_model as tms

    with capture_output() as output:
        with tempdir() as tmpdir:
            if 'model_file' not in opt:
                opt['model_file'] = os.path.join(tmpdir, 'model')
            if 'dict_file' not in opt:
                opt['dict_file'] = os.path.join(tmpdir, 'model.dict')
            parser = tms.setup_args()
            # needed at the very least to set the overrides.
            parser.set_params(**opt)
            parser.set_params(log_every_n_secs=10)
            popt = parser.parse_args(print_args=False)
            # in some rare cases, like for instance if the model class also
            # overrides its default params, the params override will not
            # be taken into account.
            for k, v in opt.items():
                popt[k] = v
            tl = tms.TrainLoop(popt)
            valid, test = tl.train()

    return (output.getvalue(), valid, test)
コード例 #7
0
def setup_args():
    parser = single_train.setup_args()
    parser.add_distributed_training_args()
    parser.add_argument('--port',
                        type=int,
                        default=61337,
                        help='TCP port number')
    return parser
コード例 #8
0
    def test_train_model_with_no_dict_file(self):
        """Ensure training a model requires a dict_file or model_file."""
        import parlai.scripts.train_model as tms

        with testing_utils.capture_output():
            parser = tms.setup_args()
            parser.set_params(task='babi:task1k:1', model='seq2seq')
            popt = parser.parse_args(print_args=False)
            with self.assertRaises(RuntimeError):
                tms.TrainLoop(popt)
コード例 #9
0
 def test_train_model_with_no_dict_file(self):
     """Check that attempting to train a model without specifying a dict_file
     or model_file fails
     """
     import parlai.scripts.train_model as tms
     with testing_utils.capture_output():
         parser = tms.setup_args()
         parser.set_params(task='babi:task1k:1', model='seq2seq')
         popt = parser.parse_args(print_args=False)
         with self.assertRaises(RuntimeError):
             tms.TrainLoop(popt)
コード例 #10
0
ファイル: test_fairseq.py プロジェクト: yueyihua/ParlAI
def _mock_train(**args):
    outdir = tempfile.mkdtemp()
    parser = setup_args()
    parser.set_defaults(
        model_file=os.path.join(outdir, "model"),
        **args,
    )
    stdout = io.StringIO()
    with contextlib.redirect_stdout(stdout):
        tl = TrainLoop(parser.parse_args(print_args=False))
        valid, test = tl.train()

    shutil.rmtree(outdir)
    return stdout.getvalue(), valid, test
コード例 #11
0
ファイル: test_transformers.py プロジェクト: mahdimor/ParlAI
def _mock_train(outdir=None, keepoutdir=False, override=None, **args):
    if not outdir:
        outdir = tempfile.mkdtemp()
    parser = setup_args()
    parser.set_defaults(
        model_file=os.path.join(outdir, "model"),
        **args,
    )
    stdout = io.StringIO()
    with contextlib.redirect_stdout(stdout):
        opt = parser.parse_args(print_args=False)
        if override:
            opt['override'] = override
        tl = TrainLoop(opt)
        valid, test = tl.train()
    if not keepoutdir:
        shutil.rmtree(outdir)
    return stdout.getvalue(), valid, test
コード例 #12
0
ファイル: test_train_model.py プロジェクト: magicye/ParlAI
 def get_tl(tmpdir):
     opt = {
         'task': 'integration_tests',
         'model': 'parlai.agents.test_agents.test_agents:MockTrainUpdatesAgent',
         'model_file': os.path.join(tmpdir, 'model'),
         'dict_file': os.path.join(tmpdir, 'model.dict'),
         # step opts
         'max_train_steps': num_train_steps,
         'validation_every_n_steps': int(num_train_steps / num_validations),
         'log_every_n_steps': int(num_train_steps / num_logs),
         'update_freq': update_freq,
     }
     parser = tms.setup_args()
     parser.set_params(**opt)
     popt = parser.parse_args([])
     for k, v in opt.items():
         popt[k] = v
     return tms.TrainLoop(popt)
コード例 #13
0
ファイル: distributed_train.py プロジェクト: shatu/ParlAI
def main():
    # double check we're using SLURM
    node_list = os.environ.get('SLURM_JOB_NODELIST')
    if node_list is None:
        raise RuntimeError(
            'Does not appear to be in a SLURM environment. '
            'You should not call this script directly; see launch_distributed.py'
        )

    parser = single_train.setup_args()
    parser.add_distributed_training_args()
    parser.add_argument('--port',
                        type=int,
                        default=61337,
                        help='TCP port number')
    opt = parser.parse_args(print_args=(os.environ['SLURM_PROCID'] == '0'))

    # We can determine the init method automatically for Slurm.
    try:
        # Figure out the main host, and which rank we are.
        hostnames = subprocess.check_output(
            ['scontrol', 'show', 'hostnames', node_list])
        main_host = hostnames.split()[0].decode('utf-8')
        distributed_rank = int(os.environ['SLURM_PROCID'])
        if opt.get('model_parallel'):
            # -1 signals to multiprocessing_train to use all GPUs available.
            # (A value of None signals to multiprocessing_train to use the GPU
            # corresponding to the rank.
            device_id = -1
        else:
            device_id = int(os.environ['SLURM_LOCALID'])
        port = opt['port']
        logging.info(
            f'Initializing host {socket.gethostname()} as rank {distributed_rank}, '
            f'main is {main_host}')
        # Begin distributed training
        multiprocess_train(distributed_rank, opt, port, 0, device_id,
                           main_host)
    except subprocess.CalledProcessError as e:
        # scontrol failed
        raise e
    except FileNotFoundError:
        # Slurm is not installed
        raise RuntimeError('SLURM does not appear to be installed.')
コード例 #14
0
ファイル: test_hogwild.py プロジェクト: zwcdp/ParlAI
    def test_hogwild_train(self):
        """Test the trainer eval with numthreads > 1 and batchsize in [1,2,3]."""
        parser = setup_args()
        NUM_EXS = 500
        parser.set_defaults(
            task='tasks.repeat:RepeatTeacher:{}'.format(1),
            evaltask='tasks.repeat:RepeatTeacher:{}'.format(NUM_EXS),
            model='repeat_label',
            num_examples=-1,
            display_examples=False,
            num_epochs=10,
        )

        old_out = sys.stdout
        output = display_output()
        try:
            sys.stdout = output
            for nt in [2, 5, 10]:
                parser.set_defaults(numthreads=nt)
                for bs in [1, 2, 3]:
                    parser.set_defaults(batchsize=bs)
                    parser.set_defaults(batch_sort=(bs % 2 == 0))
                    tl = TrainLoop(parser)
                    report_valid, report_test = tl.train()
                    # test final valid and test evals
                    self.assertEqual(report_valid['exs'], NUM_EXS)
                    self.assertEqual(report_test['exs'], NUM_EXS)

                    report_full, _world = run_eval(tl.agent,
                                                   tl.opt,
                                                   'valid',
                                                   max_exs=-1,
                                                   valid_world=tl.valid_world)
                    self.assertEqual(report_full['exs'], NUM_EXS)
                    report_part, _world = run_eval(tl.agent,
                                                   tl.opt,
                                                   'valid',
                                                   max_exs=NUM_EXS / 5,
                                                   valid_world=tl.valid_world)
                    self.assertTrue(report_part['exs'] < NUM_EXS)
        finally:
            # restore sys.stdout
            sys.stdout = old_out
コード例 #15
0
    def test_output(self):
        f = io.StringIO()
        with redirect_stdout(f):
            try:
                import torch
            except ImportError:
                print('Cannot import torch, skipping test_train_model')
                return
            parser = setup_args()
            parser.set_defaults(
                model='mlb_vqa',
                task='pytorch_teacher',
                pytorch_buildteacher='vqa_v1',
                dataset='parlai.tasks.vqa_v1.agents',
                image_mode='resnet152_spatial',
                image_size=448,
                image_cropsize=448,
                dict_file='/tmp/vqa_v1',
                batchsize=1,
                num_epochs=1,
                no_cuda=True,
                no_hdf5=True,
                pytorch_preprocess=False,
                batch_sort_cache='none',
                numworkers=1,
                unittest=True
            )
            TrainLoop(parser).train()

        str_output = f.getvalue()
        self.assertTrue(len(str_output) > 0, "Output is empty")
        self.assertTrue("[ training... ]" in str_output,
                        "Did not reach training step")
        self.assertTrue("[ running eval: valid ]" in str_output,
                        "Did not reach validation step")
        self.assertTrue("valid:{'total': 10," in str_output,
                        "Did not complete validation")
        self.assertTrue("[ running eval: test ]" in str_output,
                        "Did not reach evaluation step")
        self.assertTrue("test:{'total': 0}" in str_output,
                        "Did not complete evaluation")
コード例 #16
0
def main():
    parser = single_train.setup_args()
    parser.add_distributed_training_args()
    parser.set_defaults(distributed_world_size=torch.cuda.device_count())
    opt = parser.parse_args()

    port = random.randint(32000, 48000)

    # Launch multiple subprocesses
    spawncontext = torch.multiprocessing.spawn(
        multiprocess_train,
        (opt, port),
        nprocs=opt['distributed_world_size'],
        join=False,
    )

    try:
        spawncontext.join()
    except KeyboardInterrupt:
        # tell the subprocesses to stop too
        for p in spawncontext.processes:
            if p.is_alive():
                os.kill(p.pid, signal.SIGINT)
コード例 #17
0
def setup_args():
    parser = single_train.setup_args()
    parser.add_distributed_training_args()
    parser.set_defaults(distributed_world_size=torch.cuda.device_count())
    return parser
コード例 #18
0
ファイル: train_transformer_rec.py プロジェクト: zwcdp/KBRD
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Train a model using parlai's standard training loop.

For documentation, see parlai.scripts.train_model.
"""

from parlai.scripts.train_model import TrainLoop, setup_args

if __name__ == '__main__':
    parser = setup_args()
    parser.set_defaults(
        task='redial',
        model='transformer_rec/generator',
        model_file='saved/transformer_rec',
        dict_tokenizer='nltk',
        dict_lower=True,
        batchsize=64,
        truncate=1024,
        dropout=0.1,
        relu_dropout=0.1,
        n_entity=64368,
        n_relation=214,
        validation_metric='nll_loss',
        validation_metric_mode='min',
        validation_every_n_secs=300,
        validation_patience=5,
        tensorboard_log=True,
コード例 #19
0
    def test_output(self):
        class display_output(object):
            def __init__(self):
                self.data = []

            def write(self, s):
                self.data.append(s)

            def flush(self):
                pass

            def __str__(self):
                return "".join(self.data)

        old_out = sys.stdout
        output = display_output()

        try:
            sys.stdout = output
            try:
                import torch  # noqa: F401
            except ImportError:
                print('Cannot import torch, skipping test_train_model')
                return
            parser = setup_args()
            parser.set_defaults(
                model='memnn',
                task='tasks.repeat:RepeatTeacher:10',
                dict_file='/tmp/repeat',
                batchsize=1,
                numthreads=1,
                validation_every_n_epochs=10,
                validation_patience=5,
                embedding_size=8,
                no_cuda=True,
                validation_share_agent=True,
                num_episodes=10,
            )
            opt = parser.parse_args()
            TrainLoop(opt).train()
        finally:
            # restore sys.stdout
            sys.stdout = old_out

        str_output = str(output)

        self.assertTrue(len(str_output) > 0, "Output is empty")
        self.assertTrue("[ training... ]" in str_output,
                        "Did not reach training step")
        self.assertTrue("[ running eval: valid ]" in str_output,
                        "Did not reach validation step")
        self.assertTrue("valid:{'exs': 10," in str_output,
                        "Did not complete validation")
        self.assertTrue("[ running eval: test ]" in str_output,
                        "Did not reach evaluation step")
        self.assertTrue("test:{'exs': 10," in str_output,
                        "Did not complete evaluation")

        list_output = str_output.split("\n")
        for line in list_output:
            if "test:{" in line:
                score = ast.literal_eval(line.split("test:", 1)[1])
                self.assertTrue(
                    score['accuracy'] > 0.5,
                    'Accuracy not convincing enough, was {}'
                    ''.format(score['accuracy']))
コード例 #20
0
    If model_file is not in opt, then this helper will create a temporary
    directory to store the model, dict, etc.

    :return: (stdout, valid_results, test_results)
    :rtype: (str, dict, dict)
    """
    import parlai.scripts.train_model as tms

<<<<<<< HEAD
    with capture_output() as output:
        with tempdir() as tmpdir:
            if 'model_file' not in opt:
                opt['model_file'] = os.path.join(tmpdir, 'model')
            if 'dict_file' not in opt:
                opt['dict_file'] = os.path.join(tmpdir, 'model.dict')
            parser = tms.setup_args()
            # needed at the very least to set the overrides.
            parser.set_params(**opt)
            parser.set_params(log_every_n_secs=10)
<<<<<<< HEAD
            popt = parser.parse_args([], print_args=False)
=======
<<<<<<< HEAD
<<<<<<< HEAD
            popt = parser.parse_args([], print_args=False)
=======
            popt = parser.parse_args(print_args=False)
>>>>>>> 4f6b99642d60aff1a41b9eae8bd2ccd9e40ebba4
>>>>>>> origin/master
=======
            popt = parser.parse_args(print_args=False)