def ResNet34(num_classes, pretrained=None): """ ResNet34 for classification Args: num_classes (int): the number of classes criterion: loss function pretrained (str | None): :return: model """ model = ResNet_CLS(num_classes, depth=34) if pretrained is not None: model_dict = model.state_dict() pretrained_dict = torch.load(pretrained) matched_dict, unmatched = validate_ckpt(model_dict, pretrained_dict) if len(unmatched) > 0: print('unmatched key ({}): '.format(len(unmatched))) for i in range(len(unmatched)): unmatched_key = unmatched[i] if unmatched_key in model_dict.keys(): print( f'model_dict[{unmatched_key}].shape={model_dict[unmatched_key].shape} vs ' f'pretrained[{unmatched_key}].shape={pretrained_dict[unmatched_key].shape}' ) else: print(f'key \'{unmatched_key}\' is not in model_dict') model_dict.update(matched_dict) model.load_state_dict(model_dict) return model
def ResNeXt101_32x8d(num_classes, pretrained=None): model = ResNeXt_CLS(num_classes, depth=101, groups=32, width_per_group=8) if pretrained is not None: model_dict = model.state_dict() pretrained_dict = torch.load(pretrained) matched_dict, unmatched = validate_ckpt(model_dict, pretrained_dict) if len(unmatched) > 0: print('unmatched key ({}): '.format(len(unmatched))) for i in range(len(unmatched)): unmatched_key = unmatched[i] if unmatched_key in model_dict.keys(): print( f'model_dict[{unmatched_key}].shape={model_dict[unmatched_key].shape} vs ' f'pretrained[{unmatched_key}].shape={pretrained_dict[unmatched_key].shape}' ) else: print(f'key \'{unmatched_key}\' is not in model_dict') model_dict.update(matched_dict) model.load_state_dict(model_dict) return model