def main(primitives): if not torch.cuda.is_available(): logging.info('no gpu device available') sys.exit(1) np.random.seed(args.seed) torch.cuda.set_device(args.gpu) cudnn.benchmark = True torch.manual_seed(args.seed) cudnn.enabled = True torch.cuda.manual_seed(args.seed) logging.info('gpu device = %d' % args.gpu) criterion = nn.CrossEntropyLoss() criterion = criterion.cuda() model_init = Network(args.init_channels, args.n_classes, args.layers, criterion, primitives, steps=args.nodes) model_init = model_init.cuda() #logging.info("param size = %fMB", utils.count_parameters_in_MB(model_init)) optimizer_init = torch.optim.SGD(model_init.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) architect_init = Architect(model_init, args) scheduler_init = CosineAnnealingLR(optimizer_init, float(args.epochs), eta_min=args.learning_rate_min) analyser_init = Analyzer(args, model_init) la_tracker = utils.EVLocalAvg(args.window, args.report_freq_hessian, args.epochs) if args.resume: if os.path.isfile(args.resume_file): print("=> loading checkpoint '{}'".format(args.resume_file)) checkpoint = torch.load(args.resume_file) start_epoch = 27 print('start_epoch', start_epoch) model_init.load_state_dict(checkpoint['state_dict']) model_init.alphas_normal.data = checkpoint['alphas_normal'] model_init.alphas_reduce.data = checkpoint['alphas_reduce'] model_init = model_init.cuda() logging.info("param size = %fMB", utils.count_parameters_in_MB(model_init)) optimizer_init.load_state_dict(checkpoint['optimizer']) architect_init.optimizer.load_state_dict( checkpoint['arch_optimizer']) scheduler_init = CosineAnnealingLR(optimizer_init, float(args.epochs), eta_min=args.learning_rate_min) analyser_init = Analyzer(args, model_init) la_tracker = utils.EVLocalAvg(args.window, args.report_freq_hessian, args.epochs) la_tracker.ev = checkpoint['ev'] la_tracker.ev_local_avg = checkpoint['ev_local_avg'] la_tracker.genotypes = checkpoint['genotypes'] la_tracker.la_epochs = checkpoint['la_epochs'] la_tracker.la_start_idx = checkpoint['la_start_idx'] la_tracker.la_end_idx = checkpoint['la_end_idx'] lr = checkpoint['lr'] train_queue, valid_queue, train_transform, valid_transform = helper.get_train_val_loaders( ) errors_dict = { 'train_acc': [], 'train_loss': [], 'valid_acc': [], 'valid_loss': [] } #for epoch in range(args.epochs): def train_epochs(epochs_to_train, iteration, args=args, model=model_init, optimizer=optimizer_init, scheduler=scheduler_init, train_queue=train_queue, valid_queue=valid_queue, train_transform=train_transform, valid_transform=valid_transform, architect=architect_init, criterion=criterion, primitives=primitives, analyser=analyser_init, la_tracker=la_tracker, errors_dict=errors_dict, start_epoch=-1): logging.info('STARTING ITERATION: %d', iteration) logging.info('EPOCHS TO TRAIN: %d', epochs_to_train - start_epoch - 1) la_tracker.stop_search = False if epochs_to_train - start_epoch - 1 <= 0: return model.genotype(), -1 for epoch in range(start_epoch + 1, epochs_to_train): # set the epoch to the right one #epoch += args.epochs - epochs_to_train scheduler.step(epoch) lr = scheduler.get_lr()[0] if args.drop_path_prob != 0: model.drop_path_prob = args.drop_path_prob * epoch / ( args.epochs - 1) train_transform.transforms[ -1].cutout_prob = args.cutout_prob * epoch / (args.epochs - 1) logging.info('epoch %d lr %e drop_prob %e cutout_prob %e', epoch, lr, model.drop_path_prob, train_transform.transforms[-1].cutout_prob) else: logging.info('epoch %d lr %e', epoch, lr) # training train_acc, train_obj = train(epoch, primitives, train_queue, valid_queue, model, architect, criterion, optimizer, lr, analyser, la_tracker, iteration) logging.info('train_acc %f', train_acc) # validation valid_acc, valid_obj = infer(valid_queue, model, criterion) logging.info('valid_acc %f', valid_acc) # update the errors dictionary errors_dict['train_acc'].append(100 - train_acc) errors_dict['train_loss'].append(train_obj) errors_dict['valid_acc'].append(100 - valid_acc) errors_dict['valid_loss'].append(valid_obj) genotype = model.genotype() logging.info('genotype = %s', genotype) print(F.softmax(model.alphas_normal, dim=-1)) print(F.softmax(model.alphas_reduce, dim=-1)) state = { 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'alphas_normal': model.alphas_normal.data, 'alphas_reduce': model.alphas_reduce.data, 'arch_optimizer': architect.optimizer.state_dict(), 'lr': lr, 'ev': la_tracker.ev, 'ev_local_avg': la_tracker.ev_local_avg, 'genotypes': la_tracker.genotypes, 'la_epochs': la_tracker.la_epochs, 'la_start_idx': la_tracker.la_start_idx, 'la_end_idx': la_tracker.la_end_idx, #'scheduler': scheduler.state_dict(), } utils.save_checkpoint(state, False, args.save, epoch, args.task_id) if not args.compute_hessian: ev = -1 else: ev = la_tracker.ev[-1] params = { 'iteration': iteration, 'epoch': epoch, 'wd': args.weight_decay, 'ev': ev, } schedule_of_params.append(params) # limit the number of iterations based on the maximum regularization # value predefined by the user final_iteration = round( np.log(args.max_weight_decay) / np.log(args.weight_decay), 1) == 1. ##lr decay到一定程度就停止 if la_tracker.stop_search and not final_iteration: if args.early_stop == 1: # set the following to the values they had at stop_epoch errors_dict['valid_acc'] = errors_dict[ 'valid_acc'][:la_tracker.stop_epoch + 1] genotype = la_tracker.stop_genotype valid_acc = 100 - errors_dict['valid_acc'][ la_tracker.stop_epoch] logging.info( 'Decided to stop the search at epoch %d (Current epoch: %d)', la_tracker.stop_epoch, epoch) logging.info('Validation accuracy at stop epoch: %f', valid_acc) logging.info('Genotype at stop epoch: %s', genotype) break elif args.early_stop == 2: # simulate early stopping and continue search afterwards simulated_errors_dict = errors_dict[ 'valid_acc'][:la_tracker.stop_epoch + 1] simulated_genotype = la_tracker.stop_genotype simulated_valid_acc = 100 - simulated_errors_dict[ la_tracker.stop_epoch] logging.info( '(SIM) Decided to stop the search at epoch %d (Current epoch: %d)', la_tracker.stop_epoch, epoch) logging.info('(SIM) Validation accuracy at stop epoch: %f', simulated_valid_acc) logging.info('(SIM) Genotype at stop epoch: %s', simulated_genotype) with open( os.path.join(args.save, 'arch_early_{}'.format(args.task_id)), 'w') as file: file.write(str(simulated_genotype)) utils.write_yaml_results(args, 'early_' + args.results_file_arch, str(simulated_genotype)) utils.write_yaml_results(args, 'early_stop_epochs', la_tracker.stop_epoch) args.early_stop = 0 elif args.early_stop == 3: # adjust regularization simulated_errors_dict = errors_dict[ 'valid_acc'][:la_tracker.stop_epoch + 1] simulated_genotype = la_tracker.stop_genotype simulated_valid_acc = 100 - simulated_errors_dict[ la_tracker.stop_epoch] stop_epoch = la_tracker.stop_epoch start_again_epoch = stop_epoch - args.extra_rollback_epochs logging.info( '(ADA) Decided to increase regularization at epoch %d (Current epoch: %d)', stop_epoch, epoch) logging.info('(ADA) Rolling back to epoch %d', start_again_epoch) logging.info( '(ADA) Restoring model parameters and continuing for %d epochs', epochs_to_train - start_again_epoch - 1) if iteration == 1: logging.info( '(ADA) Saving the architecture at the early stop epoch and ' 'continuing with the adaptive regularization strategy' ) utils.write_yaml_results( args, 'early_' + args.results_file_arch, str(simulated_genotype)) del model del architect del optimizer del scheduler del analyser model_new = Network(args.init_channels, args.n_classes, args.layers, criterion, primitives, steps=args.nodes) model_new = model_new.cuda() optimizer_new = torch.optim.SGD( model_new.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) architect_new = Architect(model_new, args) analyser_new = Analyzer(args, model_new) la_tracker = utils.EVLocalAvg(args.window, args.report_freq_hessian, args.epochs) lr = utils.load_checkpoint(model_new, optimizer_new, None, architect_new, args.save, la_tracker, start_again_epoch, args.task_id) args.weight_decay *= args.mul_factor for param_group in optimizer_new.param_groups: param_group['weight_decay'] = args.weight_decay scheduler_new = CosineAnnealingLR( optimizer_new, float(args.epochs), eta_min=args.learning_rate_min) logging.info('(ADA) Validation accuracy at stop epoch: %f', simulated_valid_acc) logging.info('(ADA) Genotype at stop epoch: %s', simulated_genotype) logging.info( '(ADA) Adjusting L2 regularization to the new value: %f', args.weight_decay) genotype, valid_acc = train_epochs(args.epochs, iteration + 1, model=model_new, optimizer=optimizer_new, architect=architect_new, scheduler=scheduler_new, analyser=analyser_new, start_epoch=start_epoch) args.early_stop = 0 break return genotype, valid_acc # call train_epochs recursively genotype, valid_acc = train_epochs(args.epochs, 1) with codecs.open(os.path.join(args.save, 'errors_{}.json'.format(args.task_id)), 'w', encoding='utf-8') as file: json.dump(errors_dict, file, separators=(',', ':')) with open(os.path.join(args.save, 'arch_{}'.format(args.task_id)), 'w') as file: file.write(str(genotype)) utils.write_yaml_results(args, args.results_file_arch, str(genotype)) utils.write_yaml_results(args, args.results_file_perf, 100 - valid_acc) with open( os.path.join(args.save, 'schedule_{}.pickle'.format(args.task_id)), 'ab') as file: pickle.dump(schedule_of_params, file, pickle.HIGHEST_PROTOCOL)
def train_epochs(epochs_to_train, iteration, args=args, model=model_init, optimizer=optimizer_init, scheduler=scheduler_init, train_queue=train_queue, valid_queue=valid_queue, train_transform=train_transform, valid_transform=valid_transform, architect=architect_init, criterion=criterion, primitives=primitives, analyser=analyser_init, la_tracker=la_tracker, errors_dict=errors_dict, start_epoch=-1): logging.info('STARTING ITERATION: %d', iteration) logging.info('EPOCHS TO TRAIN: %d', epochs_to_train - start_epoch - 1) la_tracker.stop_search = False if epochs_to_train - start_epoch - 1 <= 0: return model.genotype(), -1 for epoch in range(start_epoch + 1, epochs_to_train): # set the epoch to the right one #epoch += args.epochs - epochs_to_train scheduler.step(epoch) lr = scheduler.get_lr()[0] if args.drop_path_prob != 0: model.drop_path_prob = args.drop_path_prob * epoch / ( args.epochs - 1) train_transform.transforms[ -1].cutout_prob = args.cutout_prob * epoch / (args.epochs - 1) logging.info('epoch %d lr %e drop_prob %e cutout_prob %e', epoch, lr, model.drop_path_prob, train_transform.transforms[-1].cutout_prob) else: logging.info('epoch %d lr %e', epoch, lr) # training train_acc, train_obj = train(epoch, primitives, train_queue, valid_queue, model, architect, criterion, optimizer, lr, analyser, la_tracker, iteration) logging.info('train_acc %f', train_acc) # validation valid_acc, valid_obj = infer(valid_queue, model, criterion) logging.info('valid_acc %f', valid_acc) # update the errors dictionary errors_dict['train_acc'].append(100 - train_acc) errors_dict['train_loss'].append(train_obj) errors_dict['valid_acc'].append(100 - valid_acc) errors_dict['valid_loss'].append(valid_obj) genotype = model.genotype() logging.info('genotype = %s', genotype) print(F.softmax(model.alphas_normal, dim=-1)) print(F.softmax(model.alphas_reduce, dim=-1)) state = { 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'alphas_normal': model.alphas_normal.data, 'alphas_reduce': model.alphas_reduce.data, 'arch_optimizer': architect.optimizer.state_dict(), 'lr': lr, 'ev': la_tracker.ev, 'ev_local_avg': la_tracker.ev_local_avg, 'genotypes': la_tracker.genotypes, 'la_epochs': la_tracker.la_epochs, 'la_start_idx': la_tracker.la_start_idx, 'la_end_idx': la_tracker.la_end_idx, #'scheduler': scheduler.state_dict(), } utils.save_checkpoint(state, False, args.save, epoch, args.task_id) if not args.compute_hessian: ev = -1 else: ev = la_tracker.ev[-1] params = { 'iteration': iteration, 'epoch': epoch, 'wd': args.weight_decay, 'ev': ev, } schedule_of_params.append(params) # limit the number of iterations based on the maximum regularization # value predefined by the user final_iteration = round( np.log(args.max_weight_decay) / np.log(args.weight_decay), 1) == 1. ##lr decay到一定程度就停止 if la_tracker.stop_search and not final_iteration: if args.early_stop == 1: # set the following to the values they had at stop_epoch errors_dict['valid_acc'] = errors_dict[ 'valid_acc'][:la_tracker.stop_epoch + 1] genotype = la_tracker.stop_genotype valid_acc = 100 - errors_dict['valid_acc'][ la_tracker.stop_epoch] logging.info( 'Decided to stop the search at epoch %d (Current epoch: %d)', la_tracker.stop_epoch, epoch) logging.info('Validation accuracy at stop epoch: %f', valid_acc) logging.info('Genotype at stop epoch: %s', genotype) break elif args.early_stop == 2: # simulate early stopping and continue search afterwards simulated_errors_dict = errors_dict[ 'valid_acc'][:la_tracker.stop_epoch + 1] simulated_genotype = la_tracker.stop_genotype simulated_valid_acc = 100 - simulated_errors_dict[ la_tracker.stop_epoch] logging.info( '(SIM) Decided to stop the search at epoch %d (Current epoch: %d)', la_tracker.stop_epoch, epoch) logging.info('(SIM) Validation accuracy at stop epoch: %f', simulated_valid_acc) logging.info('(SIM) Genotype at stop epoch: %s', simulated_genotype) with open( os.path.join(args.save, 'arch_early_{}'.format(args.task_id)), 'w') as file: file.write(str(simulated_genotype)) utils.write_yaml_results(args, 'early_' + args.results_file_arch, str(simulated_genotype)) utils.write_yaml_results(args, 'early_stop_epochs', la_tracker.stop_epoch) args.early_stop = 0 elif args.early_stop == 3: # adjust regularization simulated_errors_dict = errors_dict[ 'valid_acc'][:la_tracker.stop_epoch + 1] simulated_genotype = la_tracker.stop_genotype simulated_valid_acc = 100 - simulated_errors_dict[ la_tracker.stop_epoch] stop_epoch = la_tracker.stop_epoch start_again_epoch = stop_epoch - args.extra_rollback_epochs logging.info( '(ADA) Decided to increase regularization at epoch %d (Current epoch: %d)', stop_epoch, epoch) logging.info('(ADA) Rolling back to epoch %d', start_again_epoch) logging.info( '(ADA) Restoring model parameters and continuing for %d epochs', epochs_to_train - start_again_epoch - 1) if iteration == 1: logging.info( '(ADA) Saving the architecture at the early stop epoch and ' 'continuing with the adaptive regularization strategy' ) utils.write_yaml_results( args, 'early_' + args.results_file_arch, str(simulated_genotype)) del model del architect del optimizer del scheduler del analyser model_new = Network(args.init_channels, args.n_classes, args.layers, criterion, primitives, steps=args.nodes) model_new = model_new.cuda() optimizer_new = torch.optim.SGD( model_new.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) architect_new = Architect(model_new, args) analyser_new = Analyzer(args, model_new) la_tracker = utils.EVLocalAvg(args.window, args.report_freq_hessian, args.epochs) lr = utils.load_checkpoint(model_new, optimizer_new, None, architect_new, args.save, la_tracker, start_again_epoch, args.task_id) args.weight_decay *= args.mul_factor for param_group in optimizer_new.param_groups: param_group['weight_decay'] = args.weight_decay scheduler_new = CosineAnnealingLR( optimizer_new, float(args.epochs), eta_min=args.learning_rate_min) logging.info('(ADA) Validation accuracy at stop epoch: %f', simulated_valid_acc) logging.info('(ADA) Genotype at stop epoch: %s', simulated_genotype) logging.info( '(ADA) Adjusting L2 regularization to the new value: %f', args.weight_decay) genotype, valid_acc = train_epochs(args.epochs, iteration + 1, model=model_new, optimizer=optimizer_new, architect=architect_new, scheduler=scheduler_new, analyser=analyser_new, start_epoch=start_epoch) args.early_stop = 0 break return genotype, valid_acc