def n_training_step(self, step, training_set: TrainingSet): logits_ls = [] runner_ls = [] for i in tqdm.trange(self.rparams.num_models): runner_train_examples = training_set.get_training_examples(i) sub_log_writer = self.get_sub_log_writer(f"runner{step}__{i}") with sub_log_writer.log_context(): runner = self.runner_creator.create( num_train_examples=len(runner_train_examples), log_writer=sub_log_writer, ) val_examples = self.task.get_val_examples() mrunner = metarunner.MetaRunner( runner=runner, train_examples=runner_train_examples, # quick and dirty val_examples=val_examples[:self.meta_runner_parameters. partial_eval_number], should_save_func=metarunner.get_should_save_func( self.meta_runner_parameters.save_every_steps), should_eval_func=metarunner.get_should_eval_func( self.meta_runner_parameters.eval_every_steps), output_dir=self.meta_runner_parameters.output_dir, verbose=True, save_best_model=self.meta_runner_parameters.do_save, load_best_model=True, log_writer=self.log_writer, ) mrunner.train_val_save_every() logits = runner.run_test(self.unlabeled_examples) runner_save_memory(runner) logits_ls.append(logits) runner_ls.append(runner) all_logits = np.stack(logits_ls, axis=1) self.log_writer.write_obj("sub_runner_logits", all_logits, { "step": step, }) self.log_writer.flush() chosen_examples, chosen_preds = get_n_training_pseudolabels( all_logits=all_logits, with_disagreement=self.rparams.with_disagreement, confidence_threshold=self.rparams.confidence_threshold, ) new_training_set = TrainingSet( task=training_set.task, labeled_examples=training_set.labeled_examples, unlabeled_examples=training_set.unlabeled_examples, labeled_indices=training_set.labeled_indices, chosen_examples=chosen_examples, chosen_preds=chosen_preds, ) return new_training_set, runner_ls
def main(args): quick_init_out = initialization.quick_init(args=args, verbose=True) with quick_init_out.log_writer.log_context(): task = tasks.create_task_from_config_path( config_path=args.task_config_path, verbose=True, ) with distributed.only_first_process(local_rank=args.local_rank): # load the model model_class_spec = model_resolution.resolve_model_setup_classes( model_type=args.model_type, task_type=task.TASK_TYPE, ) model_wrapper = model_setup.simple_model_setup( model_type=args.model_type, model_class_spec=model_class_spec, config_path=args.model_config_path, tokenizer_path=args.model_tokenizer_path, task=task, ) model_setup.simple_load_model_path( model=model_wrapper.model, model_load_mode=args.model_load_mode, model_path=args.model_path, verbose=True, ) model_wrapper.model.to(quick_init_out.device) train_examples = task.get_train_examples() train_examples, _ = train_setup.maybe_subsample_train( train_examples=train_examples, train_examples_number=args.train_examples_number, train_examples_fraction=args.train_examples_fraction, ) num_train_examples = len(train_examples) train_schedule = train_setup.get_train_schedule( num_train_examples=num_train_examples, max_steps=args.max_steps, num_train_epochs=args.num_train_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, per_gpu_train_batch_size=args.train_batch_size, n_gpu=quick_init_out.n_gpu, ) quick_init_out.log_writer.write_entry( "text", f"t_total: {train_schedule.t_total}", do_print=True) loss_criterion = train_setup.resolve_loss_function( task_type=task.TASK_TYPE) optimizer_scheduler = model_setup.create_optimizer( model=model_wrapper.model, learning_rate=args.learning_rate, t_total=train_schedule.t_total, warmup_steps=args.warmup_steps, warmup_proportion=args.warmup_proportion, optimizer_type=args.optimizer_type, verbose=True, ) model_setup.special_model_setup( model_wrapper=model_wrapper, optimizer_scheduler=optimizer_scheduler, fp16=args.fp16, fp16_opt_level=args.fp16_opt_level, n_gpu=quick_init_out.n_gpu, local_rank=args.local_rank, ) rparams = simple_runner.RunnerParameters( feat_spec=model_resolution.build_featurization_spec( model_type=args.model_type, max_seq_length=args.max_seq_length, ), local_rank=args.local_rank, n_gpu=quick_init_out.n_gpu, fp16=args.fp16, learning_rate=args.learning_rate, eval_batch_size=args.eval_batch_size, max_grad_norm=args.max_grad_norm, ) runner = simple_runner.SimpleTaskRunner( task=task, model_wrapper=model_wrapper, optimizer_scheduler=optimizer_scheduler, loss_criterion=loss_criterion, device=quick_init_out.device, rparams=rparams, train_schedule=train_schedule, log_writer=quick_init_out.log_writer, ) if args.do_train: val_examples = task.get_val_examples() metarunner.MetaRunner( runner=runner, train_examples=train_examples, val_examples= val_examples[:args.partial_eval_number], # quick and dirty should_save_func=metarunner.get_should_save_func( args.save_every_steps), should_eval_func=metarunner.get_should_eval_func( args.eval_every_steps), output_dir=args.output_dir, verbose=True, save_best_model=args.do_save, load_best_model=True, log_writer=quick_init_out.log_writer, ).train_val_save_every() if args.do_save: torch.save(model_wrapper.model.state_dict(), os.path.join(args.output_dir, "model.p")) if args.do_val: val_examples = task.get_val_examples() results = runner.run_val(val_examples) evaluate.write_val_results( results=results, output_dir=args.output_dir, verbose=True, ) if args.do_test: test_examples = task.get_test_examples() logits = runner.run_test(test_examples) evaluate.write_preds( logits=logits, output_path=os.path.join(args.output_dir, "test_preds.csv"), )
def main(args): quick_init_out = initialization.quick_init(args=args, verbose=True) with quick_init_out.log_writer.log_context(): task_dict = create_task_dict( multitask_config_path=args.multitask_config_path, task_name_ls=args.task_name_ls, ) with distributed.only_first_process(local_rank=args.local_rank): # load the model model_wrapper = multitask_model_setup.setup_multitask_ptt_model( model_type=args.model_type, config_path=args.model_config_path, tokenizer_path=args.model_tokenizer_path, task_dict=task_dict, ) model_setup.simple_load_model_path( model=model_wrapper.model.model_dict[list( task_dict.keys())[0]], model_load_mode=args.model_load_mode, model_path=args.model_path, verbose=True, ) model_wrapper.model.to(quick_init_out.device) train_examples_dict = {} for task_name, task in task_dict.items(): train_examples = task.get_train_examples() train_examples, _ = train_setup.maybe_subsample_train( train_examples=train_examples, train_examples_number=args.train_examples_number, train_examples_fraction=args.train_examples_fraction, ) train_examples_dict[task_name] = train_examples # TODO: Tweak the schedule total_num_train_examples = sum( len(train_examples) for train_examples in train_examples_dict.values()) train_schedule = train_setup.get_train_schedule( num_train_examples=total_num_train_examples, max_steps=args.max_steps, num_train_epochs=args.num_train_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, per_gpu_train_batch_size=args.train_batch_size, n_gpu=quick_init_out.n_gpu, ) quick_init_out.log_writer.write_entry( "text", f"t_total: {train_schedule.t_total}", do_print=True) loss_criterion_dict = { task_name: train_setup.resolve_loss_function(task_type=task.TASK_TYPE) for task_name, task in task_dict.items() } optimizer_scheduler = model_setup.create_optimizer( model=model_wrapper.model, learning_rate=args.learning_rate, t_total=train_schedule.t_total, warmup_steps=args.warmup_steps, warmup_proportion=args.warmup_proportion, optimizer_type=args.optimizer_type, verbose=True, ) model_setup.special_model_setup( model_wrapper=model_wrapper, optimizer_scheduler=optimizer_scheduler, fp16=args.fp16, fp16_opt_level=args.fp16_opt_level, n_gpu=quick_init_out.n_gpu, local_rank=args.local_rank, ) rparams = simple_runner.RunnerParameters( feat_spec=model_resolution.build_featurization_spec( model_type=args.model_type, max_seq_length=args.max_seq_length, ), local_rank=args.local_rank, n_gpu=quick_init_out.n_gpu, fp16=args.fp16, learning_rate=args.learning_rate, eval_batch_size=args.eval_batch_size, max_grad_norm=args.max_grad_norm, ) runner = multitask_runner.MultiTaskRunner( task_dict=task_dict, model_wrapper=model_wrapper, optimizer_scheduler=optimizer_scheduler, loss_criterion_dict=loss_criterion_dict, device=quick_init_out.device, rparams=rparams, train_schedule=train_schedule, log_writer=quick_init_out.log_writer, ) if args.do_train: val_examples_dict = { task_name: task.get_val_examples()[:args.partial_eval_number] for task_name, task in task_dict.items() } metarunner.MetaRunner( runner=runner, train_examples=train_examples_dict, val_examples=val_examples_dict, # quick and dirty should_save_func=metarunner.get_should_save_func( args.save_every_steps), should_eval_func=metarunner.get_should_eval_func( args.eval_every_steps), output_dir=args.output_dir, verbose=True, save_best_model=args.do_save, load_best_model=True, log_writer=quick_init_out.log_writer, ).train_val_save_every() if args.do_save: torch.save(model_wrapper.model.state_dict(), os.path.join(args.output_dir, "model.p")) if args.do_val: val_examples_dict = { task_name: task.get_val_examples()[:args.partial_eval_number] for task_name, task in task_dict.items() } results = runner.run_val(val_examples_dict) evaluate.write_metrics( results=results, output_path=os.path.join(args.output_dir, "val_metrics.json"), verbose=True, ) if args.do_test: raise NotImplementedError()