class neural_architecture_search(): def __init__(self, args): self.args = args if not torch.cuda.is_available(): logging.info('no gpu device available') sys.exit(1) if self.args.distributed: # Init distributed environment self.rank, self.world_size, self.device = init_dist( port=self.args.port) self.seed = self.rank * self.args.seed else: torch.cuda.set_device(self.args.gpu) self.device = torch.device("cuda") self.rank = 0 self.seed = self.args.seed self.world_size = 1 if self.args.fix_seedcudnn: random.seed(self.seed) torch.backends.cudnn.deterministic = True np.random.seed(self.seed) cudnn.benchmark = False torch.manual_seed(self.seed) cudnn.enabled = True torch.cuda.manual_seed(self.seed) torch.cuda.manual_seed_all(self.seed) else: np.random.seed(self.seed) cudnn.benchmark = True torch.manual_seed(self.seed) cudnn.enabled = True torch.cuda.manual_seed(self.seed) torch.cuda.manual_seed_all(self.seed) self.path = os.path.join(generate_date, self.args.save) if self.rank == 0: utils.create_exp_dir(generate_date, self.path, scripts_to_save=glob.glob('*.py')) logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(self.path, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) logging.info("self.args = %s", self.args) self.logger = tensorboardX.SummaryWriter( './runs/' + generate_date + '/nas_{}'.format(self.args.remark)) else: self.logger = None # set default resource_lambda for different methods if self.args.resource_efficient: if self.args.method == 'policy_gradient': if self.args.log_penalty: default_resource_lambda = 1e-4 else: default_resource_lambda = 1e-5 if self.args.method == 'reparametrization': if self.args.log_penalty: default_resource_lambda = 1e-2 else: default_resource_lambda = 1e-5 if self.args.method == 'discrete': if self.args.log_penalty: default_resource_lambda = 1e-2 else: default_resource_lambda = 1e-4 if self.args.resource_lambda == default_lambda: self.args.resource_lambda = default_resource_lambda #initialize loss function self.criterion = nn.CrossEntropyLoss().to(self.device) #initialize model self.init_model() #calculate model param size if self.rank == 0: logging.info("param size = %fMB", utils.count_parameters_in_MB(self.model)) self.model._logger = self.logger self.model._logging = logging #initialize optimizer self.init_optimizer() #iniatilize dataset loader self.init_loaddata() self.update_theta = True self.update_alpha = True def init_model(self): self.model = Network(self.args.init_channels, CIFAR_CLASSES, self.args.layers, self.criterion, self.args, self.rank, self.world_size) self.model.to(self.device) if self.args.distributed: broadcast_params(self.model) for v in self.model.parameters(): if v.requires_grad: if v.grad is None: v.grad = torch.zeros_like(v) self.model.normal_log_alpha.grad = torch.zeros_like( self.model.normal_log_alpha) self.model.reduce_log_alpha.grad = torch.zeros_like( self.model.reduce_log_alpha) def init_optimizer(self): if args.distributed: self.optimizer = torch.optim.SGD( [ param for name, param in self.model.named_parameters() if name != 'normal_log_alpha' and name != 'reduce_log_alpha' ], self.args.learning_rate, momentum=self.args.momentum, weight_decay=self.args.weight_decay) self.arch_optimizer = torch.optim.Adam( [ param for name, param in self.model.named_parameters() if name == 'normal_log_alpha' or name == 'reduce_log_alpha' ], lr=self.args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=self.args.arch_weight_decay) else: self.optimizer = torch.optim.SGD(self.model.parameters(), self.args.learning_rate, momentum=self.args.momentum, weight_decay=args.weight_decay) self.arch_optimizer = torch.optim.SGD( self.model.arch_parameters(), lr=self.args.arch_learning_rate) def init_loaddata(self): train_transform, valid_transform = utils._data_transforms_cifar10( self.args) train_data = dset.CIFAR10(root=self.args.data, train=True, download=True, transform=train_transform) valid_data = dset.CIFAR10(root=self.args.data, train=False, download=True, transform=valid_transform) if self.args.seed: def worker_init_fn(): seed = self.seed np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) return else: worker_init_fn = None if self.args.distributed: train_sampler = DistributedSampler(train_data) valid_sampler = DistributedSampler(valid_data) self.train_queue = torch.utils.data.DataLoader( train_data, batch_size=self.args.batch_size // self.world_size, shuffle=False, num_workers=0, pin_memory=False, sampler=train_sampler) self.valid_queue = torch.utils.data.DataLoader( valid_data, batch_size=self.args.batch_size // self.world_size, shuffle=False, num_workers=0, pin_memory=False, sampler=valid_sampler) else: self.train_queue = torch.utils.data.DataLoader( train_data, batch_size=self.args.batch_size, shuffle=True, pin_memory=False, num_workers=2) self.valid_queue = torch.utils.data.DataLoader( valid_data, batch_size=self.args.batch_size, shuffle=False, pin_memory=False, num_workers=2) def main(self): # lr scheduler: cosine annealing # temp scheduler: linear annealing (self-defined in utils) self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, float(self.args.epochs), eta_min=self.args.learning_rate_min) self.temp_scheduler = utils.Temp_Scheduler(self.args.epochs, self.model._temp, self.args.temp, temp_min=self.args.temp_min) for epoch in range(self.args.epochs): if self.args.random_sample_pretrain: if epoch < self.args.random_sample_pretrain_epoch: self.args.random_sample = True else: self.args.random_sample = False self.scheduler.step() if self.args.temp_annealing: self.model._temp = self.temp_scheduler.step() self.lr = self.scheduler.get_lr()[0] if self.rank == 0: logging.info('epoch %d lr %e temp %e', epoch, self.lr, self.model._temp) self.logger.add_scalar('epoch_temp', self.model._temp, epoch) logging.info(self.model.normal_log_alpha) logging.info(self.model.reduce_log_alpha) logging.info( self.model._get_weights(self.model.normal_log_alpha[0])) logging.info( self.model._get_weights(self.model.reduce_log_alpha[0])) genotype_edge_all = self.model.genotype_edge_all() if self.rank == 0: logging.info('genotype_edge_all = %s', genotype_edge_all) # create genotypes.txt file txt_name = self.args.remark + '_genotype_edge_all_epoch' + str( epoch) utils.txt('genotype', self.args.save, txt_name, str(genotype_edge_all), generate_date) self.model.train() train_acc, loss, error_loss, loss_alpha = self.train( epoch, logging) if self.rank == 0: logging.info('train_acc %f', train_acc) self.logger.add_scalar("epoch_train_acc", train_acc, epoch) self.logger.add_scalar("epoch_train_error_loss", error_loss, epoch) if self.args.dsnas: self.logger.add_scalar("epoch_train_alpha_loss", loss_alpha, epoch) # validation self.model.eval() valid_acc, valid_obj = self.infer(epoch) if self.args.gen_max_child: self.args.gen_max_child_flag = True valid_acc_max_child, valid_obj_max_child = self.infer(epoch) self.args.gen_max_child_flag = False if self.rank == 0: logging.info('valid_acc %f', valid_acc) self.logger.add_scalar("epoch_valid_acc", valid_acc, epoch) if self.args.gen_max_child: logging.info('valid_acc_argmax_alpha %f', valid_acc_max_child) self.logger.add_scalar("epoch_valid_acc_argmax_alpha", valid_acc_max_child, epoch) utils.save(self.model, os.path.join(self.path, 'weights.pt')) if self.rank == 0: logging.info(self.model.normal_log_alpha) logging.info(self.model.reduce_log_alpha) genotype_edge_all = self.model.genotype_edge_all() logging.info('genotype_edge_all = %s', genotype_edge_all) def train(self, epoch, logging): objs = utils.AvgrageMeter() top1 = utils.AvgrageMeter() top5 = utils.AvgrageMeter() grad = utils.AvgrageMeter() normal_resource_gradient = 0 reduce_resource_gradient = 0 normal_loss_gradient = 0 reduce_loss_gradient = 0 normal_total_gradient = 0 reduce_total_gradient = 0 loss_alpha = None count = 0 for step, (input, target) in enumerate(self.train_queue): if self.args.alternate_update: if step % 2 == 0: self.update_theta = True self.update_alpha = False else: self.update_theta = False self.update_alpha = True n = input.size(0) input = input.to(self.device) target = target.to(self.device, non_blocking=True) if self.args.snas: logits, logits_aux, penalty, op_normal, op_reduce = self.model( input) error_loss = self.criterion(logits, target) if self.args.auxiliary: loss_aux = self.criterion(logits_aux, target) error_loss += self.args.auxiliary_weight * loss_aux if self.args.dsnas: logits, error_loss, loss_alpha, penalty = self.model( input, target, self.criterion) num_normal = self.model.num_normal num_reduce = self.model.num_reduce normal_arch_entropy = self.model._arch_entropy( self.model.normal_log_alpha) reduce_arch_entropy = self.model._arch_entropy( self.model.reduce_log_alpha) if self.args.resource_efficient: if self.args.method == 'policy_gradient': resource_penalty = (penalty[2]) / 6 + self.args.ratio * ( penalty[7]) / 2 log_resource_penalty = ( penalty[35]) / 6 + self.args.ratio * (penalty[36]) / 2 elif self.args.method == 'reparametrization': resource_penalty = (penalty[26]) / 6 + self.args.ratio * ( penalty[25]) / 2 log_resource_penalty = ( penalty[37]) / 6 + self.args.ratio * (penalty[38]) / 2 elif self.args.method == 'discrete': resource_penalty = (penalty[28]) / 6 + self.args.ratio * ( penalty[27]) / 2 log_resource_penalty = ( penalty[39]) / 6 + self.args.ratio * (penalty[40]) / 2 elif self.args.method == 'none': # TODo resource_penalty = torch.zeros(1).cuda() log_resource_penalty = torch.zeros(1).cuda() else: logging.info( "wrongly input of method, please re-enter --method from 'policy_gradient', 'discrete', " "'reparametrization', 'none'") sys.exit(1) else: resource_penalty = torch.zeros(1).cuda() log_resource_penalty = torch.zeros(1).cuda() if self.args.log_penalty: resource_loss = self.model._resource_lambda * log_resource_penalty else: resource_loss = self.model._resource_lambda * resource_penalty if self.args.loss: if self.args.snas: loss = resource_loss.clone() + error_loss.clone() elif self.args.dsnas: loss = resource_loss.clone() else: loss = resource_loss.clone() + -child_coef * ( torch.log(normal_one_hot_prob) + torch.log(reduce_one_hot_prob)).sum() else: if self.args.snas or self.args.dsnas: loss = error_loss.clone() if self.args.distributed: loss.div_(self.world_size) error_loss.div_(self.world_size) resource_loss.div_(self.world_size) if self.args.dsnas: loss_alpha.div_(self.world_size) # logging gradient count += 1 if self.args.resource_efficient: self.optimizer.zero_grad() self.arch_optimizer.zero_grad() resource_loss.backward(retain_graph=True) if not self.args.random_sample: normal_resource_gradient += self.model.normal_log_alpha.grad reduce_resource_gradient += self.model.reduce_log_alpha.grad if self.args.snas: self.optimizer.zero_grad() self.arch_optimizer.zero_grad() error_loss.backward(retain_graph=True) if not self.args.random_sample: normal_loss_gradient += self.model.normal_log_alpha.grad reduce_loss_gradient += self.model.reduce_log_alpha.grad self.optimizer.zero_grad() self.arch_optimizer.zero_grad() if self.args.snas or not self.args.random_sample and not self.args.dsnas: loss.backward() if not self.args.random_sample: normal_total_gradient += self.model.normal_log_alpha.grad reduce_total_gradient += self.model.reduce_log_alpha.grad if self.args.distributed: reduce_tensorgradients(self.model.parameters(), sync=True) nn.utils.clip_grad_norm_([ param for name, param in self.model.named_parameters() if name != 'normal_log_alpha' and name != 'reduce_log_alpha' ], self.args.grad_clip) arch_grad_norm = nn.utils.clip_grad_norm_([ param for name, param in self.model.named_parameters() if name == 'normal_log_alpha' or name == 'reduce_log_alpha' ], 10.) else: nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip) arch_grad_norm = nn.utils.clip_grad_norm_( self.model.arch_parameters(), 10.) grad.update(arch_grad_norm) if not self.args.fix_weight and self.update_theta: self.optimizer.step() self.optimizer.zero_grad() if not self.args.random_sample and self.update_alpha: self.arch_optimizer.step() self.arch_optimizer.zero_grad() if self.rank == 0: self.logger.add_scalar( "iter_train_loss", error_loss, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "normal_arch_entropy", normal_arch_entropy, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reduce_arch_entropy", reduce_arch_entropy, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "total_arch_entropy", normal_arch_entropy + reduce_arch_entropy, step + len(self.train_queue.dataset) * epoch) if self.args.dsnas: #reward_normal_edge self.logger.add_scalar( "reward_normal_edge_0", self.model.normal_edge_reward[0], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_normal_edge_1", self.model.normal_edge_reward[1], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_normal_edge_2", self.model.normal_edge_reward[2], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_normal_edge_3", self.model.normal_edge_reward[3], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_normal_edge_4", self.model.normal_edge_reward[4], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_normal_edge_5", self.model.normal_edge_reward[5], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_normal_edge_6", self.model.normal_edge_reward[6], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_normal_edge_7", self.model.normal_edge_reward[7], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_normal_edge_8", self.model.normal_edge_reward[8], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_normal_edge_9", self.model.normal_edge_reward[9], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_normal_edge_10", self.model.normal_edge_reward[10], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_normal_edge_11", self.model.normal_edge_reward[11], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_normal_edge_12", self.model.normal_edge_reward[12], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_normal_edge_13", self.model.normal_edge_reward[13], step + len(self.train_queue.dataset) * epoch) #reward_reduce_edge self.logger.add_scalar( "reward_reduce_edge_0", self.model.reduce_edge_reward[0], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_reduce_edge_1", self.model.reduce_edge_reward[1], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_reduce_edge_2", self.model.reduce_edge_reward[2], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_reduce_edge_3", self.model.reduce_edge_reward[3], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_reduce_edge_4", self.model.reduce_edge_reward[4], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_reduce_edge_5", self.model.reduce_edge_reward[5], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_reduce_edge_6", self.model.reduce_edge_reward[6], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_reduce_edge_7", self.model.reduce_edge_reward[7], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_reduce_edge_8", self.model.reduce_edge_reward[8], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_reduce_edge_9", self.model.reduce_edge_reward[9], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_reduce_edge_10", self.model.reduce_edge_reward[10], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_reduce_edge_11", self.model.reduce_edge_reward[11], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_reduce_edge_12", self.model.reduce_edge_reward[12], step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "reward_reduce_edge_13", self.model.reduce_edge_reward[13], step + len(self.train_queue.dataset) * epoch) #policy size self.logger.add_scalar( "iter_normal_size_policy", penalty[2] / num_normal, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_reduce_size_policy", penalty[7] / num_reduce, step + len(self.train_queue.dataset) * epoch) # baseline: discrete_probability self.logger.add_scalar( "iter_normal_size_baseline", penalty[3] / num_normal, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_normal_flops_baseline", penalty[5] / num_normal, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_normal_mac_baseline", penalty[6] / num_normal, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_reduce_size_baseline", penalty[8] / num_reduce, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_reduce_flops_baseline", penalty[9] / num_reduce, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_reduce_mac_baseline", penalty[10] / num_reduce, step + len(self.train_queue.dataset) * epoch) # R - median(R) self.logger.add_scalar( "iter_normal_size-avg", penalty[60] / num_normal, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_normal_flops-avg", penalty[61] / num_normal, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_normal_mac-avg", penalty[62] / num_normal, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_reduce_size-avg", penalty[63] / num_reduce, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_reduce_flops-avg", penalty[64] / num_reduce, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_reduce_mac-avg", penalty[65] / num_reduce, step + len(self.train_queue.dataset) * epoch) # lnR - ln(median) self.logger.add_scalar( "iter_normal_ln_size-ln_avg", penalty[66] / num_normal, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_normal_ln_flops-ln_avg", penalty[67] / num_normal, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_normal_ln_mac-ln_avg", penalty[68] / num_normal, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_reduce_ln_size-ln_avg", penalty[69] / num_reduce, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_reduce_ln_flops-ln_avg", penalty[70] / num_reduce, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_reduce_ln_mac-ln_avg", penalty[71] / num_reduce, step + len(self.train_queue.dataset) * epoch) ''' self.logger.add_scalar("iter_normal_size_normalized", penalty[17] / 6, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar("iter_normal_flops_normalized", penalty[18] / 6, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar("iter_normal_mac_normalized", penalty[19] / 6, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar("iter_reduce_size_normalized", penalty[20] / 2, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar("iter_reduce_flops_normalized", penalty[21] / 2, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar("iter_reduce_mac_normalized", penalty[22] / 2, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar("iter_normal_penalty_normalized", penalty[23] / 6, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar("iter_reduce_penalty_normalized", penalty[24] / 2, step + len(self.train_queue.dataset) * epoch) ''' # Monte_Carlo(R_i) self.logger.add_scalar( "iter_normal_size_mc", penalty[29] / num_normal, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_normal_flops_mc", penalty[30] / num_normal, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_normal_mac_mc", penalty[31] / num_normal, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_reduce_size_mc", penalty[32] / num_reduce, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_reduce_flops_mc", penalty[33] / num_reduce, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_reduce_mac_mc", penalty[34] / num_reduce, step + len(self.train_queue.dataset) * epoch) # log(|R_i|) self.logger.add_scalar( "iter_normal_log_size", penalty[41] / num_normal, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_normal_log_flops", penalty[42] / num_normal, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_normal_log_mac", penalty[43] / num_normal, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_reduce_log_size", penalty[44] / num_reduce, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_reduce_log_flops", penalty[45] / num_reduce, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_reduce_log_mac", penalty[46] / num_reduce, step + len(self.train_queue.dataset) * epoch) # log(P)R_i self.logger.add_scalar( "iter_normal_logP_size", penalty[47] / num_normal, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_normal_logP_flops", penalty[48] / num_normal, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_normal_logP_mac", penalty[49] / num_normal, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_reduce_logP_size", penalty[50] / num_reduce, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_reduce_logP_flops", penalty[51] / num_reduce, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_reduce_logP_mac", penalty[52] / num_reduce, step + len(self.train_queue.dataset) * epoch) # log(P)log(R_i) self.logger.add_scalar( "iter_normal_logP_log_size", penalty[53] / num_normal, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_normal_logP_log_flops", penalty[54] / num_normal, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_normal_logP_log_mac", penalty[55] / num_normal, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_reduce_logP_log_size", penalty[56] / num_reduce, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_reduce_logP_log_flops", penalty[57] / num_reduce, step + len(self.train_queue.dataset) * epoch) self.logger.add_scalar( "iter_reduce_logP_log_mac", penalty[58] / num_reduce, step + len(self.train_queue.dataset) * epoch) prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) if self.args.distributed: loss = loss.detach() dist.all_reduce(error_loss) dist.all_reduce(prec1) dist.all_reduce(prec5) prec1.div_(self.world_size) prec5.div_(self.world_size) #dist_util.all_reduce([loss, prec1, prec5], 'mean') objs.update(error_loss.item(), n) top1.update(prec1.item(), n) top5.update(prec5.item(), n) if step % self.args.report_freq == 0 and self.rank == 0: logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) self.logger.add_scalar( "iter_train_top1_acc", top1.avg, step + len(self.train_queue.dataset) * epoch) if self.rank == 0: logging.info('-------resource gradient--------') logging.info(normal_resource_gradient / count) logging.info(reduce_resource_gradient / count) logging.info('-------loss gradient--------') logging.info(normal_loss_gradient / count) logging.info(reduce_loss_gradient / count) logging.info('-------total gradient--------') logging.info(normal_total_gradient / count) logging.info(reduce_total_gradient / count) return top1.avg, loss, error_loss, loss_alpha def infer(self, epoch): objs = utils.AvgrageMeter() top1 = utils.AvgrageMeter() top5 = utils.AvgrageMeter() self.model.eval() with torch.no_grad(): for step, (input, target) in enumerate(self.valid_queue): input = input.to(self.device) target = target.to(self.device) if self.args.snas: logits, logits_aux, resource_loss, op_normal, op_reduce = self.model( input) loss = self.criterion(logits, target) elif self.args.dsnas: logits, error_loss, loss_alpha, resource_loss = self.model( input, target, self.criterion) loss = error_loss prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) if self.args.distributed: loss.div_(self.world_size) loss = loss.detach() dist.all_reduce(loss) dist.all_reduce(prec1) dist.all_reduce(prec5) prec1.div_(self.world_size) prec5.div_(self.world_size) objs.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) top5.update(prec5.item(), input.size(0)) if step % self.args.report_freq == 0 and self.rank == 0: logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) self.logger.add_scalar( "iter_valid_loss", loss, step + len(self.valid_queue.dataset) * epoch) self.logger.add_scalar( "iter_valid_top1_acc", top1.avg, step + len(self.valid_queue.dataset) * epoch) return top1.avg, objs.avg
class neural_architecture_search(): def __init__(self, args): self.args = args if not torch.cuda.is_available(): logging.info('no gpu device available') sys.exit(1) torch.cuda.set_device(self.args.gpu) self.device = torch.device("cuda") self.rank = 0 self.seed = self.args.seed self.world_size = 1 if self.args.fix_cudnn: random.seed(self.seed) torch.backends.cudnn.deterministic = True np.random.seed(self.seed) cudnn.benchmark = False torch.manual_seed(self.seed) cudnn.enabled = True torch.cuda.manual_seed(self.seed) torch.cuda.manual_seed_all(self.seed) else: np.random.seed(self.seed) cudnn.benchmark = True torch.manual_seed(self.seed) cudnn.enabled = True torch.cuda.manual_seed(self.seed) torch.cuda.manual_seed_all(self.seed) self.path = os.path.join(generate_date, self.args.save) if self.rank == 0: utils.create_exp_dir(generate_date, self.path, scripts_to_save=glob.glob('*.py')) logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(self.path, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) logging.info("self.args = %s", self.args) self.logger = tensorboardX.SummaryWriter('./runs/' + generate_date + '/' + self.args.save_log) else: self.logger = None #initialize loss function self.criterion = nn.CrossEntropyLoss().to(self.device) #initialize model self.init_model() if self.args.resume: self.reload_model() #calculate model param size if self.rank == 0: logging.info("param size = %fMB", utils.count_parameters_in_MB(self.model)) self.model._logger = self.logger self.model._logging = logging #initialize optimizer self.init_optimizer() #iniatilize dataset loader self.init_loaddata() self.update_theta = True self.update_alpha = True def init_model(self): self.model = Network(self.args.init_channels, CIFAR_CLASSES, self.args.layers, self.criterion, self.args, self.rank, self.world_size, self.args.steps, self.args.multiplier) self.model.to(self.device) for v in self.model.parameters(): if v.requires_grad: if v.grad is None: v.grad = torch.zeros_like(v) self.model.normal_log_alpha.grad = torch.zeros_like( self.model.normal_log_alpha) self.model.reduce_log_alpha.grad = torch.zeros_like( self.model.reduce_log_alpha) def reload_model(self): self.model.load_state_dict(torch.load(self.args.resume_path + '/weights.pt'), strict=True) def init_optimizer(self): self.optimizer = torch.optim.SGD(self.model.parameters(), self.args.learning_rate, momentum=self.args.momentum, weight_decay=args.weight_decay) self.arch_optimizer = torch.optim.Adam( self.model.arch_parameters(), lr=self.args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=self.args.arch_weight_decay) def init_loaddata(self): train_transform, valid_transform = utils._data_transforms_cifar10( self.args) train_data = dset.CIFAR10(root=self.args.data, train=True, download=True, transform=train_transform) valid_data = dset.CIFAR10(root=self.args.data, train=False, download=True, transform=valid_transform) if self.args.seed: def worker_init_fn(): seed = self.seed np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) return else: worker_init_fn = None num_train = len(train_data) indices = list(range(num_train)) self.train_queue = torch.utils.data.DataLoader( train_data, batch_size=self.args.batch_size, shuffle=True, pin_memory=False, num_workers=2) self.valid_queue = torch.utils.data.DataLoader( valid_data, batch_size=self.args.batch_size, shuffle=False, pin_memory=False, num_workers=2) def main(self): # lr scheduler: cosine annealing # temp scheduler: linear annealing (self-defined in utils) self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, float(self.args.epochs), eta_min=self.args.learning_rate_min) self.temp_scheduler = utils.Temp_Scheduler(self.args.epochs, self.model._temp, self.args.temp, temp_min=self.args.temp_min) for epoch in range(self.args.epochs): if self.args.child_reward_stat: self.update_theta = False self.update_alpha = False if self.args.current_reward: self.model.normal_reward_mean = torch.zeros_like( self.model.normal_reward_mean) self.model.reduce_reward_mean = torch.zeros_like( self.model.reduce_reward_mean) self.model.count = 0 if epoch < self.args.resume_epoch: continue self.scheduler.step() if self.args.temp_annealing: self.model._temp = self.temp_scheduler.step() self.lr = self.scheduler.get_lr()[0] if self.rank == 0: logging.info('epoch %d lr %e temp %e', epoch, self.lr, self.model._temp) self.logger.add_scalar('epoch_temp', self.model._temp, epoch) logging.info(self.model.normal_log_alpha) logging.info(self.model.reduce_log_alpha) logging.info(F.softmax(self.model.normal_log_alpha, dim=-1)) logging.info(F.softmax(self.model.reduce_log_alpha, dim=-1)) genotype_edge_all = self.model.genotype_edge_all() if self.rank == 0: logging.info('genotype_edge_all = %s', genotype_edge_all) # create genotypes.txt file txt_name = remark + '_genotype_edge_all_epoch' + str(epoch) utils.txt('genotype', self.args.save, txt_name, str(genotype_edge_all), generate_date) self.model.train() train_acc, loss, error_loss, loss_alpha = self.train( epoch, logging) if self.rank == 0: logging.info('train_acc %f', train_acc) self.logger.add_scalar("epoch_train_acc", train_acc, epoch) self.logger.add_scalar("epoch_train_error_loss", error_loss, epoch) if self.args.dsnas: self.logger.add_scalar("epoch_train_alpha_loss", loss_alpha, epoch) if self.args.dsnas and not self.args.child_reward_stat: if self.args.current_reward: logging.info('reward mean stat') logging.info(self.model.normal_reward_mean) logging.info(self.model.reduce_reward_mean) logging.info('count') logging.info(self.model.count) else: logging.info('reward mean stat') logging.info(self.model.normal_reward_mean) logging.info(self.model.reduce_reward_mean) if self.model.normal_reward_mean.size(0) > 1: logging.info('reward mean total stat') logging.info(self.model.normal_reward_mean.sum(0)) logging.info(self.model.reduce_reward_mean.sum(0)) if self.args.child_reward_stat: logging.info('reward mean stat') logging.info(self.model.normal_reward_mean.sum(0)) logging.info(self.model.reduce_reward_mean.sum(0)) logging.info('reward var stat') logging.info( self.model.normal_reward_mean_square.sum(0) - self.model.normal_reward_mean.sum(0)**2) logging.info( self.model.reduce_reward_mean_square.sum(0) - self.model.reduce_reward_mean.sum(0)**2) # validation self.model.eval() valid_acc, valid_obj = self.infer(epoch) if self.args.gen_max_child: self.args.gen_max_child_flag = True valid_acc_max_child, valid_obj_max_child = self.infer(epoch) self.args.gen_max_child_flag = False if self.rank == 0: logging.info('valid_acc %f', valid_acc) self.logger.add_scalar("epoch_valid_acc", valid_acc, epoch) if self.args.gen_max_child: logging.info('valid_acc_argmax_alpha %f', valid_acc_max_child) self.logger.add_scalar("epoch_valid_acc_argmax_alpha", valid_acc_max_child, epoch) utils.save(self.model, os.path.join(self.path, 'weights.pt')) if self.rank == 0: logging.info(self.model.normal_log_alpha) logging.info(self.model.reduce_log_alpha) genotype_edge_all = self.model.genotype_edge_all() logging.info('genotype_edge_all = %s', genotype_edge_all) def train(self, epoch, logging): objs = utils.AvgrageMeter() top1 = utils.AvgrageMeter() top5 = utils.AvgrageMeter() grad = utils.AvgrageMeter() normal_loss_gradient = 0 reduce_loss_gradient = 0 normal_total_gradient = 0 reduce_total_gradient = 0 loss_alpha = None train_correct_count = 0 train_correct_cost = 0 train_correct_entropy = 0 train_correct_loss = 0 train_wrong_count = 0 train_wrong_cost = 0 train_wrong_entropy = 0 train_wrong_loss = 0 count = 0 for step, (input, target) in enumerate(self.train_queue): n = input.size(0) input = input.to(self.device) target = target.to(self.device, non_blocking=True) if self.args.snas: logits, logits_aux = self.model(input) error_loss = self.criterion(logits, target) if self.args.auxiliary: loss_aux = self.criterion(logits_aux, target) error_loss += self.args.auxiliary_weight * loss_aux if self.args.dsnas: logits, error_loss, loss_alpha = self.model( input, target, self.criterion, update_theta=self.update_theta, update_alpha=self.update_alpha) for i in range(logits.size(0)): index = logits[i].topk(5, 0, True, True)[1] if index[0].item() == target[i].item(): train_correct_cost += ( -logits[i, target[i].item()] + (F.softmax(logits[i]) * logits[i]).sum()) train_correct_count += 1 discrete_prob = F.softmax(logits[i], dim=-1) train_correct_entropy += -( discrete_prob * torch.log(discrete_prob)).sum(-1) train_correct_loss += -torch.log(discrete_prob)[ target[i].item()] else: train_wrong_cost += ( -logits[i, target[i].item()] + (F.softmax(logits[i]) * logits[i]).sum()) train_wrong_count += 1 discrete_prob = F.softmax(logits[i], dim=-1) train_wrong_entropy += -(discrete_prob * torch.log(discrete_prob)).sum(-1) train_wrong_loss += -torch.log(discrete_prob)[ target[i].item()] num_normal = self.model.num_normal num_reduce = self.model.num_reduce if self.args.snas or self.args.dsnas: loss = error_loss.clone() #self.update_lr() # logging gradient count += 1 if self.args.snas: self.optimizer.zero_grad() self.arch_optimizer.zero_grad() error_loss.backward(retain_graph=True) if not self.args.random_sample: normal_loss_gradient += self.model.normal_log_alpha.grad reduce_loss_gradient += self.model.reduce_log_alpha.grad self.optimizer.zero_grad() self.arch_optimizer.zero_grad() if self.args.snas and (not self.args.random_sample and not self.args.dsnas): loss.backward() if not self.args.random_sample: normal_total_gradient += self.model.normal_log_alpha.grad reduce_total_gradient += self.model.reduce_log_alpha.grad nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip) arch_grad_norm = nn.utils.clip_grad_norm_( self.model.arch_parameters(), 10.) grad.update(arch_grad_norm) if not self.args.fix_weight and self.update_theta: self.optimizer.step() self.optimizer.zero_grad() if not self.args.random_sample and self.update_alpha: self.arch_optimizer.step() self.arch_optimizer.zero_grad() prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) objs.update(error_loss.item(), n) top1.update(prec1.item(), n) top5.update(prec5.item(), n) if step % self.args.report_freq == 0 and self.rank == 0: logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) self.logger.add_scalar( "iter_train_top1_acc", top1.avg, step + len(self.train_queue.dataset) * epoch) if self.rank == 0: logging.info('-------loss gradient--------') logging.info(normal_loss_gradient / count) logging.info(reduce_loss_gradient / count) logging.info('-------total gradient--------') logging.info(normal_total_gradient / count) logging.info(reduce_total_gradient / count) logging.info('correct loss ') logging.info((train_correct_loss / train_correct_count).item()) logging.info('correct entropy ') logging.info((train_correct_entropy / train_correct_count).item()) logging.info('correct cost ') logging.info((train_correct_cost / train_correct_count).item()) logging.info('correct count ') logging.info(train_correct_count) logging.info('wrong loss ') logging.info((train_wrong_loss / train_wrong_count).item()) logging.info('wrong entropy ') logging.info((train_wrong_entropy / train_wrong_count).item()) logging.info('wrong cost ') logging.info((train_wrong_cost / train_wrong_count).item()) logging.info('wrong count ') logging.info(train_wrong_count) logging.info('total loss ') logging.info(((train_correct_loss + train_wrong_loss) / (train_correct_count + train_wrong_count)).item()) logging.info('total entropy ') logging.info(((train_correct_entropy + train_wrong_entropy) / (train_correct_count + train_wrong_count)).item()) logging.info('total cost ') logging.info(((train_correct_cost + train_wrong_cost) / (train_correct_count + train_wrong_count)).item()) logging.info('total count ') logging.info(train_correct_count + train_wrong_count) return top1.avg, loss, error_loss, loss_alpha def infer(self, epoch): objs = utils.AvgrageMeter() top1 = utils.AvgrageMeter() top5 = utils.AvgrageMeter() self.model.eval() with torch.no_grad(): for step, (input, target) in enumerate(self.valid_queue): input = input.to(self.device) target = target.to(self.device) if self.args.snas: logits, logits_aux = self.model(input) loss = self.criterion(logits, target) elif self.args.dsnas: logits, error_loss, loss_alpha = self.model( input, target, self.criterion) loss = error_loss prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) objs.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) top5.update(prec5.item(), input.size(0)) if step % self.args.report_freq == 0 and self.rank == 0: logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) self.logger.add_scalar( "iter_valid_loss", loss, step + len(self.valid_queue.dataset) * epoch) self.logger.add_scalar( "iter_valid_top1_acc", top1.avg, step + len(self.valid_queue.dataset) * epoch) return top1.avg, objs.avg