Пример #1
0
    def forward(self, x, pairs_info, pairs_info_augmented, image_id, flag_, phase):
        out1 = self.Conv_pretrain(x)  ### out1_shape(batch, 1024, 25, 25) -> (batch, channel, width, height) from (batch, 3, 400, 400)

        import pdb; pdb.set_trace()

        rois_people, rois_objects, spatial_locs, union_box = ROI.get_pool_loc(out1, image_id, flag_, size=pool_size,
                                                                              spatial_scale=25,
                                                                              batch_size=len(pairs_info))
        #rois_people -> total batch, person area, calculate each size after then, adaptive pool to make same size (10, 10)
        #rois_objects -> same with rois_people in place of people
        #spatial_locs -> objects(persons~, objects~) [x1, y1, x2, y2, x2-x1, y2-y1] , scale 0~25 from spatial_scale option
        #union_box -> in each batch, every person and objects combination spatial map that value is 100, ohter area is 0


        ### Defining The Pooling Operations #######
        x, y = out1.size()[2], out1.size()[3]
        hum_pool = nn.AvgPool2d(pool_size, padding=0, stride=(1, 1))
        obj_pool = nn.AvgPool2d(pool_size, padding=0, stride=(1, 1))
        context_pool = nn.AvgPool2d((x, y), padding=0, stride=(1, 1))
        #################################################
        ### Human###
        residual_people = rois_people
        res_people = self.Conv_people(rois_people) + residual_people
        res_av_people = hum_pool(res_people)
        out2_people = self.flat(res_av_people)
        ###########

        ##Objects##
        residual_objects = rois_objects
        res_objects = self.Conv_objects(rois_objects) + residual_objects
        res_av_objects = obj_pool(res_objects)
        out2_objects = self.flat(res_av_objects)
        #############

        #### Context ######
        residual_context = out1
        res_context = self.Conv_context(out1) + residual_context
        res_av_context = context_pool(res_context)
        out2_context = self.flat(res_av_context)
        #################

        ##Attention Features##
        out2_union = self.spmap_up(self.flat(self.conv_sp_map(union_box)))
        ############################



        pairs, people, objects_only = ROI.pairing(out2_people, out2_objects, out2_context, spatial_locs, pairs_info)

        pairs = self.test_spatial_to_node_feature(pairs)

        node_feature_cat = ROI.get_node_feature(out2_people, out2_objects, out2_context, pairs_info)
        node_feature = node_feature_cat.reshape(node_feature_cat.size()[0], node_feature_cat.size()[1], 1)
        #node_feature = self.learnable_conv(node_feature)
        node_feature = node_feature.reshape(node_feature.size()[0], node_feature.size()[1])

        ###210121 test spatial_locs feature into node_feature by ADD
        ##after this test, have to test by MUL
        #import pdb; pdb.set_trace()
        node_feature = node_feature * pairs

        test = self.learnable_matrix(node_feature)
        # test2 = self.learnable_conv(node_feature.unsqueeze(2).unsqueeze(2))
        # test2 = self.learnable_conv_matrix(test2)
        test2 = self.learnable_single(node_feature)

        ###Interaction Prob
        interaction_feature = self.interaction_prob_matrix(node_feature_cat)
        interaction_prob = interaction_feature * out2_union
        interaction_score = self.interaction_prob_value(interaction_prob)
        interaction_score = self.sigmoid(interaction_score)



        return [test, test2, interaction_score] # ,lin_obj_ids]
Пример #2
0
    def forward(self, x, pairs_info, pairs_info_augmented, image_id, flag_,
                phase):
        out1 = self.Conv_pretrain(x)  ###

        rois_people, rois_objects, spatial_locs, union_box = ROI.get_pool_loc(
            out1,
            image_id,
            flag_,
            size=pool_size,
            spatial_scale=25,
            batch_size=len(pairs_info))

        ### Defining The Pooling Operations #######
        x, y = out1.size()[2], out1.size()[3]
        hum_pool = nn.AvgPool2d(pool_size, padding=0, stride=(1, 1))
        obj_pool = nn.AvgPool2d(pool_size, padding=0, stride=(1, 1))
        context_pool = nn.AvgPool2d((x, y), padding=0, stride=(1, 1))
        #################################################

        ### Human###
        residual_people = rois_people
        res_people = self.Conv_people(rois_people) + residual_people
        res_av_people = hum_pool(res_people)
        out2_people = self.flat(res_av_people)
        ###########

        ##Objects##
        residual_objects = rois_objects
        res_objects = self.Conv_objects(rois_objects) + residual_objects
        res_av_objects = obj_pool(res_objects)
        out2_objects = self.flat(res_av_objects)
        #############

        #### Context ######
        residual_context = out1
        res_context = self.Conv_context(out1) + residual_context
        res_av_context = context_pool(res_context)
        out2_context = self.flat(res_av_context)
        #################

        ##Attention Features##
        out2_union = self.spmap_up(self.flat(self.conv_sp_map(union_box)))
        ############################

        #### Making Essential Pairing##########
        pairs, people, objects_only = ROI.pairing(out2_people, out2_objects,
                                                  out2_context, spatial_locs,
                                                  pairs_info)
        ####################################

        ###### Interaction Probability##########
        lin_single_h = self.lin_single_head(pairs)
        lin_single_t = lin_single_h * out2_union
        lin_single = self.lin_single_tail(lin_single_t)
        interaction_prob = self.sigmoid(lin_single)
        ####################################################

        ####### Graph Model Base Structure##################
        people_t = people
        objects_only = objects_only
        combine_g = []
        people_f = []
        objects_f = []
        pairs_f = []
        start_p = 0
        start_o = 0
        start_c = 0
        for batch_num, l in enumerate(pairs_info):

            ####Slicing##########
            people_this_batch = people_t[start_p:start_p + int(l[0])]
            no_peo = len(people_this_batch)
            objects_this_batch = objects_only[start_o:start_o + int(l[1])][1:]
            no_objects_this_batch = objects_only[start_o:start_o +
                                                 int(l[1])][0]
            no_obj = len(objects_this_batch)
            interaction_prob_this_batch = interaction_prob[start_c:start_c +
                                                           int(l[1]) *
                                                           int(l[0])]
            if no_obj == 0:

                people_this_batch_r = people_this_batch

                objects_this_batch_r = no_objects_this_batch.view([1, 1024])
            else:
                peo_to_obj_this_batch = torch.stack([
                    torch.cat((i, j))
                    for ind_p, i in enumerate(people_this_batch)
                    for ind_o, j in enumerate(objects_this_batch)
                ])
                obj_to_peo_this_batch = torch.stack([
                    torch.cat((i, j))
                    for ind_p, i in enumerate(objects_this_batch)
                    for ind_o, j in enumerate(people_this_batch)
                ])
                ###################

                ####### Adjecency###########
                adj_l = []
                adj_po = torch.zeros([no_peo, no_obj]).cuda()
                adj_op = torch.zeros([no_obj, no_peo]).cuda()

                for index_probs, probs in enumerate(
                        interaction_prob_this_batch):
                    if index_probs % (no_obj + 1) != 0:
                        adj_l.append(probs)

                adj_po = torch.cat(adj_l).view(len(adj_l), 1)
                adj_op = adj_po
                ##############################

                ###Finding Out Refined Features######

                people_this_batch_r = people_this_batch + torch.mm(
                    adj_po.view([no_peo, no_obj]),
                    self.peo_to_obj_w(objects_this_batch))

                objects_this_batch_r = objects_this_batch + torch.mm(
                    adj_op.view([no_peo, no_obj]).t(),
                    self.obj_to_peo_w(people_this_batch))
                objects_this_batch_r = torch.cat(
                    (no_objects_this_batch.view([1,
                                                 1024]), objects_this_batch_r))
            #############################

            #### Restructuring ####
            people_f.append(people_this_batch_r)
            people_t_f = people_this_batch_r
            objects_f.append(objects_this_batch_r)
            objects_t_f = objects_this_batch_r

            pairs_f.append(
                torch.stack([
                    torch.cat((i, j)) for ind_p, i in enumerate(people_t_f)
                    for ind_o, j in enumerate(objects_t_f)
                ]))

            #import pdb;pdb.set_trace()
            ##############################

            ###Loop increment for next batch##
            start_p += int(l[0])
            start_o += int(l[1])
            start_c += int(l[0]) * int(l[1])
            #####################

        people_graph = torch.cat(people_f)
        objects_graph = torch.cat(objects_f)
        pairs_graph = torch.cat(pairs_f)
        ######################################################################################################################################

        #### Prediction from visual features####
        lin_h = self.lin_visual_head(pairs)
        lin_t = lin_h * out2_union
        lin_visual = self.lin_visual_tail(lin_t)
        ##############################

        #### Prediction from visual features####

        lin_graph_h = self.lin_graph_head(pairs_graph)
        lin_graph_t = lin_graph_h * out2_union
        lin_graph = self.lin_graph_tail(lin_graph_t)

        ####################################

        ##### Prediction from attention features #######
        lin_att = self.lin_spmap_tail(out2_union)
        #############################

        return [lin_visual, lin_single, lin_graph, lin_att]  #,lin_obj_ids]
Пример #3
0
    def forward(self, x, pairs_info, pairs_info_augmented, image_id, flag_,
                phase):

        out1 = self.Conv_pretrain(x)

        rois_people, rois_objects, spatial_locs, union_box = ROI.get_pool_loc(
            out1,
            image_id,
            flag_,
            size=pool_size,
            spatial_scale=25,
            batch_size=len(pairs_info))

        ### Defining The Pooling Operations #######
        x, y = out1.size()[2], out1.size()[3]
        hum_pool = nn.AvgPool2d(pool_size, padding=0, stride=(1, 1))
        obj_pool = nn.AvgPool2d(pool_size, padding=0, stride=(1, 1))
        context_pool = nn.AvgPool2d((x, y), padding=0, stride=(1, 1))
        #################################################

        ### Human###
        residual_people = rois_people
        res_people = self.Conv_people(rois_people) + residual_people
        res_av_people = hum_pool(res_people)
        out2_people = self.flat(res_av_people)
        ###########

        ##Objects##
        residual_objects = rois_objects
        res_objects = self.Conv_objects(rois_objects) + residual_objects
        res_av_objects = obj_pool(res_objects)
        out2_objects = self.flat(res_av_objects)
        #############

        #### Context ######
        residual_context = out1
        res_context = self.Conv_context(out1) + residual_context
        res_av_context = context_pool(res_context)
        out2_context = self.flat(res_av_context)
        #################

        ## Attention features ##
        a_ho = self.W_Spat(self.flat(self.Conv_spatial(union_box)))

        ### Making Essential Pairing##########
        pairs, people, objects_only = ROI.pairing(out2_people, out2_objects,
                                                  out2_context, spatial_locs,
                                                  pairs_info)
        ######################################

        ### Interaction Probability ########
        f_Vis = self.W_vis(pairs)
        f_Ref = f_Vis * a_ho
        i_ho = self.W_IP(f_Ref)
        interaction_prob = self.sigmoid(i_ho)
        p_Ref = self.W_Ref(f_Ref)

        ### Prediction from attention features

        p_Att = self.W_Att(a_ho)

        ### Graph model base structure
        people_t = people
        objects_only = objects_only
        combine_g = []
        people_f = []
        objects_f = []
        pairs_f = []
        start_p = 0
        start_o = 0
        start_c = 0

        for batch_num, l in enumerate(pairs_info):

            ### Slicing ###
            people_this_batch = people_t[start_p:start_p + int(l[0])]
            num_peo = len(people_this_batch)

            objects_this_batch = objects_only[start_o:start_o + int(l[1])][1:]
            # because first index means no object
            no_objects_this_batch = objects_only[start_o:start_o +
                                                 int(l[1])][0]
            num_obj = len(objects_this_batch)

            interaction_prob_this_batch = interaction_prob[start_c : start_c + \
                                                           int(l[0]) * int(l[1])]

            if num_obj == 0:
                people_this_batch_r = people_this_batch  # r means refine
                objects_this_batch_r = no_objects_this_batch.view([1, 1024])
            else:
                peo_to_obj_this_batch = torch.stack([
                    torch.cat((i, j))
                    for ind_p, i in enumerate(people_this_batch)
                    for ind_o, j in enumerate(objects_this_batch)
                ])
                obj_to_peo_this_batch = torch.stack([
                    torch.cat((i, j))
                    for ind_p, i in enumerate(objects_this_batch)
                    for ind_o, j in enumerate(people_this_batch)
                ])

                ###################

                ## Adjacency ###
                adj_l = []
                adj_po = torch.zeros([num_peo, num_obj]).cuda()
                adj_op = torch.zeros([num_obj, num_peo]).cuda()

                for index_probs, probs in enumerate(
                        interaction_prob_this_batch):
                    if index_probs % (num_obj + 1) != 0:
                        adj_l.append(probs)

                adj_po = torch.cat(adj_l).view(len(adj_l),
                                               1)  # no gradient flow? I guess
                adj_op = adj_po

                ### Finding out Refined features ###

                people_this_batch_r = people_this_batch + torch.mm(
                    adj_po.view([num_peo, num_obj]),
                    self.W_oh(objects_this_batch))

                objects_this_batch_r = objects_this_batch + torch.mm(
                    adj_op.view([num_peo, num_obj]).t(),
                    self.W_ho(people_this_batch))
                objects_this_batch_r = torch.cat((no_objects_this_batch.view([1, 1024]),\
                                                  objects_this_batch_r))

            ### Reconstructing ###
            people_f.append(people_this_batch_r)
            people_t_f = people_this_batch_r
            objects_f.append(objects_this_batch_r)
            objects_t_f = objects_this_batch_r

            pairs_f.append(torch.stack([torch.cat((i, j)) for ind_p, i in enumerate(people_t_f) \
                                                  for ind_o, j in enumerate(objects_t_f)]))

            ## loop increment for next batch
            start_p += int(l[0])
            start_o += int(l[1])
            start_c += int(i[0]) * int(i[1])

        people_graph = torch.cat(people_f)
        objects_graph = torch.cat(objects_f)
        pairs_graph = torch.cat(pairs_f)

        p_Graph = self.W_graph(pairs_graph)

        return i_ho, p_Ref, p_Att, p_Graph