def train_epochs(epochs_to_train, iteration, args=args, model=model_init, optimizer=optimizer_init, scheduler=scheduler_init, train_queue=train_queue, valid_queue=valid_queue, train_transform=train_transform, valid_transform=valid_transform, architect=architect_init, criterion=criterion, primitives=primitives, analyser=analyser_init, la_tracker=la_tracker, errors_dict=errors_dict, start_epoch=-1): logging.info('STARTING ITERATION: %d', iteration) logging.info('EPOCHS TO TRAIN: %d', epochs_to_train - start_epoch - 1) la_tracker.stop_search = False if epochs_to_train - start_epoch - 1 <= 0: return model.genotype(), -1 for epoch in range(start_epoch + 1, epochs_to_train): # set the epoch to the right one #epoch += args.epochs - epochs_to_train scheduler.step(epoch) lr = scheduler.get_lr()[0] if args.drop_path_prob != 0: model.drop_path_prob = args.drop_path_prob * epoch / ( args.epochs - 1) train_transform.transforms[ -1].cutout_prob = args.cutout_prob * epoch / (args.epochs - 1) logging.info('epoch %d lr %e drop_prob %e cutout_prob %e', epoch, lr, model.drop_path_prob, train_transform.transforms[-1].cutout_prob) else: logging.info('epoch %d lr %e', epoch, lr) # training train_acc, train_obj = train(epoch, primitives, train_queue, valid_queue, model, architect, criterion, optimizer, lr, analyser, la_tracker, iteration) logging.info('train_acc %f', train_acc) # validation valid_acc, valid_obj = infer(valid_queue, model, criterion) logging.info('valid_acc %f', valid_acc) # update the errors dictionary errors_dict['train_acc'].append(100 - train_acc) errors_dict['train_loss'].append(train_obj) errors_dict['valid_acc'].append(100 - valid_acc) errors_dict['valid_loss'].append(valid_obj) genotype = model.genotype() logging.info('genotype = %s', genotype) print(F.softmax(model.alphas_normal, dim=-1)) print(F.softmax(model.alphas_reduce, dim=-1)) state = { 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'alphas_normal': model.alphas_normal.data, 'alphas_reduce': model.alphas_reduce.data, 'arch_optimizer': architect.optimizer.state_dict(), 'lr': lr, 'ev': la_tracker.ev, 'ev_local_avg': la_tracker.ev_local_avg, 'genotypes': la_tracker.genotypes, 'la_epochs': la_tracker.la_epochs, 'la_start_idx': la_tracker.la_start_idx, 'la_end_idx': la_tracker.la_end_idx, #'scheduler': scheduler.state_dict(), } utils.save_checkpoint(state, False, args.save, epoch, args.task_id) if not args.compute_hessian: ev = -1 else: ev = la_tracker.ev[-1] params = { 'iteration': iteration, 'epoch': epoch, 'wd': args.weight_decay, 'ev': ev, } schedule_of_params.append(params) # limit the number of iterations based on the maximum regularization # value predefined by the user final_iteration = round( np.log(args.max_weight_decay) / np.log(args.weight_decay), 1) == 1. ##lr decay到一定程度就停止 if la_tracker.stop_search and not final_iteration: if args.early_stop == 1: # set the following to the values they had at stop_epoch errors_dict['valid_acc'] = errors_dict[ 'valid_acc'][:la_tracker.stop_epoch + 1] genotype = la_tracker.stop_genotype valid_acc = 100 - errors_dict['valid_acc'][ la_tracker.stop_epoch] logging.info( 'Decided to stop the search at epoch %d (Current epoch: %d)', la_tracker.stop_epoch, epoch) logging.info('Validation accuracy at stop epoch: %f', valid_acc) logging.info('Genotype at stop epoch: %s', genotype) break elif args.early_stop == 2: # simulate early stopping and continue search afterwards simulated_errors_dict = errors_dict[ 'valid_acc'][:la_tracker.stop_epoch + 1] simulated_genotype = la_tracker.stop_genotype simulated_valid_acc = 100 - simulated_errors_dict[ la_tracker.stop_epoch] logging.info( '(SIM) Decided to stop the search at epoch %d (Current epoch: %d)', la_tracker.stop_epoch, epoch) logging.info('(SIM) Validation accuracy at stop epoch: %f', simulated_valid_acc) logging.info('(SIM) Genotype at stop epoch: %s', simulated_genotype) with open( os.path.join(args.save, 'arch_early_{}'.format(args.task_id)), 'w') as file: file.write(str(simulated_genotype)) utils.write_yaml_results(args, 'early_' + args.results_file_arch, str(simulated_genotype)) utils.write_yaml_results(args, 'early_stop_epochs', la_tracker.stop_epoch) args.early_stop = 0 elif args.early_stop == 3: # adjust regularization simulated_errors_dict = errors_dict[ 'valid_acc'][:la_tracker.stop_epoch + 1] simulated_genotype = la_tracker.stop_genotype simulated_valid_acc = 100 - simulated_errors_dict[ la_tracker.stop_epoch] stop_epoch = la_tracker.stop_epoch start_again_epoch = stop_epoch - args.extra_rollback_epochs logging.info( '(ADA) Decided to increase regularization at epoch %d (Current epoch: %d)', stop_epoch, epoch) logging.info('(ADA) Rolling back to epoch %d', start_again_epoch) logging.info( '(ADA) Restoring model parameters and continuing for %d epochs', epochs_to_train - start_again_epoch - 1) if iteration == 1: logging.info( '(ADA) Saving the architecture at the early stop epoch and ' 'continuing with the adaptive regularization strategy' ) utils.write_yaml_results( args, 'early_' + args.results_file_arch, str(simulated_genotype)) del model del architect del optimizer del scheduler del analyser model_new = Network(args.init_channels, args.n_classes, args.layers, criterion, primitives, steps=args.nodes) model_new = model_new.cuda() optimizer_new = torch.optim.SGD( model_new.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) architect_new = Architect(model_new, args) analyser_new = Analyzer(args, model_new) la_tracker = utils.EVLocalAvg(args.window, args.report_freq_hessian, args.epochs) lr = utils.load_checkpoint(model_new, optimizer_new, None, architect_new, args.save, la_tracker, start_again_epoch, args.task_id) args.weight_decay *= args.mul_factor for param_group in optimizer_new.param_groups: param_group['weight_decay'] = args.weight_decay scheduler_new = CosineAnnealingLR( optimizer_new, float(args.epochs), eta_min=args.learning_rate_min) logging.info('(ADA) Validation accuracy at stop epoch: %f', simulated_valid_acc) logging.info('(ADA) Genotype at stop epoch: %s', simulated_genotype) logging.info( '(ADA) Adjusting L2 regularization to the new value: %f', args.weight_decay) genotype, valid_acc = train_epochs(args.epochs, iteration + 1, model=model_new, optimizer=optimizer_new, architect=architect_new, scheduler=scheduler_new, analyser=analyser_new, start_epoch=start_epoch) args.early_stop = 0 break return genotype, valid_acc
def main(primitives): if not torch.cuda.is_available() or args.disable_cuda: logging.info('no gpu device available or disabling cuda') np.random.seed(args.seed) torch.manual_seed(args.seed) if not args.disable_cuda: torch.cuda.set_device(args.gpu) logging.info('gpu device = %d' % args.gpu) cudnn.benchmark = True cudnn.enabled = True torch.cuda.manual_seed(args.seed) criterion = nn.CrossEntropyLoss() if not args.disable_cuda: criterion = criterion.cuda() model_init = Network(args.init_channels, args.n_classes, args.layers, criterion, primitives, steps=args.nodes, args=args) if not args.disable_cuda: model_init = model_init.cuda() logging.info("param size = %fMB", utils.count_parameters_in_MB(model_init)) optimizer_init = torch.optim.SGD(model_init.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) architect_init = Architect(model_init, args) scheduler_init = CosineAnnealingLR(optimizer_init, float(args.epochs), eta_min=args.learning_rate_min) analyser = Analyzer(args, model_init) la_tracker = utils.EVLocalAvg(args.window, args.report_freq_hessian, args.epochs) train_queue, valid_queue, train_transform, valid_transform = helper.get_train_val_loaders( ) def valid_generator(): while True: for x, t in valid_queue: yield x, t valid_gen = valid_generator() for epoch in range(args.ev_start_epoch - 1, args.epochs): beta_decay_scheduler.step(epoch) logging.info("EPOCH %d SKIP BETA DECAY RATE: %e", epoch, beta_decay_scheduler.decay_rate) if (epoch % args.report_freq_hessian == 0) or (epoch == (args.epochs - 1)): lr = utils.load_checkpoint(model_init, optimizer_init, None, architect_init, args.save, la_tracker, epoch, args.task_id) logging.info("Loaded %d-th checkpoint." % epoch) if args.test_infer: valid_acc, valid_obj = infer(valid_queue, model_init, criterion) logging.info('valid_acc %f', valid_acc) if args.compute_hessian: input, target = next(iter(train_queue)) input = Variable(input, requires_grad=False) target = Variable(target, requires_grad=False) input_search, target_search = next( valid_gen) #next(iter(valid_queue)) input_search = Variable(input_search, requires_grad=False) target_search = Variable(target_search, requires_grad=False) if not args.disable_cuda: input = input.cuda() target = target.cuda(async=True) input_search = input_search.cuda() target_search = target_search.cuda(async=True) if not args.debug: H = analyser.compute_Hw(input, target, input_search, target_search, lr, optimizer_init, False) g = analyser.compute_dw(input, target, input_search, target_search, lr, optimizer_init, False) g = torch.cat([x.view(-1) for x in g]) state = { 'epoch': epoch, 'H': H.cpu().data.numpy().tolist(), 'g': g.cpu().data.numpy().tolist(), #'g_train': float(grad_norm), #'eig_train': eigenvalue, } with codecs.open(os.path.join( args.save, 'derivatives_{}.json'.format(args.task_id)), 'a', encoding='utf-8') as file: json.dump(state, file, separators=(',', ':')) file.write('\n') # early stopping ev = max(LA.eigvals(H.cpu().data.numpy())) logging.info('CURRENT EV: %f', ev)
def main(): if not torch.cuda.is_available(): logging.info('no gpu device available') sys.exit(1) np.random.seed(args.seed) torch.cuda.set_device(args.gpu) cudnn.benchmark = True torch.manual_seed(args.seed) cudnn.enabled = True torch.cuda.manual_seed(args.seed) logging.info('gpu device = %d' % args.gpu) logging.info("args = %s", args) genotype = eval("genotypes.%s" % args.arch) model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype) if args.parallel: model = nn.DataParallel(model).cuda() else: model = model.cuda() logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) criterion = nn.CrossEntropyLoss() criterion = criterion.cuda() criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth) criterion_smooth = criterion_smooth.cuda() optimizer = torch.optim.SGD(model.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) traindir = os.path.join(args.data, 'train') validdir = os.path.join(args.data, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_data = dset.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2), transforms.ToTensor(), normalize, ])) valid_data = dset.ImageFolder( validdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])) train_queue = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4) valid_queue = torch.utils.data.DataLoader(valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.decay_period, gamma=args.gamma) if args.load: model, optimizer, start_epoch, best_acc_top1 = utils.load_checkpoint( model, optimizer, '../../experiments/sota/imagenet/eval/EXP-20200210-143540-c10_s3_pgd-0-auxiliary-0.4-2753' ) else: best_acc_top1 = 0 for epoch in range(start_epoch, args.epochs): logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0]) model.drop_path_prob = args.drop_path_prob * epoch / args.epochs train_acc, train_obj = train(train_queue, model, criterion_smooth, optimizer) logging.info('train_acc %f', train_acc) writer.add_scalar('Acc/train', train_acc, epoch) writer.add_scalar('Obj/train', train_obj, epoch) scheduler.step() valid_acc_top1, valid_acc_top5, valid_obj = infer( valid_queue, model, criterion) logging.info('valid_acc_top1 %f', valid_acc_top1) logging.info('valid_acc_top5 %f', valid_acc_top5) writer.add_scalar('Acc/valid_top1', valid_acc_top1, epoch) writer.add_scalar('Acc/valid_top5', valid_acc_top5, epoch) is_best = False if valid_acc_top1 > best_acc_top1: best_acc_top1 = valid_acc_top1 is_best = True utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_acc_top1': best_acc_top1, 'optimizer': optimizer.state_dict(), }, is_best, args.save)