def generate_synthetic_controller_data(nasbench, model, base_arch=None, random_arch=0, direction='+'): random_synthetic_input = [] random_synthetic_target = [] if random_arch > 0: while len(random_synthetic_input) < random_arch: seq = utils.generate_arch(1, nasbench)[1][0] if seq not in random_synthetic_input and seq not in base_arch: random_synthetic_input.append(seq) controller_synthetic_dataset = utils.ControllerDataset( random_synthetic_input, None, False) controller_synthetic_queue = torch.utils.data.DataLoader( controller_synthetic_dataset, batch_size=len(controller_synthetic_dataset), shuffle=False, pin_memory=True) with torch.no_grad(): model.eval() for sample in controller_synthetic_queue: encoder_input = sample['encoder_input'].cuda() _, _, _, predict_value = model.encoder(encoder_input) random_synthetic_target += predict_value.data.squeeze().tolist( ) assert len(random_synthetic_input) == len(random_synthetic_target) synthetic_input = random_synthetic_input synthetic_target = random_synthetic_target assert len(synthetic_input) == len(synthetic_target) return synthetic_input, synthetic_target
def train_controller(model, train_input, train_target, epochs): logging.info('Train data: {}'.format(len(train_input))) controller_train_dataset = utils.ControllerDataset(train_input, train_target, True) controller_train_queue = torch.utils.data.DataLoader( controller_train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2_reg) for epoch in range(1, epochs + 1): loss, mse, ce = controller_train(controller_train_queue, model, optimizer) logging.info("epoch %04d train loss %.6f mse %.6f ce %.6f", epoch, loss, mse, ce)
def generate_synthetic_controller_data(nasbench, model, base_arch=None, random_arch=0): random_synthetic_input = [] random_synthetic_target = [] if random_arch > 0: all_keys = list(nasbench.hash_iterator()) np.random.shuffle(all_keys) i = 0 while len(random_synthetic_input) < random_arch: key = all_keys[i] fs, cs = nasbench.get_metrics_from_hash(key) seq = utils.convert_arch_to_seq(fs['module_adjacency'], fs['module_operations']) if seq not in random_synthetic_input and seq not in base_arch: random_synthetic_input.append(seq) i += 1 nao_synthetic_dataset = utils.ControllerDataset( random_synthetic_input, None, False) nao_synthetic_queue = torch.utils.data.DataLoader( nao_synthetic_dataset, batch_size=len(nao_synthetic_dataset), shuffle=False, pin_memory=True) with torch.no_grad(): model.eval() for sample in nao_synthetic_queue: encoder_input = sample['encoder_input'].cuda() _, _, _, predict_value = model.encoder(encoder_input) random_synthetic_target += predict_value.data.squeeze().tolist( ) assert len(random_synthetic_input) == len(random_synthetic_target) synthetic_input = random_synthetic_input synthetic_target = random_synthetic_target assert len(synthetic_input) == len(synthetic_target) return synthetic_input, synthetic_target
def main(mode, data_file_path): # mode = "train" # data_file_path = "data/train_data.json" random.seed(conf.seed) np.random.seed(conf.seed) torch.manual_seed(conf.seed) logging.info("conf = %s", conf) conf.source_length = conf.encoder_length = conf.decoder_length = ( conf.graph_size + 2) * (conf.graph_size - 1) // 2 epochs = conf.epochs model = Graph2Seq(mode=mode, conf=conf) # load data dataset = utils.ControllerDataset(data_file_path) queue = torch.utils.data.DataLoader(dataset, batch_size=conf.batch_size, shuffle=True, pin_memory=True, collate_fn=utils.collate_fn) if mode == "train": model.train() logging.info('Train data: {}'.format(len(queue))) optimizer = torch.optim.Adam(model.parameters(), lr=conf.learning_rate, weight_decay=conf.l2_reg) def train_step(train_queue, optimizer): objs = utils.AvgrageMeter() nll = utils.AvgrageMeter() for step, sample in enumerate(train_queue): fw_adjs = sample['fw_adjs'] bw_adjs = sample['bw_adjs'] operations = sample['operations'] num_nodes = sample['num_nodes'] sequence = sample['sequence'] optimizer.zero_grad() log_prob, predicted_value = model(fw_adjs, bw_adjs, operations, num_nodes, targets=sequence) # print("input: {} output : {}".format(log_prob.size(), sequence.size())) loss = F.nll_loss( log_prob.contiguous().view(-1, log_prob.size(-1)), sequence.view(-1)) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), conf.grad_bound) optimizer.step() n = sequence.size(0) objs.update(loss.data, n) nll.update(loss.data, n) # logging.info("step : %04d, objs: %.6f, nll : %.6f", step, objs,avgs, nll) return objs.avg, nll.avg ## Check with one epoch epoch = 1 for epoch in range(1, epochs + 1): loss, ce = train_step(queue, optimizer) logging.info("epoch %04d train loss %.6f ce %.6f", epoch, loss, ce) ## save trainable parameters torch.save(model.state_dict(), conf.model_path) if mode == "test": model.load_state_dict(torch.load(conf.model_path)) model.eval() def test_step(test_queue): match = 0 total = 0 for step, sample in enumerate(test_queue): fw_adjs = sample['fw_adjs'] bw_adjs = sample['bw_adjs'] operations = sample['operations'] num_nodes = sample['num_nodes'] sequence = sample['sequence'] log_prob, predicted_value = model(fw_adjs, bw_adjs, operations, num_nodes) match = torch.all(torch.equal(predicted_value, sequence), dim=1) total += len(num_nodes) accuracy = match / predicted_value.size(0) return accuracy logging.info('Test data: {}'.format(len(queue))) for epoch in range(1, epochs + 1): accuracy = test_step(queue) logging.info("epoch %04d accuracy %.6f", epoch, accuracy)
def main(): if not torch.cuda.is_available(): logging.info('No GPU found!') sys.exit(1) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) cudnn.enabled = True cudnn.benchmark = True logging.info("Args = %s", args) args.source_length = args.encoder_length = args.decoder_length = ( args.nodes + 2) * (args.nodes - 1) // 2 nasbench = api.NASBench(os.path.join(args.data, 'nasbench_full.tfrecord')) controller = NAO( args.encoder_layers, args.decoder_layers, args.mlp_layers, args.mlp_hidden_size, args.hidden_size, args.vocab_size, args.dropout, args.source_length, args.encoder_length, args.decoder_length, ) logging.info("param size = %d", utils.count_parameters(controller)) controller = controller.cuda() child_arch_pool, child_seq_pool, child_arch_pool_valid_acc = utils.generate_arch( args.seed_arch, nasbench, need_perf=True) arch_pool = [] seq_pool = [] arch_pool_valid_acc = [] for i in range(args.iteration + 1): logging.info('Iteration {}'.format(i + 1)) if not child_arch_pool_valid_acc: for arch in child_arch_pool: data = nasbench.query(arch) child_arch_pool_valid_acc.append(data['validation_accuracy']) arch_pool += child_arch_pool arch_pool_valid_acc += child_arch_pool_valid_acc seq_pool += child_seq_pool arch_pool_valid_acc_sorted_indices = np.argsort( arch_pool_valid_acc)[::-1] arch_pool = [arch_pool[i] for i in arch_pool_valid_acc_sorted_indices] seq_pool = [seq_pool[i] for i in arch_pool_valid_acc_sorted_indices] arch_pool_valid_acc = [ arch_pool_valid_acc[i] for i in arch_pool_valid_acc_sorted_indices ] with open(os.path.join(args.output_dir, 'arch_pool.{}'.format(i)), 'w') as fa: for arch, seq, valid_acc in zip(arch_pool, seq_pool, arch_pool_valid_acc): fa.write('{}\t{}\t{}\t{}\n'.format(arch.matrix, arch.ops, seq, valid_acc)) for arch_index in range(10): print('Top 10 architectures:') print('Architecutre connection:{}'.format( arch_pool[arch_index].matrix)) print('Architecture operations:{}'.format( arch_pool[arch_index].ops)) print('Valid accuracy:{}'.format(arch_pool_valid_acc[arch_index])) if i == args.iteration: print('Final architectures:') for arch_index in range(10): print('Architecutre connection:{}'.format( arch_pool[arch_index].matrix)) print('Architecture operations:{}'.format( arch_pool[arch_index].ops)) print('Valid accuracy:{}'.format( arch_pool_valid_acc[arch_index])) fs, cs = nasbench.get_metrics_from_spec(arch_pool[arch_index]) test_acc = np.mean( [cs[108][j]['final_test_accuracy'] for j in range(3)]) print('Mean test accuracy:{}'.format(test_acc)) break train_encoder_input = seq_pool min_val = min(arch_pool_valid_acc) max_val = max(arch_pool_valid_acc) train_encoder_target = [(i - min_val) / (max_val - min_val) for i in arch_pool_valid_acc] # Pre-train logging.info('Pre-train EPD') train_controller(controller, train_encoder_input, train_encoder_target, args.pretrain_epochs) logging.info('Finish pre-training EPD') # Generate synthetic data logging.info('Generate synthetic data for EPD') synthetic_encoder_input, synthetic_encoder_target = generate_synthetic_controller_data( nasbench, controller, train_encoder_input, args.random_arch) if args.up_sample_ratio is None: up_sample_ratio = np.ceil(args.random_arch / len(train_encoder_input)).astype(np.int) else: up_sample_ratio = args.up_sample_ratio all_encoder_input = train_encoder_input * up_sample_ratio + synthetic_encoder_input all_encoder_target = train_encoder_target * up_sample_ratio + synthetic_encoder_target # Train logging.info('Train EPD') train_controller(controller, all_encoder_input, all_encoder_target, args.epochs) logging.info('Finish training EPD') new_archs = [] new_seqs = [] predict_step_size = 0 unique_input = train_encoder_input + synthetic_encoder_input unique_target = train_encoder_target + synthetic_encoder_target unique_indices = np.argsort(unique_target)[::-1] unique_input = [unique_input[i] for i in unique_indices] topk_archs = unique_input[:args.k] controller_infer_dataset = utils.ControllerDataset( topk_archs, None, False) controller_infer_queue = torch.utils.data.DataLoader( controller_infer_dataset, batch_size=len(controller_infer_dataset), shuffle=False, pin_memory=True) while len(new_archs) < args.new_arch: predict_step_size += 1 logging.info('Generate new architectures with step size %d', predict_step_size) new_seq, new_perfs = controller_infer(controller_infer_queue, controller, predict_step_size, direction='+') for seq in new_seq: matrix, ops = utils.convert_seq_to_arch(seq) arch = api.ModelSpec(matrix=matrix, ops=ops) if nasbench.is_valid(arch) and len( arch.ops ) == 7 and seq not in train_encoder_input and seq not in new_seqs: new_archs.append(arch) new_seqs.append(seq) if len(new_seqs) >= args.new_arch: break logging.info('%d new archs generated now', len(new_archs)) if predict_step_size > args.max_step_size: break child_arch_pool = new_archs child_seq_pool = new_seqs child_arch_pool_valid_acc = [] child_arch_pool_test_acc = [] logging.info("Generate %d new archs", len(child_arch_pool)) print(nasbench.get_budget_counters())