Beispiel #1
0
class neural_architecture_search():
    def __init__(self, args):
        self.args = args

        if not torch.cuda.is_available():
            logging.info('no gpu device available')
            sys.exit(1)

        if self.args.distributed:
            # Init distributed environment
            self.rank, self.world_size, self.device = init_dist(
                port=self.args.port)
            self.seed = self.rank * self.args.seed
        else:
            torch.cuda.set_device(self.args.gpu)
            self.device = torch.device("cuda")
            self.rank = 0
            self.seed = self.args.seed
            self.world_size = 1

        if self.args.fix_seedcudnn:
            random.seed(self.seed)
            torch.backends.cudnn.deterministic = True
            np.random.seed(self.seed)
            cudnn.benchmark = False
            torch.manual_seed(self.seed)
            cudnn.enabled = True
            torch.cuda.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)
        else:
            np.random.seed(self.seed)
            cudnn.benchmark = True
            torch.manual_seed(self.seed)
            cudnn.enabled = True
            torch.cuda.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)

        self.path = os.path.join(generate_date, self.args.save)
        if self.rank == 0:
            utils.create_exp_dir(generate_date,
                                 self.path,
                                 scripts_to_save=glob.glob('*.py'))
            logging.basicConfig(stream=sys.stdout,
                                level=logging.INFO,
                                format=log_format,
                                datefmt='%m/%d %I:%M:%S %p')
            fh = logging.FileHandler(os.path.join(self.path, 'log.txt'))
            fh.setFormatter(logging.Formatter(log_format))
            logging.getLogger().addHandler(fh)
            logging.info("self.args = %s", self.args)
            self.logger = tensorboardX.SummaryWriter(
                './runs/' + generate_date + '/nas_{}'.format(self.args.remark))
        else:
            self.logger = None

        # set default resource_lambda for different methods
        if self.args.resource_efficient:
            if self.args.method == 'policy_gradient':
                if self.args.log_penalty:
                    default_resource_lambda = 1e-4
                else:
                    default_resource_lambda = 1e-5
            if self.args.method == 'reparametrization':
                if self.args.log_penalty:
                    default_resource_lambda = 1e-2
                else:
                    default_resource_lambda = 1e-5
            if self.args.method == 'discrete':
                if self.args.log_penalty:
                    default_resource_lambda = 1e-2
                else:
                    default_resource_lambda = 1e-4
            if self.args.resource_lambda == default_lambda:
                self.args.resource_lambda = default_resource_lambda

        #initialize loss function
        self.criterion = nn.CrossEntropyLoss().to(self.device)

        #initialize model
        self.init_model()

        #calculate model param size
        if self.rank == 0:
            logging.info("param size = %fMB",
                         utils.count_parameters_in_MB(self.model))
            self.model._logger = self.logger
            self.model._logging = logging

        #initialize optimizer
        self.init_optimizer()

        #iniatilize dataset loader
        self.init_loaddata()

        self.update_theta = True
        self.update_alpha = True

    def init_model(self):

        self.model = Network(self.args.init_channels, CIFAR_CLASSES,
                             self.args.layers, self.criterion, self.args,
                             self.rank, self.world_size)
        self.model.to(self.device)
        if self.args.distributed:
            broadcast_params(self.model)
        for v in self.model.parameters():
            if v.requires_grad:
                if v.grad is None:
                    v.grad = torch.zeros_like(v)
        self.model.normal_log_alpha.grad = torch.zeros_like(
            self.model.normal_log_alpha)
        self.model.reduce_log_alpha.grad = torch.zeros_like(
            self.model.reduce_log_alpha)

    def init_optimizer(self):

        if args.distributed:
            self.optimizer = torch.optim.SGD(
                [
                    param for name, param in self.model.named_parameters() if
                    name != 'normal_log_alpha' and name != 'reduce_log_alpha'
                ],
                self.args.learning_rate,
                momentum=self.args.momentum,
                weight_decay=self.args.weight_decay)
            self.arch_optimizer = torch.optim.Adam(
                [
                    param for name, param in self.model.named_parameters()
                    if name == 'normal_log_alpha' or name == 'reduce_log_alpha'
                ],
                lr=self.args.arch_learning_rate,
                betas=(0.5, 0.999),
                weight_decay=self.args.arch_weight_decay)
        else:
            self.optimizer = torch.optim.SGD(self.model.parameters(),
                                             self.args.learning_rate,
                                             momentum=self.args.momentum,
                                             weight_decay=args.weight_decay)

            self.arch_optimizer = torch.optim.SGD(
                self.model.arch_parameters(), lr=self.args.arch_learning_rate)

    def init_loaddata(self):

        train_transform, valid_transform = utils._data_transforms_cifar10(
            self.args)
        train_data = dset.CIFAR10(root=self.args.data,
                                  train=True,
                                  download=True,
                                  transform=train_transform)
        valid_data = dset.CIFAR10(root=self.args.data,
                                  train=False,
                                  download=True,
                                  transform=valid_transform)

        if self.args.seed:

            def worker_init_fn():
                seed = self.seed
                np.random.seed(seed)
                random.seed(seed)
                torch.manual_seed(seed)
                return
        else:
            worker_init_fn = None

        if self.args.distributed:
            train_sampler = DistributedSampler(train_data)
            valid_sampler = DistributedSampler(valid_data)

            self.train_queue = torch.utils.data.DataLoader(
                train_data,
                batch_size=self.args.batch_size // self.world_size,
                shuffle=False,
                num_workers=0,
                pin_memory=False,
                sampler=train_sampler)
            self.valid_queue = torch.utils.data.DataLoader(
                valid_data,
                batch_size=self.args.batch_size // self.world_size,
                shuffle=False,
                num_workers=0,
                pin_memory=False,
                sampler=valid_sampler)

        else:
            self.train_queue = torch.utils.data.DataLoader(
                train_data,
                batch_size=self.args.batch_size,
                shuffle=True,
                pin_memory=False,
                num_workers=2)

            self.valid_queue = torch.utils.data.DataLoader(
                valid_data,
                batch_size=self.args.batch_size,
                shuffle=False,
                pin_memory=False,
                num_workers=2)

    def main(self):
        # lr scheduler: cosine annealing
        # temp scheduler: linear annealing (self-defined in utils)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            float(self.args.epochs),
            eta_min=self.args.learning_rate_min)

        self.temp_scheduler = utils.Temp_Scheduler(self.args.epochs,
                                                   self.model._temp,
                                                   self.args.temp,
                                                   temp_min=self.args.temp_min)

        for epoch in range(self.args.epochs):
            if self.args.random_sample_pretrain:
                if epoch < self.args.random_sample_pretrain_epoch:
                    self.args.random_sample = True
                else:
                    self.args.random_sample = False

            self.scheduler.step()
            if self.args.temp_annealing:
                self.model._temp = self.temp_scheduler.step()
            self.lr = self.scheduler.get_lr()[0]

            if self.rank == 0:
                logging.info('epoch %d lr %e temp %e', epoch, self.lr,
                             self.model._temp)
                self.logger.add_scalar('epoch_temp', self.model._temp, epoch)
                logging.info(self.model.normal_log_alpha)
                logging.info(self.model.reduce_log_alpha)
                logging.info(
                    self.model._get_weights(self.model.normal_log_alpha[0]))
                logging.info(
                    self.model._get_weights(self.model.reduce_log_alpha[0]))

            genotype_edge_all = self.model.genotype_edge_all()

            if self.rank == 0:
                logging.info('genotype_edge_all = %s', genotype_edge_all)
                # create genotypes.txt file
                txt_name = self.args.remark + '_genotype_edge_all_epoch' + str(
                    epoch)
                utils.txt('genotype', self.args.save, txt_name,
                          str(genotype_edge_all), generate_date)

            self.model.train()
            train_acc, loss, error_loss, loss_alpha = self.train(
                epoch, logging)
            if self.rank == 0:
                logging.info('train_acc %f', train_acc)
                self.logger.add_scalar("epoch_train_acc", train_acc, epoch)
                self.logger.add_scalar("epoch_train_error_loss", error_loss,
                                       epoch)
                if self.args.dsnas:
                    self.logger.add_scalar("epoch_train_alpha_loss",
                                           loss_alpha, epoch)

            # validation
            self.model.eval()
            valid_acc, valid_obj = self.infer(epoch)
            if self.args.gen_max_child:
                self.args.gen_max_child_flag = True
                valid_acc_max_child, valid_obj_max_child = self.infer(epoch)
                self.args.gen_max_child_flag = False

            if self.rank == 0:
                logging.info('valid_acc %f', valid_acc)
                self.logger.add_scalar("epoch_valid_acc", valid_acc, epoch)
                if self.args.gen_max_child:
                    logging.info('valid_acc_argmax_alpha %f',
                                 valid_acc_max_child)
                    self.logger.add_scalar("epoch_valid_acc_argmax_alpha",
                                           valid_acc_max_child, epoch)

                utils.save(self.model, os.path.join(self.path, 'weights.pt'))

        if self.rank == 0:
            logging.info(self.model.normal_log_alpha)
            logging.info(self.model.reduce_log_alpha)
            genotype_edge_all = self.model.genotype_edge_all()
            logging.info('genotype_edge_all = %s', genotype_edge_all)

    def train(self, epoch, logging):
        objs = utils.AvgrageMeter()
        top1 = utils.AvgrageMeter()
        top5 = utils.AvgrageMeter()
        grad = utils.AvgrageMeter()

        normal_resource_gradient = 0
        reduce_resource_gradient = 0
        normal_loss_gradient = 0
        reduce_loss_gradient = 0
        normal_total_gradient = 0
        reduce_total_gradient = 0

        loss_alpha = None

        count = 0
        for step, (input, target) in enumerate(self.train_queue):
            if self.args.alternate_update:
                if step % 2 == 0:
                    self.update_theta = True
                    self.update_alpha = False
                else:
                    self.update_theta = False
                    self.update_alpha = True

            n = input.size(0)
            input = input.to(self.device)
            target = target.to(self.device, non_blocking=True)
            if self.args.snas:
                logits, logits_aux, penalty, op_normal, op_reduce = self.model(
                    input)
                error_loss = self.criterion(logits, target)
                if self.args.auxiliary:
                    loss_aux = self.criterion(logits_aux, target)
                    error_loss += self.args.auxiliary_weight * loss_aux

            if self.args.dsnas:
                logits, error_loss, loss_alpha, penalty = self.model(
                    input, target, self.criterion)

            num_normal = self.model.num_normal
            num_reduce = self.model.num_reduce
            normal_arch_entropy = self.model._arch_entropy(
                self.model.normal_log_alpha)
            reduce_arch_entropy = self.model._arch_entropy(
                self.model.reduce_log_alpha)

            if self.args.resource_efficient:
                if self.args.method == 'policy_gradient':
                    resource_penalty = (penalty[2]) / 6 + self.args.ratio * (
                        penalty[7]) / 2
                    log_resource_penalty = (
                        penalty[35]) / 6 + self.args.ratio * (penalty[36]) / 2
                elif self.args.method == 'reparametrization':
                    resource_penalty = (penalty[26]) / 6 + self.args.ratio * (
                        penalty[25]) / 2
                    log_resource_penalty = (
                        penalty[37]) / 6 + self.args.ratio * (penalty[38]) / 2
                elif self.args.method == 'discrete':
                    resource_penalty = (penalty[28]) / 6 + self.args.ratio * (
                        penalty[27]) / 2
                    log_resource_penalty = (
                        penalty[39]) / 6 + self.args.ratio * (penalty[40]) / 2
                elif self.args.method == 'none':
                    # TODo
                    resource_penalty = torch.zeros(1).cuda()
                    log_resource_penalty = torch.zeros(1).cuda()
                else:
                    logging.info(
                        "wrongly input of method, please re-enter --method from 'policy_gradient', 'discrete', "
                        "'reparametrization', 'none'")
                    sys.exit(1)
            else:
                resource_penalty = torch.zeros(1).cuda()
                log_resource_penalty = torch.zeros(1).cuda()

            if self.args.log_penalty:
                resource_loss = self.model._resource_lambda * log_resource_penalty
            else:
                resource_loss = self.model._resource_lambda * resource_penalty

            if self.args.loss:
                if self.args.snas:
                    loss = resource_loss.clone() + error_loss.clone()
                elif self.args.dsnas:
                    loss = resource_loss.clone()
                else:
                    loss = resource_loss.clone() + -child_coef * (
                        torch.log(normal_one_hot_prob) +
                        torch.log(reduce_one_hot_prob)).sum()
            else:
                if self.args.snas or self.args.dsnas:
                    loss = error_loss.clone()

            if self.args.distributed:
                loss.div_(self.world_size)
                error_loss.div_(self.world_size)
                resource_loss.div_(self.world_size)
                if self.args.dsnas:
                    loss_alpha.div_(self.world_size)

            # logging gradient
            count += 1
            if self.args.resource_efficient:
                self.optimizer.zero_grad()
                self.arch_optimizer.zero_grad()
                resource_loss.backward(retain_graph=True)
                if not self.args.random_sample:
                    normal_resource_gradient += self.model.normal_log_alpha.grad
                    reduce_resource_gradient += self.model.reduce_log_alpha.grad
            if self.args.snas:
                self.optimizer.zero_grad()
                self.arch_optimizer.zero_grad()
                error_loss.backward(retain_graph=True)
                if not self.args.random_sample:
                    normal_loss_gradient += self.model.normal_log_alpha.grad
                    reduce_loss_gradient += self.model.reduce_log_alpha.grad
                self.optimizer.zero_grad()
                self.arch_optimizer.zero_grad()

            if self.args.snas or not self.args.random_sample and not self.args.dsnas:
                loss.backward()
            if not self.args.random_sample:
                normal_total_gradient += self.model.normal_log_alpha.grad
                reduce_total_gradient += self.model.reduce_log_alpha.grad

            if self.args.distributed:
                reduce_tensorgradients(self.model.parameters(), sync=True)
                nn.utils.clip_grad_norm_([
                    param for name, param in self.model.named_parameters() if
                    name != 'normal_log_alpha' and name != 'reduce_log_alpha'
                ], self.args.grad_clip)
                arch_grad_norm = nn.utils.clip_grad_norm_([
                    param for name, param in self.model.named_parameters()
                    if name == 'normal_log_alpha' or name == 'reduce_log_alpha'
                ], 10.)
            else:
                nn.utils.clip_grad_norm_(self.model.parameters(),
                                         self.args.grad_clip)
                arch_grad_norm = nn.utils.clip_grad_norm_(
                    self.model.arch_parameters(), 10.)

            grad.update(arch_grad_norm)
            if not self.args.fix_weight and self.update_theta:
                self.optimizer.step()
            self.optimizer.zero_grad()
            if not self.args.random_sample and self.update_alpha:
                self.arch_optimizer.step()
            self.arch_optimizer.zero_grad()

            if self.rank == 0:
                self.logger.add_scalar(
                    "iter_train_loss", error_loss,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "normal_arch_entropy", normal_arch_entropy,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "reduce_arch_entropy", reduce_arch_entropy,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "total_arch_entropy",
                    normal_arch_entropy + reduce_arch_entropy,
                    step + len(self.train_queue.dataset) * epoch)
                if self.args.dsnas:
                    #reward_normal_edge
                    self.logger.add_scalar(
                        "reward_normal_edge_0",
                        self.model.normal_edge_reward[0],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_1",
                        self.model.normal_edge_reward[1],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_2",
                        self.model.normal_edge_reward[2],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_3",
                        self.model.normal_edge_reward[3],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_4",
                        self.model.normal_edge_reward[4],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_5",
                        self.model.normal_edge_reward[5],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_6",
                        self.model.normal_edge_reward[6],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_7",
                        self.model.normal_edge_reward[7],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_8",
                        self.model.normal_edge_reward[8],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_9",
                        self.model.normal_edge_reward[9],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_10",
                        self.model.normal_edge_reward[10],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_11",
                        self.model.normal_edge_reward[11],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_12",
                        self.model.normal_edge_reward[12],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_13",
                        self.model.normal_edge_reward[13],
                        step + len(self.train_queue.dataset) * epoch)
                    #reward_reduce_edge
                    self.logger.add_scalar(
                        "reward_reduce_edge_0",
                        self.model.reduce_edge_reward[0],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_1",
                        self.model.reduce_edge_reward[1],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_2",
                        self.model.reduce_edge_reward[2],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_3",
                        self.model.reduce_edge_reward[3],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_4",
                        self.model.reduce_edge_reward[4],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_5",
                        self.model.reduce_edge_reward[5],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_6",
                        self.model.reduce_edge_reward[6],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_7",
                        self.model.reduce_edge_reward[7],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_8",
                        self.model.reduce_edge_reward[8],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_9",
                        self.model.reduce_edge_reward[9],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_10",
                        self.model.reduce_edge_reward[10],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_11",
                        self.model.reduce_edge_reward[11],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_12",
                        self.model.reduce_edge_reward[12],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_13",
                        self.model.reduce_edge_reward[13],
                        step + len(self.train_queue.dataset) * epoch)
                #policy size
                self.logger.add_scalar(
                    "iter_normal_size_policy", penalty[2] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_size_policy", penalty[7] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                # baseline: discrete_probability
                self.logger.add_scalar(
                    "iter_normal_size_baseline", penalty[3] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_flops_baseline", penalty[5] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_mac_baseline", penalty[6] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_size_baseline", penalty[8] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_flops_baseline", penalty[9] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_mac_baseline", penalty[10] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                # R - median(R)
                self.logger.add_scalar(
                    "iter_normal_size-avg", penalty[60] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_flops-avg", penalty[61] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_mac-avg", penalty[62] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_size-avg", penalty[63] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_flops-avg", penalty[64] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_mac-avg", penalty[65] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                # lnR - ln(median)
                self.logger.add_scalar(
                    "iter_normal_ln_size-ln_avg", penalty[66] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_ln_flops-ln_avg", penalty[67] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_ln_mac-ln_avg", penalty[68] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_ln_size-ln_avg", penalty[69] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_ln_flops-ln_avg", penalty[70] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_ln_mac-ln_avg", penalty[71] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                '''
                self.logger.add_scalar("iter_normal_size_normalized", penalty[17] / 6, step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar("iter_normal_flops_normalized", penalty[18] / 6, step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar("iter_normal_mac_normalized", penalty[19] / 6, step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar("iter_reduce_size_normalized", penalty[20] / 2, step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar("iter_reduce_flops_normalized", penalty[21] / 2, step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar("iter_reduce_mac_normalized", penalty[22] / 2, step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar("iter_normal_penalty_normalized", penalty[23] / 6,
                                  step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar("iter_reduce_penalty_normalized", penalty[24] / 2,
                                  step + len(self.train_queue.dataset) * epoch)
                '''
                # Monte_Carlo(R_i)
                self.logger.add_scalar(
                    "iter_normal_size_mc", penalty[29] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_flops_mc", penalty[30] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_mac_mc", penalty[31] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_size_mc", penalty[32] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_flops_mc", penalty[33] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_mac_mc", penalty[34] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                # log(|R_i|)
                self.logger.add_scalar(
                    "iter_normal_log_size", penalty[41] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_log_flops", penalty[42] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_log_mac", penalty[43] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_log_size", penalty[44] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_log_flops", penalty[45] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_log_mac", penalty[46] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                # log(P)R_i
                self.logger.add_scalar(
                    "iter_normal_logP_size", penalty[47] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_logP_flops", penalty[48] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_logP_mac", penalty[49] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_logP_size", penalty[50] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_logP_flops", penalty[51] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_logP_mac", penalty[52] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                # log(P)log(R_i)
                self.logger.add_scalar(
                    "iter_normal_logP_log_size", penalty[53] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_logP_log_flops", penalty[54] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_logP_log_mac", penalty[55] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_logP_log_size", penalty[56] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_logP_log_flops", penalty[57] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_logP_log_mac", penalty[58] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)

            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))

            if self.args.distributed:
                loss = loss.detach()
                dist.all_reduce(error_loss)
                dist.all_reduce(prec1)
                dist.all_reduce(prec5)
                prec1.div_(self.world_size)
                prec5.div_(self.world_size)
                #dist_util.all_reduce([loss, prec1, prec5], 'mean')
            objs.update(error_loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

            if step % self.args.report_freq == 0 and self.rank == 0:
                logging.info('train %03d %e %f %f', step, objs.avg, top1.avg,
                             top5.avg)
                self.logger.add_scalar(
                    "iter_train_top1_acc", top1.avg,
                    step + len(self.train_queue.dataset) * epoch)

        if self.rank == 0:
            logging.info('-------resource gradient--------')
            logging.info(normal_resource_gradient / count)
            logging.info(reduce_resource_gradient / count)
            logging.info('-------loss gradient--------')
            logging.info(normal_loss_gradient / count)
            logging.info(reduce_loss_gradient / count)
            logging.info('-------total gradient--------')
            logging.info(normal_total_gradient / count)
            logging.info(reduce_total_gradient / count)

        return top1.avg, loss, error_loss, loss_alpha

    def infer(self, epoch):
        objs = utils.AvgrageMeter()
        top1 = utils.AvgrageMeter()
        top5 = utils.AvgrageMeter()

        self.model.eval()
        with torch.no_grad():
            for step, (input, target) in enumerate(self.valid_queue):
                input = input.to(self.device)
                target = target.to(self.device)
                if self.args.snas:
                    logits, logits_aux, resource_loss, op_normal, op_reduce = self.model(
                        input)
                    loss = self.criterion(logits, target)
                elif self.args.dsnas:
                    logits, error_loss, loss_alpha, resource_loss = self.model(
                        input, target, self.criterion)
                    loss = error_loss

                prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))

                if self.args.distributed:
                    loss.div_(self.world_size)
                    loss = loss.detach()
                    dist.all_reduce(loss)
                    dist.all_reduce(prec1)
                    dist.all_reduce(prec5)
                    prec1.div_(self.world_size)
                    prec5.div_(self.world_size)
                objs.update(loss.item(), input.size(0))
                top1.update(prec1.item(), input.size(0))
                top5.update(prec5.item(), input.size(0))

                if step % self.args.report_freq == 0 and self.rank == 0:
                    logging.info('valid %03d %e %f %f', step, objs.avg,
                                 top1.avg, top5.avg)
                    self.logger.add_scalar(
                        "iter_valid_loss", loss,
                        step + len(self.valid_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "iter_valid_top1_acc", top1.avg,
                        step + len(self.valid_queue.dataset) * epoch)

        return top1.avg, objs.avg
class neural_architecture_search():
    def __init__(self, args):
        self.args = args

        if not torch.cuda.is_available():
            logging.info('no gpu device available')
            sys.exit(1)

        torch.cuda.set_device(self.args.gpu)
        self.device = torch.device("cuda")
        self.rank = 0
        self.seed = self.args.seed
        self.world_size = 1

        if self.args.fix_cudnn:
            random.seed(self.seed)
            torch.backends.cudnn.deterministic = True
            np.random.seed(self.seed)
            cudnn.benchmark = False
            torch.manual_seed(self.seed)
            cudnn.enabled = True
            torch.cuda.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)
        else:
            np.random.seed(self.seed)
            cudnn.benchmark = True
            torch.manual_seed(self.seed)
            cudnn.enabled = True
            torch.cuda.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)

        self.path = os.path.join(generate_date, self.args.save)
        if self.rank == 0:
            utils.create_exp_dir(generate_date,
                                 self.path,
                                 scripts_to_save=glob.glob('*.py'))
            logging.basicConfig(stream=sys.stdout,
                                level=logging.INFO,
                                format=log_format,
                                datefmt='%m/%d %I:%M:%S %p')
            fh = logging.FileHandler(os.path.join(self.path, 'log.txt'))
            fh.setFormatter(logging.Formatter(log_format))
            logging.getLogger().addHandler(fh)
            logging.info("self.args = %s", self.args)
            self.logger = tensorboardX.SummaryWriter('./runs/' +
                                                     generate_date + '/' +
                                                     self.args.save_log)
        else:
            self.logger = None

        #initialize loss function
        self.criterion = nn.CrossEntropyLoss().to(self.device)

        #initialize model
        self.init_model()
        if self.args.resume:
            self.reload_model()

        #calculate model param size
        if self.rank == 0:
            logging.info("param size = %fMB",
                         utils.count_parameters_in_MB(self.model))
            self.model._logger = self.logger
            self.model._logging = logging

        #initialize optimizer
        self.init_optimizer()

        #iniatilize dataset loader
        self.init_loaddata()

        self.update_theta = True
        self.update_alpha = True

    def init_model(self):

        self.model = Network(self.args.init_channels, CIFAR_CLASSES,
                             self.args.layers, self.criterion, self.args,
                             self.rank, self.world_size, self.args.steps,
                             self.args.multiplier)
        self.model.to(self.device)
        for v in self.model.parameters():
            if v.requires_grad:
                if v.grad is None:
                    v.grad = torch.zeros_like(v)
        self.model.normal_log_alpha.grad = torch.zeros_like(
            self.model.normal_log_alpha)
        self.model.reduce_log_alpha.grad = torch.zeros_like(
            self.model.reduce_log_alpha)

    def reload_model(self):
        self.model.load_state_dict(torch.load(self.args.resume_path +
                                              '/weights.pt'),
                                   strict=True)

    def init_optimizer(self):

        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         self.args.learning_rate,
                                         momentum=self.args.momentum,
                                         weight_decay=args.weight_decay)

        self.arch_optimizer = torch.optim.Adam(
            self.model.arch_parameters(),
            lr=self.args.arch_learning_rate,
            betas=(0.5, 0.999),
            weight_decay=self.args.arch_weight_decay)

    def init_loaddata(self):

        train_transform, valid_transform = utils._data_transforms_cifar10(
            self.args)
        train_data = dset.CIFAR10(root=self.args.data,
                                  train=True,
                                  download=True,
                                  transform=train_transform)
        valid_data = dset.CIFAR10(root=self.args.data,
                                  train=False,
                                  download=True,
                                  transform=valid_transform)

        if self.args.seed:

            def worker_init_fn():
                seed = self.seed
                np.random.seed(seed)
                random.seed(seed)
                torch.manual_seed(seed)
                return
        else:
            worker_init_fn = None

        num_train = len(train_data)
        indices = list(range(num_train))

        self.train_queue = torch.utils.data.DataLoader(
            train_data,
            batch_size=self.args.batch_size,
            shuffle=True,
            pin_memory=False,
            num_workers=2)

        self.valid_queue = torch.utils.data.DataLoader(
            valid_data,
            batch_size=self.args.batch_size,
            shuffle=False,
            pin_memory=False,
            num_workers=2)

    def main(self):
        # lr scheduler: cosine annealing
        # temp scheduler: linear annealing (self-defined in utils)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            float(self.args.epochs),
            eta_min=self.args.learning_rate_min)

        self.temp_scheduler = utils.Temp_Scheduler(self.args.epochs,
                                                   self.model._temp,
                                                   self.args.temp,
                                                   temp_min=self.args.temp_min)

        for epoch in range(self.args.epochs):
            if self.args.child_reward_stat:
                self.update_theta = False
                self.update_alpha = False

            if self.args.current_reward:
                self.model.normal_reward_mean = torch.zeros_like(
                    self.model.normal_reward_mean)
                self.model.reduce_reward_mean = torch.zeros_like(
                    self.model.reduce_reward_mean)
                self.model.count = 0

            if epoch < self.args.resume_epoch:
                continue
            self.scheduler.step()
            if self.args.temp_annealing:
                self.model._temp = self.temp_scheduler.step()
            self.lr = self.scheduler.get_lr()[0]

            if self.rank == 0:
                logging.info('epoch %d lr %e temp %e', epoch, self.lr,
                             self.model._temp)
                self.logger.add_scalar('epoch_temp', self.model._temp, epoch)
                logging.info(self.model.normal_log_alpha)
                logging.info(self.model.reduce_log_alpha)
                logging.info(F.softmax(self.model.normal_log_alpha, dim=-1))
                logging.info(F.softmax(self.model.reduce_log_alpha, dim=-1))

            genotype_edge_all = self.model.genotype_edge_all()

            if self.rank == 0:
                logging.info('genotype_edge_all = %s', genotype_edge_all)
                # create genotypes.txt file
                txt_name = remark + '_genotype_edge_all_epoch' + str(epoch)
                utils.txt('genotype', self.args.save, txt_name,
                          str(genotype_edge_all), generate_date)

            self.model.train()
            train_acc, loss, error_loss, loss_alpha = self.train(
                epoch, logging)
            if self.rank == 0:
                logging.info('train_acc %f', train_acc)
                self.logger.add_scalar("epoch_train_acc", train_acc, epoch)
                self.logger.add_scalar("epoch_train_error_loss", error_loss,
                                       epoch)
                if self.args.dsnas:
                    self.logger.add_scalar("epoch_train_alpha_loss",
                                           loss_alpha, epoch)

                if self.args.dsnas and not self.args.child_reward_stat:
                    if self.args.current_reward:
                        logging.info('reward mean stat')
                        logging.info(self.model.normal_reward_mean)
                        logging.info(self.model.reduce_reward_mean)
                        logging.info('count')
                        logging.info(self.model.count)
                    else:
                        logging.info('reward mean stat')
                        logging.info(self.model.normal_reward_mean)
                        logging.info(self.model.reduce_reward_mean)
                        if self.model.normal_reward_mean.size(0) > 1:
                            logging.info('reward mean total stat')
                            logging.info(self.model.normal_reward_mean.sum(0))
                            logging.info(self.model.reduce_reward_mean.sum(0))

                if self.args.child_reward_stat:
                    logging.info('reward mean stat')
                    logging.info(self.model.normal_reward_mean.sum(0))
                    logging.info(self.model.reduce_reward_mean.sum(0))
                    logging.info('reward var stat')
                    logging.info(
                        self.model.normal_reward_mean_square.sum(0) -
                        self.model.normal_reward_mean.sum(0)**2)
                    logging.info(
                        self.model.reduce_reward_mean_square.sum(0) -
                        self.model.reduce_reward_mean.sum(0)**2)

            # validation
            self.model.eval()
            valid_acc, valid_obj = self.infer(epoch)
            if self.args.gen_max_child:
                self.args.gen_max_child_flag = True
                valid_acc_max_child, valid_obj_max_child = self.infer(epoch)
                self.args.gen_max_child_flag = False

            if self.rank == 0:
                logging.info('valid_acc %f', valid_acc)
                self.logger.add_scalar("epoch_valid_acc", valid_acc, epoch)
                if self.args.gen_max_child:
                    logging.info('valid_acc_argmax_alpha %f',
                                 valid_acc_max_child)
                    self.logger.add_scalar("epoch_valid_acc_argmax_alpha",
                                           valid_acc_max_child, epoch)

                utils.save(self.model, os.path.join(self.path, 'weights.pt'))

        if self.rank == 0:
            logging.info(self.model.normal_log_alpha)
            logging.info(self.model.reduce_log_alpha)
            genotype_edge_all = self.model.genotype_edge_all()
            logging.info('genotype_edge_all = %s', genotype_edge_all)

    def train(self, epoch, logging):
        objs = utils.AvgrageMeter()
        top1 = utils.AvgrageMeter()
        top5 = utils.AvgrageMeter()
        grad = utils.AvgrageMeter()

        normal_loss_gradient = 0
        reduce_loss_gradient = 0
        normal_total_gradient = 0
        reduce_total_gradient = 0

        loss_alpha = None

        train_correct_count = 0
        train_correct_cost = 0
        train_correct_entropy = 0
        train_correct_loss = 0
        train_wrong_count = 0
        train_wrong_cost = 0
        train_wrong_entropy = 0
        train_wrong_loss = 0

        count = 0
        for step, (input, target) in enumerate(self.train_queue):

            n = input.size(0)
            input = input.to(self.device)
            target = target.to(self.device, non_blocking=True)
            if self.args.snas:
                logits, logits_aux = self.model(input)
                error_loss = self.criterion(logits, target)
                if self.args.auxiliary:
                    loss_aux = self.criterion(logits_aux, target)
                    error_loss += self.args.auxiliary_weight * loss_aux

            if self.args.dsnas:
                logits, error_loss, loss_alpha = self.model(
                    input,
                    target,
                    self.criterion,
                    update_theta=self.update_theta,
                    update_alpha=self.update_alpha)

            for i in range(logits.size(0)):
                index = logits[i].topk(5, 0, True, True)[1]
                if index[0].item() == target[i].item():
                    train_correct_cost += (
                        -logits[i, target[i].item()] +
                        (F.softmax(logits[i]) * logits[i]).sum())
                    train_correct_count += 1
                    discrete_prob = F.softmax(logits[i], dim=-1)
                    train_correct_entropy += -(
                        discrete_prob * torch.log(discrete_prob)).sum(-1)
                    train_correct_loss += -torch.log(discrete_prob)[
                        target[i].item()]
                else:
                    train_wrong_cost += (
                        -logits[i, target[i].item()] +
                        (F.softmax(logits[i]) * logits[i]).sum())
                    train_wrong_count += 1
                    discrete_prob = F.softmax(logits[i], dim=-1)
                    train_wrong_entropy += -(discrete_prob *
                                             torch.log(discrete_prob)).sum(-1)
                    train_wrong_loss += -torch.log(discrete_prob)[
                        target[i].item()]

            num_normal = self.model.num_normal
            num_reduce = self.model.num_reduce

            if self.args.snas or self.args.dsnas:
                loss = error_loss.clone()

            #self.update_lr()

            # logging gradient
            count += 1
            if self.args.snas:
                self.optimizer.zero_grad()
                self.arch_optimizer.zero_grad()
                error_loss.backward(retain_graph=True)
                if not self.args.random_sample:
                    normal_loss_gradient += self.model.normal_log_alpha.grad
                    reduce_loss_gradient += self.model.reduce_log_alpha.grad
                self.optimizer.zero_grad()
                self.arch_optimizer.zero_grad()

            if self.args.snas and (not self.args.random_sample
                                   and not self.args.dsnas):
                loss.backward()

            if not self.args.random_sample:
                normal_total_gradient += self.model.normal_log_alpha.grad
                reduce_total_gradient += self.model.reduce_log_alpha.grad

            nn.utils.clip_grad_norm_(self.model.parameters(),
                                     self.args.grad_clip)
            arch_grad_norm = nn.utils.clip_grad_norm_(
                self.model.arch_parameters(), 10.)

            grad.update(arch_grad_norm)
            if not self.args.fix_weight and self.update_theta:
                self.optimizer.step()
            self.optimizer.zero_grad()

            if not self.args.random_sample and self.update_alpha:
                self.arch_optimizer.step()
            self.arch_optimizer.zero_grad()

            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))

            objs.update(error_loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

            if step % self.args.report_freq == 0 and self.rank == 0:
                logging.info('train %03d %e %f %f', step, objs.avg, top1.avg,
                             top5.avg)
                self.logger.add_scalar(
                    "iter_train_top1_acc", top1.avg,
                    step + len(self.train_queue.dataset) * epoch)

        if self.rank == 0:
            logging.info('-------loss gradient--------')
            logging.info(normal_loss_gradient / count)
            logging.info(reduce_loss_gradient / count)
            logging.info('-------total gradient--------')
            logging.info(normal_total_gradient / count)
            logging.info(reduce_total_gradient / count)

        logging.info('correct loss ')
        logging.info((train_correct_loss / train_correct_count).item())
        logging.info('correct entropy ')
        logging.info((train_correct_entropy / train_correct_count).item())
        logging.info('correct cost ')
        logging.info((train_correct_cost / train_correct_count).item())
        logging.info('correct count ')
        logging.info(train_correct_count)

        logging.info('wrong loss ')
        logging.info((train_wrong_loss / train_wrong_count).item())
        logging.info('wrong entropy ')
        logging.info((train_wrong_entropy / train_wrong_count).item())
        logging.info('wrong cost ')
        logging.info((train_wrong_cost / train_wrong_count).item())
        logging.info('wrong count ')
        logging.info(train_wrong_count)

        logging.info('total loss ')
        logging.info(((train_correct_loss + train_wrong_loss) /
                      (train_correct_count + train_wrong_count)).item())
        logging.info('total entropy ')
        logging.info(((train_correct_entropy + train_wrong_entropy) /
                      (train_correct_count + train_wrong_count)).item())
        logging.info('total cost ')
        logging.info(((train_correct_cost + train_wrong_cost) /
                      (train_correct_count + train_wrong_count)).item())
        logging.info('total count ')
        logging.info(train_correct_count + train_wrong_count)

        return top1.avg, loss, error_loss, loss_alpha

    def infer(self, epoch):
        objs = utils.AvgrageMeter()
        top1 = utils.AvgrageMeter()
        top5 = utils.AvgrageMeter()

        self.model.eval()
        with torch.no_grad():
            for step, (input, target) in enumerate(self.valid_queue):
                input = input.to(self.device)
                target = target.to(self.device)
                if self.args.snas:
                    logits, logits_aux = self.model(input)
                    loss = self.criterion(logits, target)
                elif self.args.dsnas:
                    logits, error_loss, loss_alpha = self.model(
                        input, target, self.criterion)
                    loss = error_loss

                prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))

                objs.update(loss.item(), input.size(0))
                top1.update(prec1.item(), input.size(0))
                top5.update(prec5.item(), input.size(0))

                if step % self.args.report_freq == 0 and self.rank == 0:
                    logging.info('valid %03d %e %f %f', step, objs.avg,
                                 top1.avg, top5.avg)
                    self.logger.add_scalar(
                        "iter_valid_loss", loss,
                        step + len(self.valid_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "iter_valid_top1_acc", top1.avg,
                        step + len(self.valid_queue.dataset) * epoch)

        return top1.avg, objs.avg