예제 #1
0
def main_test():
    global args, best_prec1, dtype
    best_prec1 = 0
    args = parser.parse_args()
    dtype = torch_dtypes.get(args.dtype)
    torch.manual_seed(args.seed)
    time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    if args.evaluate:
        args.results_dir = '/tmp'
    if args.save is '':
        args.save = time_stamp
    save_path = os.path.join(args.results_dir, args.save)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    setup_logging(os.path.join(save_path, 'log.txt'),
                  resume=args.resume is not '')
    results_path = os.path.join(save_path, 'results')
    results = ResultsLog(results_path,
                         title='Training Results - %s' % args.save)

    logging.info("saving to %s", save_path)
    logging.debug("run arguments: %s", args)

    if 'cuda' in args.device and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.cuda.set_device(args.device_ids[0])
        cudnn.benchmark = True
    else:
        args.device_ids = None

    # create model
    logging.info("creating model %s", args.model)
    model = models.__dict__[args.model]
    model_config = {'input_size': args.input_size, 'dataset': args.dataset}

    if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))

    model = model(**model_config)
    return model
예제 #2
0
def main():
    global args, best_prec1, dtype
    best_prec1 = 0
    args = parser.parse_args()
    dtype = torch_dtypes.get(args.dtype)
    torch.manual_seed(args.seed)
    time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    if args.evaluate:
        args.results_dir = '/tmp'
    if args.save is '':
        args.save = time_stamp
    save_path = os.path.join(args.results_dir, args.save)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    setup_logging(os.path.join(save_path, 'log.txt'),
                  resume=args.resume is not '')
    results_path = os.path.join(save_path, 'results')
    results = ResultsLog(results_path,
                         title='Training Results - %s' % args.save)

    logging.info("saving to %s", save_path)
    logging.debug("run arguments: %s", args)

    if 'cuda' in args.device and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.cuda.set_device(args.device_ids[0])
        cudnn.benchmark = True
    else:
        args.device_ids = None

    # create model
    logging.info("creating model %s", args.model)
    model = models.__dict__[args.model]
    model_config = {'input_size': args.input_size, 'dataset': args.dataset}

    if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))

    model = model(**model_config)
    logging.info("created model with configuration: %s", model_config)

    # optionally resume from a checkpoint
    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            parser.error('invalid checkpoint: {}'.format(args.evaluate))
        checkpoint = torch.load(args.evaluate)
        model.load_state_dict(checkpoint['state_dict'])
        logging.info("loaded checkpoint '%s' (epoch %s)", args.evaluate,
                     checkpoint['epoch'])
    elif args.resume:
        checkpoint_file = args.resume
        if os.path.isdir(checkpoint_file):
            results.load(os.path.join(checkpoint_file, 'results.csv'))
            checkpoint_file = os.path.join(checkpoint_file,
                                           'model_best.pth.tar')
        if os.path.isfile(checkpoint_file):
            logging.info("loading checkpoint '%s'", args.resume)
            checkpoint = torch.load(checkpoint_file)
            args.start_epoch = checkpoint['epoch'] - 1
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            logging.info("loaded checkpoint '%s' (epoch %s)", checkpoint_file,
                         checkpoint['epoch'])
        else:
            logging.error("no checkpoint found at '%s'", args.resume)

    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

    # Data loading code
    default_transform = {
        'train':
        get_transform(args.dataset, input_size=args.input_size, augment=True),
        'eval':
        get_transform(args.dataset, input_size=args.input_size, augment=False)
    }
    transform = getattr(model, 'input_transform', default_transform)
    regime = getattr(model, 'regime', [{
        'epoch': 0,
        'optimizer': args.optimizer,
        'lr': args.lr,
        'momentum': args.momentum,
        'weight_decay': args.weight_decay
    }])

    # define loss function (criterion) and optimizer
    criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)()
    criterion.to(args.device, dtype)
    model.to(args.device, dtype)

    val_data = get_dataset(args.dataset, 'val', transform['eval'])
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    train_data = get_dataset(args.dataset, 'train', transform['train'])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    optimizer = OptimRegime(model.parameters(), regime)
    logging.info('training regime: %s', regime)

    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        train_loss, train_prec1, train_prec5 = train(train_loader, model,
                                                     criterion, epoch,
                                                     optimizer)

        # evaluate on validation set
        val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion,
                                                  epoch)

        # remember best prec@1 and save checkpoint
        is_best = val_prec1 > best_prec1
        best_prec1 = max(val_prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': args.model,
                'config': args.model_config,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'regime': regime
            },
            is_best,
            path=save_path)
        logging.info('\n Epoch: {0}\t'
                     'Training Loss {train_loss:.4f} \t'
                     'Training Prec@1 {train_prec1:.3f} \t'
                     'Training Prec@5 {train_prec5:.3f} \t'
                     'Validation Loss {val_loss:.4f} \t'
                     'Validation Prec@1 {val_prec1:.3f} \t'
                     'Validation Prec@5 {val_prec5:.3f} \n'.format(
                         epoch + 1,
                         train_loss=train_loss,
                         val_loss=val_loss,
                         train_prec1=train_prec1,
                         val_prec1=val_prec1,
                         train_prec5=train_prec5,
                         val_prec5=val_prec5))

        results.add(epoch=epoch + 1,
                    train_loss=train_loss,
                    val_loss=val_loss,
                    train_error1=100 - train_prec1,
                    val_error1=100 - val_prec1,
                    train_error5=100 - train_prec5,
                    val_error5=100 - val_prec5)
        results.plot(x='epoch',
                     y=['train_loss', 'val_loss'],
                     legend=['training', 'validation'],
                     title='Loss',
                     ylabel='loss')
        results.plot(x='epoch',
                     y=['train_error1', 'val_error1'],
                     legend=['training', 'validation'],
                     title='Error@1',
                     ylabel='error %')
        results.plot(x='epoch',
                     y=['train_error5', 'val_error5'],
                     legend=['training', 'validation'],
                     title='Error@5',
                     ylabel='error %')
        results.save()
예제 #3
0
def main():
    global args, best_prec1, dtype
    best_prec1 = 0
    args = parser.parse_args()
    dtype = torch_dtypes.get(args.dtype)
    torch.manual_seed(args.seed)
    time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    if args.evaluate:
        args.results_dir = '/tmp'
    if args.save is '':
        args.save = time_stamp
    save_path = os.path.join(args.results_dir, args.save)

    args.distributed = args.local_rank >= 0 or args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_init,
                                world_size=args.world_size,
                                rank=args.local_rank)
        args.local_rank = dist.get_rank()
        args.world_size = dist.get_world_size()
        if args.dist_backend == 'mpi':
            # If using MPI, select all visible devices
            args.device_ids = list(range(torch.cuda.device_count()))
        else:
            args.device_ids = [args.local_rank]

    if not os.path.exists(save_path) and not (args.distributed
                                              and args.local_rank > 0):
        os.makedirs(save_path)

    setup_logging(os.path.join(save_path, 'log.txt'),
                  resume=args.resume is not '',
                  dummy=args.distributed and args.local_rank > 0)

    results_path = os.path.join(save_path, 'results')
    results = ResultsLog(results_path,
                         title='Training Results - %s' % args.save)

    logging.info("saving to %s", save_path)
    logging.debug("run arguments: %s", args)
    logging.info("creating model %s", args.model)

    if 'cuda' in args.device and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.cuda.set_device(args.device_ids[0])
        cudnn.benchmark = True
    else:
        args.device_ids = None

    # create model
    model = models.__dict__[args.model]
    model_config = {'dataset': args.dataset}

    if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))

    model = model(**model_config)
    logging.info("created model with configuration: %s", model_config)
    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

    # optionally resume from a checkpoint
    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            parser.error('invalid checkpoint: {}'.format(args.evaluate))
        checkpoint = torch.load(args.evaluate)
        model.load_state_dict(checkpoint['state_dict'])
        logging.info("loaded checkpoint '%s' (epoch %s)", args.evaluate,
                     checkpoint['epoch'])
    elif args.resume:
        checkpoint_file = args.resume
        if os.path.isdir(checkpoint_file):
            results.load(os.path.join(checkpoint_file, 'results.csv'))
            checkpoint_file = os.path.join(checkpoint_file,
                                           'model_best.pth.tar')
        if os.path.isfile(checkpoint_file):
            logging.info("loading checkpoint '%s'", args.resume)
            checkpoint = torch.load(checkpoint_file)
            if args.start_epoch < 0:  # not explicitly set
                args.start_epoch = checkpoint['epoch'] - 1
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            logging.info("loaded checkpoint '%s' (epoch %s)", checkpoint_file,
                         checkpoint['epoch'])
        else:
            logging.error("no checkpoint found at '%s'", args.resume)

    # define loss function (criterion) and optimizer
    loss_params = {}
    if args.label_smoothing > 0:
        loss_params['smooth_eps'] = args.label_smoothing
    criterion = getattr(model, 'criterion', CrossEntropyLoss)(**loss_params)
    criterion.to(args.device, dtype)
    model.to(args.device, dtype)

    # Batch-norm should always be done in float
    if 'half' in args.dtype:
        FilterModules(model, module=is_bn).to(dtype=torch.float)

    # optimizer configuration
    optim_regime = getattr(model, 'regime', [{
        'epoch': 0,
        'optimizer': args.optimizer,
        'lr': args.lr,
        'momentum': args.momentum,
        'weight_decay': args.weight_decay
    }])

    optimizer = optim_regime if isinstance(optim_regime, OptimRegime) \
        else OptimRegime(model, optim_regime, use_float_copy='half' in args.dtype)

    trainer = Trainer(model,
                      criterion,
                      optimizer,
                      device_ids=args.device_ids,
                      device=args.device,
                      dtype=dtype,
                      distributed=args.distributed,
                      local_rank=args.local_rank,
                      grad_clip=args.grad_clip,
                      print_freq=args.print_freq,
                      adapt_grad_norm=args.adapt_grad_norm)

    # Evaluation Data loading code
    args.eval_batch_size = args.eval_batch_size if args.eval_batch_size > 0 else args.batch_size
    val_data = DataRegime(getattr(model, 'data_eval_regime', None),
                          defaults={
                              'datasets_path': args.datasets_dir,
                              'name': args.dataset,
                              'split': 'val',
                              'augment': False,
                              'input_size': args.input_size,
                              'batch_size': args.eval_batch_size,
                              'shuffle': False,
                              'num_workers': args.workers,
                              'pin_memory': True,
                              'drop_last': False
                          })

    if args.evaluate:
        results = trainer.validate(val_data.get_loader())
        logging.info(results)
        return

    # Training Data loading code
    train_data = DataRegime(getattr(model, 'data_regime', None),
                            defaults={
                                'datasets_path': args.datasets_dir,
                                'name': args.dataset,
                                'split': 'train',
                                'augment': True,
                                'input_size': args.input_size,
                                'batch_size': args.batch_size,
                                'shuffle': True,
                                'num_workers': args.workers,
                                'pin_memory': True,
                                'drop_last': True,
                                'distributed': args.distributed,
                                'duplicates': args.duplicates,
                                'cutout': {
                                    'holes': 1,
                                    'length': 16
                                } if args.cutout else None
                            })

    logging.info('optimization regime: %s', optim_regime)
    args.start_epoch = max(args.start_epoch, 0)
    trainer.training_steps = args.start_epoch * len(train_data)
    for epoch in range(args.start_epoch, args.epochs):
        trainer.epoch = epoch
        train_data.set_epoch(epoch)
        val_data.set_epoch(epoch)
        logging.info('\nStarting Epoch: {0}\n'.format(epoch + 1))

        # train for one epoch
        train_results = trainer.train(train_data.get_loader(),
                                      duplicates=train_data.get('duplicates'),
                                      chunk_batch=args.chunk_batch)

        # evaluate on validation set
        val_results = trainer.validate(val_data.get_loader())

        if args.distributed and args.local_rank > 0:
            continue

        # remember best prec@1 and save checkpoint
        is_best = val_results['prec1'] > best_prec1
        best_prec1 = max(val_results['prec1'], best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': args.model,
                'config': args.model_config,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1
            },
            is_best,
            path=save_path)

        logging.info('\nResults - Epoch: {0}\n'
                     'Training Loss {train[loss]:.4f} \t'
                     'Training Prec@1 {train[prec1]:.3f} \t'
                     'Training Prec@5 {train[prec5]:.3f} \t'
                     'Validation Loss {val[loss]:.4f} \t'
                     'Validation Prec@1 {val[prec1]:.3f} \t'
                     'Validation Prec@5 {val[prec5]:.3f} \t\n'.format(
                         epoch + 1, train=train_results, val=val_results))

        values = dict(epoch=epoch + 1, steps=trainer.training_steps)
        values.update({'training ' + k: v for k, v in train_results.items()})
        values.update({'validation ' + k: v for k, v in val_results.items()})
        results.add(**values)

        results.plot(x='epoch',
                     y=['training loss', 'validation loss'],
                     legend=['training', 'validation'],
                     title='Loss',
                     ylabel='loss')
        results.plot(x='epoch',
                     y=['training error1', 'validation error1'],
                     legend=['training', 'validation'],
                     title='Error@1',
                     ylabel='error %')
        results.plot(x='epoch',
                     y=['training error5', 'validation error5'],
                     legend=['training', 'validation'],
                     title='Error@5',
                     ylabel='error %')
        if 'grad' in train_results.keys():
            results.plot(x='epoch',
                         y=['training grad'],
                         legend=['gradient L2 norm'],
                         title='Gradient Norm',
                         ylabel='value')
        results.save()
def main_worker(args):
    global best_prec1, dtype
    best_prec1 = 0
    dtype = torch_dtypes.get(args.dtype)
    torch.manual_seed(args.seed)
    time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    if args.evaluate:
        args.results_dir = '/tmp'
    if args.save is '':
        args.save = time_stamp
    save_path = os.path.join(args.results_dir, args.save)

    args.distributed = args.local_rank >= 0 or args.world_size > 1

    if not os.path.exists(save_path) and not (args.distributed
                                              and args.local_rank > 0):
        os.makedirs(save_path)

    setup_logging(os.path.join(save_path, 'log.txt'),
                  resume=args.resume is not '',
                  dummy=args.distributed and args.local_rank > 0)

    results_path = os.path.join(save_path, 'results')
    results = ResultsLog(results_path,
                         title='Training Results - %s' % args.save)

    if 'cuda' in args.device and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.cuda.set_device(args.device_ids[0])
        cudnn.benchmark = True
    else:
        args.device_ids = None

    if not os.path.isfile(args.evaluate):
        parser.error('invalid checkpoint: {}'.format(args.evaluate))
    checkpoint = torch.load(args.evaluate, map_location="cpu")
    # Overrride configuration with checkpoint info
    args.model = checkpoint.get('model', args.model)
    args.model_config = checkpoint.get('config', args.model_config)

    logging.info("saving to %s", save_path)
    logging.debug("run arguments: %s", args)
    logging.info("creating model %s", args.model)

    # create model
    model = models.__dict__[args.model]
    model_config = {'dataset': args.dataset}

    if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))

    model = model(**model_config)
    logging.info("created model with configuration: %s", model_config)
    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

    # load checkpoint

    ################
    if args.pretrained:
        state_dict = load_state_dict_from_url(model_urls[args.model],
                                              progress=progress)
        model.load_state_dict(state_dict)

        # model = imagenet_extra_models.__dict__[arch](pretrained=pretrained) // from distiller
    else:
        model.load_state_dict(checkpoint['state_dict'])
        logging.info("loaded checkpoint '%s' (epoch %s)", args.evaluate,
                     checkpoint['epoch'])
    ###########
    model.load_state_dict(checkpoint['state_dict'])
    logging.info("loaded checkpoint '%s' (epoch %s)", args.evaluate,
                 checkpoint['epoch'])

    if args.absorb_bn:
        search_absorb_bn(model, remove_bn=not args.calibrate_bn, verbose=True)

    # define loss function (criterion) and optimizer
    loss_params = {}
    if args.label_smoothing > 0:
        loss_params['smooth_eps'] = args.label_smoothing
    criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)(**loss_params)
    criterion.to(args.device, dtype)
    model.to(args.device, dtype)

    # Batch-norm should always be done in float
    if 'half' in args.dtype:
        FilterModules(model, module=is_bn).to(dtype=torch.float)

    trainer = Trainer(model,
                      criterion,
                      device_ids=args.device_ids,
                      device=args.device,
                      dtype=dtype,
                      mixup=args.mixup,
                      print_freq=args.print_freq)

    # Evaluation Data loading code
    val_data = DataRegime(None,
                          defaults={
                              'datasets_path': args.datasets_dir,
                              'name': args.dataset,
                              'split': 'val',
                              'augment': args.augment,
                              'input_size': args.input_size,
                              'batch_size': args.batch_size,
                              'shuffle': False,
                              'duplicates': args.duplicates,
                              'autoaugment': args.autoaugment,
                              'cutout': {
                                  'holes': 1,
                                  'length': 16
                              } if args.cutout else None,
                              'num_workers': args.workers,
                              'pin_memory': True,
                              'drop_last': False
                          })

    if args.calibrate_bn:
        train_data = DataRegime(None,
                                defaults={
                                    'datasets_path': args.datasets_dir,
                                    'name': args.dataset,
                                    'split': 'train',
                                    'augment': True,
                                    'input_size': args.input_size,
                                    'batch_size': args.batch_size,
                                    'shuffle': True,
                                    'num_workers': args.workers,
                                    'pin_memory': True,
                                    'drop_last': False
                                })
        trainer.calibrate_bn(train_data.get_loader(), num_steps=200)
    results = trainer.validate(val_data.get_loader(),
                               average_output=args.avg_out)
    logging.info(results)
    print(results)
    return results
예제 #5
0
def main_worker(args):
    global best_prec1, dtype
    best_prec1 = 0
    dtype = torch_dtypes.get(args.dtype)
    torch.manual_seed(args.seed)
    time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    if args.evaluate:
        args.results_dir = '/tmp'
    if args.save is '':
        args.save = time_stamp
    save_path = path.join(args.results_dir, args.save)

    args.distributed = args.local_rank >= 0 or args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_init,
                                world_size=args.world_size,
                                rank=args.local_rank)
        args.local_rank = dist.get_rank()
        args.world_size = dist.get_world_size()
        if args.dist_backend == 'mpi':
            # If using MPI, select all visible devices
            args.device_ids = list(range(torch.cuda.device_count()))
        else:
            args.device_ids = [args.local_rank]

    # if not (args.distributed and args.local_rank > 0):
    if not path.exists(save_path):
        makedirs(save_path)
    dump_args(args, path.join(save_path, 'args.txt'))

    setup_logging(path.join(save_path, 'log.txt'),
                  resume=args.resume is not '',
                  dummy=False)

    results_path = path.join(save_path, 'results')
    results = ResultsLog(results_path,
                         title='Training Results - %s' % args.save)

    logging.info("saving to %s", save_path)
    logging.debug("run arguments: %s", args)
    logging.info("creating model %s", args.model)

    if 'cuda' in args.device and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.cuda.set_device(args.device_ids[0])
        cudnn.benchmark = True
    else:
        args.device_ids = None

    # All parameters to the model should be passed via this dict.
    model_config = {
        'dataset': args.dataset,
        'dp_type': args.dropout_type,
        'dp_percentage': args.dropout_perc,
        'dropout': args.drop_rate,
        'device': args.device
    }

    if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))

    # create Resnet model
    model = resnet(**model_config)
    logging.info("created model with configuration: %s", model_config)
    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

    # # optionally resume from a checkpoint
    # if args.evaluate:
    #     if not path.isfile(args.evaluate):
    #         parser.error('invalid checkpoint: {}'.format(args.evaluate))
    #     checkpoint = torch.load(args.evaluate, map_location="cpu")
    #     # Overrride configuration with checkpoint info
    #     args.model = checkpoint.get('model', args.model)
    #     args.model_config = checkpoint.get('config', args.model_config)
    #     # load checkpoint
    #     model.load_state_dict(checkpoint['state_dict'])
    #     logging.info("loaded checkpoint '%s' (epoch %s)",
    #                  args.evaluate, checkpoint['epoch'])
    #
    # if args.resume:
    #     checkpoint_file = args.resume
    #     if path.isdir(checkpoint_file):
    #         results.load(path.join(checkpoint_file, 'results.csv'))
    #         checkpoint_file = path.join(
    #             checkpoint_file, 'model_best.pth.tar')
    #     if path.isfile(checkpoint_file):
    #         logging.info("loading checkpoint '%s'", args.resume)
    #         checkpoint = torch.load(checkpoint_file, map_location="cpu")
    #         if args.start_epoch < 0:  # not explicitly set
    #             args.start_epoch = checkpoint['epoch']
    #         best_prec1 = checkpoint['best_prec1']
    #         model.load_state_dict(checkpoint['state_dict'])
    #         optim_state_dict = checkpoint.get('optim_state_dict', None)
    #         logging.info("loaded checkpoint '%s' (epoch %s)",
    #                      checkpoint_file, checkpoint['epoch'])
    #     else:
    #         logging.error("no checkpoint found at '%s'", args.resume)
    # else:
    #     optim_state_dict = None

    # define loss function (criterion) and optimizer
    loss_params = {}
    if args.label_smoothing > 0:
        loss_params['smooth_eps'] = args.label_smoothing
    criterion = getattr(model, 'criterion', CrossEntropyLoss)(**loss_params)
    criterion.to(args.device, dtype)
    model.to(args.device, dtype)

    # Batch-norm should always be done in float
    if 'half' in args.dtype:
        FilterModules(model, module=is_bn).to(dtype=torch.float)

    # optimizer configuration
    optim_regime = getattr(model, 'regime', [{
        'epoch': 0,
        'optimizer': args.optimizer,
        'lr': args.lr,
        'momentum': args.momentum,
        'weight_decay': args.weight_decay
    }])

    optimizer = optim_regime if isinstance(optim_regime, OptimRegime) \
        else OptimRegime(model, optim_regime, use_float_copy='half' in args.dtype)

    # if optim_state_dict is not None:
    #     optimizer.load_state_dict(optim_state_dict)

    trainer = Trainer(model,
                      criterion,
                      optimizer,
                      device_ids=args.device_ids,
                      device=args.device,
                      dtype=dtype,
                      distributed=args.distributed,
                      local_rank=args.local_rank,
                      mixup=args.mixup,
                      loss_scale=args.loss_scale,
                      grad_clip=args.grad_clip,
                      print_freq=args.print_freq,
                      adapt_grad_norm=args.adapt_grad_norm)
    if args.tensorwatch:
        trainer.set_watcher(filename=path.abspath(
            path.join(save_path, 'tensorwatch.log')),
                            port=args.tensorwatch_port)

    # Evaluation Data loading code
    args.eval_batch_size = args.eval_batch_size if args.eval_batch_size > 0 else args.batch_size
    val_data = DataRegime(getattr(model, 'data_eval_regime', None),
                          defaults={
                              'datasets_path': args.datasets_dir,
                              'name': args.dataset,
                              'split': 'val',
                              'augment': False,
                              'input_size': args.input_size,
                              'batch_size': args.eval_batch_size,
                              'shuffle': False,
                              'num_workers': args.workers,
                              'pin_memory': True,
                              'drop_last': False
                          })

    if args.evaluate:
        results = trainer.validate(val_data.get_loader())
        logging.info(results)
        return

    # Training Data loading code
    train_data_defaults = {
        'datasets_path': args.datasets_dir,
        'name': args.dataset,
        'split': 'train',
        'augment': True,
        'input_size': args.input_size,
        'batch_size': args.batch_size,
        'shuffle': True,
        'num_workers': args.workers,
        'pin_memory': True,
        'drop_last': True,
        'distributed': args.distributed,
        'duplicates': args.duplicates,
        'autoaugment': args.autoaugment,
        'cutout': {
            'holes': 1,
            'length': 16
        } if args.cutout else None
    }

    if hasattr(model, 'sampled_data_regime'):
        sampled_data_regime = model.sampled_data_regime
        probs, regime_configs = zip(*sampled_data_regime)
        regimes = []
        for config in regime_configs:
            defaults = {**train_data_defaults}
            defaults.update(config)
            regimes.append(DataRegime(None, defaults=defaults))
        train_data = SampledDataRegime(regimes, probs)
    else:
        train_data = DataRegime(getattr(model, 'data_regime', None),
                                defaults=train_data_defaults)

    logging.info('optimization regime: %s', optim_regime)
    args.start_epoch = max(args.start_epoch, 0)
    trainer.training_steps = args.start_epoch * len(train_data)
    for epoch in range(args.start_epoch, args.epochs):
        trainer.epoch = epoch
        train_data.set_epoch(epoch)
        val_data.set_epoch(epoch)
        logging.info('\nStarting Epoch: {0}\n'.format(epoch + 1))

        # train for one epoch
        train_results = trainer.train(train_data.get_loader(),
                                      chunk_batch=args.chunk_batch)

        # evaluate on validation set
        val_results = trainer.validate(val_data.get_loader())

        # # save weights heatmap
        # w = model._modules['layer3']._modules['5']._modules['conv2']._parameters['weight'].view(64, -1).cpu().detach().numpy()
        # heat_maps_dir = 'C:\\Users\\Pavel\\Desktop\\targeted_dropout_pytorch\\pics\\experiment_0'
        # plot = sns.heatmap(w, center=0)
        # name = str(datetime.now()).replace(':', '_').replace('-', '_').replace('.', '_').replace(' ', '_') + '.png'
        # plot.get_figure().savefig(path.join(heat_maps_dir, name))
        # plt.clf()

        if args.distributed and args.local_rank > 0:
            continue

        # remember best prec@1 and save checkpoint
        is_best = val_results['prec1'] > best_prec1
        best_prec1 = max(val_results['prec1'], best_prec1)

        if args.drop_optim_state:
            optim_state_dict = None
        else:
            optim_state_dict = optimizer.state_dict()

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': args.model,
                'config': args.model_config,
                'state_dict': model.state_dict(),
                'optim_state_dict': optim_state_dict,
                'best_prec1': best_prec1
            },
            is_best,
            path=save_path,
            save_all=False)

        logging.info('\nResults - Epoch: {0}\n'
                     'Training Loss {train[loss]:.4f} \t'
                     'Training Prec@1 {train[prec1]:.3f} \t'
                     'Training Prec@5 {train[prec5]:.3f} \t'
                     'Validation Loss {val[loss]:.4f} \t'
                     'Validation Prec@1 {val[prec1]:.3f} \t'
                     'Validation Prec@5 {val[prec5]:.3f} \t\n'.format(
                         epoch + 1, train=train_results, val=val_results))

        values = dict(epoch=epoch + 1, steps=trainer.training_steps)
        values.update({'training ' + k: v for k, v in train_results.items()})
        values.update({'validation ' + k: v for k, v in val_results.items()})
        results.add(**values)

        results.plot(x='epoch',
                     y=['training loss', 'validation loss'],
                     legend=['training', 'validation'],
                     title='Loss',
                     ylabel='loss')
        results.plot(x='epoch',
                     y=['training error1', 'validation error1'],
                     legend=['training', 'validation'],
                     title='Error@1',
                     ylabel='error %')
        results.plot(x='epoch',
                     y=['training error5', 'validation error5'],
                     legend=['training', 'validation'],
                     title='Error@5',
                     ylabel='error %')
        if 'grad' in train_results.keys():
            results.plot(x='epoch',
                         y=['training grad'],
                         legend=['gradient L2 norm'],
                         title='Gradient Norm',
                         ylabel='value')
        results.save()
def main_worker(args, ml_logger):
    global best_prec1, dtype
    best_prec1 = 0
    dtype = torch_dtypes.get(args.dtype)
    torch.manual_seed(args.seed)

    args.distributed = args.local_rank >= 0 or args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_init,
                                world_size=args.world_size,
                                rank=args.local_rank)
        args.local_rank = dist.get_rank()
        args.world_size = dist.get_world_size()
        if args.dist_backend == 'mpi':
            # If using MPI, select all visible devices
            args.device_ids = list(range(torch.cuda.device_count()))
        else:
            args.device_ids = [args.local_rank]

    if not (args.distributed and args.local_rank > 0):
        if not path.exists(args.save_path):
            makedirs(args.save_path)
        export_args_namespace(args, path.join(args.save_path, 'config.json'))

    setup_logging(path.join(args.save_path, 'log.txt'),
                  resume=args.resume is not '',
                  dummy=args.distributed and args.local_rank > 0)

    results_path = path.join(args.save_path, 'results')
    results = ResultsLog(results_path,
                         title='Training Results - %s' % args.save)

    logging.info("saving to %s", args.save_path)
    logging.debug("run arguments: %s", args)
    logging.info("creating model %s", args.model)

    if 'cuda' in args.device and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.cuda.set_device(args.device_ids[0])
        cudnn.benchmark = True
    else:
        args.device_ids = None

    # create model
    model = models.__dict__[args.model]
    model_config = {'dataset': args.dataset}

    if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))

    model = model(**model_config)
    logging.info("created model with configuration: %s", model_config)
    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

    if args.resume:
        checkpoint_file = args.resume
        if path.isdir(checkpoint_file):
            results.load(path.join(checkpoint_file, 'results.csv'))
            checkpoint_file = path.join(checkpoint_file, 'model_best.pth.tar')
        if path.isfile(checkpoint_file):
            logging.info("loading checkpoint '%s'", args.resume)
            checkpoint = torch.load(checkpoint_file, map_location="cpu")
            if args.start_epoch < 0:  # not explicitly set
                args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optim_state_dict = checkpoint.get('optim_state_dict', None)
            logging.info("loaded checkpoint '%s' (epoch %s)", checkpoint_file,
                         checkpoint['epoch'])
        else:
            logging.error("no checkpoint found at '%s'", args.resume)
    else:
        optim_state_dict = None

    # define loss function (criterion) and optimizer
    loss_params = {}
    if args.label_smoothing > 0:
        loss_params['smooth_eps'] = args.label_smoothing
    criterion = getattr(model, 'criterion', CrossEntropyLoss)(**loss_params)

    criterion.to(args.device, dtype)
    model.to(args.device, dtype)

    # Batch-norm should always be done in float
    if 'half' in args.dtype:
        FilterModules(model, module=is_bn).to(dtype=torch.float)

    # optimizer configuration
    optim_regime = getattr(model, 'regime', [{
        'epoch': 0,
        'optimizer': args.optimizer,
        'lr': args.lr,
        'momentum': args.momentum,
        'weight_decay': args.weight_decay
    }])

    optimizer = optim_regime if isinstance(optim_regime, OptimRegime) \
        else OptimRegime(model, optim_regime, use_float_copy='half' in args.dtype)

    if optim_state_dict is not None:
        optimizer.load_state_dict(optim_state_dict)

    trainer = Trainer(
        model,
        criterion,
        optimizer,
        device_ids=args.device_ids,
        device=args.device,
        dtype=dtype,
        print_freq=args.print_freq,
        distributed=args.distributed,
        local_rank=args.local_rank,
        mixup=args.mixup,
        cutmix=args.cutmix,
        loss_scale=args.loss_scale,
        grad_clip=args.grad_clip,
        adapt_grad_norm=args.adapt_grad_norm,
    )

    # Evaluation Data loading code
    args.eval_batch_size = args.eval_batch_size if args.eval_batch_size > 0 else args.batch_size
    val_data = DataRegime(getattr(model, 'data_eval_regime', None),
                          defaults={
                              'datasets_path': args.datasets_dir,
                              'name': args.dataset,
                              'split': 'val',
                              'augment': False,
                              'input_size': args.input_size,
                              'batch_size': args.eval_batch_size,
                              'shuffle': False,
                              'num_workers': args.workers,
                              'pin_memory': True,
                              'drop_last': False
                          })

    # Training Data loading code
    train_data_defaults = {
        'datasets_path': args.datasets_dir,
        'name': args.dataset,
        'split': 'train',
        'augment': True,
        'input_size': args.input_size,
        'batch_size': args.batch_size,
        'shuffle': True,
        'num_workers': args.workers,
        'pin_memory': True,
        'drop_last': True,
        'distributed': args.distributed,
        'duplicates': args.duplicates,
        'autoaugment': args.autoaugment,
        'cutout': {
            'holes': 1,
            'length': 16
        } if args.cutout else None
    }

    if hasattr(model, 'sampled_data_regime'):
        sampled_data_regime = model.sampled_data_regime
        probs, regime_configs = zip(*sampled_data_regime)
        regimes = []
        for config in regime_configs:
            defaults = {**train_data_defaults}
            defaults.update(config)
            regimes.append(DataRegime(None, defaults=defaults))
        train_data = SampledDataRegime(regimes, probs)
    else:
        train_data = DataRegime(getattr(model, 'data_regime', None),
                                defaults=train_data_defaults)

    logging.info('optimization regime: %s', optim_regime)
    logging.info('data regime: %s', train_data)
    args.start_epoch = max(args.start_epoch, 0)
    trainer.training_steps = args.start_epoch * len(train_data)

    if 'zeroBN' in model_config:  #hot start
        num_steps = int(len(train_data.get_loader()) * 0.5)
        trainer.train(train_data.get_loader(),
                      chunk_batch=args.chunk_batch,
                      num_steps=num_steps)

        for m in model.modules():
            if isinstance(m, ZeroBN):
                m.max_sparsity = args.max_sparsity
                m.max_cos_sim = args.max_cos_sim
                if args.preserve_cosine:
                    if args.layers_cos_sim1 in m.fullName:
                        m.preserve_cosine = args.preserve_cosine
                        m.cos_sim = args.cos_sim1
                    if args.layers_cos_sim2 in m.fullName:
                        m.preserve_cosine = args.preserve_cosine
                        m.cos_sim = args.cos_sim2
                    if args.layers_cos_sim3 in m.fullName:
                        m.preserve_cosine = args.preserve_cosine
                        m.cos_sim = args.cos_sim3

                if args.min_cos_sim:
                    if args.layers_min_cos_sim1 in m.fullName:
                        m.min_cos_sim = args.min_cos_sim
                        m.cos_sim_min = args.cos_sim_min1
                    if args.layers_min_cos_sim2 in m.fullName:
                        m.min_cos_sim = args.min_cos_sim
                        m.cos_sim_min = args.cos_sim_min2

    for epoch in range(args.start_epoch, args.epochs):
        trainer.epoch = epoch
        train_data.set_epoch(epoch)
        val_data.set_epoch(epoch)
        logging.info('\nStarting Epoch: {0}\n'.format(epoch + 1))

        if 'zeroBN' in model_config:
            trainer.collectStat(train_data.get_loader(),
                                num_steps=1,
                                prunRatio=args.stochasticPrunning,
                                cos_sim=args.cos_sim,
                                cos_sim_max=args.cos_sim_max)
            trainer.collectStat(train_data.get_loader(),
                                num_steps=1,
                                prunRatio=args.stochasticPrunning,
                                cos_sim=args.cos_sim,
                                cos_sim_max=args.cos_sim_max)
    #   torch.cuda.empty_cache()
        train_results = trainer.train(train_data.get_loader(),
                                      ml_logger,
                                      chunk_batch=args.chunk_batch)

        # evaluate on validation set

        val_results = trainer.validate(val_data.get_loader())
        ml_logger.log_metric('Val Acc1', val_results['prec1'], step='auto')
        #   torch.cuda.empty_cache()
        if args.distributed and args.local_rank > 0:
            continue

        # remember best prec@1 and save checkpoint
        is_best = val_results['prec1'] > best_prec1
        best_prec1 = max(val_results['prec1'], best_prec1)

        if args.drop_optim_state:
            optim_state_dict = None
        else:
            optim_state_dict = optimizer.state_dict()

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': args.model,
                'config': args.model_config,
                'state_dict': model.state_dict(),
                'optim_state_dict': optim_state_dict,
                'best_prec1': best_prec1
            },
            is_best,
            path=args.save_path,
            save_all=args.save_all)

        logging.info('\nResults - Epoch: {0}\n'
                     'Training Loss {train[loss]:.4f} \t'
                     'Training Prec@1 {train[prec1]:.3f} \t'
                     'Training Prec@5 {train[prec5]:.3f} \t'
                     'Validation Loss {val[loss]:.4f} \t'
                     'Validation Prec@1 {val[prec1]:.3f} \t'
                     'Validation Prec@5 {val[prec5]:.3f} \t\n'.format(
                         epoch + 1, train=train_results, val=val_results))

        values = dict(epoch=epoch + 1, steps=trainer.training_steps)
        values.update({'training ' + k: v for k, v in train_results.items()})
        values.update({'validation ' + k: v for k, v in val_results.items()})
        results.add(**values)

        results.plot(x='epoch',
                     y=['training loss', 'validation loss'],
                     legend=['training', 'validation'],
                     title='Loss',
                     ylabel='loss')
        results.plot(x='epoch',
                     y=['training error1', 'validation error1'],
                     legend=['training', 'validation'],
                     title='Error@1',
                     ylabel='error %')
        results.plot(x='epoch',
                     y=['training error5', 'validation error5'],
                     legend=['training', 'validation'],
                     title='Error@5',
                     ylabel='error %')
        if 'grad' in train_results.keys():
            results.plot(x='epoch',
                         y=['training grad'],
                         legend=['gradient L2 norm'],
                         title='Gradient Norm',
                         ylabel='value')
        results.save()
def main_worker(args):
    global best_prec1, dtype
    acc = -1
    loss = -1
    best_prec1 = 0
    dtype = torch_dtypes.get(args.dtype)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    if args.evaluate:
        args.results_dir = '/tmp'
    if args.save is '':
        args.save = time_stamp
    save_path = os.path.join(args.results_dir, args.save)

    args.distributed = args.local_rank >= 0 or args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_init,
                                world_size=args.world_size, rank=args.local_rank)
        args.local_rank = dist.get_rank()
        args.world_size = dist.get_world_size()
        if args.dist_backend == 'mpi':
            # If using MPI, select all visible devices
            args.device_ids = list(range(torch.cuda.device_count()))
        else:
            args.device_ids = [args.local_rank]

    if not os.path.exists(save_path) and not (args.distributed and args.local_rank > 0):
        os.makedirs(save_path)

    setup_logging(os.path.join(save_path, 'log.txt'),
                  resume=args.resume is not '',
                  dummy=args.distributed and args.local_rank > 0)

    results_path = os.path.join(save_path, 'results')
    results = ResultsLog(
        results_path, title='Training Results - %s' % args.save)

    logging.info("saving to %s", save_path)
    logging.debug("run arguments: %s", args)
    logging.info("creating model %s", args.model)

    if 'cuda' in args.device and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.cuda.set_device(args.device_ids[0])
        cudnn.benchmark = True
    else:
        args.device_ids = None

    # create model
    model = models.__dict__[args.model]
    dataset_type = 'imagenet' if 'imagenet_calib' in args.dataset else args.dataset
    model_config = {'dataset': dataset_type}

    if args.model_config is not '':
        if isinstance(args.model_config, dict):
            for k, v in args.model_config.items():
                if k not in model_config.keys():
                    model_config[k] = v
        else:
            args_dict = literal_eval(args.model_config)
            for k, v in args_dict.items():
                model_config[k] = v
    if (args.absorb_bn or args.load_from_vision or args.pretrained) and not args.batch_norn_tuning:
        if args.load_from_vision:
            import torchvision
            exec_lfv_str = 'torchvision.models.' + args.load_from_vision + '(pretrained=True)'
            model = eval(exec_lfv_str)
        else:
            if not os.path.isfile(args.absorb_bn):
                parser.error('invalid checkpoint: {}'.format(args.evaluate))
            model = model(**model_config)
            checkpoint = torch.load(args.absorb_bn,map_location=lambda storage, loc: storage)
            checkpoint = checkpoint['state_dict'] if 'state_dict' in checkpoint.keys() else checkpoint
            sd={}
            for key in checkpoint.keys():
                key_clean=key.split('module.1.')[1]
                sd[key_clean]=checkpoint[key]
            checkpoint = sd    
            model.load_state_dict(checkpoint,strict=False)
        if  args.load_from_vision or ('batch_norm' in model_config and model_config['batch_norm']):
            logging.info('Creating absorb_bn state dict')
            search_absorbe_bn(model)
            filename_ab = args.absorb_bn+'.absorb_bn' if args.absorb_bn else save_path+'/'+args.model+'.absorb_bn'
            torch.save(model.state_dict(),filename_ab)
            if not args.load_from_vision: return
        else:    
            filename_bn = save_path+'/'+args.model+'.with_bn'
            torch.save(model.state_dict(),filename_bn)
        if args.load_from_vision: return
           
    model = model(**model_config)
    logging.info("created model with configuration: %s", model_config)
    
    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

    # optionally resume from a checkpoint
    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            parser.error('invalid checkpoint: {}'.format(args.evaluate))
        checkpoint = torch.load(args.evaluate, map_location="cpu")    
        # Overrride configuration with checkpoint info
        args.model = checkpoint.get('model', args.model)
        #args.model_config = checkpoint.get('config', args.model_config)
        if not model_config['batch_norm']:
            search_absorbe_fake_bn(model)
        # load checkpoint
        if 'state_dict' in checkpoint.keys():
            if any([True for key in checkpoint['state_dict'].keys() if 'module.1.' in key]):
                sd={}
                for key in checkpoint['state_dict'].keys():
                    key_clean=key.split('module.1.')[1]
                    sd[key_clean]=checkpoint['state_dict'][key]              
                model.load_state_dict(sd,strict=False)
            else:
                model.load_state_dict(checkpoint['state_dict'],strict=False)
                logging.info("loaded checkpoint '%s'", args.evaluate)
        else:
            model.load_state_dict(checkpoint,strict=False)    
            logging.info("loaded checkpoint '%s'",args.evaluate)
          

    if args.resume:
        checkpoint_file = args.resume
        if os.path.isdir(checkpoint_file):
            results.load(os.path.join(checkpoint_file, 'results.csv'))
            checkpoint_file = os.path.join(
                checkpoint_file, 'model_best.pth.tar')
        if os.path.isfile(checkpoint_file):
            logging.info("loading checkpoint '%s'", args.resume)
            checkpoint = torch.load(checkpoint_file)
            if args.start_epoch < 0:  # not explicitly set
                args.start_epoch = checkpoint['epoch'] - 1 if 'epoch' in checkpoint.keys() else 0    
            best_prec1 = checkpoint['best_prec1'] if 'best_prec1' in checkpoint.keys() else -1
            sd = checkpoint['state_dict'] if 'state_dict' in checkpoint.keys() else checkpoint
            model.load_state_dict(sd,strict=False)
            logging.info("loaded checkpoint '%s' (epoch %s)",
                         checkpoint_file, args.start_epoch)
        else:
            logging.error("no checkpoint found at '%s'", args.resume)

    # define loss function (criterion) and optimizer
    loss_params = {}
    if args.label_smoothing > 0:
        loss_params['smooth_eps'] = args.label_smoothing
    criterion = getattr(model, 'criterion', CrossEntropyLoss)(**loss_params)
    criterion.to(args.device, dtype)
    model.to(args.device, dtype)

    # Batch-norm should always be done in float
    if 'half' in args.dtype:
        FilterModules(model, module=is_bn).to(dtype=torch.float)

    # optimizer configuration
    optim_regime = getattr(model, 'regime', [{'epoch': 0,
                                              'optimizer': args.optimizer,
                                              'lr': args.lr,
                                              'momentum': args.momentum,
                                              'weight_decay': args.weight_decay}])
    if args.fine_tune or args.prune: 
        if not args.resume: args.start_epoch=0  
        if args.update_only_th:
            #optim_regime = [
            #    {'epoch': 0, 'optimizer': 'Adam', 'lr': 1e-4}] 
            optim_regime = [
                {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-1},
                {'epoch': 10, 'lr': 1e-2},
                {'epoch': 15, 'lr': 1e-3}]
        else:              
            optim_regime = [
                {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-4, 'momentum': 0.9},
                {'epoch': 2, 'lr': 1e-5, 'momentum': 0.9},
                {'epoch': 10, 'lr': 1e-6, 'momentum': 0.9}]
    optimizer = optim_regime if isinstance(optim_regime, OptimRegime) \
        else OptimRegime(model, optim_regime, use_float_copy='half' in args.dtype)

    # Training Data loading code
    
    train_data = DataRegime(getattr(model, 'data_regime', None),
                            defaults={'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'train', 'augment': False,
                                      'input_size': args.input_size,  'batch_size': args.batch_size, 'shuffle': True,
                                      'num_workers': args.workers, 'pin_memory': True, 'drop_last': True,
                                      'distributed': args.distributed, 'duplicates': args.duplicates, 'autoaugment': args.autoaugment,
                                      'cutout': {'holes': 1, 'length': 16} if args.cutout else None})

    prunner = None 
    trainer = Trainer(model,prunner, criterion, optimizer,
                      device_ids=args.device_ids, device=args.device, dtype=dtype,
                      distributed=args.distributed, local_rank=args.local_rank, mixup=args.mixup, loss_scale=args.loss_scale,
                      grad_clip=args.grad_clip, print_freq=args.print_freq, adapt_grad_norm=args.adapt_grad_norm,epoch=args.start_epoch)

    
    # Evaluation Data loading code
    args.eval_batch_size = args.eval_batch_size if args.eval_batch_size > 0 else args.batch_size    
    dataset_type = 'imagenet' if 'imagenet_calib' in args.dataset else args.dataset 
    #dataset_type = 'imagenet' if args.dataset =='imagenet_calib' else args.dataset
    val_data = DataRegime(getattr(model, 'data_eval_regime', None),
                          defaults={'datasets_path': args.datasets_dir, 'name': dataset_type, 'split': 'val', 'augment': False,
                                    'input_size': args.input_size, 'batch_size': args.eval_batch_size, 'shuffle': True,
                                    'num_workers': args.workers, 'pin_memory': True, 'drop_last': False})


    cached_input_output = {}
    cached_layer_names = []
    if args.adaprune:
        generate_masks_from_model(model)
        def hook(module, input, output):
            if module not in cached_input_output:
                cached_input_output[module] = []
            # Meanwhile store data in the RAM.
            cached_input_output[module].append((input[0].detach().cpu(), output.detach().cpu()))
            print(module.__str__()[:70])

        handlers = []
        count = 0
        global_val = calc_global_prune_val(model,args.sparsity_level) if args.global_pruning  else None
        for name,m in model.named_modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                m.quantize = False
                if count < 1000:
                    handlers.append(m.register_forward_hook(hook))
                    count += 1
                    cached_layer_names.append(name)

        # Store input/output for all quantizable layers
        trainer.validate(train_data.get_loader(),num_steps=1)
        print("Input/outputs cached")

        for handler in handlers:
            handler.remove()

        mse_df = pd.DataFrame(index=np.arange(len(cached_input_output)), columns=['name', 'shape', 'mse_before', 'mse_after'])
        print_freq = 100
        masks_dict = {}
        for i, layer in enumerate(cached_input_output):   
            layer.name =   cached_layer_names[i]    
            print("\nOptimize {}:{} for shape of {}".format(i, layer.name , layer.weight.shape))
            sparsity_level = 1 if args.keep_first_last and (i==0 or i==len(cached_input_output)) else args.sparsity_level
            prune_topk = args.prune_bs if args.keep_first_last and (i==0 or i==len(cached_input_output)) else args.prune_topk   

            mse_before, mse_after, snr_before, snr_after, kurt_in, kurt_w, mask= \
                optimize_layer(layer, cached_input_output[layer], args.optimize_weights,bs=args.prune_bs,topk=prune_topk,extract_topk=args.prune_extract_topk, \
                    unstructured=args.unstructured,sparsity_level=sparsity_level,global_val=global_val,conf_level=args.conf_level)
            masks_dict[layer.name ] = mask   
            print("\nMSE before optimization: {}".format(mse_before))
            print("MSE after optimization:  {}".format(mse_after))
            mse_df.loc[i, 'name'] = layer.name 
            mse_df.loc[i, 'shape'] = str(layer.weight.shape)
            mse_df.loc[i, 'mse_before'] = mse_before
            mse_df.loc[i, 'mse_after'] = mse_after
            mse_df.loc[i, 'snr_before'] = snr_before
            mse_df.loc[i, 'snr_after'] = snr_after
            mse_df.loc[i, 'kurt_in'] = kurt_in
            mse_df.loc[i, 'kurt_w'] = kurt_w
            if i > 0 and i % print_freq == 0:
                print('\n')
                val_results = trainer.validate(val_data.get_loader())
                logging.info(val_results)
        
        total_sparsity = calc_masks_sparsity(masks_dict)
        mse_csv = args.evaluate + '.mse.csv'
        mse_df.to_csv(mse_csv)

        filename = args.evaluate + '.adaprune'
        torch.save(model.state_dict(), filename)

        
        cached_input_output = None
        val_results = trainer.validate(val_data.get_loader())
        logging.info(val_results)
        trainer.cal_bn_stats(train_data.get_loader())
        val_results = trainer.validate(val_data.get_loader())
        logging.info(val_results)

    elif args.batch_norn_tuning:

        for m in model.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                m.quantize = False

        exec_lfv_str = 'torchvision.models.' + args.load_from_vision + '(pretrained=True)'
        model_orig = eval(exec_lfv_str)
        model_orig.to(args.device, dtype)
        search_copy_bn_params(model_orig)
    
        layers_orig = dict([(n, m) for n, m in model_orig.named_modules() if isinstance(m, nn.Conv2d)])
        layers_q = dict([(n, m) for n, m in model.named_modules() if isinstance(m, nn.Conv2d)])
        for l in layers_orig:
            conv_orig = layers_orig[l]
            conv_q = layers_q[l]
            conv_q.register_parameter('gamma', nn.Parameter(conv_orig.gamma.clone()))
            conv_q.register_parameter('beta', nn.Parameter(conv_orig.beta.clone()))
    
        del model_orig
    
        search_add_bn(model)
    
        print("Run BN tuning")
        for tt in range(args.tuning_iter):
            print(tt)
            trainer.cal_bn_stats(train_data.get_loader())
    
        search_absorbe_tuning_bn(model)
    
        filename = args.evaluate + '.bn_tuning'
        print("Save model to: {}".format(filename))
        torch.save(model.state_dict(), filename)
        val_results = trainer.validate(val_data.get_loader())
        logging.info(val_results)
    
        if args.res_log is not None:
            if not os.path.exists(args.res_log):
                df = pd.DataFrame()
            else:
                df = pd.read_csv(args.res_log, index_col=0)
    
            ckp = ntpath.basename(args.evaluate)
            df.loc[ckp, 'acc_bn_tuning'] = val_results['prec1']
            df.loc[ckp, 'loss_bn_tuning'] = val_results['loss']
            df.to_csv(args.res_log)
    else:
        #print('Please Choose one of the following ....')
        if model_config['measure']:
            results = trainer.validate(train_data.get_loader(),rec=args.rec)
        else: 
            if args.evaluate_init_configuration:   
                results = trainer.validate(val_data.get_loader())
                if args.res_log is not None:
                    if not os.path.exists(args.res_log):
                        df = pd.DataFrame()
                    else:
                        df = pd.read_csv(args.res_log, index_col=0)

                    ckp = ntpath.basename(args.evaluate)
                    df.loc[ckp, 'acc_base'] = results['prec1']
                    df.loc[ckp, 'loss_base'] = results['loss']
                    df.to_csv(args.res_log)
           
        if args.evaluate_init_configuration:
                logging.info(results)
    return acc, loss
예제 #8
0
def post_main():
    global args, best_prec1, dtype
    best_prec1 = 0
    args = parser.parse_args()
    dtype = torch_dtypes.get(args.dtype)
    torch.manual_seed(args.seed)
    time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    if args.evaluate:
        args.results_dir = '/tmp'
    if args.save is '':
        args.save = time_stamp
    save_path = os.path.join(args.results_dir, args.save)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    setup_logging(os.path.join(save_path, 'log.txt'),
                  resume=args.resume is not '')
    results_path = os.path.join(save_path, 'results')
    results = ResultsLog(results_path,
                         title='Training Results - %s' % args.save)

    logging.info("saving to %s", save_path)
    logging.debug("run arguments: %s", args)

    if 'cuda' in args.device and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.cuda.set_device(args.device_ids[0])
        cudnn.benchmark = True
    else:
        args.device_ids = None

    # create model
    logging.info("creating model %s", args.model)
    model = models.__dict__[args.model]
    model_config = {'input_size': args.input_size, 'dataset': args.dataset}

    if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))

    model = model(**model_config)
    logging.info("created model with configuration: %s", model_config)

    # optionally resume from a checkpoint
    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            parser.error('invalid checkpoint: {}'.format(args.evaluate))
        checkpoint = torch.load(args.evaluate)
        model.load_state_dict(checkpoint['state_dict'])
        logging.info("loaded checkpoint '%s' (epoch %s)", args.evaluate,
                     checkpoint['epoch'])
    elif args.resume:
        checkpoint_file = args.resume
        if os.path.isdir(checkpoint_file):
            results.load(os.path.join(checkpoint_file, 'results.csv'))
            checkpoint_file = os.path.join(checkpoint_file,
                                           'model_best.pth.tar')
        if os.path.isfile(checkpoint_file):
            logging.info("loading checkpoint '%s'", args.resume)
            checkpoint = torch.load(checkpoint_file)
            args.start_epoch = checkpoint['epoch'] - 1
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            logging.info("loaded checkpoint '%s' (epoch %s)", checkpoint_file,
                         checkpoint['epoch'])
        else:
            logging.error("no checkpoint found at '%s'", args.resume)

    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)
    # Data loading code
    regime = getattr(model, 'regime', [{
        'epoch': 0,
        'optimizer': args.optimizer,
        'lr': args.lr,
        'momentum': args.momentum,
        'weight_decay': args.weight_decay
    }])
    kwargs = {'num_workers': 1, 'pin_memory': True}
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    val_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                             batch_size=1000,
                                             shuffle=True,
                                             **kwargs)

    # define loss function (criterion) and optimizer
    criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)()
    criterion.to(args.device, dtype)
    model.to(args.device, dtype)

    # pruning the model
    masks = weight_prune(model, 50)
    print('Start pruning')
    for i, layer in enumerate(get_layer_list(model)):
        layer.set_mask(masks[i])

    #if args.evaluate:
    validate(val_loader, model, criterion, 0)
    #return

    optimizer = OptimRegime(model.parameters(), regime)
    logging.info('training regime: %s', regime)

    for i in range(len(get_layer_list(model))):
        sketch_layer(model, i, 0.5)
        fix_layer(model, i)
        model.cuda()
        logging.info(str(i) + '-Sketch: ')
        val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion,
                                                  0)
        logging.info('Start Retraining ...')

        for epoch in range(1, 100 + 1):
            # train for one epoch
            train_loss, train_prec1, train_prec5 = train(
                train_loader, model, criterion, epoch, optimizer)
            val_loss, val_prec1, val_prec5 = validate(val_loader, model,
                                                      criterion, 0)
예제 #9
0
def main_worker(args):
    global best_prec1, dtype
    best_prec1 = 0
    dtype = torch_dtypes.get(args.dtype)
    torch.manual_seed(args.seed)
    time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    if args.evaluate:
        args.results_dir = '/tmp'


    model_config = {'dataset': args.dataset, 'batch': args.batch_size}

    if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))

    ##autoname
    fname = auto_name(args, model_config)
    args.save = fname


    monitor = args.monitor

    print(fname)

    save_path = path.join(args.results_dir, args.save)

    args.distributed = args.local_rank >= 0 or args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_init,
                                world_size=args.world_size, rank=args.local_rank)
        args.local_rank = dist.get_rank()
        args.world_size = dist.get_world_size()
        if args.dist_backend == 'mpi':
            # If using MPI, select all visible devices
            args.device_ids = list(range(torch.cuda.device_count()))
        else:
            args.device_ids = [args.local_rank]

    if not (args.distributed and args.local_rank > 0):
        if not args.dry:
            if not path.exists(save_path):
                    makedirs(save_path)
            export_args_namespace(args, path.join(save_path, 'config.json'))


    if monitor > 0 and not args.dry: 

        events_path = "runs/%s" % fname
        my_file = Path(events_path)
        if my_file.is_file():
            os.remove(events_path) 

        writer = SummaryWriter(log_dir=events_path  ,comment=str(args))
        model_config['writer'] = writer
        model_config['monitor'] = monitor
    else:
        monitor = 0
        writer = None

    if args.dry:
        model = models.__dict__[args.model]
        model = model(**model_config)
        print("created model with configuration: %s" % model_config)
        num_parameters = sum([l.nelement() for l in model.parameters()])
        print("number of parameters: %d" % num_parameters)
        return

    setup_logging(path.join(save_path, 'log.txt'),
                  resume=args.resume is not '',
                  dummy=args.distributed and args.local_rank > 0)

    results_path = path.join(save_path, 'results')
    results = ResultsLog(results_path,
                         title='Training Results - %s' % args.save)

    logging.info("saving to %s", save_path)
    logging.debug("run arguments: %s", args)
    logging.info("creating model %s", args.model)

    if 'cuda' in args.device and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.cuda.set_device(args.device_ids[0])
        cudnn.benchmark = True
    else:
        args.device_ids = None

    # create model
    model = models.__dict__[args.model]
    model = model(**model_config)
    if args.sync_bn:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    logging.info("created model with configuration: %s", model_config)
    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

    # optionally resume from a checkpoint
    if args.evaluate:
        if not path.isfile(args.evaluate):
            parser.error('invalid checkpoint: {}'.format(args.evaluate))
        checkpoint = torch.load(args.evaluate, map_location="cpu")
        # Overrride configuration with checkpoint info
        args.model = checkpoint.get('model', args.model)
        args.model_config = checkpoint.get('config', args.model_config)
        # load checkpoint
        model.load_state_dict(checkpoint['state_dict'])
        logging.info("loaded checkpoint '%s' (epoch %s)",
                     args.evaluate, checkpoint['epoch'])

    if args.resume:
        checkpoint_file = args.resume
        if path.isdir(checkpoint_file):
            results.load(path.join(checkpoint_file, 'results.csv'))
            checkpoint_file = path.join(
                checkpoint_file, 'model_best.pth.tar')
        if path.isfile(checkpoint_file):
            logging.info("loading checkpoint '%s'", args.resume)
            checkpoint = torch.load(checkpoint_file, map_location="cpu")
            if args.start_epoch < 0:  # not explicitly set
                args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optim_state_dict = checkpoint.get('optim_state_dict', None)
            logging.info("loaded checkpoint '%s' (epoch %s)",
                         checkpoint_file, checkpoint['epoch'])
        else:
            logging.error("no checkpoint found at '%s'", args.resume)
    else:
        optim_state_dict = None

    # define loss function (criterion) and optimizer
    loss_params = {}
    if args.label_smoothing > 0:
        loss_params['smooth_eps'] = args.label_smoothing
    criterion = getattr(model, 'criterion', CrossEntropyLoss)(**loss_params)
    criterion.to(args.device, dtype)
    model.to(args.device, dtype)

    # Batch-norm should always be done in float
    if 'half' in args.dtype:
        FilterModules(model, module=is_bn).to(dtype=torch.float)

    # optimizer configuration
    optim_regime = getattr(model, 'regime', [{'epoch': 0,
                                              'optimizer': args.optimizer,
                                              'lr': args.lr,
                                              'momentum': args.momentum,
                                              'weight_decay': args.weight_decay}])

    optimizer = optim_regime if isinstance(optim_regime, OptimRegime) \
        else OptimRegime(model, optim_regime, use_float_copy='half' in args.dtype)

    if optim_state_dict is not None:
        optimizer.load_state_dict(optim_state_dict)

    trainer = Trainer(model, criterion, optimizer,
                      device_ids=args.device_ids, device=args.device, dtype=dtype, print_freq=args.print_freq,
                      distributed=args.distributed, local_rank=args.local_rank, mixup=args.mixup, cutmix=args.cutmix,
                      loss_scale=args.loss_scale, grad_clip=args.grad_clip,  adapt_grad_norm=args.adapt_grad_norm, writer = writer, monitor = monitor)
    if args.tensorwatch:
        trainer.set_watcher(filename=path.abspath(path.join(save_path, 'tensorwatch.log')),
                            port=args.tensorwatch_port)

    # Evaluation Data loading code
    args.eval_batch_size = args.eval_batch_size if args.eval_batch_size > 0 else args.batch_size
    val_data = DataRegime(getattr(model, 'data_eval_regime', None),
                          defaults={'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'val', 'augment': False,
                                    'input_size': args.input_size, 'batch_size': args.eval_batch_size, 'shuffle': False,
                                    'num_workers': args.workers, 'pin_memory': True, 'drop_last': False})


    if args.evaluate:
        results = trainer.validate(val_data.get_loader())
        logging.info(results)
        return

    # Training Data loading code
    train_data_defaults = {'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'train', 'augment': True,
                           'input_size': args.input_size,  'batch_size': args.batch_size, 'shuffle': True,
                           'num_workers': args.workers, 'pin_memory': True, 'drop_last': True,
                           'distributed': args.distributed, 'duplicates': args.duplicates, 'autoaugment': args.autoaugment,
                           'cutout': {'holes': 1, 'length': 16} if args.cutout else None}

    if hasattr(model, 'sampled_data_regime'):
        sampled_data_regime = model.sampled_data_regime
        probs, regime_configs = zip(*sampled_data_regime)
        regimes = []
        for config in regime_configs:
            defaults = {**train_data_defaults}
            defaults.update(config)
            regimes.append(DataRegime(None, defaults=defaults))
        train_data = SampledDataRegime(regimes, probs)
    else:
        train_data = DataRegime(
            getattr(model, 'data_regime', None), defaults=train_data_defaults)

    logging.info('optimization regime: %s', optim_regime)
    logging.info('data regime: %s', train_data)
    args.start_epoch = max(args.start_epoch, 0)
    trainer.training_steps = args.start_epoch * len(train_data)

    if not args.covmat == "":
        try: 
            int_covmat = int(args.covmat)
            if int_covmat < 0:
                total_layers = len([name for name, layer in model.named_children()])
                int_covmat = total_layers + int_covmat
            child_cnt = 0
        except ValueError:
            int_covmat = None

        def calc_covmat(x_, partitions = 64):

            L = x_.shape[0] // partitions

            non_diags = []
            diags = []
            for p1 in range(partitions):
                for p2 in range(partitions):
                    x = x_[p1*L:(p1+1)*L]
                    y = x_[p2*L:(p2+1)*L]
                    X = torch.matmul(x,y.transpose(0,1))
                    if p1 == p2:
                        mask = torch.eye(X.shape[0],dtype=torch.bool)
                        non_diag = X[~mask].reshape(-1).cpu()
                        diag = X[mask].reshape(-1).cpu()
                        non_diags.append(non_diag)
                        diags.append(diag)
                    else:
                        non_diag = X.reshape(-1).cpu()
                        non_diags.append(diag)
            diags = torch.cat(diags)
            non_diags = torch.cat(non_diags)


            diag_var = diags.var()
            non_diag_var = non_diags.var()


            diags = diags - diags.mean()
            non_diags = non_diags - non_diags.mean()
            ##import pdb; pdb.set_trace()

            diag_small_ratio = (diags < -diags.std()).to(dtype = torch.float).mean() 
            non_diag_small_ratio = (non_diags < -non_diags.std()).to(dtype = torch.float).mean() 

            return diag_var, non_diag_var, diag_small_ratio, non_diag_small_ratio


        global diag_var_mean 
        global non_diag_var_mean 
        global var_count 
        var_count = 0
        diag_var_mean = 0 
        non_diag_var_mean = 0

        def report_covmat_hook(module, input, output):
            global diag_var_mean 
            global non_diag_var_mean 
            global var_count 

            flatten_output = output.reshape([-1,1]).detach()
            diag_var, non_diag_var, diag_small_ratio, non_diag_small_ratio = calc_covmat(flatten_output)

            diag_var_mean = diag_var_mean + diag_var
            
            non_diag_var_mean = non_diag_var_mean + non_diag_var
            
            var_count = var_count + 1
            if var_count % 10 == 1:
                print("diag_var = %.02f (%.02f), ratio: %.02f , non_diag_var = %0.2f (%.02f), ratio: %.02f" % (diag_var, diag_var_mean/var_count, diag_small_ratio , non_diag_var, non_diag_var_mean/var_count, non_diag_small_ratio ))

        for name, layer in model.named_children():
            if  int_covmat is None:
                condition =  (name  == args.covmat)
            else:
                condition = (child_cnt == int_covmat)
                child_cnt = child_cnt + 1
            if condition:
                layer.register_forward_hook( report_covmat_hook)
                


    for epoch in range(args.start_epoch, args.epochs):
        trainer.epoch = epoch
        train_data.set_epoch(epoch)
        val_data.set_epoch(epoch)
        logging.info('\nStarting Epoch: {0}\n'.format(epoch + 1))

        # train for one epoch
        train_results = trainer.train(train_data.get_loader(),
                                      chunk_batch=args.chunk_batch)

        # evaluate on validation set
        val_results = trainer.validate(val_data.get_loader())

        if args.distributed and args.local_rank > 0:
            continue

        # remember best prec@1 and save checkpoint
        is_best = val_results['prec1'] > best_prec1
        best_prec1 = max(val_results['prec1'], best_prec1)

        if args.drop_optim_state:
            optim_state_dict = None
        else:
            optim_state_dict = optimizer.state_dict()

        save_checkpoint({
            'epoch': epoch + 1,
            'model': args.model,
            'config': args.model_config,
            'state_dict': model.state_dict(),
            'optim_state_dict': optim_state_dict,
            'best_prec1': best_prec1
        }, is_best, path=save_path, save_all=args.save_all)

        logging.info('\nResults - Epoch: {0}\n'
                     'Training Loss {train[loss]:.4f} \t'
                     'Training Prec@1 {train[prec1]:.3f} \t'
                     'Training Prec@5 {train[prec5]:.3f} \t'
                     'Validation Loss {val[loss]:.4f} \t'
                     'Validation Prec@1 {val[prec1]:.3f} \t'
                     'Validation Prec@5 {val[prec5]:.3f} \t\n'
                     .format(epoch + 1, train=train_results, val=val_results))

        if writer is not None:
            writer.add_scalar('Train/Loss', train_results['loss'], epoch)
            writer.add_scalar('Train/Prec@1', train_results['prec1'], epoch)
            writer.add_scalar('Train/Prec@5', train_results['prec5'], epoch)
            writer.add_scalar('Val/Loss', val_results['loss'], epoch)
            writer.add_scalar('Val/Prec@1', val_results['prec1'], epoch)
            writer.add_scalar('Val/Prec@5', val_results['prec5'], epoch)
            # tmplr = optimizer.get_lr()
            # writer.add_scalar('HyperParameters/learning-rate',  tmplr, epoch)

        values = dict(epoch=epoch + 1, steps=trainer.training_steps)
        values.update({'training ' + k: v for k, v in train_results.items()})
        values.update({'validation ' + k: v for k, v in val_results.items()})
        results.add(**values)

        results.plot(x='epoch', y=['training loss', 'validation loss'],
                     legend=['training', 'validation'],
                     title='Loss', ylabel='loss')
        results.plot(x='epoch', y=['training error1', 'validation error1'],
                     legend=['training', 'validation'],
                     title='Error@1', ylabel='error %')
        results.plot(x='epoch', y=['training error5', 'validation error5'],
                     legend=['training', 'validation'],
                     title='Error@5', ylabel='error %')
        if 'grad' in train_results.keys():
            results.plot(x='epoch', y=['training grad'],
                         legend=['gradient L2 norm'],
                         title='Gradient Norm', ylabel='value')
        results.save()
    logging.info(f'\nBest Validation Accuracy (top1): {best_prec1}')

    if writer:
        writer.close()
예제 #10
0
def main():
    global args, best_prec1, dtype
    best_prec1 = 0
    args = parser.parse_args()
    dtype = torch_dtypes.get(args.dtype)
    torch.manual_seed(args.seed)
    time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    if args.evaluate:
        args.results_dir = '/tmp'
    if args.save is '':
        args.save = time_stamp
    save_path = os.path.join(args.results_dir, args.save)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    else:
        print('***************************************\n'
              'Warning: PATH exists - override warning\n'
              '***************************************')

    args.distributed = args.local_rank >= 0 or args.world_size > 1
    setup_logging(os.path.join(save_path, 'log.txt'),
                  resume=args.resume is not '',
                  dummy=args.distributed and args.local_rank > 0)

    if args.deterministic:
        logging.info('Deterministic Run Set')
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    results_path = os.path.join(save_path, 'results')
    results = ResultsLog(results_path,
                         title='Training Results - %s' % args.save)

    if args.distributed:
        args.device_ids = [args.local_rank]
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_init,
                                world_size=args.world_size,
                                rank=args.local_rank)

    logging.info("saving to %s", save_path)
    logging.debug("run arguments: %s", args)
    logging.info("creating model %s", args.model)

    if 'cuda' in args.device and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.cuda.set_device(args.device_ids[0])
        cudnn.benchmark = True
    else:
        args.device_ids = None

    # create model
    set_global_seeds(args.seed)
    model = models.__dict__[args.model]
    model_config = {'dataset': args.dataset}

    if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))

    model = model(**model_config)
    logging.info("created model with configuration: %s", model_config)
    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

    # optionally resume from a checkpoint
    shards = None
    x = None
    checkpoint = None
    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            parser.error('invalid checkpoint: {}'.format(args.evaluate))
        checkpoint = torch.load(args.evaluate)
        x = dict()
        for name, val in checkpoint['server_state_dict'].items():
            x[name[7:]] = val
        model.load_state_dict(x)
        shards = checkpoint['server_weight_shards']
        logging.info("loaded checkpoint '%s' (epoch %s)", args.evaluate,
                     checkpoint['epoch'])
    elif args.resume:
        checkpoint_file = args.resume
        if os.path.isdir(checkpoint_file):
            results.load(os.path.join(checkpoint_file, 'results.csv'))
            checkpoint_file = os.path.join(checkpoint_file,
                                           'model_best.pth.tar')
        if os.path.isfile(checkpoint_file):
            logging.info("loading checkpoint '%s'", args.resume)
            checkpoint = torch.load(checkpoint_file,
                                    map_location=torch.device('cpu'))
            args.start_epoch = checkpoint['epoch'] - 1
            best_prec1 = checkpoint['best_prec1']
            # model_dict = {'.'.join(k.split('.')[1:]): v for k, v in checkpoint['server_state_dict'].items()}
            # model.load_state_dict(model_dict)
            model.load_state_dict(checkpoint['server_state_dict'])
            logging.info("loaded checkpoint '%s' (epoch %s)", checkpoint_file,
                         checkpoint['epoch'])
            shards = checkpoint['server_weight_shards']
        else:
            logging.error("no checkpoint found at '%s'", args.resume)

    # define loss function (criterion) and optimizer
    loss_params = {}
    if args.label_smoothing > 0:
        loss_params['smooth_eps'] = args.label_smoothing
    criterion = getattr(model, 'criterion', CrossEntropyLoss)(**loss_params)
    criterion.to(args.device, dtype)
    model.to(args.device, dtype)

    # optimizer configuration
    optim_regime = getattr(model, 'regime', [{
        'epoch': 0,
        'optimizer': args.optimizer,
        'lr': args.lr,
        'momentum': args.momentum,
        'weight_decay': args.weight_decay
    }])
    cpu_store = True if args.dataset == 'imagenet' and args.workers_num > 32 else False
    args.server = args.server if args.delay > 0 else 'ssgd'
    server = ParameterServer.get_server(args.server,
                                        args.delay,
                                        model=model,
                                        shards=shards,
                                        optimizer_regime=optim_regime,
                                        device_ids=args.device_ids,
                                        device=args.device,
                                        dtype=dtype,
                                        distributed=args.distributed,
                                        local_rank=args.local_rank,
                                        grad_clip=args.grad_clip,
                                        workers_num=args.workers_num,
                                        cpu_store=cpu_store)
    del shards, x, checkpoint
    torch.cuda.empty_cache()

    trainer = Trainer(model,
                      server,
                      criterion,
                      device_ids=args.device_ids,
                      device=args.device,
                      dtype=dtype,
                      distributed=args.distributed,
                      local_rank=args.local_rank,
                      workers_number=args.workers_num,
                      grad_clip=args.grad_clip,
                      print_freq=args.print_freq,
                      schedule=args.schedule)

    # Evaluation Data loading code
    args.eval_batch_size = args.eval_batch_size if args.eval_batch_size > 0 else args.batch_size
    val_data = DataRegime(getattr(model, 'data_eval_regime', None),
                          defaults={
                              'datasets_path': args.datasets_dir,
                              'name': args.dataset,
                              'split': 'val',
                              'augment': False,
                              'input_size': args.input_size,
                              'batch_size': args.eval_batch_size,
                              'shuffle': False,
                              'num_workers': args.workers,
                              'pin_memory': True,
                              'drop_last': True
                          })

    # Training Data loading code
    train_data = DataRegime(getattr(model, 'data_regime', None),
                            defaults={
                                'datasets_path': args.datasets_dir,
                                'name': args.dataset,
                                'split': 'train',
                                'augment': args.augment,
                                'input_size': args.input_size,
                                'batch_size': args.batch_size,
                                'shuffle': True,
                                'num_workers': args.workers,
                                'pin_memory': True,
                                'drop_last': True,
                                'distributed': args.distributed,
                                'duplicates': args.duplicates,
                                'cutout': {
                                    'holes': 1,
                                    'length': 16
                                } if args.cutout else None
                            })

    if args.evaluate:
        trainer.forward_pass(train_data.get_loader(),
                             duplicates=args.duplicates)
        results = trainer.validate(val_data.get_loader())
        logging.info(results)
        return

    logging.info('optimization regime: %s', optim_regime)
    trainer.training_steps = args.start_epoch * len(train_data)
    args.iterations_steps = trainer.training_steps

    with open(os.path.join(save_path, 'args.txt'), 'w') as file:
        file.write(dict_to_table(vars(args)))
    tb.init(path=save_path,
            title='Training Results',
            params=args,
            res_iterations=args.resolution)

    for epoch in range(args.start_epoch, args.epochs):
        trainer.epoch = epoch
        train_data.set_epoch(epoch)
        val_data.set_epoch(epoch)
        logging.info('\nStarting Epoch: {0}\n'.format(epoch + 1))

        # train for one epoch
        train_results = trainer.train(train_data.get_loader(),
                                      duplicates=args.duplicates)
        # evaluate on validation set
        val_results = trainer.validate(val_data.get_loader())
        if args.distributed and args.local_rank > 0:
            continue

        # remember best prec@1 and save checkpoint
        is_best = val_results['prec1'] > best_prec1
        best_prec1 = max(val_results['prec1'], best_prec1)
        if (epoch + 1) % args.save_freq == 0:
            tb.tboard.set_resume_step(epoch)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model': args.model,
                    'server_state_dict': server._model.state_dict(),
                    'server_weight_shards': server._shards_weights,
                    'config': args.model_config,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                },
                is_best,
                path=save_path)
        errors = {
            'error1_train': 100 - train_results['prec1'],
            'error5_train': 100 - train_results['prec5'],
            'error1_val': 100 - val_results['prec1'],
            'error5_val': 100 - val_results['prec5'],
            'epochs': epoch
        }
        logging.info('\nResults - Epoch: {0}\n'
                     'Training Loss {train[loss]:.4f} \t'
                     'Training Error@1 {errors[error1_train]:.3f} \t'
                     'Training Error@5 {errors[error5_train]:.3f} \t'
                     'Validation Loss {val[loss]:.4f} \t'
                     'Validation Error@1 {errors[error1_val]:.3f} \t'
                     'Validation Error@5 {errors[error5_val]:.3f} \t\n'.format(
                         epoch + 1,
                         train=train_results,
                         val=val_results,
                         errors=errors))

        values = dict(epoch=epoch + 1, steps=trainer.training_steps)
        values.update({'training ' + k: v for k, v in train_results.items()})
        values.update({'validation ' + k: v for k, v in val_results.items()})
        tb.tboard.log_results(epoch, **values)
        tb.tboard.log_model(server, epoch)
        if args.delay > 0:
            tb.tboard.log_delay(trainer.delay_hist, epoch)

    tb.tboard.close()
    return errors, args
def main_worker(args):
    global best_prec1, dtype
    best_prec1 = 0
    dtype = torch_dtypes.get(args.dtype)
    torch.manual_seed(args.seed)
    time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    if args.evaluate:
        args.results_dir = '/tmp'
    if args.save is '':
        args.save = time_stamp
    save_path = path.join(args.results_dir, args.save)

    args.distributed = args.local_rank >= 0 or args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_init,
                                world_size=args.world_size,
                                rank=args.local_rank)
        args.local_rank = dist.get_rank()
        args.world_size = dist.get_world_size()
        if args.dist_backend == 'mpi':
            # If using MPI, select all visible devices
            args.device_ids = list(range(torch.cuda.device_count()))
        else:
            args.device_ids = [args.local_rank]

    if not (args.distributed and args.local_rank > 0):
        if not path.exists(save_path):
            makedirs(save_path)
        export_args_namespace(args, path.join(save_path, 'config.json'))

    setup_logging(path.join(save_path, 'log.txt'),
                  resume=args.resume is not '',
                  dummy=args.distributed and args.local_rank > 0)

    results_path = path.join(save_path, 'results')
    results = ResultsLog(results_path,
                         title='Training Results - %s' % args.save)

    grad_stats_path = path.join(save_path, 'grad_stats')
    grad_stats = ResultsLog(grad_stats_path,
                            title='collect grad stats - %s' % args.save)

    logging.info("saving to %s", save_path)
    logging.debug("run arguments: %s", args)
    logging.info("creating model %s", args.model)

    if 'cuda' in args.device and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.cuda.set_device(args.device_ids[0])
        cudnn.benchmark = True
    else:
        args.device_ids = None

    # create model
    model = models.__dict__[args.model]
    model_config = {'dataset': args.dataset}

    if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))

    if args.enable_scheduler:
        model_config['fp8_dynamic'] = True
    if args.smart_loss_scale_only:
        model_config['smart_loss_scale_only'] = True
    if args.smart_loss_scale_and_exp_bits:
        model_config['smart_loss_scale_and_exp_bits'] = True
    model = model(**model_config)
    quantize_modules_name = [
        n for n, m in model.named_modules() if isinstance(m, nn.Conv2d)
    ]
    fp8_scheduler = FP8TrainingScheduler(
        model,
        model_config,
        args,
        collect_stats_online=False,
        start_to_collect_stats_in_epoch=3,
        collect_stats_every_epochs=10,
        online_update=False,
        first_update_with_stats_from_epoch=4,
        start_online_update_in_epoch=3,
        update_every_epochs=1,
        update_loss_scale=True,
        update_exp_bit_width=args.smart_loss_scale_and_exp_bits,
        stats_path=
        "/data/moran/ConvNet_lowp_0/convNet.pytorch/results/2020-05-16_01-44-22/results.csv",  # ResNet18- cifar10
        # stats_path = "/data/moran/ConvNet_lowp_0/convNet.pytorch/results/2020-05-19_01-27-57/results.csv",  # ResNet18- ImageNet
        quantize_modules_name=quantize_modules_name,
        enable_scheduler=False)

    if args.sync_bn:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    logging.info("created model with configuration: %s", model_config)
    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

    # optionally resume from a checkpoint
    if args.evaluate:
        if not path.isfile(args.evaluate):
            parser.error('invalid checkpoint: {}'.format(args.evaluate))
        checkpoint = torch.load(args.evaluate, map_location="cpu")
        # Overrride configuration with checkpoint info
        args.model = checkpoint.get('model', args.model)
        args.model_config = checkpoint.get('config', args.model_config)
        # load checkpoint
        model.load_state_dict(checkpoint['state_dict'])
        logging.info("loaded checkpoint '%s' (epoch %s)", args.evaluate,
                     checkpoint['epoch'])

    if args.resume:
        checkpoint_file = args.resume
        if path.isdir(checkpoint_file):
            results.load(path.join(checkpoint_file, 'results.csv'))
            checkpoint_file = path.join(checkpoint_file, 'model_best.pth.tar')
        if path.isfile(checkpoint_file):
            logging.info("loading checkpoint '%s'", args.resume)
            checkpoint = torch.load(checkpoint_file, map_location="cpu")
            if args.start_epoch < 0:  # not explicitly set
                args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optim_state_dict = checkpoint.get('optim_state_dict', None)
            logging.info("loaded checkpoint '%s' (epoch %s)", checkpoint_file,
                         checkpoint['epoch'])
        else:
            logging.error("no checkpoint found at '%s'", args.resume)
    else:
        optim_state_dict = None

    # define loss function (criterion) and optimizer
    loss_params = {}
    if args.label_smoothing > 0:
        loss_params['smooth_eps'] = args.label_smoothing
    criterion = getattr(model, 'criterion', CrossEntropyLoss)(**loss_params)
    criterion.to(args.device, dtype)
    model.to(args.device, dtype)

    # Batch-norm should always be done in float
    if 'half' in args.dtype:
        FilterModules(model, module=is_bn).to(dtype=torch.float)

    # optimizer configuration
    optim_regime = getattr(model, 'regime', [{
        'epoch': 0,
        'optimizer': args.optimizer,
        'lr': args.lr,
        'momentum': args.momentum,
        'weight_decay': args.weight_decay
    }])

    optimizer = optim_regime if isinstance(optim_regime, OptimRegime) \
        else OptimRegime(model, optim_regime, use_float_copy='half' in args.dtype)

    if optim_state_dict is not None:
        optimizer.load_state_dict(optim_state_dict)

    trainer = Trainer(model,
                      criterion,
                      optimizer,
                      device_ids=args.device_ids,
                      device=args.device,
                      dtype=dtype,
                      print_freq=args.print_freq,
                      distributed=args.distributed,
                      local_rank=args.local_rank,
                      mixup=args.mixup,
                      cutmix=args.cutmix,
                      loss_scale=args.loss_scale,
                      grad_clip=args.grad_clip,
                      adapt_grad_norm=args.adapt_grad_norm,
                      enable_input_grad_statistics=True,
                      exp_bits=args.exp_bits,
                      fp_bits=args.fp_bits)
    if args.tensorwatch:
        trainer.set_watcher(filename=path.abspath(
            path.join(save_path, 'tensorwatch.log')),
                            port=args.tensorwatch_port)

    # Evaluation Data loading code
    args.eval_batch_size = args.eval_batch_size if args.eval_batch_size > 0 else args.batch_size
    val_data = DataRegime(getattr(model, 'data_eval_regime', None),
                          defaults={
                              'datasets_path': args.datasets_dir,
                              'name': args.dataset,
                              'split': 'val',
                              'augment': False,
                              'input_size': args.input_size,
                              'batch_size': args.eval_batch_size,
                              'shuffle': False,
                              'num_workers': args.workers,
                              'pin_memory': True,
                              'drop_last': False
                          })

    if args.evaluate:
        results = trainer.validate(val_data.get_loader())
        logging.info(results)
        return

    # Training Data loading code
    train_data_defaults = {
        'datasets_path': args.datasets_dir,
        'name': args.dataset,
        'split': 'train',
        'augment': True,
        'input_size': args.input_size,
        'batch_size': args.batch_size,
        'shuffle': True,
        'num_workers': args.workers,
        'pin_memory': True,
        'drop_last': True,
        'distributed': args.distributed,
        'duplicates': args.duplicates,
        'autoaugment': args.autoaugment,
        'cutout': {
            'holes': 1,
            'length': 16
        } if args.cutout else None
    }

    if hasattr(model, 'sampled_data_regime'):
        sampled_data_regime = model.sampled_data_regime
        probs, regime_configs = zip(*sampled_data_regime)
        regimes = []
        for config in regime_configs:
            defaults = {**train_data_defaults}
            defaults.update(config)
            regimes.append(DataRegime(None, defaults=defaults))
        train_data = SampledDataRegime(regimes, probs)
    else:
        train_data = DataRegime(getattr(model, 'data_regime', None),
                                defaults=train_data_defaults)

    logging.info('optimization regime: %s', optim_regime)
    logging.info('data regime: %s', train_data)
    args.start_epoch = max(args.start_epoch, 0)
    trainer.training_steps = args.start_epoch * len(train_data)
    for epoch in range(args.start_epoch, args.epochs):
        trainer.epoch = epoch
        train_data.set_epoch(epoch)
        val_data.set_epoch(epoch)
        logging.info('\nStarting Epoch: {0}\n'.format(epoch + 1))

        fp8_scheduler.schedule_before_epoch(epoch)
        # train for one epoch
        # pdb.set_trace()
        train_results, meters_grad = trainer.train(
            train_data.get_loader(),
            chunk_batch=args.chunk_batch,
            scheduled_instructions=fp8_scheduler.scheduled_instructions)

        # evaluate on validation set

        if args.calibrate_bn:
            train_data = DataRegime(None,
                                    defaults={
                                        'datasets_path': args.datasets_dir,
                                        'name': args.dataset,
                                        'split': 'train',
                                        'augment': True,
                                        'input_size': args.input_size,
                                        'batch_size': args.batch_size,
                                        'shuffle': True,
                                        'num_workers': args.workers,
                                        'pin_memory': True,
                                        'drop_last': False
                                    })
            trainer.calibrate_bn(train_data.get_loader(), num_steps=200)

        val_results, _ = trainer.validate(val_data.get_loader())

        if args.distributed and args.local_rank > 0:
            continue

        # remember best prec@1 and save checkpoint
        is_best = val_results['prec1'] > best_prec1
        best_prec1 = max(val_results['prec1'], best_prec1)

        if args.drop_optim_state:
            optim_state_dict = None
        else:
            optim_state_dict = optimizer.state_dict()

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': args.model,
                'config': args.model_config,
                'state_dict': model.state_dict(),
                'optim_state_dict': optim_state_dict,
                'best_prec1': best_prec1
            },
            is_best,
            path=save_path,
            save_all=args.save_all)

        logging.info('\nResults - Epoch: {0}\n'
                     'Training Loss {train[loss]:.4f} \t'
                     'Training Prec@1 {train[prec1]:.3f} \t'
                     'Training Prec@5 {train[prec5]:.3f} \t'
                     'Validation Loss {val[loss]:.4f} \t'
                     'Validation Prec@1 {val[prec1]:.3f} \t'
                     'Validation Prec@5 {val[prec5]:.3f} \t\n'.format(
                         epoch + 1, train=train_results, val=val_results))

        values = dict(epoch=epoch + 1, steps=trainer.training_steps)

        values.update({'training ' + k: v for k, v in train_results.items()})
        values.update({'validation ' + k: v for k, v in val_results.items()})

        values.update(
            {'grad mean ' + k: v['mean'].avg
             for k, v in meters_grad.items()})
        values.update(
            {'grad std ' + k: v['std'].avg
             for k, v in meters_grad.items()})

        results.add(**values)

        # stats was collected
        if fp8_scheduler.scheduled_instructions['collect_stat']:
            grad_stats_values = dict(epoch=epoch + 1)
            grad_stats_values.update({
                'grad mean ' + k: v['mean'].avg
                for k, v in meters_grad.items()
            })
            grad_stats_values.update({
                'grad std ' + k: v['std'].avg
                for k, v in meters_grad.items()
            })

            grad_stats.add(**grad_stats_values)
            fp8_scheduler.update_stats(grad_stats)

        results.plot(x='epoch',
                     y=['training loss', 'validation loss'],
                     legend=['training', 'validation'],
                     title='Loss',
                     ylabel='loss')
        results.plot(x='epoch',
                     y=['training error1', 'validation error1'],
                     legend=['training', 'validation'],
                     title='Error@1',
                     ylabel='error %')
        results.plot(x='epoch',
                     y=['training error5', 'validation error5'],
                     legend=['training', 'validation'],
                     title='Error@5',
                     ylabel='error %')
        if 'grad' in train_results.keys():
            results.plot(x='epoch',
                         y=['training grad'],
                         legend=['gradient L2 norm'],
                         title='Gradient Norm',
                         ylabel='value')
        results.save()
        grad_stats.save()
예제 #12
0
def main():
    global args, best_prec1, dtype
    best_prec1 = 0
    args = parser.parse_args()
    dtype = torch_dtypes.get(args.dtype)
    torch.manual_seed(args.seed)
    time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    if args.evaluate:
        args.results_dir = '/tmp'
    if args.save is '':
        args.save = time_stamp
    save_path = os.path.join(args.results_dir, args.save)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    setup_logging(os.path.join(save_path, 'log.txt'),
                  resume=args.resume is not '')
    results_path = os.path.join(save_path, 'results')
    results = ResultsLog(results_path,
                         title='Training Results - %s' % args.save)

    logging.info("saving to %s", save_path)
    logging.debug("run arguments: %s", args)

    if 'cuda' in args.device and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.cuda.set_device(args.device_ids[0])
        cudnn.benchmark = True
    else:
        args.device_ids = None

    # create model
    logging.info("creating model %s", args.model)
    model_builder = models.__dict__[args.model]

    model_config = {
        'input_size': args.input_size,
        'dataset': args.dataset if args.dataset != 'imaginet' else 'imagenet'
    }
    if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))
    model = model_builder(**model_config)
    model.to(args.device, dtype)

    # Data loading code
    default_transform = {
        'train':
        get_transform(args.dataset, input_size=args.input_size, augment=True),
        'eval':
        get_transform(args.dataset, input_size=args.input_size, augment=False)
    }
    logging.info("created model with configuration: %s", model_config)

    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

    transform = getattr(model, 'input_transform', default_transform)
    regime = getattr(model, 'regime', [{
        'epoch': 0,
        'optimizer': args.optimizer,
        'lr': args.lr,
        'momentum': args.momentum,
        'weight_decay': args.weight_decay
    }])

    # define loss function (criterion) and optimizer
    criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)()
    criterion.to(args.device, dtype)
    train_data = get_dataset(args.dataset, 'train', transform['train'])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)
    val_data = get_dataset(args.dataset, 'val', transform['eval'])
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    def load_maybe_calibrate(checkpoint):
        try:
            model.load_state_dict(checkpoint)
        except BaseException as e:
            if model_config.get('quantize'):
                measure_name = '{}-{}.measure'.format(args.model,
                                                      model_config['depth'])
                measure_path = os.path.join(save_path, measure_name)
                if os.path.exists(measure_path):
                    logging.info("loading checkpoint '%s'", args.resume)
                    checkpoint = torch.load(measure_path)
                    if 'state_dict' in checkpoint:
                        best_prec1 = checkpoint['best_prec1']
                        checkpoint = checkpoint['state_dict']
                        logging.info(
                            f"Measured checkpoint loaded, reference score top1 {best_prec1:.3f}"
                        )
                    model.load_state_dict(checkpoint)
                else:
                    if model_config.get('absorb_bn'):
                        from utils.absorb_bn import search_absorbe_bn
                        logging.info('absorbing batch normalization')
                        model_config.update({
                            'absorb_bn': False,
                            'quantize': False
                        })
                        model_bn = model_builder(**model_config)
                        model_bn.load_state_dict(checkpoint)
                        search_absorbe_bn(model_bn, verbose=True)
                        model_config.update({
                            'absorb_bn': True,
                            'quantize': True
                        })
                        checkpoint = model_bn.state_dict()
                    model.load_state_dict(checkpoint, strict=False)
                    logging.info("set model measure mode")
                    # set_bn_is_train(model,False)
                    set_measure_mode(model, True, logger=logging)
                    logging.info(
                        "calibrating apprentice model to get quant params")
                    model.to(args.device, dtype)
                    with torch.no_grad():
                        losses_avg, top1_avg, top5_avg = forward(
                            val_loader,
                            model,
                            criterion,
                            0,
                            training=False,
                            optimizer=None)
                    logging.info('Measured float resutls:\nLoss {loss:.4f}\t'
                                 'Prec@1 {top1:.3f}\t'
                                 'Prec@5 {top5:.3f}'.format(loss=losses_avg,
                                                            top1=top1_avg,
                                                            top5=top5_avg))
                    set_measure_mode(model, False, logger=logging)
                    # logging.info("test quant model accuracy")
                    # losses_avg, top1_avg, top5_avg = validate(val_loader, model, criterion, 0)
                    # logging.info('Quantized results:\nLoss {loss:.4f}\t'
                    #              'Prec@1 {top1:.3f}\t'
                    #              'Prec@5 {top5:.3f}'.format(loss=losses_avg, top1=top1_avg, top5=top5_avg))

                    save_checkpoint(
                        {
                            'epoch': 0,
                            'model': args.model,
                            'config': args.model_config,
                            'state_dict': model.state_dict(),
                            'best_prec1': top1_avg,
                            'regime': regime
                        },
                        True,
                        path=save_path,
                        save_all=True,
                        filename=measure_name)

            else:
                raise e

    # optionally resume from a checkpoint
    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            parser.error('invalid checkpoint: {}'.format(args.evaluate))
        checkpoint = torch.load(args.evaluate)
        # model.load_state_dict(checkpoint['state_dict'])
        # logging.info("loaded checkpoint '%s' (epoch %s)",
        #              args.evaluate, checkpoint['epoch'])
        load_maybe_calibrate(checkpoint)
    elif args.resume:
        checkpoint_file = args.resume
        if os.path.isdir(checkpoint_file):
            results.load(os.path.join(checkpoint_file, 'results.csv'))
            checkpoint_file = os.path.join(checkpoint_file,
                                           'model_best.pth.tar')
        if os.path.isfile(checkpoint_file):
            logging.info("loading checkpoint '%s'", args.resume)
            checkpoint = torch.load(checkpoint_file)
            if 'state_dict' in checkpoint:
                if checkpoint['epoch'] > 0:
                    args.start_epoch = checkpoint['epoch'] - 1
                best_prec1 = checkpoint['best_prec1']
                checkpoint = checkpoint['state_dict']

            try:
                model.load_state_dict(checkpoint)
            except BaseException as e:
                if model_config.get('quantize'):
                    if model_config.get('absorb_bn'):
                        from utils.absorb_bn import search_absorbe_bn
                        logging.info('absorbing batch normalization')
                        model_config.update({
                            'absorb_bn': False,
                            'quantize': False
                        })
                        model_bn = model_builder(**model_config)
                        model_bn.load_state_dict(checkpoint)
                        search_absorbe_bn(model_bn, verbose=True)
                        model_config.update({
                            'absorb_bn': True,
                            'quantize': True
                        })
                        checkpoint = model_bn.state_dict()
                    model.load_state_dict(checkpoint, strict=False)
                    model.to(args.device, dtype)
                    logging.info("set model measure mode")
                    # set_bn_is_train(model,False)
                    set_measure_mode(model, True, logger=logging)
                    logging.info(
                        "calibrating apprentice model to get quant params")
                    model.to(args.device, dtype)
                    with torch.no_grad():
                        losses_avg, top1_avg, top5_avg = forward(
                            val_loader,
                            model,
                            criterion,
                            0,
                            training=False,
                            optimizer=None)
                    logging.info('Measured float resutls:\nLoss {loss:.4f}\t'
                                 'Prec@1 {top1:.3f}\t'
                                 'Prec@5 {top5:.3f}'.format(loss=losses_avg,
                                                            top1=top1_avg,
                                                            top5=top5_avg))
                    set_measure_mode(model, False, logger=logging)
                    logging.info("test quant model accuracy")
                    losses_avg, top1_avg, top5_avg = validate(
                        val_loader, model, criterion, 0)
                    logging.info('Quantized results:\nLoss {loss:.4f}\t'
                                 'Prec@1 {top1:.3f}\t'
                                 'Prec@5 {top5:.3f}'.format(loss=losses_avg,
                                                            top1=top1_avg,
                                                            top5=top5_avg))
                    save_checkpoint(
                        {
                            'epoch': 0,
                            'model': args.model,
                            'config': args.model_config,
                            'state_dict': model.state_dict(),
                            'best_prec1': top1_avg,
                            'regime': regime
                        },
                        True,
                        path=save_path,
                        save_freq=5)
                    #save_checkpoint(model.state_dict(), is_best=True, path=save_path, save_all=True)
                    logging.info(
                        f'overwriting quantization method with {args.q_method}'
                    )
                    set_global_quantization_method(model, args.q_method)
                else:
                    raise e

            logging.info("loaded checkpoint '%s' (epoch %s)", checkpoint_file,
                         args.start_epoch)
        else:
            logging.error("no checkpoint found at '%s'", args.resume)
    if args.evaluate:
        if model_config.get('quantize'):
            logging.info(
                f'overwriting quantization method with {args.q_method}')
            set_global_quantization_method(model, args.q_method)
        losses_avg, top1_avg, top5_avg = validate(val_loader, model, criterion,
                                                  0)
        logging.info('Evaluation results:\nLoss {loss:.4f}\t'
                     'Prec@1 {top1:.3f}\t'
                     'Prec@5 {top5:.3f}'.format(loss=losses_avg,
                                                top1=top1_avg,
                                                top5=top5_avg))
        return

    optimizer = OptimRegime(model, regime)
    logging.info('training regime: %s', regime)

    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        train_loss, train_prec1, train_prec5 = train(train_loader, model,
                                                     criterion, epoch,
                                                     optimizer)

        # evaluate on validation set
        val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion,
                                                  epoch)

        # remember best prec@1 and save checkpoint
        is_best = val_prec1 > best_prec1
        best_prec1 = max(val_prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': args.model,
                'config': args.model_config,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'regime': regime
            },
            is_best,
            path=save_path)
        logging.info('\n Epoch: {0}\t'
                     'Training Loss {train_loss:.4f} \t'
                     'Training Prec@1 {train_prec1:.3f} \t'
                     'Training Prec@5 {train_prec5:.3f} \t'
                     'Validation Loss {val_loss:.4f} \t'
                     'Validation Prec@1 {val_prec1:.3f} \t'
                     'Validation Prec@5 {val_prec5:.3f} \n'.format(
                         epoch + 1,
                         train_loss=train_loss,
                         val_loss=val_loss,
                         train_prec1=train_prec1,
                         val_prec1=val_prec1,
                         train_prec5=train_prec5,
                         val_prec5=val_prec5))

        results.add(epoch=epoch + 1,
                    train_loss=train_loss,
                    val_loss=val_loss,
                    train_error1=100 - train_prec1,
                    val_error1=100 - val_prec1,
                    train_error5=100 - train_prec5,
                    val_error5=100 - val_prec5)
        results.plot(x='epoch',
                     y=['train_loss', 'val_loss'],
                     legend=['training', 'validation'],
                     title='Loss',
                     ylabel='loss')
        results.plot(x='epoch',
                     y=['train_error1', 'val_error1'],
                     legend=['training', 'validation'],
                     title='Error@1',
                     ylabel='error %')
        results.plot(x='epoch',
                     y=['train_error5', 'val_error5'],
                     legend=['training', 'validation'],
                     title='Error@5',
                     ylabel='error %')
        results.save()
예제 #13
0
def main(args):
    set_global_seeds(args.seed)

    device = args.device
    dtype = torch_dtypes.get(args.dtype)

    if 'cuda' in args.device:
        device_id = args.device_ids
        device = torch.device(device, device_id)
        torch.cuda.set_device(device_id)

    save_path = os.path.join(args.results_dir, args.model)
    if not os.path.exists(args.results_dir):
        os.mkdir(args.results_dir)
    log_file = open(os.path.join(args.results_dir, args.model + ".log"), "w")

    regime = literal_eval(args.optimization_config)
    model_config = literal_eval(args.model_config)

    vocab, rev_vocab = pickle.load(open(args.vocab, 'rb'))

    model_config.setdefault('encoder', {})
    model_config.setdefault('decoder', {})
    model_config['encoder']['vocab_size'] = len(vocab)
    model_config['decoder']['vocab_size'] = len(vocab)
    model_config['vocab_size'] = model_config['decoder']['vocab_size']
    args.model_config = model_config
    model = transformer.Transformer(**model_config)
    model.to(device)

    criterion = nn.NLLLoss(ignore_index=PAD)
    params = model.parameters()

    optimizer = optim.Adam(params, lr=regime['lr'])

    # load data, word vocab, and parse vocab
    h5f_train = h5py.File(args.train_data, 'r')
    inp_train = h5f_train['inputs']
    out_train = h5f_train['outputs']
    input_lens_train = h5f_train['input_lens']
    output_lens_train = h5f_train['output_lens']
    inp_order_train = h5f_train['reordering_input']
    out_order_train = h5f_train['reordering_output']
    print("training samples: %d" % len(inp_train))
    log_file.write("training samples: %d \n" % len(inp_train))

    batch_size = args.batch_size
    h5f_dev = h5py.File(args.dev_data, 'r')
    inp_dev = h5f_dev['inputs'][0:500]
    out_dev = h5f_dev['outputs'][0:500]
    input_lens_dev = h5f_dev['input_lens'][0:500]
    output_lens_dev = h5f_dev['output_lens'][0:500]
    inp_order_dev = h5f_dev['reordering_input'][0:500]

    include_coverage_loss = False
    include_reorder_information = args.include_reorder_information

    train_minibatches = [(start, start + batch_size)
                         for start in range(0, inp_train.shape[0], batch_size)
                         ][:-1]
    dev_minibatches = [(start, start + batch_size)
                       for start in range(0, inp_dev.shape[0], batch_size)
                       ][:-1]
    random.shuffle(train_minibatches)

    log_file.write("num training batches: %d \n \n" % len(train_minibatches))

    coverage_coef = 0.5
    for ep in range(args.epochs):
        random.shuffle(train_minibatches)
        ep_loss = 0.
        start_time = time.time()
        num_batches = 0
        cov_loss = 0.

        for b_idx, (start, end) in enumerate(train_minibatches):
            inp = inp_train[start:end]
            out = out_train[start:end]
            in_len = input_lens_train[start:end]
            out_len = output_lens_train[start:end]
            in_order = inp_order_train[start:end]
            out_order = out_order_train[start:end]

            # chop input based on length of last instance (for encoder efficiency)
            max_in_len = int(np.amax(in_len))
            inp = inp[:, :max_in_len]
            in_order = in_order[:, :max_in_len]

            # compute max output length and chop output (for decoder efficiency)
            max_out_len = int(np.amax(out_len))
            out = out[:, :max_out_len]
            out_order = out_order[:, :max_out_len]

            in_order = np.asarray(in_order)

            # sentences are too short
            if max_in_len < args.min_sent_length:
                continue

            swap = random.random() > 0.5
            if swap:
                inp, out = out, inp
                in_order, out_order = out_order, in_order

            out_x = np.concatenate(
                [out[:, 1:], np.zeros((out.shape[0], 1))], axis=1)

            # torchify input
            curr_inp = Variable(
                torch.from_numpy(inp.astype('int32')).long().cuda())
            curr_out = Variable(
                torch.from_numpy(out.astype('int32')).long().cuda())
            curr_out_x = Variable(
                torch.from_numpy(out_x.astype('int32')).long().cuda())
            curr_in_order = Variable(
                torch.from_numpy(in_order.astype('int32')).long().cuda())

            # forward prop
            if include_reorder_information:
                preds, attention = model(curr_inp,
                                         curr_out,
                                         curr_in_order,
                                         get_attention=True)
            else:
                preds, attention = model(curr_inp,
                                         curr_out,
                                         None,
                                         None,
                                         get_attention=True)
            preds = preds.view(-1, len(vocab))
            preds = nn.functional.log_softmax(preds, -1)

            num_batches += 1
            # compute masked loss
            loss = criterion(preds, curr_out_x.view(-1))

            if include_coverage_loss:
                coverage_loss = 0
                attention = attention[
                    1]  ## Batch size * max out len * max in len
                coverage = torch.zeros(
                    (attention.shape[0], attention.shape[2])).cuda()
                for att_idx in range(0, attention.shape[1]):
                    if att_idx == 0:
                        c_t = coverage
                    else:
                        c_t = coverage + attention[:,
                                                   att_idx - 1, :].squeeze(1)

                    x = torch.min(attention[:, att_idx, :].squeeze(1), c_t)
                    coverage_loss += torch.mean(torch.sum(x, 1))

                coverage_loss = coverage_loss / attention.shape[1]
                loss_total = loss + coverage_coef * coverage_loss
                cov_loss += coverage_loss.item()

            else:
                loss_total = loss

            optimizer.zero_grad()
            loss_total.backward(retain_graph=False)
            torch.nn.utils.clip_grad_norm_(params, args.grad_clip)
            optimizer.step()
            ep_loss += loss.data.item()

            if b_idx % (args.save_freq) == 0:

                to_print = random.randint(0, len(dev_minibatches) - 1)
                dev_nll = 0.
                for b_dev_idx, (start, end) in enumerate(dev_minibatches):

                    inp = inp_dev[start:end]
                    out = out_dev[start:end]
                    in_len = input_lens_dev[start:end]
                    out_len = output_lens_dev[start:end]
                    in_order = inp_order_dev[start:end]
                    curr_bsz = inp.shape[0]

                    max_in_len = int(np.amax(in_len))
                    inp = inp[:, :max_in_len]
                    in_order = in_order[:, :max_in_len]

                    max_out_len = int(np.amax(out_len))
                    out = out[:, :max_out_len]
                    out_x = np.concatenate(
                        [out[:, 1:], np.zeros((out.shape[0], 1))], axis=1)

                    curr_inp = Variable(
                        torch.from_numpy(inp.astype('int32')).long().cuda())
                    curr_out = Variable(
                        torch.from_numpy(out.astype('int32')).long().cuda())
                    curr_out_x = Variable(
                        torch.from_numpy(out_x.astype('int32')).long().cuda())
                    curr_in_order = Variable(
                        torch.from_numpy(
                            in_order.astype('int32')).long().cuda())

                    if include_reorder_information:
                        preds, _ = model(curr_inp, curr_out, curr_in_order)
                    else:
                        preds, _ = model(curr_inp, curr_out, None, None)
                    preds = preds.view((-1, len(vocab)))
                    preds = nn.functional.log_softmax(preds, -1)

                    bos = Variable(
                        torch.from_numpy(
                            np.asarray([vocab["BOS"]
                                        ]).astype('int32')).long().cuda())
                    loss_dev = criterion(preds, curr_out_x.view(-1))
                    dev_nll += loss_dev.item()

                    preds = preds.view(curr_bsz, max_out_len,
                                       -1).cpu().data.numpy()

                    if b_dev_idx == to_print:
                        for i in range(min(3, curr_bsz)):

                            print('input: %s' % ' '.join([rev_vocab[w] for (j, w) in enumerate(inp[i]) \
                                                          if j < in_len[i]]))
                            print('gt output: %s' % ' '.join([rev_vocab[w] for (j, w) in enumerate(out[i]) \
                                                              if j < out_len[i]]))

                            if include_reorder_information:
                                x = model.generate(
                                    curr_inp[i].unsqueeze(0), [list(bos)],
                                    curr_in_order[i].unsqueeze(0),
                                    beam_size=5,
                                    max_sequence_length=50)[0]
                            else:
                                x = model.generate(curr_inp[i].unsqueeze(0),
                                                   [list(bos)],
                                                   None,
                                                   beam_size=5,
                                                   max_sequence_length=50)[0]
                            preds = [s.output for s in x]
                            print([
                                ' '.join(
                                    [rev_vocab[int(w.data.cpu())] for w in p])
                                for p in preds
                            ][0])
                            print("\n")

                print('dev nll per token: %f' %
                      (dev_nll / float(len(dev_minibatches))))

                print('done with batch %d / %d in epoch %d, loss: %f, cov loss: %f, time:%d' \
                      % (b_idx, len(train_minibatches), ep,
                         ep_loss / num_batches, cov_loss / num_batches, time.time() - start_time))
                print('train nll per token : %f \n' %
                      (float(ep_loss) / float(num_batches)))

                torch.save(
                    {
                        'state_dict': model.state_dict(),
                        'ep_loss': ep_loss / num_batches,
                        'train_minibatches': train_minibatches,
                        'config_args': args
                    }, save_path)

                log_file.write("epoch : %d , batch : %d\n" % (ep, num_batches))
                log_file.write("dev nll: %f \n" %
                               (dev_nll / float(len(dev_minibatches))))
                log_file.write("train nll: %f \n \n" %
                               (float(ep_loss) / float(num_batches)))

                ep_loss = 0.
                num_batches = 0.
                start_time = time.time()
예제 #14
0
def main_worker(args):
    global best_prec1, dtype
    acc = -1
    loss = -1
    best_prec1 = 0
    dtype = torch_dtypes.get(args.dtype)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    if args.evaluate:
        args.results_dir = '/tmp'
    if args.save is '':
        args.save = time_stamp
    save_path = os.path.join(args.results_dir, args.save)

    args.distributed = args.local_rank >= 0 or args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_init,
                                world_size=args.world_size,
                                rank=args.local_rank)
        args.local_rank = dist.get_rank()
        args.world_size = dist.get_world_size()
        if args.dist_backend == 'mpi':
            # If using MPI, select all visible devices
            args.device_ids = list(range(torch.cuda.device_count()))
        else:
            args.device_ids = [args.local_rank]

    if not os.path.exists(save_path) and not (args.distributed
                                              and args.local_rank > 0):
        os.makedirs(save_path)

    setup_logging(os.path.join(save_path, 'log.txt'),
                  resume=args.resume is not '',
                  dummy=args.distributed and args.local_rank > 0)

    results_path = os.path.join(save_path, 'results')
    results = ResultsLog(results_path,
                         title='Training Results - %s' % args.save)

    logging.info("saving to %s", save_path)
    logging.debug("run arguments: %s", args)
    logging.info("creating model %s", args.model)

    if 'cuda' in args.device and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.cuda.set_device(args.device_ids[0])
        cudnn.benchmark = True
    else:
        args.device_ids = None

    # create model
    model = models.__dict__[args.model]
    dataset_type = 'imagenet' if args.dataset == 'imagenet_calib' else args.dataset
    model_config = {'dataset': dataset_type}

    if args.model_config is not '':
        if isinstance(args.model_config, dict):
            for k, v in args.model_config.items():
                if k not in model_config.keys():
                    model_config[k] = v
        else:
            args_dict = literal_eval(args.model_config)
            for k, v in args_dict.items():
                model_config[k] = v
    if (args.absorb_bn or args.load_from_vision
            or args.pretrained) and not args.batch_norn_tuning:
        if args.load_from_vision:
            import torchvision
            exec_lfv_str = 'torchvision.models.' + args.load_from_vision + '(pretrained=True)'
            model = eval(exec_lfv_str)
            if 'pytcv' in args.model:
                from pytorchcv.model_provider import get_model as ptcv_get_model
                exec_lfv_str = 'ptcv_get_model("' + args.load_from_vision + '", pretrained=True)'
                model_pytcv = eval(exec_lfv_str)
                model = convert_pytcv_model(model, model_pytcv)
        else:
            if not os.path.isfile(args.absorb_bn):
                parser.error('invalid checkpoint: {}'.format(args.evaluate))
            model = model(**model_config)
            checkpoint = torch.load(args.absorb_bn,
                                    map_location=lambda storage, loc: storage)
            checkpoint = checkpoint[
                'state_dict'] if 'state_dict' in checkpoint.keys(
                ) else checkpoint
            model.load_state_dict(checkpoint, strict=False)
        if 'batch_norm' in model_config and not model_config['batch_norm']:
            logging.info('Creating absorb_bn state dict')
            search_absorbe_bn(model)
            filename_ab = args.absorb_bn + '.absorb_bn' if args.absorb_bn else save_path + '/' + args.model + '.absorb_bn'
            torch.save(model.state_dict(), filename_ab)
        else:
            filename_bn = save_path + '/' + args.model + '.with_bn'
            torch.save(model.state_dict(), filename_bn)
        if (args.load_from_vision
                or args.absorb_bn) and not args.evaluate_init_configuration:
            return

    if 'inception' in args.model:
        model = model(init_weights=False, **model_config)
    else:
        model = model(**model_config)
    logging.info("created model with configuration: %s", model_config)

    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

    # optionally resume from a checkpoint
    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            parser.error('invalid checkpoint: {}'.format(args.evaluate))
        checkpoint = torch.load(args.evaluate, map_location="cpu")
        # Overrride configuration with checkpoint info
        args.model = checkpoint.get('model', args.model)
        args.model_config = checkpoint.get('config', args.model_config)
        if not model_config['batch_norm']:
            search_absorbe_fake_bn(model)
        # load checkpoint
        if 'state_dict' in checkpoint.keys():
            model.load_state_dict(checkpoint['state_dict'])
            logging.info("loaded checkpoint '%s'", args.evaluate)
        else:
            model.load_state_dict(checkpoint, strict=False)
            logging.info("loaded checkpoint '%s'", args.evaluate)

    if args.resume:
        checkpoint_file = args.resume
        if os.path.isdir(checkpoint_file):
            results.load(os.path.join(checkpoint_file, 'results.csv'))
            checkpoint_file = os.path.join(checkpoint_file,
                                           'model_best.pth.tar')
        if os.path.isfile(checkpoint_file):
            logging.info("loading checkpoint '%s'", args.resume)
            checkpoint = torch.load(checkpoint_file)
            if args.start_epoch < 0:  # not explicitly set
                args.start_epoch = checkpoint[
                    'epoch'] - 1 if 'epoch' in checkpoint.keys() else 0
            best_prec1 = checkpoint[
                'best_prec1'] if 'best_prec1' in checkpoint.keys() else -1
            sd = checkpoint['state_dict'] if 'state_dict' in checkpoint.keys(
            ) else checkpoint
            model.load_state_dict(sd, strict=False)
            logging.info("loaded checkpoint '%s' (epoch %s)", checkpoint_file,
                         args.start_epoch)
        else:
            logging.error("no checkpoint found at '%s'", args.resume)

    # define loss function (criterion) and optimizer
    loss_params = {}
    if args.label_smoothing > 0:
        loss_params['smooth_eps'] = args.label_smoothing
    criterion = getattr(model, 'criterion', CrossEntropyLoss)(**loss_params)
    if args.kld_loss:
        criterion = nn.KLDivLoss(reduction='mean')
    criterion.to(args.device, dtype)
    model.to(args.device, dtype)

    # Batch-norm should always be done in float
    if 'half' in args.dtype:
        FilterModules(model, module=is_bn).to(dtype=torch.float)

    # optimizer configuration
    optim_regime = getattr(model, 'regime', [{
        'epoch': 0,
        'optimizer': args.optimizer,
        'lr': args.lr,
        'momentum': args.momentum,
        'weight_decay': args.weight_decay
    }])
    if args.fine_tune or args.prune:
        if not args.resume: args.start_epoch = 0
        if args.update_only_th:
            #optim_regime = [
            #    {'epoch': 0, 'optimizer': 'Adam', 'lr': 1e-4}]
            optim_regime = [{
                'epoch': 0,
                'optimizer': 'SGD',
                'lr': 1e-1
            }, {
                'epoch': 10,
                'lr': 1e-2
            }, {
                'epoch': 15,
                'lr': 1e-3
            }]
        else:
            optim_regime = [{
                'epoch': 0,
                'optimizer': 'SGD',
                'lr': 1e-4,
                'momentum': 0.9
            }, {
                'epoch': 2,
                'lr': 1e-5,
                'momentum': 0.9
            }, {
                'epoch': 10,
                'lr': 1e-6,
                'momentum': 0.9
            }]
    optimizer = optim_regime if isinstance(optim_regime, OptimRegime) \
        else OptimRegime(model, optim_regime, use_float_copy='half' in args.dtype)

    # Training Data loading code

    train_data = DataRegime(getattr(model, 'data_regime', None),
                            defaults={
                                'datasets_path': args.datasets_dir,
                                'name': args.dataset,
                                'split': 'train',
                                'augment': False,
                                'input_size': args.input_size,
                                'batch_size': args.batch_size,
                                'shuffle': not args.seq_adaquant,
                                'num_workers': args.workers,
                                'pin_memory': True,
                                'drop_last': True,
                                'distributed': args.distributed,
                                'duplicates': args.duplicates,
                                'autoaugment': args.autoaugment,
                                'cutout': {
                                    'holes': 1,
                                    'length': 16
                                } if args.cutout else None,
                                'inception_prep': 'inception' in args.model
                            })
    if args.names_sp_layers is None and args.layers_precision_dict is None:
        args.names_sp_layers = [
            key[:-7] for key in model.state_dict().keys()
            if 'weight' in key and 'running' not in key and (
                'conv' in key or 'downsample.0' in key or 'fc' in key)
        ]
        if args.keep_first_last:
            args.names_sp_layers = [
                name for name in args.names_sp_layers if name != 'conv1'
                and name != 'fc' and name != 'Conv2d_1a_3x3.conv'
            ]
        args.names_sp_layers = [
            k for k in args.names_sp_layers if 'downsample' not in k
        ] if args.ignore_downsample else args.names_sp_layers
        if args.num_sp_layers == 0 and not args.keep_first_last:
            args.names_sp_layers = []

    if args.layers_precision_dict is not None:
        print(args.layers_precision_dict)

    prunner = None
    trainer = Trainer(model,
                      prunner,
                      criterion,
                      optimizer,
                      device_ids=args.device_ids,
                      device=args.device,
                      dtype=dtype,
                      distributed=args.distributed,
                      local_rank=args.local_rank,
                      mixup=args.mixup,
                      loss_scale=args.loss_scale,
                      grad_clip=args.grad_clip,
                      print_freq=args.print_freq,
                      adapt_grad_norm=args.adapt_grad_norm,
                      epoch=args.start_epoch,
                      update_only_th=args.update_only_th,
                      optimize_rounding=args.optimize_rounding)

    # Evaluation Data loading code
    args.eval_batch_size = args.eval_batch_size if args.eval_batch_size > 0 else args.batch_size
    dataset_type = 'imagenet' if args.dataset == 'imagenet_calib' else args.dataset
    val_data = DataRegime(getattr(model, 'data_eval_regime', None),
                          defaults={
                              'datasets_path': args.datasets_dir,
                              'name': dataset_type,
                              'split': 'val',
                              'augment': False,
                              'input_size': args.input_size,
                              'batch_size': args.eval_batch_size,
                              'shuffle': True,
                              'num_workers': args.workers,
                              'pin_memory': True,
                              'drop_last': False
                          })

    if args.evaluate or args.resume:
        from utils.layer_sensativity import search_replace_layer, extract_save_quant_state_dict, search_replace_layer_from_dict
        if args.layers_precision_dict is not None:
            model = search_replace_layer_from_dict(
                model, ast.literal_eval(args.layers_precision_dict))
        else:
            model = search_replace_layer(model,
                                         args.names_sp_layers,
                                         num_bits_activation=args.nbits_act,
                                         num_bits_weight=args.nbits_weight)

    cached_input_output = {}
    quant_keys = [
        '.weight', '.bias', '.equ_scale', '.quantize_input.running_zero_point',
        '.quantize_input.running_range', '.quantize_weight.running_zero_point',
        '.quantize_weight.running_range',
        '.quantize_input1.running_zero_point', '.quantize_input1.running_range'
        '.quantize_input2.running_zero_point', '.quantize_input2.running_range'
    ]
    if args.adaquant:

        def Qhook(name, module, input, output):
            if module not in cached_qinput:
                cached_qinput[module] = []
            # Meanwhile store data in the RAM.
            cached_qinput[module].append(input[0].detach().cpu())
            # print(name)

        def hook(name, module, input, output):
            if module not in cached_input_output:
                cached_input_output[module] = []
            # Meanwhile store data in the RAM.
            cached_input_output[module].append(
                (input[0].detach().cpu(), output.detach().cpu()))
            # print(name)

        from models.modules.quantize import QConv2d, QLinear
        handlers = []
        count = 0
        for name, m in model.named_modules():
            if isinstance(m, QConv2d) or isinstance(m, QLinear):
                #if isinstance(m, QConv2d) or isinstance(m, QLinear):
                # if isinstance(m, QConv2d):
                m.quantize = False
                if count < 1000:
                    # if (isinstance(m, QConv2d) and m.groups == 1) or isinstance(m, QLinear):
                    handlers.append(
                        m.register_forward_hook(partial(hook, name)))
                    count += 1

        # Store input/output for all quantizable layers
        trainer.validate(train_data.get_loader())
        print("Input/outputs cached")

        for handler in handlers:
            handler.remove()

        for m in model.modules():
            if isinstance(m, QConv2d) or isinstance(m, QLinear):
                m.quantize = True

        mse_df = pd.DataFrame(
            index=np.arange(len(cached_input_output)),
            columns=['name', 'bit', 'shape', 'mse_before', 'mse_after'])
        print_freq = 100
        for i, layer in enumerate(cached_input_output):
            if i > 0 and args.seq_adaquant:
                count = 0
                cached_qinput = {}
                for name, m in model.named_modules():
                    if layer.name == name:
                        if count < 1000:
                            handler = m.register_forward_hook(
                                partial(Qhook, name))
                            count += 1
                # Store input/output for all quantizable layers
                trainer.validate(train_data.get_loader())
                print("cashed quant Input%s" % layer.name)
                cached_input_output[layer][0] = (
                    cached_qinput[layer][0], cached_input_output[layer][0][1])
                handler.remove()
            print("\nOptimize {}:{} for {} bit of shape {}".format(
                i, layer.name, layer.num_bits, layer.weight.shape))
            mse_before, mse_after, snr_before, snr_after, kurt_in, kurt_w = \
                optimize_layer(layer, cached_input_output[layer], args.optimize_weights, batch_size=args.batch_size, model_name=args.model)
            print("\nMSE before optimization: {}".format(mse_before))
            print("MSE after optimization:  {}".format(mse_after))
            mse_df.loc[i, 'name'] = layer.name
            mse_df.loc[i, 'bit'] = layer.num_bits
            mse_df.loc[i, 'shape'] = str(layer.weight.shape)
            mse_df.loc[i, 'mse_before'] = mse_before
            mse_df.loc[i, 'mse_after'] = mse_after
            mse_df.loc[i, 'snr_before'] = snr_before
            mse_df.loc[i, 'snr_after'] = snr_after
            mse_df.loc[i, 'kurt_in'] = kurt_in
            mse_df.loc[i, 'kurt_w'] = kurt_w

        mse_csv = args.evaluate + '.mse.csv'
        mse_df.to_csv(mse_csv)

        filename = args.evaluate + '.adaquant'
        torch.save(model.state_dict(), filename)

        train_data = None
        cached_input_output = None
        val_results = trainer.validate(val_data.get_loader())
        logging.info(val_results)

        if args.res_log is not None:
            if not os.path.exists(args.res_log):
                df = pd.DataFrame()
            else:
                df = pd.read_csv(args.res_log, index_col=0)

            ckp = ntpath.basename(args.evaluate)
            if args.cmp is not None:
                ckp += '_{}'.format(args.cmp)
            adaquant_type = 'adaquant_seq' if args.seq_adaquant else 'adaquant_parallel'
            df.loc[ckp, 'acc_' + adaquant_type] = val_results['prec1']
            df.to_csv(args.res_log)
            # print(df)

    elif args.per_layer:
        # Store input/output for all quantizable layers
        calib_all_8_results = trainer.validate(train_data.get_loader())
        print('########## All 8bit results ###########', calib_all_8_results)
        int8_opt_model_state_dict = torch.load(args.int8_opt_model_path)
        int4_opt_model_state_dict = torch.load(args.int4_opt_model_path)

        per_layer_results = {}
        args.names_sp_layers = [
            key[:-7] for key in model.state_dict().keys()
            if 'weight' in key and 'running' not in key and 'quantize' not in
            key and ('conv' in key or 'downsample.0' in key or 'fc' in key)
        ]
        for layer_idx, layer in enumerate(args.names_sp_layers):
            model.load_state_dict(int8_opt_model_state_dict, strict=False)
            model = search_replace_layer(model, [layer],
                                         num_bits_activation=args.nbits_act,
                                         num_bits_weight=args.nbits_weight)
            layer_keys = [
                key for key in int8_opt_model_state_dict
                for qpkey in quant_keys if layer + qpkey == key
            ]
            for key in layer_keys:
                model.state_dict()[key].copy_(int4_opt_model_state_dict[key])
            calib_results = trainer.validate(train_data.get_loader())
            model = search_replace_layer(model, [layer],
                                         num_bits_activation=8,
                                         num_bits_weight=8)
            print('finished %d out of %d' %
                  (layer_idx, len(args.names_sp_layers)))
            logging.info(layer)
            logging.info(calib_results)
            per_layer_results[layer] = {
                'base precision':
                8,
                'replaced precision':
                args.nbits_act,
                'replaced layer':
                layer,
                'accuracy':
                calib_results['prec1'],
                'loss':
                calib_results['loss'],
                'Parameters Size [Elements]':
                model.state_dict()[layer + '.weight'].numel(),
                'MACs':
                '-'
            }

        torch.save(
            per_layer_results, args.evaluate + '.per_layer_accuracy.A' +
            str(args.nbits_act) + '.W' + str(args.nbits_weight))
        all_8_dict = {
            'base precision': 8,
            'replaced precision': args.nbits_act,
            'replaced layer': '-',
            'accuracy': calib_all_8_results['prec1'],
            'loss': calib_all_8_results['loss'],
            'Parameters Size [Elements]': '-',
            'MACs': '-'
        }
        columns = [key for key in all_8_dict]
        with open(
                args.evaluate + '.per_layer_accuracy.A' + str(args.nbits_act) +
                '.W' + str(args.nbits_weight) + '.csv', "w") as f:
            f.write(",".join(columns) + "\n")
            col = [str(all_8_dict[c]) for c in all_8_dict.keys()]
            f.write(",".join(col) + "\n")
            for layer in per_layer_results:
                r = per_layer_results[layer]
                col = [str(r[c]) for c in r.keys()]
                f.write(",".join(col) + "\n")
    elif args.mixed_builder:
        if isinstance(args.names_sp_layers, list):
            print('loading int8 model" ', args.int8_opt_model_path)
            int8_opt_model_state_dict = torch.load(args.int8_opt_model_path)
            print('loading int4 model" ', args.int4_opt_model_path)
            int4_opt_model_state_dict = torch.load(args.int4_opt_model_path)

            model.load_state_dict(int8_opt_model_state_dict, strict=False)
            model = search_replace_layer(model,
                                         args.names_sp_layers,
                                         num_bits_activation=args.nbits_act,
                                         num_bits_weight=args.nbits_weight)
            for layer_idx, layer in enumerate(args.names_sp_layers):
                layer_keys = [
                    key for key in int8_opt_model_state_dict
                    for qpkey in quant_keys if layer + qpkey == key
                ]
                for key in layer_keys:
                    model.state_dict()[key].copy_(
                        int4_opt_model_state_dict[key])
                print('switched layer %s to 4 bit' % (layer))
        elif isinstance(args.names_sp_layers, dict):
            quant_models = {}
            base_precision = args.precisions[0]
            for m, prec in zip(args.opt_model_paths, args.precisions):
                print('For precision={}, loading {}'.format(prec, m))
                quant_models[prec] = torch.load(m)
            model.load_state_dict(quant_models[base_precision], strict=False)
            for layer_name, nbits_list in args.names_sp_layers.items():
                model = search_replace_layer(model, [layer_name],
                                             num_bits_activation=nbits_list[0],
                                             num_bits_weight=nbits_list[0])
                layer_keys = [
                    key for key in quant_models[base_precision]
                    for qpkey in quant_keys if layer_name + qpkey == key
                ]
                for key in layer_keys:
                    model.state_dict()[key].copy_(
                        quant_models[nbits_list[0]][key])
                print('switched layer {} to {} bit'.format(
                    layer_name, nbits_list[0]))
        if os.environ.get('DEBUG') == 'True':
            from utils.layer_sensativity import check_quantized_model
            fp_names = check_quantized_model(trainer.model)
            if len(fp_names) > 0:
                logging.info('Found FP32 layers in the model:')
                logging.info(fp_names)
        if args.eval_on_train:
            mixedIP_results = trainer.validate(train_data.get_loader())
        else:
            mixedIP_results = trainer.validate(val_data.get_loader())
        torch.save(
            {
                'state_dict': model.state_dict(),
                'config-ip': args.names_sp_layers
            }, args.evaluate + '.mixed-ip-results.' + args.suffix)
        logging.info(mixedIP_results)
        acc = mixedIP_results['prec1']
        loss = mixedIP_results['loss']
    elif args.batch_norn_tuning:
        from utils.layer_sensativity import search_replace_layer, extract_save_quant_state_dict, search_replace_layer_from_dict
        from models.modules.quantize import QConv2d
        if args.layers_precision_dict is not None:
            model = search_replace_layer_from_dict(
                model, literal_eval(args.layers_precision_dict))
        else:
            model = search_replace_layer(model,
                                         args.names_sp_layers,
                                         num_bits_activation=args.nbits_act,
                                         num_bits_weight=args.nbits_weight)

        exec_lfv_str = 'torchvision.models.' + args.load_from_vision + '(pretrained=True)'
        model_orig = eval(exec_lfv_str)
        model_orig.to(args.device, dtype)
        search_copy_bn_params(model_orig)

        layers_orig = dict([(n, m) for n, m in model_orig.named_modules()
                            if isinstance(m, nn.Conv2d)])
        layers_q = dict([(n, m) for n, m in model.named_modules()
                         if isinstance(m, QConv2d)])
        for l in layers_orig:
            conv_orig = layers_orig[l]
            conv_q = layers_q[l]
            conv_q.register_parameter('gamma',
                                      nn.Parameter(conv_orig.gamma.clone()))
            conv_q.register_parameter('beta',
                                      nn.Parameter(conv_orig.beta.clone()))

        del model_orig

        search_add_bn(model)

        print("Run BN tuning")
        for tt in range(args.tuning_iter):
            print(tt)
            trainer.cal_bn_stats(train_data.get_loader())

        search_absorbe_tuning_bn(model)

        filename = args.evaluate + '.bn_tuning'
        print("Save model to: {}".format(filename))
        torch.save(model.state_dict(), filename)

        val_results = trainer.validate(val_data.get_loader())
        logging.info(val_results)

        if args.res_log is not None:
            if not os.path.exists(args.res_log):
                df = pd.DataFrame()
            else:
                df = pd.read_csv(args.res_log, index_col=0)

            ckp = ntpath.basename(args.evaluate)
            df.loc[ckp, 'acc_bn_tuning'] = val_results['prec1']
            df.loc[ckp, 'loss_bn_tuning'] = val_results['loss']
            df.to_csv(args.res_log)
            # print(df)

    elif args.bias_tuning:
        for epoch in range(args.epochs):
            trainer.epoch = epoch
            train_data.set_epoch(epoch)
            val_data.set_epoch(epoch)
            logging.info('\nStarting Epoch: {0}\n'.format(epoch + 1))
            # train for one epoch
            repeat_train = 20 if args.update_only_th else 1
            for tt in range(repeat_train):
                print(tt)
                train_results = trainer.train(
                    train_data.get_loader(),
                    duplicates=train_data.get('duplicates'),
                    chunk_batch=args.chunk_batch)
                logging.info(train_results)

        val_results = trainer.validate(val_data.get_loader())
        logging.info(val_results)
        if args.res_log is not None:
            if not os.path.exists(args.res_log):
                df = pd.DataFrame()
            else:
                df = pd.read_csv(args.res_log, index_col=0)

            ckp = ntpath.basename(args.evaluate)
            if 'bn_tuning' in ckp:
                ckp = ckp.replace('.bn_tuning', '')
            df.loc[ckp, 'acc_bias_tuning'] = val_results['prec1']
            df.to_csv(args.res_log)
        # import pdb; pdb.set_trace()
    else:
        #print('Please Choose one of the following ....')
        if model_config['measure']:
            results = trainer.validate(train_data.get_loader(), rec=args.rec)
            # results = trainer.validate(val_data.get_loader())
            # print(results)
        else:
            if args.evaluate_init_configuration:
                results = trainer.validate(val_data.get_loader())
                if args.res_log is not None:
                    if not os.path.exists(args.res_log):
                        df = pd.DataFrame()
                    else:
                        df = pd.read_csv(args.res_log, index_col=0)

                    ckp = ntpath.basename(args.evaluate)
                    if args.cmp is not None:
                        ckp += '_{}'.format(args.cmp)
                    df.loc[ckp, 'acc_base'] = results['prec1']
                    df.to_csv(args.res_log)

        if args.extract_bias_mean:
            file_name = 'bias_mean_measure' if model_config[
                'measure'] else 'bias_mean_quant'
            torch.save(trainer.bias_mean, file_name)
        if model_config['measure']:
            filename = args.evaluate + '.measure'
            if 'perC' in args.model_config: filename += '_perC'
            torch.save(model.state_dict(), filename)
            logging.info(results)
        else:
            if args.evaluate_init_configuration:
                logging.info(results)
    return acc, loss
예제 #15
0
def main_worker(args):
    global best_prec1, dtype
    best_prec1 = 0
    dtype = torch_dtypes.get(args.dtype)
    torch.manual_seed(args.seed)
    time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    if args.eval_path:
        args.results_dir = os.path.join(args.results_dir, 'evaluating_results')
    else:
        args.results_dir = os.path.join(args.results_dir, 'training_results')
    if not os.path.exists(args.results_dir):
        os.mkdir(args.results_dir)

    if args.save is '':
        args.save = time_stamp
    save_path = os.path.join(args.results_dir, args.save)

    args.distributed = args.local_rank >= 0 or args.world_size > 1

    if not os.path.exists(save_path) and not (args.distributed
                                              and args.local_rank > 0):
        os.makedirs(save_path)

    setup_logging(os.path.join(save_path, 'log.txt'),
                  resume=args.resume is not '',
                  dummy=args.distributed and args.local_rank > 0)

    results_path = os.path.join(save_path, 'results')
    results = ResultsLog(results_path,
                         title='Training Results - %s' % args.save)

    if 'cuda' in args.device and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.cuda.set_device(args.device_ids[0])
        cudnn.benchmark = True
    else:
        args.device_ids = None

    dump_args(args, os.path.join(save_path, 'args.txt'))

    # Pruning and evaluating
    if len(args.pruning_percs) > 0:
        results = []
        checkpoint = torch.load(args.eval_path, map_location="cpu")
        for perc in args.pruning_percs:
            model, criterion = create_model_and_criterion(args)
            model = register_stats_collectors(model)
            model = prune_model(model,
                                checkpoint['state_dict'],
                                prune_perc=perc)
            res = eval_checkpoint(args, model, criterion)['prec1']
            results.append(res)

            # After we gathered min/max statistics we can gather also histograms.
            if args.gather_histograms == True:
                model = register_hist_collectors(model)
                _ = eval_checkpoint(args, model, criterion)['prec1']

            dump_buffers(model, save_path)
        for perc, res in zip(args.pruning_percs, results):
            print("prune%:", perc, " acc:", res)