Example #1
0
        print("Processing learning rate = {} {}".format(lr, '*' * 40))

        base_results_dir = './results/analyze_lr_rate/lr_{}'.format(lr)

        cont_int_layer = new_piech_models.CurrentSubtractInhibitLayer(
            lateral_e_size=15,
            lateral_i_size=15,
            n_iters=5,
            use_recurrent_batch_norm=True)
        # cont_int_layer = new_piech_models.CurrentDivisiveInhibitLayer(
        #     lateral_e_size=15, lateral_i_size=15, n_iters=5, use_recurrent_batch_norm=True)

        # cont_int_layer = new_control_models.ControlMatchParametersLayer(
        #      lateral_e_size=15, lateral_i_size=15)

        model = new_piech_models.ContourIntegrationResnet50(cont_int_layer)

        train_parameters = {
            'random_seed': random_seed,
            'train_batch_size': 32,
            'test_batch_size': 32,
            'learning_rate': lr,
            'num_epochs': 100,
            'lateral_w_reg_weight': 0.0001,
            'lateral_w_reg_gaussian_sigma': 10,
            'clip_negative_lateral_weights': True,
            'lr_sched_step_size': 80,
            'lr_sched_gamma': 0.5
        }

        main(model,
    # saved_contour_integration_model = \
    #     './results/new_model_resnet_based/Old/' \
    #     '/ContourIntegrationResnet50_CurrentSubtractInhibitLayer_20200816_222302_baseline' \
    #     '/best_accuracy.pth'

    # -----------------------------------------------------------------------------------
    # Model
    # -----------------------------------------------------------------------------------
    print(">>> Building the model {}".format('.' * 80))

    # # Check that the code works with standard models as well
    # net = torchvision.models.resnet50(pretrained=True)

    # Edge Extract + Contour Integration layers
    cont_int_model = new_piech_models.ContourIntegrationResnet50(
        cont_int_layer,
        pre_trained_edge_extract=True,
        classifier=new_piech_models.DummyHead)

    if saved_contour_integration_model is not None:
        cont_int_model.load_state_dict(
            torch.load(saved_contour_integration_model), strict=False)
        # strict = False do not care about loading classifier weights

    net = new_piech_models.embed_into_resnet50(
        edge_extract_and_contour_integration_layers=cont_int_model,
        pretrained=True)

    # check_requires_grad(net)
    # import pdb
    # pdb.set_trace()