예제 #1
0
def train_with_vgg(lr,
                   route_iter,
                   train_file_pre,
                   test_file_pre,
                   out_dir_pre,
                   n_classes,
                   folds=[4, 9],
                   model_name='vgg_capsule_disfa',
                   epoch_stuff=[30, 60],
                   res=False,
                   reconstruct=False,
                   loss_weights=None,
                   exp=False,
                   dropout=0,
                   gpu_id=0,
                   aug_more='flip',
                   model_to_test=None,
                   save_after=10,
                   batch_size=32,
                   batch_size_val=32,
                   criterion='marginmulti'):

    # torch.setdefaulttensortype('torch.FloatTensor')

    num_epochs = epoch_stuff[1]

    if model_to_test is None:
        model_to_test = num_epochs - 1

    epoch_start = 0
    if exp:
        dec_after = ['exp', 0.96, epoch_stuff[0], 1e-6]
    else:
        dec_after = ['step', epoch_stuff[0], 0.1]

    lr = lr
    im_resize = 256
    im_size = 224
    model_file = None
    margin_params = None

    for split_num in folds:
        # post_pend = [split_num,'reconstruct',reconstruct]+aug_more+[num_epochs]+dec_after+lr+[dropout]
        # out_dir_train =  '_'.join([str(val) for val in [out_dir_pre]+post_pend]);
        out_dir_train = get_out_dir_train_name(out_dir_pre, lr, route_iter,
                                               split_num, epoch_stuff,
                                               reconstruct, exp, dropout,
                                               aug_more)

        print out_dir_train
        # raw_input()

        final_model_file = os.path.join(out_dir_train,
                                        'model_' + str(num_epochs - 1) + '.pt')
        if os.path.exists(final_model_file):
            print 'skipping', final_model_file
            # continue
        else:
            print 'not skipping', final_model_file

        train_file = train_file_pre + str(split_num) + '.txt'
        test_file = test_file_pre + str(split_num) + '.txt'

        class_weights = util.get_class_weights_au(
            util.readLinesFromFile(train_file))
        # class_weights = None

        mean_std = np.array([[93.5940, 104.7624, 129.1863], [1., 1.,
                                                             1.]])  #bgr
        std_div = np.array([0.225 * 255, 0.224 * 255, 0.229 * 255])
        bgr = True

        list_of_to_dos = aug_more
        print list_of_to_dos

        data_transforms = {}
        train_resize = None
        list_transforms = []
        if 'hs' in list_of_to_dos:
            print '**********HS!!!!!!!'
            list_transforms.append(
                lambda x: augmenters.random_crop(x, im_size))
            list_transforms.append(lambda x: augmenters.hide_and_seek(x))
            if 'flip' in list_of_to_dos:
                list_transforms.append(lambda x: augmenters.horizontal_flip(x))
            list_transforms.append(transforms.ToTensor())
        elif 'flip' in list_of_to_dos and len(list_of_to_dos) == 1:
            train_resize = im_size
            list_transforms.extend([
                lambda x: augmenters.horizontal_flip(x),
                transforms.ToTensor()
            ])
        elif 'none' in list_of_to_dos:
            train_resize = im_size
            list_transforms.append(transforms.ToTensor())

            # data_transforms['train']= transforms.Compose([
            #     # lambda x: augmenters.random_crop(x,im_size),
            #     transforms.ToTensor(),
            # ])
        else:
            # data_transforms['train']= transforms.Compose([
            list_transforms.append(
                lambda x: augmenters.random_crop(x, im_size))
            list_transforms.append(lambda x: augmenters.augment_image(
                x, list_of_to_dos, color=True, im_size=im_size))
            list_transforms.append(transforms.ToTensor())
            # lambda x: x*255.
            # ])

        list_transforms_val = [transforms.ToTensor()]

        if torch.version.cuda.startswith('9.1'):
            list_transforms.append(lambda x: x.float())
        else:
            list_transforms.append(lambda x: x * 255.)

        data_transforms['train'] = transforms.Compose(list_transforms)
        data_transforms['val'] = transforms.Compose(list_transforms_val)

        train_data = dataset.Bp4d_Dataset_with_mean_std_val(
            train_file,
            bgr=bgr,
            binarize=False,
            mean_std=mean_std,
            transform=data_transforms['train'],
            resize=train_resize)
        test_data = dataset.Bp4d_Dataset_with_mean_std_val(
            test_file,
            bgr=bgr,
            binarize=False,
            mean_std=mean_std,
            transform=data_transforms['val'],
            resize=im_size)

        network_params = dict(n_classes=n_classes,
                              pool_type='max',
                              r=route_iter,
                              init=False,
                              class_weights=class_weights,
                              reconstruct=reconstruct,
                              loss_weights=loss_weights,
                              std_div=std_div,
                              dropout=dropout)

        util.makedirs(out_dir_train)

        train_params = dict(out_dir_train=out_dir_train,
                            train_data=train_data,
                            test_data=test_data,
                            batch_size=batch_size,
                            batch_size_val=batch_size_val,
                            num_epochs=num_epochs,
                            save_after=save_after,
                            disp_after=1,
                            plot_after=100,
                            test_after=10,
                            lr=lr,
                            dec_after=dec_after,
                            model_name=model_name,
                            criterion=criterion,
                            gpu_id=gpu_id,
                            num_workers=0,
                            model_file=model_file,
                            epoch_start=epoch_start,
                            margin_params=margin_params,
                            network_params=network_params,
                            weight_decay=0)
        test_params = dict(out_dir_train=out_dir_train,
                           model_num=model_to_test,
                           train_data=train_data,
                           test_data=test_data,
                           gpu_id=gpu_id,
                           model_name=model_name,
                           batch_size_val=batch_size_val,
                           criterion=criterion,
                           margin_params=margin_params,
                           network_params=network_params,
                           post_pend='',
                           barebones=True)

        print train_params
        param_file = os.path.join(out_dir_train, 'params.txt')
        all_lines = []
        for k in train_params.keys():
            str_print = '%s: %s' % (k, train_params[k])
            print str_print
            all_lines.append(str_print)

        train_model_recon(**train_params)
        test_model_recon(**test_params)
예제 #2
0
def train_khorrami_aug_mmi(wdecay,lr,route_iter,folds=[4,9],model_name='vgg_capsule_disfa',epoch_stuff=[30,60],res=False, class_weights = False, reconstruct = False, oulu = False, meta_data_dir = None,loss_weights = None, exp = False, non_peak = False, model_to_test = None):
    out_dirs = []

    out_dir_meta = '../experiments/'+model_name+str(route_iter)
    num_epochs = epoch_stuff[1]
    if model_to_test is None:
        model_to_test = num_epochs -1

    epoch_start = 0
    if exp:
        dec_after = ['exp',0.96,epoch_stuff[0],1e-6]
    # dec_after = ['exp',0.96,epoch_stuff[0],1e-6]
    else:
        dec_after = ['step',epoch_stuff[0],0.1]

    lr = lr
    im_resize = 110
    im_size = 96
    save_after = 10
    
    type_data = 'train_test_files'; n_classes = 6;
    train_pre = os.path.join('../data/mmi',type_data)
    test_pre = train_pre
    pre_pend = 'mmi_96_'
    # test_pre =  os.path.join('../data/mmi','train_test_files')

    criterion = 'margin'
    criterion_str = criterion

    # criterion = nn.CrossEntropyLoss()
    # criterion_str = 'crossentropy'
    
    init = False
    strs_append_list = ['reconstruct',reconstruct,class_weights,'all_aug',criterion_str,init,'wdecay',wdecay,num_epochs]+dec_after+lr

    if loss_weights is not None:
        strs_append_list = strs_append_list     +['lossweights']+loss_weights
    strs_append = '_'+'_'.join([str(val) for val in strs_append_list])
    
    # if oulu:
    #     pre_pend = 'oulu_96_'+meta_data_dir+'_'
    # else:
    #     pre_pend = 'ck_96_'+type_data+'_'
    
    lr_p=lr[:]
    for split_num in folds:
        
        if res:
            pass
        else:
            model_file = None    


        margin_params = None
        
        out_dir_train =  os.path.join(out_dir_meta,pre_pend+str(split_num)+strs_append)
        final_model_file = os.path.join(out_dir_train,'model_'+str(num_epochs-1)+'.pt')
        if os.path.exists(final_model_file):
            print 'skipping',final_model_file
            # raw_input()
            # continue 
        else:
            print 'not skipping', final_model_file
            raw_input()
            # continue

        train_file = os.path.join(train_pre,'train_'+str(split_num)+'.txt')
        test_file_easy = os.path.join(train_pre,'test_front_'+str(split_num)+'.txt')        
        test_file = os.path.join(test_pre,'test_side_'+str(split_num)+'.txt')
        mean_file = os.path.join(train_pre,'train_'+str(split_num)+'_mean.png')
        std_file = os.path.join(train_pre,'train_'+str(split_num)+'_std.png')


        mean_im = scipy.misc.imread(mean_file).astype(np.float32)
        std_im = scipy.misc.imread(std_file).astype(np.float32)

        class_weights = util.get_class_weights(util.readLinesFromFile(train_file))

        # print std_im.shape
        # print np.min(std_im),np.max(std_im)
        # raw_input()

        list_of_to_dos = ['flip','rotate','scale_translate']
        # , 'pixel_augment']
        
        data_transforms = {}
        data_transforms['train']= transforms.Compose([
            lambda x: augmenters.random_crop(x,im_size),
            lambda x: augmenters.augment_image(x,list_of_to_dos),
            # ,mean_im,std_im,im_size),
            transforms.ToTensor(),
            lambda x: x*255.
        ])
        data_transforms['val']= transforms.Compose([
            transforms.ToTensor(),
            lambda x: x*255.
            ])

        # train_data = dataset.CK_96_Dataset_Just_Mean(train_file, mean_file, data_transforms['train'])
        # test_data = dataset.CK_96_Dataset_Just_Mean(test_file, mean_file, data_transforms['val'])

        print train_file
        print test_file
        print std_file
        print mean_file
        # raw_input()

        train_data = dataset.CK_96_Dataset_with_rs(train_file, mean_file, std_file, data_transforms['train'])
        train_data_no_t = dataset.CK_96_Dataset_with_rs(test_file_easy, mean_file, std_file, data_transforms['val'],resize = im_size)
        test_data = dataset.CK_96_Dataset_with_rs(test_file, mean_file, std_file, data_transforms['val'],resize = im_size)
        
        network_params = dict(n_classes=n_classes,pool_type='max',r=route_iter,init=init,class_weights = class_weights, reconstruct = reconstruct,loss_weights = loss_weights)
        # if lr[0]==0:
        batch_size = 128
        batch_size_val = 128
        # else:
        #     batch_size = 32
        #     batch_size_val = 16

        util.makedirs(out_dir_train)
        
        train_params = dict(out_dir_train = out_dir_train,
                    train_data = train_data,
                    test_data = test_data,
                    batch_size = batch_size,
                    batch_size_val = batch_size_val,
                    num_epochs = num_epochs,
                    save_after = save_after,
                    disp_after = 1,
                    plot_after = 100,
                    test_after = 1,
                    lr = lr,
                    dec_after = dec_after, 
                    model_name = model_name,
                    criterion = criterion,
                    gpu_id = 0,
                    num_workers = 2,
                    model_file = model_file,
                    epoch_start = epoch_start,
                    margin_params = margin_params,
                    network_params = network_params,
                    weight_decay=wdecay)
        test_params = dict(out_dir_train = out_dir_train,
                    model_num = model_to_test,
                    # num_epochs-1, 
                    train_data = train_data,
                    test_data = test_data,
                    gpu_id = 0,
                    model_name = model_name,
                    batch_size_val = batch_size_val,
                    criterion = criterion,
                    margin_params = margin_params,
                    network_params = network_params)
        test_params_train = dict(**test_params)
        test_params_train['test_data'] = train_data_no_t
        test_params_train['post_pend'] = '_easy'


        print train_params
        param_file = os.path.join(out_dir_train,'params.txt')
        all_lines = []
        for k in train_params.keys():
            str_print = '%s: %s' % (k,train_params[k])
            print str_print
            all_lines.append(str_print)
        util.writeFile(param_file,all_lines)


            
        # if reconstruct:
            # train_model_recon(**train_params)
        test_model_recon(**test_params)
        test_model_recon(**test_params_train)
        # else:
        #     train_model(**train_params)
        #     test_model(**test_params)

        
    getting_accuracy.print_accuracy(out_dir_meta,pre_pend,strs_append,folds,log='log.txt')
    getting_accuracy.view_loss_curves(out_dir_meta,pre_pend,strs_append,folds,num_epochs-1)
예제 #3
0
def main():
    from models.utils import get_args, get_dataloader

    # args = get_args()
    # print args

    split_num = 0

    train_file = '../data/ck_96/train_test_files/train_' + str(
        split_num) + '.txt'
    test_file = '../data/ck_96/train_test_files/test_' + str(
        split_num) + '.txt'
    mean_file = '../data/ck_96/train_test_files/train_' + str(
        split_num) + '_mean.png'
    std_file = '../data/ck_96/train_test_files/train_' + str(
        split_num) + '_std.png'

    list_of_to_dos = ['flip', 'rotate']
    mean_im = scipy.misc.imresize(scipy.misc.imread(mean_file),
                                  (32, 32)).astype(np.float32)
    std_im = scipy.misc.imresize(scipy.misc.imread(std_file),
                                 (32, 32)).astype(np.float32)

    mean_im = scipy.misc.imread(mean_file).astype(np.float32)
    std_im = scipy.misc.imread(std_file).astype(np.float32)

    batch_size = 6
    clip = 5
    disable_cuda = False
    gpu = 2
    lr = 0.2
    num_epochs = 10
    disp_after = 1
    r = 1
    use_cuda = True

    batch_size_val = 64
    save_after = 1
    test_after = num_epochs - 1

    plot_after = 10

    lambda_ = 1e-2  #TODO:find a good schedule to increase lambda and m
    m = 0.2

    data_transforms = {}

    data_transforms['train'] = transforms.Compose([
        # lambda x: augment_image(x, list_of_to_dos, mean_im = mean_im, std_im = std_im,im_size = 48),
        # lambda x: np.concatenate([x,x,x],2),
        # transforms.ToPILImage(),

        # transforms.RandomCrop(32),
        # transforms.RandomHorizontalFlip(),
        # lambda x: x[:,:,:1],
        lambda x: augmenters.random_crop(x, 32),
        lambda x: augmenters.horizontal_flip(x),
        transforms.ToTensor(),
        lambda x: x * 255.
    ])

    data_transforms['val'] = transforms.Compose([
        # transforms.CenterCrop(32),
        lambda x: augmenters.crop_center(x, 32, 32),
        transforms.ToTensor(),
        lambda x: x * 255.
    ])

    # train_loader, test_loader = get_dataloader(batch_size)
    # for data in train_loader:
    #     imgs,labels = data
    #     print labels
    #     break
    # return

    our_data = True
    train_data = dataset.CK_48_Dataset(train_file, mean_file, std_file,
                                       data_transforms['train'])
    test_data = dataset.CK_48_Dataset(test_file, mean_file, std_file,
                                      data_transforms['val'])

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=0)

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=batch_size_val,
                                              shuffle=False,
                                              num_workers=0)

    # -batch_size=64 -lr=2e-2 -num_epochs=5 -r=1 -print_freq=5
    steps = len(train_loader.dataset) // batch_size
    print 'steps'

    A, B, C, D, E, r = 32, 8, 16, 16, 8, r  # a small CapsNet
    # model = CapsNet(A,B,C,D,E,r)
    import models
    params = dict(A=A, B=B, C=C, D=D, E=E, r=r)
    net = models.get('pytorch_mat_capsules', params)
    # net = Network(A,B,C,D,E,r)
    model = net.model
    # .cuda()

    # A,B,C,CC,D,E,r = 32,8,16,16,16,8,r # additional conv-caps layer for bigger input

    # # A,B,C,CC,D,E,r = 64,8,16,16,16,8,r #  additional conv-caps layer for bigger input
    # model = CapsNet_ck(A,B,C,CC,D,E,r)

    # print model

    with torch.cuda.device(gpu):
        #        print(gpu, type(gpu))
        # if pretrained:
        #     model.load_state_dict(torch.load(pretrained))
        #     m = 0.8
        #     lambda_ = 0.9
        if use_cuda:
            print("activating cuda")
            model.cuda()

        optimizer = torch.optim.Adam(net.get_lr_list(0.02))
        # optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max',patience = 1)

        for data in train_loader:
            # b += 1
            # if lambda_ < 1:
            #     lambda_ += 2e-1/steps
            # if m < 0.9:
            #     m += 2e-1/steps
            # optimizer.zero_grad()

            if our_data:
                imgs = data['image']
                labels = data['label']

            else:
                imgs, labels = data  #b,1,28,28; #b

            imgs, labels = Variable(imgs), Variable(labels)
            if use_cuda:
                imgs = imgs.cuda()
                labels = labels.cuda()

            print imgs.size()
            print labels.size()

            break

        for epoch in range(num_epochs):
            m = 0.2
            # print 'm',m
            #Train
            # print("Epoch {}".format(epoch))
            b = 0
            correct = 0

            # raw_input()

            optimizer.zero_grad()
            out = model(imgs)
            # ,lambda_) #b,10,17
            out_poses, out_labels = out[:, :-8], out[:, -8:]  #b,16*10; b,10
            loss = model.spread_loss(out_labels, labels, m)
            # raw_input()
            # loss = model.loss2(out_labels,labels)
            torch.nn.utils.clip_grad_norm(model.parameters(), clip)
            loss.backward()
            optimizer.step()
            #stats
            pred = out_labels.max(1)[1]  #b
            acc = pred.eq(labels).cpu().sum().data[0]
            correct += acc
            # if b % disp_after == 0:
            print("batch:{}, loss:{:.4f}, acc:{:}/{}".format(
                epoch, loss.data[0], acc, batch_size))
            #     break

            # break

            acc = correct / float(len(train_loader.dataset))
            # print("Epoch{} Train acc:{:4}".format(epoch, acc))
            # scheduler.step(acc)
            if epoch % save_after == 0:
                torch.save(model.state_dict(), "./model_{}.pth".format(epoch))
            # if loss.cpu().data[0]==0.0:
            #     print out_labels,labels
            #     break

            #Test
            if epoch % test_after == 0:
                print('Testing...')
                correct = 0
                # for data in test_loader:
                #     if our_data:
                #         imgs = data['image']
                #         labels = data['label']
                #     else:
                #         imgs,labels = data #b,1,28,28; #b
                #     imgs,labels = Variable(imgs),Variable(labels)
                #     if use_cuda:
                #         imgs = imgs.cuda()
                #         labels = labels.cuda()
                out = model(imgs)
                # ,lambda_) #b,10,17
                out_poses, out_labels = out[:, :-8], out[:,
                                                         -8:]  #b,16*10; b,10
                # loss = model.loss(out_labels, labels, m)
                print labels
                print out_labels

                loss = model.loss(out_labels, labels, m)
                #stats
                pred = out_labels.max(1)[1]  #b
                acc = pred.eq(labels).cpu().sum().data[0]
                correct += acc

                acc = correct / float(len(test_loader.dataset))
                print("Epoch{} Test acc:{:4}".format(epoch, acc))
예제 #4
0
def khorrami_bl_exp(mmi=False, model_to_test=None):

    out_dir_meta = '../experiments/khorrami_ck_96_caps_bl/'
    # pre_pend = os.path.join(out_dir_meta,'ck_')
    # post_pend = strs_append

    num_epochs = 300
    epoch_start = 0
    # dec_after = ['exp',0.96,350,1e-6]
    dec_after = ['exp', 0.96, 350, 1e-6]
    # dec_after = ['step',num_epochs,0.1]
    lr = [0.001, 0.001]

    im_size = 96
    model_name = 'khorrami_ck_96'
    # model_name = 'khorrami_ck_96_caps_bl'
    save_after = 10
    # margin_params = {'step':1,'start':0.2}
    # strs_append = '_'.join([str(val) for val in [model_name,300]+dec_after+lr])
    # out_dir_train = os.path.join(out_dir_meta,'ck_'+str(split_num)+'_'+strs_append)
    # model_file = os.path.join(out_dir_train,'model_299.pt')
    model_file = None
    if not mmi:
        strs_append = '_'.join([
            str(val) for val in
            ['train_test_files_non_peak_one_third', model_name, num_epochs] +
            dec_after + lr
        ])
        strs_append = '_' + strs_append
        pre_pend = 'ck_'
        folds = range(10)
    else:
        pre_pend = 'mmi_96_'
        folds = range(2)
        strs_append = '_'.join([
            str(val) for val in ['train_test_files', model_name, num_epochs] +
            dec_after + lr
        ])
        strs_append = '_' + strs_append

    if model_to_test is None:
        model_to_test = num_epochs - 1

    for split_num in folds:
        out_dir_train = os.path.join(out_dir_meta,
                                     pre_pend + str(split_num) + strs_append)
        print out_dir_train

        out_file_model = os.path.join(out_dir_train,
                                      'model_' + str(num_epochs - 1) + '.pt')
        if os.path.exists(out_file_model):
            print 'skipping', out_file_model
            # continue
        else:
            print 'not done', out_file_model
            raw_input()

        if not mmi:
            train_file = '../data/ck_96/train_test_files_non_peak_one_third/train_' + str(
                split_num) + '.txt'
            test_file = '../data/ck_96/train_test_files/test_' + str(
                split_num) + '.txt'
            test_file_easy = '../data/ck_96/train_test_files_non_peak_one_third/test_' + str(
                split_num) + '.txt'
            mean_file = '../data/ck_96/train_test_files_non_peak_one_third/train_' + str(
                split_num) + '_mean.png'
            std_file = '../data/ck_96/train_test_files_non_peak_one_third/train_' + str(
                split_num) + '_std.png'
        else:
            type_data = 'train_test_files'
            n_classes = 6
            train_pre = os.path.join('../data/mmi', type_data)
            test_pre = train_pre
            train_file = os.path.join(train_pre,
                                      'train_' + str(split_num) + '.txt')
            test_file_easy = os.path.join(
                train_pre, 'test_front_' + str(split_num) + '.txt')
            test_file = os.path.join(test_pre,
                                     'test_side_' + str(split_num) + '.txt')
            mean_file = os.path.join(train_pre,
                                     'train_' + str(split_num) + '_mean.png')
            std_file = os.path.join(train_pre,
                                    'train_' + str(split_num) + '_std.png')

        # train_file = '../data/ck_96/train_test_files/train_'+str(split_num)+'.txt'
        # test_file = '../data/ck_96/train_test_files/test_'+str(split_num)+'.txt'
        # mean_file = '../data/ck_96/train_test_files/train_'+str(split_num)+'_mean.png'
        # std_file = '../data/ck_96/train_test_files/train_'+str(split_num)+'_std.png'

        mean_im = scipy.misc.imread(mean_file).astype(np.float32)
        std_im = scipy.misc.imread(std_file).astype(np.float32)
        std_im[std_im == 0] = 1.

        if not mmi:
            list_of_to_dos = [
                'pixel_augment', 'flip', 'rotate', 'scale_translate'
            ]
            data_transforms = {}
            data_transforms['train'] = transforms.Compose([
                lambda x: augmenters.augment_image(x, list_of_to_dos, mean_im,
                                                   std_im, im_size),
                transforms.ToTensor(), lambda x: x * 255.
            ])
            data_transforms['val'] = transforms.Compose(
                [transforms.ToTensor(), lambda x: x * 255.])

            train_data = dataset.CK_96_Dataset(train_file, mean_file, std_file,
                                               data_transforms['train'])
            test_data = dataset.CK_96_Dataset(test_file, mean_file, std_file,
                                              data_transforms['val'])
            test_data_easy = dataset.CK_96_Dataset(test_file_easy, mean_file,
                                                   std_file,
                                                   data_transforms['val'])
        else:
            list_of_to_dos = ['flip', 'rotate', 'scale_translate']
            data_transforms = {}
            data_transforms['train'] = transforms.Compose([
                lambda x: augmenters.random_crop(x, im_size),
                lambda x: augmenters.augment_image(x, list_of_to_dos),
                transforms.ToTensor(), lambda x: x * 255.
            ])
            data_transforms['val'] = transforms.Compose(
                [transforms.ToTensor(), lambda x: x * 255.])

            print train_file
            print test_file
            print std_file
            print mean_file
            # raw_input()

            train_data = dataset.CK_96_Dataset_with_rs(
                train_file, mean_file, std_file, data_transforms['train'])
            test_data_easy = dataset.CK_96_Dataset_with_rs(
                test_file_easy,
                mean_file,
                std_file,
                data_transforms['val'],
                resize=im_size)
            test_data = dataset.CK_96_Dataset_with_rs(test_file,
                                                      mean_file,
                                                      std_file,
                                                      data_transforms['val'],
                                                      resize=im_size)

        network_params = dict(n_classes=8, bn=False)

        batch_size = 128
        batch_size_val = 128

        util.makedirs(out_dir_train)

        train_params = dict(out_dir_train=out_dir_train,
                            train_data=train_data,
                            test_data=test_data,
                            batch_size=batch_size,
                            batch_size_val=batch_size_val,
                            num_epochs=num_epochs,
                            save_after=save_after,
                            disp_after=1,
                            plot_after=10,
                            test_after=1,
                            lr=lr,
                            dec_after=dec_after,
                            model_name=model_name,
                            criterion=nn.CrossEntropyLoss(),
                            gpu_id=1,
                            num_workers=0,
                            model_file=model_file,
                            epoch_start=epoch_start,
                            network_params=network_params)

        test_params = dict(out_dir_train=out_dir_train,
                           model_num=model_to_test,
                           train_data=train_data,
                           test_data=test_data,
                           gpu_id=1,
                           model_name=model_name,
                           batch_size_val=batch_size_val,
                           criterion=nn.CrossEntropyLoss(),
                           margin_params=None,
                           network_params=network_params,
                           post_pend='',
                           model_nums=None)

        test_params_easy = dict(out_dir_train=out_dir_train,
                                model_num=model_to_test,
                                train_data=train_data,
                                test_data=test_data_easy,
                                gpu_id=1,
                                model_name=model_name,
                                batch_size_val=batch_size_val,
                                criterion=nn.CrossEntropyLoss(),
                                margin_params=None,
                                network_params=network_params,
                                post_pend='_easy',
                                model_nums=None)

        print train_params
        param_file = os.path.join(out_dir_train, 'params.txt')
        all_lines = []
        for k in train_params.keys():
            str_print = '%s: %s' % (k, train_params[k])
            print str_print
            all_lines.append(str_print)
        util.writeFile(param_file, all_lines)

        # train_model(**train_params)
        test_model(**test_params)
        # print test_params['test_data']
        # print test_params['post_pend']
        # # raw_input()
        # print test_params_easy['test_data']
        # print test_params_easy['post_pend']

        test_model(**test_params_easy)
        # print out_dir_train, model_to_test
        # raw_input()

    getting_accuracy.print_accuracy(out_dir_meta,
                                    pre_pend,
                                    strs_append,
                                    folds,
                                    log='log.txt')
    getting_accuracy.view_loss_curves(out_dir_meta, pre_pend, strs_append,
                                      folds, num_epochs - 1)
예제 #5
0
def train_gray(wdecay,
               lr,
               route_iter,
               folds=[4, 9],
               model_name='vgg_capsule_bp4d',
               epoch_stuff=[30, 60],
               res=False,
               class_weights=False,
               reconstruct=False,
               loss_weights=None,
               exp=False,
               disfa=False,
               vgg_base_file=None,
               vgg_base_file_str=None,
               mean_file=None,
               std_file=None,
               aug_more=False,
               align=True):
    out_dirs = []

    out_dir_meta = '../experiments/' + model_name + str(route_iter)
    num_epochs = epoch_stuff[1]
    epoch_start = 0
    if exp:
        dec_after = ['exp', 0.96, epoch_stuff[0], 1e-6]
    else:
        dec_after = ['step', epoch_stuff[0], 0.1]

    lr = lr
    im_resize = 110
    # 256
    im_size = 96
    save_after = 1
    if disfa:
        dir_files = '../data/disfa'
        # type_data = 'train_test_10_6_method_110_gray_align'; n_classes = 10;
        type_data = 'train_test_8_au_all_method_110_gray_align'
        n_classes = 8
        pre_pend = 'disfa_' + type_data + '_'
        binarize = True
    else:
        dir_files = '../data/bp4d'
        if align:
            type_data = 'train_test_files_110_gray_align'
            n_classes = 12
        else:
            type_data = 'train_test_files_110_gray_nodetect'
            n_classes = 12
        pre_pend = 'bp4d_' + type_data + '_'
        binarize = False

    criterion = 'marginmulti'
    criterion_str = criterion

    init = False
    aug_str = aug_more
    # if aug_more:
    #     aug_str = 'cropkhAugNoColor'
    # else:
    #     aug_str = 'flipCrop'

    strs_append = '_' + '_'.join([
        str(val) for val in [
            'reconstruct', reconstruct, class_weights, aug_str, criterion_str,
            init, 'wdecay', wdecay, num_epochs
        ] + dec_after + lr + ['lossweights'] + loss_weights +
        [vgg_base_file_str]
    ])

    lr_p = lr[:]
    for split_num in folds:

        if res:

            # strs_appendc = '_'+'_'.join([str(val) for val in ['reconstruct',reconstruct,True,'flipCrop',criterion_str,init,'wdecay',wdecay,10,'exp',0.96,350,1e-6]+['lossweights']+loss_weights])
            # dec_afterc = dec_after
            strs_appendc = '_' + '_'.join([
                str(val) for val in [
                    'reconstruct', reconstruct, True, aug_str, criterion_str,
                    init, 'wdecay', wdecay, 10
                ] + dec_after + lr + ['lossweights'] + loss_weights +
                [vgg_base_file_str]
            ])

            out_dir_train = os.path.join(
                out_dir_meta, pre_pend + str(split_num) + strs_appendc)
            model_file = os.path.join(out_dir_train, 'model_9.pt')
            epoch_start = 10
            # lr =[0.1*lr_curr for lr_curr in lr_p]

        else:
            model_file = None

        margin_params = None

        out_dir_train = os.path.join(out_dir_meta,
                                     pre_pend + str(split_num) + strs_append)
        final_model_file = os.path.join(out_dir_train,
                                        'model_' + str(num_epochs - 1) + '.pt')
        if os.path.exists(final_model_file):
            print 'skipping', final_model_file
            # raw_input()
            # continue
        else:
            print 'not skipping', final_model_file
            # raw_input()
            # continue

        train_file = os.path.join(dir_files, type_data,
                                  'train_' + str(split_num) + '.txt')
        test_file = os.path.join(dir_files, type_data,
                                 'test_' + str(split_num) + '.txt')
        if vgg_base_file is None:
            mean_file = os.path.join(dir_files, type_data,
                                     'train_' + str(split_num) + '_mean.png')
            std_file = os.path.join(dir_files, type_data,
                                    'train_' + str(split_num) + '_std.png')

        print train_file
        print test_file
        print mean_file
        print std_file
        # raw_input()

        class_weights = util.get_class_weights_au(
            util.readLinesFromFile(train_file))

        data_transforms = {}
        if aug_more == 'cropkhAugNoColor':
            train_resize = None
            print 'AUGING MORE'
            list_of_todos = ['flip', 'rotate', 'scale_translate']

            data_transforms['train'] = transforms.Compose([
                lambda x: augmenters.random_crop(x, im_size),
                lambda x: augmenters.augment_image(x, list_of_todos),
                # lambda x: augmenters.horizontal_flip(x),
                transforms.ToTensor(),
                lambda x: x * 255,
            ])
        elif aug_more == 'cropFlip':
            train_resize = None
            data_transforms['train'] = transforms.Compose([
                lambda x: augmenters.random_crop(x, im_size),
                lambda x: augmenters.horizontal_flip(x),
                transforms.ToTensor(),
                lambda x: x * 255,
            ])
        elif aug_more == 'NONE':
            train_resize = im_size
            data_transforms['train'] = transforms.Compose([
                transforms.ToTensor(),
                lambda x: x * 255,
            ])
        else:
            raise ValueError('aug_more is problematic')

        data_transforms['val'] = transforms.Compose([
            transforms.ToTensor(),
            lambda x: x * 255,
        ])

        train_data = dataset.Bp4d_Dataset_Mean_Std_Im(
            train_file,
            mean_file,
            std_file,
            transform=data_transforms['train'],
            binarize=binarize,
            resize=train_resize)
        test_data = dataset.Bp4d_Dataset_Mean_Std_Im(
            test_file,
            mean_file,
            std_file,
            resize=im_size,
            transform=data_transforms['val'],
            binarize=binarize)

        # train_data = dataset.Bp4d_Dataset_Mean_Std_Im(test_file, mean_file, std_file, resize= im_size, transform = data_transforms['val'])

        network_params = dict(n_classes=n_classes,
                              pool_type='max',
                              r=route_iter,
                              init=init,
                              class_weights=class_weights,
                              reconstruct=reconstruct,
                              loss_weights=loss_weights,
                              vgg_base_file=vgg_base_file)

        batch_size = 128
        batch_size_val = 128

        util.makedirs(out_dir_train)

        train_params = dict(out_dir_train=out_dir_train,
                            train_data=train_data,
                            test_data=test_data,
                            batch_size=batch_size,
                            batch_size_val=batch_size_val,
                            num_epochs=num_epochs,
                            save_after=save_after,
                            disp_after=1,
                            plot_after=10,
                            test_after=1,
                            lr=lr,
                            dec_after=dec_after,
                            model_name=model_name,
                            criterion=criterion,
                            gpu_id=0,
                            num_workers=0,
                            model_file=model_file,
                            epoch_start=epoch_start,
                            margin_params=margin_params,
                            network_params=network_params,
                            weight_decay=wdecay)
        test_params = dict(out_dir_train=out_dir_train,
                           model_num=num_epochs - 1,
                           train_data=train_data,
                           test_data=test_data,
                           gpu_id=0,
                           model_name=model_name,
                           batch_size_val=batch_size_val,
                           criterion=criterion,
                           margin_params=margin_params,
                           network_params=network_params,
                           barebones=True)
        # test_params_train = dict(**test_params)
        # test_params_train['test_data'] = train_data_no_t
        # test_params_train['post_pend'] = '_train'

        print train_params
        param_file = os.path.join(out_dir_train, 'params.txt')
        all_lines = []
        for k in train_params.keys():
            str_print = '%s: %s' % (k, train_params[k])
            print str_print
            all_lines.append(str_print)
        util.writeFile(param_file, all_lines)

        # if reconstruct:

        train_model_recon(**train_params)
        test_model_recon(**test_params)
        # test_model_recon(**test_params_train)

        # else:
        #     train_model(**train_params)
        # test_params = dict(out_dir_train = out_dir_train,
        #         model_num = num_epochs-1,
        #         train_data = train_data,
        #         test_data = test_data,
        #         gpu_id = 0,
        #         model_name = model_name,
        #         batch_size_val = batch_size_val,
        #         criterion = criterion,
        #         margin_params = margin_params,
        #         network_params = network_params)
        # test_model(**test_params)

    getting_accuracy.print_accuracy(out_dir_meta,
                                    pre_pend,
                                    strs_append,
                                    folds,
                                    log='log.txt')
예제 #6
0
def save_test_results(wdecay,
                      lr,
                      route_iter,
                      folds=[4, 9],
                      model_name='vgg_capsule_bp4d',
                      epoch_stuff=[30, 60],
                      res=False,
                      class_weights=False,
                      reconstruct=False,
                      loss_weights=None,
                      models_to_test=None,
                      exp=False,
                      disfa=False):
    out_dirs = []

    out_dir_meta = '../experiments/' + model_name + str(route_iter)
    num_epochs = epoch_stuff[1]
    epoch_start = 0
    # dec_after = ['exp',0.96,epoch_stuff[0],1e-6]
    if exp:
        dec_after = ['exp', 0.96, epoch_stuff[0], 1e-6]
    else:
        dec_after = ['step', epoch_stuff[0], 0.1]

    lr = lr
    im_resize = 110
    # 256
    im_size = 96
    # save_after = 1

    if disfa:
        dir_files = '../data/disfa'
        # type_data = 'train_test_10_6_method_110_gray_align'; n_classes = 10;
        type_data = 'train_test_8_au_all_method_110_gray_align'
        n_classes = 8
        pre_pend = 'disfa_' + type_data + '_'
        binarize = True
    else:
        dir_files = '../data/bp4d'
        type_data = 'train_test_files_110_gray_align'
        n_classes = 12
        pre_pend = 'bp4d_' + type_data + '_'
        binarize = False

    criterion = 'marginmulti'
    criterion_str = criterion

    init = False

    strs_append = '_' + '_'.join([
        str(val) for val in [
            'reconstruct', reconstruct, class_weights, 'flipCrop',
            criterion_str, init, 'wdecay', wdecay, num_epochs
        ] + dec_after + lr + ['lossweights'] + loss_weights
    ])

    # pre_pend = 'bp4d_110_'

    lr_p = lr[:]
    for split_num in folds:
        for model_num_curr in models_to_test:
            margin_params = None
            out_dir_train = os.path.join(
                out_dir_meta, pre_pend + str(split_num) + strs_append)
            final_model_file = os.path.join(
                out_dir_train, 'model_' + str(num_epochs - 1) + '.pt')

            if os.path.exists(
                    os.path.join(out_dir_train,
                                 'results_model_' + str(model_num_curr))):
                print 'exists', model_num_curr, split_num
                print out_dir_train
                # continue
            else:

                print 'does not exist', model_num_curr, split_num
                # print 'bp4d_train_test_files_110_gray_align_0_reconstruct_True_True_flipCrop_marginmulti_False_wdecay_0_20_exp_0.96_350_1e-06_0.001_0.001_0.001_lossweights_1.0_1.0'
                print out_dir_train
                # raw_input()

            # if os.path.exists(final_model_file):
            #     print 'skipping',final_model_file
            #     # raw_input()
            #     # continue
            # else:
            #     print 'not skipping', final_model_file
            #     # raw_input()
            #     # continue

            train_file = os.path.join(dir_files, type_data,
                                      'train_' + str(split_num) + '.txt')
            test_file = os.path.join(dir_files, type_data,
                                     'test_' + str(split_num) + '.txt')
            mean_file = os.path.join(dir_files, type_data,
                                     'train_' + str(split_num) + '_mean.png')
            std_file = os.path.join(dir_files, type_data,
                                    'train_' + str(split_num) + '_std.png')

            # train_file = os.path.join('../data/bp4d',type_data,'train_'+str(split_num)+'.txt')
            # test_file = os.path.join('../data/bp4d',type_data,'test_'+str(split_num)+'.txt')

            if model_name.startswith('vgg'):
                mean_std = np.array([[93.5940, 104.7624, 129.1863],
                                     [1., 1., 1.]])  #bgr
                bgr = True
            else:
                # print 'ELSING'
                # mean_std = np.array([[129.1863,104.7624,93.5940],[1.,1.,1.]])
                mean_std = np.array([[0.485 * 255, 0.456 * 255, 0.406 * 255],
                                     [0.229 * 255, 0.224 * 255, 0.225 * 255]])
                # print mean_std
                # raw_input()
                bgr = False

            # print mean_std

            # mean_im = scipy.misc.imread(mean_file).astype(np.float32)
            # std_im = scipy.misc.imread(std_file).astype(np.float32)

            class_weights = util.get_class_weights_au(
                util.readLinesFromFile(train_file))
            data_transforms = {}
            data_transforms['train'] = transforms.Compose([
                lambda x: augmenters.random_crop(x, im_size),
                lambda x: augmenters.horizontal_flip(x),
                transforms.ToTensor(),
                lambda x: x * 255,
            ])
            data_transforms['val'] = transforms.Compose([
                # transforms.ToPILImage(),
                # transforms.Resize((im_size,im_size)),
                # lambda x: augmenters.resize(x,im_size),
                transforms.ToTensor(),
                lambda x: x * 255,
            ])

            # data_transforms = {}
            # data_transforms['train']= transforms.Compose([
            #     transforms.ToPILImage(),
            #     # transforms.Resize((im_resize,im_resize)),
            #     transforms.RandomCrop(im_size),
            #     transforms.RandomHorizontalFlip(),
            #     transforms.RandomRotation(15),
            #     transforms.ColorJitter(),
            #     transforms.ToTensor(),
            #     lambda x: x*255,
            #     transforms.Normalize(mean_std[0,:],mean_std[1,:]),
            # ])
            # data_transforms['val']= transforms.Compose([
            #     transforms.ToPILImage(),
            #     transforms.Resize((im_size,im_size)),
            #     transforms.ToTensor(),
            #     lambda x: x*255,
            #     transforms.Normalize(mean_std[0,:],mean_std[1,:]),
            #     ])

            # print train_file
            # print test_file
            # train_data = dataset.Bp4d_Dataset(train_file, bgr = bgr, transform = data_transforms['train'])
            # test_data = dataset.Bp4d_Dataset(test_file, bgr = bgr, transform = data_transforms['val'])

            train_data = dataset.Bp4d_Dataset_Mean_Std_Im(
                train_file,
                mean_file,
                std_file,
                transform=data_transforms['train'],
                binarize=binarize)
            test_data = dataset.Bp4d_Dataset_Mean_Std_Im(
                test_file,
                mean_file,
                std_file,
                resize=im_size,
                transform=data_transforms['val'],
                binarize=binarize)

            network_params = dict(n_classes=n_classes,
                                  pool_type='max',
                                  r=route_iter,
                                  init=init,
                                  class_weights=class_weights,
                                  reconstruct=reconstruct,
                                  loss_weights=loss_weights)

            batch_size = 96
            batch_size_val = 96

            util.makedirs(out_dir_train)

            test_params = dict(out_dir_train=out_dir_train,
                               model_num=model_num_curr,
                               train_data=train_data,
                               test_data=test_data,
                               gpu_id=0,
                               model_name=model_name,
                               batch_size_val=batch_size_val,
                               criterion=criterion,
                               margin_params=margin_params,
                               network_params=network_params,
                               barebones=True)
            test_model_recon(**test_params)
예제 #7
0
def train_vgg(wdecay,
              lr,
              route_iter,
              folds=[4, 9],
              model_name='vgg_capsule_bp4d',
              epoch_stuff=[30, 60],
              res=False,
              class_weights=False,
              reconstruct=False,
              loss_weights=None,
              exp=False,
              align=False,
              disfa=False,
              more_aug=False,
              dropout=None,
              model_to_test=None,
              gpu_id=0,
              test_mode=False):
    out_dirs = []

    out_dir_meta = '../experiments/' + model_name + str(route_iter)
    num_epochs = epoch_stuff[1]

    if model_to_test is None:
        model_to_test = num_epochs - 1

    epoch_start = 0
    if exp:
        dec_after = ['exp', 0.96, epoch_stuff[0], 1e-6]
    else:
        dec_after = ['step', epoch_stuff[0], 0.1]

    lr = lr

    if model_name.startswith('vgg'):
        im_resize = 256
        im_size = 224
        if not disfa:
            dir_files = '../data/bp4d'
            if align:
                type_data = 'train_test_files_256_color_align'
                n_classes = 12
            else:
                type_data = 'train_test_files_256_color_nodetect'
                n_classes = 12
            pre_pend = 'bp4d_256_' + type_data + '_'
            binarize = False
        else:
            dir_files = '../data/disfa'
            type_data = 'train_test_8_au_all_method_256_color_align'
            n_classes = 8
            pre_pend = 'disfa_' + type_data + '_'
            binarize = True
            pre_pend = 'disfa_256_' + type_data + '_'
    else:
        if not disfa:
            im_resize = 110
            im_size = 96
            binarize = False
            dir_files = '../data/bp4d'
            type_data = 'train_test_files_110_color_align'
            n_classes = 12
            pre_pend = 'bp4d_110_'
        else:
            im_resize = 110
            im_size = 96
            dir_files = '../data/disfa'
            type_data = 'train_test_8_au_all_method_110_color_align'
            n_classes = 8
            binarize = True
            pre_pend = 'disfa_110_' + type_data + '_'

    save_after = 1
    criterion = 'marginmulti'
    criterion_str = criterion

    init = False

    strs_append_list = [
        'reconstruct', reconstruct, class_weights, 'all_aug', criterion_str,
        init, 'wdecay', wdecay, num_epochs
    ] + dec_after + lr + [more_aug] + [dropout]
    if loss_weights is not None:
        strs_append_list = strs_append_list + ['lossweights'] + loss_weights
    strs_append = '_' + '_'.join([str(val) for val in strs_append_list])

    lr_p = lr[:]
    for split_num in folds:

        if res:

            strs_append_list_c = [
                'reconstruct', reconstruct, False, 'all_aug', criterion_str,
                init, 'wdecay', wdecay, 10
            ] + ['step', 10, 0.1] + lr + [more_aug] + [dropout]
            # print dec_after
            # raw_input()
            if loss_weights is not None:
                strs_append_list_c = strs_append_list_c + ['lossweights'
                                                           ] + loss_weights

            strs_append_c = '_' + '_'.join(
                [str(val) for val in strs_append_list_c])
            out_dir_train = os.path.join(
                out_dir_meta, pre_pend + str(split_num) + strs_append_c)

            model_file = os.path.join(out_dir_train, 'model_4.pt')
            epoch_start = 5
            lr = [val * 0.1 for val in lr]
            print 'FILE EXISTS', os.path.exists(
                model_file), model_file, epoch_start

            raw_input()

        else:
            model_file = None

        margin_params = None

        out_dir_train = os.path.join(out_dir_meta,
                                     pre_pend + str(split_num) + strs_append)
        final_model_file = os.path.join(out_dir_train,
                                        'model_' + str(num_epochs - 1) + '.pt')
        # final_model_file = os.path.join(out_dir_train,'results_model_'+str(model_to_test))
        if os.path.exists(final_model_file) and not test_mode:
            print 'skipping', final_model_file
            # raw_input()
            continue
        else:
            print 'not skipping', final_model_file
            # raw_input()
            # continue

        train_file = os.path.join(dir_files, type_data,
                                  'train_' + str(split_num) + '.txt')
        test_file = os.path.join(dir_files, type_data,
                                 'test_' + str(split_num) + '.txt')

        data_transforms = None
        if model_name.startswith('vgg_capsule_7_3_imagenet'
                                 ) or model_name.startswith('scratch_'):
            # mean_std = np.array([[93.5940,104.7624,129.1863],[1.,1.,1.]]) #bgr
            # std_div = np.array([0.225*255,0.224*255,0.229*255])
            # print std_div
            # raw_input()
            mean_std = np.array([[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]])

            bgr = False
            normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])
            std_div = None

            data_transforms = {}
            data_transforms['train'] = [
                transforms.ToPILImage(),
                transforms.RandomCrop(im_size),
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(15),
                transforms.ColorJitter(),
                transforms.ToTensor(), normalize
            ]
            data_transforms['val'] = [
                transforms.ToPILImage(),
                transforms.Resize((im_size, im_size)),
                transforms.ToTensor(), normalize
            ]

            if torch.version.cuda.startswith('9'):
                data_transforms['train'].append(lambda x: x.float())
                data_transforms['val'].append(lambda x: x.float())

            data_transforms['train'] = transforms.Compose(
                data_transforms['train'])
            data_transforms['val'] = transforms.Compose(data_transforms['val'])

            train_data = dataset.Bp4d_Dataset(
                train_file,
                bgr=bgr,
                binarize=binarize,
                transform=data_transforms['train'])
            test_data = dataset.Bp4d_Dataset(test_file,
                                             bgr=bgr,
                                             binarize=binarize,
                                             transform=data_transforms['val'])

        elif model_name.startswith('vgg'):
            mean_std = np.array([[93.5940, 104.7624, 129.1863], [1., 1.,
                                                                 1.]])  #bgr
            std_div = np.array([0.225 * 255, 0.224 * 255, 0.229 * 255])
            print std_div
            # raw_input()
            bgr = True
        else:
            mean_std = np.array([[0.485 * 255, 0.456 * 255, 0.406 * 255],
                                 [0.229 * 255, 0.224 * 255, 0.225 * 255]])
            bgr = False

        print mean_std

        class_weights = util.get_class_weights_au(
            util.readLinesFromFile(train_file))

        if data_transforms is None:
            data_transforms = {}
            if more_aug == 'MORE':
                print more_aug
                list_of_to_dos = ['flip', 'rotate', 'scale_translate']

                # print torch.version.cuda
                # raw_input()
                if torch.version.cuda.startswith('9'):
                    # print 'HEYLO'
                    # raw_input()
                    data_transforms['train'] = transforms.Compose([
                        lambda x: augmenters.random_crop(x, im_size),
                        lambda x: augmenters.augment_image(
                            x, list_of_to_dos, color=True, im_size=im_size),
                        transforms.ToTensor(), lambda x: x.float()
                    ])
                    data_transforms['val'] = transforms.Compose(
                        [transforms.ToTensor(), lambda x: x.float()])
                else:
                    data_transforms['train'] = transforms.Compose([
                        lambda x: augmenters.random_crop(x, im_size),
                        lambda x: augmenters.augment_image(
                            x, list_of_to_dos, color=True, im_size=im_size),
                        transforms.ToTensor(),
                        lambda x: x * 255,
                    ])
                    data_transforms['val'] = transforms.Compose([
                        transforms.ToTensor(),
                        lambda x: x * 255,
                    ])

                train_data = dataset.Bp4d_Dataset_with_mean_std_val(
                    train_file,
                    bgr=bgr,
                    binarize=binarize,
                    mean_std=mean_std,
                    transform=data_transforms['train'])
                test_data = dataset.Bp4d_Dataset_with_mean_std_val(
                    test_file,
                    bgr=bgr,
                    binarize=binarize,
                    mean_std=mean_std,
                    transform=data_transforms['val'],
                    resize=im_size)
            elif more_aug == 'LESS':
                # std_div = None
                data_transforms['train'] = transforms.Compose([
                    transforms.ToPILImage(),
                    # transforms.Resize((im_resize,im_resize)),
                    transforms.RandomCrop(im_size),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomRotation(15),
                    transforms.ColorJitter(),
                    transforms.ToTensor(),
                    lambda x: x * 255,
                    transforms.Normalize(mean_std[0, :], mean_std[1, :]),
                ])
                data_transforms['val'] = transforms.Compose([
                    transforms.ToPILImage(),
                    transforms.Resize((im_size, im_size)),
                    transforms.ToTensor(),
                    lambda x: x * 255,
                    transforms.Normalize(mean_std[0, :], mean_std[1, :]),
                ])

                train_data = dataset.Bp4d_Dataset(
                    train_file,
                    bgr=bgr,
                    binarize=binarize,
                    transform=data_transforms['train'])
                test_data = dataset.Bp4d_Dataset(
                    test_file,
                    bgr=bgr,
                    binarize=binarize,
                    transform=data_transforms['val'])
            elif more_aug == 'NONE':
                print 'NO AUGING'
                data_transforms['train'] = transforms.Compose(
                    [transforms.ToTensor(), lambda x: x * 255])
                data_transforms['val'] = transforms.Compose(
                    [transforms.ToTensor(), lambda x: x * 255])
                train_data = dataset.Bp4d_Dataset_with_mean_std_val(
                    train_file,
                    bgr=bgr,
                    binarize=binarize,
                    mean_std=mean_std,
                    transform=data_transforms['train'],
                    resize=im_size)
                test_data = dataset.Bp4d_Dataset_with_mean_std_val(
                    test_file,
                    bgr=bgr,
                    binarize=binarize,
                    mean_std=mean_std,
                    transform=data_transforms['val'],
                    resize=im_size)
            else:
                raise ValueError('more_aug not valid')

        if dropout is not None:
            print 'RECONS', reconstruct
            network_params = dict(n_classes=n_classes,
                                  pool_type='max',
                                  r=route_iter,
                                  init=init,
                                  class_weights=class_weights,
                                  reconstruct=reconstruct,
                                  loss_weights=loss_weights,
                                  std_div=std_div,
                                  dropout=dropout)
        else:
            network_params = dict(n_classes=n_classes,
                                  pool_type='max',
                                  r=route_iter,
                                  init=init,
                                  class_weights=class_weights,
                                  reconstruct=reconstruct,
                                  loss_weights=loss_weights,
                                  std_div=std_div)

        batch_size = 32
        batch_size_val = 32

        util.makedirs(out_dir_train)

        train_params = dict(out_dir_train=out_dir_train,
                            train_data=train_data,
                            test_data=test_data,
                            batch_size=batch_size,
                            batch_size_val=batch_size_val,
                            num_epochs=num_epochs,
                            save_after=save_after,
                            disp_after=1,
                            plot_after=100,
                            test_after=1,
                            lr=lr,
                            dec_after=dec_after,
                            model_name=model_name,
                            criterion=criterion,
                            gpu_id=gpu_id,
                            num_workers=0,
                            model_file=model_file,
                            epoch_start=epoch_start,
                            margin_params=margin_params,
                            network_params=network_params,
                            weight_decay=wdecay)
        test_params = dict(out_dir_train=out_dir_train,
                           model_num=model_to_test,
                           train_data=train_data,
                           test_data=test_data,
                           gpu_id=gpu_id,
                           model_name=model_name,
                           batch_size_val=batch_size_val,
                           criterion=criterion,
                           margin_params=margin_params,
                           network_params=network_params,
                           barebones=True)
        # test_params_train = dict(**test_params)
        # test_params_train['test_data'] = train_data_no_t
        # test_params_train['post_pend'] = '_train'

        print train_params
        param_file = os.path.join(out_dir_train, 'params.txt')
        all_lines = []
        for k in train_params.keys():
            str_print = '%s: %s' % (k, train_params[k])
            print str_print
            all_lines.append(str_print)
        # util.writeFile(param_file,all_lines)

        # if reconstruct:
        if not test_mode:
            train_model_recon(**train_params)

        test_model_recon(**test_params)

        # test_params = dict(out_dir_train = out_dir_train,
        #             model_num = 4,
        #             train_data = train_data,
        #             test_data = test_data,
        #             gpu_id = gpu_id,
        #             model_name = model_name,
        #             batch_size_val = batch_size_val,
        #             criterion = criterion,
        #             margin_params = margin_params,
        #             network_params = network_params,barebones=True)

        # test_model_recon(**test_params)

    getting_accuracy.print_accuracy(out_dir_meta,
                                    pre_pend,
                                    strs_append,
                                    folds,
                                    log='log.txt')