Esempio n. 1
0
def main(_):
    # Define data pre-processors
    load_shape = [80, 80, 3]
    shape_transfer = [64, 64, 3]
    crop_sz = (64, 64)
    preprocessor = Preprocessor(target_shape=load_shape, src_shape=(96, 96, 3))
    preprocessor_lin = Preprocessor(target_shape=shape_transfer, src_shape=(96, 96, 3))

    # Initialize the data generators
    data_gen_ssl = STL10('train_unlabeled')
    data_gen_ftune = STL10('train')
    data_test = STL10('test')

    # Define the network and SSL training
    model = TRCNet(batch_size=FLAGS.batch_size, im_shape=load_shape, n_tr_classes=6, tag=FLAGS.tag,
                   lci_patch_sz=42, lci_crop_sz=48, n_layers_lci=4, ae_dim=48,
                   enc_params={'padding': 'SAME'})
    trainer = CINTrainer(model=model, data_generator=data_gen_ssl, pre_processor=preprocessor, crop_sz=crop_sz,
                         wd_class=FLAGS.wd, init_lr_class=FLAGS.pre_lr,
                         num_epochs=FLAGS.n_eps_pre, num_gpus=FLAGS.num_gpus,
                         optimizer='adam', init_lr=0.0002, momentum=0.5,  # Parameters for LCI training only
                         train_scopes='features')
    trainer.train_model(None)

    # Get the final checkpoint
    ckpt_dir_model = trainer.get_save_dir()
    ckpt = wait_for_new_checkpoint(ckpt_dir_model, last_checkpoint=None)
    print('Found checkpoint: {}'.format(ckpt))
    ckpt_id = ckpt.split('-')[-1]

    # Train linear classifiers on frozen features
    tag_class = '{}_classifier_ckpt_{}'.format(FLAGS.tag, ckpt_id)
    model = TRCNet(batch_size=FLAGS.batch_size_ftune, im_shape=shape_transfer, tag=tag_class,
                   feats_ids=['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'],
                   enc_params={'use_fc': False, 'padding': 'SAME'})
    trainer_class = ClassifierTrainer(model=model, data_generator=data_gen_ftune, pre_processor=preprocessor_lin,
                                      optimizer='momentum', init_lr=FLAGS.ftune_lr, momentum=0.9,
                                      num_epochs=FLAGS.n_eps_ftune, num_gpus=1,
                                      train_scopes='classifier')
    trainer_class.train_model(ckpt)
    ckpt_dir = trainer_class.get_save_dir()

    # Evaluate on the test set
    model.batch_size = 100
    tester = ClassifierTester(model=model, data_generator=data_test, pre_processor=preprocessor_lin)
    acc = tester.test_classifier(ckpt_dir)
    write_experiments_multi(acc, tag_class, FLAGS.tag)
Esempio n. 2
0
def build_dataset(is_train, args):
    transform = build_transform(is_train, args)

    if args.data_set == 'CIFAR10':
        dataset = CIFAR10(os.path.join(args.dataset_location, 'CIFAR10_dataset'), 
                          download=True, train=is_train, transform=transform, 
                          num_imgs_per_cat=args.num_imgs_per_cat,
                          training_mode = args.training_mode)
        nb_classes = 10

    
    elif args.data_set == 'CIFAR100':
        dataset = CIFAR100(os.path.join(args.dataset_location, 'CIFAR100_dataset'), 
                           download=True, train=is_train, transform=transform, 
                           num_imgs_per_cat=args.num_imgs_per_cat,
                           training_mode = args.training_mode)
        
        nb_classes = 100
        

    elif args.data_set == 'STL10':
        #### Note num_imgs_per_cat is not implemented in this dataset as it has unlabeled data
        split = 'train+unlabeled' if args.training_mode=='SSL' else 'train'
        split = split if is_train else 'test'
        
        dataset = STL10(root=os.path.join(args.dataset_location, 'STL10'), 
                        download=True, split=split, transform=transform,
                          training_mode = args.training_mode)
        nb_classes = 10
        
    elif args.data_set == 'TinyImageNet':
        mode='train' if is_train else 'val'
        root_dir = os.path.join(args.dataset_location, 'TinyImageNet/tiny-imagenet-200/')
        dataset = TinyImageNetDataset(root_dir=root_dir, download=True, mode=mode, transform=transform, 
                          num_imgs_per_cat=args.num_imgs_per_cat,
                          training_mode = args.training_mode)
        nb_classes = 200


    return dataset, nb_classes
Esempio n. 3
0
from Preprocessor import Preprocessor
from train.SDNetTrainer import SDNetTrainer
from datasets.STL10 import STL10
from models.SDNet import SDNet
from utils import get_checkpoint_path

target_shape = [96, 96, 3]

for fold in range(10):
    model = SDNet(num_layers=4,
                  batch_size=200,
                  target_shape=target_shape,
                  pool5=False)
    data = STL10()
    preprocessor = Preprocessor(target_shape=target_shape)
    trainer = SDNetTrainer(model=model,
                           dataset=data,
                           pre_processor=preprocessor,
                           num_epochs=400,
                           tag='baseline',
                           lr_policy='linear',
                           optimizer='adam')
    chpt_path = get_checkpoint_path(trainer.get_save_dir())
    trainer.finetune_cv(chpt_path,
                        num_conv2train=5,
                        num_conv2init=5,
                        fold=fold)