Exemplo n.º 1
0
 def __init__(self, hp):
     super().__init__()
     self._hp = hp
     time_cond_length = self._hp.max_seq_len if self._hp.one_hot_attn_time_cond else 1
     input_size = hp.nz_enc * 2
     self.query_net = Predictor(hp, input_size, hp.nz_attn_key)
     self.attention_layers = nn.ModuleList([MultiheadAttention(hp) for _ in range(hp.n_attention_layers)])
     self.predictor_layers = nn.ModuleList([Predictor(hp, hp.nz_enc, hp.nz_attn_key, num_layers=2)
                                            for _ in range(hp.n_attention_layers)])
     self.out = nn.Linear(hp.nz_enc, hp.nz_enc)
Exemplo n.º 2
0
    def build_network(self):
        super().build_network()

        self.distance_predictor = Predictor(self._hp,
                                            self._hp.nz_enc * 2,
                                            1,
                                            spatial=False)
Exemplo n.º 3
0
    def build_network(self):
        hp = self._hp

        q, self.prior = setup_variational_inference(self._hp, self._hp.nz_enc,
                                                    self._hp.nz_enc * 2)
        if self._hp.attentive_inference:
            self.inference = AttentiveInference(self._hp, q)
        else:
            self.inference = Inference(self._hp, q)

        # todo clean this up with subclassing?
        pred_inp_dim = hp.nz_enc * 2 + hp.nz_vae
        if self._hp.context_every_step:
            pred_inp_dim = pred_inp_dim + hp.nz_enc * 2

        if hp.tree_lstm:
            self.subgoal_pred, self.lstm_initializer = build_tree_lstm(
                hp, pred_inp_dim, hp.nz_enc)
        else:
            self.subgoal_pred = GeneralizedPredictorModel(
                hp,
                input_dim=pred_inp_dim,
                output_dims=[hp.nz_enc],
                activations=[None])

        self.build_binding()

        if self._hp.regress_index:
            self.index_predictor = Predictor(self._hp,
                                             self._hp.nz_enc * 2,
                                             self._hp.max_seq_len,
                                             detached=False,
                                             spatial=False)
Exemplo n.º 4
0
 def build_network(self, build_encoder=True):
     self.encoder = Encoder(self._hp)
     if not self._hp.reactive:
         self.policy = RecurrentPolicyModule(self._hp, 2 * self._hp.nz_enc,
                                             self._hp.n_actions)
     else:
         self.policy = Predictor(self._hp, 2 * self._hp.nz_enc,
                                 self._hp.n_actions)
Exemplo n.º 5
0
 def __init__(self, hp, cell, input_sz):
     super().__init__(hp, cell)
     from blox.torch.subnetworks import Predictor  # to avoid cyclic import
     self.net = Predictor(self._hp,
                          input_sz,
                          output_dim=2 * self._hidden_size,
                          spatial=hp.use_conv_lstm,
                          num_layers=self._hp.init_mlp_layers,
                          mid_size=self._hp.init_mlp_mid_sz)
Exemplo n.º 6
0
    def build_network(self, build_encoder=True):
        self._hp.input_nc = 6
        self.encoder = Encoder(self._hp)

        if self._hp.pred_states:
            outdim = self._hp.n_actions + self._hp.state_dim
        else:
            outdim = self._hp.n_actions
        self.action_pred = Predictor(self._hp, self._hp.nz_enc, outdim, 3)
Exemplo n.º 7
0
 def build_inference_encoder(self):
     if self._hp.states_inference:
         self.inf_encoder = Predictor(self._hp, self._hp.nz_enc + 2, self._hp.nz_enc)
     elif self._hp.act_cond_inference:
         self.inf_encoder = self.build_act_cond_inf_encoder()
     else:
         self.inf_encoder = self.build_inf_encoder()
         
     self.inf_key_encoder = nn.Sequential(self.build_inf_encoder(), AttnKeyEncodingModule(self._hp, add_time=False))
Exemplo n.º 8
0
 def __init__(self, hp):
     super().__init__()
     self._hp = hp
     if hp.builder.use_convs:
         self.net = ConvEncoder(hp)
     else:
         self.net = Predictor(hp,
                              hp.state_dim,
                              hp.nz_enc,
                              num_layers=hp.builder.get_num_layers())
Exemplo n.º 9
0
    def build_network(self):
        hp = self._hp

        q, self.prior = setup_variational_inference(self._hp, self._hp.nz_enc, self._hp.nz_enc * 2)
        self.inference = AttentiveInference(self._hp, q, Attention(self._hp))

        # todo clean this up with subclassing?
        pred_inp_dim = hp.nz_enc * 2 + hp.nz_vae
        if self._hp.var_inf is '2layer':
            pred_inp_dim = pred_inp_dim + hp.nz_vae2
        if self._hp.context_every_step:
            pred_inp_dim = pred_inp_dim + hp.nz_enc * 2
            
        if hp.tree_lstm:
            if hp.tree_lstm == 'sum':
                cls = SumTreeHiddenStatePredictorModel
            elif hp.tree_lstm == 'linear':
                cls = LinTreeHiddenStatePredictorModel
            elif hp.tree_lstm == 'split_linear':
                cls = SplitLinTreeHiddenStatePredictorModel
            else:
                raise ValueError("don't know this TreeLSTM type")
                
            self.subgoal_pred = cls(hp, input_dim=pred_inp_dim, output_dim=hp.nz_enc)
            self.lstm_initializer = self._get_lstm_initializer(self.subgoal_pred)
        else:
            self.subgoal_pred = GeneralizedPredictorModel(hp, input_dim=pred_inp_dim, output_dims=[hp.nz_enc],
                                                          activations=[None])

        # TODO this can be moved into matcher
        self.criterion = LossAveragingCriterion(self._hp)
        self.build_matcher()
        
        if self.predict_fraction:
            # TODO implement the inference side version of this
            # TODO put this inside the matcher
            input_size = hp.nz_enc * 2 if hp.timestep_cond_attention else hp.nz_enc * 3
            self.fraction_pred = Predictor(hp, input_size, output_dim=1, spatial=False,
                                           final_activation=nn.Sigmoid())

        if self._hp.regress_index:
            self.index_predictor = Predictor(
                self._hp, self._hp.nz_enc * 2, self._hp.max_seq_len, detached=False, spatial=False)
Exemplo n.º 10
0
    def __init__(self, hp, regress_actions):
        self._hp = hp
        decoder_net = self.build_decoder_net()
        super().__init__(hp, decoder_net)

        self.regress_actions = regress_actions
        if regress_actions:
            self.act_net = Predictor(hp, hp.nz_enc, hp.n_actions)
            self.act_log_sigma = get_constant_parameter(0, hp.learn_beta)
            self.act_sigma_updater = ConstantUpdater(self.act_log_sigma, 20,
                                                     'decoder_action_sigma')
Exemplo n.º 11
0
    def build_network(self):
        self.temp = nn.Parameter(self._hp.matching_temp * torch.ones(1))
        if not self._hp.learn_matching_temp:
            self.temp.requires_grad_(False)

        if self._hp.matching_temp_tenthlife != -1:
            assert not self._hp.learn_matching_temp
            self.matching_temp_updater = ExponentialDecayUpdater(
                self.temp, self._hp.matching_temp_tenthlife, min_limit=self._hp.matching_temp_min)
            
        self.distance_predictor = Predictor(self._hp, self._hp.nz_enc * 2, 1, spatial=False)

        self.criterion = LossAveragingCriterion(self._hp)
Exemplo n.º 12
0
    def build_decoder_net(self):
        hp = self._hp
        if self._hp.builder.use_convs:
            assert not (self._hp.add_weighted_pixel_copy
                        & self._hp.pixel_shift_decoder)
            if self._hp.pixel_shift_decoder:
                decoder_net = PixelShiftDecoder(self._hp)
            elif self._hp.add_weighted_pixel_copy:
                decoder_net = PixelCopyDecoder(self._hp)
            else:
                decoder_net = ConvDecoder(self._hp)
        else:
            assert not self._hp.use_skips
            assert not self._hp.add_weighted_pixel_copy
            assert not self._hp.pixel_shift_decoder
            state_predictor = Predictor(hp,
                                        hp.nz_enc,
                                        hp.state_dim,
                                        num_layers=hp.builder.get_num_layers())
            decoder_net = AttrDictPredictor({'images': state_predictor})

        return decoder_net
Exemplo n.º 13
0
 def build_network(self, build_encoder=True):
     self.cost_pred = Predictor(self._hp, self._hp.nz_enc * 2, 1, detached=True)
Exemplo n.º 14
0
 def build_network(self, build_encoder=True):
     if self._hp.build_encoder:
         self.encoder = Encoder(self._hp)
     input_sz = self._hp.nz_enc * 3 if self._hp.add_lstm_state_enc else self._hp.nz_enc * 2
     self.action_pred = Predictor(self._hp, input_sz, self._hp.n_actions)
Exemplo n.º 15
0
 def build_network(self, build_encoder=True):
     self.action_pred = Predictor(self._hp, self._hp.state_dim * 2,
                                  self._hp.n_actions, 3)
Exemplo n.º 16
0
 def build_network(self):
     self.existence_predictor = Predictor(self._hp, self._hp.nz_enc, 1, spatial=False)
Exemplo n.º 17
0
 def build_network(self, input_size, hp):
     self.net = Predictor(hp, input_size, hp.nz_attn_key, num_layers=1)
Exemplo n.º 18
0
 def __init__(self, hp, net):
     super().__init__()
     self.net = net
     self.ac_net = Predictor(hp, hp.nz_enc + hp.n_actions, hp.nz_enc)
Exemplo n.º 19
0
 def __init__(self, hp):
     super().__init__()
     self._hp = hp
     self.p = Predictor(hp, hp.nz_enc * 2, hp.max_seq_len)