示例#1
0
    def __init__(self, config, train_set, test_set, use_cuda=False):

        self.NUM_EPOCHS = config.NUM_EPOCHS
        self.ALPHA = config.ALPHA
        self.BATCH_SIZE = config.BATCH_SIZE # number of models to generate for each action
        self.HIDDEN_SIZE = config.HIDDEN_SIZE
        self.BETA = config.BETA
        self.GAMMA = config.GAMMA
        self.DEVICE = torch.device('cuda' if torch.cuda.is_available() and use_cuda else 'cpu')
        self.INPUT_SIZE = config.INPUT_SIZE
        self.NUM_STEPS = config.NUM_STEPS
        self.ACTION_SPACE = config.ACTION_SPACE

        self.train = train_set
        self.test = test_set

        # instantiate the tensorboard writer
        self.writer = SummaryWriter(comment=f'_PG_CP_Gamma={self.GAMMA},'
                                            f'LR={self.ALPHA},'
                                            f'BS={self.BATCH_SIZE},'
                                            f'NH={self.HIDDEN_SIZE},'
                                            f'BETA={self.BETA}')

        # the agent driven by a neural network architecture
        if use_cuda:
            self.agent = Agent(self.INPUT_SIZE, self.HIDDEN_SIZE, self.NUM_STEPS, device=self.DEVICE).cuda()
        else:
            self.agent = Agent(self.INPUT_SIZE, self.HIDDEN_SIZE, self.NUM_STEPS, device=self.DEVICE)
        self.adam = optim.Adam(params=self.agent.parameters(), lr=self.ALPHA)
        self.total_rewards = deque([], maxlen=100)
示例#2
0
 def __init__(self,
              input_size,
              hidden_size,
              num_steps,
              action_space,
              learning_rate=0.001,
              beta=0.1):
     self.agent = Agent(input_size, hidden_size, num_steps)
     self.optimizer = torch.optim.Adam(self.agent.parameters(),
                                       learning_rate)
     self.beta = beta
     self.action_space = action_space
示例#3
0
class PolicyGradient:
    def __init__(self, config, train_set, test_set, use_cuda=False):

        self.NUM_EPOCHS = config.NUM_EPOCHS
        self.ALPHA = config.ALPHA
        self.BATCH_SIZE = config.BATCH_SIZE # number of models to generate for each action
        self.HIDDEN_SIZE = config.HIDDEN_SIZE
        self.BETA = config.BETA
        self.GAMMA = config.GAMMA
        self.DEVICE = torch.device('cuda' if torch.cuda.is_available() and use_cuda else 'cpu')
        self.INPUT_SIZE = config.INPUT_SIZE
        self.NUM_STEPS = config.NUM_STEPS
        self.ACTION_SPACE = config.ACTION_SPACE

        self.train = train_set
        self.test = test_set

        # instantiate the tensorboard writer
        self.writer = SummaryWriter(comment=f'_PG_CP_Gamma={self.GAMMA},'
                                            f'LR={self.ALPHA},'
                                            f'BS={self.BATCH_SIZE},'
                                            f'NH={self.HIDDEN_SIZE},'
                                            f'BETA={self.BETA}')

        # the agent driven by a neural network architecture
        if use_cuda:
            self.agent = Agent(self.INPUT_SIZE, self.HIDDEN_SIZE, self.NUM_STEPS, device=self.DEVICE).cuda()
        else:
            self.agent = Agent(self.INPUT_SIZE, self.HIDDEN_SIZE, self.NUM_STEPS, device=self.DEVICE)
        self.adam = optim.Adam(params=self.agent.parameters(), lr=self.ALPHA)
        self.total_rewards = deque([], maxlen=100)


    def solve_environment(self):
        """
            The main interface for the Policy Gradient solver
        """
        # init the episode and the epoch
        epoch = 0

        while epoch < self.NUM_EPOCHS:
            # init the epoch arrays
            # used for entropy calculation
            epoch_logits = torch.empty(size=(0, self.ACTION_SPACE), device=self.DEVICE)
            epoch_weighted_log_probs = torch.empty(size=(0,), dtype=torch.float, device=self.DEVICE)

            # Sample BATCH_SIZE models and do average
            for i in range(self.BATCH_SIZE):
                # play an episode of the environment
                (episode_weighted_log_prob_trajectory,
                 episode_logits,
                 sum_of_episode_rewards) = self.play_episode()

                # after each episode append the sum of total rewards to the deque
                self.total_rewards.append(sum_of_episode_rewards)

                # append the weighted log-probabilities of actions
                epoch_weighted_log_probs = torch.cat((epoch_weighted_log_probs, episode_weighted_log_prob_trajectory),
                                                     dim=0)
                # append the logits - needed for the entropy bonus calculation
                epoch_logits = torch.cat((epoch_logits, episode_logits), dim=0)

            # calculate the loss
            loss, entropy = self.calculate_loss(epoch_logits=epoch_logits,
                                                weighted_log_probs=epoch_weighted_log_probs)

            # zero the gradient
            self.adam.zero_grad()

            # backprop
            loss.backward()

            # update the parameters
            self.adam.step()

            # feedback
            print("\r", f"Epoch: {epoch}, Avg Return per Epoch: {np.mean(self.total_rewards):.3f}",
                  end="",
                  flush=True)

            self.writer.add_scalar(tag='Average Return over 100 episodes',
                                   scalar_value=np.mean(self.total_rewards),
                                   global_step=epoch)

            self.writer.add_scalar(tag='Entropy',
                                   scalar_value=entropy,
                                   global_step=epoch)
            # check if solved
            # if np.mean(self.total_rewards) > 200:
            #     print('\nSolved!')
            #     break
            epoch += 1
        # close the writer
        self.writer.close()

    def play_episode(self):
        """
            Plays an episode of the environment.
            episode: the episode counter
            Returns:
                sum_weighted_log_probs: the sum of the log-prob of an action multiplied by the reward-to-go from that state
                episode_logits: the logits of every step of the episode - needed to compute entropy for entropy bonus
                finished_rendering_this_epoch: pass-through rendering flag
                sum_of_rewards: sum of the rewards for the episode - needed for the average over 200 episode statistic
        """
        # Init state
        init_state = [[3, 8, 16]]

        # get the action logits from the agent - (preferences)
        episode_logits = self.agent(torch.tensor(init_state).float().to(self.DEVICE))

        # sample an action according to the action distribution
        action_index = Categorical(logits=episode_logits).sample().unsqueeze(1)

        mask = one_hot(action_index, num_classes=self.ACTION_SPACE)

        episode_log_probs = torch.sum(mask.float() * log_softmax(episode_logits, dim=1), dim=1)

        # append the action to the episode action list to obtain the trajectory
        # we need to store the actions and logits so we could calculate the gradient of the performance
        #episode_actions = torch.cat((episode_actions, action_index), dim=0)

        # Get action actions
        action_space = torch.tensor([[3, 5, 7], [8, 16, 32], [3, 5, 7], [8, 16, 32]], device=self.DEVICE)
        action = torch.gather(action_space, 1, action_index).squeeze(1)
        # generate a submodel given predicted actions
        net = NASModel(action)
        #net = Net()

        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

        for epoch in range(2):  # loop over the dataset multiple times

            running_loss = 0.0
            for i, data in enumerate(self.train, 0):
                # get the inputs; data is a list of [inputs, labels]
                inputs, labels = data

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward + backward + optimize
                outputs = net(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                # print statistics
                running_loss += loss.item()
                if i % 2000 == 1999:  # print every 2000 mini-batches
                    print('[%d, %5d] loss: %.3f' %
                          (epoch + 1, i + 1, running_loss / 2000))
                    running_loss = 0.0

        print('Finished Training')

        # load best performance epoch in this training session
        # model.load_weights('weights/temp_network.h5')

        # evaluate the model
        correct = 0
        total = 0
        with torch.no_grad():
            for data in self.test:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        acc = 100 * correct / total
        print('Accuracy of the network on the 10000 test images: {}'.format(acc))

        # compute the reward
        reward = acc

        episode_weighted_log_probs = episode_log_probs * reward
        sum_weighted_log_probs = torch.sum(episode_weighted_log_probs).unsqueeze(dim=0)

        return  sum_weighted_log_probs, episode_logits, reward

    def calculate_loss(self, epoch_logits: torch.Tensor, weighted_log_probs: torch.Tensor) -> (torch.Tensor, torch.Tensor):
        """
            Calculates the policy "loss" and the entropy bonus
            Args:
                epoch_logits: logits of the policy network we have collected over the epoch
                weighted_log_probs: loP * W of the actions taken
            Returns:
                policy loss + the entropy bonus
                entropy: needed for logging
        """
        policy_loss = -1 * torch.mean(weighted_log_probs)

        # add the entropy bonus
        p = softmax(epoch_logits, dim=1)
        log_p = log_softmax(epoch_logits, dim=1)
        entropy = -1 * torch.mean(torch.sum(p * log_p, dim=1), dim=0)
        entropy_bonus = -1 * self.BETA * entropy

        return policy_loss + entropy_bonus, entropy
示例#4
0
文件: main.py 项目: ND-SCL/NAQS
def quantization_search(device, dir='experiment'):
    dir = os.path.join(dir,
                       f"rLut={args.rLUT}, rThroughput={args.rThroughput}")
    if os.path.exists(dir) is False:
        os.makedirs(dir)
    filepath = os.path.join(dir, f"quantization ({args.episodes} episodes)")
    logger = get_logger(filepath)
    csvfile = open(filepath + '.csv', mode='w+', newline='')
    writer = csv.writer(csvfile)
    logger.info(f"INFORMATION")
    logger.info(f"mode: \t\t\t\t\t {'quantization'}")
    logger.info(f"dataset: \t\t\t\t {args.dataset}")
    logger.info(f"number of child network layers: \t {args.layers}")
    logger.info(f"include stride: \t\t\t {not args.no_stride}")
    logger.info(f"include pooling: \t\t\t {not args.no_pooling}")
    logger.info(f"skip connection: \t\t\t {args.skip}")
    logger.info(f"required # LUTs: \t\t\t {args.rLUT}")
    logger.info(f"required throughput: \t\t\t {args.rThroughput}")
    logger.info(f"Assumed frequency: \t\t\t {CLOCK_FREQUENCY}")
    logger.info(f"training epochs: \t\t\t {args.epochs}")
    logger.info(f"data augmentation: \t\t\t {args.augment}")
    logger.info(f"batch size: \t\t\t\t {args.batch_size}")
    logger.info(f"controller learning rate: \t\t {args.learning_rate}")
    logger.info(f"architecture episodes: \t\t\t {args.episodes}")
    logger.info(f"using multi gpus: \t\t\t {args.multi_gpu}")
    logger.info(f"architecture space: ")
    # for name, value in ARCH_SPACE.items():
    #     logger.info(name + f": \t\t\t\t {value}")
    logger.info(f"quantization space: ")
    for name, value in QUAN_SPACE.items():
        logger.info(name + f": \t\t\t {value}")
    agent = Agent(QUAN_SPACE,
                  args.layers,
                  lr=args.learning_rate,
                  device=torch.device('cpu'),
                  skip=False)
    train_data, val_data = data.get_data(args.dataset,
                                         device,
                                         shuffle=True,
                                         batch_size=args.batch_size,
                                         augment=args.augment)
    input_shape, num_classes = data.get_info(args.dataset)
    writer.writerow(["ID"] +
                    ["Layer {}".format(i)
                     for i in range(args.layers)] + ["Accuracy"] + [
                         "Partition (Tn, Tm)", "Partition (#LUTs)",
                         "Partition (#cycles)", "Total LUT", "Total Throughput"
                     ] + ["Time"])
    child_id, total_time = 0, 0
    logger.info('=' * 50 + "Start exploring quantization space" + '=' * 50)
    best_samples = BestSamples(5)
    arch_paras = [{
        'filter_height': 3,
        'filter_width': 3,
        'stride_height': 1,
        'stride_width': 1,
        'num_filters': 64,
        'pool_size': 1
    }, {
        'filter_height': 7,
        'filter_width': 5,
        'stride_height': 1,
        'stride_width': 1,
        'num_filters': 48,
        'pool_size': 1
    }, {
        'filter_height': 5,
        'filter_width': 5,
        'stride_height': 2,
        'stride_width': 1,
        'num_filters': 48,
        'pool_size': 1
    }, {
        'filter_height': 3,
        'filter_width': 5,
        'stride_height': 1,
        'stride_width': 1,
        'num_filters': 64,
        'pool_size': 1
    }, {
        'filter_height': 5,
        'filter_width': 7,
        'stride_height': 1,
        'stride_width': 1,
        'num_filters': 36,
        'pool_size': 1
    }, {
        'filter_height': 3,
        'filter_width': 1,
        'stride_height': 1,
        'stride_width': 2,
        'num_filters': 64,
        'pool_size': 2
    }]
    model, optimizer = child.get_model(input_shape,
                                       arch_paras,
                                       num_classes,
                                       device,
                                       multi_gpu=args.multi_gpu,
                                       do_bn=False)
    _, val_acc = backend.fit(model,
                             optimizer,
                             train_data=train_data,
                             val_data=val_data,
                             epochs=args.epochs,
                             verbosity=args.verbosity)
    print(val_acc)
    for e in range(args.episodes):
        logger.info('-' * 130)
        child_id += 1
        start = time.time()
        quan_rollout, quan_paras = agent.rollout()
        logger.info("Sample Quantization ID: {}, Sampled actions: {}".format(
            child_id, quan_rollout))
        fpga_model = FPGAModel(rLUT=args.rLUT,
                               rThroughput=args.rThroughput,
                               arch_paras=arch_paras,
                               quan_paras=quan_paras)
        if fpga_model.validate():
            _, reward = backend.fit(model,
                                    optimizer,
                                    val_data=val_data,
                                    quan_paras=quan_paras,
                                    epochs=1,
                                    verbosity=args.verbosity)
        else:
            reward = 0
        agent.store_rollout(quan_rollout, reward)
        end = time.time()
        ep_time = end - start
        total_time += ep_time
        best_samples.register(child_id, quan_rollout, reward)
        writer.writerow([child_id] +
                        [str(quan_paras[i]) for i in range(args.layers)] +
                        [reward] + list(fpga_model.get_info()) + [ep_time])
        logger.info(f"Reward: {reward}, " + f"Elasped time: {ep_time}, " +
                    f"Average time: {total_time/(e+1)}")
        logger.info(f"Best Reward: {best_samples.reward_list[0]}, " +
                    f"ID: {best_samples.id_list[0]}, " +
                    f"Rollout: {best_samples.rollout_list[0]}")
    logger.info('=' * 50 + "Quantization sapce exploration finished" +
                '=' * 50)
    logger.info(f"Total elasped time: {total_time}")
    logger.info(f"Best samples: {best_samples}")
    csvfile.close()
示例#5
0
文件: main.py 项目: ND-SCL/NAQS
def nested_search(device, dir='experiment'):
    dir = os.path.join(dir,
                       f"rLut={args.rLUT}, rThroughput={args.rThroughput}")
    if os.path.exists(dir) is False:
        os.makedirs(dir)
    filepath = os.path.join(dir, f"nested ({args.episodes} episodes)")
    logger = get_logger(filepath)
    csvfile = open(filepath + '.csv', mode='w+', newline='')
    writer = csv.writer(csvfile)
    logger.info(f"INFORMATION")
    logger.info(f"mode: \t\t\t\t\t {'nested'}")
    logger.info(f"dataset: \t\t\t\t {args.dataset}")
    logger.info(f"number of child network layers: \t {args.layers}")
    logger.info(f"include stride: \t\t\t {not args.no_stride}")
    logger.info(f"include pooling: \t\t\t {not args.no_pooling}")
    logger.info(f"skip connection: \t\t\t {args.skip}")
    logger.info(f"required # LUTs: \t\t\t {args.rLUT}")
    logger.info(f"required throughput: \t\t\t {args.rThroughput}")
    logger.info(f"Assumed frequency: \t\t\t {CLOCK_FREQUENCY}")
    logger.info(f"training epochs: \t\t\t {args.epochs}")
    logger.info(f"data augmentation: \t\t\t {args.augment}")
    logger.info(f"batch size: \t\t\t\t {args.batch_size}")
    logger.info(f"controller learning rate: \t\t {args.learning_rate}")
    logger.info(f"architecture episodes: \t\t\t {args.episodes1}")
    logger.info(f"quantization episodes: \t\t\t {args.episodes2}")
    logger.info(f"using multi gpus: \t\t\t {args.multi_gpu}")
    logger.info(f"architecture space: ")
    for name, value in ARCH_SPACE.items():
        logger.info(name + f": \t\t\t\t {value}")
    logger.info(f"quantization space: ")
    for name, value in QUAN_SPACE.items():
        logger.info(name + f": \t\t\t {value}")
    train_data, val_data = data.get_data(args.dataset,
                                         device,
                                         shuffle=True,
                                         batch_size=args.batch_size,
                                         augment=args.augment)
    input_shape, num_classes = data.get_info(args.dataset)
    writer.writerow(["ID"] +
                    ["Layer {}".format(i)
                     for i in range(args.layers)] + ["Accuracy"] + [
                         "Partition (Tn, Tm)", "Partition (#LUTs)",
                         "Partition (#cycles)", "Total LUT", "Total Throughput"
                     ] + ["Time"])
    arch_agent = Agent(ARCH_SPACE,
                       args.layers,
                       lr=args.learning_rate,
                       device=torch.device('cpu'),
                       skip=args.skip)
    arch_id, total_time = 0, 0
    logger.info('=' * 50 + "Start exploring architecture space" + '=' * 50)
    best_arch = BestSamples(5)
    for e1 in range(args.episodes1):
        logger.info('-' * 130)
        arch_id += 1
        start = time.time()
        arch_rollout, arch_paras = arch_agent.rollout()
        logger.info("Sample Architecture ID: {}, Sampled arch: {}".format(
            arch_id, arch_rollout))
        model, optimizer = child.get_model(input_shape,
                                           arch_paras,
                                           num_classes,
                                           device,
                                           multi_gpu=args.multi_gpu,
                                           do_bn=False)
        backend.fit(model,
                    optimizer,
                    train_data,
                    val_data,
                    epochs=args.epochs,
                    verbosity=args.verbosity)
        quan_id = 0
        best_quan_reward = -1
        logger.info('=' * 50 + "Start exploring quantization space" + '=' * 50)
        quan_agent = Agent(QUAN_SPACE,
                           args.layers,
                           lr=args.learning_rate,
                           device=torch.device('cpu'),
                           skip=False)
        for e2 in range(args.episodes2):
            quan_id += 1
            quan_rollout, quan_paras = quan_agent.rollout()
            fpga_model = FPGAModel(rLUT=args.rLUT,
                                   rThroughput=args.rThroughput,
                                   arch_paras=arch_paras,
                                   quan_paras=quan_paras)
            if fpga_model.validate():
                _, quan_reward = backend.fit(model,
                                             optimizer,
                                             val_data=val_data,
                                             quan_paras=quan_paras,
                                             epochs=1,
                                             verbosity=args.verbosity)
            else:
                quan_reward = 0
            logger.info(
                "Sample Quantization ID: {}, Sampled Quantization: {}, reward: {}"
                .format(quan_id, quan_rollout, quan_reward))
            quan_agent.store_rollout(quan_rollout, quan_reward)
            if quan_reward > best_quan_reward:
                best_quan_reward = quan_reward
                best_quan_rollout, best_quan_paras = quan_rollout, quan_paras
        logger.info('=' * 50 + "Quantization space exploration finished" +
                    '=' * 50)
        arch_reward = best_quan_reward
        arch_agent.store_rollout(arch_rollout, arch_reward)
        end = time.time()
        ep_time = end - start
        total_time += ep_time
        best_arch.register(
            arch_id,
            utility.combine_rollout(arch_rollout, best_quan_rollout,
                                    args.layers), arch_reward)
        writer.writerow([arch_id] + [
            str(arch_paras[i]) + '\n' + str(best_quan_paras[i])
            for i in range(args.layers)
        ] + [arch_reward] + list(fpga_model.get_info()) + [ep_time])
        logger.info(f"Reward: {arch_reward}, " + f"Elasped time: {ep_time}, " +
                    f"Average time: {total_time/(e1+1)}")
        logger.info(f"Best Reward: {best_arch.reward_list[0]}, " +
                    f"ID: {best_arch.id_list[0]}, " +
                    f"Rollout: {best_arch.rollout_list[0]}")
    logger.info('=' * 50 +
                "Architecture & quantization sapce exploration finished" +
                '=' * 50)
    logger.info(f"Total elasped time: {total_time}")
    logger.info(f"Best samples: {best_arch}")
    csvfile.close()
示例#6
0
文件: main.py 项目: ND-SCL/NAQS
def nas(device, dir='experiment'):
    filepath = os.path.join(dir, f"nas ({args.episodes} episodes)")
    logger = get_logger(filepath)
    csvfile = open(filepath + '.csv', mode='w+', newline='')
    writer = csv.writer(csvfile)
    logger.info(f"INFORMATION")
    logger.info(f"mode: \t\t\t\t\t {'nas'}")
    logger.info(f"dataset: \t\t\t\t {args.dataset}")
    logger.info(f"number of child network layers: \t {args.layers}")
    logger.info(f"include stride: \t\t\t {not args.no_stride}")
    logger.info(f"include pooling: \t\t\t {not args.no_pooling}")
    logger.info(f"skip connection: \t\t\t {args.skip}")
    logger.info(f"training epochs: \t\t\t {args.epochs}")
    logger.info(f"data augmentation: \t\t\t {args.augment}")
    logger.info(f"batch size: \t\t\t\t {args.batch_size}")
    logger.info(f"controller learning rate: \t\t {args.learning_rate}")
    logger.info(f"architecture episodes: \t\t\t {args.episodes}")
    logger.info(f"using multi gpus: \t\t\t {args.multi_gpu}")
    logger.info(f"architecture space: ")
    for name, value in ARCH_SPACE.items():
        logger.info(name + f": \t\t\t\t {value}")
    agent = Agent(ARCH_SPACE,
                  args.layers,
                  lr=args.learning_rate,
                  device=torch.device('cpu'),
                  skip=args.skip)
    train_data, val_data = data.get_data(args.dataset,
                                         device,
                                         shuffle=True,
                                         batch_size=args.batch_size,
                                         augment=args.augment)
    input_shape, num_classes = data.get_info(args.dataset)
    writer.writerow(["ID"] +
                    ["Layer {}".format(i)
                     for i in range(args.layers)] + ["Accuracy", "Time"])
    arch_id, total_time = 0, 0
    logger.info('=' * 50 + "Start exploring architecture space" + '=' * 50)
    logger.info('-' * len("Start exploring architecture space"))
    best_samples = BestSamples(5)
    for e in range(args.episodes):
        arch_id += 1
        start = time.time()
        arch_rollout, arch_paras = agent.rollout()
        logger.info("Sample Architecture ID: {}, Sampled actions: {}".format(
            arch_id, arch_rollout))
        model, optimizer = child.get_model(input_shape,
                                           arch_paras,
                                           num_classes,
                                           device,
                                           multi_gpu=args.multi_gpu,
                                           do_bn=True)
        _, arch_reward = backend.fit(model,
                                     optimizer,
                                     train_data,
                                     val_data,
                                     epochs=args.epochs,
                                     verbosity=args.verbosity)
        agent.store_rollout(arch_rollout, arch_reward)
        end = time.time()
        ep_time = end - start
        total_time += ep_time
        best_samples.register(arch_id, arch_rollout, arch_reward)
        writer.writerow([arch_id] +
                        [str(arch_paras[i]) for i in range(args.layers)] +
                        [arch_reward] + [ep_time])
        logger.info(f"Architecture Reward: {arch_reward}, " +
                    f"Elasped time: {ep_time}, " +
                    f"Average time: {total_time/(e+1)}")
        logger.info(f"Best Reward: {best_samples.reward_list[0]}, " +
                    f"ID: {best_samples.id_list[0]}, " +
                    f"Rollout: {best_samples.rollout_list[0]}")
        logger.info('-' * len("Start exploring architecture space"))
    logger.info('=' * 50 + "Architecture sapce exploration finished" +
                '=' * 50)
    logger.info(f"Total elasped time: {total_time}")
    logger.info(f"Best samples: {best_samples}")
    csvfile.close()
示例#7
0
class PolicyNetwork:
    # hidden size, number of layers, dropout_probs
    ACTION_SPACE = torch.tensor([[1, 2, 3, 4], [10, 50, 100, 150],
                                 [0.2, 0.4, 0.5, 0.7]])

    def __init__(self,
                 input_size,
                 hidden_size,
                 num_steps,
                 action_space,
                 learning_rate=0.001,
                 beta=0.1):
        self.agent = Agent(input_size, hidden_size, num_steps)
        self.optimizer = torch.optim.Adam(self.agent.parameters(),
                                          learning_rate)
        self.beta = beta
        self.action_space = action_space

    def __str__(self):
        obj_dict = vars(self)
        return pprint.pformat(obj_dict, indent=4)

    def _predict(self, s):
        """
        Compute the action probabilities of state s using the learning model
        """
        return self.agent(torch.tensor(s).float())

    def _policy_loss(self, logits, current_action_probs, is_entropy=True):
        """
        Entropy will add exploration benefits.
        (more elaboration : https://github.com/dennybritz/reinforcement-learning/issues/34)

        At the beginning, all the actions will have same probability of occurrence.
        After some episodes of learning, few of the choices (actions) might have
        higher probability of selection. Those actions will dominate the selection process.
        Entropy will add costs to those dominated actions -> allowing algorithm to select new action pairs.
        Entropy keeps on decreasing over time.

        beta controls the level of regularization. 
        Too Low beta -> entropy will be insignificant to make policy to go to another direction(explore)
        Too High beta -> add more randomness, might also give suboptimal policy
        Args:
            logits : For each child network, we have a new set of logits representing actions
            current_action_probs : log prob times return(rewards) for a given child network
        """
        self.loss = -1 * torch.mean(current_action_probs)
        if is_entropy:
            self.probs = softmax(logits, dim=1) + 1e-8
            self.entropy = -1 * torch.sum(
                self.probs * log_softmax(self.probs, dim=1), dim=1)
            self.entropy_mean = torch.mean(self.entropy, dim=0)
            self.entropy_bonus = -1 * self.beta * self.entropy_mean
            self.loss += self.entropy_bonus

    def update(self, logits, log_probs, is_entropy=True):
        """
        Update the weights of the policy network.

        """
        self._policy_loss(logits, log_probs, is_entropy)

        self.optimizer.zero_grad()
        self.loss.backward()
        self.optimizer.step()

    def get_action(self, s):
        """
        Estimate the policy and sample an action, compute its log probability
        """
        # probs = self.predict(s)
        # action = torch.multinomial(probs, 1).item()
        # log_prob = torch.log(probs[action])
        # return action, log_prob
        # TODO : test with multinomial ???

        logits = self._predict(s)
        ind = Categorical(logits=logits).sample().unsqueeze(1)
        action_mask = one_hot(ind, num_classes=self.action_space)
        action_selection_prob = log_softmax(logits, dim=1)
        log_prob = torch.sum(action_mask.float() * action_selection_prob,
                             dim=1)

        action = torch.gather(self.ACTION_SPACE, 1, ind).squeeze(1).numpy()

        print('logits', logits)
        print('ind', ind)
        print('mask', action_mask)
        print('action selection prob', action_selection_prob)
        print('current log prob', log_prob)
        print("current action is", action)
        action_dict = {
            "n_hidden": action[0],
            "n_layers": action[1],
            "dropout_prob": action[2],
        }
        print('selected action', action_dict)
        return action_dict, log_prob, logits

    @classmethod
    def from_dict(cls, d):
        return cls(**d)
def sync_search(device, dir='experiment'):
    dir = os.path.join(
        dir,
        utility.cleanText(f"rLut-{args.rLUT}_rThroughput-{args.rThroughput}"))
    if os.path.exists(dir) is False:
        os.makedirs(dir)
    filepath = os.path.join(
        dir, utility.cleanText(f"joint_{args.episodes}-episodes"))
    logger = utility.get_logger(filepath)
    csvfile = open(filepath + '.csv', mode='w+', newline='')
    writer = csv.writer(csvfile)
    tb_writer = SummaryWriter(filepath)

    logger.info(f"INFORMATION")
    logger.info(f"mode: \t\t\t\t\t {'joint'}")
    logger.info(f"dataset: \t\t\t\t {args.dataset}")
    logger.info(f"number of child network layers: \t {args.layers}")
    logger.info(f"seed: \t\t\t\t {args.seed}")
    logger.info(f"gpu: \t\t\t\t {args.gpu}")
    logger.info(f"include batchnorm: \t\t\t {args.batchnorm}")
    logger.info(f"include stride: \t\t\t {not args.no_stride}")
    logger.info(f"include pooling: \t\t\t {not args.no_pooling}")
    logger.info(f"skip connection: \t\t\t {args.skip}")
    logger.info(f"required # LUTs: \t\t\t {args.rLUT}")
    logger.info(f"required throughput: \t\t\t {args.rThroughput}")
    logger.info(f"Assumed frequency: \t\t\t {CLOCK_FREQUENCY}")
    logger.info(f"training epochs: \t\t\t {args.epochs}")
    logger.info(f"data augmentation: \t\t\t {args.augment}")
    logger.info(f"batch size: \t\t\t\t {args.batch_size}")
    logger.info(f"controller learning rate: \t\t {args.learning_rate}")
    logger.info(f"controller learning rate: \t\t {args.learning_rate}")
    logger.info(f"architecture episodes: \t\t\t {args.episodes}")
    logger.info(f"using multi gpus: \t\t\t {args.multi_gpu}")
    logger.info(f"architecture space: ")
    for name, value in ARCH_SPACE.items():
        logger.info(name + f": \t\t\t\t {value}")
    logger.info(f"quantization space: ")
    for name, value in QUAN_SPACE.items():
        logger.info(name + f": \t\t\t {value}")

    agent = Agent({
        **ARCH_SPACE,
        **QUAN_SPACE
    },
                  args.layers,
                  lr=args.learning_rate,
                  device=torch.device('cpu'),
                  skip=args.skip)

    train_data, val_data = data.get_data(args.dataset,
                                         device,
                                         shuffle=True,
                                         batch_size=args.batch_size,
                                         augment=args.augment)

    input_shape, num_classes = data.get_info(args.dataset)
    ## (3,32,32) -> (1,3,32,32) add batch dimension
    sample_input = utility.get_sample_input(device, input_shape)

    writer.writerow(["ID"] +
                    ["Layer {}".format(i)
                     for i in range(args.layers)] + ["Accuracy"] + [
                         "Partition (Tn, Tm)", "Partition (#LUTs)",
                         "Partition (#cycles)", "Total LUT", "Total Throughput"
                     ] + ["Time"])

    arch_id, total_time = 0, 0
    best_reward = float('-inf')

    logger.info('=' * 50 +
                "Start exploring architecture & quantization space" + '=' * 50)
    best_samples = BestSamples(5)

    for e in range(args.episodes):
        logger.info('-' * 130)
        arch_id += 1
        start = time.time()
        rollout, paras = agent.rollout()
        logger.info("Sample Architecture ID: {}, Sampled actions: {}".format(
            arch_id, rollout))
        arch_paras, quan_paras = utility.split_paras(paras)

        fpga_model = FPGAModel(rLUT=args.rLUT,
                               rThroughput=args.rThroughput,
                               arch_paras=arch_paras,
                               quan_paras=quan_paras)

        if fpga_model.validate():

            model, optimizer = child.get_model(input_shape,
                                               arch_paras,
                                               num_classes,
                                               device,
                                               multi_gpu=args.multi_gpu,
                                               do_bn=args.batchnorm)

            if args.verbosity > 1:
                print(model)
                torchsummary.summary(model, input_shape)

            if args.adapt:
                num_w = utility.get_net_param(model)
                macs = utility.get_net_macs(model, sample_input)
                tb_writer.add_scalar('num_param', num_w, arch_id)
                tb_writer.add_scalar('macs', macs, arch_id)
                if args.verbosity > 1:
                    print(f"# of param: {num_w}, macs: {macs}")

            _, val_acc = backend.fit(model,
                                     optimizer,
                                     train_data,
                                     val_data,
                                     quan_paras=quan_paras,
                                     epochs=args.epochs,
                                     verbosity=args.verbosity)
        else:
            val_acc = 0

        if args.adapt:
            ## TODO: how to make arch_reward function with macs and latency?
            arch_reward = val_acc
        else:
            arch_reward = val_acc

        agent.store_rollout(rollout, arch_reward)
        end = time.time()
        ep_time = end - start
        total_time += ep_time
        best_samples.register(arch_id, rollout, arch_reward)

        tb_writer.add_scalar('val_acc', val_acc, arch_id)
        tb_writer.add_scalar('arch_reward', arch_reward, arch_id)

        if arch_reward > best_reward:
            best_reward = arch_reward
            tb_writer.add_scalar('best_reward', best_reward, arch_id)
            tb_writer.add_graph(model.eval(), (sample_input, ))

        writer.writerow([arch_id] +
                        [str(paras[i])
                         for i in range(args.layers)] + [arch_reward] +
                        list(fpga_model.get_info()) + [ep_time])
        logger.info(f"Reward: {arch_reward}, " + f"Elasped time: {ep_time}, " +
                    f"Average time: {total_time/(e+1)}")
        logger.info(f"Best Reward: {best_samples.reward_list[0]}, " +
                    f"ID: {best_samples.id_list[0]}, " +
                    f"Rollout: {best_samples.rollout_list[0]}")
    logger.info('=' * 50 +
                "Architecture & quantization sapce exploration finished" +
                '=' * 50)
    logger.info(f"Total elasped time: {total_time}")
    logger.info(f"Best samples: {best_samples}")
    tb_writer.close()
    csvfile.close()
def nas(device, dir='experiment'):
    filepath = os.path.join(dir,
                            utility.cleanText(f"nas_{args.episodes}-episodes"))
    logger = utility.get_logger(filepath)
    csvfile = open(filepath + '.csv', mode='w+', newline='')
    writer = csv.writer(csvfile)
    tb_writer = SummaryWriter(filepath)
    logger.info(f"INFORMATION")
    logger.info(f"mode: \t\t\t\t\t {'nas'}")
    logger.info(f"dataset: \t\t\t\t {args.dataset}")
    logger.info(f"seed: \t\t\t\t {args.seed}")
    logger.info(f"gpu: \t\t\t\t {args.gpu}")
    logger.info(f"number of child network layers: \t {args.layers}")
    logger.info(f"include batchnorm: \t\t\t {args.batchnorm}")
    logger.info(f"include stride: \t\t\t {not args.no_stride}")
    logger.info(f"include pooling: \t\t\t {not args.no_pooling}")
    logger.info(f"skip connection: \t\t\t {args.skip}")
    logger.info(f"training epochs: \t\t\t {args.epochs}")
    logger.info(f"data augmentation: \t\t\t {args.augment}")
    logger.info(f"batch size: \t\t\t\t {args.batch_size}")
    logger.info(f"controller learning rate: \t\t {args.learning_rate}")
    logger.info(f"controller learning rate: \t\t {args.learning_rate}")
    logger.info(f"architecture episodes: \t\t\t {args.episodes}")
    logger.info(f"using multi gpus: \t\t\t {args.multi_gpu}")
    logger.info(f"architecture space: ")
    for name, value in ARCH_SPACE.items():
        logger.info(name + f": \t\t\t\t {value}")

    agent = Agent(ARCH_SPACE,
                  args.layers,
                  lr=args.learning_rate,
                  device=torch.device('cpu'),
                  skip=args.skip)
    train_data, val_data = data.get_data(args.dataset,
                                         device,
                                         shuffle=True,
                                         batch_size=args.batch_size,
                                         augment=args.augment)

    input_shape, num_classes = data.get_info(args.dataset)
    ## (3,32,32) -> (1,3,32,32) add batch dimension
    sample_input = utility.get_sample_input(device, input_shape)

    ## write header
    if args.adapt:
        writer.writerow(["ID"] +
                        ["Layer {}".format(i) for i in range(args.layers)] +
                        ["Accuracy", "Time", "params", "macs", "reward"])

    else:
        writer.writerow(["ID"] +
                        ["Layer {}".format(i)
                         for i in range(args.layers)] + ["Accuracy", "Time"])

    arch_id, total_time = 0, 0
    best_reward = float('-inf')
    logger.info('=' * 50 + "Start exploring architecture space" + '=' * 50)
    logger.info('-' * len("Start exploring architecture space"))
    best_samples = BestSamples(5)

    for e in range(args.episodes):
        arch_id += 1
        start = time.time()
        arch_rollout, arch_paras = agent.rollout()
        logger.info("Sample Architecture ID: {}, Sampled actions: {}".format(
            arch_id, arch_rollout))
        ## get model
        model, optimizer = child.get_model(input_shape,
                                           arch_paras,
                                           num_classes,
                                           device,
                                           multi_gpu=args.multi_gpu,
                                           do_bn=args.batchnorm)

        if args.verbosity > 1:
            print(model)
            torchsummary.summary(model, input_shape)

        if args.adapt:
            num_w = utility.get_net_param(model)
            macs = utility.get_net_macs(model, sample_input)
            tb_writer.add_scalar('num_param', num_w, arch_id)
            tb_writer.add_scalar('macs', macs, arch_id)
            if args.verbosity > 1:
                print(f"# of param: {num_w}, macs: {macs}")

        ## train model and get val_acc
        _, val_acc = backend.fit(model,
                                 optimizer,
                                 train_data,
                                 val_data,
                                 epochs=args.epochs,
                                 verbosity=args.verbosity)

        if args.adapt:
            ## TODO: how to model arch_reward?? with num_w and macs?
            arch_reward = val_acc
        else:
            arch_reward = val_acc

        agent.store_rollout(arch_rollout, arch_reward)

        end = time.time()
        ep_time = end - start
        total_time += ep_time

        tb_writer.add_scalar('val_acc', val_acc, arch_id)
        tb_writer.add_scalar('arch_reward', arch_reward, arch_id)

        if arch_reward > best_reward:
            best_reward = arch_reward
            tb_writer.add_scalar('best_reward', best_reward, arch_id)
            tb_writer.add_graph(model.eval(), (sample_input, ))

        best_samples.register(arch_id, arch_rollout, arch_reward)
        if args.adapt:
            writer.writerow([arch_id] +
                            [str(arch_paras[i])
                             for i in range(args.layers)] + [val_acc] +
                            [ep_time] + [num_w] + [macs] + [arch_reward])
        else:
            writer.writerow([arch_id] +
                            [str(arch_paras[i]) for i in range(args.layers)] +
                            [val_acc] + [ep_time])
        logger.info(f"Architecture Reward: {arch_reward}, " +
                    f"Elasped time: {ep_time}, " +
                    f"Average time: {total_time/(e+1)}")
        logger.info(f"Best Reward: {best_samples.reward_list[0]}, " +
                    f"ID: {best_samples.id_list[0]}, " +
                    f"Rollout: {best_samples.rollout_list[0]}")

        logger.info('-' * len("Start exploring architecture space"))
    logger.info('=' * 50 + "Architecture sapce exploration finished" +
                '=' * 50)
    logger.info(f"Total elasped time: {total_time}")
    logger.info(f"Best samples: {best_samples}")
    tb_writer.close()
    csvfile.close()