def set_model(opt):
    model = SupConResNet(name=opt.model)
    criterion = torch.nn.CrossEntropyLoss()

    classifier = LinearClassifier(name=opt.model, num_classes=opt.n_cls)

    ckpt = torch.load(opt.ckpt, map_location='cpu')
    state_dict = ckpt['model']

    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            model.encoder = torch.nn.DataParallel(model.encoder)
        else:
            new_state_dict = {}
            for k, v in state_dict.items():
                k = k.replace("module.", "")
                new_state_dict[k] = v
            state_dict = new_state_dict
        model = model.cuda()
        classifier = classifier.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

        model.load_state_dict(state_dict)

    return model, classifier, criterion
示例#2
0
def set_model(opt):
    model = SupConResNet(name=opt.model)
    criterion = SupConLoss(temperature=opt.temp)

    # enable synchronized Batch Normalization
    if opt.syncBN:
        model = apex.parallel.convert_syncbn_model(model)

    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            model.encoder = torch.nn.DataParallel(model.encoder)
        model = model.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

    return model, criterion
示例#3
0
def set_model(opt):
    model = SupConResNet(name=opt.model)
    classifier = LinearClassifier(name=opt.model, num_classes=opt.n_cls)

    criterions = {
        'SupConLoss': SupConLoss(temperature=opt.temp),
        'CrossEntropyLoss': torch.nn.CrossEntropyLoss()
    }

    # enable synchronized Batch Normalization
    if opt.syncBN:
        model = apex.parallel.convert_syncbn_model(model)

    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            model.encoder = torch.nn.DataParallel(model.encoder)
        model = model.cuda()
        classifier = classifier.cuda()
        for name, criterion in criterions.items():
            criterions[name] = criterion.cuda()
        cudnn.benchmark = True

    return model, classifier, criterions
示例#4
0
def set_models(opt):
    online_encoder = SupConResNet(name=opt.model)
    online_predictor = MLP()
    criterion = BYOLLoss()

    # enable synchronized Batch Normalization
    if opt.syncBN:
        online_encoder = apex.parallel.convert_syncbn_model(online_encoder)

    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            online_encoder.encoder = torch.nn.DataParallel(online_encoder.encoder)
        online_encoder = online_encoder.cuda()
        online_predictor = online_predictor.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

    target_encoder = copy.deepcopy(online_encoder)

    models = {'online_encoder': online_encoder,
              'target_encoder': target_encoder,
              'online_predictor': online_predictor}

    return models, criterion