예제 #1
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args, dump_config=True)
    pprint(args)
    pprint(config)
    set_global_seeds(args.seed)

    assert args.baselogdir is not None or args.logdir is not None

    if args.logdir is None:
        modules_ = prepare_modules(model_dir=args.model_dir)
        logdir = modules_["model"].prepare_logdir(config=config)
        args.logdir = str(pathlib2.Path(args.baselogdir).joinpath(logdir))

    create_if_need(args.logdir)
    save_config(config=config, logdir=args.logdir)
    modules = prepare_modules(model_dir=args.model_dir, dump_dir=args.logdir)

    datasource = modules["data"].DataSource()
    model = modules["model"].prepare_model(config)

    runner = modules["model"].ModelRunner(model=model)
    runner.train(datasource=datasource,
                 args=args,
                 stages_config=config["stages"],
                 verbose=args.verbose)
예제 #2
0
def prepare_modules(model_dir, dump_dir=None):
    model_dir = model_dir[:-1] if model_dir.endswith("/") else model_dir
    model_dir_name = model_dir.rsplit("/", 1)[-1]

    new_model_dir = None
    if dump_dir is not None:
        current_date = datetime.now().strftime("%y-%m-%d-%H-%M-%S-%M-%f")
        new_src_dir = f"/src-{current_date}/"

        new_model_dir = f"{new_src_dir}" + model_dir
        new_model_dir = dump_dir + new_model_dir
        create_if_need(new_model_dir)

        # @TODO: hardcoded
        old_pro_dir = os.path.dirname(os.path.abspath(__file__)) + "/../../"
        new_pro_dir = dump_dir + f"/{new_src_dir}/catalyst/"
        shutil.copytree(old_pro_dir, new_pro_dir)

    pyfiles = list(
        map(lambda x: x.name[:-3],
            pathlib2.Path(model_dir).glob("*.py")))

    modules = {}
    for name in pyfiles:
        module_name = f"{model_dir_name}.{name}"
        module_src = model_dir + "/" + f"{name}.py"

        module = import_module(module_name, module_src)
        modules[name] = module

        if new_model_dir is not None:
            module_dst = new_model_dir + "/" + f"{name}.py"
            shutil.copy2(module_src, module_dst)

    return modules
예제 #3
0
    def create_loggers(logdir, loaders):
        create_if_need(logdir)
        # logfile = open("{logdir}/stdout.txt".format(logdir=logdir), "a")
        # sys.stdout = stream_tee(sys.stdout, logfile)

        loggers = []
        for key in loaders:
            logger = UtilsFactory.create_tflogger(logdir=logdir, name=key)
            loggers.append((key, logger))

        loggers = OrderedDict(loggers)

        return loggers
예제 #4
0
def save_config(config, logdir):
    create_if_need(logdir)
    with open("{}/config.json".format(logdir), "w") as fout:
        json.dump(config, fout, indent=2)
예제 #5
0
TIME_TRICK = args.time_trick
CHANGE_TRICK = args.change_trick
ACTION_MIXIN = args.action_mixin
assert ACTION_MIXIN > 0
ACTION_INTERVAL = 1. / ACTION_MIXIN
ACTION_DELTAS = np.arange(0, 1 + ACTION_INTERVAL, ACTION_INTERVAL)
ACTION_NOISE = args.action_noise
PARAM_NOISE = args.param_noise
SIDE_TRICK = args.side_trick

IN_CONSENSUS = args.in_consensus
OUT_CONSENSUS = args.out_consensus

NCPU = args.n_cpu
OUTDIR = args.outdir
create_if_need(OUTDIR)

assert LOGDIR is not None \
       or (LOGDIR_START is not None and LOGDIR_RUN is not None)
LOGDIR_START = LOGDIR_START or LOGDIR


def algos_by_dir(dir):
    algos = []
    dirs = path.Path(dir).listdir()
    for logpath in dirs:
        config_path = logpath + "/config.json"
        checkpoints = path.Path(logpath).glob("*.pth.tar")
        for checkpoint_path in checkpoints:
            args = argparse.Namespace(config=config_path)
            args, config = parse_args_uargs(args, [])
예제 #6
0
    def __init__(
            self,
            algorithm,
            state_shape,
            action_shape,
            logdir,
            redis_server=None,
            redis_prefix=None,
            n_workers=2,
            replay_buffer_size=int(1e6),
            batch_size=64,
            start_learning=int(1e3),
            gamma=0.99,
            n_step=1,
            history_len=1,
            epoch_len=int(1e2),
            save_period=10,
            target_update_period=1,
            online_update_period=1,
            weights_sync_period=1,
            max_redis_trials=1000):

        self.algorithm = algorithm
        history_len = history_len

        self.logdir = logdir
        current_date = datetime.now().strftime("%y-%m-%d-%H-%M-%S-%M-%f")
        logpath = f"{logdir}/trainer-{current_date}"
        create_if_need(logpath)
        self.logger = SummaryWriter(logpath)

        self.episodes_queue = mp.Queue()

        self.buffer = BufferDataset(
            state_shape=state_shape,
            action_shape=action_shape,
            max_size=replay_buffer_size,
            history_len=history_len,
            n_step=n_step,
            gamma=gamma)

        self.gamma = gamma
        self.n_step = n_step
        self.history_len = history_len

        self.batch_size = batch_size
        self.n_workers = n_workers

        self.sampler = BufferSampler(
            buffer=self.buffer,
            epoch_len=epoch_len,
            batch_size=batch_size)

        self.loader = DataLoader(
            dataset=self.buffer,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.n_workers,
            pin_memory=torch.cuda.is_available(),
            sampler=self.sampler)

        self.redis_server = redis_server
        self.redis_prefix = redis_prefix
        self.max_redis_trials = max_redis_trials
        self.start_learning = start_learning

        self.epoch = 0
        self.epoch_len = epoch_len

        # (actor_period, critic_period)
        target_update_period = (
            target_update_period
            if isinstance(target_update_period, list)
            else (target_update_period, target_update_period))
        self.actor_update_period, self.critic_update_period = \
            target_update_period

        # (actor_period, critic_period)
        online_update_period = (
            online_update_period
            if isinstance(online_update_period, list)
            else (online_update_period, online_update_period))
        self.actor_grad_period, self.critic_grad_period = \
            online_update_period

        self.save_period = save_period
        self.weights_sync_period = weights_sync_period

        self.actor_updates = 0
        self.critic_updates = 0
예제 #7
0
    def __init__(self,
                 actor,
                 env,
                 id,
                 logdir=None,
                 redis_server=None,
                 redis_prefix=None,
                 buffer_size=int(1e4),
                 history_len=1,
                 weights_sync_period=1,
                 mode="infer",
                 resume=None,
                 action_noise_prob=0,
                 action_noise_t=1,
                 random_process=None,
                 param_noise_prob=0,
                 param_noise_d=0.2,
                 param_noise_steps=1000,
                 seeds=None,
                 action_clip=(-1, 1),
                 episode_limit=None,
                 force_store=False,
                 min_episode_steps=None,
                 min_episode_reward=None):

        self._seed = 42 + id
        set_global_seeds(self._seed)

        self._sampler_id = id
        self._device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.actor = copy.deepcopy(actor).to(self._device)
        self.env = env
        self.redis_server = redis_server
        self.redis_prefix = redis_prefix or ""
        self.resume = resume
        self.episode_limit = episode_limit or int(2**32 - 2)
        self.force_store = force_store
        self.min_episode_steps = min_episode_steps
        self.min_episode_reward = min_episode_reward
        self.hard_seeds = set()
        min_episode_flag_ = \
            min_episode_steps is None and min_episode_reward is None
        assert min_episode_flag_ or seeds is None

        self.min_episode_steps = self.min_episode_steps or -int(1e6)
        self.min_episode_reward = self.min_episode_reward or -int(1e6)

        self.history_len = history_len
        self.buffer_size = buffer_size
        self.weights_sync_period = weights_sync_period
        self.episode_index = 0
        self.action_clip = action_clip

        self.infer = mode == "infer"
        self.seeds = seeds

        self.action_noise_prob = action_noise_prob
        self.action_noise_t = action_noise_t
        self.random_process = random_process or RandomProcess()

        self.param_noise_prob = param_noise_prob
        self.param_noise_d = param_noise_d
        self.param_noise_steps = param_noise_steps

        if self.infer:
            self.action_noise_prob = 0
            self.param_noise_prob = 0

        if logdir is not None:
            current_date = datetime.now().strftime("%y-%m-%d-%H-%M-%S-%M-%f")
            logpath = f"{logdir}/sampler-{mode}-{id}-{current_date}"
            create_if_need(logpath)
            self.logger = SummaryWriter(logpath)
        else:
            self.logger = None

        self.buffer = SamplerBuffer(
            capacity=self.buffer_size,
            state_shape=self.env.observation_space.shape,
            action_shape=self.env.action_space.shape)