Exemple #1
0
def run_one_epoch(epoch,
                  loader,
                  model,
                  criterion,
                  optimizer,
                  lr_scheduler,
                  ema,
                  meters,
                  max_iter=None,
                  phase='train'):
    """Run one epoch."""
    assert phase in ['train', 'val', 'test', 'bn_calibration'
                     ], "phase not be in train/val/test/bn_calibration."
    train = phase == 'train'
    if train:
        model.train()
    else:
        model.eval()
    if phase == 'bn_calibration':
        model.apply(bn_calibration)

    if FLAGS.use_distributed:
        loader.sampler.set_epoch(epoch)

    results = None
    data_iterator = iter(loader)
    if FLAGS.use_distributed:
        data_fetcher = dataflow.DataPrefetcher(data_iterator)
    else:
        # TODO(meijieru): prefetch for non distributed
        logging.warning('Not use prefetcher')
        data_fetcher = data_iterator
    for batch_idx, (input, target) in enumerate(data_fetcher):
        # used for bn calibration
        if max_iter is not None:
            assert phase == 'bn_calibration'
            if batch_idx >= max_iter:
                break

        target = target.cuda(non_blocking=True)
        if train:
            optimizer.zero_grad()
            loss = mc.forward_loss(model, criterion, input, target, meters)
            loss_l2 = optim.cal_l2_loss(model, FLAGS.weight_decay,
                                        FLAGS.weight_decay_method)
            loss = loss + loss_l2
            loss.backward()
            if FLAGS.use_distributed:
                udist.allreduce_grads(model)

            if FLAGS._global_step % FLAGS.log_interval == 0:
                results = mc.reduce_and_flush_meters(meters)
                if udist.is_master():
                    logging.info('Epoch {}/{} Iter {}/{} {}: '.format(
                        epoch, FLAGS.num_epochs, batch_idx, len(loader), phase)
                                 + ', '.join('{}: {:.4f}'.format(k, v)
                                             for k, v in results.items()))
                    for k, v in results.items():
                        mc.summary_writer.add_scalar('{}/{}'.format(phase, k),
                                                     v, FLAGS._global_step)
            if udist.is_master(
            ) and FLAGS._global_step % FLAGS.log_interval == 0:
                mc.summary_writer.add_scalar('train/learning_rate',
                                             optimizer.param_groups[0]['lr'],
                                             FLAGS._global_step)
                mc.summary_writer.add_scalar('train/l2_regularize_loss',
                                             extract_item(loss_l2),
                                             FLAGS._global_step)
                mc.summary_writer.add_scalar(
                    'train/current_epoch',
                    FLAGS._global_step / FLAGS._steps_per_epoch,
                    FLAGS._global_step)
                if FLAGS.data_loader_workers > 0:
                    mc.summary_writer.add_scalar(
                        'data/train/prefetch_size',
                        get_data_queue_size(data_iterator), FLAGS._global_step)

            optimizer.step()
            lr_scheduler.step()
            if FLAGS.use_distributed and FLAGS.allreduce_bn:
                udist.allreduce_bn(model)
            FLAGS._global_step += 1

            # NOTE: after steps count upate
            if ema is not None:
                model_unwrap = mc.unwrap_model(model)
                ema_names = ema.average_names()
                params = get_params_by_name(model_unwrap, ema_names)
                for name, param in zip(ema_names, params):
                    ema(name, param, FLAGS._global_step)
        else:
            mc.forward_loss(model, criterion, input, target, meters)

    if not train:
        results = mc.reduce_and_flush_meters(meters)
        if udist.is_master():
            logging.info(
                'Epoch {}/{} {}: '.format(epoch, FLAGS.num_epochs, phase) +
                ', '.join('{}: {:.4f}'.format(k, v)
                          for k, v in results.items()))
            for k, v in results.items():
                mc.summary_writer.add_scalar('{}/{}'.format(phase, k), v,
                                             FLAGS._global_step)
    return results
Exemple #2
0
def run_one_epoch(epoch,
                  loader,
                  model,
                  criterion,
                  optimizer,
                  lr_scheduler,
                  ema,
                  rho_scheduler,
                  meters,
                  max_iter=None,
                  phase='train'):
    """Run one epoch."""
    assert phase in [
        'train', 'val', 'test', 'bn_calibration'
    ] or phase.startswith(
        'prune'), "phase not be in train/val/test/bn_calibration/prune."
    train = phase == 'train'
    if train:
        model.train()
    else:
        model.eval()
    if phase == 'bn_calibration':
        model.apply(bn_calibration)

    if not FLAGS.use_hdfs:
        if FLAGS.use_distributed:
            loader.sampler.set_epoch(epoch)

    results = None
    data_iterator = iter(loader)
    if not FLAGS.use_hdfs:
        if FLAGS.use_distributed:
            if FLAGS.dataset == 'coco':
                data_fetcher = dataflow.DataPrefetcherKeypoint(data_iterator)
            else:
                data_fetcher = dataflow.DataPrefetcher(data_iterator)
        else:
            logging.warning('Not use prefetcher')
            data_fetcher = data_iterator

    for batch_idx, data in enumerate(data_fetcher):
        if FLAGS.dataset == 'coco':
            input, target, target_weight, meta = data
            # print(input.shape, target.shape, target_weight.shape, meta)
            # (4, 3, 384, 288), (4, 17, 96, 72), (4, 17, 1),
        else:
            input, target = data
        # if batch_idx > 400:
        #     break
        # used for bn calibration
        if max_iter is not None:
            assert phase == 'bn_calibration'
            if batch_idx >= max_iter:
                break

        target = target.cuda(non_blocking=True)
        if train:
            optimizer.zero_grad()
            rho = rho_scheduler(FLAGS._global_step)

            if FLAGS.dataset == 'coco':
                outputs = model(input)
                if isinstance(outputs, list):
                    loss = criterion(outputs[0], target, target_weight)
                    for output in outputs[1:]:
                        loss += criterion(output, target, target_weight)
                else:
                    output = outputs
                    loss = criterion(output, target, target_weight)
                _, avg_acc, cnt, pred = accuracy_keypoint(
                    output.detach().cpu().numpy(),
                    target.detach().cpu().numpy())  # cnt=17
                meters['acc'].cache(avg_acc)
                meters['loss'].cache(loss)
            else:
                loss = mc.forward_loss(model,
                                       criterion,
                                       input,
                                       target,
                                       meters,
                                       task=FLAGS.model_kwparams.task,
                                       distill=FLAGS.distill)
            if FLAGS.prune_params['method'] is not None:
                loss_l2 = optim.cal_l2_loss(
                    model, FLAGS.weight_decay,
                    FLAGS.weight_decay_method)  # manual weight decay
                loss_bn_l1 = prune.cal_bn_l1_loss(get_prune_weights(model),
                                                  FLAGS._bn_to_prune.penalty,
                                                  rho)
                if FLAGS.prune_params.use_transformer:

                    transformer_weights = get_prune_weights(model, True)
                    loss_bn_l1 += prune.cal_bn_l1_loss(
                        transformer_weights,
                        FLAGS._bn_to_prune_transformer.penalty, rho)

                    transformer_dict = []
                    for name, weight in zip(
                            FLAGS._bn_to_prune_transformer.weight,
                            transformer_weights):
                        transformer_dict.append(
                            sum(weight > FLAGS.model_shrink_threshold).item())
                    FLAGS._bn_to_prune_transformer.add_info_list(
                        'channels', transformer_dict)
                    FLAGS._bn_to_prune_transformer.update_penalty()
                    if udist.is_master(
                    ) and FLAGS._global_step % FLAGS.log_interval == 0:
                        logging.info(transformer_dict)
                        # logging.info(FLAGS._bn_to_prune_transformer.penalty)

                meters['loss_l2'].cache(loss_l2)
                meters['loss_bn_l1'].cache(loss_bn_l1)
                loss = loss + loss_l2 + loss_bn_l1
            loss.backward()
            if FLAGS.use_distributed:
                udist.allreduce_grads(model)

            if FLAGS._global_step % FLAGS.log_interval == 0:
                results = mc.reduce_and_flush_meters(meters)
                if udist.is_master():
                    logging.info('Epoch {}/{} Iter {}/{} Lr: {} {}: '.format(
                        epoch, FLAGS.num_epochs, batch_idx, len(loader),
                        optimizer.param_groups[0]["lr"], phase) +
                                 ', '.join('{}: {:.4f}'.format(k, v)
                                           for k, v in results.items()))
                    for k, v in results.items():
                        mc.summary_writer.add_scalar('{}/{}'.format(phase, k),
                                                     v, FLAGS._global_step)

            if udist.is_master(
            ) and FLAGS._global_step % FLAGS.log_interval == 0:
                mc.summary_writer.add_scalar('train/learning_rate',
                                             optimizer.param_groups[0]['lr'],
                                             FLAGS._global_step)
                if FLAGS.prune_params['method'] is not None:
                    mc.summary_writer.add_scalar('train/l2_regularize_loss',
                                                 extract_item(loss_l2),
                                                 FLAGS._global_step)
                    mc.summary_writer.add_scalar('train/bn_l1_loss',
                                                 extract_item(loss_bn_l1),
                                                 FLAGS._global_step)
                mc.summary_writer.add_scalar('prune/rho', rho,
                                             FLAGS._global_step)
                mc.summary_writer.add_scalar(
                    'train/current_epoch',
                    FLAGS._global_step / FLAGS._steps_per_epoch,
                    FLAGS._global_step)
                if FLAGS.data_loader_workers > 0:
                    mc.summary_writer.add_scalar(
                        'data/train/prefetch_size',
                        get_data_queue_size(data_iterator), FLAGS._global_step)

            if udist.is_master(
            ) and FLAGS._global_step % FLAGS.log_interval_detail == 0:
                summary_bn(model, 'train')

            optimizer.step()
            if FLAGS.lr_scheduler == 'poly':
                optim.poly_learning_rate(
                    optimizer, FLAGS.lr,
                    epoch * FLAGS._steps_per_epoch + batch_idx + 1,
                    FLAGS.num_epochs * FLAGS._steps_per_epoch)
            else:
                lr_scheduler.step()
            if FLAGS.use_distributed and FLAGS.allreduce_bn:
                udist.allreduce_bn(model)
            FLAGS._global_step += 1

            # NOTE: after steps count update
            if ema is not None:
                model_unwrap = mc.unwrap_model(model)
                ema_names = ema.average_names()
                params = get_params_by_name(model_unwrap, ema_names)
                for name, param in zip(ema_names, params):
                    ema(name, param, FLAGS._global_step)
        else:
            if FLAGS.dataset == 'coco':
                outputs = model(input)
                if isinstance(outputs, list):
                    loss = criterion(outputs[0], target, target_weight)
                    for output in outputs[1:]:
                        loss += criterion(output, target, target_weight)
                else:
                    output = outputs
                    loss = criterion(output, target, target_weight)
                _, avg_acc, cnt, pred = accuracy_keypoint(
                    output.detach().cpu().numpy(),
                    target.detach().cpu().numpy())  # cnt=17
                meters['acc'].cache(avg_acc)
                meters['loss'].cache(loss)
            else:
                mc.forward_loss(model,
                                criterion,
                                input,
                                target,
                                meters,
                                task=FLAGS.model_kwparams.task,
                                distill=False)

    if not train:
        results = mc.reduce_and_flush_meters(meters)
        if udist.is_master():
            logging.info(
                'Epoch {}/{} {}: '.format(epoch, FLAGS.num_epochs, phase) +
                ', '.join('{}: {:.4f}'.format(k, v)
                          for k, v in results.items()))
            for k, v in results.items():
                mc.summary_writer.add_scalar('{}/{}'.format(phase, k), v,
                                             FLAGS._global_step)
    return results
Exemple #3
0
def run_one_epoch(
        epoch, loader, model, criterion, optimizer, meters, phase='train', scheduler=None):
    """run one epoch for train/val/test"""
    t_start = time.time()
    assert phase in ['train', 'val', 'test'], "phase not be in train/val/test."
    train = phase == 'train'
    if train:
        model.train()
    else:
        model.eval()

    if getattr(FLAGS, 'distributed', False):
        loader.sampler.set_epoch(epoch)

    for batch_idx, (input, target) in enumerate(loader):
        target = target.cuda(non_blocking=True)
        if train:
            if FLAGS.lr_scheduler == 'linear_decaying':
                linear_decaying_per_step = (
                    FLAGS.lr/FLAGS.num_epochs/len(loader.dataset)*FLAGS.batch_size)
                for param_group in optimizer.param_groups:
                    param_group['lr'] -= linear_decaying_per_step
            # For PyTorch 1.1+, comment the following two line
            #if FLAGS.lr_scheduler in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter', 'multistep_iter']:
            #    scheduler.step()
            optimizer.zero_grad()
            if getattr(FLAGS, 'adaptive_training', False):
                for bits_idx, bits in enumerate(FLAGS.bits_list):
                    model.apply(
                        lambda m: setattr(m, 'bits', bits))
                    if is_master():
                        meter = meters[str(bits)]
                    else:
                        meter = None
                    loss = forward_loss(
                        model, criterion, input, target, meter)
                    loss.backward()
            else:
                loss = forward_loss(
                    model, criterion, input, target, meters)
                loss.backward()
            if getattr(FLAGS, 'distributed', False) and getattr(FLAGS, 'distributed_all_reduce', False):
                allreduce_grads(model)
            optimizer.step()
            # For PyTorch 1.0 or earlier, comment the following two lines
            if FLAGS.lr_scheduler in ['exp_decaying_iter', 'cos_annealing_iter', 'multistep_iter']:
                scheduler.step()
        else: #not train
            if getattr(FLAGS, 'adaptive_training', False):
                for bits_idx, bits in enumerate(FLAGS.bits_list):
                    model.apply(
                        lambda m: setattr(m, 'bits', bits))
                    if is_master() and meters is not None:
                        meter = meters[str(bits)]
                    else:
                        meter = None
                    forward_loss(
                        model, criterion, input, target, meter)
            else:
                forward_loss(model, criterion, input, target, meters)
                
    val_top1 = None
    if is_master() and meters is not None:
        if getattr(FLAGS, 'adaptive_training', False):
            val_top1_list = []
            for bits in FLAGS.bits_list:
                results = flush_scalar_meters(meters[str(bits)])
                mprint('{:.1f}s\t{}\t{} bits\t{}/{}: '.format(
                    time.time() - t_start, phase, bits, epoch,
                    FLAGS.num_epochs) + ', '.join('{}: {}'.format(k, v)
                                                  for k, v in results.items()))
                val_top1_list.append(results['top1_error'])
            val_top1 = np.mean(val_top1_list)
        else:
            results = flush_scalar_meters(meters)
            mprint('{:.1f}s\t{}\t{}/{}: '.format(
                time.time() - t_start, phase, epoch, FLAGS.num_epochs) +
                  ', '.join('{}: {}'.format(k, v) for k, v in results.items()))
            val_top1 = results['top1_error']
    return val_top1
Exemple #4
0
def run_one_epoch(
        epoch, loader, model, criterion, optimizer, meters, phase='train', ema=None, scheduler=None):
    """run one epoch for train/val/test/cal"""
    t_start = time.time()
    assert phase in ['train', 'val', 'test', 'cal'], "phase not be in train/val/test/cal."
    train = phase == 'train'
    if train:
        model.train()
    else:
        model.eval()

    if getattr(FLAGS, 'distributed', False):
        loader.sampler.set_epoch(epoch)

    for batch_idx, (input, target) in enumerate(loader):
        if phase == 'cal':
            if batch_idx == getattr(FLAGS, 'bn_cal_batch_num', -1):
                break
        target = target.cuda(non_blocking=True)
        if train:
            if FLAGS.lr_scheduler == 'linear_decaying':
                linear_decaying_per_step = (
                    FLAGS.lr/FLAGS.num_epochs/len(loader.dataset)*FLAGS.batch_size)
                for param_group in optimizer.param_groups:
                    param_group['lr'] -= linear_decaying_per_step
            # For PyTorch 1.1+, comment the following two line
            #if FLAGS.lr_scheduler in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter']:
            #    scheduler.step()
            optimizer.zero_grad()
            loss = forward_loss(
                model, criterion, input, target, meters)
            if epoch >= FLAGS.warmup_epochs and not getattr(FLAGS,'hard_assignment', False):
              if getattr(FLAGS,'weight_only', False):
                loss += getattr(FLAGS, 'kappa', 1.0) * get_model_size_loss(model)
              else:  
                loss += getattr(FLAGS, 'kappa', 1.0) * get_comp_cost_loss(model)
            loss.backward()
            if getattr(FLAGS, 'distributed', False) and getattr(FLAGS, 'distributed_all_reduce', False):
                allreduce_grads(model)
            optimizer.step()
            # For PyTorch 1.0 or earlier, comment the following two lines
            if FLAGS.lr_scheduler in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter']:
                scheduler.step()
            if ema:
                ema.shadow_update(model)
                #for name, param in model.named_parameters():
                #    if param.requires_grad:
                #        ema.update(name, param.data)
                #bn_idx = 0
                #for m in model.modules():
                #    if isinstance(m, nn.BatchNorm2d):
                #        ema.update('bn{}_mean'.format(bn_idx), m.running_mean)
                #        ema.update('bn{}_var'.format(bn_idx), m.running_var)
                #        bn_idx += 1
        else: #not train
            if ema:
                mprint('ema apply')
                ema.shadow_apply(model)
            forward_loss(model, criterion, input, target, meters)
            if ema:
                mprint('ema recover')
                ema.weight_recover(model)
    val_top1 = None
    if is_master():
        results = flush_scalar_meters(meters)
        mprint('{:.1f}s\t{}\t{}/{}: '.format(
            time.time() - t_start, phase, epoch, FLAGS.num_epochs) +
              ', '.join('{}: {}'.format(k, v) for k, v in results.items()))
        val_top1 = results['top1_error']
    return val_top1
Exemple #5
0
def run_one_epoch(
        epoch, loader, model, criterion, optimizer, meters, phase='train',
        soft_criterion=None):
    """run one epoch for train/val/test/cal"""
    t_start = time.time()
    assert phase in ['train', 'val', 'test', 'cal'], 'Invalid phase.'
    train = phase == 'train'
    if train:
        model.train()
    else:
        model.eval()
        if phase == 'cal':
            model.apply(bn_calibration_init)
    # change learning rate in each iteration
    if getattr(FLAGS, 'universally_slimmable_training', False):
        max_width = FLAGS.width_mult_range[1]
        min_width = FLAGS.width_mult_range[0]
    elif getattr(FLAGS, 'slimmable_training', False):
        max_width = max(FLAGS.width_mult_list)
        min_width = min(FLAGS.width_mult_list)

    if getattr(FLAGS, 'distributed', False):
        loader.sampler.set_epoch(epoch)
    for batch_idx, (input, target) in enumerate(loader):
        if phase == 'cal':
            if batch_idx == getattr(FLAGS, 'bn_cal_batch_num', -1):
                break
        target = target.cuda(non_blocking=True)
        if train:
            # change learning rate if necessary
            lr_schedule_per_iteration(optimizer, epoch, batch_idx)
            optimizer.zero_grad()
            if getattr(FLAGS, 'slimmable_training', False):
                if getattr(FLAGS, 'universally_slimmable_training', False):
                    # universally slimmable model (us-nets)
                    widths_train = []
                    for _ in range(getattr(FLAGS, 'num_sample_training', 2)-2):
                        widths_train.append(
                            random.uniform(min_width, max_width))
                    widths_train = [max_width, min_width] + widths_train
                    for width_mult in widths_train:
                        # the sandwich rule
                        if width_mult in [max_width, min_width]:
                            model.apply(
                                lambda m: setattr(m, 'width_mult', width_mult))
                        elif getattr(FLAGS, 'nonuniform', False):
                            model.apply(lambda m: setattr(
                                m, 'width_mult',
                                lambda: random.uniform(min_width, max_width)))
                        else:
                            model.apply(lambda m: setattr(
                                m, 'width_mult',
                                width_mult))

                        # always track largest model and smallest model
                        if is_master() and width_mult in [
                                max_width, min_width]:
                            meter = meters[str(width_mult)]
                        else:
                            meter = None

                        # inplace distillation
                        if width_mult == max_width:
                            loss, soft_target = forward_loss(
                                model, criterion, input, target, meter,
                                return_soft_target=True)
                        else:
                            if getattr(FLAGS, 'inplace_distill', False):
                                loss = forward_loss(
                                    model, criterion, input, target, meter,
                                    soft_target=soft_target.detach(),
                                    soft_criterion=soft_criterion)
                            else:
                                loss = forward_loss(
                                    model, criterion, input, target, meter)
                        loss.backward()
                else:
                    # slimmable model (s-nets)
                    for width_mult in sorted(
                            FLAGS.width_mult_list, reverse=True):
                        model.apply(
                            lambda m: setattr(m, 'width_mult', width_mult))
                        if is_master():
                            meter = meters[str(width_mult)]
                        else:
                            meter = None
                        if width_mult == max_width:
                            loss, soft_target = forward_loss(
                                model, criterion, input, target, meter,
                                return_soft_target=True)
                        else:
                            if getattr(FLAGS, 'inplace_distill', False):
                                loss = forward_loss(
                                    model, criterion, input, target, meter,
                                    soft_target=soft_target.detach(),
                                    soft_criterion=soft_criterion)
                            else:
                                loss = forward_loss(
                                    model, criterion, input, target, meter)
                        loss.backward()
            else:
                loss = forward_loss(
                    model, criterion, input, target, meters)
                loss.backward()
            if (getattr(FLAGS, 'distributed', False)
                    and getattr(FLAGS, 'distributed_all_reduce', False)):
                allreduce_grads(model)
            optimizer.step()
            if is_master() and getattr(FLAGS, 'slimmable_training', False):
                for width_mult in sorted(FLAGS.width_mult_list, reverse=True):
                    meter = meters[str(width_mult)]
                    meter['lr'].cache(optimizer.param_groups[0]['lr'])
            elif is_master():
                meters['lr'].cache(optimizer.param_groups[0]['lr'])
            else:
                pass
        else:
            if getattr(FLAGS, 'slimmable_training', False):
                for width_mult in sorted(FLAGS.width_mult_list, reverse=True):
                    model.apply(
                        lambda m: setattr(m, 'width_mult', width_mult))
                    if is_master():
                        meter = meters[str(width_mult)]
                    else:
                        meter = None
                    forward_loss(model, criterion, input, target, meter)
            else:
                forward_loss(model, criterion, input, target, meters)
    if is_master() and getattr(FLAGS, 'slimmable_training', False):
        for width_mult in sorted(FLAGS.width_mult_list, reverse=True):
            results = flush_scalar_meters(meters[str(width_mult)])
            print('{:.1f}s\t{}\t{}\t{}/{}: '.format(
                time.time() - t_start, phase, str(width_mult), epoch,
                FLAGS.num_epochs) + ', '.join(
                    '{}: {:.3f}'.format(k, v) for k, v in results.items()))
    elif is_master():
        results = flush_scalar_meters(meters)
        print(
            '{:.1f}s\t{}\t{}/{}: '.format(
                time.time() - t_start, phase, epoch, FLAGS.num_epochs) +
            ', '.join('{}: {:.3f}'.format(k, v) for k, v in results.items()))
    else:
        results = None
    return results
Exemple #6
0
def run_one_epoch(epoch,
                  loader,
                  model,
                  criterion,
                  optimizer,
                  meters,
                  phase='train',
                  ema=None,
                  scheduler=None,
                  eta=None,
                  epoch_dict=None,
                  single_sample=False):
    """run one epoch for train/val/test/cal"""
    t_start = time.time()
    assert phase in ['train', 'val', 'test',
                     'cal'], "phase not be in train/val/test/cal."
    train = phase == 'train'
    if train:
        model.train()
    else:
        model.eval()
        #if getattr(FLAGS, 'bn_calib', False) and phase == 'val' and epoch < FLAGS.num_epochs - 10:
        #    model.apply(bn_calibration)
        #if getattr(FLAGS, 'bn_calib_stoch_valid', False):
        #    model.apply(bn_calibration)
        if phase == 'cal':
            model.apply(bn_calibration)

    if getattr(FLAGS, 'distributed', False):
        loader.sampler.set_epoch(epoch)

    scale_dict = {}
    if getattr(FLAGS, 'switch_lr', False):
        scale_dict = {
            32: 1.0,
            16: 1.0,
            8: 1.0,
            6: 1.0,
            5: 1.0,
            4: 1.02,
            3: 1.08,
            2: 1.62,
            1: 4.83
        }

    for batch_idx, (input, target) in enumerate(loader):
        if phase == 'cal':
            if batch_idx == getattr(FLAGS, 'bn_cal_batch_num', -1):
                break
        target = target.cuda(non_blocking=True)
        if train:
            if FLAGS.lr_scheduler == 'linear_decaying':
                linear_decaying_per_step = (FLAGS.lr / FLAGS.num_epochs /
                                            len(loader.dataset) *
                                            FLAGS.batch_size)
                for param_group in optimizer.param_groups:
                    param_group['lr'] -= linear_decaying_per_step
            # For PyTorch 1.1+, comment the following two line
            #if FLAGS.lr_scheduler in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter']:
            #    scheduler.step()
            optimizer.zero_grad()
            if getattr(FLAGS, 'quantizable_training',
                       False) and not single_sample:
                for bits_idx, bits in enumerate(FLAGS.bits_list):
                    model.apply(lambda m: setattr(m, 'bits', bits))
                    if is_master():
                        meter = meters[str(bits)]
                    else:
                        meter = None
                    loss = forward_loss(model, criterion, input, target, meter)
                    if eta is not None:
                        #if isinstance(bits, (list, tuple)):
                        #    bitw = bits[0]
                        #else:
                        #    bitw = bits
                        #loss *= eta(bitw)
                        loss *= eta(_pair(bits)[0])
                    if getattr(FLAGS, 'switch_lr', False):
                        #mprint(scale_dict[_pair(bits)[0]])
                        loss *= scale_dict[_pair(bits)[0]]
                    if epoch_dict is None:
                        loss.backward()
                    else:
                        epoch_valid = epoch_dict[_pair(bits)[0]]
                        if isinstance(epoch_valid, (list, tuple)):
                            epoch_start, epoch_end = epoch_valid
                        else:
                            epoch_start = epoch_valid
                            epoch_end = 1.0
                        epoch_start = int(FLAGS.num_epochs * epoch_start)
                        epoch_end = int(FLAGS.num_epochs * epoch_end)
                        if epoch_start <= epoch and epoch < epoch_end:
                            loss.backward()
                    if getattr(FLAGS, 'print_grad_std', False):
                        mprint(f'bits: {bits}')
                        layer_idx = 0
                        grad_std_list = []
                        for m in model.modules():
                            #if getattr(m, 'weight', None) is not None:
                            if isinstance(
                                    m, (QuantizableConv2d, QuantizableLinear)):
                                grad_std = torch.std(m.weight.grad)
                                mprint(f'layer_{layer_idx} grad: {grad_std}'
                                       )  #, module: {m}')
                                grad_std_list.append(grad_std)
                                layer_idx += 1
                        mprint(
                            f'average grad std: {torch.mean(torch.tensor(grad_std_list))}'
                        )
            else:
                loss = forward_loss(model, criterion, input, target, meters)
                loss.backward()
            if getattr(FLAGS, 'distributed', False) and getattr(
                    FLAGS, 'distributed_all_reduce', False):
                allreduce_grads(model)
            optimizer.step()
            # For PyTorch 1.0 or earlier, comment the following two lines
            if FLAGS.lr_scheduler in [
                    'exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter',
                    'butterworth_iter', 'mixed_iter'
            ]:
                scheduler.step()
            if ema:
                ema.shadow_update(model)
                #for name, param in model.named_parameters():
                #    if param.requires_grad:
                #        ema.update(name, param.data)
                #bn_idx = 0
                #for m in model.modules():
                #    if isinstance(m, nn.BatchNorm2d):
                #        ema.update('bn{}_mean'.format(bn_idx), m.running_mean)
                #        ema.update('bn{}_var'.format(bn_idx), m.running_var)
                #        bn_idx += 1
        else:  #not train
            if ema:
                ema.shadow_apply(model)
            if getattr(FLAGS, 'quantizable_training',
                       False) and not single_sample:
                for bits_idx, bits in enumerate(FLAGS.bits_list):
                    model.apply(lambda m: setattr(m, 'bits', bits))
                    #model.apply(
                    #    lambda m: setattr(m, 'threshold', FLAGS.schmitt_threshold * (0.0 * (epoch <= 30) + 0.01 * (30 < epoch <= 60) + 0.1 * (60 < epoch <= 90) + 1.0 * (90 < epoch))))
                    #model.apply(
                    #    lambda m: setattr(m, 'threshold', epoch * FLAGS.schmitt_threshold / FLAGS.num_epochs))
                    if is_master():
                        meter = meters[str(bits)]
                    else:
                        meter = None
                    forward_loss(model, criterion, input, target, meter)
            else:
                forward_loss(model, criterion, input, target, meters)
            if ema:
                ema.weight_recover(model)
    ##opt_loss = float('inf')
    ##opt_results = None
    val_top1 = None
    if is_master():
        if getattr(FLAGS, 'quantizable_training', False) and not single_sample:
            #results_dict = {}
            val_top1_list = []
            for bits in FLAGS.bits_list:
                results = flush_scalar_meters(meters[str(bits)])
                mprint('{:.1f}s\t{}\t{} bits\t{}/{}: '.format(
                    time.time() -
                    t_start, phase, bits, epoch, FLAGS.num_epochs) +
                       ', '.join('{}: {}'.format(k, v)
                                 for k, v in results.items()))
                #results_dict[str(bits)] = results
                ##if results['loss'] < opt_loss:
                ##    opt_results = results
                ##    opt_loss = results['loss']
                val_top1_list.append(results['top1_error'])
            #results = results_dict
            val_top1 = np.mean(val_top1_list)
        else:
            results = flush_scalar_meters(meters)
            mprint('{:.1f}s\t{}\t{}/{}: '.format(time.time() - t_start, phase,
                                                 epoch, FLAGS.num_epochs) +
                   ', '.join('{}: {}'.format(k, v)
                             for k, v in results.items()))
            ##if results['loss'] < opt_loss:
            ##    opt_results = results
            ##    opt_loss = results['loss']
            val_top1 = results['top1_error']
    ##return opt_results
    #return results
    return val_top1