Exemple #1
0
    def __init__(self, obs_shape, action_shape, args):
        super().__init__(obs_shape, action_shape, args)
        self.aux_update_freq = args.aux_update_freq
        self.soda_batch_size = args.soda_batch_size
        self.soda_tau = args.soda_tau
        self.use_intrinsic = args.use_intrinsic
        self.in_gamma = args.in_gamma
        self.in_decay = args.in_decay
        self.args = args

        shared_cnn = self.critic.encoder.shared_cnn
        aux_cnn = self.critic.encoder.head_cnn
        soda_encoder = m.Encoder(
            shared_cnn, aux_cnn,
            m.SODAMLP(aux_cnn.out_shape[0], args.projection_dim,
                      args.projection_dim))

        self.predictor = m.SODAPredictor(soda_encoder,
                                         args.projection_dim).cuda()
        self.predictor_target = deepcopy(self.predictor)

        self.soda_optimizer = torch.optim.Adam(self.predictor.parameters(),
                                               lr=args.aux_lr,
                                               betas=(args.aux_beta, 0.999))
        self.train()
Exemple #2
0
    def __init__(self, obs_shape, action_shape, args):
        self.discount = args.discount
        self.critic_tau = args.critic_tau
        self.encoder_tau = args.encoder_tau
        self.actor_update_freq = args.actor_update_freq
        self.critic_target_update_freq = args.critic_target_update_freq

        shared_cnn = m.SharedCNN(obs_shape, args.num_shared_layers,
                                 args.num_filters).cuda()
        rl_cnn = m.HeadCNN(shared_cnn.out_shape, args.num_head_layers,
                           args.num_filters).cuda()
        actor_encoder = m.Encoder(
            shared_cnn, rl_cnn,
            m.RLProjection(rl_cnn.out_shape, args.projection_dim))
        critic_encoder = m.Encoder(
            shared_cnn, rl_cnn,
            m.RLProjection(rl_cnn.out_shape, args.projection_dim))

        self.actor = m.Actor(actor_encoder, action_shape, args.hidden_dim,
                             args.actor_log_std_min,
                             args.actor_log_std_max).cuda()
        self.critic = m.Critic(critic_encoder, action_shape,
                               args.hidden_dim).cuda()
        self.critic_target = deepcopy(self.critic)

        self.log_alpha = torch.tensor(np.log(args.init_temperature)).cuda()
        self.log_alpha.requires_grad = True
        self.target_entropy = -np.prod(action_shape)

        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=args.actor_lr,
                                                betas=(args.actor_beta, 0.999))
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=args.critic_lr,
                                                 betas=(args.critic_beta,
                                                        0.999))
        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],
                                                    lr=args.alpha_lr,
                                                    betas=(args.alpha_beta,
                                                           0.999))

        self.train()
        self.critic_target.train()
Exemple #3
0
    def __init__(self, obs_shape, action_shape, args):
        super().__init__(obs_shape, action_shape, args)
        self.aux_update_freq = args.aux_update_freq
        self.aux_lr = args.aux_lr
        self.aux_beta = args.aux_beta
        self.soda_batch_size = args.soda_batch_size

        shared_cnn = self.critic.encoder.shared_cnn
        aux_cnn = m.HeadCNN(shared_cnn.out_shape, args.num_head_layers,
                            args.num_filters).cuda()
        aux_encoder = m.Encoder(
            shared_cnn, aux_cnn,
            m.RLProjection(aux_cnn.out_shape, args.projection_dim)).cuda()

        self.autoEncoder = m.Decoder(aux_encoder, obs_shape=obs_shape).cuda()

        self.ccm_lambda = args.ccm_lambda
        self.ccm_head = m.CCMHead(aux_encoder, args.hidden_dim).cuda()
        self.bn = nn.BatchNorm1d(args.hidden_dim, affine=False).cuda()

        self.pad_head = m.InverseDynamics(aux_encoder, action_shape,
                                          args.hidden_dim).cuda()

        # NEG
        self.soda_tau = args.soda_tau
        self.predictor = m.SODAPredictor(aux_encoder,
                                         args.projection_dim).cuda()
        self.predictor_target = deepcopy(self.predictor)
        self.neg_optimizer = torch.optim.Adam(self.predictor.parameters(),
                                              lr=1e-4,
                                              betas=(args.aux_beta, 0.999))

        # rad
        # self.data_augs = data_augs
        # self.augs_funcs = {}
        self.augs_funcs = {
            'crop': rad.random_crop,
            'grayscale': rad.random_grayscale,
            'cutout': rad.random_cutout,
            'cutout_color': rad.random_cutout_color,
            # 'flip':rad.random_flip,
            # 'rotate':rad.random_rotation,
            'rand_conv': rad.random_convolution,
            'color_jitter': rad.random_color_jitter,
            'translate': rad.random_translate,
            'no_aug': rad.no_aug,
        }

        # for aug_name in self.data_augs.split('-'):
        #     assert aug_name in aug_to_func, 'invalid data aug string'
        #     self.augs_funcs[aug_name] = aug_to_func[aug_name]

        self.init_optimizer()
        self.train()
Exemple #4
0
	def __init__(self, obs_shape, action_shape, args):
		super().__init__(obs_shape, action_shape, args)
		self.aux_update_freq = args.aux_update_freq
		self.aux_lr = args.aux_lr
		self.aux_beta = args.aux_beta

		shared_cnn = self.critic.encoder.shared_cnn
		aux_cnn = m.HeadCNN(shared_cnn.out_shape, args.num_head_layers, args.num_filters).cuda()
		aux_encoder = m.Encoder(
			shared_cnn,
			aux_cnn,
			m.RLProjection(aux_cnn.out_shape, args.projection_dim)
		)
		self.pad_head = m.InverseDynamics(aux_encoder, action_shape, args.hidden_dim).cuda()
		self.init_pad_optimizer()
		self.train()