示例#1
0
    def step(self):
        if len(self.clauses) == 0:
            return torch.tensor(0, dtype=torch.float32)

        # binding_dict = query(self.graph.scene, self.clauses, self.config)
        next_clause = self.clauses[-1]
        useful = self.interp.useful_check(self.interp_state, next_clause)
        logging.info(f"useful: {useful}")

        binding_dict, new_state = self.interp.state_query(
            self.interp_state, next_clause)
        self.interp_state = new_state

        # update the success and possible fields
        self.check_possible(binding_dict)

        if not self.possible:
            self.obj_poss_left.append(0)
        elif not "var_0" in binding_dict.keys():
            self.obj_poss_left.append(self.obj_nums)
        else:
            self.obj_poss_left.append(len(binding_dict["var_0"]))
        self.check_success(binding_dict)

        done = self.success or (not self.possible)

        if not useful:
            reward = torch.tensor(-1, dtype=torch.float32)
        else:
            reward = torch.tensor(get_reward(self), dtype=torch.float32)

        return reward
示例#2
0
def episode(policy, target, data_point, graph, config, attr_encoder, memory,
            total_count, optimizer):
    env = Env(data_point, graph, config, attr_encoder)
    for iter_count in range(cmd_args.episode_length):

        state = env.get_state()
        action = select_action(policy, state)

        logging.info(f"selected clause: {env.actions[action]}")
        next_state, reward, done, _ = env.step(action)

        if done:
            next_state = None

        # cannot find out the result in limited steps
        if (iter_count == cmd_args.episode_length - 1):
            reward = get_final_reward(env)

        logging.info(f"reward: {reward}")
        memory.push(state, action, next_state, reward)
        optimize_model_DQ(memory, policy, target, optimizer)

        if done:
            break  #

        if total_count % cmd_args.target_update == 0:
            target.load_state_dict(policy.state_dict())

    if env.success:
        return True
    else:
        return False
def fit_one(refrl, data_point, graph, eps):
    global sub_ct
    sub_ct += 1

    if sub_ct > cmd_args.max_sub_prob:
        return None, None

    env, retrain_list = refrl.episode(data_point, graph, eps,
                                      refrl.dataset.attr_encoder)
    loss = policy_gradient_loss(refrl.policy.reward_history,
                                refrl.policy.prob_history)

    if cmd_args.sub_loss:
        sub_loss = []
        for data_point, clauses in retrain_list:
            logging.info(f"clauses in env: {clauses}")
            e, r = fit_one(refrl, data_point, env.graph, eps)
            if e == None:
                print("Oops! running out of budget")
                return env, loss

            sub_loss.append(r)
        loss += sum(sub_loss)

    return env, loss
    def step(self, action_idx):

        is_uncertain = self.is_uncertain
        next_clause = self.actions[action_idx]

        if self.ref_flag:
            for element in next_clause:
                if type(element) == int:
                    if element not in self.ref:
                        self.ref.append(element)

            self.unreachable = self.unreachable_dict[str(sorted(self.ref))]

        # if 'red' in next_clause or 'blue' in next_clause:
        #     print('here')

        self.idx_selected.append(action_idx)
        self.clauses.append(next_clause)

        useful = self.interp.useful_check(self.state, next_clause)
        logging.info(f"useful: {useful}")
        binding_dict, new_state = self.interp.state_query(
            self.state, next_clause)
        self.state = new_state

        # update the success and possible fields
        self.check_possible(binding_dict)

        if not self.possible:
            self.obj_poss_left.append(0)
        elif not "var_0" in binding_dict.keys():
            self.obj_poss_left.append(self.obj_nums)
            self.update_data(binding_dict)
        else:
            self.obj_poss_left.append(len(binding_dict["var_0"]))
            self.update_data(binding_dict)

        # update whether done or not
        self.check_success(binding_dict)
        done = self.success or (not self.possible)

        if not useful:
            reward = torch.tensor(-1, dtype=torch.float32)
        else:
            reward = torch.tensor(get_reward(self), dtype=torch.float32)

        self.update_data(binding_dict)
        info = {}

        logging.info(f"selected: {self.idx_selected}")
        logging.info(f"done: {done}")
        logging.info(f"success: {self.success}")

        if self.ref_flag:
            state = self.get_state(self.unreachable)
        else:
            state = self.get_state()

        return state, reward, done, info
示例#5
0
    def episode(self, data_point, graph, eps, attr_encoder):
        self.policy.reset(eps)
        retrain_list = []

        env = Env(data_point, graph, self.config, attr_encoder)
        iter_count = 0

        while not env.is_finished():
            # cannot figure out the clauses in limited step
            if (iter_count > cmd_args.episode_length):
                if cmd_args.reward_type == "only_success":
                    final_reward = get_final_reward(env)
                    if final_reward == -1:
                        self.policy.reward_history = [0.0] * len(
                            self.policy.reward_history)
                        self.policy.reward_history.append(-1.0)
                    else:
                        self.policy.reward_history.append(1.0)
                    self.policy.reward_history = torch.tensor(
                        self.policy.reward_history)
                else:
                    self.policy.reward_history.append(get_final_reward(env))
                    self.policy.prob_history = torch.cat(
                        [self.policy.prob_history,
                         torch.tensor([1.0])])
                break

            iter_count += 1
            env = self.policy(env)

        logging.info(self.policy.reward_history)
        logging.info(self.policy.prob_history)
        logging.info(env.clauses)
        return env, retrain_list
def select_action(policy_net, state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (cmd_args.eps - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    logging.info(f"eps_threshold: {eps_threshold}")
    clauses_idx = state.actions
    avoid_idx = state.idx_selected

    selections = []
    # aviod repeat selections
    for idx in range(len(clauses_idx)):
        if idx not in avoid_idx:
            selections.append(idx)

    action_values = policy_net(state)
    # logging.info(f"all probs: {action_values}")

    if sample > eps_threshold:
        with torch.no_grad():
            # t.max(1) will return largest column value of each row.
            # second column on max result is index of where max element was
            # found, so we pick action with the larger expected reward.
            logging.info("max")

            action_idx = action_values[:, selections].max(1)[1]
            action_idx = torch.tensor(selections[action_idx]).view(1, 1)
            action_value = action_values[:, selections].max(1)[0]
            logging.info(f"clause prob: {action_value}")
            return action_idx
    else:
        logging.info("random")
        select_id = random.choice(selections)
        action_value = action_values[0][select_id]
        logging.info(f"clause prob: {action_value}")
        return torch.tensor([[select_id]], dtype=torch.long)
示例#7
0
def fit(refrl):
    refrl.policy.train()
    print(refrl.policy.train())
    # refrl.train_data.shuffle()
    total_ct = 0
    data_loader = DataLoader(refrl.train_data)
    eps = cmd_args.eps
    print(type(eps))
    # with autograd.detect_anomaly():
    for it in range(cmd_args.episode_iter):

        logging.info(f"training iteration: {it}")
        success = 0

        total_loss = 0.0
        if refrl.iteration > it:
            continue

        for data_point, ct in zip(data_loader, tqdm(range(len(data_loader)))):
            global sub_ct
            sub_ct = 0
            total_ct += 1
            logging.info(ct)

            # if ct == 20000:
            #     print ("debug")
            #     continue
            graph = refrl.graphs[data_point.graph_id]
            env, loss = fit_one(refrl, data_point, graph, eps)

            total_loss += loss
            if env.success:
                success += 1

            if total_ct % cmd_args.batch_size == 0:

                refrl.optimizer.zero_grad()

            loss.backward()

            if total_ct % cmd_args.batch_size == 0:
                refrl.optimizer.step()

            if total_ct % cmd_args.save_num == 0 and not total_ct == 0:
                torch.save(refrl, cmd_args.model_path)

        logging.info(
            f"at train iter {it}, success num {success}, ave loss {total_loss/ct}"
        )
        refrl.iteration += 1
        eps = cmd_args.eps_decay * eps
def test(refrl, split="test"):

    logging.info(f"testing on {split} data")
    refrl.policy.eval()

    if split == "train":
        data_loader = DataLoader(refrl.train_data)
    else:
        data_loader = DataLoader(refrl.test_data)

    success = 0
    avg_loss = 0
    eps = 0
    total_ct = 0

    for it in range(cmd_args.test_iter):
        logging.info(f"testing iteration: {it}")
        with torch.no_grad():
            for data_point, ct in zip(data_loader,
                                      tqdm(range(len(data_loader)))):

                logging.info(ct)
                total_ct += 1

                graph = refrl.graphs[data_point.graph_id]
                env, _ = refrl.episode(data_point,
                                       graph,
                                       eps,
                                       refrl.dataset.attr_encoder,
                                       phase="test")
                loss = policy_gradient_loss(refrl.policy.reward_history,
                                            refrl.policy.prob_history)
                avg_loss += loss

                if env.success:
                    success += 1

    avg_loss /= total_ct
    logging.info(
        f"Testing {split}: success {success} out of {total_ct}, average loss is {avg_loss}"
    )
def episode(policy,
            data_point,
            graph,
            eps,
            attr_encoder,
            config,
            phase="train"):
    policy.reset(eps)
    retrain_list = []

    env = Env(data_point, graph, config, attr_encoder)
    iter_count = 0

    while not env.is_finished():
        # cannot figure out the clauses in limited step
        if (iter_count > cmd_args.episode_length):
            if cmd_args.reward_type == "only_success":
                final_reward = get_final_reward(env)
                if final_reward == -1:
                    policy.reward_history = [0.0] * len(policy.reward_history)
                    policy.reward_history.append(-1.0)
                else:
                    policy.reward_history.append(1.0)
                policy.reward_history = torch.tensor(policy.reward_history)
            else:
                policy.reward_history.append(get_final_reward(env))
                policy.prob_history = torch.cat(
                    [policy.prob_history,
                     torch.tensor([1.0])])
            break

        iter_count += 1
        env = policy(env)

        if cmd_args.sub_loss:
            if not env.is_finished():
                retrain_prob = (env.obj_poss_left[-2] -
                                env.obj_poss_left[-1]) / env.obj_poss_left[0]
                retrain_prob = max(0.0, retrain_prob)
                decision = np.random.choice([0, 1],
                                            1,
                                            p=[1 - retrain_prob, retrain_prob])
                if decision[0] == 1:
                    retrain_list.append(
                        (deepcopy(env.data), deepcopy(env.clauses)))

    logging.info(policy.reward_history)
    logging.info(policy.prob_history)
    logging.info(env.clauses)
    return env, retrain_list
示例#10
0
def optimize_model_DQ(memory, policy_net, target_net, optimizer):

    batch_size = cmd_args.batch_size
    gamma = cmd_args.gamma
    if len(memory) < cmd_args.batch_size:
        return

    transitions = memory.sample(batch_size)
    # logging.info(f"Transitions: {transitions}")

    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)

    # state_batch = batch.state
    # action_batch = torch.stack(batch.action).view(-1)
    # reward_batch = torch.stack(batch.reward).view(-1)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = []
    for state, action, reward in zip(batch.state, batch.action, batch.reward):
        state_action_values.append(policy_net(state).gather(1, action))
    state_action_values = torch.cat(state_action_values)

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1)[0].
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.

    next_state_values = []
    for state in batch.next_state:

        if type(state) == type(None):
            next_state_values.append(torch.tensor([0.0]))
        else:
            unary_clauses_idx, binary_clauses_idx = state.get_clauses_idx()
            clauses_idx = unary_clauses_idx + binary_clauses_idx
            next_idx_selected = state.idx_selected

            selections = []
            for idx in range(len(clauses_idx)):
                if idx not in next_idx_selected:
                    selections.append(idx)

            next_state_values.append(
                target_net(state)[:, selections].max(1)[0].detach())
    next_state_values = torch.cat(next_state_values)

    # Compute the expected Q values
    expected_state_action_values = (next_state_values * gamma) + torch.tensor(
        batch.reward)

    # Compute Huber loss
    loss = F.smooth_l1_loss(state_action_values,
                            expected_state_action_values.view(-1, 1))
    logging.info(f"loss: {loss}")

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.named_parameters():
        if not type(param[1].grad) == type(None):
            param[1].grad.data.clamp_(-1, 1)
            # logging.info(f"name: {param[0]}, grad:{param[1].grad.data}]")
    optimizer.step()
示例#11
0
def train(dataset, graphs, config):

    construct_new = False

    if os.path.exists(cmd_args.model_save_dir):
        list_of_files = glob.glob(cmd_args.model_save_dir + '/*')
        if not list_of_files == []:
            latest_file = max(list_of_files, key=os.path.getctime)

            model = torch.load(latest_file)
            DC.steps_done = model.steps_done
        else:
            construct_new = True
    else:
        construct_new = True

    if construct_new:
        decoder = ClauseDecoder()
        policy = DQPolicy(dataset, decoder)
        current_it = 0
        memory = ReplayMemory(10000)
        optimizer = optim.RMSprop(policy.parameters(), lr=cmd_args.lr)
        DC.steps_done = 0
        model = Learning_Model(decoder, policy, memory, optimizer, current_it,
                               DC.steps_done)

    target_decoder = type(model.decoder)()
    target = DQPolicy(dataset, target_decoder)
    target.load_state_dict(model.policy.state_dict())
    target.eval()

    data_loader = DataLoader(dataset)

    for it in range(cmd_args.episode_iter):
        logging.info(f"training iteration: {it}")
        success_ct = 0
        total_loss = 0.0
        total_ct = 0

        if model.current_it > it:
            continue

        for data_point, ct in zip(data_loader, tqdm(range(len(data_loader)))):

            logging.info(f"task ct: {ct}")

            graph = graphs[data_point.graph_id]
            suc = episode(model.policy, target, data_point, graph, config,
                          dataset.attr_encoder, model.memory, total_ct,
                          model.optimizer)

            total_ct += 1
            if suc:
                success_ct += 1

        logging.info(f"success count: {success_ct}")

        # if it % cmd_args.save_num == 0:
        #     model.steps_done = DC.steps_done
        #     model.current_it = it
        #     torch.save(model, cmd_args.model_path)

        if it % cmd_args.save_num == 0:
            model.steps_done = DC.steps_done
            # model.eps = eps
            model.current_it = it
            model_name = f"model_{it}.pkl"
            model_path = os.path.join(cmd_args.model_save_dir, model_name)
            torch.save(model, model_path)

    print('Complete')
示例#12
0
    cts = []
    progs = []
    suc = 0
    total = 0

    scene_path = os.path.join(
        os.path.abspath(__file__ + "../../../../data/processed_dataset/raw/"),
        "3_1_1_1_1_testing.json")
    with open(scene_path, "r") as scene_file:
        scenes = json.load(scene_file)

    start = time()
    for sct, scene in enumerate(scenes):
        print(f"scene ct: {sct}")
        logging.info(f"scene ct: {sct}")
        for target in range(len(scene["objects"])):
            logging.info(f"total ct: {total}")
            total += 1

            # prog = dfs_prog(scene, config, target, max_depth)
            prog = eu_solve_prog(scene, config, target, max_depth)

            logging.info(prog)
            print(f"suc {not type(prog) == type(None)}")

            end = time()
            print(f"time used {end - start}")
            logging.info(f"time used {end - start}")

            if not type(prog) == type(None):
示例#13
0
    cont_res_dir = os.path.abspath(
        os.path.join(data_dir, f"eval_result/{cmd_args.cont_res_name}"))
    if not os.path.exists(cont_res_dir):
        os.mkdir(cont_res_dir)

    raw_path = os.path.abspath(
        os.path.join(data_dir, "./processed_dataset/raw"))
    scenes_path = os.path.abspath(
        os.path.join(raw_path, cmd_args.scene_file_name))
    graphs_path = os.path.join(raw_path, cmd_args.graph_file_name)

    # update the cmd_args corresponding to the info we have
    model = torch.load(cmd_args.model_path, map_location=cmd_args.device)

    graphs, scene_dataset = create_dataset(data_dir, scenes_path, graphs_path)
    logging.info("start cont training")

    dataloader = DataLoader(scene_dataset)
    start_time = time.time()
    for ct, datapoint in enumerate(dataloader):
        cont_single(ct,
                    datapoint,
                    model,
                    graphs,
                    cont_res_dir,
                    scene_dataset.attr_encoder,
                    scene_dataset.config,
                    save=True,
                    n=10)
    end_time = time.time()
示例#14
0
    attr_encoder = Encoder(config)

    for scene in scenes:
        for target_id in range(len(scene["objects"])):
            graph = Graph(config, scene, target_id)
            graphs.append(graph)
    
    with open(graphs_path, 'wb') as graphs_file:
        pickle.dump(graphs, graphs_file) 

    root = os.path.join(data_dir, "./processed_dataset")
    scene_dataset = SceneDataset(root, config)

    if os.path.exists(cmd_args.model_path) and os.path.getsize(cmd_args.model_path) > 0:
        refrl = torch.load(cmd_args.model_path)
        logging.info("Loaded refrl model")
    else:
        refrl = RefRL(scene_dataset, config, graphs)
        logging.info("Constructed refrl model")

    start_time = time.time()
    logging.info(f"Start {cmd_args.phase}")
    if cmd_args.phase == "training": 
        fit(refrl)
    else:
        test (refrl, "train")
        test (refrl, "test")
    end_time = time.time()
    logging.info(f"Finished {cmd_args.phase} in {end_time - start_time}")

    print("Done")
def cont_single(file_name,
                datapoint,
                policy,
                graphs,
                save_dir,
                attr_encoder,
                config,
                save=True,
                n=10):

    policy = deepcopy(policy)
    policy.train()
    optimizer = torch.optim.Adam(policy.parameters(), lr=cmd_args.lr)

    success_progs = []
    all_progs = []

    global sub_ct
    eps = cmd_args.eps
    graph = graphs[datapoint.graph_id]

    for it in range(cmd_args.episode_iter):
        sub_ct = 0
        env, loss = fit_one(policy, datapoint, graph, eps, attr_encoder,
                            config)
        if env.success:
            success_progs.append(env.clauses)

        if not env.possible:
            all_progs.append((env.clauses[:-1], env.obj_poss_left[-2]))

        if it % cmd_args.batch_size == 0:
            optimizer.zero_grad()

        loss.backward()

        if it % cmd_args.batch_size == 0:
            optimizer.step()

        if len(success_progs) > 0:
            break

    if save:
        file_path = os.path.join(save_dir, str(file_name))
        with open(file_path, 'w') as suc_prog_file:
            json.dump(success_progs, suc_prog_file)

    logging.info(f"success: {success_progs}")
    # print(f"all: {all_progs}")
    # if not success, explot the best selections.
    scene_interpreter = SceneInterp(graph.scene, config)

    if len(success_progs) == 0:
        logging.info("start exploit")
        # take second element for sort
        candidates = sorted(all_progs, key=lambda x: x[1])
        for prog, obj_left in candidates:
            # no progress
            if len(prog) == 0:
                logging.info("empty prog to exploit, stop")
                break

            logging.info(f"prog: {prog}")
            logging.info(f"obj left: {obj_left}")
            prog_left = scene_interpreter.interp_enum(prog, datapoint.y)
            if not type(prog_left) == type(None):
                print(prog_left)
                prog.append(prog_left)
                success_progs.append(prog)
                logging.info(f"We found: {success_progs}")
                break

    return success_progs
示例#16
0
def cont_single(file_name,
                datapoint,
                model,
                graphs,
                save_dir,
                attr_encoder,
                config,
                save=True,
                n=10):

    file_path = os.path.join(save_dir, str(file_name))
    exist = os.path.exists(file_path)
    logging.info(f"task: {file_name}, {exist}")
    if exist:
        logging.info("skip")
        return

    total_ct = 0
    memory = ReplayMemory(10000)
    method = None

    policy = copy.deepcopy(model.policy)
    target = copy.deepcopy(policy)
    target.eval()

    optimizer = optim.RMSprop(policy.parameters(), lr=cmd_args.lr)

    success_progs = []
    all_progs = []

    eps = cmd_args.eps
    graph = graphs[datapoint.graph_id]

    for it in range(cmd_args.test_iter):

        suc, prog, obj_left = episode(policy, target, datapoint, graph, config,
                                      attr_encoder, memory, total_ct,
                                      optimizer)
        total_ct += 1

        if suc:
            success_progs.append(prog)
            method = "model"
            break

        if not suc:
            all_progs.append((prog, obj_left))

    logging.info(f"success: {success_progs}")

    # if not success, explot the best selections.
    scene_interpreter = SceneInterp(graph.ground_truth_scene, config)

    if len(success_progs) == 0:
        logging.info("start exploit")

        # take second element for sort
        candidates = sorted(all_progs, key=lambda x: x[1])
        for prog, obj_left in candidates:
            # no progress
            if len(prog) == 0:
                logging.info("empty prog to exploit, stop")
                break

            logging.info(f"prog: {prog}")
            logging.info(f"obj left: {obj_left}")
            prog_left = scene_interpreter.interp_enum(prog,
                                                      datapoint.y,
                                                      max_depth=2)

            if not type(prog_left) == type(None):
                print(prog_left)
                prog.append(prog_left)
                success_progs.append(prog)
                method = f"exploit: {len(prog_left)}"
                logging.info(
                    f"We found: {success_progs} at depth {len(prog_left)}")
                break

    if save:
        file_path = os.path.join(save_dir, str(file_name))
        with open(file_path, 'w') as suc_prog_file:
            res = {}
            res['prog'] = success_progs
            res['method'] = method
            json.dump(res, suc_prog_file)

    return success_progs
示例#17
0
def train(dataset, graphs, config):

    if os.path.exists(
            cmd_args.model_path) and os.path.getsize(cmd_args.model_path) > 0:
        model = torch.load(cmd_args.model_path)
        DC.steps_done = model.steps_done
    else:
        decoder_name = "GlobalDecoder"
        policy = DQPolicy(dataset, decoder_name)
        decoder = policy.decoder
        current_it = 0
        memory = ReplayMemory(10000)
        optimizer = optim.RMSprop(policy.parameters(), lr=cmd_args.lr)
        DC.steps_done = 0
        model = Learning_Model(decoder, policy, memory, optimizer, current_it,
                               DC.steps_done)

    decoder_name = str(type(model.policy.decoder).__name__)
    target = DQPolicy(dataset, decoder_name)
    target_decoder = target.decoder
    target.load_state_dict(model.policy.state_dict())
    target.eval()

    data_loader = DataLoader(dataset)

    for it in range(cmd_args.episode_iter):
        logging.info(f"training iteration: {it}")
        success_ct = 0
        total_loss = 0.0
        total_ct = 0

        if model.current_it > it:
            continue

        for data_point, ct in zip(data_loader, tqdm(range(len(data_loader)))):

            logging.info(f"task ct: {ct}")

            graph = graphs[data_point.graph_id]
            suc = episode(model.policy, target, data_point, graph, config,
                          dataset.attr_encoder, model.memory, total_ct,
                          model.optimizer)

            total_ct += 1
            if suc:
                success_ct += 1

        logging.info(f"success count: {success_ct}")

        # if it % cmd_args.save_num == 0:
        #     model.steps_done = DC.steps_done
        #     model.current_it = it
        #     torch.save(model, cmd_args.model_path)

        if it % cmd_args.save_num == 0:
            model.steps_done = DC.steps_done
            # model.eps = eps
            model.current_it = it
            model_name = f"model_{it}.pkl"
            model_path = os.path.join(cmd_args.model_save_dir, model_name)
            torch.save(model, model_path)

    print('Complete')