예제 #1
0
def load_model_weight(model, model_weight_file):
    assert osp.exists(model_weight_file), "model_weight_file {} does not exist!".format(model_weight_file)
    assert osp.isfile(model_weight_file), "model_weight_file {} is not file!".format(model_weight_file)
    model_weight = torch.load(model_weight_file, map_location=(lambda storage, loc: storage))
    load_state_dict(model, model_weight)
    msg = '=> Loaded model_weight from {}'.format(model_weight_file)
    print(msg)
예제 #2
0
def get_generater(args, config):
    # Import the model.
    model = __import__(config["model"])
    G = model.Generator(**config).to(config["device"])
    utils.count_parameters(G)

    utils.load_state_dict(G,
                          torch.load(args.weights_path),
                          strict=not args.not_strict)
    G.eval()
    return G
예제 #3
0
def export_onnx(args):
    model_name = args.model
    if model_name in models.model_zoo:
        model, args = models.get_model(args)
    else:
        print("model(%s) not support, available models: %r" %
              (model_name, models.model_zoo))
        return

    if utils.check_file(args.old):
        print("load pretrained from %s" % args.old)
        if torch.cuda.is_available():
            checkpoint = torch.load(args.old)
        else:  # force cpu mode
            checkpoint = torch.load(args.old, map_location='cpu')
        print("load pretrained ==> last epoch: %d" %
              checkpoint.get('epoch', 0))
        print("load pretrained ==> last best_acc: %f" %
              checkpoint.get('best_acc', 0))
        print("load pretrained ==> last learning_rate: %f" %
              checkpoint.get('learning_rate', 0))
        try:
            utils.load_state_dict(model, checkpoint.get('state_dict', None))
        except RuntimeError:
            print("Loading pretrained model failed")
    else:
        print(
            "no pretrained file exists({}), init model with default initlizer".
            format(args.old))

    onnx_model = torch.nn.Sequential(
        OrderedDict([
            ('network', model),
            ('softmax', torch.nn.Softmax()),
        ]))

    onnx_path = "onnx/" + model_name
    if not os.path.exists(onnx_path):
        os.makedirs(onnx_path)
    onnx_save = onnx_path + "/" + model_name + '.onnx'

    input_names = ["input"]
    dummy_input = torch.zeros((1, 3, args.input_size, args.input_size))
    output_names = ['prob']
    torch.onnx.export(onnx_model,
                      dummy_input,
                      onnx_save,
                      verbose=True,
                      input_names=input_names,
                      output_names=output_names,
                      opset_version=7,
                      keep_initializers_as_inputs=True)
    def _resume_checkpoint(self, resume_path, state_dict_only=False):
        """
        Resume from saved checkpoints

        :param resume_path: Checkpoint path to be resumed
        """
        resume_path = str(resume_path)
        self.logger.info("Loading checkpoint: {} ...".format(resume_path))
        checkpoint = torch.load(resume_path)

        if not state_dict_only:
            if 'epoch' in checkpoint:
                self.start_epoch = checkpoint['epoch'] + 1

            if 'monitor_best' in checkpoint:
                self.mnt_best = checkpoint['monitor_best']

            # load architecture params from checkpoint.
            if checkpoint['config']['arch'] != self.config['arch']:
                self.logger.warning(
                    "Warning: Architecture configuration given in config file is different from that of "
                    "checkpoint. This may yield an exception while state_dict is being loaded."
                )

        state_dict = checkpoint['state_dict']
        if state_dict_only:
            rename_parallel_state_dict(state_dict)

        # self.model.load_state_dict(state_dict)
        load_state_dict(self.model, state_dict)

        if not state_dict_only:
            if 'criterion' in checkpoint:
                load_state_dict(self.criterion, checkpoint['criterion'])
                self.logger.info("Criterion state dict is loaded")
            else:
                self.logger.info(
                    "Criterion state dict is not found, so it's not loaded.")

            # load optimizer state from checkpoint only when optimizer type is not changed.
            if 'optimizer' in checkpoint:
                if checkpoint['config']['optimizer']['type'] != self.config[
                        'optimizer']['type']:
                    self.logger.warning(
                        "Warning: Optimizer type given in config file is different from that of checkpoint. "
                        "Optimizer parameters not being resumed.")
                else:
                    self.optimizer.load_state_dict(checkpoint['optimizer'])

        self.logger.info(
            "Checkpoint loaded. Resume training from epoch {}".format(
                self.start_epoch))
예제 #5
0
    def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader,
                 valid_data_loader=None, lr_scheduler=None, len_epoch=None):
        super().__init__(model, criterion, metric_ftns, optimizer, config)
        self.config = config

        self.distill = config._config.get('distill', False)
        
        # add_extra_info will return info about individual experts. This is crucial for individual loss. If this is false, we can only get a final mean logits.
        self.add_extra_info = config._config.get('add_extra_info', False)

        if self.distill:
            print("** Distill is on, please double check distill_checkpoint in config **")
            self.teacher_model = config.init_obj('distill_arch', module_arch)
            teacher_checkpoint = torch.load(config['distill_checkpoint'], map_location="cpu")

            self.teacher_model = self.teacher_model.to(self.device)

            teacher_state_dict = teacher_checkpoint["state_dict"]

            rename_parallel_state_dict(teacher_state_dict)
            
            if len(self.device_ids) > 1:
                print("Using multiple GPUs for teacher model")
                self.teacher_model = torch.nn.DataParallel(self.teacher_model, device_ids=self.device_ids)
                load_state_dict(self.teacher_model, {"module." + k: v for k, v in teacher_state_dict.items()}, no_ignore=True)
            else:
                load_state_dict(self.teacher_model, teacher_state_dict, no_ignore=True)

        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch

        if use_fp16:
            self.logger.warn("FP16 is enabled. This option should be used with caution unless you make sure it's working and we do not provide guarantee.")
            from torch.cuda.amp import GradScaler
            self.scaler = GradScaler()
        else:
            self.scaler = None

        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))

        self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
        self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
예제 #6
0
def inception_v3(pretrained=False, **kwargs):
    r"""Inception v3 model architecture from
    `"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    if pretrained:
        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True
        model = Inception3(**kwargs)
        utils.load_state_dict(
            model, model_zoo.load_url(model_urls['inception_v3_google']))
        return model

    return Inception3(**kwargs)
예제 #7
0
 def load(self, path, loc='cpu', verbose=False):
     state_dict = load_state_dict(path, loc)
     results = self.model.load_state_dict(state_dict, strict=False)
     if verbose:
         print('Loaded from ', path)
         print(results)
     self.start_epoch = int(args.load.split('Epoch')[-1])
예제 #8
0
def main(args, trainloader, testloader):
    net = utils.get_model(args, target_name, target_mode)
    net = torch.nn.DataParallel(net).to(device)

    utils.load_state_dict(net, target_path)
    net.eval()

    ece_train = eval_ece(net, trainloader, n_bins=n_bins)
    ece_test = eval_ece(net, testloader, n_bins=n_bins)

    print("The Expected Calibration Accuracy on Training Set is %.03f." %
          (ece_train))
    print("The Expected Calibration Accuracy on Test Set is %.03f." %
          (ece_test))
    print(
        "=============================================================================="
    )
예제 #9
0
def get_pretrained_model(include_top=False, pretrain_kind='imagenet', model_name='resnet50'):
    if pretrain_kind == 'vggface2':
        N_IDENTITY = 8631  # the number of identities in VGGFace2 for which ResNet and SENet are trained
        resnet50_weight_file = 'weights/resnet50_ft_weight.pkl'
        senet_weight_file = 'weights/senet50_ft_weight.pkl'

        if model_name == 'resnet50':
            model = ResNet.resnet50(num_classes=N_IDENTITY, include_top=include_top).eval()
            utils.load_state_dict(model, resnet50_weight_file)
            return model

        elif model_name == 'senet50':
            model = SeNet.senet50(num_classes=N_IDENTITY, include_top=include_top).eval()
            utils.load_state_dict(model, senet_weight_file)
            return model
    elif pretrain_kind == 'imagenet':
        return nn.Sequential(*list(models.resnet50(pretrained=True).children())[:-2])
    return None
예제 #10
0
    def load_checkpoint(self):
        checkpoint_file = self.config.checkpoint
        resume_optim = self.config.resume_optim
        if osp.isfile(checkpoint_file):
            loc_func = None if self.config.cuda else lambda storage, loc:storage
            # map_location: specify how to remap storage
            checkpoint = torch.load(checkpoint_file, map_location=loc_func)
            self.best_model = checkpoint
            load_state_dict(self.model, checkpoint["model_state_dict"])

            self.start_epoch = checkpoint['epoch']

            # Is this meaningful !?
            if checkpoint.has_key('criterion_state_dict'):
                c_state = checkpoint['criterion_state_dict']
                # retrieve key in train_criterion
                append_dict = {
                    k: torch.Tensor([0, 0])
                    for k, _ in self.train_criterion.named_parameters()
                    if not k in c_state
                }
                # load zeros into state_dict
                c_state.update(append_dict)
                self.train_criterion.load_state_dict(c_state)

            print("Loaded checkpoint {:s} epoch {:d}".format(
                checkpoint_file, checkpoint['epoch']
            ))
            print("Loss of loaded model = {}".format(checkpoint['loss']))

            if resume_optim:
                print("Load parameters in optimizer")
                self.optimizer.load_state_dict(checkpoint["optim_state_dict"])
                for state in self.optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.cuda()


            else:
                print("Notice: load checkpoint but didn't load optimizer.")
        else:
            print("Can't find specified checkpoint.!")
            exit(-1)
예제 #11
0
def load_lxmert_qa(args, path, model, label2ans, verbose=False, loc='cpu'):
    """
    Load model weights from LXMERT pre-training.
    The answers in the fine-tuned QA task (indicated by label2ans)
        would also be properly initialized with LXMERT pre-trained
        QA heads.

    :param path: Path to LXMERT snapshot.
    :param model: LXRT model instance.
    :param label2ans: The label2ans dict of fine-tuned QA datasets, like
        {0: 'cat', 1: 'dog', ...}
    :return:
    """
    if verbose:
        print("Load QA pre-trained LXMERT from %s " % path)

    loaded_state_dict = load_state_dict(path, loc)
    model_state_dict = model.state_dict()

    # Do surgery on answer state dict
    ans_weight = loaded_state_dict['answer_head.logit_fc.3.weight']
    ans_bias = loaded_state_dict['answer_head.logit_fc.3.bias']

    new_answer_weight = copy.deepcopy(
        model_state_dict['answer_head.logit_fc.3.weight'])
    new_answer_bias = copy.deepcopy(
        model_state_dict['answer_head.logit_fc.3.bias'])
    answer_table = AnswerTable(args)
    loaded = 0
    unload = 0
    if type(label2ans) is list:
        label2ans = {label: ans for label, ans in enumerate(label2ans)}
    for label, ans in label2ans.items():
        new_ans = answer_table.convert_ans(ans)
        if answer_table.used(new_ans):
            ans_id_9500 = answer_table.ans2id(new_ans)
            new_answer_weight[label] = ans_weight[ans_id_9500]
            new_answer_bias[label] = ans_bias[ans_id_9500]
            loaded += 1
        else:
            new_answer_weight[label] = 0.
            new_answer_bias[label] = 0.
            unload += 1
    if verbose:
        print("Loaded %d answers from LXRTQA pre-training and %d not" %
              (loaded, unload))
        print()
    loaded_state_dict['answer_head.logit_fc.3.weight'] = new_answer_weight
    loaded_state_dict['answer_head.logit_fc.3.bias'] = new_answer_bias

    result = model.load_state_dict(loaded_state_dict, strict=False)
    if verbose:
        print(result)
예제 #12
0
def main(args):
    # create model
    if "resnet" in args.backbone or "resnext" in args.backbone:
        print('resnet', args.att)
        model = ResNet(args)
    elif 'b' in args.backbone:
        model = EfficientNet.from_pretrained(f'efficientnet-{args.backbone}',
                                             8)

    if args.input_level == 'per-study':
        # add decoder if train per-study
        if args.conv_lstm:
            decoder = ConvDecoder(args)
        else:
            decoder = Decoder(args)

        encoder = model
        model = (encoder, decoder)

    if args.input_level == 'per-study':
        model[0].cuda(), model[1].cuda()
    else:
        model = model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))

            checkpoint = torch.load(args.resume, "cpu")
            input_level = checkpoint['input_level']
            assert input_level == args.input_level
            if args.input_level == 'per-study':
                encoder, decoder = model
                load_state_dict(checkpoint.pop('encoder'), encoder)
                load_state_dict(checkpoint.pop('decoder'), decoder)
            else:
                load_state_dict(checkpoint.pop('state_dict'), model)

            # load_state_dict(checkpoint.pop('state_dict'), model)
            epoch = checkpoint['epoch']
            best_loss = checkpoint['best_loss']
            print(
                f"=> loaded checkpoint '{args.resume}' (loss {best_loss:.4f}@{epoch})"
            )
        else:
            raise ValueError("=> no checkpoint found at '{}'".format(
                args.resume))

    # if args.to_stack:
    #     loader = get_test_dl(args)
    #     to_submit(args, model, loader)
    # else:
    if args.val:
        val_dl = get_val_dl(args)
        to_stacking_on_val(args, model, val_dl)
    else:
        test_dl = get_test_dl(args)
        to_stacking_on_test(args, model, test_dl)
    def _load_crt(self, cRT_pretrain):
        """
        Load from cRT pretrain
        :param cRT pretrain path to the checkpoint of cRT pretrain
        """
        state_dict = torch.load(cRT_pretrain)['state_dict']
        ignore_linear = True

        rename_parallel_state_dict(state_dict)

        if ignore_linear:
            for k in list(state_dict.keys()):
                if k.startswith('backbone.linear'):
                    state_dict.pop(k)
                    print("Popped", k)
        load_state_dict(self.real_model, state_dict)
        for name, param in self.real_model.named_parameters():
            if not name.startswith('backbone.linear'):
                param.requires_grad_(False)
            else:
                print("Allow gradient on:", name)
        print("** Please check the list of allowed gradient to confirm **")
예제 #14
0
파일: trainer_base.py 프로젝트: j-min/VL-T5
    def load_checkpoint(self, ckpt_path):
        state_dict = load_state_dict(ckpt_path, 'cpu')

        original_keys = list(state_dict.keys())
        for key in original_keys:
            if key.startswith("vis_encoder."):
                new_key = 'encoder.' + key[len("vis_encoder."):]
                state_dict[new_key] = state_dict.pop(key)

            if key.startswith("model.vis_encoder."):
                new_key = 'model.encoder.' + key[len("model.vis_encoder."):]
                state_dict[new_key] = state_dict.pop(key)

        results = self.model.load_state_dict(state_dict, strict=False)
        if self.verbose:
            print('Model loaded from ', ckpt_path)
            pprint(results)
예제 #15
0
import utils
import torch

torch_state_dict_path = '/local/mnt2/workspace2/tkuai/cnn_audio_denoiser/pytorch_model3/model2/AudioDenoiser_512_0.001_49.pt'
model = utils.load_state_dict(torch_state_dict_path)
예제 #16
0
def main():
    parser = argparse.ArgumentParser("PyTorch Face Recognizer")
    parser.add_argument('--cmd',
                        default='extract',
                        type=str,
                        choices=['train', 'test', 'extract'],
                        help='train, test or extract')

    parser.add_argument('--arch_type',
                        type=str,
                        default='senet50_ft',
                        help='model type',
                        choices=[
                            'resnet50_ft', 'senet50_ft', 'resnet50_scratch',
                            'senet50_scratch'
                        ])

    parser.add_argument('--dataset_dir',
                        type=str,
                        default='/tmp/Datasets/3Dto2D/squared/uniques',
                        help='dataset directory')

    parser.add_argument('--log_file',
                        type=str,
                        default='/path/to/log_file',
                        help='log file')
    parser.add_argument(
        '--train_img_list_file',
        type=str,
        default='/path/to/train_image_list.txt',
        help='text file containing image files used for training')
    parser.add_argument(
        '--test_img_list_file',
        type=str,
        default='/path/to/test_image_list.txt',
        help=
        'text file containing image files used for validation, test or feature extraction'
    )
    parser.add_argument(
        '--meta_file',
        type=str,
        default='/tmp/face-hallucination/style/vgg-face/identity_meta.csv',
        help='meta file')
    parser.add_argument('--checkpoint_dir',
                        type=str,
                        default='/path/to/checkpoint_directory',
                        help='checkpoints directory')
    parser.add_argument('--feature_dir',
                        type=str,
                        default='/path/to/feature_directory',
                        help='directory where extracted features are saved')
    parser.add_argument(
        '-c',
        '--config',
        type=int,
        default=1,
        choices=configurations.keys(),
        help='the number of settings and hyperparameters used in training')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='batch size')
    parser.add_argument('--resume',
                        type=str,
                        default='',
                        help='checkpoint file')
    parser.add_argument(
        '--weight_file',
        type=str,
        default=
        '/tmp/face-hallucination/style/vgg-face/models/senet50_ft_weight.pkl',
        help='weight file')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('-j',
                        '--workers',
                        default=4,
                        type=int,
                        metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument(
        '--horizontal_flip',
        action='store_true',
        help='horizontally flip images specified in test_img_list_file')
    args = parser.parse_args()
    print(args)

    if args.cmd == "extract":
        utils.create_dir(args.feature_dir)

    if args.cmd == 'train':
        utils.create_dir(args.checkpoint_dir)
        cfg = configurations[args.config]

    log_file = args.log_file
    resume = args.resume

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    cuda = torch.cuda.is_available()
    if cuda:
        print("torch.backends.cudnn.version: {}".format(
            torch.backends.cudnn.version()))

    torch.manual_seed(1337)
    if cuda:
        torch.cuda.manual_seed(1337)

    # 0. id label map
    meta_file = args.meta_file
    id_label_dict = utils.get_id_label_map(meta_file)

    # 1. data loader
    root = args.dataset_dir
    train_img_list_file = args.train_img_list_file
    test_img_list_file = args.test_img_list_file

    kwargs = {'num_workers': args.workers, 'pin_memory': True} if cuda else {}

    if args.cmd == 'train':
        dt = datasets.VGG_Faces2(root,
                                 train_img_list_file,
                                 id_label_dict,
                                 split='train')
        train_loader = torch.utils.data.DataLoader(dt,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)

    dv = datasets.VGG_Faces2(root,
                             test_img_list_file,
                             id_label_dict,
                             split='valid',
                             horizontal_flip=args.horizontal_flip)
    val_loader = torch.utils.data.DataLoader(dv,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             **kwargs)

    # 2. model
    include_top = True if args.cmd != 'extract' else False
    if 'resnet' in args.arch_type:
        model = ResNet.resnet50(num_classes=N_IDENTITY,
                                include_top=include_top)
    else:
        model = SENet.senet50(num_classes=N_IDENTITY, include_top=include_top)
    # print(model)

    start_epoch = 0
    start_iteration = 0
    if resume:
        checkpoint = torch.load(resume)
        model.load_state_dict(checkpoint['model_state_dict'])
        start_epoch = checkpoint['epoch']
        start_iteration = checkpoint['iteration']
        assert checkpoint['arch'] == args.arch_type
        print("Resume from epoch: {}, iteration: {}".format(
            start_epoch, start_iteration))
    else:
        utils.load_state_dict(model, args.weight_file)
        if args.cmd == 'train':
            model.fc.reset_parameters()

    if cuda:
        model = model.cuda()

    criterion = nn.CrossEntropyLoss()
    if cuda:
        criterion = criterion.cuda()

    # 3. optimizer
    if args.cmd == 'train':
        optim = torch.optim.SGD([
            {
                'params': get_parameters(model, bias=False)
            },
            {
                'params': get_parameters(model, bias=True),
                'lr': cfg['lr'] * 2,
                'weight_decay': 0
            },
        ],
                                lr=cfg['lr'],
                                momentum=cfg['momentum'],
                                weight_decay=cfg['weight_decay'])
        if resume:
            optim.load_state_dict(checkpoint['optim_state_dict'])

        # lr_policy: step
        last_epoch = start_iteration if resume else -1
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optim,
                                                       cfg['step_size'],
                                                       gamma=cfg['gamma'],
                                                       last_epoch=last_epoch)

    if args.cmd == 'train':
        trainer = Trainer(
            cmd=args.cmd,
            cuda=cuda,
            model=model,
            criterion=criterion,
            optimizer=optim,
            lr_scheduler=lr_scheduler,
            train_loader=train_loader,
            val_loader=val_loader,
            log_file=log_file,
            max_iter=cfg['max_iteration'],
            checkpoint_dir=args.checkpoint_dir,
            print_freq=1,
        )
        trainer.epoch = start_epoch
        trainer.iteration = start_iteration
        trainer.train()
    elif args.cmd == 'test':
        validator = Validator(
            cmd=args.cmd,
            cuda=cuda,
            model=model,
            criterion=criterion,
            val_loader=val_loader,
            log_file=log_file,
            print_freq=1,
        )
        validator.validate()
    elif args.cmd == 'extract':
        extractor = Extractor(
            cuda=cuda,
            model=model,
            val_loader=val_loader,
            log_file=log_file,
            feature_dir=args.feature_dir,
            flatten_feature=True,
            print_freq=1,
        )
        extractor.extract()
예제 #17
0
파일: search.py 프로젝트: Frizy-up/pt.darts
def main():
    logger.info("Logger is set - training start")

    # set default gpu device id
    torch.cuda.set_device(config.gpus[0])

    # set seed
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    torch.backends.cudnn.benchmark = True

    # get data with meta info
    input_size, input_channels, n_classes, train_data = utils.get_data(
        config.dataset, config.data_path, cutout_length=0, validation=False)

    net_crit = nn.CrossEntropyLoss().to(device)
    model = SearchCNNController(input_channels,
                                config.init_channels,
                                n_classes,
                                config.layers,
                                net_crit,
                                device_ids=config.gpus)
    model = model.to(device)

    # weights optimizer
    w_optim = torch.optim.SGD(model.weights(),
                              config.w_lr,
                              momentum=config.w_momentum,
                              weight_decay=config.w_weight_decay)
    # alphas optimizer
    alpha_optim = torch.optim.Adam(model.alphas(),
                                   config.alpha_lr,
                                   betas=(0.5, 0.999),
                                   weight_decay=config.alpha_weight_decay)

    # split data to train/validation
    n_train = len(train_data)
    split = n_train // 2
    indices = list(range(n_train))
    train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
        indices[:split])
    valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(
        indices[split:])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               sampler=train_sampler,
                                               num_workers=config.workers,
                                               pin_memory=False)
    valid_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               sampler=valid_sampler,
                                               num_workers=config.workers,
                                               pin_memory=False)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optim, config.epochs, eta_min=config.w_lr_min)
    architect = Architect(model, config.w_momentum, config.w_weight_decay)

    # training loop
    best_top1 = -1.0
    best_epoch = 0
    ################################ restore from last time #############################################
    epoch_restore = config.epoch_restore
    if config.restore:
        utils.load_state_dict(model,
                              config.path,
                              extra='model',
                              parallel=(len(config.gpus) > 1))
        if not config.model_only:
            utils.load_state_dict(w_optim,
                                  config.path,
                                  extra='w_optim',
                                  parallel=False)
            utils.load_state_dict(alpha_optim,
                                  config.path,
                                  extra='alpha_optim',
                                  parallel=False)
            utils.load_state_dict(lr_scheduler,
                                  config.path,
                                  extra='lr_scheduler',
                                  parallel=False)
            utils.load_state_dict(epoch_restore,
                                  config.path,
                                  extra='epoch_restore',
                                  parallel=False)
    #####################################################################################################
    for epoch in range(epoch_restore, config.epochs):
        lr_scheduler.step()
        lr = lr_scheduler.get_lr()[0]

        model.print_alphas(logger)

        # training
        train(train_loader, valid_loader, model, architect, w_optim,
              alpha_optim, lr, epoch)

        # validation
        cur_step = (epoch + 1) * len(train_loader)
        top1 = validate(valid_loader, model, epoch, cur_step)
        # top1 = 0.0

        # log
        # genotype
        genotype = model.genotype()
        logger.info("genotype = {}".format(genotype))

        # genotype as a image
        plot_path = os.path.join(config.plot_path,
                                 "EP{:02d}".format(epoch + 1))
        caption = "Epoch {}".format(epoch + 1)
        plot(genotype.normal, plot_path + "-normal", caption)
        plot(genotype.reduce, plot_path + "-reduce", caption)

        # save
        if best_top1 < top1:
            best_top1 = top1
            best_genotype = genotype
            is_best = True
            best_epoch = epoch + 1
        else:
            is_best = False
        utils.save_checkpoint(model, config.path, is_best)

        ######################################## save all state ###################################################
        utils.save_state_dict(model,
                              config.path,
                              extra='model',
                              is_best=is_best,
                              parallel=(len(config.gpus) > 1),
                              epoch=epoch + 1,
                              acc=top1,
                              last_state=((epoch + 1) >= config.epochs))
        utils.save_state_dict(lr_scheduler,
                              config.path,
                              extra='lr_scheduler',
                              is_best=is_best,
                              parallel=False,
                              epoch=epoch + 1,
                              acc=top1,
                              last_state=((epoch + 1) >= config.epochs))
        utils.save_state_dict(alpha_optim,
                              config.path,
                              extra='alpha_optim',
                              is_best=is_best,
                              parallel=False,
                              epoch=epoch + 1,
                              acc=top1,
                              last_state=((epoch + 1) >= config.epochs))
        utils.save_state_dict(w_optim,
                              config.path,
                              extra='w_optim',
                              is_best=is_best,
                              parallel=False,
                              epoch=epoch + 1,
                              acc=top1,
                              last_state=((epoch + 1) >= config.epochs))
        ############################################################################################################
        print("")
    logger.info("Best Genotype at {} epch.".format(best_epoch))
    logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
    logger.info("Best Genotype = {}".format(best_genotype))
예제 #18
0
def main(args, trainloader, testloader, eval_model):
    target_net = utils.get_model(args, target_name, target_mode)
    target_net = torch.nn.DataParallel(target_net).to(device)
    utils.load_state_dict(target_net, target_path)
    target_net.eval()
    eval_model.eval()

    attack_net = utils.get_attack_model(args, attack_name)
    attack_net = torch.nn.DataParallel(attack_net).to(device)

    optimizer = torch.optim.Adam(attack_net.parameters(),
                                 lr=args[attack_name]['learning_rate'],
                                 betas=(0.5, 0.999),
                                 amsgrad=True)

    best_acc_train, best_l2_train = 0, 1e9
    best_acc_test, best_l2_test = 0, 1e9
    n_epochs = args[attack_name]['epochs']
    noise = noise_vector.to(device)

    # n_epochs = 1

    print("Start Training!")
    for e in range(n_epochs):
        tf = time.time()
        attack_net.train()
        for img, __ in trainloader:
            img = img.to(device)
            #utils.save_tensor_images(img, os.path.join(save_img_path, 'sample.png'), nrow=8)
            if target_mode.startswith('vib'):
                out_prob = target_net(img)[0].detach().cpu().numpy()
            else:
                out_prob = target_net(img).detach().cpu().numpy()
            in_prob = torch.from_numpy(out_prob).to(device)
            in_prob = (1 - w) * in_prob + w * noise
            rec_img = attack_net(in_prob)
            loss = F.mse_loss(img, rec_img)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        train_l2, train_acc = eval_net(attack_net, target_net, trainloader,
                                       "train", e, eval_model)
        test_l2, test_acc = eval_net(attack_net, target_net, testloader,
                                     "test", e, eval_model)

        if test_acc > best_acc_test:
            best_acc_test = test_acc
            best_model = deepcopy(attack_net)
        if test_l2 < best_l2_test:
            best_l2_test = test_l2

        if train_acc > best_acc_train:
            best_acc_train = train_acc
        if train_l2 < best_l2_train:
            best_l2_train = train_l2

        interval = time.time() - tf
        print(
            "Epoch:{}\tTime:{:.2f}\tTrain L2:{:.4f}\tTrain Acc:{:.2f}\tTest L2:{:.4f}\tTest Acc:{:.2f}"
            .format(e, interval, train_l2, train_acc, test_l2, test_acc))

    torch.save({'state_dict': best_model.state_dict()},
               os.path.join(
                   save_model_path,
                   "attack_{}_{}.tar".format(attack_name, target_mode)))

    print("The best train acc is {:.2f}".format(train_acc))
    print("The best train l2 is {:.3f}".format(train_l2))
    print("The best test acc is {:.2f}".format(test_acc))
    print("The best test l2 is {:.3f}".format(test_l2))
    print(
        "=============================================================================="
    )
예제 #19
0
    print("The best test l2 is {:.3f}".format(test_l2))
    print(
        "=============================================================================="
    )


if __name__ == '__main__':
    file = dataset_name + ".json"
    args = utils.load_params(file)
    if w > 0:
        log_file = "attack" + '_' + target_name + '_{}_{}.txt'.format(
            target_mode, w)
    else:
        log_file = "attack" + '_' + target_name + '_{}.txt'.format(target_mode)
    logger = utils.Tee(os.path.join(save_log_path, log_file), 'w')
    utils.print_params(args)

    train_file = args['dataset']['test_file']
    test_file = args['dataset']['train_file']
    trainloader = utils.init_dataloader(args, train_file, mode="train")
    testloader = utils.init_dataloader(args, test_file, mode="test")

    eval_model = utils.get_model(args, "VGG16", "reg")
    eval_model = torch.nn.DataParallel(eval_model).to(device)
    utils.load_state_dict(eval_model, eval_path)

    save_img_path = os.path.join(
        save_img_path, "attack_{}_{}".format(target_name, target_mode))
    os.makedirs(save_img_path, exist_ok=True)
    main(args, trainloader, testloader, eval_model)
예제 #20
0
def main():
    warnings.filterwarnings("ignore")
    # parse args
    args, opt = parser.parse_known_args()
    opt = parse_unknown_args(opt)

    # setup gpu and distributed training
    if args.gpu is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    if not torch.distributed.is_initialized():
        dist.init()
    torch.backends.cudnn.benchmark = True
    torch.cuda.set_device(dist.local_rank())

    # setup path
    os.makedirs(args.path, exist_ok=True)

    # setup random seed
    if args.resume:
        args.manual_seed = int(time.time())
    torch.manual_seed(args.manual_seed)
    torch.cuda.manual_seed_all(args.manual_seed)

    # load config
    exp_config = yaml.safe_load(open(args.config, "r"))
    partial_update_config(exp_config, opt)
    # save config to run directory
    yaml.dump(exp_config,
              open(os.path.join(args.path, "config.yaml"), "w"),
              sort_keys=False)

    # build data_loader
    image_size = exp_config["data_provider"]["image_size"]
    data_provider, n_classes = build_data_loader(
        exp_config["data_provider"]["dataset"],
        image_size,
        exp_config["data_provider"]["base_batch_size"],
        exp_config["data_provider"]["n_worker"],
        exp_config["data_provider"]["data_path"],
        dist.size(),
        dist.rank(),
    )

    # build model
    model = build_model(
        exp_config["model"]["name"],
        n_classes,
        exp_config["model"]["dropout_rate"],
    )
    print(model)

    # netaug
    if exp_config.get("netaug", None) is not None:
        use_netaug = True
        model = augemnt_model(model, exp_config["netaug"], n_classes,
                              exp_config["model"]["dropout_rate"])
        model.set_active(mode="min")
    else:
        use_netaug = False

    # load init
    if args.init_from is not None:
        init = load_state_dict_from_file(args.init_from)
        load_state_dict(model, init, strict=False)
        print("Loaded init from %s" % args.init_from)
    else:
        init_modules(model, init_type=exp_config["run_config"]["init_type"])
        print("Random Init")

    # profile
    profile_model = copy.deepcopy(model)
    # during inference, bn will be fused into conv
    remove_bn(profile_model)
    print(f"Params: {trainable_param_num(profile_model)}M")
    print(
        f"MACs: {inference_macs(profile_model, data_shape=(1, 3, image_size, image_size))}M"
    )

    # train
    exp_config["generator"] = torch.Generator()
    exp_config["generator"].manual_seed(args.manual_seed)
    model = nn.parallel.DistributedDataParallel(model.cuda(),
                                                device_ids=[dist.local_rank()])
    train(model, data_provider, exp_config, args.path, args.resume, use_netaug)
예제 #21
0
 def load(self, path, loc='cpu', verbose=False):
     state_dict = load_state_dict(path, loc)
     results = self.model.load_state_dict(state_dict, strict=False)
     if verbose:
         print('Loaded from ', path)
         print(results)
예제 #22
0
N_IDENTITY = 8631
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose(
    [transforms.Resize((256, 256)),
     transforms.ToTensor()])

if __name__ == '__main__':
    parser = argparse.ArgumentParser("face feature extractor")
    parser.add_argument('--weight',
                        type=str,
                        default='./models/resnet50_scratch_weight.pkl')
    parser.add_argument('--source', type=str, default='./')
    args = parser.parse_args()
    print(args)
    model = resnet50(num_classes=N_IDENTITY, include_top=False)
    utils.load_state_dict(model, args.weight)
    model = model.to(device)
    model.eval()
    image = args.source
    image = cv2.imread(image)
    image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    image = transform(image)
    image = image.unsqueeze(0)
    image = image.to(device)
    tic = time.time()
    out = model(image)
    print(time.time() - tic)
    feature = out.view(out.shape[0], -1)
    print(feature.shape)
    feature = feature[0].data.cpu().numpy()
    feature_file = './feature/feature.npy'
예제 #23
0
def main(args=None):
    if args is None:
        args = get_parameter()

    if args.dataset == 'dali' and not dali_enable:
        args.case = args.case.replace('dali', 'imagenet')
        args.dataset = 'imagenet'
        args.workers = 12

    # log_dir
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)

    model_arch = args.model
    model_name = model_arch
    if args.evaluate:
        log_suffix = 'eval-' + model_arch + '-' + args.case
    else:
        log_suffix = model_arch + '-' + args.case
    utils.setup_logging(os.path.join(args.log_dir, log_suffix + '.txt'),
                        resume=args.resume)

    logging.info("current folder: %r", os.getcwd())
    logging.info("alqnet plugins: %r", plugin_enable)
    logging.info("apex available: %r", apex_enable)
    logging.info("dali available: %r", dali_enable)
    for x in vars(args):
        logging.info("config %s: %r", x, getattr(args, x))

    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and len(args.device_ids) > 0:
        args.device_ids = [
            x for x in args.device_ids
            if x < torch.cuda.device_count() and x >= 0
        ]
        if len(args.device_ids) == 0:
            args.device_ids = None
        else:
            logging.info("training on %d gpu", len(args.device_ids))
    else:
        args.device_ids = None

    if args.device_ids is not None:
        torch.cuda.manual_seed_all(args.seed)
        cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True  #https://github.com/pytorch/pytorch/issues/8019
    else:
        logging.info(
            "no gpu available, try CPU version, lots of functions limited")
        #return

    if model_name in models.model_zoo:
        model, args = models.get_model(args)
    else:
        logging.error("model(%s) not support, available models: %r" %
                      (model_name, models.model_zoo))
        return
    criterion = nn.CrossEntropyLoss()
    if 'label-smooth' in args.keyword:
        criterion_smooth = utils.CrossEntropyLabelSmooth(
            args.num_classes, args.label_smooth)

    # load policy for initial phase
    models.policy.deploy_on_init(model, getattr(args, 'policy', ''))
    # load policy for epoch updating
    epoch_policies = models.policy.read_policy(getattr(args, 'policy', ''),
                                               section='epoch')
    # print model
    logging.info("models: %r" % model)
    logging.info("epoch_policies: %r" % epoch_policies)

    utils.check_folder(args.weights_dir)
    args.weights_dir = os.path.join(args.weights_dir, model_name)
    utils.check_folder(args.weights_dir)
    args.resume_file = os.path.join(args.weights_dir,
                                    args.case + "-" + args.resume_file)
    args.pretrained = os.path.join(args.weights_dir, args.pretrained)
    epoch = 0
    lr = args.lr
    best_acc = 0
    scheduler = None
    checkpoint = None
    # resume training
    if args.resume:
        if utils.check_file(args.resume_file):
            logging.info("resuming from %s" % args.resume_file)
            if torch.cuda.is_available():
                checkpoint = torch.load(args.resume_file)
            else:
                checkpoint = torch.load(args.resume_file, map_location='cpu')
            if 'epoch' in checkpoint:
                epoch = checkpoint['epoch']
                logging.info("resuming ==> last epoch: %d" % epoch)
                epoch = epoch + 1
                logging.info("updating ==> epoch: %d" % epoch)
            if 'best_acc' in checkpoint:
                best_acc = checkpoint['best_acc']
                logging.info("resuming ==> best_acc: %f" % best_acc)
            if 'learning_rate' in checkpoint:
                lr = checkpoint['learning_rate']
                logging.info("resuming ==> learning_rate: %f" % lr)
            if 'state_dict' in checkpoint:
                utils.load_state_dict(model, checkpoint['state_dict'])
                logging.info("resumed from %s" % args.resume_file)
        else:
            logging.info("warning: *** resume file not exists({})".format(
                args.resume_file))
            args.resume = False
    else:
        if utils.check_file(args.pretrained):
            logging.info("load pretrained from %s" % args.pretrained)
            if torch.cuda.is_available():
                checkpoint = torch.load(args.pretrained)
            else:
                checkpoint = torch.load(args.pretrained, map_location='cpu')
            logging.info("load pretrained ==> last epoch: %d" %
                         checkpoint.get('epoch', 0))
            logging.info("load pretrained ==> last best_acc: %f" %
                         checkpoint.get('best_acc', 0))
            logging.info("load pretrained ==> last learning_rate: %f" %
                         checkpoint.get('learning_rate', 0))
            #if 'learning_rate' in checkpoint:
            #    lr = checkpoint['learning_rate']
            #    logging.info("resuming ==> learning_rate: %f" % lr)
            try:
                utils.load_state_dict(
                    model,
                    checkpoint.get('state_dict',
                                   checkpoint.get('model', checkpoint)))
            except RuntimeError as err:
                logging.info("Loading pretrained model failed %r" % err)
        else:
            logging.info(
                "no pretrained file exists({}), init model with default initlizer"
                .format(args.pretrained))

    if args.device_ids is not None:
        torch.cuda.set_device(args.device_ids[0])
        if not isinstance(model, nn.DataParallel) and len(args.device_ids) > 1:
            model = nn.DataParallel(model, args.device_ids).cuda()
        else:
            model = model.cuda()
        criterion = criterion.cuda()
        if 'label-smooth' in args.keyword:
            criterion_smooth = criterion_smooth.cuda()

    if 'label-smooth' in args.keyword:
        train_criterion = criterion_smooth
    else:
        train_criterion = criterion

    # move after to_cuda() for speedup
    if args.re_init and not args.resume:
        for m in model.modules():
            if hasattr(m, 'init_after_load_pretrain'):
                m.init_after_load_pretrain()

    # dataset
    data_path = args.root
    dataset = args.dataset
    logging.info("loading dataset with batch_size {} and val-batch-size {}. "
                 "dataset: {}, resolution: {}, path: {}".format(
                     args.batch_size, args.val_batch_size, dataset,
                     args.input_size, data_path))

    if args.val_batch_size < 1:
        val_loader = None
    else:
        if args.evaluate:
            val_batch_size = (args.batch_size // 100) * 100
            if val_batch_size > 0:
                args.val_batch_size = val_batch_size
            logging.info("update val_batch_size to %d in evaluate mode" %
                         args.val_batch_size)
        val_loader = datasets.data_loader(args.dataset)('val', args)

    if args.evaluate and val_loader is not None:
        if args.fp16 and torch.backends.cudnn.enabled and apex_enable and args.device_ids is not None:
            logging.info("training with apex fp16 at opt_level {}".format(
                args.opt_level))
        else:
            args.fp16 = False
            logging.info("training without apex")

        if args.fp16:
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=args.lr,
                                         weight_decay=args.weight_decay)  #
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.opt_level)

        logging.info("evaluate the dataset on pretrained model...")
        result = validate(val_loader, model, criterion, args)
        top1, top5, loss = result
        logging.info('evaluate accuracy on dataset: top1(%f) top5(%f)' %
                     (top1, top5))
        return

    train_loader = datasets.data_loader(args.dataset)('train', args)
    if isinstance(train_loader, torch.utils.data.dataloader.DataLoader):
        train_length = len(train_loader)
    else:
        train_length = getattr(train_loader, '_size', 0) / getattr(
            train_loader, 'batch_size', 1)

    # sample several iteration / epoch to calculate the initial value of quantization parameters
    if args.stable_epoch > 0 and args.stable <= 0:
        args.stable = train_length * args.stable_epoch
        logging.info("update stable: %d" % args.stable)

    # fix learning rate at the beginning to warmup
    if args.warmup_epoch > 0 and args.warmup <= 0:
        args.warmup = train_length * args.warmup_epoch
        logging.info("update warmup: %d" % args.warmup)

    params_dict = dict(model.named_parameters())
    params = []
    quant_wrapper = []
    for key, value in params_dict.items():
        #print(key)
        if 'quant_weight' in key and 'quant_weight' in args.custom_lr_list:
            to_be_quant = key.split('.quant_weight')[0] + '.weight'
            if to_be_quant not in quant_wrapper:
                quant_wrapper += [to_be_quant]
    if len(quant_wrapper) > 0 and args.verbose:
        logging.info("quant_wrapper: {}".format(quant_wrapper))

    for key, value in params_dict.items():
        shape = value.shape
        custom_hyper = dict()
        custom_hyper['params'] = value
        if value.requires_grad == False:
            continue

        found = False
        for i in args.custom_decay_list:
            if i in key and len(i) > 0:
                found = True
                break
        if found:
            custom_hyper['weight_decay'] = args.custom_decay
        elif (not args.decay_small and args.no_decay_small) and (
            (len(shape) == 4 and shape[1] == 1) or (len(shape) == 1)):
            custom_hyper['weight_decay'] = 0.0

        found = False
        for i in args.custom_lr_list:
            if i in key and len(i) > 0:
                found = True
                break
        if found:
            #custom_hyper.setdefault('lr_constant', args.custom_lr) # 2019.11.25
            custom_hyper['lr'] = args.custom_lr
        elif key in quant_wrapper:
            custom_hyper.setdefault('lr_constant', args.custom_lr)
            custom_hyper['lr'] = args.custom_lr

        params += [custom_hyper]

        if 'debug' in args.keyword:
            logging.info("{}, decay {}, lr {}, constant {}".format(
                key, custom_hyper.get('weight_decay', "default"),
                custom_hyper.get('lr', "default"),
                custom_hyper.get('lr_constant', "No")))

    optimizer = None
    if args.optimizer == "ADAM":
        optimizer = torch.optim.Adam(params,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)

    if args.optimizer == "SGD":
        optimizer = torch.optim.SGD(params,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

    if args.resume and checkpoint is not None:
        try:
            optimizer.load_state_dict(checkpoint['optimizer'])
        except RuntimeError as error:
            logging.info("Restore optimizer state failed %r" % error)

    if args.fp16 and torch.backends.cudnn.enabled and apex_enable and args.device_ids is not None:
        logging.info("training with apex fp16 at opt_level {}".format(
            args.opt_level))
    else:
        args.fp16 = False
        logging.info("training without apex")

    if args.sync_bn:
        logging.info("sync_bn to be supported, currently not yet")

    if args.fp16:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level)
        if args.resume and checkpoint is not None:
            try:
                amp.load_state_dict(checkpoint['amp'])
            except RuntimeError as error:
                logging.info("Restore amp state failed %r" % error)

    # start tensorboard as late as possible
    if args.tensorboard and not args.evaluate:
        tb_log = os.path.join(args.log_dir, log_suffix)
        args.tensorboard = SummaryWriter(tb_log,
                                         filename_suffix='.' + log_suffix)
    else:
        args.tensorboard = None

    logging.info("start to train network " + model_name + ' with case ' +
                 args.case)
    while epoch < (args.epochs + args.extra_epoch):
        if 'proxquant' in args.keyword:
            if args.proxquant_step < 10:
                if args.lr_policy in ['sgdr', 'sgdr_step', 'custom_step']:
                    index = len([x for x in args.lr_custom_step if x <= epoch])
                    for m in model.modules():
                        if hasattr(m, 'prox'):
                            m.prox = 1.0 - 1.0 / args.proxquant_step * (index +
                                                                        1)
            else:
                for m in model.modules():
                    if hasattr(m, 'prox'):
                        m.prox = 1.0 - 1.0 / args.proxquant_step * epoch
                        if m.prox < 0:
                            m.prox = 0
        if epoch < args.epochs:
            lr, scheduler = utils.setting_learning_rate(
                optimizer, epoch, train_length, checkpoint, args, scheduler)
        if lr is None:
            logging.info('lr is invalid at epoch %d' % epoch)
            return
        else:
            logging.info('[epoch %d]: lr %e', epoch, lr)

        loss = 0
        top1, top5, eloss = 0, 0, 0
        is_best = top1 > best_acc
        # leverage policies on epoch
        models.policy.deploy_on_epoch(model,
                                      epoch_policies,
                                      epoch,
                                      optimizer=optimizer,
                                      verbose=logging.info)

        if 'lr-test' not in args.keyword:  # otherwise only print the learning rate in each epoch
            # training
            loss = train(train_loader, model, train_criterion, optimizer, args,
                         scheduler, epoch, lr)
            #for i in range(train_length):
            #  scheduler.step()
            logging.info('[epoch %d]: train_loss %.3f' % (epoch, loss))

            # validate
            top1, top5, eloss = 0, 0, 0
            top1, top5, eloss = validate(val_loader, model, criterion, args)
            is_best = top1 > best_acc
            if is_best:
                best_acc = top1
            logging.info('[epoch %d]: test_acc %f %f, best top1: %f, loss: %f',
                         epoch, top1, top5, best_acc, eloss)

        if args.tensorboard is not None:
            args.tensorboard.add_scalar(log_suffix + '/train-loss', loss,
                                        epoch)
            args.tensorboard.add_scalar(log_suffix + '/eval-top1', top1, epoch)
            args.tensorboard.add_scalar(log_suffix + '/eval-top5', top5, epoch)
            args.tensorboard.add_scalar(log_suffix + '/lr', lr, epoch)

        utils.save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler':
                None if scheduler is None else scheduler.state_dict(),
                'best_acc': best_acc,
                'learning_rate': lr,
                'amp': None if not args.fp16 else amp.state_dict(),
            }, is_best, args)

        epoch = epoch + 1
        if epoch == 1:
            logging.info(utils.gpu_info())
예제 #24
0
def main(cfg):
    global best_loss
    best_loss = 100.

    # make dirs
    for dirs in [cfg["MODELS_DIR"], cfg["OUTPUT_DIR"], cfg["LOGS_DIR"]]:
        if not os.path.exists(dirs):
            os.makedirs(dirs)

    # create dataset
    train_ds = RSNAHemorrhageDS3d(cfg, mode="train")
    valid_ds = RSNAHemorrhageDS3d(cfg, mode="valid")
    test_ds = RSNAHemorrhageDS3d(cfg, mode="test")

    # create model
    extra_model_args = {
        "attention": cfg["ATTENTION"],
        "dropout": cfg["DROPOUT"],
        "num_layers": cfg["NUM_LAYERS"],
        "recur_type": cfg["RECUR_TYPE"],
        "num_heads": cfg["NUM_HEADS"],
        "dim_ffw": cfg["DIM_FFW"]
    }
    if cfg["MODEL_NAME"].startswith("tf_efficient"):
        model = GenericEfficientNet3d(cfg["MODEL_NAME"],
                                      input_channels=cfg["NUM_INP_CHAN"],
                                      num_classes=cfg["NUM_CLASSES"],
                                      **extra_model_args)
    elif "res" in cfg["MODEL_NAME"]:
        model = ResNet3d(cfg["MODEL_NAME"],
                         input_channels=cfg["NUM_INP_CHAN"],
                         num_classes=cfg["NUM_CLASSES"],
                         **extra_model_args)
    # print(model)

    # define loss function & optimizer
    class_weight = torch.FloatTensor(cfg["BCE_W"])
    # criterion = nn.BCEWithLogitsLoss(weight=class_weight)
    criterion = nn.BCEWithLogitsLoss(weight=class_weight, reduction='none')
    kd_criterion = KnowledgeDistillationLoss(temperature=cfg["TAU"])
    valid_criterion = nn.BCEWithLogitsLoss(weight=class_weight,
                                           reduction='none')
    optimizer = make_optimizer(cfg, model)

    if cfg["CUDA"]:
        model = model.cuda()
        criterion = criterion.cuda()
        kd_criterion.cuda()
        valid_criterion = valid_criterion.cuda()

    if args.dtype == 'float16':
        if args.opt_level == "O1":
            keep_batchnorm_fp32 = None
        else:
            keep_batchnorm_fp32 = True
        model, optimizer = amp.initialize(
            model,
            optimizer,
            opt_level=args.opt_level,
            keep_batchnorm_fp32=keep_batchnorm_fp32)

    start_epoch = 0
    # optionally resume from a checkpoint
    if cfg["RESUME"]:
        if os.path.isfile(cfg["RESUME"]):
            logger.info("=> Loading checkpoint '{}'".format(cfg["RESUME"]))
            checkpoint = torch.load(cfg["RESUME"], "cpu")
            load_state_dict(checkpoint.pop('state_dict'), model)
            if not args.finetune:
                start_epoch = checkpoint['epoch']
                optimizer.load_state_dict(checkpoint.pop('optimizer'))
                best_loss = checkpoint['best_loss']
            logger.info("=> Loaded checkpoint '{}' (epoch {})".format(
                cfg["RESUME"], checkpoint['epoch']))
        else:
            logger.info("=> No checkpoint found at '{}'".format(cfg["RESUME"]))

    if cfg["MULTI_GPU"]:
        model = nn.DataParallel(model)

    # create data loaders & lr scheduler
    train_loader = DataLoader(train_ds,
                              cfg["BATCH_SIZE"],
                              pin_memory=False,
                              shuffle=True,
                              drop_last=False,
                              num_workers=cfg['NUM_WORKERS'])
    valid_loader = DataLoader(valid_ds,
                              pin_memory=False,
                              shuffle=False,
                              drop_last=False,
                              num_workers=cfg['NUM_WORKERS'])
    test_loader = DataLoader(test_ds,
                             pin_memory=False,
                             collate_fn=test_collate_fn,
                             shuffle=False,
                             drop_last=False,
                             num_workers=cfg['NUM_WORKERS'])
    scheduler = WarmupCyclicalLR("cos",
                                 cfg["BASE_LR"],
                                 cfg["EPOCHS"],
                                 iters_per_epoch=len(train_loader),
                                 warmup_epochs=cfg["WARMUP_EPOCHS"])
    logger.info("Using {} lr scheduler\n".format(scheduler.mode))

    if args.eval:
        _, prob = validate(cfg, valid_loader, model, valid_criterion)
        imgids = pd.read_csv(cfg["DATA_DIR"] + "valid_{}_df_fold{}.csv" \
            .format(cfg["SPLIT"], cfg["FOLD"]))["image"]
        save_df = pd.concat([imgids, pd.DataFrame(prob.numpy())], 1)
        save_df.columns = [
            "image", "any", "intraparenchymal", "intraventricular",
            "subarachnoid", "subdural", "epidural"
        ]
        save_df.to_csv(os.path.join(cfg["OUTPUT_DIR"],
                                    "val_" + cfg["SESS_NAME"] + '.csv'),
                       index=False)
        return

    if args.eval_test:
        if not os.path.exists(cfg["OUTPUT_DIR"]):
            os.makedirs(cfg["OUTPUT_DIR"])
        submit_fpath = os.path.join(cfg["OUTPUT_DIR"],
                                    "test_" + cfg["SESS_NAME"] + '.csv')
        submit_df = test(cfg, test_loader, model)
        submit_df.to_csv(submit_fpath, index=False)
        return

    for epoch in range(start_epoch, cfg["EPOCHS"]):
        logger.info("Epoch {}\n".format(str(epoch + 1)))
        random.seed(epoch)
        torch.manual_seed(epoch)
        # train for one epoch
        train(cfg, train_loader, model, criterion, kd_criterion, optimizer,
              scheduler, epoch)
        # evaluate
        loss, _ = validate(cfg, valid_loader, model, valid_criterion)
        # remember best loss and save checkpoint
        is_best = loss < best_loss
        best_loss = min(loss, best_loss)
        if cfg["MULTI_GPU"]:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': cfg["MODEL_NAME"],
                    'state_dict': model.module.state_dict(),
                    'best_loss': best_loss,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                root=cfg['MODELS_DIR'],
                filename=f"{cfg['SESS_NAME']}.pth")
        else:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': cfg["MODEL_NAME"],
                    'state_dict': model.state_dict(),
                    'best_loss': best_loss,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                root=cfg['MODELS_DIR'],
                filename=f"{cfg['SESS_NAME']}.pth")
예제 #25
0
        train_attack(attack_net, train_data, train_label, test_data,
                     test_label, optimizer, criterion)
        attack_acc += eval_attack(attack_net, test_data, test_label)
    attack_acc /= 10
    print("Attack acc:{:.2f}".format(attack_acc))
    return attack_acc


if __name__ == "__main__":
    if model_v.startswith("vib"):
        shadow_net = models.SimpleCNN_VIB(n_classes)
    else:
        shadow_net = models.SimpleCNN(n_classes)
    shadow_net = nn.DataParallel(shadow_net).cuda()
    shadow_ckp = torch.load(shadow_path)['state_dict']
    load_state_dict(shadow_net, shadow_ckp)

    if model_v.startswith("vib"):
        target_net = models.SimpleCNN_VIB(n_classes)
    else:
        target_net = models.SimpleCNN(n_classes)
    target_net = nn.DataParallel(target_net).cuda()
    target_ckp = torch.load(target_path)['state_dict']
    load_state_dict(target_net, target_ckp)
    criterion = nn.CrossEntropyLoss().cuda()

    __, __, D_s_u_loader, D_u_loader, probe_imgs, probe_labels, __ = load_data(
        batch_size=batch_size)
    probe_imgs, probe_labels = probe_imgs.to(device), probe_labels.to(device)

    train_in, train_label = get_dataset(shadow_net, D_s_u_loader, probe_imgs,
예제 #26
0
def main():
    args = get_parameter()
    args.weights_dir = os.path.join(args.weights_dir, args.model)
    utils.check_folder(args.weights_dir)

    if os.path.exists(args.log_dir):
        utils.setup_logging(os.path.join(args.log_dir, 'tools.txt'), resume=True)

    config = dict()
    for i in args.keyword:
        config[i] = True

    if 'export_onnx' in config.keys():
        export_onnx(args)

    if 'inference' in config.keys():
        inference(args)

    if 'verbose' in config.keys():
        if torch.cuda.is_available():
            checkpoint = torch.load(args.old)
        else:  # force cpu mode
            checkpoint = torch.load(args.old, map_location='cpu')
        if 'state_dict' in checkpoint:
            checkpoint = checkpoint['state_dict']
        if 'model' in checkpoint:
            checkpoint = checkpoint['model']
        for name, value in checkpoint.items():
            if ('quant_activation' in name or 'quant_weight' in name) and name.split('.')[-1] in args.verbose_list:
                print(name, value.shape, value.requires_grad)
                print(value.data)
            elif "all" in args.verbose_list:
                if 'num_batches_tracked' not in name:
                    if isinstance(value, torch.Tensor):
                        print(name, value.shape, value.requires_grad)
                    elif isinstance(value, int) or isinstance(value, float) or isinstance(value, str):
                        print(name, value, type(value))
                    else:
                        print(name, type(value))

    if 'load' in config.keys() or 'save' in config.keys():
        model_name = args.model
        if model_name in models.model_zoo:
            model, args = models.get_model(args)
        else:
            print("model(%s) not support, available models: %r" % (model_name, models.model_zoo))
            return
        if utils.check_file(args.old):
            raw = 'raw' in config.keys()
            if torch.cuda.is_available():
                checkpoint = torch.load(args.old)
            else:  # force cpu mode
                checkpoint = torch.load(args.old, map_location='cpu')
            try:
                utils.load_state_dict(model, checkpoint.get('state_dict', None) if not raw else checkpoint, verbose=False)
            except RuntimeError:
                print("Loading pretrained model failed")
            print("Loading pretrained model OK")

            if 'save' in config.keys() and args.new != '':
                torch.save(model.state_dict(), args.new)
                print("Save pretrained model into %s" % args.new)
        else:
            print("file not exist %s" % args.old)

    if 'update' in config.keys():
        mapping_from = []
        mapping_to = []
        if os.path.isfile(args.mapping_from):
            with open(args.mapping_from) as f:
                mapping_from = f.readlines()
                f.close()
        if os.path.isfile(args.mapping_to):
            with open(args.mapping_to) as f:
                mapping_to = f.readlines()
                f.close()
        mapping_from = [ i.strip().strip('\n').strip('"').strip("'") for i in mapping_from]
        mapping_from = [ i for i in mapping_from if len(i) > 0 and i[0] != '#'] 
        mapping_to = [ i.strip().strip('\n').strip('"').strip("'") for i in mapping_to]
        mapping_to = [ i for i in mapping_to if len(i) > 0 and i[0] != '#']
        if len(mapping_to) != len(mapping_from) or len(mapping_to) == 0 or len(mapping_from) == 0:
            mapping = None
            logging.info('no valid mapping')
        else:
            mapping = dict()
            for i, k in enumerate(mapping_from):
                if '{' in k and '}' in k and '{' in mapping_to[i] and '}' in mapping_to[i]:
                    item = k.split('{')
                    for v in item[1].strip('}').split(","):
                        v = v.strip()
                        mapping[item[0] + v] = mapping_to[i].split('{')[0] + v
                else:
                    mapping[k] = mapping_to[i] 

        raw = 'raw' in config.keys()
        if not os.path.isfile(args.old):
            args.old = args.pretrained
        utils.import_state_dict(args.old, args.new, mapping, raw, raw_prefix=args.case)

    if 'det-load' in  config.keys():
        from third_party.checkpoint import DetectionCheckpointer
        model_name = args.model
        if model_name in models.model_zoo:
            model, args = models.get_model(args)
        else:
            print("model(%s) not support, available models: %r" % (model_name, models.model_zoo))
            return
        split = os.path.split(args.old)
        checkpointer = DetectionCheckpointer(model, split[0], save_to_disk=True)
        checkpointer.resume_or_load(args.old, resume=True)
        checkpointer.save(split[1])

    if 'swap' in config.keys():
        mapping_from = []
        if os.path.isfile(args.mapping_from):
            with open(args.mapping_from) as f:
                mapping_from = f.readlines()
                f.close()
            mapping_from = [ i.strip().strip('\n').strip('"').strip("'") for i in mapping_from]
            mapping_from = [ i for i in mapping_from if len(i) > 0 and i[0] != '#']
            lists = args.verbose_list
            for i in lists:
                item = i.split('/')
                interval = (int)(item[0])
                index = item[1].split('-')
                index = [(int)(x) for x in index]
                if len(mapping_from) % interval == 0 and len(index) <= interval:
                    mapping_to = mapping_from.copy()
                    for j, k in enumerate(index):
                        k = k % interval
                        mapping_to[j::interval] = mapping_from[k::interval]

            mapping_to= [ i + '\n' for i in mapping_to]
            with open(args.mapping_from + "-swap", 'w') as f:
                f.writelines(mapping_to)
                f.close()

    if 'sort' in config.keys():
        mapping_from = []
        if os.path.isfile(args.mapping_from):
            with open(args.mapping_from) as f:
                mapping_from = f.readlines()
                f.close()
            mapping_from.sort()
            with open(args.mapping_from + "-sort", 'w') as f:
                f.writelines(mapping_from)
                f.close()

    if 'verify-data' in config.keys() or 'verify-image' in config.keys():
        if 'verify-image' in config.keys():
            lists = args.verbose_list
        else:
            with open(os.path.join(args.root, 'train.txt')) as f:
                lists = f.readlines()
                f.close()
        from PIL import Image
        from threading import Thread
        print("going to check %d files" % len(lists))
        def check(lists, start, end, index):
            for i, item in enumerate(lists[start:end]):
                try:
                    items = item.split()
                    if len(items) >= 1:
                        path = items[0].strip().strip('\n')
                    else:
                        print("skip line %s" % i)
                        continue
                    path = os.path.join(args.root, os.path.join("train", path))
                    imgs = Image.open(path)
                    imgs.resize((256,256))
                    if index == 0:
                        print(i, end ="\r", file=sys.stderr)
                except (RuntimeError, IOError):
                    print("\nError when read image %s" % path)
            print("\nFinish checking", index)
        #lists = lists[45000:]
        num = min(len(lists), 20)
        for i in range(num):
            start = len(lists) // num * i
            end = min(start + len(lists) // num, len(lists))
            th = Thread(target=check, args=(lists, start, end, i))
            th.start()
예제 #27
0
파일: sa_resnet.py 프로젝트: zlwdghh/SA-Net
def _sanet(arch, block, layers, pretrained, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict(model_urls[arch])
        model.load_state_dict(state_dict, strict=False)
    return model
예제 #28
0
    return sent

if __name__ == '__main__':
    args = parse_args()
    args.n_grids = args.grid_size**2
    # args.gpu = torch.cuda.current_device()

    from time import time

    start = time()

    # 1) Load X-LXMERT
    model = ImggenModel.from_pretrained("bert-base-uncased")

    ckpt_path = Path(__file__).resolve().parents[2].joinpath('snap/pretrained/x_lxmert/Epoch20_LXRT.pth')
    state_dict = load_state_dict(ckpt_path, 'cpu')

    results = model.load_state_dict(state_dict, strict=False)
    print(results)
    print(f'Loaded X-LXMERT | {time() - start:.2f}s')

    # 2) Load Visual embedding
    clustering_dir = args.datasets_dir.joinpath('clustering')
    centroid_path = clustering_dir.joinpath(f'{args.encoder}_{args.cluster_src}_centroids{args.n_centroids}_iter{args.n_iter}_d{args.feat_dim}_grid{args.grid_size}.npy')
    centroids = np.load(centroid_path)
    model.set_visual_embedding(centroids)

    # 2) Load Generator
    code_dim = 256
    SN = True
    base_dim = 32
예제 #29
0
    return output


filepath = opt.test_hr_folder
if filepath.split('/')[-2] == 'Set5' or filepath.split('/')[-2] == 'Set14':
    ext = '.bmp'
else:
    ext = '.png'

filelist = utils.get_list(filepath, ext=ext)
psnr_list = np.zeros(len(filelist))
ssim_list = np.zeros(len(filelist))
time_list = np.zeros(len(filelist))

model = architecture.IMDN_AS()
model_dict = utils.load_state_dict(opt.checkpoint)
model.load_state_dict(model_dict, strict=True)

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
i = 0
for imname in filelist:
    im_gt = sio.imread(imname)
    im_l = sio.imread(opt.test_lr_folder + imname.split('/')[-1])
    if len(im_gt.shape) < 3:
        im_gt = im_gt[..., np.newaxis]
        im_gt = np.concatenate([im_gt] * 3, 2)
        im_l = im_l[..., np.newaxis]
        im_l = np.concatenate([im_l] * 3, 2)
    im_input = im_l / 255.0
    im_input = np.transpose(im_input, (2, 0, 1))
예제 #30
0
def test_predict():
    VOC_CLASSES = [
        'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
        'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
        'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
        'tvmonitor'
    ]

    img = cv2.imread('./dummyImgs/4.jpg', 1)
    img = cv2.resize(img, (300, 300))
    img_float = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img_float = img_float.astype('float32')
    # img_float /= 255.0
    img_float -= np.array([123, 117, 104])
    img_tensor = torch.from_numpy(img_float).permute(2, 0, 1).unsqueeze(0)

    # 第一步,构建模型
    ssd_model = build_vgg_ssd(cfg)
    ssd_model.init()
    state_dict = load_state_dict(os.path.join('./weights/zhuque_best.pth'))
    ssd_model.load_state_dict(state_dict)
    ssd_model = ssd_model.eval()
    prior_bboxes = generator_prior_bboxes(cfg)
    with torch.no_grad():
        pred_conf, pred_loc = ssd_model(img_tensor)
        pred_loc = convert_offset_to_center(pred_loc, prior_bboxes)
        pred_loc = convert_center_to_corner(pred_loc)

        # print(pred_loc.shape, pred_conf.shape)
        nms_bboxes, nms_probs, nms_labels = nms1(pred_loc[0],
                                                 pred_conf[0],
                                                 prob_threshold=0.1,
                                                 iou_threshold=0.5,
                                                 topN=200)

    nms_bboxes_np = nms_bboxes.data.numpy()
    nms_labels_np = nms_labels.data.numpy()
    nms_probs_np = nms_probs.data.numpy()
    num = nms_bboxes_np.shape[0]
    for i in range(num):

        bbox = nms_bboxes_np[i]
        label = nms_labels_np[i]
        prob = nms_probs_np[i]
        print(">>>")
        print("    predict category: {}, pro: {:.4f}".format(
            VOC_CLASSES[label], prob))
        print("    predict bbox: {}".format(bbox))
        xmin, ymin, xmax, ymax = [math.floor(b) for b in bbox]

        color = (random.randint(0, 255), random.randint(0, 255),
                 random.randint(0, 255))
        cv2.rectangle(img, (xmin, ymin), (xmax, ymax),
                      color=color,
                      thickness=2)
        cv2.putText(img,
                    text=VOC_CLASSES[label],
                    org=(xmin, ymin + 10),
                    fontFace=cv2.FONT_HERSHEY_PLAIN,
                    fontScale=1,
                    thickness=2,
                    color=color)
    cv2.imshow('test', img)
    cv2.waitKey(0)