Exemplo n.º 1
0
    def search(self, train_x, train_y, valid_x, valid_y, metadata):

        np.random.seed(self.seed)
        cudnn.benchmark = True
        torch.manual_seed(self.seed)
        cudnn.enabled = True
        torch.cuda.manual_seed(self.seed)
        is_multi_gpu = False

        helper_function()
        n_classes = metadata['n_classes']

        # check torch available
        if not torch.cuda.is_available():
            logging.info('no gpu device available')
            sys.exit(1)

        cudnn.benchmark = True
        cudnn.enabled = True

        # loading criterion
        criterion = nn.CrossEntropyLoss()
        criterion = criterion.cuda()

        train_pack = list(zip(train_x, train_y))
        valid_pack = list(zip(valid_x, valid_y))

        data_channel = np.array(train_x).shape[1]

        train_loader = torch.utils.data.DataLoader(train_pack,
                                                   int(self.batch_size),
                                                   pin_memory=True,
                                                   num_workers=4)
        valid_loader = torch.utils.data.DataLoader(valid_pack,
                                                   int(self.batch_size),
                                                   pin_memory=True,
                                                   num_workers=4)

        model = Network(self.init_channels, data_channel, n_classes,
                        self.layers, criterion)
        model = model.cuda()

        # since submission server does not deal with multi-gpu
        if is_multi_gpu:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
            model = nn.DataParallel(model)

        arch_parameters = model.module.arch_parameters(
        ) if is_multi_gpu else model.arch_parameters()
        arch_params = list(map(id, arch_parameters))

        parameters = model.module.parameters(
        ) if is_multi_gpu else model.parameters()
        weight_params = filter(lambda p: id(p) not in arch_params, parameters)

        optimizer = torch.optim.SGD(weight_params,
                                    self.learning_rate,
                                    momentum=self.momentum,
                                    weight_decay=self.weight_decay)

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, float(self.epochs), eta_min=self.learning_rate_min)

        architect = Architect(is_multi_gpu, model, criterion, self.momentum,
                              self.weight_decay, self.arch_learning_rate,
                              self.arch_weight_decay)

        best_accuracy = 0
        best_accuracy_different_cnn_counts = dict()

        for epoch in range(self.epochs):
            lr = scheduler.get_lr()[0]
            logging.info('epoch %d lr %e', epoch, lr)

            # training
            objs = utils.AvgrageMeter()
            top1 = utils.AvgrageMeter()
            top5 = utils.AvgrageMeter()

            train_batch = time.time()

            for step, (input, target) in enumerate(train_loader):

                # logging.info("epoch %d, step %d START" % (epoch, step))
                model.train()
                n = input.size(0)

                input = input.cuda()
                target = target.cuda()

                # get a random minibatch from the search queue with replacement
                input_search, target_search = next(iter(valid_loader))
                input_search = input_search.cuda()
                target_search = target_search.cuda()

                # Update architecture alpha by Adam-SGD
                # logging.info("step %d. update architecture by Adam. START" % step)
                # if args.optimization == "DARTS":
                #     architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled)
                # else:
                architect.step_milenas_2ndorder(input, target, input_search,
                                                target_search, lr, optimizer,
                                                1, 1)

                # logging.info("step %d. update architecture by Adam. FINISH" % step)
                # Update weights w by SGD, ignore the weights that gained during architecture training

                # logging.info("step %d. update weight by SGD. START" % step)
                optimizer.zero_grad()
                logits = model(input)
                loss = criterion(logits, target)

                loss.backward()
                parameters = model.module.arch_parameters(
                ) if is_multi_gpu else model.arch_parameters()
                nn.utils.clip_grad_norm_(parameters, self.grad_clip)
                optimizer.step()

                # logging.info("step %d. update weight by SGD. FINISH\n" % step)

                prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
                objs.update(loss.item(), n)
                top1.update(prec1.item(), n)
                top5.update(prec5.item(), n)

                # torch.cuda.empty_cache()

                if step % self.report_freq == 0:
                    average_batch_t = (time.time() - train_batch) / (step + 1)
                    print("Epoch: {}, Step: {}, Top1: {}, Top5: {}, T: {}".
                          format(
                              epoch, step, top1.avg, top5.avg,
                              show_time(average_batch_t *
                                        (len(train_loader) - step))))

            model.eval()

            # validation
            with torch.no_grad():
                objs = utils.AvgrageMeter()
                top1 = utils.AvgrageMeter()
                top5 = utils.AvgrageMeter()

                for step, (input, target) in enumerate(valid_loader):
                    input = input.cuda()
                    target = target.cuda()

                    logits = model(input)
                    loss = criterion(logits, target)

                    prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
                    n = input.size(0)
                    objs.update(loss.item(), n)
                    top1.update(prec1.item(), n)
                    top5.update(prec5.item(), n)

                    if step % self.report_freq == 0:
                        print("Epoch: {}, Step: {}, Top1: {}, Top5: {}".format(
                            epoch, step, top1.avg, top5.avg))

            scheduler.step()

            # save the structure
            genotype, normal_cnn_count, reduce_cnn_count = model.module.genotype(
            ) if is_multi_gpu else model.genotype()
            print("(n:%d,r:%d)" % (normal_cnn_count, reduce_cnn_count))
            # print(F.softmax(model.module.alphas_normal if is_multi_gpu else model.alphas_normal, dim=-1))
            # print(F.softmax(model.module.alphas_reduce if is_multi_gpu else model.alphas_reduce, dim=-1))
            # logging.info('genotype = %s', genotype)

        return model
Exemplo n.º 2
0
class neural_architecture_search():
    def __init__(self, args):
        self.args = args

        if not torch.cuda.is_available():
            logging.info('no gpu device available')
            sys.exit(1)

        if self.args.distributed:
            # Init distributed environment
            self.rank, self.world_size, self.device = init_dist(
                port=self.args.port)
            self.seed = self.rank * self.args.seed
        else:
            torch.cuda.set_device(self.args.gpu)
            self.device = torch.device("cuda")
            self.rank = 0
            self.seed = self.args.seed
            self.world_size = 1

        if self.args.fix_seedcudnn:
            random.seed(self.seed)
            torch.backends.cudnn.deterministic = True
            np.random.seed(self.seed)
            cudnn.benchmark = False
            torch.manual_seed(self.seed)
            cudnn.enabled = True
            torch.cuda.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)
        else:
            np.random.seed(self.seed)
            cudnn.benchmark = True
            torch.manual_seed(self.seed)
            cudnn.enabled = True
            torch.cuda.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)

        self.path = os.path.join(generate_date, self.args.save)
        if self.rank == 0:
            utils.create_exp_dir(generate_date,
                                 self.path,
                                 scripts_to_save=glob.glob('*.py'))
            logging.basicConfig(stream=sys.stdout,
                                level=logging.INFO,
                                format=log_format,
                                datefmt='%m/%d %I:%M:%S %p')
            fh = logging.FileHandler(os.path.join(self.path, 'log.txt'))
            fh.setFormatter(logging.Formatter(log_format))
            logging.getLogger().addHandler(fh)
            logging.info("self.args = %s", self.args)
            self.logger = tensorboardX.SummaryWriter(
                './runs/' + generate_date + '/nas_{}'.format(self.args.remark))
        else:
            self.logger = None

        # set default resource_lambda for different methods
        if self.args.resource_efficient:
            if self.args.method == 'policy_gradient':
                if self.args.log_penalty:
                    default_resource_lambda = 1e-4
                else:
                    default_resource_lambda = 1e-5
            if self.args.method == 'reparametrization':
                if self.args.log_penalty:
                    default_resource_lambda = 1e-2
                else:
                    default_resource_lambda = 1e-5
            if self.args.method == 'discrete':
                if self.args.log_penalty:
                    default_resource_lambda = 1e-2
                else:
                    default_resource_lambda = 1e-4
            if self.args.resource_lambda == default_lambda:
                self.args.resource_lambda = default_resource_lambda

        #initialize loss function
        self.criterion = nn.CrossEntropyLoss().to(self.device)

        #initialize model
        self.init_model()

        #calculate model param size
        if self.rank == 0:
            logging.info("param size = %fMB",
                         utils.count_parameters_in_MB(self.model))
            self.model._logger = self.logger
            self.model._logging = logging

        #initialize optimizer
        self.init_optimizer()

        #iniatilize dataset loader
        self.init_loaddata()

        self.update_theta = True
        self.update_alpha = True

    def init_model(self):

        self.model = Network(self.args.init_channels, CIFAR_CLASSES,
                             self.args.layers, self.criterion, self.args,
                             self.rank, self.world_size)
        self.model.to(self.device)
        if self.args.distributed:
            broadcast_params(self.model)
        for v in self.model.parameters():
            if v.requires_grad:
                if v.grad is None:
                    v.grad = torch.zeros_like(v)
        self.model.normal_log_alpha.grad = torch.zeros_like(
            self.model.normal_log_alpha)
        self.model.reduce_log_alpha.grad = torch.zeros_like(
            self.model.reduce_log_alpha)

    def init_optimizer(self):

        if args.distributed:
            self.optimizer = torch.optim.SGD(
                [
                    param for name, param in self.model.named_parameters() if
                    name != 'normal_log_alpha' and name != 'reduce_log_alpha'
                ],
                self.args.learning_rate,
                momentum=self.args.momentum,
                weight_decay=self.args.weight_decay)
            self.arch_optimizer = torch.optim.Adam(
                [
                    param for name, param in self.model.named_parameters()
                    if name == 'normal_log_alpha' or name == 'reduce_log_alpha'
                ],
                lr=self.args.arch_learning_rate,
                betas=(0.5, 0.999),
                weight_decay=self.args.arch_weight_decay)
        else:
            self.optimizer = torch.optim.SGD(self.model.parameters(),
                                             self.args.learning_rate,
                                             momentum=self.args.momentum,
                                             weight_decay=args.weight_decay)

            self.arch_optimizer = torch.optim.SGD(
                self.model.arch_parameters(), lr=self.args.arch_learning_rate)

    def init_loaddata(self):

        train_transform, valid_transform = utils._data_transforms_cifar10(
            self.args)
        train_data = dset.CIFAR10(root=self.args.data,
                                  train=True,
                                  download=True,
                                  transform=train_transform)
        valid_data = dset.CIFAR10(root=self.args.data,
                                  train=False,
                                  download=True,
                                  transform=valid_transform)

        if self.args.seed:

            def worker_init_fn():
                seed = self.seed
                np.random.seed(seed)
                random.seed(seed)
                torch.manual_seed(seed)
                return
        else:
            worker_init_fn = None

        if self.args.distributed:
            train_sampler = DistributedSampler(train_data)
            valid_sampler = DistributedSampler(valid_data)

            self.train_queue = torch.utils.data.DataLoader(
                train_data,
                batch_size=self.args.batch_size // self.world_size,
                shuffle=False,
                num_workers=0,
                pin_memory=False,
                sampler=train_sampler)
            self.valid_queue = torch.utils.data.DataLoader(
                valid_data,
                batch_size=self.args.batch_size // self.world_size,
                shuffle=False,
                num_workers=0,
                pin_memory=False,
                sampler=valid_sampler)

        else:
            self.train_queue = torch.utils.data.DataLoader(
                train_data,
                batch_size=self.args.batch_size,
                shuffle=True,
                pin_memory=False,
                num_workers=2)

            self.valid_queue = torch.utils.data.DataLoader(
                valid_data,
                batch_size=self.args.batch_size,
                shuffle=False,
                pin_memory=False,
                num_workers=2)

    def main(self):
        # lr scheduler: cosine annealing
        # temp scheduler: linear annealing (self-defined in utils)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            float(self.args.epochs),
            eta_min=self.args.learning_rate_min)

        self.temp_scheduler = utils.Temp_Scheduler(self.args.epochs,
                                                   self.model._temp,
                                                   self.args.temp,
                                                   temp_min=self.args.temp_min)

        for epoch in range(self.args.epochs):
            if self.args.random_sample_pretrain:
                if epoch < self.args.random_sample_pretrain_epoch:
                    self.args.random_sample = True
                else:
                    self.args.random_sample = False

            self.scheduler.step()
            if self.args.temp_annealing:
                self.model._temp = self.temp_scheduler.step()
            self.lr = self.scheduler.get_lr()[0]

            if self.rank == 0:
                logging.info('epoch %d lr %e temp %e', epoch, self.lr,
                             self.model._temp)
                self.logger.add_scalar('epoch_temp', self.model._temp, epoch)
                logging.info(self.model.normal_log_alpha)
                logging.info(self.model.reduce_log_alpha)
                logging.info(
                    self.model._get_weights(self.model.normal_log_alpha[0]))
                logging.info(
                    self.model._get_weights(self.model.reduce_log_alpha[0]))

            genotype_edge_all = self.model.genotype_edge_all()

            if self.rank == 0:
                logging.info('genotype_edge_all = %s', genotype_edge_all)
                # create genotypes.txt file
                txt_name = self.args.remark + '_genotype_edge_all_epoch' + str(
                    epoch)
                utils.txt('genotype', self.args.save, txt_name,
                          str(genotype_edge_all), generate_date)

            self.model.train()
            train_acc, loss, error_loss, loss_alpha = self.train(
                epoch, logging)
            if self.rank == 0:
                logging.info('train_acc %f', train_acc)
                self.logger.add_scalar("epoch_train_acc", train_acc, epoch)
                self.logger.add_scalar("epoch_train_error_loss", error_loss,
                                       epoch)
                if self.args.dsnas:
                    self.logger.add_scalar("epoch_train_alpha_loss",
                                           loss_alpha, epoch)

            # validation
            self.model.eval()
            valid_acc, valid_obj = self.infer(epoch)
            if self.args.gen_max_child:
                self.args.gen_max_child_flag = True
                valid_acc_max_child, valid_obj_max_child = self.infer(epoch)
                self.args.gen_max_child_flag = False

            if self.rank == 0:
                logging.info('valid_acc %f', valid_acc)
                self.logger.add_scalar("epoch_valid_acc", valid_acc, epoch)
                if self.args.gen_max_child:
                    logging.info('valid_acc_argmax_alpha %f',
                                 valid_acc_max_child)
                    self.logger.add_scalar("epoch_valid_acc_argmax_alpha",
                                           valid_acc_max_child, epoch)

                utils.save(self.model, os.path.join(self.path, 'weights.pt'))

        if self.rank == 0:
            logging.info(self.model.normal_log_alpha)
            logging.info(self.model.reduce_log_alpha)
            genotype_edge_all = self.model.genotype_edge_all()
            logging.info('genotype_edge_all = %s', genotype_edge_all)

    def train(self, epoch, logging):
        objs = utils.AvgrageMeter()
        top1 = utils.AvgrageMeter()
        top5 = utils.AvgrageMeter()
        grad = utils.AvgrageMeter()

        normal_resource_gradient = 0
        reduce_resource_gradient = 0
        normal_loss_gradient = 0
        reduce_loss_gradient = 0
        normal_total_gradient = 0
        reduce_total_gradient = 0

        loss_alpha = None

        count = 0
        for step, (input, target) in enumerate(self.train_queue):
            if self.args.alternate_update:
                if step % 2 == 0:
                    self.update_theta = True
                    self.update_alpha = False
                else:
                    self.update_theta = False
                    self.update_alpha = True

            n = input.size(0)
            input = input.to(self.device)
            target = target.to(self.device, non_blocking=True)
            if self.args.snas:
                logits, logits_aux, penalty, op_normal, op_reduce = self.model(
                    input)
                error_loss = self.criterion(logits, target)
                if self.args.auxiliary:
                    loss_aux = self.criterion(logits_aux, target)
                    error_loss += self.args.auxiliary_weight * loss_aux

            if self.args.dsnas:
                logits, error_loss, loss_alpha, penalty = self.model(
                    input, target, self.criterion)

            num_normal = self.model.num_normal
            num_reduce = self.model.num_reduce
            normal_arch_entropy = self.model._arch_entropy(
                self.model.normal_log_alpha)
            reduce_arch_entropy = self.model._arch_entropy(
                self.model.reduce_log_alpha)

            if self.args.resource_efficient:
                if self.args.method == 'policy_gradient':
                    resource_penalty = (penalty[2]) / 6 + self.args.ratio * (
                        penalty[7]) / 2
                    log_resource_penalty = (
                        penalty[35]) / 6 + self.args.ratio * (penalty[36]) / 2
                elif self.args.method == 'reparametrization':
                    resource_penalty = (penalty[26]) / 6 + self.args.ratio * (
                        penalty[25]) / 2
                    log_resource_penalty = (
                        penalty[37]) / 6 + self.args.ratio * (penalty[38]) / 2
                elif self.args.method == 'discrete':
                    resource_penalty = (penalty[28]) / 6 + self.args.ratio * (
                        penalty[27]) / 2
                    log_resource_penalty = (
                        penalty[39]) / 6 + self.args.ratio * (penalty[40]) / 2
                elif self.args.method == 'none':
                    # TODo
                    resource_penalty = torch.zeros(1).cuda()
                    log_resource_penalty = torch.zeros(1).cuda()
                else:
                    logging.info(
                        "wrongly input of method, please re-enter --method from 'policy_gradient', 'discrete', "
                        "'reparametrization', 'none'")
                    sys.exit(1)
            else:
                resource_penalty = torch.zeros(1).cuda()
                log_resource_penalty = torch.zeros(1).cuda()

            if self.args.log_penalty:
                resource_loss = self.model._resource_lambda * log_resource_penalty
            else:
                resource_loss = self.model._resource_lambda * resource_penalty

            if self.args.loss:
                if self.args.snas:
                    loss = resource_loss.clone() + error_loss.clone()
                elif self.args.dsnas:
                    loss = resource_loss.clone()
                else:
                    loss = resource_loss.clone() + -child_coef * (
                        torch.log(normal_one_hot_prob) +
                        torch.log(reduce_one_hot_prob)).sum()
            else:
                if self.args.snas or self.args.dsnas:
                    loss = error_loss.clone()

            if self.args.distributed:
                loss.div_(self.world_size)
                error_loss.div_(self.world_size)
                resource_loss.div_(self.world_size)
                if self.args.dsnas:
                    loss_alpha.div_(self.world_size)

            # logging gradient
            count += 1
            if self.args.resource_efficient:
                self.optimizer.zero_grad()
                self.arch_optimizer.zero_grad()
                resource_loss.backward(retain_graph=True)
                if not self.args.random_sample:
                    normal_resource_gradient += self.model.normal_log_alpha.grad
                    reduce_resource_gradient += self.model.reduce_log_alpha.grad
            if self.args.snas:
                self.optimizer.zero_grad()
                self.arch_optimizer.zero_grad()
                error_loss.backward(retain_graph=True)
                if not self.args.random_sample:
                    normal_loss_gradient += self.model.normal_log_alpha.grad
                    reduce_loss_gradient += self.model.reduce_log_alpha.grad
                self.optimizer.zero_grad()
                self.arch_optimizer.zero_grad()

            if self.args.snas or not self.args.random_sample and not self.args.dsnas:
                loss.backward()
            if not self.args.random_sample:
                normal_total_gradient += self.model.normal_log_alpha.grad
                reduce_total_gradient += self.model.reduce_log_alpha.grad

            if self.args.distributed:
                reduce_tensorgradients(self.model.parameters(), sync=True)
                nn.utils.clip_grad_norm_([
                    param for name, param in self.model.named_parameters() if
                    name != 'normal_log_alpha' and name != 'reduce_log_alpha'
                ], self.args.grad_clip)
                arch_grad_norm = nn.utils.clip_grad_norm_([
                    param for name, param in self.model.named_parameters()
                    if name == 'normal_log_alpha' or name == 'reduce_log_alpha'
                ], 10.)
            else:
                nn.utils.clip_grad_norm_(self.model.parameters(),
                                         self.args.grad_clip)
                arch_grad_norm = nn.utils.clip_grad_norm_(
                    self.model.arch_parameters(), 10.)

            grad.update(arch_grad_norm)
            if not self.args.fix_weight and self.update_theta:
                self.optimizer.step()
            self.optimizer.zero_grad()
            if not self.args.random_sample and self.update_alpha:
                self.arch_optimizer.step()
            self.arch_optimizer.zero_grad()

            if self.rank == 0:
                self.logger.add_scalar(
                    "iter_train_loss", error_loss,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "normal_arch_entropy", normal_arch_entropy,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "reduce_arch_entropy", reduce_arch_entropy,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "total_arch_entropy",
                    normal_arch_entropy + reduce_arch_entropy,
                    step + len(self.train_queue.dataset) * epoch)
                if self.args.dsnas:
                    #reward_normal_edge
                    self.logger.add_scalar(
                        "reward_normal_edge_0",
                        self.model.normal_edge_reward[0],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_1",
                        self.model.normal_edge_reward[1],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_2",
                        self.model.normal_edge_reward[2],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_3",
                        self.model.normal_edge_reward[3],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_4",
                        self.model.normal_edge_reward[4],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_5",
                        self.model.normal_edge_reward[5],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_6",
                        self.model.normal_edge_reward[6],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_7",
                        self.model.normal_edge_reward[7],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_8",
                        self.model.normal_edge_reward[8],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_9",
                        self.model.normal_edge_reward[9],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_10",
                        self.model.normal_edge_reward[10],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_11",
                        self.model.normal_edge_reward[11],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_12",
                        self.model.normal_edge_reward[12],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_13",
                        self.model.normal_edge_reward[13],
                        step + len(self.train_queue.dataset) * epoch)
                    #reward_reduce_edge
                    self.logger.add_scalar(
                        "reward_reduce_edge_0",
                        self.model.reduce_edge_reward[0],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_1",
                        self.model.reduce_edge_reward[1],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_2",
                        self.model.reduce_edge_reward[2],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_3",
                        self.model.reduce_edge_reward[3],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_4",
                        self.model.reduce_edge_reward[4],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_5",
                        self.model.reduce_edge_reward[5],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_6",
                        self.model.reduce_edge_reward[6],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_7",
                        self.model.reduce_edge_reward[7],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_8",
                        self.model.reduce_edge_reward[8],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_9",
                        self.model.reduce_edge_reward[9],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_10",
                        self.model.reduce_edge_reward[10],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_11",
                        self.model.reduce_edge_reward[11],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_12",
                        self.model.reduce_edge_reward[12],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_13",
                        self.model.reduce_edge_reward[13],
                        step + len(self.train_queue.dataset) * epoch)
                #policy size
                self.logger.add_scalar(
                    "iter_normal_size_policy", penalty[2] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_size_policy", penalty[7] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                # baseline: discrete_probability
                self.logger.add_scalar(
                    "iter_normal_size_baseline", penalty[3] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_flops_baseline", penalty[5] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_mac_baseline", penalty[6] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_size_baseline", penalty[8] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_flops_baseline", penalty[9] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_mac_baseline", penalty[10] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                # R - median(R)
                self.logger.add_scalar(
                    "iter_normal_size-avg", penalty[60] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_flops-avg", penalty[61] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_mac-avg", penalty[62] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_size-avg", penalty[63] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_flops-avg", penalty[64] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_mac-avg", penalty[65] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                # lnR - ln(median)
                self.logger.add_scalar(
                    "iter_normal_ln_size-ln_avg", penalty[66] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_ln_flops-ln_avg", penalty[67] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_ln_mac-ln_avg", penalty[68] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_ln_size-ln_avg", penalty[69] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_ln_flops-ln_avg", penalty[70] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_ln_mac-ln_avg", penalty[71] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                '''
                self.logger.add_scalar("iter_normal_size_normalized", penalty[17] / 6, step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar("iter_normal_flops_normalized", penalty[18] / 6, step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar("iter_normal_mac_normalized", penalty[19] / 6, step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar("iter_reduce_size_normalized", penalty[20] / 2, step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar("iter_reduce_flops_normalized", penalty[21] / 2, step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar("iter_reduce_mac_normalized", penalty[22] / 2, step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar("iter_normal_penalty_normalized", penalty[23] / 6,
                                  step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar("iter_reduce_penalty_normalized", penalty[24] / 2,
                                  step + len(self.train_queue.dataset) * epoch)
                '''
                # Monte_Carlo(R_i)
                self.logger.add_scalar(
                    "iter_normal_size_mc", penalty[29] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_flops_mc", penalty[30] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_mac_mc", penalty[31] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_size_mc", penalty[32] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_flops_mc", penalty[33] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_mac_mc", penalty[34] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                # log(|R_i|)
                self.logger.add_scalar(
                    "iter_normal_log_size", penalty[41] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_log_flops", penalty[42] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_log_mac", penalty[43] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_log_size", penalty[44] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_log_flops", penalty[45] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_log_mac", penalty[46] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                # log(P)R_i
                self.logger.add_scalar(
                    "iter_normal_logP_size", penalty[47] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_logP_flops", penalty[48] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_logP_mac", penalty[49] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_logP_size", penalty[50] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_logP_flops", penalty[51] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_logP_mac", penalty[52] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                # log(P)log(R_i)
                self.logger.add_scalar(
                    "iter_normal_logP_log_size", penalty[53] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_logP_log_flops", penalty[54] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_logP_log_mac", penalty[55] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_logP_log_size", penalty[56] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_logP_log_flops", penalty[57] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_logP_log_mac", penalty[58] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)

            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))

            if self.args.distributed:
                loss = loss.detach()
                dist.all_reduce(error_loss)
                dist.all_reduce(prec1)
                dist.all_reduce(prec5)
                prec1.div_(self.world_size)
                prec5.div_(self.world_size)
                #dist_util.all_reduce([loss, prec1, prec5], 'mean')
            objs.update(error_loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

            if step % self.args.report_freq == 0 and self.rank == 0:
                logging.info('train %03d %e %f %f', step, objs.avg, top1.avg,
                             top5.avg)
                self.logger.add_scalar(
                    "iter_train_top1_acc", top1.avg,
                    step + len(self.train_queue.dataset) * epoch)

        if self.rank == 0:
            logging.info('-------resource gradient--------')
            logging.info(normal_resource_gradient / count)
            logging.info(reduce_resource_gradient / count)
            logging.info('-------loss gradient--------')
            logging.info(normal_loss_gradient / count)
            logging.info(reduce_loss_gradient / count)
            logging.info('-------total gradient--------')
            logging.info(normal_total_gradient / count)
            logging.info(reduce_total_gradient / count)

        return top1.avg, loss, error_loss, loss_alpha

    def infer(self, epoch):
        objs = utils.AvgrageMeter()
        top1 = utils.AvgrageMeter()
        top5 = utils.AvgrageMeter()

        self.model.eval()
        with torch.no_grad():
            for step, (input, target) in enumerate(self.valid_queue):
                input = input.to(self.device)
                target = target.to(self.device)
                if self.args.snas:
                    logits, logits_aux, resource_loss, op_normal, op_reduce = self.model(
                        input)
                    loss = self.criterion(logits, target)
                elif self.args.dsnas:
                    logits, error_loss, loss_alpha, resource_loss = self.model(
                        input, target, self.criterion)
                    loss = error_loss

                prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))

                if self.args.distributed:
                    loss.div_(self.world_size)
                    loss = loss.detach()
                    dist.all_reduce(loss)
                    dist.all_reduce(prec1)
                    dist.all_reduce(prec5)
                    prec1.div_(self.world_size)
                    prec5.div_(self.world_size)
                objs.update(loss.item(), input.size(0))
                top1.update(prec1.item(), input.size(0))
                top5.update(prec5.item(), input.size(0))

                if step % self.args.report_freq == 0 and self.rank == 0:
                    logging.info('valid %03d %e %f %f', step, objs.avg,
                                 top1.avg, top5.avg)
                    self.logger.add_scalar(
                        "iter_valid_loss", loss,
                        step + len(self.valid_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "iter_valid_top1_acc", top1.avg,
                        step + len(self.valid_queue.dataset) * epoch)

        return top1.avg, objs.avg
    def search(self, train_x, train_y, valid_x, valid_y, metadata):

        np.random.seed(self.seed)
        cudnn.benchmark = True
        torch.manual_seed(self.seed)
        cudnn.enabled = True
        torch.cuda.manual_seed(self.seed)

        helpers.helper_function()
        n_classes = metadata['n_classes']

        #         reshape it to this dataset
        #         model = torchvision.models.resnet18()
        #         model.conv1 = nn.Conv2d(train_x.shape[1], 64, kernel_size=(7, 7), stride=1, padding=3)
        #         model.fc = nn.Linear(model.fc.in_features, n_classes, bias=True)
        #         return model

        if not torch.cuda.is_available():
            logging.info('no gpu device available')
            sys.exit(1)

        cudnn.benchmark = True
        cudnn.enabled = True

        criterion = nn.CrossEntropyLoss()
        criterion = criterion.cuda()

        model = Network(self.init_channels, n_classes, self.layers, criterion)
        model = model.cuda()

        logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

        optimizer = torch.optim.SGD(model.parameters(),
                                    self.learning_rate,
                                    momentum=self.momentum,
                                    weight_decay=self.weight_decay)

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, float(self.epochs), eta_min=self.learning_rate_min)

        architect = Architect(model)

        train_pack = list(zip(train_x, train_y))
        valid_pack = list(zip(valid_x, valid_y))

        train_loader = torch.utils.data.DataLoader(train_pack,
                                                   int(self.batch_size),
                                                   shuffle=False)
        valid_loader = torch.utils.data.DataLoader(valid_pack,
                                                   int(self.batch_size))

        for epoch in range(self.epochs):
            scheduler.step()
            lr = scheduler.get_lr()[0]
            logging.info('epoch %d lr %e', epoch, lr)

            genotype = model.genotype()
            logging.info('genotype = %s', genotype)

            #             print(F.softmax(model.alphas_normal, dim=-1))
            #             print(F.softmax(model.alphas_reduce, dim=-1))

            # training
            print("++++++Start training+++++++")
            for step, (input, target) in enumerate(train_loader):
                model.train()
                n = input.size(0)

                input = Variable(input, requires_grad=False).cuda()
                target = Variable(target,
                                  requires_grad=False).cuda(non_blocking=True)

                # get a random minibatch from the search queue with replacement
                input_search, target_search = next(iter(valid_loader))
                input_search = Variable(input_search,
                                        requires_grad=False).cuda()
                target_search = Variable(
                    target_search, requires_grad=False).cuda(non_blocking=True)

                architect.step(input,
                               target,
                               input_search,
                               target_search,
                               lr,
                               optimizer,
                               unrolled=self.unrolled)

                optimizer.zero_grad()
                logits = model(input)
                loss = criterion(logits, target)

                loss.backward()
                nn.utils.clip_grad_norm(model.parameters(), self.grad_clip)
                optimizer.step()

                if step % self.report_freq == 0:
                    prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
                    print(step, loss, prec1, prec5)

            # validation
            print("++++++Start validation+++++++")
            with torch.no_grad():
                for step, (input, target) in enumerate(valid_loader):
                    input = Variable(input).cuda()
                    target = Variable(target).cuda(non_blocking=True)

                    model.eval()

                    logits = model(input)
                    loss = criterion(logits, target)

                    if step % self.report_freq == 0:
                        prec1, prec5 = utils.accuracy(logits,
                                                      target,
                                                      topk=(1, 5))
                        print(step, loss, prec1, prec5)

        return model
Exemplo n.º 4
0
)

state_dict = torch.load(args.trained_model,
                        map_location=lambda storage, loc: storage)
print("Pretrained model loading OK...")
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    if "auxiliary" not in k:
        name = k[7:]  # remove module.
        new_state_dict[name] = v
    else:
        print("Auxiliary loss is used when retraining.")

net.load_state_dict(new_state_dict)
net.cuda()
net.eval()
print("Finished loading model!")

transform = TestBaseTransform((104, 117, 123))


def preprocess(img):
    x = torch.from_numpy(transform(img)[0]).permute(2, 0, 1)
    x = x.unsqueeze(0).cuda()
    return x


save_path = args.path + "_res"
os.makedirs(save_path, exist_ok=True)

files = glob.glob(os.path.join(args.path, "*.png")) + glob.glob(
class neural_architecture_search():
    def __init__(self, args):
        self.args = args

        if not torch.cuda.is_available():
            logging.info('no gpu device available')
            sys.exit(1)

        torch.cuda.set_device(self.args.gpu)
        self.device = torch.device("cuda")
        self.rank = 0
        self.seed = self.args.seed
        self.world_size = 1

        if self.args.fix_cudnn:
            random.seed(self.seed)
            torch.backends.cudnn.deterministic = True
            np.random.seed(self.seed)
            cudnn.benchmark = False
            torch.manual_seed(self.seed)
            cudnn.enabled = True
            torch.cuda.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)
        else:
            np.random.seed(self.seed)
            cudnn.benchmark = True
            torch.manual_seed(self.seed)
            cudnn.enabled = True
            torch.cuda.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)

        self.path = os.path.join(generate_date, self.args.save)
        if self.rank == 0:
            utils.create_exp_dir(generate_date,
                                 self.path,
                                 scripts_to_save=glob.glob('*.py'))
            logging.basicConfig(stream=sys.stdout,
                                level=logging.INFO,
                                format=log_format,
                                datefmt='%m/%d %I:%M:%S %p')
            fh = logging.FileHandler(os.path.join(self.path, 'log.txt'))
            fh.setFormatter(logging.Formatter(log_format))
            logging.getLogger().addHandler(fh)
            logging.info("self.args = %s", self.args)
            self.logger = tensorboardX.SummaryWriter('./runs/' +
                                                     generate_date + '/' +
                                                     self.args.save_log)
        else:
            self.logger = None

        #initialize loss function
        self.criterion = nn.CrossEntropyLoss().to(self.device)

        #initialize model
        self.init_model()
        if self.args.resume:
            self.reload_model()

        #calculate model param size
        if self.rank == 0:
            logging.info("param size = %fMB",
                         utils.count_parameters_in_MB(self.model))
            self.model._logger = self.logger
            self.model._logging = logging

        #initialize optimizer
        self.init_optimizer()

        #iniatilize dataset loader
        self.init_loaddata()

        self.update_theta = True
        self.update_alpha = True

    def init_model(self):

        self.model = Network(self.args.init_channels, CIFAR_CLASSES,
                             self.args.layers, self.criterion, self.args,
                             self.rank, self.world_size, self.args.steps,
                             self.args.multiplier)
        self.model.to(self.device)
        for v in self.model.parameters():
            if v.requires_grad:
                if v.grad is None:
                    v.grad = torch.zeros_like(v)
        self.model.normal_log_alpha.grad = torch.zeros_like(
            self.model.normal_log_alpha)
        self.model.reduce_log_alpha.grad = torch.zeros_like(
            self.model.reduce_log_alpha)

    def reload_model(self):
        self.model.load_state_dict(torch.load(self.args.resume_path +
                                              '/weights.pt'),
                                   strict=True)

    def init_optimizer(self):

        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         self.args.learning_rate,
                                         momentum=self.args.momentum,
                                         weight_decay=args.weight_decay)

        self.arch_optimizer = torch.optim.Adam(
            self.model.arch_parameters(),
            lr=self.args.arch_learning_rate,
            betas=(0.5, 0.999),
            weight_decay=self.args.arch_weight_decay)

    def init_loaddata(self):

        train_transform, valid_transform = utils._data_transforms_cifar10(
            self.args)
        train_data = dset.CIFAR10(root=self.args.data,
                                  train=True,
                                  download=True,
                                  transform=train_transform)
        valid_data = dset.CIFAR10(root=self.args.data,
                                  train=False,
                                  download=True,
                                  transform=valid_transform)

        if self.args.seed:

            def worker_init_fn():
                seed = self.seed
                np.random.seed(seed)
                random.seed(seed)
                torch.manual_seed(seed)
                return
        else:
            worker_init_fn = None

        num_train = len(train_data)
        indices = list(range(num_train))

        self.train_queue = torch.utils.data.DataLoader(
            train_data,
            batch_size=self.args.batch_size,
            shuffle=True,
            pin_memory=False,
            num_workers=2)

        self.valid_queue = torch.utils.data.DataLoader(
            valid_data,
            batch_size=self.args.batch_size,
            shuffle=False,
            pin_memory=False,
            num_workers=2)

    def main(self):
        # lr scheduler: cosine annealing
        # temp scheduler: linear annealing (self-defined in utils)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            float(self.args.epochs),
            eta_min=self.args.learning_rate_min)

        self.temp_scheduler = utils.Temp_Scheduler(self.args.epochs,
                                                   self.model._temp,
                                                   self.args.temp,
                                                   temp_min=self.args.temp_min)

        for epoch in range(self.args.epochs):
            if self.args.child_reward_stat:
                self.update_theta = False
                self.update_alpha = False

            if self.args.current_reward:
                self.model.normal_reward_mean = torch.zeros_like(
                    self.model.normal_reward_mean)
                self.model.reduce_reward_mean = torch.zeros_like(
                    self.model.reduce_reward_mean)
                self.model.count = 0

            if epoch < self.args.resume_epoch:
                continue
            self.scheduler.step()
            if self.args.temp_annealing:
                self.model._temp = self.temp_scheduler.step()
            self.lr = self.scheduler.get_lr()[0]

            if self.rank == 0:
                logging.info('epoch %d lr %e temp %e', epoch, self.lr,
                             self.model._temp)
                self.logger.add_scalar('epoch_temp', self.model._temp, epoch)
                logging.info(self.model.normal_log_alpha)
                logging.info(self.model.reduce_log_alpha)
                logging.info(F.softmax(self.model.normal_log_alpha, dim=-1))
                logging.info(F.softmax(self.model.reduce_log_alpha, dim=-1))

            genotype_edge_all = self.model.genotype_edge_all()

            if self.rank == 0:
                logging.info('genotype_edge_all = %s', genotype_edge_all)
                # create genotypes.txt file
                txt_name = remark + '_genotype_edge_all_epoch' + str(epoch)
                utils.txt('genotype', self.args.save, txt_name,
                          str(genotype_edge_all), generate_date)

            self.model.train()
            train_acc, loss, error_loss, loss_alpha = self.train(
                epoch, logging)
            if self.rank == 0:
                logging.info('train_acc %f', train_acc)
                self.logger.add_scalar("epoch_train_acc", train_acc, epoch)
                self.logger.add_scalar("epoch_train_error_loss", error_loss,
                                       epoch)
                if self.args.dsnas:
                    self.logger.add_scalar("epoch_train_alpha_loss",
                                           loss_alpha, epoch)

                if self.args.dsnas and not self.args.child_reward_stat:
                    if self.args.current_reward:
                        logging.info('reward mean stat')
                        logging.info(self.model.normal_reward_mean)
                        logging.info(self.model.reduce_reward_mean)
                        logging.info('count')
                        logging.info(self.model.count)
                    else:
                        logging.info('reward mean stat')
                        logging.info(self.model.normal_reward_mean)
                        logging.info(self.model.reduce_reward_mean)
                        if self.model.normal_reward_mean.size(0) > 1:
                            logging.info('reward mean total stat')
                            logging.info(self.model.normal_reward_mean.sum(0))
                            logging.info(self.model.reduce_reward_mean.sum(0))

                if self.args.child_reward_stat:
                    logging.info('reward mean stat')
                    logging.info(self.model.normal_reward_mean.sum(0))
                    logging.info(self.model.reduce_reward_mean.sum(0))
                    logging.info('reward var stat')
                    logging.info(
                        self.model.normal_reward_mean_square.sum(0) -
                        self.model.normal_reward_mean.sum(0)**2)
                    logging.info(
                        self.model.reduce_reward_mean_square.sum(0) -
                        self.model.reduce_reward_mean.sum(0)**2)

            # validation
            self.model.eval()
            valid_acc, valid_obj = self.infer(epoch)
            if self.args.gen_max_child:
                self.args.gen_max_child_flag = True
                valid_acc_max_child, valid_obj_max_child = self.infer(epoch)
                self.args.gen_max_child_flag = False

            if self.rank == 0:
                logging.info('valid_acc %f', valid_acc)
                self.logger.add_scalar("epoch_valid_acc", valid_acc, epoch)
                if self.args.gen_max_child:
                    logging.info('valid_acc_argmax_alpha %f',
                                 valid_acc_max_child)
                    self.logger.add_scalar("epoch_valid_acc_argmax_alpha",
                                           valid_acc_max_child, epoch)

                utils.save(self.model, os.path.join(self.path, 'weights.pt'))

        if self.rank == 0:
            logging.info(self.model.normal_log_alpha)
            logging.info(self.model.reduce_log_alpha)
            genotype_edge_all = self.model.genotype_edge_all()
            logging.info('genotype_edge_all = %s', genotype_edge_all)

    def train(self, epoch, logging):
        objs = utils.AvgrageMeter()
        top1 = utils.AvgrageMeter()
        top5 = utils.AvgrageMeter()
        grad = utils.AvgrageMeter()

        normal_loss_gradient = 0
        reduce_loss_gradient = 0
        normal_total_gradient = 0
        reduce_total_gradient = 0

        loss_alpha = None

        train_correct_count = 0
        train_correct_cost = 0
        train_correct_entropy = 0
        train_correct_loss = 0
        train_wrong_count = 0
        train_wrong_cost = 0
        train_wrong_entropy = 0
        train_wrong_loss = 0

        count = 0
        for step, (input, target) in enumerate(self.train_queue):

            n = input.size(0)
            input = input.to(self.device)
            target = target.to(self.device, non_blocking=True)
            if self.args.snas:
                logits, logits_aux = self.model(input)
                error_loss = self.criterion(logits, target)
                if self.args.auxiliary:
                    loss_aux = self.criterion(logits_aux, target)
                    error_loss += self.args.auxiliary_weight * loss_aux

            if self.args.dsnas:
                logits, error_loss, loss_alpha = self.model(
                    input,
                    target,
                    self.criterion,
                    update_theta=self.update_theta,
                    update_alpha=self.update_alpha)

            for i in range(logits.size(0)):
                index = logits[i].topk(5, 0, True, True)[1]
                if index[0].item() == target[i].item():
                    train_correct_cost += (
                        -logits[i, target[i].item()] +
                        (F.softmax(logits[i]) * logits[i]).sum())
                    train_correct_count += 1
                    discrete_prob = F.softmax(logits[i], dim=-1)
                    train_correct_entropy += -(
                        discrete_prob * torch.log(discrete_prob)).sum(-1)
                    train_correct_loss += -torch.log(discrete_prob)[
                        target[i].item()]
                else:
                    train_wrong_cost += (
                        -logits[i, target[i].item()] +
                        (F.softmax(logits[i]) * logits[i]).sum())
                    train_wrong_count += 1
                    discrete_prob = F.softmax(logits[i], dim=-1)
                    train_wrong_entropy += -(discrete_prob *
                                             torch.log(discrete_prob)).sum(-1)
                    train_wrong_loss += -torch.log(discrete_prob)[
                        target[i].item()]

            num_normal = self.model.num_normal
            num_reduce = self.model.num_reduce

            if self.args.snas or self.args.dsnas:
                loss = error_loss.clone()

            #self.update_lr()

            # logging gradient
            count += 1
            if self.args.snas:
                self.optimizer.zero_grad()
                self.arch_optimizer.zero_grad()
                error_loss.backward(retain_graph=True)
                if not self.args.random_sample:
                    normal_loss_gradient += self.model.normal_log_alpha.grad
                    reduce_loss_gradient += self.model.reduce_log_alpha.grad
                self.optimizer.zero_grad()
                self.arch_optimizer.zero_grad()

            if self.args.snas and (not self.args.random_sample
                                   and not self.args.dsnas):
                loss.backward()

            if not self.args.random_sample:
                normal_total_gradient += self.model.normal_log_alpha.grad
                reduce_total_gradient += self.model.reduce_log_alpha.grad

            nn.utils.clip_grad_norm_(self.model.parameters(),
                                     self.args.grad_clip)
            arch_grad_norm = nn.utils.clip_grad_norm_(
                self.model.arch_parameters(), 10.)

            grad.update(arch_grad_norm)
            if not self.args.fix_weight and self.update_theta:
                self.optimizer.step()
            self.optimizer.zero_grad()

            if not self.args.random_sample and self.update_alpha:
                self.arch_optimizer.step()
            self.arch_optimizer.zero_grad()

            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))

            objs.update(error_loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

            if step % self.args.report_freq == 0 and self.rank == 0:
                logging.info('train %03d %e %f %f', step, objs.avg, top1.avg,
                             top5.avg)
                self.logger.add_scalar(
                    "iter_train_top1_acc", top1.avg,
                    step + len(self.train_queue.dataset) * epoch)

        if self.rank == 0:
            logging.info('-------loss gradient--------')
            logging.info(normal_loss_gradient / count)
            logging.info(reduce_loss_gradient / count)
            logging.info('-------total gradient--------')
            logging.info(normal_total_gradient / count)
            logging.info(reduce_total_gradient / count)

        logging.info('correct loss ')
        logging.info((train_correct_loss / train_correct_count).item())
        logging.info('correct entropy ')
        logging.info((train_correct_entropy / train_correct_count).item())
        logging.info('correct cost ')
        logging.info((train_correct_cost / train_correct_count).item())
        logging.info('correct count ')
        logging.info(train_correct_count)

        logging.info('wrong loss ')
        logging.info((train_wrong_loss / train_wrong_count).item())
        logging.info('wrong entropy ')
        logging.info((train_wrong_entropy / train_wrong_count).item())
        logging.info('wrong cost ')
        logging.info((train_wrong_cost / train_wrong_count).item())
        logging.info('wrong count ')
        logging.info(train_wrong_count)

        logging.info('total loss ')
        logging.info(((train_correct_loss + train_wrong_loss) /
                      (train_correct_count + train_wrong_count)).item())
        logging.info('total entropy ')
        logging.info(((train_correct_entropy + train_wrong_entropy) /
                      (train_correct_count + train_wrong_count)).item())
        logging.info('total cost ')
        logging.info(((train_correct_cost + train_wrong_cost) /
                      (train_correct_count + train_wrong_count)).item())
        logging.info('total count ')
        logging.info(train_correct_count + train_wrong_count)

        return top1.avg, loss, error_loss, loss_alpha

    def infer(self, epoch):
        objs = utils.AvgrageMeter()
        top1 = utils.AvgrageMeter()
        top5 = utils.AvgrageMeter()

        self.model.eval()
        with torch.no_grad():
            for step, (input, target) in enumerate(self.valid_queue):
                input = input.to(self.device)
                target = target.to(self.device)
                if self.args.snas:
                    logits, logits_aux = self.model(input)
                    loss = self.criterion(logits, target)
                elif self.args.dsnas:
                    logits, error_loss, loss_alpha = self.model(
                        input, target, self.criterion)
                    loss = error_loss

                prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))

                objs.update(loss.item(), input.size(0))
                top1.update(prec1.item(), input.size(0))
                top5.update(prec5.item(), input.size(0))

                if step % self.args.report_freq == 0 and self.rank == 0:
                    logging.info('valid %03d %e %f %f', step, objs.avg,
                                 top1.avg, top5.avg)
                    self.logger.add_scalar(
                        "iter_valid_loss", loss,
                        step + len(self.valid_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "iter_valid_top1_acc", top1.avg,
                        step + len(self.valid_queue.dataset) * epoch)

        return top1.avg, objs.avg
Exemplo n.º 6
0
def train():

    use_gpu = cfg.MODEL.DEVICE == "cuda"
    # 1、make dataloader
    train_loader, val_loader, test_loader, num_query, num_class = darts_make_data_loader(
        cfg)
    # print(num_query, num_class)

    # 2、make model
    model = Network(num_class, cfg)
    # tensor = torch.randn(2, 3, 256, 128)
    # res = model(tensor)
    # print(res[0].size()) [2, 751]

    # 3、make optimizer
    optimizer = make_optimizer(cfg, model)
    arch_optimizer = torch.optim.Adam(
        model._arch_parameters(),
        lr=cfg.SOLVER.ARCH_LR,
        betas=(0.5, 0.999),
        weight_decay=cfg.SOLVER.ARCH_WEIGHT_DECAY)

    # 4、make lr scheduler
    lr_scheduler = make_lr_scheduler(cfg, optimizer)
    # make lr scheduler
    arch_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        arch_optimizer, [80, 160], 0.1)

    # 5、make loss
    loss_fn = darts_make_loss(cfg)

    # model._set_loss(loss_fn, compute_loss_acc)

    # 6、make architect
    # architect = Architect(model, cfg)

    # get parameters
    device = cfg.MODEL.DEVICE
    use_gpu = device == "cuda"
    pretrained = cfg.MODEL.PRETRAINED != ""

    log_period = cfg.OUTPUT.LOG_PERIOD
    ckpt_period = cfg.OUTPUT.CKPT_PERIOD
    eval_period = cfg.OUTPUT.EVAL_PERIOD
    output_dir = cfg.OUTPUT.DIRS
    ckpt_save_path = output_dir + cfg.OUTPUT.CKPT_DIRS

    epochs = cfg.SOLVER.MAX_EPOCHS
    batch_size = cfg.SOLVER.BATCH_SIZE
    grad_clip = cfg.SOLVER.GRAD_CLIP

    batch_num = len(train_loader)
    log_iters = batch_num // log_period

    if not os.path.exists(ckpt_save_path):
        os.makedirs(ckpt_save_path)

    # create *_result.xlsx
    # save the result for analyze
    name = (cfg.OUTPUT.LOG_NAME).split(".")[0] + ".xlsx"
    result_path = cfg.OUTPUT.DIRS + name

    wb = xl.Workbook()
    sheet = wb.worksheets[0]
    titles = [
        'size/M', 'speed/ms', 'final_planes', 'acc', 'mAP', 'r1', 'r5', 'r10',
        'loss', 'acc', 'mAP', 'r1', 'r5', 'r10', 'loss', 'acc', 'mAP', 'r1',
        'r5', 'r10', 'loss'
    ]
    sheet.append(titles)
    check_epochs = [40, 80, 120, 160, 200, 240, 280, 320, 360, epochs]
    values = []

    logger = logging.getLogger("CSNet_Search.train")
    size = count_parameters(model)
    values.append(format(size, '.2f'))
    values.append(model.final_planes)

    logger.info("the param number of the model is {:.2f} M".format(size))

    logger.info("Starting Search CDNetwork")

    best_mAP, best_r1 = 0., 0.
    is_best = False
    avg_loss, avg_acc = RunningAverageMeter(), RunningAverageMeter()
    avg_time, global_avg_time = AverageMeter(), AverageMeter()

    if use_gpu:
        model = model.to(device)

    if pretrained:
        logger.info("load self pretrained chekpoint to init")
        model.load_pretrained_model(cfg.MODEL.PRETRAINED)
    else:
        logger.info("use kaiming init to init the model")
        model.kaiming_init_()
    # exit(1)
    for epoch in range(epochs):
        model.set_tau(cfg.MODEL.TAU_MAX -
                      (cfg.MODEL.TAU_MAX - cfg.MODEL.TAU_MIN) * epoch /
                      (epochs - 1))
        lr_scheduler.step()
        lr = lr_scheduler.get_lr()[0]
        # architect lr.step
        arch_lr_scheduler.step()

        # if save epoch_num k, then run k+1 epoch next
        if pretrained and epoch < model.start_epoch:
            continue

        # print(epoch)
        # exit(1)
        model.train()
        avg_loss.reset()
        avg_acc.reset()
        avg_time.reset()

        for i, batch in enumerate(train_loader):

            t0 = time.time()
            imgs, labels = batch
            val_imgs, val_labels = next(iter(val_loader))

            if use_gpu:
                imgs = imgs.to(device)
                labels = labels.to(device)
                val_imgs = val_imgs.to(device)
                val_labels = val_labels.to(device)

            # 1、 update the weights
            optimizer.zero_grad()
            res = model(imgs)

            # loss = loss_fn(scores, feats, labels)
            loss, acc = compute_loss_acc(res, labels, loss_fn)
            loss.backward()

            if grad_clip != 0:
                nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

            optimizer.step()

            # 2、update the alpha
            arch_optimizer.zero_grad()
            res = model(val_imgs)

            val_loss, val_acc = compute_loss_acc(res, val_labels, loss_fn)
            val_loss.backward()
            arch_optimizer.step()

            # compute the acc
            # acc = (scores.max(1)[1] == labels).float().mean()

            t1 = time.time()
            avg_time.update((t1 - t0) / batch_size)
            avg_loss.update(loss)
            avg_acc.update(acc)

            # log info
            if (i + 1) % log_iters == 0:
                logger.info(
                    "epoch {}: {}/{} with loss is {:.5f} and acc is {:.3f}".
                    format(epoch + 1, i + 1, batch_num, avg_loss.avg,
                           avg_acc.avg))

        logger.info(
            "end epochs {}/{} with lr: {:.5f} and avg_time is: {:.3f} ms".
            format(epoch + 1, epochs, lr, avg_time.avg * 1000))
        global_avg_time.update(avg_time.avg)

        # test the model
        if (epoch + 1) % eval_period == 0 or (epoch + 1) in check_epochs:

            model.eval()
            metrics = R1_mAP(num_query, use_gpu=use_gpu)

            with torch.no_grad():
                for vi, batch in enumerate(test_loader):
                    # break
                    # print(len(batch))
                    imgs, labels, camids = batch
                    if use_gpu:
                        imgs = imgs.to(device)

                    feats = model(imgs)
                    metrics.update((feats, labels, camids))

                #compute cmc and mAP
                cmc, mAP = metrics.compute()
                logger.info("validation results at epoch {}".format(epoch + 1))
                logger.info("mAP:{:2%}".format(mAP))
                for r in [1, 5, 10]:
                    logger.info("CMC curve, Rank-{:<3}:{:.2%}".format(
                        r, cmc[r - 1]))

                # determine whether current model is the best
                if mAP > best_mAP:
                    is_best = True
                    best_mAP = mAP
                    logger.info("Get a new best mAP")
                if cmc[0] > best_r1:
                    is_best = True
                    best_r1 = cmc[0]
                    logger.info("Get a new best r1")

                # add the result to sheet
                if (epoch + 1) in check_epochs:
                    val = [avg_acc.avg, mAP, cmc[0], cmc[4], cmc[9]]
                    change = [format(v * 100, '.2f') for v in val]
                    change.append(format(avg_loss.avg, '.3f'))
                    values.extend(change)

        # whether to save the model
        if (epoch + 1) % ckpt_period == 0 or is_best:
            torch.save(model.state_dict(),
                       ckpt_save_path + "checkpoint_{}.pth".format(epoch + 1))
            model._parse_genotype(file=ckpt_save_path +
                                  "genotype_{}.json".format(epoch + 1))
            logger.info("checkpoint {} was saved".format(epoch + 1))

            if is_best:
                torch.save(model.state_dict(),
                           ckpt_save_path + "best_ckpt.pth")
                model._parse_genotype(file=ckpt_save_path +
                                      "best_genotype.json")
                logger.info("best_checkpoint was saved")
                is_best = False
        # exit(1)

    values.insert(1, format(global_avg_time.avg * 1000, '.2f'))
    sheet.append(values)
    wb.save(result_path)

    logger.info("Ending Search GDAS_Search")