Exemple #1
0
 def _save_checkpoint(self, epoch, rank1, save_dir, is_best=False):
     save_checkpoint({
         'state_dict': self.model.state_dict(),
         'epoch': epoch + 1,
         'rank1': rank1,
         'optimizer': self.optimizer.state_dict(),
     }, save_dir, is_best=is_best)
Exemple #2
0
    def save_model(self, epoch, rank1, save_dir, is_best=False):
        names = self.get_model_names()

        for name in names:
            save_checkpoint(
                {
                    'state_dict': self._models[name].state_dict(),
                    'epoch': epoch + 1,
                    'rank1': rank1,
                    'optimizer': self._optims[name].state_dict(),
                    'scheduler': self._scheds[name].state_dict()
                },
                osp.join(save_dir, name),
                is_best=is_best)
Exemple #3
0
 def _save_checkpoint(self, epoch, rank1, save_dir, is_best=False):
     try:
         save_checkpoint(
             {
                 'state_dict': self.model.state_dict(),
                 'epoch': epoch + 1,
                 'rank1': rank1,
                 'optimizer': self.optimizer.state_dict(),
                 'scheduler': self.scheduler.state_dict(),
             },
             save_dir,
             is_best=is_best)
     except:
         print("could not save epoch::" + str(epoch))
def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
                                     description='The script adds the initial LR value '
                                                 'to deep-object-reid checkpoints '
                                                 '-- it will allow using it for NNCF, '
                                                 'NNCF part will initialize its LR from the checkpoints`s LR')
    parser.add_argument('--lr', type=float, required=True,
                        help='path to config file')
    parser.add_argument('--checkpoint', type=str, required=True,
                        help='path to the src checkpoint file')
    parser.add_argument('--dst-folder', type=str, required=True,
                        help='path to the dst folder to store dst checkpoint file')
    args = parser.parse_args()

    checkpoint = torch.load(args.checkpoint, map_location='cpu')
    if not isinstance(checkpoint, dict):
        raise RuntimeError('Wrong format of checkpoint -- it is not the result of deep-object-reid training')
    if checkpoint.get('initial_lr'):
        raise RuntimeError(f'Checkpoint {args.checkpoint} already has initial_lr')

    if not os.path.isdir(args.dst_folder):
        raise RuntimeError(f'The dst folder {args.dst_folder} is NOT present')

    checkpoint['initial_lr'] = float(args.lr)
    res_path = save_checkpoint(checkpoint, args.dst_folder)
    def save_model(self,
                   epoch,
                   save_dir,
                   is_best=False,
                   should_save_ema_model=False):
        def create_sym_link(path, name):
            if osp.lexists(name):
                os.remove(name)
            os.symlink(path, name)

        names = self.get_model_names()
        for name in names:
            if should_save_ema_model and name == self.main_model_name:
                assert self.use_ema_decay
                model_state_dict = self.ema_model.module.state_dict()
            else:
                model_state_dict = self.models[name].state_dict()

            checkpoint = {
                'state_dict': model_state_dict,
                'epoch': epoch + 1,
                'optimizer': self.optims[name].state_dict(),
                'scheduler': self.scheds[name].state_dict(),
                'num_classes': self.datamanager.num_train_pids,
                'classes_map': self.datamanager.train_loader.dataset.classes,
                'initial_lr': self.initial_lr,
            }

            if self.compression_ctrl is not None:
                checkpoint[
                    'compression_state'] = self.compression_ctrl.get_compression_state(
                    )
                checkpoint['nncf_metainfo'] = self.nncf_metainfo

            ckpt_path = save_checkpoint(checkpoint,
                                        osp.join(save_dir, name),
                                        is_best=is_best,
                                        name=name)

            if name == self.main_model_name:
                latest_ckpt_filename = 'latest.pth'
                best_ckpt_filename = 'best.pth'
            else:
                latest_ckpt_filename = f'latest_{name}.pth'
                best_ckpt_filename = f'best_{name}.pth'

            latest_name = osp.join(save_dir, latest_ckpt_filename)
            create_sym_link(ckpt_path, latest_name)
            if is_best:
                best_model = osp.join(save_dir, best_ckpt_filename)
                create_sym_link(ckpt_path, best_model)
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description=
        'The script adds the default int8 quantization NNCF metainfo '
        'to NNCF deep-object-reid checkpoints '
        'that were trained when NNCF metainfo was not '
        'stored in NNCF checkpoints')
    parser.add_argument('--config-file',
                        type=str,
                        required=True,
                        help='path to config file')
    parser.add_argument('--checkpoint',
                        type=str,
                        required=True,
                        help='path to the src checkpoint file')
    parser.add_argument(
        '--dst-folder',
        type=str,
        required=True,
        help='path to the dst folder to store dst checkpoint file')
    args = parser.parse_args()

    cfg = get_default_config()
    merge_from_files_with_base(cfg, args.config_file)
    checkpoint = torch.load(args.checkpoint, map_location='cpu')
    if not isinstance(checkpoint, dict):
        raise RuntimeError(
            'Wrong format of checkpoint -- it is not the result of deep-object-reid training'
        )
    if checkpoint.get('nncf_metainfo'):
        raise RuntimeError(
            f'Checkpoint {args.checkpoint} already has nncf_metainfo')

    if not os.path.isdir(args.dst_folder):
        raise RuntimeError(f'The dst folder {args.dst_folder} is NOT present')

    # default nncf config
    h, w = cfg.data.height, cfg.data.width
    nncf_config_data = get_default_nncf_compression_config(h, w)

    nncf_metainfo = {
        'nncf_compression_enabled': True,
        'nncf_config': nncf_config_data
    }
    checkpoint['nncf_metainfo'] = nncf_metainfo
    res_path = save_checkpoint(checkpoint, args.dst_folder)
Exemple #7
0
def main():
    global args

    set_random_seed(args.seed)
    use_gpu = torch.cuda.is_available() and not args.use_cpu
    log_name = 'test.log' if args.evaluate else 'train.log'
    sys.stdout = Logger(osp.join(args.save_dir, log_name))

    print('** Arguments **')
    arg_keys = list(args.__dict__.keys())
    arg_keys.sort()
    for key in arg_keys:
        print('{}: {}'.format(key, args.__dict__[key]))
    print('\n')
    print('Collecting env info ...')
    print('** System info **\n{}\n'.format(collect_env_info()))

    if use_gpu:
        torch.backends.cudnn.benchmark = True
    else:
        warnings.warn(
            'Currently using CPU, however, GPU is highly recommended')

    dataset_vars = init_dataset(use_gpu)
    trainloader, valloader, testloader, num_attrs, attr_dict = dataset_vars

    if args.weighted_bce:
        print('Use weighted binary cross entropy')
        print('Computing the weights ...')
        bce_weights = torch.zeros(num_attrs, dtype=torch.float)
        for _, attrs, _ in trainloader:
            bce_weights += attrs.sum(0)  # sum along the batch dim
        bce_weights /= len(trainloader) * args.batch_size
        print('Sample ratio for each attribute: {}'.format(bce_weights))
        bce_weights = torch.exp(-1 * bce_weights)
        print('BCE weights: {}'.format(bce_weights))
        bce_weights = bce_weights.expand(args.batch_size, num_attrs)
        criterion = nn.BCEWithLogitsLoss(weight=bce_weights)

    else:
        print('Use plain binary cross entropy')
        criterion = nn.BCEWithLogitsLoss()

    print('Building model: {}'.format(args.arch))
    model = models.build_model(args.arch,
                               num_attrs,
                               pretrained=not args.no_pretrained,
                               use_gpu=use_gpu)
    num_params, flops = compute_model_complexity(
        model, (1, 3, args.height, args.width))
    print('Model complexity: params={:,} flops={:,}'.format(num_params, flops))

    if args.load_weights and check_isfile(args.load_weights):
        load_pretrained_weights(model, args.load_weights)

    if use_gpu:
        model = nn.DataParallel(model).cuda()
        criterion = criterion.cuda()

    if args.evaluate:
        test(model, testloader, attr_dict, use_gpu)
        return

    optimizer = torchreid.optim.build_optimizer(model,
                                                **optimizer_kwargs(args))
    scheduler = torchreid.optim.build_lr_scheduler(optimizer,
                                                   **lr_scheduler_kwargs(args))

    start_epoch = args.start_epoch
    best_result = -np.inf
    if args.resume and check_isfile(args.resume):
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        best_result = checkpoint['label_mA']
        print('Loaded checkpoint from "{}"'.format(args.resume))
        print('- start epoch: {}'.format(start_epoch))
        print('- label_mA: {}'.format(best_result))

    time_start = time.time()

    for epoch in range(start_epoch, args.max_epoch):
        train(epoch, model, criterion, optimizer, scheduler, trainloader,
              use_gpu)
        test_outputs = test(model, testloader, attr_dict, use_gpu)
        label_mA = test_outputs[0]
        is_best = label_mA > best_result
        if is_best:
            best_result = label_mA

        save_checkpoint(
            {
                'state_dict': model.state_dict(),
                'epoch': epoch + 1,
                'label_mA': label_mA,
                'optimizer': optimizer.state_dict(),
            },
            args.save_dir,
            is_best=is_best)

    elapsed = round(time.time() - time_start)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    print('Elapsed {}'.format(elapsed))