Ejemplo n.º 1
0
Archivo: child.py Proyecto: ND-SCL/NAQS
 def validate(self, quantize=False, verbosity=False):
     _, val_data = data.get_data(self.dataset,
                                 self.device,
                                 shuffle=False,
                                 batch_size=128)
     acc = backend.fit(self.model,
                       self.optimizer,
                       val_data=val_data,
                       quan_paras=self.quan_paras,
                       epochs=1,
                       verbosity=verbosity)
     return acc
Ejemplo n.º 2
0
Archivo: child.py Proyecto: ND-SCL/NAQS
 def train(self, batch_size=128, epochs=40, verbosity=True, validate=False):
     train_data, val_data = data.get_data(self.dataset,
                                          self.device,
                                          shuffle=True,
                                          batch_size=batch_size,
                                          augment=True)
     acc = backend.fit(self.model,
                       self.optimizer,
                       train_data=train_data,
                       val_data=None if validate is False else val_data,
                       epochs=epochs,
                       verbosity=verbosity)
     return acc
 def fit(self, validate=False, quantize=False, verbosity=0, epochs=40):
     train_data, val_data = data.get_data(
         self.dataset, self.device,
         shuffle=True,
         batch_size=128,
         augment=True)
     loss, acc = backend.fit(
         self.model, self.optimizer,
         train_data=train_data,
         val_data=None if validate is False else val_data,
         epochs=epochs,
         verbosity=verbosity,
         quan_paras=None if quantize is False else self.quan_paras)
     return loss, acc
Ejemplo n.º 4
0
                                         device,
                                         shuffle=True,
                                         batch_size=128)
    input_shape, num_classes = data.get_info(dataset)
    model = SimpleNet().to(device)
    if device.type == 'cuda':
        print("using parallel data")
        model = torch.nn.DataParallel(model)
    optimizer = optim.SGD(model.parameters(),
                          lr=0.01,
                          momentum=0.9,
                          weight_decay=1e-4,
                          nesterov=True)
    # optimizer = optim.Adam(
    #     model.parameters(),
    #     lr=0.001,
    #     betas=(0.9, 0.999),
    #     eps=1e-8,
    #     weight_decay=0.0,
    #     amsgrad=True
    # )
    start = time.time()
    backend.fit(model,
                optimizer,
                train_data,
                val_data,
                epochs=200,
                verbosity=1)
    end = time.time()
    print(end - start)
        'anchor_point': [0, 0, 0, 1]
    }, {
        'num_filters': 128,
        'filter_height': 3,
        'filter_width': 3,
        'pool_size': 2,
        'anchor_point': [0, 0, 0, 0, 1]
    }]
    for c in arch_paras:
        c.pop('anchor_point')
    model = child.get_model(input_shape, arch_paras, num_classes, device)
    optimizer = child.get_optimizer(model, 'SGD')
    quan_paras = []
    for l in range(len(arch_paras)):
        layer = {}
        layer['act_num_int_bits'] = 7
        layer['act_num_frac_bits'] = 7
        layer['weight_num_int_bits'] = 7
        layer['weight_num_frac_bits'] = 7
        quan_paras.append(layer)
    start = time.time()
    backend.fit(model,
                optimizer,
                train_data,
                val_data,
                epochs=40,
                verbosity=1,
                quan_paras=None)
    end = time.time()
    print(end - start)
Ejemplo n.º 6
0
Archivo: main.py Proyecto: ND-SCL/NAQS
def quantization_search(device, dir='experiment'):
    dir = os.path.join(dir,
                       f"rLut={args.rLUT}, rThroughput={args.rThroughput}")
    if os.path.exists(dir) is False:
        os.makedirs(dir)
    filepath = os.path.join(dir, f"quantization ({args.episodes} episodes)")
    logger = get_logger(filepath)
    csvfile = open(filepath + '.csv', mode='w+', newline='')
    writer = csv.writer(csvfile)
    logger.info(f"INFORMATION")
    logger.info(f"mode: \t\t\t\t\t {'quantization'}")
    logger.info(f"dataset: \t\t\t\t {args.dataset}")
    logger.info(f"number of child network layers: \t {args.layers}")
    logger.info(f"include stride: \t\t\t {not args.no_stride}")
    logger.info(f"include pooling: \t\t\t {not args.no_pooling}")
    logger.info(f"skip connection: \t\t\t {args.skip}")
    logger.info(f"required # LUTs: \t\t\t {args.rLUT}")
    logger.info(f"required throughput: \t\t\t {args.rThroughput}")
    logger.info(f"Assumed frequency: \t\t\t {CLOCK_FREQUENCY}")
    logger.info(f"training epochs: \t\t\t {args.epochs}")
    logger.info(f"data augmentation: \t\t\t {args.augment}")
    logger.info(f"batch size: \t\t\t\t {args.batch_size}")
    logger.info(f"controller learning rate: \t\t {args.learning_rate}")
    logger.info(f"architecture episodes: \t\t\t {args.episodes}")
    logger.info(f"using multi gpus: \t\t\t {args.multi_gpu}")
    logger.info(f"architecture space: ")
    # for name, value in ARCH_SPACE.items():
    #     logger.info(name + f": \t\t\t\t {value}")
    logger.info(f"quantization space: ")
    for name, value in QUAN_SPACE.items():
        logger.info(name + f": \t\t\t {value}")
    agent = Agent(QUAN_SPACE,
                  args.layers,
                  lr=args.learning_rate,
                  device=torch.device('cpu'),
                  skip=False)
    train_data, val_data = data.get_data(args.dataset,
                                         device,
                                         shuffle=True,
                                         batch_size=args.batch_size,
                                         augment=args.augment)
    input_shape, num_classes = data.get_info(args.dataset)
    writer.writerow(["ID"] +
                    ["Layer {}".format(i)
                     for i in range(args.layers)] + ["Accuracy"] + [
                         "Partition (Tn, Tm)", "Partition (#LUTs)",
                         "Partition (#cycles)", "Total LUT", "Total Throughput"
                     ] + ["Time"])
    child_id, total_time = 0, 0
    logger.info('=' * 50 + "Start exploring quantization space" + '=' * 50)
    best_samples = BestSamples(5)
    arch_paras = [{
        'filter_height': 3,
        'filter_width': 3,
        'stride_height': 1,
        'stride_width': 1,
        'num_filters': 64,
        'pool_size': 1
    }, {
        'filter_height': 7,
        'filter_width': 5,
        'stride_height': 1,
        'stride_width': 1,
        'num_filters': 48,
        'pool_size': 1
    }, {
        'filter_height': 5,
        'filter_width': 5,
        'stride_height': 2,
        'stride_width': 1,
        'num_filters': 48,
        'pool_size': 1
    }, {
        'filter_height': 3,
        'filter_width': 5,
        'stride_height': 1,
        'stride_width': 1,
        'num_filters': 64,
        'pool_size': 1
    }, {
        'filter_height': 5,
        'filter_width': 7,
        'stride_height': 1,
        'stride_width': 1,
        'num_filters': 36,
        'pool_size': 1
    }, {
        'filter_height': 3,
        'filter_width': 1,
        'stride_height': 1,
        'stride_width': 2,
        'num_filters': 64,
        'pool_size': 2
    }]
    model, optimizer = child.get_model(input_shape,
                                       arch_paras,
                                       num_classes,
                                       device,
                                       multi_gpu=args.multi_gpu,
                                       do_bn=False)
    _, val_acc = backend.fit(model,
                             optimizer,
                             train_data=train_data,
                             val_data=val_data,
                             epochs=args.epochs,
                             verbosity=args.verbosity)
    print(val_acc)
    for e in range(args.episodes):
        logger.info('-' * 130)
        child_id += 1
        start = time.time()
        quan_rollout, quan_paras = agent.rollout()
        logger.info("Sample Quantization ID: {}, Sampled actions: {}".format(
            child_id, quan_rollout))
        fpga_model = FPGAModel(rLUT=args.rLUT,
                               rThroughput=args.rThroughput,
                               arch_paras=arch_paras,
                               quan_paras=quan_paras)
        if fpga_model.validate():
            _, reward = backend.fit(model,
                                    optimizer,
                                    val_data=val_data,
                                    quan_paras=quan_paras,
                                    epochs=1,
                                    verbosity=args.verbosity)
        else:
            reward = 0
        agent.store_rollout(quan_rollout, reward)
        end = time.time()
        ep_time = end - start
        total_time += ep_time
        best_samples.register(child_id, quan_rollout, reward)
        writer.writerow([child_id] +
                        [str(quan_paras[i]) for i in range(args.layers)] +
                        [reward] + list(fpga_model.get_info()) + [ep_time])
        logger.info(f"Reward: {reward}, " + f"Elasped time: {ep_time}, " +
                    f"Average time: {total_time/(e+1)}")
        logger.info(f"Best Reward: {best_samples.reward_list[0]}, " +
                    f"ID: {best_samples.id_list[0]}, " +
                    f"Rollout: {best_samples.rollout_list[0]}")
    logger.info('=' * 50 + "Quantization sapce exploration finished" +
                '=' * 50)
    logger.info(f"Total elasped time: {total_time}")
    logger.info(f"Best samples: {best_samples}")
    csvfile.close()
Ejemplo n.º 7
0
Archivo: main.py Proyecto: ND-SCL/NAQS
def nested_search(device, dir='experiment'):
    dir = os.path.join(dir,
                       f"rLut={args.rLUT}, rThroughput={args.rThroughput}")
    if os.path.exists(dir) is False:
        os.makedirs(dir)
    filepath = os.path.join(dir, f"nested ({args.episodes} episodes)")
    logger = get_logger(filepath)
    csvfile = open(filepath + '.csv', mode='w+', newline='')
    writer = csv.writer(csvfile)
    logger.info(f"INFORMATION")
    logger.info(f"mode: \t\t\t\t\t {'nested'}")
    logger.info(f"dataset: \t\t\t\t {args.dataset}")
    logger.info(f"number of child network layers: \t {args.layers}")
    logger.info(f"include stride: \t\t\t {not args.no_stride}")
    logger.info(f"include pooling: \t\t\t {not args.no_pooling}")
    logger.info(f"skip connection: \t\t\t {args.skip}")
    logger.info(f"required # LUTs: \t\t\t {args.rLUT}")
    logger.info(f"required throughput: \t\t\t {args.rThroughput}")
    logger.info(f"Assumed frequency: \t\t\t {CLOCK_FREQUENCY}")
    logger.info(f"training epochs: \t\t\t {args.epochs}")
    logger.info(f"data augmentation: \t\t\t {args.augment}")
    logger.info(f"batch size: \t\t\t\t {args.batch_size}")
    logger.info(f"controller learning rate: \t\t {args.learning_rate}")
    logger.info(f"architecture episodes: \t\t\t {args.episodes1}")
    logger.info(f"quantization episodes: \t\t\t {args.episodes2}")
    logger.info(f"using multi gpus: \t\t\t {args.multi_gpu}")
    logger.info(f"architecture space: ")
    for name, value in ARCH_SPACE.items():
        logger.info(name + f": \t\t\t\t {value}")
    logger.info(f"quantization space: ")
    for name, value in QUAN_SPACE.items():
        logger.info(name + f": \t\t\t {value}")
    train_data, val_data = data.get_data(args.dataset,
                                         device,
                                         shuffle=True,
                                         batch_size=args.batch_size,
                                         augment=args.augment)
    input_shape, num_classes = data.get_info(args.dataset)
    writer.writerow(["ID"] +
                    ["Layer {}".format(i)
                     for i in range(args.layers)] + ["Accuracy"] + [
                         "Partition (Tn, Tm)", "Partition (#LUTs)",
                         "Partition (#cycles)", "Total LUT", "Total Throughput"
                     ] + ["Time"])
    arch_agent = Agent(ARCH_SPACE,
                       args.layers,
                       lr=args.learning_rate,
                       device=torch.device('cpu'),
                       skip=args.skip)
    arch_id, total_time = 0, 0
    logger.info('=' * 50 + "Start exploring architecture space" + '=' * 50)
    best_arch = BestSamples(5)
    for e1 in range(args.episodes1):
        logger.info('-' * 130)
        arch_id += 1
        start = time.time()
        arch_rollout, arch_paras = arch_agent.rollout()
        logger.info("Sample Architecture ID: {}, Sampled arch: {}".format(
            arch_id, arch_rollout))
        model, optimizer = child.get_model(input_shape,
                                           arch_paras,
                                           num_classes,
                                           device,
                                           multi_gpu=args.multi_gpu,
                                           do_bn=False)
        backend.fit(model,
                    optimizer,
                    train_data,
                    val_data,
                    epochs=args.epochs,
                    verbosity=args.verbosity)
        quan_id = 0
        best_quan_reward = -1
        logger.info('=' * 50 + "Start exploring quantization space" + '=' * 50)
        quan_agent = Agent(QUAN_SPACE,
                           args.layers,
                           lr=args.learning_rate,
                           device=torch.device('cpu'),
                           skip=False)
        for e2 in range(args.episodes2):
            quan_id += 1
            quan_rollout, quan_paras = quan_agent.rollout()
            fpga_model = FPGAModel(rLUT=args.rLUT,
                                   rThroughput=args.rThroughput,
                                   arch_paras=arch_paras,
                                   quan_paras=quan_paras)
            if fpga_model.validate():
                _, quan_reward = backend.fit(model,
                                             optimizer,
                                             val_data=val_data,
                                             quan_paras=quan_paras,
                                             epochs=1,
                                             verbosity=args.verbosity)
            else:
                quan_reward = 0
            logger.info(
                "Sample Quantization ID: {}, Sampled Quantization: {}, reward: {}"
                .format(quan_id, quan_rollout, quan_reward))
            quan_agent.store_rollout(quan_rollout, quan_reward)
            if quan_reward > best_quan_reward:
                best_quan_reward = quan_reward
                best_quan_rollout, best_quan_paras = quan_rollout, quan_paras
        logger.info('=' * 50 + "Quantization space exploration finished" +
                    '=' * 50)
        arch_reward = best_quan_reward
        arch_agent.store_rollout(arch_rollout, arch_reward)
        end = time.time()
        ep_time = end - start
        total_time += ep_time
        best_arch.register(
            arch_id,
            utility.combine_rollout(arch_rollout, best_quan_rollout,
                                    args.layers), arch_reward)
        writer.writerow([arch_id] + [
            str(arch_paras[i]) + '\n' + str(best_quan_paras[i])
            for i in range(args.layers)
        ] + [arch_reward] + list(fpga_model.get_info()) + [ep_time])
        logger.info(f"Reward: {arch_reward}, " + f"Elasped time: {ep_time}, " +
                    f"Average time: {total_time/(e1+1)}")
        logger.info(f"Best Reward: {best_arch.reward_list[0]}, " +
                    f"ID: {best_arch.id_list[0]}, " +
                    f"Rollout: {best_arch.rollout_list[0]}")
    logger.info('=' * 50 +
                "Architecture & quantization sapce exploration finished" +
                '=' * 50)
    logger.info(f"Total elasped time: {total_time}")
    logger.info(f"Best samples: {best_arch}")
    csvfile.close()
Ejemplo n.º 8
0
Archivo: main.py Proyecto: ND-SCL/NAQS
def nas(device, dir='experiment'):
    filepath = os.path.join(dir, f"nas ({args.episodes} episodes)")
    logger = get_logger(filepath)
    csvfile = open(filepath + '.csv', mode='w+', newline='')
    writer = csv.writer(csvfile)
    logger.info(f"INFORMATION")
    logger.info(f"mode: \t\t\t\t\t {'nas'}")
    logger.info(f"dataset: \t\t\t\t {args.dataset}")
    logger.info(f"number of child network layers: \t {args.layers}")
    logger.info(f"include stride: \t\t\t {not args.no_stride}")
    logger.info(f"include pooling: \t\t\t {not args.no_pooling}")
    logger.info(f"skip connection: \t\t\t {args.skip}")
    logger.info(f"training epochs: \t\t\t {args.epochs}")
    logger.info(f"data augmentation: \t\t\t {args.augment}")
    logger.info(f"batch size: \t\t\t\t {args.batch_size}")
    logger.info(f"controller learning rate: \t\t {args.learning_rate}")
    logger.info(f"architecture episodes: \t\t\t {args.episodes}")
    logger.info(f"using multi gpus: \t\t\t {args.multi_gpu}")
    logger.info(f"architecture space: ")
    for name, value in ARCH_SPACE.items():
        logger.info(name + f": \t\t\t\t {value}")
    agent = Agent(ARCH_SPACE,
                  args.layers,
                  lr=args.learning_rate,
                  device=torch.device('cpu'),
                  skip=args.skip)
    train_data, val_data = data.get_data(args.dataset,
                                         device,
                                         shuffle=True,
                                         batch_size=args.batch_size,
                                         augment=args.augment)
    input_shape, num_classes = data.get_info(args.dataset)
    writer.writerow(["ID"] +
                    ["Layer {}".format(i)
                     for i in range(args.layers)] + ["Accuracy", "Time"])
    arch_id, total_time = 0, 0
    logger.info('=' * 50 + "Start exploring architecture space" + '=' * 50)
    logger.info('-' * len("Start exploring architecture space"))
    best_samples = BestSamples(5)
    for e in range(args.episodes):
        arch_id += 1
        start = time.time()
        arch_rollout, arch_paras = agent.rollout()
        logger.info("Sample Architecture ID: {}, Sampled actions: {}".format(
            arch_id, arch_rollout))
        model, optimizer = child.get_model(input_shape,
                                           arch_paras,
                                           num_classes,
                                           device,
                                           multi_gpu=args.multi_gpu,
                                           do_bn=True)
        _, arch_reward = backend.fit(model,
                                     optimizer,
                                     train_data,
                                     val_data,
                                     epochs=args.epochs,
                                     verbosity=args.verbosity)
        agent.store_rollout(arch_rollout, arch_reward)
        end = time.time()
        ep_time = end - start
        total_time += ep_time
        best_samples.register(arch_id, arch_rollout, arch_reward)
        writer.writerow([arch_id] +
                        [str(arch_paras[i]) for i in range(args.layers)] +
                        [arch_reward] + [ep_time])
        logger.info(f"Architecture Reward: {arch_reward}, " +
                    f"Elasped time: {ep_time}, " +
                    f"Average time: {total_time/(e+1)}")
        logger.info(f"Best Reward: {best_samples.reward_list[0]}, " +
                    f"ID: {best_samples.id_list[0]}, " +
                    f"Rollout: {best_samples.rollout_list[0]}")
        logger.info('-' * len("Start exploring architecture space"))
    logger.info('=' * 50 + "Architecture sapce exploration finished" +
                '=' * 50)
    logger.info(f"Total elasped time: {total_time}")
    logger.info(f"Best samples: {best_samples}")
    csvfile.close()
def sync_search(device, dir='experiment'):
    dir = os.path.join(
        dir,
        utility.cleanText(f"rLut-{args.rLUT}_rThroughput-{args.rThroughput}"))
    if os.path.exists(dir) is False:
        os.makedirs(dir)
    filepath = os.path.join(
        dir, utility.cleanText(f"joint_{args.episodes}-episodes"))
    logger = utility.get_logger(filepath)
    csvfile = open(filepath + '.csv', mode='w+', newline='')
    writer = csv.writer(csvfile)
    tb_writer = SummaryWriter(filepath)

    logger.info(f"INFORMATION")
    logger.info(f"mode: \t\t\t\t\t {'joint'}")
    logger.info(f"dataset: \t\t\t\t {args.dataset}")
    logger.info(f"number of child network layers: \t {args.layers}")
    logger.info(f"seed: \t\t\t\t {args.seed}")
    logger.info(f"gpu: \t\t\t\t {args.gpu}")
    logger.info(f"include batchnorm: \t\t\t {args.batchnorm}")
    logger.info(f"include stride: \t\t\t {not args.no_stride}")
    logger.info(f"include pooling: \t\t\t {not args.no_pooling}")
    logger.info(f"skip connection: \t\t\t {args.skip}")
    logger.info(f"required # LUTs: \t\t\t {args.rLUT}")
    logger.info(f"required throughput: \t\t\t {args.rThroughput}")
    logger.info(f"Assumed frequency: \t\t\t {CLOCK_FREQUENCY}")
    logger.info(f"training epochs: \t\t\t {args.epochs}")
    logger.info(f"data augmentation: \t\t\t {args.augment}")
    logger.info(f"batch size: \t\t\t\t {args.batch_size}")
    logger.info(f"controller learning rate: \t\t {args.learning_rate}")
    logger.info(f"controller learning rate: \t\t {args.learning_rate}")
    logger.info(f"architecture episodes: \t\t\t {args.episodes}")
    logger.info(f"using multi gpus: \t\t\t {args.multi_gpu}")
    logger.info(f"architecture space: ")
    for name, value in ARCH_SPACE.items():
        logger.info(name + f": \t\t\t\t {value}")
    logger.info(f"quantization space: ")
    for name, value in QUAN_SPACE.items():
        logger.info(name + f": \t\t\t {value}")

    agent = Agent({
        **ARCH_SPACE,
        **QUAN_SPACE
    },
                  args.layers,
                  lr=args.learning_rate,
                  device=torch.device('cpu'),
                  skip=args.skip)

    train_data, val_data = data.get_data(args.dataset,
                                         device,
                                         shuffle=True,
                                         batch_size=args.batch_size,
                                         augment=args.augment)

    input_shape, num_classes = data.get_info(args.dataset)
    ## (3,32,32) -> (1,3,32,32) add batch dimension
    sample_input = utility.get_sample_input(device, input_shape)

    writer.writerow(["ID"] +
                    ["Layer {}".format(i)
                     for i in range(args.layers)] + ["Accuracy"] + [
                         "Partition (Tn, Tm)", "Partition (#LUTs)",
                         "Partition (#cycles)", "Total LUT", "Total Throughput"
                     ] + ["Time"])

    arch_id, total_time = 0, 0
    best_reward = float('-inf')

    logger.info('=' * 50 +
                "Start exploring architecture & quantization space" + '=' * 50)
    best_samples = BestSamples(5)

    for e in range(args.episodes):
        logger.info('-' * 130)
        arch_id += 1
        start = time.time()
        rollout, paras = agent.rollout()
        logger.info("Sample Architecture ID: {}, Sampled actions: {}".format(
            arch_id, rollout))
        arch_paras, quan_paras = utility.split_paras(paras)

        fpga_model = FPGAModel(rLUT=args.rLUT,
                               rThroughput=args.rThroughput,
                               arch_paras=arch_paras,
                               quan_paras=quan_paras)

        if fpga_model.validate():

            model, optimizer = child.get_model(input_shape,
                                               arch_paras,
                                               num_classes,
                                               device,
                                               multi_gpu=args.multi_gpu,
                                               do_bn=args.batchnorm)

            if args.verbosity > 1:
                print(model)
                torchsummary.summary(model, input_shape)

            if args.adapt:
                num_w = utility.get_net_param(model)
                macs = utility.get_net_macs(model, sample_input)
                tb_writer.add_scalar('num_param', num_w, arch_id)
                tb_writer.add_scalar('macs', macs, arch_id)
                if args.verbosity > 1:
                    print(f"# of param: {num_w}, macs: {macs}")

            _, val_acc = backend.fit(model,
                                     optimizer,
                                     train_data,
                                     val_data,
                                     quan_paras=quan_paras,
                                     epochs=args.epochs,
                                     verbosity=args.verbosity)
        else:
            val_acc = 0

        if args.adapt:
            ## TODO: how to make arch_reward function with macs and latency?
            arch_reward = val_acc
        else:
            arch_reward = val_acc

        agent.store_rollout(rollout, arch_reward)
        end = time.time()
        ep_time = end - start
        total_time += ep_time
        best_samples.register(arch_id, rollout, arch_reward)

        tb_writer.add_scalar('val_acc', val_acc, arch_id)
        tb_writer.add_scalar('arch_reward', arch_reward, arch_id)

        if arch_reward > best_reward:
            best_reward = arch_reward
            tb_writer.add_scalar('best_reward', best_reward, arch_id)
            tb_writer.add_graph(model.eval(), (sample_input, ))

        writer.writerow([arch_id] +
                        [str(paras[i])
                         for i in range(args.layers)] + [arch_reward] +
                        list(fpga_model.get_info()) + [ep_time])
        logger.info(f"Reward: {arch_reward}, " + f"Elasped time: {ep_time}, " +
                    f"Average time: {total_time/(e+1)}")
        logger.info(f"Best Reward: {best_samples.reward_list[0]}, " +
                    f"ID: {best_samples.id_list[0]}, " +
                    f"Rollout: {best_samples.rollout_list[0]}")
    logger.info('=' * 50 +
                "Architecture & quantization sapce exploration finished" +
                '=' * 50)
    logger.info(f"Total elasped time: {total_time}")
    logger.info(f"Best samples: {best_samples}")
    tb_writer.close()
    csvfile.close()
def nas(device, dir='experiment'):
    filepath = os.path.join(dir,
                            utility.cleanText(f"nas_{args.episodes}-episodes"))
    logger = utility.get_logger(filepath)
    csvfile = open(filepath + '.csv', mode='w+', newline='')
    writer = csv.writer(csvfile)
    tb_writer = SummaryWriter(filepath)
    logger.info(f"INFORMATION")
    logger.info(f"mode: \t\t\t\t\t {'nas'}")
    logger.info(f"dataset: \t\t\t\t {args.dataset}")
    logger.info(f"seed: \t\t\t\t {args.seed}")
    logger.info(f"gpu: \t\t\t\t {args.gpu}")
    logger.info(f"number of child network layers: \t {args.layers}")
    logger.info(f"include batchnorm: \t\t\t {args.batchnorm}")
    logger.info(f"include stride: \t\t\t {not args.no_stride}")
    logger.info(f"include pooling: \t\t\t {not args.no_pooling}")
    logger.info(f"skip connection: \t\t\t {args.skip}")
    logger.info(f"training epochs: \t\t\t {args.epochs}")
    logger.info(f"data augmentation: \t\t\t {args.augment}")
    logger.info(f"batch size: \t\t\t\t {args.batch_size}")
    logger.info(f"controller learning rate: \t\t {args.learning_rate}")
    logger.info(f"controller learning rate: \t\t {args.learning_rate}")
    logger.info(f"architecture episodes: \t\t\t {args.episodes}")
    logger.info(f"using multi gpus: \t\t\t {args.multi_gpu}")
    logger.info(f"architecture space: ")
    for name, value in ARCH_SPACE.items():
        logger.info(name + f": \t\t\t\t {value}")

    agent = Agent(ARCH_SPACE,
                  args.layers,
                  lr=args.learning_rate,
                  device=torch.device('cpu'),
                  skip=args.skip)
    train_data, val_data = data.get_data(args.dataset,
                                         device,
                                         shuffle=True,
                                         batch_size=args.batch_size,
                                         augment=args.augment)

    input_shape, num_classes = data.get_info(args.dataset)
    ## (3,32,32) -> (1,3,32,32) add batch dimension
    sample_input = utility.get_sample_input(device, input_shape)

    ## write header
    if args.adapt:
        writer.writerow(["ID"] +
                        ["Layer {}".format(i) for i in range(args.layers)] +
                        ["Accuracy", "Time", "params", "macs", "reward"])

    else:
        writer.writerow(["ID"] +
                        ["Layer {}".format(i)
                         for i in range(args.layers)] + ["Accuracy", "Time"])

    arch_id, total_time = 0, 0
    best_reward = float('-inf')
    logger.info('=' * 50 + "Start exploring architecture space" + '=' * 50)
    logger.info('-' * len("Start exploring architecture space"))
    best_samples = BestSamples(5)

    for e in range(args.episodes):
        arch_id += 1
        start = time.time()
        arch_rollout, arch_paras = agent.rollout()
        logger.info("Sample Architecture ID: {}, Sampled actions: {}".format(
            arch_id, arch_rollout))
        ## get model
        model, optimizer = child.get_model(input_shape,
                                           arch_paras,
                                           num_classes,
                                           device,
                                           multi_gpu=args.multi_gpu,
                                           do_bn=args.batchnorm)

        if args.verbosity > 1:
            print(model)
            torchsummary.summary(model, input_shape)

        if args.adapt:
            num_w = utility.get_net_param(model)
            macs = utility.get_net_macs(model, sample_input)
            tb_writer.add_scalar('num_param', num_w, arch_id)
            tb_writer.add_scalar('macs', macs, arch_id)
            if args.verbosity > 1:
                print(f"# of param: {num_w}, macs: {macs}")

        ## train model and get val_acc
        _, val_acc = backend.fit(model,
                                 optimizer,
                                 train_data,
                                 val_data,
                                 epochs=args.epochs,
                                 verbosity=args.verbosity)

        if args.adapt:
            ## TODO: how to model arch_reward?? with num_w and macs?
            arch_reward = val_acc
        else:
            arch_reward = val_acc

        agent.store_rollout(arch_rollout, arch_reward)

        end = time.time()
        ep_time = end - start
        total_time += ep_time

        tb_writer.add_scalar('val_acc', val_acc, arch_id)
        tb_writer.add_scalar('arch_reward', arch_reward, arch_id)

        if arch_reward > best_reward:
            best_reward = arch_reward
            tb_writer.add_scalar('best_reward', best_reward, arch_id)
            tb_writer.add_graph(model.eval(), (sample_input, ))

        best_samples.register(arch_id, arch_rollout, arch_reward)
        if args.adapt:
            writer.writerow([arch_id] +
                            [str(arch_paras[i])
                             for i in range(args.layers)] + [val_acc] +
                            [ep_time] + [num_w] + [macs] + [arch_reward])
        else:
            writer.writerow([arch_id] +
                            [str(arch_paras[i]) for i in range(args.layers)] +
                            [val_acc] + [ep_time])
        logger.info(f"Architecture Reward: {arch_reward}, " +
                    f"Elasped time: {ep_time}, " +
                    f"Average time: {total_time/(e+1)}")
        logger.info(f"Best Reward: {best_samples.reward_list[0]}, " +
                    f"ID: {best_samples.id_list[0]}, " +
                    f"Rollout: {best_samples.rollout_list[0]}")

        logger.info('-' * len("Start exploring architecture space"))
    logger.info('=' * 50 + "Architecture sapce exploration finished" +
                '=' * 50)
    logger.info(f"Total elasped time: {total_time}")
    logger.info(f"Best samples: {best_samples}")
    tb_writer.close()
    csvfile.close()