class AtariPgModel(torch.nn.Module): """Can feed in conv and/or fc1 layer from pre-trained model, or have it initialize new ones (if initializing new, must provide image_shape).""" def __init__( self, image_shape, action_size, hidden_sizes=512, stop_conv_grad=False, channels=None, # Defaults below. kernel_sizes=None, strides=None, paddings=None, kiaming_init=True, normalize_conv_out=False, ): super().__init__() c, h, w = image_shape self.conv = Conv2dModel( in_channels=c, channels=channels or [32, 64, 64], kernel_sizes=kernel_sizes or [8, 4, 3], strides=strides or [4, 2, 1], paddings=paddings, ) self._conv_out_size = self.conv.conv_out_size(h=h, w=w) self.pi_v_mlp = MlpModel( input_size=self._conv_out_size, hidden_sizes=hidden_sizes, output_size=action_size + 1, ) if kiaming_init: self.apply(weight_init) self.stop_conv_grad = stop_conv_grad logger.log("Model stopping gradient at CONV." if stop_conv_grad else "Modeul using gradients on all parameters.") if normalize_conv_out: # Havent' seen this make a difference yet. logger.log("Model normalizing conv output across all pixels.") self.conv_rms = RunningMeanStdModel((1, )) self.var_clip = 1e-6 self.normalize_conv_out = normalize_conv_out def forward(self, observation, prev_action, prev_reward): if observation.dtype == torch.uint8: img = observation.type(torch.float) img = img.mul_(1.0 / 255) else: img = observation lead_dim, T, B, img_shape = infer_leading_dims(img, 3) conv = self.conv(img.view(T * B, *img_shape)) if self.stop_conv_grad: conv = conv.detach() if self.normalize_conv_out: conv_var = self.conv_rms.var conv_var = torch.clamp(conv_var, min=self.var_clip) # stddev of uniform [a,b] = (b-a)/sqrt(12), 1/sqrt(12)~0.29 # then allow [0, 10]? conv = torch.clamp(0.29 * conv / conv_var.sqrt(), 0, 10) pi_v = self.pi_v_mlp(conv.view(T * B, -1)) pi = F.softmax(pi_v[:, :-1], dim=-1) v = pi_v[:, -1] pi, v, conv = restore_leading_dims((pi, v, conv), lead_dim, T, B) return pi, v, conv def update_conv_rms(self, observation): if self.normalize_conv_out: with torch.no_grad(): if observation.dtype == torch.uint8: img = observation.type(torch.float) img = img.mul_(1.0 / 255) else: img = observation lead_dim, T, B, img_shape = infer_leading_dims(img, 3) conv = self.conv(img.view(T * B, *img_shape)) self.conv_rms.update(conv.view(-1, 1)) def parameters(self): if not self.stop_conv_grad: yield from self.conv.parameters() yield from self.pi_v_mlp.parameters() def named_parameters(self): if not self.stop_conv_grad: yield from self.conv.named_parameters() yield from self.pi_v_mlp.named_parameters() @property def conv_out_size(self): return self._conv_out_size
class DmlabPgLstmModel(torch.nn.Module): def __init__( self, image_shape, output_size, lstm_size, skip_connections=True, hidden_sizes=None, kiaming_init=True, stop_conv_grad=False, skip_lstm=True, ): super().__init__() c, h, w = image_shape self.conv = DmlabConv2dModel( in_channels=c, use_fourth_layer=True, use_maxpool=False, skip_connections=skip_connections, ) self._conv_out_size = self.conv.output_size(h=h, w=w) self.fc1 = torch.nn.Linear( in_features=self._conv_out_size, out_features=lstm_size, ) self.lstm = torch.nn.LSTM(lstm_size + output_size + 1, lstm_size) self.pi_v_head = MlpModel( input_size=lstm_size, hidden_sizes=hidden_sizes, output_size=output_size + 1, ) if kiaming_init: self.apply(weight_init) self.stop_conv_grad = stop_conv_grad logger.log("Model stopping gradient at CONV." if stop_conv_grad else "Modeul using gradients on all parameters.") self._skip_lstm = skip_lstm def forward(self, observation, prev_action, prev_reward, init_rnn_state): if observation.dtype == torch.uint8: img = observation.type(torch.float) img = img.mul_(1.0 / 255) else: img = observation lead_dim, T, B, img_shape = infer_leading_dims(img, 3) conv = self.conv(img.view(T * B, *img_shape)) if self.stop_conv_grad: conv = conv.detach() fc1 = F.relu(self.fc1(conv.view(T * B, -1))) lstm_input = torch.cat( [ fc1.view(T, B, -1), prev_action.view(T, B, -1), # Assumed onehot prev_reward.view(T, B, 1), ], dim=2, ) init_rnn_state = None if init_rnn_state is None else tuple( init_rnn_state) lstm_out, (hn, cn) = self.lstm(lstm_input, init_rnn_state) if self._skip_lstm: lstm_out = lstm_out.view(T * B, -1) + fc1 pi_v = self.pi_v_head(lstm_out.view(T * B, -1)) pi = F.softmax(pi_v[:, :-1], dim=-1) v = pi_v[:, -1] pi, v, conv = restore_leading_dims((pi, v, conv), lead_dim, T, B) next_rnn_state = RnnState(h=hn, c=cn) return pi, v, next_rnn_state, conv def parameters(self): if not self.stop_conv_grad: yield from self.conv.parameters() yield from self.fc1.parameters() yield from self.lstm.parameters() yield from self.pi_v_head.parameters() def named_parameters(self): if not self.stop_conv_grad: yield from self.conv.named_parameters() yield from self.fc1.named_parameters() yield from self.lstm.named_parameters() yield from self.pi_v_head.named_parameters() @property def conv_out_size(self): return self._conv_out_size
class AugmentedTemporalSimilarity(BaseUlAlgorithm): """Similarity loss (as in BYOL) against one future time step, using a momentum encoder for the target.""" opt_info_fields = tuple(f for f in OptInfo._fields) # copy def __init__( self, replay_filepath, ReplayCls=UlForRlReplayBuffer, delta_T=1, batch_T=1, batch_B=256, learning_rate=1e-3, learning_rate_anneal=None, # cosine learning_rate_warmup=0, # number of updates OptimCls=torch.optim.Adam, optim_kwargs=None, clip_grad_norm=10., target_update_tau=0.01, # 1 for hard update target_update_interval=1, EncoderCls=EncoderModel, encoder_kwargs=None, latent_size=256, anchor_hidden_sizes=512, initial_state_dict=None, random_shift_prob=1., random_shift_pad=4, activation_loss_coefficient=0., # rarely if ever use validation_split=0.0, n_validation_batches=0, # usually don't do it. ): encoder_kwargs = dict() if encoder_kwargs is None else encoder_kwargs save__init__args(locals()) assert learning_rate_anneal in [None, "cosine"] self.batch_size = batch_B * batch_T # for logging only self._replay_T = batch_T + delta_T def initialize(self, n_updates, cuda_idx=None): self.device = torch.device( "cpu") if cuda_idx is None else torch.device("cuda", index=cuda_idx) examples = self.load_replay() self.encoder = self.EncoderCls(image_shape=examples.observation.shape, latent_size=self.latent_size, **self.encoder_kwargs) self.target_encoder = copy.deepcopy(self.encoder) self.predictor = MlpModel( input_size=self.latent_size, hidden_sizes=self.anchor_hidden_sizes, output_size=self.latent_size, ) self.encoder.to(self.device) self.target_encoder.to(self.device) self.predictor.to(self.device) self.optim_initialize(n_updates) if self.initial_state_dict is not None: self.load_state_dict(self.initial_state_dict) def optimize(self, itr): opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) samples = self.replay_buffer.sample_batch(self.batch_B) if self.lr_scheduler is not None: self.lr_scheduler.step(itr) # Do every itr instead of every epoch self.optimizer.zero_grad() ats_loss, conv_output = self.ats_loss(samples) act_loss = self.activation_loss(conv_output) loss = ats_loss + act_loss loss.backward() if self.clip_grad_norm is None: grad_norm = 0. else: grad_norm = torch.nn.utils.clip_grad_norm_(self.parameters(), self.clip_grad_norm) self.optimizer.step() opt_info.atsLoss.append(ats_loss.item()) opt_info.activationLoss.append(act_loss.item()) opt_info.gradNorm.append(grad_norm.item()) opt_info.convActivation.append( conv_output[0].detach().cpu().view(-1).numpy()) # Keep 1 full one. if itr % self.target_update_interval == 0: update_state_dict(self.target_encoder, self.encoder.state_dict(), self.target_update_tau) return opt_info def ats_loss(self, samples): anchor = (samples.observation if self.delta_T == 0 else samples.observation[:-self.delta_T]) positive = samples.observation[self.delta_T:] t, b, c, h, w = anchor.shape anchor = anchor.view(t * b, c, h, w) # Treat all T,B as separate. positive = positive.view(t * b, c, h, w) if self.random_shift_prob > 0.: anchor = random_shift( imgs=anchor, pad=self.random_shift_pad, prob=self.random_shift_prob, ) positive = random_shift( imgs=positive, pad=self.random_shift_pad, prob=self.random_shift_prob, ) anchor, positive = buffer_to((anchor, positive), device=self.device) with torch.no_grad(): z_positive, _ = self.target_encoder(positive) z_anchor, conv_output = self.encoder(anchor) q_anchor = self.predictor(z_anchor) q = F.normalize(q_anchor, dim=-1, p=2) z = F.normalize(z_positive, dim=-1, p=2) ats_losses = 2. - 2 * (q * z).sum(dim=-1) # from BYOL valid = valid_from_done(samples.done.type(torch.bool)) valid = valid[self.delta_T:].reshape(-1) valid = valid.to(self.device) ats_loss = valid_mean(ats_losses, valid) return ats_loss, conv_output def validation(self, itr): logger.log("Computing validation loss...") val_info = ValInfo(*([] for _ in range(len(ValInfo._fields)))) self.optimizer.zero_grad() for _ in range(self.n_validation_batches): samples = self.replay_buffer.sample_batch(self.batch_B, validation=True) with torch.no_grad(): ats_loss, conv_output = self.ats_loss(samples) val_info.atsLoss.append(ats_loss.item()) val_info.convActivation.append( conv_output[0].detach().cpu().view(-1).numpy()) self.optimizer.zero_grad() logger.log("...validation loss completed.") return val_info def state_dict(self): return dict( encoder=self.encoder.state_dict(), target_encoder=self.target_encoder.state_dict(), predictor=self.predictor.state_dict(), optimizer=self.optimizer.state_dict(), ) def load_state_dict(self, state_dict): self.encoder.load_state_dict(state_dict["encoder"]) self.target_encoder.load_state_dict(state_dict["target_encoder"]) self.predictor.load_state_dict(state_dict["predictor"]) self.optimizer.load_state_dict(state_dict["optimizer"]) def parameters(self): yield from self.encoder.parameters() yield from self.predictor.parameters() def named_parameters(self): """To allow filtering by name in weight decay.""" yield from self.encoder.named_parameters() yield from self.predictor.named_parameters() def eval(self): self.encoder.eval() # in case of batch norm self.predictor.eval() def train(self): self.encoder.train() self.predictor.train()