コード例 #1
0
def run_evaluation(args, model, data_loaders, model_description, n_choices,
                   layers_types, downsample_layers):

    start = time.time()

    num_samples = utils.get_number_of_samples(args.dataset)

    all_values = {}

    device = 'cuda'

    #setting up random seeds
    utils.setup_torch(args.seed)

    #creating model skeleton based on description
    propagate_weights = []
    for layer in model_description:
        cur_weights = [0 for i in range(n_choices)]
        cur_weights[layers_types.index(layer)] = 1
        propagate_weights.append(cur_weights)

    model.propagate = propagate_weights

    #Create the computationally identical model but without multiple choice blocks (just a single path net)
    #This is needed to correctly measure MACs
    pruned_model = models.SinglePathSupernet(
        num_classes=utils.get_number_of_classes(args.dataset),
        propagate=propagate_weights,
        put_downsampling=downsample_layers)  #.to(device)
    pruned_model.propagate = propagate_weights
    inputs = torch.randn((1, 3, 32, 32))
    total_ops, total_params = profile(pruned_model, (inputs, ), verbose=True)
    all_values['MMACs'] = np.round(total_ops / (1000.0**2), 2)
    all_values['Params'] = int(total_params)

    del pruned_model
    del inputs

    ################################################
    criterion = torch.nn.CrossEntropyLoss()

    #Initialize batch normalization parameters
    utils.bn_update(device, data_loaders['train_for_bn_recalc'], model)

    val_res = utils.evaluate(device, data_loaders['val'], model, criterion,
                             num_samples['val'])
    test_res = utils.evaluate(device, data_loaders['test'], model, criterion,
                              num_samples['test'])

    all_values['val_loss'] = np.round(val_res['loss'], 3)
    all_values['val_acc'] = np.round(val_res['accuracy'], 3)
    all_values['test_loss'] = np.round(test_res['loss'], 3)
    all_values['test_acc'] = np.round(test_res['accuracy'], 3)

    print(all_values, 'time taken: %.2f sec.' % (time.time() - start))

    utils.save_result(all_values, args.dir, model_description)
コード例 #2
0
def swa_train(model, swa_model, train_iter, valid_iter, optimizer, criterion, pretrain_epochs, swa_epochs, swa_lr, cycle_length, device, writer, cpt_filename):
    swa_n = 1

    swa_model.load_state_dict(copy.deepcopy(model.state_dict()))

    utils.save_checkpoint(
        cpt_directory,
        1,
        '{}-swa-{:2.4f}-{:03d}-{}'.format(date, swa_lr, cycle_length, cpt_filename),
        state_dict=model.state_dict(),
        swa_state_dict=swa_model.state_dict(),
        swa_n=swa_n,
        optimizer=optimizer.state_dict()
    )

    for e in range(swa_epochs):
        epoch = e + pretrain_epochs
        time_ep = time.time()
        lr = utils.schedule(epoch, cycle_length, lr_init, swa_lr)
        utils.adjust_learning_rate(optimizer, lr)

        train_res = utils.train_epoch(model, train_iter, optimizer, criterion, device)
        valid_res = utils.evaluate(model, valid_iter, criterion, device)

        utils.moving_average(swa_model, model, swa_n)
        swa_n += 1
        utils.bn_update(train_iter, swa_model)
        swa_res = utils.evaluate(swa_model, valid_iter, criterion, device)

        time_ep = time.time() - time_ep
        values = [epoch + 1, lr, swa_lr, cycle_length, train_res['loss'], valid_res['loss'], swa_res['loss'], None, None, time_ep]
        writer.writerow(values)

        table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.4f')
        if epoch % 20 == 0:
            table = table.split('\n')
            table = '\n'.join([table[1]] + table)
        else:
            table = table.split('\n')[2]
        print(table)

        utils.save_checkpoint(
            cpt_directory,
            epoch + 1,
            '{}-swa-{:2.4f}-{:03d}-{}'.format(date, swa_lr, cycle_length, cpt_filename),
            state_dict=model.state_dict(),
            swa_state_dict=swa_model.state_dict(),
            swa_n=swa_n,
            optimizer=optimizer.state_dict()
        )
コード例 #3
0
    def evaluate(self, torch_state_dict, update_buffers=False):
        criterion = F.cross_entropy

        # Recover model from state_dict
        self.model.load_state_dict(torch_state_dict)

        # Update BatchNorm buffers (if any)
        if update_buffers:
            utils.bn_update(self.loaders['train'], self.model)

        # Evalute on the training and test sets
        train_res = utils.eval(self.loaders['train'], self.model, criterion)
        test_res = utils.eval(self.loaders['test'], self.model, criterion)

        return train_res, test_res
コード例 #4
0
    lr = schedule(epoch)
    utils.adjust_learning_rate(optimizer, lr)
    train_res = utils.train_epoch(loaders['train'], model, criterion,
                                  optimizer)
    if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1:
        test_res = utils.eval(loaders['test'], model, criterion)
    else:
        test_res = {'loss': None, 'accuracy': None}

    if args.swa and (epoch + 1) >= args.swa_start and (
            epoch + 1 - args.swa_start) % args.swa_c_epochs == 0:
        utils.moving_average(swa_model, model, 1.0 / (swa_n + 1))
        swa_n += 1
        if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1:
            utils.bn_update(loaders['train'], swa_model)
            swa_res = utils.eval(loaders['test'], swa_model, criterion)
        else:
            swa_res = {'loss': None, 'accuracy': None}

    if (epoch + 1) % args.save_freq == 0:
        utils.save_checkpoint(
            args.dir,
            epoch + 1,
            state_dict=model.state_dict(),
            swa_state_dict=swa_model.state_dict() if args.swa else None,
            swa_n=swa_n if args.swa else None,
            optimizer=optimizer.state_dict())

    time_ep = time.time() - time_ep
    values = [
コード例 #5
0
def train_main(cfg):
    '''
    训练的主函数
    :param cfg: 配置
    :return:
    '''

    # config
    train_cfg = cfg.train_cfg
    dataset_cfg = cfg.dataset_cfg
    model_cfg = cfg.model_cfg
    is_parallel = cfg.setdefault(key='is_parallel', default=False)
    device = cfg.device
    is_online_train = cfg.setdefault(key='is_online_train', default=False)

    # 配置logger
    logging.basicConfig(filename=cfg.logfile,
                        filemode='a',
                        level=logging.INFO,
                        format='%(asctime)s\n%(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S')
    logger = logging.getLogger()

    #
    # 构建数据集
    train_dataset = LandDataset(DIR_list=dataset_cfg.train_dir_list,
                                mode='train',
                                input_channel=dataset_cfg.input_channel,
                                transform=dataset_cfg.train_transform)
    split_val_from_train_ratio = dataset_cfg.setdefault(
        key='split_val_from_train_ratio', default=None)
    if split_val_from_train_ratio is None:
        val_dataset = LandDataset(DIR_list=dataset_cfg.val_dir_list,
                                  mode='val',
                                  input_channel=dataset_cfg.input_channel,
                                  transform=dataset_cfg.val_transform)
    else:
        val_size = int(len(train_dataset) * split_val_from_train_ratio)
        train_size = len(train_dataset) - val_size
        train_dataset, val_dataset = random_split(
            train_dataset, [train_size, val_size],
            generator=torch.manual_seed(cfg.random_seed))
        # val_dataset.dataset.transform = dataset_cfg.val_transform # 要配置一下val的transform
        print(f"按照{split_val_from_train_ratio}切分训练集...")

    # 构建dataloader
    def _init_fn():
        np.random.seed(cfg.random_seed)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=train_cfg.batch_size,
                                  shuffle=True,
                                  num_workers=train_cfg.num_workers,
                                  drop_last=True,
                                  worker_init_fn=_init_fn())
    val_dataloader = DataLoader(val_dataset,
                                batch_size=train_cfg.batch_size,
                                num_workers=train_cfg.num_workers,
                                shuffle=False,
                                drop_last=True,
                                worker_init_fn=_init_fn())

    # 构建模型
    if train_cfg.is_swa:
        model = torch.load(train_cfg.check_point_file, map_location=device).to(
            device)  # device参数传在里面,不然默认是先加载到cuda:0,to之后再加载到相应的device上
        swa_model = torch.load(
            train_cfg.check_point_file, map_location=device).to(
                device)  # device参数传在里面,不然默认是先加载到cuda:0,to之后再加载到相应的device上
        if is_parallel:
            model = torch.nn.DataParallel(model)
            swa_model = torch.nn.DataParallel(swa_model)
        swa_n = 0
        parameters = swa_model.parameters()
    else:
        model = build_model(model_cfg).to(device)
        if is_parallel:
            model = torch.nn.DataParallel(model)
        parameters = model.parameters()

    # 定义优化器
    optimizer_cfg = train_cfg.optimizer_cfg
    lr_scheduler_cfg = train_cfg.lr_scheduler_cfg
    if optimizer_cfg.type == 'adam':
        optimizer = optim.Adam(params=parameters,
                               lr=optimizer_cfg.lr,
                               weight_decay=optimizer_cfg.weight_decay)
    elif optimizer_cfg.type == 'adamw':
        optimizer = optim.AdamW(params=parameters,
                                lr=optimizer_cfg.lr,
                                weight_decay=optimizer_cfg.weight_decay)
    elif optimizer_cfg.type == 'sgd':
        optimizer = optim.SGD(params=parameters,
                              lr=optimizer_cfg.lr,
                              momentum=optimizer_cfg.momentum,
                              weight_decay=optimizer_cfg.weight_decay)
    elif optimizer_cfg.type == 'RMS':
        optimizer = optim.RMSprop(params=parameters,
                                  lr=optimizer_cfg.lr,
                                  weight_decay=optimizer_cfg.weight_decay)
    else:
        raise Exception('没有该优化器!')

    if not lr_scheduler_cfg:
        lr_scheduler = None
    elif lr_scheduler_cfg.policy == 'cos':
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            lr_scheduler_cfg.T_0,
            lr_scheduler_cfg.T_mult,
            lr_scheduler_cfg.eta_min,
            last_epoch=lr_scheduler_cfg.last_epoch)
    elif lr_scheduler_cfg.policy == 'LambdaLR':
        import math
        lf = lambda x: (((1 + math.cos(x * math.pi / train_cfg.num_epochs)) / 2
                         )**1.0) * 0.95 + 0.05  # cosine
        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                         lr_lambda=lf)
        lr_scheduler.last_epoch = 0
    else:
        lr_scheduler = None

    # 定义损失函数
    DiceLoss_fn = DiceLoss(mode='multiclass')
    SoftCrossEntropy_fn = SoftCrossEntropyLoss(smooth_factor=0.1)
    loss_func = L.JointLoss(first=DiceLoss_fn,
                            second=SoftCrossEntropy_fn,
                            first_weight=0.5,
                            second_weight=0.5).cuda()
    # loss_cls_func = torch.nn.BCEWithLogitsLoss()

    # 创建保存模型的文件夹
    check_point_dir = '/'.join(model_cfg.check_point_file.split('/')[:-1])
    if not os.path.exists(check_point_dir):  # 如果文件夹不存在就创建
        os.mkdir(check_point_dir)

    # 开始训练
    auto_save_epoch_list = train_cfg.setdefault(key='auto_save_epoch_list',
                                                default=5)  # 每隔几轮保存一次模型,默认为5
    train_loss_list = []
    val_loss_list = []
    val_loss_min = 999999
    best_epoch = 0
    best_miou = 0
    train_loss = 10  # 设置一个初始值
    logger.info('开始在{}上训练{}模型...'.format(device, model_cfg.type))
    logger.info('补充信息:{}\n'.format(cfg.setdefault(key='info', default='None')))
    for epoch in range(train_cfg.num_epochs):
        print()
        print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
        start_time = time.time()
        print(f"正在进行第{epoch}轮训练...")
        logger.info('*' * 10 + f"第{epoch}轮" + '*' * 10)
        #
        # 训练一轮
        if train_cfg.is_swa:  # swa训练方式
            train_loss = train_epoch(swa_model, optimizer, lr_scheduler,
                                     loss_func, train_dataloader, epoch,
                                     device)
            moving_average(model, swa_model, 1.0 / (swa_n + 1))
            swa_n += 1
            bn_update(train_dataloader, model, device)
        else:
            train_loss = train_epoch(model, optimizer, lr_scheduler, loss_func,
                                     train_dataloader, epoch, device)
            # train_loss = train_unet3p_epoch(model, optimizer, lr_scheduler, loss_func, train_dataloader, epoch, device)

        #
        # 在训练集上评估模型
        # val_loss, val_miou = evaluate_unet3p_model(model, val_dataset, loss_func, device,
        #                                     cfg.num_classes, train_cfg.num_workers, batch_size=train_cfg.batch_size)
        if not is_online_train:  # 只有在线下训练的时候才需要评估模型
            val_loss, val_miou = evaluate_model(model, val_dataloader,
                                                loss_func, device,
                                                cfg.num_classes)
        else:
            val_loss = 0
            val_miou = 0

        train_loss_list.append(train_loss)
        val_loss_list.append(val_loss)

        # 保存模型
        if not is_online_train:  # 非线上训练时需要保存best model
            if val_loss < val_loss_min:
                val_loss_min = val_loss
                best_epoch = epoch
                best_miou = val_miou
                if is_parallel:
                    torch.save(model.module, model_cfg.check_point_file)
                else:
                    torch.save(model, model_cfg.check_point_file)

        if epoch in auto_save_epoch_list:  # 如果再需要保存的轮次中,则保存
            model_file = model_cfg.check_point_file.split(
                '.pth')[0] + '-epoch{}.pth'.format(epoch)
            if is_parallel:
                torch.save(model.module, model_file)
            else:
                torch.save(model, model_file)

        # 打印中间结果
        end_time = time.time()
        run_time = int(end_time - start_time)
        m, s = divmod(run_time, 60)
        time_str = "{:02d}分{:02d}秒".format(m, s)
        print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
        out_str = "第{}轮训练完成,耗时{},\t训练集上的loss={:.6f};\t验证集上的loss={:.4f},mIoU={:.6f}\t最好的结果是第{}轮,mIoU={:.6f}" \
            .format(epoch, time_str, train_loss, val_loss, val_miou, best_epoch, best_miou)
        # out_str = "第{}轮训练完成,耗时{},\n训练集上的segm_loss={:.6f},cls_loss{:.6f}\n验证集上的segm_loss={:.4f},cls_loss={:.4f},mIoU={:.6f}\n最好的结果是第{}轮,mIoU={:.6f}" \
        #     .format(epoch, time_str, train_loss, train_cls_loss, val_loss, val_cls_loss, val_miou, best_epoch,
        #             best_miou)
        print(out_str)
        logger.info(out_str + '\n')
コード例 #6
0
                            momentum=args.momentum,
                            weight_decay=args.wd)

start_epoch = 0

columns = ['ep', 'lr', 'tr_loss', 'tr_acc', 'te_loss', 'te_acc', 'time']

for epoch in range(start_epoch, args.epochs):
    time_ep = time.time()

    lr = schedule(epoch)
    utils.adjust_learning_rate(optimizer, lr)
    train_res = utils.train_epoch(loaders['train'], model, criterion,
                                  optimizer)
    if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1:
        utils.bn_update(loaders['train'], model, use_half=True, use_cuda=True)
        test_res = utils.eval(loaders['test'],
                              model,
                              criterion,
                              use_half=True,
                              use_cuda=True)
    else:
        test_res = {'loss': None, 'accuracy': None}

    time_ep = time.time() - time_ep
    values = [
        epoch + 1, lr, train_res['loss'], train_res['accuracy'],
        test_res['loss'], test_res['accuracy'], time_ep
    ]
    table = tabulate.tabulate([values],
                              columns,
コード例 #7
0
criterion = F.cross_entropy
# optimizer = torch.optim.SGD(
#     model_temp.parameters(),
#     lr=args.lr_init,
#     momentum=args.momentum,
#     weight_decay=args.wd
# )

start_epoch = 0
if args.model1_resume is not None:
    print('Resume training from %s' % args.model1_resume)
    checkpoint = torch.load(args.model1_resume)
    start_epoch = checkpoint['epoch']
    model_1.load_state_dict(checkpoint['state_dict'])
    utils.bn_update(loaders['train'], model_1)
    print(utils.eval(loaders['train'], model_1, criterion))
vec_1 = parameters_to_vector(model_1.parameters())

if args.model2_resume is not None:
    print('Resume training from %s' % args.model2_resume)
    checkpoint = torch.load(args.model2_resume)
    start_epoch = checkpoint['epoch']
    model_2.load_state_dict(checkpoint['swa_state_dict'])
    model_temp.load_state_dict(checkpoint['state_dict'])
    utils.bn_update(loaders['train'], model_2)
    print(utils.eval(loaders['train'], model_2, criterion))
vec_2 = parameters_to_vector(model_2.parameters())

vec_inter = vec_1 - vec_2
# vec_inter_norm = torch.norm(vec_inter)
コード例 #8
0
def main():
    script_dir = os.path.dirname(__file__)
    module_path = os.path.abspath(os.path.join(script_dir, '..', '..'))
    global msglogger

    # Parse arguments
    args = parser.get_parser().parse_args()
    if args.epochs is None:
        args.epochs = 90

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    msglogger = apputils.config_pylogger(
        os.path.join(script_dir, 'logging.conf'), args.name, args.output_dir)

    # Log various details about the execution environment.  It is sometimes useful
    # to refer to past experiment executions and this information may be useful.
    apputils.log_execution_env_state(args.compress,
                                     msglogger.logdir,
                                     gitroot=module_path)
    msglogger.debug("Distiller: %s", distiller.__version__)

    start_epoch = 0
    ending_epoch = args.epochs
    perf_scores_history = []

    if args.evaluate:
        args.deterministic = True
    if args.deterministic:
        # Experiment reproducibility is sometimes important.  Pete Warden expounded about this
        # in his blog: https://petewarden.com/2018/03/19/the-machine-learning-reproducibility-crisis/
        distiller.set_deterministic(
        )  # Use a well-known seed, for repeatability of experiments
    else:
        # Turn on CUDNN benchmark mode for best performance. This is usually "safe" for image
        # classification models, as the input sizes don't change during the run
        # See here: https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/3
        cudnn.benchmark = True

    if args.cpu or not torch.cuda.is_available():
        # Set GPU index to -1 if using CPU
        args.device = 'cpu'
        args.gpus = -1
    else:
        args.device = 'cuda'
        if args.gpus is not None:
            try:
                args.gpus = [int(s) for s in args.gpus.split(',')]
            except ValueError:
                raise ValueError(
                    'ERROR: Argument --gpus must be a comma-separated list of integers only'
                )
            available_gpus = torch.cuda.device_count()
            for dev_id in args.gpus:
                if dev_id >= available_gpus:
                    raise ValueError(
                        'ERROR: GPU device ID {0} requested, but only {1} devices available'
                        .format(dev_id, available_gpus))
            # Set default device in case the first one on the list != 0
            torch.cuda.set_device(args.gpus[0])

    # Infer the dataset from the model name
    args.dataset = 'cifar10' if 'cifar' in args.arch else 'imagenet'
    args.num_classes = 10 if args.dataset == 'cifar10' else 1000

    # Create the model
    model = create_model(args.pretrained,
                         args.dataset,
                         args.arch,
                         parallel=not args.load_serialized,
                         device_ids=args.gpus)

    if args.swa:
        swa_model = create_model(args.pretrained,
                                 args.dataset,
                                 args.arch,
                                 parallel=not args.load_serialized,
                                 device_ids=args.gpus)
        swa_n = 0

    compression_scheduler = None
    # Create a couple of logging backends.  TensorBoardLogger writes log files in a format
    # that can be read by Google's Tensor Board.  PythonLogger writes to the Python logger.
    tflogger = TensorBoardLogger(msglogger.logdir)
    pylogger = PythonLogger(msglogger)

    # TODO(barrh): args.deprecated_resume is deprecated since v0.3.1
    if args.deprecated_resume:
        msglogger.warning(
            'The "--resume" flag is deprecated. Please use "--resume-from=YOUR_PATH" instead.'
        )
        if not args.reset_optimizer:
            msglogger.warning(
                'If you wish to also reset the optimizer, call with: --reset-optimizer'
            )
            args.reset_optimizer = True
        args.resumed_checkpoint_path = args.deprecated_resume

    # We can optionally resume from a checkpoint
    optimizer = None
    # TODO: resume from swa mode
    if args.resumed_checkpoint_path:
        if args.swa:
            model, swa_model, swa_n, compression_scheduler, optimizer, start_epoch = apputils.load_checkpoint(
                model,
                args.resumed_checkpoint_path,
                swa_model=swa_model,
                swa_n=swa_n,
                model_device=args.device)
        else:
            model, compression_scheduler, optimizer, start_epoch = apputils.load_checkpoint(
                model, args.resumed_checkpoint_path, model_device=args.device)
    elif args.load_model_path:
        model = apputils.load_lean_checkpoint(model,
                                              args.load_model_path,
                                              model_device=args.device)
    if args.reset_optimizer:
        start_epoch = 0
        if optimizer is not None:
            optimizer = None
            msglogger.info(
                '\nreset_optimizer flag set: Overriding resumed optimizer and resetting epoch count to 0'
            )

    # Define loss function (criterion)
    criterion = nn.CrossEntropyLoss().to(args.device)

    if optimizer is None:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        msglogger.info('Optimizer Type: %s', type(optimizer))
        msglogger.info('Optimizer Args: %s', optimizer.defaults)

    # This sample application can be invoked to produce various summary reports.
    if args.summary:
        return summarize_model(model, args.dataset, which_summary=args.summary)

    activations_collectors = create_activation_stats_collectors(
        model, *args.activation_stats)

    # Load the datasets: the dataset to load is inferred from the model name passed
    # in args.arch.  The default dataset is ImageNet, but if args.arch contains the
    # substring "_cifar", then cifar10 is used.
    train_loader, val_loader, test_loader, _ = apputils.load_data(
        args.dataset, os.path.expanduser(args.data), args.batch_size,
        args.workers, args.validation_split, args.deterministic,
        args.effective_train_size, args.effective_valid_size,
        args.effective_test_size)
    msglogger.info('Dataset sizes:\n\ttraining=%d\n\tvalidation=%d\n\ttest=%d',
                   len(train_loader.sampler), len(val_loader.sampler),
                   len(test_loader.sampler))

    if args.sensitivity is not None:
        sensitivities = np.arange(args.sensitivity_range[0],
                                  args.sensitivity_range[1],
                                  args.sensitivity_range[2])
        return sensitivity_analysis(model, criterion, test_loader, pylogger,
                                    args, sensitivities)

    if args.evaluate:
        return evaluate_model(model, criterion, test_loader, pylogger,
                              activations_collectors, args,
                              compression_scheduler)

    if args.compress:
        # The main use-case for this sample application is CNN compression. Compression
        # requires a compression schedule configuration file in YAML.
        compression_scheduler = distiller.file_config(
            model, optimizer, args.compress, compression_scheduler,
            (start_epoch - 1) if args.resumed_checkpoint_path else None)
        # Model is re-transferred to GPU in case parameters were added (e.g. PACTQuantizer)
        model.to(args.device)
    elif compression_scheduler is None:
        compression_scheduler = distiller.CompressionScheduler(model)

    if args.thinnify:
        #zeros_mask_dict = distiller.create_model_masks_dict(model)
        assert args.resumed_checkpoint_path is not None, \
            "You must use --resume-from to provide a checkpoint file to thinnify"
        distiller.remove_filters(model,
                                 compression_scheduler.zeros_mask_dict,
                                 args.arch,
                                 args.dataset,
                                 optimizer=None)
        apputils.save_checkpoint(0,
                                 args.arch,
                                 model,
                                 optimizer=None,
                                 scheduler=compression_scheduler,
                                 name="{}_thinned".format(
                                     args.resumed_checkpoint_path.replace(
                                         ".pth.tar", "")),
                                 dir=msglogger.logdir)
        print(
            "Note: your model may have collapsed to random inference, so you may want to fine-tune"
        )
        return

    if args.lr_find:
        lr_finder = distiller.LRFinder(model,
                                       optimizer,
                                       criterion,
                                       device=args.device)
        lr_finder.range_test(train_loader, end_lr=10, num_iter=100)
        lr_finder.plot()
        return

    if start_epoch >= ending_epoch:
        msglogger.error(
            'epoch count is too low, starting epoch is {} but total epochs set to {}'
            .format(start_epoch, ending_epoch))
        raise ValueError('Epochs parameter is too low. Nothing to do.')

    for epoch in range(start_epoch, ending_epoch):
        # This is the main training loop.
        msglogger.info('\n')

        if compression_scheduler:
            compression_scheduler.on_epoch_begin(
                epoch, metrics=(vloss if (epoch != start_epoch) else 10**6))

        # Train for one epoch
        with collectors_context(activations_collectors["train"]) as collectors:
            train(train_loader,
                  model,
                  criterion,
                  optimizer,
                  epoch,
                  compression_scheduler,
                  loggers=[tflogger, pylogger],
                  args=args)
            # distiller.log_weights_sparsity(model, epoch, loggers=[tflogger, pylogger])
            # distiller.log_activation_statsitics(epoch, "train", loggers=[tflogger],
            #                                     collector=collectors["sparsity"])
            if args.masks_sparsity:
                msglogger.info(
                    distiller.masks_sparsity_tbl_summary(
                        model, compression_scheduler))

        # evaluate on validation set
        with collectors_context(activations_collectors["valid"]) as collectors:
            top1, top5, vloss = validate(val_loader, model, criterion,
                                         [pylogger], args, epoch)
            msglogger.info('==> Top1: %.3f    Top5: %.3f    Loss: %.3f\n',
                           top1, top5, vloss)
            distiller.log_activation_statsitics(
                epoch,
                "valid",
                loggers=[tflogger],
                collector=collectors["sparsity"])
            save_collectors_data(collectors, msglogger.logdir)

        stats = ('Performance/Validation/',
                 OrderedDict([('Loss', vloss), ('Top1', top1),
                              ('Top5', top5)]))

        if args.swa and (epoch + 1) >= args.swa_start and (
                epoch + 1 - args.swa_start
        ) % args.swa_freq == 0 or epoch == ending_epoch - 1:
            utils.moving_average(swa_model, model, 1. / (swa_n + 1))
            swa_n += 1
            utils.bn_update(train_loader, swa_model, args)
            swa_top1, swa_top5, swa_loss = validate(val_loader, swa_model,
                                                    criterion, [pylogger],
                                                    args, epoch)
            msglogger.info(
                '==> SWA_Top1: %.3f    SWA_Top5: %.3f    SWA_Loss: %.3f\n',
                swa_top1, swa_top5, swa_loss)
            swa_res = OrderedDict([('SWA_Loss', swa_loss),
                                   ('SWA_Top1', swa_top1),
                                   ('SWA_Top5', swa_top5)])
            stats[1].update(swa_res)

        distiller.log_training_progress(stats,
                                        None,
                                        epoch,
                                        steps_completed=0,
                                        total_steps=1,
                                        log_freq=1,
                                        loggers=[tflogger])

        if compression_scheduler:
            compression_scheduler.on_epoch_end(epoch, optimizer)

        # Update the list of top scores achieved so far, and save the checkpoint
        update_training_scores_history(perf_scores_history, model, top1, top5,
                                       epoch, args.num_best_scores)
        is_best = epoch == perf_scores_history[0].epoch
        checkpoint_extras = {
            'current_top1': top1,
            'best_top1': perf_scores_history[0].top1,
            'best_epoch': perf_scores_history[0].epoch
        }
        if args.swa:
            apputils.save_checkpoint(epoch,
                                     args.arch,
                                     model,
                                     swa_model,
                                     swa_n,
                                     optimizer=optimizer,
                                     scheduler=compression_scheduler,
                                     extras=checkpoint_extras,
                                     is_best=is_best,
                                     name=args.name,
                                     dir=msglogger.logdir)
        else:
            apputils.save_checkpoint(epoch,
                                     args.arch,
                                     model,
                                     optimizer=optimizer,
                                     scheduler=compression_scheduler,
                                     extras=checkpoint_extras,
                                     is_best=is_best,
                                     name=args.name,
                                     dir=msglogger.logdir)
    # Finally run results on the test set
    test(test_loader,
         model,
         criterion, [pylogger],
         activations_collectors,
         args=args)
    if args.swa:
        test(test_loader,
             swa_model,
             criterion, [pylogger],
             activations_collectors,
             args=args)
コード例 #9
0
ファイル: nas_evaluate.py プロジェクト: lilujunai/MacroNAS
def run_evaluation(model,
                   ensemble_model,
                   data_loaders,
                   args,
                   save_model='',
                   load_model=''):

    all_values = {}
    device = 'cuda'

    utils.setup_torch(args['seed'])

    inputs = torch.randn(
        (1, args['input_channels'], args['img_size'], args['img_size']))
    total_ops, total_params = profile(model, (inputs, ), verbose=True)
    all_values['MMACs'] = np.round(total_ops / (1000.0**2), 2)
    all_values['Params'] = int(total_params)
    print(all_values)

    start = time.time()
    model = model.to(device)
    ensemble_model = ensemble_model.to(device)
    print('models to device', time.time() - start)

    if len(load_model) > 0:
        model.load_state_dict(torch.load(os.path.join(args['dir'],
                                                      load_model)))

    criterion = torch.nn.CrossEntropyLoss()

    ################################################

    summary(model, (3, 32, 32), batch_size=args['batch_size'], device='cuda')

    criterion = torch.nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args['lr_init'],
                                momentum=0.9,
                                weight_decay=1e-4)

    lrs = []
    n_models = 0

    all_values['epoch'] = []
    all_values['overall_time'] = []
    all_values['lr'] = []

    all_values['tr_loss'] = []
    all_values['tr_acc'] = []

    all_values['val_loss_single'] = []
    all_values['val_acc_single'] = []
    all_values['val_loss_ensemble'] = []
    all_values['val_acc_ensemble'] = []

    all_values['test_loss_single'] = []
    all_values['test_acc_single'] = []
    all_values['test_loss_ensemble'] = []
    all_values['test_acc_ensemble'] = []

    n_models = 0
    time_start = time.time()

    for epoch in range(args['epochs']):
        time_ep = time.time()

        lr = utils.get_cyclic_lr(epoch, lrs, args['lr_init'],
                                 args['lr_start_cycle'], args['cycle_period'])
        #print ('lr=%.3f' % lr)
        utils.set_learning_rate(optimizer, lr)
        lrs.append(lr)

        train_res = utils.train_epoch(device, data_loaders['train'], model,
                                      criterion, optimizer,
                                      args['num_samples_train'])

        values = [epoch + 1, lr, train_res['loss'], train_res['accuracy']]

        if (epoch + 1) >= args['lr_start_cycle'] and (
                epoch + 1) % args['cycle_period'] == 0:

            all_values['epoch'].append(epoch + 1)
            all_values['lr'].append(lr)

            all_values['tr_loss'].append(train_res['loss'])
            all_values['tr_acc'].append(train_res['accuracy'])

            val_res = utils.evaluate(device, data_loaders['val'], model,
                                     criterion, args['num_samples_val'])
            test_res = utils.evaluate(device, data_loaders['test'], model,
                                      criterion, args['num_samples_test'])

            all_values['val_loss_single'].append(val_res['loss'])
            all_values['val_acc_single'].append(val_res['accuracy'])
            all_values['test_loss_single'].append(test_res['loss'])
            all_values['test_acc_single'].append(test_res['accuracy'])

            utils.moving_average_ensemble(ensemble_model, model,
                                          1.0 / (n_models + 1))
            utils.bn_update(device, data_loaders['train_for_bn_recalc'],
                            ensemble_model)
            n_models += 1

            val_res = utils.evaluate(device, data_loaders['val'],
                                     ensemble_model, criterion,
                                     args['num_samples_val'])
            test_res = utils.evaluate(device, data_loaders['test'],
                                      ensemble_model, criterion,
                                      args['num_samples_test'])

            all_values['val_loss_ensemble'].append(val_res['loss'])
            all_values['val_acc_ensemble'].append(val_res['accuracy'])
            all_values['test_loss_ensemble'].append(test_res['loss'])
            all_values['test_acc_ensemble'].append(test_res['accuracy'])

            overall_training_time = time.time() - time_start
            all_values['overall_time'].append(overall_training_time)

        #print (epoch, 'epoch_time', time.time() - time_ep)

        overall_training_time = time.time() - time_start
        #print ('overall time', overall_training_time)

        #print (all_values)

    if len(save_model) > 0:
        torch.save(ensemble_model.state_dict(),
                   os.path.join(args['dir'], save_model + '_ensemble'))
        torch.save(model.state_dict(), os.path.join(args['dir'], save_model))

    return all_values
コード例 #10
0
def train_oneshot_model(args,
                        data_loaders,
                        n_cells,
                        n_choices,
                        put_downsampling=[]):

    num_samples = utils.get_number_of_samples(args.dataset)

    device = 'cuda'

    utils.setup_torch(args.seed)

    print('Initializing model...')

    #Create a supernet skeleton (include all cell types for each position)
    propagate_weights = [[1, 1, 1] for i in range(n_cells)]
    model_class = getattr(models, 'Supernet')

    #Create the supernet model  and its SWA ensemble version
    model = model_class(num_classes=utils.get_number_of_classes(args.dataset),
                        propagate=propagate_weights,
                        training=True,
                        n_choices=n_choices,
                        put_downsampling=put_downsampling).to(device)
    ensemble_model = model_class(num_classes=utils.get_number_of_classes(
        args.dataset),
                                 propagate=propagate_weights,
                                 training=True,
                                 n_choices=n_choices,
                                 put_downsampling=put_downsampling).to(device)

    #These summaries are for verification purposes only
    #However, removing them will cause inconsistency in results since random generators are used inside them to propagate
    summary(model, (3, 32, 32), batch_size=args.batch_size, device='cuda')
    summary(ensemble_model, (3, 32, 32),
            batch_size=args.batch_size,
            device='cuda')

    criterion = torch.nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr_init,
                                momentum=0.9,
                                weight_decay=1e-4)

    start_epoch = 0

    columns = [
        'epoch time', 'overall training time', 'epoch', 'lr', 'train_loss',
        'train_acc', 'val_loss', 'val_acc', 'test_loss', 'test_acc'
    ]

    lrs = []
    n_models = 0

    all_values = {}
    all_values['epoch'] = []
    all_values['lr'] = []

    all_values['tr_loss'] = []
    all_values['tr_acc'] = []

    all_values['val_loss'] = []
    all_values['val_acc'] = []
    all_values['test_loss'] = []
    all_values['test_acc'] = []

    n_models = 0
    print('Start training...')

    time_start = time.time()
    for epoch in range(start_epoch, args.epochs):
        time_ep = time.time()

        #lr = utils.get_cosine_annealing_lr(epoch, args.lr_init, args.epochs)
        lr = utils.get_cyclic_lr(epoch, lrs, args.lr_init, args.lr_start_cycle,
                                 args.cycle_period)
        utils.set_learning_rate(optimizer, lr)
        lrs.append(lr)

        train_res = utils.train_epoch(device, data_loaders['train'], model,
                                      criterion, optimizer,
                                      num_samples['train'])

        values = [epoch + 1, lr, train_res['loss'], train_res['accuracy']]

        if (epoch + 1) >= args.lr_start_cycle and (epoch +
                                                   1) % args.cycle_period == 0:

            all_values['epoch'].append(epoch + 1)
            all_values['lr'].append(lr)

            all_values['tr_loss'].append(train_res['loss'])
            all_values['tr_acc'].append(train_res['accuracy'])

            val_res = utils.evaluate(device, data_loaders['val'], model,
                                     criterion, num_samples['val'])
            test_res = utils.evaluate(device, data_loaders['test'], model,
                                      criterion, num_samples['test'])

            all_values['val_loss'].append(val_res['loss'])
            all_values['val_acc'].append(val_res['accuracy'])
            all_values['test_loss'].append(test_res['loss'])
            all_values['test_acc'].append(test_res['accuracy'])
            values += [
                val_res['loss'], val_res['accuracy'], test_res['loss'],
                test_res['accuracy']
            ]

            utils.moving_average_ensemble(ensemble_model, model,
                                          1.0 / (n_models + 1))
            utils.bn_update(device, data_loaders['train'], ensemble_model)
            n_models += 1

            print(all_values)

        overall_training_time = time.time() - time_start
        values = [time.time() - time_ep, overall_training_time] + values
        table = tabulate.tabulate([values],
                                  columns,
                                  tablefmt='simple',
                                  floatfmt='8.4f')
        print(table)

    print('Training finished. Saving final nets...')
    utils.save_result(all_values, args.dir, 'model_supernet')

    torch.save(model.state_dict(), args.dir + '/supernet.pth')
    torch.save(ensemble_model.state_dict(), args.dir + '/supernet_swa.pth')
コード例 #11
0
ファイル: train.py プロジェクト: lonestar686/swa
def main():

    ds = getattr(torchvision.datasets, args.dataset)
    path = os.path.join(args.data_path, args.dataset.lower())
    train_set = ds(path,
                   train=True,
                   download=True,
                   transform=model_cfg.transform_train)
    test_set = ds(path,
                  train=False,
                  download=True,
                  transform=model_cfg.transform_test)
    loaders = {
        'train':
        torch.utils.data.DataLoader(train_set,
                                    batch_size=args.batch_size,
                                    shuffle=True,
                                    num_workers=args.num_workers,
                                    pin_memory=True),
        'test':
        torch.utils.data.DataLoader(test_set,
                                    batch_size=args.batch_size,
                                    shuffle=False,
                                    num_workers=args.num_workers,
                                    pin_memory=True)
    }
    num_classes = len(train_set.classes)  # max(train_set.train_labels) + 1
    print(num_classes)

    print('Preparing model')
    model = model_cfg.base(*model_cfg.args,
                           num_classes=num_classes,
                           **model_cfg.kwargs)
    model.cuda()

    if args.swa:
        print('SWA training')
        swa_model = model_cfg.base(*model_cfg.args,
                                   num_classes=num_classes,
                                   **model_cfg.kwargs)
        swa_model.cuda()
        swa_n = 0
    else:
        print('SGD training')

    def schedule(epoch):
        t = (epoch) / (args.swa_start if args.swa else args.epochs)
        lr_ratio = args.swa_lr / args.lr_init if args.swa else 0.01
        if t <= 0.5:
            factor = 1.0
        elif t <= 0.9:
            factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4
        else:
            factor = lr_ratio
        return args.lr_init * factor

    criterion = F.cross_entropy
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr_init,
                                momentum=args.momentum,
                                weight_decay=args.wd)

    start_epoch = 0
    if args.resume is not None:
        print('Resume training from %s' % args.resume)
        checkpoint = torch.load(args.resume)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if args.swa:
            swa_state_dict = checkpoint['swa_state_dict']
            if swa_state_dict is not None:
                swa_model.load_state_dict(swa_state_dict)
            swa_n_ckpt = checkpoint['swa_n']
            if swa_n_ckpt is not None:
                swa_n = swa_n_ckpt

    columns = ['ep', 'lr', 'tr_loss', 'tr_acc', 'te_loss', 'te_acc', 'time']
    if args.swa:
        columns = columns[:-1] + ['swa_te_loss', 'swa_te_acc'] + columns[-1:]
        swa_res = {'loss': None, 'accuracy': None}

    utils.save_checkpoint(
        args.dir,
        start_epoch,
        state_dict=model.state_dict(),
        swa_state_dict=swa_model.state_dict() if args.swa else None,
        swa_n=swa_n if args.swa else None,
        optimizer=optimizer.state_dict())

    for epoch in range(start_epoch, args.epochs):
        time_ep = time.time()

        lr = schedule(epoch)
        utils.adjust_learning_rate(optimizer, lr)
        train_res = utils.train_epoch(loaders['train'], model, criterion,
                                      optimizer)
        if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1:
            test_res = utils.eval(loaders['test'], model, criterion)
        else:
            test_res = {'loss': None, 'accuracy': None}

        if args.swa and (epoch + 1) >= args.swa_start and (
                epoch + 1 - args.swa_start) % args.swa_c_epochs == 0:
            utils.moving_average(swa_model, model, 1.0 / (swa_n + 1))
            swa_n += 1
            if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1:
                utils.bn_update(loaders['train'], swa_model)
                swa_res = utils.eval(loaders['test'], swa_model, criterion)
            else:
                swa_res = {'loss': None, 'accuracy': None}

        if (epoch + 1) % args.save_freq == 0:
            utils.save_checkpoint(
                args.dir,
                epoch + 1,
                state_dict=model.state_dict(),
                swa_state_dict=swa_model.state_dict() if args.swa else None,
                swa_n=swa_n if args.swa else None,
                optimizer=optimizer.state_dict())

        time_ep = time.time() - time_ep
        values = [
            epoch + 1, lr, train_res['loss'], train_res['accuracy'],
            test_res['loss'], test_res['accuracy'], time_ep
        ]
        if args.swa:
            values = values[:-1] + [swa_res['loss'], swa_res['accuracy']
                                    ] + values[-1:]
        table = tabulate.tabulate([values],
                                  columns,
                                  tablefmt='simple',
                                  floatfmt='8.4f')
        if epoch % 40 == 0:
            table = table.split('\n')
            table = '\n'.join([table[1]] + table)
        else:
            table = table.split('\n')[2]
        print(table)

    if args.epochs % args.save_freq != 0:
        utils.save_checkpoint(
            args.dir,
            args.epochs,
            state_dict=model.state_dict(),
            swa_state_dict=swa_model.state_dict() if args.swa else None,
            swa_n=swa_n if args.swa else None,
            optimizer=optimizer.state_dict())