Ejemplo n.º 1
0
def train_cifar(args, nb_epochs, trainloader,testloader, net,optimizer,criterion,logging_freq=2000):
    # TODO: test loss
    for epoch in range(nb_epochs):  # loop over the dataset multiple times
        running_train_loss = 0.0
        #running_test_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs
            start_time = time.time()
            inputs, labels = data
            if args.enable_cuda:
                inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
            else:
                inputs, labels = Variable(inputs), Variable(labels)
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # print statistics
            running_train_loss += loss.data[0]
            seconds,minutes,hours = utils.report_times(start_time)
            st()
            if i % logging_freq == logging_freq-1:    # print every logging_freq mini-batches
                # note you dividing by logging_freq because you summed logging_freq mini-batches, so the average is dividing by logging_freq.
                print(f'monitoring during training: eptoch={epoch+1}, batch_index={i+1}, loss={running_train_loss/logging_freq}')
                running_train_loss = 0.0
 def test_report_times(self):
     eps = 0.01
     ##
     start = time.time()
     time.sleep(0.5)
     msg, h = timeSince(start)
     ##
     start = time.time()
     time.sleep(0.5)
     msg, seconds, _, _ = report_times(start)
     ##
     diff = abs(seconds - h*60*60)
     self.assertTrue(diff <= eps)
Ejemplo n.º 3
0
def main(plot=True):
    if args.means != '':
        means = [float(x.strip()) for x in args.means.strip('[').strip(']').split(',')]
    else:
        means = []
    if args.stds != '':
        stds = [float(x.strip()) for x in args.stds.strip('[').strip(']').split(',')]
    else:
        stds = []
    ##
    hostname = utils.get_hostname()
    ''' cuda '''
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print(f'device = {device}')
    ''' '''
    store_net = True
    other_stats = dict({'sj':sj,'satid':satid,'hostname':hostname,'label_corrupt_prob':args.label_corrupt_prob})
    ''' reproducibility setup/params'''
    #num_workers = 2 # how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
    githash = subprocess.check_output(["git", "describe", "--always"]).strip()
    seed = args.seed
    if seed is None: # if seed is None it has not been set, so get a random seed, else use the seed that was set
        seed = int.from_bytes(os.urandom(7), byteorder="big")
    print(f'seed: {seed}')
    ## SET SEED/determinism
    num_workers = 3
    torch.manual_seed(seed)
    #torch.backends.cudnn.deterministic=True
    ''' date parameters setup'''
    today_obj = date.today() # contains datetime.date(year, month, day); accessible via .day etc
    day = today_obj.day
    month = calendar.month_name[today_obj.month]
    setup_time = time.time()
    ''' filenames '''
    ## folder names
    results_root = './test_runs_flatness5_ProperOriginalExpt'
    expt_folder = f'flatness_{month}_label_corrupt_prob_{args.label_corrupt_prob}_exptlabel_{args.exptlabel}_' \
                  f'only_1st_layer_BIAS_{args.only_1st_layer_bias}_data_set_{args.data_set}_reg_param_{args.reg_param}'
    ## filenames
    matlab_file_name = f'flatness_{day}_{month}_sj_{sj}_staid_{satid}_seed_{seed}_{hostname}'
    net_file_name = f'net_{day}_{month}_sj_{sj}_staid_{satid}_seed_{seed}_{hostname}'
    ## folder to hold all nets
    all_nets_folder = f'nets_folder_{day}_{month}_sj_{sj}_staid_{satid}_seed_{seed}_{hostname}'
    ## experiment path
    expt_path = os.path.join(results_root,expt_folder)
    ''' data set '''
    data_path = './data'
    standardize = not args.dont_standardize_data # x - mu / std , [-1,+1]
    trainset, testset, classes = data_class.get_data_processors(data_path,args.label_corrupt_prob,dataset_type=args.data_set,standardize=standardize,type_standardize=args.type_standardize)
    ''' experiment params '''
    evalaute_mdl_data_set = get_function_evaluation_from_name(args.evalaute_mdl_data_set)
    suffle_test = False
    shuffle_train = True
    nb_epochs = 4 if args.epochs is None else args.epochs
    batch_size = 256
    #batch_size_train,batch_size_test = batch_size,batch_size
    batch_size_train = batch_size
    batch_size_test = 256
    ''' get NN '''
    nets = []
    mdl = args.mdl
    do_bn = args.use_bn
    other_stats = dict({'mdl':mdl,'do_bn':do_bn, 'type_standardize':args.type_standardize},**other_stats)
    print(f'model = {mdl}')
    if mdl == 'cifar_10_tutorial_net':
        suffle_test = False
        net = nn_mdls.Net()
        nets.append(net)
    elif mdl == 'debug':
        suffle_test = False
        nb_conv_layers=1
        ## conv params
        Fs = [3]*nb_conv_layers
        Ks = [2]*nb_conv_layers
        ## fc params
        FC = len(classes)
        C,H,W = 3,32,32
        net = nn_mdls.LiaoNet(C,H,W,Fs,Ks,FC,do_bn)
        nets.append(net)
    elif mdl == 'sequential':
        batch_size_train = 256
        batch_size_test = 256
        ##
        batch_size = batch_size_train
        suffle_test = False
        ##
        FC = [10,10]
        C,H,W = 3, 32, 32
        # net = torch.nn.Sequential(OrderedDict([
        #     ('Flatten',Flatten()),
        #     ('FC1', torch.nn.Linear(C*H*W,FC[0])),
        #     ('FC2', torch.nn.Linear(FC[0],FC[1]))
        # ]))
        # net = torch.nn.Sequential(OrderedDict([
        #     ('Flatten',Flatten()),
        #     ('FC1', torch.nn.Linear(C*H*W,FC[0])),
        #     ('relu1', torch.nn.ReLU()),
        #     ('FC2', torch.nn.Linear(FC[0],FC[1]))
        # ]))
        net = torch.nn.Sequential(OrderedDict([
            ('conv0', torch.nn.Conv2d(3,420,5,bias=True)),
            ('relu0', torch.nn.ReLU()),
            ('conv1', torch.nn.Conv2d(420,50,5, bias=True)),
            ('relu1', torch.nn.ReLU()),
            ('Flatten',Flatten()),
            ('FC1', torch.nn.Linear(28800,50,bias=True)),
            ('relu2', torch.nn.ReLU()),
            ('FC2', torch.nn.Linear(50, 10, bias=True))
        ]))
        ##
        nets.append(net)
    elif mdl == 'BoixNet':
        batch_size_train = 256
        batch_size_test = 256
        ##
        batch_size = batch_size_train
        suffle_test = False
        ## conv params
        nb_filters1,nb_filters2 = 32, 32
        nb_filters1, nb_filters2 = 32, 32
        kernel_size1,kernel_size2 = 5,5
        ## fc params
        nb_units_fc1,nb_units_fc2,nb_units_fc3 = 512,256,len(classes)
        C,H,W = 3,32,32
        net = nn_mdls.BoixNet(C,H,W,nb_filters1,nb_filters2, kernel_size1,kernel_size2, nb_units_fc1,nb_units_fc2,nb_units_fc3,do_bn)
        nets.append(net)
    elif mdl == 'LiaoNet':
        suffle_test = False
        nb_conv_layers=5
        ## conv params
        Fs = [32]*nb_conv_layers
        Ks = [10]*nb_conv_layers
        ## fc params
        FC = len(classes)
        C,H,W = 3,32,32
        net = nn_mdls.LiaoNet(C,H,W,Fs,Ks,FC,do_bn)
        nets.append(net)
    elif mdl == 'GBoixNet':
        #batch_size_train = 16384 # 2**14
        #batch_size_test = 16384
        batch_size_train = 2**10
        batch_size_test = 2**10
        ##
        batch_size = batch_size_train
        suffle_test = False
        ## conv params
        nb_conv_layers=2
        Fs = [34]*nb_conv_layers
        Ks = [5]*nb_conv_layers
        #nb_conv_layers = 4
        #Fs = [60] * nb_conv_layers
        #Ks = [5] * nb_conv_layers
        ## fc params
        FCs = [len(classes)]
        ##
        print(f'------> FCs = {FCs}')
        if args.data_set == 'mnist':
            CHW = (1, 28, 28)
        else:
            CHW = (3,32,32)
        net = nn_mdls.GBoixNet(CHW,Fs,Ks,FCs,do_bn,only_1st_layer_bias=args.only_1st_layer_bias)
        print(f'net = {net}')
        ##
        if len(means) != 0 and len(stds) != 0:
            params = net.named_parameters()
            dict_params = dict(params)
            i = 0
            for name, param in dict_params.items():
                if name in dict_params:
                    print(name)
                    if name != 'conv0.bias':
                        mu,s = means[i], stds[i]
                        param.data.normal_(mean=mu,std=s)
                        i+=1
        ##
        expt_path = f'{expt_path}_means_{args.means}_stds_{args.stds}'
        other_stats = dict({'means': means, 'stds': stds}, **other_stats)
        ##
        nets.append(net)
        other_stats = dict({'only_1st_layer_bias': args.only_1st_layer_bias}, **other_stats)
    elif mdl == 'AllConvNetStefOe':
        #batch_size_train = 16384 # 2**14
        #batch_size_test = 16384
        #batch_size_train = 2**10
        # batch_size_train = 2**10
        # batch_size_test = 2**10
        batch_size_train = 32
        batch_size_test = 124
        ##
        batch_size = batch_size_train
        suffle_test = False
        ## AllConvNet
        only_1st_layer_bias = args.only_1st_layer_bias
        CHW = (3,32,32)
        dropout = args.use_dropout
        net = nn_mdls.AllConvNetStefOe(nc=len(CHW),dropout=dropout,only_1st_layer_bias=only_1st_layer_bias)
        ##
        nets.append(net)
        other_stats = dict({'only_1st_layer_bias': args.only_1st_layer_bias,'dropout':dropout}, **other_stats)
        expt_path = f'{expt_path}_dropout_{dropout}'
    elif mdl == 'AndyNet':
        #batch_size_train = 16384 # 2**14
        #batch_size_test = 16384
        #batch_size_train = 2**10
        batch_size_train = 2**10
        batch_size_test = 2**10
        # batch_size_train = 32
        # batch_size_test = 124
        ##
        batch_size = batch_size_train
        suffle_test = False
        ## AndyNet
        #only_1st_layer_bias = args.only_1st_layer_bias ## TODO fix
        only_1st_layer_bias = args.only_1st_layer_bias
        CHW = (3,32,32)
        net = nn_mdls.get_AndyNet()
        ##
        nets.append(net)
        other_stats = dict({'only_1st_layer_bias': args.only_1st_layer_bias}, **other_stats)
        expt_path = f'{expt_path}'
    elif mdl == 'interpolate':
        suffle_test = True
        batch_size = 2**10
        batch_size_train, batch_size_test = batch_size, batch_size
        iterations = inf # controls how many epochs to stop before returning the data set error
        #iterations = 1 # controls how many epochs to stop before returning the data set error
        ''' '''
        path_nl = os.path.join(results_root,'flatness_27_April_label_corrupt_prob_0.0_exptlabel_GB_24_24_10_2C1FC_momentum_NL_polestar/net_27_April_sj_343_staid_1_seed_56134200848018679')
        path_rl_nl = os.path.join(results_root,'flatness_27_April_label_corrupt_prob_0.0_exptlabel_GB_24_24_10_2C1FC_momentum_RLNL_polestar/net_27_April_sj_345_staid_1_seed_57700439347820897')
        ''' restore nets'''
        net_nl = utils.restore_entire_mdl(path_nl)
        net_rlnl = utils.restore_entire_mdl(path_rl_nl)
        nets.append(net_nl)
        nets.append(net_rlnl)
    elif mdl == 'radius_flatness':
        suffle_test = True
        batch_size = 2**10
        batch_size_train, batch_size_test = batch_size, batch_size
        iterations = 11 # controls how many epochs to stop before returning the data set error
        #iterations = inf  # controls how many epochs to stop before returning the data set error
        other_stats = dict({'iterations':iterations},**other_stats)
        ''' load net '''
        if args.net_name == 'NL':
            #path = os.path.join(results_root,'flatness_28_March_label_corrupt_prob_0.0_exptlabel_BoixNet_polestar_300_stand_natural_labels/net_28_March_206')
            path = os.path.join(results_root,'flatness_27_April_label_corrupt_prob_0.0_exptlabel_GB_24_24_10_2C1FC_momentum_NL_polestar/net_27_April_sj_343_staid_1_seed_56134200848018679')
        else: # RLNL
            #path = os.path.join(results_root,'flatness_28_March_label_corrupt_prob_0.0_exptlabel_re_train_RLBoixNet_noBN_polestar_150/net_28_March_18')
            path = os.path.join(results_root,'flatness_27_April_label_corrupt_prob_0.0_exptlabel_GB_24_24_10_2C1FC_momentum_RLNL_polestar/net_27_April_sj_345_staid_1_seed_57700439347820897')
        ''' restore nets'''
        net = utils.restore_entire_mdl(path)
        nets.append(net)
        store_net = False
    elif mdl == 'sharpness':
        suffle_test=False #doesn't matter
        ''' load net '''
        if args.net_name == 'NL':
            #path = os.path.join(results_root,'flatness_28_March_label_corrupt_prob_0.0_exptlabel_BoixNet_polestar_300_stand_natural_labels/net_28_March_206')
            path = os.path.join(results_root,'flatness_27_April_label_corrupt_prob_0.0_exptlabel_GB_24_24_10_2C1FC_momentum_NL_polestar/net_27_April_sj_343_staid_1_seed_56134200848018679')
            path_adverserial_data = os.path.join('./data/sharpness_data_NL/','sdata_NL_net_27_April_sj_343_staid_1_seed_56134200848018679.npz')
        else: # RLNL
            #path = os.path.join(results_root,'flatness_28_March_label_corrupt_prob_0.0_exptlabel_re_train_RLBoixNet_noBN_polestar_150/net_28_March_18')
            path = os.path.join(results_root,'flatness_27_April_label_corrupt_prob_0.0_exptlabel_GB_24_24_10_2C1FC_momentum_RLNL_polestar/net_27_April_sj_345_staid_1_seed_57700439347820897')
            path_adverserial_data = os.path.join('./data/sharpness_data_RLNL/','sdata_RLNL_net_27_April_sj_345_staid_1_seed_57700439347820897.npz')
        ''' restore nets'''
        net = torch.load(path)
        nets.append(net)
        store_net = False
    elif mdl == 'divide_constant':
        ''' ''' # both false because we want low variation on the output of the error
        iterations = inf # controls how many epochs to stop before returning the data set error
        #iterations = 11 # controls how many epochs to stop before returning the data set error
        batch_size = 2**10
        batch_size_train, batch_size_test = batch_size, batch_size
        shuffle_train = True
        suffle_test = False
        ''' load net '''
        ## NL
        #path_nl = os.path.join(results_root,'flatness_27_April_label_corrupt_prob_0.0_exptlabel_GB_24_24_10_2C1FC_momentum_NL_polestar/net_27_April_sj_343_staid_1_seed_56134200848018679')
        #path_nl = os.path.join(results_root,'flatness_May_label_corrupt_prob_0.0_exptlabel_SGD_ManyRuns_Momentum0.9/net_17_May_sj_641_staid_5_seed_31866864409272026_polestar-old')
        path_nl = os.path.join(results_root,'flatness_May_label_corrupt_prob_0.0_exptlabel_MovieNL_lr_0.01_momentum_0.9/net_22_May_sj_1168_staid_1_seed_59937023958974481_polestar-old_epoch_173')
        ## RLNL
        #path_rlnl = os.path.join(results_root,'flatness_27_April_label_corrupt_prob_0.0_exptlabel_GB_24_24_10_2C1FC_momentum_RLNL_polestar/net_27_April_sj_345_staid_1_seed_57700439347820897')
        path_rlnl = os.path.join(results_root,'flatness_May_label_corrupt_prob_0.0_exptlabel_MovieRLNLmdls_label_corruption0.5_lr_0.01_momentum_0.9/net_22_May_sj_1172_staid_1_seed_38150714758131256_polestar-old_epoch_148')
        ##
        net_nl = torch.load(path_nl)
        net_rlnl = torch.load(path_rlnl)
        ''' '''
        print('NL')
        l2_norm_all_params(net_nl)
        print('RLNL')
        l2_norm_all_params(net_rlnl)
        ''' modify nets '''
        W_nl = 1
        W_rlnl = (get_norm(net_rlnl, l=2)/get_norm(net_nl, l=2)) # 2.284937620162964
        W_rlnl = (10)**(1.0/3.0)
        #W_rlnl = 1/0.57775
        #W_rlnl = 1/0.7185
        #W_rlnl = 1/0.85925
        #W_rlnl = 1
        print(f'W_rlnl = {W_rlnl}')
        print(f'norm of weight BEFORE division: get_norm(net_nl,l=2)={get_norm(net_nl,l=2)}, get_norm(net_rlnl,l=2)={get_norm(net_rlnl,l=2)}')
        #net_nl = divide_params_by(W_nl, net_nl)
        #net_rlnl = divide_params_by(W_rlnl, net_rlnl)
        net_rlnl = divide_params_by_taking_bias_into_account(W=W_rlnl,net=net_rlnl)
        print(f'norm of weight AFTER division: get_norm(net_nl,l=2)={get_norm(net_nl,l=2)}, get_norm(net_rlnl,l=2)={get_norm(net_rlnl,l=2)}')
        nets.append(net_nl)
        nets.append(net_rlnl)
        other_stats = dict({'W_rlnl':W_rlnl,'W_nl':W_nl})
    elif mdl == 'load_nl_and_rlnl':
        ''' load net '''
        # NL
        #path = os.path.join(results_root,'flatness_27_April_label_corrupt_prob_0.0_exptlabel_GB_24_24_10_2C1FC_momentum_NL_polestar/net_27_April_sj_343_staid_1_seed_56134200848018679')
        path = os.path.join(results_root,'flatness_May_label_corrupt_prob_0.0_exptlabel_MovieNL_lr_0.01_momentum_0.9/net_22_May_sj_1168_staid_1_seed_59937023958974481_polestar-old_epoch_173')
        net = torch.load(path)
        nets.append(net)
        # RLNL
        #path_rlnl = os.path.join(results_root,'flatness_27_April_label_corrupt_prob_0.0_exptlabel_GB_24_24_10_2C1FC_momentum_RLNL_polestar/net_27_April_sj_345_staid_1_seed_57700439347820897')
        path_rlnl = os.path.join(results_root,'flatness_May_label_corrupt_prob_0.0_exptlabel_MovieRLNLmdls_label_corruption0.5_lr_0.01_momentum_0.9/net_22_May_sj_1172_staid_1_seed_38150714758131256_polestar-old_epoch_148')
        net_rlnl = torch.load(path_rlnl)
        nets.append(net_rlnl)
        other_stats = dict({'path': path, 'path_rlnl': path_rlnl}, **other_stats)
    elif mdl == 'load_one_net':
        # path = os.path.join(results_root, '/')
        ''' load net '''
        ## 0.0001
        path = os.path.join(results_root,'flatness_June_label_corrupt_prob_0.0001_exptlabel_RLInits_only_1st_layer_BIAS_True_batch_size_train_1024_lr_0.01_momentum_0.9_scheduler_milestones_200,250,300_gamma_1.0/net_21_June_sj_974_staid_1_seed_44940314088747654_polestar-old')
        ## 0.001
        path = os.path.join(results_root,'flatness_June_label_corrupt_prob_0.001_exptlabel_RLInits_only_1st_layer_BIAS_True_batch_size_train_1024_lr_0.01_momentum_0.9_scheduler_milestones_200,250,300_gamma_1.0/net_21_June_sj_967_staid_1_seed_1986409594254668_polestar-old')
        ## 0.01
        path = os.path.join(results_root, 'flatness_June_label_corrupt_prob_0.01_exptlabel_RLInits_only_1st_layer_BIAS_True_batch_size_train_1024_lr_0.01_momentum_0.9_scheduler_milestones_200,250,300_gamma_1.0/net_21_June_sj_976_staid_1_seed_34669758900780265_polestar-old')
        ## 0.1
        path = os.path.join(results_root,'flatness_June_label_corrupt_prob_0.1_exptlabel_RLInits_only_1st_layer_BIAS_True_batch_size_train_1024_lr_0.01_momentum_0.9_scheduler_milestones_200,250,300_gamma_1.0/net_21_June_sj_977_staid_1_seed_57003505407221650_polestar-old')
        ## 0.2
        path = os.path.join(results_root, 'flatness_June_label_corrupt_prob_0.2_exptlabel_RLInits_only_1st_layer_BIAS_True_batch_size_train_1024_lr_0.01_momentum_0.9_scheduler_milestones_200,250,300_gamma_1.0/net_21_June_sj_978_staid_1_seed_63479113068450657_polestar-old')
        ## 0.5
        path = os.path.join(results_root, 'flatness_June_label_corrupt_prob_0.5_exptlabel_RLInits_only_1st_layer_BIAS_True_batch_size_train_1024_lr_0.01_momentum_0.9_scheduler_milestones_200,250,300_gamma_1.0/net_21_June_sj_979_staid_1_seed_51183371945505111_polestar-old')
        ## 0.75
        path = os.path.join(results_root, 'flatness_June_label_corrupt_prob_0.75_exptlabel_RLInits_only_1st_layer_BIAS_True_batch_size_train_1024_lr_0.01_momentum_0.9_scheduler_milestones_200,250,300_gamma_1.0/net_21_June_sj_980_staid_1_seed_63292262317939652_polestar-old')
        ## 1.0
        path = os.path.join(results_root, 'flatness_June_label_corrupt_prob_1.0_exptlabel_RLInits_only_1st_layer_BIAS_True_batch_size_train_1024_lr_0.01_momentum_0.9_scheduler_milestones_200,250,300_gamma_1.0/net_21_June_sj_981_staid_1_seed_34295360820373818_polestar-old')
        ''' load net '''
        net = torch.load(path)
        nets.append(net)
        other_stats = dict({'path': path}, **other_stats)
    elif mdl == 'l2_norm_all_params':
        ''' load net '''
        # path = os.path.join(results_root,'flatness_June_label_corrupt_sqprob_0.0_exptlabel_WeightDecay_lambda100_lr_0.1_momentum_0.0/net_1_June_sj_2833_staid_2_seed_45828051420330772_polestar-old')
        # path = os.path.join(results_root,'flatness_June_label_corrupt_prob_0.0_exptlabel_WeightDecay_lambda1_lr_0.1_momentum_0.0/net_1_June_sj_2830_staid_1_seed_53714812690274511_polestar-old')
        # path = os.path.join(results_root,'flatness_June_label_corrupt_prob_0.0_exptlabel_WeightDecay_lambda0.1_lr_0.1_momentum_0.0/net_1_June_sj_2835_staid_2_seed_66755608399194708_polestar-old')
        # path = os.path.join(results_root,'flatness_June_label_corrupt_prob_0.0_exptlabel_WeightDecay_lambda0.01_lr_0.1_momentum_0.0/net_1_June_sj_2832_staid_1_seed_47715620118836168_polestar-old')

        #path = os.path.join(results_root,'flatness_May_label_corrupt_prob_0.0_exptlabel_WeightDecay_lambda0.1_lr_0.01_momentum_0.9/net_31_May_sj_2784_staid_1_seed_59165331201064855_polestar-old')
        #path = os.path.join(results_root,'flatness_May_label_corrupt_prob_0.0_exptlabel_WeightDecay_lambda0.01_lr_0.01_momentum_0.9/net_31_May_sj_2792_staid_1_seed_42391375291583068_polestar-old')
        #path = os.path.join(results_root,'flatness_May_label_corrupt_prob_0.0_exptlabel_WeightDecay_lambda0.001_lr_0.01_momentum_0.9/net_31_May_sj_2793_staid_2_seed_47559284752010338_polestar-old')

        #path = os.path.join(results_root,'flatness_June_label_corrupt_prob_0.0_exptlabel_L2_squared_lambda1_lr_0.1_momentum_0.0/net_1_June_sj_2841_staid_2_seed_29441453139027048_polestar-old')
        #path = os.path.join(results_root,'flatness_June_label_corrupt_prob_0.0_exptlabel_L2_squared_lambda0.1_lr_0.1_momentum_0.0/net_1_June_sj_2839_staid_2_seed_35447208985369634_polestar-old')
        #path = os.path.join(results_root,'flatness_June_label_corrupt_prob_0.0_exptlabel_L2_squared_lambda0.01_lr_0.1_momentum_0.0/net_1_June_sj_2837_staid_2_seed_57556488720733908_polestar-old')
        #path = os.path.join(results_root,'flatness_June_label_corrupt_prob_0.0_exptlabel_L2_squared_lambda0.001_lr_0.1_momentum_0.0/net_1_June_sj_2848_staid_1_seed_48943421305461120_polestar-old')
        #path = os.path.join(results_root,'flatness_June_label_corrupt_prob_0.0_exptlabel_L2_squared_lambda0.0001_lr_0.1_momentum_0.0/net_1_June_sj_2850_staid_1_seed_2881772832480048_polestar-old')
        #path = os.path.join(results_root,'flatness_June_label_corrupt_prob_0.0_exptlabel_L2_squared_lambda0.00001_lr_0.1_momentum_0.0/net_1_June_sj_2852_staid_1_seed_24293440492629928_polestar-old')
        print(f'path = {path}')
        net = torch.load(path)
        ''' l2_norm_all_params '''
        l2_norm_all_params(net)
        ''' evaluate data set '''
        standardize = not args.dont_standardize_data  # x - mu / std , [-1,+1]
        error_criterion = metrics.error_criterion
        criterion = torch.nn.CrossEntropyLoss()
        trainset, testset, classes = data_class.get_data_processors(data_path, args.label_corrupt_prob,dataset_type=args.data_set,standardize=standardize)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size_train, shuffle=shuffle_train,num_workers=num_workers)
        testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size_test, shuffle=suffle_test,num_workers=num_workers)
        train_loss_epoch, train_error_epoch = evalaute_mdl_data_set(criterion, error_criterion, net,trainloader,device)
        test_loss_epoch, test_error_epoch = evalaute_mdl_data_set(criterion, error_criterion, net,testloader,device)
        print(f'[-1, -1], (train_loss: {train_loss_epoch}, train error: {train_error_epoch}) , (test loss: {test_loss_epoch}, test error: {test_error_epoch})')
        ''' end '''
        nets.append(net)
        sys.exit()
    else:
        print('RESTORED FROM PRE-TRAINED NET')
        suffle_test = False
        ''' RESTORED PRE-TRAINED NET '''
        # example name of file, os.path.join(results_root,expt_path,f'net_{day}_{month}_{seed}')
        # args.net_path = 'flatness_27_March_label_corrupt_prob_0_exptlabel_BoixNet_stand_600_OM/net_27_Match_64'
        path_to_mdl = args.mdl
        path = os.path.join(results_root,path_to_mdl)
        # net = utils.restore_entire_mdl(path)
        net = torch.load(path)
        nets.append(net)
    print(f'nets = {nets}')
    ''' cuda/gpu '''
    for net in nets:
        net.to(device)
    nb_params = nn_mdls.count_nb_params(net)
    ''' stats collector '''
    stats_collector = StatsCollector(net)
    ''' get data set '''
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size_train, shuffle=shuffle_train, num_workers=num_workers)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size_test, shuffle=suffle_test, num_workers=num_workers)
    ''' Cross Entropy + Optmizer '''
    lr = args.lr
    momentum = 0.9
    ## Error/Loss criterions
    error_criterion = metrics.error_criterion
    criterion = torch.nn.CrossEntropyLoss()
    #criterion = torch.nn.MultiMarginLoss()
    #criterion = torch.nn.MSELoss(size_average=True)
    print(f'Training Algorithm = {args.train_alg}')
    if args.train_alg == 'SGD':
        optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
    elif args.train_alg == 'Adam':
        optimizer = optim.Adam(net.parameters(), lr=lr)
    else:
        raise ValueError(f'Training alg not existent: {args.train_alg}')
    other_stats = dict({'nb_epochs':nb_epochs,'batch_size':batch_size,'mdl':mdl,'lr':lr,'momentum':momentum, 'seed':seed,'githash':githash},**other_stats)
    expt_path = f'{expt_path}_args.train_alg_{args.train_alg}_batch_train_{batch_size_train}_lr_{lr}_moment_{momentum}_epochs_{nb_epochs}'
    ''' scheduler '''
    #milestones = [20, 30, 40]
    milestones = [200, 250, 300]
    #milestones = [700, 800, 900]
    #milestones = [1700, 1800, 1900]
    scheduler_gamma = args.decay_rate
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=scheduler_gamma)
    other_stats = dict({'milestones': milestones, 'scheduler_gamma': scheduler_gamma}, **other_stats)
    milestones_str = ','.join(str(m) for m in milestones)
    #expt_path = f'{expt_path}_scheduler_milestones_{milestones_str}_gamma_{gamma}'
    expt_path = f'{expt_path}_scheduler_gamma_{scheduler_gamma}'
    print(f'scheduler_gamma = {scheduler_gamma}')
    ''' Verify model you got has the right error'''
    train_loss_epoch, train_error_epoch = evalaute_mdl_data_set(criterion, error_criterion, net, trainloader, device)
    test_loss_epoch, test_error_epoch = evalaute_mdl_data_set(criterion, error_criterion, net, testloader, device)
    print(f'train_loss_epoch, train_error_epoch  = {train_loss_epoch}, {train_error_epoch} \n test_loss_epoch, test_error_epoch  = {test_loss_epoch}, {test_error_epoch}')
    ''' Is it over parametrized?'''
    overparametrized = len(trainset)<nb_params # N < W ?
    print(f'Model overparametrized? N, W = {len(trainset)} vs {nb_params}')
    print(f'Model overparametrized? N < W = {overparametrized}')
    other_stats = dict({'overparametrized':overparametrized,'nb_params':nb_params}, **other_stats)
    ''' report time for setup'''
    seconds_setup,minutes_setup,hours_setup = utils.report_times(setup_time,'setup')
    other_stats = dict({'seconds_setup': seconds_setup, 'minutes_setup': minutes_setup, 'hours_setup': hours_setup}, **other_stats)
    ''' Start Training '''
    training_time = time.time()
    print(f'----\nSTART training: label_corrupt_prob={args.label_corrupt_prob},nb_epochs={nb_epochs},batch_size={batch_size},lr={lr},momentum={momentum},mdl={mdl},batch-norm={do_bn},nb_params={nb_params}')
    ## START TRAIN
    if args.train_alg == 'SGD' or args.train_alg == 'Adam':
        #iterations = 4 # the number of iterations to get a sense of test error, smaller faster larger more accurate. Grows as sqrt(n) though.
        iterations = inf
        ''' set up Trainer '''
        if args.save_every_epoch:
            save_every_epoch = args.save_every_epoch
            trainer = Trainer(trainloader, testloader, optimizer, scheduler, criterion, error_criterion, stats_collector,
                              device, expt_path,net_file_name,all_nets_folder,save_every_epoch,args.evalaute_mdl_data_set,
                              reg_param=args.reg_param,p=args.Lp_norm)
        else:
            trainer = Trainer(trainloader,testloader, optimizer, scheduler, criterion,error_criterion, stats_collector,
                              device,evalaute_mdl_data_set=args.evalaute_mdl_data_set,reg_param=args.reg_param,p=args.Lp_norm)
        last_errors = trainer.train_and_track_stats(net, nb_epochs,iterations)
        ''' Test the Network on the test data '''
        train_loss_epoch, train_error_epoch, test_loss_epoch, test_error_epoch = last_errors
        print(f'train_loss_epoch={train_loss_epoch} \ntrain_error_epoch={train_error_epoch} \ntest_loss_epoch={test_loss_epoch} \ntest_error_epoch={test_error_epoch}')
    elif args.train_alg == 'pert':
        ''' batch sizes '''
        batch_size_train, batch_size_test = 50*10**3, 10*10**3
        ''' number of repetitions '''
        nb_perturbation_trials = nb_epochs
        ''' noise level '''
        nb_layers = len(list(net.parameters()))
        noise_level = args.noise_level
        perturbation_magnitudes = nb_layers*[noise_level]
        print(f'noise_level={noise_level}')
        ''' locate where to save it '''
        folder_name_noise = f'noise_{perturbation_magnitudes[0]}'
        expt_path = os.path.join(expt_path,folder_name_noise)
        matlab_file_name = f'noise_{perturbation_magnitudes}_{matlab_file_name}'
        ## TODO collect by perburbing current model X number of times with current perturbation_magnitudes
        use_w_norm2 = args.not_pert_w_norm2
        train_loss,train_error,test_loss,test_error = get_errors_for_all_perturbations(net,perturbation_magnitudes,use_w_norm2,device,nb_perturbation_trials,stats_collector,criterion,error_criterion,trainloader,testloader)
        print(f'noise_level={noise_level},train_loss,train_error,test_loss,test_error={train_loss},{train_error},{test_loss},{test_error}')
        other_stats = dict({'perturbation_magnitudes':perturbation_magnitudes}, **other_stats)
    elif args.train_alg == 'interpolate':
        ''' prints stats before interpolation'''
        print_evaluation_of_nets(net_nl, net_rlnl, criterion, error_criterion, trainloader, testloader, device, iterations)
        ''' do interpolation of nets'''
        nb_interpolations = nb_epochs
        interpolations = np.linspace(0,1,nb_interpolations)
        get_landscapes_stats_between_nets(net_nl,net_rlnl,interpolations, device,stats_collector,criterion,error_criterion,trainloader,testloader,iterations)
        ''' print stats of the net '''
        other_stats = dict({'interpolations':interpolations},**other_stats)
        #print_evaluation_of_nets(net_nl, net_rlnl, criterion, error_criterion, trainloader, testloader, device, iterations)
    elif args.train_alg == 'brando_chiyuan_radius_inter':
        r_large = args.r_large ## check if this number is good
        nb_radius_samples = nb_epochs
        interpolations = np.linspace(0,1,nb_radius_samples)
        expt_path = os.path.join(expt_path+f'_RLarge_{r_large}')
        ''' '''
        nb_dirs = args.nb_dirs
        stats_collector = StatsCollector(net,nb_dirs,nb_epochs)
        get_all_radius_errors_loss_list_interpolate(nb_dirs,net,r_large,interpolations,device,stats_collector,criterion,error_criterion,trainloader,testloader,iterations)
        other_stats = dict({'nb_dirs':nb_dirs,'interpolations':interpolations,'nb_radius_samples':nb_radius_samples,'r_large':r_large},**other_stats)
    elif args.train_alg == 'sharpness':
        ''' load the data set '''
        print('About to load the data set')
        shuffle_train = True
        #batch_size = 2**10
        batch_size = 2**5
        batch_size_train, batch_size_test = batch_size, batch_size
        iterations = inf  # controls how many epochs to stop before returning the data set error
        #eps = 2500/50000
        eps = 1 / 50000
        other_stats = dict({'iterations':iterations,'eps':eps},**other_stats)
        trainset,trainloader = data_class.load_only_train(path_adverserial_data,eps,batch_size_train,shuffle_train,num_workers)
        ''' three musketeers '''
        print('Preparing the three musketeers')
        net_pert = copy.deepcopy(net)
        #nn_mdls.reset_parameters(net_pert)
        net_original = dont_train(net)
        #net_original = net
        initialize_to_zero(net_original)
        debug=False
        if debug:
            ## conv params
            nb_conv_layers=3
            Fs = [24]*nb_conv_layers
            Ks = [5]*nb_conv_layers
            ## fc params
            FCs = [len(classes)]
            CHW = (3,32,32)
            net_pert = nn_mdls.GBoixNet(CHW,Fs,Ks,FCs,do_bn).to(device)
        print('Musketeers are prepared')
        ''' optimizer + criterion stuff '''
        optimizer = optim.SGD(net_pert.parameters(), lr=lr, momentum=momentum)
        #optimizer = optim.Adam(net_pert.parameters(), lr=lr)
        error_criterion = metrics.error_criterion
        criterion = torch.nn.CrossEntropyLoss()
        #criterion = torch.nn.MultiMarginLoss()
        #criterion = torch.nn.MultiLabelMarginLoss()
        ''' Landscape Inspector '''
        save_all_learning_curves = True
        save_all_perts = False
        nb_lambdas = 1
        lambdas = np.linspace(1,10,nb_lambdas)
        print('Do Sharpness expt!')
        sharpness_inspector = LandscapeInspector(net_original,net_pert, nb_epochs,iterations, trainloader,testloader, optimizer,
            criterion,error_criterion, device, lambdas,save_all_learning_curves=save_all_learning_curves,save_all_perts=save_all_perts)
        sharpness_inspector.do_sharpness_experiment()
    elif args.train_alg == 'flatness_bs':
        ''' BS params '''
        r_initial = 50
        epsilon = args.epsilon ## check if this number is good
        # nb_radius_samples = nb_epochs could use this number as a cap of # iterations of BS
        expt_path = os.path.join(expt_path+f'_BS')
        ''' Do BS '''
        precision = 0.001
        nb_dirs = args.nb_dirs
        # stats_collector = StatsCollector(net,nb_dirs,nb_epochs) TODO
        rand_inspector = RandLandscapeInspector(epsilon,net,r_initial,device,criterion,error_criterion,trainloader,testloader,iterations)
        rand_inspector.get_faltness_radii_for_isotropic_directions(nb_dirs=nb_dirs,precision=precision)
        other_stats = dict({'nb_dirs':nb_dirs,'flatness_radii':rand_inspector.flatness_radii},**other_stats)
    elif args.train_alg == 'evaluate_nets':
        plot = False
        print('')
        iterations = inf
        print(f'W_nl = {W_nl}')
        print(f'W_rlnl = {W_rlnl}')
        ''' train errors '''
        loss_nl_train, error_nl_train = evalaute_mdl_data_set(criterion, error_criterion, net_nl, trainloader, device, iterations)
        loss_rlnl_train, error_rlnl_train = evalaute_mdl_data_set(criterion,error_criterion,net_rlnl,trainloader,device,iterations)
        print(f'loss_nl_train, error_nl_train = {loss_nl_train, error_nl_train}')
        print(f'loss_rlnl_train, error_rlnl_train = {loss_rlnl_train, error_rlnl_train}')
        ''' test errors '''
        loss_nl_test, error_nl_test = evalaute_mdl_data_set(criterion, error_criterion, net_nl, testloader, device, iterations)
        loss_rlnl_test, error_rlnl_test = evalaute_mdl_data_set(criterion,error_criterion,net_rlnl,testloader,device,iterations)
        print(f'loss_nl_test, error_nl_test = {loss_nl_test, error_nl_test}')
        print(f'loss_rlnl_test, error_rlnl_test = {loss_rlnl_test, error_rlnl_test}')
        ''' '''
        store_results = False
        store_net = False
    # elif args.train_alg == 'reach_target_loss':
    #     iterations = inf
    #     precision = 0.00001
    #     ''' set target loss '''
    #     loss_rlnl_train, error_rlnl_train = evalaute_mdl_data_set(criterion, error_criterion, net_rlnl, trainloader,device, iterations)
    #     target_train_loss = loss_rlnl_train
    #     ''' do SGD '''
    #     trainer = Trainer(trainloader,testloader, optimizer,criterion,error_criterion, stats_collector, device)
    #     last_errors = trainer.train_and_track_stats(net,nb_epochs,iterations=iterations,target_train_loss=target_train_loss,precision=precision)
    #     ''' Test the Network on the test data '''
    #     train_loss_epoch, train_error_epoch, test_loss_epoch, test_error_epoch = last_errors
    #     print(f'train_loss_epoch={train_loss_epoch} train_error_epoch={train_error_epoch}')
    #     print(f'test_loss_epoch={test_loss_epoch} test_error_epoch={test_error_epoch}')
    #     st()
    elif args.train_alg == 'no_train':
        print('NO TRAIN BRANCH')
    print(f'expt_path={expt_path}')
    utils.make_and_check_dir(expt_path)
    ''' save times '''
    seconds_training, minutes_training, hours_training = utils.report_times(training_time,meta_str='training')
    other_stats = dict({'seconds_training': seconds_training, 'minutes_training': minutes_training, 'hours_training': hours_training}, **other_stats)
    seconds, minutes, hours = seconds_training+seconds_setup, minutes_training+minutes_setup, hours_training+hours_setup
    other_stats = dict({'seconds':seconds,'minutes':minutes,'hours':hours}, **other_stats)
    print(f'nb_epochs = {nb_epochs}')
    print(f'Finished Training, hours={hours}')
    print(f'seed = {seed}, githash = {githash}')
    ''' save results from experiment '''
    store_results = not args.dont_save_expt_results
    print(f'ALL other_stats={other_stats}')
    if store_results:
        print(f'storing results!')
        matlab_path_to_filename = os.path.join(expt_path,matlab_file_name)
        save2matlab.save2matlab_flatness_expt(matlab_path_to_filename, stats_collector,other_stats=other_stats)
    ''' save net model '''
    if store_net:
        print(f'saving final net mdl!')
        net_path_to_filename = os.path.join(expt_path,net_file_name)
        torch.save(net,net_path_to_filename)
        ''' check the error of net saved '''
        loss_original, error_original = evalaute_mdl_data_set(criterion, error_criterion, net, trainloader,device)
        restored_net = utils.restore_entire_mdl(net_path_to_filename)
        loss_restored,error_restored = evalaute_mdl_data_set(criterion,error_criterion,restored_net,trainloader,device)
        print()
        print(f'net_path_to_filename = {net_path_to_filename}')
        print(f'loss_original={loss_original},error_original={error_original}\a')
        print(f'loss_restored={loss_restored},error_restored={error_restored}\a')
    ''' send e-mail '''
    if hostname == 'polestar' or args.email:
        message = f'SLURM Job_id=MANUAL Name=flatness_expts.py Ended, ' \
                  f'Total Run time hours:{hours},minutes:{minutes},seconds:{seconds} COMPLETED, ExitCode [0-0]'
        utils.send_email(message,destination='*****@*****.**')
    ''' plot '''
    if sj == 0 and plot:
        #TODO
        plot_utils.plot_loss_and_accuracies(stats_collector)
        plt.show()
def main(plot):
    start_time = time.time()
    ''' Directory names '''
    path = '../pytorch_experiments/test_runs_flatness/keras_expt_April_19'
    filename = f'chance_plateau_{satid}'
    utils.make_and_check_dir(path)
    ''' experiment type '''
    #expt = 'BoixNet'
    #expt = 'LiaoNet'
    expt = 'GBoixNet'
    #expt = 'debug'
    ''' declare variables '''
    batch_size = 2**10
    num_classes = 10
    nb_epochs = 300
    lr = 0.1
    ''' load cifar '''
    standardize = True
    (x_train, y_train), (x_test,
                         y_test) = load_cifar10(num_classes,
                                                standardize=standardize)
    print(x_train[0, :])
    ''' params for cnn'''
    print(f'expt = {expt}')
    if expt == 'BoixNet':
        nb_conv_layers = 2
        nb_conv_filters = [32] * nb_conv_layers
        kernels = [(5, 5)] * nb_conv_layers
        nb_fc_layers = 3
        nb_units_fcs = [512, 256, num_classes]
    elif expt == 'LiaoNet':
        nb_conv_layers, nb_fc_layers = 3, 1
        nb_conv_filters = [32] * nb_conv_layers
        kernels = [(5, 5)] * nb_conv_layers
        nb_units_fcs = [num_classes]
    elif expt == 'GBoixNet':
        cnn_filename = f'keras_net_{satid}'
        nb_conv_layers, nb_fc_layers = 1, 2
        nb_conv_filters = [22] * nb_conv_layers
        kernels = [(5, 5)] * nb_conv_layers
        nb_units_fcs = [30, num_classes]
    elif expt == 'debug':
        nb_conv_layers, nb_fc_layers = 1, 1
        nb_conv_filters = [2] * nb_conv_layers
        kernels = [(10, 10)] * nb_conv_layers
        nb_units_fcs = [2, num_classes]
    CHW = x_train.shape[1:]  # (3, 32, 32)
    ''' get model '''
    cnn_n = model_convs_FCs(CHW, nb_conv_layers, nb_fc_layers, nb_conv_filters,
                            kernels, nb_units_fcs)
    cnn_n.summary()
    compile_mdl_with_sgd(cnn_n, lr, weight_decay=0, momentum=0, nesterov=False)
    ''' Fit model '''
    cnn = cnn_n.fit(x_train,
                    y_train,
                    batch_size=batch_size,
                    epochs=nb_epochs,
                    validation_data=(x_test, y_test),
                    shuffle=True)
    seconds, minutes, hours = utils.report_times(start_time)
    print(f'\nFinished Training, hours={hours}\a')
    ''' save history and mdl '''
    path_2_save = os.path.join(path, filename)
    print(f'path_2_save = {path_2_save}')
    print(f'does dir exist? {os.path.isdir(path)}')
    # save history
    with open(path_2_save, 'wb+') as file_pi:
        history = dict(
            {
                'batch_size': batch_size,
                'nb_epochs': nb_epochs,
                'lr': lr,
                'expt': expt
            }, **cnn.history)
        pickle.dump(history, file_pi)
    # save model
    cnn_n.save(os.path.join(path, cnn_filename))
    ''' Plots '''
    if plot:
        # Plots for training and testing process: loss and accuracy
        plt.figure(0)
        plt.plot(cnn.history['acc'], 'r')
        plt.plot(cnn.history['val_acc'], 'g')
        plt.xticks(np.arange(0, nb_epochs + 1, 2.0))
        plt.rcParams['figure.figsize'] = (8, 6)
        plt.xlabel("Num of Epochs")
        plt.ylabel("Accuracy")
        plt.title("Training Accuracy vs Validation Accuracy")
        plt.legend(['train', 'validation'])

        plt.figure(1)
        plt.plot(cnn.history['loss'], 'r')
        plt.plot(cnn.history['val_loss'], 'g')
        plt.xticks(np.arange(0, nb_epochs + 1, 2.0))
        plt.rcParams['figure.figsize'] = (8, 6)
        plt.xlabel("Num of Epochs")
        plt.ylabel("Loss")
        plt.title("Training Loss vs Validation Loss")
        plt.legend(['train', 'validation'])

        plt.show()
def main(plot=False):
    ''' date parameters setup'''
    today_obj = date.today() # contains datetime.date(year, month, day); accessible via .day etc
    day = today_obj.day
    month = calendar.month_name[today_obj.month]
    start_time = time.time()
    ''' '''
    label_corrupt_prob = 0
    results_root = './test_runs_flatness'
    expt_path = f'flatness_label_corrupt_prob_{label_corrupt_prob}_debug2'
    matlab_file_name = f'flatness_{day}_{month}'
    ''' '''
    nb_epochs = 200
    batch_size = 256
    #batch_size_train,batch_size_test = batch_size,batch_size
    batch_size_train = batch_size
    batch_size_test = 256
    data_path = './data'
    num_workers = 2 # how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
    ''' get data set '''
    standardize = True
    trainset,trainloader, testset,testloader, classes = data_class.get_cifer_data_processors(data_path,batch_size_train,batch_size_test,num_workers,label_corrupt_prob,standardize=standardize)
    ''' get NN '''
    mdl = 'cifar_10_tutorial_net'
    mdl = 'BoixNet'
    mdl = 'LiaoNet'
    ##
    print(f'model = {mdl}')
    if mdl == 'cifar_10_tutorial_net':
        do_bn = False
        net = nn_mdls.Net()
    elif mdl == 'BoixNet':
        do_bn=False
        ## conv params
        nb_filters1,nb_filters2 = 32, 32
        kernel_size1,kernel_size2 = 5,5
        ## fc params
        nb_units_fc1,nb_units_fc2,nb_units_fc3 = 512,256,len(classes)
        C,H,W = 3,32,32
        net = nn_mdls.BoixNet(C,H,W,nb_filters1,nb_filters2, kernel_size1,kernel_size2, nb_units_fc1,nb_units_fc2,nb_units_fc3,do_bn)
    elif mdl == 'LiaoNet':
        do_bn=False
        nb_conv_layers=3
        ## conv params
        Fs = [32]*nb_conv_layers
        Ks = [5]*nb_conv_layers
        ## fc params
        FC = len(classes)
        C,H,W = 3,32,32
        net = nn_mdls.LiaoNet(C,H,W,Fs,Ks,FC,do_bn)
    # elif mdl == 'MMNISTNet':
    #     net = MMNISTNet()
    if args.enable_cuda:
        net.cuda()
    nb_params = nn_mdls.count_nb_params(net)
    ''' Cross Entropy + Optmizer'''
    lr = 0.01
    momentum = 0.0
    #error_criterion = metrics.error_criterion
    error_criterion = metrics.error_criterion2
    #criterion = torch.nn.CrossEntropyLoss()
    criterion = torch.nn.MultiMarginLoss()
    #loss = torch.nn.MSELoss(size_average=True)
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
    ''' stats collector '''
    stats_collector = tr_alg.StatsCollector(net,None,None)
    dynamic_stats = ['final_act',(np.zeros(10,nb_epochs),tr_alg.)]
    ''' Train the Network '''
    print(f'----\nSTART training: label_corrupt_prob={label_corrupt_prob},nb_epochs={nb_epochs},batch_size={batch_size},lr={lr},mdl={mdl},batch-norm={do_bn},nb_params={nb_params}')
    overparametrized = len(trainset)<nb_params # N < W ?
    print(f'Model over parametrized? N, W = {len(trainset)} vs {nb_params}')
    print(f'Model over parametrized? N < W = {overparametrized}')
    # We simply have to loop over our data iterator, and feed the inputs to the network and optimize.
    #tr_alg.train_cifar(args, nb_epochs, trainloader,testloader, net,optimizer,criterion)
    train_loss_epoch, train_error_epoch, test_loss_epoch, test_error_epoch = tr_alg.train_and_track_stats2(args, nb_epochs, trainloader,testloader, net,optimizer,criterion,error_criterion, stats_collector)
    seconds,minutes,hours = utils.report_times(start_time)
    print(f'Finished Training, hours={hours}')
    ''' Test the Network on the test data '''
    print(f'train_loss_epoch={train_loss_epoch} \ntrain_error_epoch={train_error_epoch} \ntest_loss_epoch={test_loss_epoch} \ntest_error_epoch={test_error_epoch}')
    ''' save results from experiment '''
    other_stats = {'nb_epochs':nb_epochs,'batch_size':batch_size,'mdl':mdl,'lr':lr,'momentum':momentum}
    save2matlab.save2matlab_flatness_expt(results_root,expt_path,matlab_file_name, stats_collector,other_stats=other_stats)
    ''' save net model '''
    path = os.path.join(results_root,expt_path,f'net_{day}_{month}')
    utils.save_entire_mdl(path,net)
    restored_net = utils.restore_entire_mdl(path)
    loss_restored,error_restored = tr_alg.evalaute_mdl_data_set(criterion,error_criterion,restored_net,testloader,args.enable_cuda)
    print(f'\nloss_restored={loss_restored},error_restored={error_restored}\a')
    ''' plot '''
    if plot:
        #TODO
        plot_utils.plot_loss_and_accuracies(stats_collector)
        plt.show()
Ejemplo n.º 6
0
    target_loss = 0.0044
    normalizer = Normalizer(list_names,
                            data_path,
                            normalization_scheme,
                            p,
                            division_constant,
                            data_set_type,
                            type_standardize=type_standardize)
    #results = normalizer.extract_all_results_vs_test_errors(path_all_expts, target_loss)
    results = normalizer.get_hist_from_single_net(path_all_expts)
    ''' '''
    path = os.path.join(
        path_all_expts,
        f'{RL_str}loss_vs_gen_errors_norm_{norm}_data_set_{data_set_type}')
    #path = os.path.join(path_all_expts, f'RL_corruption_1.0_loss_vs_gen_errors_norm_{norm}')
    #path = os.path.join(path_all_expts,f'loss_vs_gen_errors_norm_{norm}_final')
    scipy.io.savemat(path, results)
    ''' plot '''
    #plt.scatter(train_all_losses,gen_all_errors)
    #plt.show()
    ''' cry '''
    print('\a')


if __name__ == '__main__':
    time_taken = time.time()
    main()
    seconds_setup, minutes_setup, hours_setup = utils.report_times(time_taken)
    print('end')
    print('\a')
def main(plotting=False,save=False):
    ''' setup'''
    start_time = time.time()
    np.set_printoptions(suppress=True) #Whether or not suppress printing of small floating point values using scientific notation (default False).
    ##dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU
    ''' pytorch dtype setup '''
    # dtype_y = torch.LongTensor
    dtype_x = torch.FloatTensor
    dtype_y = torch.FloatTensor
    # dtype_x = torch.cuda.FloatTensor
    # dtype_y = torch.cuda.FloatTensor
    ''' date parameters setup'''
    today_obj = date.today() # contains datetime.date(year, month, day); accessible via .day etc
    day = today_obj.day
    month = calendar.month_name[today_obj.month]
    ''' Model to train setup param '''
    #MDL_2_TRAIN='logistic_regression_vec_mdl'
    #MDL_2_TRAIN='logistic_regression_poly_mdl'
    MDL_2_TRAIN = 'regression_poly_mdl'
    #MDL_2_TRAIN = 'HBF'
    ''' data file names '''
    truth_filename=''
    data_filename=''
    #data_filename = 'classification_manual'
    data_filename = 'regression_manual'
    ''' Folder for experiment '''
    experiment_name = 'RedoFig5_Cheby'
    ##########
    ''' Regularization '''
    ##
    #reg_type = 'VW'
    #reg_type = 'V2W_D3'
    reg_type = ''
    reg = 0
    ''' Experiment LAMBDA experiment params '''
    # expt_type = 'LAMBDAS'
    # N_lambdas = 50
    # lb,ub = 0.01,10000
    # one_over_lambdas = np.linspace(lb,ub,N_lambdas)
    # lambdas = list( 1/one_over_lambdas )
    # lambdas = N_lambdas*[0.0]
    # nb_iterations = [int(1.4*10**6)]
    # repetitions = len(lambdas)*[15]
    ''' Experiment ITERATIONS experiment params '''
    # expt_type = 'ITERATIONS'
    # N_iterations = 30
    # lb,ub = 1,60*10**4
    # lambdas = [0]
    # nb_iterations = [ int(i) for i in np.linspace(lb,ub,N_iterations)]
    # repetitions = len(nb_iterations)*[10]
    ''' Experiment DEGREE/MONOMIALS '''
    expt_type='DEGREES'
    step_deg=1
    lb_deg,ub_deg = 39,39
    degrees = list(range(lb_deg,ub_deg+1,step_deg))
    lambdas = [0]
    #nb_iterations = [int(2500000)]
    #nb_iterations = [int(1000000)]
    #nb_iterations = [int(5 * 10**6)]
    #nb_iterations = [int(1.1 * 10 ** 7)]
    repetitions = len(degrees)*[30]
    ''' Experiment Number of vector elements'''
    expt_type='NB_VEC_ELEMENTS'
    step=1
    lb_vec,ub_vec = 30,30
    nb_elements_vecs = list(range(lb_vec,ub_vec+1,step))
    lambdas = [0]
    nb_iterations = [int(250000)]
    #nb_iterations = [int(2500)]
    repetitions = len(nb_elements_vecs)*[1]
    ''' Get setup for process to run '''
    ps_params = NamedDict() # process params
    if expt_type == 'LAMBDAS':
        ps_params.degrees=[]
        ps_params.reg_lambda = dispatcher_code.get_hp_to_run(hyper_params=lambdas,repetitions=repetitions,satid=satid)
        ps_params.nb_iter = nb_iterations[0]
        #ps_params.prefix_experiment = f'it_{nb_iter}/lambda_{reg_lambda}_reg_{reg_type}'
    elif expt_type == 'ITERATIONS':
        ps_params.degrees=[]
        ps_params.reg_lambda = lambdas[0]
        ps_params.nb_iter = dispatcher_code.get_hp_to_run(hyper_params=nb_iterations,repetitions=repetitions,satid=satid)
        #ps_params.prefix_experiment = f'lambda_{reg_lambda}/it_{nb_iter}_reg_{reg_type}'
    elif expt_type == 'DEGREES':
        ps_params.reg_lambda = lambdas[0]
        ps_params.degree_mdl = dispatcher_code.get_hp_to_run(hyper_params=degrees,repetitions=repetitions,satid=satid)
        #ps_params.prefix_experiment = f'fig4_expt_lambda_{reg_lambda}_it_{nb_iter}/deg_{Degree_mdl}'
        hp_param = ps_params.degree_mdl
    elif expt_type == 'NB_VEC_ELEMENTS':
        ps_params.reg_lambda = lambdas[0]
        ps_params.nb_elements_vec = dispatcher_code.get_hp_to_run(hyper_params=nb_elements_vecs,repetitions=repetitions,satid=satid)
        ps_params.nb_iter = nb_iterations[0]
        #ps_params.prefix_experiment = f'it_{ps_params.nb_iter}/lambda_{ps_params.reg_lambda}_reg_{reg_type}'
    else:
        raise ValueError(f'Experiment type expt_type={expt_type} does not exist, try a different expt_type.')
    print(f'ps_params={ps_params}')
    ######## data set
    ''' Get data set'''
    if data_filename == 'classification_manual':
        N_train,N_val,N_test = 81,100,500
        lb,ub = -1,1
        w_target = np.array([1,1])
        f_target = lambda x: np.int64( (np.dot(w_target,x) > 0).astype(int) )
        Xtr,Ytr, Xv,Yv, Xt,Yt = data_class.get_2D_classification_data(N_train,N_val,N_test,lb,ub,f_target)
    elif data_filename == 'regression_manual':
        N_train,N_val,N_test = 9,81,100
        lb,ub = -1,1
        f_target = lambda x: np.sin(2*np.pi*4*x)
        Xtr,Ytr, Xv,Yv, Xt,Yt = data_reg.get_2D_regression_data_chebyshev_nodes(N_train,N_val,N_test,lb,ub,f_target)
    else:
        data = np.load( './data/{}'.format(data_filename) )
        if 'lb' and 'ub' in data:
            data_lb, data_ub = data['lb'],data['ub']
        else:
            raise ValueError('Error, go to code and fix lb and ub')
    N_train,N_test = Xtr.shape[0], Xt.shape[0]
    print(f'N_train={N_train}, N_test={N_test}')
    ########
    ''' SGD params '''
    #optimizer_mode = 'SGD_AND_PERTURB'
    optimizer_mode = 'SGD_train_then_pert'
    M = int(Xtr.shape[0])
    #M = int(81)
    eta = 0.2
    momentum = 0.0
    nb_iter = nb_iterations[0]
    A = 0.0
    ##
    logging_freq = 1
    ''' MODEL '''
    if MDL_2_TRAIN=='logistic_regression_vec_mdl':
        in_features=31
        n_classes=1
        bias=False
        mdl = mdl_lreg.get_logistic_regression_mdl(in_features,n_classes,bias)
        loss = torch.nn.CrossEntropyLoss(size_average=True)
        ''' stats collector '''
        loss_collector = lambda mdl,X,Y: calc_loss(mdl,loss,X,Y)
        acc_collector = calc_accuracy
        acc_collector = calc_error
        stats_collector = tr_alg.StatsCollector(mdl, loss_collector,acc_collector)
        ''' make features for data '''
        poly = PolynomialFeatures(in_features-1)
        Xtr,Xv,Xt = poly.fit_transform(Xtr), poly.fit_transform(Xv), poly.fit_transform(Xt)
    elif MDL_2_TRAIN == 'regression_poly_mdl':
        in_features = degrees[0]+1
        mdl = mdl_lreg.get_logistic_regression_mdl(in_features, 1, bias=False)
        loss = torch.nn.MSELoss(size_average=True)
        ''' stats collector '''
        loss_collector = lambda mdl, X, Y: calc_loss(mdl, loss, X, Y)
        acc_collector = loss_collector
        acc_collector = loss_collector
        stats_collector = tr_alg.StatsCollector(mdl, loss_collector, acc_collector)
        ''' make features for data '''
        poly = PolynomialFeatures(in_features - 1)
        Xtr, Xv, Xt = poly.fit_transform(Xtr), poly.fit_transform(Xv), poly.fit_transform(Xt)
    elif MDL_2_TRAIN=='HBF':
        bias=True
        D_in, D_out = Xtr.shape[0], Ytr.shape[1]
        ## RBF
        std = (Xtr[1] - Xtr[0])/ 0.8 # less than half the avg distance #TODO use np.mean
        centers=Xtr
        mdl = hkm.OneLayerHBF(D_in,D_out, centers=centers,std=std, train_centers=False,train_std=False)
        mdl[0].weight.data.fill_(0)
        mdl[0].bias.data.fill_(0)
        loss = torch.nn.MSELoss(size_average=True)
        ''' stats collector '''
        loss_collector = lambda mdl,X,Y: tr_alg.calc_loss(mdl,loss,X,Y)
        acc_collector = loss_collector
        ''' dynamic stats collector '''
        c_pinv = hkm.get_rbf_coefficients(X=Xtr,centers=Xtr,Y=Ytr,std=std)
        def diff_GD_vs_PINV(storer, i, mdl, Xtr,Ytr,Xv,Yv,Xt,Yt):
            c_pinv_torch = torch.FloatTensor( c_pinv )
            diff_GD_pinv = (mdl.C.weight.data.t() - c_pinv_torch).norm(2)
            storer.append(diff_GD_pinv)
        dynamic_stats = NamedDict(diff_GD_vs_PINV=([],diff_GD_vs_PINV))
        ##
        stats_collector = tr_alg.StatsCollector(mdl, loss_collector,acc_collector,dynamic_stats=dynamic_stats)
    else:
        raise ValueError(f'MDL_2_TRAIN={MDL_2_TRAIN}')
    ''' TRAIN '''
    perturbfreq = 1.1 * 10**5
    perturb_magnitude = 0.45
    if optimizer_mode =='SGD_AND_PERTURB':
        ##
        momentum = 0.0
        optim = torch.optim.SGD(mdl.parameters(), lr=eta, momentum=momentum)
        ##
        reg_lambda = ps_params.reg_lambda
        tr_alg.SGD_perturb(mdl, Xtr,Ytr,Xv,Yv,Xt,Yt, optim,loss, M,eta,nb_iter,A ,logging_freq,
            dtype_x,dtype_y, perturbfreq,perturb_magnitude,
            reg=reg,reg_lambda=reg_lambda,
            stats_collector=stats_collector)
    elif optimizer_mode == 'SGD_train_then_pert':
        iterations_switch_mode = 1 # never perturb
        #iterations_switch_mode = nb_iter # always perturb
        iterations_switch_mode = nb_iter/2 # perturb for half
        print(f'iterations_switch_mode={iterations_switch_mode}')
        ##
        optimizer = torch.optim.SGD(mdl.parameters(), lr=eta, momentum=momentum)
        ##
        reg_lambda = ps_params.reg_lambda
        tr_alg.SGD_pert_then_train(mdl, Xtr,Ytr,Xv,Yv,Xt,Yt, optimizer,loss, M,nb_iter ,logging_freq ,dtype_x,dtype_y,
                                   perturbfreq,perturb_magnitude, iterations_switch_mode, reg,reg_lambda, stats_collector)
    else:
        raise ValueError(f'MDL_2_TRAIN={MDL_2_TRAIN} not implemented')
    seconds,minutes,hours = utils.report_times(start_time)
    ''' Plots and Print statements'''
    print('\n----\a\a')
    print(f'some SGD params: batch_size={M}, eta={eta}, nb_iterations={nb_iter}')
    if save:
        ''' save experiment results to maltab '''
        experiment_results=stats_collector.get_stats_dict()
        experiment_results=NamedDict(seconds=seconds,minutes=minutes,hours=hours,**experiment_results)
        save2matlab.save_experiment_results_2_matlab(experiment_results=experiment_results,
            root_path=f'./test_runs_flatness3',
            experiment_name=experiment_name,
            training_config_name=f'nb_iterations_{nb_iterations[0]}_N_train_{Xtr.shape[0]}_N_test_{Xt.shape[0]}_batch_size_{M}_perturb_freq_{perturbfreq}_perturb_magnitude_{perturb_magnitude}_momentum_{momentum}_iterations_switch_mode_{iterations_switch_mode}',
            main_experiment_params=f'{expt_type}_lambda_{ps_params.reg_lambda}_it_{nb_iter}_reg_{reg_type}',
            expt_type=f'expt_type_{expt_type}_{hp_param}',
            matlab_file_name=f'satid_{satid}_sid_{sj}_{month}_{day}'
            )
    if MDL_2_TRAIN=='HBF':
        ''' print statements R/HBF'''
        print(f'distance_btw_data_points={Xtr[1] - Xtr[0]}')
        print(f'std={std}')
        print(f'less than half the average distance?={(std < (Xtr[1] - Xtr[0])/2)}')
        beta = (1.0/std)**2
        rank = np.linalg.matrix_rank( np.exp( -beta*hkm.euclidean_distances_manual(x=Xtr,W=centers.T) ) )
        print(f'rank of Kernel matrix = Rank(K) = {rank}')
        ''' plots for R/HBF'''
        f_mdl = lambda x: mdl( Variable(torch.FloatTensor(x),requires_grad=False) ).data.numpy()
        f_pinv = lambda x: hkm.f_rbf(x,c=c_pinv,centers=Xtr,std=std)
        f_target = f_target
        iterations = np.array(range(0,nb_iter))
        N_denseness = 1000
        legend_hyper_params=f'N_train={Xtr.shape[0]},N_test={Xt.shape[0]},batch-size={M},learning step={eta},# iterations = {nb_iter} momentum={momentum}, Model=Gaussian, # centers={centers.shape[0]}, std={std[0]}'
        ''' PLOT '''
        ## plots
        plot_utils.plot_loss_errors(iterations,stats_collector,test_error_pinv=data_utils.l2_np_loss(f_pinv(Xt),Yt),legend_hyper_params=legend_hyper_params)
        plot_utils.visualize_1D_reconstruction(lb,ub,N_denseness, f_mdl,f_target=f_target,f_pinv=f_pinv,X=Xtr,Y=Ytr,legend_data_set='Training data points')
        plot_utils.plot_sgd_vs_pinv_distance_during_training(iterations,stats_collector)
        #plot_utils.print_gd_vs_pinv_params(mdl,c_pinv)
        plt.show()
    elif MDL_2_TRAIN=='logistic_regression_vec_mdl':
        ''' arguments for plotting things '''
        f_mdl = lambda x: mdl( Variable(torch.FloatTensor(x),requires_grad=False) ).data.numpy()
        f_target = lambda x: -1*(w_target[0]/w_target[1])*x
        iterations = np.array(range(0,nb_iter))
        N_denseness = 1000
        legend_hyper_params=f'N_train={Xtr.shape[0]},N_test={Xt.shape[0]},batch-size={M},learning step={eta},# iterations = {nb_iter} momentum={momentum}, Model=Logistic Regression'
        ''' PLOT '''
        ## plots
        plot_utils.plot_loss_errors(iterations,stats_collector,legend_hyper_params=legend_hyper_params)
        plot_utils.plot_loss_classification_errors(iterations,stats_collector,legend_hyper_params=legend_hyper_params)
        plot_utils.visualize_classification_data_learned_planes_2D(lb,ub,N_denseness,Xtr,Ytr,f_mdl,f_target)
        plot_utils.plot_weight_norm_vs_iterations(iterations,stats_collector.w_norms[0])
        plt.show()
    if plotting:
        legend_hyper_params = f'N_train={Xtr.shape[0]},N_test={Xt.shape[0]},batch-size={M},learning step={eta},# iterations = {nb_iter} momentum={momentum}, Model=Regression'
        iterations = np.array(range(0, nb_iter))
        plot_utils.plot_loss_errors(iterations, stats_collector, legend_hyper_params=legend_hyper_params)
        plot_utils.plot_weight_norm_vs_iterations(iterations, stats_collector.w_norms[0])
        plt.show()