Ejemplo n.º 1
0
    def setUpClass(cls):
        master_seed(SEED)

        Model = get_model('MnistCnnV2')
        model = Model()
        logger.info('Starting %s data container...', NAME)
        dc = DataContainer(DATASET_LIST[NAME], get_data_path())
        dc()
        mc = ModelContainerPT(model, dc)
        mc.load(MODEL_FILE)
        accuracy = mc.evaluate(dc.x_test, dc.y_test)
        logger.info('Accuracy on test set: %f', accuracy)

        cls.distillation = DistillationContainer(mc,
                                                 Model(),
                                                 temperature=TEMPERATURE,
                                                 pretrained=False)

        filename = get_pt_model_filename(
            model.__class__.__name__, NAME,
            str(MAX_EPOCHS) + 't' + str(int(TEMPERATURE * 10)))
        filename = os.path.join('test', 'distill_' + filename)
        file_path = os.path.join('save', filename)
        if not os.path.exists(file_path):
            # Expected initial loss = -log(1/num_classes) = 2.3025850929940455'
            cls.distillation.fit(max_epochs=MAX_EPOCHS, batch_size=BATCH_SIZE)
            cls.distillation.save(filename, overwrite=True)
        else:
            cls.distillation.load(file_path)

        smooth_mc = cls.distillation.get_def_model_container()
        accuracy = smooth_mc.evaluate(dc.x_test, dc.y_test)
        logger.info('Accuracy on test set: %f', accuracy)
def main():
    master_seed(SEED)

    logger.info('Starting %s data container...', NAME)
    dc = DataContainer(DATASET_LIST[NAME], get_data_path())
    dc(shuffle=True, normalize=True)

    num_features = dc.dim_data[0]
    num_classes = dc.num_classes
    print('Features:', num_features)
    print('Classes:', num_classes)
    model = BCNN(num_features, num_classes)
    filename = get_pt_model_filename(BCNN.__name__, NAME, MAX_EPOCHS)
    logger.debug('File name: %s', filename)

    mc = ModelContainerPT(model, dc)

    file_path = os.path.join('save', filename)
    if not os.path.exists(file_path):
        logger.debug('Expected initial loss: %f', np.log(dc.num_classes))
        mc.fit(max_epochs=MAX_EPOCHS, batch_size=BATCH_SIZE)
        mc.save(filename, overwrite=True)
    else:
        logger.info('Use saved parameters from %s', filename)
        mc.load(file_path)

    accuracy = mc.evaluate(dc.x_test, dc.y_test)
    logger.info('Accuracy on test set: %f', accuracy)
    def setUpClass(cls):
        master_seed(SEED)

        logger.info('Starting %s data container...', NAME)
        cls.dc = DataContainer(DATASET_LIST[NAME], get_data_path())
        cls.dc(shuffle=True)

        model = MnistCnnV2()
        logger.info('Using model: %s', model.__class__.__name__)

        cls.mc = ModelContainerPT(model, cls.dc)

        filename = get_pt_model_filename(MnistCnnV2.__name__, NAME, MAX_EPOCHS)
        file_path = os.path.join('save', filename)
        if not os.path.exists(file_path):
            cls.mc.fit(max_epochs=MAX_EPOCHS, batch_size=BATCH_SIZE)
            cls.mc.save(filename, overwrite=True)
        else:
            logger.info('Use saved parameters from %s', filename)
            cls.mc.load(file_path)

        accuracy = cls.mc.evaluate(cls.dc.x_test, cls.dc.y_test)
        logger.info('Accuracy on test set: %f', accuracy)

        hidden_model = model.hidden_model

        logger.info('sample_ratio: %f', SAMPLE_RATIO)
        cls.ad = ApplicabilityDomainContainer(
            cls.mc,
            hidden_model=hidden_model,
            k2=9,
            reliability=1.6,
            sample_ratio=SAMPLE_RATIO,
            kappa=10,
            confidence=0.9,
        )
        cls.ad.fit()

        # shuffle the test set
        x_test = cls.dc.x_test
        y_test = cls.dc.y_test
        shuffled_indices = np.random.permutation(len(x_test))[:NUM_ADV]
        cls.x = x_test[shuffled_indices]
        cls.y = y_test[shuffled_indices]
        logger.info('# of test set: %d', len(cls.x))
    def setUpClass(cls):
        master_seed(SEED)

        logger.info('Starting %s data container...', NAME)
        cls.dc = DataContainer(DATASET_LIST[NAME], get_data_path())
        # ordered by labels, it requires shuffle!
        cls.dc(shuffle=True, normalize=True)

        num_features = cls.dc.dim_data[0]
        num_classes = cls.dc.num_classes
        model = BCNN(num_features, num_classes)
        logger.info('Using model: %s', model.__class__.__name__)

        cls.mc = ModelContainerPT(model, cls.dc)

        filename = get_pt_model_filename(BCNN.__name__, NAME, MAX_EPOCHS)
        file_path = os.path.join('save', filename)
        if not os.path.exists(file_path):
            cls.mc.fit(max_epochs=MAX_EPOCHS, batch_size=BATCH_SIZE)
            cls.mc.save(filename, overwrite=True)
        else:
            logger.info('Use saved parameters from %s', filename)
            cls.mc.load(file_path)

        accuracy = cls.mc.evaluate(cls.dc.x_test, cls.dc.y_test)
        logger.info('Accuracy on test set: %f', accuracy)

        hidden_model = model.hidden_model
        cls.ad = ApplicabilityDomainContainer(
            cls.mc,
            hidden_model=hidden_model,
            k2=6,
            reliability=1.6,
            sample_ratio=SAMPLE_RATIO,
            kappa=10,
            confidence=0.9,
        )
        cls.ad.fit()
Ejemplo n.º 5
0
def main():
    parser = ap.ArgumentParser()
    parser.add_argument('-d',
                        '--dataset',
                        type=str,
                        required=True,
                        choices=get_dataset_list(),
                        help='the dataset you want to train')
    parser.add_argument(
        '-o',
        '--ofile',
        type=str,
        help='the filename will be used to store model parameters')
    parser.add_argument('-e',
                        '--epoch',
                        type=int,
                        default=5,
                        help='the number of max epochs for training')
    parser.add_argument('-b',
                        '--batchsize',
                        type=int,
                        default=128,
                        help='batch size')
    parser.add_argument('-s',
                        '--seed',
                        type=int,
                        default=4096,
                        help='the seed for random number generator')
    parser.add_argument('-H',
                        '--shuffle',
                        type=bool,
                        default=True,
                        help='shuffle the dataset')
    parser.add_argument(
        '-n',
        '--normalize',
        type=bool,
        default=True,
        help=
        'apply zero mean and scaling to the dataset (for numeral dataset only)'
    )
    parser.add_argument('-m',
                        '--model',
                        type=str,
                        choices=AVALIABLE_MODELS,
                        help='select a model to train the data')
    parser.add_argument('-v',
                        '--verbose',
                        action='store_true',
                        default=False,
                        help='set logger level to debug')
    parser.add_argument('-l',
                        '--savelog',
                        action='store_true',
                        default=False,
                        help='save logging file')
    parser.add_argument('-w',
                        '--overwrite',
                        action='store_true',
                        default=False,
                        help='overwrite the existing file')
    args = parser.parse_args()
    dname = args.dataset
    filename = args.ofile
    max_epochs = args.epoch
    batch_size = args.batchsize
    seed = args.seed
    use_shuffle = args.shuffle
    use_normalize = args.normalize
    model_name = args.model
    verbose = args.verbose
    save_log = args.savelog
    overwrite = args.overwrite

    # set logging config. Run this before logging anything!
    set_logging('train', dname, verbose, save_log)

    # show parameters
    print('[train] Start training {} model...'.format(model_name))
    logger.info('Start at      : %s', get_time_str())
    logger.info('RECEIVED PARAMETERS:')
    logger.info('dataset       :%s', dname)
    logger.info('filename      :%s', filename)
    logger.info('max_epochs    :%d', max_epochs)
    logger.info('batch_size    :%d', batch_size)
    logger.info('seed          :%d', seed)
    logger.info('use_shuffle   :%r', use_shuffle)
    logger.info('use_normalize :%r', use_normalize)
    logger.info('model_name    :%s', model_name)
    logger.info('verbose       :%r', verbose)
    logger.info('save_log      :%r', save_log)
    logger.info('overwrite     :%r', overwrite)

    master_seed(seed)

    # set DataContainer
    dc = get_data_container(
        dname,
        use_shuffle=use_shuffle,
        use_normalize=use_normalize,
    )

    # select a model
    model = None
    if model_name is not None:
        Model = models.get_model(model_name)
        model = Model()
    else:
        if dname == 'MNIST':
            model = models.MnistCnnV2()
        elif dname == 'CIFAR10':
            model = models.CifarCnn()
        elif dname == 'BreastCancerWisconsin':
            model = models.BCNN()
        elif dname in ('BankNote', 'HTRU2', 'Iris', 'WheatSeed'):
            num_classes = dc.num_classes
            num_features = dc.dim_data[0]
            model = models.IrisNN(num_features=num_features,
                                  hidden_nodes=num_features * 4,
                                  num_classes=num_classes)

    if model is None:
        raise AttributeError('Cannot find model!')
    modelname = model.__class__.__name__
    logger.info('Selected %s model', modelname)

    # set ModelContainer and train the model
    mc = models.ModelContainerPT(model, dc)
    mc.fit(max_epochs=max_epochs, batch_size=batch_size)

    # save
    if not os.path.exists('save'):
        os.makedirs('save')
    if filename is None:
        filename = get_pt_model_filename(modelname, dname, max_epochs)
    logger.debug('File name: %s', filename)
    mc.save(filename, overwrite=overwrite)

    # test result
    file_path = os.path.join('save', filename)
    logger.debug('Use saved parameters from %s', filename)
    mc.load(file_path)
    accuracy = mc.evaluate(dc.x_test, dc.y_test)
    logger.info('Accuracy on test set: %f', accuracy)