Ejemplo n.º 1
0
    def validate_train_opts(cls, opt):
        if opt.epochs:
            raise AssertionError(
                "-epochs is deprecated please use -train_steps.")
        if opt.truncated_decoder > 0 and max(opt.accum_count) > 1:
            raise AssertionError("BPTT is not compatible with -accum > 1")

        if opt.gpuid:
            raise AssertionError(
                "gpuid is deprecated see world_size and gpu_ranks")
        if torch.cuda.is_available() and not opt.gpu_ranks:
            logger.warn("You have a CUDA device, should run with -gpu_ranks")
        if opt.world_size < len(opt.gpu_ranks):
            raise AssertionError(
                "parameter counts of -gpu_ranks must be less or equal "
                "than -world_size.")
        if opt.world_size == len(opt.gpu_ranks) and \
                min(opt.gpu_ranks) > 0:
            raise AssertionError(
                "-gpu_ranks should have master(=0) rank "
                "unless -world_size is greater than len(gpu_ranks).")
        assert len(opt.data_ids) == len(opt.data_weights), \
            "Please check -data_ids and -data_weights options!"

        assert len(opt.dropout) == len(opt.dropout_steps), \
            "Number of dropout values must match accum_steps values"

        assert len(opt.attention_dropout) == len(opt.dropout_steps), \
            "Number of attention_dropout values must match accum_steps values"
        assert not opt.train_with_rl, \
            "Don't train with RL in this project"
Ejemplo n.º 2
0
def _check_save_model_path(opt):
    if not hasattr(opt, 'save_model') or opt.save_model is None:
        if hasattr(opt, 'exp_dir'):
            setattr(opt, 'save_model', opt.exp_dir+'/models/model')
        else:
            raise Exception('Neither exp_dir nor save_model is not given!')

    save_model_path = os.path.abspath(opt.save_model)
    model_dirname = os.path.dirname(save_model_path)
    if not os.path.exists(model_dirname):
        os.makedirs(model_dirname)

    if not hasattr(opt, 'log_file') or opt.log_file is None or opt.log_file=='':
        if hasattr(opt, 'log_file'):
            setattr(opt, 'log_file', opt.exp_dir+'/train.log')
        else:
            logger.warn("opt.log_file is not set")

    if not hasattr(opt, 'tensorboard_log_dir') or opt.tensorboard_log_dir is None:
        if hasattr(opt, 'exp_dir'):
            setattr(opt, 'tensorboard_log_dir', opt.exp_dir+'/logs/tfevents/')
        else:
            logger.warn("opt.tensorboard_log_dir is not set")
    if hasattr(opt, 'tensorboard_log_dir') and not os.path.exists(opt.tensorboard_log_dir):
        os.makedirs(opt.tensorboard_log_dir)

    if not hasattr(opt, 'wandb_log_dir') or opt.wandb_log_dir is None:
        if hasattr(opt, 'exp_dir'):
            # a `/wandb` will be appended by wandb
            setattr(opt, 'wandb_log_dir', opt.exp_dir+'/logs/')
        else:
            logger.warn("opt.wandb_log_dir is not set")
    if hasattr(opt, 'wandb_log_dir') and not os.path.exists(opt.wandb_log_dir):
        os.makedirs(opt.wandb_log_dir)
Ejemplo n.º 3
0
    def validate_train_opts(cls, opt):
        if opt.epochs:
            raise AssertionError(
                "-epochs is deprecated please use -train_steps.")
        if opt.truncated_decoder > 0 and max(opt.accum_count) > 1:
            raise AssertionError("BPTT is not compatible with -accum > 1")

        if opt.gpuid:
            raise AssertionError(
                "gpuid is deprecated see world_size and gpu_ranks")
        if torch.cuda.is_available() and not opt.gpu_ranks:
            logger.warn("You have a CUDA device, should run with -gpu_ranks")
        if opt.world_size < len(opt.gpu_ranks):
            raise AssertionError(
                "parameter counts of -gpu_ranks must be less or equal "
                "than -world_size.")
        if opt.world_size == len(opt.gpu_ranks) and \
                min(opt.gpu_ranks) > 0:
            raise AssertionError(
                "-gpu_ranks should have master(=0) rank "
                "unless -world_size is greater than len(gpu_ranks).")

        assert len(opt.dropout) == len(opt.dropout_steps), \
            "Number of dropout values must match accum_steps values"

        assert len(opt.attention_dropout) == len(opt.dropout_steps), \
            "Number of attention_dropout values must match accum_steps values"

        assert len(opt.accum_count) == len(opt.accum_steps), \
            'Number of accum_count values must match number of accum_steps'

        if opt.update_vocab:
            assert opt.train_from, \
                "-update_vocab needs -train_from option"
            assert opt.reset_optim in ['states', 'all'], \
                '-update_vocab needs -reset_optim "states" or "all"'
Ejemplo n.º 4
0
        opt.tensorboard_log_dir = join(exp_root, 'log')
        with open(join(exp_root, 'log', 'hps.json'), 'w') as writer:
            json.dump(vars(opt), writer, indent=4)

        # git info
        try:
            logger.info("Waiting on git info....")
            c = subprocess.run(["git", "rev-parse", "--abbrev-ref", "HEAD"],
                               timeout=10, stdout=subprocess.PIPE)
            git_branch_name = c.stdout.decode().strip()
            logger.info("Git branch: %s", git_branch_name)
            c = subprocess.run(["git", "rev-parse", "HEAD"],
                               timeout=10, stdout=subprocess.PIPE)
            git_sha = c.stdout.decode().strip()
            logger.info("Git SHA: %s", git_sha)
            git_dir = abspath(dirname(__file__))
            git_status = subprocess.check_output(
                ['git', 'status', '--short'],
                cwd=git_dir, universal_newlines=True).strip()
            with open(join(exp_root, 'log', 'git_info.json'), 'w') as writer:
                json.dump({'branch': git_branch_name,
                           'is_dirty': bool(git_status),
                           'status': git_status,
                           'sha': git_sha},
                          writer, indent=4)
        except subprocess.TimeoutExpired as e:
            logger.exception(e)
            logger.warn("Git info not found. Moving right along...")

    main(opt)