예제 #1
0
def load_terminal_design_data(raw_dataset_path, grammar_file):
    graphs = rd.load_graphs(grammar_file)
    rules = [rd.create_rule_from_graph(g) for g in graphs]

    all_labels = set()
    for rule in rules:
        for node in rule.lhs.nodes:
            all_labels.add(node.attrs.label)
    all_labels = sorted(list(all_labels))

    preprocessor = Preprocessor(all_labels=all_labels)

    with open(raw_dataset_path, newline='') as log_file:
        reader = csv.DictReader(log_file)

        all_link_features = []
        all_link_adj = []
        all_results = []
        max_nodes = 0
        for row in reader:
            rule_seq = ast.literal_eval(row['rule_seq'])
            result = float(row['result'])

            all_results.append(result)

            # Build a robot from the rule sequence
            robot_graph = make_initial_graph()
            for r in rule_seq:
                matches = rd.find_matches(rules[r].lhs, robot_graph)
                # Always use the first match
                robot_graph = rd.apply_rule(rules[r], robot_graph, matches[0])

            adj_matrix, link_features, _ = preprocessor.preprocess(robot_graph)

            all_link_features.append(link_features)
            all_link_adj.append(adj_matrix)

            max_nodes = max(max_nodes, adj_matrix.shape[0])

        all_adj_matrix_pad, all_link_features_pad, all_masks = [], [], []
        for adj_matrix, link_features in zip(all_link_adj, all_link_features):
            adj_matrix_pad, link_features_pad, masks = preprocessor.pad_graph(
                adj_matrix, link_features, max_nodes=max_nodes)
            all_adj_matrix_pad.append(adj_matrix_pad)
            all_link_features_pad.append(link_features_pad)
            all_masks.append(masks)

    return all_link_features_pad, all_adj_matrix_pad, all_masks, all_results
예제 #2
0
def build_robot(args):
    graphs = rd.load_graphs(args.grammar_file)
    rules = [rd.create_rule_from_graph(g) for g in graphs]

    rule_sequence = [int(s.strip(",")) for s in args.rule_sequence]

    graph = make_initial_graph()
    for r in rule_sequence:
        matches = rd.find_matches(rules[r].lhs, graph)
        if matches:
            graph = rd.apply_rule(rules[r], graph, matches[0])

    robot = build_normalized_robot(graph)
    finalize_robot(robot)

    return robot
def sample_design(args, task_id, seed, env, V, eps, results_queue, time_queue,
                  done_event):
    tt0 = time.time()

    random.seed(seed)

    valid = False
    samples = []
    while not valid:
        state = make_initial_graph()
        rule_seq = []
        no_action_flag = False
        for _ in range(args.depth):
            action, step_type = select_action(env, V, state, eps)
            if action is None:
                no_action_flag = True
                break
            rule_seq.append(action)
            next_state = env.transite(state, action)
            state = next_state
            if not has_nonterminals(state):
                break

        valid = env.is_valid(state)

        if not valid:
            # update the invalid sample's count
            if no_action_flag:
                info = 'no_action'
            elif has_nonterminals(state):
                info = 'step_exceeded'
            else:
                info = 'self_collision'
            samples.append(Sample(task_id, rule_seq, -2.0, info))
        else:
            samples.append(
                Sample(task_id, rule_seq, predict(V, state), info='valid'))

    tt = time.time() - tt0
    time_queue.put(tt)

    results_queue.put(samples)

    done_event.wait()
예제 #4
0
 def __init__(self,
              task,
              rules,
              seed=0,
              mpc_num_processes=8,
              enable_reward_oracle=False,
              preprocessor=None):
     self.task = task
     self.rules = rules
     self.seed = seed
     self.rng = random.Random(seed)
     self.mpc_num_processes = mpc_num_processes
     self.enable_reward_oracle = enable_reward_oracle
     if self.enable_reward_oracle:
         assert preprocessor is not None
         self.preprocessor = preprocessor
         self.load_reward_oracle()
     self.initial_state = make_initial_graph()
     self.result_cache = dict()
     self.state = None
     self.rule_seq = []
def search_algo_1(args):
    # iniailize random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # initialize/load
    # TODO: use 80 to fit the input of trained MPC GNN, use args.depth * 3 later for real mpc
    max_nodes = 80
    task_class = getattr(tasks, args.task)
    task = task_class()
    graphs = rd.load_graphs(args.grammar_file)
    rules = [rd.create_rule_from_graph(g) for g in graphs]

    # state preprocessor
    # Find all possible link labels, so they can be one-hot encoded
    all_labels = set()
    for rule in rules:
        for node in rule.lhs.nodes:
            all_labels.add(node.attrs.require_label)
    all_labels = sorted(list(all_labels))
    global preprocessor
    preprocessor = Preprocessor(max_nodes=max_nodes, all_labels=all_labels)

    # initialize the env
    env = RobotGrammarEnv(task,
                          rules,
                          enable_reward_oracle=True,
                          preprocessor=preprocessor)

    # initialize Value function
    device = 'cpu'
    state = env.reset()
    sample_adj_matrix, sample_features, sample_masks = preprocessor.preprocess(
        state)
    num_features = sample_features.shape[1]
    V = Net(max_nodes=max_nodes, num_channels=num_features,
            num_outputs=1).to(device)

    # load pretrained V function
    if args.load_V_path is not None:
        V.load_state_dict(torch.load(args.load_V_path))
        print_info('Loaded pretrained V function from {}'.format(
            args.load_V_path))

    # initialize target V_hat look up table
    V_hat = dict()

    # load pretrained V_hat
    if args.load_Vhat_path is not None:
        V_hat_fp = open(args.load_Vhat_path, 'rb')
        V_hat = pickle.load(V_hat_fp)
        V_hat_fp.close()
        print_info('Loaded pretrained Vhat from {}'.format(
            args.load_Vhat_path))

    # initialize the seen states pool
    states_pool = StatesPool(capacity=args.states_pool_capacity)
    all_sample_designs = []

    # explored designs
    designs = []
    design_rewards = []

    # load previously explored designs
    if args.load_designs_path is not None:
        fp_csv = open(args.load_designs_path, newline='')
        reader = csv.DictReader(fp_csv)
        for row in reader:
            rule_seq = ast.literal_eval(row['rule_seq'])
            reward = float(row['reward'])
            state = make_initial_graph()
            for i in range(len(rule_seq)):
                state = env.transite(state, rule_seq[i])
            designs.append(state)
            design_rewards.append(reward)
            if not np.isclose(V_hat[hash(state)], reward):
                print(rule_seq)
                print(V_hat[hash(state)], reward)
                print_error("Vhat and designs don't match")
        fp_csv.close()
        print_info('Loaded pretrained designs from {}'.format(
            args.load_designs_path))

    if not args.test:
        # initialize save folders and files
        fp_log = open(os.path.join(args.save_dir, 'log.txt'), 'w')
        fp_log.close()
        fp_eval = open(os.path.join(args.save_dir, 'eval.txt'), 'w')
        fp_eval.close()
        design_csv_path = os.path.join(args.save_dir, 'designs.csv')
        fp_csv = open(design_csv_path, 'w')
        fieldnames = ['rule_seq', 'reward']
        writer = csv.DictWriter(fp_csv, fieldnames=fieldnames)
        writer.writeheader()
        fp_csv.close()

        # initialize the optimizer
        global optimizer
        optimizer = torch.optim.Adam(V.parameters(), lr=args.lr)

        # initialize best design rule sequence
        best_design, best_reward = None, -np.inf

        # reward history
        epoch_rew_his = []
        last_checkpoint = -1

        # recording time
        t_sample_sum = 0.

        # record the count for invalid samples
        no_action_samples, step_exceeded_samples = 0, 0

        for epoch in range(args.num_iterations):
            t_start = time.time()

            V.eval()

            # update eps and eps_sample
            if args.eps_schedule == 'linear-decay':
                eps = args.eps_start + epoch / args.num_iterations * (
                    args.eps_end - args.eps_start)
            elif args.eps_schedule == 'exp-decay':
                eps = args.eps_end + (args.eps_start - args.eps_end) * np.exp(
                    -1.0 * epoch / args.num_iterations / args.eps_decay)

            if args.eps_sample_schedule == 'linear-decay':
                eps_sample = args.eps_sample_start + epoch / args.num_iterations * (
                    args.eps_sample_end - args.eps_sample_start)
            elif args.eps_sample_schedule == 'exp-decay':
                eps_sample = args.eps_sample_end + (
                    args.eps_sample_start - args.eps_sample_end) * np.exp(
                        -1.0 * epoch / args.num_iterations /
                        args.eps_sample_decay)

            t_sample, t_update, t_mpc, t_opt = 0, 0, 0, 0

            best_candidate_design, best_candidate_reward = None, -1.0
            best_candidate_state_seq, best_candidate_rule_seq = None, None

            p = random.random()
            if p < eps_sample:
                num_samples = 1
            else:
                num_samples = args.num_samples

            # use e-greedy to sample a design within maximum #steps.
            for _ in range(num_samples):
                valid = False
                while not valid:
                    t0 = time.time()

                    state = env.reset()
                    rule_seq = []
                    state_seq = [state]
                    random_step_cnt, optimal_step_cnt = 0, 0
                    no_action_flag = False
                    for _ in range(args.depth):
                        action, step_type = select_action(env, V, state, eps)
                        if action is None:
                            no_action_flag = True
                            break
                        if step_type == 'random':
                            random_step_cnt += 1
                        elif step_type == 'optimal':
                            optimal_step_cnt += 1
                        rule_seq.append(action)
                        next_state = env.transite(state, action)
                        state_seq.append(next_state)
                        state = next_state
                        if env.is_valid(next_state):
                            valid = True
                            break

                    t_sample += time.time() - t0

                    t0 = time.time()

                    # update the invalid sample's count
                    if not valid:
                        if no_action_flag:
                            no_action_samples += 1
                        else:
                            step_exceeded_samples += 1

                    # update the Vhat for invalid designs
                    if not valid:
                        update_Vhat(V_hat, state_seq, 0.0)
                        # update states pool
                        update_states_pool(states_pool, state_seq)

                    # if valid but has been explored as a valid design before, then put in state pool but resample it
                    if valid and (hash(state)
                                  in V_hat) and (V_hat(hash(state)) > 1e-3):
                        update_Vhat(V_hat, state_seq, V_hat[hash(state)])
                        update_states_pool(states_pool, state_seq)
                        valid = False

                    # record the sampled design
                    all_sample_designs.append(rule_seq)

                    t_update += time.time() - t0

                predicted_value = predict(V, state)
                if predicted_value > best_candidate_reward:
                    best_candidate_design, best_candidate_reward = state, predicted_value
                    best_candidate_rule_seq, best_candidate_state_seq = rule_seq, state_seq

            t0 = time.time()

            _, reward = env.get_reward(best_candidate_design)

            t_mpc += time.time() - t0

            # save the design and the reward in the list
            designs.append(best_candidate_rule_seq)
            design_rewards.append(reward)

            # update best design
            if reward > best_reward:
                best_design, best_reward = best_candidate_rule_seq, reward
                print_info(
                    'new best: reward = {:.4f}, predicted reward = {:.4f}, num_samples = {}'
                    .format(reward, best_candidate_reward, num_samples))

            t0 = time.time()

            # update V_hat for the valid design
            update_Vhat(V_hat, best_candidate_state_seq, reward)

            # update states pool for the valid design
            update_states_pool(states_pool, best_candidate_state_seq)

            t_update += time.time() - t0

            t0 = time.time()

            # optimize
            V.train()
            total_loss = 0.0
            for _ in range(args.opt_iter):
                minibatch = states_pool.sample(
                    min(len(states_pool), args.batch_size))

                train_adj_matrix, train_features, train_masks, train_reward = [], [], [], []
                for robot_graph in minibatch:
                    hash_key = hash(robot_graph)
                    target_reward = V_hat[hash_key]
                    adj_matrix, features, masks = preprocessor.preprocess(
                        robot_graph)
                    train_adj_matrix.append(adj_matrix)
                    train_features.append(features)
                    train_masks.append(masks)
                    train_reward.append(target_reward)

                train_adj_matrix_torch = torch.tensor(train_adj_matrix)
                train_features_torch = torch.tensor(train_features)
                train_masks_torch = torch.tensor(train_masks)
                train_reward_torch = torch.tensor(train_reward)

                optimizer.zero_grad()
                output, loss_link, loss_entropy = V(train_features_torch,
                                                    train_adj_matrix_torch,
                                                    train_masks_torch)
                loss = F.mse_loss(output[:, 0], train_reward_torch)
                loss.backward()
                total_loss += loss.item()
                optimizer.step()

            t_opt += time.time() - t0

            t_end = time.time()

            t_sample_sum += t_sample

            # logging
            if (epoch + 1
                ) % args.log_interval == 0 or epoch + 1 == args.num_iterations:
                iter_save_dir = os.path.join(args.save_dir,
                                             '{}'.format(epoch + 1))
                os.makedirs(os.path.join(iter_save_dir), exist_ok=True)
                # save model
                save_path = os.path.join(iter_save_dir, 'V_model.pt')
                torch.save(V.state_dict(), save_path)
                # save V_hat
                save_path = os.path.join(iter_save_dir, 'V_hat')
                fp = open(save_path, 'wb')
                pickle.dump(V_hat, fp)
                fp.close()
                # save all_sampled_designs
                save_path = os.path.join(iter_save_dir, 'all_sampled_designs')
                fp = open(save_path, 'wb')
                pickle.dump(all_sample_designs, fp)
                fp.close()
                # save explored design and its reward
                fp_csv = open(design_csv_path, 'a')
                fieldnames = ['rule_seq', 'reward']
                writer = csv.DictWriter(fp_csv, fieldnames=fieldnames)
                for i in range(last_checkpoint + 1, len(designs)):
                    writer.writerow({
                        'rule_seq': str(designs[i]),
                        'reward': design_rewards[i]
                    })
                last_checkpoint = len(designs) - 1
                fp_csv.close()

            epoch_rew_his.append(reward)

            avg_loss = total_loss / args.depth
            len_his = min(len(epoch_rew_his), 30)
            avg_reward = np.sum(epoch_rew_his[-len_his:]) / len_his
            print('Epoch {}: T_sample = {:.2f}, T_update = {:.2f}, T_mpc = {:.2f}, T_opt = {:.2f}, eps = {:.3f}, eps_sample = {:.3f}, #samples = {} = {}, training loss = {:.4f}, predicted_reward = {:.4f}, reward = {:.4f}, last 30 epoch reward = {:.4f}, best reward = {:.4f}'.format(\
                epoch, t_sample, t_update, t_mpc, t_opt, eps, eps_sample, num_samples, \
                avg_loss, best_candidate_reward, reward, avg_reward, best_reward))

            fp_log = open(os.path.join(args.save_dir, 'log.txt'), 'a')
            fp_log.write('eps = {:.4f}, eps_sample = {:.4f}, num_samples = {}, T_sample = {:4f}, T_update = {:4f}, T_mpc = {:.4f}, T_opt = {:.4f}, loss = {:.4f}, predicted_reward = {:.4f}, reward = {:.4f}, avg_reward = {:.4f}\n'.format(\
                eps, eps_sample, num_samples, t_sample, t_update, t_mpc, t_opt, avg_loss, best_candidate_reward, reward, avg_reward))
            fp_log.close()

            if (epoch + 1) % args.log_interval == 0:
                print_info(
                    'Avg sampling time for last {} epoch: {:.4f} second'.
                    format(args.log_interval,
                           t_sample_sum / args.log_interval))
                t_sample_sum = 0.
                invalid_cnt, valid_cnt = 0, 0
                for state in states_pool.pool:
                    if np.isclose(V_hat[hash(state)], 0.):
                        invalid_cnt += 1
                    else:
                        valid_cnt += 1
                print_info(
                    'states_pool size = {}, #valid = {}, #invalid = {}, #valid / #invalid = {}'
                    .format(len(states_pool), valid_cnt, invalid_cnt,
                            valid_cnt / invalid_cnt))
                print_info(
                    'Invalid samples: #no_action_samples = {}, #step_exceeded_samples = {}, #no_action / #step_exceeded = {}'
                    .format(no_action_samples, step_exceeded_samples,
                            no_action_samples / step_exceeded_samples))

            # evaluation
            if args.eval_interval > 0 and (
                (epoch + 1) % args.eval_interval == 0
                    or epoch + 1 == args.num_iterations):
                print_info('-------- Doing evaluation --------')
                print_info('#states = {}'.format(len(states_pool)))
                loss_total = 0.
                for state in states_pool.pool:
                    value = predict(V, state)
                    loss_total += (V_hat[hash(state)] - value)**2
                print_info('Loss = {:.3f}'.format(loss_total /
                                                  len(states_pool)))
                fp_eval = open(os.path.join(args.save_dir, 'eval.txt'), 'a')
                fp_eval.write('epoch = {}, loss = {:.3f}\n'.format(
                    epoch + 1, loss_total / len(states_pool)))
                fp_eval.close()

        save_path = os.path.join(args.save_dir, 'model_state_dict_final.pt')
        torch.save(V.state_dict(), save_path)
    else:
        import IPython
        IPython.embed()

        # test
        V.eval()
        print('Start testing')
        test_epoch = 30
        y0 = []
        y1 = []
        x = []
        for ii in range(0, 11):
            eps = 1.0 - 0.1 * ii

            print('------------------------------------------')
            print('eps = ', eps)

            reward_sum = 0.
            best_reward = -np.inf
            for epoch in range(test_epoch):
                t0 = time.time()

                # use e-greedy to sample a design within maximum #steps.
                vaild = False
                while not valid:
                    state = env.reset()
                    rule_seq = []
                    state_seq = [state]
                    for _ in range(args.depth):
                        action, step_type = select_action(env, V, state, eps)
                        if action is None:
                            break
                        rule_seq.append(action)
                        next_state = env.transite(state, action)
                        state_seq.append(next_state)
                        if env.is_valid(state_next):
                            valid = True
                            break
                        state = next_state

                _, reward = env.get_reward(state)
                reward_sum += reward
                best_reward = max(best_reward, reward)
                print(
                    f'design {epoch}: reward = {reward}, time = {time.time() - t0}'
                )

            print('test avg reward = ', reward_sum / test_epoch)
            print('best reward found = ', best_reward)
            x.append(eps)
            y0.append(reward_sum / test_epoch)
            y1.append(best_reward)

        import matplotlib.pyplot as plt
        fig, ax = plt.subplots(1, 2, figsize=(10, 5))
        ax[0].plot(x, y0)
        ax[0].set_title('Avg Reward')
        ax[0].set_xlabel('eps')
        ax[0].set_ylabel('reward')

        ax[1].plot(x, y1)
        ax[0].set_title('Best Reward')
        ax[0].set_xlabel('eps')
        ax[0].set_ylabel('reward')

        plt.show()
예제 #6
0
def load_partial_design_data(raw_dataset_path, grammar_file):
    graphs = rd.load_graphs(grammar_file)
    rules = [rd.create_rule_from_graph(g) for g in graphs]

    all_labels = set()
    for rule in rules:
        for node in rule.lhs.nodes:
            all_labels.add(node.attrs.label)
    all_labels = sorted(list(all_labels))

    preprocessor = Preprocessor(all_labels=all_labels)

    with open(raw_dataset_path, newline='') as log_file:
        reader = csv.DictReader(log_file)

        memory = dict()
        idx = 0
        for row in reader:
            if idx % 1000 == 0:
                print(f'processing idx = {idx}')
            idx += 1

            rule_seq = ast.literal_eval(row['rule_seq'])
            result = float(row['result'])

            # Build a robot from the rule sequence
            robot_graph = make_initial_graph()
            update_memory(memory, preprocessor, robot_graph, result)
            for r in rule_seq:
                matches = rd.find_matches(rules[r].lhs, robot_graph)
                # Always use the first match
                robot_graph = rd.apply_rule(rules[r], robot_graph, matches[0])
                update_memory(memory, preprocessor, robot_graph, result)

        initial_robot_graph = make_initial_graph()
        print('#hit on initial state: ',
              memory[hash(initial_robot_graph)]['hit'])

        all_link_features = []
        all_link_adj = []
        all_results = []
        max_nodes = 0
        for _, robot_hash_key in enumerate(memory):
            adj_matrix, link_features, result = \
                memory[robot_hash_key]['adj_matrix'], memory[robot_hash_key]['link_features'], memory[robot_hash_key]['V']

            all_link_features.append(link_features)
            all_link_adj.append(adj_matrix)
            all_results.append(result)

            max_nodes = max(max_nodes, adj_matrix.shape[0])

        all_adj_matrix_pad, all_link_features_pad, all_masks = [], [], []
        for adj_matrix, link_features in zip(all_link_adj, all_link_features):
            adj_matrix_pad, link_features_pad, masks = preprocessor.pad_graph(
                adj_matrix, link_features, max_nodes=max_nodes)
            all_adj_matrix_pad.append(adj_matrix_pad)
            all_link_features_pad.append(link_features_pad)
            all_masks.append(masks)

    return all_link_features_pad, all_adj_matrix_pad, all_masks, all_results
def search_algo(args):
    # iniailize random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.set_num_threads(1)

    # initialize/load
    # TODO: use 80 to fit the input of trained MPC GNN, use args.depth * 3 later for real mpc
    max_nodes = 80
    task_class = getattr(tasks, args.task)
    if not args.no_noise:
        task = task_class()
    else:
        task = task_class(force_std=0.0, torque_std=0.0)

    graphs = rd.load_graphs(args.grammar_file)
    rules = [rd.create_rule_from_graph(g) for g in graphs]

    # initialize preprocessor
    # Find all possible link labels, so they can be one-hot encoded
    all_labels = set()
    for rule in rules:
        for node in rule.lhs.nodes:
            all_labels.add(node.attrs.require_label)
    all_labels = sorted(list(all_labels))

    global preprocessor
    preprocessor = Preprocessor(max_nodes=max_nodes, all_labels=all_labels)

    # initialize the env
    env = RobotGrammarEnv(task,
                          rules,
                          seed=args.seed,
                          mpc_num_processes=args.mpc_num_processes)

    # initialize Value function
    device = 'cpu'
    state = env.reset()
    sample_adj_matrix, sample_features, sample_masks = preprocessor.preprocess(
        state)
    num_features = sample_features.shape[1]
    V = Net(max_nodes=max_nodes, num_channels=num_features,
            num_outputs=1).to(device)
    # V.share_memory()

    # load pretrained V function
    if args.load_V_path is not None:
        V.load_state_dict(torch.load(args.load_V_path))
        print_info('Loaded pretrained V function from {}'.format(
            args.load_V_path))

    # initialize target V_hat look up table
    V_hat = dict()

    # load pretrained V_hat
    if args.load_Vhat_path is not None:
        V_hat_fp = open(args.load_Vhat_path, 'rb')
        V_hat = pickle.load(V_hat_fp)
        V_hat_fp.close()
        print_info('Loaded pretrained Vhat from {}'.format(
            args.load_Vhat_path))

    # initialize invalid_cnt
    invalid_his = dict()
    num_invalid_samples, num_valid_samples = 0, 0
    repeated_cnt = 0

    # initialize the seen states pool
    states_pool = StatesPool(capacity=args.states_pool_capacity)

    # explored designs
    designs = []
    design_rewards = []
    design_opt_seeds = []

    # record prediction error
    prediction_error_sum = 0.0

    if not args.test:
        # initialize save folders and files
        fp_log = open(os.path.join(args.save_dir, 'log.txt'), 'w')
        fp_log.close()
        fp_eval = open(os.path.join(args.save_dir, 'eval.txt'), 'w')
        fp_eval.close()
        design_csv_path = os.path.join(args.save_dir, 'designs.csv')
        fp_csv = open(design_csv_path, 'w')
        fieldnames = ['rule_seq', 'reward', 'opt_seed']
        writer = csv.DictWriter(fp_csv, fieldnames=fieldnames)
        writer.writeheader()
        fp_csv.close()

        # initialize the optimizer
        global optimizer
        optimizer = torch.optim.Adam(V.parameters(), lr=args.lr)

        # initialize best design rule sequence
        best_design, best_reward = None, -np.inf

        # reward history
        epoch_rew_his = []
        last_checkpoint = -1

        # recording time
        t_sample_sum = 0.

        # record the count for invalid samples
        no_action_samples, step_exceeded_samples, self_collision_samples = 0, 0, 0

        # define state0
        state0 = make_initial_graph()
        for epoch in range(args.num_iterations):
            t_start = time.time()

            V.eval()

            # update eps and eps_sample
            if args.eps_schedule == 'linear-decay':
                eps = args.eps_start + epoch / args.num_iterations * (
                    args.eps_end - args.eps_start)
            elif args.eps_schedule == 'exp-decay':
                eps = args.eps_end + (args.eps_start - args.eps_end) * np.exp(
                    -1.0 * epoch / args.num_iterations / args.eps_decay)

            if args.eps_sample_schedule == 'linear-decay':
                eps_sample = args.eps_sample_start + epoch / args.num_iterations * (
                    args.eps_sample_end - args.eps_sample_start)
            elif args.eps_sample_schedule == 'exp-decay':
                eps_sample = args.eps_sample_end + (
                    args.eps_sample_start - args.eps_sample_end) * np.exp(
                        -1.0 * epoch / args.num_iterations /
                        args.eps_sample_decay)

            t_sample, t_mpc, t_opt = 0, 0, 0

            t0 = time.time()

            selected_design, selected_reward = None, -np.inf
            selected_state_seq, selected_rule_seq = None, None

            p = random.random()
            if p < eps_sample:
                num_samples = 1
            else:
                num_samples = args.num_samples

            # use e-greedy to sample a design within maximum #steps.
            results_queue = Queue()
            done_event = Event()
            time_queue = Queue()
            tt0 = time.time()
            processes = []
            for task_id in range(num_samples):
                seed = random.getrandbits(32)
                p = Process(target=sample_design,
                            args=(args, task_id, seed, env, V, eps,
                                  results_queue, time_queue, done_event))
                p.start()
                processes.append(p)
            t_start = time.time() - tt0

            sampled_rewards = [0.0 for _ in range(num_samples)]
            thread_times = []
            t_update = 0
            for _ in range(num_samples):
                samples = results_queue.get()
                thread_time = time_queue.get()
                thread_times.append(thread_time)
                tt0 = time.time()
                for i in range(len(samples) - 1):
                    assert samples[i].info != 'valid'
                    if samples[i].info == 'no_action':
                        no_action_samples += 1
                    elif samples[i].info == 'step_exceeded':
                        step_exceeded_samples += 1
                    else:
                        self_collision_samples += 1

                    state, state_seq = apply_rules(state0, samples[i].rule_seq,
                                                   env)
                    # update the Vhat for invalid designs
                    update_Vhat(args,
                                V_hat,
                                state_seq,
                                -2.0,
                                invalid=True,
                                invalid_cnt=invalid_his)
                    # update states pool
                    update_states_pool(states_pool, state_seq, V_hat)
                    num_invalid_samples += 1

                assert samples[-1].info == 'valid'
                state, state_seq = apply_rules(state0, samples[-1].rule_seq,
                                               env)
                num_valid_samples += 1
                if samples[-1].predicted_reward > selected_reward:
                    selected_design, selected_reward = state, samples[
                        -1].predicted_reward
                    selected_rule_seq, selected_state_seq = samples[
                        -1].rule_seq, state_seq

                sampled_rewards[
                    samples[-1].task_id] = samples[-1].predicted_reward
                t_update += time.time() - tt0

            done_event.set()

            for p in processes:
                p.join()

            print('thread time = {}'.format(thread_times))
            print('t_update = {}, t_start = {}'.format(t_update, t_start))

            # print('all sampled designs:')
            # print(sampled_rewards)

            t_sample += time.time() - t0

            t0 = time.time()

            repeated = False
            if (hash(selected_design)
                    in V_hat) and (V_hat[hash(selected_design)] > -2.0 + 1e-3):
                repeated = True
                repeated_cnt += 1

            ctrl_seq, reward = env.get_reward(selected_design)

            t_mpc += time.time() - t0

            # save the design and the reward in the list
            designs.append(selected_rule_seq)
            design_rewards.append(reward)
            design_opt_seeds.append(env.last_opt_seed)

            # update best design
            if reward > best_reward:
                best_design, best_reward = selected_rule_seq, reward
                print_info(
                    'new best: reward = {:.4f}, predicted reward = {:.4f}, num_samples = {}'
                    .format(reward, selected_reward, num_samples))

            # update V_hat for the valid design
            update_Vhat(args, V_hat, selected_state_seq, reward)

            # update states pool for the valid design
            update_states_pool(states_pool, selected_state_seq, V_hat)

            t0 = time.time()

            # optimize
            V.train()
            total_loss = 0.0
            for _ in range(args.opt_iter):
                minibatch = states_pool.sample(
                    min(len(states_pool), args.batch_size))

                train_adj_matrix, train_features, train_masks, train_reward = [], [], [], []
                for robot_graph in minibatch:
                    hash_key = hash(robot_graph)
                    target_reward = V_hat[hash_key]
                    adj_matrix, features, masks = preprocessor.preprocess(
                        robot_graph)
                    train_adj_matrix.append(adj_matrix)
                    train_features.append(features)
                    train_masks.append(masks)
                    train_reward.append(target_reward)

                train_adj_matrix_torch = torch.tensor(train_adj_matrix)
                train_features_torch = torch.tensor(train_features)
                train_masks_torch = torch.tensor(train_masks)
                train_reward_torch = torch.tensor(train_reward)

                optimizer.zero_grad()
                output, loss_link, loss_entropy = V(train_features_torch,
                                                    train_adj_matrix_torch,
                                                    train_masks_torch)
                loss = F.mse_loss(output[:, 0], train_reward_torch)
                loss.backward()
                total_loss += loss.item()
                optimizer.step()

            t_opt += time.time() - t0

            t_end = time.time()

            t_sample_sum += t_sample

            # logging
            if (epoch + 1
                ) % args.log_interval == 0 or epoch + 1 == args.num_iterations:
                iter_save_dir = os.path.join(args.save_dir,
                                             '{}'.format(epoch + 1))
                os.makedirs(os.path.join(iter_save_dir), exist_ok=True)
                # save model
                save_path = os.path.join(iter_save_dir, 'V_model.pt')
                torch.save(V.state_dict(), save_path)
                # save V_hat
                save_path = os.path.join(iter_save_dir, 'V_hat')
                fp = open(save_path, 'wb')
                pickle.dump(V_hat, fp)
                fp.close()

            # save explored design and its reward
            fp_csv = open(design_csv_path, 'a')
            fieldnames = ['rule_seq', 'reward', 'opt_seed']
            writer = csv.DictWriter(fp_csv, fieldnames=fieldnames)
            for i in range(last_checkpoint + 1, len(designs)):
                writer.writerow({
                    'rule_seq': str(designs[i]),
                    'reward': design_rewards[i],
                    'opt_seed': design_opt_seeds[i]
                })
            last_checkpoint = len(designs) - 1
            fp_csv.close()

            epoch_rew_his.append(reward)

            avg_loss = total_loss / args.opt_iter
            len_his = min(len(epoch_rew_his), 30)
            avg_reward = np.sum(epoch_rew_his[-len_his:]) / len_his
            prediction_error_sum += (selected_reward - reward)**2
            avg_prediction_error = prediction_error_sum / (epoch + 1)

            if repeated:
                print_white('Epoch {:4}: T_sample = {:5.2f}, T_mpc = {:5.2f}, T_opt = {:5.2f}, eps = {:5.3f}, eps_sample = {:5.3f}, #samples = {:2}, training loss = {:7.4f}, pred_error = {:6.4f}, predicted_reward = {:6.4f}, reward = {:6.4f}, last 30 epoch reward = {:6.4f}, best reward = {:6.4f}'.format(\
                    epoch, t_sample, t_mpc, t_opt, eps, eps_sample, num_samples, \
                    avg_loss, avg_prediction_error, selected_reward, reward, avg_reward, best_reward))
            else:
                print_warning('Epoch {:4}: T_sample = {:5.2f}, T_mpc = {:5.2f}, T_opt = {:5.2f}, eps = {:5.3f}, eps_sample = {:5.3f}, #samples = {:2}, training loss = {:7.4f}, pred_error = {:6.4f}, predicted_reward = {:6.4f}, reward = {:6.4f}, last 30 epoch reward = {:6.4f}, best reward = {:6.4f}'.format(\
                    epoch, t_sample, t_mpc, t_opt, eps, eps_sample, num_samples, \
                    avg_loss, avg_prediction_error, selected_reward, reward, avg_reward, best_reward))

            fp_log = open(os.path.join(args.save_dir, 'log.txt'), 'a')
            fp_log.write('eps = {:.4f}, eps_sample = {:.4f}, num_samples = {}, T_sample = {:4f}, T_mpc = {:.4f}, T_opt = {:.4f}, loss = {:.4f}, predicted_reward = {:.4f}, reward = {:.4f}, avg_reward = {:.4f}\n'.format(\
                eps, eps_sample, num_samples, t_sample, t_mpc, t_opt, avg_loss, selected_reward, reward, avg_reward))
            fp_log.close()

            if (epoch + 1) % args.log_interval == 0:
                print_info(
                    'Avg sampling time for last {} epoch: {:.4f} second'.
                    format(args.log_interval,
                           t_sample_sum / args.log_interval))
                t_sample_sum = 0
                print_info('size of states_pool = {}'.format(len(states_pool)))
                print_info(
                    '#valid samples = {}, #invalid samples = {}, #valid / #invalid = {}'
                    .format(
                        num_valid_samples, num_invalid_samples,
                        num_valid_samples / num_invalid_samples
                        if num_invalid_samples > 0 else 10000.0))
                print_info(
                    'Invalid samples: #no_action_samples = {}, #step_exceeded_samples = {}, #self_collision_samples = {}'
                    .format(no_action_samples, step_exceeded_samples,
                            self_collision_samples))
                max_trials, cnt = 0, 0
                for key in invalid_his.keys():
                    if invalid_his[key] > max_trials:
                        max_trials = invalid_his[key]
                    if invalid_his[key] > args.max_trials:
                        cnt += 1

                print_info(
                    'max invalid_trials = {}, #failed nodes = {}'.format(
                        max_trials, cnt))
                print_info('repeated rate = {}'.format(repeated_cnt /
                                                       (epoch + 1)))

        save_path = os.path.join(args.save_dir, 'model_state_dict_final.pt')
        torch.save(V.state_dict(), save_path)
    else:
        import IPython
        IPython.embed()

        # test
        V.eval()
        print('Start testing')
        test_epoch = 30
        y0 = []
        y1 = []
        x = []
        for ii in range(0, 11):
            eps = 1.0 - 0.1 * ii

            print('------------------------------------------')
            print('eps = ', eps)

            reward_sum = 0.
            best_reward = -np.inf
            for epoch in range(test_epoch):
                t0 = time.time()

                # use e-greedy to sample a design within maximum #steps.
                vaild = False
                while not valid:
                    state = env.reset()
                    rule_seq = []
                    state_seq = [state]
                    for _ in range(args.depth):
                        action, step_type = select_action(env, V, state, eps)
                        if action is None:
                            break
                        rule_seq.append(action)
                        next_state = env.transite(state, action)
                        state_seq.append(next_state)
                        if not has_nonterminals(next_state):
                            valid = True
                            break
                        state = next_state

                _, reward = env.get_reward(state)
                reward_sum += reward
                best_reward = max(best_reward, reward)
                print(
                    f'design {epoch}: reward = {reward}, time = {time.time() - t0}'
                )

            print('test avg reward = ', reward_sum / test_epoch)
            print('best reward found = ', best_reward)
            x.append(eps)
            y0.append(reward_sum / test_epoch)
            y1.append(best_reward)

        import matplotlib.pyplot as plt
        fig, ax = plt.subplots(1, 2, figsize=(10, 5))
        ax[0].plot(x, y0)
        ax[0].set_title('Avg Reward')
        ax[0].set_xlabel('eps')
        ax[0].set_ylabel('reward')

        ax[1].plot(x, y1)
        ax[0].set_title('Best Reward')
        ax[0].set_xlabel('eps')
        ax[0].set_ylabel('reward')

        plt.show()
예제 #8
0
def main(log_file=None, grammar_file=None):
    parser = argparse.ArgumentParser(
        description="Example code for parsing a MCTS log file.")

    if not log_file or not grammar_file:
        parser.add_argument("log_file", type=str, help="Log file (.csv)")
        parser.add_argument("grammar_file",
                            type=str,
                            help="Grammar file (.dot)")
        args = parser.parse_args()
    else:
        args = argparse.Namespace()
        args.grammar_file = grammar_file
        args.log_file = log_file

    graphs = rd.load_graphs(args.grammar_file)
    rules = [rd.create_rule_from_graph(g) for g in graphs]

    # Find all possible link labels, so they can be one-hot encoded
    all_labels = set()
    for rule in rules:
        for node in rule.lhs.nodes:
            all_labels.add(node.attrs.label)
    all_labels = sorted(list(all_labels))

    with open(args.log_file, newline='') as log_file:
        reader = csv.DictReader(log_file)

        all_link_features = []
        all_link_adj = []
        all_results = []
        for row in reader:
            full_rule_seq = ast.literal_eval(row['rule_seq'])
            result = float(row['result'])

            for prefix_len in range(len(full_rule_seq) + 1):
                rule_seq = full_rule_seq[:prefix_len]
                all_results.append(result)

                # Build a robot from the rule sequence
                robot_graph = make_initial_graph()
                for r in rule_seq:
                    matches = rd.find_matches(rules[r].lhs, robot_graph)
                    # Always use the first match
                    robot_graph = rd.apply_rule(rules[r], robot_graph,
                                                matches[0])
                robot = build_normalized_robot(robot_graph)

                # Find the world position and rotation of links
                pos_rot = []
                for i, link in enumerate(robot.links):
                    if link.parent >= 0:
                        parent_pos, parent_rot = pos_rot[link.parent]
                        parent_link_length = robot.links[link.parent].length
                    else:
                        parent_pos, parent_rot = np.zeros(3), np.quaternion(
                            1, 0, 0, 0)
                        parent_link_length = 0

                    offset = np.array(
                        [parent_link_length * link.joint_pos, 0, 0])
                    rel_pos = quaternion.rotate_vectors(parent_rot, offset)
                    rel_rot = np_quaternion(link.joint_rot).conjugate()
                    pos = parent_pos + rel_pos
                    rot = parent_rot * rel_rot
                    pos_rot.append((pos, rot))

                # Generate adjacency matrix
                adj_matrix = np.zeros((len(robot.links), len(robot.links)))
                for i, link in enumerate(robot.links):
                    if link.parent >= 0:
                        adj_matrix[link.parent, i] += 1

                # Generate features for links
                # Note: we can work with either the graph or the robot kinematic tree, but
                # the kinematic tree provides more information
                link_features = []
                for i, link in enumerate(robot.links):
                    world_pos, world_rot = pos_rot[i]
                    world_joint_axis = quaternion.rotate_vectors(
                        world_rot, link.joint_axis)
                    label_vec = np.zeros(len(all_labels))
                    label_vec[all_labels.index(link.label)] = 1

                    link_features.append(
                        np.array([
                            *featurize_link(link), *world_pos,
                            *quaternion_coords(world_rot), *world_joint_axis,
                            *label_vec
                        ]))
                link_features = np.array(link_features)

                all_link_features.append(link_features)
                all_link_adj.append(adj_matrix)

    return all_link_features, all_link_adj, all_results
예제 #9
0
 best_rule_seq = None
 best_designs = []
 for row in reader:
     N += 1
     design = row['rule_seq']
     reward = float(row['reward'])
     if 'opt_seed' in row:
         opt_seed = row['opt_seed']
     else:
         opt_seed = None
     if design not in memory:
         memory[design] = 0
     memory[design] += 1
     rule_seq = ast.literal_eval(row['rule_seq'])
     rule_seqs.append(rule_seq)
     state = make_initial_graph()
     for rule in rule_seq:
         state = env.transite(state, rule)
     if hash(state) not in design_cnt:
         design_cnt[hash(state)] = [0, reward, reward, 0]
     design_cnt[hash(state)][0] += 1
     design_cnt[hash(state)][1] = max(design_cnt[hash(state)][1], reward)
     design_cnt[hash(state)][3] += reward
     if len(best_reward) == 0:
         best_reward = [reward]
     else:
         if reward > best_reward[-1]:
             best_design = state
             best_designs.append(state)
             best_rule_seq = design
             if design_cnt[hash(state)][0] > 1: