def train_translation_model(data_dir, arch, extra_flags=None, task='translation', run_validation=False, lang_flags=None, extra_valid_flags=None): if lang_flags is None: lang_flags = [ '--source-lang', 'in', '--target-lang', 'out', ] train_parser = options.get_training_parser() train_args = options.parse_args_and_arch( train_parser, [ '--task', task, data_dir, '--save-dir', data_dir, '--arch', arch, '--optimizer', 'nag', '--lr', '0.05', '--max-tokens', '500', '--max-epoch', '1', '--no-progress-bar', '--distributed-world-size', '1', '--num-workers', '0', ] + lang_flags + (extra_flags or []), ) train.main(train_args) if run_validation: # test validation validate_parser = options.get_validation_parser() validate_args = options.parse_args_and_arch(validate_parser, [ '--task', task, data_dir, '--path', os.path.join(data_dir, 'checkpoint_last.pt'), '--valid-subset', 'valid', '--max-tokens', '500', '--no-progress-bar', '--num-workers', '0', ] + lang_flags + (extra_valid_flags or [])) validate.main(validate_args)
def train_language_model(data_dir, arch, extra_flags=None, run_validation=False): train_parser = options.get_training_parser() train_args = options.parse_args_and_arch( train_parser, [ '--task', 'language_modeling', data_dir, '--arch', arch, '--optimizer', 'adam', '--lr', '0.0001', '--criterion', 'adaptive_loss', '--adaptive-softmax-cutoff', '5,10,15', '--max-tokens', '500', '--tokens-per-sample', '500', '--save-dir', data_dir, '--max-epoch', '1', '--no-progress-bar', '--distributed-world-size', '1', '--ddp-backend', 'no_c10d', ] + (extra_flags or []), ) train.main(train_args) if run_validation: # test validation validate_parser = options.get_validation_parser() validate_args = options.parse_args_and_arch(validate_parser, [ '--task', 'language_modeling', data_dir, '--path', os.path.join(data_dir, 'checkpoint_last.pt'), '--valid-subset', 'valid', '--max-tokens', '500', '--no-progress-bar', ]) validate.main(validate_args)
def train_language_model( data_dir, arch, extra_flags=None, run_validation=False, extra_valid_flags=None, task="language_modeling", world_size=1, ): train_parser = options.get_training_parser() train_args = options.parse_args_and_arch( train_parser, [ "--task", task, data_dir, "--arch", arch, "--optimizer", "adam", "--lr", "0.0001", "--max-tokens", "500", "--tokens-per-sample", "500", "--save-dir", data_dir, "--max-epoch", "1", "--no-progress-bar", "--distributed-world-size", str(world_size), "--ddp-backend", "no_c10d", "--num-workers", "0", ] + (extra_flags or []), ) cfg = convert_namespace_to_omegaconf(train_args) distributed_utils.call_main(cfg, train.main) if run_validation: # test validation validate_parser = options.get_validation_parser() validate_args = options.parse_args_and_arch( validate_parser, [ "--task", task, data_dir, "--path", os.path.join(data_dir, "checkpoint_last.pt"), "--valid-subset", "valid", "--max-tokens", "500", "--no-progress-bar", "--num-workers", "0", ] + (extra_valid_flags or []), ) validate.main(validate_args)
def train_translation_model( data_dir, arch, extra_flags=None, task="translation", run_validation=False, lang_flags=None, extra_valid_flags=None, world_size=1, ): if lang_flags is None: lang_flags = [ "--source-lang", "in", "--target-lang", "out", ] train_parser = options.get_training_parser() train_args = options.parse_args_and_arch( train_parser, [ "--task", task, data_dir, "--save-dir", data_dir, "--arch", arch, "--optimizer", "nag", "--lr", "0.05", "--max-tokens", "500", "--max-epoch", "1", "--no-progress-bar", "--distributed-world-size", str(world_size), "--num-workers", "0", ] + lang_flags + (extra_flags or []), ) cfg = convert_namespace_to_omegaconf(train_args) distributed_utils.call_main(cfg, train.main) if run_validation: # test validation validate_parser = options.get_validation_parser() validate_args = options.parse_args_and_arch( validate_parser, [ "--task", task, data_dir, "--path", os.path.join(data_dir, "checkpoint_last.pt"), "--valid-subset", "valid", "--max-tokens", "500", "--no-progress-bar", "--num-workers", "0", ] + lang_flags + (extra_valid_flags or []), ) validate.main(validate_args)