Exemple #1
0
def train_t(config):
    seed = config.pop('seed')
    static_params = config.pop('static_params')

    torch.backends.cudnn.enabled = True
    if static_params['t_id'] == 0:
        torch.backends.cudnn.deterministic = True
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        np.random.seed(seed)
        random.seed(seed)
    else:
        torch.backends.cudnn.deterministic = False

    if 'PSSN' in tune.get_trial_name() or static_params['t_id'] == 0:
        torch.backends.cudnn.benchmark = False
    else:
        torch.backends.cudnn.benchmark = True

    if 'learner' in config:
        learner = config.pop('learner')
    else:
        learner_path = config.pop('learner_path')
        learner = torch.load(learner_path)

    rescaled, t, metrics, b_state_dict, stats = train_single_task(config=config, learner=learner, **static_params)

    learner_save_path = os.path.join(tune.get_trial_dir(), 'learner.pth')
    # raise ValueError(learner_save_path)
    torch.save(learner, learner_save_path)
 def maybe_start_communicator(self):
     if self.config.get('run_communicator', False):
         name = tune.get_trial_name()
         if name is None:
             name = str(uuid1())
             print(f"Selecting name {name}")
         if not ray.is_initialized():
             ray_kwargs = self.ray_kwargs if hasattr(self,
                                                     'ray_kwargs') else {}
             ray.init(**ray_kwargs)
         logging.info(f"Starting parameter communicator with name {name}")
         self.communicator = run_communicator(name)
Exemple #3
0
def train_transformer(config, checkpoint_dir=None):
    data_args = DataTrainingArguments(task_name=config["task_name"],
                                      data_dir=config["data_dir"])
    tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
    train_dataset = GlueDataset(data_args,
                                tokenizer=tokenizer,
                                mode="train",
                                cache_dir=config["data_dir"])
    eval_dataset = GlueDataset(data_args,
                               tokenizer=tokenizer,
                               mode="dev",
                               cache_dir=config["data_dir"])
    eval_dataset = eval_dataset[:len(eval_dataset) // 2]
    training_args = TrainingArguments(
        output_dir=tune.get_trial_dir(),
        learning_rate=config["learning_rate"],
        do_train=True,
        do_eval=True,
        evaluate_during_training=True,
        eval_steps=(len(train_dataset) // config["per_gpu_train_batch_size"]) +
        1,
        # We explicitly set save to 0, and do saving in evaluate instead
        save_steps=0,
        num_train_epochs=config["num_epochs"],
        max_steps=config["max_steps"],
        per_device_train_batch_size=config["per_gpu_train_batch_size"],
        per_device_eval_batch_size=config["per_gpu_val_batch_size"],
        warmup_steps=0,
        weight_decay=config["weight_decay"],
        logging_dir="./logs",
    )

    # Arguments for W&B.
    name = tune.get_trial_name()
    wandb_args = {
        "project_name": "transformers_pbt",
        "watch": "false",  # Either set to gradient, false, or all
        "run_name": name,
    }

    tune_trainer = get_trainer(recover_checkpoint(checkpoint_dir,
                                                  config["model_name"]),
                               train_dataset,
                               eval_dataset,
                               config["task_name"],
                               training_args,
                               wandb_args=wandb_args)
    tune_trainer.train(recover_checkpoint(checkpoint_dir,
                                          config["model_name"]))
    def collect_initial(self, do_tqdm=True):

        do_ascii = tune.get_trial_name() is not None

        if self.n_collectors == 0:
            # resetting the local buffer
            info = self.replay_buffer.reset()
            logging.info(f"Buffer reset response: {info}")

            for _ in tqdm(range(self.future_batch_size),
                          disable=not do_tqdm,
                          desc="Initial buffer fill [local]",
                          ascii=do_ascii,
                          colour='red'):
                self.replay_buffer.collect_local(self.rl_context)
        else:
            # resetting the remote buffer
            info = ray.get(self.replay_buffer.reset.remote())
            logging.info(f"Buffer reset response: {info}")

            target = self.config.get('collect_initial_steps', 1000)
            collected = 0

            with tqdm(initial=collected,
                      total=target,
                      disable=not do_tqdm,
                      desc="Initial buffer fill",
                      ascii=do_ascii,
                      colour='red') as pbar:
                while collected < target:
                    stats = ray.get(
                        self.replay_buffer.collect.remote(min_batches=0,
                                                          enable_wait=False))
                    delta = stats['steps_collected_now']

                    if delta:
                        pbar.update(delta)
                        pbar.set_postfix(**stats)
                        collected += delta
                    sleep(0.1)
Exemple #5
0
 def track_train(config):
     tune.report(name=tune.get_trial_name(),
                 trial_id=tune.get_trial_id())
Exemple #6
0
def train(
    net: RNN_MODEL1,
    train_loader: RNNIterator,
    valid_loader: RNNIterator,
    patience: int,
    args: object,
    dtype: torch.dtype,
    device: torch.device,
    savedir: str,
    neptune: neptune,
):
    """
    Train CNN on provided data set.
    Args:
        net: initialized neural network
        train_loader: DataLoader containing training set
        parameters: dictionary containing parameters to be passed to the optimizer.
            - lr: default (0.001)
            - momentum: default (0.0)
            - weight_decay: default (0.0)
            - num_epochs: default (1)
        dtype: torch dtype
        device: torch device
    Returns:
        nn.Module: trained CNN.
    """
    # Initialize network
    net.to(device)  # pyre-ignore [28]
    # Define loss and optimizer
    criterion = nn.NLLLoss()
    mystring = "optim." + args.optimizer

    trial_name = tune.get_trial_name()
    print(trial_name, "Train")
    if args.optimizer == 'Adam':
        print("Adam optimizer")
        optimizer = optim.Adam(net.parameters(), lr=args.lr)


#        optimizer = eval(mystring)(net.parameters(), lr=args.lr)
    elif args.optimizer == 'sgd':
        print("sgd optimizer")
        optimizer = optim.SGD(net.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    elif args.optimizer == 'RMSprop':
        print("RMSprop optimizer")
        optimizer = optim.RMSprop(net.parameters(),
                                  lr=args.lr,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
    else:
        print("Optimizer Exception. args.optimizer:", args.optimizer)
        raise Exception()
        optimizer = eval(mystring)(net.parameters(), lr=args.lr)

    print("=" * 90)
    print('data_dim: {:d} | data_len {:d}'.format(train_loader.in_n,
                                                  train_loader.data_len))
    print('data_dim: {:d} | data_len {:d}'.format(valid_loader.in_n,
                                                  valid_loader.data_len))
    print("=" * 90)
    print(net)
    print("=" * 90)
    print(optimizer)
    print("=" * 90)

    num_epochs = 1000
    bc = 0
    best_val = 0
    best_net = None
    train_loss = 0.0

    # Train Network
    # pyre-fixme[6]: Expected `int` for 1st param but got `float`.
    print("Training start", num_epochs)
    for epoch in range(num_epochs):
        if epoch >= args.max_epoch:
            break

        net.train()
        for iloop, (inputs, labels, end_of_data) in enumerate(train_loader):
            if end_of_data == 1:
                break

            # move data to proper dtype and device
            inputs = inputs.to(dtype=dtype, device=device)
            labels = labels.to(dtype=torch.int64, device=device).squeeze()

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)[-1, :, :]  # for RNN
            loss = criterion(outputs, labels)
            loss.backward()

            optimizer.step()
            train_loss += loss.item()

        # check if validation improves
        print("Train end", iloop)
        train_loss = train_loss / iloop
        print("Train loss:", train_loss)

        print("Test start")
        acc, prec, rec, f1 = test(net, valid_loader, device)

        tune.report(acc=acc, prec=prec, rec=rec, f1=f1)
        #val = evaluate_(net, valid_loader, criterion, device)
        #print('epoch: {:d} | train: {:.4f} | val: {:.4f}'.format(epoch+1, train_loss, val))
        #neptune.log_metric('tr loss', epoch, train_loss)
        #neptune.log_metric('val loss', epoch, val)
        #if epoch==0 or val < best_val:
        # if yes, save model, init bc, and best_val
        #    torch.save(net, savedir)
        #    bc = 0
        #    best_val = val
        #else:
        # if no, bc++, check if patience is over
        #    bc += 1
        #    if bc > patience:
        #        break

    #print('training over')
    #best_net = torch.load(savedir)
    #return best_net

    return
Exemple #7
0
    def tune_train(args,
                   model_class,
                   task_info: TaskInfo,
                   build_method=default_build_method,
                   model_kwargs: dict = None,
                   tune_config=None):
        if model_kwargs is None:
            model_kwargs = {}
        this_time = time.strftime("%m-%d_%H:%M:%S", time.localtime())
        experiment_name = f'{task_info.task_name}_{this_time}'

        if tune_config is None:
            config = {
                # 3e-4 for Small, 1e-4 for Base, 5e-5 for Large
                "lr":
                tune.loguniform(args.tune_min_lr, args.tune_max_lr),

                # -1 for disable, 0.8 for Base/Small, 0.9 for Large
                "layerwise_lr_decay_power":
                tune.choice([0.8, 0.9]),

                # lr scheduler
                "lr_scheduler":
                tune.choice([
                    'linear_schedule_with_warmup',
                    'polynomial_decay_schedule_with_warmup'
                ]),
            }
        else:
            config = tune_config
        if torch.cuda.is_available():
            resources_per_trial = {
                "cpu": args.tune_cpus_per_trial,
                "gpu": args.tune_gpus_per_trial
            }
        else:
            resources_per_trial = {"cpu": args.tune_cpus_per_trial}
        print("resources_per_trial", resources_per_trial)

        tune_dir = os.path.abspath('tune_lightning_logs')

        analysis = tune.run(
            tune.with_parameters(
                tune_train_once,
                args=args,
                task_info=task_info,
                model_class=model_class,
                build_method=build_method,
                model_kwargs=model_kwargs,
                resume=args.tune_resume,
                group=experiment_name,
                log_dir=tune_dir,
            ),
            mode="max",
            config=config,
            num_samples=args.tune_num_samples,
            metric=f'tune_{task_info.metric_name}',
            name=experiment_name,
            progress_reporter=CLIReporter(
                parameter_columns=list(config.keys()),
                metric_columns=[
                    "loss", f'tune_{task_info.metric_name}',
                    "training_iteration"
                ]),
            callbacks=[TBXLoggerCallback(),
                       CSVLoggerCallback()],
            resources_per_trial=resources_per_trial,
            scheduler=ASHAScheduler(
                max_t=args.max_epochs + 1,  # for test
                grace_period=args.min_epochs),
            queue_trials=True,
            keep_checkpoints_num=args.tune_keep_checkpoints_num,
            checkpoint_score_attr=f'tune_{task_info.metric_name}',
            local_dir=tune_dir,
        )
        print("Best hyperparameters found were: ", analysis.best_config)
        print("Best checkpoint: ", analysis.best_checkpoint)

        args_vars = vars(args)
        args_vars.update(analysis.best_config)
        model = model_class.load_from_checkpoint(os.path.join(
            analysis.best_checkpoint, "tune.ckpt"),
                                                 hparams=args,
                                                 **model_kwargs)

        pl_loggers = [
            loggers.CSVLogger(save_dir=tune.get_trial_dir(),
                              name="",
                              version="."),
            loggers.TensorBoardLogger(save_dir=tune.get_trial_dir(),
                                      name="",
                                      version=".",
                                      default_hp_metric=False),
        ]

        try:
            import wandb
            pl_loggers.append(
                loggers.WandbLogger(save_dir=tune_dir,
                                    project=args.project,
                                    name=tune.get_trial_name(),
                                    id=tune.get_trial_id(),
                                    offline=args.offline,
                                    group=experiment_name))
        except Exception:
            pass

        trainer: Trainer = Trainer.from_argparse_args(args, logger=pl_loggers)
        build_method(model, task_info)
        trainer.test(model)
Exemple #8
0
    def tune_train_once(config,
                        checkpoint_dir=None,
                        args: argparse.Namespace = None,
                        model_class: type = None,
                        build_method=None,
                        task_info: TaskInfo = None,
                        model_kwargs: dict = None,
                        resume: str = None,
                        group: str = None,
                        log_dir: str = None,
                        **kwargs):
        if resume is None:
            resume = 'all'
        args_vars = vars(args)
        args_vars.update(config)

        pl.seed_everything(args.seed)
        pl_loggers = [
            loggers.CSVLogger(save_dir=tune.get_trial_dir(),
                              name="",
                              version="."),
            loggers.TensorBoardLogger(save_dir=tune.get_trial_dir(),
                                      name="",
                                      version=".",
                                      default_hp_metric=False),
        ]

        try:
            import wandb
            pl_loggers.append(
                loggers.WandbLogger(save_dir=log_dir or 'tune_lightning_logs',
                                    project=args.project,
                                    name=tune.get_trial_name(),
                                    id=tune.get_trial_id(),
                                    offline=args.offline,
                                    group=group))
        except Exception:
            pass

        trainer_args = dict(
            logger=pl_loggers,
            progress_bar_refresh_rate=0,
            callbacks=[
                TuneReportCheckpointCallback(metrics={
                    f'tune_{task_info.metric_name}':
                    f'{task_info.task_name}/val_{task_info.metric_name}'
                },
                                             filename="tune.ckpt",
                                             on="validation_end")
            ])
        if checkpoint_dir and resume == 'all':
            trainer_args['resume_from_checkpoint'] = os.path.join(
                checkpoint_dir, "tune.ckpt")

        # fix slurm trainer
        os.environ["SLURM_JOB_NAME"] = "bash"
        model = model_class(args, **model_kwargs)
        build_method(model, task_info)
        trainer: Trainer = Trainer.from_argparse_args(args, **trainer_args)
        if checkpoint_dir and resume == 'model':
            ckpt = pl_load(os.path.join(checkpoint_dir, "tune.ckpt"),
                           map_location=lambda storage, loc: storage)
            model = model._load_model_state(ckpt)
            trainer.current_epoch = ckpt["epoch"]

        trainer.fit(model)