def copy_weights_between_models(m1, m2): """ Copy weights for layers common between m1 and m2. From m1 => m2 """ # Load state dictionaries for m1 model and m2 model m1_state_dict = m1.state_dict() m2_state_dict = m2.state_dict() # Set the m2 model's weights with trained m1 model weights for name, param in m1_state_dict.items(): if name not in m2_state_dict: continue else: m2_state_dict[name] = param.data m2.load_state_dict(m2_state_dict) # Test that model m2 **really** has got updated weights return test_copy_weights(m1, m2) if __name__ == '__main__': pr = pirl_resnet('res18') cr = classifier_resnet('res18', num_classes=10) copy_success = copy_weights_between_models(pr, cr)
# Load state dictionaries for m1 model and m2 model m1_state_dict = m1.state_dict() m2_state_dict = m2.state_dict() # Get m1 and m2 layer names m1_layer_names, m2_layer_names = [], [] for name, param in m1_state_dict.items(): m1_layer_names.append(name) for name, param in m2_state_dict.items(): m2_layer_names.append(name) cnt = 0 for ind in range(len(m1_layer_names)): if m1_layer_names[ind][:6] == 'resnet': cnt += 1 m2_state_dict[m2_layer_names[ind]] = m1_state_dict[ m1_layer_names[ind]].data m2.load_state_dict(m2_state_dict) print('Count of layers whose weights were copied between two models', cnt) return m2 if __name__ == '__main__': pr = pirl_resnet('res18', non_linear_head=False) cr = classifier_resnet('res18', num_classes=10) copy_success = copy_weights_between_models(pr, cr)
# Print sample batches that would be returned by the train_data_loader dataiter = iter(train_loader) X, y = dataiter.__next__() print(X.size()) print(y.size()) # Train required model using data loaders defined above num_outputs = 10 epochs = args.epochs lr = args.lr weight_decay_const = args.weight_decay # Define model_to_train and inherit weights from pre-trained SSL model model_to_train = classifier_resnet(args.model_type, num_classes=num_outputs) pirl_model = pirl_resnet(args.model_type) pirl_model.load_state_dict(torch.load(pirl_file_path, map_location=device)) weight_copy_success = copy_weights_between_models(pirl_model, model_to_train) if not weight_copy_success: print( 'Weight copy between SSL and classification net failed. Pls check !!' ) exit() # Freeze all layers except fully connected in classification net for name, param in model_to_train.named_parameters(): if name[:7] == 'resnet_': param.requires_grad = False
aux_model = CombineAndUpSample(n_feature=64) if args.fcn_type == 'fcn32s': main_model = fcn_resnet(args.model_type, 1) else: # fcn type = fcn8s main_model = fcn_resnet8s(args.model_type, 1) # Set device on which training is done. aux_model.to(device) main_model.to(device) # Inherit weights for aux model from pre-trained SSL aux model aux_file_path = os.path.join(PAR_WEIGHTS_DIR, args.ssl_trained_aux_file) aux_model.load_state_dict(torch.load(aux_file_path, map_location=device)) # Inherit weights for main model from pre-trained SSL main model pirl_model = pirl_resnet(args.model_type, non_linear_head=False) pirl_model_file_path = os.path.join(PAR_WEIGHTS_DIR, args.ssl_trained_main_file) pirl_model.load_state_dict( torch.load(pirl_model_file_path, map_location=device)) pirl_model.to(device) main_model = copy_weights_between_models(pirl_model, main_model) test_copy_weights_resnet_module(pirl_model, main_model) del pirl_model # Freeze all layers in the aux model and main model's resnet module for name, param in aux_model.named_parameters(): param.requires_grad = False for name, param in main_model.named_parameters(): if name[:7] == 'resnet_': param.requires_grad = False
train_sampler = SubsetRandomSampler(train_indices) val_sampler = SubsetRandomSampler(val_indices) train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, sampler=train_sampler, num_workers=8) val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, sampler=val_sampler, num_workers=8) # Train required model using data loaders defined above epochs = args.epochs lr = args.lr weight_decay_const = args.weight_decay # If using Resnet18 model_to_train = pirl_resnet(args.model_type, args.non_linear_head) # Set device on which training is done. Plus optimizer to use. model_to_train.to(device) sgd_optimizer = optim.SGD(model_to_train.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay_const) scheduler = CosineAnnealingLR(sgd_optimizer, args.tmax_for_cos_decay, eta_min=1e-4, last_epoch=-1) # Initialize model weights with a previously trained model if using warm start if args.warm_start and os.path.exists(model_file_path): model_to_train.load_state_dict(torch.load(model_file_path, map_location=device)) # Start training all_images_mem = np.random.randn(len_train_set, 128) model_train_test_obj = PIRLModelTrainTest( model_to_train, device, model_file_path, all_images_mem, train_indices, val_indices, args.count_negatives, args.temp_parameter, args.beta
np.save(all_samples_mem_file, all_samples_mem) if __name__ == '__main__': # Identify device for holding tensors and carrying out computations device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Define model file_paths aux_model_file = PAR_WEIGHTS_DIR + '/e1_pirl_auto_aux_epoch_90' main_model_file = PAR_WEIGHTS_DIR + '/e1_pirl_auto_main_epoch_90' all_samples_mem_file = PAR_ACTIVATIONS_DIR + '/e1_pirl_auto_activ_val_epoch_90.npy' # Get the initial random weight models aux_model = CombineAndUpSample(n_feature=64) main_model = pirl_resnet('res34', non_linear_head=False) # Initialize model weights with a previously trained model aux_model.load_state_dict(torch.load(aux_model_file, map_location=device)) main_model.load_state_dict(torch.load(main_model_file, map_location=device)) # Get data loader base_images_dir = '../data' scene_indices = np.arange(111, 134) sample_data_set = UnlabeledDataset(base_images_dir, scene_indices, first_dim='sample', transform=def_train_transform) sample_data_loader = torch.utils.data.DataLoader(sample_data_set, batch_size=37,
# Print sample batches that would be returned by the train_data_loader dataiter = iter(train_loader) X, y1, y2 = dataiter.__next__() print(len(train_set)) print(X.size()) print(y1.size()) print(y2.size()) # Train required model using data loaders defined above epochs = args.epochs lr = args.lr weight_decay_const = args.weight_decay # Define model(s) to train aux_model = CombineAndUpSample(n_feature=64) main_model = pirl_resnet(args.model_type, args.non_linear_head) # Set device on which training is done. Plus optimizer to use. aux_model.to(device) main_model.to(device) params = list(aux_model.parameters()) + list(main_model.parameters()) sgd_optimizer = optim.SGD(params, lr=lr, momentum=0.9, weight_decay=weight_decay_const) scheduler = CosineAnnealingLR(sgd_optimizer, args.tmax_for_cos_decay, eta_min=1e-4, last_epoch=-1) # Initialize model weights with a previously trained model if using warm start