Beispiel #1
0
def setup_distributed(num_images=None):
    """Setup distributed related parameters."""
    # init distributed
    if FLAGS.use_distributed:
        udist.init_dist()
        FLAGS.batch_size = udist.get_world_size() * FLAGS.per_gpu_batch_size
        FLAGS._loader_batch_size = FLAGS.per_gpu_batch_size
        if FLAGS.bn_calibration:
            FLAGS._loader_batch_size_calib = \
                FLAGS.bn_calibration_per_gpu_batch_size
        FLAGS.data_loader_workers = round(
            FLAGS.data_loader_workers / udist.get_local_size()
        )  # Per_gpu_workers(the function will return the nearest integer
    else:
        count = torch.cuda.device_count()
        FLAGS.batch_size = count * FLAGS.per_gpu_batch_size
        FLAGS._loader_batch_size = FLAGS.batch_size
        if FLAGS.bn_calibration:
            FLAGS._loader_batch_size_calib = \
                FLAGS.bn_calibration_per_gpu_batch_size * count
    if hasattr(FLAGS, 'base_lr'):
        FLAGS.lr = FLAGS.base_lr * (FLAGS.batch_size / FLAGS.base_total_batch)
    if num_images:
        # NOTE: don't drop last batch, thus must use ceil, otherwise learning
        # rate will be negative
        # the smallest integer not less than x
        FLAGS._steps_per_epoch = math.ceil(num_images / FLAGS.batch_size)
Beispiel #2
0
def train_val_test():
    """train and val"""
    torch.backends.cudnn.benchmark = True
    # init distributed
    if getattr(FLAGS, 'distributed', False):
        init_dist()
    # seed
    if getattr(FLAGS, 'use_diff_seed', False):
        print('use diff seed is True')
        while not is_initialized():
            print('Waiting for initialization ...')
            time.sleep(5)
        print('Expected seed: {}'.format(
            getattr(FLAGS, 'random_seed', 0) + get_rank()))
        set_random_seed(getattr(FLAGS, 'random_seed', 0) + get_rank())
    else:
        set_random_seed()
    # experiment setting
    experiment_setting = get_experiment_setting()

    # model
    model, model_wrapper = get_model()
    criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()
    if getattr(FLAGS, 'profiling_only', False):
        if 'gpu' in FLAGS.profiling:
            profiling(model, use_cuda=True)
        if 'cpu' in FLAGS.profiling:
            profiling(model, use_cuda=False)
        return

    # data
    train_transforms, val_transforms, test_transforms = data_transforms()
    train_set, val_set, test_set = dataset(train_transforms, val_transforms,
                                           test_transforms)
    train_loader, val_loader, test_loader = data_loader(
        train_set, val_set, test_set)

    log_dir = FLAGS.log_dir
    log_dir = os.path.join(log_dir, experiment_setting)

    checkpoint = torch.load(os.path.join(log_dir, 'best_model.pt'),
                            map_location=lambda storage, loc: storage)
    model_wrapper.load_state_dict(checkpoint['model'])
    optimizer = get_optimizer(model_wrapper)

    mprint('Start testing.')
    test_meters = get_meters('test')
    with torch.no_grad():
        run_one_epoch(-1,
                      test_loader,
                      model_wrapper,
                      criterion,
                      optimizer,
                      test_meters,
                      phase='test',
                      ema=ema)
Beispiel #3
0
def main():
    """Entry."""
    # init distributed
    global is_root_rank
    if FLAGS.use_distributed:
        udist.init_dist()
        FLAGS.batch_size = udist.get_world_size() * FLAGS.per_gpu_batch_size
        FLAGS._loader_batch_size = FLAGS.per_gpu_batch_size
        if FLAGS.bn_calibration:
            FLAGS._loader_batch_size_calib = FLAGS.bn_calibration_per_gpu_batch_size
        FLAGS.data_loader_workers = round(FLAGS.data_loader_workers /
                                          udist.get_local_size())
        is_root_rank = udist.is_master()
    else:
        count = torch.cuda.device_count()
        FLAGS.batch_size = count * FLAGS.per_gpu_batch_size
        FLAGS._loader_batch_size = FLAGS.batch_size
        if FLAGS.bn_calibration:
            FLAGS._loader_batch_size_calib = FLAGS.bn_calibration_per_gpu_batch_size * count
        is_root_rank = True
    FLAGS.lr = FLAGS.base_lr * (FLAGS.batch_size / FLAGS.base_total_batch)
    # NOTE: don't drop last batch, thus must use ceil, otherwise learning rate
    # will be negative
    FLAGS._steps_per_epoch = int(np.ceil(NUM_IMAGENET_TRAIN /
                                         FLAGS.batch_size))

    if is_root_rank:
        FLAGS.log_dir = '{}/{}'.format(FLAGS.log_dir,
                                       time.strftime("%Y%m%d-%H%M%S"))
        create_exp_dir(
            FLAGS.log_dir,
            FLAGS.config_path,
            blacklist_dirs=[
                'exp',
                '.git',
                'pretrained',
                'tmp',
                'deprecated',
                'bak',
            ],
        )
        setup_logging(FLAGS.log_dir)
        for k, v in _ENV_EXPAND.items():
            logging.info('Env var expand: {} to {}'.format(k, v))
        logging.info(FLAGS)

    set_random_seed(FLAGS.get('random_seed', 0))
    with SummaryWriterManager():
        train_val_test()
Beispiel #4
0
def get_model():
    """get model"""
    model_lib = importlib.import_module(FLAGS.model)
    model = model_lib.Model(FLAGS.num_classes)
    if getattr(FLAGS, 'distributed', False):
        gpu_id = init_dist()
        if getattr(FLAGS, 'distributed_all_reduce', False):
            model_wrapper = AllReduceDistributedDataParallel(model.cuda())
        else:
            model_wrapper = torch.nn.parallel.DistributedDataParallel(
                model.cuda(), [gpu_id], gpu_id)
    else:
        model_wrapper = torch.nn.DataParallel(model).cuda()
    return model, model_wrapper
Beispiel #5
0
def get_model():
    """get model"""
    model_lib = importlib.import_module(FLAGS.model)
    model = model_lib.Model(FLAGS.num_classes, input_size=FLAGS.image_size)
    inputs = Variable(torch.randn((2, 3, 32, 32)))
    output = model(inputs)
    print(output)
    print(output.size())
    #input('pause')
    if getattr(FLAGS, 'distributed', False):
        gpu_id = init_dist()
        if getattr(FLAGS, 'distributed_all_reduce', False):
            # seems faster
            model_wrapper = AllReduceDistributedDataParallel(model.cuda())
        else:
            model_wrapper = torch.nn.parallel.DistributedDataParallel(
                model.cuda(), [gpu_id], gpu_id)
    else:
        model_wrapper = torch.nn.DataParallel(model).cuda()
    return model, model_wrapper
Beispiel #6
0
def train_val_test():
    """train and val"""
    torch.backends.cudnn.benchmark = True
    # init distributed
    if getattr(FLAGS, 'distributed', False):
        init_dist()
    # seed
    if getattr(FLAGS, 'use_diff_seed', False) and not getattr(FLAGS, 'stoch_valid', False):
        print('use diff seed is True')
        while not is_initialized():
            print('Waiting for initialization ...')
            time.sleep(5)
        print('Expected seed: {}'.format(getattr(FLAGS, 'random_seed', 0) + get_rank()))
        set_random_seed(getattr(FLAGS, 'random_seed', 0) + get_rank())
    else:
        set_random_seed()

    # experiment setting
    experiment_setting = get_experiment_setting()

    # model
    model, model_wrapper = get_model()
    criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()
    if getattr(FLAGS, 'profiling_only', False):
        if 'gpu' in FLAGS.profiling:
            profiling(model, use_cuda=True)
        if 'cpu' in FLAGS.profiling:
            profiling(model, use_cuda=False)
        return

    # data
    train_transforms, val_transforms, test_transforms = data_transforms()
    train_set, val_set, test_set = dataset(
        train_transforms, val_transforms, test_transforms)
    train_loader, val_loader, test_loader = data_loader(
        train_set, val_set, test_set)

    log_dir = FLAGS.log_dir
    log_dir = os.path.join(log_dir, experiment_setting)

    # full precision pretrained
    if getattr(FLAGS, 'fp_pretrained_file', None):
        checkpoint = torch.load(
            FLAGS.fp_pretrained_file, map_location=lambda storage, loc: storage)
        # update keys from external models
        if type(checkpoint) == dict and 'model' in checkpoint:
            checkpoint = checkpoint['model']
        if getattr(FLAGS, 'pretrained_model_remap_keys', False):
            new_checkpoint = {}
            new_keys = list(model_wrapper.state_dict().keys())
            old_keys = list(checkpoint.keys())
            for key_new, key_old in zip(new_keys, old_keys):
                new_checkpoint[key_new] = checkpoint[key_old]
                mprint('remap {} to {}'.format(key_new, key_old))
            checkpoint = new_checkpoint
        model_dict = model_wrapper.state_dict()
        #checkpoint = {k: v for k, v in checkpoint.items() if k in model_dict}
        # switch bn
        for k in list(checkpoint.keys()):
            if 'bn' in k:
                for bn_idx in range(len(FLAGS.bits_list)):
                    k_new = k.split('bn')[0] + 'bn' + k.split('bn')[1][0] + str(bn_idx) + k.split('bn')[1][2:]
                    mprint(k)
                    mprint(k_new)
                    checkpoint[k_new] = model_dict[k]
        if getattr(FLAGS, 'switch_alpha', False):
            for k, v in checkpoint.items():
                if 'alpha' in k and checkpoint[k].size() != model_dict[k].size():
                    #checkpoint[k] = checkpoint[k].repeat(model_dict[k].size())
                    checkpoint[k] = nn.Parameter(torch.stack([checkpoint[k] for _ in range(model_dict[k].size()[0])]))
        # remove unexpected keys
        for k in list(checkpoint.keys()):
            if k not in model_dict.keys():
                checkpoint.pop(k)
        model_dict.update(checkpoint)
        model_wrapper.load_state_dict(model_dict)
        mprint('Loaded full precision model {}.'.format(FLAGS.fp_pretrained_file))

    # check pretrained
    if FLAGS.pretrained_file:
        pretrained_dir = FLAGS.pretrained_dir
        pretrained_dir = os.path.join(pretrained_dir, experiment_setting)
        pretrained_file = os.path.join(pretrained_dir, FLAGS.pretrained_file)
        checkpoint = torch.load(
            pretrained_file, map_location=lambda storage, loc: storage)
        # update keys from external models
        if type(checkpoint) == dict and 'model' in checkpoint:
            checkpoint = checkpoint['model']
        if getattr(FLAGS, 'pretrained_model_remap_keys', False):
            new_checkpoint = {}
            new_keys = list(model_wrapper.state_dict().keys())
            old_keys = list(checkpoint.keys())
            for key_new, key_old in zip(new_keys, old_keys):
                new_checkpoint[key_new] = checkpoint[key_old]
                mprint('remap {} to {}'.format(key_new, key_old))
            checkpoint = new_checkpoint
        model_wrapper.load_state_dict(checkpoint)
        mprint('Loaded model {}.'.format(pretrained_file))
    optimizer = get_optimizer(model_wrapper)

    if FLAGS.test_only and (test_loader is not None):
        mprint('Start profiling.')
        if 'gpu' in FLAGS.profiling:
            profiling(model, use_cuda=True)
        if 'cpu' in FLAGS.profiling:
            profiling(model, use_cuda=False)
        mprint('Start testing.')
        test_meters = get_meters('test')
        with torch.no_grad():
            run_one_epoch(
                -1, test_loader,
                model_wrapper, criterion, optimizer,
                test_meters, phase='test')
        return

    # check resume training
    if os.path.exists(os.path.join(log_dir, 'latest_checkpoint.pt')):
        checkpoint = torch.load(
            os.path.join(log_dir, 'latest_checkpoint.pt'),
            map_location=lambda storage, loc: storage)
        model_wrapper.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        last_epoch = checkpoint['last_epoch']
        if FLAGS.lr_scheduler in ['exp_decaying_iter', 'cos_annealing_iter', 'multistep_iter']:
            lr_scheduler = get_lr_scheduler(optimizer, len(train_loader))
            lr_scheduler.last_epoch = last_epoch * len(train_loader)
        else:
            lr_scheduler = get_lr_scheduler(optimizer)
            lr_scheduler.last_epoch = last_epoch
        best_val = checkpoint['best_val']
        train_meters, val_meters = checkpoint['meters']
        mprint('Loaded checkpoint {} at epoch {}.'.format(
            log_dir, last_epoch))
    else:
        if FLAGS.lr_scheduler in ['exp_decaying_iter', 'cos_annealing_iter', 'multistep_iter']:
            lr_scheduler = get_lr_scheduler(optimizer, len(train_loader))
        else:
            lr_scheduler = get_lr_scheduler(optimizer)
        last_epoch = lr_scheduler.last_epoch
        best_val = 1.
        train_meters = get_meters('train')
        val_meters = get_meters('val')
        # if start from scratch, print model and do profiling
        mprint(model_wrapper)
        if getattr(FLAGS, 'profiling', False):
            if 'gpu' in FLAGS.profiling:
                profiling(model, use_cuda=True)
            if 'cpu' in FLAGS.profiling:
                profiling(model, use_cuda=False)

    if getattr(FLAGS, 'log_dir', None):
        try:
            os.makedirs(log_dir)
        except OSError:
            pass

    mprint('Start training.')

    for epoch in range(last_epoch+1, FLAGS.num_epochs):
        if FLAGS.lr_scheduler in ['exp_decaying_iter', 'cos_annealing_iter', 'multistep_iter']:
            lr_sched = lr_scheduler
        else:
            lr_sched = None
            # For PyTorch 1.1+, comment the following line
            #lr_scheduler.step()
        # train
        mprint(' train '.center(40, '*'))
        run_one_epoch(
            epoch, train_loader, model_wrapper, criterion, optimizer,
            train_meters, phase='train', scheduler=lr_sched)

        # val
        mprint(' validation '.center(40, '~'))
        if val_meters is not None:
            val_meters['best_val'].cache(best_val)
        with torch.no_grad():
            top1_error = run_one_epoch(
                epoch, val_loader, model_wrapper, criterion, optimizer,
                val_meters, phase='val')
        if is_master():
            if top1_error < best_val:
                best_val = top1_error
                torch.save(
                    {
                        'model': model_wrapper.state_dict(),
                    },
                    os.path.join(log_dir, 'best_model.pt'))
                mprint('New best validation top1 error: {:.3f}'.format(best_val))

            # save latest checkpoint
            torch.save(
                {
                    'model': model_wrapper.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'last_epoch': epoch,
                    'best_val': best_val,
                    'meters': (train_meters, val_meters),
                },
                os.path.join(log_dir, 'latest_checkpoint.pt'))

        # For PyTorch 1.0 or earlier, comment the following two lines
        if FLAGS.lr_scheduler not in ['exp_decaying_iter', 'cos_annealing_iter', 'multistep_iter']:
            lr_scheduler.step()

    if is_master():
        profiling(model, use_cuda=True)

    return
Beispiel #7
0
def train_val_test():
    """train and val"""
    torch.backends.cudnn.benchmark = True
    # init distributed
    if getattr(FLAGS, 'distributed', False):
        init_dist()
    # seed
    #if getattr(FLAGS, 'use_diff_seed', False):
    #if getattr(FLAGS, 'use_diff_seed', False) and not FLAGS.test_only:
    if getattr(FLAGS, 'use_diff_seed', False) and not getattr(FLAGS, 'stoch_valid', False):
        print('use diff seed is True')
        while not is_initialized():
            print('Waiting for initialization ...')
            time.sleep(5)
        print('Expected seed: {}'.format(getattr(FLAGS, 'random_seed', 0) + get_rank()))
        set_random_seed(getattr(FLAGS, 'random_seed', 0) + get_rank())
    else:
        set_random_seed()

    # experiment setting
    experiment_setting = get_experiment_setting()

    # model
    model, model_wrapper = get_model()
    criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()
    if getattr(FLAGS, 'profiling_only', False):
        if 'gpu' in FLAGS.profiling:
            profiling(model, use_cuda=True)
        if 'cpu' in FLAGS.profiling:
            profiling(model, use_cuda=False)
        return

    #
    ema_decay = getattr(FLAGS, 'ema_decay', None)
    if ema_decay:
        ema = EMA(ema_decay)
        ema.shadow_register(model_wrapper)
        #for name, param in model.named_parameters():
        #    if param.requires_grad:
        #        ema.register(name, param.data)
        #bn_idx = 0
        #for m in model.modules():
        #    if isinstance(m, nn.BatchNorm2d):
        #        ema.register('bn{}_mean'.format(bn_idx), m.running_mean)
        #        ema.register('bn{}_var'.format(bn_idx), m.running_var)
        #        bn_idx += 1
    else:
        ema = None

    # data
    train_transforms, val_transforms, test_transforms = data_transforms()
    train_set, val_set, test_set = dataset(
        train_transforms, val_transforms, test_transforms)
    train_loader, val_loader, test_loader = data_loader(
        train_set, val_set, test_set)

    log_dir = FLAGS.log_dir
    log_dir = os.path.join(log_dir, experiment_setting)
    io = UltronIO('hdfs://haruna/home')
    # full precision pretrained
    if getattr(FLAGS, 'fp_pretrained_file', None):
        checkpoint = io.torch_load(
            FLAGS.fp_pretrained_file, map_location=lambda storage, loc: storage)
        # update keys from external models
        if type(checkpoint) == dict and 'model' in checkpoint:
            checkpoint = checkpoint['model']
        if getattr(FLAGS, 'pretrained_model_remap_keys', False):
            new_checkpoint = {}
            new_keys = list(model_wrapper.state_dict().keys())
            old_keys = list(checkpoint.keys())
            for key_new, key_old in zip(new_keys, old_keys):
                new_checkpoint[key_new] = checkpoint[key_old]
                mprint('remap {} to {}'.format(key_new, key_old))
            checkpoint = new_checkpoint
        model_dict = model_wrapper.state_dict()
        #checkpoint = {k: v for k, v in checkpoint.items() if k in model_dict}
        # remove unexpected keys
        for k in list(checkpoint.keys()):
            if k not in model_dict.keys():
                checkpoint.pop(k)
        model_dict.update(checkpoint)
        model_wrapper.load_state_dict(model_dict)
        mprint('Loaded full precision model {}.'.format(FLAGS.fp_pretrained_file))

    # check pretrained
    if FLAGS.pretrained_file and FLAGS.pretrained_dir:
        pretrained_dir = FLAGS.pretrained_dir
        #pretrained_dir = os.path.join(pretrained_dir, experiment_setting)
        pretrained_file = os.path.join(pretrained_dir, FLAGS.pretrained_file)
        checkpoint = io.torch_load(
            pretrained_file, map_location=lambda storage, loc: storage)
        # update keys from external models
        #if type(checkpoint) == dict and 'model' in checkpoint:
        #    checkpoint = checkpoint['model']
        if getattr(FLAGS, 'pretrained_model_remap_keys', False):
            new_checkpoint = {}
            new_keys = list(model_wrapper.state_dict().keys())
            old_keys = list(checkpoint.keys())
            for key_new, key_old in zip(new_keys, old_keys):
                new_checkpoint[key_new] = checkpoint[key_old]
                mprint('remap {} to {}'.format(key_new, key_old))
            checkpoint = new_checkpoint
        # filter lamda_w and lamda_a args:
        pretrained_dict = {}
        for k,v in checkpoint['model'].items():
            if 'lamda_w' in k or 'lamda_a' in k:
                checkpoint['model'][k] = v.repeat(model_wrapper.state_dict()[k].size())
        model_wrapper.load_state_dict(checkpoint['model'])
        mprint('Loaded model {}.'.format(pretrained_file))
    optimizer = get_optimizer(model_wrapper)

    if FLAGS.test_only and (test_loader is not None):
        mprint('Start testing.')
        ema = checkpoint.get('ema', None)
        test_meters = get_meters('test')
        with torch.no_grad():
            run_one_epoch(
                -1, test_loader,
                model_wrapper, criterion, optimizer,
                test_meters, phase='test', ema=ema)
        return

    # check resume training
    if io.check_path(os.path.join(log_dir, 'latest_checkpoint.pt')):
        checkpoint = io.torch_load(
            os.path.join(log_dir, 'latest_checkpoint.pt'),
            map_location=lambda storage, loc: storage)
        model_wrapper.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        last_epoch = checkpoint['last_epoch']
        if FLAGS.lr_scheduler in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter']:
            lr_scheduler = get_lr_scheduler(optimizer, len(train_loader))
            lr_scheduler.last_epoch = last_epoch * len(train_loader)
        else:
            lr_scheduler = get_lr_scheduler(optimizer)
            lr_scheduler.last_epoch = last_epoch
        best_val = checkpoint['best_val']
        train_meters, val_meters = checkpoint['meters']
        ema = checkpoint.get('ema', None)
        mprint('Loaded checkpoint {} at epoch {}.'.format(
            log_dir, last_epoch))
    else:
        if FLAGS.lr_scheduler in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter']:
            lr_scheduler = get_lr_scheduler(optimizer, len(train_loader))
        else:
            lr_scheduler = get_lr_scheduler(optimizer)
        last_epoch = lr_scheduler.last_epoch
        best_val = 1.
        train_meters = get_meters('train')
        val_meters = get_meters('val')
        # if start from scratch, print model and do profiling
        mprint(model_wrapper)
        if getattr(FLAGS, 'profiling', False):
            if 'gpu' in FLAGS.profiling:
                profiling(model, use_cuda=True)
            if 'cpu' in FLAGS.profiling:
                profiling(model, use_cuda=False)

    if getattr(FLAGS, 'log_dir', None):
        try:
            io.create_folder(log_dir)
        except OSError:
            pass

    mprint('Start training.')
    for epoch in range(last_epoch+1, FLAGS.num_epochs):
        if FLAGS.lr_scheduler in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter']:
            lr_sched = lr_scheduler
        else:
            lr_sched = None
            # For PyTorch 1.1+, comment the following line
            #lr_scheduler.step()
        # train
        mprint(' train '.center(40, '*'))
        run_one_epoch(
          epoch, train_loader, model_wrapper, criterion, optimizer,
          train_meters, phase='train', ema=ema, scheduler=lr_sched)

        # val
        mprint(' validation '.center(40, '~'))
        if val_meters is not None:
            val_meters['best_val'].cache(best_val)
        with torch.no_grad():
            if epoch == getattr(FLAGS,'hard_assign_epoch', float('inf')):
                mprint('Start to use hard assigment')
                setattr(FLAGS, 'hard_assignment', True)
                lower_offset = -1
                higher_offset = 0
                setattr(FLAGS, 'hard_offset', 0)


                with_ratio = 0.01
                bitops, bytesize = profiling(model, use_cuda=True)
                search_trials = 10
                trial = 0
                if getattr(FLAGS,'weight_only', False):
                    target_bytesize = getattr(FLAGS, 'target_size', 0)
                    while trial < search_trials:
                        trial += 1
                        if bytesize - target_bytesize > with_ratio * target_bytesize:
                            higher_offset = FLAGS.hard_offset
                        elif bytesize - target_bytesize < -with_ratio * target_bytesize:
                            lower_offset = FLAGS.hard_offset
                        else:
                            break
                        FLAGS.hard_offset = (higher_offset + lower_offset) /2
                        bitops, bytesize = profiling(model, use_cuda=True)
                else:
                    target_bitops = getattr(FLAGS, 'target_bitops',0)
                    while trial < search_trials:
                        trial += 1
                        if bitops - target_bitops > with_ratio *target_bitops:
                            higher_offset = FLAGS.hard_offset
                        elif bitops - target_bitops < -with_ratio * target_bitops:
                            lower_offset = FLAGS.hard_offset
                        else:
                            break
                        FLAGS.hard_offset = (higher_offset + lower_offset) /2
                        bitops, bytesize = profiling(model, use_cuda=True)
                bit_discretizing(model_wrapper)
                setattr(FLAGS,'hard_offset', 0)
            top1_error = run_one_epoch(
                epoch, val_loader, model_wrapper, criterion, optimizer,
                val_meters, phase='val', ema=ema)
        if is_master():
            if top1_error < best_val:
                best_val = top1_error
                io.torch_save(
                    os.path.join(log_dir, 'best_model.pt'),
                    {
                        'model': model_wrapper.state_dict(),
                    }
                    )
                mprint('New best validation top1 error: {:.3f}'.format(best_val))

            # save latest checkpoint
            io.torch_save(
                os.path.join(log_dir, 'latest_checkpoint.pt'),
                {
                    'model': model_wrapper.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'last_epoch': epoch,
                    'best_val': best_val,
                    'meters': (train_meters, val_meters),
                    'ema': ema,
                })

        # For PyTorch 1.0 or earlier, comment the following two lines
        if FLAGS.lr_scheduler not in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter']:
            lr_scheduler.step()

    if is_master():
        profiling(model, use_cuda=True)
        for m in model.modules():
            if hasattr(m, 'alpha'):
                mprint(m, m.alpha)
            if hasattr(m, 'lamda_w'):
                mprint(m, m.lamda_w)
            if hasattr(m, 'lamda_a'):
                mprint(m, m.lamda_a)
    return
Beispiel #8
0
def train_val_test():
    """train and val"""
    torch.backends.cudnn.benchmark = True
    # init distributed
    if getattr(FLAGS, 'distributed', False):
        init_dist()
    # seed
    if getattr(FLAGS, 'use_diff_seed',
               False) and not getattr(FLAGS, 'stoch_valid', False):
        print('use diff seed is True')
        while not is_initialized():
            print('Waiting for initialization ...')
            time.sleep(5)
        print('Expected seed: {}'.format(
            getattr(FLAGS, 'random_seed', 0) + get_rank()))
        set_random_seed(getattr(FLAGS, 'random_seed', 0) + get_rank())
    else:
        set_random_seed()

    if getattr(FLAGS, 'adjust_lr', False):
        eta_dict = {
            32: 1.0,
            16: 1.0,
            8: 1.0,
            7: 0.99,
            6: 0.98,
            5: 0.97,
            4: 0.94,
            3: 0.88,
            2: 0.77,
            1: 0.58
        }
        eta = lambda b: eta_dict[b]  # noqa: E731
    else:
        eta = None

    # experiment setting
    experiment_setting = get_experiment_setting()
    mprint('stoch_valid: {}, bn_calib_stoch_valid: {}'.format(
        getattr(FLAGS, 'stoch_valid', False),
        getattr(FLAGS, 'bn_calib_stoch_valid', False)))

    # model
    model, model_wrapper = get_model()
    criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()
    if getattr(FLAGS, 'profiling_only', False):
        if 'gpu' in FLAGS.profiling:
            profiling(model, use_cuda=True)
        if 'cpu' in FLAGS.profiling:
            profiling(model, use_cuda=False)
        return

    #
    ema_decay = getattr(FLAGS, 'ema_decay', None)
    if ema_decay:
        ema = EMA(ema_decay)
        ema.shadow_register(model_wrapper)
        #for name, param in model.named_parameters():
        #    if param.requires_grad:
        #        ema.register(name, param.data)
        #bn_idx = 0
        #for m in model.modules():
        #    if isinstance(m, nn.BatchNorm2d):
        #        ema.register('bn{}_mean'.format(bn_idx), m.running_mean)
        #        ema.register('bn{}_var'.format(bn_idx), m.running_var)
        #        bn_idx += 1
    else:
        ema = None

    # data
    train_transforms, val_transforms, test_transforms = data_transforms()
    train_set, val_set, test_set = dataset(train_transforms, val_transforms,
                                           test_transforms)
    train_loader, val_loader, test_loader = data_loader(
        train_set, val_set, test_set)

    log_dir = FLAGS.log_dir
    log_dir = os.path.join(log_dir, experiment_setting)

    # check pretrained
    if FLAGS.pretrained_file:
        pretrained_dir = FLAGS.pretrained_dir
        pretrained_dir = os.path.join(pretrained_dir, experiment_setting)
        pretrained_file = os.path.join(pretrained_dir, FLAGS.pretrained_file)
        checkpoint = torch.load(pretrained_file,
                                map_location=lambda storage, loc: storage)
        # update keys from external models
        if type(checkpoint) == dict and 'model' in checkpoint:
            checkpoint = checkpoint['model']
        if getattr(FLAGS, 'pretrained_model_remap_keys', False):
            new_checkpoint = {}
            new_keys = list(model_wrapper.state_dict().keys())
            old_keys = list(checkpoint.keys())
            for key_new, key_old in zip(new_keys, old_keys):
                new_checkpoint[key_new] = checkpoint[key_old]
                mprint('remap {} to {}'.format(key_new, key_old))
            checkpoint = new_checkpoint
        model_wrapper.load_state_dict(checkpoint)
        mprint('Loaded model {}.'.format(pretrained_file))
    optimizer = get_optimizer(model_wrapper)
    cal_meters = get_meters('cal', single_sample=True)
    mprint('Start calibration.')
    run_one_epoch(-1,
                  train_loader,
                  model_wrapper,
                  criterion,
                  optimizer,
                  cal_meters,
                  phase='cal',
                  ema=ema,
                  single_sample=True)
    mprint('Start validation after calibration.')
    with torch.no_grad():
        run_one_epoch(-1,
                      val_loader,
                      model_wrapper,
                      criterion,
                      optimizer,
                      cal_meters,
                      phase='val',
                      ema=ema,
                      single_sample=True)
    return
Beispiel #9
0
def init(config):
    random_seed = config.random_seed
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    random.seed(random_seed)
    dist.init_dist(config.distributed.enable, port=config.port)