示例#1
0
def _train(args):
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] => %(message)s',
        handlers=[
            logging.FileHandler(filename=args['prefix'] +
                                '_{}_{}_{}_{}.log'.format(
                                    args['model_name'], args['convnet_type'],
                                    args['init_cls'], args['increment'])),
            logging.StreamHandler(sys.stdout)
        ])

    logging.info('Seed: {}'.format(args['seed']))
    logging.info('Model: {}'.format(args['model_name']))
    logging.info('Convnet: {}'.format(args['convnet_type']))
    _set_device(args)
    data_manager = DataManager(args['dataset'], args['shuffle'], args['seed'],
                               args['init_cls'], args['increment'])
    model = factory.get_model(args['model_name'], args)

    curve = []
    for task in range(data_manager.nb_tasks):
        logging.info('All params: {}'.format(count_parameters(model._network)))
        logging.info('Trainable params: {}'.format(
            count_parameters(model._network, True)))
        model.incremental_train(data_manager)
        accy = model.eval_task()
        model.after_task()

        logging.info(accy)
        curve.append(accy['total'])
        logging.info('Curve: {}\n'.format(curve))
示例#2
0
def _train(args):
    logfilename = '{}_{}_{}_{}_{}_{}_{}'.format(
        args['prefix'], args['seed'], args['model_name'], args['convnet_type'],
        args['dataset'], args['init_cls'], args['increment'])
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s [%(filename)s] => %(message)s',
                        handlers=[
                            logging.FileHandler(filename=logfilename + '.log'),
                            logging.StreamHandler(sys.stdout)
                        ])

    _set_random()
    _set_device(args)
    print_args(args)
    data_manager = DataManager(args['dataset'], args['shuffle'], args['seed'],
                               args['init_cls'], args['increment'])
    model = factory.get_model(args['model_name'], args)

    cnn_curve, nme_curve = {'top1': [], 'top5': []}, {'top1': [], 'top5': []}
    for task in range(data_manager.nb_tasks):
        logging.info('All params: {}'.format(count_parameters(model._network)))
        logging.info('Trainable params: {}'.format(
            count_parameters(model._network, True)))
        model.incremental_train(data_manager)
        cnn_accy, nme_accy = model.eval_task()
        model.after_task()

        if nme_accy is not None:
            logging.info('CNN: {}'.format(cnn_accy['grouped']))
            logging.info('NME: {}'.format(nme_accy['grouped']))

            cnn_curve['top1'].append(cnn_accy['top1'])
            cnn_curve['top5'].append(cnn_accy['top5'])

            nme_curve['top1'].append(nme_accy['top1'])
            nme_curve['top5'].append(nme_accy['top5'])

            logging.info('CNN top1 curve: {}'.format(cnn_curve['top1']))
            logging.info('CNN top5 curve: {}'.format(cnn_curve['top5']))
            logging.info('NME top1 curve: {}'.format(nme_curve['top1']))
            logging.info('NME top5 curve: {}\n'.format(nme_curve['top5']))
        else:
            logging.info('No NME accuracy.')
            logging.info('CNN: {}'.format(cnn_accy['grouped']))

            cnn_curve['top1'].append(cnn_accy['top1'])
            cnn_curve['top5'].append(cnn_accy['top5'])

            logging.info('CNN top1 curve: {}'.format(cnn_curve['top1']))
            logging.info('CNN top5 curve: {}\n'.format(cnn_curve['top5']))
示例#3
0
def _train(args):
    logfilename = '{}_{}_{}_{}_{}_{}_{}'.format(
        args['prefix'], args['seed'], args['model_name'], args['convnet_type'],
        args['dataset'], args['init_cls'], args['increment'])
    '''@Author:defeng
        {
        "prefix": "reproduce",
        "dataset": "cifar100",
        "memory_size": 2000,
        "memory_per_class": 20,
        "fixed_memory": true,
        "shuffle": true,
        "init_cls": 50,
        "increment": 10, #increase $increment classes each task. see "# Grouped accuracy" in toolkit.py
        "model_name": "UCIR",
        "convnet_type": "cosine_resnet32",
        "device": ["0"],
        "seed": [30]
        }

    '''
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s [%(filename)s] => %(message)s',
                        handlers=[
                            logging.FileHandler(filename=logfilename + '.log'),
                            logging.StreamHandler(sys.stdout)
                        ])
    '''@Author:defeng
        see for details: https://www.cnblogs.com/xianyulouie/p/11041777.html
        26 May 2021 (Wednesday)
        format: %(filename)s enables output like this "2021-05-26 22:01:34,371 [*ucir.py*]" and we can know which file \
        a certain output come from.
    '''
    '''@Author:defeng
        set random seed and cuda devices
    '''
    _set_random()
    _set_device(args)
    print_args(args)
    '''@Author:defeng
        *set: dataset and model.*
    '''
    data_manager = DataManager(args['dataset'], args['shuffle'], args['seed'],
                               args['init_cls'], args['increment'])
    model = factory.get_model(args['model_name'], args)
    '''@Author:defeng
        the actual work for getting model ready is done by the .py files in the "models" folder.
    '''
    '''@Author:defeng
        cnn: softmax prediction
        nme: nearest-mean-of-neightbors prediction
        see ucir paper "Baselines" for detail.
    '''
    cnn_curve, nme_curve = {'top1': [], 'top5': []}, {'top1': [], 'top5': []}
    for task in range(data_manager.nb_tasks):
        logging.info('All params: {}'.format(count_parameters(model._network)))
        logging.info('Trainable params: {}'.format(
            count_parameters(model._network, True)))

        model.incremental_train(data_manager)  #train
        cnn_accy, nme_accy = model.eval_task()  #val
        model.after_task()  #post-processing

        if nme_accy is not None:
            logging.info('CNN: {}'.format(cnn_accy['grouped']))
            logging.info('NME: {}'.format(nme_accy['grouped']))

            cnn_curve['top1'].append(cnn_accy['top1'])
            cnn_curve['top5'].append(cnn_accy['top5'])

            nme_curve['top1'].append(nme_accy['top1'])
            nme_curve['top5'].append(nme_accy['top5'])

            logging.info('CNN top1 curve: {}'.format(cnn_curve['top1']))
            logging.info('CNN top5 curve: {}'.format(cnn_curve['top5']))
            logging.info('NME top1 curve: {}'.format(nme_curve['top1']))
            logging.info('NME top5 curve: {}\n'.format(nme_curve['top5']))
        else:
            logging.info('No NME accuracy.')
            logging.info('CNN: {}'.format(cnn_accy['grouped']))

            cnn_curve['top1'].append(cnn_accy['top1'])
            cnn_curve['top5'].append(cnn_accy['top5'])

            logging.info('CNN top1 curve: {}'.format(cnn_curve['top1']))
            logging.info('CNN top5 curve: {}\n'.format(cnn_curve['top5']))
示例#4
0
 def __init__(self, cfg):
     super().__init__()
     self.cfg = cfg
     self.net = get_model(self.cfg.model)
     self.loss = get_loss(self.cfg.loss)