Пример #1
0
fh = logging.FileHandler(
    os.path.join('../../results/{}'.format(args.model), log_path))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)
logging.info(
    "*************************************************************************************************"
)
logging.info(
    "                                         add_noise.py                                             "
)
logging.info(
    "*************************************************************************************************"
)
logging.info("args = %s", args)

method = continualNN.ContinualNN()
method.initial_single_network(init_weights=True)

# -----------------------------------------
# Prepare dataset
# -----------------------------------------

task_list, _ = method.create_task()
logging.info('Task list %s: ', task_list)

task_division = []
for item in args.task_division.split(","):
    task_division.append(int(item))
total_task = len(task_division)
logging.info('task_division %s', task_division)
if args.dataset == 'cifar10':
os.environ["CUDA_VISIBLE_DEVICES"]= args.gpu
torch.backends.cudnn.deterministic = True
np.random.seed(args.seed)  # Python random module.
random.seed(args.seed)
torch.manual_seed(args.seed)
log_path = './log.txt'.format()

log_format = '%(asctime)s   %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
	format=log_format, datefmt='%m/%d %I:%M%p')
fh = logging.FileHandler(os.path.join('../results/',log_path))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)
logging.info("args = %s", args)

secondTask = continualNN.ContinualNN()

task_list = secondTask.create_task()

#
logging.info("========================================== Get KD targets ==============================================")
model_NA = secondTask.initial_network(task_id =0)
secondTask.initialization(args.lr_mutant, args.lr_mutant_step_size,  args.weight_decay_2)
# secondTask.load_mutant(0, 0, model_NA)
task_id = 1
current_task = task_list[task_id]
partial_trainsetLoader, partial_testsetLoader = get_dataset_cifar(current_task, -1* args.classes_per_task)
for batch_idx, (data, target) in enumerate(partial_trainsetLoader):
	print('batch {} Current task:{}'.format(batch_idx, target[0:10]))
	break