def main_worker(args): global best_prec1, dtype best_prec1 = 0 dtype = torch_dtypes.get(args.dtype) torch.manual_seed(args.seed) time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') if args.evaluate: args.results_dir = '/tmp' if args.save is '': args.save = time_stamp save_path = path.join(args.results_dir, args.save) args.distributed = args.local_rank >= 0 or args.world_size > 1 if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_init, world_size=args.world_size, rank=args.local_rank) args.local_rank = dist.get_rank() args.world_size = dist.get_world_size() if args.dist_backend == 'mpi': # If using MPI, select all visible devices args.device_ids = list(range(torch.cuda.device_count())) else: args.device_ids = [args.local_rank] if not (args.distributed and args.local_rank > 0): if not path.exists(save_path): makedirs(save_path) export_args_namespace(args, path.join(save_path, 'config.json')) setup_logging(path.join(save_path, 'log.txt'), resume=args.resume is not '', dummy=args.distributed and args.local_rank > 0) results_path = path.join(save_path) results = ResultsLog(results_path, title='Training Results - %s' % args.save) logging.info("saving to %s", save_path) logging.debug("run arguments: %s", args) logging.info("creating model %s", args.model) if 'cuda' in args.device and torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) torch.cuda.set_device(args.device_ids[0]) cudnn.benchmark = True else: args.device_ids = None # create model model = models.__dict__[args.model] model_config = {'dataset': args.dataset} if args.model_config is not '': model_config = dict(model_config, **literal_eval(args.model_config)) model = model(**model_config) logging.info("created model with configuration: %s", model_config) num_parameters = sum([l.nelement() for l in model.parameters()]) logging.info("number of parameters: %d", num_parameters) # optionally resume from a checkpoint if args.evaluate: if not path.isfile(args.evaluate): parser.error('invalid checkpoint: {}'.format(args.evaluate)) checkpoint = torch.load(args.evaluate, map_location="cpu") # Overrride configuration with checkpoint info args.model = checkpoint.get('model', args.model) args.model_config = checkpoint.get('config', args.model_config) # load checkpoint model.load_state_dict(checkpoint['state_dict']) logging.info("loaded checkpoint '%s' (epoch %s)", args.evaluate, checkpoint['epoch']) if args.resume: checkpoint_file = args.resume if path.isdir(checkpoint_file): results.load(path.join(checkpoint_file, 'results.csv')) checkpoint_file = path.join(checkpoint_file, 'model_best.pth.tar') if path.isfile(checkpoint_file): logging.info("loading checkpoint '%s'", args.resume) checkpoint = torch.load(checkpoint_file, map_location="cpu") if args.start_epoch < 0: # not explicitly set args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optim_state_dict = checkpoint.get('optim_state_dict', None) logging.info("loaded checkpoint '%s' (epoch %s)", checkpoint_file, checkpoint['epoch']) else: logging.error("no checkpoint found at '%s'", args.resume) else: optim_state_dict = None # define loss function (criterion) and optimizer loss_params = {} if args.label_smoothing > 0: loss_params['smooth_eps'] = args.label_smoothing criterion = getattr(model, 'criterion', CrossEntropyLoss)(**loss_params) criterion.to(args.device, dtype) model.to(args.device, dtype) # Batch-norm should always be done in float if 'half' in args.dtype: FilterModules(model, module=is_bn).to(dtype=torch.float) # optimizer configuration optim_regime = getattr(model, 'regime', [{ 'epoch': 0, 'optimizer': args.optimizer, 'lr': args.lr, 'momentum': args.momentum, 'weight_decay': args.weight_decay }]) optimizer = optim_regime if isinstance(optim_regime, OptimRegime) \ else OptimRegime(model, optim_regime, use_float_copy='half' in args.dtype) if optim_state_dict is not None: optimizer.load_state_dict(optim_state_dict) trainer = Trainer(model, criterion, optimizer, device_ids=args.device_ids, device=args.device, dtype=dtype, print_freq=args.print_freq, distributed=args.distributed, local_rank=args.local_rank, mixup=args.mixup, cutmix=args.cutmix, loss_scale=args.loss_scale, grad_clip=args.grad_clip, adapt_grad_norm=args.adapt_grad_norm) if args.tensorwatch: trainer.set_watcher(filename=path.abspath( path.join(save_path, 'tensorwatch.log')), port=args.tensorwatch_port) # Evaluation Data loading code args.eval_batch_size = args.eval_batch_size if args.eval_batch_size > 0 else args.batch_size val_data = DataRegime(getattr(model, 'data_eval_regime', None), defaults={ 'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'val', 'augment': False, 'input_size': args.input_size, 'batch_size': args.eval_batch_size, 'shuffle': False, 'num_workers': args.workers, 'pin_memory': True, 'drop_last': False }) if args.evaluate: results = trainer.validate(val_data.get_loader()) logging.info(results) return # Training Data loading code train_data_defaults = { 'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'train', 'augment': True, 'input_size': args.input_size, 'batch_size': args.batch_size, 'shuffle': True, 'num_workers': args.workers, 'pin_memory': True, 'drop_last': True, 'distributed': args.distributed, 'duplicates': args.duplicates, 'autoaugment': args.autoaugment, 'cutout': { 'holes': 1, 'length': 16 } if args.cutout else None } if hasattr(model, 'sampled_data_regime'): sampled_data_regime = model.sampled_data_regime probs, regime_configs = zip(*sampled_data_regime) regimes = [] for config in regime_configs: defaults = {**train_data_defaults} defaults.update(config) regimes.append(DataRegime(None, defaults=defaults)) train_data = SampledDataRegime(regimes, probs) else: train_data = DataRegime(getattr(model, 'data_regime', None), defaults=train_data_defaults) logging.info('optimization regime: %s', optim_regime) logging.info('data regime: %s', train_data) args.start_epoch = max(args.start_epoch, 0) trainer.training_steps = args.start_epoch * len(train_data) start_time = time.time() end_time = None end_epoch = None found = False for epoch in range(args.start_epoch, args.epochs): trainer.epoch = epoch train_data.set_epoch(epoch) val_data.set_epoch(epoch) logging.info('\nStarting Epoch: {0}\n'.format(epoch + 1)) # train for one epoch train_results = trainer.train(train_data.get_loader(), chunk_batch=args.chunk_batch) # evaluate on validation set val_results = trainer.validate(val_data.get_loader()) if args.distributed and args.local_rank > 0: continue # remember best prec@1 and save checkpoint is_best = val_results['prec1'] > best_prec1 best_prec1 = max(val_results['prec1'], best_prec1) if args.drop_optim_state: optim_state_dict = None else: optim_state_dict = optimizer.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'model': args.model, 'config': args.model_config, 'state_dict': model.state_dict(), 'optim_state_dict': optim_state_dict, 'best_prec1': best_prec1 }, is_best, path=save_path, save_all=args.save_all) logging.info('\nResults - Epoch: {0}\n' 'Training Loss {train[loss]:.4f} \t' 'Training Prec@1 {train[prec1]:.3f} \t' 'Training Prec@5 {train[prec5]:.3f} \t' 'Validation Loss {val[loss]:.4f} \t' 'Validation Prec@1 {val[prec1]:.3f} \t' 'Validation Prec@5 {val[prec5]:.3f} \t\n'.format( epoch + 1, train=train_results, val=val_results)) values = dict(epoch=epoch + 1, steps=trainer.training_steps) values.update({'training ' + k: v for k, v in train_results.items()}) values.update({'validation ' + k: v for k, v in val_results.items()}) results.add(**values) results.plot(x='epoch', y=['training loss', 'validation loss'], legend=['training', 'validation'], title='Loss', ylabel='loss') results.plot(x='epoch', y=['training error1', 'validation error1'], legend=['training', 'validation'], title='Error@1', ylabel='error %') results.plot(x='epoch', y=['training error5', 'validation error5'], legend=['training', 'validation'], title='Error@5', ylabel='error %') if 'grad' in train_results.keys(): results.plot(x='epoch', y=['training grad'], legend=['gradient L2 norm'], title='Gradient Norm', ylabel='value') results.save() if not found and val_results['prec1'] > 94: found = True end_time = time.time() - start_time end_epoch = epoch + 1 if not found: end_time = time.time() - start_time end_epoch = epoch + 1 print("Target reached: {}, minutes: {}, epochs: {}".format( found, round(end_time / 60, 3), end_epoch))
def main_worker(args): global best_prec1, dtype best_prec1 = 0 dtype = torch_dtypes.get(args.dtype) torch.manual_seed(args.seed) time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') if args.evaluate: args.results_dir = '/tmp' model_config = {'dataset': args.dataset, 'batch': args.batch_size} if args.model_config is not '': model_config = dict(model_config, **literal_eval(args.model_config)) ##autoname fname = auto_name(args, model_config) args.save = fname monitor = args.monitor print(fname) save_path = path.join(args.results_dir, args.save) args.distributed = args.local_rank >= 0 or args.world_size > 1 if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_init, world_size=args.world_size, rank=args.local_rank) args.local_rank = dist.get_rank() args.world_size = dist.get_world_size() if args.dist_backend == 'mpi': # If using MPI, select all visible devices args.device_ids = list(range(torch.cuda.device_count())) else: args.device_ids = [args.local_rank] if not (args.distributed and args.local_rank > 0): if not args.dry: if not path.exists(save_path): makedirs(save_path) export_args_namespace(args, path.join(save_path, 'config.json')) if monitor > 0 and not args.dry: events_path = "runs/%s" % fname my_file = Path(events_path) if my_file.is_file(): os.remove(events_path) writer = SummaryWriter(log_dir=events_path ,comment=str(args)) model_config['writer'] = writer model_config['monitor'] = monitor else: monitor = 0 writer = None if args.dry: model = models.__dict__[args.model] model = model(**model_config) print("created model with configuration: %s" % model_config) num_parameters = sum([l.nelement() for l in model.parameters()]) print("number of parameters: %d" % num_parameters) return setup_logging(path.join(save_path, 'log.txt'), resume=args.resume is not '', dummy=args.distributed and args.local_rank > 0) results_path = path.join(save_path, 'results') results = ResultsLog(results_path, title='Training Results - %s' % args.save) logging.info("saving to %s", save_path) logging.debug("run arguments: %s", args) logging.info("creating model %s", args.model) if 'cuda' in args.device and torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) torch.cuda.set_device(args.device_ids[0]) cudnn.benchmark = True else: args.device_ids = None # create model model = models.__dict__[args.model] model = model(**model_config) if args.sync_bn: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) logging.info("created model with configuration: %s", model_config) num_parameters = sum([l.nelement() for l in model.parameters()]) logging.info("number of parameters: %d", num_parameters) # optionally resume from a checkpoint if args.evaluate: if not path.isfile(args.evaluate): parser.error('invalid checkpoint: {}'.format(args.evaluate)) checkpoint = torch.load(args.evaluate, map_location="cpu") # Overrride configuration with checkpoint info args.model = checkpoint.get('model', args.model) args.model_config = checkpoint.get('config', args.model_config) # load checkpoint model.load_state_dict(checkpoint['state_dict']) logging.info("loaded checkpoint '%s' (epoch %s)", args.evaluate, checkpoint['epoch']) if args.resume: checkpoint_file = args.resume if path.isdir(checkpoint_file): results.load(path.join(checkpoint_file, 'results.csv')) checkpoint_file = path.join( checkpoint_file, 'model_best.pth.tar') if path.isfile(checkpoint_file): logging.info("loading checkpoint '%s'", args.resume) checkpoint = torch.load(checkpoint_file, map_location="cpu") if args.start_epoch < 0: # not explicitly set args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optim_state_dict = checkpoint.get('optim_state_dict', None) logging.info("loaded checkpoint '%s' (epoch %s)", checkpoint_file, checkpoint['epoch']) else: logging.error("no checkpoint found at '%s'", args.resume) else: optim_state_dict = None # define loss function (criterion) and optimizer loss_params = {} if args.label_smoothing > 0: loss_params['smooth_eps'] = args.label_smoothing criterion = getattr(model, 'criterion', CrossEntropyLoss)(**loss_params) criterion.to(args.device, dtype) model.to(args.device, dtype) # Batch-norm should always be done in float if 'half' in args.dtype: FilterModules(model, module=is_bn).to(dtype=torch.float) # optimizer configuration optim_regime = getattr(model, 'regime', [{'epoch': 0, 'optimizer': args.optimizer, 'lr': args.lr, 'momentum': args.momentum, 'weight_decay': args.weight_decay}]) optimizer = optim_regime if isinstance(optim_regime, OptimRegime) \ else OptimRegime(model, optim_regime, use_float_copy='half' in args.dtype) if optim_state_dict is not None: optimizer.load_state_dict(optim_state_dict) trainer = Trainer(model, criterion, optimizer, device_ids=args.device_ids, device=args.device, dtype=dtype, print_freq=args.print_freq, distributed=args.distributed, local_rank=args.local_rank, mixup=args.mixup, cutmix=args.cutmix, loss_scale=args.loss_scale, grad_clip=args.grad_clip, adapt_grad_norm=args.adapt_grad_norm, writer = writer, monitor = monitor) if args.tensorwatch: trainer.set_watcher(filename=path.abspath(path.join(save_path, 'tensorwatch.log')), port=args.tensorwatch_port) # Evaluation Data loading code args.eval_batch_size = args.eval_batch_size if args.eval_batch_size > 0 else args.batch_size val_data = DataRegime(getattr(model, 'data_eval_regime', None), defaults={'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'val', 'augment': False, 'input_size': args.input_size, 'batch_size': args.eval_batch_size, 'shuffle': False, 'num_workers': args.workers, 'pin_memory': True, 'drop_last': False}) if args.evaluate: results = trainer.validate(val_data.get_loader()) logging.info(results) return # Training Data loading code train_data_defaults = {'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'train', 'augment': True, 'input_size': args.input_size, 'batch_size': args.batch_size, 'shuffle': True, 'num_workers': args.workers, 'pin_memory': True, 'drop_last': True, 'distributed': args.distributed, 'duplicates': args.duplicates, 'autoaugment': args.autoaugment, 'cutout': {'holes': 1, 'length': 16} if args.cutout else None} if hasattr(model, 'sampled_data_regime'): sampled_data_regime = model.sampled_data_regime probs, regime_configs = zip(*sampled_data_regime) regimes = [] for config in regime_configs: defaults = {**train_data_defaults} defaults.update(config) regimes.append(DataRegime(None, defaults=defaults)) train_data = SampledDataRegime(regimes, probs) else: train_data = DataRegime( getattr(model, 'data_regime', None), defaults=train_data_defaults) logging.info('optimization regime: %s', optim_regime) logging.info('data regime: %s', train_data) args.start_epoch = max(args.start_epoch, 0) trainer.training_steps = args.start_epoch * len(train_data) if not args.covmat == "": try: int_covmat = int(args.covmat) if int_covmat < 0: total_layers = len([name for name, layer in model.named_children()]) int_covmat = total_layers + int_covmat child_cnt = 0 except ValueError: int_covmat = None def calc_covmat(x_, partitions = 64): L = x_.shape[0] // partitions non_diags = [] diags = [] for p1 in range(partitions): for p2 in range(partitions): x = x_[p1*L:(p1+1)*L] y = x_[p2*L:(p2+1)*L] X = torch.matmul(x,y.transpose(0,1)) if p1 == p2: mask = torch.eye(X.shape[0],dtype=torch.bool) non_diag = X[~mask].reshape(-1).cpu() diag = X[mask].reshape(-1).cpu() non_diags.append(non_diag) diags.append(diag) else: non_diag = X.reshape(-1).cpu() non_diags.append(diag) diags = torch.cat(diags) non_diags = torch.cat(non_diags) diag_var = diags.var() non_diag_var = non_diags.var() diags = diags - diags.mean() non_diags = non_diags - non_diags.mean() ##import pdb; pdb.set_trace() diag_small_ratio = (diags < -diags.std()).to(dtype = torch.float).mean() non_diag_small_ratio = (non_diags < -non_diags.std()).to(dtype = torch.float).mean() return diag_var, non_diag_var, diag_small_ratio, non_diag_small_ratio global diag_var_mean global non_diag_var_mean global var_count var_count = 0 diag_var_mean = 0 non_diag_var_mean = 0 def report_covmat_hook(module, input, output): global diag_var_mean global non_diag_var_mean global var_count flatten_output = output.reshape([-1,1]).detach() diag_var, non_diag_var, diag_small_ratio, non_diag_small_ratio = calc_covmat(flatten_output) diag_var_mean = diag_var_mean + diag_var non_diag_var_mean = non_diag_var_mean + non_diag_var var_count = var_count + 1 if var_count % 10 == 1: print("diag_var = %.02f (%.02f), ratio: %.02f , non_diag_var = %0.2f (%.02f), ratio: %.02f" % (diag_var, diag_var_mean/var_count, diag_small_ratio , non_diag_var, non_diag_var_mean/var_count, non_diag_small_ratio )) for name, layer in model.named_children(): if int_covmat is None: condition = (name == args.covmat) else: condition = (child_cnt == int_covmat) child_cnt = child_cnt + 1 if condition: layer.register_forward_hook( report_covmat_hook) for epoch in range(args.start_epoch, args.epochs): trainer.epoch = epoch train_data.set_epoch(epoch) val_data.set_epoch(epoch) logging.info('\nStarting Epoch: {0}\n'.format(epoch + 1)) # train for one epoch train_results = trainer.train(train_data.get_loader(), chunk_batch=args.chunk_batch) # evaluate on validation set val_results = trainer.validate(val_data.get_loader()) if args.distributed and args.local_rank > 0: continue # remember best prec@1 and save checkpoint is_best = val_results['prec1'] > best_prec1 best_prec1 = max(val_results['prec1'], best_prec1) if args.drop_optim_state: optim_state_dict = None else: optim_state_dict = optimizer.state_dict() save_checkpoint({ 'epoch': epoch + 1, 'model': args.model, 'config': args.model_config, 'state_dict': model.state_dict(), 'optim_state_dict': optim_state_dict, 'best_prec1': best_prec1 }, is_best, path=save_path, save_all=args.save_all) logging.info('\nResults - Epoch: {0}\n' 'Training Loss {train[loss]:.4f} \t' 'Training Prec@1 {train[prec1]:.3f} \t' 'Training Prec@5 {train[prec5]:.3f} \t' 'Validation Loss {val[loss]:.4f} \t' 'Validation Prec@1 {val[prec1]:.3f} \t' 'Validation Prec@5 {val[prec5]:.3f} \t\n' .format(epoch + 1, train=train_results, val=val_results)) if writer is not None: writer.add_scalar('Train/Loss', train_results['loss'], epoch) writer.add_scalar('Train/Prec@1', train_results['prec1'], epoch) writer.add_scalar('Train/Prec@5', train_results['prec5'], epoch) writer.add_scalar('Val/Loss', val_results['loss'], epoch) writer.add_scalar('Val/Prec@1', val_results['prec1'], epoch) writer.add_scalar('Val/Prec@5', val_results['prec5'], epoch) # tmplr = optimizer.get_lr() # writer.add_scalar('HyperParameters/learning-rate', tmplr, epoch) values = dict(epoch=epoch + 1, steps=trainer.training_steps) values.update({'training ' + k: v for k, v in train_results.items()}) values.update({'validation ' + k: v for k, v in val_results.items()}) results.add(**values) results.plot(x='epoch', y=['training loss', 'validation loss'], legend=['training', 'validation'], title='Loss', ylabel='loss') results.plot(x='epoch', y=['training error1', 'validation error1'], legend=['training', 'validation'], title='Error@1', ylabel='error %') results.plot(x='epoch', y=['training error5', 'validation error5'], legend=['training', 'validation'], title='Error@5', ylabel='error %') if 'grad' in train_results.keys(): results.plot(x='epoch', y=['training grad'], legend=['gradient L2 norm'], title='Gradient Norm', ylabel='value') results.save() logging.info(f'\nBest Validation Accuracy (top1): {best_prec1}') if writer: writer.close()