예제 #1
0
def reduce_and_flush_meters(meters, method='avg'):
    """Sync and flush meters."""
    if not FLAGS.use_distributed:
        results = flush_scalar_meters(meters)
    else:
        results = {}
        assert isinstance(meters, dict), "meters should be a dict."
        # NOTE: Ensure same order, otherwise may deadlock
        for name in sorted(meters.keys()):
            meter = meters[name]
            if not isinstance(meter, ScalarMeter):
                continue
            if method == 'avg':
                method_fun = torch.mean
            elif method == 'sum':
                method_fun = torch.sum
            elif method == 'max':
                method_fun = torch.max
            elif method == 'min':
                method_fun = torch.min
            else:
                raise NotImplementedError(
                    'flush method: {} is not yet implemented.'.format(method))
            tensor = torch.tensor(meter.values).cuda()
            gather_tensors = [
                torch.ones_like(tensor) for _ in range(udist.get_world_size())
            ]
            dist.all_gather(gather_tensors, tensor)
            value = method_fun(torch.cat(gather_tensors))
            meter.flush(value)
            results[name] = value
    return results
예제 #2
0
def run_one_epoch(
        epoch, loader, model, criterion, optimizer, meters, phase='train'):
    """run one epoch for train/val/test"""
    t_start = time.time()
    assert phase in ['train', 'val', 'test'], "phase not be in train/val/test."
    train = phase == 'train'
    if train:
        model.train()
    else:
        model.eval()
    if getattr(FLAGS, 'slimmable_sample_training', False):
        max_width = max(FLAGS.width_mult_list)
        min_width = min(FLAGS.width_mult_list)
        other_widths = FLAGS.width_mult_list.copy()
        other_widths.remove(max_width)
        other_widths.remove(min_width)
    if train and FLAGS.lr_scheduler == 'linear_decaying':
        linear_decaying_per_step = (
            FLAGS.lr / FLAGS.num_epochs /
            len(loader.dataset) * FLAGS.batch_size)
    for batch_idx, (input, target) in enumerate(loader):
        target = target.cuda(non_blocking=True)
        if train:
            if FLAGS.lr_scheduler == 'linear_decaying':
                for param_group in optimizer.param_groups:
                    param_group['lr'] -= linear_decaying_per_step
            optimizer.zero_grad()
            if getattr(FLAGS, 'slimmable_training', False):
                for width_mult in sorted(FLAGS.width_mult_list, reverse=True):
                    model.apply(
                        lambda m: setattr(m, 'width_mult', width_mult))
                    loss = forward_loss(
                        model, criterion, input, target,
                        meters[str(width_mult)])
                    loss.backward()
            else:
                loss = forward_loss(
                    model, criterion, input, target, meters)
                loss.backward()
            optimizer.step()
        else:
            if getattr(FLAGS, 'slimmable_training', False):
                for width_mult in sorted(FLAGS.width_mult_list, reverse=True):
                    model.apply(
                        lambda m: setattr(m, 'width_mult', width_mult))
                    forward_loss(
                        model, criterion, input, target,
                        meters[str(width_mult)])
            else:
                forward_loss(model, criterion, input, target, meters)
    if getattr(FLAGS, 'slimmable_training', False):
        for width_mult in sorted(FLAGS.width_mult_list, reverse=True):
            results = flush_scalar_meters(meters[str(width_mult)])
            print('{:.1f}s\t{}\t{}\t{}/{}: '.format(
                time.time() - t_start, phase, str(width_mult), epoch,
                FLAGS.num_epochs) + ', '.join('{}: {:.3f}'.format(k, v)
                                              for k, v in results.items()))
    else:
        results = flush_scalar_meters(meters)
        print('{:.1f}s\t{}\t{}/{}: '.format(time.time() - t_start,
                                            phase, epoch, FLAGS.num_epochs) +
              ', '.join('{}: {:.3f}'.format(k, v) for k, v in results.items()))
    return results
예제 #3
0
파일: train.py 프로젝트: phuocphn/AdaBits
def run_one_epoch(
        epoch, loader, model, criterion, optimizer, meters, phase='train', scheduler=None):
    """run one epoch for train/val/test"""
    t_start = time.time()
    assert phase in ['train', 'val', 'test'], "phase not be in train/val/test."
    train = phase == 'train'
    if train:
        model.train()
    else:
        model.eval()

    if getattr(FLAGS, 'distributed', False):
        loader.sampler.set_epoch(epoch)

    for batch_idx, (input, target) in enumerate(loader):
        target = target.cuda(non_blocking=True)
        if train:
            if FLAGS.lr_scheduler == 'linear_decaying':
                linear_decaying_per_step = (
                    FLAGS.lr/FLAGS.num_epochs/len(loader.dataset)*FLAGS.batch_size)
                for param_group in optimizer.param_groups:
                    param_group['lr'] -= linear_decaying_per_step
            # For PyTorch 1.1+, comment the following two line
            #if FLAGS.lr_scheduler in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter', 'multistep_iter']:
            #    scheduler.step()
            optimizer.zero_grad()
            if getattr(FLAGS, 'adaptive_training', False):
                for bits_idx, bits in enumerate(FLAGS.bits_list):
                    model.apply(
                        lambda m: setattr(m, 'bits', bits))
                    if is_master():
                        meter = meters[str(bits)]
                    else:
                        meter = None
                    loss = forward_loss(
                        model, criterion, input, target, meter)
                    loss.backward()
            else:
                loss = forward_loss(
                    model, criterion, input, target, meters)
                loss.backward()
            if getattr(FLAGS, 'distributed', False) and getattr(FLAGS, 'distributed_all_reduce', False):
                allreduce_grads(model)
            optimizer.step()
            # For PyTorch 1.0 or earlier, comment the following two lines
            if FLAGS.lr_scheduler in ['exp_decaying_iter', 'cos_annealing_iter', 'multistep_iter']:
                scheduler.step()
        else: #not train
            if getattr(FLAGS, 'adaptive_training', False):
                for bits_idx, bits in enumerate(FLAGS.bits_list):
                    model.apply(
                        lambda m: setattr(m, 'bits', bits))
                    if is_master() and meters is not None:
                        meter = meters[str(bits)]
                    else:
                        meter = None
                    forward_loss(
                        model, criterion, input, target, meter)
            else:
                forward_loss(model, criterion, input, target, meters)
                
    val_top1 = None
    if is_master() and meters is not None:
        if getattr(FLAGS, 'adaptive_training', False):
            val_top1_list = []
            for bits in FLAGS.bits_list:
                results = flush_scalar_meters(meters[str(bits)])
                mprint('{:.1f}s\t{}\t{} bits\t{}/{}: '.format(
                    time.time() - t_start, phase, bits, epoch,
                    FLAGS.num_epochs) + ', '.join('{}: {}'.format(k, v)
                                                  for k, v in results.items()))
                val_top1_list.append(results['top1_error'])
            val_top1 = np.mean(val_top1_list)
        else:
            results = flush_scalar_meters(meters)
            mprint('{:.1f}s\t{}\t{}/{}: '.format(
                time.time() - t_start, phase, epoch, FLAGS.num_epochs) +
                  ', '.join('{}: {}'.format(k, v) for k, v in results.items()))
            val_top1 = results['top1_error']
    return val_top1
예제 #4
0
def run_one_epoch(
        epoch, loader, model, criterion, optimizer, meters, phase='train', ema=None, scheduler=None, eta=None, epoch_dict=None, single_sample=False):
    """run one epoch for train/val/test/cal"""
    t_start = time.time()
    assert phase in ['train', 'val', 'test', 'cal'], "phase not be in train/val/test/cal."
    train = phase == 'train'
    if train:
        model.train()
    else:
        model.eval()
        #if getattr(FLAGS, 'bn_calib', False) and phase == 'val' and epoch < FLAGS.num_epochs - 10:
        #    model.apply(bn_calibration)
        #if getattr(FLAGS, 'bn_calib_stoch_valid', False):
        #    model.apply(bn_calibration)
        if phase == 'cal':
            model.apply(bn_calibration)

    if getattr(FLAGS, 'distributed', False): ###################### What does this line do?? ##########################
        loader.sampler.set_epoch(epoch)

    scale_dict = {}
    if getattr(FLAGS, 'switch_lr', False):
        scale_dict = {32: 1.0, 16: 1.0, 8: 1.0, 6: 1.0, 5: 1.0, 4: 1.02, 3: 1.08, 2: 1.62, 1: 4.83}

    for batch_idx, (input, target) in enumerate(loader):
        if phase == 'cal':
            if batch_idx == getattr(FLAGS, 'bn_cal_batch_num', -1):
                break
        target = target.cuda(non_blocking=True)
        if train:
            if FLAGS.lr_scheduler == 'linear_decaying':
                linear_decaying_per_step = (
                    FLAGS.lr/FLAGS.num_epochs/len(loader.dataset)*FLAGS.batch_size)
                for param_group in optimizer.param_groups:
                    param_group['lr'] -= linear_decaying_per_step
            # For PyTorch 1.1+, comment the following two line
            #if FLAGS.lr_scheduler in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter']:
            #    scheduler.step()
            optimizer.zero_grad()
            if getattr(FLAGS, 'quantizable_training', False) and not single_sample:
                for bits_idx, bits in enumerate(FLAGS.bits_list):
                    model.apply(
                        lambda m: setattr(m, 'bits', bits))
                    if is_master():
                        meter = meters[str(bits)]
                    else:
                        meter = None
                    loss = forward_loss(
                        model, criterion, input, target, meter)
                    if eta is not None:
                        #if isinstance(bits, (list, tuple)):
                        #    bitw = bits[0]
                        #else:
                        #    bitw = bits
                        #loss *= eta(bitw)
                        loss *= eta(_pair(bits)[0])
                    if getattr(FLAGS, 'switch_lr', False):
                        #mprint(scale_dict[_pair(bits)[0]])
                        loss *= scale_dict[_pair(bits)[0]]
                    if epoch_dict is None:
                        loss.backward()
                    else:
                        epoch_valid = epoch_dict[_pair(bits)[0]]
                        if isinstance(epoch_valid, (list, tuple)):
                            epoch_start, epoch_end = epoch_valid
                        else:
                            epoch_start = epoch_valid
                            epoch_end = 1.0
                        epoch_start = int(FLAGS.num_epochs * epoch_start)
                        epoch_end = int(FLAGS.num_epochs * epoch_end)
                        if epoch_start <= epoch and epoch < epoch_end:
                            loss.backward()
                    if getattr(FLAGS, 'print_grad_std', False):
                        mprint(f'bits: {bits}')
                        layer_idx = 0
                        grad_std_list = []
                        for m in model.modules():
                            #if getattr(m, 'weight', None) is not None:
                            if isinstance(m, (QuantizableConv2d, QuantizableLinear)):
                                grad_std = torch.std(m.weight.grad)
                                mprint(f'layer_{layer_idx} grad: {grad_std}') #, module: {m}')
                                grad_std_list.append(grad_std)
                                layer_idx += 1
                        mprint(f'average grad std: {torch.mean(torch.tensor(grad_std_list))}')
            else:
                loss = forward_loss(
                    model, criterion, input, target, meters)
                loss.backward()
            #if getattr(FLAGS, 'distributed', False) and getattr(FLAGS, 'distributed_all_reduce', False):
            #    allreduce_grads(model)
            optimizer.step()
            # For PyTorch 1.0 or earlier, comment the following two lines
            if FLAGS.lr_scheduler in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter']:
                scheduler.step()
            if ema:
                ema.shadow_update(model)
                #for name, param in model.named_parameters():
                #    if param.requires_grad:
                #        ema.update(name, param.data)
                #bn_idx = 0
                #for m in model.modules():
                #    if isinstance(m, nn.BatchNorm2d):
                #        ema.update('bn{}_mean'.format(bn_idx), m.running_mean)
                #        ema.update('bn{}_var'.format(bn_idx), m.running_var)
                #        bn_idx += 1
        else: #not train
            if ema:
                ema.shadow_apply(model)
            if getattr(FLAGS, 'quantizable_training', False) and not single_sample:
                for bits_idx, bits in enumerate(FLAGS.bits_list):
                    model.apply(
                        lambda m: setattr(m, 'bits', bits))
                    #model.apply(
                    #    lambda m: setattr(m, 'threshold', FLAGS.schmitt_threshold * (0.0 * (epoch <= 30) + 0.01 * (30 < epoch <= 60) + 0.1 * (60 < epoch <= 90) + 1.0 * (90 < epoch))))
                    #model.apply(
                    #    lambda m: setattr(m, 'threshold', epoch * FLAGS.schmitt_threshold / FLAGS.num_epochs))
                    if is_master():
                        meter = meters[str(bits)]
                    else:
                        meter = None
                    forward_loss(
                        model, criterion, input, target, meter)
            else:
                forward_loss(model, criterion, input, target, meters)
            if ema:
                ema.weight_recover(model)
    ##opt_loss = float('inf')
    ##opt_results = None
    val_top1 = None
    if is_master():
        if getattr(FLAGS, 'quantizable_training', False) and not single_sample:
            #results_dict = {}
            val_top1_list = []
            for bits in FLAGS.bits_list:
                results = flush_scalar_meters(meters[str(bits)])
                mprint('{:.1f}s\t{}\t{} bits\t{}/{}: '.format(
                    time.time() - t_start, phase, bits, epoch,
                    FLAGS.num_epochs) + ', '.join('{}: {}'.format(k, v)
                                                  for k, v in results.items()))
                #results_dict[str(bits)] = results
                ##if results['loss'] < opt_loss:
                ##    opt_results = results
                ##    opt_loss = results['loss']
                val_top1_list.append(results['top1_error'])
            #results = results_dict
            val_top1 = np.mean(val_top1_list)
        else:
            results = flush_scalar_meters(meters)
            mprint('{:.1f}s\t{}\t{}/{}: '.format(
                time.time() - t_start, phase, epoch, FLAGS.num_epochs) +
                  ', '.join('{}: {}'.format(k, v) for k, v in results.items()))
            ##if results['loss'] < opt_loss:
            ##    opt_results = results
            ##    opt_loss = results['loss']
            val_top1 = results['top1_error']
    ##return opt_results
    #return results
    return val_top1
예제 #5
0
파일: train.py 프로젝트: oj9040/FracBits
def run_one_epoch(
        epoch, loader, model, criterion, optimizer, meters, phase='train', ema=None, scheduler=None):
    """run one epoch for train/val/test/cal"""
    t_start = time.time()
    assert phase in ['train', 'val', 'test', 'cal'], "phase not be in train/val/test/cal."
    train = phase == 'train'
    if train:
        model.train()
    else:
        model.eval()

    if getattr(FLAGS, 'distributed', False):
        loader.sampler.set_epoch(epoch)

    for batch_idx, (input, target) in enumerate(loader):
        if phase == 'cal':
            if batch_idx == getattr(FLAGS, 'bn_cal_batch_num', -1):
                break
        target = target.cuda(non_blocking=True)
        if train:
            if FLAGS.lr_scheduler == 'linear_decaying':
                linear_decaying_per_step = (
                    FLAGS.lr/FLAGS.num_epochs/len(loader.dataset)*FLAGS.batch_size)
                for param_group in optimizer.param_groups:
                    param_group['lr'] -= linear_decaying_per_step
            # For PyTorch 1.1+, comment the following two line
            #if FLAGS.lr_scheduler in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter']:
            #    scheduler.step()
            optimizer.zero_grad()
            loss = forward_loss(
                model, criterion, input, target, meters)
            if epoch >= FLAGS.warmup_epochs and not getattr(FLAGS,'hard_assignment', False):
              if getattr(FLAGS,'weight_only', False):
                loss += getattr(FLAGS, 'kappa', 1.0) * get_model_size_loss(model)
              else:  
                loss += getattr(FLAGS, 'kappa', 1.0) * get_comp_cost_loss(model)
            loss.backward()
            if getattr(FLAGS, 'distributed', False) and getattr(FLAGS, 'distributed_all_reduce', False):
                allreduce_grads(model)
            optimizer.step()
            # For PyTorch 1.0 or earlier, comment the following two lines
            if FLAGS.lr_scheduler in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter']:
                scheduler.step()
            if ema:
                ema.shadow_update(model)
                #for name, param in model.named_parameters():
                #    if param.requires_grad:
                #        ema.update(name, param.data)
                #bn_idx = 0
                #for m in model.modules():
                #    if isinstance(m, nn.BatchNorm2d):
                #        ema.update('bn{}_mean'.format(bn_idx), m.running_mean)
                #        ema.update('bn{}_var'.format(bn_idx), m.running_var)
                #        bn_idx += 1
        else: #not train
            if ema:
                mprint('ema apply')
                ema.shadow_apply(model)
            forward_loss(model, criterion, input, target, meters)
            if ema:
                mprint('ema recover')
                ema.weight_recover(model)
    val_top1 = None
    if is_master():
        results = flush_scalar_meters(meters)
        mprint('{:.1f}s\t{}\t{}/{}: '.format(
            time.time() - t_start, phase, epoch, FLAGS.num_epochs) +
              ', '.join('{}: {}'.format(k, v) for k, v in results.items()))
        val_top1 = results['top1_error']
    return val_top1
예제 #6
0
def run_one_epoch(
        epoch, loader, model, criterion, optimizer, meters, phase='train',
        soft_criterion=None):
    """run one epoch for train/val/test/cal"""
    t_start = time.time()
    assert phase in ['train', 'val', 'test', 'cal'], 'Invalid phase.'
    train = phase == 'train'
    if train:
        model.train()
    else:
        model.eval()
        if phase == 'cal':
            model.apply(bn_calibration_init)
    # change learning rate in each iteration
    if getattr(FLAGS, 'universally_slimmable_training', False):
        max_width = FLAGS.width_mult_range[1]
        min_width = FLAGS.width_mult_range[0]
    elif getattr(FLAGS, 'slimmable_training', False):
        max_width = max(FLAGS.width_mult_list)
        min_width = min(FLAGS.width_mult_list)

    if getattr(FLAGS, 'distributed', False):
        loader.sampler.set_epoch(epoch)
    for batch_idx, (input, target) in enumerate(loader):
        if phase == 'cal':
            if batch_idx == getattr(FLAGS, 'bn_cal_batch_num', -1):
                break
        target = target.cuda(non_blocking=True)
        if train:
            # change learning rate if necessary
            lr_schedule_per_iteration(optimizer, epoch, batch_idx)
            optimizer.zero_grad()
            if getattr(FLAGS, 'slimmable_training', False):
                if getattr(FLAGS, 'universally_slimmable_training', False):
                    # universally slimmable model (us-nets)
                    widths_train = []
                    for _ in range(getattr(FLAGS, 'num_sample_training', 2)-2):
                        widths_train.append(
                            random.uniform(min_width, max_width))
                    widths_train = [max_width, min_width] + widths_train
                    for width_mult in widths_train:
                        # the sandwich rule
                        if width_mult in [max_width, min_width]:
                            model.apply(
                                lambda m: setattr(m, 'width_mult', width_mult))
                        elif getattr(FLAGS, 'nonuniform', False):
                            model.apply(lambda m: setattr(
                                m, 'width_mult',
                                lambda: random.uniform(min_width, max_width)))
                        else:
                            model.apply(lambda m: setattr(
                                m, 'width_mult',
                                width_mult))

                        # always track largest model and smallest model
                        if is_master() and width_mult in [
                                max_width, min_width]:
                            meter = meters[str(width_mult)]
                        else:
                            meter = None

                        # inplace distillation
                        if width_mult == max_width:
                            loss, soft_target = forward_loss(
                                model, criterion, input, target, meter,
                                return_soft_target=True)
                        else:
                            if getattr(FLAGS, 'inplace_distill', False):
                                loss = forward_loss(
                                    model, criterion, input, target, meter,
                                    soft_target=soft_target.detach(),
                                    soft_criterion=soft_criterion)
                            else:
                                loss = forward_loss(
                                    model, criterion, input, target, meter)
                        loss.backward()
                else:
                    # slimmable model (s-nets)
                    for width_mult in sorted(
                            FLAGS.width_mult_list, reverse=True):
                        model.apply(
                            lambda m: setattr(m, 'width_mult', width_mult))
                        if is_master():
                            meter = meters[str(width_mult)]
                        else:
                            meter = None
                        if width_mult == max_width:
                            loss, soft_target = forward_loss(
                                model, criterion, input, target, meter,
                                return_soft_target=True)
                        else:
                            if getattr(FLAGS, 'inplace_distill', False):
                                loss = forward_loss(
                                    model, criterion, input, target, meter,
                                    soft_target=soft_target.detach(),
                                    soft_criterion=soft_criterion)
                            else:
                                loss = forward_loss(
                                    model, criterion, input, target, meter)
                        loss.backward()
            else:
                loss = forward_loss(
                    model, criterion, input, target, meters)
                loss.backward()
            if (getattr(FLAGS, 'distributed', False)
                    and getattr(FLAGS, 'distributed_all_reduce', False)):
                allreduce_grads(model)
            optimizer.step()
            if is_master() and getattr(FLAGS, 'slimmable_training', False):
                for width_mult in sorted(FLAGS.width_mult_list, reverse=True):
                    meter = meters[str(width_mult)]
                    meter['lr'].cache(optimizer.param_groups[0]['lr'])
            elif is_master():
                meters['lr'].cache(optimizer.param_groups[0]['lr'])
            else:
                pass
        else:
            if getattr(FLAGS, 'slimmable_training', False):
                for width_mult in sorted(FLAGS.width_mult_list, reverse=True):
                    model.apply(
                        lambda m: setattr(m, 'width_mult', width_mult))
                    if is_master():
                        meter = meters[str(width_mult)]
                    else:
                        meter = None
                    forward_loss(model, criterion, input, target, meter)
            else:
                forward_loss(model, criterion, input, target, meters)
    if is_master() and getattr(FLAGS, 'slimmable_training', False):
        for width_mult in sorted(FLAGS.width_mult_list, reverse=True):
            results = flush_scalar_meters(meters[str(width_mult)])
            print('{:.1f}s\t{}\t{}\t{}/{}: '.format(
                time.time() - t_start, phase, str(width_mult), epoch,
                FLAGS.num_epochs) + ', '.join(
                    '{}: {:.3f}'.format(k, v) for k, v in results.items()))
    elif is_master():
        results = flush_scalar_meters(meters)
        print(
            '{:.1f}s\t{}\t{}/{}: '.format(
                time.time() - t_start, phase, epoch, FLAGS.num_epochs) +
            ', '.join('{}: {:.3f}'.format(k, v) for k, v in results.items()))
    else:
        results = None
    return results