def train(epoch, model, criterion, optimizer, loader, writer, args): model.train() meters = AverageMeterGroup() cur_lr = optimizer.param_groups[0]["lr"] for step, (x, y) in enumerate(loader): cur_step = len(loader) * epoch + step optimizer.zero_grad() logits = model(x) loss = criterion(logits, y) loss.backward() optimizer.step() metrics = accuracy(logits, y) metrics["loss"] = loss.item() meters.update(metrics) writer.add_scalar("lr", cur_lr, global_step=cur_step) writer.add_scalar("loss/train", loss.item(), global_step=cur_step) writer.add_scalar("acc1/train", metrics["acc1"], global_step=cur_step) writer.add_scalar("acc5/train", metrics["acc5"], global_step=cur_step) if step % args.log_frequency == 0 or step + 1 == len(loader): logger.info("Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, args.epochs, step + 1, len(loader), meters) logger.info("Epoch %d training summary: %s", epoch + 1, meters)
def validate_one_epoch(self, epoch): # return list of meters meter_list = [] self.model.eval() self.mutator.eval() meters = AverageMeterGroup() with torch.no_grad(): self.mutator.reset() for step, (X, y) in enumerate(self.test_loader): X, y = X.to(self.device), y.to(self.device) logits = self.model(X) metrics = self.metrics(logits, y) loss = self.loss(logits, y) metrics["loss"] = loss.item() meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, self.num_epochs, step + 1, len(self.test_loader), meters) meter_dict = json.loads(json.dumps('{' + meters.summary() + '}')) meter_list.append(json.loads(meter_dict)) return meter_list
def train_one_epoch(self, epoch): self.model.train() self.mutator.train() meters = AverageMeterGroup() for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)): if self.debug and step > 0: break trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device) val_X, val_y = val_X.to(self.device), val_y.to(self.device) # phase 1. architecture step self.ctrl_optim.zero_grad() if self.unrolled: self._unrolled_backward(trn_X, trn_y, val_X, val_y) else: self._backward(val_X, val_y) self.ctrl_optim.step() # phase 2: child network step self.optimizer.zero_grad() logits, loss = self._logits_and_loss(trn_X, trn_y) loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), 5.) # gradient clipping self.optimizer.step() metrics = self.metrics(logits, trn_y) metrics["loss"] = loss.item() meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: self.logger.info("Model Epoch [{}/{}] Step [{}/{}] Model size: {} {}".format( epoch + 1, self.num_epochs, step + 1, len(self.train_loader), self.model_size(), meters)) return meters
def train(logger, config, train_loader, model, optimizer, criterion, epoch, main_proc): meters = AverageMeterGroup() cur_lr = optimizer.param_groups[0]["lr"] if main_proc: logger.info("Epoch %d LR %.6f", epoch, cur_lr) model.train() for step, (x, y) in enumerate(train_loader): x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True) optimizer.zero_grad() logits, aux_logits = model(x) loss = criterion(logits, y) if config.aux_weight > 0.: loss += config.aux_weight * criterion(aux_logits, y) loss.backward() nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) optimizer.step() prec1, prec5 = utils.accuracy(logits, y, topk=(1, 5)) metrics = {"prec1": prec1, "prec5": prec5, "loss": loss} metrics = utils.reduce_metrics(metrics, config.distributed) meters.update(metrics) if main_proc and (step % config.log_frequency == 0 or step + 1 == len(train_loader)): logger.info("Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, config.epochs, step + 1, len(train_loader), meters) if main_proc: logger.info("Train: [%d/%d] Final Prec@1 %.4f Prec@5 %.4f", epoch + 1, config.epochs, meters.prec1.avg, meters.prec5.avg)
def _warmup(self, phase, epoch): assert phase in [PHASE_SMALL, PHASE_LARGE] if phase == PHASE_SMALL: model, optimizer = self.model_small, self.optimizer_small elif phase == PHASE_LARGE: model, optimizer = self.model_large, self.optimizer_large model.train() meters = AverageMeterGroup() for step in range(self.steps_per_epoch): x, y = next(self.train_loader) x, y = x.cuda(), y.cuda() optimizer.zero_grad() logits_main, _ = model(x) loss = self.criterion(logits_main, y) loss.backward() self._clip_grad_norm(model) optimizer.step() prec1, prec5 = accuracy(logits_main, y, topk=(1, 5)) metrics = {"prec1": prec1, "prec5": prec5, "loss": loss} metrics = reduce_metrics(metrics, self.distributed) meters.update(metrics) if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch): self.logger.info("Epoch [%d/%d] Step [%d/%d] (%s) %s", epoch + 1, self.epochs, step + 1, self.steps_per_epoch, phase, meters)
def train_one_epoch(self, epoch): self.mutator.train() meters = AverageMeterGroup() for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.batched_train, self.batched_validate)): trn_X = pad_sequence([self.model.vectors[x] for x in trn_X]).permute(1,0,2) val_X = pad_sequence([self.model.vectors[x] for x in val_X]).permute(1,0,2) trn_X = trn_X.to(self.device) trn_y = torch.stack([y.int() for y in trn_y]).to(self.device) val_X = val_X.to(self.device) val_y = torch.stack([y.int() for y in val_y]).to(self.device) # phase 1. architecture step self.ctrl_optim.zero_grad() if self.unrolled: self._unrolled_backward(trn_X, trn_y, val_X, val_y) else: self._backward(val_X, val_y) self.ctrl_optim.step() # phase 2: child network step self.optimizer.zero_grad() logits, loss = self._logits_and_loss(trn_X, trn_y) loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), 5.) # gradient clipping self.optimizer.step() metrics = self.metrics(logits, trn_y) metrics["loss"] = loss.item() meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, self.num_epochs, step + 1, len(self.train_loader), meters)
def _joint_train(self, epoch): self.model_large.train() self.model_small.train() meters = AverageMeterGroup() for step in range(self.steps_per_epoch): trn_x, trn_y = next(self.train_loader) val_x, val_y = next(self.valid_loader) trn_x, trn_y = trn_x.cuda(), trn_y.cuda() val_x, val_y = val_x.cuda(), val_y.cuda() # step 1. optimize architecture self.optimizer_alpha.zero_grad() self.optimizer_large.zero_grad() reg_decay = max( self.regular_coeff * (1 - float(epoch - self.warmup_epochs) / ((self.epochs - self.warmup_epochs) * self.regular_ratio)), 0) loss_regular = self.mutator_small.reset_with_loss() if loss_regular: loss_regular *= reg_decay logits_search, emsemble_logits_search = self.model_small(val_x) logits_main, emsemble_logits_main = self.model_large(val_x) loss_cls = (self.criterion(logits_search, val_y) + self.criterion(logits_main, val_y)) / self.loss_alpha loss_interactive = self.interactive_loss( emsemble_logits_search, emsemble_logits_main) * (self.loss_T**2) * self.loss_alpha loss = loss_cls + loss_interactive + loss_regular loss.backward() self._clip_grad_norm(self.model_large) self.optimizer_large.step() self.optimizer_alpha.step() # NOTE: need to call here `self._reset_nan(self.mutator_small.parameters())` if `cut_choices` # step 2. optimize op weights self.optimizer_small.zero_grad() with torch.no_grad(): # resample architecture since parameters have been changed self.mutator_small.reset_with_loss() logits_search_train, _ = self.model_small(trn_x) loss_weight = self.criterion(logits_search_train, trn_y) loss_weight.backward() self._clip_grad_norm(self.model_small) self.optimizer_small.step() metrics = { "loss_cls": loss_cls, "loss_interactive": loss_interactive, "loss_regular": loss_regular, "loss_weight": loss_weight } metrics = reduce_metrics(metrics, self.distributed) meters.update(metrics) if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch): self.logger.info("Epoch [%d/%d] Step [%d/%d] (joint) %s", epoch + 1, self.epochs, step + 1, self.steps_per_epoch, meters)
def validate_one_epoch(self, epoch): self.model.eval() self.mutator.eval() meters = AverageMeterGroup() with torch.no_grad(): self.mutator.reset() for step, (X, y) in enumerate(self.test_loader): X, y = X.to(self.device), y.to(self.device) logits = self.model(X) metrics = self.metrics(logits, y) meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, self.num_epochs, step + 1, len(self.test_loader), meters)
def validate_one_epoch(self, epoch): self.model.eval() meters = AverageMeterGroup() with torch.no_grad(): for step, (x, y) in enumerate(self.valid_loader): self.mutator.reset() logits = self.model(x) loss = self.loss(logits, y) metrics = self.metrics(logits, y) metrics["loss"] = loss.item() meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1, self.num_epochs, step + 1, len(self.valid_loader), meters)
def validate_one_epoch(self, epoch): self.model.eval() self.mutator.eval() meters = AverageMeterGroup() with torch.no_grad(): self.mutator.reset() for step, (X, y) in enumerate(self.batched_test): X = pad_sequence([self.model.vectors[x] for x in X]).permute(1,0,2) X = X.to(self.device) y = torch.stack([y_.int() for y_ in y]).to(self.device) logits = self.model(X) metrics = self.metrics(logits, y) meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, self.num_epochs, step + 1, len(self.test_loader), meters)
def test_acc(model, criterion, log_freq, loader): logger.info("Start testing...") model.eval() meters = AverageMeterGroup() start_time = time.time() with torch.no_grad(): for step, (inputs, targets) in enumerate(loader): logits = model(inputs) loss = criterion(logits, targets) metrics = accuracy(logits, targets) metrics["loss"] = loss.item() meters.update(metrics) if step % log_freq == 0 or step + 1 == len(loader): logger.info("Valid Step [%d/%d] time %.3fs acc1 %.4f acc5 %.4f loss %.4f", step + 1, len(loader), time.time() - start_time, meters.acc1.avg, meters.acc5.avg, meters.loss.avg) return meters.acc1.avg
def train_one_epoch(self, epoch): self.model.train() meters = AverageMeterGroup() for step, (x, y) in enumerate(self.train_loader): self.optimizer.zero_grad() self.mutator.reset() logits = self.model(x) loss = self.loss(logits, y) loss.backward() self.optimizer.step() metrics = self.metrics(logits, y) metrics["loss"] = loss.item() meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, self.num_epochs, step + 1, len(self.train_loader), meters)
def test_one_epoch(self, epoch): self.model.eval() self.mutator.eval() meters = AverageMeterGroup() with torch.no_grad(): self.mutator.reset() for step, (X, y) in enumerate(self.test_loader): if self.debug and step > 0: break X, y = X.to(self.device), y.to(self.device) logits = self.model(X) metrics = self.metrics(logits, y) meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: self.logger.info("Test: Step [{}/{}] {}".format(step + 1, len(self.test_loader), meters)) self.logger.info("Final model metric = {}".format(meters.meters['save_metric'].avg)) return meters
def validate_one_epoch(self, epoch): self.model.eval() meters = AverageMeterGroup() with torch.no_grad(): for step, (x, y) in enumerate(self.valid_loader): self.mutator.reset() logits = self.model(x) loss = self.val_loss(logits, y) prec1, prec5 = accuracy(logits, y, topk=(1, 5)) metrics = {"prec1": prec1, "prec5": prec5, "loss": loss} metrics = reduce_metrics(metrics) meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1, self.num_epochs, step + 1, len(self.valid_loader), meters)
def validate_one_epoch(self, epoch): with torch.no_grad(): for arc_id in range(self.test_arc_per_epoch): meters = AverageMeterGroup() for x, y in self.test_loader: x, y = to_device(x, self.device), to_device(y, self.device) self.mutator.reset() logits = self.model(x) if isinstance(logits, tuple): logits, _ = logits metrics = self.metrics(logits, y) loss = self.loss(logits, y) metrics["loss"] = loss.item() meters.update(metrics) logger.info("Test Epoch [%d/%d] Arc [%d/%d] Summary %s", epoch + 1, self.num_epochs, arc_id + 1, self.test_arc_per_epoch, meters.summary())
def validate(epoch, model, criterion, loader, writer, args): model.eval() meters = AverageMeterGroup() with torch.no_grad(): for step, (x, y) in enumerate(loader): logits = model(x) loss = criterion(logits, y) metrics = accuracy(logits, y) metrics["loss"] = loss.item() meters.update(metrics) if step % args.log_frequency == 0 or step + 1 == len(loader): logger.info("Epoch [%d/%d] Validation Step [%d/%d] %s", epoch + 1, args.epochs, step + 1, len(loader), meters) writer.add_scalar("loss/test", meters.loss.avg, global_step=epoch) writer.add_scalar("acc1/test", meters.acc1.avg, global_step=epoch) writer.add_scalar("acc5/test", meters.acc5.avg, global_step=epoch) logger.info("Epoch %d validation: top1 = %f, top5 = %f", epoch + 1, meters.acc1.avg, meters.acc5.avg)
def retrain_bn(model, criterion, max_iters, log_freq, loader): with torch.no_grad(): logger.info("Clear BN statistics...") for m in model.modules(): if isinstance(m, nn.BatchNorm2d): m.running_mean = torch.zeros_like(m.running_mean) m.running_var = torch.ones_like(m.running_var) logger.info("Train BN with training set (BN sanitize)...") model.train() meters = AverageMeterGroup() for step in range(max_iters): inputs, targets = next(loader) logits = model(inputs) loss = criterion(logits, targets) metrics = accuracy(logits, targets) metrics["loss"] = loss.item() meters.update(metrics) if step % log_freq == 0 or step + 1 == max_iters: logger.info("Train Step [%d/%d] %s", step + 1, max_iters, meters)
def validate_one_epoch(self, epoch): # return a list of meters meter_list = [] with torch.no_grad(): for arc_id in range(self.test_arc_per_epoch): meters = AverageMeterGroup() for x, y in self.test_loader: x, y = to_device(x, self.device), to_device(y, self.device) self.mutator.reset() logits = self.model(x) if isinstance(logits, tuple): logits, _ = logits metrics = self.metrics(logits, y) loss = self.loss(logits, y) metrics["loss"] = loss.item() meters.update(metrics) meter_dict = json.loads( json.dumps('{' + meters.summary() + '}')) #print("meter_dict: {}".format(meter_dict)) #print("type of meter_dict: {}".format(type(json.loads(meter_dict)))) meter_list.append(json.loads(meter_dict)) logger.info("Test Epoch [%d/%d] Arc [%d/%d] Summary %s", epoch + 1, self.num_epochs, arc_id + 1, self.test_arc_per_epoch, meters.summary()) return meter_list
def validate(logger, config, valid_loader, model, criterion, epoch, main_proc): meters = AverageMeterGroup() model.eval() with torch.no_grad(): for step, (x, y) in enumerate(valid_loader): x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True) logits, _ = model(x) loss = criterion(logits, y) prec1, prec5 = utils.accuracy(logits, y, topk=(1, 5)) metrics = {"prec1": prec1, "prec5": prec5, "loss": loss} metrics = utils.reduce_metrics(metrics, config.distributed) meters.update(metrics) if main_proc and (step % config.log_frequency == 0 or step + 1 == len(valid_loader)): logger.info("Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, config.epochs, step + 1, len(valid_loader), meters) if main_proc: logger.info("Train: [%d/%d] Final Prec@1 %.4f Prec@5 %.4f", epoch + 1, config.epochs, meters.prec1.avg, meters.prec5.avg) return meters.prec1.avg, meters.prec5.avg
def train_one_epoch(self, epoch): # Train sampler (mutator) self.model.eval() self.mutator.train() total_loss=0 meters = AverageMeterGroup() for mutator_step in range(1, self.mutator_steps + 1): self.mutator_optim.zero_grad() for step in range(1, self.mutator_steps_aggregate + 1): x, y = next(self.valid_loader) x, y = to_device(x, self.device), to_device(y, self.device) self.mutator.reset() with torch.no_grad(): logits = self.model(x) self._write_graph_status() jacobian = get_batch_jacobian(self.model, x) jacobian = jacobian.reshape(jacobian.size(0), -1) reward = eval_score(jacobian) total_loss += reward.item() if self.entropy_weight: reward += self.entropy_weight * self.mutator.sample_entropy.item() # https://arxiv.org/pdf/1707.06347.pdf self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay) loss = self.mutator.sample_log_prob * (reward - self.baseline) if self.skip_weight: loss += self.skip_weight * self.mutator.sample_skip_penalty loss /= self.mutator_steps_aggregate loss.backward() cur_step = step + (mutator_step - 1) * self.mutator_steps_aggregate if self.log_frequency is not None and cur_step % self.log_frequency == 0: logger.info("RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s", epoch + 1, self.num_epochs, mutator_step, self.mutator_steps, step, self.mutator_steps_aggregate, reward.int()) nn.utils.clip_grad_norm_(self.mutator.parameters(), 5.) self.mutator_optim.step() mlflow.log_metric('Total reward', -total_loss/(self.mutator_steps*self.mutator_steps_aggregate), epoch) torch.save({ 'model':self.mutator.state_dict(), 'optimizer':self.mutator_optim.state_dict() }, 'mutator_run_stats.pyt') mlflow.log_artifact('mutator_run_stats.pyt')
def train_one_epoch(self, epoch): # Sample model and train self.model.train() self.mutator.eval() meters = AverageMeterGroup() for step, (x, y) in enumerate(self.train_loader): x, y = x.to(self.device), y.to(self.device) self.optimizer.zero_grad() with torch.no_grad(): self.mutator.reset() logits = self.model(x) if isinstance(logits, tuple): logits, aux_logits = logits aux_loss = self.loss(aux_logits, y) else: aux_loss = 0. metrics = self.metrics(logits, y) loss = self.loss(logits, y) loss = loss + self.aux_weight * aux_loss loss.backward() self.optimizer.step() metrics["loss"] = loss.item() meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: logger.info("Model Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, self.num_epochs, step + 1, len(self.train_loader), meters) # Train sampler (mutator) self.model.eval() self.mutator.train() meters = AverageMeterGroup() mutator_step, total_mutator_steps = 0, self.mutator_steps * self.mutator_steps_aggregate while mutator_step < total_mutator_steps: for step, (x, y) in enumerate(self.valid_loader): x, y = x.to(self.device), y.to(self.device) self.mutator.reset() with torch.no_grad(): logits = self.model(x) metrics = self.metrics(logits, y) reward = self.reward_function(logits, y) if self.entropy_weight is not None: reward += self.entropy_weight * self.mutator.sample_entropy self.baseline = self.baseline * self.baseline_decay + reward * ( 1 - self.baseline_decay) self.baseline = self.baseline.detach().item() loss = self.mutator.sample_log_prob * (reward - self.baseline) if self.skip_weight: loss += self.skip_weight * self.mutator.sample_skip_penalty metrics["reward"] = reward metrics["loss"] = loss.item() metrics["ent"] = self.mutator.sample_entropy.item() metrics["baseline"] = self.baseline metrics["skip"] = self.mutator.sample_skip_penalty loss = loss / self.mutator_steps_aggregate loss.backward() meters.update(metrics) if mutator_step % self.mutator_steps_aggregate == 0: self.mutator_optim.step() self.mutator_optim.zero_grad() if self.log_frequency is not None and step % self.log_frequency == 0: logger.info( "RL Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, self.num_epochs, mutator_step // self.mutator_steps_aggregate + 1, self.mutator_steps, meters) mutator_step += 1 if mutator_step >= total_mutator_steps: break
def train_one_epoch(self, epoch): def get_model(model): return model.module meters = AverageMeterGroup() for step, (input_data, target) in enumerate(self.train_loader): self.optimizer.zero_grad() self.mutator.reset() input_data = input_data.cuda() target = target.cuda() cand_flops = self.est.get_flops(self.mutator._cache) if epoch > self.meta_sta_epoch and step > 0 and step % self.update_iter == 0: slice_ind = self.slices x = deepcopy(input_data[:slice_ind].clone().detach()) if self.best_children_pool: if self.pick_method == 'top1': meta_value, cand = 1, sorted(self.best_children_pool, reverse=True)[0][3] elif self.pick_method == 'meta': meta_value, cand_idx, cand = -1000000000, -1, None for now_idx, item in enumerate( self.best_children_pool): inputx = item['input'] output = F.softmax(self.model(inputx), dim=1) weight = get_model( self.model).forward_meta(output - item['feature_map']) if weight > meta_value: meta_value = weight # deepcopy(torch.nn.functional.sigmoid(weight)) cand_idx = now_idx cand = self.arch_dict[( self.best_children_pool[cand_idx]['acc'], self.best_children_pool[cand_idx] ['arch_list'])] assert cand is not None meta_value = torch.nn.functional.sigmoid(-weight) else: raise ValueError('Method Not supported') u_output = self.model(x) saved_cache = self.mutator._cache self.mutator._cache = cand u_teacher_output = self.model(x) self.mutator._cache = saved_cache u_soft_label = F.softmax(u_teacher_output, dim=1) kd_loss = meta_value * self.cross_entropy_loss_with_soft_target( u_output, u_soft_label) self.optimizer.zero_grad() grad_1 = torch.autograd.grad( kd_loss, get_model(self.model).rand_parameters( self.mutator._cache), create_graph=True) def raw_sgd(w, g): return g * self.optimizer.param_groups[-1]['lr'] + w students_weight = [ raw_sgd(p, grad_item) for p, grad_item in zip( get_model(self.model).rand_parameters( self.mutator._cache), grad_1) ] # update student weights for weight, grad_item in zip( get_model(self.model).rand_parameters( self.mutator._cache), grad_1): weight.grad = grad_item torch.nn.utils.clip_grad_norm_( get_model(self.model).rand_parameters( self.mutator._cache), 1) self.optimizer.step() for weight, grad_item in zip( get_model(self.model).rand_parameters( self.mutator._cache), grad_1): del weight.grad held_out_x = deepcopy(input_data[slice_ind:slice_ind * 2].clone().detach()) output_2 = self.model(held_out_x) valid_loss = self.loss(output_2, target[slice_ind:slice_ind * 2]) self.optimizer.zero_grad() grad_student_val = torch.autograd.grad( valid_loss, get_model(self.model).rand_parameters( self.mutator._cache), retain_graph=True) grad_teacher = torch.autograd.grad( students_weight[0], get_model(self.model).rand_parameters( cand, self.pick_method == 'meta'), grad_outputs=grad_student_val) # update teacher model for weight, grad_item in zip( get_model(self.model).rand_parameters( cand, self.pick_method == 'meta'), grad_teacher): weight.grad = grad_item torch.nn.utils.clip_grad_norm_( get_model(self.model).rand_parameters( self.mutator._cache, self.pick_method == 'meta'), 1) self.optimizer.step() for weight, grad_item in zip( get_model(self.model).rand_parameters( cand, self.pick_method == 'meta'), grad_teacher): del weight.grad for item in students_weight: del item del grad_teacher, grad_1, grad_student_val, x, held_out_x del valid_loss, kd_loss, u_soft_label, u_output, u_teacher_output, output_2 else: raise ValueError("Must 1nd or 2nd update teacher weights") # get_best_teacher if self.best_children_pool: if self.pick_method == 'top1': meta_value, cand = 0.5, sorted(self.best_children_pool, reverse=True)[0][3] elif self.pick_method == 'meta': meta_value, cand_idx, cand = -1000000000, -1, None for now_idx, item in enumerate(self.best_children_pool): inputx = item['input'] output = F.softmax(self.model(inputx), dim=1) weight = get_model( self.model).forward_meta(output - item['feature_map']) if weight > meta_value: meta_value = weight cand_idx = now_idx cand = self.arch_dict[( self.best_children_pool[cand_idx]['acc'], self.best_children_pool[cand_idx]['arch_list'] )] assert cand is not None meta_value = torch.nn.functional.sigmoid(-weight) else: raise ValueError('Method Not supported') if not self.best_children_pool: output = self.model(input_data) loss = self.loss(output, target) kd_loss = loss elif epoch <= self.meta_sta_epoch: output = self.model(input_data) loss = self.loss(output, target) else: output = self.model(input_data) with torch.no_grad(): # save student arch saved_cache = self.mutator._cache self.mutator._cache = cand # forward teacher_output = self.model(input_data).detach() # restore student arch self.mutator._cache = saved_cache soft_label = F.softmax(teacher_output, dim=1) kd_loss = self.cross_entropy_loss_with_soft_target( output, soft_label) valid_loss = self.loss(output, target) loss = (meta_value * kd_loss + (2 - meta_value) * valid_loss) / 2 self.optimizer.zero_grad() loss.backward() self.optimizer.step() prec1, prec5 = self.accuracy(output, target, topk=(1, 5)) metrics = {"prec1": prec1, "prec5": prec5, "loss": loss} metrics = self.reduce_metrics(metrics, self.distributed) meters.update(metrics) if epoch > self.meta_sta_epoch and ( (len(self.best_children_pool) < self.pool_size) or (prec1 > self.best_children_pool[-1]['acc'] + 5) or (prec1 > self.best_children_pool[-1]['acc'] and cand_flops < self.best_children_pool[-1]['flops'])): val_prec1 = prec1 training_data = deepcopy(input_data[:self.slices].detach()) if not self.best_children_pool: features = deepcopy(output[:self.slices].detach()) else: features = deepcopy(teacher_output[:self.slices].detach()) self.best_children_pool.append({ 'acc': val_prec1, 'accu': prec1, 'flops': cand_flops, 'input': training_data, 'feature_map': F.softmax(features, dim=1) }) self.arch_dict[(val_prec1, cand_flops)] = self.mutator._cache self.best_children_pool = sorted(self.best_children_pool, key=lambda x: x['acc'], reverse=True) if len(self.best_children_pool) > self.pool_size: self.best_children_pool = sorted(self.best_children_pool, key=lambda x: x['acc'], reverse=True) del self.best_children_pool[-1] if self.lr_scheduler is not None: self.lr_scheduler.step() if self.main_proc and self.log_frequency is not None and step % self.log_frequency == 0: logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, self.num_epochs, step + 1, len(self.train_loader), meters) if self.main_proc: for idx, i in enumerate(self.best_children_pool): logger.info("No.%s %s", idx, i[:4])
def _joint_train(self, epoch): meters = AverageMeterGroup() for step in range(self.steps_per_epoch): totall_lc = 0 totall_lw = 0 totall_li = 0 totall_lr = 0 loss_regular = self.mutator_small.reset_with_loss() reg_decay = max(self.regular_coeff * (1 - float(epoch - self.warmup_epochs) / ( (self.epochs - self.warmup_epochs) * self.regular_ratio)), 0) if loss_regular: loss_regular *= reg_decay samples_x = [] samples_y = [] criterion_l = [] emsemble_logits_l = [] def trn_l(totall_lc, totall_lw, totall_li, totall_lr): self.model_large.train() self.optimizer_large.zero_grad() for fb in range(self.fake_batch): val_x, val_y = next(self.valid_loader) val_x, val_y = val_x.cuda(), val_y.cuda() logits_main, emsemble_logits_main = self.model_large(val_x) cel = self.criterion(logits_main, val_y) loss_weight = cel / (self.fake_batch) loss_weight.backward(retain_graph=True) criterion_l.append(cel.cpu()) emsemble_logits_l.append(emsemble_logits_main.cpu()) totall_lw += float(loss_weight) samples_x.append(val_x.cpu()) samples_y.append(val_y.cpu()) self._clip_grad_norm(self.model_large) self.optimizer_large.step() self.model_large.train(mode=False) return totall_lc, totall_lw, totall_li, totall_lr totall_lc, totall_lw, totall_li, totall_lr = trn_l(totall_lc, totall_lw, totall_li, totall_lr) def sleep(s): print("--" + str(s)) time.sleep(2) print(torch.cuda.memory_summary()) print("++" + str(s)) def trn_s(totall_lc, totall_lw, totall_li, totall_lr): print("sts") self.model_small.cuda() self.model_small.train() self.optimizer_alpha.zero_grad() self.optimizer_small.zero_grad() i = 0; ls = [] els = [] sleep(0) def sc(): reg_decay = max(self.regular_coeff * (1 - float(epoch - self.warmup_epochs) / ( (self.epochs - self.warmup_epochs) * self.regular_ratio)), 0) loss_regular = self.mutator_small.reset_with_loss() if loss_regular: loss_regular *= reg_decay loss_regular.backward() loss_regular = loss_regular.cpu().detach() sc() sleep(0.5) for i in range(len(samples_x)): val_x = samples_x[i] val_x = val_x.cuda() val_y = samples_y[i] val_y = val_y.cuda() logits_search, emsemble_logits_search = self.model_small(val_x) cls = self.criterion(logits_search, val_y) ls.append(cls.cpu()) els.append(emsemble_logits_search.cpu()) val_x.cpu().detach() val_y.cpu().detach() sleep(1) for i in range(len(samples_x)): criterion_logits_main = criterion_l[i].cuda() cls = ls[i].cuda() emsemble_logits_search = els[i].cuda() loss_weight = cls / (self.fake_batch) totall_lw += float(loss_weight) loss_cls = (cls + criterion_logits_main) / self.loss_alpha / self.fake_batch loss_cls.backward(retain_graph=True) totall_lc += float(loss_cls) criterion_logits_main.cpu().detach() sleep(2) for i in range(len(samples_x)): emsemble_logits_main = emsemble_logits_l[i].cuda() emsemble_logits_search = els[i].cuda() sleep(3) loss_interactive = self.interactive_loss(emsemble_logits_search, emsemble_logits_main) * ( self.loss_T ** 2) * self.loss_alpha / self.fake_batch loss_interactive.backward(retain_graph=True) sleep(5) emsemble_logits_search.cpu() totall_li += float(loss_interactive) totall_lr += float(loss_regular) emsemble_logits_search.cpu().detach() emsemble_logits_main.cpu().detach() sleep(6) i = i + 1 self.optimizer_alpha.step() self._clip_grad_norm(self.model_small) self.optimizer_small.step() self.model_small.train(mode=False) samples_x.clear() samples_y.clear() criterion_l.clear() emsemble_logits_l.clear() return totall_lc, totall_lw, totall_li, totall_lr totall_lc, totall_lw, totall_li, totall_lr = trn_s(totall_lc, totall_lw, totall_li, totall_lr) metrics = {"loss_cls": totall_lc, "loss_interactive": totall_li, "loss_regular": totall_lr, "loss_weight": totall_lw} #metrics = reduce_metrics(metrics, self.distributed) meters.update(metrics) if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch): self.logger.info("Epoch [%d/%d] Step [%d/%d] (joint) %s", epoch + 1, self.epochs, step + 1, self.steps_per_epoch, meters)
def train_one_epoch(self, epoch): # Sample model and train self.model.train() self.mutator.eval() meters = AverageMeterGroup() for step in range(1, self.child_steps + 1): x, y = next(self.train_loader) x, y = to_device(x, self.device), to_device(y, self.device) self.optimizer.zero_grad() with torch.no_grad(): self.mutator.reset() logits = self.model(x) if isinstance(logits, tuple): logits, aux_logits = logits aux_loss = self.loss(aux_logits, y) else: aux_loss = 0. metrics = self.metrics(logits, y) loss = self.loss(logits, y) loss = loss + self.aux_weight * aux_loss loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), 5.) self.optimizer.step() metrics["loss"] = loss.item() meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: logger.info("Model Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, self.num_epochs, step, self.child_steps, meters) # Train sampler (mutator) self.model.eval() self.mutator.train() meters = AverageMeterGroup() for mutator_step in range(1, self.mutator_steps + 1): self.mutator_optim.zero_grad() for step in range(1, self.mutator_steps_aggregate + 1): x, y = next(self.valid_loader) x, y = to_device(x, self.device), to_device(y, self.device) self.mutator.reset() with torch.no_grad(): logits = self.model(x) metrics = self.metrics(logits, y) reward = self.reward_function(logits, y) if self.entropy_weight: reward += self.entropy_weight * self.mutator.sample_entropy.item( ) self.baseline = self.baseline * self.baseline_decay + reward * ( 1 - self.baseline_decay) loss = self.mutator.sample_log_prob * (reward - self.baseline) if self.skip_weight: loss += self.skip_weight * self.mutator.sample_skip_penalty metrics["reward"] = reward metrics["loss"] = loss.item() metrics["ent"] = self.mutator.sample_entropy.item() metrics["log_prob"] = self.mutator.sample_log_prob.item() metrics["baseline"] = self.baseline metrics["skip"] = self.mutator.sample_skip_penalty loss /= self.mutator_steps_aggregate loss.backward() meters.update(metrics) cur_step = step + (mutator_step - 1) * self.mutator_steps_aggregate if self.log_frequency is not None and cur_step % self.log_frequency == 0: logger.info("RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s", epoch + 1, self.num_epochs, mutator_step, self.mutator_steps, step, self.mutator_steps_aggregate, meters) nn.utils.clip_grad_norm_(self.mutator.parameters(), 5.) self.mutator_optim.step()
def train_one_epoch(self, epoch): self.current_epoch = epoch meters = AverageMeterGroup() self.steps_per_epoch = len(self.train_loader) for step, (input_data, target) in enumerate(self.train_loader): self.mutator.reset() self.current_student_arch = self.mutator._cache input_data, target = input_data.cuda(), target.cuda() # calculate flops of current architecture cand_flops = self._get_cand_flops(self.mutator._cache) # update meta matching network self._run_update(input_data, target, step) if self._board_size() > 0: # select teacher architecture meta_value, teacher_cand = self._select_teacher() self.current_teacher_arch = teacher_cand # forward supernet if self._board_size() == 0 or epoch <= self.meta_sta_epoch: self._replace_mutator_cand(self.current_student_arch) output = self.model(input_data) loss = self.loss(output, target) kd_loss, teacher_output, teacher_cand = None, None, None else: self._replace_mutator_cand(self.current_student_arch) output = self.model(input_data) gt_loss = self.loss(output, target) with torch.no_grad(): self._replace_mutator_cand(self.current_teacher_arch) teacher_output = self.model(input_data).detach() soft_label = torch.nn.functional.softmax(teacher_output, dim=1) kd_loss = self._cross_entropy_loss_with_soft_target( output, soft_label) loss = (meta_value * kd_loss + (2 - meta_value) * gt_loss) / 2 # update network self.optimizer.zero_grad() loss.backward() self.optimizer.step() # update metrics prec1, prec5 = accuracy(output, target, topk=(1, 5)) metrics = {"prec1": prec1, "prec5": prec5, "loss": loss} metrics = reduce_metrics(metrics) meters.update(metrics) # update prioritized board self._update_prioritized_board(input_data, teacher_output, output, metrics['prec1'], cand_flops) if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch): logger.info("Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, self.num_epochs, step + 1, len(self.train_loader), meters) if self.main_proc and self.num_epochs == epoch + 1: for idx, i in enumerate(self.prioritized_board): logger.info("No.%s %s", idx, i[:4])
def train_one_epoch(self, epoch): # Sample model and train self.model.train() self.mutator.eval() meters = AverageMeterGroup() # COMMENT: 先训练模型 for step in range(1, self.child_steps + 1): x, y = next(self.train_loader) x, y = to_device(x, self.device), to_device(y, self.device) self.optimizer.zero_grad() with torch.no_grad(): self.mutator.reset() self._write_graph_status() logits = self.model(x) if isinstance(logits, tuple): logits, aux_logits = logits aux_loss = self.loss(aux_logits, y) else: aux_loss = 0. metrics = self.metrics(logits, y) # 计算acc loss = self.loss(logits, y) # 计算loss loss = loss + self.aux_weight * aux_loss loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), 5.) self.optimizer.step() metrics["loss"] = loss.item() meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: logger.info("Model Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, self.num_epochs, step, self.child_steps, meters) # Train sampler (mutator) self.model.eval() self.mutator.train() # 然后训练变化器,突变器 meters = AverageMeterGroup() for mutator_step in range(1, self.mutator_steps + 1): self.mutator_optim.zero_grad() for step in range(1, self.mutator_steps_aggregate + 1): x, y = next(self.valid_loader) x, y = to_device(x, self.device), to_device(y, self.device) self.mutator.reset() with torch.no_grad(): logits = self.model(x) self._write_graph_status() # 得到acc metrics = self.metrics(logits, y) # 得到reward ''' def reward_accuracy(output, target, topk=(1,)): batch_size = target.size(0) _, predicted = torch.max(output.data, 1) return (predicted == target).sum().item() / batch_size ''' reward = self.reward_function(logits, y) # 当前这个batch正确的个数 if self.entropy_weight: # 交叉熵权重 reward += self.entropy_weight * self.mutator.sample_entropy.item() # 得到样本熵 self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay) # 有点policy gradient的感觉了 loss = self.mutator.sample_log_prob * (reward - self.baseline) if self.skip_weight: loss += self.skip_weight * self.mutator.sample_skip_penalty metrics["reward"] = reward metrics["loss"] = loss.item() metrics["ent"] = self.mutator.sample_entropy.item() metrics["log_prob"] = self.mutator.sample_log_prob.item() metrics["baseline"] = self.baseline metrics["skip"] = self.mutator.sample_skip_penalty loss /= self.mutator_steps_aggregate loss.backward() meters.update(metrics) cur_step = step + (mutator_step - 1) * self.mutator_steps_aggregate if self.log_frequency is not None and cur_step % self.log_frequency == 0: logger.info("RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s", epoch + 1, self.num_epochs, mutator_step, self.mutator_steps, step, self.mutator_steps_aggregate, meters) nn.utils.clip_grad_norm_(self.mutator.parameters(), 5.) self.mutator_optim.step()
def train(logger, config, train_loader, model, optimizer, criterion, epoch, main_proc, fake_batch=4, steps=128): meters = AverageMeterGroup() cur_lr = optimizer.param_groups[0]["lr"] if main_proc: logger.info("Epoch %d LR %.6f", epoch, cur_lr) model.train() for step in range(steps): totall_l = 0 totall_p = 0 def top(totall_l, totall_p): optimizer.zero_grad() def s(totall_l, totall_p): x, y = next(train_loader) x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True) logits, aux_logits = model(x) loss = criterion(logits, y) if config.aux_weight > 0.: loss += config.aux_weight * criterion(aux_logits, y) loss = loss / fake_batch try: prec1, prec1 = utils.accuracy(logits, y, topk=(1, 1)) prec1 = prec1 / fake_batch totall_p += prec1 loss.backward(retain_graph=True) totall_l += float(loss) except: print("Err") pass return totall_l, totall_p for fb in range(fake_batch): totall_l, totall_p = s(totall_l, totall_p) optimizer.step() return totall_l, totall_p totall_l, totall_p = top(totall_l, totall_p) nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) metrics = {"prec1": totall_p, "loss": totall_l} # metrics = utils.reduce_metrics(metrics, config.distributed) meters.update(metrics) if main_proc and (step % config.log_frequency == 0 or step + 1 == len(train_loader)): torch.save( { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': totall_l, 'model': model }, 'model_' + str(epoch) + '.pt') logger.info("Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, config.epochs, step + 1, steps, meters) if main_proc: logger.info("Train: [%d/%d] Final Prec@1 %.4f", epoch + 1, config.epochs, meters.prec1.avg)