def hp_search_optuna(trial: optuna.Trial): global gargs args = gargs # set config config = load_config(args) config['args'] = args logger.info("%s", config) # set path set_path(config) # create accelerator accelerator = Accelerator() config['accelerator'] = accelerator args.device = accelerator.device # set search spaces lr = trial.suggest_float('lr', 1e-5, 1e-3, log=True) bsz = trial.suggest_categorical('batch_size', [8, 16, 32, 64]) seed = trial.suggest_int('seed', 17, 42) epochs = trial.suggest_int('epochs', 1, args.epoch) # prepare train, valid dataset train_loader, valid_loader = prepare_datasets(config, hp_search_bsz=bsz) with temp_seed(seed): # prepare model model = prepare_model(config) # create optimizer, scheduler, summary writer model, optimizer, scheduler, writer = prepare_others(config, model, train_loader, lr=lr) # create secondary optimizer, scheduler _, optimizer_2nd, scheduler_2nd, _ = prepare_others( config, model, train_loader, lr=args.bert_lr_during_freezing) train_loader = accelerator.prepare(train_loader) valid_loader = accelerator.prepare(valid_loader) config['optimizer'] = optimizer config['scheduler'] = scheduler config['optimizer_2nd'] = optimizer_2nd config['scheduler_2nd'] = scheduler_2nd config['writer'] = writer total_batch_size = args.batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_loader)}") logger.info(f" Num Epochs = {args.epoch}") logger.info( f" Instantaneous batch size per device = {args.batch_size}") logger.info( f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" ) logger.info( f" Gradient Accumulation steps = {args.gradient_accumulation_steps}" ) logger.info(f" Total optimization steps = {args.max_train_steps}") early_stopping = EarlyStopping(logger, patience=args.patience, measure='f1', verbose=1) best_eval_f1 = -float('inf') for epoch in range(epochs): eval_loss, eval_f1, best_eval_f1 = train_epoch( model, config, train_loader, valid_loader, epoch, best_eval_f1) # early stopping if early_stopping.validate(eval_f1, measure='f1'): break if eval_f1 == best_eval_f1: early_stopping.reset(best_eval_f1) early_stopping.status() trial.report(eval_f1, epoch) if trial.should_prune(): raise optuna.TrialPruned() return eval_f1
def train(args): # set etc torch.autograd.set_detect_anomaly(False) # set config config = load_config(args) config['args'] = args logger.info("%s", config) # set path set_path(config) # create accelerator accelerator = Accelerator() config['accelerator'] = accelerator args.device = accelerator.device # prepare train, valid dataset train_loader, valid_loader = prepare_datasets(config) with temp_seed(args.seed): # prepare model model = prepare_model(config) # create optimizer, scheduler, summary writer model, optimizer, scheduler, writer = prepare_others( config, model, train_loader) # create secondary optimizer, scheduler _, optimizer_2nd, scheduler_2nd, _ = prepare_others( config, model, train_loader, lr=args.bert_lr_during_freezing) train_loader = accelerator.prepare(train_loader) valid_loader = accelerator.prepare(valid_loader) config['optimizer'] = optimizer config['scheduler'] = scheduler config['optimizer_2nd'] = optimizer_2nd config['scheduler_2nd'] = scheduler_2nd config['writer'] = writer total_batch_size = args.batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_loader)}") logger.info(f" Num Epochs = {args.epoch}") logger.info( f" Instantaneous batch size per device = {args.batch_size}") logger.info( f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" ) logger.info( f" Gradient Accumulation steps = {args.gradient_accumulation_steps}" ) logger.info(f" Total optimization steps = {args.max_train_steps}") # training early_stopping = EarlyStopping(logger, patience=args.patience, measure='f1', verbose=1) local_worse_epoch = 0 best_eval_f1 = -float('inf') for epoch_i in range(args.epoch): epoch_st_time = time.time() eval_loss, eval_f1, best_eval_f1 = train_epoch( model, config, train_loader, valid_loader, epoch_i, best_eval_f1) # early stopping if early_stopping.validate(eval_f1, measure='f1'): break if eval_f1 == best_eval_f1: early_stopping.reset(best_eval_f1) early_stopping.status()