Ejemplo n.º 1
0
	def __init__(self):

		args = self.get_args()
		self.args = args
		self.args['crop_size'] = 112
		self.args['num_frames_per_clip'] = 16
		
		#Building Graph
		self.place_holders = PlaceHolders(args)

		inputs = self.place_holders.inference()
		#[self.image_sequence, self.raw_image, self.depth_image, self.seg_image, self.speed, self.collision, self.intersection, self.control, self.reward, self.transition]

		self.c3d_encoder = C3D_Encoder(args,'c3d_encoder', inputs[0])
		self.c3d_future = C3D_Encoder(args,'c3d_encoder', inputs[9], reuse=True)

		# self.vae = VAE(args, self.c3d_encoder.inference())
		# self.future_vae = VAE(args, self.c3d_future.inference())

		z = self.c3d_encoder.inference()
		self.z = z

		self.raw_decoder = ImageDecoder(args, 'raw_image', z, last=3)
		self.raw_decoder_loss = MSELoss(args, 'raw_image', self.raw_decoder.inference(), inputs[1])

		self.seg_decoder = ImageDecoder(args, 'seg', z, last=13)
		self.seg_decoder_loss = CrossEntropyLoss(args, 'seg', self.seg_decoder.inference(), inputs[3])
		
		self.depth_decoder = ImageDecoder(args, 'depth', z, last=1)
		self.depth_decoder_loss = MSELoss(args, 'depth', self.depth_decoder.inference(), inputs[2])

		self.speed_prediction = MLP(args, 'speed', z, 1, 101)
		self.speed_loss = MSELoss(args, 'speed', self.speed_prediction.inference(), inputs[4])        

		self.collision_prediction = MLP(args, 'collision', z, 1, 101)
		self.collision_loss = MSELoss(args, 'collision', self.collision_prediction.inference(), inputs[5])

		self.intersection_prediction = MLP(args, 'intersection', z, 1, 101)
		self.intersection_loss = MSELoss(args, 'intersection', self.intersection_prediction.inference(), inputs[6])

		self.policy = PG(args, 'policy', z, 13)
		self.log_probs = self.policy.inference()
		self.policy_loss = PGLoss(args, 'policy', inputs[7], inputs[8], self.log_probs)


		# self.value = MLP(args, 'value', z, 1, 101)

		self.transition = MLP(args, 'transition', tf.concat([z, self.log_probs],1), 101, 101)
		self.transition_loss = MSELoss(args, 'transition', self.transition.inference(), self.c3d_future.inference())

		# self.imitation_loss = CrossEntropyLoss(args, self.policy.inference(), inputs[7])
		# self.reward_loss = MESLoss(args, self.value.inference(), inputs[8])


		# # MCTS
		# self.z_mcts = tf.placeholder(tf.float32, shape=(1, 100))
		# self.policy_mcts = MLP(args, 'policy', self.z_mcts, 36, 100).inference()
		# self.value_mcts = MLP(args, 'value', self.z_mcts, 1, 100).inference()
		# self.transition_mcts = MLP(args, 'transition', self.z_mcts, 100, 100).inference()

		# self.mcts = MCTS('mcts', self.policy_inference, self.value_inference, self.transition_inference)
		# self.action = self.mcts.inference()
		#Structures with variables    
		# self.intersection_lane = MLP('intersection_lane')
		# self.intersection_offroad = MLP('intersection_offroad') 

		# Process Steps
		# self.mcts = MCTS('mcts')

		# self.transition = TransitionNetwork('transition')
		# self.policy = PolicyNetwork('policy')
		# self.safety = ValueNetwork('safety')
		# self.goal = ValueNetwork('goal')

		self.variable_parts = [self.c3d_encoder, self.raw_decoder, self.seg_decoder, self.depth_decoder, \
			self.speed_prediction, self.collision_prediction, self.intersection_prediction, self.policy, self.transition]

		self.loss_parts = self.collision_loss.inference() + self.intersection_loss.inference() + self.speed_loss.inference() + self.depth_decoder_loss.inference() + \
			self.raw_decoder_loss.inference() + self.seg_decoder_loss.inference() + self.policy_loss.inference() + self.transition_loss.inference()
				
		weight_decay_loss = tf.get_collection('weightdecay_losses')
		total_loss = self.loss_parts + weight_decay_loss
		tf.summary.scalar('total_loss', tf.reduce_mean(total_loss))

		self.final_ops = []
		for part in self.variable_parts:
			self.final_ops.append(part.optimize(total_loss))
		self.final_ops = tf.group(self.final_ops)

		config = tf.ConfigProto(allow_soft_placement = True)
		self.sess = tf.Session(config = config)
		self.sess.run(tf.global_variables_initializer())

		print('Restoring!')

		for part in self.variable_parts:
			part.variable_restore(self.sess)

		# Create summary writter
		merged = tf.summary.merge_all()
		writer = tf.summary.FileWriter('./logs/candy', self.sess.graph)

		print('Model Started!')
Ejemplo n.º 2
0
class Machine(object):
	def __init__(self):

		args = self.get_args()
		self.args = args
		self.args['crop_size'] = 112
		self.args['num_frames_per_clip'] = 16
		
		#Building Graph
		self.place_holders = PlaceHolders(args)

		inputs = self.place_holders.inference()
		#[self.image_sequence, self.raw_image, self.depth_image, self.seg_image, self.speed, self.collision, self.intersection, self.control, self.reward, self.transition]

		self.c3d_encoder = C3D_Encoder(args,'c3d_encoder', inputs[0])
		self.c3d_future = C3D_Encoder(args,'c3d_encoder', inputs[9], reuse=True)

		# self.vae = VAE(args, self.c3d_encoder.inference())
		# self.future_vae = VAE(args, self.c3d_future.inference())

		z = self.c3d_encoder.inference()
		self.z = z

		self.raw_decoder = ImageDecoder(args, 'raw_image', z, last=3)
		self.raw_decoder_loss = MSELoss(args, 'raw_image', self.raw_decoder.inference(), inputs[1])

		self.seg_decoder = ImageDecoder(args, 'seg', z, last=13)
		self.seg_decoder_loss = CrossEntropyLoss(args, 'seg', self.seg_decoder.inference(), inputs[3])
		
		self.depth_decoder = ImageDecoder(args, 'depth', z, last=1)
		self.depth_decoder_loss = MSELoss(args, 'depth', self.depth_decoder.inference(), inputs[2])

		self.speed_prediction = MLP(args, 'speed', z, 1, 101)
		self.speed_loss = MSELoss(args, 'speed', self.speed_prediction.inference(), inputs[4])        

		self.collision_prediction = MLP(args, 'collision', z, 1, 101)
		self.collision_loss = MSELoss(args, 'collision', self.collision_prediction.inference(), inputs[5])

		self.intersection_prediction = MLP(args, 'intersection', z, 1, 101)
		self.intersection_loss = MSELoss(args, 'intersection', self.intersection_prediction.inference(), inputs[6])

		self.policy = PG(args, 'policy', z, 13)
		self.log_probs = self.policy.inference()
		self.policy_loss = PGLoss(args, 'policy', inputs[7], inputs[8], self.log_probs)


		# self.value = MLP(args, 'value', z, 1, 101)

		self.transition = MLP(args, 'transition', tf.concat([z, self.log_probs],1), 101, 101)
		self.transition_loss = MSELoss(args, 'transition', self.transition.inference(), self.c3d_future.inference())

		# self.imitation_loss = CrossEntropyLoss(args, self.policy.inference(), inputs[7])
		# self.reward_loss = MESLoss(args, self.value.inference(), inputs[8])


		# # MCTS
		# self.z_mcts = tf.placeholder(tf.float32, shape=(1, 100))
		# self.policy_mcts = MLP(args, 'policy', self.z_mcts, 36, 100).inference()
		# self.value_mcts = MLP(args, 'value', self.z_mcts, 1, 100).inference()
		# self.transition_mcts = MLP(args, 'transition', self.z_mcts, 100, 100).inference()

		# self.mcts = MCTS('mcts', self.policy_inference, self.value_inference, self.transition_inference)
		# self.action = self.mcts.inference()
		#Structures with variables    
		# self.intersection_lane = MLP('intersection_lane')
		# self.intersection_offroad = MLP('intersection_offroad') 

		# Process Steps
		# self.mcts = MCTS('mcts')

		# self.transition = TransitionNetwork('transition')
		# self.policy = PolicyNetwork('policy')
		# self.safety = ValueNetwork('safety')
		# self.goal = ValueNetwork('goal')

		self.variable_parts = [self.c3d_encoder, self.raw_decoder, self.seg_decoder, self.depth_decoder, \
			self.speed_prediction, self.collision_prediction, self.intersection_prediction, self.policy, self.transition]

		self.loss_parts = self.collision_loss.inference() + self.intersection_loss.inference() + self.speed_loss.inference() + self.depth_decoder_loss.inference() + \
			self.raw_decoder_loss.inference() + self.seg_decoder_loss.inference() + self.policy_loss.inference() + self.transition_loss.inference()
				
		weight_decay_loss = tf.get_collection('weightdecay_losses')
		total_loss = self.loss_parts + weight_decay_loss
		tf.summary.scalar('total_loss', tf.reduce_mean(total_loss))

		self.final_ops = []
		for part in self.variable_parts:
			self.final_ops.append(part.optimize(total_loss))
		self.final_ops = tf.group(self.final_ops)

		config = tf.ConfigProto(allow_soft_placement = True)
		self.sess = tf.Session(config = config)
		self.sess.run(tf.global_variables_initializer())

		print('Restoring!')

		for part in self.variable_parts:
			part.variable_restore(self.sess)

		# Create summary writter
		merged = tf.summary.merge_all()
		writer = tf.summary.FileWriter('./logs/candy', self.sess.graph)

		print('Model Started!')

	def get_args(self):
		with open("args.yaml", 'r') as f:
			try:
				t = yaml.load(f)
				return t
			except yaml.YAMLError as exc:
				print(exc)
		


	def train(self, inputs, global_step):
		self.sess.run(self.final_ops, feed_dict=self.place_holders.get_feed_dict_train(inputs))
		print('Start Saving')
		# for i in self.variable_parts:
		# 	i.saver.save(self.sess, 'my-model', global_step=global_step)
		print('Saving Done.')



	def inference(self, inputs):
		log_probs = self.sess.run(self.log_probs, feed_dict=self.place_holders.get_feed_dict_inference(inputs))
		print(len(log_probs))

		def softmax(x):
			return np.exp(x) / np.sum(np.exp(x), axis=0)

		log_probs = softmax(log_probs[0])
		action = np.random.choice(range(log_probs.shape[0]), p=log_probs.ravel())  # 根据概率来选 action
		return action
Ejemplo n.º 3
0
def main(_):
    '''
    Train for one epoch to get supernet , then random sample 50 architectures for finetuning.
    This structure is basically the same as train_search.py
    TODO: Add PGD here and calculate FSP
    '''
    # init
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu

    logger = tf.get_logger()
    logger.disabled = True
    logger.setLevel(logging.FATAL)
    set_memory_growth()

    cfg = load_yaml(FLAGS.cfg_path)

    # define network
    sna = SearchNetArch(cfg)
    sna.model.summary(line_length=80)
    print("param size = {:f}MB".format(count_parameters_in_MB(sna.model)))

    # load dataset
    t_split = f"train[0%:{int(cfg['train_portion'] * 100)}%]"
    v_split = f"train[{int(cfg['train_portion'] * 100)}%:100%]"
    train_dataset = load_cifar10_dataset(
        cfg['batch_size'],
        split=t_split,
        shuffle=True,
        drop_remainder=True,
        using_normalize=cfg['using_normalize'],
        using_crop=cfg['using_crop'],
        using_flip=cfg['using_flip'],
        using_cutout=cfg['using_cutout'],
        cutout_length=cfg['cutout_length'])
    val_dataset = load_cifar10_dataset(cfg['batch_size'],
                                       split=v_split,
                                       shuffle=True,
                                       drop_remainder=True,
                                       using_normalize=cfg['using_normalize'],
                                       using_crop=cfg['using_crop'],
                                       using_flip=cfg['using_flip'],
                                       using_cutout=cfg['using_cutout'],
                                       cutout_length=cfg['cutout_length'])

    # define optimizer
    steps_per_epoch = int(cfg['dataset_len'] * cfg['train_portion'] //
                          cfg['batch_size'])
    learning_rate = CosineAnnealingLR(initial_learning_rate=cfg['init_lr'],
                                      t_period=cfg['epoch'] * steps_per_epoch,
                                      lr_min=cfg['lr_min'])
    optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate,
                                        momentum=cfg['momentum'])
    optimizer_arch = tf.keras.optimizers.Adam(
        learning_rate=cfg['arch_learning_rate'], beta_1=0.5, beta_2=0.999)

    # define losses function
    criterion = CrossEntropyLoss()

    # load checkpoint
    checkpoint_dir = './checkpoints/' + cfg['sub_name']
    checkpoint = tf.train.Checkpoint(step=tf.Variable(0, name='step'),
                                     optimizer=optimizer,
                                     optimizer_arch=optimizer_arch,
                                     model=sna.model,
                                     alphas_normal=sna.alphas_normal,
                                     alphas_reduce=sna.alphas_reduce,
                                     betas_normal=sna.betas_normal,
                                     betas_reduce=sna.betas_reduce)
    manager = tf.train.CheckpointManager(checkpoint=checkpoint,
                                         directory=checkpoint_dir,
                                         max_to_keep=3)
    if manager.latest_checkpoint:
        checkpoint.restore(manager.latest_checkpoint)
        print('[*] load ckpt from {} at step {}.'.format(
            manager.latest_checkpoint, checkpoint.step.numpy()))
    else:
        print("[*] training from scratch.")
    print(f"[*] searching model after {cfg['start_search_epoch']} epochs.")

    # define training step function for model
    @tf.function
    def train_step(inputs, labels):
        with tf.GradientTape() as tape:
            logits = sna.model((inputs, *sna.arch_parameters), training=True)

            losses = {}
            losses['reg'] = tf.reduce_sum(sna.model.losses)
            losses['ce'] = criterion(labels, logits)
            total_loss = tf.add_n([l for l in losses.values()])

        grads = tape.gradient(total_loss, sna.model.trainable_variables)
        grads = [(tf.clip_by_norm(grad, cfg['grad_clip'])) for grad in grads]
        optimizer.apply_gradients(zip(grads, sna.model.trainable_variables))

        return logits, total_loss, losses

    # define training step function for arch_parameters
    @tf.function
    def train_step_arch(inputs, labels):
        with tf.GradientTape() as tape:
            logits = sna.model((inputs, *sna.arch_parameters), training=True)

            losses = {}
            losses['reg'] = cfg['arch_weight_decay'] * tf.add_n(
                [tf.reduce_sum(p**2) for p in sna.arch_parameters])
            losses['ce'] = criterion(labels, logits)
            total_loss = tf.add_n([l for l in losses.values()])

        grads = tape.gradient(total_loss, sna.arch_parameters)
        optimizer_arch.apply_gradients(zip(grads, sna.arch_parameters))

        return losses

    summary_writer = tf.summary.create_file_writer('./logs/' + cfg['sub_name'])

    print("[*] finished searching for one epoch")

    print("[*] Start sampling architetures")

    prog_bar = ProgressBar(50, 0)

    # Start sampling for 50 archs
    for geno_num in range(50):
        genotype = sna.get_genotype(random_search_flag=True)
        prog_bar.update(f"\n Sampled{geno_num}th arch: {genotype}")
        # print(f"\n Sampled {geno_num}th arch: {genotype}")
        f = open(
            os.path.join('./logs', cfg['sub_name'],
                         'search_random_arch_genotype.py'), 'a')
        f.write(f"\n{cfg['sub_name']}_{geno_num} = {genotype}\n")
        f.close()

    print("Sampling done!")
    debugpy.wait_for_client()
Ejemplo n.º 4
0
def main(_):
    # init
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu

    logger = tf.get_logger()
    logger.disabled = True
    logger.setLevel(logging.FATAL)
    set_memory_growth()

    cfg = load_yaml(FLAGS.cfg_path)

    # define training step function
    @tf.function
    def train_step(inputs, labels, drop_path_prob):
        with tf.GradientTape() as tape:
            logits, logits_aux = model((inputs, drop_path_prob), training=True)

            losses = {}
            losses['reg'] = tf.reduce_sum(model.losses)
            losses['ce'] = criterion(labels, logits)
            losses['ce_auxiliary'] = \
                cfg['auxiliary_weight'] * criterion(labels, logits_aux)
            total_loss = tf.add_n([l for l in losses.values()])

        grads = tape.gradient(total_loss, model.trainable_variables)
        grads = [(tf.clip_by_norm(grad, cfg['grad_clip'])) for grad in grads]
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        return logits, total_loss, losses

    # Used to store the final accuracy for every arch
    final_acc = pd.DataFrame(data=None, columns=['arch_name', 'acc'])

    loop_num = 50

    if Debug:
        # debugpy.wait_for_client()
        loop_num = 1
    # define network
    for arch_num in range(loop_num):
        # read the arch
        arch = str(f"{cfg['sub_name']}_{arch_num}")
        cfg['arch'] = arch

        model = CifarModel(cfg, training=True, file_name=FLAGS.file_name)
        if Debug:
            model.summary(line_length=80)
            print("param size = {:f}MB".format(count_parameters_in_MB(model)))

        # load dataset
        train_dataset = load_cifar10_dataset(
            cfg['batch_size'],
            split='train',
            shuffle=True,
            drop_remainder=True,
            using_normalize=cfg['using_normalize'],
            using_crop=cfg['using_crop'],
            using_flip=cfg['using_flip'],
            using_cutout=cfg['using_cutout'],
            cutout_length=cfg['cutout_length'])
        val_dataset = load_cifar10_dataset(
            cfg['val_batch_size'],
            split='test',
            shuffle=False,
            drop_remainder=False,
            using_normalize=cfg['using_normalize'],
            using_crop=False,
            using_flip=False,
            using_cutout=False)

        # define optimizer
        steps_per_epoch = cfg['dataset_len'] // cfg['batch_size']
        learning_rate = CosineAnnealingLR(initial_learning_rate=cfg['init_lr'],
                                          t_period=cfg['epoch'] *
                                          steps_per_epoch,
                                          lr_min=cfg['lr_min'])
        optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate,
                                            momentum=cfg['momentum'])

        # define losses function
        criterion = CrossEntropyLoss()

        # load checkpoint
        checkpoint_dir = './checkpoints/' + arch
        checkpoint = tf.train.Checkpoint(step=tf.Variable(0, name='step'),
                                         optimizer=optimizer,
                                         model=model)
        manager = tf.train.CheckpointManager(checkpoint=checkpoint,
                                             directory=checkpoint_dir,
                                             max_to_keep=3)
        if manager.latest_checkpoint:
            checkpoint.restore(manager.latest_checkpoint)
            print('[*] load ckpt from {} at step {}.'.format(
                manager.latest_checkpoint, checkpoint.step.numpy()))
        else:
            print("[*] training from scratch.")

        # training loop
        summary_writer = tf.summary.create_file_writer('./logs/' +
                                                       cfg['sub_name'])
        total_steps = steps_per_epoch * cfg['epoch']
        remain_steps = max(total_steps - checkpoint.step.numpy(), 0)
        prog_bar = ProgressBar(steps_per_epoch,
                               checkpoint.step.numpy() % steps_per_epoch)

        train_acc = AvgrageMeter()
        val_acc = AvgrageMeter()
        best_acc = 0.
        for inputs, labels in train_dataset.take(remain_steps):
            checkpoint.step.assign_add(1)
            drop_path_prob = cfg['drop_path_prob'] * (
                tf.cast(checkpoint.step, tf.float32) / total_steps)
            steps = checkpoint.step.numpy()
            epochs = ((steps - 1) // steps_per_epoch) + 1

            logits, total_loss, losses = train_step(inputs, labels,
                                                    drop_path_prob)
            train_acc.update(
                accuracy(logits.numpy(), labels.numpy())[0], cfg['batch_size'])

            prog_bar.update(
                "epoch={}/{}, loss={:.4f}, acc={:.2f}, lr={:.2e}".format(
                    epochs, cfg['epoch'], total_loss.numpy(), train_acc.avg,
                    optimizer.lr(steps).numpy()))

            if steps % cfg['val_steps'] == 0 and steps > 1:
                print("\n[*] validate...", end='')
                val_acc.reset()
                for inputs_val, labels_val in val_dataset:
                    logits_val, _ = model((inputs_val, tf.constant([0.])))
                    val_acc.update(
                        accuracy(logits_val.numpy(), labels_val.numpy())[0],
                        inputs_val.shape[0])

                if val_acc.avg > best_acc:
                    best_acc = val_acc.avg
                    model.save_weights(
                        f"checkpoints/{cfg['sub_name']}/best.ckpt")

                val_str = " val acc {:.2f}%, best acc {:.2f}%"
                print(val_str.format(val_acc.avg, best_acc), end='')

            if steps % 10 == 0:
                with summary_writer.as_default():
                    tf.summary.scalar('acc/train', train_acc.avg, step=steps)
                    tf.summary.scalar('acc/val', val_acc.avg, step=steps)

                    tf.summary.scalar('loss/total_loss',
                                      total_loss,
                                      step=steps)
                    for k, l in losses.items():
                        tf.summary.scalar('loss/{}'.format(k), l, step=steps)
                    tf.summary.scalar('learning_rate',
                                      optimizer.lr(steps),
                                      step=steps)

            if steps % cfg['save_steps'] == 0:
                manager.save()
                print("\n[*] save ckpt file at {}".format(
                    manager.latest_checkpoint))

            if steps % steps_per_epoch == 0:
                train_acc.reset()

        manager.save()
        print("\n[*] training one arch done! save ckpt file at {}".format(
            manager.latest_checkpoint))
        final_acc.loc[arch_num] = list([arch, best_acc])
    print("Whole training ended, the best result is :")
    print("\t", final_acc.iloc[final_acc['acc'].idxmax()])
Ejemplo n.º 5
0
def main(_):
    # init
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu

    logger = tf.get_logger()
    logger.disabled = True
    logger.setLevel(logging.FATAL)
    set_memory_growth()

    cfg = load_yaml(FLAGS.cfg_path)

    # define network
    sna = SearchNetArch(cfg)
    sna.model.summary(line_length=80)
    print("param size = {:f}MB".format(count_parameters_in_MB(sna.model)))

    # load dataset
    t_split = f"train[0%:{int(cfg['train_portion'] * 100)}%]"
    v_split = f"train[{int(cfg['train_portion'] * 100)}%:100%]"
    train_dataset = load_cifar10_dataset(
        cfg['batch_size'],
        split=t_split,
        shuffle=True,
        drop_remainder=True,
        using_normalize=cfg['using_normalize'],
        using_crop=cfg['using_crop'],
        using_flip=cfg['using_flip'],
        using_cutout=cfg['using_cutout'],
        cutout_length=cfg['cutout_length'])
    val_dataset = load_cifar10_dataset(cfg['batch_size'],
                                       split=v_split,
                                       shuffle=True,
                                       drop_remainder=True,
                                       using_normalize=cfg['using_normalize'],
                                       using_crop=cfg['using_crop'],
                                       using_flip=cfg['using_flip'],
                                       using_cutout=cfg['using_cutout'],
                                       cutout_length=cfg['cutout_length'])

    # define optimizer
    steps_per_epoch = int(cfg['dataset_len'] * cfg['train_portion'] //
                          cfg['batch_size'])
    learning_rate = CosineAnnealingLR(initial_learning_rate=cfg['init_lr'],
                                      t_period=cfg['epoch'] * steps_per_epoch,
                                      lr_min=cfg['lr_min'])
    optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate,
                                        momentum=cfg['momentum'])
    optimizer_arch = tf.keras.optimizers.Adam(
        learning_rate=cfg['arch_learning_rate'], beta_1=0.5, beta_2=0.999)

    # define losses function
    criterion = CrossEntropyLoss()

    # load checkpoint
    checkpoint_dir = './checkpoints/' + cfg['sub_name']
    checkpoint = tf.train.Checkpoint(step=tf.Variable(0, name='step'),
                                     optimizer=optimizer,
                                     optimizer_arch=optimizer_arch,
                                     model=sna.model,
                                     alphas_normal=sna.alphas_normal,
                                     alphas_reduce=sna.alphas_reduce,
                                     betas_normal=sna.betas_normal,
                                     betas_reduce=sna.betas_reduce)
    manager = tf.train.CheckpointManager(checkpoint=checkpoint,
                                         directory=checkpoint_dir,
                                         max_to_keep=3)
    if manager.latest_checkpoint:
        checkpoint.restore(manager.latest_checkpoint)
        print('[*] load ckpt from {} at step {}.'.format(
            manager.latest_checkpoint, checkpoint.step.numpy()))
    else:
        print("[*] training from scratch.")
    print(f"[*] searching model after {cfg['start_search_epoch']} epochs.")

    # define training step function for model
    @tf.function
    def train_step(inputs, labels):
        with tf.GradientTape() as tape:
            logits = sna.model((inputs, *sna.arch_parameters), training=True)

            losses = {}
            losses['reg'] = tf.reduce_sum(sna.model.losses)
            losses['ce'] = criterion(labels, logits)
            total_loss = tf.add_n([l for l in losses.values()])

        grads = tape.gradient(total_loss, sna.model.trainable_variables)
        grads = [(tf.clip_by_norm(grad, cfg['grad_clip'])) for grad in grads]
        optimizer.apply_gradients(zip(grads, sna.model.trainable_variables))

        return logits, total_loss, losses

    # define training step function for arch_parameters
    @tf.function
    def train_step_arch(inputs, labels):
        with tf.GradientTape() as tape:
            logits = sna.model((inputs, *sna.arch_parameters), training=True)

            losses = {}
            losses['reg'] = cfg['arch_weight_decay'] * tf.add_n(
                [tf.reduce_sum(p**2) for p in sna.arch_parameters])
            losses['ce'] = criterion(labels, logits)
            total_loss = tf.add_n([l for l in losses.values()])

        grads = tape.gradient(total_loss, sna.arch_parameters)
        optimizer_arch.apply_gradients(zip(grads, sna.arch_parameters))

        return losses

    # training loop
    summary_writer = tf.summary.create_file_writer('./logs/' + cfg['sub_name'])
    total_steps = steps_per_epoch * cfg['epoch']
    remain_steps = max(total_steps - checkpoint.step.numpy(), 0)
    prog_bar = ProgressBar(steps_per_epoch,
                           checkpoint.step.numpy() % steps_per_epoch)

    train_acc = AvgrageMeter()
    for inputs, labels in train_dataset.take(remain_steps):
        checkpoint.step.assign_add(1)
        steps = checkpoint.step.numpy()
        epochs = ((steps - 1) // steps_per_epoch) + 1

        if epochs > cfg['start_search_epoch']:
            inputs_val, labels_val = next(iter(val_dataset))
            arch_losses = train_step_arch(inputs_val, labels_val)

        logits, total_loss, losses = train_step(inputs, labels)
        train_acc.update(
            accuracy(logits.numpy(), labels.numpy())[0], cfg['batch_size'])

        prog_bar.update(
            "epoch={:d}/{:d}, loss={:.4f}, acc={:.2f}, lr={:.2e}".format(
                epochs, cfg['epoch'], total_loss.numpy(), train_acc.avg,
                optimizer.lr(steps).numpy()))

        if steps % 10 == 0:
            with summary_writer.as_default():
                tf.summary.scalar('acc/train', train_acc.avg, step=steps)

                tf.summary.scalar('loss/total_loss', total_loss, step=steps)
                for k, l in losses.items():
                    tf.summary.scalar('loss/{}'.format(k), l, step=steps)
                tf.summary.scalar('learning_rate',
                                  optimizer.lr(steps),
                                  step=steps)

                if epochs > cfg['start_search_epoch']:
                    for k, l in arch_losses.items():
                        tf.summary.scalar('arch_losses/{}'.format(k),
                                          l,
                                          step=steps)
                    tf.summary.scalar('arch_learning_rate',
                                      cfg['arch_learning_rate'],
                                      step=steps)

        if steps % cfg['save_steps'] == 0:
            manager.save()
            print("\n[*] save ckpt file at {}".format(
                manager.latest_checkpoint))

        if steps % steps_per_epoch == 0:
            train_acc.reset()
            if epochs > cfg['start_search_epoch']:
                genotype = sna.get_genotype()
                print(f"\nsearch arch: {genotype}")
                f = open(
                    os.path.join('./logs', cfg['sub_name'],
                                 'search_arch_genotype.py'), 'a')
                f.write(f"\n{cfg['sub_name']}_{epochs} = {genotype}\n")
                f.close()

    manager.save()
    print("\n[*] training done! save ckpt file at {}".format(
        manager.latest_checkpoint))