Example #1
0
def search(args):
    # initialize the env
    max_nodes = args.depth * 2
    task_class = getattr(tasks, args.task)
    task = task_class()
    graphs = rd.load_graphs(args.grammar_file)
    rules = [rd.create_rule_from_graph(g) for g in graphs]
    env = RobotGrammarEnv(task, rules, seed = args.seed, mpc_num_processes = args.mpc_num_processes)

    # state preprocessor
    # Find all possible link labels, so they can be one-hot encoded
    all_labels = set()
    for rule in rules:
        for node in rule.lhs.nodes:
            all_labels.add(node.attrs.require_label)
    all_labels = sorted(list(all_labels))
    global preprocessor
    preprocessor = Preprocessor(max_nodes = max_nodes, all_labels = all_labels)

    # initialize Q function
    device = 'cpu'
    state = env.reset()
    sample_adj_matrix, sample_features, sample_masks = preprocessor.preprocess(state)
    num_features = sample_features.shape[1]
    Q = Net(max_nodes = max_nodes, num_channels = num_features, num_outputs = len(rules)).to(device)

    # initialize the optimizer
    global optimizer
    optimizer = torch.optim.Adam(Q.parameters(), lr = args.lr)

    # initialize DQN
    memory = ReplayMemory(capacity = 1000000)
    scores = deque(maxlen = 100)
    data = []

    for epoch in range(args.num_iterations):
        done = False
        eps = args.eps_start + epoch / args.num_iterations * (args.eps_end - args.eps_start)
        # eps = 1.0
        while not done:
            state = env.reset()
            total_reward = 0.
            rule_seq = []
            state_seq = []
            for i in range(args.depth):
                action = select_action(env, Q, state, eps)
                rule_seq.append(action)
                if action is None:
                    break
                next_state, reward, done = env.step(action)
                state_seq.append((state, action, next_state, reward, done))
                total_reward += reward
                state = next_state
                if done:
                    break
        for i in range(len(state_seq)):
            memory.push(state_seq[i][0], state_seq[i][1], state_seq[i][2], state_seq[i][3], state_seq[i][4])
            data.append((state_seq[i][0], state_seq[i][1], total_reward))
        scores.append(total_reward)

        loss = 0.0
        for i in range(len(state_seq)):
            loss += optimize(Q, Q, memory, args.batch_size)
        print('epoch ', epoch, ': reward = ', total_reward, ', eps = ', eps, ', Q loss = ', loss)

    # test
    cnt = 0
    for i in range(len(data)):
        if data[i][2] > 0.5:
            y_predict, _, _ = predict(Q, data[i][0])
            print('target = ', data[i][2], ', predicted = ', y_predict[0][data[i][1]])
            cnt += 1
            if cnt == 5:
                break
    cnt = 0
    for i in range(len(data)):
        if data[i][2] < 0.5:
            y_predict, _, _ = predict(Q, data[i][0])
            print('target = ', data[i][2], ', predicted = ', y_predict[0][data[i][1]])
            cnt += 1
            if cnt == 5:
                break
def search_algo_2(args):
    # iniailize random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # initialize/load
    # TODO: use 80 to fit the input of trained MPC GNN, use args.depth * 3 later for real mpc
    max_nodes = 80
    task_class = getattr(tasks, args.task)
    task = task_class()
    graphs = rd.load_graphs(args.grammar_file)
    rules = [rd.create_rule_from_graph(g) for g in graphs]

    # state preprocessor
    # Find all possible link labels, so they can be one-hot encoded
    all_labels = set()
    for rule in rules:
        for node in rule.lhs.nodes:
            all_labels.add(node.attrs.require_label)
    all_labels = sorted(list(all_labels))
    global preprocessor
    preprocessor = Preprocessor(max_nodes = max_nodes, all_labels = all_labels)

    # initialize the env
    env = RobotGrammarEnv(task, rules, enable_reward_oracle = True, preprocessor = preprocessor)

    # initialize Value function
    device = 'cpu'
    state = env.reset()
    sample_adj_matrix, sample_features, sample_masks = preprocessor.preprocess(state)
    num_features = sample_features.shape[1]
    V = Net(max_nodes = max_nodes, num_channels = num_features, num_outputs = 1).to(device)

    # load pretrained V function
    if args.load_V_path is not None:
        V.load_state_dict(torch.load(args.load_V_path))
        print_info('Loaded pretrained V function from {}'.format(args.load_V_path))

    if not args.test:
        # initialize save folders and files
        fp_log = open(os.path.join(args.save_dir, 'log.txt'), 'w')
        fp_log.close()
        design_csv_path = os.path.join(args.save_dir, 'designs.csv')
        fp_csv = open(design_csv_path, 'w')
        fieldnames = ['rule_seq', 'reward']
        writer = csv.DictWriter(fp_csv, fieldnames=fieldnames)
        writer.writeheader()
        fp_csv.close()

        # initialize the optimizer
        global optimizer
        optimizer = torch.optim.Adam(V.parameters(), lr = args.lr)

        # initialize best design
        best_design, best_reward = None, -np.inf
        
        # initialize the seen states pool
        states_pool = []
        
        # initialize visited states
        state_set = set()

        # TODO: load previously explored designs
        
        # explored designs
        designs = []
        design_rewards = []
        
        # reward history
        epoch_rew_his = []

        for epoch in range(args.num_iterations):
            t_start = time.time()

            V.eval()

            t0 = time.time()

            # use e-greedy to sample a design within maximum #steps.
            if args.eps_schedule == 'linear-decay':
                # linear schedule
                eps = args.eps_start + epoch / args.num_iterations * (args.eps_end - args.eps_start)
            elif args.eps_schedule == 'exp-decay':
                # exp schedule
                eps = args.eps_end + (args.eps_start - args.eps_end) * np.exp(-1.0 * epoch / args.num_iterations / args.eps_decay)

            done = False
            while not done:
                state = env.reset()
                rule_seq = []
                state_seq = [state]
                total_reward = 0.
                for _ in range(args.depth):
                    action = select_action(env, V, state, eps)
                    if action is None:
                        break
                    rule_seq.append(action)
                    next_state, reward, done = env.step(action)
                    total_reward += reward
                    state_seq.append(next_state)
                    state = next_state
                    if done:
                        break
            
            # save the design and the reward in the list
            designs.append(rule_seq)
            design_rewards.append(total_reward)

            # update best design
            if total_reward > best_reward:
                best_design, best_reward = rule_seq, total_reward
            
            # update state pool
            for ancestor in state_seq:
                state_hash_key = hash(ancestor)
                if not (state_hash_key in state_set):
                    state_set.add(state_hash_key)
                    states_pool.append(ancestor)

            t1 = time.time()

            # optimize
            V.train()
            total_loss = 0.0
            for _ in range(args.depth):
                minibatch = random.sample(states_pool, min(len(states_pool), args.batch_size))

                train_adj_matrix, train_features, train_masks, train_reward = [], [], [], []
                for robot_graph in minibatch:
                    V_hat = compute_Vhat(robot_graph, env, V)
                    adj_matrix, features, masks = preprocessor.preprocess(robot_graph)
                    train_adj_matrix.append(adj_matrix)
                    train_features.append(features)
                    train_masks.append(masks)
                    train_reward.append(V_hat)
                
                train_adj_matrix_torch = torch.tensor(train_adj_matrix)
                train_features_torch = torch.tensor(train_features)
                train_masks_torch = torch.tensor(train_masks)
                train_reward_torch = torch.tensor(train_reward)
                
                optimizer.zero_grad()
                output, loss_link, loss_entropy = V(train_features_torch, train_adj_matrix_torch, train_masks_torch)
                loss = F.mse_loss(output[:, 0], train_reward_torch)
                loss.backward()
                total_loss += loss.item()
                optimizer.step()

            t2 = time.time()

            # logging
            if (epoch + 1) % args.log_interval == 0 or epoch + 1 == args.num_iterations:
                iter_save_dir = os.path.join(args.save_dir, '{}'.format(epoch + 1))
                os.makedirs(os.path.join(iter_save_dir), exist_ok = True)
                # save model
                save_path = os.path.join(iter_save_dir, 'V_model.pt')
                torch.save(V.state_dict(), save_path)
                # save explored designs and their rewards
                fp_csv = open(design_csv_path, 'a')
                fieldnames = ['rule_seq', 'reward']
                writer = csv.DictWriter(fp_csv, fieldnames=fieldnames)
                for i in range(epoch - args.log_interval + 1, epoch + 1):
                    writer.writerow({'rule_seq': str(designs[i]), 'reward': design_rewards[i]})
                fp_csv.close()

            epoch_rew_his.append(total_reward)

            t_end = time.time()
            avg_loss = total_loss / args.depth
            len_his = min(len(epoch_rew_his), 30)
            avg_reward = np.sum(epoch_rew_his[-len_his:]) / len_his
            print('Epoch {}: Time = {:.2f}, T_sample = {:.2f}, T_opt = {:.2f}, eps = {:.3f}, training loss = {:.4f}, reward = {:.4f}, last 30 epoch reward = {:.4f}, best reward = {:.4f}'.format(epoch, t_end - t_start, t1 - t0, t2 - t1, eps, avg_loss, total_reward, avg_reward, best_reward))
            fp_log = open(os.path.join(args.save_dir, 'log.txt'), 'a')
            fp_log.write('eps = {:.4f}, loss = {:.4f}, reward = {:.4f}, avg_reward = {:.4f}\n'.format(eps, avg_loss, total_reward, avg_reward))
            fp_log.close()

        save_path = os.path.join(args.save_dir, 'model_state_dict_final.pt')
        torch.save(V.state_dict(), save_path)
    else:
        import IPython
        IPython.embed()

        # test
        V.eval()
        print('Start testing')
        test_epoch = 30
        y0 = []
        y1 = []
        x = []
        for ii in range(10):
            eps = 1.0 - 0.1 * ii

            print('------------------------------------------')
            print('eps = ', eps)

            reward_sum = 0.
            best_reward = -np.inf
            for epoch in range(test_epoch):
                t0 = time.time()

                # use e-greedy to sample a design within maximum #steps.
                done = False
                while not done:
                    state = env.reset() 
                    rule_seq = []
                    state_seq = [state]
                    total_reward = 0.
                    for _ in range(args.depth):
                        action = select_action(env, V, state, eps)
                        if action is None:
                            break
                        rule_seq.append(action)
                        next_state, reward, done = env.step(action)
                        total_reward += reward
                        state_seq.append(next_state)
                        state = next_state
                        if done:
                            break

                reward_sum += total_reward
                best_reward = max(best_reward, total_reward)
                print(f'design {epoch}: reward = {total_reward}, time = {time.time() - t0}')

            print('test avg reward = ', reward_sum / test_epoch)
            print('best reward found = ', best_reward)
            x.append(eps)
            y0.append(reward_sum / test_epoch)
            y1.append(best_reward)

        import matplotlib.pyplot as plt
        fig, ax = plt.subplots(1, 2, figsize = (10, 5))
        ax[0].plot(x, y0)
        ax[1].plot(x, y1)
        plt.show()
def search_algo(args):
    # iniailize random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.set_num_threads(1)

    # initialize/load
    task_class = getattr(tasks, args.task)
    if args.no_noise:
        task = task_class(force_std=0.0, torque_std=0.0)
    else:
        task = task_class()
    graphs = rd.load_graphs(args.grammar_file)
    rules = [rd.create_rule_from_graph(g) for g in graphs]

    # initialize preprocessor
    # Find all possible link labels, so they can be one-hot encoded
    all_labels = set()
    for rule in rules:
        for node in rule.lhs.nodes:
            all_labels.add(node.attrs.require_label)
    all_labels = sorted(list(all_labels))

    # TODO: use 80 to fit the input of trained MPC GNN, use args.depth * 3 later for real mpc
    max_nodes = args.max_nodes

    global preprocessor
    # preprocessor = Preprocessor(max_nodes = max_nodes, all_labels = all_labels)
    preprocessor = Preprocessor(all_labels=all_labels)

    # initialize the env
    env = RobotGrammarEnv(task,
                          rules,
                          seed=args.seed,
                          mpc_num_processes=args.mpc_num_processes)

    # initialize Value function
    device = 'cpu'
    state = env.reset()
    sample_adj_matrix, sample_features, sample_masks = preprocessor.preprocess(
        state)
    num_features = sample_features.shape[1]
    V = Net(max_nodes=max_nodes, num_channels=num_features,
            num_outputs=1).to(device)

    # load pretrained V function
    if args.load_V_path is not None:
        V.load_state_dict(torch.load(args.load_V_path))
        print_info('Loaded pretrained V function from {}'.format(
            args.load_V_path))

    # initialize target V_hat look up table
    V_hat = dict()

    # load pretrained V_hat
    if args.load_Vhat_path is not None:
        V_hat_fp = open(args.load_Vhat_path, 'rb')
        V_hat = pickle.load(V_hat_fp)
        V_hat_fp.close()
        print_info('Loaded pretrained Vhat from {}'.format(
            args.load_Vhat_path))

    if not args.test:
        # initialize save folders and files
        fp_log = open(os.path.join(args.save_dir, 'log.txt'), 'w')
        fp_log.close()
        fp_eval = open(os.path.join(args.save_dir, 'eval.txt'), 'w')
        fp_eval.close()
        design_csv_path = os.path.join(args.save_dir, 'designs.csv')
        fp_csv = open(design_csv_path, 'w')
        fieldnames = ['rule_seq', 'reward', 'opt_seed']
        writer = csv.DictWriter(fp_csv, fieldnames=fieldnames)
        writer.writeheader()
        fp_csv.close()

        # initialize the optimizer
        global optimizer
        optimizer = torch.optim.Adam(V.parameters(), lr=args.lr)

        # initialize the seen states pool
        states_pool = StatesPool(capacity=args.states_pool_capacity)
        states_set = set()

        # explored designs
        designs = []
        design_rewards = []
        design_opt_seeds = []

        # initialize best design rule sequence
        best_design, best_reward = None, -np.inf

        # reward history
        epoch_rew_his = []
        last_checkpoint = -1

        # recording time
        t_sample_sum = 0.
        num_samples_interval = 0
        max_seen_nodes = 0

        # record the count for invalid samples
        no_action_samples, step_exceeded_samples, self_collision_samples = 0, 0, 0

        # initialize stats variables
        num_invalid_samples, num_valid_samples = 0, 0
        repeated_cnt = 0

        # record prediction error
        prediction_error = []

        for epoch in range(args.num_iterations):
            t_start = time.time()

            V.eval()

            # update eps and eps_sample
            if args.eps_schedule == 'linear-decay':
                eps = args.eps_start + epoch / args.num_iterations * (
                    args.eps_end - args.eps_start)
            elif args.eps_schedule == 'exp-decay':
                eps = args.eps_end + (args.eps_start - args.eps_end) * np.exp(
                    -1.0 * epoch / args.num_iterations / args.eps_decay)

            if args.eps_sample_schedule == 'linear-decay':
                eps_sample = args.eps_sample_start + epoch / args.num_iterations * (
                    args.eps_sample_end - args.eps_sample_start)
            elif args.eps_sample_schedule == 'exp-decay':
                eps_sample = args.eps_sample_end + (
                    args.eps_sample_start - args.eps_sample_end) * np.exp(
                        -1.0 * epoch / args.num_iterations /
                        args.eps_sample_decay)

            t_sample, t_update, t_mpc, t_opt = 0, 0, 0, 0

            selected_design, selected_reward = None, -np.inf
            selected_state_seq, selected_rule_seq = None, None

            p = random.random()
            if p < eps_sample:
                num_samples = 1
            else:
                num_samples = args.num_samples

            # use e-greedy to sample a design within maximum #steps.
            for _ in range(num_samples):
                valid = False
                while not valid:
                    t0 = time.time()

                    state = env.reset()
                    rule_seq = []
                    state_seq = [state]
                    no_action_flag = False
                    for _ in range(args.depth):
                        action, step_type = select_action(env, V, state, eps)
                        if action is None:
                            no_action_flag = True
                            break
                        rule_seq.append(action)
                        next_state = env.transite(state, action)
                        state_seq.append(next_state)
                        state = next_state
                        if not has_nonterminals(state):
                            break

                    valid = env.is_valid(state)

                    t_sample += time.time() - t0

                    t0 = time.time()

                    if not valid:
                        # update the invalid sample's count
                        if no_action_flag:
                            no_action_samples += 1
                        elif has_nonterminals(state):
                            step_exceeded_samples += 1
                        else:
                            self_collision_samples += 1
                        num_invalid_samples += 1
                    else:
                        if hash(state) in V_hat:
                            update_Vhat(args, V_hat, state_seq,
                                        V_hat[hash(state)])
                            update_states_pool(states_pool, state_seq,
                                               states_set)
                            valid = False
                            num_invalid_samples += 1
                        else:
                            num_valid_samples += 1

                    num_samples_interval += 1

                    t_update += time.time() - t0

                predicted_value = predict(V, state)
                if predicted_value > selected_reward:
                    selected_design, selected_reward = state, predicted_value
                    selected_rule_seq, selected_state_seq = rule_seq, state_seq

            t0 = time.time()

            repeated = False
            if hash(selected_design) in V_hat:
                repeated = True
                repeated_cnt += 1

            reward, best_seed = -np.inf, None

            for _ in range(args.num_eval):
                _, rew = env.get_reward(selected_design)
                if rew > reward:
                    reward, best_seed = rew, env.last_opt_seed

            t_mpc += time.time() - t0

            # save the design and the reward in the list
            designs.append(selected_rule_seq)
            design_rewards.append(reward)
            design_opt_seeds.append(best_seed)

            # update best design
            if reward > best_reward:
                best_design, best_reward = selected_rule_seq, reward
                print_info(
                    'new best: reward = {:.4f}, predicted reward = {:.4f}, num_samples = {}'
                    .format(reward, selected_reward, num_samples))

            t0 = time.time()

            # update V_hat for the valid design
            update_Vhat(args, V_hat, selected_state_seq, reward)

            # update states pool for the valid design
            update_states_pool(states_pool, selected_state_seq, states_set)

            t_update += time.time() - t0

            t0 = time.time()

            # optimize
            V.train()
            total_loss = 0.0
            for _ in range(args.opt_iter):
                minibatch = states_pool.sample(
                    min(len(states_pool), args.batch_size))

                train_adj_matrix, train_features, train_masks, train_reward = [], [], [], []
                max_nodes = 0
                for robot_graph in minibatch:
                    hash_key = hash(robot_graph)
                    target_reward = V_hat[hash_key]
                    adj_matrix, features, _ = preprocessor.preprocess(
                        robot_graph)
                    max_nodes = max(max_nodes, len(features))
                    train_adj_matrix.append(adj_matrix)
                    train_features.append(features)
                    train_reward.append(target_reward)

                max_seen_nodes = max(max_seen_nodes, max_nodes)

                for i in range(len(minibatch)):
                    train_adj_matrix[i], train_features[i], masks = \
                        preprocessor.pad_graph(train_adj_matrix[i], train_features[i], max_nodes)
                    train_masks.append(masks)

                train_adj_matrix_torch = torch.tensor(train_adj_matrix)
                train_features_torch = torch.tensor(train_features)
                train_masks_torch = torch.tensor(train_masks)
                train_reward_torch = torch.tensor(train_reward)

                optimizer.zero_grad()
                output, loss_link, loss_entropy = V(train_features_torch,
                                                    train_adj_matrix_torch,
                                                    train_masks_torch)
                loss = F.mse_loss(output[:, 0], train_reward_torch)
                loss.backward()
                total_loss += loss.item()
                optimizer.step()

            t_opt += time.time() - t0

            t_end = time.time()

            t_sample_sum += t_sample

            # logging
            if (epoch + 1
                ) % args.log_interval == 0 or epoch + 1 == args.num_iterations:
                iter_save_dir = os.path.join(args.save_dir,
                                             '{}'.format(epoch + 1))
                os.makedirs(os.path.join(iter_save_dir), exist_ok=True)
                # save model
                save_path = os.path.join(iter_save_dir, 'V_model.pt')
                torch.save(V.state_dict(), save_path)
                # save V_hat
                save_path = os.path.join(iter_save_dir, 'V_hat')
                fp = open(save_path, 'wb')
                pickle.dump(V_hat, fp)
                fp.close()

            # save explored design and its reward
            fp_csv = open(design_csv_path, 'a')
            fieldnames = ['rule_seq', 'reward', 'opt_seed']
            writer = csv.DictWriter(fp_csv, fieldnames=fieldnames)
            for i in range(last_checkpoint + 1, len(designs)):
                writer.writerow({
                    'rule_seq': str(designs[i]),
                    'reward': design_rewards[i],
                    'opt_seed': design_opt_seeds[i]
                })
            last_checkpoint = len(designs) - 1
            fp_csv.close()

            epoch_rew_his.append(reward)

            avg_loss = total_loss / args.opt_iter
            len_his = min(len(epoch_rew_his), 30)
            avg_reward = np.sum(epoch_rew_his[-len_his:]) / len_his
            prediction_error.append(np.abs(selected_reward - reward))
            avg_prediction_error = np.sum(
                prediction_error[-len_his:]) / len_his

            if repeated:
                print_white('Epoch {:4}: T_sample = {:5.2f}, T_mpc = {:5.2f}, T_opt = {:5.2f}, eps = {:5.3f}, #samples = {:2}, training loss = {:7.4f}, avg_pred_error = {:6.4f}, predicted_reward = {:6.4f}, reward = {:6.4f}, last 30 epoch reward = {:6.4f}, best reward = {:6.4f}'.format(\
                    epoch, t_sample, t_mpc, t_opt, eps, num_samples, \
                    avg_loss, avg_prediction_error, selected_reward, reward, avg_reward, best_reward))
            else:
                print_warning('Epoch {:4}: T_sample = {:5.2f}, T_mpc = {:5.2f}, T_opt = {:5.2f}, eps = {:5.3f}, #samples = {:2}, training loss = {:7.4f}, avg_pred_error = {:6.4f}, predicted_reward = {:6.4f}, reward = {:6.4f}, last 30 epoch reward = {:6.4f}, best reward = {:6.4f}'.format(\
                    epoch, t_sample, t_mpc, t_opt, eps, num_samples, \
                    avg_loss, avg_prediction_error, selected_reward, reward, avg_reward, best_reward))

            fp_log = open(os.path.join(args.save_dir, 'log.txt'), 'a')
            fp_log.write('eps = {:.4f}, eps_sample = {:.4f}, num_samples = {}, T_sample = {:4f}, T_update = {:4f}, T_mpc = {:.4f}, T_opt = {:.4f}, loss = {:.4f}, predicted_reward = {:.4f}, reward = {:.4f}, avg_reward = {:.4f}\n'.format(\
                eps, eps_sample, num_samples, t_sample, t_update, t_mpc, t_opt, avg_loss, selected_reward, reward, avg_reward))
            fp_log.close()

            if (epoch + 1) % args.log_interval == 0:
                print_info(
                    'Avg sampling time for last {} epoch: {:.4f} second'.
                    format(args.log_interval,
                           t_sample_sum / num_samples_interval))
                t_sample_sum = 0.
                num_samples_interval = 0
                print_info('max seen nodes = {}'.format(max_seen_nodes))
                print_info('size of states_pool = {}'.format(len(states_pool)))
                print_info(
                    '#valid samples = {}, #invalid samples = {}, #valid / #invalid = {}'
                    .format(
                        num_valid_samples, num_invalid_samples,
                        num_valid_samples / num_invalid_samples
                        if num_invalid_samples > 0 else 10000.0))
                print_info('Invalid samples: #no_action_samples = {}, #step_exceeded_samples = {}, #self_collision_samples = {}, #repeated_samples = {}'.format(\
                    no_action_samples, step_exceeded_samples, self_collision_samples, num_invalid_samples - no_action_samples - step_exceeded_samples - self_collision_samples))
                print_info('repeated rate = {}'.format(repeated_cnt /
                                                       (epoch + 1)))

        save_path = os.path.join(args.save_dir, 'model_state_dict_final.pt')
        torch.save(V.state_dict(), save_path)
    else:
        import IPython
        IPython.embed()

        # test
        V.eval()
        print('Start testing')
        test_epoch = 30
        y0 = []
        y1 = []
        x = []
        for ii in range(0, 11):
            eps = 1.0 - 0.1 * ii

            print('------------------------------------------')
            print('eps = ', eps)

            reward_sum = 0.
            best_reward = -np.inf
            for epoch in range(test_epoch):
                t0 = time.time()

                # use e-greedy to sample a design within maximum #steps.
                vaild = False
                while not valid:
                    state = env.reset()
                    rule_seq = []
                    state_seq = [state]
                    for _ in range(args.depth):
                        action, step_type = select_action(env, V, state, eps)
                        if action is None:
                            break
                        rule_seq.append(action)
                        next_state = env.transite(state, action)
                        state_seq.append(next_state)
                        if not has_nonterminals(next_state):
                            valid = True
                            break
                        state = next_state

                _, reward = env.get_reward(state)
                reward_sum += reward
                best_reward = max(best_reward, reward)
                print(
                    f'design {epoch}: reward = {reward}, time = {time.time() - t0}'
                )

            print('test avg reward = ', reward_sum / test_epoch)
            print('best reward found = ', best_reward)
            x.append(eps)
            y0.append(reward_sum / test_epoch)
            y1.append(best_reward)

        import matplotlib.pyplot as plt
        fig, ax = plt.subplots(1, 2, figsize=(10, 5))
        ax[0].plot(x, y0)
        ax[0].set_title('Avg Reward')
        ax[0].set_xlabel('eps')
        ax[0].set_ylabel('reward')

        ax[1].plot(x, y1)
        ax[0].set_title('Best Reward')
        ax[0].set_xlabel('eps')
        ax[0].set_ylabel('reward')

        plt.show()
Example #4
0
                        help="Grammar file (.dot)")
    parser.add_argument('--index',
                        type=int,
                        default=None,
                        help='index of the designs to be shown at the end')

    args = parser.parse_args()

    fp = open(args.log_path, newline='')
    reader = csv.DictReader(fp)

    graphs = rd.load_graphs(args.grammar_file)
    rules = [rd.create_rule_from_graph(g) for g in graphs]

    # initialize the env
    env = RobotGrammarEnv(None, rules)

    design_cnt = dict()
    memory = dict()
    N = 0
    best_reward = []
    rewards = []
    rule_seqs = []
    opt_seeds = []
    best_design = None
    best_rule_seq = None
    best_designs = []
    for row in reader:
        N += 1
        design = row['rule_seq']
        reward = float(row['reward'])