def run_one_epoch(epoch, loader, model, criterion, optimizer, lr_scheduler, ema, meters, max_iter=None, phase='train'): """Run one epoch.""" assert phase in ['train', 'val', 'test', 'bn_calibration' ], "phase not be in train/val/test/bn_calibration." train = phase == 'train' if train: model.train() else: model.eval() if phase == 'bn_calibration': model.apply(bn_calibration) if FLAGS.use_distributed: loader.sampler.set_epoch(epoch) results = None data_iterator = iter(loader) if FLAGS.use_distributed: data_fetcher = dataflow.DataPrefetcher(data_iterator) else: # TODO(meijieru): prefetch for non distributed logging.warning('Not use prefetcher') data_fetcher = data_iterator for batch_idx, (input, target) in enumerate(data_fetcher): # used for bn calibration if max_iter is not None: assert phase == 'bn_calibration' if batch_idx >= max_iter: break target = target.cuda(non_blocking=True) if train: optimizer.zero_grad() loss = mc.forward_loss(model, criterion, input, target, meters) loss_l2 = optim.cal_l2_loss(model, FLAGS.weight_decay, FLAGS.weight_decay_method) loss = loss + loss_l2 loss.backward() if FLAGS.use_distributed: udist.allreduce_grads(model) if FLAGS._global_step % FLAGS.log_interval == 0: results = mc.reduce_and_flush_meters(meters) if udist.is_master(): logging.info('Epoch {}/{} Iter {}/{} {}: '.format( epoch, FLAGS.num_epochs, batch_idx, len(loader), phase) + ', '.join('{}: {:.4f}'.format(k, v) for k, v in results.items())) for k, v in results.items(): mc.summary_writer.add_scalar('{}/{}'.format(phase, k), v, FLAGS._global_step) if udist.is_master( ) and FLAGS._global_step % FLAGS.log_interval == 0: mc.summary_writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], FLAGS._global_step) mc.summary_writer.add_scalar('train/l2_regularize_loss', extract_item(loss_l2), FLAGS._global_step) mc.summary_writer.add_scalar( 'train/current_epoch', FLAGS._global_step / FLAGS._steps_per_epoch, FLAGS._global_step) if FLAGS.data_loader_workers > 0: mc.summary_writer.add_scalar( 'data/train/prefetch_size', get_data_queue_size(data_iterator), FLAGS._global_step) optimizer.step() lr_scheduler.step() if FLAGS.use_distributed and FLAGS.allreduce_bn: udist.allreduce_bn(model) FLAGS._global_step += 1 # NOTE: after steps count upate if ema is not None: model_unwrap = mc.unwrap_model(model) ema_names = ema.average_names() params = get_params_by_name(model_unwrap, ema_names) for name, param in zip(ema_names, params): ema(name, param, FLAGS._global_step) else: mc.forward_loss(model, criterion, input, target, meters) if not train: results = mc.reduce_and_flush_meters(meters) if udist.is_master(): logging.info( 'Epoch {}/{} {}: '.format(epoch, FLAGS.num_epochs, phase) + ', '.join('{}: {:.4f}'.format(k, v) for k, v in results.items())) for k, v in results.items(): mc.summary_writer.add_scalar('{}/{}'.format(phase, k), v, FLAGS._global_step) return results
def run_one_epoch(epoch, loader, model, criterion, optimizer, lr_scheduler, ema, rho_scheduler, meters, max_iter=None, phase='train'): """Run one epoch.""" assert phase in [ 'train', 'val', 'test', 'bn_calibration' ] or phase.startswith( 'prune'), "phase not be in train/val/test/bn_calibration/prune." train = phase == 'train' if train: model.train() else: model.eval() if phase == 'bn_calibration': model.apply(bn_calibration) if not FLAGS.use_hdfs: if FLAGS.use_distributed: loader.sampler.set_epoch(epoch) results = None data_iterator = iter(loader) if not FLAGS.use_hdfs: if FLAGS.use_distributed: if FLAGS.dataset == 'coco': data_fetcher = dataflow.DataPrefetcherKeypoint(data_iterator) else: data_fetcher = dataflow.DataPrefetcher(data_iterator) else: logging.warning('Not use prefetcher') data_fetcher = data_iterator for batch_idx, data in enumerate(data_fetcher): if FLAGS.dataset == 'coco': input, target, target_weight, meta = data # print(input.shape, target.shape, target_weight.shape, meta) # (4, 3, 384, 288), (4, 17, 96, 72), (4, 17, 1), else: input, target = data # if batch_idx > 400: # break # used for bn calibration if max_iter is not None: assert phase == 'bn_calibration' if batch_idx >= max_iter: break target = target.cuda(non_blocking=True) if train: optimizer.zero_grad() rho = rho_scheduler(FLAGS._global_step) if FLAGS.dataset == 'coco': outputs = model(input) if isinstance(outputs, list): loss = criterion(outputs[0], target, target_weight) for output in outputs[1:]: loss += criterion(output, target, target_weight) else: output = outputs loss = criterion(output, target, target_weight) _, avg_acc, cnt, pred = accuracy_keypoint( output.detach().cpu().numpy(), target.detach().cpu().numpy()) # cnt=17 meters['acc'].cache(avg_acc) meters['loss'].cache(loss) else: loss = mc.forward_loss(model, criterion, input, target, meters, task=FLAGS.model_kwparams.task, distill=FLAGS.distill) if FLAGS.prune_params['method'] is not None: loss_l2 = optim.cal_l2_loss( model, FLAGS.weight_decay, FLAGS.weight_decay_method) # manual weight decay loss_bn_l1 = prune.cal_bn_l1_loss(get_prune_weights(model), FLAGS._bn_to_prune.penalty, rho) if FLAGS.prune_params.use_transformer: transformer_weights = get_prune_weights(model, True) loss_bn_l1 += prune.cal_bn_l1_loss( transformer_weights, FLAGS._bn_to_prune_transformer.penalty, rho) transformer_dict = [] for name, weight in zip( FLAGS._bn_to_prune_transformer.weight, transformer_weights): transformer_dict.append( sum(weight > FLAGS.model_shrink_threshold).item()) FLAGS._bn_to_prune_transformer.add_info_list( 'channels', transformer_dict) FLAGS._bn_to_prune_transformer.update_penalty() if udist.is_master( ) and FLAGS._global_step % FLAGS.log_interval == 0: logging.info(transformer_dict) # logging.info(FLAGS._bn_to_prune_transformer.penalty) meters['loss_l2'].cache(loss_l2) meters['loss_bn_l1'].cache(loss_bn_l1) loss = loss + loss_l2 + loss_bn_l1 loss.backward() if FLAGS.use_distributed: udist.allreduce_grads(model) if FLAGS._global_step % FLAGS.log_interval == 0: results = mc.reduce_and_flush_meters(meters) if udist.is_master(): logging.info('Epoch {}/{} Iter {}/{} Lr: {} {}: '.format( epoch, FLAGS.num_epochs, batch_idx, len(loader), optimizer.param_groups[0]["lr"], phase) + ', '.join('{}: {:.4f}'.format(k, v) for k, v in results.items())) for k, v in results.items(): mc.summary_writer.add_scalar('{}/{}'.format(phase, k), v, FLAGS._global_step) if udist.is_master( ) and FLAGS._global_step % FLAGS.log_interval == 0: mc.summary_writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], FLAGS._global_step) if FLAGS.prune_params['method'] is not None: mc.summary_writer.add_scalar('train/l2_regularize_loss', extract_item(loss_l2), FLAGS._global_step) mc.summary_writer.add_scalar('train/bn_l1_loss', extract_item(loss_bn_l1), FLAGS._global_step) mc.summary_writer.add_scalar('prune/rho', rho, FLAGS._global_step) mc.summary_writer.add_scalar( 'train/current_epoch', FLAGS._global_step / FLAGS._steps_per_epoch, FLAGS._global_step) if FLAGS.data_loader_workers > 0: mc.summary_writer.add_scalar( 'data/train/prefetch_size', get_data_queue_size(data_iterator), FLAGS._global_step) if udist.is_master( ) and FLAGS._global_step % FLAGS.log_interval_detail == 0: summary_bn(model, 'train') optimizer.step() if FLAGS.lr_scheduler == 'poly': optim.poly_learning_rate( optimizer, FLAGS.lr, epoch * FLAGS._steps_per_epoch + batch_idx + 1, FLAGS.num_epochs * FLAGS._steps_per_epoch) else: lr_scheduler.step() if FLAGS.use_distributed and FLAGS.allreduce_bn: udist.allreduce_bn(model) FLAGS._global_step += 1 # NOTE: after steps count update if ema is not None: model_unwrap = mc.unwrap_model(model) ema_names = ema.average_names() params = get_params_by_name(model_unwrap, ema_names) for name, param in zip(ema_names, params): ema(name, param, FLAGS._global_step) else: if FLAGS.dataset == 'coco': outputs = model(input) if isinstance(outputs, list): loss = criterion(outputs[0], target, target_weight) for output in outputs[1:]: loss += criterion(output, target, target_weight) else: output = outputs loss = criterion(output, target, target_weight) _, avg_acc, cnt, pred = accuracy_keypoint( output.detach().cpu().numpy(), target.detach().cpu().numpy()) # cnt=17 meters['acc'].cache(avg_acc) meters['loss'].cache(loss) else: mc.forward_loss(model, criterion, input, target, meters, task=FLAGS.model_kwparams.task, distill=False) if not train: results = mc.reduce_and_flush_meters(meters) if udist.is_master(): logging.info( 'Epoch {}/{} {}: '.format(epoch, FLAGS.num_epochs, phase) + ', '.join('{}: {:.4f}'.format(k, v) for k, v in results.items())) for k, v in results.items(): mc.summary_writer.add_scalar('{}/{}'.format(phase, k), v, FLAGS._global_step) return results
def run_one_epoch( epoch, loader, model, criterion, optimizer, meters, phase='train', scheduler=None): """run one epoch for train/val/test""" t_start = time.time() assert phase in ['train', 'val', 'test'], "phase not be in train/val/test." train = phase == 'train' if train: model.train() else: model.eval() if getattr(FLAGS, 'distributed', False): loader.sampler.set_epoch(epoch) for batch_idx, (input, target) in enumerate(loader): target = target.cuda(non_blocking=True) if train: if FLAGS.lr_scheduler == 'linear_decaying': linear_decaying_per_step = ( FLAGS.lr/FLAGS.num_epochs/len(loader.dataset)*FLAGS.batch_size) for param_group in optimizer.param_groups: param_group['lr'] -= linear_decaying_per_step # For PyTorch 1.1+, comment the following two line #if FLAGS.lr_scheduler in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter', 'multistep_iter']: # scheduler.step() optimizer.zero_grad() if getattr(FLAGS, 'adaptive_training', False): for bits_idx, bits in enumerate(FLAGS.bits_list): model.apply( lambda m: setattr(m, 'bits', bits)) if is_master(): meter = meters[str(bits)] else: meter = None loss = forward_loss( model, criterion, input, target, meter) loss.backward() else: loss = forward_loss( model, criterion, input, target, meters) loss.backward() if getattr(FLAGS, 'distributed', False) and getattr(FLAGS, 'distributed_all_reduce', False): allreduce_grads(model) optimizer.step() # For PyTorch 1.0 or earlier, comment the following two lines if FLAGS.lr_scheduler in ['exp_decaying_iter', 'cos_annealing_iter', 'multistep_iter']: scheduler.step() else: #not train if getattr(FLAGS, 'adaptive_training', False): for bits_idx, bits in enumerate(FLAGS.bits_list): model.apply( lambda m: setattr(m, 'bits', bits)) if is_master() and meters is not None: meter = meters[str(bits)] else: meter = None forward_loss( model, criterion, input, target, meter) else: forward_loss(model, criterion, input, target, meters) val_top1 = None if is_master() and meters is not None: if getattr(FLAGS, 'adaptive_training', False): val_top1_list = [] for bits in FLAGS.bits_list: results = flush_scalar_meters(meters[str(bits)]) mprint('{:.1f}s\t{}\t{} bits\t{}/{}: '.format( time.time() - t_start, phase, bits, epoch, FLAGS.num_epochs) + ', '.join('{}: {}'.format(k, v) for k, v in results.items())) val_top1_list.append(results['top1_error']) val_top1 = np.mean(val_top1_list) else: results = flush_scalar_meters(meters) mprint('{:.1f}s\t{}\t{}/{}: '.format( time.time() - t_start, phase, epoch, FLAGS.num_epochs) + ', '.join('{}: {}'.format(k, v) for k, v in results.items())) val_top1 = results['top1_error'] return val_top1
def run_one_epoch( epoch, loader, model, criterion, optimizer, meters, phase='train', ema=None, scheduler=None): """run one epoch for train/val/test/cal""" t_start = time.time() assert phase in ['train', 'val', 'test', 'cal'], "phase not be in train/val/test/cal." train = phase == 'train' if train: model.train() else: model.eval() if getattr(FLAGS, 'distributed', False): loader.sampler.set_epoch(epoch) for batch_idx, (input, target) in enumerate(loader): if phase == 'cal': if batch_idx == getattr(FLAGS, 'bn_cal_batch_num', -1): break target = target.cuda(non_blocking=True) if train: if FLAGS.lr_scheduler == 'linear_decaying': linear_decaying_per_step = ( FLAGS.lr/FLAGS.num_epochs/len(loader.dataset)*FLAGS.batch_size) for param_group in optimizer.param_groups: param_group['lr'] -= linear_decaying_per_step # For PyTorch 1.1+, comment the following two line #if FLAGS.lr_scheduler in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter']: # scheduler.step() optimizer.zero_grad() loss = forward_loss( model, criterion, input, target, meters) if epoch >= FLAGS.warmup_epochs and not getattr(FLAGS,'hard_assignment', False): if getattr(FLAGS,'weight_only', False): loss += getattr(FLAGS, 'kappa', 1.0) * get_model_size_loss(model) else: loss += getattr(FLAGS, 'kappa', 1.0) * get_comp_cost_loss(model) loss.backward() if getattr(FLAGS, 'distributed', False) and getattr(FLAGS, 'distributed_all_reduce', False): allreduce_grads(model) optimizer.step() # For PyTorch 1.0 or earlier, comment the following two lines if FLAGS.lr_scheduler in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter']: scheduler.step() if ema: ema.shadow_update(model) #for name, param in model.named_parameters(): # if param.requires_grad: # ema.update(name, param.data) #bn_idx = 0 #for m in model.modules(): # if isinstance(m, nn.BatchNorm2d): # ema.update('bn{}_mean'.format(bn_idx), m.running_mean) # ema.update('bn{}_var'.format(bn_idx), m.running_var) # bn_idx += 1 else: #not train if ema: mprint('ema apply') ema.shadow_apply(model) forward_loss(model, criterion, input, target, meters) if ema: mprint('ema recover') ema.weight_recover(model) val_top1 = None if is_master(): results = flush_scalar_meters(meters) mprint('{:.1f}s\t{}\t{}/{}: '.format( time.time() - t_start, phase, epoch, FLAGS.num_epochs) + ', '.join('{}: {}'.format(k, v) for k, v in results.items())) val_top1 = results['top1_error'] return val_top1
def run_one_epoch( epoch, loader, model, criterion, optimizer, meters, phase='train', soft_criterion=None): """run one epoch for train/val/test/cal""" t_start = time.time() assert phase in ['train', 'val', 'test', 'cal'], 'Invalid phase.' train = phase == 'train' if train: model.train() else: model.eval() if phase == 'cal': model.apply(bn_calibration_init) # change learning rate in each iteration if getattr(FLAGS, 'universally_slimmable_training', False): max_width = FLAGS.width_mult_range[1] min_width = FLAGS.width_mult_range[0] elif getattr(FLAGS, 'slimmable_training', False): max_width = max(FLAGS.width_mult_list) min_width = min(FLAGS.width_mult_list) if getattr(FLAGS, 'distributed', False): loader.sampler.set_epoch(epoch) for batch_idx, (input, target) in enumerate(loader): if phase == 'cal': if batch_idx == getattr(FLAGS, 'bn_cal_batch_num', -1): break target = target.cuda(non_blocking=True) if train: # change learning rate if necessary lr_schedule_per_iteration(optimizer, epoch, batch_idx) optimizer.zero_grad() if getattr(FLAGS, 'slimmable_training', False): if getattr(FLAGS, 'universally_slimmable_training', False): # universally slimmable model (us-nets) widths_train = [] for _ in range(getattr(FLAGS, 'num_sample_training', 2)-2): widths_train.append( random.uniform(min_width, max_width)) widths_train = [max_width, min_width] + widths_train for width_mult in widths_train: # the sandwich rule if width_mult in [max_width, min_width]: model.apply( lambda m: setattr(m, 'width_mult', width_mult)) elif getattr(FLAGS, 'nonuniform', False): model.apply(lambda m: setattr( m, 'width_mult', lambda: random.uniform(min_width, max_width))) else: model.apply(lambda m: setattr( m, 'width_mult', width_mult)) # always track largest model and smallest model if is_master() and width_mult in [ max_width, min_width]: meter = meters[str(width_mult)] else: meter = None # inplace distillation if width_mult == max_width: loss, soft_target = forward_loss( model, criterion, input, target, meter, return_soft_target=True) else: if getattr(FLAGS, 'inplace_distill', False): loss = forward_loss( model, criterion, input, target, meter, soft_target=soft_target.detach(), soft_criterion=soft_criterion) else: loss = forward_loss( model, criterion, input, target, meter) loss.backward() else: # slimmable model (s-nets) for width_mult in sorted( FLAGS.width_mult_list, reverse=True): model.apply( lambda m: setattr(m, 'width_mult', width_mult)) if is_master(): meter = meters[str(width_mult)] else: meter = None if width_mult == max_width: loss, soft_target = forward_loss( model, criterion, input, target, meter, return_soft_target=True) else: if getattr(FLAGS, 'inplace_distill', False): loss = forward_loss( model, criterion, input, target, meter, soft_target=soft_target.detach(), soft_criterion=soft_criterion) else: loss = forward_loss( model, criterion, input, target, meter) loss.backward() else: loss = forward_loss( model, criterion, input, target, meters) loss.backward() if (getattr(FLAGS, 'distributed', False) and getattr(FLAGS, 'distributed_all_reduce', False)): allreduce_grads(model) optimizer.step() if is_master() and getattr(FLAGS, 'slimmable_training', False): for width_mult in sorted(FLAGS.width_mult_list, reverse=True): meter = meters[str(width_mult)] meter['lr'].cache(optimizer.param_groups[0]['lr']) elif is_master(): meters['lr'].cache(optimizer.param_groups[0]['lr']) else: pass else: if getattr(FLAGS, 'slimmable_training', False): for width_mult in sorted(FLAGS.width_mult_list, reverse=True): model.apply( lambda m: setattr(m, 'width_mult', width_mult)) if is_master(): meter = meters[str(width_mult)] else: meter = None forward_loss(model, criterion, input, target, meter) else: forward_loss(model, criterion, input, target, meters) if is_master() and getattr(FLAGS, 'slimmable_training', False): for width_mult in sorted(FLAGS.width_mult_list, reverse=True): results = flush_scalar_meters(meters[str(width_mult)]) print('{:.1f}s\t{}\t{}\t{}/{}: '.format( time.time() - t_start, phase, str(width_mult), epoch, FLAGS.num_epochs) + ', '.join( '{}: {:.3f}'.format(k, v) for k, v in results.items())) elif is_master(): results = flush_scalar_meters(meters) print( '{:.1f}s\t{}\t{}/{}: '.format( time.time() - t_start, phase, epoch, FLAGS.num_epochs) + ', '.join('{}: {:.3f}'.format(k, v) for k, v in results.items())) else: results = None return results
def run_one_epoch(epoch, loader, model, criterion, optimizer, meters, phase='train', ema=None, scheduler=None, eta=None, epoch_dict=None, single_sample=False): """run one epoch for train/val/test/cal""" t_start = time.time() assert phase in ['train', 'val', 'test', 'cal'], "phase not be in train/val/test/cal." train = phase == 'train' if train: model.train() else: model.eval() #if getattr(FLAGS, 'bn_calib', False) and phase == 'val' and epoch < FLAGS.num_epochs - 10: # model.apply(bn_calibration) #if getattr(FLAGS, 'bn_calib_stoch_valid', False): # model.apply(bn_calibration) if phase == 'cal': model.apply(bn_calibration) if getattr(FLAGS, 'distributed', False): loader.sampler.set_epoch(epoch) scale_dict = {} if getattr(FLAGS, 'switch_lr', False): scale_dict = { 32: 1.0, 16: 1.0, 8: 1.0, 6: 1.0, 5: 1.0, 4: 1.02, 3: 1.08, 2: 1.62, 1: 4.83 } for batch_idx, (input, target) in enumerate(loader): if phase == 'cal': if batch_idx == getattr(FLAGS, 'bn_cal_batch_num', -1): break target = target.cuda(non_blocking=True) if train: if FLAGS.lr_scheduler == 'linear_decaying': linear_decaying_per_step = (FLAGS.lr / FLAGS.num_epochs / len(loader.dataset) * FLAGS.batch_size) for param_group in optimizer.param_groups: param_group['lr'] -= linear_decaying_per_step # For PyTorch 1.1+, comment the following two line #if FLAGS.lr_scheduler in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter']: # scheduler.step() optimizer.zero_grad() if getattr(FLAGS, 'quantizable_training', False) and not single_sample: for bits_idx, bits in enumerate(FLAGS.bits_list): model.apply(lambda m: setattr(m, 'bits', bits)) if is_master(): meter = meters[str(bits)] else: meter = None loss = forward_loss(model, criterion, input, target, meter) if eta is not None: #if isinstance(bits, (list, tuple)): # bitw = bits[0] #else: # bitw = bits #loss *= eta(bitw) loss *= eta(_pair(bits)[0]) if getattr(FLAGS, 'switch_lr', False): #mprint(scale_dict[_pair(bits)[0]]) loss *= scale_dict[_pair(bits)[0]] if epoch_dict is None: loss.backward() else: epoch_valid = epoch_dict[_pair(bits)[0]] if isinstance(epoch_valid, (list, tuple)): epoch_start, epoch_end = epoch_valid else: epoch_start = epoch_valid epoch_end = 1.0 epoch_start = int(FLAGS.num_epochs * epoch_start) epoch_end = int(FLAGS.num_epochs * epoch_end) if epoch_start <= epoch and epoch < epoch_end: loss.backward() if getattr(FLAGS, 'print_grad_std', False): mprint(f'bits: {bits}') layer_idx = 0 grad_std_list = [] for m in model.modules(): #if getattr(m, 'weight', None) is not None: if isinstance( m, (QuantizableConv2d, QuantizableLinear)): grad_std = torch.std(m.weight.grad) mprint(f'layer_{layer_idx} grad: {grad_std}' ) #, module: {m}') grad_std_list.append(grad_std) layer_idx += 1 mprint( f'average grad std: {torch.mean(torch.tensor(grad_std_list))}' ) else: loss = forward_loss(model, criterion, input, target, meters) loss.backward() if getattr(FLAGS, 'distributed', False) and getattr( FLAGS, 'distributed_all_reduce', False): allreduce_grads(model) optimizer.step() # For PyTorch 1.0 or earlier, comment the following two lines if FLAGS.lr_scheduler in [ 'exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter' ]: scheduler.step() if ema: ema.shadow_update(model) #for name, param in model.named_parameters(): # if param.requires_grad: # ema.update(name, param.data) #bn_idx = 0 #for m in model.modules(): # if isinstance(m, nn.BatchNorm2d): # ema.update('bn{}_mean'.format(bn_idx), m.running_mean) # ema.update('bn{}_var'.format(bn_idx), m.running_var) # bn_idx += 1 else: #not train if ema: ema.shadow_apply(model) if getattr(FLAGS, 'quantizable_training', False) and not single_sample: for bits_idx, bits in enumerate(FLAGS.bits_list): model.apply(lambda m: setattr(m, 'bits', bits)) #model.apply( # lambda m: setattr(m, 'threshold', FLAGS.schmitt_threshold * (0.0 * (epoch <= 30) + 0.01 * (30 < epoch <= 60) + 0.1 * (60 < epoch <= 90) + 1.0 * (90 < epoch)))) #model.apply( # lambda m: setattr(m, 'threshold', epoch * FLAGS.schmitt_threshold / FLAGS.num_epochs)) if is_master(): meter = meters[str(bits)] else: meter = None forward_loss(model, criterion, input, target, meter) else: forward_loss(model, criterion, input, target, meters) if ema: ema.weight_recover(model) ##opt_loss = float('inf') ##opt_results = None val_top1 = None if is_master(): if getattr(FLAGS, 'quantizable_training', False) and not single_sample: #results_dict = {} val_top1_list = [] for bits in FLAGS.bits_list: results = flush_scalar_meters(meters[str(bits)]) mprint('{:.1f}s\t{}\t{} bits\t{}/{}: '.format( time.time() - t_start, phase, bits, epoch, FLAGS.num_epochs) + ', '.join('{}: {}'.format(k, v) for k, v in results.items())) #results_dict[str(bits)] = results ##if results['loss'] < opt_loss: ## opt_results = results ## opt_loss = results['loss'] val_top1_list.append(results['top1_error']) #results = results_dict val_top1 = np.mean(val_top1_list) else: results = flush_scalar_meters(meters) mprint('{:.1f}s\t{}\t{}/{}: '.format(time.time() - t_start, phase, epoch, FLAGS.num_epochs) + ', '.join('{}: {}'.format(k, v) for k, v in results.items())) ##if results['loss'] < opt_loss: ## opt_results = results ## opt_loss = results['loss'] val_top1 = results['top1_error'] ##return opt_results #return results return val_top1