def add_predicates(graph, predicate_data):
    for image in predicate_data:
        relationships_info = image["relationships"]
        for relationship_info in relationships_info:
            predicate_name = predicate_to_aliases(
                relationship_info["predicate"])
            if graph.get_predicate_id(predicate_name) == None:
                # add node to graph
                pred_node = Predicate_Node(predicate_name, graph)
            subject_name = relationship_info["subject"][
                "name"] if "name" in relationship_info[
                    "subject"] else relationship_info["subject"]["names"][0]
            object_name = relationship_info["object"][
                "name"] if "name" in relationship_info[
                    "object"] else relationship_info["object"]["names"][0]
            subject_name = entity_to_aliases(subject_name)
            object_name = entity_to_aliases(object_name)
            subject_node = graph.get_entity_by_name(subject_name)
            if subject_node == None:
                subject_node = Entity_Node(subject_name, graph)
            object_id = graph.get_entity_id(object_name)
            if object_id == None:
                object_node = Entity_Node(object_name, graph)
                object_id = object_node.ID
            predicate_id = graph.get_predicate_id(predicate_name)

            if subject_node.get_predicate_edge(predicate_id,
                                               object_id) == None:
                # creating and adding edge
                new_edge = Predicate_Edge(subject_node.ID, predicate_id,
                                          object_id)
                subject_node.add_predicate_edge(new_edge)
def add_objects(graph, object_data):
    for image in object_data:
        for entity in image["objects"]:
            object_name = entity_to_aliases(
                entity["name"]) if "name" in entity else entity_to_aliases(
                    entity["names"])
            if graph.get_entity_id(object_name) == None:
                # add node to graph
                Entity_Node(object_name, graph)
def add_objects(graph, object_data, entity_counts, min_occurrences):
    for image in object_data:
        for entity in image["objects"]:
            object_name = entity_to_aliases(
                entity["name"]) if "name" in entity else entity_to_aliases(
                    entity["names"])
            if entity_counts[str(object_name)] >= min_occurrences:
                if graph.get_entity_id(object_name) == None:
                    # add node to graph
                    Entity_Node(object_name, graph)
def add_attributes(graph, attribute_data, entity_counts, attribute_counts,
                   min_occurrences):
    for image in attribute_data:
        for entity in image["attributes"]:
            if "attributes" in entity:
                attributes = entity["attributes"]
            else:
                continue
            for attribute_name in attributes:
                if attribute_counts[attribute_name] < min_occurrences:
                    continue
                if graph.get_attribute_id(attribute_name) == None:
                    # add node to graph
                    attribute_node = Attribute_Node(attribute_name, graph)
                subject_name = entity["name"] if "name" in entity else entity[
                    "names"]
                subject_name = entity_to_aliases(subject_name)
                if entity_counts[str(subject_name)] < min_occurrences:
                    continue
                subject_node = graph.get_entity_by_name(subject_name)
                if subject_node == None:
                    subject_node = Entity_Node(subject_name, graph)
                attribute_id = graph.get_attribute_id(attribute_name)
                if subject_node.get_attribute_edge(attribute_id) == None:
                    # creating and adding edge
                    new_edge = Attribute_Edge(subject_node.ID, attribute_id)
                    subject_node.add_attribute_edge(new_edge)
Beispiel #5
0
def get_entity_predicate_counts(scene_graph_data):
	for image in scene_graph_data:
		for entity in image["objects"]:
			entity_name = entity["name"] if "name" in entity else entity["names"]
			entity_name = entity_to_aliases(entity_name)
			if str(entity_name) in entity_counts:
				entity_counts[str(entity_name)] += 1
			else:
				entity_counts[str(entity_name)] = 1
		for predicate in image["relationships"]:
			predicate_name = predicate["predicate"]
			predicate_name = predicate_to_aliases(predicate_name)
			if str(predicate_name) in predicate_counts:
				predicate_counts[str(predicate_name)] += 1
			else:
				predicate_counts[str(predicate_name)] = 1
Beispiel #6
0
def train(semantic_action_graph, parameters, flags, models, dataloaders,
          optimizers, loss_functions, replay_buffer):
    print("CUDA Available: " + str(torch.cuda.is_available()))
    # make model CUDA
    if torch.cuda.is_available():
        model_IM_EMB = models["im_emb_model"].cuda()
        model_FRCNN = models["model_frcnn"].cuda()
        model_next_object_main = models["DQN_next_object_main"].cuda()
        model_next_object_target = models["DQN_next_object_target"].cuda()
        model_attribute_main = models["DQN_attribute_main"].cuda()
        model_attribute_target = models["DQN_attribute_target"].cuda()
        model_predicate_main = models["DQN_predicate_main"].cuda()
        model_predicate_target = models["DQN_predicate_target"].cuda()

    # keeps track of current scene graphs for images
    image_states = {}
    total_number_timesteps_taken = 0
    data_loader_val = dataloaders["val"]
    number_of_epochs = parameters["num_epochs"]
    data_loader = dataloaders["train"]

    # dictionary for skip-though
    skip_thought_dict = defaultdict(lambda: [])

    for epoch in range(number_of_epochs):
        print("Epoch: ", epoch)
        num = -1
        for progress, (images, images_orig,
                       gt_scene_graph) in enumerate(data_loader):
            images = torch.autograd.Variable(torch.squeeze(images, 1))
            if torch.cuda.is_available():
                images = images.cuda()
            # get image features from VGG16
            images = model_IM_EMB(images)

            # iterate through images in batch
            for idx in range(images.size(0)):
                num += 1
                print("Image number " + str(num))
                # initializing image state if necessary
                image_name = gt_scene_graph[idx]["image_name"]
                if image_name not in image_states:
                    gt_sg = gt_scene_graph[idx]
                    image_feature = images[idx]
                    entity_proposals, entity_scores, entity_classes = [], [], []
                    for obj in gt_scene_graph[idx]["labels"]["objects"]:
                        entity_proposals.append([
                            obj["x"], obj["y"], obj["x"] + obj["w"],
                            obj["y"] + obj["h"]
                        ])
                        entity_scores.append(1)
                        if "name" in obj:
                            entity_classes.append(obj["name"])
                        else:
                            entity_classes.append(obj["names"][0])
                    entity_proposals = np.array(entity_proposals)
                    entity_scores = np.array(entity_scores)
                    entity_classes = np.array(entity_classes)
                    #entity_proposals, entity_scores, entity_classes = models["model_FRCNN"].detect(images_orig[idx], object_detection_threshold)

                    entity_proposals = entity_proposals[:parameters[
                        "maximum_num_entities_per_image"]]
                    entity_scores = entity_scores[:parameters[
                        "maximum_num_entities_per_image"]]
                    entity_classes = entity_classes[:parameters[
                        "maximum_num_entities_per_image"]]
                    if len(entity_scores) < 2:
                        continue

                    entity_features = []
                    for box in entity_proposals:
                        cropped_entity = crop_box(images_orig[idx], box)
                        cropped_entity = torch.autograd.Variable(
                            cropped_entity)
                        if torch.cuda.is_available():
                            cropped_entity = cropped_entity.cuda()
                        box_feature = model_IM_EMB(cropped_entity)
                        entity_features.append(box_feature)
                    im_state = ImageState(gt_sg["image_name"], gt_sg,
                                          image_feature, entity_features,
                                          entity_proposals, entity_classes,
                                          entity_scores, semantic_action_graph)
                    im_state.initialize_entities(entity_proposals,
                                                 entity_classes, entity_scores)
                    image_states[image_name] = im_state
                else:
                    # reset image state from last epoch
                    image_states[image_name].reset()
                im_state = image_states[image_name]
                while not im_state.is_done():

                    #print("Iter for image " + str(image_name))

                    # get the image state object for image
                    im_state = image_states[image_name]

                    #print("Computing state vector")
                    # compute state vector of image
                    state_vector = create_state_vector(
                        im_state,
                        skip_thought_dict,
                        models["skip_thought_encoder"],
                        semantic_action_graph,
                        use_skip_thought=flags["skip_thought"])
                    subject_id = im_state.current_subject
                    object_id = im_state.current_object
                    if type(state_vector) == type(None):
                        if im_state.current_subject == None:
                            break
                        else:
                            im_state.explored_entities.append(
                                im_state.current_subject)
                            im_state.current_subject = None
                            im_state.current_object = None
                            continue

                    # perform variation structured traveral scheme to get adaptive actions
                    #print("Creating adaptive action sets...")
                    subject_name = entity_to_aliases(
                        im_state.entity_classes[subject_id])
                    object_name = entity_to_aliases(
                        im_state.entity_classes[object_id])
                    subject_bbox = im_state.entity_proposals[subject_id]
                    previously_mined_attributes = im_state.current_scene_graph[
                        "objects"][subject_id]["attributes"]
                    previously_mined_next_objects = im_state.objects_explored_per_subject[
                        subject_id]

                    if flags["adaptive_action_sets"]:
                        attribute_adaptive_actions, predicate_adaptive_actions = semantic_action_graph.variation_based_traversal(
                            subject_name, object_name,
                            previously_mined_attributes)
                        next_object_adaptive_actions = find_object_neighbors(
                            subject_bbox, im_state.entity_proposals,
                            previously_mined_next_objects)
                    else:
                        attribute_adaptive_actions = range(
                            len(semantic_action_graph.attribute_nodes))
                        predicate_adaptive_actions = range(
                            len(semantic_action_graph.predicate_nodes))
                        next_object_adaptive_actions = range(
                            len(im_state.entity_proposals) - 1)

                    # creating state + action vectors to feed in DQN
                    #print("Creating state + action vectors to pass into DQN...")
                    attribute_state_vectors = create_state_action_vector(
                        state_vector, attribute_adaptive_actions,
                        len(semantic_action_graph.attribute_nodes))
                    predicate_state_vectors = create_state_action_vector(
                        state_vector, predicate_adaptive_actions,
                        len(semantic_action_graph.predicate_nodes))
                    next_object_state_vectors = create_state_action_vector(
                        state_vector, next_object_adaptive_actions,
                        parameters["maximum_num_entities_per_image"])

                    # choose action using epsilon greedy
                    #print("Choose action using epsilon greedy...")
                    attribute_action, predicate_action, next_object_action = None, None, None
                    if type(attribute_state_vectors) != type(None):
                        attribute_action = choose_action_epsilon_greedy(
                            attribute_state_vectors,
                            attribute_adaptive_actions,
                            model_attribute_main,
                            parameters["epsilon"],
                            training=replay_buffer.can_sample())
                    if type(predicate_state_vectors) != type(None):
                        predicate_action = choose_action_epsilon_greedy(
                            predicate_state_vectors,
                            predicate_adaptive_actions,
                            model_predicate_main,
                            parameters["epsilon"],
                            training=replay_buffer.can_sample())

                    # update skip thought vector
                    if predicate_action != None and flags["skip_thought"]:
                        skip_thought_dict[(
                            im_state.current_subject,
                            im_state.current_object)].append(predicate_action)
                    if len(skip_thought_dict[(im_state.current_subject,
                                              im_state.current_object)]) > 2:
                        skip_thought_dict[(im_state.current_subject,
                                           im_state.current_object)].pop(0)

                    if type(next_object_state_vectors) != type(None):
                        next_object_action = choose_action_epsilon_greedy(
                            next_object_state_vectors,
                            next_object_adaptive_actions,
                            model_next_object_main,
                            parameters["epsilon"],
                            training=replay_buffer.can_sample())
                    # step image_state
                    #print("Step state environment using action...")
                    attribute_reward, predicate_reward, next_object_reward, done = im_state.step(
                        attribute_action, predicate_action, next_object_action)
                    #print("Rewards(A,P,O)", attribute_reward, predicate_reward, next_object_reward)
                    next_state = create_state_vector(
                        im_state,
                        skip_thought_dict,
                        models["skip_thought_encoder"],
                        semantic_action_graph,
                        use_skip_thought=flags["skip_thought"])
                    im_state = image_states[image_name]
                    # decay epsilon
                    if parameters["epsilon"] > parameters["epsilon_end"]:
                        parameters["epsilon"] = parameters[
                            "epsilon"] * parameters["epsilon_anneal_rate"]
                        #print("NEW EPSILON", parameters["epsilon"])
                    # add transition tuple to replay buffer
                    #print("Adding transition tuple to replay buffer...")
                    subject_name_1 = entity_to_aliases(
                        im_state.entity_classes[im_state.current_subject])
                    object_name_1 = entity_to_aliases(
                        im_state.entity_classes[im_state.current_object])
                    previously_mined_attributes_1 = im_state.current_scene_graph[
                        "objects"][im_state.current_subject]["attributes"]
                    previously_mined_next_objects_1 = im_state.objects_explored_per_subject[
                        im_state.current_subject]
                    attribute_adaptive_actions_1, predicate_adaptive_actions_1 = semantic_action_graph.variation_based_traversal(
                        subject_name_1, object_name_1,
                        previously_mined_attributes)
                    next_object_adaptive_actions_1 = find_object_neighbors(
                        im_state.entity_proposals[im_state.current_subject],
                        im_state.entity_proposals,
                        previously_mined_next_objects)

                    replay_buffer.push(state_vector, next_state,
                                       attribute_adaptive_actions,
                                       predicate_adaptive_actions,
                                       next_object_adaptive_actions,
                                       attribute_reward, predicate_reward,
                                       next_object_reward,
                                       attribute_adaptive_actions_1,
                                       predicate_adaptive_actions_1,
                                       next_object_adaptive_actions_1, done)

                    # sample minibatch if replay_buffer has enough samples
                    if replay_buffer.can_sample():
                        #print("Sample minibatch of transitions...")
                        minibatch_transitions = replay_buffer.sample(
                            parameters["batch_size"])
                        main_q_attribute_list, main_q_predicate_list, main_q_next_object_list = [], [], []
                        target_q_attribute_list, target_q_predicate_list, target_q_next_object_list = [], [], []
                        for transition in minibatch_transitions:
                            total_number_timesteps_taken += 1
                            target_q_attribute, target_q_predicate, target_q_next_object = None, None, None
                            if transition.done:
                                target_q_attribute = transition.attribute_reward
                                target_q_predicate = transition.predicate_reward
                                target_q_next_object = transition.target_q_next_object
                            else:
                                next_state_attribute = create_state_action_vector(
                                    transition.next_state,
                                    transition.next_state_attribute_actions,
                                    len(semantic_action_graph.attribute_nodes))
                                next_state_predicate = create_state_action_vector(
                                    transition.next_state,
                                    transition.next_state_predicate_actions,
                                    len(semantic_action_graph.predicate_nodes))
                                next_state_next_object = create_state_action_vector(
                                    transition.next_state,
                                    transition.next_state_next_object_actions,
                                    parameters[
                                        "maximum_num_entities_per_image"])
                                if type(next_state_attribute) != type(None):
                                    next_state_attribute.volatile = True
                                    output = torch.max(
                                        model_attribute_target(
                                            next_state_attribute))[0]
                                    #print("output of target model attributes", output)
                                    target_q_attribute = transition.attribute_reward + parameters[
                                        "discount_factor"] * output
                                if type(next_state_predicate) != type(None):
                                    next_state_predicate.volatile = True
                                    target_q_predicate = transition.predicate_reward + parameters[
                                        "discount_factor"] * torch.max(
                                            model_predicate_target(
                                                next_state_predicate))[0]
                                if type(next_state_next_object) != type(None):
                                    next_state_next_object.volatile = True
                                    target_q_next_object = transition.next_object_reward + parameters[
                                        "discount_factor"] * torch.max(
                                            model_next_object_target(
                                                next_state_next_object))[0]
                            # compute loss
                            main_state_attribute = create_state_action_vector(
                                transition.state, transition.attribute_actions,
                                len(semantic_action_graph.attribute_nodes))
                            main_state_predicate = create_state_action_vector(
                                transition.state, transition.predicate_actions,
                                len(semantic_action_graph.predicate_nodes))
                            main_state_next_object = create_state_action_vector(
                                transition.state,
                                transition.next_object_actions,
                                parameters["maximum_num_entities_per_image"])

                            main_q_attribute, main_q_predicate, main_q_next_object = None, None, None
                            if type(main_state_attribute) != type(
                                    None) and type(target_q_attribute) != type(
                                        None):
                                main_q_attribute = transition.attribute_reward + parameters[
                                    "discount_factor"] * torch.max(
                                        model_attribute_main(
                                            main_state_attribute))
                                #print("main & target preds", main_q_attribute, target_q_attribute)
                                loss_attribute = loss_functions["attribute"](
                                    main_q_attribute, target_q_attribute)
                                #print("Loss attribute: " + str(loss_attribute.data[0]))
                                optimizers["attribute"].zero_grad()
                                loss_attribute.backward()
                                for param in model_attribute_main.parameters():
                                    param.grad.data.clamp_(-1, 1)
                                optimizers["attribute"].step()

                            if type(main_state_predicate) != type(
                                    None) and type(target_q_predicate) != type(
                                        None):
                                main_q_predicate = transition.predicate_reward + parameters[
                                    "discount_factor"] * torch.max(
                                        model_predicate_main(
                                            main_state_predicate))
                                loss_predicate = loss_functions["predicate"](
                                    main_q_predicate, target_q_predicate)
                                optimizers["predicate"].zero_grad()
                                #print("Loss predicate: " + str(loss_predicate.data[0]))
                                loss_predicate.backward()
                                for param in model_predicate_main.parameters():
                                    param.grad.data.clamp_(-1, 1)
                                optimizers["predicate"].step()

                            if type(main_state_next_object) != type(
                                    None) and type(
                                        target_q_next_object) != type(None):
                                main_q_next_object = transition.next_object_reward + parameters[
                                    "discount_factor"] * torch.max(
                                        model_next_object_main(
                                            main_state_next_object))
                                loss_next_object = loss_functions[
                                    "next_object"](main_q_next_object,
                                                   target_q_next_object)
                                optimizers["next_object"].zero_grad()
                                #print("Loss next object: " + str(loss_next_object.data[0]))
                                loss_next_object.backward()
                                for param in model_next_object_main.parameters(
                                ):
                                    param.grad.data.clamp_(-1, 1)
                                optimizers["next_object"].step()

                    # update target weights if it has been tao steps
                    if total_number_timesteps_taken % parameters[
                            "target_update_frequency"] == 0:
                        #print("UPDATING TARGET NOW")
                        update_target(model_attribute_main,
                                      model_attribute_target)
                        update_target(model_predicate_main,
                                      model_predicate_target)
                        update_target(model_next_object_main,
                                      model_next_object_target)

    gt_graphs = []
    our_graphs = []
    for ims in image_states.values():
        gt_graphs.append(ims.gt_scene_graph)
        our_graphs.append(ims.current_scene_graph)
    with open("image_states.pickle", "wb") as handle:
        pickle.dump({"gt": gt_graphs, "curr": our_graphs}, handle)
Beispiel #7
0
def train(parameters, train=True):
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    print("CUDA Available: " + str(torch.cuda.is_available()))
    # make model CUDA
    if torch.cuda.is_available():
        model_VGG = model_vgg.cuda()
        #model_FRCNN = model_frcnn.cuda()
        model_main = DQN_main.cuda()
        model_target = DQN_target.cuda()
        # model_next_object_main = DQN_next_object_main.cuda()
        # model_next_object_target = DQN_next_object_target.cuda()
        # model_attribute_main = DQN_attribute_main.cuda()
        # model_attribute_target = DQN_attribute_target.cuda()
        # model_predicate_main = DQN_predicate_main.cuda()
        # model_predicate_target = DQN_predicate_target.cuda()

    # keeps track of current scene graphs for images
    image_states = {}
    total_number_timesteps_taken = 0
    if train == False:
        number_of_epochs = 1
        data_loader = validation_data_loader
    else:
        number_of_epochs = num_epochs
        data_loader = train_data_loader

    # dictionary for skip-though
    skip_thought_dict = defaultdict(lambda: [])

    for epoch in range(number_of_epochs):
        print("Epoch: ", epoch)
        num = -1
        for progress, (images, images_orig,
                       gt_scene_graph) in enumerate(data_loader):
            images = torch.autograd.Variable(torch.squeeze(images, 1))
            if torch.cuda.is_available():
                images = images.cuda()
            # get image features from VGG16
            # print images.size()
            images = model_VGG(images)
            # print images.size()
            # iterate through images in batch
            for idx in range(images.size(0)):
                num += 1
                print("Image number " + str(num))
                # initializing image state if necessary
                image_name = gt_scene_graph[idx]["image_name"]
                if image_name not in image_states:
                    gt_sg = gt_scene_graph[idx]
                    image_feature = torch.mean(
                        torch.mean(images[idx, :, :, :], 1), 1)
                    # print image_feature.size()
                    entity_proposals, entity_scores, entity_classes = [], [], []
                    for obj in gt_scene_graph[idx]["labels"]["objects"]:
                        entity_proposals.append([
                            obj["x"], obj["y"], obj["x"] + obj["w"],
                            obj["y"] + obj["h"]
                        ])
                        entity_scores.append(1)
                        if "name" in obj:
                            entity_classes.append(obj["name"])
                        else:
                            entity_classes.append(obj["names"][0])
                    entity_proposals = np.array(entity_proposals)
                    entity_scores = np.array(entity_scores)
                    entity_classes = np.array(entity_classes)
                    #entity_proposals, entity_scores, entity_classes = model_FRCNN.detect(images_orig[idx], object_detection_threshold)
                    #
                    entity_proposals = entity_proposals[:
                                                        maximum_num_entities_per_image]
                    entity_scores = entity_scores[:
                                                  maximum_num_entities_per_image]
                    entity_classes = entity_classes[:
                                                    maximum_num_entities_per_image]
                    if len(entity_scores) < 2:
                        continue

                    entity_features = []
                    for box in entity_proposals:
                        # cropped_entity = crop_box(images_orig[idx], box)
                        # cropped_entity = torch.autograd.Variable(cropped_entity)
                        # if torch.cuda.is_available():
                        # 	cropped_entity = cropped_entity.cuda()
                        # box_feature = model_VGG(cropped_entity)
                        box_feature = object_features(images_orig[idx],
                                                      images[idx], box)
                        entity_features.append(box_feature)
                    print len(entity_features)
                    im_state = ImageState(gt_sg["image_name"], gt_sg,
                                          image_feature, entity_features,
                                          entity_proposals, entity_classes,
                                          entity_scores, semantic_action_graph)
                    im_state.initialize_entities(entity_proposals,
                                                 entity_classes, entity_scores)
                    image_states[image_name] = im_state
                else:
                    # reset image state from last epoch
                    image_states[image_name].reset()
                im_state = image_states[image_name]
                next_state = create_state_vector(im_state, skip_thought_dict)
                subject_id = im_state.current_subject
                object_id = im_state.current_object
                subject_name = entity_to_aliases(
                    im_state.entity_classes[subject_id])
                object_name = entity_to_aliases(
                    im_state.entity_classes[object_id])
                subject_bbox = im_state.entity_proposals[subject_id]
                previously_mined_attributes = im_state.current_scene_graph[
                    "objects"][subject_id]["attributes"]
                previously_mined_next_objects = im_state.objects_explored_per_subject[
                    subject_id]
                # constraints
                # list of possible actions
                attribute_adaptive_actions_1, predicate_adaptive_actions_1 = semantic_action_graph.variation_based_traversal(
                    subject_name, object_name, previously_mined_attributes)
                object_neighbours = find_object_neighbors(
                    subject_bbox, im_state.entity_proposals,
                    previously_mined_next_objects)
                # added terminal action
                temp = [
                    entity_to_aliases(im_state.entity_classes[x])
                    for x in object_neighbours
                ]
                next_object_adaptive_actions_1 = sorted(
                    list(
                        set([
                            semantic_action_graph.entity_name_to_id[x]
                            for x in temp
                            if x in semantic_action_graph.entity_name_to_id
                        ] + [1293])))

                n_steps = 0
                t1 = time.time()
                while not im_state.is_done():
                    #print("Iter for image " + str(image_name))
                    #print("Computing state vector")
                    # compute state vector of image
                    n_steps += 1
                    if n_steps > parameters["max_allowed_steps"]:
                        break
                    # state_vector = create_state_vector(im_state, skip_thought_dict)
                    state_vector = next_state
                    # subject_id = im_state.current_subject
                    # object_id = im_state.current_object

                    # if type(state_vector) == type(None):
                    # 	if im_state.current_subject == None:
                    # 		break
                    # 	else:
                    # 		im_state.explored_entities.append(im_state.current_subject)
                    # 		im_state.current_subject = None
                    # 		im_state.current_object = None
                    # 		continue

                    # perform variation structured traveral scheme to get adaptive actions
                    #print("Creating adaptive action sets...")

                    # subject_name = entity_to_aliases(im_state.entity_classes[subject_id])
                    # object_name = entity_to_aliases(im_state.entity_classes[object_id])
                    # subject_bbox = im_state.entity_proposals[subject_id]
                    # previously_mined_attributes = im_state.current_scene_graph["objects"][subject_id]["attributes"]
                    # previously_mined_next_objects = im_state.objects_explored_per_subject[subject_id]

                    # constraints
                    # list of possible actions
                    # attribute_adaptive_actions, predicate_adaptive_actions = semantic_action_graph.variation_based_traversal(subject_name, object_name, previously_mined_attributes)
                    # object_neighbours = find_object_neighbors(subject_bbox, im_state.entity_proposals, previously_mined_next_objects)
                    # # added terminal action
                    # temp = [entity_to_aliases(im_state.entity_classes[x]) for x in object_neighbours]
                    # next_object_adaptive_actions = sorted(list(set([semantic_action_graph.entity_name_to_id[x] for x in temp if x in semantic_action_graph.entity_name_to_id]+[1293])))
                    attribute_adaptive_actions = attribute_adaptive_actions_1
                    predicate_adaptive_actions = predicate_adaptive_actions_1
                    next_object_adaptive_actions = next_object_adaptive_actions_1

                    # if args.use_adaptive_action_sets:
                    # 	attribute_adaptive_actions, predicate_adaptive_actions = semantic_action_graph.variation_based_traversal(subject_name, object_name, previously_mined_attributes)
                    # 	next_object_adaptive_actions = find_object_neighbors(subject_bbox, im_state.entity_proposals, previously_mined_next_objects)
                    # else:
                    # 	attribute_adaptive_actions = range(len(semantic_action_graph.attribute_nodes))
                    # 	predicate_adaptive_actions = range(len(semantic_action_graph.predicate_nodes))
                    # 	next_object_adaptive_actions = range(len(im_state.entity_proposals)+1)

                    # creating state + action vectors to feed in DQN
                    #print("Creating state + action vectors to pass into DQN...")
                    # attribute_state_vectors = create_state_action_vector(state_vector, attribute_adaptive_actions, len(semantic_action_graph.attribute_nodes))
                    # predicate_state_vectors = create_state_action_vector(state_vector, predicate_adaptive_actions, len(semantic_action_graph.predicate_nodes))
                    # next_object_state_vectors = create_state_action_vector(state_vector, next_object_adaptive_actions, parameters["maximum_num_entities_per_image"])

                    # choose action using epsilon greedy
                    # print("Choose action using epsilon greedy...")
                    # attribute_action, predicate_action, next_object_action = None, None, None
                    # if type(attribute_state_vectors) != type(None):
                    # 	attribute_action = choose_action_epsilon_greedy(attribute_state_vectors, attribute_adaptive_actions, model_attribute_main, parameters["epsilon"], training=replay_buffer.can_sample())
                    # if type(predicate_state_vectors) != type(None):
                    # 	predicate_action = choose_action_epsilon_greedy(predicate_state_vectors, predicate_adaptive_actions, model_predicate_main, parameters["epsilon"], training=replay_buffer.can_sample())

                    output = model_main(state_vector)
                    attribute_action, predicate_action, next_object_action = choose_actions(
                        output, attribute_adaptive_actions,
                        predicate_adaptive_actions,
                        next_object_adaptive_actions, parameters["epsilon"])

                    # update skip thought vector
                    if predicate_action != None and args.use_skip_thought:
                        skip_thought_dict[(
                            im_state.current_subject,
                            im_state.current_object)].append(predicate_action)
                    if len(skip_thought_dict[(im_state.current_subject,
                                              im_state.current_object)]) > 2:
                        skip_thought_dict[(im_state.current_subject,
                                           im_state.current_object)].pop(0)

                    # if type(next_object_state_vectors) != type(None):
                    # 	next_object_action = choose_action_epsilon_greedy(next_object_state_vectors, next_object_adaptive_actions, model_next_object_main, parameters["epsilon"], training=replay_buffer.can_sample())
                    # step image_state
                    #print("Step state environment using action...")
                    attribute_reward, predicate_reward, next_object_reward, done = im_state.step(
                        attribute_action[0], predicate_action[0],
                        next_object_action[0])
                    #print("Rewards(A,P,O)", attribute_reward, predicate_reward, next_object_reward)
                    next_state = create_state_vector(im_state,
                                                     skip_thought_dict)

                    # decay epsilon
                    if parameters["epsilon"] > parameters["epsilon_end"]:
                        parameters["epsilon"] = parameters[
                            "epsilon"] * parameters["epsilon_anneal_rate"]
                        #print("NEW EPSILON", parameters["epsilon"])
                    # add transition tuple to replay buffer
                    #print("Adding transition tuple to replay buffer...")
                    subject_name_1 = entity_to_aliases(
                        im_state.entity_classes[im_state.current_subject])
                    object_name_1 = entity_to_aliases(
                        im_state.entity_classes[im_state.current_object])
                    previously_mined_attributes_1 = im_state.current_scene_graph[
                        "objects"][im_state.current_subject]["attributes"]
                    previously_mined_next_objects_1 = im_state.objects_explored_per_subject[
                        im_state.current_subject]
                    # attribute_adaptive_actions_1, predicate_adaptive_actions_1 = semantic_action_graph.variation_based_traversal(subject_name_1, object_name_1, previously_mined_attributes)
                    # next_object_adaptive_actions_1 = find_object_neighbors(im_state.entity_proposals[im_state.current_subject], im_state.entity_proposals, previously_mined_next_objects)

                    subject_bbox_1 = im_state.entity_proposals[
                        im_state.current_subject]
                    attribute_adaptive_actions_1, predicate_adaptive_actions_1 = semantic_action_graph.variation_based_traversal(
                        subject_name_1, object_name_1,
                        previously_mined_attributes_1)
                    object_neighbours = find_object_neighbors(
                        subject_bbox_1, im_state.entity_proposals,
                        previously_mined_next_objects_1)
                    # added terminal action
                    temp = [
                        entity_to_aliases(im_state.entity_classes[x])
                        for x in object_neighbours
                    ]
                    next_object_adaptive_actions_1 = sorted(
                        list(
                            set([
                                semantic_action_graph.entity_name_to_id[x]
                                for x in temp
                                if x in semantic_action_graph.entity_name_to_id
                            ] + [1293])))

                    # print object_neighbours, next_object_adaptive_actions_1, previously_mined_next_objects_1

                    if len(attribute_adaptive_actions_1) == 0:
                        attribute_adaptive_actions_1 = [random.randint(0, 741)]
                    if len(predicate_adaptive_actions_1) == 0:
                        predicate_adaptive_actions_1 = [random.randint(0, 302)]
                    if len(next_object_adaptive_actions_1) == 0:
                        next_object_adaptive_actions_1 = [
                            random.randint(0, 1293)
                        ]
                    replay_buffer.push(state_vector, next_state,
                                       attribute_action, predicate_action,
                                       next_object_action, attribute_reward,
                                       predicate_reward, next_object_reward,
                                       attribute_adaptive_actions_1,
                                       predicate_adaptive_actions_1,
                                       next_object_adaptive_actions_1, done)

                    # sample minibatch if replay_buffer has enough samples
                    if replay_buffer.can_sample():
                        #print("Sample minibatch of transitions...")
                        minibatch_transitions = replay_buffer.sample(
                            parameters["batch_size"])
                        main_q_attribute_list, main_q_predicate_list, main_q_next_object_list = [], [], []
                        target_q_attribute_list, target_q_predicate_list, target_q_next_object_list = [], [], []

                        next_state_vectors = []
                        curr_state_vectors = []
                        expected = []
                        predicted = []
                        for indx, trans in enumerate(minibatch_transitions):
                            next_state_vectors.append(trans.next_state)
                            curr_state_vectors.append(trans.state)
                        output = model_main(torch.stack(curr_state_vectors, 0))
                        output1 = model_target(
                            torch.stack(next_state_vectors, 0))
                        output1 = np.asarray(output1.data)
                        for indx, trans in enumerate(minibatch_transitions):
                            target_q_attribute = trans.attribute_reward
                            target_q_predicate = trans.predicate_reward
                            target_q_next_object = trans.next_object_reward
                            if ~trans.done:
                                # print trans.next_state_attribute_actions
                                # print trans.next_state_predicate_actions
                                # print trans.next_state_next_object_actions
                                temp = output1[indx][:742]
                                target_q_attribute += parameters[
                                    "discount_factor"] * np.max(temp[
                                        trans.next_state_attribute_actions])
                                temp = output1[indx][742:742 + 303]
                                target_q_predicate += parameters[
                                    "discount_factor"] * np.max(temp[
                                        trans.next_state_predicate_actions])
                                temp = output1[indx][742 + 303:]
                                target_q_next_object += parameters[
                                    "discount_factor"] * np.max(temp[
                                        trans.next_state_next_object_actions])
                            expected.append(target_q_attribute)
                            expected.append(target_q_predicate)
                            expected.append(target_q_next_object)
                            predicted.append(output[indx,
                                                    trans.attribute_action[0]])
                            predicted.append(output[indx, 742 +
                                                    trans.predicate_action[0]])
                            predicted.append(
                                output[indx, 742 + 303 +
                                       trans.next_object_action[0]])
                        loss = loss_fn(
                            torch.stack(predicted, 0),
                            torch.autograd.Variable(
                                torch.from_numpy(
                                    np.stack(expected,
                                             0).astype(np.float32))).cuda())
                        #print("Loss attribute: " + str(loss_attribute.data[0]))
                        optimizer.zero_grad()
                        loss.backward()
                        for param in model_main.parameters():
                            param.grad.data.clamp_(-1, 1)
                        optimizer.step()

                        # for transition in minibatch_transitions:
                        # 	total_number_timesteps_taken += 1
                        # 	target_q_attribute, target_q_predicate, target_q_next_object = None, None, None

                        # 	if transition.done:
                        # 		target_q_attribute = transition.attribute_reward
                        # 		target_q_predicate = transition.predicate_reward
                        # 		target_q_next_object = transition.target_q_next_object
                        # 	else:
                        # 		next_state_attribute = create_state_action_vector(transition.next_state, transition.next_state_attribute_actions, len(semantic_action_graph.attribute_nodes))
                        # 		next_state_predicate = create_state_action_vector(transition.next_state, transition.next_state_predicate_actions, len(semantic_action_graph.predicate_nodes))
                        # 		next_state_next_object = create_state_action_vector(transition.next_state, transition.next_state_next_object_actions, parameters["maximum_num_entities_per_image"])
                        # 		if type(next_state_attribute) != type(None):
                        # 			next_state_attribute.volatile = True
                        # 			output = torch.max(model_attribute_target(next_state_attribute))[0]
                        # 			#print("output of target model attributes", output)
                        # 			target_q_attribute = transition.attribute_reward + parameters["discount_factor"] * output
                        # 		if type(next_state_predicate) != type(None):
                        # 			next_state_predicate.volatile = True
                        # 			target_q_predicate = transition.predicate_reward + parameters["discount_factor"] * torch.max(model_predicate_target(next_state_predicate))[0]
                        # 		if type(next_state_next_object) != type(None):
                        # 			next_state_next_object.volatile = True
                        # 			target_q_next_object = transition.next_object_reward + parameters["discount_factor"] * torch.max(model_next_object_target(next_state_next_object))[0]

                        # 	# compute loss
                        # 	main_state_attribute = create_state_action_vector(transition.state, transition.attribute_actions, len(semantic_action_graph.attribute_nodes))
                        # 	main_state_predicate = create_state_action_vector(transition.state, transition.predicate_actions, len(semantic_action_graph.predicate_nodes))
                        # 	main_state_next_object = create_state_action_vector(transition.state, transition.next_object_actions, parameters["maximum_num_entities_per_image"])

                        # 	main_q_attribute, main_q_predicate, main_q_next_object = None, None, None

                        # 	if type(main_state_attribute) != type(None) and type(target_q_attribute) != type(None):
                        # 		main_q_attribute = transition.attribute_reward + parameters["discount_factor"] * torch.max(model_attribute_main(main_state_attribute))
                        # 		#print("main & target preds", main_q_attribute, target_q_attribute)
                        # 		loss_attribute = loss_fn_attribute(main_q_attribute, target_q_attribute)
                        # 		#print("Loss attribute: " + str(loss_attribute.data[0]))
                        # 		optimizer_attribute.zero_grad()
                        # 		loss_attribute.backward()
                        # 		for param in model_attribute_main.parameters():
                        # 			param.grad.data.clamp_(-1, 1)
                        # 		optimizer_attribute.step()

                        # 	if type(main_state_predicate) != type(None) and type(target_q_predicate) != type(None):
                        # 		main_q_predicate = transition.predicate_reward + parameters["discount_factor"] * torch.max(model_predicate_main(main_state_predicate))
                        # 		loss_predicate = loss_fn_predicate(main_q_predicate, target_q_predicate)
                        # 		optimizer_predicate.zero_grad()
                        # 		#print("Loss predicate: " + str(loss_predicate.data[0]))
                        # 		loss_predicate.backward()
                        # 		for param in model_predicate_main.parameters():
                        # 			param.grad.data.clamp_(-1, 1)
                        # 		optimizer_predicate.step()

                        # 	if type(main_state_next_object) != type(None) and type(target_q_next_object) != type(None):
                        # 		main_q_next_object = transition.next_object_reward + parameters["discount_factor"] * torch.max(model_next_object_main(main_state_next_object))
                        # 		loss_next_object = loss_fn_next_object(main_q_next_object, target_q_next_object)
                        # 		optimizer_next_object.zero_grad()
                        # 		#print("Loss next object: " + str(loss_next_object.data[0]))
                        # 		loss_next_object.backward()
                        # 		for param in model_next_object_main.parameters():
                        # 			param.grad.data.clamp_(-1, 1)
                        # 		optimizer_next_object.step()

                    # update target weights if it has been tao steps
                    if total_number_timesteps_taken % target_update_frequency == 0:
                        #print("UPDATING TARGET NOW")
                        update_target(model_main, model_target)
                        # update_target(model_predicate_main, model_predicate_target)
                        # update_target(model_next_object_main, model_next_object_target)

                print "Image :", n_steps, time.time() - t1

    gt_graphs = []
    our_graphs = []
    for ims in image_states.values():
        gt_graphs.append(ims.gt_scene_graph)
        our_graphs.append(ims.current_scene_graph)
    with open("../data/image_states.pickle", "wb") as handle:
        pickle.dump({"gt": gt_graphs, "curr": our_graphs}, handle)
Beispiel #8
0
    def step(self, attribute_action, predicate_action, next_object_action):
        # should return reward_attribute, reward_predicate, and
        # reward_next_object, and boolean indicating whether done
        print self.current_subject, self.current_object, attribute_action, predicate_action, next_object_action, self.entity_queue

        reward_attribute, reward_predicate, reward_next_object = -1, -1, -1
        pred_attribute_name = self.graph.attribute_nodes[
            attribute_action].name if attribute_action != None else None
        pred_predicate_name = self.graph.predicate_nodes[
            predicate_action].name if predicate_action != None else None
        if next_object_action != 1293:
            pred_next_action_name = self.graph.entity_nodes[
                next_object_action].name if next_object_action != None else None
        if attribute_action != None:
            self.add_attribute(self.current_subject, attribute_action)
        if pred_predicate_name != None:
            self.add_predicate(self.current_subject, pred_predicate_name,
                               self.current_object)

        # removed as not necessary with ground truth
        # gt_subject_index = self.overlaps(self.current_subject)
        # if gt_subject_index != -1: #overlap
        if "attributes" in self.gt_scene_graph["labels"]["objects"][
                self.
                current_subject] and pred_attribute_name in self.gt_scene_graph[
                    "labels"]["objects"][self.current_subject]["attributes"]:
            reward_attribute = 1
        # gt_object_index = self.overlaps(self.current_object)
        # if gt_object_index != -1:

        # make HashMap
        for relationship_dict in self.gt_scene_graph["labels"][
                "relationships"]:
            if pred_predicate_name == relationship_dict["predicate"] and \
             self.current_subject == relationship_dict["subject_id"] and \
             self.current_object == relationship_dict["object_id"]:
                reward_predicate = 1
                break

        # if next_object_action != None and next_object_action < len(self.entity_proposals):
        # 	gt_new_object_index = self.overlaps(next_object_action)
        # 	#self.explored_entities.append(new_object_index)
        # 	if gt_new_object_index != -1:
        # 		if gt_new_object_index not in self.explored_entities:
        # 			reward_next_object = 5
        # 	self.current_object = next_object_action
        # else:
        # 	self.current_subject = None

        # make this fast
        if next_object_action != None and next_object_action != 1293:
            flag = 0
            subject_bbox = self.entity_proposals[self.current_subject]
            x, y, x2, y2 = subject_bbox[0], subject_bbox[1], subject_bbox[
                2], subject_bbox[3]
            w, h = x2 - x, y2 - y
            for index, obj_dict in enumerate(
                    self.gt_scene_graph["labels"]["objects"]):
                obj_name = entity_to_aliases(
                    obj_dict["name"] if "name" in
                    obj_dict else obj_dict["names"][0])
                obj = self.entity_proposals[index]
                if pred_next_action_name == obj_name and index not in self.objects_explored_per_subject[
                        self.current_subject]:
                    if abs(obj[0] - x) < 0.5 * (obj[2] + w) and abs(
                            obj[1] - y) < 0.5 * (obj[3] + h):
                        # print pred_next_action_name, obj_name
                        self.current_object = index
                        self.objects_explored_per_subject[
                            self.current_subject].append(self.current_object)
                        if index not in self.explored_entities and index not in self.entity_queue:
                            reward_next_object = 5
                            # add to queue of BFS
                            self.entity_queue.append(index)
                        flag = 1
                        break
            if flag == 0:
                self.get_new_subject()
                self.current_object = None
        elif next_object_action == 1293:
            self.get_new_subject()
            self.current_object = None
        return reward_attribute, reward_predicate, reward_next_object, self.is_done(
        )