class Predictor: def __init__(self, buffer, cfg, device="cuda"): self.device = device self.buffer = buffer self.model = PredictorModel(cfg["agent"]["rnn_size"]) self.model = self.model.to(device).train() lr = cfg["self_sup"]["lr"] self.optim = ParamOptim(params=self.model.parameters(), lr=lr) self.ri_mean = self.ri_std = None self.ri_momentum = cfg["self_sup"]["ri_momentum"] def get_error(self, batch, hx=None, update_stats=False): z = batch["obs"].float() action = batch["action"][1:] done = batch["done"][:-1] z_pred, hx = self.model(z[:-1], action, done, hx) err = (z[1:] - z_pred).pow(2).mean(2) ri = err.detach() if update_stats: if self.ri_mean is None: self.ri_mean = ri.mean() self.ri_std = ri.std() else: m = self.ri_momentum self.ri_mean = m * self.ri_mean + (1 - m) * ri.mean() self.ri_std = m * self.ri_std + (1 - m) * ri.std() if self.ri_mean is not None: ri = (ri[..., None] - self.ri_mean) / self.ri_std else: ri = 0 return err.mean(), ri, hx def train(self): # this function is used only for pretrain, main training loop is in dqn learner batch_size = 64 sample_steps = 100 if len(self.buffer) < self.buffer.maxlen: no_prev = set(range(sample_steps)) else: no_prev = set((self.buffer.cursor + i) % self.buffer.maxlen for i in range(sample_steps)) all_idx = list(set(range(len(self.buffer))) - no_prev) idx0 = torch.tensor(random.choices(all_idx, k=batch_size)) idx1 = torch.tensor( random.choices(range(self.buffer.num_env), k=batch_size)) batch = self.buffer.query(idx0, idx1, sample_steps) loss = self.get_error(batch, update_stats=True)[0] self.optim.step(loss) return {"loss_predictor": loss.item()} def load(self): cp = torch.load("models/predictor.pt", map_location=self.device) self.ri_mean, self.ri_std, model = cp self.model.load_state_dict(model) def save(self): data = [self.ri_mean, self.ri_std, self.model.state_dict()] torch.save(data, "models/predictor.pt")
class Learner: def __init__(self, model, buffer, predictor, cfg): model_t = deepcopy(model) model_t = model_t.cuda().eval() self.model, self.model_t = model, model_t self.buffer = buffer self.predictor = predictor self.optim = ParamOptim(params=model.parameters(), **cfg["optim"]) self.batch_size = cfg["agent"]["batch_size"] self.unroll = cfg["agent"]["unroll"] self.unroll_prefix = cfg["agent"]["burnin"] + 1 self.sample_steps = self.unroll_prefix + self.unroll self.target_tau = cfg["agent"]["target_tau"] self.td_error = partial(get_td_error, model=model, model_t=model_t, cfg=cfg) self.add_ri = cfg["add_ri"] def _update_target(self): for t, s in zip(self.model_t.parameters(), self.model.parameters()): t.data.copy_(t.data * (1.0 - self.target_tau) + s.data * self.target_tau) def loss_uniform(self): if len(self.buffer) < self.buffer.maxlen: no_prev = set(range(self.sample_steps)) else: no_prev = set( (self.buffer.cursor + i) % self.buffer.maxlen for i in range(self.sample_steps) ) all_idx = list(set(range(len(self.buffer))) - no_prev) idx0 = torch.tensor(random.choices(all_idx, k=self.batch_size)) idx1 = torch.tensor( random.choices(range(self.buffer.num_env), k=self.batch_size) ) batch = self.buffer.query(idx0, idx1, self.sample_steps) loss_pred, ri, _ = self.predictor.get_error(batch, update_stats=True) if self.add_ri: batch["reward"][1:] += ri td_error, log = self.td_error(batch, None) loss = td_error.pow(2).sum(0) return loss, loss_pred, ri, log def train(self, need_stat=True): loss, loss_pred, ri, log = self.loss_uniform() self.optim.step(loss.mean()) self.predictor.optim.step(loss_pred) self._update_target() if need_stat: log.update( { "ri_std": ri.std(), "ri_mean": ri.mean(), "ri_run_mean": self.predictor.ri_mean, "ri_run_std": self.predictor.ri_std, "loss_predictor": loss_pred.mean().detach(), } ) return log
class CPC: buffer: Buffer num_action: int frame_stack: int = 1 batch_size: int = 32 unroll: int = 32 emb_size: int = 32 lr: float = 5e-4 device: str = "cuda" def __post_init__(self): self.model = CPCModel(self.num_action, self.emb_size, self.frame_stack) self.model = self.model.train().to(self.device) self.optim = ParamOptim(params=self.model.parameters(), lr=self.lr) self.target = torch.arange(self.batch_size * self.unroll).to( self.device) def train(self): # burnin = 2, fstack = 4, unroll = 2 # idx 0 1 2 3 4 5 6 7 # bin p p p b b b # a a # hx # rol p p p o o o # a a sample_steps = self.frame_stack + self.unroll if len(self.buffer) < self.buffer.maxlen: no_prev = set(range(sample_steps)) else: no_prev = set((self.buffer.cursor + i) % self.buffer.maxlen for i in range(sample_steps)) all_idx = list(set(range(len(self.buffer))) - no_prev) idx0 = torch.tensor(random.choices(all_idx, k=self.batch_size)) idx1 = torch.tensor( random.choices(range(self.buffer.num_env), k=self.batch_size)) batch = self.buffer.query(idx0, idx1, sample_steps) obs = batch["obs"] action = batch["action"][self.frame_stack:] done = batch["done"] z, z_pred, _ = self.model(obs, action, done) size = self.batch_size * self.unroll z = z.view(size, self.emb_size) z_pred = z_pred.view(size, self.emb_size) logits = z @ z_pred.t() loss = cross_entropy(logits, self.target) acc = (logits.argmax(-1) == self.target).float().mean() self.optim.step(loss) return {"loss_cpc": loss.item(), "acc_cpc": acc} def load(self): cp = torch.load("models/cpc.pt", map_location=self.device) self.model.load_state_dict(cp) def save(self): torch.save(self.model.state_dict(), "models/cpc.pt")
class IDF: buffer: Buffer num_action: int emb_size: int = 32 batch_size: int = 256 lr: float = 5e-4 frame_stack: int = 1 device: str = "cuda" def __post_init__(self): self.encoder = mnih_cnn(self.frame_stack, self.emb_size) self.encoder = self.encoder.to(self.device).train() self.clf = nn.Sequential( nn.Linear(self.emb_size * 2, 128), nn.ReLU(), nn.Linear(128, self.num_action), ) self.clf = self.clf.to(self.device).train() params = chain(self.encoder.parameters(), self.clf.parameters()) self.optim = ParamOptim(lr=self.lr, params=params) def train(self): # 0 1 2 3 4 # p p p o o # a sample_steps = self.frame_stack + 1 if len(self.buffer) < self.buffer.maxlen: no_prev = set(range(sample_steps)) else: no_prev = set((self.buffer.cursor + i) % self.buffer.maxlen for i in range(sample_steps)) all_idx = list(set(range(len(self.buffer))) - no_prev) idx0 = torch.tensor(random.choices(all_idx, k=self.batch_size)) idx1 = torch.tensor( random.choices(range(self.buffer.num_env), k=self.batch_size)) batch = self.buffer.query(idx0, idx1, sample_steps) obs = prepare_obs(batch["obs"], batch["done"], self.frame_stack) action = batch["action"][-1, :, 0] x0, x1 = self.encoder(obs[0]), self.encoder(obs[1]) x = torch.cat([x0, x1], dim=-1) y = self.clf(relu(x)) loss_idf = cross_entropy(y, action) acc_idf = (y.argmax(-1) == action).float().mean() self.optim.step(loss_idf) return {"loss_idf": loss_idf, "acc_idf": acc_idf} def load(self): cp = torch.load("models/idf.pt", map_location=self.device) self.encoder.load_state_dict(cp[0]) self.clf.load_state_dict(cp[1]) def save(self): cp = [self.encoder.state_dict(), self.clf.state_dict()] torch.save(cp, "models/idf.pt")
class STDIM(BaseEncoder): def __post_init__(self): num_layer = 64 self.encoder = Conv(self.emb_size, num_layer).to(self.device) self.classifier1 = nn.Linear(self.emb_size, num_layer).to(self.device) self.classifier2 = nn.Linear(num_layer, num_layer).to(self.device) self.encoder.train() self.classifier1.train() self.classifier2.train() self.target = torch.arange(self.batch_size).to(self.device) params = list(self.encoder.parameters()) +\ list(self.classifier1.parameters()) +\ list(self.classifier2.parameters()) self.optim = ParamOptim(lr=self.lr, params=params) def _step(self, x1, x2): x1_loc, x1_glob = self.encoder.forward_blocks(x1) x2_loc = self.encoder.forward_block1(x2) sy, sx = x1_loc.shape[2:] loss_loc, loss_glob = 0, 0 for y in range(sy): for x in range(sx): positive = x2_loc[:, :, y, x] predictions = self.classifier1(x1_glob) logits = torch.matmul(predictions, positive.t()) loss_glob += F.cross_entropy(logits, self.target) predictions = self.classifier2(x1_loc[:, :, y, x]) logits = torch.matmul(predictions, positive.t()) loss_loc += F.cross_entropy(logits, self.target) loss_loc /= (sx * sy) loss_glob /= (sx * sy) return { 'sum': self.optim.step(loss_glob + loss_loc), 'glob': loss_glob, 'loc': loss_loc, }
class IIC(BaseEncoder): def __post_init__(self): num_heads = 6 self.encoder = Encoder(self.emb_size, num_heads).to(self.device) self.encoder.train() self.optim = ParamOptim(lr=self.lr, params=self.encoder.parameters()) self.head_loss = [deque(maxlen=256) for _ in range(num_heads)] def _step(self, x1, x2): heads = zip(self.encoder.forward_heads(x1), self.encoder.forward_heads(x2)) losses = {} for i, (x1, x2) in enumerate(heads): losses[f'head_{i}'] = loss = IID_loss(x1, x2) self.head_loss[i].append(loss.item()) loss_mean = sum(losses.values()) / len(losses) losses['mean'] = self.optim.step(loss_mean) return losses def select_head(self): x = torch.tensor(self.head_loss).mean(-1) x[0] = 1 # 0 head is auxiliary self.encoder.head_main = x.argmin().item() return self.encoder.head_main
class Predictor: def __init__(self, buffer, encoder, num_action, cfg, device="cuda"): self.device = device self.buffer = buffer self.encoder = encoder self.frame_stack = cfg["w_mse"]["frame_stack"] self.emb_size = cfg["w_mse"]["emb_size"] self.rnn_size = cfg["w_mse"]["rnn_size"] self.model = PredictorModel( num_action, self.frame_stack, self.emb_size, self.rnn_size ) self.model = self.model.to(device).train() lr = cfg["w_mse"]["lr"] self.optim = ParamOptim(params=self.model.parameters(), lr=lr) self.ri_mean = self.ri_std = None self.ri_momentum = cfg["w_mse"]["ri_momentum"] self.ri_clamp = cfg["w_mse"].get("ri_clamp") self.ri_scale = cfg["ri_scale"] def get_error(self, batch, hx=None, update_stats=False): # p p p o o o # a a obs = prepare_obs(batch["obs"], batch["done"], self.frame_stack) steps, batch_size, *obs_shape = obs.shape obs = obs.view(batch_size * steps, *obs_shape) with torch.no_grad(): z = self.encoder(obs) z = z.view(steps, batch_size, self.emb_size) action = batch["action"][self.frame_stack :] done = batch["done"][self.frame_stack - 1 : -1] z_pred, hx = self.model(z[:-1], action, done, hx) err = (z[1:] - z_pred).pow(2).mean(2) ri = err.detach() if update_stats: if self.ri_mean is None: self.ri_mean = ri.mean() self.ri_std = ri.std() else: m = self.ri_momentum self.ri_mean = m * self.ri_mean + (1 - m) * ri.mean() self.ri_std = m * self.ri_std + (1 - m) * ri.std() if self.ri_mean is not None: ri = (ri[..., None] - self.ri_mean) / self.ri_std if self.ri_clamp is not None: ri.clamp_(-self.ri_clamp, self.ri_clamp) ri *= self.ri_scale else: ri = 0 return err, ri, hx def train(self): # this function is used only for pretrain, main training loop is in dqn learner batch_size = 16 sample_steps = self.frame_stack - 1 + 100 if len(self.buffer) < self.buffer.maxlen: no_prev = set(range(sample_steps)) else: no_prev = set( (self.buffer.cursor + i) % self.buffer.maxlen for i in range(sample_steps) ) all_idx = list(set(range(len(self.buffer))) - no_prev) idx0 = torch.tensor(random.choices(all_idx, k=batch_size)) idx1 = torch.tensor(random.choices(range(self.buffer.num_env), k=batch_size)) batch = self.buffer.query(idx0, idx1, sample_steps) er = self.get_error(batch, update_stats=True)[0] loss = er.sum(0).mean() self.optim.step(loss) return {"loss_predictor": loss.item()} def load(self): cp = torch.load("models/predictor.pt", map_location=self.device) self.ri_mean, self.ri_std, model = cp self.model.load_state_dict(model) def save(self): data = [self.ri_mean, self.ri_std, self.model.state_dict()] torch.save(data, "models/predictor.pt")
class Learner: def __init__(self, model, buffer, predictor, cfg): num_env = cfg["agent"]["actors"] model_t = deepcopy(model) model_t = model_t.cuda().eval() self.model, self.model_t = model, model_t self.buffer = buffer self.predictor = predictor self.optim = ParamOptim(params=model.parameters(), **cfg["optim"]) self.batch_size = cfg["agent"]["batch_size"] self.unroll = cfg["agent"]["unroll"] self.unroll_prefix = (cfg["agent"]["burnin"] + cfg["agent"]["n_step"] + cfg["agent"]["frame_stack"] - 1) self.sample_steps = self.unroll_prefix + self.unroll self.hx_shift = cfg["agent"]["frame_stack"] - 1 num_unrolls = (self.buffer.maxlen - self.unroll_prefix) // self.unroll if cfg["buffer"]["prior_exp"] > 0: self.sampler = Sampler( num_env=num_env, maxlen=num_unrolls, prior_exp=cfg["buffer"]["prior_exp"], importance_sampling_exp=cfg["buffer"] ["importance_sampling_exp"], ) self.s2b = torch.empty(num_unrolls, dtype=torch.long) self.hxs = torch.empty(num_unrolls, num_env, 512, device="cuda") self.hx_cursor = 0 else: self.sampler = None self.target_tau = cfg["agent"]["target_tau"] self.td_error = partial(get_td_error, model=model, model_t=model_t, cfg=cfg) def _update_target(self): for t, s in zip(self.model_t.parameters(), self.model.parameters()): t.data.copy_(t.data * (1.0 - self.target_tau) + s.data * self.target_tau) def append(self, step, hx, n_iter): self.buffer.append(step) if self.sampler is not None: if (n_iter + 1) % self.unroll == self.hx_shift: self.hxs[self.hx_cursor] = hx self.hx_cursor = (self.hx_cursor + 1) % len(self.hxs) k = n_iter - self.unroll_prefix if k > 0 and (k + 1) % self.unroll == 0: self.s2b[self.sampler.cursor] = self.buffer.cursor - 1 x = self.buffer.get_recent(self.sample_steps) hx = self.hxs[self.sampler.cursor] with torch.no_grad(): loss, _ = self.td_error(x, hx) self.sampler.append(tde_to_prior(loss)) if len(self.sampler) == self.sampler.maxlen: idx_new = self.s2b[self.sampler.cursor - 1] idx_old = self.s2b[self.sampler.cursor] d = (idx_old - idx_new) % self.buffer.maxlen assert self.unroll_prefix + self.unroll <= d assert d < self.unroll_prefix + self.unroll * 2 def loss_sampler(self, need_stat): idx0, idx1, weights = self.sampler.sample(self.batch_size) weights = weights.cuda() batch = self.buffer.query(self.s2b[idx0], idx1, self.sample_steps) hx = self.hxs[idx0, idx1] loss_pred, ri, _ = self.predictor.get_error(batch, update_stats=True) batch["reward"][1:] += ri td_error, log = self.td_error(batch, hx, need_stat=need_stat) self.sampler.update_prior(idx0, idx1, tde_to_prior(td_error)) loss = td_error.pow(2).sum(0) * weights[..., None] loss_pred = loss_pred.sum(0) * weights[..., None] return loss, loss_pred, ri, log def loss_uniform(self, need_stat): if len(self.buffer) < self.buffer.maxlen: no_prev = set(range(self.sample_steps)) else: no_prev = set((self.buffer.cursor + i) % self.buffer.maxlen for i in range(self.sample_steps)) all_idx = list(set(range(len(self.buffer))) - no_prev) idx0 = torch.tensor(random.choices(all_idx, k=self.batch_size)) idx1 = torch.tensor( random.choices(range(self.buffer.num_env), k=self.batch_size)) batch = self.buffer.query(idx0, idx1, self.sample_steps) loss_pred, ri, _ = self.predictor.get_error(batch, update_stats=True) batch["reward"][1:] += ri td_error, log = self.td_error(batch, None, need_stat=need_stat) loss = td_error.pow(2).sum(0) loss_pred = loss_pred.sum(0) return loss, loss_pred, ri, log def train(self, need_stat=True): loss_f = self.loss_uniform if self.sampler is None else self.loss_sampler loss, loss_pred, ri, log = loss_f(need_stat) self.optim.step(loss.mean()) self.predictor.optim.step(loss_pred.mean()) self._update_target() if need_stat: log.update({ "ri_std": ri.std(), "ri_mean": ri.mean(), "ri_run_mean": self.predictor.ri_mean, "ri_run_std": self.predictor.ri_std, "loss_predictor": loss_pred.mean(), }) if self.sampler is not None: log.update(self.sampler.stats()) return log