def set_model(opt): model = SupConResNet(name=opt.model) criterion = torch.nn.CrossEntropyLoss() classifier = LinearClassifier(name=opt.model, num_classes=opt.n_cls) ckpt = torch.load(opt.ckpt, map_location='cpu') state_dict = ckpt['model'] if torch.cuda.is_available(): if torch.cuda.device_count() > 1: model.encoder = torch.nn.DataParallel(model.encoder) else: new_state_dict = {} for k, v in state_dict.items(): k = k.replace("module.", "") new_state_dict[k] = v state_dict = new_state_dict model = model.cuda() classifier = classifier.cuda() criterion = criterion.cuda() cudnn.benchmark = True model.load_state_dict(state_dict) return model, classifier, criterion
def set_model(opt): model = SupConResNet(name=opt.model) criterion = SupConLoss(temperature=opt.temp) # enable synchronized Batch Normalization if opt.syncBN: model = apex.parallel.convert_syncbn_model(model) if torch.cuda.is_available(): if torch.cuda.device_count() > 1: model.encoder = torch.nn.DataParallel(model.encoder) model = model.cuda() criterion = criterion.cuda() cudnn.benchmark = True return model, criterion
def set_model(opt): model = SupConResNet(name=opt.model) classifier = LinearClassifier(name=opt.model, num_classes=opt.n_cls) criterions = { 'SupConLoss': SupConLoss(temperature=opt.temp), 'CrossEntropyLoss': torch.nn.CrossEntropyLoss() } # enable synchronized Batch Normalization if opt.syncBN: model = apex.parallel.convert_syncbn_model(model) if torch.cuda.is_available(): if torch.cuda.device_count() > 1: model.encoder = torch.nn.DataParallel(model.encoder) model = model.cuda() classifier = classifier.cuda() for name, criterion in criterions.items(): criterions[name] = criterion.cuda() cudnn.benchmark = True return model, classifier, criterions
def set_models(opt): online_encoder = SupConResNet(name=opt.model) online_predictor = MLP() criterion = BYOLLoss() # enable synchronized Batch Normalization if opt.syncBN: online_encoder = apex.parallel.convert_syncbn_model(online_encoder) if torch.cuda.is_available(): if torch.cuda.device_count() > 1: online_encoder.encoder = torch.nn.DataParallel(online_encoder.encoder) online_encoder = online_encoder.cuda() online_predictor = online_predictor.cuda() criterion = criterion.cuda() cudnn.benchmark = True target_encoder = copy.deepcopy(online_encoder) models = {'online_encoder': online_encoder, 'target_encoder': target_encoder, 'online_predictor': online_predictor} return models, criterion