Beispiel #1
0
class Predictor:
    def __init__(self, buffer, cfg, device="cuda"):
        self.device = device
        self.buffer = buffer

        self.model = PredictorModel(cfg["agent"]["rnn_size"])
        self.model = self.model.to(device).train()
        lr = cfg["self_sup"]["lr"]
        self.optim = ParamOptim(params=self.model.parameters(), lr=lr)
        self.ri_mean = self.ri_std = None
        self.ri_momentum = cfg["self_sup"]["ri_momentum"]

    def get_error(self, batch, hx=None, update_stats=False):
        z = batch["obs"].float()
        action = batch["action"][1:]
        done = batch["done"][:-1]
        z_pred, hx = self.model(z[:-1], action, done, hx)
        err = (z[1:] - z_pred).pow(2).mean(2)

        ri = err.detach()
        if update_stats:
            if self.ri_mean is None:
                self.ri_mean = ri.mean()
                self.ri_std = ri.std()
            else:
                m = self.ri_momentum
                self.ri_mean = m * self.ri_mean + (1 - m) * ri.mean()
                self.ri_std = m * self.ri_std + (1 - m) * ri.std()
        if self.ri_mean is not None:
            ri = (ri[..., None] - self.ri_mean) / self.ri_std
        else:
            ri = 0
        return err.mean(), ri, hx

    def train(self):
        # this function is used only for pretrain, main training loop is in dqn learner
        batch_size = 64
        sample_steps = 100
        if len(self.buffer) < self.buffer.maxlen:
            no_prev = set(range(sample_steps))
        else:
            no_prev = set((self.buffer.cursor + i) % self.buffer.maxlen
                          for i in range(sample_steps))
        all_idx = list(set(range(len(self.buffer))) - no_prev)
        idx0 = torch.tensor(random.choices(all_idx, k=batch_size))
        idx1 = torch.tensor(
            random.choices(range(self.buffer.num_env), k=batch_size))
        batch = self.buffer.query(idx0, idx1, sample_steps)
        loss = self.get_error(batch, update_stats=True)[0]
        self.optim.step(loss)
        return {"loss_predictor": loss.item()}

    def load(self):
        cp = torch.load("models/predictor.pt", map_location=self.device)
        self.ri_mean, self.ri_std, model = cp
        self.model.load_state_dict(model)

    def save(self):
        data = [self.ri_mean, self.ri_std, self.model.state_dict()]
        torch.save(data, "models/predictor.pt")
Beispiel #2
0
class Learner:
    def __init__(self, model, buffer, predictor, cfg):
        model_t = deepcopy(model)
        model_t = model_t.cuda().eval()
        self.model, self.model_t = model, model_t
        self.buffer = buffer
        self.predictor = predictor
        self.optim = ParamOptim(params=model.parameters(), **cfg["optim"])

        self.batch_size = cfg["agent"]["batch_size"]
        self.unroll = cfg["agent"]["unroll"]
        self.unroll_prefix = cfg["agent"]["burnin"] + 1
        self.sample_steps = self.unroll_prefix + self.unroll

        self.target_tau = cfg["agent"]["target_tau"]
        self.td_error = partial(get_td_error, model=model, model_t=model_t, cfg=cfg)
        self.add_ri = cfg["add_ri"]

    def _update_target(self):
        for t, s in zip(self.model_t.parameters(), self.model.parameters()):
            t.data.copy_(t.data * (1.0 - self.target_tau) + s.data * self.target_tau)

    def loss_uniform(self):
        if len(self.buffer) < self.buffer.maxlen:
            no_prev = set(range(self.sample_steps))
        else:
            no_prev = set(
                (self.buffer.cursor + i) % self.buffer.maxlen
                for i in range(self.sample_steps)
            )
        all_idx = list(set(range(len(self.buffer))) - no_prev)
        idx0 = torch.tensor(random.choices(all_idx, k=self.batch_size))
        idx1 = torch.tensor(
            random.choices(range(self.buffer.num_env), k=self.batch_size)
        )
        batch = self.buffer.query(idx0, idx1, self.sample_steps)
        loss_pred, ri, _ = self.predictor.get_error(batch, update_stats=True)
        if self.add_ri:
            batch["reward"][1:] += ri
        td_error, log = self.td_error(batch, None)
        loss = td_error.pow(2).sum(0)
        return loss, loss_pred, ri, log

    def train(self, need_stat=True):
        loss, loss_pred, ri, log = self.loss_uniform()
        self.optim.step(loss.mean())
        self.predictor.optim.step(loss_pred)
        self._update_target()

        if need_stat:
            log.update(
                {
                    "ri_std": ri.std(),
                    "ri_mean": ri.mean(),
                    "ri_run_mean": self.predictor.ri_mean,
                    "ri_run_std": self.predictor.ri_std,
                    "loss_predictor": loss_pred.mean().detach(),
                }
            )
        return log
Beispiel #3
0
class CPC:
    buffer: Buffer
    num_action: int
    frame_stack: int = 1
    batch_size: int = 32
    unroll: int = 32
    emb_size: int = 32
    lr: float = 5e-4
    device: str = "cuda"

    def __post_init__(self):
        self.model = CPCModel(self.num_action, self.emb_size, self.frame_stack)
        self.model = self.model.train().to(self.device)
        self.optim = ParamOptim(params=self.model.parameters(), lr=self.lr)
        self.target = torch.arange(self.batch_size * self.unroll).to(
            self.device)

    def train(self):
        # burnin = 2, fstack = 4, unroll = 2
        # idx 0 1 2 3 4 5 6 7
        # bin p p p b b b
        #             a a
        #             hx
        # rol     p p p o o o
        #                 a a

        sample_steps = self.frame_stack + self.unroll

        if len(self.buffer) < self.buffer.maxlen:
            no_prev = set(range(sample_steps))
        else:
            no_prev = set((self.buffer.cursor + i) % self.buffer.maxlen
                          for i in range(sample_steps))
        all_idx = list(set(range(len(self.buffer))) - no_prev)
        idx0 = torch.tensor(random.choices(all_idx, k=self.batch_size))
        idx1 = torch.tensor(
            random.choices(range(self.buffer.num_env), k=self.batch_size))
        batch = self.buffer.query(idx0, idx1, sample_steps)

        obs = batch["obs"]
        action = batch["action"][self.frame_stack:]
        done = batch["done"]
        z, z_pred, _ = self.model(obs, action, done)

        size = self.batch_size * self.unroll
        z = z.view(size, self.emb_size)
        z_pred = z_pred.view(size, self.emb_size)
        logits = z @ z_pred.t()
        loss = cross_entropy(logits, self.target)
        acc = (logits.argmax(-1) == self.target).float().mean()
        self.optim.step(loss)
        return {"loss_cpc": loss.item(), "acc_cpc": acc}

    def load(self):
        cp = torch.load("models/cpc.pt", map_location=self.device)
        self.model.load_state_dict(cp)

    def save(self):
        torch.save(self.model.state_dict(), "models/cpc.pt")
Beispiel #4
0
class IDF:
    buffer: Buffer
    num_action: int
    emb_size: int = 32
    batch_size: int = 256
    lr: float = 5e-4
    frame_stack: int = 1
    device: str = "cuda"

    def __post_init__(self):
        self.encoder = mnih_cnn(self.frame_stack, self.emb_size)
        self.encoder = self.encoder.to(self.device).train()
        self.clf = nn.Sequential(
            nn.Linear(self.emb_size * 2, 128),
            nn.ReLU(),
            nn.Linear(128, self.num_action),
        )
        self.clf = self.clf.to(self.device).train()
        params = chain(self.encoder.parameters(), self.clf.parameters())
        self.optim = ParamOptim(lr=self.lr, params=params)

    def train(self):
        # 0 1 2 3 4
        # p p p o o
        #         a

        sample_steps = self.frame_stack + 1
        if len(self.buffer) < self.buffer.maxlen:
            no_prev = set(range(sample_steps))
        else:
            no_prev = set((self.buffer.cursor + i) % self.buffer.maxlen
                          for i in range(sample_steps))
        all_idx = list(set(range(len(self.buffer))) - no_prev)
        idx0 = torch.tensor(random.choices(all_idx, k=self.batch_size))
        idx1 = torch.tensor(
            random.choices(range(self.buffer.num_env), k=self.batch_size))
        batch = self.buffer.query(idx0, idx1, sample_steps)
        obs = prepare_obs(batch["obs"], batch["done"], self.frame_stack)
        action = batch["action"][-1, :, 0]

        x0, x1 = self.encoder(obs[0]), self.encoder(obs[1])
        x = torch.cat([x0, x1], dim=-1)
        y = self.clf(relu(x))
        loss_idf = cross_entropy(y, action)
        acc_idf = (y.argmax(-1) == action).float().mean()

        self.optim.step(loss_idf)
        return {"loss_idf": loss_idf, "acc_idf": acc_idf}

    def load(self):
        cp = torch.load("models/idf.pt", map_location=self.device)
        self.encoder.load_state_dict(cp[0])
        self.clf.load_state_dict(cp[1])

    def save(self):
        cp = [self.encoder.state_dict(), self.clf.state_dict()]
        torch.save(cp, "models/idf.pt")
Beispiel #5
0
    def __init__(self, buffer, cfg, device="cuda"):
        self.device = device
        self.buffer = buffer

        self.model = PredictorModel(cfg["agent"]["rnn_size"])
        self.model = self.model.to(device).train()
        lr = cfg["self_sup"]["lr"]
        self.optim = ParamOptim(params=self.model.parameters(), lr=lr)
        self.ri_mean = self.ri_std = None
        self.ri_momentum = cfg["self_sup"]["ri_momentum"]
Beispiel #6
0
 def __post_init__(self):
     self.encoder = mnih_cnn(self.frame_stack, self.emb_size)
     self.encoder = self.encoder.to(self.device).train()
     self.clf = nn.Sequential(
         nn.Linear(self.emb_size * 2, 128),
         nn.ReLU(),
         nn.Linear(128, self.num_action),
     )
     self.clf = self.clf.to(self.device).train()
     params = chain(self.encoder.parameters(), self.clf.parameters())
     self.optim = ParamOptim(lr=self.lr, params=params)
Beispiel #7
0
    def __post_init__(self):
        num_layer = 64
        self.encoder = Conv(self.emb_size, num_layer).to(self.device)
        self.classifier1 = nn.Linear(self.emb_size, num_layer).to(self.device)
        self.classifier2 = nn.Linear(num_layer, num_layer).to(self.device)
        self.encoder.train()
        self.classifier1.train()
        self.classifier2.train()
        self.target = torch.arange(self.batch_size).to(self.device)

        params = list(self.encoder.parameters()) +\
            list(self.classifier1.parameters()) +\
            list(self.classifier2.parameters())
        self.optim = ParamOptim(lr=self.lr, params=params)
Beispiel #8
0
    def __init__(self, model, buffer, predictor, cfg):
        model_t = deepcopy(model)
        model_t = model_t.cuda().eval()
        self.model, self.model_t = model, model_t
        self.buffer = buffer
        self.predictor = predictor
        self.optim = ParamOptim(params=model.parameters(), **cfg["optim"])

        self.batch_size = cfg["agent"]["batch_size"]
        self.unroll = cfg["agent"]["unroll"]
        self.unroll_prefix = cfg["agent"]["burnin"] + 1
        self.sample_steps = self.unroll_prefix + self.unroll

        self.target_tau = cfg["agent"]["target_tau"]
        self.td_error = partial(get_td_error, model=model, model_t=model_t, cfg=cfg)
        self.add_ri = cfg["add_ri"]
Beispiel #9
0
def train(cfg_name, resume):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f'running on {device}')
    cfg = load_cfg(cfg_name)
    log = Logger(device=device)
    envs = make_vec_envs(**cfg['env'])
    model, n_start = init_model(cfg, envs, device, resume)
    runner = EnvRunner(rollout_size=cfg['train']['rollout_size'],
                       envs=envs,
                       model=model,
                       device=device)
    optim = ParamOptim(**cfg['optimizer'], params=model.parameters())
    agent = Agent(model=model, optim=optim, **cfg['agent'])

    cp_iter = cfg['train']['checkpoint_every']
    log_iter = cfg['train']['log_every']
    n_end = cfg['train']['steps']
    cp_name = cfg['train']['checkpoint_name']

    for n_iter, rollout in zip(trange(n_start, n_end), runner):
        agent_log = agent.update(rollout)

        if n_iter % log_iter == 0:
            log.output({**agent_log, **runner.get_logs()}, n_iter)

        if n_iter > n_start and n_iter % cp_iter == 0:
            f = cp_name.format(n_iter=n_iter // cp_iter)
            torch.save(model.state_dict(), f)
Beispiel #10
0
def train(cfg_name, env_name):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f'running on {device}')
    cfg = load_cfg(cfg_name)
    log = Logger(device=device)
    if env_name == 'OT':
        envs = make_obstacle_tower(cfg['train']['num_env'])
    else:
        envs = make_vec_envs(env_name + 'NoFrameskip-v4',
                             cfg['train']['num_env'])

    emb = cfg['embedding']
    model = ActorCritic(output_size=envs.action_space.n,
                        device=device,
                        emb_size=emb['size'])
    model.train().to(device=device)

    runner = EnvRunner(
        rollout_size=cfg['train']['rollout_size'],
        envs=envs,
        model=model,
        device=device,
        emb_stack=emb['history_size'],
    )

    optim = ParamOptim(**cfg['optimizer'], params=model.parameters())
    agent = Agent(model=model, optim=optim, **cfg['agent'])

    n_start = 0
    log_iter = cfg['train']['log_every']
    n_end = cfg['train']['steps']

    log.log.add_text('env', env_name)

    for n_iter, rollout in zip(trange(n_start, n_end), runner):
        progress = n_iter / n_end
        optim.update(progress)
        agent_log = agent.update(rollout, progress)
        if n_iter % log_iter == 0:
            log.output({**agent_log, **runner.get_logs()}, n_iter)

    reward = eval_model(model, envs, emb['history_size'], emb['size'], device)
    reward_str = f'{reward.mean():.2f} ± {reward.std():.2f}'
    log.log.add_text('final', reward_str)
    log.log.close()
Beispiel #11
0
    def __init__(self, buffer, encoder, num_action, cfg, device="cuda"):
        self.device = device
        self.buffer = buffer
        self.encoder = encoder

        self.frame_stack = cfg["w_mse"]["frame_stack"]
        self.emb_size = cfg["w_mse"]["emb_size"]
        self.rnn_size = cfg["w_mse"]["rnn_size"]

        self.model = PredictorModel(
            num_action, self.frame_stack, self.emb_size, self.rnn_size
        )
        self.model = self.model.to(device).train()
        lr = cfg["w_mse"]["lr"]
        self.optim = ParamOptim(params=self.model.parameters(), lr=lr)
        self.ri_mean = self.ri_std = None
        self.ri_momentum = cfg["w_mse"]["ri_momentum"]
        self.ri_clamp = cfg["w_mse"].get("ri_clamp")
        self.ri_scale = cfg["ri_scale"]
Beispiel #12
0
    def __init__(self, model, buffer, predictor, cfg):
        num_env = cfg["agent"]["actors"]
        model_t = deepcopy(model)
        model_t = model_t.cuda().eval()
        self.model, self.model_t = model, model_t
        self.buffer = buffer
        self.predictor = predictor
        self.optim = ParamOptim(params=model.parameters(), **cfg["optim"])

        self.batch_size = cfg["agent"]["batch_size"]
        self.unroll = cfg["agent"]["unroll"]
        self.unroll_prefix = (cfg["agent"]["burnin"] + cfg["agent"]["n_step"] +
                              cfg["agent"]["frame_stack"] - 1)
        self.sample_steps = self.unroll_prefix + self.unroll
        self.hx_shift = cfg["agent"]["frame_stack"] - 1
        num_unrolls = (self.buffer.maxlen - self.unroll_prefix) // self.unroll

        if cfg["buffer"]["prior_exp"] > 0:
            self.sampler = Sampler(
                num_env=num_env,
                maxlen=num_unrolls,
                prior_exp=cfg["buffer"]["prior_exp"],
                importance_sampling_exp=cfg["buffer"]
                ["importance_sampling_exp"],
            )
            self.s2b = torch.empty(num_unrolls, dtype=torch.long)
            self.hxs = torch.empty(num_unrolls, num_env, 512, device="cuda")
            self.hx_cursor = 0
        else:
            self.sampler = None

        self.target_tau = cfg["agent"]["target_tau"]
        self.td_error = partial(get_td_error,
                                model=model,
                                model_t=model_t,
                                cfg=cfg)
Beispiel #13
0
class STDIM(BaseEncoder):
    def __post_init__(self):
        num_layer = 64
        self.encoder = Conv(self.emb_size, num_layer).to(self.device)
        self.classifier1 = nn.Linear(self.emb_size, num_layer).to(self.device)
        self.classifier2 = nn.Linear(num_layer, num_layer).to(self.device)
        self.encoder.train()
        self.classifier1.train()
        self.classifier2.train()
        self.target = torch.arange(self.batch_size).to(self.device)

        params = list(self.encoder.parameters()) +\
            list(self.classifier1.parameters()) +\
            list(self.classifier2.parameters())
        self.optim = ParamOptim(lr=self.lr, params=params)

    def _step(self, x1, x2):
        x1_loc, x1_glob = self.encoder.forward_blocks(x1)
        x2_loc = self.encoder.forward_block1(x2)
        sy, sx = x1_loc.shape[2:]
        loss_loc, loss_glob = 0, 0
        for y in range(sy):
            for x in range(sx):
                positive = x2_loc[:, :, y, x]

                predictions = self.classifier1(x1_glob)
                logits = torch.matmul(predictions, positive.t())
                loss_glob += F.cross_entropy(logits, self.target)

                predictions = self.classifier2(x1_loc[:, :, y, x])
                logits = torch.matmul(predictions, positive.t())
                loss_loc += F.cross_entropy(logits, self.target)
        loss_loc /= (sx * sy)
        loss_glob /= (sx * sy)
        return {
            'sum': self.optim.step(loss_glob + loss_loc),
            'glob': loss_glob,
            'loc': loss_loc,
        }
Beispiel #14
0
class IIC(BaseEncoder):
    def __post_init__(self):
        num_heads = 6
        self.encoder = Encoder(self.emb_size, num_heads).to(self.device)
        self.encoder.train()
        self.optim = ParamOptim(lr=self.lr, params=self.encoder.parameters())
        self.head_loss = [deque(maxlen=256) for _ in range(num_heads)]

    def _step(self, x1, x2):
        heads = zip(self.encoder.forward_heads(x1),
                    self.encoder.forward_heads(x2))
        losses = {}
        for i, (x1, x2) in enumerate(heads):
            losses[f'head_{i}'] = loss = IID_loss(x1, x2)
            self.head_loss[i].append(loss.item())
        loss_mean = sum(losses.values()) / len(losses)
        losses['mean'] = self.optim.step(loss_mean)
        return losses

    def select_head(self):
        x = torch.tensor(self.head_loss).mean(-1)
        x[0] = 1  # 0 head is auxiliary
        self.encoder.head_main = x.argmin().item()
        return self.encoder.head_main
Beispiel #15
0
 def __post_init__(self):
     self.model = CPCModel(self.num_action, self.emb_size, self.frame_stack)
     self.model = self.model.train().to(self.device)
     self.optim = ParamOptim(params=self.model.parameters(), lr=self.lr)
     self.target = torch.arange(self.batch_size * self.unroll).to(
         self.device)
Beispiel #16
0
 def __post_init__(self):
     num_heads = 6
     self.encoder = Encoder(self.emb_size, num_heads).to(self.device)
     self.encoder.train()
     self.optim = ParamOptim(lr=self.lr, params=self.encoder.parameters())
     self.head_loss = [deque(maxlen=256) for _ in range(num_heads)]
Beispiel #17
0
class Learner:
    def __init__(self, model, buffer, predictor, cfg):
        num_env = cfg["agent"]["actors"]
        model_t = deepcopy(model)
        model_t = model_t.cuda().eval()
        self.model, self.model_t = model, model_t
        self.buffer = buffer
        self.predictor = predictor
        self.optim = ParamOptim(params=model.parameters(), **cfg["optim"])

        self.batch_size = cfg["agent"]["batch_size"]
        self.unroll = cfg["agent"]["unroll"]
        self.unroll_prefix = (cfg["agent"]["burnin"] + cfg["agent"]["n_step"] +
                              cfg["agent"]["frame_stack"] - 1)
        self.sample_steps = self.unroll_prefix + self.unroll
        self.hx_shift = cfg["agent"]["frame_stack"] - 1
        num_unrolls = (self.buffer.maxlen - self.unroll_prefix) // self.unroll

        if cfg["buffer"]["prior_exp"] > 0:
            self.sampler = Sampler(
                num_env=num_env,
                maxlen=num_unrolls,
                prior_exp=cfg["buffer"]["prior_exp"],
                importance_sampling_exp=cfg["buffer"]
                ["importance_sampling_exp"],
            )
            self.s2b = torch.empty(num_unrolls, dtype=torch.long)
            self.hxs = torch.empty(num_unrolls, num_env, 512, device="cuda")
            self.hx_cursor = 0
        else:
            self.sampler = None

        self.target_tau = cfg["agent"]["target_tau"]
        self.td_error = partial(get_td_error,
                                model=model,
                                model_t=model_t,
                                cfg=cfg)

    def _update_target(self):
        for t, s in zip(self.model_t.parameters(), self.model.parameters()):
            t.data.copy_(t.data * (1.0 - self.target_tau) +
                         s.data * self.target_tau)

    def append(self, step, hx, n_iter):
        self.buffer.append(step)

        if self.sampler is not None:
            if (n_iter + 1) % self.unroll == self.hx_shift:
                self.hxs[self.hx_cursor] = hx
                self.hx_cursor = (self.hx_cursor + 1) % len(self.hxs)

            k = n_iter - self.unroll_prefix
            if k > 0 and (k + 1) % self.unroll == 0:
                self.s2b[self.sampler.cursor] = self.buffer.cursor - 1
                x = self.buffer.get_recent(self.sample_steps)
                hx = self.hxs[self.sampler.cursor]
                with torch.no_grad():
                    loss, _ = self.td_error(x, hx)
                self.sampler.append(tde_to_prior(loss))

                if len(self.sampler) == self.sampler.maxlen:
                    idx_new = self.s2b[self.sampler.cursor - 1]
                    idx_old = self.s2b[self.sampler.cursor]
                    d = (idx_old - idx_new) % self.buffer.maxlen
                    assert self.unroll_prefix + self.unroll <= d
                    assert d < self.unroll_prefix + self.unroll * 2

    def loss_sampler(self, need_stat):
        idx0, idx1, weights = self.sampler.sample(self.batch_size)
        weights = weights.cuda()
        batch = self.buffer.query(self.s2b[idx0], idx1, self.sample_steps)
        hx = self.hxs[idx0, idx1]
        loss_pred, ri, _ = self.predictor.get_error(batch, update_stats=True)
        batch["reward"][1:] += ri
        td_error, log = self.td_error(batch, hx, need_stat=need_stat)
        self.sampler.update_prior(idx0, idx1, tde_to_prior(td_error))
        loss = td_error.pow(2).sum(0) * weights[..., None]
        loss_pred = loss_pred.sum(0) * weights[..., None]
        return loss, loss_pred, ri, log

    def loss_uniform(self, need_stat):
        if len(self.buffer) < self.buffer.maxlen:
            no_prev = set(range(self.sample_steps))
        else:
            no_prev = set((self.buffer.cursor + i) % self.buffer.maxlen
                          for i in range(self.sample_steps))
        all_idx = list(set(range(len(self.buffer))) - no_prev)
        idx0 = torch.tensor(random.choices(all_idx, k=self.batch_size))
        idx1 = torch.tensor(
            random.choices(range(self.buffer.num_env), k=self.batch_size))
        batch = self.buffer.query(idx0, idx1, self.sample_steps)
        loss_pred, ri, _ = self.predictor.get_error(batch, update_stats=True)
        batch["reward"][1:] += ri
        td_error, log = self.td_error(batch, None, need_stat=need_stat)
        loss = td_error.pow(2).sum(0)
        loss_pred = loss_pred.sum(0)
        return loss, loss_pred, ri, log

    def train(self, need_stat=True):
        loss_f = self.loss_uniform if self.sampler is None else self.loss_sampler
        loss, loss_pred, ri, log = loss_f(need_stat)
        self.optim.step(loss.mean())
        self.predictor.optim.step(loss_pred.mean())
        self._update_target()

        if need_stat:
            log.update({
                "ri_std": ri.std(),
                "ri_mean": ri.mean(),
                "ri_run_mean": self.predictor.ri_mean,
                "ri_run_std": self.predictor.ri_std,
                "loss_predictor": loss_pred.mean(),
            })
            if self.sampler is not None:
                log.update(self.sampler.stats())
        return log
Beispiel #18
0
class Predictor:
    def __init__(self, buffer, encoder, num_action, cfg, device="cuda"):
        self.device = device
        self.buffer = buffer
        self.encoder = encoder

        self.frame_stack = cfg["w_mse"]["frame_stack"]
        self.emb_size = cfg["w_mse"]["emb_size"]
        self.rnn_size = cfg["w_mse"]["rnn_size"]

        self.model = PredictorModel(
            num_action, self.frame_stack, self.emb_size, self.rnn_size
        )
        self.model = self.model.to(device).train()
        lr = cfg["w_mse"]["lr"]
        self.optim = ParamOptim(params=self.model.parameters(), lr=lr)
        self.ri_mean = self.ri_std = None
        self.ri_momentum = cfg["w_mse"]["ri_momentum"]
        self.ri_clamp = cfg["w_mse"].get("ri_clamp")
        self.ri_scale = cfg["ri_scale"]

    def get_error(self, batch, hx=None, update_stats=False):
        # p p p o o o
        #         a a
        obs = prepare_obs(batch["obs"], batch["done"], self.frame_stack)
        steps, batch_size, *obs_shape = obs.shape
        obs = obs.view(batch_size * steps, *obs_shape)
        with torch.no_grad():
            z = self.encoder(obs)
        z = z.view(steps, batch_size, self.emb_size)

        action = batch["action"][self.frame_stack :]
        done = batch["done"][self.frame_stack - 1 : -1]
        z_pred, hx = self.model(z[:-1], action, done, hx)
        err = (z[1:] - z_pred).pow(2).mean(2)

        ri = err.detach()
        if update_stats:
            if self.ri_mean is None:
                self.ri_mean = ri.mean()
                self.ri_std = ri.std()
            else:
                m = self.ri_momentum
                self.ri_mean = m * self.ri_mean + (1 - m) * ri.mean()
                self.ri_std = m * self.ri_std + (1 - m) * ri.std()
        if self.ri_mean is not None:
            ri = (ri[..., None] - self.ri_mean) / self.ri_std
            if self.ri_clamp is not None:
                ri.clamp_(-self.ri_clamp, self.ri_clamp)
            ri *= self.ri_scale
        else:
            ri = 0
        return err, ri, hx

    def train(self):
        # this function is used only for pretrain, main training loop is in dqn learner
        batch_size = 16
        sample_steps = self.frame_stack - 1 + 100
        if len(self.buffer) < self.buffer.maxlen:
            no_prev = set(range(sample_steps))
        else:
            no_prev = set(
                (self.buffer.cursor + i) % self.buffer.maxlen
                for i in range(sample_steps)
            )
        all_idx = list(set(range(len(self.buffer))) - no_prev)
        idx0 = torch.tensor(random.choices(all_idx, k=batch_size))
        idx1 = torch.tensor(random.choices(range(self.buffer.num_env), k=batch_size))
        batch = self.buffer.query(idx0, idx1, sample_steps)
        er = self.get_error(batch, update_stats=True)[0]
        loss = er.sum(0).mean()
        self.optim.step(loss)
        return {"loss_predictor": loss.item()}

    def load(self):
        cp = torch.load("models/predictor.pt", map_location=self.device)
        self.ri_mean, self.ri_std, model = cp
        self.model.load_state_dict(model)

    def save(self):
        data = [self.ri_mean, self.ri_std, self.model.state_dict()]
        torch.save(data, "models/predictor.pt")