예제 #1
0
def main(argv=None):
    ''' Main entry point '''
    args = parse_args(argv)
    print(f'Running torch {torch.version.__version__}')

    profile_cuda_memory = args.config.cuda.profile_cuda_memory
    pin_memory = 'cuda' in args.device.type and not profile_cuda_memory
    dataloader = get_dataloader(args.config.data,
                                args.seed_fn,
                                pin_memory,
                                args.num_devices,
                                shuffle=args.shuffle)
    print(dataloader.dataset.stats)

    model = args.model(args.config.model, dataloader.dataset)
    action = args.action(args.action_config, model, dataloader, args.device)
    if args.action_type == 'train' and args.action_config.early_stopping:
        args.config.data.split = 'valid'
        args.config.data.max_examples = 0
        action.validation_dataloader = get_dataloader(args.config.data,
                                                      args.seed_fn,
                                                      pin_memory,
                                                      args.num_devices,
                                                      shuffle=args.shuffle)

    if args.config.cuda.profile_cuda_memory:
        print('Profiling CUDA memory')
        memory_profiler = profile.CUDAMemoryProfiler(
            action.modules.values(), filename=profile_cuda_memory)

        sys.settrace(memory_profiler)
        threading.settrace(memory_profiler)

    step = 0
    epoch = 0
    if args.restore:
        restore_modules = {
            module_name: module
            for module_name, module in action.modules.items()
            if module_name not in args.reset_parameters
        }

        epoch, step = restore(args.restore,
                              restore_modules,
                              num_checkpoints=args.average_checkpoints,
                              map_location=args.device.type,
                              strict=not args.reset_parameters)

        model.reset_named_parameters(args.reset_parameters)
        if 'step' in args.reset_parameters:
            step = 0
            epoch = 0

    args.experiment.set_step(step)

    with ExitStack() as stack:
        stack.enter_context(profiler.emit_nvtx(args.config.cuda.profile_cuda))
        stack.enter_context(set_detect_anomaly(args.detect_anomalies))
        action(epoch, args.experiment, args.verbose)
def prepare_data(args):
    train_transform_S = get_transform(train=True,
                                      dataset_name=cfg.DATASET.SOURCE)
    train_transform_T = get_transform(train=True,
                                      dataset_name=cfg.DATASET.TARGET)
    val_transform = get_transform(train=False, dataset_name=cfg.DATASET.VAL)

    train_dataset_S = eval('Dataset.%s' % cfg.DATASET.SOURCE)(
        cfg.DATASET.DATAROOT_S,
        cfg.DATASET.TRAIN_SPLIT_S,
        transform=train_transform_S)

    train_dataset_T = eval('Dataset.%s' % cfg.DATASET.TARGET)(
        cfg.DATASET.DATAROOT_T,
        cfg.DATASET.TRAIN_SPLIT_T,
        transform=train_transform_T)

    val_dataset = eval('Dataset.%s' % cfg.DATASET.VAL)(
        cfg.DATASET.DATAROOT_VAL,
        cfg.DATASET.VAL_SPLIT,
        transform=val_transform)

    # construct dataloaders
    train_dataloader_S = data_utils.get_dataloader(
        train_dataset_S,
        cfg.TRAIN.TRAIN_BATCH_SIZE,
        cfg.NUM_WORKERS,
        train=True,
        distributed=args.distributed,
        world_size=gen_utils.get_world_size())

    train_dataloader_T = data_utils.get_dataloader(
        train_dataset_T,
        cfg.TRAIN.TRAIN_BATCH_SIZE,
        cfg.NUM_WORKERS,
        train=True,
        distributed=args.distributed,
        world_size=gen_utils.get_world_size())

    val_dataloader = data_utils.get_dataloader(
        val_dataset,
        cfg.TRAIN.VAL_BATCH_SIZE,
        cfg.NUM_WORKERS,
        train=False,
        distributed=args.distributed,
        world_size=gen_utils.get_world_size())

    dataloaders = {'train_S': train_dataloader_S, \
            'train_T': train_dataloader_T, 'val': val_dataloader}

    return dataloaders
예제 #3
0
    def __call__(self) -> float:
        """
        Run the evaluation!
        """
        dataloader = get_dataloader(self.args.data,
                                    self.dataset,
                                    num_devices=len(self.model.device_ids))

        def get_description():
            return f"Eval {self.metric_store}"

        batch_iterator = tqdm(
            dataloader,
            unit="batch",
            initial=1,
            dynamic_ncols=True,
            desc=get_description(),
            file=sys.stdout,  # needed to make tqdm_wrap_stdout work
        )

        with ExitStack() as stack:
            # pylint:disable=no-member
            stack.enter_context(tqdm_wrap_stdout())
            stack.enter_context(chunked_scattering())
            # pylint:enable=no-member

            for batch in batch_iterator:
                try:
                    self.eval_step(batch)
                except RuntimeError as rte:
                    if "out of memory" in str(rte):
                        self.metric_store["oom"].update(1)
                        logging.warning(str(rte))
                    else:
                        batch_iterator.close()
                        raise rte

                batch_iterator.set_description_str(get_description())

            batch_iterator.close()

        return self.metric_store["nll"].average
def prepare_data(args):
    if cfg.TEST.DOMAIN == 'source':
        dataset_name = cfg.DATASET.SOURCE
        dataset_root = cfg.DATASET.DATAROOT_S
    else:
        dataset_name = cfg.DATASET.TARGET
        dataset_root = cfg.DATASET.DATAROOT_T

    test_transform = get_transform(dataset_name)

    dataset_split = cfg.DATASET.TEST_SPLIT
    test_dataset = eval('Dataset.%s' % dataset_name)(dataset_root,
                                                     dataset_split,
                                                     transform=test_transform)

    # construct dataloaders
    test_dataloader = data_utils.get_dataloader(test_dataset,
                                                cfg.TEST.BATCH_SIZE,
                                                cfg.NUM_WORKERS,
                                                train=False,
                                                distributed=args.distributed,
                                                world_size=args.world_size)

    return test_dataset, test_dataloader
예제 #5
0
def main():
    # max_length needs to be multiples of span_size
    # mp.set_start_method('spawn')
    args = get_cl_args()
    print(args)
    print("Number of GPUs:", torch.cuda.device_count())
    config = {
        'max_length': args.max_length,
        'span_size': args.span_size,
        'learning_rate': args.learning_rate,
        'weight_decay': args.weight_decay,
        'print_every': args.print_every,
        'save_path': args.save_path,
        'restore_path': args.restore,
        'best_save_path': args.best_model,
        'minibatch_size': args.minibatch_size,
        'num_epochs': args.num_epochs,
        'num_evaluate': args.num_evaluate,
        'hidden_size': args.hidden_size,
        'optimizer': args.optimizer,
        'dataset': args.dataset,
        'mode': args.mode,
        'evaluate_path': args.evaluate_path,
        'seed': args.seed,
        'shuffle': args.shuffle,
        'batch_size_buffer': args.batch_size_buffer,
        'batch_method': args.batch_method,
        'lr_decay': args.lr_decay,
        'experiment_path': args.experiment_path,
        'save_loss_every': args.save_loss_every,
        'save_checkpoint_every': args.save_checkpoint_every,
        'length_penalty': args.length_penalty,
        'drop_last': args.drop_last,
        'beam_width': args.beam_width,
        'beam_search_all': args.beam_search_all,
        'clip': args.clip,
        'search_method': args.search_method,
        'eval_when_train': args.eval_when_train,
        'filter': args.filter,
        'detokenize': args.detokenize,
        'rnn_type': args.rnn_type,
        'num_layers': args.num_layers,
        'teacher_forcing_ratio': args.teacher_forcing_ratio,
        'num_directions': args.num_directions,
        'trim': args.trim,
        'init_rnn': args.init_rnn,
        'lr_milestone': args.lr_milestone,
        'lr_scheduler_type': args.lr_scheduler_type,
        'eps': args.eps,
        'label_smoothing': args.label_smoothing,
        'more_decoder_layers': args.more_decoder_layers,
        'new_lr_scheduler': args.new_lr_scheduler,
        'average_checkpoints': args.average_checkpoints,
        'start_epoch': args.start_epoch,
        'end_epoch': args.end_epoch,
        'restore': args.restore,
        'accumulate_steps': args.accumulate_steps,
        'reverse': args.reverse,
        'preprocess_directory': args.preprocess_directory,
        'preprocess_buffer_size': args.preprocess_buffer_size
    }

    # config dataloader

    datasets = {"WMT": WMTDataset, "IWSLT": IWSLTDataset}
    dataset = datasets[args.dataset]
    profile_cuda_memory = args.profile_cuda_memory
    pin_memory = 'cuda' in DEVICE.type and not profile_cuda_memory

    if args.seed is not None:
        args.seed_fn = get_random_seed_fn(args.seed)
        args.seed_fn()
    else:
        args.seed_fn = None

    dataloader_train = get_dataloader(dataset,
                                      config,
                                      "train",
                                      args.seed_fn,
                                      pin_memory,
                                      NUM_DEVICES,
                                      shuffle=args.shuffle)

    dataloader_valid = get_dataloader(dataset,
                                      config,
                                      "valid",
                                      args.seed_fn,
                                      pin_memory,
                                      NUM_DEVICES,
                                      shuffle=args.shuffle)

    dataloader_test = get_dataloader(dataset,
                                     config,
                                     "test",
                                     args.seed_fn,
                                     pin_memory,
                                     NUM_DEVICES,
                                     shuffle=args.shuffle)

    # define the models

    torch.cuda.empty_cache()

    encoder1 = RNMTPlusEncoderRNN(
        dataloader_train.dataset.num_words,
        args.hidden_size,
        num_layers=args.num_layers,
        dropout_p=args.dropout,
        # max_length=args.max_length,
        rnn_type=args.rnn_type,
        num_directions=args.num_directions).to(DEVICE)
    attn_decoder1 = RNMTPlusDecoderRNN(
        args.hidden_size,
        dataloader_train.dataset.num_words,
        num_layers=args.num_layers,
        dropout_p=args.dropout,
        # max_length=args.max_length,
        span_size=args.span_size,
        rnn_type=args.rnn_type,
        num_directions=args.num_directions).to(DEVICE)
    if args.init_rnn:
        encoder1.init_rnn()
        attn_decoder1.init_rnn()

    models = {'encoder': encoder1, 'decoder': attn_decoder1}

    if args.track:
        experiment = Experiment(
            project_name="rnn-nmt-syntax",
            workspace="umass-nlp",
            auto_metric_logging=False,
            auto_output_logging=None,
            auto_param_logging=False,
            log_git_metadata=False,
            log_git_patch=False,
            log_env_details=False,
            log_graph=False,
            log_code=False,
            parse_args=False,
        )

        experiment.log_parameters(config)
    else:
        experiment = None

    if args.mode == "train":
        trainer = Trainer(config=config,
                          models=models,
                          dataloader=dataloader_train,
                          dataloader_valid=dataloader_valid,
                          experiment=experiment)
        if args.restore is not None:
            trainer.restore_checkpoint(args.experiment_path + args.restore)
        trainer.train()
    elif args.mode == "evaluate":
        evaluator = Evaluator(config=config,
                              models=models,
                              dataloader=dataloader_valid,
                              experiment=experiment)
        if args.restore is not None:
            evaluator.restore_checkpoint(args.experiment_path + args.restore)
        preds = evaluator.evaluate(args.search_method)
        save_predictions(preds, args.evaluate_path, args.detokenize)
    elif args.mode == "evaluate_train":
        evaluator = Evaluator(config=config,
                              models=models,
                              dataloader=dataloader_train,
                              experiment=experiment)
        if args.restore is not None:
            evaluator.restore_checkpoint(args.experiment_path + args.restore)
        preds = evaluator.evaluate(args.search_method)
        save_predictions(preds, args.evaluate_path, args.detokenize)
    elif args.mode == "test":
        evaluator = Evaluator(config=config,
                              models=models,
                              dataloader=dataloader_test,
                              experiment=experiment)
        if args.restore is not None:
            evaluator.restore_checkpoint(args.experiment_path + args.restore)
        preds = evaluator.evaluate(args.search_method)
        save_predictions(preds, args.evaluate_path, args.detokenize)
예제 #6
0
# -*- coding: utf-8 -*-
"""
Created on Tue May 26 19:29:34 2020

@author: vws
"""
from data.utils import get_dataloader
from data.Alphabet import Alphabet


# %%
alphabet = Alphabet("data/english_alphabet.txt")

dataloader = get_dataloader("librispeech", 
                            batch_size=8, 
                            use_cuda=False, 
                            alphabet=alphabet,
                            n_features=128,
                            split="train")

# %%
X, y, X_lens, y_lens = next(iter(dataloader))

# %%
import matplotlib.pyplot as plt
from data.utils import get_dataset
data_dir = "librispeech"
dataset = get_dataset(data_dir, download=True, split="train", as_audiodataset=True)

# %%
sample = dataset[0]
예제 #7
0
    def __call__(self):
        """
        Run the training!
        """
        # Must be called first
        self.try_init_amp()

        model = self.modules["model"]
        optimizer = self.modules["optimizer"]
        scheduler = self.modules["scheduler"]

        if self.args.optim.use_gradient_checkpointing:
            model.enable_gradient_checkpointing()

        model = nn.DataParallel(model)
        dataloader = get_dataloader(
            self.args.data,
            self.dataset,
            num_devices=len(model.device_ids),
            shuffle=True,
        )

        def get_description():
            return f"Train {self.metric_store}"

        max_steps = self.args.optim.max_steps
        accumulation_steps = self.args.optim.gradient_accumulation_steps
        progress = tqdm(
            unit="step",
            initial=self.step,
            dynamic_ncols=True,
            desc=get_description(),
            total=max_steps,
            file=sys.stdout,  # needed to make tqdm_wrap_stdout work
        )

        with ExitStack() as stack:
            # pylint:disable=no-member
            stack.enter_context(tqdm_wrap_stdout())
            stack.enter_context(chunked_scattering())
            stack.enter_context(self.experiment.train())
            # pylint:enable=no-member

            if self.args.optim.early_stopping:
                # If using early stopping, must evaluate regularly to determine
                # if training should stop early, so setup an Evaluator
                eval_args = copy.deepcopy(self.args)
                eval_args.data.batch_size = self.args.optim.eval_batch_size

                evaluator = Evaluator(eval_args)
                evaluator.model = model
                evaluator.load_dataset("validation")
                evaluator.initialize_experiment(experiment=self.experiment)

                # Make sure we are tracking validation nll
                self.metric_store.add(
                    metrics.Metric("vnll", "format_float", "g(m)"))

                # And store a local variable for easy access
                vnll_metric = self.metric_store["vnll"]

            loss = 0
            num_tokens = 0
            for step, batch in enumerate(cycle(dataloader), 1):
                try:
                    step_loss = self.compute_gradients_and_loss(
                        batch, model, optimizer)
                    run_optimizer = (step % accumulation_steps) == 0

                    if run_optimizer:
                        # Run an optimization step
                        optimizer.step()
                        scheduler.step()  # Update learning rate schedule
                        model.zero_grad()

                    # Update loss and num tokens after running an optimization
                    # step, in case it results in an out of memory error
                    loss += step_loss
                    num_tokens += batch["num_tokens"]

                    if run_optimizer:
                        # Since we ran the optimizer, increment current step
                        self.step += 1
                        self.experiment.set_step(self.step)
                        progress.update()

                        # update our metrics as well
                        self.update_metrics(
                            loss / accumulation_steps,
                            num_tokens,
                            scheduler.get_lr()[0],
                        )
                        num_tokens = 0
                        loss = 0

                        # and finally check if we should save
                        if (self.args.save_steps > 0
                                and self.step % self.args.save_steps == 0):
                            # First save the current checkpoint
                            self.save()

                            # Then if we are implementing early stopping, see
                            # if we achieved a new best
                            if self.args.optim.early_stopping:
                                evaluator.reset_metrics()
                                with ExitStack() as eval_stack:
                                    # pylint:disable=no-member
                                    eval_stack.enter_context(
                                        tqdm_unwrap_stdout())
                                    eval_stack.enter_context(
                                        release_cuda_memory(
                                            collect_tensors(optimizer.state)))
                                    # pylint:enable=no-member

                                    vnll = evaluator()
                                    vnll_metric.update(vnll)

                                    # Save the updated metrics
                                    self.save_metrics()

                                    if vnll == vnll_metric.min:
                                        self.on_new_best()

                                # Try to combat OOM errors caused by doing evaluation
                                # in the same loop with training. This manifests in out
                                # of memory errors after the first or second evaluation
                                # run.
                                refresh_cuda_memory()

                            if not self.prune_checkpoints():
                                logging.info("Stopping early")
                                break

                            if self.step >= max_steps:
                                logging.info("Finished training")
                                break

                except RuntimeError as rte:
                    if "out of memory" in str(rte):
                        self.metric_store["oom"].update(1)
                        logging.warning(str(rte))
                    else:
                        progress.close()
                        raise rte

                progress.set_description_str(get_description())

            progress.close()