示例#1
0
def ResNet34(num_classes, pretrained=None):
    """
    ResNet34 for classification
    Args:
        num_classes (int): the number of classes
        criterion: loss function
        pretrained (str | None):
    :return: model
    """
    model = ResNet_CLS(num_classes, depth=34)
    if pretrained is not None:
        model_dict = model.state_dict()
        pretrained_dict = torch.load(pretrained)
        matched_dict, unmatched = validate_ckpt(model_dict, pretrained_dict)
        if len(unmatched) > 0:
            print('unmatched key ({}): '.format(len(unmatched)))
            for i in range(len(unmatched)):
                unmatched_key = unmatched[i]
                if unmatched_key in model_dict.keys():
                    print(
                        f'model_dict[{unmatched_key}].shape={model_dict[unmatched_key].shape} vs '
                        f'pretrained[{unmatched_key}].shape={pretrained_dict[unmatched_key].shape}'
                    )
                else:
                    print(f'key \'{unmatched_key}\' is not in model_dict')
        model_dict.update(matched_dict)
        model.load_state_dict(model_dict)
    return model
示例#2
0
def ResNeXt101_32x8d(num_classes, pretrained=None):
    model = ResNeXt_CLS(num_classes, depth=101, groups=32, width_per_group=8)
    if pretrained is not None:
        model_dict = model.state_dict()
        pretrained_dict = torch.load(pretrained)
        matched_dict, unmatched = validate_ckpt(model_dict, pretrained_dict)
        if len(unmatched) > 0:
            print('unmatched key ({}): '.format(len(unmatched)))
            for i in range(len(unmatched)):
                unmatched_key = unmatched[i]
                if unmatched_key in model_dict.keys():
                    print(
                        f'model_dict[{unmatched_key}].shape={model_dict[unmatched_key].shape} vs '
                        f'pretrained[{unmatched_key}].shape={pretrained_dict[unmatched_key].shape}'
                    )
                else:
                    print(f'key \'{unmatched_key}\' is not in model_dict')
        model_dict.update(matched_dict)
        model.load_state_dict(model_dict)
    return model