def _train(args): logging.basicConfig( level=logging.INFO, format='%(asctime)s [%(levelname)s] => %(message)s', handlers=[ logging.FileHandler(filename=args['prefix'] + '_{}_{}_{}_{}.log'.format( args['model_name'], args['convnet_type'], args['init_cls'], args['increment'])), logging.StreamHandler(sys.stdout) ]) logging.info('Seed: {}'.format(args['seed'])) logging.info('Model: {}'.format(args['model_name'])) logging.info('Convnet: {}'.format(args['convnet_type'])) _set_device(args) data_manager = DataManager(args['dataset'], args['shuffle'], args['seed'], args['init_cls'], args['increment']) model = factory.get_model(args['model_name'], args) curve = [] for task in range(data_manager.nb_tasks): logging.info('All params: {}'.format(count_parameters(model._network))) logging.info('Trainable params: {}'.format( count_parameters(model._network, True))) model.incremental_train(data_manager) accy = model.eval_task() model.after_task() logging.info(accy) curve.append(accy['total']) logging.info('Curve: {}\n'.format(curve))
def _train(args): logfilename = '{}_{}_{}_{}_{}_{}_{}'.format( args['prefix'], args['seed'], args['model_name'], args['convnet_type'], args['dataset'], args['init_cls'], args['increment']) logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(filename)s] => %(message)s', handlers=[ logging.FileHandler(filename=logfilename + '.log'), logging.StreamHandler(sys.stdout) ]) _set_random() _set_device(args) print_args(args) data_manager = DataManager(args['dataset'], args['shuffle'], args['seed'], args['init_cls'], args['increment']) model = factory.get_model(args['model_name'], args) cnn_curve, nme_curve = {'top1': [], 'top5': []}, {'top1': [], 'top5': []} for task in range(data_manager.nb_tasks): logging.info('All params: {}'.format(count_parameters(model._network))) logging.info('Trainable params: {}'.format( count_parameters(model._network, True))) model.incremental_train(data_manager) cnn_accy, nme_accy = model.eval_task() model.after_task() if nme_accy is not None: logging.info('CNN: {}'.format(cnn_accy['grouped'])) logging.info('NME: {}'.format(nme_accy['grouped'])) cnn_curve['top1'].append(cnn_accy['top1']) cnn_curve['top5'].append(cnn_accy['top5']) nme_curve['top1'].append(nme_accy['top1']) nme_curve['top5'].append(nme_accy['top5']) logging.info('CNN top1 curve: {}'.format(cnn_curve['top1'])) logging.info('CNN top5 curve: {}'.format(cnn_curve['top5'])) logging.info('NME top1 curve: {}'.format(nme_curve['top1'])) logging.info('NME top5 curve: {}\n'.format(nme_curve['top5'])) else: logging.info('No NME accuracy.') logging.info('CNN: {}'.format(cnn_accy['grouped'])) cnn_curve['top1'].append(cnn_accy['top1']) cnn_curve['top5'].append(cnn_accy['top5']) logging.info('CNN top1 curve: {}'.format(cnn_curve['top1'])) logging.info('CNN top5 curve: {}\n'.format(cnn_curve['top5']))
def _train(args): logfilename = '{}_{}_{}_{}_{}_{}_{}'.format( args['prefix'], args['seed'], args['model_name'], args['convnet_type'], args['dataset'], args['init_cls'], args['increment']) '''@Author:defeng { "prefix": "reproduce", "dataset": "cifar100", "memory_size": 2000, "memory_per_class": 20, "fixed_memory": true, "shuffle": true, "init_cls": 50, "increment": 10, #increase $increment classes each task. see "# Grouped accuracy" in toolkit.py "model_name": "UCIR", "convnet_type": "cosine_resnet32", "device": ["0"], "seed": [30] } ''' logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(filename)s] => %(message)s', handlers=[ logging.FileHandler(filename=logfilename + '.log'), logging.StreamHandler(sys.stdout) ]) '''@Author:defeng see for details: https://www.cnblogs.com/xianyulouie/p/11041777.html 26 May 2021 (Wednesday) format: %(filename)s enables output like this "2021-05-26 22:01:34,371 [*ucir.py*]" and we can know which file \ a certain output come from. ''' '''@Author:defeng set random seed and cuda devices ''' _set_random() _set_device(args) print_args(args) '''@Author:defeng *set: dataset and model.* ''' data_manager = DataManager(args['dataset'], args['shuffle'], args['seed'], args['init_cls'], args['increment']) model = factory.get_model(args['model_name'], args) '''@Author:defeng the actual work for getting model ready is done by the .py files in the "models" folder. ''' '''@Author:defeng cnn: softmax prediction nme: nearest-mean-of-neightbors prediction see ucir paper "Baselines" for detail. ''' cnn_curve, nme_curve = {'top1': [], 'top5': []}, {'top1': [], 'top5': []} for task in range(data_manager.nb_tasks): logging.info('All params: {}'.format(count_parameters(model._network))) logging.info('Trainable params: {}'.format( count_parameters(model._network, True))) model.incremental_train(data_manager) #train cnn_accy, nme_accy = model.eval_task() #val model.after_task() #post-processing if nme_accy is not None: logging.info('CNN: {}'.format(cnn_accy['grouped'])) logging.info('NME: {}'.format(nme_accy['grouped'])) cnn_curve['top1'].append(cnn_accy['top1']) cnn_curve['top5'].append(cnn_accy['top5']) nme_curve['top1'].append(nme_accy['top1']) nme_curve['top5'].append(nme_accy['top5']) logging.info('CNN top1 curve: {}'.format(cnn_curve['top1'])) logging.info('CNN top5 curve: {}'.format(cnn_curve['top5'])) logging.info('NME top1 curve: {}'.format(nme_curve['top1'])) logging.info('NME top5 curve: {}\n'.format(nme_curve['top5'])) else: logging.info('No NME accuracy.') logging.info('CNN: {}'.format(cnn_accy['grouped'])) cnn_curve['top1'].append(cnn_accy['top1']) cnn_curve['top5'].append(cnn_accy['top5']) logging.info('CNN top1 curve: {}'.format(cnn_curve['top1'])) logging.info('CNN top5 curve: {}\n'.format(cnn_curve['top5']))
def __init__(self, cfg): super().__init__() self.cfg = cfg self.net = get_model(self.cfg.model) self.loss = get_loss(self.cfg.loss)