示例#1
0
    def _warmup(self, phase, epoch):
        assert phase in [PHASE_SMALL, PHASE_LARGE]
        if phase == PHASE_SMALL:
            model, optimizer = self.model_small, self.optimizer_small
        elif phase == PHASE_LARGE:
            model, optimizer = self.model_large, self.optimizer_large
        model.train()
        meters = AverageMeterGroup()
        for step in range(self.steps_per_epoch):
            x, y = next(self.train_loader)
            x, y = x.cuda(), y.cuda()

            optimizer.zero_grad()
            logits_main, _ = model(x)
            loss = self.criterion(logits_main, y)
            loss.backward()

            self._clip_grad_norm(model)
            optimizer.step()
            prec1, prec5 = accuracy(logits_main, y, topk=(1, 5))
            metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
            metrics = reduce_metrics(metrics, self.distributed)
            meters.update(metrics)
            if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch):
                self.logger.info("Epoch [%d/%d] Step [%d/%d] (%s)  %s", epoch + 1, self.epochs,
                                 step + 1, self.steps_per_epoch, phase, meters)
示例#2
0
    def validate_one_epoch(self, epoch):

        # return list of meters
        meter_list = []

        self.model.eval()
        self.mutator.eval()
        meters = AverageMeterGroup()
        with torch.no_grad():
            self.mutator.reset()
            for step, (X, y) in enumerate(self.test_loader):
                X, y = X.to(self.device), y.to(self.device)
                logits = self.model(X)
                metrics = self.metrics(logits, y)
                loss = self.loss(logits, y)
                metrics["loss"] = loss.item()
                meters.update(metrics)

                if self.log_frequency is not None and step % self.log_frequency == 0:
                    logger.info("Epoch [%s/%s] Step [%s/%s]  %s", epoch + 1,
                                self.num_epochs, step + 1, len(self.test_loader), meters)
            
            meter_dict = json.loads(json.dumps('{' + meters.summary() + '}'))
            meter_list.append(json.loads(meter_dict))
                
        return meter_list
示例#3
0
    def validate_one_epoch(self, epoch):

        # return a list of meters
        meter_list = []

        with torch.no_grad():
            for arc_id in range(self.test_arc_per_epoch):
                meters = AverageMeterGroup()
                for x, y in self.test_loader:
                    x, y = to_device(x, self.device), to_device(y, self.device)
                    self.mutator.reset()
                    logits = self.model(x)
                    if isinstance(logits, tuple):
                        logits, _ = logits
                    metrics = self.metrics(logits, y)
                    loss = self.loss(logits, y)
                    metrics["loss"] = loss.item()
                    meters.update(metrics)

                meter_dict = json.loads(
                    json.dumps('{' + meters.summary() + '}'))
                #print("meter_dict: {}".format(meter_dict))
                #print("type of meter_dict: {}".format(type(json.loads(meter_dict))))
                meter_list.append(json.loads(meter_dict))

                logger.info("Test Epoch [%d/%d] Arc [%d/%d] Summary  %s",
                            epoch + 1, self.num_epochs, arc_id + 1,
                            self.test_arc_per_epoch, meters.summary())

        return meter_list
示例#4
0
def train(epoch, model, criterion, optimizer, loader, writer, args):
    model.train()
    meters = AverageMeterGroup()
    cur_lr = optimizer.param_groups[0]["lr"]

    for step, (x, y) in enumerate(loader):
        cur_step = len(loader) * epoch + step
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        metrics = accuracy(logits, y)
        metrics["loss"] = loss.item()
        meters.update(metrics)

        writer.add_scalar("lr", cur_lr, global_step=cur_step)
        writer.add_scalar("loss/train", loss.item(), global_step=cur_step)
        writer.add_scalar("acc1/train", metrics["acc1"], global_step=cur_step)
        writer.add_scalar("acc5/train", metrics["acc5"], global_step=cur_step)

        if step % args.log_frequency == 0 or step + 1 == len(loader):
            logger.info("Epoch [%d/%d] Step [%d/%d]  %s", epoch + 1,
                        args.epochs, step + 1, len(loader), meters)

    logger.info("Epoch %d training summary: %s", epoch + 1, meters)
示例#5
0
    def train_one_epoch(self, epoch):
        self.model.train()
        self.mutator.train()
        meters = AverageMeterGroup()
        for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)):
            if self.debug and step > 0:
                break
            trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device)
            val_X, val_y = val_X.to(self.device), val_y.to(self.device)

            # phase 1. architecture step
            self.ctrl_optim.zero_grad()
            if self.unrolled:
                self._unrolled_backward(trn_X, trn_y, val_X, val_y)
            else:
                self._backward(val_X, val_y)
            self.ctrl_optim.step()

            # phase 2: child network step
            self.optimizer.zero_grad()
            logits, loss = self._logits_and_loss(trn_X, trn_y)
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), 5.)  # gradient clipping
            self.optimizer.step()

            metrics = self.metrics(logits, trn_y)
            metrics["loss"] = loss.item()
            meters.update(metrics)
            if self.log_frequency is not None and step % self.log_frequency == 0:
                self.logger.info("Model Epoch [{}/{}] Step [{}/{}] Model size: {} {}".format(
                    epoch + 1, self.num_epochs, step + 1, len(self.train_loader), self.model_size(), meters))

        return meters
示例#6
0
    def train_one_epoch(self, epoch):
        self.mutator.train()
        meters = AverageMeterGroup()
        for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.batched_train, self.batched_validate)):
            trn_X = pad_sequence([self.model.vectors[x] for x in trn_X]).permute(1,0,2)
            val_X = pad_sequence([self.model.vectors[x] for x in val_X]).permute(1,0,2)

            trn_X = trn_X.to(self.device)
            trn_y = torch.stack([y.int() for y in trn_y]).to(self.device)
            val_X = val_X.to(self.device)
            val_y = torch.stack([y.int() for y in val_y]).to(self.device)

            # phase 1. architecture step
            self.ctrl_optim.zero_grad()
            if self.unrolled:
                self._unrolled_backward(trn_X, trn_y, val_X, val_y)
            else:
                self._backward(val_X, val_y)
            self.ctrl_optim.step()

            # phase 2: child network step
            self.optimizer.zero_grad()
            logits, loss = self._logits_and_loss(trn_X, trn_y)
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), 5.)  # gradient clipping
            self.optimizer.step()

            metrics = self.metrics(logits, trn_y)
            metrics["loss"] = loss.item()
            meters.update(metrics)
            if self.log_frequency is not None and step % self.log_frequency == 0:
                logger.info("Epoch [%s/%s] Step [%s/%s]  %s", epoch + 1,
                            self.num_epochs, step + 1, len(self.train_loader), meters)
示例#7
0
    def _joint_train(self, epoch):
        self.model_large.train()
        self.model_small.train()
        meters = AverageMeterGroup()
        for step in range(self.steps_per_epoch):
            trn_x, trn_y = next(self.train_loader)
            val_x, val_y = next(self.valid_loader)
            trn_x, trn_y = trn_x.cuda(), trn_y.cuda()
            val_x, val_y = val_x.cuda(), val_y.cuda()

            # step 1. optimize architecture
            self.optimizer_alpha.zero_grad()
            self.optimizer_large.zero_grad()
            reg_decay = max(
                self.regular_coeff *
                (1 - float(epoch - self.warmup_epochs) /
                 ((self.epochs - self.warmup_epochs) * self.regular_ratio)), 0)
            loss_regular = self.mutator_small.reset_with_loss()
            if loss_regular:
                loss_regular *= reg_decay
            logits_search, emsemble_logits_search = self.model_small(val_x)
            logits_main, emsemble_logits_main = self.model_large(val_x)
            loss_cls = (self.criterion(logits_search, val_y) +
                        self.criterion(logits_main, val_y)) / self.loss_alpha
            loss_interactive = self.interactive_loss(
                emsemble_logits_search,
                emsemble_logits_main) * (self.loss_T**2) * self.loss_alpha
            loss = loss_cls + loss_interactive + loss_regular
            loss.backward()
            self._clip_grad_norm(self.model_large)
            self.optimizer_large.step()
            self.optimizer_alpha.step()
            # NOTE: need to call here `self._reset_nan(self.mutator_small.parameters())` if `cut_choices`

            # step 2. optimize op weights
            self.optimizer_small.zero_grad()
            with torch.no_grad():
                # resample architecture since parameters have been changed
                self.mutator_small.reset_with_loss()
            logits_search_train, _ = self.model_small(trn_x)
            loss_weight = self.criterion(logits_search_train, trn_y)
            loss_weight.backward()
            self._clip_grad_norm(self.model_small)
            self.optimizer_small.step()

            metrics = {
                "loss_cls": loss_cls,
                "loss_interactive": loss_interactive,
                "loss_regular": loss_regular,
                "loss_weight": loss_weight
            }
            metrics = reduce_metrics(metrics, self.distributed)
            meters.update(metrics)

            if self.main_proc and (step % self.log_frequency == 0
                                   or step + 1 == self.steps_per_epoch):
                self.logger.info("Epoch [%d/%d] Step [%d/%d] (joint)  %s",
                                 epoch + 1, self.epochs, step + 1,
                                 self.steps_per_epoch, meters)
示例#8
0
 def validate_one_epoch(self, epoch):
     self.model.eval()
     self.mutator.eval()
     meters = AverageMeterGroup()
     with torch.no_grad():
         self.mutator.reset()
         for step, (X, y) in enumerate(self.test_loader):
             X, y = X.to(self.device), y.to(self.device)
             logits = self.model(X)
             metrics = self.metrics(logits, y)
             meters.update(metrics)
             if self.log_frequency is not None and step % self.log_frequency == 0:
                 logger.info("Epoch [%s/%s] Step [%s/%s]  %s", epoch + 1,
                             self.num_epochs, step + 1, len(self.test_loader), meters)
示例#9
0
 def validate_one_epoch(self, epoch):
     self.model.eval()
     meters = AverageMeterGroup()
     with torch.no_grad():
         for step, (x, y) in enumerate(self.valid_loader):
             self.mutator.reset()
             logits = self.model(x)
             loss = self.loss(logits, y)
             metrics = self.metrics(logits, y)
             metrics["loss"] = loss.item()
             meters.update(metrics)
             if self.log_frequency is not None and step % self.log_frequency == 0:
                 logger.info("Epoch [%s/%s] Validation Step [%s/%s]  %s", epoch + 1,
                             self.num_epochs, step + 1, len(self.valid_loader), meters)
示例#10
0
 def validate_one_epoch(self, epoch):
     self.model.eval()
     self.mutator.eval()
     meters = AverageMeterGroup()
     with torch.no_grad():
         self.mutator.reset()
         for step, (X, y) in enumerate(self.batched_test):
             X = pad_sequence([self.model.vectors[x] for x in X]).permute(1,0,2)
             X = X.to(self.device)
             y = torch.stack([y_.int() for y_ in y]).to(self.device)
             logits = self.model(X)
             metrics = self.metrics(logits, y)
             meters.update(metrics)
             if self.log_frequency is not None and step % self.log_frequency == 0:
                 logger.info("Epoch [%s/%s] Step [%s/%s]  %s", epoch + 1,
                             self.num_epochs, step + 1, len(self.test_loader), meters)
示例#11
0
文件: tester.py 项目: JSong-Jia/nni-1
def test_acc(model, criterion, log_freq, loader):
    logger.info("Start testing...")
    model.eval()
    meters = AverageMeterGroup()
    start_time = time.time()
    with torch.no_grad():
        for step, (inputs, targets) in enumerate(loader):
            logits = model(inputs)
            loss = criterion(logits, targets)
            metrics = accuracy(logits, targets)
            metrics["loss"] = loss.item()
            meters.update(metrics)
            if step % log_freq == 0 or step + 1 == len(loader):
                logger.info("Valid Step [%d/%d] time %.3fs acc1 %.4f acc5 %.4f loss %.4f",
                            step + 1, len(loader), time.time() - start_time,
                            meters.acc1.avg, meters.acc5.avg, meters.loss.avg)
    return meters.acc1.avg
示例#12
0
 def test_one_epoch(self, epoch):
     self.model.eval()
     self.mutator.eval()
     meters = AverageMeterGroup()
     with torch.no_grad():
         self.mutator.reset()
         for step, (X, y) in enumerate(self.test_loader):
             if self.debug and step > 0:
                 break
             X, y = X.to(self.device), y.to(self.device)
             logits = self.model(X)
             metrics = self.metrics(logits, y)
             meters.update(metrics)
             if self.log_frequency is not None and step % self.log_frequency == 0:
                 self.logger.info("Test: Step [{}/{}]  {}".format(step + 1, len(self.test_loader), meters))
         self.logger.info("Final model metric = {}".format(meters.meters['save_metric'].avg))
     return meters
示例#13
0
    def validate_one_epoch(self, epoch):
        self.model.eval()
        meters = AverageMeterGroup()
        with torch.no_grad():
            for step, (x, y) in enumerate(self.valid_loader):
                self.mutator.reset()
                logits = self.model(x)
                loss = self.val_loss(logits, y)
                prec1, prec5 = accuracy(logits, y, topk=(1, 5))
                metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
                metrics = reduce_metrics(metrics)
                meters.update(metrics)

                if self.log_frequency is not None and step % self.log_frequency == 0:
                    logger.info("Epoch [%s/%s] Validation Step [%s/%s]  %s",
                                epoch + 1, self.num_epochs, step + 1,
                                len(self.valid_loader), meters)
示例#14
0
    def train_one_epoch(self, epoch):
        self.model.train()
        meters = AverageMeterGroup()
        for step, (x, y) in enumerate(self.train_loader):
            self.optimizer.zero_grad()
            self.mutator.reset()
            logits = self.model(x)
            loss = self.loss(logits, y)
            loss.backward()
            self.optimizer.step()

            metrics = self.metrics(logits, y)
            metrics["loss"] = loss.item()
            meters.update(metrics)
            if self.log_frequency is not None and step % self.log_frequency == 0:
                logger.info("Epoch [%s/%s] Step [%s/%s]  %s", epoch + 1,
                            self.num_epochs, step + 1, len(self.train_loader), meters)
示例#15
0
    def validate_one_epoch(self, epoch):
        with torch.no_grad():
            for arc_id in range(self.test_arc_per_epoch):
                meters = AverageMeterGroup()
                for x, y in self.test_loader:
                    x, y = to_device(x, self.device), to_device(y, self.device)
                    self.mutator.reset()
                    logits = self.model(x)
                    if isinstance(logits, tuple):
                        logits, _ = logits
                    metrics = self.metrics(logits, y)
                    loss = self.loss(logits, y)
                    metrics["loss"] = loss.item()
                    meters.update(metrics)

                logger.info("Test Epoch [%d/%d] Arc [%d/%d] Summary  %s",
                            epoch + 1, self.num_epochs, arc_id + 1,
                            self.test_arc_per_epoch, meters.summary())
示例#16
0
def validate(epoch, model, criterion, loader, writer, args):
    model.eval()
    meters = AverageMeterGroup()
    with torch.no_grad():
        for step, (x, y) in enumerate(loader):
            logits = model(x)
            loss = criterion(logits, y)
            metrics = accuracy(logits, y)
            metrics["loss"] = loss.item()
            meters.update(metrics)

            if step % args.log_frequency == 0 or step + 1 == len(loader):
                logger.info("Epoch [%d/%d] Validation Step [%d/%d]  %s", epoch + 1,
                            args.epochs, step + 1, len(loader), meters)

    writer.add_scalar("loss/test", meters.loss.avg, global_step=epoch)
    writer.add_scalar("acc1/test", meters.acc1.avg, global_step=epoch)
    writer.add_scalar("acc5/test", meters.acc5.avg, global_step=epoch)

    logger.info("Epoch %d validation: top1 = %f, top5 = %f", epoch + 1, meters.acc1.avg, meters.acc5.avg)
示例#17
0
文件: tester.py 项目: JSong-Jia/nni-1
def retrain_bn(model, criterion, max_iters, log_freq, loader):
    with torch.no_grad():
        logger.info("Clear BN statistics...")
        for m in model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.running_mean = torch.zeros_like(m.running_mean)
                m.running_var = torch.ones_like(m.running_var)

        logger.info("Train BN with training set (BN sanitize)...")
        model.train()
        meters = AverageMeterGroup()
        for step in range(max_iters):
            inputs, targets = next(loader)
            logits = model(inputs)
            loss = criterion(logits, targets)
            metrics = accuracy(logits, targets)
            metrics["loss"] = loss.item()
            meters.update(metrics)
            if step % log_freq == 0 or step + 1 == max_iters:
                logger.info("Train Step [%d/%d] %s", step + 1, max_iters, meters)
示例#18
0
文件: trainer.py 项目: zsjtoby/nni
    def train_one_epoch(self, epoch):
        # Sample model and train
        self.model.train()
        self.mutator.eval()
        meters = AverageMeterGroup()
        for step, (x, y) in enumerate(self.train_loader):
            x, y = x.to(self.device), y.to(self.device)
            self.optimizer.zero_grad()

            with torch.no_grad():
                self.mutator.reset()
            logits = self.model(x)

            if isinstance(logits, tuple):
                logits, aux_logits = logits
                aux_loss = self.loss(aux_logits, y)
            else:
                aux_loss = 0.
            metrics = self.metrics(logits, y)
            loss = self.loss(logits, y)
            loss = loss + self.aux_weight * aux_loss
            loss.backward()
            self.optimizer.step()
            metrics["loss"] = loss.item()
            meters.update(metrics)

            if self.log_frequency is not None and step % self.log_frequency == 0:
                logger.info("Model Epoch [%s/%s] Step [%s/%s]  %s",
                            epoch + 1, self.num_epochs, step + 1,
                            len(self.train_loader), meters)

        # Train sampler (mutator)
        self.model.eval()
        self.mutator.train()
        meters = AverageMeterGroup()
        mutator_step, total_mutator_steps = 0, self.mutator_steps * self.mutator_steps_aggregate
        while mutator_step < total_mutator_steps:
            for step, (x, y) in enumerate(self.valid_loader):
                x, y = x.to(self.device), y.to(self.device)

                self.mutator.reset()
                with torch.no_grad():
                    logits = self.model(x)
                metrics = self.metrics(logits, y)
                reward = self.reward_function(logits, y)
                if self.entropy_weight is not None:
                    reward += self.entropy_weight * self.mutator.sample_entropy
                self.baseline = self.baseline * self.baseline_decay + reward * (
                    1 - self.baseline_decay)
                self.baseline = self.baseline.detach().item()
                loss = self.mutator.sample_log_prob * (reward - self.baseline)
                if self.skip_weight:
                    loss += self.skip_weight * self.mutator.sample_skip_penalty
                metrics["reward"] = reward
                metrics["loss"] = loss.item()
                metrics["ent"] = self.mutator.sample_entropy.item()
                metrics["baseline"] = self.baseline
                metrics["skip"] = self.mutator.sample_skip_penalty

                loss = loss / self.mutator_steps_aggregate
                loss.backward()
                meters.update(metrics)

                if mutator_step % self.mutator_steps_aggregate == 0:
                    self.mutator_optim.step()
                    self.mutator_optim.zero_grad()

                if self.log_frequency is not None and step % self.log_frequency == 0:
                    logger.info(
                        "RL Epoch [%s/%s] Step [%s/%s]  %s", epoch + 1,
                        self.num_epochs,
                        mutator_step // self.mutator_steps_aggregate + 1,
                        self.mutator_steps, meters)
                mutator_step += 1
                if mutator_step >= total_mutator_steps:
                    break
示例#19
0
    def train_one_epoch(self, epoch):
        # Sample model and train
        self.model.train()
        self.mutator.eval()
        meters = AverageMeterGroup()
        for step in range(1, self.child_steps + 1):
            x, y = next(self.train_loader)
            x, y = to_device(x, self.device), to_device(y, self.device)
            self.optimizer.zero_grad()

            with torch.no_grad():
                self.mutator.reset()
            logits = self.model(x)

            if isinstance(logits, tuple):
                logits, aux_logits = logits
                aux_loss = self.loss(aux_logits, y)
            else:
                aux_loss = 0.
            metrics = self.metrics(logits, y)
            loss = self.loss(logits, y)
            loss = loss + self.aux_weight * aux_loss
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
            self.optimizer.step()
            metrics["loss"] = loss.item()
            meters.update(metrics)

            if self.log_frequency is not None and step % self.log_frequency == 0:
                logger.info("Model Epoch [%d/%d] Step [%d/%d]  %s", epoch + 1,
                            self.num_epochs, step, self.child_steps, meters)

        # Train sampler (mutator)
        self.model.eval()
        self.mutator.train()
        meters = AverageMeterGroup()
        for mutator_step in range(1, self.mutator_steps + 1):
            self.mutator_optim.zero_grad()
            for step in range(1, self.mutator_steps_aggregate + 1):
                x, y = next(self.valid_loader)
                x, y = to_device(x, self.device), to_device(y, self.device)

                self.mutator.reset()
                with torch.no_grad():
                    logits = self.model(x)
                metrics = self.metrics(logits, y)
                reward = self.reward_function(logits, y)
                if self.entropy_weight:
                    reward += self.entropy_weight * self.mutator.sample_entropy.item(
                    )
                self.baseline = self.baseline * self.baseline_decay + reward * (
                    1 - self.baseline_decay)
                loss = self.mutator.sample_log_prob * (reward - self.baseline)
                if self.skip_weight:
                    loss += self.skip_weight * self.mutator.sample_skip_penalty
                metrics["reward"] = reward
                metrics["loss"] = loss.item()
                metrics["ent"] = self.mutator.sample_entropy.item()
                metrics["log_prob"] = self.mutator.sample_log_prob.item()
                metrics["baseline"] = self.baseline
                metrics["skip"] = self.mutator.sample_skip_penalty

                loss /= self.mutator_steps_aggregate
                loss.backward()
                meters.update(metrics)

                cur_step = step + (mutator_step -
                                   1) * self.mutator_steps_aggregate
                if self.log_frequency is not None and cur_step % self.log_frequency == 0:
                    logger.info("RL Epoch [%d/%d] Step [%d/%d] [%d/%d]  %s",
                                epoch + 1, self.num_epochs, mutator_step,
                                self.mutator_steps, step,
                                self.mutator_steps_aggregate, meters)

            nn.utils.clip_grad_norm_(self.mutator.parameters(), 5.)
            self.mutator_optim.step()
示例#20
0
文件: trainer.py 项目: Tudor33/nni
    def train_one_epoch(self, epoch):
        def get_model(model):
            return model.module

        meters = AverageMeterGroup()
        for step, (input_data, target) in enumerate(self.train_loader):
            self.optimizer.zero_grad()
            self.mutator.reset()

            input_data = input_data.cuda()
            target = target.cuda()

            cand_flops = self.est.get_flops(self.mutator._cache)

            if epoch > self.meta_sta_epoch and step > 0 and step % self.update_iter == 0:

                slice_ind = self.slices
                x = deepcopy(input_data[:slice_ind].clone().detach())

                if self.best_children_pool:
                    if self.pick_method == 'top1':
                        meta_value, cand = 1, sorted(self.best_children_pool,
                                                     reverse=True)[0][3]
                    elif self.pick_method == 'meta':
                        meta_value, cand_idx, cand = -1000000000, -1, None
                        for now_idx, item in enumerate(
                                self.best_children_pool):
                            inputx = item['input']
                            output = F.softmax(self.model(inputx), dim=1)
                            weight = get_model(
                                self.model).forward_meta(output -
                                                         item['feature_map'])
                            if weight > meta_value:
                                meta_value = weight  # deepcopy(torch.nn.functional.sigmoid(weight))
                                cand_idx = now_idx
                                cand = self.arch_dict[(
                                    self.best_children_pool[cand_idx]['acc'],
                                    self.best_children_pool[cand_idx]
                                    ['arch_list'])]
                        assert cand is not None
                        meta_value = torch.nn.functional.sigmoid(-weight)
                    else:
                        raise ValueError('Method Not supported')

                    u_output = self.model(x)

                    saved_cache = self.mutator._cache
                    self.mutator._cache = cand
                    u_teacher_output = self.model(x)
                    self.mutator._cache = saved_cache

                    u_soft_label = F.softmax(u_teacher_output, dim=1)
                    kd_loss = meta_value * self.cross_entropy_loss_with_soft_target(
                        u_output, u_soft_label)
                    self.optimizer.zero_grad()

                    grad_1 = torch.autograd.grad(
                        kd_loss,
                        get_model(self.model).rand_parameters(
                            self.mutator._cache),
                        create_graph=True)

                    def raw_sgd(w, g):
                        return g * self.optimizer.param_groups[-1]['lr'] + w

                    students_weight = [
                        raw_sgd(p, grad_item) for p, grad_item in zip(
                            get_model(self.model).rand_parameters(
                                self.mutator._cache), grad_1)
                    ]

                    # update student weights
                    for weight, grad_item in zip(
                            get_model(self.model).rand_parameters(
                                self.mutator._cache), grad_1):
                        weight.grad = grad_item
                    torch.nn.utils.clip_grad_norm_(
                        get_model(self.model).rand_parameters(
                            self.mutator._cache), 1)
                    self.optimizer.step()
                    for weight, grad_item in zip(
                            get_model(self.model).rand_parameters(
                                self.mutator._cache), grad_1):
                        del weight.grad

                    held_out_x = deepcopy(input_data[slice_ind:slice_ind *
                                                     2].clone().detach())
                    output_2 = self.model(held_out_x)
                    valid_loss = self.loss(output_2,
                                           target[slice_ind:slice_ind * 2])
                    self.optimizer.zero_grad()

                    grad_student_val = torch.autograd.grad(
                        valid_loss,
                        get_model(self.model).rand_parameters(
                            self.mutator._cache),
                        retain_graph=True)

                    grad_teacher = torch.autograd.grad(
                        students_weight[0],
                        get_model(self.model).rand_parameters(
                            cand, self.pick_method == 'meta'),
                        grad_outputs=grad_student_val)

                    # update teacher model
                    for weight, grad_item in zip(
                            get_model(self.model).rand_parameters(
                                cand, self.pick_method == 'meta'),
                            grad_teacher):
                        weight.grad = grad_item
                    torch.nn.utils.clip_grad_norm_(
                        get_model(self.model).rand_parameters(
                            self.mutator._cache, self.pick_method == 'meta'),
                        1)
                    self.optimizer.step()
                    for weight, grad_item in zip(
                            get_model(self.model).rand_parameters(
                                cand, self.pick_method == 'meta'),
                            grad_teacher):
                        del weight.grad

                    for item in students_weight:
                        del item
                    del grad_teacher, grad_1, grad_student_val, x, held_out_x
                    del valid_loss, kd_loss, u_soft_label, u_output, u_teacher_output, output_2

                else:
                    raise ValueError("Must 1nd or 2nd update teacher weights")

            # get_best_teacher
            if self.best_children_pool:
                if self.pick_method == 'top1':
                    meta_value, cand = 0.5, sorted(self.best_children_pool,
                                                   reverse=True)[0][3]
                elif self.pick_method == 'meta':
                    meta_value, cand_idx, cand = -1000000000, -1, None
                    for now_idx, item in enumerate(self.best_children_pool):
                        inputx = item['input']
                        output = F.softmax(self.model(inputx), dim=1)
                        weight = get_model(
                            self.model).forward_meta(output -
                                                     item['feature_map'])
                        if weight > meta_value:
                            meta_value = weight
                            cand_idx = now_idx
                            cand = self.arch_dict[(
                                self.best_children_pool[cand_idx]['acc'],
                                self.best_children_pool[cand_idx]['arch_list']
                            )]
                    assert cand is not None
                    meta_value = torch.nn.functional.sigmoid(-weight)
                else:
                    raise ValueError('Method Not supported')
            if not self.best_children_pool:
                output = self.model(input_data)
                loss = self.loss(output, target)
                kd_loss = loss
            elif epoch <= self.meta_sta_epoch:
                output = self.model(input_data)
                loss = self.loss(output, target)
            else:
                output = self.model(input_data)
                with torch.no_grad():
                    # save student arch
                    saved_cache = self.mutator._cache
                    self.mutator._cache = cand

                    # forward
                    teacher_output = self.model(input_data).detach()

                    # restore student arch
                    self.mutator._cache = saved_cache
                    soft_label = F.softmax(teacher_output, dim=1)
                kd_loss = self.cross_entropy_loss_with_soft_target(
                    output, soft_label)
                valid_loss = self.loss(output, target)
                loss = (meta_value * kd_loss +
                        (2 - meta_value) * valid_loss) / 2

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            prec1, prec5 = self.accuracy(output, target, topk=(1, 5))
            metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
            metrics = self.reduce_metrics(metrics, self.distributed)
            meters.update(metrics)

            if epoch > self.meta_sta_epoch and (
                (len(self.best_children_pool) < self.pool_size) or
                (prec1 > self.best_children_pool[-1]['acc'] + 5) or
                (prec1 > self.best_children_pool[-1]['acc']
                 and cand_flops < self.best_children_pool[-1]['flops'])):
                val_prec1 = prec1
                training_data = deepcopy(input_data[:self.slices].detach())
                if not self.best_children_pool:
                    features = deepcopy(output[:self.slices].detach())
                else:
                    features = deepcopy(teacher_output[:self.slices].detach())
                self.best_children_pool.append({
                    'acc':
                    val_prec1,
                    'accu':
                    prec1,
                    'flops':
                    cand_flops,
                    'input':
                    training_data,
                    'feature_map':
                    F.softmax(features, dim=1)
                })
                self.arch_dict[(val_prec1, cand_flops)] = self.mutator._cache
                self.best_children_pool = sorted(self.best_children_pool,
                                                 key=lambda x: x['acc'],
                                                 reverse=True)

            if len(self.best_children_pool) > self.pool_size:
                self.best_children_pool = sorted(self.best_children_pool,
                                                 key=lambda x: x['acc'],
                                                 reverse=True)
                del self.best_children_pool[-1]

            if self.lr_scheduler is not None:
                self.lr_scheduler.step()

            if self.main_proc and self.log_frequency is not None and step % self.log_frequency == 0:
                logger.info("Epoch [%s/%s] Step [%s/%s]  %s",
                            epoch + 1, self.num_epochs, step + 1,
                            len(self.train_loader), meters)

        if self.main_proc:
            for idx, i in enumerate(self.best_children_pool):
                logger.info("No.%s %s", idx, i[:4])
示例#21
0
    def train_one_epoch(self, epoch):
        self.current_epoch = epoch
        meters = AverageMeterGroup()
        self.steps_per_epoch = len(self.train_loader)
        for step, (input_data, target) in enumerate(self.train_loader):
            self.mutator.reset()
            self.current_student_arch = self.mutator._cache

            input_data, target = input_data.cuda(), target.cuda()

            # calculate flops of current architecture
            cand_flops = self._get_cand_flops(self.mutator._cache)

            # update meta matching network
            self._run_update(input_data, target, step)

            if self._board_size() > 0:
                # select teacher architecture
                meta_value, teacher_cand = self._select_teacher()
                self.current_teacher_arch = teacher_cand

            # forward supernet
            if self._board_size() == 0 or epoch <= self.meta_sta_epoch:
                self._replace_mutator_cand(self.current_student_arch)
                output = self.model(input_data)

                loss = self.loss(output, target)
                kd_loss, teacher_output, teacher_cand = None, None, None
            else:
                self._replace_mutator_cand(self.current_student_arch)
                output = self.model(input_data)

                gt_loss = self.loss(output, target)

                with torch.no_grad():
                    self._replace_mutator_cand(self.current_teacher_arch)
                    teacher_output = self.model(input_data).detach()

                    soft_label = torch.nn.functional.softmax(teacher_output,
                                                             dim=1)
                kd_loss = self._cross_entropy_loss_with_soft_target(
                    output, soft_label)

                loss = (meta_value * kd_loss + (2 - meta_value) * gt_loss) / 2

            # update network
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # update metrics
            prec1, prec5 = accuracy(output, target, topk=(1, 5))
            metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
            metrics = reduce_metrics(metrics)
            meters.update(metrics)

            # update prioritized board
            self._update_prioritized_board(input_data, teacher_output, output,
                                           metrics['prec1'], cand_flops)

            if self.main_proc and (step % self.log_frequency == 0
                                   or step + 1 == self.steps_per_epoch):
                logger.info("Epoch [%d/%d] Step [%d/%d] %s",
                            epoch + 1, self.num_epochs, step + 1,
                            len(self.train_loader), meters)

        if self.main_proc and self.num_epochs == epoch + 1:
            for idx, i in enumerate(self.prioritized_board):
                logger.info("No.%s %s", idx, i[:4])
示例#22
0
    def train_one_epoch(self, epoch):
        # Sample model and train
        self.model.train()
        self.mutator.eval()
        meters = AverageMeterGroup()
        # COMMENT: 先训练模型
        for step in range(1, self.child_steps + 1):
            x, y = next(self.train_loader)
            x, y = to_device(x, self.device), to_device(y, self.device)
            self.optimizer.zero_grad()

            with torch.no_grad():
                self.mutator.reset()
            self._write_graph_status()
            logits = self.model(x)

            if isinstance(logits, tuple):
                logits, aux_logits = logits
                aux_loss = self.loss(aux_logits, y)
            else:
                aux_loss = 0.
            metrics = self.metrics(logits, y)
            # 计算acc 

            loss = self.loss(logits, y)
            # 计算loss

            loss = loss + self.aux_weight * aux_loss

            loss.backward()

            nn.utils.clip_grad_norm_(self.model.parameters(), 5.)

            self.optimizer.step()

            metrics["loss"] = loss.item()
            meters.update(metrics)

            if self.log_frequency is not None and step % self.log_frequency == 0:
                logger.info("Model Epoch [%d/%d] Step [%d/%d]  %s", epoch + 1,
                            self.num_epochs, step, self.child_steps, meters)

        # Train sampler (mutator)
        self.model.eval()
        self.mutator.train()
        # 然后训练变化器,突变器
        meters = AverageMeterGroup()
        for mutator_step in range(1, self.mutator_steps + 1):
            self.mutator_optim.zero_grad()
            for step in range(1, self.mutator_steps_aggregate + 1):
                x, y = next(self.valid_loader)
                x, y = to_device(x, self.device), to_device(y, self.device)

                self.mutator.reset()

                with torch.no_grad():
                    logits = self.model(x)

                self._write_graph_status()

                # 得到acc
                metrics = self.metrics(logits, y)
                # 得到reward

                '''
                def reward_accuracy(output, target, topk=(1,)):
                    batch_size = target.size(0)
                    
                    _, predicted = torch.max(output.data, 1)
                    return (predicted == target).sum().item() / batch_size
                '''

                reward = self.reward_function(logits, y) # 当前这个batch正确的个数

                if self.entropy_weight: # 交叉熵权重 
                    reward += self.entropy_weight * self.mutator.sample_entropy.item() # 得到样本熵
                
                self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
                # 有点policy gradient的感觉了
                
                loss = self.mutator.sample_log_prob * (reward - self.baseline)

                if self.skip_weight:
                    loss += self.skip_weight * self.mutator.sample_skip_penalty
                
                metrics["reward"] = reward
                metrics["loss"] = loss.item()
                metrics["ent"] = self.mutator.sample_entropy.item()
                metrics["log_prob"] = self.mutator.sample_log_prob.item()
                metrics["baseline"] = self.baseline
                metrics["skip"] = self.mutator.sample_skip_penalty

                loss /= self.mutator_steps_aggregate
                loss.backward()
                meters.update(metrics)

                cur_step = step + (mutator_step - 1) * self.mutator_steps_aggregate
                if self.log_frequency is not None and cur_step % self.log_frequency == 0:
                    logger.info("RL Epoch [%d/%d] Step [%d/%d] [%d/%d]  %s", epoch + 1, self.num_epochs,
                                mutator_step, self.mutator_steps, step, self.mutator_steps_aggregate,
                                meters)

            nn.utils.clip_grad_norm_(self.mutator.parameters(), 5.)
            self.mutator_optim.step()
示例#23
0
    def _joint_train(self, epoch):
        meters = AverageMeterGroup()
        for step in range(self.steps_per_epoch):
            totall_lc = 0
            totall_lw = 0
            totall_li = 0
            totall_lr = 0

            loss_regular = self.mutator_small.reset_with_loss()
            reg_decay = max(self.regular_coeff * (1 - float(epoch - self.warmup_epochs) / (
                    (self.epochs - self.warmup_epochs) * self.regular_ratio)), 0)
            if loss_regular:
                loss_regular *= reg_decay

            samples_x = []
            samples_y = []
            criterion_l = []
            emsemble_logits_l = []

            def trn_l(totall_lc, totall_lw, totall_li, totall_lr):

                self.model_large.train()
                self.optimizer_large.zero_grad()

                for fb in range(self.fake_batch):
                    val_x, val_y = next(self.valid_loader)
                    val_x, val_y = val_x.cuda(), val_y.cuda()

                    logits_main, emsemble_logits_main = self.model_large(val_x)
                    cel = self.criterion(logits_main, val_y)
                    loss_weight = cel / (self.fake_batch)
                    loss_weight.backward(retain_graph=True)

                    criterion_l.append(cel.cpu())
                    emsemble_logits_l.append(emsemble_logits_main.cpu())

                    totall_lw += float(loss_weight)
                    samples_x.append(val_x.cpu())
                    samples_y.append(val_y.cpu())

                self._clip_grad_norm(self.model_large)
                self.optimizer_large.step()
                self.model_large.train(mode=False)

                return totall_lc, totall_lw, totall_li, totall_lr

            totall_lc, totall_lw, totall_li, totall_lr = trn_l(totall_lc, totall_lw, totall_li, totall_lr)
            def sleep(s):
                print("--" + str(s))
                time.sleep(2)
                print(torch.cuda.memory_summary())
                print("++" + str(s))

            def trn_s(totall_lc, totall_lw, totall_li, totall_lr):
                print("sts")
                self.model_small.cuda()
                self.model_small.train()
                self.optimizer_alpha.zero_grad()
                self.optimizer_small.zero_grad()
                i = 0;
                ls = []
                els = []
                sleep(0)
                def sc():
                    reg_decay = max(self.regular_coeff * (1 - float(epoch - self.warmup_epochs) / (
                            (self.epochs - self.warmup_epochs) * self.regular_ratio)), 0)
                    loss_regular = self.mutator_small.reset_with_loss()
                    if loss_regular:
                        loss_regular *= reg_decay
                    loss_regular.backward()
                    loss_regular = loss_regular.cpu().detach()
                sc()
                sleep(0.5)
                for i in range(len(samples_x)):
                    val_x = samples_x[i]
                    val_x = val_x.cuda()
                    val_y = samples_y[i]
                    val_y = val_y.cuda()


                    logits_search, emsemble_logits_search = self.model_small(val_x)
                    cls = self.criterion(logits_search, val_y)

                    ls.append(cls.cpu())
                    els.append(emsemble_logits_search.cpu())
                    val_x.cpu().detach()
                    val_y.cpu().detach()

                sleep(1)
                for i in range(len(samples_x)):
                    criterion_logits_main = criterion_l[i].cuda()
                    cls = ls[i].cuda()
                    emsemble_logits_search = els[i].cuda()
                    loss_weight = cls / (self.fake_batch)
                    totall_lw += float(loss_weight)
                    loss_cls = (cls + criterion_logits_main) / self.loss_alpha / self.fake_batch
                    loss_cls.backward(retain_graph=True)
                    totall_lc += float(loss_cls)
                    criterion_logits_main.cpu().detach()

                sleep(2)
                for i in range(len(samples_x)):
                    emsemble_logits_main = emsemble_logits_l[i].cuda()
                    emsemble_logits_search = els[i].cuda()
                    sleep(3)
                    loss_interactive = self.interactive_loss(emsemble_logits_search, emsemble_logits_main) * (
                                self.loss_T ** 2) * self.loss_alpha / self.fake_batch
                    loss_interactive.backward(retain_graph=True)
                    sleep(5)
                    emsemble_logits_search.cpu()
                    totall_li += float(loss_interactive)
                    totall_lr += float(loss_regular)
                    emsemble_logits_search.cpu().detach()
                    emsemble_logits_main.cpu().detach()
                    sleep(6)
                    i = i + 1


                self.optimizer_alpha.step()
                self._clip_grad_norm(self.model_small)
                self.optimizer_small.step()
                self.model_small.train(mode=False)
                samples_x.clear()
                samples_y.clear()
                criterion_l.clear()
                emsemble_logits_l.clear()
                return totall_lc, totall_lw, totall_li, totall_lr

            totall_lc, totall_lw, totall_li, totall_lr = trn_s(totall_lc, totall_lw, totall_li, totall_lr)



            metrics = {"loss_cls": totall_lc, "loss_interactive": totall_li,
                       "loss_regular": totall_lr, "loss_weight": totall_lw}
            #metrics = reduce_metrics(metrics, self.distributed)
            meters.update(metrics)

            if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch):
                self.logger.info("Epoch [%d/%d] Step [%d/%d] (joint)  %s", epoch + 1, self.epochs,
                                 step + 1, self.steps_per_epoch, meters)
示例#24
0
def train(logger,
          config,
          train_loader,
          model,
          optimizer,
          criterion,
          epoch,
          main_proc,
          fake_batch=4,
          steps=128):
    meters = AverageMeterGroup()
    cur_lr = optimizer.param_groups[0]["lr"]
    if main_proc:
        logger.info("Epoch %d LR %.6f", epoch, cur_lr)

    model.train()
    for step in range(steps):
        totall_l = 0
        totall_p = 0

        def top(totall_l, totall_p):
            optimizer.zero_grad()

            def s(totall_l, totall_p):
                x, y = next(train_loader)
                x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)
                logits, aux_logits = model(x)
                loss = criterion(logits, y)
                if config.aux_weight > 0.:
                    loss += config.aux_weight * criterion(aux_logits, y)
                loss = loss / fake_batch
                try:
                    prec1, prec1 = utils.accuracy(logits, y, topk=(1, 1))
                    prec1 = prec1 / fake_batch
                    totall_p += prec1
                    loss.backward(retain_graph=True)
                    totall_l += float(loss)
                except:
                    print("Err")
                    pass
                return totall_l, totall_p

            for fb in range(fake_batch):
                totall_l, totall_p = s(totall_l, totall_p)
            optimizer.step()
            return totall_l, totall_p

        totall_l, totall_p = top(totall_l, totall_p)
        nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
        metrics = {"prec1": totall_p, "loss": totall_l}
        # metrics = utils.reduce_metrics(metrics, config.distributed)
        meters.update(metrics)

        if main_proc and (step % config.log_frequency == 0
                          or step + 1 == len(train_loader)):
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': totall_l,
                    'model': model
                }, 'model_' + str(epoch) + '.pt')
            logger.info("Epoch [%d/%d] Step [%d/%d]  %s", epoch + 1,
                        config.epochs, step + 1, steps, meters)

    if main_proc:
        logger.info("Train: [%d/%d] Final Prec@1 %.4f", epoch + 1,
                    config.epochs, meters.prec1.avg)