def copy_weights_between_models(m1, m2):
    """
    Copy weights for layers common between m1 and m2.
    From m1 => m2
    """

    # Load state dictionaries for m1 model and m2 model
    m1_state_dict = m1.state_dict()
    m2_state_dict = m2.state_dict()

    # Set the m2 model's weights with trained m1 model weights
    for name, param in m1_state_dict.items():
        if name not in m2_state_dict:
            continue
        else:
            m2_state_dict[name] = param.data
    m2.load_state_dict(m2_state_dict)

    # Test that model m2 **really** has got updated weights
    return test_copy_weights(m1, m2)


if __name__ == '__main__':

    pr = pirl_resnet('res18')
    cr = classifier_resnet('res18', num_classes=10)

    copy_success = copy_weights_between_models(pr, cr)

    # Load state dictionaries for m1 model and m2 model
    m1_state_dict = m1.state_dict()
    m2_state_dict = m2.state_dict()

    # Get m1 and m2 layer names
    m1_layer_names, m2_layer_names = [], []
    for name, param in m1_state_dict.items():
        m1_layer_names.append(name)
    for name, param in m2_state_dict.items():
        m2_layer_names.append(name)

    cnt = 0
    for ind in range(len(m1_layer_names)):
        if m1_layer_names[ind][:6] == 'resnet':
            cnt += 1
            m2_state_dict[m2_layer_names[ind]] = m1_state_dict[
                m1_layer_names[ind]].data

    m2.load_state_dict(m2_state_dict)

    print('Count of layers whose weights were copied between two models', cnt)
    return m2


if __name__ == '__main__':

    pr = pirl_resnet('res18', non_linear_head=False)
    cr = classifier_resnet('res18', num_classes=10)

    copy_success = copy_weights_between_models(pr, cr)
    # Print sample batches that would be returned by the train_data_loader
    dataiter = iter(train_loader)
    X, y = dataiter.__next__()
    print(X.size())
    print(y.size())

    # Train required model using data loaders defined above
    num_outputs = 10
    epochs = args.epochs
    lr = args.lr
    weight_decay_const = args.weight_decay

    # Define model_to_train and inherit weights from pre-trained SSL model
    model_to_train = classifier_resnet(args.model_type,
                                       num_classes=num_outputs)
    pirl_model = pirl_resnet(args.model_type)
    pirl_model.load_state_dict(torch.load(pirl_file_path, map_location=device))
    weight_copy_success = copy_weights_between_models(pirl_model,
                                                      model_to_train)

    if not weight_copy_success:
        print(
            'Weight copy between SSL and classification net failed. Pls check !!'
        )
        exit()

    # Freeze all layers except fully connected in classification net
    for name, param in model_to_train.named_parameters():
        if name[:7] == 'resnet_':
            param.requires_grad = False
Example #4
0
    aux_model = CombineAndUpSample(n_feature=64)
    if args.fcn_type == 'fcn32s':
        main_model = fcn_resnet(args.model_type, 1)
    else:  #  fcn type = fcn8s
        main_model = fcn_resnet8s(args.model_type, 1)

    # Set device on which training is done.
    aux_model.to(device)
    main_model.to(device)

    # Inherit weights for aux model from pre-trained SSL aux model
    aux_file_path = os.path.join(PAR_WEIGHTS_DIR, args.ssl_trained_aux_file)
    aux_model.load_state_dict(torch.load(aux_file_path, map_location=device))

    # Inherit weights for main model from pre-trained SSL main model
    pirl_model = pirl_resnet(args.model_type, non_linear_head=False)
    pirl_model_file_path = os.path.join(PAR_WEIGHTS_DIR,
                                        args.ssl_trained_main_file)
    pirl_model.load_state_dict(
        torch.load(pirl_model_file_path, map_location=device))
    pirl_model.to(device)
    main_model = copy_weights_between_models(pirl_model, main_model)
    test_copy_weights_resnet_module(pirl_model, main_model)
    del pirl_model

    # Freeze all layers in the aux model and main model's resnet module
    for name, param in aux_model.named_parameters():
        param.requires_grad = False
    for name, param in main_model.named_parameters():
        if name[:7] == 'resnet_':
            param.requires_grad = False
Example #5
0
    train_sampler = SubsetRandomSampler(train_indices)
    val_sampler = SubsetRandomSampler(val_indices)

    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, sampler=train_sampler,
                                               num_workers=8)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, sampler=val_sampler,
                                             num_workers=8)

    # Train required model using data loaders defined above
    epochs = args.epochs
    lr = args.lr
    weight_decay_const = args.weight_decay

    # If using Resnet18
    model_to_train = pirl_resnet(args.model_type, args.non_linear_head)

    # Set device on which training is done. Plus optimizer to use.
    model_to_train.to(device)
    sgd_optimizer = optim.SGD(model_to_train.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay_const)
    scheduler = CosineAnnealingLR(sgd_optimizer, args.tmax_for_cos_decay, eta_min=1e-4, last_epoch=-1)

    # Initialize model weights with a previously trained model if using warm start
    if args.warm_start and os.path.exists(model_file_path):
        model_to_train.load_state_dict(torch.load(model_file_path, map_location=device))

    # Start training
    all_images_mem = np.random.randn(len_train_set, 128)
    model_train_test_obj = PIRLModelTrainTest(
        model_to_train, device, model_file_path, all_images_mem, train_indices, val_indices, args.count_negatives,
        args.temp_parameter, args.beta
    np.save(all_samples_mem_file, all_samples_mem)


if __name__ == '__main__':

    # Identify device for holding tensors and carrying out computations
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Define model file_paths
    aux_model_file = PAR_WEIGHTS_DIR + '/e1_pirl_auto_aux_epoch_90'
    main_model_file = PAR_WEIGHTS_DIR + '/e1_pirl_auto_main_epoch_90'
    all_samples_mem_file = PAR_ACTIVATIONS_DIR + '/e1_pirl_auto_activ_val_epoch_90.npy'

    # Get the initial random weight models
    aux_model = CombineAndUpSample(n_feature=64)
    main_model = pirl_resnet('res34', non_linear_head=False)

    # Initialize model weights with a previously trained model
    aux_model.load_state_dict(torch.load(aux_model_file, map_location=device))
    main_model.load_state_dict(torch.load(main_model_file,
                                          map_location=device))

    # Get data loader
    base_images_dir = '../data'
    scene_indices = np.arange(111, 134)
    sample_data_set = UnlabeledDataset(base_images_dir,
                                       scene_indices,
                                       first_dim='sample',
                                       transform=def_train_transform)
    sample_data_loader = torch.utils.data.DataLoader(sample_data_set,
                                                     batch_size=37,
Example #7
0
    # Print sample batches that would be returned by the train_data_loader
    dataiter = iter(train_loader)
    X, y1, y2 = dataiter.__next__()
    print(len(train_set))
    print(X.size())
    print(y1.size())
    print(y2.size())

    # Train required model using data loaders defined above
    epochs = args.epochs
    lr = args.lr
    weight_decay_const = args.weight_decay

    # Define model(s) to train
    aux_model = CombineAndUpSample(n_feature=64)
    main_model = pirl_resnet(args.model_type, args.non_linear_head)

    # Set device on which training is done. Plus optimizer to use.
    aux_model.to(device)
    main_model.to(device)
    params = list(aux_model.parameters()) + list(main_model.parameters())
    sgd_optimizer = optim.SGD(params,
                              lr=lr,
                              momentum=0.9,
                              weight_decay=weight_decay_const)
    scheduler = CosineAnnealingLR(sgd_optimizer,
                                  args.tmax_for_cos_decay,
                                  eta_min=1e-4,
                                  last_epoch=-1)

    # Initialize model weights with a previously trained model if using warm start