def forward(self,
                img_feat,
                navigable_feat,
                pre_feat,
                h_0,
                c_0,
                ctx,
                navigable_index=None,
                ctx_mask=None):
        """ Takes a single step in the decoder LSTM.

        img_feat: batch x 36 x feature_size
        navigable_feat: batch x max_navigable x feature_size
        h_0: batch x hidden_size
        c_0: batch x hidden_size
        ctx: batch x seq_len x dim
        navigable_index: list of list
        ctx_mask: batch x seq_len - indices to be masked
        """
        batch_size, num_imgs, feat_dim = img_feat.size()

        # add 1 because the navigable index yet count in "stay" location
        # but navigable feature does include the "stay" location at [:,0,:]
        index_length = [len(_index) + 1 for _index in navigable_index]
        navigable_mask = create_mask(batch_size, self.max_navigable,
                                     index_length)

        proj_img_feat = proj_masking(img_feat, self.proj_img_mlp)

        proj_navigable_feat = proj_masking(navigable_feat,
                                           self.proj_navigable_mlp,
                                           navigable_mask)

        weighted_img_feat, _ = self.soft_attn(self.h0_fc(h_0), proj_img_feat,
                                              img_feat)

        concat_input = torch.cat((pre_feat, weighted_img_feat), 1)

        h_1, c_1 = self.lstm(self.dropout(concat_input), (h_0, c_0))

        h_1_drop = self.dropout(h_1)

        # use attention on language instruction
        weighted_context, ctx_attn = self.soft_attn(self.h1_fc(h_1_drop),
                                                    self.dropout(ctx),
                                                    mask=ctx_mask)
        h_tilde = self.proj_out(weighted_context)

        logit = torch.bmm(proj_navigable_feat, h_tilde.unsqueeze(2)).squeeze(2)

        return h_1, c_1, ctx_attn, logit, navigable_mask
Exemple #2
0
    def forward(self, img_feat, navigable_feat, pre_feat, question, h_0, c_0, ctx, pre_ctx_attend,
                s_0, r_t, navigable_index=None, ctx_mask=None):

    #def forward(self, img_feat, navigable_feat, pre_feat, question, h_0, c_0, ctx, pre_ctx_attend,\
               # navigable_index=None, ctx_mask=None, s_0, r_t, config_embedding):

        """ Takes a single step in the decoder LSTM.
        config_embedding: batch x max_config_len x config embeddding
        image_feature: batch x 12 images  x image_feature_size
        navigable_index: list of navigable viewstates
        h_t: batch x hidden_size
        c_t: batch x hidden_size
        ctx_mask: batch x seq_len - indices to be masked
        """
        batch_size, num_imgs, feat_dim = img_feat.size()

        index_length = [len(_index) + self.num_predefined_action for _index in navigable_index]
        navigable_mask = create_mask(batch_size, self.max_navigable, index_length)

        proj_navigable_feat = proj_masking(navigable_feat, self.proj_navigable_mlp, navigable_mask)
        proj_pre_feat = self.proj_navigable_mlp(pre_feat)

        weighted_img_feat, img_attn = self.soft_attn(self.h0_fc(h_0), proj_navigable_feat, mask=navigable_mask)

        
        if r_t is None:
            r_t = self.r_linear(torch.cat((weighted_img_feat, h_0), dim=1)) 
            r_t = self.sm(r_t)
        
        
        # r_t = self.r_linear(torch.cat((weighted_img_feat, h_0), dim=1)) 
        # r_t = self.sm(r_t)

        weighted_ctx, ctx_attn = self.state_attention(s_0, r_t, self.config_fc(ctx), ctx_mask)
        # positioned_ctx = self.lang_position(self.config_fc(ctx))

        # weighted_ctx, ctx_attn = self.soft_attn(self.h1_fc(h_0), positioned_ctx, mask=ctx_mask)

        # merge info into one LSTM to be carry through time
        concat_input = torch.cat((proj_pre_feat, weighted_img_feat, weighted_ctx), 1)

        h_1, c_1 = self.lstm(concat_input, (h_0, c_0))
        h_1_drop = self.dropout(h_1)

        # policy network
        h_tilde = self.logit_fc(torch.cat((weighted_ctx, h_1_drop), dim=1))
        logit = torch.bmm(proj_navigable_feat, h_tilde.unsqueeze(2)).squeeze(2)

        # value estimation
        concat_value_input = self.h2_fc_lstm(torch.cat((h_0, weighted_img_feat), 1))

        h_1_value = self.dropout(torch.sigmoid(concat_value_input) * torch.tanh(c_1))

        value = self.critic(torch.cat((ctx_attn, h_1_value), dim=1))
    
        return h_1, c_1, weighted_ctx, img_attn, ctx_attn, logit, value, navigable_mask
Exemple #3
0
    def forward(self, img_feat, navigable_feat, pre_feat, question, h_0, c_0, ctx, pre_ctx_attend,
               s0, r0, navigable_index=None, ctx_mask=None):
        """ Takes a single step in the decoder

        img_feat: batch x 36 x feature_size
        navigable_feat: batch x max_navigable x feature_size

        pre_feat: previous attended feature, batch x feature_size

        question: this should be a single vector representing instruction

        ctx: batch x seq_len x dim
        navigable_index: list of list
        ctx_mask: batch x seq_len - indices to be masked
        """
        batch_size, num_imgs, feat_dim = img_feat.size()

        index_length = [len(_index) + self.num_predefined_action for _index in navigable_index]
        navigable_mask = create_mask(batch_size, self.max_navigable, index_length)

        proj_navigable_feat = proj_masking(navigable_feat, self.proj_navigable_mlp, navigable_mask)
        proj_pre_feat = self.proj_navigable_mlp(pre_feat)
        #positioned_ctx = self.lang_position(ctx)

        weighted_ctx, ctx_attn = self.text_soft_attn(self.h1_fc(h_0), ctx, mask=ctx_mask)


        weighted_img_feat, img_attn = self.img_soft_attn(self.h0_fc(h_0), proj_navigable_feat, mask=navigable_mask)

        # if r_t is None:
        #     r_t = self.r_linear(torch.cat((weighted_img_feat, h_0), dim=1)) 
        #     r_t = self.sm(r_t)

        # weighted_ctx, ctx_attn = self.state_attention(s_0, r_t, ctx, ctx_mask)

        # merge info into one LSTM to be carry through time
        concat_input = torch.cat((proj_pre_feat, weighted_img_feat, weighted_ctx), 1)

        h_1, c_1 = self.lstm(concat_input, (h_0, c_0))
        h_1_drop = self.dropout(h_1)

        # policy network
        h_tilde = self.logit_fc(torch.cat((weighted_ctx, h_1_drop), dim=1))
        logit = torch.bmm(proj_navigable_feat, h_tilde.unsqueeze(2)).squeeze(2)

        # value estimation
        concat_value_input = self.h2_fc_lstm(torch.cat((h_0, weighted_img_feat), 1))

        h_1_value = self.dropout(torch.sigmoid(concat_value_input) * torch.tanh(c_1))

        value = self.critic(torch.cat((ctx_attn, h_1_value), dim=1))

        return h_1, c_1, weighted_ctx, img_attn, ctx_attn, logit, value, navigable_mask
Exemple #4
0
    def forward(self, img_feat, navigable_feat, pre_feat, h_0, c_0, ctx, navigable_index=None, ctx_mask=None):
        """ Takes a single step in the decoder LSTM.

        img_feat: batch x 36 x feature_size
        navigable_feat: batch x max_navigable x feature_size
        h_0: batch x hidden_size
        c_0: batch x hidden_size
        ctx: batch x seq_len x dim
        navigable_index: list of list
        ctx_mask: batch x seq_len - indices to be masked
        """
        batch_size, num_imgs, feat_dim = img_feat.size()

        index_length = [len(_index) + self.num_predefined_action for _index in navigable_index]
        navigable_mask = create_mask(batch_size, self.max_navigable, index_length)

        proj_navigable_feat = proj_masking(navigable_feat, self.proj_navigable_mlp, navigable_mask)
        proj_pre_feat = self.proj_navigable_mlp(pre_feat)
        positioned_ctx = self.lang_position(ctx)

        weighted_ctx, ctx_attn = self.soft_attn(self.h1_fc(h_0), positioned_ctx, mask=ctx_mask)

        weighted_img_feat, img_attn = self.soft_attn(self.h0_fc(h_0), proj_navigable_feat, mask=navigable_mask)

        # merge info into one LSTM to be carry through time
        concat_input = torch.cat((proj_pre_feat, weighted_img_feat, weighted_ctx), 1)

        h_1, c_1 = self.lstm(concat_input, (h_0, c_0))
        h_1_drop = self.dropout(h_1)

        # policy network
        h_tilde = self.logit_fc(torch.cat((weighted_ctx, h_1_drop), dim=1))
        logit = torch.bmm(proj_navigable_feat, h_tilde.unsqueeze(2)).squeeze(2)

        # value estimation
        concat_value_input = self.h2_fc_lstm(torch.cat((h_0, weighted_img_feat), 1))

        h_1_value = self.dropout(torch.sigmoid(concat_value_input) * torch.tanh(c_1))

        value = self.critic(torch.cat((ctx_attn, h_1_value), dim=1))

        return h_1, c_1, weighted_ctx, img_attn, ctx_attn, logit, value, navigable_mask
    def forward(self,
                img_feat,
                navigable_feat,
                pre_feat,
                pre_value,
                h_0,
                c_0,
                ctx,
                navigable_index=None,
                navigable_idx_to_previous=None,
                oscillation_index=None,
                block_oscillation_index=None,
                ctx_mask=None,
                seq_lengths=None,
                prevent_oscillation=False,
                prevent_rollback=False,
                is_training=None):
        """
        forward passing the network of regretful agent

        :param img_feat: (batch_size, 36, d-dim feat)
        :param navigable_feat: (batch_size, max number of navigable direction, d-dim)
        :param pre_feat: (batch_size, d-dim)
        :param pre_value: (batch_size, 1)
        :param h_0: (batch_size, d-dim)
        :param c_0: (batch_size, d-dim)
        :param ctx: (batch_size, max instruction length, d-dim)
        :param navigable_index: list of list, index for navigable directions for each sample in the mini-batch
        :param navigable_idx_to_previous: list
        :param oscillation_index: list
        :param block_oscillation_index: list
        :param ctx_mask: (batch_size, max instruction length)
        :param seq_lengths: list
        :param prevent_oscillation: 1 or 0
        :param prevent_rollback: 1 or 0
        :param is_training: True or False
        """

        batch_size, num_imgs, feat_dim = img_feat.size()

        # creating a mask to block out non-navigable directions, due to batch processing
        index_length = [
            len(_index) + self.num_predefined_action
            for _index in navigable_index
        ]
        navigable_mask = create_mask(batch_size, self.max_navigable,
                                     index_length)

        # prevent rollback action as a sanity check. See Table 3 in the paper.
        if prevent_rollback and not is_training:
            navigable_mask[torch.LongTensor(range(batch_size)), np.array(oscillation_index) + 1] = \
                navigable_mask[torch.LongTensor(range(batch_size)), np.array(oscillation_index) + 1] * (
                            1 - (torch.Tensor(index_length) > 2)).float().to(self.device)

        # block the navigable direction that leads to oscillation
        if 1 in block_oscillation_index and prevent_oscillation:
            navigable_mask = self.block_oscillation(batch_size, navigable_mask,
                                                    oscillation_index,
                                                    block_oscillation_index)

        # get navigable features without attached markers for visual grounding
        navigable_feat_no_visited = navigable_feat[:, :, :-self.opts.tiled_len]
        proj_navigable_feat = proj_masking(navigable_feat_no_visited,
                                           self.proj_navigable_mlp,
                                           navigable_mask)
        proj_pre_feat = self.proj_navigable_mlp(
            pre_feat[:, :-self.opts.tiled_len])
        weighted_img_feat, img_attn = self.soft_attn(self.h0_fc(h_0),
                                                     proj_navigable_feat,
                                                     mask=navigable_mask)

        # positional encoding instruction embeddings and textual grounding
        positioned_ctx = self.positional_encoding(ctx)
        weighted_ctx, ctx_attn = self.soft_attn(self.h1_fc(h_0),
                                                positioned_ctx,
                                                mask=ctx_mask)

        # merge info into one LSTM to be carry through time
        concat_input = torch.cat(
            (proj_pre_feat, weighted_img_feat, weighted_ctx), 1)
        h_1, c_1 = self.lstm(concat_input, (h_0, c_0))
        h_1_drop = self.dropout(h_1)

        # =========== forward and rollback embeddings ===========
        m_forward = self.logit_fc(torch.cat((weighted_ctx, h_1_drop), dim=1))
        m_rollback = proj_navigable_feat[torch.LongTensor(range(batch_size)),
                                         np.array(navigable_idx_to_previous) +
                                         1, :]

        # =========== Progress Monitor ===========
        concat_value_input = self.h2_fc_lstm(
            torch.cat((h_0, weighted_img_feat), 1))
        h_1_value = self.dropout(
            torch.sigmoid(concat_value_input) * torch.tanh(c_1))
        critics_input = torch.cat((ctx_attn, h_1_value), dim=1)

        if self.opts.monitor_sigmoid:
            value = self.sigmoid(self.critic_fc(critics_input))
        else:
            value = self.tanh(self.critic_fc(critics_input))

        # =========== Progress Marker ===========
        value_detached = value.detach()
        value_for_marker = value_detached.unsqueeze(1).repeat(1, self.max_navigable, self.opts.tiled_len) * \
                           navigable_mask.unsqueeze(2).repeat(1, 1, self.opts.tiled_len)

        navigable_visited_feat = value_for_marker - navigable_feat[:, :,
                                                                   -self.opts.
                                                                   tiled_len:]

        proj_navigable_feat_visited = torch.cat(
            (proj_navigable_feat, navigable_visited_feat), dim=2)

        # =========== Regret Module ===========
        rollback_forward = torch.cat(
            (m_rollback.unsqueeze(1), m_forward.unsqueeze(1)), dim=1)
        rollback_forward_logit = self.critic_valueDiff_fc(value_detached -
                                                          pre_value)
        rollback_forward_attn = self.softmax(rollback_forward_logit)
        m_forward_rollback = torch.bmm(rollback_forward_attn.unsqueeze(1),
                                       rollback_forward).squeeze(1)

        # =========== Action selection with Progress Marker ===========
        logit = torch.bmm(
            proj_navigable_feat_visited,
            self.move_fc(m_forward_rollback).unsqueeze(2)).squeeze(2)

        # Scene Recognition Auxiliary task
        vp_class = self.viewpoint_fc(h_1_drop)

        return h_1, c_1, img_attn, ctx_attn, rollback_forward_attn, logit, rollback_forward_logit, value, navigable_mask, vp_class
Exemple #6
0
    def forward(self, navigable_img_feat, navigable_obj_feat, pre_feat, question, h_0, c_0, ctx, pre_ctx_attend, \
                s_0, r_t, navigable_index, ctx_mask):

        """ Takes a single step in the decoder LSTM.
        config_embedding: batch x max_config_len x config embeddding
        image_feature: batch x 12 images x 36 boxes x image_feature_size
        navigable_index: list of navigable viewstates
        h_t: batch x hidden_size
        c_t: batch x hidden_size
        ctx_mask: batch x seq_len - indices to be masked
        """
        # input of image_feature should be changed
    
        
        batch_size, num_heading, num_object, object_feat_dim = navigable_obj_feat.size()
        navigable_obj_feat = navigable_obj_feat.view(batch_size, num_heading*num_object, object_feat_dim) #4 x 16*36 x 152
        index_length = [len(_index)+1 for _index in navigable_index]
        
        navigable_mask = create_mask(batch_size, self.max_navigable, index_length)
        navigable_obj_mask = create_mask_for_object(batch_size, self.max_navigable*num_object, index_length) #batch x 16*36

        proj_navigable_obj_feat = proj_masking(navigable_obj_feat, self.proj_navigable_obj_mlp, navigable_obj_mask) # batch x 16*36 x 152 -> batch x 16*36 x 128
        proj_navigable_feat = proj_masking(navigable_img_feat, self.proj_navigable_img_mlp, navigable_mask)

        proj_pre_feat = self.proj_navigable_img_mlp(pre_feat)

        # first use soft attention to object to be the input of r
        pre_weighted_img_feat, img_attn = self.soft_attn(self.h0_fc(h_0), proj_navigable_feat, mask=navigable_mask)
      #  weighted_obj_feat, obj_attn = self.soft_attn(self.h0_fc(h_0), proj_navigable_obj_feat, mask=navigable_obj_mask) # batch x 128

        # if r_t is None:
        #     r_t = self.r_linear(torch.cat((weighted_obj_feat, h_0), dim=1)) [stop, 1/4, 1/2, go]
        #     r_t = self.sm(r_t)

        # if r_t is None:
        #     r_t = self.r_linear(torch.cat((pre_weighted_img_feat, h_0), dim=1))
        #     r_t = self.sm(r_t)
        
       
        # new_r_transform = self.r_transform.to(r_t.device)
        # new_r_t = torch.matmul(r_t, new_r_transform)
        # r_t0 = torch.matmul(r_t, torch.tensor([1,0,0.75,0.5], device=r_t.device))
        # r_t1 = torch.matmul(r_t, torch.tensor([0,1,0.25,0.5], device=r_t.device))
        # new_r_t = torch.stack([r_t0, r_t1], dim=1)
    
            
            # r_t[:,0] = r_t[:,0] * 1 + r_t[:,1] * 0 + r_t[:,2] * 0.75 + r_t[:,3] * 0.5
            # r_t[:,1] = r_t[:,0] * 0 + r_t[:,1] * 1 + r_t[:,2] * 0.25 + r_t[:,3] * 0.5

        weighted_ctx, ctx_attn = self.state_attention(s_0, new_r_t, self.config_fc(ctx), ctx_mask)

        #second use the selected configuration to be the attention to select the object

        conf_obj_feat = self.config_obj_attention(self.config_atten_linear(weighted_ctx), proj_navigable_obj_feat, navigable_mask) # 4 x 16 x 128
        weighted_conf_obj_feat, conf_obj_attn = self.soft_attn(self.h0_fc(h_0), conf_obj_feat, mask=navigable_mask) # 4 x 128
        weighted_img_feat = torch.bmm(conf_obj_attn.unsqueeze(dim=1), self.image_linear(navigable_img_feat)).squeeze(dim=1)# batch x 2176

        # third use the conf_obj_attn to select the image

        # obj_attn = obj_attn.view(batch_size, num_heading, num_object) # batch x 36 x 16
        # obj_attn = torch.sum(obj_attn, dim=2) # batch x 16
        # weighted_img_feat = torch.bmm(obj_attn.unsqueeze(dim=1), self.image_linear(navigable_img_feat)).squeeze(dim=1)# batch x 2176


        # fourth use the selected object to be the attention to select the corresponding iamge
        

        concat_input = torch.cat((proj_pre_feat, weighted_img_feat, weighted_ctx), 1)

        h_1, c_1 = self.lstm(concat_input, (h_0, c_0))
        h_1_drop = self.dropout(h_1)

        # policy network
        h_tilde = self.logit_fc(torch.cat((weighted_ctx, h_1_drop), dim=1))
        logit = torch.bmm(proj_navigable_feat, h_tilde.unsqueeze(2)).squeeze(2)

        # value estimation
        concat_value_input = self.h2_fc_lstm(torch.cat((h_0, weighted_img_feat), 1))

        h_1_value = self.dropout(torch.sigmoid(concat_value_input) * torch.tanh(c_1))

        value = self.critic(torch.cat((ctx_attn, h_1_value), dim=1))
    
        return h_1, c_1, weighted_ctx, conf_obj_attn, ctx_attn, logit, value, navigable_mask
    
        
        '''
    def forward(
        self,
        nav_imgs,
        navigable_ang_feat,
        pre_feat,
        question,
        h_0,
        c_0,
        ctx,
        pre_ctx_attend,
        navigable_index=None,
        ctx_mask=None,
    ):
        """ Takes a single step in the decoder

        navigable_feat: batch x max_navigable x feature_size
        nav_img: batch x max_navigable x 3 x H x W

        pre_feat: previous attended feature, batch x feature_size

        question: this should be a single vector representing instruction

        ctx: batch x seq_len x dim
        navigable_index: list of list
        ctx_mask: batch x seq_len - indices to be masked
        """
        batch_size = nav_imgs.shape[0]

        index_length = [
            len(_index) + self.num_predefined_action
            for _index in navigable_index
        ]
        navigable_mask = create_mask(batch_size, self.max_navigable,
                                     index_length)

        # Get nav_feats from FiLM
        nav_imgs = self.resnet(nav_imgs)
        beta_gamma = self.film_gen(h_0)
        nav_imgs = self.film(nav_imgs, beta_gamma)
        navigable_feat = nav_imgs.mean(-1).mean(-1)
        navigable_feat = torch.cat([navigable_feat, navigable_ang_feat], 2)

        proj_navigable_feat = proj_masking(navigable_feat,
                                           self.proj_navigable_mlp,
                                           navigable_mask)
        proj_pre_feat = self.proj_navigable_mlp(pre_feat)
        positioned_ctx = self.lang_position(ctx)

        weighted_ctx, ctx_attn = self.soft_attn(self.h1_fc(h_0),
                                                positioned_ctx,
                                                mask=ctx_mask)

        weighted_img_feat, img_attn = self.soft_attn(self.h0_fc(h_0),
                                                     proj_navigable_feat,
                                                     mask=navigable_mask)

        # merge info into one LSTM to be carry through time
        concat_input = torch.cat(
            (proj_pre_feat, weighted_img_feat, weighted_ctx), 1)

        h_1, c_1 = self.lstm(concat_input, (h_0, c_0))
        h_1_drop = self.dropout(h_1)

        # policy network
        h_tilde = self.logit_fc(torch.cat((weighted_ctx, h_1_drop), dim=1))
        logit = torch.bmm(proj_navigable_feat, h_tilde.unsqueeze(2)).squeeze(2)

        # value estimation
        concat_value_input = self.h2_fc_lstm(
            torch.cat((h_0, weighted_img_feat), 1))

        h_1_value = self.dropout(
            torch.sigmoid(concat_value_input) * torch.tanh(c_1))

        value = self.critic(torch.cat((ctx_attn, h_1_value), dim=1))

        return (
            h_1,
            c_1,
            weighted_ctx,
            img_attn,
            ctx_attn,
            logit,
            value,
            navigable_mask,
            navigable_feat,
        )
    def forward(self,
                img_feat,
                navigable_feat,
                pre_feat,
                question,
                object_t,
                place_t,
                h_0,
                c_0,
                ctx,
                pre_ctx_attend,
                navigable_index=None,
                ctx_mask=None):
        """ Takes a single step in the decoder

        img_feat: batch x 36 x feature_size
        navigable_feat: batch x max_navigable x feature_size

        pre_feat: previous attended feature, batch x feature_size

        question: this should be a single vector representing instruction

        ctx: batch x seq_len x dim
        navigable_index: list of list
        ctx_mask: batch x seq_len - indices to be masked
        """

        object_t = object_t.float()
        place_t = place_t.float()

        #print('h_0',h_0.shape)#([4, 256])
        #print('object_t',object_t.shape)#torch.Size([4, 12, 17])

        object_att, alpha_o = self.object_attention_layer(h_0,
                                                          object_t)  # 10, 17
        place_att, alpha_p = self.place_attention_layer(h_0, place_t)  # 10, 10

        #print('object_att,place_att',object_att.shape,place_att.shape)#torch.Size([4, 1, 17]) torch.Size([4, 10])
        concat_o_p = torch.cat((object_att, place_att), 1)
        #print('concat_o_p', concat_o_p.shape)  # 10,12,27

        batch_size, num_imgs, feat_dim = img_feat.size()

        index_length = [
            len(_index) + self.num_predefined_action
            for _index in navigable_index
        ]
        navigable_mask = create_mask(batch_size, self.max_navigable,
                                     index_length)

        proj_navigable_feat = proj_masking(navigable_feat,
                                           self.proj_navigable_mlp,
                                           navigable_mask)
        proj_pre_feat = self.proj_navigable_mlp(
            pre_feat)  # I think this is previous action
        positioned_ctx = self.lang_position(ctx)

        weighted_ctx, ctx_attn = self.soft_attn(self.h1_fc(h_0),
                                                positioned_ctx,
                                                mask=ctx_mask)

        weighted_img_feat, img_attn = self.soft_attn(self.h0_fc(h_0),
                                                     proj_navigable_feat,
                                                     mask=navigable_mask)

        # merge info into one LSTM to be carry through time
        concat_input = torch.cat(
            (proj_pre_feat, weighted_img_feat, weighted_ctx, concat_o_p), 1)
        #print('concat_input',concat_input.shape)#Before : 4, 512 / with concat_o_p : 4, 539

        h_1, c_1 = self.lstm(concat_input, (h_0, c_0))
        h_1_drop = self.dropout(h_1)

        # policy network
        h_tilde = self.logit_fc(torch.cat((weighted_ctx, h_1_drop), dim=1))
        logit = torch.bmm(proj_navigable_feat, h_tilde.unsqueeze(2)).squeeze(2)

        # value estimation
        concat_value_input = self.h2_fc_lstm(
            torch.cat((h_0, weighted_img_feat), 1))

        h_1_value = self.dropout(
            torch.sigmoid(concat_value_input) *
            torch.tanh(c_1))  # h_1_value is same with h_t_pm

        value = self.critic(torch.cat(
            (ctx_attn, h_1_value),
            dim=1))  #THIS IS PROGRESS MONITOR(Value is p_t_pm

        return h_1, c_1, weighted_ctx, img_attn, ctx_attn, logit, value, navigable_mask